Resolving jaxlib conflicts with CUDA11.2 #22615
Unanswered
joshdorrington
asked this question in
Q&A
Replies: 1 comment
-
Newer versions of JAX do not support CUDA 11.2. See for example the JAX Change log, which mentions that support for CUDA 11.4 and below was dropped in JAX v0.4.8. Unfortunately, I think your options here are either to update CUDA on your cluster, or limit yourself to older versions of JAX and other packages. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi,
I am trying to run neuralgcm on my local cluster, which is unfortunately limited to CUDA 11.2, and its been suggested I ask over here for some help resolving package conflicts, if even possible.
The main issue is that neuralgcm requires Python>=3.10 and seemingly jaxlib>=4.27, but when I look here:
https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
I can't find a wheel that supports Python>=3.10, jaxlib>=4.27, and CUDA<=11.2
Am I missing any tricks here?
Details about my GPUs below, my effort to build a working python environment is attached, as is the code I'm trying to run. In the current min.txt, it doesn't make it past the jax import. In previous attempts it makes it to the end of the script but then fails due to CUDA and CUDNN incompatibilities.
Running nvidia-smi in terminal:
NVIDIA-SMI 460.106.00 Driver Version: 460.106.00 CUDA Version: 11.2 |
environment_configuration.txt
min.txt
Beta Was this translation helpful? Give feedback.
All reactions