training Pytorch models with onnxruntime
Project description
Accelerate PyTorch models with ONNX Runtime
ONNX Runtime for PyTorch accelerates PyTorch model training using ONNX Runtime.
It is available via the torch-ort python package.
This repository contains the source code for the package as well as instructions for running the package and samples demonstrating how to do so.
Pre-requisites
You need a machine with at least one NVIDIA or AMD GPU to run ONNX Runtime for PyTorch.
You can install and run torch-ort in your local environment, or with Docker.
Run in a Python environment
Default dependencies
By default, torch-ort depends on PyTorch 1.8.1, ONNX Runtime 1.8 and CUDA 10.2.
-
Install CUDA 10.2
-
Install CuDNN 7.6
-
Install torch-ort and dependencies
pip install ninja
pip install torch-ort
-
Run post-installation script for ORTModule
python -m torch_ort.configure
Explicitly install for NVIDIA CUDA 10.2
-
Install CUDA 10.2
-
Install CuDNN 7.6
-
Install torch-ort and dependencies
pip install ninja
pip install torch==1.8.1
pip install --pre onnxruntime-training -f https://onnxruntimepackages.z14.web.core.windows.net/onnxruntime_stable_cu102.html
- (or
pip install --pre onnxruntime-training -f https://onnxruntimepackages.z14.web.core.windows.net/onnxruntime_nightly_cu102.html
to use nightly build) pip install torch-ort
Explicitly install for NVIDIA CUDA 11.1
-
Install CUDA 11.1
-
Install CuDNN 8.0
-
Install torch-ort and dependencies
pip install ninja
pip install torch==1.8.1
pip install --pre onnxruntime-training -f https://onnxruntimepackages.z14.web.core.windows.net/onnxruntime_stable_cu111.html
- (or
pip install --pre onnxruntime-training -f https://onnxruntimepackages.z14.web.core.windows.net/onnxruntime_nightly_cu111.html
to use nightly build) pip install torch-ort
python -m torch_ort.configure
Explicitly install for AMD ROCm 4.2
-
Install ROCm 4.2 base package (instructions)
-
Install ROCm 4.2 libraries (instructions)
-
Install ROCm 4.2 RCCL (instructions)
-
Install torch-ort and dependencies
pip install ninja
pip install --pre torch -f https://download.pytorch.org/whl/nightly/rocm4.2/torch_nightly.html
pip install --pre onnxruntime-training -f https://onnxruntimepackages.z14.web.core.windows.net/onnxruntime_stable_rocm42.html
- (or
pip install --pre onnxruntime-training -f https://onnxruntimepackages.z14.web.core.windows.net/onnxruntime_nightly_rocm42.html
to use nightly build) pip install torch-ort
python -m torch_ort.configure
Use torch-ort from nightly build
to use torch-ort from nightly build, replace
pip install torch-ort
with
pip install -U --pre torch-ort -f https://onnxruntimepackages.z14.web.core.windows.net/torch_ort_nightly.html
Run using Docker
On NVIDIA CUDA 11.1
The docker directory contains dockerfiles for the NVIDIA CUDA 11.1 configuration.
-
Build the docker image
docker build -f Dockerfile.ort-cu111-cudnn8-devel-ubuntu18.04 -t ort.cu111 .
-
Run the docker container using the image you have just built
docker run -it --gpus all --name my-experiments ort.cu111:latest /bin/bash
On AMD Rocm 4.2
The docker directory contains dockerfiles for the NVIDIA CUDA 11.1 configuration.
-
Build the docker image
docker build -f Dockerfile.ort-rocm4.2-pytorch1.8.1-ubuntu18.04 -t ort.rocm42 .
-
Run the docker container using the image you have just built
docker run -it --rm \ --privileged \ --device=/dev/kfd \ --device=/dev/dri \ --group-add video \ --cap-add=SYS_PTRACE \ --security-opt seccomp=unconfined \ --name my-experiments \ ort.rocm42:latest /bin/bash
Test your installation
- Clone this repo
git clone git@github.com:pytorch/ort.git
- Install extra dependencies
pip install wget pandas sklearn transformers
- Run the training script
python ./ort/tests/bert_for_sequence_classification.py
Add ONNX Runtime for PyTorch to your PyTorch training script
from torch_ort import ORTModule
model = ORTModule(model)
# PyTorch training script follows
License
This project has an MIT license, as found in the LICENSE file.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
File details
Details for the file torch_ort-1.8.1-py3-none-any.whl
.
File metadata
- Download URL: torch_ort-1.8.1-py3-none-any.whl
- Upload date:
- Size: 7.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.3.0 pkginfo/1.7.0 requests/2.25.1 setuptools/57.0.0 requests-toolbelt/0.9.1 tqdm/4.48.1 CPython/3.6.9
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 26c895932a864423501487f917ed5f3b33a6a8e46c3b9a8daf9203980f257ebf |
|
MD5 | 2adfcd5f15e56e4c6d0d01c1dc0c6573 |
|
BLAKE2b-256 | 6c00072cf878d8fc82ec79edc86ce8f506bfca69c12e42ba1b7b92f4c0b9517f |