Skip to main content

MindTorch is a toolkit for support the PyTorch model running on Ascend.

Project description

Introduction
=============
MindTorch is MindSpore tool for adapting the PyTorch interface, which is designed to make PyTorch code perform efficiently on Ascend without changing the habits of the original PyTorch users.

|MindTorch-architecture|

Install
=======

MindTorch has some prerequisites that need to be installed first, including MindSpore, PIL, NumPy.

.. code:: bash

# for last stable version
pip install mindtorch

# for latest release candidate
pip install --upgrade --pre mindtorch

Alternatively, you can install the latest or development version by directly pulling from OpenI:

.. code:: bash

pip3 install git+https://openi.pcl.ac.cn/OpenI/MSAdapter.git

User guide
===========
For data processing and model building, MindTorch can be used in the same way as PyTorch, while the model training part of the code needs to be customized, as shown in the following example.

1. Data processing (only modify the import package)

.. code:: python

from mindtorch.torch.utils.data import DataLoader
from mindtorch.torchvision import datasets, transforms

transform = transforms.Compose([transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.2435, 0.2616])
])
train_images = datasets.CIFAR10('./', train=True, download=True, transform=transform)
train_data = DataLoader(train_images, batch_size=128, shuffle=True, num_workers=2, drop_last=True)

2. Model construction (modify import package only)

.. code:: python

from mindtorch.torch.nn import Module, Linear, Flatten

class MLP(Module):
def __init__(self):
super(MLP, self).__init__()
self.flatten = Flatten()
self.line1 = Linear(in_features=1024, out_features=64)
self.line2 = Linear(in_features=64, out_features=128, bias=False)
self.line3 = Linear(in_features=128, out_features=10)

def forward(self, inputs):
x = self.flatten(inputs)
x = self.line1(x)
x = self.line2(x)
x = self.line3(x)
return x

3.Model training (custom training)

.. code:: python

import mindtorch.torch as torch
import mindtorch.torch.nn as nn
import mindspore as ms

net = MLP()
net.train()
epochs = 500
criterion = nn.CrossEntropyLoss()
optimizer = ms.nn.SGD(net.trainable_params(), learning_rate=0.01, momentum=0.9, weight_decay=0.0005)

# Define the training process
loss_net = ms.nn.WithLossCell(net, criterion)
train_net = ms.nn.TrainOneStepCell(loss_net, optimizer)

for i in range(epochs):
for X, y in train_data:
res = train_net(X, y)
print("epoch:{}, loss:{:.6f}".format(i, res.asnumpy()))
# Save model
ms.save_checkpoint(net, "save_path.ckpt")


License
=======

MindTorch is released under the Apache 2.0 license.

.. |MindTorch-architecture| image:: https://openi.pcl.ac.cn/laich/pose_data/raw/branch/master/MSA_F.png

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mindtorch-0.2.1.tar.gz (696.1 kB view details)

Uploaded Source

Built Distribution

mindtorch-0.2.1-py2.py3-none-any.whl (884.9 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file mindtorch-0.2.1.tar.gz.

File metadata

  • Download URL: mindtorch-0.2.1.tar.gz
  • Upload date:
  • Size: 696.1 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.18

File hashes

Hashes for mindtorch-0.2.1.tar.gz
Algorithm Hash digest
SHA256 0fe2b2a1f8cd5bb2b63b4a2c6305d3351fcc7e8117943accfda942ac13825312
MD5 cde34626eb1ea0051cc0a803bf3c16b1
BLAKE2b-256 f620fe965db4b0e26de0c9e45c3dcba8bdb004eb307789c83953ed6b60dd4ac9

See more details on using hashes here.

File details

Details for the file mindtorch-0.2.1-py2.py3-none-any.whl.

File metadata

  • Download URL: mindtorch-0.2.1-py2.py3-none-any.whl
  • Upload date:
  • Size: 884.9 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/4.0.2 CPython/3.9.18

File hashes

Hashes for mindtorch-0.2.1-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 fa2cc91d574f45d0f959e28433f00dc42b6a5d5860ba65660e91b7a2ebe24f75
MD5 6580728462a68f5ccdaa3ea3588a6f76
BLAKE2b-256 724164f9d74944f613d9a823ecbcc24fd3542bb64114c9e51ff07cfe0fc5de5c

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page