No project description provided
Project description
ml_dtypes
ml_dtypes
is a stand-alone implementation of several NumPy dtype extensions used in machine learning libraries, including:
bfloat16
: an alternative to the standardfloat16
formatfloat8_*
: several experimental 8-bit floating point representations including:float8_e4m3b11fnuz
float8_e4m3fn
float8_e4m3fnuz
float8_e5m2
float8_e5m2fnuz
int4
anduint4
: low precision integer types.
See below for specifications of these number formats.
Installation
The ml_dtypes
package is tested with Python versions 3.9-3.12, and can be installed
with the following command:
pip install ml_dtypes
To test your installation, you can run the following:
pip install absl-py pytest
pytest --pyargs ml_dtypes
To build from source, clone the repository and run:
git submodule init
git submodule update
pip install .
Example Usage
>>> from ml_dtypes import bfloat16
>>> import numpy as np
>>> np.zeros(4, dtype=bfloat16)
array([0, 0, 0, 0], dtype=bfloat16)
Importing ml_dtypes
also registers the data types with numpy, so that they may
be referred to by their string name:
>>> np.dtype('bfloat16')
dtype(bfloat16)
>>> np.dtype('float8_e5m2')
dtype(float8_e5m2)
Specifications of implemented floating point formats
bfloat16
A bfloat16
number is a single-precision float truncated at 16 bits.
Exponent: 8, Mantissa: 7, exponent bias: 127. IEEE 754, with NaN and inf.
float8_e4m3b11fnuz
Exponent: 4, Mantissa: 3, bias: 11.
Extended range: no inf, NaN represented by 0b1000'0000.
float8_e4m3fn
Exponent: 4, Mantissa: 3, bias: 7.
Extended range: no inf, NaN represented by 0bS111'1111.
The fn
suffix is for consistency with the corresponding LLVM/MLIR type, signaling this type is not consistent with IEEE-754. The f
indicates it is finite values only. The n
indicates it includes NaNs, but only at the outer range.
float8_e4m3fnuz
8-bit floating point with 3 bit mantissa.
An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits mantissa. The suffix fnuz
is consistent with LLVM/MLIR naming and is derived from the differences to IEEE floating point conventions. F
is for "finite" (no infinities), N
for with special NaN encoding, UZ
for unsigned zero.
This type has the following characteristics:
- bit encoding: S1E4M3 -
0bSEEEEMMM
- exponent bias: 8
- infinities: Not supported
- NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s -
0b10000000
- denormals when exponent is 0
float8_e5m2
Exponent: 5, Mantissa: 2, bias: 15. IEEE 754, with NaN and inf.
float8_e5m2fnuz
8-bit floating point with 2 bit mantissa.
An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits mantissa. The suffix fnuz
is consistent with LLVM/MLIR naming and is derived from the differences to IEEE floating point conventions. F
is for "finite" (no infinities), N
for with special NaN encoding, UZ
for unsigned zero.
This type has the following characteristics:
- bit encoding: S1E5M2 -
0bSEEEEEMM
- exponent bias: 16
- infinities: Not supported
- NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s -
0b10000000
- denormals when exponent is 0
int4
and uint4
4-bit integer types, where each element is represented unpacked (i.e., padded up to a byte in memory).
NumPy does not support types smaller than a single byte. For example, the
distance between adjacent elements in an array (.strides
) is expressed in
bytes. Relaxing this restriction would be a considerable engineering project.
The int4
and uint4
types therefore use an unpacked representation, where
each element of the array is padded up to a byte in memory. The lower four bits
of each byte contain the representation of the number, whereas the upper four
bits are ignored.
Quirks of low-precision Arithmetic
If you're exploring the use of low-precision dtypes in your code, you should be
careful to anticipate when the precision loss might lead to surprising results.
One example is the behavior of aggregations like sum
; consider this bfloat16
summation in NumPy (run with version 1.24.2):
>>> from ml_dtypes import bfloat16
>>> import numpy as np
>>> rng = np.random.default_rng(seed=0)
>>> vals = rng.uniform(size=10000).astype(bfloat16)
>>> vals.sum()
256
The true sum should be close to 5000, but numpy returns exactly 256: this is
because bfloat16
does not have the precision to increment 256
by values less than
1
:
>>> bfloat16(256) + bfloat16(1)
256
After 256, the next representable value in bfloat16 is 258:
>>> np.nextafter(bfloat16(256), bfloat16(np.inf))
258
For better results you can specify that the accumulation should happen in a
higher-precision type like float32
:
>>> vals.sum(dtype='float32').astype(bfloat16)
4992
In contrast to NumPy, projects like JAX which support low-precision arithmetic more natively will often do these kinds of higher-precision accumulations automatically:
>>> import jax.numpy as jnp
>>> jnp.array(vals).sum()
Array(4992, dtype=bfloat16)
License
This is not an officially supported Google product.
The ml_dtypes
source code is licensed under the Apache 2.0 license
(see LICENSE). Pre-compiled wheels are built with the
EIGEN project, which is released under the
MPL 2.0 license (see LICENSE.eigen).
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distributions
Hashes for ml_dtypes-0.3.0-cp312-cp312-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | ba5150acb723332c2fa8da8eda19110ac0ee45efb8d4869a1a99a3d4430d546e |
|
MD5 | ace77db7e2a14c8b47e5d9fedf3b6dec |
|
BLAKE2b-256 | feff4c46ce4cb4cf88732406f819336aa7eaa7852650a2c0aa68571c6f89dc18 |
Hashes for ml_dtypes-0.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | fe6e65d850cb4f52e3977f565ff6196a580fd833682c06c912837f341df83c11 |
|
MD5 | 8d426e2c18e6beb4b75c9b35c5564f69 |
|
BLAKE2b-256 | 82682ea2b07b091cd4e1939222cc941b6354470fa59add7e43204423a57dca67 |
Hashes for ml_dtypes-0.3.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0982ae99ca687a3c7ca0fdb7e788a671ef89742702a5b9d9c1052f4a64e06618 |
|
MD5 | 2cec09994b1dcf175764a70dc1373df6 |
|
BLAKE2b-256 | dc7c3cec07e9d6555b9b33ce87fcd65554bdd94394c55fe95edabdac711ef8ea |
Hashes for ml_dtypes-0.3.0-cp312-cp312-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | c657128c7264c922898124397defbbcb7044bc10624eb6924e57139043673100 |
|
MD5 | b9382e9edba555ca2c94bb656eb1e2a7 |
|
BLAKE2b-256 | 915dae5d7ab4fd5ec210c1aa41ae0fb1a84870e62d519ac8d60d5593d04ae4ef |
Hashes for ml_dtypes-0.3.0-cp311-cp311-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e4d3d3fcf7f1a5e9572ac927b7f8ae7cda48ef8b46c293ed9b72e7037b7de3e3 |
|
MD5 | a8a8fb59ec611fc0361c324527b0cc25 |
|
BLAKE2b-256 | afc052e667bdb8967c49d910aa06ab0c78929b64e5514b5a1916640dd5515b5d |
Hashes for ml_dtypes-0.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | bd2876780316a9e3a881d201dcb50b3233d2cc95a844eab6c38192480f2f4cb6 |
|
MD5 | b76e76b286128283f4c2570bdfb3998f |
|
BLAKE2b-256 | e049d8d08efa2a07ae63c58f75dc490c83cef8ea8d14908971afbe5d78077e42 |
Hashes for ml_dtypes-0.3.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 14c9f1210394d4f019d02ac1f1e93c6a19f28215323f707705a85a0fa6c1998b |
|
MD5 | f0012def3d253bf639dc0ab7af263d4a |
|
BLAKE2b-256 | 4033faa48cd668aaaec0159cb0e33cf469470a270af149f6bb47e351534e4a94 |
Hashes for ml_dtypes-0.3.0-cp311-cp311-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 243dcf3a7e8390c169289a2149d0f833bf64cd1dd8db8969cdc226aba2ba052b |
|
MD5 | 9f3e1ef4b8d5e9f1080e6553d99996b5 |
|
BLAKE2b-256 | 1f90b54133880f542ae4ebebd564224af2f7eb9d2c91f1d4d1f61c9ccd5796cc |
Hashes for ml_dtypes-0.3.0-cp310-cp310-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3e2757a43342c5734040d21cea4448ce7245265d4e330bae906b23f66d20a641 |
|
MD5 | ea55c308ea73710f5187f648c1c9b25d |
|
BLAKE2b-256 | f8009decc8252d106b9ac544fb3cd782afb42f1a26b98be3a92aba0be2be6574 |
Hashes for ml_dtypes-0.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6b051daf1717dfef83e68d421c72d8d7196497ac2956f1157b8e5574e6b6bc28 |
|
MD5 | 4efbc6afb21f0a90e41698744f263164 |
|
BLAKE2b-256 | 39693c587082bd2b0fdc5e0107e6699c02f7701ca7998fce0051a3ce37b79d1a |
Hashes for ml_dtypes-0.3.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9e956139f95b47af527ff249f379103709670572d6940c0404dd306f7ee2583a |
|
MD5 | 3628b24030e68ff1b3f0b43902116c75 |
|
BLAKE2b-256 | 5e8a24440288ef1e53cd5a624321d50620b72581fa170ca16f79c552b35e43cd |
Hashes for ml_dtypes-0.3.0-cp310-cp310-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 18f243f134e3ea9071ac65e769c168f3a43aaff565202639d6d5a6f0da407c9d |
|
MD5 | 7bc95c1bff68dc54d4ebd58821b29887 |
|
BLAKE2b-256 | 0bd2a068942b6aea44ce8e2cad6d75e0fcd29a13c6cc88a8b6e75b6614347e49 |
Hashes for ml_dtypes-0.3.0-cp39-cp39-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 13ade1c696d1159f6fc7c48a3766c04f477e6c3925acd0149e4fea9893024d63 |
|
MD5 | 51ea00068157b415608b4fe170c3bfe5 |
|
BLAKE2b-256 | cf81d8b452e4b27901985e20a771a84f3ee439d538d3e6b9d5a2541948d8df61 |
Hashes for ml_dtypes-0.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8a7b4131e6faf51f48582fc54d06503ed540f63d8d57d4060a732c0b1a9bb66b |
|
MD5 | 4029cfe9b9db93b0fe9d72c5f7c60b09 |
|
BLAKE2b-256 | 33e754083d0bf83daed72f70807935541ec9c697bd83f0c7350ae0e62245d188 |
Hashes for ml_dtypes-0.3.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 788c1f18046377ad893a9bba22b7f4955789a5a9dec411c184d0f49dd4427336 |
|
MD5 | af284e00e0d64f339383425a8a6b9795 |
|
BLAKE2b-256 | 9cf195a8045acf0ec84300f023ba75822c2f9ef9ae62f9c99f65bbf0ad6c1e24 |
Hashes for ml_dtypes-0.3.0-cp39-cp39-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5b74d9134e8204ebc0cfa4fe08951e5cd315463b95eff9e9d13e7ee3c905e7d1 |
|
MD5 | 0509ff9276d64611002a48f980964e2d |
|
BLAKE2b-256 | be11e057858ed9f6d226b35cd5cf771ce681e37961c673157c6d553b29aa9000 |