Skip to content

A stand-alone implementation of several NumPy dtype extensions used in machine learning.

License

Notifications You must be signed in to change notification settings

jakeh-gc/ml_dtypes

 
 

Repository files navigation

ml_dtypes

Unittests Wheel Build PyPI version

This is not an officially supported Google product.

ml_dtypes is a stand-alone implementation of several NumPy dtype extensions used in machine learning libraries, including:

  • bfloat16: an alternative to the standard float16 format
  • float8_*: several experimental 8-bit floating point representations including:
    • float8_e4m3b11
    • float8_e4m3fn
    • float8_e5m2

Installation

The ml_dtypes package is tested with Python versions 3.8-3.11, and can be installed with the following command:

pip install ml_dtypes

To test your installation, you can run the following:

pip install absl-py pytest
pytest --pyargs ml_dtypes

To build from source, clone the repository and run:

git submodule init
git submodule update
pip install .

Example Usage

>>> from ml_dtypes import bfloat16
>>> import numpy as np
>>> np.zeros(4, dtype=bfloat16)
array([0, 0, 0, 0], dtype=bfloat16)

Importing ml_dtypes also registers the data types with numpy, so that they may be referred to by their string name:

>>> np.dtype('bfloat16')
dtype(bfloat16)
>>> np.dtype('float8_e5m2')
dtype(float8_e5m2)

License

The ml_dtypes source code is licensed under the Apache 2.0 license (see LICENSE). Pre-compiled wheels are built with the EIGEN project, which is released under the MPL 2.0 license (see LICENSE.eigen).

About

A stand-alone implementation of several NumPy dtype extensions used in machine learning.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • C++ 69.8%
  • Python 30.2%