Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Division by self not always "1.0" on JAX GPU, but consistently gives "1.0" on JAX CPU. #24807

Open
mattlevine22 opened this issue Nov 8, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@mattlevine22
Copy link

Description

Division by self is not always "1.0" on JAX GPU, but seems to consistently give "1.0" on JAX CPU. Observed for jax arrays, but not for floats.

Question: Is this expected in JAX? If so, is the expectation that a JAX user should build robustness against GPU/CPU floating point differences (as well as the specific divide by self \neq 1 case)?

Code for observing the issue

import jax
import jax.numpy as jnp

# JAX device check
print("************* Checking JAX device *************")

print("Running on jax device:{}".format(jax.devices()))
print("Running on jax device platform:{}".format(jax.devices()[0].platform))
print("***********************************************")


def self_div(x):
    return x / x

z_list = [0.0526315718889236, 0.987654321, 1.0, 1e-13, 1e-14, 1e-15]
for z in z_list:
    # do self_div for z and print w 16f
    print(f"z/z = {self_div(z):.16f} for z = {z} of type {type(z)}")

    # now convert to jnp array
    z = jnp.array(z)
    print(f"z/z = {self_div(z):.16f} for z = {z} of type {type(z)}")

CPU environment

jax==0.4.35
jaxlib==0.4.35
ml_dtypes==0.5.0
numpy==2.1.3
opt_einsum==3.4.0
scipy==1.14.1

Running on CPU environment yields:

************* Checking JAX device *************
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Running on jax device:[CpuDevice(id=0)]
Running on jax device platform:cpu
***********************************************
z/z = 1.0000000000000000 for z = 0.0526315718889236 of type <class 'float'>
z/z = 1.0000000000000000 for z = 0.052631571888923645 of type <class 'jaxlib.xla_extension.ArrayImpl'>
z/z = 1.0000000000000000 for z = 0.987654321 of type <class 'float'>
z/z = 1.0000000000000000 for z = 0.9876543283462524 of type <class 'jaxlib.xla_extension.ArrayImpl'>
z/z = 1.0000000000000000 for z = 1.0 of type <class 'float'>
z/z = 1.0000000000000000 for z = 1.0 of type <class 'jaxlib.xla_extension.ArrayImpl'>
z/z = 1.0000000000000000 for z = 1e-13 of type <class 'float'>
z/z = 1.0000000000000000 for z = 9.9999998245167e-14 of type <class 'jaxlib.xla_extension.ArrayImpl'>
z/z = 1.0000000000000000 for z = 1e-14 of type <class 'float'>
z/z = 1.0000000000000000 for z = 9.9999998245167e-15 of type <class 'jaxlib.xla_extension.ArrayImpl'>
z/z = 1.0000000000000000 for z = 1e-15 of type <class 'float'>
z/z = 1.0000000000000000 for z = 1.0000000036274937e-15 of type <class 'jaxlib.xla_extension.ArrayImpl'>

GPU environment

jax==0.4.35
jax-cuda12-pjrt==0.4.35
jax-cuda12-plugin==0.4.35
jaxlib==0.4.34
ml_dtypes==0.5.0
numpy==2.1.3
nvidia-cublas-cu12==12.6.3.3
nvidia-cuda-cupti-cu12==12.6.80
nvidia-cuda-nvcc-cu12==12.6.77
nvidia-cuda-runtime-cu12==12.6.77
nvidia-cudnn-cu12==9.5.1.17
nvidia-cufft-cu12==11.3.0.4
nvidia-cusolver-cu12==11.7.1.2
nvidia-cusparse-cu12==12.5.4.2
nvidia-nccl-cu12==2.23.4
nvidia-nvjitlink-cu12==12.6.77
opt_einsum==3.4.0
scipy==1.14.1

Running on GPU environment yields:

************* Checking JAX device *************
Running on jax device:[CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3)]
Running on jax device platform:gpu
***********************************************
z/z = 1.0000000000000000 for z = 0.0526315718889236 of type <class 'float'>
z/z = 1.0000001192092896 for z = 0.052631571888923645 of type <class 'jaxlib.xla_extension.ArrayImpl'>
z/z = 1.0000000000000000 for z = 0.987654321 of type <class 'float'>
z/z = 0.9999999403953552 for z = 0.9876543283462524 of type <class 'jaxlib.xla_extension.ArrayImpl'>
z/z = 1.0000000000000000 for z = 1.0 of type <class 'float'>
z/z = 1.0000000000000000 for z = 1.0 of type <class 'jaxlib.xla_extension.ArrayImpl'>
z/z = 1.0000000000000000 for z = 1e-13 of type <class 'float'>
z/z = 1.0000001192092896 for z = 9.9999998245167e-14 of type <class 'jaxlib.xla_extension.ArrayImpl'>
z/z = 1.0000000000000000 for z = 1e-14 of type <class 'float'>
z/z = 1.0000000000000000 for z = 9.9999998245167e-15 of type <class 'jaxlib.xla_extension.ArrayImpl'>
z/z = 1.0000000000000000 for z = 1e-15 of type <class 'float'>
z/z = 1.0000000000000000 for z = 1.0000000036274937e-15 of type <class 'jaxlib.xla_extension.ArrayImpl'>

System info (python version, jaxlib version, accelerator, etc.)

System Info on GPU environment

Python 3.10.15 (main, Oct  3 2024, 07:27:34) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax; jax.print_environment_info()
jax:    0.4.35
jaxlib: 0.4.34
numpy:  2.1.3
python: 3.10.15 (main, Oct  3 2024, 07:27:34) [GCC 11.2.0]
device info: NVIDIA RTX A5000-4, 4 local devices"
process_count: 1
platform: uname_result(system='Linux', node='1sum-701-c10-vector01', release='5.15.0-124-generic', version='#134-Ubuntu SMP Fri Sep 27 20:20:17 UTC 2024', machine='x86_64')


$ nvidia-smi
Fri Nov  8 15:57:58 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.120                Driver Version: 550.120        CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA RTX A5000               Off |   00000000:01:00.0 Off |                  Off |
| 30%   31C    P2             24W /  230W |     220MiB /  24564MiB |      2%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA RTX A5000               Off |   00000000:2C:00.0 Off |                  Off |
| 30%   40C    P2             26W /  230W |     220MiB /  24564MiB |      1%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA RTX A5000               Off |   00000000:41:00.0 Off |                  Off |
| 30%   39C    P2             28W /  230W |     220MiB /  24564MiB |      2%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA RTX A5000               Off |   00000000:61:00.0 Off |                  Off |
| 30%   37C    P2             28W /  230W |     220MiB /  24564MiB |      2%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      1887      G   /usr/lib/xorg/Xorg                              4MiB |
|    0   N/A  N/A    274456      C   python                                        204MiB |
|    1   N/A  N/A      1887      G   /usr/lib/xorg/Xorg                              4MiB |
|    1   N/A  N/A    274456      C   python                                        204MiB |
|    2   N/A  N/A      1887      G   /usr/lib/xorg/Xorg                              4MiB |
|    2   N/A  N/A    274456      C   python                                        204MiB |
|    3   N/A  N/A      1887      G   /usr/lib/xorg/Xorg                              4MiB |
|    3   N/A  N/A    274456      C   python                                        204MiB |
+-----------------------------------------------------------------------------------------+

System Info on CPU environment

Python 3.10.15 (main, Oct  3 2024, 07:27:34) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax; jax.print_environment_info()
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
jax:    0.4.35
jaxlib: 0.4.35
numpy:  2.1.3
python: 3.10.15 (main, Oct  3 2024, 07:27:34) [GCC 11.2.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='1sum-701-c10-vector01', release='5.15.0-124-generic', version='#134-Ubuntu SMP Fri Sep 27 20:20:17 UTC 2024', machine='x86_64')


$ nvidia-smi
Fri Nov  8 15:58:50 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.120                Driver Version: 550.120        CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA RTX A5000               Off |   00000000:01:00.0 Off |                  Off |
| 30%   31C    P8             20W /  230W |      10MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA RTX A5000               Off |   00000000:2C:00.0 Off |                  Off |
| 30%   41C    P8             22W /  230W |      10MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA RTX A5000               Off |   00000000:41:00.0 Off |                  Off |
| 30%   40C    P8             21W /  230W |      10MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA RTX A5000               Off |   00000000:61:00.0 Off |                  Off |
| 30%   37C    P8             21W /  230W |      10MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      1887      G   /usr/lib/xorg/Xorg                              4MiB |
|    1   N/A  N/A      1887      G   /usr/lib/xorg/Xorg                              4MiB |
|    2   N/A  N/A      1887      G   /usr/lib/xorg/Xorg                              4MiB |
|    3   N/A  N/A      1887      G   /usr/lib/xorg/Xorg                              4MiB |
+-----------------------------------------------------------------------------------------+
@mattlevine22 mattlevine22 added the bug Something isn't working label Nov 8, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Nov 8, 2024

Hi - thanks for the report! Though this may be a bit surprising, I think it's consistent with the expected precision of floating point operations: namely, the outputs in each case are within 1 ULP of the true result at float32 precision.

I think @hawkinsp may be able to say more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants