You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Calculating a Jacobian returns substantially different answers on CPU vs. TPU.
In the example below, we calculate the Jacobian, sum its rows and compare it to the gradient of the sum (calculated using gradient). When running on CPU, the answers are indeed almost identical, however, when running on TPU we get a large discrepancy.
TPUs by default do bf16 matmuls. You can use the jax.default_matmul_precision context manager/flag to set the default precision to 'float32' which should be much closer to the CPU results (at the cost of being slower).
Calculating a Jacobian returns substantially different answers on CPU vs. TPU.
In the example below, we calculate the Jacobian, sum its rows and compare it to the gradient of the sum (calculated using gradient). When running on CPU, the answers are indeed almost identical, however, when running on TPU we get a large discrepancy.
When running g on TPU we get
and on CPU we get
The text was updated successfully, but these errors were encountered: