Replies: 1 comment 2 replies
-
Which jax version do you have? Try the latest jax and jaxlib 0.4.31? This was a regression introduced in 0.4.30 but fixed in 0.4.31 |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi!
I'm a PhD student very happily using JAX for my smaller ML experiments locally or on limited private compute and recently I got access to an HPC cluster that I want to use to run longer running or larger scale experiments on. This cluster relies on its users to make docker containers including any dependencies. I made a Dockerfile which essentially just installs
jax[cuda12]
inside thepython:3.12.5-slim
image. This means that CUDA inside the container is installed through pip from JAX's extras.The resulting image works fine locally with
docker run --rm --gpus all <my_image_tag> <my_script>
, it sees the GPUs and runs fine, but it absolutely does not work when deployed on the HPC cluster, which is very curious. JAX does not see GPU assigned to the container from the cluster's job runner and defaults to the CPU.The thing is, I am not sure how to proceed debugging this issue because there are actually no logs produced from JAX. Locally in the past when I've gotten the JAX install wrong I've seen things like
No GPU/TPU found. Defaulting to CPU
usually following some errors that are informative about the part that's going wrong (cuDNN not found, cuda not found, etc). Inside the docker container, I can't get such logs to produce. I've triedTF_CPP_MAX_VLOG_LEVEL=2 TF_CPP_MIN_LOG_LEVEL=0
from another discussion and those flags do get some XLA logs to produce but they're not very informative, mostly logs about libraries successfully loading dynamically. I've also tried setting the Python logging level to debug withlogging.basicConfig(filename='example.log', encoding='utf-8', level=logging.DEBUG)
and that does get some jaxlib logs to produce:For reference these are the logs I see with `TF_CPP_MAX_VLOG_LEVEL=2` and `TF_CPP_MIN_LOG_LEVEL=0`
I am really not sure what else to do to debug why JAX isn't seeing the GPU from within the container. I can SSH onto the container itself on the cluster and run an interactive python shell as well as any other command (e.g.
nvidia-smi
which does show the GPU and its driver, which btw is 560 which has CUDA 12.6 support).Do you have any suggestions for things I should try to figure out what's going on? I'm thinking parts of the system's CUDA must be somehow interfering with the container's pip CUDA and for some reason JAX isn't loading CUDA / cuDNN from
site-packages
, but I've no idea how to find out what JAX is actually trying to load or what exactly is failing.Beta Was this translation helpful? Give feedback.
All reactions