MindTorch is a toolkit for support the PyTorch model running on Ascend.
Project description
Introduction
MindTorch is MindSpore tool for adapting the PyTorch interface, which is designed to make PyTorch code perform efficiently on Ascend without changing the habits of the original PyTorch users.
Install
MindTorch has some prerequisites that need to be installed first, including MindSpore, PIL, NumPy.
# for last stable version
pip install mindtorch
# for latest release candidate
pip install --upgrade --pre mindtorch
Alternatively, you can install the latest or development version by directly pulling from OpenI:
pip3 install git+https://openi.pcl.ac.cn/OpenI/MSAdapter.git
User guide
For data processing and model building, MindTorch can be used in the same way as PyTorch, while the model training part of the code needs to be customized, as shown in the following example.
Data processing (only modify the import package)
from mindtorch.torch.utils.data import DataLoader
from mindtorch.torchvision import datasets, transforms
transform = transforms.Compose([transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.2435, 0.2616])
])
train_images = datasets.CIFAR10('./', train=True, download=True, transform=transform)
train_data = DataLoader(train_images, batch_size=128, shuffle=True, num_workers=2, drop_last=True)
Model construction (modify import package only)
from mindtorch.torch.nn import Module, Linear, Flatten
class MLP(Module):
def __init__(self):
super(MLP, self).__init__()
self.flatten = Flatten()
self.line1 = Linear(in_features=1024, out_features=64)
self.line2 = Linear(in_features=64, out_features=128, bias=False)
self.line3 = Linear(in_features=128, out_features=10)
def forward(self, inputs):
x = self.flatten(inputs)
x = self.line1(x)
x = self.line2(x)
x = self.line3(x)
return x
3.Model training (custom training)
import mindtorch.torch as torch
import mindtorch.torch.nn as nn
import mindspore as ms
net = MLP()
net.train()
epochs = 500
criterion = nn.CrossEntropyLoss()
optimizer = ms.nn.SGD(net.trainable_params(), learning_rate=0.01, momentum=0.9, weight_decay=0.0005)
# Define the training process
loss_net = ms.nn.WithLossCell(net, criterion)
train_net = ms.nn.TrainOneStepCell(loss_net, optimizer)
for i in range(epochs):
for X, y in train_data:
res = train_net(X, y)
print("epoch:{}, loss:{:.6f}".format(i, res.asnumpy()))
# Save model
ms.save_checkpoint(net, "save_path.ckpt")
License
MindTorch is released under the Apache 2.0 license.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file mindtorch-0.2.0.tar.gz
.
File metadata
- Download URL: mindtorch-0.2.0.tar.gz
- Upload date:
- Size: 682.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: Python-urllib/3.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 10b73985a7bc0bf5e1783df63186468053840e2bfb79c04a144758270e79e57f |
|
MD5 | 02354f116da080b8da1e27fad0b7ece2 |
|
BLAKE2b-256 | 14c973226420cd73854841be830628acf0da0b03bb49b9ddffb5eb7f5dab83dd |
File details
Details for the file mindtorch-0.2.0-py3-none-any.whl
.
File metadata
- Download URL: mindtorch-0.2.0-py3-none-any.whl
- Upload date:
- Size: 891.5 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: Python-urllib/3.6
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | f371a33a5094864b5956584280b29880e596f62760f41130b1ce704df756460f |
|
MD5 | 21c44e9c8f903e90e45b6c02632ff8b9 |
|
BLAKE2b-256 | 3641830c27b99df7b2d22a2ab59c7c134d04627838e9538d8a4b4139f818e7f0 |