No project description provided
Project description
ml_dtypes
ml_dtypes
is a stand-alone implementation of several NumPy dtype extensions used in machine learning libraries, including:
bfloat16
: an alternative to the standardfloat16
formatfloat8_*
: several experimental 8-bit floating point representations including:float8_e4m3b11fnuz
float8_e4m3fn
float8_e4m3fnuz
float8_e5m2
float8_e5m2fnuz
int4
anduint4
: low precision integer types.
See below for specifications of these number formats.
Installation
The ml_dtypes
package is tested with Python versions 3.9-3.12, and can be installed
with the following command:
pip install ml_dtypes
To test your installation, you can run the following:
pip install absl-py pytest
pytest --pyargs ml_dtypes
To build from source, clone the repository and run:
git submodule init
git submodule update
pip install .
Example Usage
>>> from ml_dtypes import bfloat16
>>> import numpy as np
>>> np.zeros(4, dtype=bfloat16)
array([0, 0, 0, 0], dtype=bfloat16)
Importing ml_dtypes
also registers the data types with numpy, so that they may
be referred to by their string name:
>>> np.dtype('bfloat16')
dtype(bfloat16)
>>> np.dtype('float8_e5m2')
dtype(float8_e5m2)
Specifications of implemented floating point formats
bfloat16
A bfloat16
number is a single-precision float truncated at 16 bits.
Exponent: 8, Mantissa: 7, exponent bias: 127. IEEE 754, with NaN and inf.
float8_e4m3b11fnuz
Exponent: 4, Mantissa: 3, bias: 11.
Extended range: no inf, NaN represented by 0b1000'0000.
float8_e4m3fn
Exponent: 4, Mantissa: 3, bias: 7.
Extended range: no inf, NaN represented by 0bS111'1111.
The fn
suffix is for consistency with the corresponding LLVM/MLIR type, signaling this type is not consistent with IEEE-754. The f
indicates it is finite values only. The n
indicates it includes NaNs, but only at the outer range.
float8_e4m3fnuz
8-bit floating point with 3 bit mantissa.
An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits mantissa. The suffix fnuz
is consistent with LLVM/MLIR naming and is derived from the differences to IEEE floating point conventions. F
is for "finite" (no infinities), N
for with special NaN encoding, UZ
for unsigned zero.
This type has the following characteristics:
- bit encoding: S1E4M3 -
0bSEEEEMMM
- exponent bias: 8
- infinities: Not supported
- NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s -
0b10000000
- denormals when exponent is 0
float8_e5m2
Exponent: 5, Mantissa: 2, bias: 15. IEEE 754, with NaN and inf.
float8_e5m2fnuz
8-bit floating point with 2 bit mantissa.
An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits mantissa. The suffix fnuz
is consistent with LLVM/MLIR naming and is derived from the differences to IEEE floating point conventions. F
is for "finite" (no infinities), N
for with special NaN encoding, UZ
for unsigned zero.
This type has the following characteristics:
- bit encoding: S1E5M2 -
0bSEEEEEMM
- exponent bias: 16
- infinities: Not supported
- NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s -
0b10000000
- denormals when exponent is 0
int4
and uint4
4-bit integer types, where each element is represented unpacked (i.e., padded up to a byte in memory).
NumPy does not support types smaller than a single byte. For example, the
distance between adjacent elements in an array (.strides
) is expressed in
bytes. Relaxing this restriction would be a considerable engineering project.
The int4
and uint4
types therefore use an unpacked representation, where
each element of the array is padded up to a byte in memory. The lower four bits
of each byte contain the representation of the number, whereas the upper four
bits are ignored.
Quirks of low-precision Arithmetic
If you're exploring the use of low-precision dtypes in your code, you should be
careful to anticipate when the precision loss might lead to surprising results.
One example is the behavior of aggregations like sum
; consider this bfloat16
summation in NumPy (run with version 1.24.2):
>>> from ml_dtypes import bfloat16
>>> import numpy as np
>>> rng = np.random.default_rng(seed=0)
>>> vals = rng.uniform(size=10000).astype(bfloat16)
>>> vals.sum()
256
The true sum should be close to 5000, but numpy returns exactly 256: this is
because bfloat16
does not have the precision to increment 256
by values less than
1
:
>>> bfloat16(256) + bfloat16(1)
256
After 256, the next representable value in bfloat16 is 258:
>>> np.nextafter(bfloat16(256), bfloat16(np.inf))
258
For better results you can specify that the accumulation should happen in a
higher-precision type like float32
:
>>> vals.sum(dtype='float32').astype(bfloat16)
4992
In contrast to NumPy, projects like JAX which support low-precision arithmetic more natively will often do these kinds of higher-precision accumulations automatically:
>>> import jax.numpy as jnp
>>> jnp.array(vals).sum()
Array(4992, dtype=bfloat16)
License
This is not an officially supported Google product.
The ml_dtypes
source code is licensed under the Apache 2.0 license
(see LICENSE). Pre-compiled wheels are built with the
EIGEN project, which is released under the
MPL 2.0 license (see LICENSE.eigen).
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distributions
Hashes for ml_dtypes-0.4.0-cp312-cp312-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3b67ec73a697c88c1122038e0de46520e48dc2ec876d42cf61bc5efe3c0b7675 |
|
MD5 | 820ad096f2ceb7ff2ba8eba964baf110 |
|
BLAKE2b-256 | 0fb77cfca987ca898b64c0b7d185e957fbd8dccb64fe5ae9e44f68ec83371df5 |
Hashes for ml_dtypes-0.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | eaa32979ebfde3a0d7c947cafbf79edc1ec77ac05ad0780ee86c1d8df70f2259 |
|
MD5 | 7e0d922b5c39b7a486335868c112f905 |
|
BLAKE2b-256 | 8cef5635b60d444db9c949b32d4e1a0a30b3ac237afbd71cce8bd1ccfb145723 |
Hashes for ml_dtypes-0.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | ad6849a2db386b38e4d54fe13eb3293464561780531a918f8ef4c8169170dd49 |
|
MD5 | dc488c6393f0c9eb03a8729d089fe178 |
|
BLAKE2b-256 | 37d53f3085b3a155e1b84c7fc680f05538d31cf01b835aa19cb17edd4994693f |
Hashes for ml_dtypes-0.4.0-cp312-cp312-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06 |
|
MD5 | 40ae943fa8ca1f0e36ffb0c6dc988c00 |
|
BLAKE2b-256 | 309d890e8c9cb556cec121f784fd84089e1e52939ba6eabf5dc62f6435db28d6 |
Hashes for ml_dtypes-0.4.0-cp311-cp311-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 75b4faf99d0711b81f393db36d210b4255fd419f6f790bc6c1b461f95ffb7a9e |
|
MD5 | 5ee6df34814ed7c33ac5b02706b06059 |
|
BLAKE2b-256 | f036290745178e5776f7416818abc1334c1b19afb93c7c87fd1bef3cc99f84ca |
Hashes for ml_dtypes-0.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e1e2f4237b459a63c97c2c9f449baa637d7e4c20addff6a9bac486f22432f3b6 |
|
MD5 | 333195450fc9b010adaec1b309b3cf36 |
|
BLAKE2b-256 | 8417a936d3dfad84d028ba8539a93167274b7dcd7985e0d9df487e94a62f9428 |
Hashes for ml_dtypes-0.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | c83e4d443962d891d51669ff241d5aaad10a8d3d37a81c5532a45419885d591c |
|
MD5 | 7a7fc8b666ee8b1259789bcc4b323347 |
|
BLAKE2b-256 | 179b6c655eae05ba3edb30cb03e116dfbe722775b26234b16ed0a14007c871ed |
Hashes for ml_dtypes-0.4.0-cp311-cp311-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | bdf689be7351cc3c95110c910c1b864002f113e682e44508910c849e144f3df1 |
|
MD5 | 656993f3da8ae7d2d74cdbc2e8b90311 |
|
BLAKE2b-256 | 426bb2fa3e2386c2b7dde43f12b83c67f6e583039141dfbb58e5c8fd365a5a7d |
Hashes for ml_dtypes-0.4.0-cp310-cp310-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | a15d96d090aebb55ee85173d1775ae325a001aab607a76c8ea0b964ccd6b5364 |
|
MD5 | 0b12b5ee387601e73febb918e34138b8 |
|
BLAKE2b-256 | ea31cc9b87fbbb3f4bf2cb1a4aeb7648bd6d6c558dc3f60e1bd21958f18ddf71 |
Hashes for ml_dtypes-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 03e7cda6ef164eed0abb31df69d2c00c3a5ab3e2610b6d4c42183a43329c72a5 |
|
MD5 | 60b967e8b86e7ef08f0eacc7a0c4134c |
|
BLAKE2b-256 | 9d15e5af59287e712b26ce776f00911c45c97ac9f4cd82d46500602cc94127ed |
Hashes for ml_dtypes-0.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2bb83fd064db43e67e67d021e547698af4c8d5c6190f2e9b1c53c09f6ff5531d |
|
MD5 | 1e371800f418920ff548864cdfe65c85 |
|
BLAKE2b-256 | f47d1e84fa0db717f9fd27d19649f67bd01df1e3f92e041d58b918b39e1898a4 |
Hashes for ml_dtypes-0.4.0-cp310-cp310-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 93afe37f3a879d652ec9ef1fc47612388890660a2657fbb5747256c3b818fd81 |
|
MD5 | 13148c53b6e13bcf635bbd40601a1b39 |
|
BLAKE2b-256 | bc2662b6c86ecbe59dbb960be9b134b1d153cc9e0b9c54c8f19b63759403f59c |
Hashes for ml_dtypes-0.4.0-cp39-cp39-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 723af6346447268a3cf0b7356e963d80ecb5732b5279b2aa3fa4b9fc8297c85e |
|
MD5 | 375d2c82ec997a0ebb6b0d3dd705c9e9 |
|
BLAKE2b-256 | 9a83a26cb6635ffd9bee8af8d5cbb3feb71b782f8729ac1df7034cc017d8e9fd |
Hashes for ml_dtypes-0.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | f1724ddcdf5edbaf615a62110af47407f1719b8d02e68ccee60683acb5f74da1 |
|
MD5 | 50d09a6bf29db0e60eb1b39f3a83aee5 |
|
BLAKE2b-256 | 231c06b52d3dcd75a81f6ca1e56514db6b21fe928f159cc5302428c1fed46562 |
Hashes for ml_dtypes-0.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 43cf4356a0fe2eeac6d289018d0734e17a403bdf1fd911953c125dd0358edcc0 |
|
MD5 | 574e323e0b29cb1acdff47c5e0816503 |
|
BLAKE2b-256 | 509613d7c3cc82d5ef597279216cf56ff461f8b57e7096a3ef10246a83ca80c0 |
Hashes for ml_dtypes-0.4.0-cp39-cp39-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 41affb38fdfe146e3db226cf2953021184d6f0c4ffab52136613e9601706e368 |
|
MD5 | 4dea913b7d5b41d85326701e262680f9 |
|
BLAKE2b-256 | 4fdf455704233905ce4fab09b2a80d81ab61d850d530b7ae68acb7f8ef99d349 |