Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bfloat16 @ bfloat16 -> float32? #235

Open
njzjz opened this issue Nov 13, 2024 · 1 comment
Open

bfloat16 @ bfloat16 -> float32? #235

njzjz opened this issue Nov 13, 2024 · 1 comment
Assignees

Comments

@njzjz
Copy link

njzjz commented Nov 13, 2024

I see different behaviors in NumPy and JAX.

In NumPy,

>>> import numpy as np
>>> a=np.ones((4,4), dtype=ml_dtypes.bfloat16)
>>> a@a
array([[4., 4., 4., 4.],
       [4., 4., 4., 4.],
       [4., 4., 4., 4.],
       [4., 4., 4., 4.]], dtype=float32)

The data type of the output is float32.

In JAX,

>>> import jax.numpy as jnp
>>> b=jnp.asarray(a)
>>> b@b
Array([[4, 4, 4, 4],
       [4, 4, 4, 4],
       [4, 4, 4, 4],
       [4, 4, 4, 4]], dtype=bfloat16)

The data type of the output is bfloat16.

I need clarification about this behavior. I don't see any documentation about it, so I am unsure whether it is a bug or a feature.

@jakevdp
Copy link
Collaborator

jakevdp commented Nov 13, 2024

The ml_dtypes dot implementation outputs bfloat16:

template <typename T>
void NPyCustomFloat_DotFunc(void* ip1, npy_intp is1, void* ip2, npy_intp is2,
void* op, npy_intp n, void* arr) {
char* c1 = reinterpret_cast<char*>(ip1);
char* c2 = reinterpret_cast<char*>(ip2);
float acc = 0.0f;
for (npy_intp i = 0; i < n; ++i) {
T* const b1 = reinterpret_cast<T*>(c1);
T* const b2 = reinterpret_cast<T*>(c2);
acc += static_cast<float>(*b1) * static_cast<float>(*b2);
c1 += is1;
c2 += is2;
}
T* out = reinterpret_cast<T*>(op);
*out = static_cast<T>(acc);
}

This is what NumPy calls into for np.matmul, so something on NumPy's side appears to be doing this implicit conversion. That said, np.dot outputs bfloat16:

In [1]: import numpy as np

In [2]: import ml_dtypes

In [3]: x = np.ones((4, 4), dtype=ml_dtypes.bfloat16)

In [4]: np.dot(x, x)
Out[4]: 
array([[4, 4, 4, 4],
       [4, 4, 4, 4],
       [4, 4, 4, 4],
       [4, 4, 4, 4]], dtype=bfloat16)

In JAX, the output dtype will always match the input dtype unless otherwise specified.

@jakevdp jakevdp self-assigned this Nov 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants