No project description provided
Project description
ml_dtypes
ml_dtypes
is a stand-alone implementation of several NumPy dtype extensions used in machine learning libraries, including:
bfloat16
: an alternative to the standardfloat16
formatfloat8_*
: several experimental 8-bit floating point representations including:float8_e4m3b11fnuz
float8_e4m3fn
float8_e4m3fnuz
float8_e5m2
float8_e5m2fnuz
int4
anduint4
: low precision integer types.
See below for specifications of these number formats.
Installation
The ml_dtypes
package is tested with Python versions 3.9-3.12, and can be installed
with the following command:
pip install ml_dtypes
To test your installation, you can run the following:
pip install absl-py pytest
pytest --pyargs ml_dtypes
To build from source, clone the repository and run:
git submodule init
git submodule update
pip install .
Example Usage
>>> from ml_dtypes import bfloat16
>>> import numpy as np
>>> np.zeros(4, dtype=bfloat16)
array([0, 0, 0, 0], dtype=bfloat16)
Importing ml_dtypes
also registers the data types with numpy, so that they may
be referred to by their string name:
>>> np.dtype('bfloat16')
dtype(bfloat16)
>>> np.dtype('float8_e5m2')
dtype(float8_e5m2)
Specifications of implemented floating point formats
bfloat16
A bfloat16
number is a single-precision float truncated at 16 bits.
Exponent: 8, Mantissa: 7, exponent bias: 127. IEEE 754, with NaN and inf.
float8_e4m3b11fnuz
Exponent: 4, Mantissa: 3, bias: 11.
Extended range: no inf, NaN represented by 0b1000'0000.
float8_e4m3fn
Exponent: 4, Mantissa: 3, bias: 7.
Extended range: no inf, NaN represented by 0bS111'1111.
The fn
suffix is for consistency with the corresponding LLVM/MLIR type, signaling this type is not consistent with IEEE-754. The f
indicates it is finite values only. The n
indicates it includes NaNs, but only at the outer range.
float8_e4m3fnuz
8-bit floating point with 3 bit mantissa.
An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits mantissa. The suffix fnuz
is consistent with LLVM/MLIR naming and is derived from the differences to IEEE floating point conventions. F
is for "finite" (no infinities), N
for with special NaN encoding, UZ
for unsigned zero.
This type has the following characteristics:
- bit encoding: S1E4M3 -
0bSEEEEMMM
- exponent bias: 8
- infinities: Not supported
- NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s -
0b10000000
- denormals when exponent is 0
float8_e5m2
Exponent: 5, Mantissa: 2, bias: 15. IEEE 754, with NaN and inf.
float8_e5m2fnuz
8-bit floating point with 2 bit mantissa.
An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits mantissa. The suffix fnuz
is consistent with LLVM/MLIR naming and is derived from the differences to IEEE floating point conventions. F
is for "finite" (no infinities), N
for with special NaN encoding, UZ
for unsigned zero.
This type has the following characteristics:
- bit encoding: S1E5M2 -
0bSEEEEEMM
- exponent bias: 16
- infinities: Not supported
- NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s -
0b10000000
- denormals when exponent is 0
int4
and uint4
4-bit integer types, where each element is represented unpacked (i.e., padded up to a byte in memory).
NumPy does not support types smaller than a single byte. For example, the
distance between adjacent elements in an array (.strides
) is expressed in
bytes. Relaxing this restriction would be a considerable engineering project.
The int4
and uint4
types therefore use an unpacked representation, where
each element of the array is padded up to a byte in memory. The lower four bits
of each byte contain the representation of the number, whereas the upper four
bits are ignored.
Quirks of low-precision Arithmetic
If you're exploring the use of low-precision dtypes in your code, you should be
careful to anticipate when the precision loss might lead to surprising results.
One example is the behavior of aggregations like sum
; consider this bfloat16
summation in NumPy (run with version 1.24.2):
>>> from ml_dtypes import bfloat16
>>> import numpy as np
>>> rng = np.random.default_rng(seed=0)
>>> vals = rng.uniform(size=10000).astype(bfloat16)
>>> vals.sum()
256
The true sum should be close to 5000, but numpy returns exactly 256: this is
because bfloat16
does not have the precision to increment 256
by values less than
1
:
>>> bfloat16(256) + bfloat16(1)
256
After 256, the next representable value in bfloat16 is 258:
>>> np.nextafter(bfloat16(256), bfloat16(np.inf))
258
For better results you can specify that the accumulation should happen in a
higher-precision type like float32
:
>>> vals.sum(dtype='float32').astype(bfloat16)
4992
In contrast to NumPy, projects like JAX which support low-precision arithmetic more natively will often do these kinds of higher-precision accumulations automatically:
>>> import jax.numpy as jnp
>>> jnp.array(vals).sum()
Array(4992, dtype=bfloat16)
License
This is not an officially supported Google product.
The ml_dtypes
source code is licensed under the Apache 2.0 license
(see LICENSE). Pre-compiled wheels are built with the
EIGEN project, which is released under the
MPL 2.0 license (see LICENSE.eigen).
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distributions
Hashes for ml_dtypes-0.4.1-cp312-cp312-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | df0fb650d5c582a9e72bb5bd96cfebb2cdb889d89daff621c8fbc60295eba66c |
|
MD5 | 5d09a568d3d939764845abeed4b1574f |
|
BLAKE2b-256 | ae11a742d3c31b2cc8557a48efdde53427fd5f9caa2fa3c9c27d826e78a66f51 |
Hashes for ml_dtypes-0.4.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 126e7d679b8676d1a958f2651949fbfa182832c3cd08020d8facd94e4114f3e9 |
|
MD5 | 9ba00a743cd9c3fd6ef6e1ea90fe2848 |
|
BLAKE2b-256 | c7c6f89620cecc0581dc1839e218c4315171312e46c62a62da6ace204bda91c0 |
Hashes for ml_dtypes-0.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 772426b08a6172a891274d581ce58ea2789cc8abc1c002a27223f314aaf894e7 |
|
MD5 | ecb204153d13d15249e6cdfcb27ea812 |
|
BLAKE2b-256 | 8f8c7b610bd500617854c8cc6ed7c8cfb9d48d6a5c21a1437a36a4b9bc8a3598 |
Hashes for ml_dtypes-0.4.1-cp312-cp312-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 827d3ca2097085cf0355f8fdf092b888890bb1b1455f52801a2d7756f056f54b |
|
MD5 | a99ee06455beff7cfb72e007cf284b26 |
|
BLAKE2b-256 | ba1a99e924f12e4b62139fbac87419698c65f956d58de0dbfa7c028fa5b096aa |
Hashes for ml_dtypes-0.4.1-cp311-cp311-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 274cc7193dd73b35fb26bef6c5d40ae3eb258359ee71cd82f6e96a8c948bdaa6 |
|
MD5 | b9bc56c8aeee2a8119d37590ac1e92ea |
|
BLAKE2b-256 | e8d3ddfd9878b223b3aa9a930c6100a99afca5cfab7ea703662e00323acb7568 |
Hashes for ml_dtypes-0.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 74c6cfb5cf78535b103fde9ea3ded8e9f16f75bc07789054edc7776abfb3d752 |
|
MD5 | a82ac172d092de803348b4e04bd1a36a |
|
BLAKE2b-256 | 28bc6a2344338ea7b61cd7b46fb24ec459360a5a0903b57c55b156c1e46c644a |
Hashes for ml_dtypes-0.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e138a9b7a48079c900ea969341a5754019a1ad17ae27ee330f7ebf43f23877f9 |
|
MD5 | 1d2522078d58736a91eda2517598167e |
|
BLAKE2b-256 | 7e99e68c56fac5de973007a10254b6e17a0362393724f40f66d5e4033f4962c2 |
Hashes for ml_dtypes-0.4.1-cp311-cp311-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2d55b588116a7085d6e074cf0cdb1d6fa3875c059dddc4d2c94a4cc81c23e975 |
|
MD5 | 8873993686df565d1a204f01748c1353 |
|
BLAKE2b-256 | d1769835c8609c29f2214359e88f29255fc4aad4ea0f613fb48aa8815ceda1b6 |
Hashes for ml_dtypes-0.4.1-cp310-cp310-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 15fdd922fea57e493844e5abb930b9c0bd0af217d9edd3724479fc3d7ce70e3f |
|
MD5 | bc359fc0e65d6ec814a0a8d1b8e01765 |
|
BLAKE2b-256 | 041b9a3afb437702503514f3934ec8d7904270edf013d28074f3e700e5dfbb0f |
Hashes for ml_dtypes-0.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 9f5e8f75fa371020dd30f9196e7d73babae2abd51cf59bdd56cb4f8de7e13354 |
|
MD5 | a0acb5f97f050acb3319762d6264ff99 |
|
BLAKE2b-256 | 1686a9f7569e7e4f5395f927de38a13b92efa73f809285d04f2923b291783dd2 |
Hashes for ml_dtypes-0.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 8c09a6d11d8475c2a9fd2bc0695628aec105f97cab3b3a3fb7c9660348ff7d24 |
|
MD5 | 1fb90cdf031ec932fb8f18a894cc3ccd |
|
BLAKE2b-256 | 037b32650e1b2a2713a5923a0af2a8503d0d4a8fc99d1e1e0a1c40e996634460 |
Hashes for ml_dtypes-0.4.1-cp310-cp310-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 1fe8b5b5e70cd67211db94b05cfd58dace592f24489b038dc6f9fe347d2e07d5 |
|
MD5 | 5fa6906e1b50b41e8481906e2660834f |
|
BLAKE2b-256 | 569e76b84f77c7afee3b116dc8407903a2d5004ba3059a8f3dcdcfa6ebf33fff |
Hashes for ml_dtypes-0.4.1-cp39-cp39-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | ef0d7e3fece227b49b544fa69e50e607ac20948f0043e9f76b44f35f229ea450 |
|
MD5 | 7623710e0f23d011ed642633598a74ef |
|
BLAKE2b-256 | 4bf3e5ff8dd27f66c8b80f97f0f89bb0b74e4a7005e5ff5f8f4237126c827911 |
Hashes for ml_dtypes-0.4.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | ad0b757d445a20df39035c4cdeed457ec8b60d236020d2560dbc25887533cf50 |
|
MD5 | c4cbce5c513d9ef0cd728f0ac7bcbb48 |
|
BLAKE2b-256 | 3e55b9711de47135d4d8766ff7907fe54c8bffff545fd646817c352de37b0ad5 |
Hashes for ml_dtypes-0.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 560be16dc1e3bdf7c087eb727e2cf9c0e6a3d87e9f415079d2491cc419b3ebf5 |
|
MD5 | fe160194e2d39b727a3ef0b918f40a18 |
|
BLAKE2b-256 | 1af6ad0bd2735b9570ebf9c113f024b4f2b34f2331f16197c60babdc168b22d5 |
Hashes for ml_dtypes-0.4.1-cp39-cp39-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e35e486e97aee577d0890bc3bd9e9f9eece50c08c163304008587ec8cfe7575b |
|
MD5 | 75e2e7dc9c66d2bb49aeeecec358d7da |
|
BLAKE2b-256 | 8fd76e1372052fe95c0cacfdb9718dba04726203885ffddb0cfddd8f8aa89a3b |