Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cudaErrorSymbolNotFound : named symbol not found #24749

Open
amanjitsk opened this issue Nov 6, 2024 · 3 comments
Open

cudaErrorSymbolNotFound : named symbol not found #24749

amanjitsk opened this issue Nov 6, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@amanjitsk
Copy link

Description

I am encountering this problem on a fresh install of JAX.

I just followed the instructions to install JAX (pip, bundled with CUDA) as

pip install -U "jax[cuda12]"

but I cannot run any code whatsoever. I searched for the specific error to see if there was a similar
issue, but doesn't seem like it.

I run a simple command like python -c "import jax; import jax.numpy as jnp; print(jnp.linspace(0, 1, 10))", and see
the follow error:

E1106 17:22:25.991066    6564 cuda_dnn.cc:502] There was an error before creating cudnn handle (500): cudaErrorSymbolNotFound : named symbol not found
E1106 17:22:25.993042    6564 cuda_dnn.cc:502] There was an error before creating cudnn handle (500): cudaErrorSymbolNotFound : named symbol not found
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/fs01/home/<user>/Projects/test/.venv/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 6677, in linspace
    return _linspace(start, stop, num, endpoint, retstep, dtype, axis, device=device)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

Appreciate any pointers!

System info (python version, jaxlib version, accelerator, etc.)

Command: python -c "import jax; jax.print_environment_info(); import jax.numpy as jnp; print(jnp.linspace(0, 1, 10))"

jax:    0.4.35
jaxlib: 0.4.34
numpy:  2.1.3
python: 3.10.12 (main, Jul 19 2023, 10:44:52) [GCC 7.5.0]
device info: Tesla T4-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='gpu067', release='5.4.0-131-generic', version='#147~18.04.1-Ubuntu SMP Sat Oct 15 13:10:18 UTC 2022', machine='x86_64')


$ nvidia-smi
Wed Nov  6 17:22:25 2024
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            On   | 00000000:07:00.0 Off |                    0 |
| N/A   36C    P0    26W /  70W |    105MiB / 15360MiB |      2%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A      6564      C   ...cts/test/.venv/bin/python      102MiB |
+-----------------------------------------------------------------------------+

E1106 17:22:25.991066    6564 cuda_dnn.cc:502] There was an error before creating cudnn handle (500): cudaErrorSymbolNotFound : named symbol not found
E1106 17:22:25.993042    6564 cuda_dnn.cc:502] There was an error before creating cudnn handle (500): cudaErrorSymbolNotFound : named symbol not found
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/fs01/home/amanjitsk/Projects/test/.venv/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 6677, in linspace
    return _linspace(start, stop, num, endpoint, retstep, dtype, axis, device=device)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
@amanjitsk amanjitsk added the bug Something isn't working label Nov 6, 2024
@hawkinsp
Copy link
Collaborator

hawkinsp commented Nov 7, 2024

Do you have another copy of CUDA or CUDNN on your system? My guess is we're finding the wrong one, perhaps.

@amanjitsk
Copy link
Author

I guess its possible in that I wouldn't rule it out, but its unlikely. I do not load any CUDA or CUDNN manually (using any environment variables). For example,

>>> env | grep CUD
CUDA_CACHE_PATH=/h/amanjitsk/.cache/nv
CUDA_VISIBLE_DEVICES=0

And checking pip as well (this was a fresh install in an venv with only JAX installed as pip install "jax[cuda12]"):

>>> pip list | grep -E 'nvidia|jax'
jax                      0.4.35
jax-cuda12-pjrt          0.4.35
jax-cuda12-plugin        0.4.35
jaxlib                   0.4.34
nvidia-cublas-cu12       12.6.3.3
nvidia-cuda-cupti-cu12   12.6.80
nvidia-cuda-nvcc-cu12    12.6.77
nvidia-cuda-runtime-cu12 12.6.77
nvidia-cudnn-cu12        9.5.0.50
nvidia-cufft-cu12        11.3.0.4
nvidia-cusolver-cu12     11.7.1.2
nvidia-cusparse-cu12     12.5.4.2
nvidia-nccl-cu12         2.23.4
nvidia-nvjitlink-cu12    12.6.77

What would be other ways to see if any other local CUDA installations are interfering? (I highly doubt this should be the case, as any variables like CUDA_HOME and CUDNN_PATH etc. are unset).

@amanjitsk amanjitsk reopened this Nov 7, 2024
@amanjitsk
Copy link
Author

amanjitsk commented Nov 11, 2024

This error is no longer there when I do a fresh install of "jax[cuda12]==0.4.33", but starting at version 0.4.34, this error remains. I don't think its using another CUDA/CUDNN, but somehow perhaps the detection of CUDA/CUDNN is changed in the versions of jax starting at 0.4.34.

Here is the full sequence of commands I run (this makes me think something is going wrong with CUDA detection in jax-cuda12-* peripheral libraries starting at version 0.4.34):

>>> poetry run pip list | grep -E 'nvidia|jax'

jax                      0.4.34
jax-cuda12-pjrt          0.4.34
jax-cuda12-plugin        0.4.34
jaxlib                   0.4.34
nvidia-cublas-cu12       12.6.3.3
nvidia-cuda-cupti-cu12   12.6.80
nvidia-cuda-nvcc-cu12    12.6.77
nvidia-cuda-runtime-cu12 12.6.77
nvidia-cudnn-cu12        9.5.0.50
nvidia-cufft-cu12        11.3.0.4
nvidia-cusolver-cu12     11.7.1.2
nvidia-cusparse-cu12     12.5.4.2
nvidia-nccl-cu12         2.23.4
nvidia-nvjitlink-cu12    12.6.77
%

>>> poetry run python -c "import jax; jax.print_environment_info(); import jax.numpy as jnp; print(jnp.linspace(0, 1, 10))"

jax:    0.4.34
jaxlib: 0.4.34
numpy:  2.1.3
python: 3.10.12 (main, Jul 19 2023, 10:44:52) [GCC 7.5.0]
jax.devices (1 total, 1 local): [CudaDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='gpu065', release='5.4.0-131-generic', version='#147~18.04.1-Ubuntu SMP Sat Oct 15 13:10:18 UTC 2022', machine='x86_64')


$ nvidia-smi
Mon Nov 11 16:25:35 2024
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            On   | 00000000:87:00.0 Off |                    0 |
| N/A   36C    P0    27W /  70W |    105MiB / 15360MiB |      2%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A     17939      C   ...cts/test/.venv/bin/python      102MiB |
+-----------------------------------------------------------------------------+

E1111 16:25:35.293750   17939 cuda_dnn.cc:502] There was an error before creating cudnn handle (500): cudaErrorSymbolNotFound : named symbol not found
E1111 16:25:35.295681   17939 cuda_dnn.cc:502] There was an error before creating cudnn handle (500): cudaErrorSymbolNotFound : named symbol not found
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/fs01/home/amanjitsk/Projects/test/.venv/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 6199, in linspace
    return _linspace(start, stop, num, endpoint, retstep, dtype, axis, device=device)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
%

>>> poetry run pip install -U "jax[cuda12]==0.4.33"
p

Collecting jax==0.4.33 (from jax[cuda12]==0.4.33)
  Using cached jax-0.4.33-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.4.33,>=0.4.33 (from jax==0.4.33->jax[cuda12]==0.4.33)
  Using cached jaxlib-0.4.33-cp310-cp310-manylinux2014_x86_64.whl.metadata (983 bytes)
Requirement already satisfied: ml-dtypes>=0.2.0 in ./.venv/lib/python3.10/site-packages (from jax==0.4.33->jax[cuda12]==0.4.33) (0.5.0)
Requirement already satisfied: numpy>=1.24 in ./.venv/lib/python3.10/site-packages (from jax==0.4.33->jax[cuda12]==0.4.33) (2.1.3)
Requirement already satisfied: opt-einsum in ./.venv/lib/python3.10/site-packages (from jax==0.4.33->jax[cuda12]==0.4.33) (3.4.0)
Requirement already satisfied: scipy>=1.10 in ./.venv/lib/python3.10/site-packages (from jax==0.4.33->jax[cuda12]==0.4.33) (1.14.1)
Collecting jax-cuda12-plugin<=0.4.33,>=0.4.33 (from jax-cuda12-plugin[with_cuda]<=0.4.33,>=0.4.33; extra == "cuda12"->jax[cuda12]==0.4.33)
  Using cached jax_cuda12_plugin-0.4.33-cp310-cp310-manylinux2014_x86_64.whl.metadata (1.2 kB)
Collecting jax-cuda12-pjrt==0.4.33 (from jax-cuda12-plugin<=0.4.33,>=0.4.33->jax-cuda12-plugin[with_cuda]<=0.4.33,>=0.4.33; extra == "cuda12"->jax[cuda12]==0.4.33)
  Using cached jax_cuda12_pjrt-0.4.33-py3-none-manylinux2014_x86_64.whl.metadata (349 bytes)
Requirement already satisfied: nvidia-cublas-cu12>=12.1.3.1 in ./.venv/lib/python3.10/site-packages (from jax-cuda12-plugin[with_cuda]<=0.4.33,>=0.4.33; extra == "cuda12"->jax[cuda12]==0.4.33) (12.6.3.3)
Requirement already satisfied: nvidia-cuda-cupti-cu12>=12.1.105 in ./.venv/lib/python3.10/site-packages (from jax-cuda12-plugin[with_cuda]<=0.4.33,>=0.4.33; extra == "cuda12"->jax[cuda12]==0.4.33) (12.6.80)
Requirement already satisfied: nvidia-cuda-nvcc-cu12>=12.1.105 in ./.venv/lib/python3.10/site-packages (from jax-cuda12-plugin[with_cuda]<=0.4.33,>=0.4.33; extra == "cuda12"->jax[cuda12]==0.4.33) (12.6.77)
Requirement already satisfied: nvidia-cuda-runtime-cu12>=12.1.105 in ./.venv/lib/python3.10/site-packages (from jax-cuda12-plugin[with_cuda]<=0.4.33,>=0.4.33; extra == "cuda12"->jax[cuda12]==0.4.33) (12.6.77)
Requirement already satisfied: nvidia-cudnn-cu12<10.0,>=9.1 in ./.venv/lib/python3.10/site-packages (from jax-cuda12-plugin[with_cuda]<=0.4.33,>=0.4.33; extra == "cuda12"->jax[cuda12]==0.4.33) (9.5.0.50)
Requirement already satisfied: nvidia-cufft-cu12>=11.0.2.54 in ./.venv/lib/python3.10/site-packages (from jax-cuda12-plugin[with_cuda]<=0.4.33,>=0.4.33; extra == "cuda12"->jax[cuda12]==0.4.33) (11.3.0.4)
Requirement already satisfied: nvidia-cusolver-cu12>=11.4.5.107 in ./.venv/lib/python3.10/site-packages (from jax-cuda12-plugin[with_cuda]<=0.4.33,>=0.4.33; extra == "cuda12"->jax[cuda12]==0.4.33) (11.7.1.2)
Requirement already satisfied: nvidia-cusparse-cu12>=12.1.0.106 in ./.venv/lib/python3.10/site-packages (from jax-cuda12-plugin[with_cuda]<=0.4.33,>=0.4.33; extra == "cuda12"->jax[cuda12]==0.4.33) (12.5.4.2)
Requirement already satisfied: nvidia-nccl-cu12>=2.18.1 in ./.venv/lib/python3.10/site-packages (from jax-cuda12-plugin[with_cuda]<=0.4.33,>=0.4.33; extra == "cuda12"->jax[cuda12]==0.4.33) (2.23.4)
Requirement already satisfied: nvidia-nvjitlink-cu12>=12.1.105 in ./.venv/lib/python3.10/site-packages (from jax-cuda12-plugin[with_cuda]<=0.4.33,>=0.4.33; extra == "cuda12"->jax[cuda12]==0.4.33) (12.6.77)
Using cached jax-0.4.33-py3-none-any.whl (2.1 MB)
Using cached jaxlib-0.4.33-cp310-cp310-manylinux2014_x86_64.whl (85.0 MB)
Using cached jax_cuda12_plugin-0.4.33-cp310-cp310-manylinux2014_x86_64.whl (14.9 MB)
Using cached jax_cuda12_pjrt-0.4.33-py3-none-manylinux2014_x86_64.whl (99.7 MB)
Installing collected packages: jax-cuda12-pjrt, jax-cuda12-plugin, jaxlib, jax
  Attempting uninstall: jax-cuda12-pjrt
    Found existing installation: jax-cuda12-pjrt 0.4.34
    Uninstalling jax-cuda12-pjrt-0.4.34:
      Successfully uninstalled jax-cuda12-pjrt-0.4.34
  Attempting uninstall: jax-cuda12-plugin
    Found existing installation: jax-cuda12-plugin 0.4.34
    Uninstalling jax-cuda12-plugin-0.4.34:
      Successfully uninstalled jax-cuda12-plugin-0.4.34
  Attempting uninstall: jaxlib
    Found existing installation: jaxlib 0.4.34
    Uninstalling jaxlib-0.4.34:
      Successfully uninstalled jaxlib-0.4.34
  Attempting uninstall: jax
    Found existing installation: jax 0.4.34
    Uninstalling jax-0.4.34:
      Successfully uninstalled jax-0.4.34
Successfully installed jax-0.4.33 jax-cuda12-pjrt-0.4.33 jax-cuda12-plugin-0.4.33 jaxlib-0.4.33
%

>>> poetry run python -c "import jax; jax.print_environment_info(); import jax.numpy as jnp; print(jnp.linspace(0, 1, 10))"

jax:    0.4.33
jaxlib: 0.4.33
numpy:  2.1.3
python: 3.10.12 (main, Jul 19 2023, 10:44:52) [GCC 7.5.0]
jax.devices (1 total, 1 local): [CudaDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='gpu065', release='5.4.0-131-generic', version='#147~18.04.1-Ubuntu SMP Sat Oct 15 13:10:18 UTC 2022', machine='x86_64')


$ nvidia-smi
Mon Nov 11 16:26:03 2024
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            On   | 00000000:87:00.0 Off |                    0 |
| N/A   36C    P0    27W /  70W |    105MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A     18467      C   ...cts/test/.venv/bin/python      102MiB |
+-----------------------------------------------------------------------------+

2024-11-11 16:26:03.535447: W external/xla/xla/service/gpu/nvptx_compiler.cc:893] The NVIDIA driver's CUDA version is 12.0 which is older than the PTX compiler version 12.6.77. Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
[0.         0.11111111 0.22222222 0.33333334 0.44444445 0.5555556
 0.6666667  0.7777778  0.8888889  1.        ]
%

>>> poetry run pip list | grep -E 'nvidia|jax'

jax                      0.4.33
jax-cuda12-pjrt          0.4.33
jax-cuda12-plugin        0.4.33
jaxlib                   0.4.33
nvidia-cublas-cu12       12.6.3.3
nvidia-cuda-cupti-cu12   12.6.80
nvidia-cuda-nvcc-cu12    12.6.77
nvidia-cuda-runtime-cu12 12.6.77
nvidia-cudnn-cu12        9.5.0.50
nvidia-cufft-cu12        11.3.0.4
nvidia-cusolver-cu12     11.7.1.2
nvidia-cusparse-cu12     12.5.4.2
nvidia-nccl-cu12         2.23.4
nvidia-nvjitlink-cu12    12.6.77

EDIT: So it seems that I can fix the problem by manually installing jax-cuda12-pjrt==0.4.33 with the latest version of jax==0.4.35 and this works. However, pip yells at me because of the inconsistent dependencies between jax-cuda12-plugin==0.4.34 and jax-cuda12-pjrt==0.4.33. I don't know why this is happening but I have noticed it only happens when the nvidia driver version is as above (525.105.17), but does not happen on newer drivers like 535.171.04.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants