Skip to main content

training Pytorch models with onnxruntime

Project description

Train PyTorch models with ONNX Runtime

PyTorch/ORT is a Python package that uses ONNX Runtime to accelerate PyTorch model training.

Pre-requisites

You need a machine with at least one NVIDIA or AMD GPU to run PyTorch/ORT.

You can install run PyTorch/ORT in your local environment, or with Docker. If you are using Docker, the following base image is suitable for Nvidia and AMD respectively : nvidia/cuda:11.1.1-cudnn8-devel-ubuntu18.04 or rocm/pytorch:rocm4.1.1_ubuntu18.04_py3.6_pytorch.

Install for Nvidia GPUs

  1. Install CUDA

  2. Install CuDNN

  3. Install PyTorch/ORT and dependencies

Nvidia CUDA version 11.1

  • pip install onnx ninja
  • pip install --pre torch -f https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html
  • pip install --pre onnxruntime-training -f https://onnxruntimepackages.z14.web.core.windows.net/onnxruntime_nightly_cu111.html
  • pip install torch-ort

Nvidia CUDA version 10.2

  • pip install onnx ninja
  • pip install --pre torch -f https://download.pytorch.org/whl/nightly/cu102/torch_nightly.html
  • pip install --pre onnxruntime-training -f https://onnxruntimepackages.z14.web.core.windows.net/onnxruntime_nightly_cu102.html
  • pip install torch-ort

Install for AMD GPUs

  1. Install Rocm 4.1 base package (instructions)

  2. Install Rocm 4.1 libraries (instructions)

  3. Install Rocm 4.1 RCCL (instructions)

  4. Install PyTorch/ORT and dependencies

AMD ROCM version 4.1

  • pip install onnx ninja
  • pip install --pre torch -f https://download.pytorch.org/whl/nightly/rocm4.1/torch_nightly.html
  • pip install --pre onnxruntime-training -f https://onnxruntimepackages.z14.web.core.windows.net/onnxruntime_nightly_rocm41.html
  • pip install torch-ort

to install release package of onnxruntime-training:

  • pip install onnxruntime-training

Test your installation

  1. Clone this repo
  • git clone git@github.com:pytorch/ort.git
  1. Install extra dependencies
  • pip install wget pandas sklearn transformers
  1. Run the training script
  • python ./ort/tests/bert_for_sequence_classification.py

Add PyTorch/ORT to your PyTorch training script

import onnxruntime
from torch_ort import ORTModule
model = ORTModule(model)
# PyTorch training script follows

Versioning

CUDA

The PyTorch/ORT package was built with CUDA 11.1. If you have a different version of CUDA installed, you should install the CUDA 11.1 toolkit.

This is a limitation that will be removed in the next release.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distributions

No source distribution files available for this release.See tutorial on generating distribution archives.

Built Distribution

torch_ort-0.0.10.dev20210504-py3-none-any.whl (3.4 kB view details)

Uploaded Python 3

File details

Details for the file torch_ort-0.0.10.dev20210504-py3-none-any.whl.

File metadata

  • Download URL: torch_ort-0.0.10.dev20210504-py3-none-any.whl
  • Upload date:
  • Size: 3.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.4.1 importlib_metadata/4.0.1 pkginfo/1.7.0 requests/2.22.0 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.8.5

File hashes

Hashes for torch_ort-0.0.10.dev20210504-py3-none-any.whl
Algorithm Hash digest
SHA256 9226e5b446781d5ae028a52d8a292dd82ea8c41baeb3952aec302e9af0b4c43e
MD5 4a3994fe036585ea6746584cba34bb09
BLAKE2b-256 07054cbe171129e1da7e99d28eb535489e8430c8209190ce27b58a53dc3a48ac

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page