What wheel for jax should I use? #14104
Replies: 2 comments
-
The section "pip installation: GPU (CUDA)" at the main page says you should build the package from source. |
Beta Was this translation helpful? Give feedback.
-
11.0 is a fairly old CUDA version that is no longer supported by JAX, even if you build from source. I believe that the last published release with CUDA 11.0 compatibility was jaxlib 0.1.71, which is compatible with jax v0.2.19 (version numbers for compatible jax and jaxlib releases are aligned now, but previously were not). You can find the installation instructions for JAX of that era at https://github.com/google/jax/tree/jaxlib-v0.1.71#pip-installation-gpu-cuda; something like this would work:
That said, this version of jax is quite old, so you may run into other problems. You'll have a better experience if you update CUDA and use a more recent JAX release, if that's at all possible. |
Beta Was this translation helpful? Give feedback.
-
My CUDA version is 11.0, CuDNN version is 8.4.1, and python version is 3.8. Given those versions, I cannot find the correct Jax releases from here that suits my configuration. I was wondering if I could use CuDNN version 8.6. My question is will it be backward compatible and support 8.4?
I cannot upgrade or downgrade CUDA and CuDNN versions since I don't have admin access.
Beta Was this translation helpful? Give feedback.
All reactions