Eorrs about cudnn
#19598
Replies: 1 comment 2 replies
-
Thanks for the question – this likely has to do with jaxlib being built for a different architecture than your GPU chip. See #15361 for a similar report. Can you paste the output of the |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Below is my jaxlib version, and I met some problems when I run my code
I0131 17:03:55.608776 47229006633088 xla_bridge.py:450] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Interpreter Host CUDA
I0131 17:03:55.610385 47229006633088 xla_bridge.py:450] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
I0131 17:03:55.610747 47229006633088 xla_bridge.py:450] Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
2024-01-31 17:03:56.336248: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2024-01-31 17:03:56.336349: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 33672790016 bytes free, 34079637504 bytes total.
Traceback (most recent call last):
File "/scratch/work/guoq2/my_fisor/launcher/examples/train_offline.py", line 123, in
app.run(main)
File "/home/guoq2/.conda/envs/FISOR/lib/python3.9/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/guoq2/.conda/envs/FISOR/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/scratch/work/guoq2/my_fisor/launcher/examples/train_offline.py", line 119, in main
call_main(parameters)
File "/scratch/work/guoq2/my_fisor/launcher/examples/train_offline.py", line 62, in call_main
agent = globals()[model_cls].create(
File "/scratch/work/guoq2/my_fisor/./jaxrl5/agents/fisor/fisor.py", line 125, in create
rng = jax.random.PRNGKey(seed)
File "/home/guoq2/.conda/envs/FISOR/lib/python3.9/site-packages/jax/_src/random.py", line 137, in PRNGKey
key = prng.seed_with_impl(impl, seed)
File "/home/guoq2/.conda/envs/FISOR/lib/python3.9/site-packages/jax/_src/prng.py", line 320, in seed_with_impl
return random_seed(seed, impl=impl)
File "/home/guoq2/.conda/envs/FISOR/lib/python3.9/site-packages/jax/_src/prng.py", line 732, in random_seed
return random_seed_p.bind(seeds_arr, impl=impl)
File "/home/guoq2/.conda/envs/FISOR/lib/python3.9/site-packages/jax/_src/prng.py", line 744, in random_seed_impl
base_arr = random_seed_impl_base(seeds, impl=impl)
File "/home/guoq2/.conda/envs/FISOR/lib/python3.9/site-packages/jax/_src/prng.py", line 749, in random_seed_impl_base
return seed(seeds)
File "/home/guoq2/.conda/envs/FISOR/lib/python3.9/site-packages/jax/_src/prng.py", line 978, in threefry_seed
return _threefry_seed(seed)
File "/home/guoq2/.conda/envs/FISOR/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/guoq2/.conda/envs/FISOR/lib/python3.9/site-packages/jax/_src/pjit.py", line 208, in cache_miss
outs, out_flat, out_tree, args_flat = _python_pjit_helper(
File "/home/guoq2/.conda/envs/FISOR/lib/python3.9/site-packages/jax/_src/pjit.py", line 155, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **params)
File "/home/guoq2/.conda/envs/FISOR/lib/python3.9/site-packages/jax/_src/core.py", line 2633, in bind
return self.bind_with_trace(top_trace, args, params)
File "/home/guoq2/.conda/envs/FISOR/lib/python3.9/site-packages/jax/_src/core.py", line 383, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/guoq2/.conda/envs/FISOR/lib/python3.9/site-packages/jax/_src/core.py", line 790, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/guoq2/.conda/envs/FISOR/lib/python3.9/site-packages/jax/_src/pjit.py", line 1085, in _pjit_call_impl
compiled = _pjit_lower(
File "/home/guoq2/.conda/envs/FISOR/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2313, in compile
executable = UnloadedMeshExecutable.from_hlo(
File "/home/guoq2/.conda/envs/FISOR/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2633, in from_hlo
xla_executable, compile_options = _cached_compilation(
File "/home/guoq2/.conda/envs/FISOR/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2551, in _cached_compilation
xla_executable = dispatch.compile_or_get_cached(
File "/home/guoq2/.conda/envs/FISOR/lib/python3.9/site-packages/jax/_src/dispatch.py", line 494, in compile_or_get_cached
return backend_compile(backend, computation, compile_options,
File "/home/guoq2/.conda/envs/FISOR/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/guoq2/.conda/envs/FISOR/lib/python3.9/site-packages/jax/_src/dispatch.py", line 462, in backend_compile
return backend.compile(built_c, compile_options=options)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/scratch/work/guoq2/my_fisor/launcher/examples/train_offline.py", line 123, in
app.run(main)
File "/home/guoq2/.conda/envs/FISOR/lib/python3.9/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/guoq2/.conda/envs/FISOR/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/scratch/work/guoq2/my_fisor/launcher/examples/train_offline.py", line 119, in main
call_main(parameters)
File "/scratch/work/guoq2/my_fisor/launcher/examples/train_offline.py", line 62, in call_main
agent = globals()[model_cls].create(
File "/scratch/work/guoq2/my_fisor/./jaxrl5/agents/fisor/fisor.py", line 125, in create
rng = jax.random.PRNGKey(seed)
File "/home/guoq2/.conda/envs/FISOR/lib/python3.9/site-packages/jax/_src/random.py", line 137, in PRNGKey
key = prng.seed_with_impl(impl, seed)
File "/home/guoq2/.conda/envs/FISOR/lib/python3.9/site-packages/jax/_src/prng.py", line 320, in seed_with_impl
return random_seed(seed, impl=impl)
File "/home/guoq2/.conda/envs/FISOR/lib/python3.9/site-packages/jax/_src/prng.py", line 732, in random_seed
return random_seed_p.bind(seeds_arr, impl=impl)
File "/home/guoq2/.conda/envs/FISOR/lib/python3.9/site-packages/jax/_src/core.py", line 380, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/guoq2/.conda/envs/FISOR/lib/python3.9/site-packages/jax/_src/core.py", line 383, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/guoq2/.conda/envs/FISOR/lib/python3.9/site-packages/jax/_src/core.py", line 790, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/guoq2/.conda/envs/FISOR/lib/python3.9/site-packages/jax/_src/prng.py", line 744, in random_seed_impl
base_arr = random_seed_impl_base(seeds, impl=impl)
File "/home/guoq2/.conda/envs/FISOR/lib/python3.9/site-packages/jax/_src/prng.py", line 749, in random_seed_impl_base
return seed(seeds)
File "/home/guoq2/.conda/envs/FISOR/lib/python3.9/site-packages/jax/_src/prng.py", line 978, in threefry_seed
return _threefry_seed(seed)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
Can I get some help?
also the command used to install jax is:
pip install jax==0.4.9
pip install jaxlib==0.4.9+cuda12.cudnn88 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Beta Was this translation helpful? Give feedback.
All reactions