Transformer acceleration library
Project description
Transformer Engine
Quickstart | Installation | User Guide | Examples | FP8 Convergence | Integrations | Release notes
Latest News
[03/2024] Turbocharged Training: Optimizing the Databricks Mosaic AI stack with FP8
[03/2024] FP8 Training Support in SageMaker Model Parallelism Library
[12/2023] New NVIDIA NeMo Framework Features and NVIDIA H200
[11/2023] Inflection-2: The Next Step Up
[11/2023] Unleashing The Power Of Transformers With NVIDIA Transformer Engine
[09/2023] Transformer Engine added to AWS DL Container for PyTorch Training
[06/2023] Breaking MLPerf Training Records with NVIDIA H100 GPUs
[04/2023] Benchmarking Large Language Models on NVIDIA H100 GPUs with CoreWeave (Part 1)
What is Transformer Engine?
Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper GPUs, to provide better performance with lower memory utilization in both training and inference. TE provides a collection of highly optimized building blocks for popular Transformer architectures and an automatic mixed precision-like API that can be used seamlessly with your framework-specific code. TE also includes a framework agnostic C++ API that can be integrated with other deep learning libraries to enable FP8 support for Transformers.
As the number of parameters in Transformer models continues to grow, training and inference for architectures such as BERT, GPT and T5 become very memory and compute-intensive. Most deep learning frameworks train with FP32 by default. This is not essential, however, to achieve full accuracy for many deep learning models. Using mixed-precision training, which combines single-precision (FP32) with lower precision (e.g. FP16) format when training a model, results in significant speedups with minimal differences in accuracy as compared to FP32 training. With Hopper GPU architecture FP8 precision was introduced, which offers improved performance over FP16 with no degradation in accuracy. Although all major deep learning frameworks support FP16, FP8 support is not available natively in frameworks today.
TE addresses the problem of FP8 support by providing APIs that integrate with popular Large Language Model (LLM) libraries. It provides a Python API consisting of modules to easily build a Transformer layer as well as a framework-agnostic library in C++ including structs and kernels needed for FP8 support. Modules provided by TE internally maintain scaling factors and other values needed for FP8 training, greatly simplifying mixed precision training for users.
Highlights
Easy-to-use modules for building Transformer layers with FP8 support
Optimizations (e.g. fused kernels) for Transformer models
Support for FP8 on NVIDIA Hopper and NVIDIA Ada GPUs
Support for optimizations across all precisions (FP16, BF16) on NVIDIA Ampere GPU architecture generations and later
Examples
PyTorch
import torch
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
# Set dimensions.
in_features = 768
out_features = 3072
hidden_size = 2048
# Initialize model and inputs.
model = te.Linear(in_features, out_features, bias=True)
inp = torch.randn(hidden_size, in_features, device="cuda")
# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3)
# Enable autocasting for the forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
out = model(inp)
loss = out.sum()
loss.backward()
JAX
Flax
import flax
import jax
import jax.numpy as jnp
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.common import recipe
BATCH = 32
SEQLEN = 128
HIDDEN = 1024
# Initialize RNG and inputs.
rng = jax.random.PRNGKey(0)
init_rng, data_rng = jax.random.split(rng)
inp = jax.random.normal(data_rng, [BATCH, SEQLEN, HIDDEN], jnp.float32)
# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID)
# Enable autocasting for the forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
model = te_flax.DenseGeneral(features=HIDDEN)
def loss_fn(params, other_vars, inp):
out = model.apply({'params':params, **other_vars}, inp)
return jnp.mean(out)
# Initialize models.
variables = model.init(init_rng, inp)
other_variables, params = flax.core.pop(variables, 'params')
# Construct the forward and backward function
fwd_bwd_fn = jax.value_and_grad(loss_fn, argnums=(0, 1))
for _ in range(10):
loss, (param_grads, other_grads) = fwd_bwd_fn(params, other_variables, inp)
Installation
Pre-requisites
Linux x86_64
CUDA 12.0+ for Hopper and CUDA 12.1+ for Ada
NVIDIA Driver supporting CUDA 12.0 or later
cuDNN 8.1 or later
For fused attention, CUDA 12.1 or later, NVIDIA Driver supporting CUDA 12.1 or later, and cuDNN 8.9 or later.
Docker
The quickest way to get started with Transformer Engine is by using Docker images on NVIDIA GPU Cloud (NGC) Catalog. For example to use the NGC PyTorch container interactively,
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.10-py3
Where 23.10 is the container version. For example, 23.10 for the October 2023 release.
pip
To install the latest stable version of Transformer Engine,
pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch,paddle).
Alternatively, the package can be directly installed from Transformer Engine’s PyPI, e.g.
pip install transformer_engine[pytorch]
To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch,paddle]). Transformer Engine ships wheels for the core library as well as the PaddlePaddle extensions. Source distributions are shipped for the JAX and PyTorch extensions.
From source
Compiling with FlashAttention-2
Transformer Engine release v0.11.0 adds support for FlashAttention-2 in PyTorch for improved performance.
It is a known issue that FlashAttention-2 compilation is resource-intensive and requires a large amount of RAM (see bug), which may lead to out of memory errors during the installation of Transformer Engine. Please try setting MAX_JOBS=1 in the environment to circumvent the issue.
Note that NGC PyTorch 23.08+ containers include FlashAttention-2.
Breaking Changes
v1.7: Padding mask definition for PyTorch
In an effort to unify the definition and usage of the attention mask across all three frameworks in Transformer Engine, the padding mask has changed from True meaning inclusion of the corresponding position in attention to exclusion of that position in our PyTorch implementation. Since v1.7, all attention mask types follow the same definition where True means masking out the corresponding position and False means including that position in attention calculation.
An example of this change is,
# for a batch of 3 sequences where `a`s, `b`s and `c`s are the useful tokens
# and `0`s are the padding tokens,
[a, a, a, 0, 0,
b, b, 0, 0, 0,
c, c, c, c, 0]
# the padding mask for this batch before v1.7 is,
[ True, True, True, False, False,
True, True, False, False, False,
True, True, True, True, False]
# and for v1.7 onwards it should be,
[False, False, False, True, True,
False, False, True, True, True,
False, False, False, False, True]
FP8 Convergence
FP8 has been tested extensively across different model architectures and configurations and we found no significant difference between FP8 and BF16 training loss curves. FP8 has also been validated for accuracy on downstream LLM tasks (e.g. LAMBADA and WikiText). Below are examples of models tested for convergence across different frameworks.
Model |
Framework |
Source |
---|---|---|
T5-770M |
JAX/T5x |
|
MPT-1.3B |
Mosaic Composer |
|
GPT-5B |
JAX/Paxml |
https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/pax#h100-results |
GPT-5B |
NeMo Framework |
Available on request |
LLama2-7B |
Alibaba Pai |
|
T5-11B |
JAX/T5x |
Available on request |
MPT-13B |
Mosaic Composer |
https://www.databricks.com/blog/turbocharged-training-optimizing-databricks-mosaic-ai-stack-fp8 |
GPT-22B |
NeMo Framework |
Available on request |
LLama2-70B |
Alibaba Pai |
|
GPT-175B |
JAX/Paxml |
https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/pax#h100-results |
Integrations
Transformer Engine has been integrated with popular LLM frameworks such as:
Hugging Face Nanotron - Coming soon!
Colossal-AI - Coming soon!
PeriFlow - Coming soon!
GPT-NeoX - Coming soon!
Contributing
We welcome contributions to Transformer Engine! To contribute to Transformer Engine and make pull requests, follow the guidelines outlined in the CONTRIBUTING.rst guide.
Papers
Videos
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distributions
Built Distribution
File details
Details for the file transformer_engine-1.11.0-py3-none-any.whl
.
File metadata
- Download URL: transformer_engine-1.11.0-py3-none-any.whl
- Upload date:
- Size: 419.6 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/5.0.0 CPython/3.10.12
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 27ffde8fb9d885eb60b74618c857b6676fc16132b433ea7ab68637ad3e2d3668 |
|
MD5 | 9366dadd86e6380585d3e564377d5aff |
|
BLAKE2b-256 | 81e456e7dfd5430eecccac0f3f640031a37aa3ce6c0f097dcbb5c4331d5157f1 |