You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Division by self is not always "1.0" on JAX GPU, but seems to consistently give "1.0" on JAX CPU. Observed for jax arrays, but not for floats.
Question: Is this expected in JAX? If so, is the expectation that a JAX user should build robustness against GPU/CPU floating point differences (as well as the specific divide by self \neq 1 case)?
Code for observing the issue
import jax
import jax.numpy as jnp
# JAX device check
print("************* Checking JAX device *************")
print("Running on jax device:{}".format(jax.devices()))
print("Running on jax device platform:{}".format(jax.devices()[0].platform))
print("***********************************************")
def self_div(x):
return x / x
z_list = [0.0526315718889236, 0.987654321, 1.0, 1e-13, 1e-14, 1e-15]
for z in z_list:
# do self_div for z and print w 16f
print(f"z/z = {self_div(z):.16f} for z = {z} of type {type(z)}")
# now convert to jnp array
z = jnp.array(z)
print(f"z/z = {self_div(z):.16f} for z = {z} of type {type(z)}")
************* Checking JAX device *************
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Running on jax device:[CpuDevice(id=0)]
Running on jax device platform:cpu
***********************************************
z/z = 1.0000000000000000 for z = 0.0526315718889236 of type <class 'float'>
z/z = 1.0000000000000000 for z = 0.052631571888923645 of type <class 'jaxlib.xla_extension.ArrayImpl'>
z/z = 1.0000000000000000 for z = 0.987654321 of type <class 'float'>
z/z = 1.0000000000000000 for z = 0.9876543283462524 of type <class 'jaxlib.xla_extension.ArrayImpl'>
z/z = 1.0000000000000000 for z = 1.0 of type <class 'float'>
z/z = 1.0000000000000000 for z = 1.0 of type <class 'jaxlib.xla_extension.ArrayImpl'>
z/z = 1.0000000000000000 for z = 1e-13 of type <class 'float'>
z/z = 1.0000000000000000 for z = 9.9999998245167e-14 of type <class 'jaxlib.xla_extension.ArrayImpl'>
z/z = 1.0000000000000000 for z = 1e-14 of type <class 'float'>
z/z = 1.0000000000000000 for z = 9.9999998245167e-15 of type <class 'jaxlib.xla_extension.ArrayImpl'>
z/z = 1.0000000000000000 for z = 1e-15 of type <class 'float'>
z/z = 1.0000000000000000 for z = 1.0000000036274937e-15 of type <class 'jaxlib.xla_extension.ArrayImpl'>
************* Checking JAX device *************
Running on jax device:[CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3)]
Running on jax device platform:gpu
***********************************************
z/z = 1.0000000000000000 for z = 0.0526315718889236 of type <class 'float'>
z/z = 1.0000001192092896 for z = 0.052631571888923645 of type <class 'jaxlib.xla_extension.ArrayImpl'>
z/z = 1.0000000000000000 for z = 0.987654321 of type <class 'float'>
z/z = 0.9999999403953552 for z = 0.9876543283462524 of type <class 'jaxlib.xla_extension.ArrayImpl'>
z/z = 1.0000000000000000 for z = 1.0 of type <class 'float'>
z/z = 1.0000000000000000 for z = 1.0 of type <class 'jaxlib.xla_extension.ArrayImpl'>
z/z = 1.0000000000000000 for z = 1e-13 of type <class 'float'>
z/z = 1.0000001192092896 for z = 9.9999998245167e-14 of type <class 'jaxlib.xla_extension.ArrayImpl'>
z/z = 1.0000000000000000 for z = 1e-14 of type <class 'float'>
z/z = 1.0000000000000000 for z = 9.9999998245167e-15 of type <class 'jaxlib.xla_extension.ArrayImpl'>
z/z = 1.0000000000000000 for z = 1e-15 of type <class 'float'>
z/z = 1.0000000000000000 for z = 1.0000000036274937e-15 of type <class 'jaxlib.xla_extension.ArrayImpl'>
System info (python version, jaxlib version, accelerator, etc.)
System Info on GPU environment
Python 3.10.15 (main, Oct 3 2024, 07:27:34) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax; jax.print_environment_info()
jax: 0.4.35
jaxlib: 0.4.34
numpy: 2.1.3
python: 3.10.15 (main, Oct 3 2024, 07:27:34) [GCC 11.2.0]
device info: NVIDIA RTX A5000-4, 4 local devices"
process_count: 1
platform: uname_result(system='Linux', node='1sum-701-c10-vector01', release='5.15.0-124-generic', version='#134-Ubuntu SMP Fri Sep 27 20:20:17 UTC 2024', machine='x86_64')
$ nvidia-smi
Fri Nov 8 15:57:58 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.120 Driver Version: 550.120 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA RTX A5000 Off | 00000000:01:00.0 Off | Off |
| 30% 31C P2 24W / 230W | 220MiB / 24564MiB | 2% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA RTX A5000 Off | 00000000:2C:00.0 Off | Off |
| 30% 40C P2 26W / 230W | 220MiB / 24564MiB | 1% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 2 NVIDIA RTX A5000 Off | 00000000:41:00.0 Off | Off |
| 30% 39C P2 28W / 230W | 220MiB / 24564MiB | 2% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 3 NVIDIA RTX A5000 Off | 00000000:61:00.0 Off | Off |
| 30% 37C P2 28W / 230W | 220MiB / 24564MiB | 2% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 1887 G /usr/lib/xorg/Xorg 4MiB |
| 0 N/A N/A 274456 C python 204MiB |
| 1 N/A N/A 1887 G /usr/lib/xorg/Xorg 4MiB |
| 1 N/A N/A 274456 C python 204MiB |
| 2 N/A N/A 1887 G /usr/lib/xorg/Xorg 4MiB |
| 2 N/A N/A 274456 C python 204MiB |
| 3 N/A N/A 1887 G /usr/lib/xorg/Xorg 4MiB |
| 3 N/A N/A 274456 C python 204MiB |
+-----------------------------------------------------------------------------------------+
System Info on CPU environment
Python 3.10.15 (main, Oct 3 2024, 07:27:34) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax; jax.print_environment_info()
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
jax: 0.4.35
jaxlib: 0.4.35
numpy: 2.1.3
python: 3.10.15 (main, Oct 3 2024, 07:27:34) [GCC 11.2.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='1sum-701-c10-vector01', release='5.15.0-124-generic', version='#134-Ubuntu SMP Fri Sep 27 20:20:17 UTC 2024', machine='x86_64')
$ nvidia-smi
Fri Nov 8 15:58:50 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.120 Driver Version: 550.120 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA RTX A5000 Off | 00000000:01:00.0 Off | Off |
| 30% 31C P8 20W / 230W | 10MiB / 24564MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA RTX A5000 Off | 00000000:2C:00.0 Off | Off |
| 30% 41C P8 22W / 230W | 10MiB / 24564MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 2 NVIDIA RTX A5000 Off | 00000000:41:00.0 Off | Off |
| 30% 40C P8 21W / 230W | 10MiB / 24564MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 3 NVIDIA RTX A5000 Off | 00000000:61:00.0 Off | Off |
| 30% 37C P8 21W / 230W | 10MiB / 24564MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 1887 G /usr/lib/xorg/Xorg 4MiB |
| 1 N/A N/A 1887 G /usr/lib/xorg/Xorg 4MiB |
| 2 N/A N/A 1887 G /usr/lib/xorg/Xorg 4MiB |
| 3 N/A N/A 1887 G /usr/lib/xorg/Xorg 4MiB |
+-----------------------------------------------------------------------------------------+
The text was updated successfully, but these errors were encountered:
Hi - thanks for the report! Though this may be a bit surprising, I think it's consistent with the expected precision of floating point operations: namely, the outputs in each case are within 1 ULP of the true result at float32 precision.
Description
Division by self is not always "1.0" on JAX GPU, but seems to consistently give "1.0" on JAX CPU. Observed for jax arrays, but not for floats.
Question: Is this expected in JAX? If so, is the expectation that a JAX user should build robustness against GPU/CPU floating point differences (as well as the specific divide by self \neq 1 case)?
Code for observing the issue
CPU environment
Running on CPU environment yields:
GPU environment
Running on GPU environment yields:
System info (python version, jaxlib version, accelerator, etc.)
System Info on GPU environment
System Info on CPU environment
The text was updated successfully, but these errors were encountered: