No project description provided
Project description
ml_dtypes
ml_dtypes
is a stand-alone implementation of several NumPy dtype extensions used in machine learning libraries, including:
bfloat16
: an alternative to the standardfloat16
formatfloat8_*
: several experimental 8-bit floating point representations including:float8_e4m3b11fnuz
float8_e4m3fn
float8_e4m3fnuz
float8_e5m2
float8_e5m2fnuz
int4
anduint4
: low precision integer types.
See below for specifications of these number formats.
Installation
The ml_dtypes
package is tested with Python versions 3.9-3.12, and can be installed
with the following command:
pip install ml_dtypes
To test your installation, you can run the following:
pip install absl-py pytest
pytest --pyargs ml_dtypes
To build from source, clone the repository and run:
git submodule init
git submodule update
pip install .
Example Usage
>>> from ml_dtypes import bfloat16
>>> import numpy as np
>>> np.zeros(4, dtype=bfloat16)
array([0, 0, 0, 0], dtype=bfloat16)
Importing ml_dtypes
also registers the data types with numpy, so that they may
be referred to by their string name:
>>> np.dtype('bfloat16')
dtype(bfloat16)
>>> np.dtype('float8_e5m2')
dtype(float8_e5m2)
Specifications of implemented floating point formats
bfloat16
A bfloat16
number is a single-precision float truncated at 16 bits.
Exponent: 8, Mantissa: 7, exponent bias: 127. IEEE 754, with NaN and inf.
float8_e4m3b11fnuz
Exponent: 4, Mantissa: 3, bias: 11.
Extended range: no inf, NaN represented by 0b1000'0000.
float8_e4m3fn
Exponent: 4, Mantissa: 3, bias: 7.
Extended range: no inf, NaN represented by 0bS111'1111.
The fn
suffix is for consistency with the corresponding LLVM/MLIR type, signaling this type is not consistent with IEEE-754. The f
indicates it is finite values only. The n
indicates it includes NaNs, but only at the outer range.
float8_e4m3fnuz
8-bit floating point with 3 bit mantissa.
An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits mantissa. The suffix fnuz
is consistent with LLVM/MLIR naming and is derived from the differences to IEEE floating point conventions. F
is for "finite" (no infinities), N
for with special NaN encoding, UZ
for unsigned zero.
This type has the following characteristics:
- bit encoding: S1E4M3 -
0bSEEEEMMM
- exponent bias: 8
- infinities: Not supported
- NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s -
0b10000000
- denormals when exponent is 0
float8_e5m2
Exponent: 5, Mantissa: 2, bias: 15. IEEE 754, with NaN and inf.
float8_e5m2fnuz
8-bit floating point with 2 bit mantissa.
An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits mantissa. The suffix fnuz
is consistent with LLVM/MLIR naming and is derived from the differences to IEEE floating point conventions. F
is for "finite" (no infinities), N
for with special NaN encoding, UZ
for unsigned zero.
This type has the following characteristics:
- bit encoding: S1E5M2 -
0bSEEEEEMM
- exponent bias: 16
- infinities: Not supported
- NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s -
0b10000000
- denormals when exponent is 0
int4
and uint4
4-bit integer types, where each element is represented unpacked (i.e., padded up to a byte in memory).
NumPy does not support types smaller than a single byte. For example, the
distance between adjacent elements in an array (.strides
) is expressed in
bytes. Relaxing this restriction would be a considerable engineering project.
The int4
and uint4
types therefore use an unpacked representation, where
each element of the array is padded up to a byte in memory. The lower four bits
of each byte contain the representation of the number, whereas the upper four
bits are ignored.
Quirks of low-precision Arithmetic
If you're exploring the use of low-precision dtypes in your code, you should be
careful to anticipate when the precision loss might lead to surprising results.
One example is the behavior of aggregations like sum
; consider this bfloat16
summation in NumPy (run with version 1.24.2):
>>> from ml_dtypes import bfloat16
>>> import numpy as np
>>> rng = np.random.default_rng(seed=0)
>>> vals = rng.uniform(size=10000).astype(bfloat16)
>>> vals.sum()
256
The true sum should be close to 5000, but numpy returns exactly 256: this is
because bfloat16
does not have the precision to increment 256
by values less than
1
:
>>> bfloat16(256) + bfloat16(1)
256
After 256, the next representable value in bfloat16 is 258:
>>> np.nextafter(bfloat16(256), bfloat16(np.inf))
258
For better results you can specify that the accumulation should happen in a
higher-precision type like float32
:
>>> vals.sum(dtype='float32').astype(bfloat16)
4992
In contrast to NumPy, projects like JAX which support low-precision arithmetic more natively will often do these kinds of higher-precision accumulations automatically:
>>> import jax.numpy as jnp
>>> jnp.array(vals).sum()
Array(4992, dtype=bfloat16)
License
This is not an officially supported Google product.
The ml_dtypes
source code is licensed under the Apache 2.0 license
(see LICENSE). Pre-compiled wheels are built with the
EIGEN project, which is released under the
MPL 2.0 license (see LICENSE.eigen).
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distributions
Hashes for ml_dtypes-0.3.2-cp312-cp312-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | f47619d978ab1ae7dfdc4052ea97c636c6263e1f19bd1be0e42c346b98d15ff4 |
|
MD5 | fb6f6cd6088d97e40fef2b1ab8198c70 |
|
BLAKE2b-256 | 47f3847da54c3d243ff2aa778078ecf09da199194d282744718ef325dd8afd41 |
Hashes for ml_dtypes-0.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | e8505946df1665db01332d885c2020b4cb9e84a8b1241eb4ba69d59591f65855 |
|
MD5 | 4c2dd85ed90f9915e08ceb899b51422f |
|
BLAKE2b-256 | e5f193219c44bae4017e6e43391fa4433592de08e05def9d885227d3596f21a5 |
Hashes for ml_dtypes-0.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 3a17ef2322e60858d93584e9c52a5be7dd6236b056b7fa1ec57f1bb6ba043e33 |
|
MD5 | b1ee981302b99ec02045344fa3684351 |
|
BLAKE2b-256 | 6a05ec30199c791cf0d788a26f56d8efb8ee4133ede79a9680fd8cc05e706404 |
Hashes for ml_dtypes-0.3.2-cp312-cp312-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 93b78f53431c93953f7850bb1b925a17f0ab5d97527e38a7e865b5b4bc5cfc18 |
|
MD5 | 1bcec94da420ece7179a0a6269be7a5a |
|
BLAKE2b-256 | ad2d57a8aa1ba7472a93a675bfba3f0c90d9396d01d040617a5345ce87884330 |
Hashes for ml_dtypes-0.3.2-cp311-cp311-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6604877d567a29bfe7cc02969ae0f2425260e5335505cf5e7fefc3e5465f5655 |
|
MD5 | c987fbc9396e8679c88b674f065276c4 |
|
BLAKE2b-256 | a4db1784b87285588788170f87e987bfb4bda218d62a70a81ebb66c94e7f9b95 |
Hashes for ml_dtypes-0.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 2c34f2ba9660b21fe1034b608308a01be82bbef2a92fb8199f24dc6bad0d5226 |
|
MD5 | d01a8f19c8c9769e5fce34d18e798c26 |
|
BLAKE2b-256 | 77a0d4ee9e3aca5b9101c590b58555820618e8201c2ccb7004eabb417ec046ac |
Hashes for ml_dtypes-0.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | b89b194e9501a92d289c1ffd411380baf5daafb9818109a4f49b0a1b6dce4462 |
|
MD5 | 3e6bc55fff32118f7d28b247aefa9f97 |
|
BLAKE2b-256 | d1ed211bf2e1c66e4ec9b712c3be848a876185c7f0d5e94bf647b60e64ef32eb |
Hashes for ml_dtypes-0.3.2-cp311-cp311-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 763697ab8a88d47443997a7cdf3aac7340049aed45f7521f6b0ec8a0594821fe |
|
MD5 | d056fc64df7d5c8d40839386243eea51 |
|
BLAKE2b-256 | 6ea46aabb78f1569550fd77c74d2c1d008b502c8ce72776bd88b14ea6c182c9e |
Hashes for ml_dtypes-0.3.2-cp310-cp310-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 6b35c4e8ca957c877ac35c79ffa77724ecc3702a1e4b18b08306c03feae597bb |
|
MD5 | fe53133ba82a2d67ad196a084041b727 |
|
BLAKE2b-256 | 30a50480b23b2213c746cd874894bc485eb49310d7045159a36c7c03cab729ce |
Hashes for ml_dtypes-0.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 961134ea44c7b8ca63eda902a44b58cd8bd670e21d62e255c81fba0a8e70d9b7 |
|
MD5 | cfe5e605dca09409b8b7f812137d55d7 |
|
BLAKE2b-256 | 71017dc0e2cdead686a758810d08fd4111602088fe3f0d88064a83cbfb635593 |
Hashes for ml_dtypes-0.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | d1a746fe5fb9cd974a91070174258f0be129c592b93f9ce7df6cc336416c3fbd |
|
MD5 | 969aacb15212b1f6f1c57f125ac59682 |
|
BLAKE2b-256 | bc6dde99642d98feb7e83ccfbc5eb2b5970ff19ec6834094b690205bebe1c22d |
Hashes for ml_dtypes-0.3.2-cp310-cp310-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7afde548890a92b41c0fed3a6c525f1200a5727205f73dc21181a2726571bb53 |
|
MD5 | c32880e59132e2df3bd77ce431f0a7e0 |
|
BLAKE2b-256 | 620a2b586fd10be7b8311068f4078623a73376fc49c8b3768be9965034062982 |
Hashes for ml_dtypes-0.3.2-cp39-cp39-win_amd64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7ba8e1fafc7fff3e643f453bffa7d082df1678a73286ce8187d3e825e776eb94 |
|
MD5 | 8174d5a53af273fc57af2fd31cf04945 |
|
BLAKE2b-256 | 383c5d058a50340759423b25cb99f930cb3691fc30ebe86d53fdf1bff55c2d71 |
Hashes for ml_dtypes-0.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 91f8783fd1f2c23fd3b9ee5ad66b785dafa58ba3cdb050c4458021fa4d1eb226 |
|
MD5 | 4856c4e40b3b2e4699075f4a9fd5c7a3 |
|
BLAKE2b-256 | 8629b389f235add26220bc7b7f100362f4e3a84e14f7c837abd34a11347df1b0 |
Hashes for ml_dtypes-0.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | 7a4c3fcbf86fa52d0204f07cfd23947ef05b4ad743a1a988e163caa34a201e5e |
|
MD5 | 1ccf14ba5ae1f06c9d93895b931717bc |
|
BLAKE2b-256 | ea58c56da71b1d9f9c6c1e61f63d27f901c3526e13da0589ec2ff993e9a72c04 |
Hashes for ml_dtypes-0.3.2-cp39-cp39-macosx_10_9_universal2.whl
Algorithm | Hash digest | |
---|---|---|
SHA256 | c7b3fb3d4f6b39bcd4f6c4b98f406291f0d681a895490ee29a0f95bab850d53c |
|
MD5 | 5bd46ace68f1007080c52ba2c3afdcc5 |
|
BLAKE2b-256 | 7bbb4513133bccda7e66eb56ee38f68d1a8bbc81f072d00a40ee369c43f25ba9 |