-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Wrong binary accuracy with Jax #20178
Comments
The result is correct if I cast the second parameter of |
Hi @eli-osherovich- While updating state(met.update_state(x>0.5, res>0.5)), x>0.5 and res>0.5 are in boolean arrays. But BinaryAccuracy metrics accepts only numerical values(floats or integers) only. While running same code in tensorflow backend it is giving error message.
So in the JAX there should be same error message comes while giving boolean into BinaryAccuracy metrics. You can create new issue in JAX repo for adding the error message. |
We could consider casting the values to |
Hi @fchollet - I will raise PR for casting the values to floatx() in update_state(). |
I have some very strange results out of the `
Consider the code below:
I would expect to get 1 every single run. Instead I get some random result (close to 0.5).
Packages' versions (tf, keras, jax, np)
The text was updated successfully, but these errors were encountered: