loongson/pypi/: ml-dtypes-0.5.3 metadata and description

Simple index Mirror page

ml_dtypes is a stand-alone implementation of several NumPy dtype extensions used in machine learning.

author_email ml_dtypes authors <ml_dtypes@google.com>
classifiers
  • Programming Language :: Python :: 3
  • Programming Language :: Python :: 3 :: Only
  • Intended Audience :: Science/Research
description_content_type text/markdown
project_urls
  • homepage, https://github.com/jax-ml/ml_dtypes
  • repository, https://github.com/jax-ml/ml_dtypes
requires_dist
  • numpy>=1.21
  • numpy>=1.21.2; python_version >= "3.10"
  • numpy>=1.23.3; python_version >= "3.11"
  • numpy>=1.26.0; python_version >= "3.12"
  • numpy>=2.1.0; python_version >= "3.13"
  • absl-py; extra == "dev"
  • pytest; extra == "dev"
  • pytest-xdist; extra == "dev"
  • pylint>=2.6.0; extra == "dev"
  • pyink; extra == "dev"
requires_python >=3.9
File Tox results History
ml_dtypes-0.5.3-cp310-cp310-manylinux_2_36_loongarch64.manylinux_2_38_loongarch64.whl
Size
4 MB
Type
Python Wheel
Python
3.10
ml_dtypes-0.5.3-cp310-cp310-musllinux_1_2_loongarch64.whl
Size
5 MB
Type
Python Wheel
Python
3.10
ml_dtypes-0.5.3-cp311-cp311-manylinux_2_36_loongarch64.manylinux_2_38_loongarch64.whl
Size
4 MB
Type
Python Wheel
Python
3.11
ml_dtypes-0.5.3-cp311-cp311-musllinux_1_2_loongarch64.whl
Size
5 MB
Type
Python Wheel
Python
3.11
ml_dtypes-0.5.3-cp312-cp312-manylinux_2_36_loongarch64.manylinux_2_38_loongarch64.whl
Size
4 MB
Type
Python Wheel
Python
3.12
ml_dtypes-0.5.3-cp312-cp312-musllinux_1_2_loongarch64.whl
Size
5 MB
Type
Python Wheel
Python
3.12
ml_dtypes-0.5.3-cp313-cp313-manylinux_2_36_loongarch64.manylinux_2_38_loongarch64.whl
Size
4 MB
Type
Python Wheel
Python
3.13
ml_dtypes-0.5.3-cp313-cp313-musllinux_1_2_loongarch64.whl
Size
5 MB
Type
Python Wheel
Python
3.13
ml_dtypes-0.5.3-cp313-cp313t-manylinux_2_36_loongarch64.manylinux_2_38_loongarch64.whl
Size
4 MB
Type
Python Wheel
Python
3.13
ml_dtypes-0.5.3-cp313-cp313t-musllinux_1_2_loongarch64.whl
Size
5 MB
Type
Python Wheel
Python
3.13
ml_dtypes-0.5.3-cp39-cp39-manylinux_2_36_loongarch64.manylinux_2_38_loongarch64.whl
Size
4 MB
Type
Python Wheel
Python
3.9
ml_dtypes-0.5.3-cp39-cp39-musllinux_1_2_loongarch64.whl
Size
5 MB
Type
Python Wheel
Python
3.9

ml_dtypes

Unittests Wheel Build PyPI version

ml_dtypes is a stand-alone implementation of several NumPy dtype extensions used in machine learning libraries, including:

See below for specifications of these number formats.

Installation

The ml_dtypes package is tested with Python versions 3.9-3.12, and can be installed with the following command:

pip install ml_dtypes

To test your installation, you can run the following:

pip install absl-py pytest
pytest --pyargs ml_dtypes

To build from source, clone the repository and run:

git submodule init
git submodule update
pip install .

Example Usage

>>> from ml_dtypes import bfloat16
>>> import numpy as np
>>> np.zeros(4, dtype=bfloat16)
array([0, 0, 0, 0], dtype=bfloat16)

Importing ml_dtypes also registers the data types with numpy, so that they may be referred to by their string name:

>>> np.dtype('bfloat16')
dtype(bfloat16)
>>> np.dtype('float8_e5m2')
dtype(float8_e5m2)

Specifications of implemented floating point formats

bfloat16

A bfloat16 number is a single-precision float truncated at 16 bits.

Exponent: 8, Mantissa: 7, exponent bias: 127. IEEE 754, with NaN and inf.

float4_e2m1fn

Exponent: 2, Mantissa: 1, bias: 1.

Extended range: no inf, no NaN.

Microscaling format, 4 bits (encoding: 0bSEEM) using byte storage (higher 4 bits are unused). NaN representation is undefined.

Possible absolute values: [0, 0.5, 1, 1.5, 2, 3, 4, 6]

float6_e2m3fn

Exponent: 2, Mantissa: 3, bias: 1.

Extended range: no inf, no NaN.

Microscaling format, 6 bits (encoding: 0bSEEMMM) using byte storage (higher 2 bits are unused). NaN representation is undefined.

Possible values range: [-7.5; 7.5]

float6_e3m2fn

Exponent: 3, Mantissa: 2, bias: 3.

Extended range: no inf, no NaN.

Microscaling format, 4 bits (encoding: 0bSEEEMM) using byte storage (higher 2 bits are unused). NaN representation is undefined.

Possible values range: [-28; 28]

float8_e3m4

Exponent: 3, Mantissa: 4, bias: 3. IEEE 754, with NaN and inf.

float8_e4m3

Exponent: 4, Mantissa: 3, bias: 7. IEEE 754, with NaN and inf.

float8_e4m3b11fnuz

Exponent: 4, Mantissa: 3, bias: 11.

Extended range: no inf, NaN represented by 0b1000'0000.

float8_e4m3fn

Exponent: 4, Mantissa: 3, bias: 7.

Extended range: no inf, NaN represented by 0bS111'1111.

The fn suffix is for consistency with the corresponding LLVM/MLIR type, signaling this type is not consistent with IEEE-754. The f indicates it is finite values only. The n indicates it includes NaNs, but only at the outer range.

float8_e4m3fnuz

8-bit floating point with 3 bit mantissa.

An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits mantissa. The suffix fnuz is consistent with LLVM/MLIR naming and is derived from the differences to IEEE floating point conventions. F is for "finite" (no infinities), N for with special NaN encoding, UZ for unsigned zero.

This type has the following characteristics:

float8_e5m2

Exponent: 5, Mantissa: 2, bias: 15. IEEE 754, with NaN and inf.

float8_e5m2fnuz

8-bit floating point with 2 bit mantissa.

An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits mantissa. The suffix fnuz is consistent with LLVM/MLIR naming and is derived from the differences to IEEE floating point conventions. F is for "finite" (no infinities), N for with special NaN encoding, UZ for unsigned zero.

This type has the following characteristics:

float8_e8m0fnu

OpenCompute MX scale format E8M0, which has the following properties:

int2, int4, uint2 and uint4

2 and 4-bit integer types, where each element is represented unpacked (i.e., padded up to a byte in memory).

NumPy does not support types smaller than a single byte: for example, the distance between adjacent elements in an array (.strides) is expressed as an integer number of bytes. Relaxing this restriction would be a considerable engineering project. These types therefore use an unpacked representation, where each element of the array is padded up to a byte in memory. The lower two or four bits of each byte contain the representation of the number, whereas the remaining upper bits are ignored.

Quirks of low-precision Arithmetic

If you're exploring the use of low-precision dtypes in your code, you should be careful to anticipate when the precision loss might lead to surprising results. One example is the behavior of aggregations like sum; consider this bfloat16 summation in NumPy (run with version 1.24.2):

>>> from ml_dtypes import bfloat16
>>> import numpy as np
>>> rng = np.random.default_rng(seed=0)
>>> vals = rng.uniform(size=10000).astype(bfloat16)
>>> vals.sum()
256

The true sum should be close to 5000, but numpy returns exactly 256: this is because bfloat16 does not have the precision to increment 256 by values less than 1:

>>> bfloat16(256) + bfloat16(1)
256

After 256, the next representable value in bfloat16 is 258:

>>> np.nextafter(bfloat16(256), bfloat16(np.inf))
258

For better results you can specify that the accumulation should happen in a higher-precision type like float32:

>>> vals.sum(dtype='float32').astype(bfloat16)
4992

In contrast to NumPy, projects like JAX which support low-precision arithmetic more natively will often do these kinds of higher-precision accumulations automatically:

>>> import jax.numpy as jnp
>>> jnp.array(vals).sum()
Array(4992, dtype=bfloat16)

License

This is not an officially supported Google product.

The ml_dtypes source code is licensed under the Apache 2.0 license (see LICENSE). Pre-compiled wheels are built with the EIGEN project, which is released under the MPL 2.0 license (see LICENSE.eigen).