No project description provided
Project description
ml_dtypes
ml_dtypes
is a stand-alone implementation of several NumPy dtype extensions used in machine learning libraries, including:
bfloat16
: an alternative to the standardfloat16
formatfloat8_*
: several experimental 8-bit floating point representations including:float8_e4m3b11fnuz
float8_e4m3fn
float8_e4m3fnuz
float8_e5m2
float8_e5m2fnuz
int4
anduint4
: low precision integer types.
See below for specifications of these number formats.
Installation
The ml_dtypes
package is tested with Python versions 3.9-3.12, and can be installed
with the following command:
pip install ml_dtypes
To test your installation, you can run the following:
pip install absl-py pytest
pytest --pyargs ml_dtypes
To build from source, clone the repository and run:
git submodule init
git submodule update
pip install .
Example Usage
>>> from ml_dtypes import bfloat16
>>> import numpy as np
>>> np.zeros(4, dtype=bfloat16)
array([0, 0, 0, 0], dtype=bfloat16)
Importing ml_dtypes
also registers the data types with numpy, so that they may
be referred to by their string name:
>>> np.dtype('bfloat16')
dtype(bfloat16)
>>> np.dtype('float8_e5m2')
dtype(float8_e5m2)
Specifications of implemented floating point formats
bfloat16
A bfloat16
number is a single-precision float truncated at 16 bits.
Exponent: 8, Mantissa: 7, exponent bias: 127. IEEE 754, with NaN and inf.
float8_e4m3b11fnuz
Exponent: 4, Mantissa: 3, bias: 11.
Extended range: no inf, NaN represented by 0b1000'0000.
float8_e4m3fn
Exponent: 4, Mantissa: 3, bias: 7.
Extended range: no inf, NaN represented by 0bS111'1111.
The fn
suffix is for consistency with the corresponding LLVM/MLIR type, signaling this type is not consistent with IEEE-754. The f
indicates it is finite values only. The n
indicates it includes NaNs, but only at the outer range.
float8_e4m3fnuz
8-bit floating point with 3 bit mantissa.
An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits mantissa. The suffix fnuz
is consistent with LLVM/MLIR naming and is derived from the differences to IEEE floating point conventions. F
is for "finite" (no infinities), N
for with special NaN encoding, UZ
for unsigned zero.
This type has the following characteristics:
- bit encoding: S1E4M3 -
0bSEEEEMMM
- exponent bias: 8
- infinities: Not supported
- NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s -
0b10000000
- denormals when exponent is 0
float8_e5m2
Exponent: 5, Mantissa: 2, bias: 15. IEEE 754, with NaN and inf.
float8_e5m2fnuz
8-bit floating point with 2 bit mantissa.
An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits mantissa. The suffix fnuz
is consistent with LLVM/MLIR naming and is derived from the differences to IEEE floating point conventions. F
is for "finite" (no infinities), N
for with special NaN encoding, UZ
for unsigned zero.
This type has the following characteristics:
- bit encoding: S1E5M2 -
0bSEEEEEMM
- exponent bias: 16
- infinities: Not supported
- NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s -
0b10000000
- denormals when exponent is 0
int4
and uint4
4-bit integer types, where each element is represented unpacked (i.e., padded up to a byte in memory).
NumPy does not support types smaller than a single byte. For example, the
distance between adjacent elements in an array (.strides
) is expressed in
bytes. Relaxing this restriction would be a considerable engineering project.
The int4
and uint4
types therefore use an unpacked representation, where
each element of the array is padded up to a byte in memory. The lower four bits
of each byte contain the representation of the number, whereas the upper four
bits are ignored.
Quirks of low-precision Arithmetic
If you're exploring the use of low-precision dtypes in your code, you should be
careful to anticipate when the precision loss might lead to surprising results.
One example is the behavior of aggregations like sum
; consider this bfloat16
summation in NumPy (run with version 1.24.2):
>>> from ml_dtypes import bfloat16
>>> import numpy as np
>>> rng = np.random.default_rng(seed=0)
>>> vals = rng.uniform(size=10000).astype(bfloat16)
>>> vals.sum()
256
The true sum should be close to 5000, but numpy returns exactly 256: this is
because bfloat16
does not have the precision to increment 256
by values less than
1
:
>>> bfloat16(256) + bfloat16(1)
256
After 256, the next representable value in bfloat16 is 258:
>>> np.nextafter(bfloat16(256), bfloat16(np.inf))
258
For better results you can specify that the accumulation should happen in a
higher-precision type like float32
:
>>> vals.sum(dtype='float32').astype(bfloat16)
4992
In contrast to NumPy, projects like JAX which support low-precision arithmetic more natively will often do these kinds of higher-precision accumulations automatically:
>>> import jax.numpy as jnp
>>> jnp.array(vals).sum()
Array(4992, dtype=bfloat16)
License
This is not an officially supported Google product.
The ml_dtypes
source code is licensed under the Apache 2.0 license
(see LICENSE). Pre-compiled wheels are built with the
EIGEN project, which is released under the
MPL 2.0 license (see LICENSE.eigen).
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distributions
Hashes for ml_dtypes-0.4.0b1-cp312-cp312-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e2ff58b98e1ff3a44a8353fb8d5436011efbceb79406e741f76be4611c8aa7b5 |
|
MD5 | ca2b7e6bdf9f3e4b122bcb83a38a67ec |
|
BLAKE2b-256 | 697e3c209fb143bc9c3128b87ffb0d7104ccd30d5a42f65a81381301907a32ee |
Hashes for ml_dtypes-0.4.0b1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 5151463c25e7a5448d34166935090d64c3e4171e387936637583e21a94083720 |
|
MD5 | 5781f6e3c953ccd6dd41aef765e56d52 |
|
BLAKE2b-256 | bd2ed0991aae5f0536cf836916b8544f4d38e8f2006c2838db4414704dda4345 |
Hashes for ml_dtypes-0.4.0b1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3cbc06e22fb22a592c9aff627ba260f9e64e7797211ab76be71a90a4e78ab578 |
|
MD5 | 681ab370fbbde31be332407869b78105 |
|
BLAKE2b-256 | 8b68b329416087f703a357fdf24952a9eafb8cb02340bfacc29f22f9c3d7eca7 |
Hashes for ml_dtypes-0.4.0b1-cp312-cp312-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7c20cd0cbe1669c38cc670043b249a547f691cd7d961576486671f9c82451f43 |
|
MD5 | 66e55ae6e0067ce4a0412c9389245ddb |
|
BLAKE2b-256 | 57633ec092a8d3df7d4e93f66f8296521a860db1130fd3ae8d8d1c2e7d100858 |
Hashes for ml_dtypes-0.4.0b1-cp311-cp311-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 41dfdea86031b4198b5afdfbe801e33d8048a3ad172e8388efba6f525b75e654 |
|
MD5 | 2c24bf6c27c61efedf66ee3c7431d28b |
|
BLAKE2b-256 | 5519e33db6cc2d81e50360801e081a01a1320e5905d681e4e391adff9dfb5fa7 |
Hashes for ml_dtypes-0.4.0b1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | c9090fa12ee571188828edc9963ef29e86646890a08c9bd47de3e26a04ce171f |
|
MD5 | 5a712c4f5a3ab19520562ab04c54707f |
|
BLAKE2b-256 | 5ef9f818c57a67fe2fb54ce6d42d3e7cb00e52896ef1dc2d37f566ca0d7c6f9d |
Hashes for ml_dtypes-0.4.0b1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 0fb26d3375e48256f89164e18a6cd5f9b6655c254d383035956e746fd81ec290 |
|
MD5 | be1b503b8656dcd2449568d07c27037c |
|
BLAKE2b-256 | 4234f6521f53a23b5c9924d2e31fb5a1ee2eda70990c5a2172822646ef4e831d |
Hashes for ml_dtypes-0.4.0b1-cp311-cp311-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | afc6fa8468e88b47a4906465b0f5982ef02019a5c93725e01f4c4972b3a2ab04 |
|
MD5 | 60c4e8f606341f72ed481bd049695fdd |
|
BLAKE2b-256 | 2f0221415a04f1b93e43f54c8df988b80295b2ff80bde928e67c1d059cf37784 |
Hashes for ml_dtypes-0.4.0b1-cp310-cp310-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | b5a05aa934abb1875d7d8787efb9234427937a0a96d1ca177a494486fb30d745 |
|
MD5 | 475a03d70788d390365358f9ba05b9ee |
|
BLAKE2b-256 | 268e7fcc60b6ad072cfcb2376cdd420b007c5973592deb63b023512dd4cf14c3 |
Hashes for ml_dtypes-0.4.0b1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e733fe5d5e75c38cdd13f0e19bb2999c5057846274b876be84795615cf2db472 |
|
MD5 | e4030d5547b56ee780ff8765eae9f6ee |
|
BLAKE2b-256 | f952fbd8af91693b1c02269b377a3d1f75fc3968cb9ff54e94eb65454223062a |
Hashes for ml_dtypes-0.4.0b1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8139a9e1dcad9dd5b1c866fcdba888833073468b83efc83a6b77d17d081d38b4 |
|
MD5 | 60d6326f74f796f08da2d5d465e4e274 |
|
BLAKE2b-256 | 6c322eeb3d9ffdd2ab5c497a62c3a18d71698ee52068485ceeb2e546721648d6 |
Hashes for ml_dtypes-0.4.0b1-cp310-cp310-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2315c98eb6382e2fb09931b82c0f0605aedefb78f8a1977bbd53ac998111cc2c |
|
MD5 | 25bee48383657a152b5b7c46894176fe |
|
BLAKE2b-256 | aa65ce013a60cea2be88769dc83680c975f90c4f28ff20db3d76cb27188565c2 |
Hashes for ml_dtypes-0.4.0b1-cp39-cp39-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 72fca271ae4538f7910b4c80cd8096a21363bc2434a861114d949da97703932e |
|
MD5 | 8cc38999d5d97461b0f20f410bd714ad |
|
BLAKE2b-256 | da48f3e6890cd3add509799d1102d5356a5131c46d18f2350c8a50cdd0d35774 |
Hashes for ml_dtypes-0.4.0b1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 4435f3fa157ee00adf414ebb9b2a1bf8dd37e5a26e7826de15f4ef9f81f42b14 |
|
MD5 | eaf2d4b60a5357c3fc90950cead4d7a7 |
|
BLAKE2b-256 | 18c1bac06ac1811f34f64cd86e4b443e5c533578e5e632bf00eaee03d36a7782 |
Hashes for ml_dtypes-0.4.0b1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | f5b70027f8dfc45d2b22e743dd0ac16113f22268d3ea08605293be404f7755c2 |
|
MD5 | 6fc37163d77c41a7de049570d295a0f3 |
|
BLAKE2b-256 | c3fc8a55f890f1671adcdf5e1ca9ed38c88c207cb0866935674ba26685ff24e5 |
Hashes for ml_dtypes-0.4.0b1-cp39-cp39-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 15b8a4879596c5f3132a96322417e263748871e00b1503f0131b326b7435781a |
|
MD5 | d9da7908fa4c478db69500d294884e77 |
|
BLAKE2b-256 | d9dcff311f47bf78722505476c97cfa7db6689b4b8463e9d7979e6e67f629edb |