Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jnp.linalg.solve() not precise on TPU #13224

Closed
xinglong-li opened this issue Nov 13, 2022 · 1 comment
Closed

jnp.linalg.solve() not precise on TPU #13224

xinglong-li opened this issue Nov 13, 2022 · 1 comment
Labels
bug Something isn't working TPU Issues related to compilation for or performance on Google TPUs

Comments

@xinglong-li
Copy link

xinglong-li commented Nov 13, 2022

Description

I'm implementing my jax model on TPU and my model involves a step computing the inverse of a matrix. I use the function jnp.linalg.solve() and found that the result is not precise enough. Here is an example to reproduce the issue:

import jax.numpy as jnp

X = jnp.array([[280., 361., 145.],
               [361., 657., 247.],
               [145., 247., 99.]])

print(X @ jnp.linalg.solve(X, jnp.eye(3)))

and the output is

[[ 0.9968567  -0.0453186  -0.02929688]
 [-0.01763916  0.9393616  -0.078125  ]
 [-0.00104141 -0.01418686  0.93688965]]

which is relatively 'far' from the identity matrix (and not symmetric), making the following steps of my model incorrect.
Given that the matrix to be inversed is neither large nor extreme, I expect its inverse matrix can be computed precisely. (It is precise if executed on CPU).
@murphyk

What jax/jaxlib version are you using?

jax v0.3.24

Which accelerator(s) are you using?

TPU

Additional system info

Ubuntu 20.04.5 LTS

NVIDIA GPU info

No response

@hawkinsp
Copy link
Collaborator

hawkinsp commented Nov 14, 2022

Thanks for the question.

It's actually not the solve that is imprecise. It's the @ (matrix multiplication) that is imprecise on TPU!

One way to demonstrate this is to perform the matrix multiplication using classic NumPy on CPU:

> import numpy as np
> np.asarray(X) @ np.asarray(jnp.linalg.solve(X, jnp.eye(3)))

array([[ 9.9999976e-01, -9.5367432e-07,  1.9073486e-06],
       [-4.7683716e-07,  1.0000000e+00, -3.8146973e-06],
       [ 0.0000000e+00,  4.7683716e-07,  9.9999809e-01]], dtype=float32)

i.e., the result of the solve is actually pretty good.

By default TPU performs float32 matrix multiplications by truncating the inputs to bfloat16. Something similar happens on modern GPUs (>= Ampere) that perform "TF32" math which is not exactly the same but similar in spirit.

To get higher precision on TPU, you can opt into a higher precision matrix multiplication (https://jax.readthedocs.io/en/latest/_autosummary/jax.default_matmul_precision.html?highlight=matmul_precision).

For example:

import jax
import jax.numpy as jnp

jax.config.update('jax_default_matmul_precision', 'float32')

X = jnp.array([[280., 361., 145.],
               [361., 657., 247.],
               [145., 247., 99.]])

print(X @ jnp.linalg.solve(X, jnp.eye(3)))

prints:

[[ 9.9999988e-01 -2.1979213e-07  1.7136335e-06]
 [-9.1269612e-08  1.0000005e+00 -7.4505806e-08]
 [-2.5611371e-09  3.1106174e-07  9.9999952e-01]]

There's an open feature request for changing the default (#7010).

I hope that helps!

@hawkinsp hawkinsp added the TPU Issues related to compilation for or performance on Google TPUs label Nov 14, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working TPU Issues related to compilation for or performance on Google TPUs
Projects
None yet
Development

No branches or pull requests

2 participants