Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. #16901

Closed
TInaWangxue opened this issue Jul 31, 2023 · 2 comments
Closed
Labels
bug Something isn't working

Comments

@TInaWangxue
Copy link

TInaWangxue commented Jul 31, 2023

Description

The OUTPUT:

[1]
2023-07-31 01:53:45.016563: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:427] Loaded runtime CuDNN library: 8.5.0 but source was compiled with: 8.6.0. CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library. If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.
[2]
XlaRuntimeError Traceback (most recent call last)
Cell In[4], line 29
26 model = trainer.make_model(nmask)
28 lr_fn, opt = trainer.make_optimizer(steps_per_epoch=len(train_dl))
---> 29 state = trainer.create_train_state(jax.random.PRNGKey(0), model, opt)
30 state = checkpoints.restore_checkpoint(ckpt.parent, state)

File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/random.py:137, in PRNGKey(seed)
134 if np.ndim(seed):
135 raise TypeError("PRNGKey accepts a scalar seed, but was given an array of"
136 f"shape {np.shape(seed)} != (). Use jax.vmap for batching")
--> 137 key = prng.seed_with_impl(impl, seed)
138 return _return_prng_keys(True, key)

File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/prng.py:320, in seed_with_impl(impl, seed)
319 def seed_with_impl(impl: PRNGImpl, seed: Union[int, Array]) -> PRNGKeyArrayImpl:
--> 320 return random_seed(seed, impl=impl)

File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/prng.py:734, in random_seed(seeds, impl)
732 else:
733 seeds_arr = jnp.asarray(seeds)
--> 734 return random_seed_p.bind(seeds_arr, impl=impl)

File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/core.py:380, in Primitive.bind(self, *args, **params)
377 def bind(self, *args, **params):
378 assert (not config.jax_enable_checks or
379 all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 380 return self.bind_with_trace(find_top_trace(args), args, params)

File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/core.py:383, in Primitive.bind_with_trace(self, trace, args, params)
382 def bind_with_trace(self, trace, args, params):
--> 383 out = trace.process_primitive(self, map(trace.full_raise, args), params)
384 return map(full_lower, out) if self.multiple_results else full_lower(out)

File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/core.py:790, in EvalTrace.process_primitive(self, primitive, tracers, params)
789 def process_primitive(self, primitive, tracers, params):
--> 790 return primitive.impl(*tracers, **params)

File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/prng.py:746, in random_seed_impl(seeds, impl)
744 @random_seed_p.def_impl
745 def random_seed_impl(seeds, *, impl):
--> 746 base_arr = random_seed_impl_base(seeds, impl=impl)
747 return PRNGKeyArrayImpl(impl, base_arr)

File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/prng.py:751, in random_seed_impl_base(seeds, impl)
749 def random_seed_impl_base(seeds, *, impl):
750 seed = iterated_vmap_unary(seeds.ndim, impl.seed)
--> 751 return seed(seeds)

File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/prng.py:980, in threefry_seed(seed)
968 def threefry_seed(seed: typing.Array) -> typing.Array:
969 """Create a single raw threefry PRNG key from an integer seed.
970
971 Args:
(...)
978 first padding out with zeros).
979 """
--> 980 return _threefry_seed(seed)

[... skipping hidden 12 frame]

File /mnt/data/miniconda/envs/energy_transformer_117/lib/python3.11/site-packages/jax/_src/dispatch.py:463, in backend_compile(backend, module, options, host_callbacks)
458 return backend.compile(built_c, compile_options=options,
459 host_callbacks=host_callbacks)
460 # Some backends don't have host_callbacks option yet
461 # TODO(sharadmv): remove this fallback when all backends allow compile
462 # to take in host_callbacks
--> 463 return backend.compile(built_c, compile_options=options)

XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
`

What jax/jaxlib version are you using?

jax0.4.10, jaxlib0.4.10+cuda11.cudnn86

Which accelerator(s) are you using?

GPU

Additional system info

python3.11.4, Ubuntu22.04, cuda11.7,cudnn86

NVIDIA GPU info

`+-----------------------------------------------------------------------------+
| NVIDIA-SMI 520.61.05 Driver Version: 520.61.05 CUDA Version: 11.8 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA GeForce ... On | 00000000:18:00.0 Off | N/A |
| 30% 33C P8 22W / 350W | 258MiB / 12288MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 1 NVIDIA GeForce ... On | 00000000:3B:00.0 Off | N/A |
| 30% 30C P8 6W / 350W | 8MiB / 12288MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 2 NVIDIA GeForce ... On | 00000000:86:00.0 Off | N/A |
| 30% 34C P8 23W / 350W | 8MiB / 12288MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 3 NVIDIA GeForce ... On | 00000000:AF:00.0 Off | N/A |
| 30% 30C P8 10W / 350W | 8MiB / 12288MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 2565 G /usr/lib/xorg/Xorg 4MiB |
| 0 N/A N/A 2861 C+G ...ome-remote-desktop-daemon 249MiB |
| 1 N/A N/A 2565 G /usr/lib/xorg/Xorg 4MiB |
| 2 N/A N/A 2565 G /usr/lib/xorg/Xorg 4MiB |
| 3 N/A N/A 2565 G /usr/lib/xorg/Xorg 4MiB |
+-----------------------------------------------------------------------------+
`

@TInaWangxue TInaWangxue added the bug Something isn't working label Jul 31, 2023
@hawkinsp
Copy link
Collaborator

The error message says what's wrong:

 Loaded runtime CuDNN library: 8.5.0 but source was compiled with: 8.6.0. CuDNN

You installed a version of jax that needs CuDNN 8.6, but CuDNN 8.5 was found. I suggest reinstalling using the cuda11_pip or cuda12_pip packages, in a fresh virtual environment.

Hope that helps!

@TInaWangxue
Copy link
Author

@hawkinsp Thank you so much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants