You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This is what NumPy calls into for np.matmul, so something on NumPy's side appears to be doing this implicit conversion. That said, np.dot outputs bfloat16:
I see different behaviors in NumPy and JAX.
In NumPy,
The data type of the output is
float32
.In JAX,
The data type of the output is
bfloat16
.I need clarification about this behavior. I don't see any documentation about it, so I am unsure whether it is a bug or a feature.
The text was updated successfully, but these errors were encountered: