Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numerical discrepancy when calculating the jacobian on TPU vs CPU #11069

Closed
schainmariano opened this issue Jun 11, 2022 · 2 comments
Closed

Comments

@schainmariano
Copy link

Calculating a Jacobian returns substantially different answers on CPU vs. TPU.

In the example below, we calculate the Jacobian, sum its rows and compare it to the gradient of the sum (calculated using gradient). When running on CPU, the answers are indeed almost identical, however, when running on TPU we get a large discrepancy.

import jax
import jax.numpy as jnp

rng = jax.random.PRNGKey(0)
n = 1000
m = 10000

A = jax.random.normal(rng, (n, m))


def f(x):
  return jnp.dot(A, x)

g = jax.jit(jax.jacrev(f))
g_sum = jax.jit(jax.grad(lambda x: jnp.sum((f(x)))))

x0 = jax.random.normal(rng, (m,))

print("max numerical error = ", jnp.max(jnp.abs(jnp.sum(g(x0), axis=0) - g_sum(x0))))

When running g on TPU we get

>>> max numerical error =  0.21506119

and on CPU we get

>>> max numerical error =  0.0001449585
@sharadmv
Copy link
Collaborator

sharadmv commented Jun 12, 2022

I think this is a matmul precision issue.

TPUs by default do bf16 matmuls. You can use the jax.default_matmul_precision context manager/flag to set the default precision to 'float32' which should be much closer to the CPU results (at the cost of being slower).

with jax.default_matmul_precision('float32'):
  z = jnp.dot(x, y)

Alternatively you can pass a precision directly into jax.lax.dot_general

@hawkinsp
Copy link
Collaborator

Correct. This is almost certainly related to the default matmul precision on TPU.

Closing as a duplicate of #7010

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants