training Pytorch models with onnxruntime
Project description
Train PyTorch models with ONNX Runtime
PyTorch/ORT is a Python package that uses ONNX Runtime to accelerate PyTorch model training.
Pre-requisites
You need a machine with at least one NVDIA GPU to run PyTorch/ORT.
You can install run PyTorch/ORT in your local environment, or with Docker. If you are using Docker, the following base image is suitable: nvidia/cuda:11.1.1-cudnn8-devel-ubuntu18.04
.
Install
-
Install CUDA
-
Install CuDNN
-
Install PyTorch/ORT and dependencies
pip install onnx ninja
pip install --pre torch -f https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html
pip install --pre onnxruntime-training -f https://onnxruntimepackages.z14.web.core.windows.net/onnxruntime_nightly.html
pip install torch-ort
to install release package of onnxruntime-training:
pip install onnxruntime-training
Test your installation
- Clone this repo
git clone git@github.com:pytorch/ort.git
- Install extra dependencies
pip install wget pandas sklearn transformers
- Run the training script
python ./ort/tests/bert_for_sequence_classification.py
Add PyTorch/ORT to your PyTorch training script
import onnxruntime
from torch_ort import ORTModule
model = ORTModule(model)
# PyTorch training script follows
Versioning
CUDA
The PyTorch/ORT package was built with CUDA 11.1. If you have a different version of CUDA installed, you should install the CUDA 11.1 toolkit.
This is a limitation that will be removed in the next release.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
File details
Details for the file torch_ort-0.0.10.dev20210419-py3-none-any.whl
.
File metadata
- Download URL: torch_ort-0.0.10.dev20210419-py3-none-any.whl
- Upload date:
- Size: 3.1 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.4.1 importlib_metadata/4.0.0 pkginfo/1.7.0 requests/2.22.0 requests-toolbelt/0.9.1 tqdm/4.60.0 CPython/3.8.5
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | fda80edbb1bd9e495966d5c570e45232c956b4900eb31bd35c4b3bd2eebcc42d |
|
MD5 | 78a735393dde13568de6fd7db89de33d |
|
BLAKE2b-256 | c1badb190e5beee462e55972012315da199dfd3320017b27ec1381c29a723f61 |