Skip to main content

MindTorch is a toolkit for support the PyTorch model running on Ascend.

Project description

Introduction

MindTorch is MindSpore tool for adapting the PyTorch interface, which is designed to make PyTorch code perform efficiently on Ascend without changing the habits of the original PyTorch users.

MindTorch-architecture

Install

MindTorch has some prerequisites that need to be installed first, including MindSpore, PIL, NumPy.

# for last stable version
pip install mindtorch

# for latest release candidate
pip install --upgrade --pre mindtorch

Alternatively, you can install the latest or development version by directly pulling from OpenI:

pip3 install git+https://openi.pcl.ac.cn/OpenI/MSAdapter.git

User guide

For data processing and model building, MindTorch can be used in the same way as PyTorch, while the model training part of the code needs to be customized, as shown in the following example.

  1. Data processing (only modify the import package)

from mindtorch.torch.utils.data import DataLoader
from mindtorch.torchvision import datasets, transforms

transform = transforms.Compose([transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.2435, 0.2616])
                               ])
train_images = datasets.CIFAR10('./', train=True, download=True, transform=transform)
train_data = DataLoader(train_images, batch_size=128, shuffle=True, num_workers=2, drop_last=True)
  1. Model construction (modify import package only)

from mindtorch.torch.nn import Module, Linear, Flatten

class MLP(Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.flatten = Flatten()
        self.line1 = Linear(in_features=1024, out_features=64)
        self.line2 = Linear(in_features=64, out_features=128, bias=False)
        self.line3 = Linear(in_features=128, out_features=10)

    def forward(self, inputs):
        x = self.flatten(inputs)
        x = self.line1(x)
        x = self.line2(x)
        x = self.line3(x)
        return x

3.Model training (custom training)

import mindtorch.torch as torch
import mindtorch.torch.nn as nn
import mindspore as ms

net = MLP()
net.train()
epochs = 500
criterion = nn.CrossEntropyLoss()
optimizer = ms.nn.SGD(net.trainable_params(), learning_rate=0.01, momentum=0.9, weight_decay=0.0005)

# Define the training process
loss_net = ms.nn.WithLossCell(net, criterion)
train_net = ms.nn.TrainOneStepCell(loss_net, optimizer)

for i in range(epochs):
    for X, y in train_data:
        res = train_net(X, y)
        print("epoch:{}, loss:{:.6f}".format(i, res.asnumpy()))
# Save model
ms.save_checkpoint(net, "save_path.ckpt")

License

MindTorch is released under the Apache 2.0 license.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

mindtorch-0.2.0.tar.gz (682.5 kB view details)

Uploaded Source

Built Distribution

mindtorch-0.2.0-py3-none-any.whl (891.5 kB view details)

Uploaded Python 3

File details

Details for the file mindtorch-0.2.0.tar.gz.

File metadata

  • Download URL: mindtorch-0.2.0.tar.gz
  • Upload date:
  • Size: 682.5 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: Python-urllib/3.6

File hashes

Hashes for mindtorch-0.2.0.tar.gz
Algorithm Hash digest
SHA256 10b73985a7bc0bf5e1783df63186468053840e2bfb79c04a144758270e79e57f
MD5 02354f116da080b8da1e27fad0b7ece2
BLAKE2b-256 14c973226420cd73854841be830628acf0da0b03bb49b9ddffb5eb7f5dab83dd

See more details on using hashes here.

File details

Details for the file mindtorch-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: mindtorch-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 891.5 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: Python-urllib/3.6

File hashes

Hashes for mindtorch-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 f371a33a5094864b5956584280b29880e596f62760f41130b1ce704df756460f
MD5 21c44e9c8f903e90e45b6c02632ff8b9
BLAKE2b-256 3641830c27b99df7b2d22a2ab59c7c134d04627838e9538d8a4b4139f818e7f0

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page