Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Different results on different GPUs #28

Closed
jaeminoh opened this issue Mar 20, 2024 · 4 comments
Closed

Different results on different GPUs #28

jaeminoh opened this issue Mar 20, 2024 · 4 comments

Comments

@jaeminoh
Copy link

Hi f0uriest,

I encountered an issue that interpolation results vary along different machines.

I used a 1d interpolator with the monotonic method, allowing extrap=True.

test machines: [CPU, RTX Titan, RTX 4090].
reference machine: CPU with double precision (x64).

below table presents relative $L^1$ error: abs(a - b).sum() / abs(b).sum()

precision CPU RTX Titan RTX 4090
x32 5.87719e-08 5.89367e-08 1.78212e-04
x64 reference 4.16375e-17 4.16375e-17

Since I used the same (xq, xp, yp), the errors of each row must coincide, respectively.

However, as you can see, interpolation on RTX 4090 with single precision produced quite an inaccurate result.

Do you have any ideas on this?

@f0uriest
Copy link
Owner

f0uriest commented Mar 22, 2024

Do you know if this is specific to interpax? It's likely it's a more general JAX issue (or really a CUDA/XLA issue) that things get compiled differently for different hardware, see google/jax#20371 and google/jax#10674 (comment)

Also, is the error uniformly bad for all points being interpolated, or is it localized in some way?

@jaeminoh
Copy link
Author

Hi! Thank you for the reply.

I believe it's related to a general JAX-related issue since I could not observe machine-specific implementation in interpax.
But I don't know where to start to fix it 😅

Here I attached two images, which present relative pointwise error abs(a - b) / abs(b).

4090_x32

This is for 4090 with single precision,

4090_x64

and this is for 4090 with double precision.
Numbers on the axes are just indices.

For the left vertical edge of the figures, xq is monotonically increasing (from 0 to 1)
So I would say that the error is uniformly bad.

In fact, my query points xq were loaded from Excel files using pandas.read_excel, so this could've been a cause.
So I switched my query points xq to numpy.linspace(0, 1, 1000), and observed the same issue again.
On Titan, $\approx 10^{-6}$, however on 4090, $\approx 10^{-2}$ for the relative $L^1$ error, where the baseline is Cpu results with x64 arithemetic.

@f0uriest
Copy link
Owner

Can you share some code/data that seems to reproduce the issue? I don't have access to either of those GPUs but I can try some others and see if its a more general issue.

@jaeminoh
Copy link
Author

jaeminoh commented Apr 4, 2024

Hi, I think I found the cause.

I ran the test with
NVIDIA_TF32_OVERRIDE=0,
I got the correct result:

overpotential_rtx_4090

It might be related to this issue (default TF32 overriding of JAX):
google/jax#7010 (comment)
patrick-kidger/diffrax#213 (comment)

@jaeminoh jaeminoh closed this as completed Apr 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants