Make GPU-JAX work when CUDA is not available #14208
Unanswered
carlosgmartin
asked this question in
Q&A
Replies: 2 comments 2 replies
-
Have you tried setting the platform? Maybe that avoids the folder check. import jax
jax.config.update("jax_platform_name", "cpu") You would have to do this at the very tippy-top of your code, so that no other JAX computation happens before it. |
Beta Was this translation helpful? Give feedback.
2 replies
-
You could set the
or in a script: import os
os.environ['JAX_PLATFORMS'] = 'cpu'
import jax |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I have a Slurm cluster with both GPU and non-GPU nodes.
CPU-JAX (obtained via
python3 -m pip install jax
) works on both types of nodes, but does not use the GPUs on the GPU nodes.GPU-JAX (obtained via
python3 -m pip install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
) works on the GPU nodes (and uses their GPUs), but fails on the non-GPU nodes with the following error:This is because the CUDA folder does not exist on the non-GPU nodes. What's the recommended way to address this? Can GPU-JAX be configured to ignore the above error and proceed with the CPUs only, when CUDA is not available?
Beta Was this translation helpful? Give feedback.
All reactions