Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question:WARNING: no GPU detected, will be using CPU #210

Open
fengdelu opened this issue Jan 24, 2024 · 30 comments
Open

Question:WARNING: no GPU detected, will be using CPU #210

fengdelu opened this issue Jan 24, 2024 · 30 comments

Comments

@fengdelu
Copy link

QQ截图20240125001555
hello,I'm having some problems running it with a “WARNING: no GPU detected, will be using CPU” error, how can I fix this?
QQ截图20240125001758

@YoshitakaMo
Copy link
Owner

Unfortunately, this problem did not occur in my environment. Please check the following points.

  1. Does nvidia-smi work properly? This is mandatory.
  2. Reboot the machine.
  3. Install using the latest install_colabbatch_linux.sh in another directory.
  4. Check whether jax can recognize the device:
$ /path/to/your/localcolabfold/colabfold-conda/bin/python3.10
Python 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:36:39) [GCC 12.3.0] on linux
>>> import jax
>>> print(jax.local_devices()[0].platform)
gpu

I expect "gpu" returns here.

@dsclassen
Copy link

I'm running into the same problem. I've performed a new installation as per instructions.
nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Tue_Mar__8_18:18:20_PST_2022
Cuda compilation tools, release 11.6, V11.6.124
Build cuda_11.6.r11.6/compiler.31057947_0

gcc --version

gcc (GCC) 10.3.1 20210422 (Red Hat 10.3.1-1)

nvidia-smi

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08              Driver Version: 545.23.08    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A100 80GB PCIe          Off | 00000000:01:00.0 Off |                    0 |
| N/A   35C    P0              46W / 300W |    138MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA A100 80GB PCIe          Off | 00000000:41:00.0 Off |                    0 |
| N/A   36C    P0              45W / 300W |      4MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

but still getting the warning about no GPU found:

2024-01-29 12:50:21,031 Running colabfold 1.5.5 (a00ce1bcc477491d7693e3816d21ea3fc2cf40fd)

WARNING: You are welcome to use the default MSA server, however keep in mind that it's a
limited shared resource only capable of processing a few thousand MSAs per day. Please
submit jobs only from a single IP address. We reserve the right to limit access to the
server case-by-case when usage exceeds fair use. If you require more MSAs: You can 
precompute all MSAs with `colabfold_search` or host your own API and pass it to `--host-url`

2024-01-29 12:50:21,235 WARNING: no GPU detected, will be using CPU
2024-01-29 12:50:21,575 Found 5 citations for tools or databases
2024-01-29 12:50:21,575 Query 1/1: T1050_A7LXT1__Bacteroides_Ovatus__779_residues_ (length 779)

seems jax is the problem

localcolabfold/colabfold-conda/bin/python3.10
Python 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:36:39) [GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> print(jax.local_devices()[0].platform)
CUDA backend failed to initialize: Found cuDNN version 8401, but JAX was built against version 8600, which is newer. The copy of cuDNN that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
cpu

I'm not exactly sure how to update JAX within the localcolabfold conda env

@dsclassen
Copy link

I have figured this out. As per a hint on the jax documentation site I unset LD_LIBRARY_PATH and now the GPUs are being recognized.

The next problem I had was this error:

2024-01-29 13:49:54,867 Running on GPU
2024-01-29 13:49:55,214 Found 5 citations for tools or databases
2024-01-29 13:49:55,214 Query 1/1: T1050_A7LXT1__Bacteroides_Ovatus__779_residues_ (length 779)
2024-01-29 13:49:57,823 Setting max_seq=512, max_extra_seq=5120
2024-01-29 13:49:57,870 Could not predict T1050_A7LXT1__Bacteroides_Ovatus__779_residues_. Not Enough GPU memory? INTERNAL: XLA requires ptxas version 11.8 or higher
2024-01-29 13:49:57,870 Done

I was able to solve this problem by installing the cuda-toolkit-11-8 package on our Rocky Linux 8.9 system

dnf install cuda-toolkit-11-8

Now colabfold_batch is working as expected.

@A-Talavera
Copy link

Hi I was having the same error "WARNING: no GPU detected, will be using CPU".

I upgraded the Nvidia driver to 545.23.08 with cuda version 12.3. I checked this with both commands: "nvidia-smi" and "nvcc --version". Nonethelss, every time I tried to run "colabfold_batch" I always got the same error as mentioned before. Following the instructions of @YoshitakaMo I saw that JAX could not load cuda. Despite having the right version of cuda installed, JAX was still complaning about the version of cuda being 11.7 and not 11.8.

I could fix this issue by going into the installation script (install_colabbatch_linux.sh) and changing line 25:

"$COLABFOLDDIR/colabfold-conda/bin/pip" install --upgrade "jax[cuda11_pip]==0.4.23"
to
"$COLABFOLDDIR/colabfold-conda/bin/pip" install --upgrade "jax[cuda12_pip]==0.4.23"

Basically instead of installing the cuda11 version of jax I installed the version for cuda12. After that the localcolabfold envioroment has the right cuda installation andcolabfold_bacth is running as nice and smooth as before.

Best regards,
Ariel

@broomsday
Copy link

broomsday commented Apr 4, 2024

BTW, I had this same error recently after a fresh install where both nvidia-smi and nvcc --version confirmed that I had a GPU available with CUDA 11.8.

In the end the issue was with the tensorflow installation which seems to default to using CUDA 12.

The solution for me was that after running the install script I then entered the conda environment and essentially did a pip install tensorflow[and-cuda]==2.14 to force the version of tensorflow that was last compatible with CUDA 11.8.

I actually used the script below to accomplish this, based on the linux install script, your mileage may vary.

COLABFOLDDIR="/AlphaFold/localcolabfold/"
source "${COLABFOLDDIR}/conda/etc/profile.d/conda.sh"
export PATH="${COLABFOLDDIR}/conda/condabin:${PATH}"
conda activate "$COLABFOLDDIR/colabfold-conda"
"$COLABFOLDDIR/colabfold-conda/bin/pip" install "tensorflow[and-cuda]==2.14"

@HobbitBaba
Copy link

I have an AMD GPU in my workstation. Is this GPU going to work with Alphafold?

@YoshitakaMo
Copy link
Owner

No, AMD GPUs are not supported currently.

@frenko
Copy link

frenko commented Apr 30, 2024

I have to this day the same problem.
I have a machine with ubuntu 22.04 and RTX 3080Ti.
I installed the drivers and cuda via .run installer obtained from nvidia site.
The output from nvcc
fred@jade:~$ nvcc --version nvcc: NVIDIA (R) Cuda compiler driver Copyright (c) 2005-2024 NVIDIA Corporation Built on Thu_Mar_28_02:18:24_PDT_2024 Cuda compilation tools, release 12.4, V12.4.131 Build cuda_12.4.r12.4/compiler.34097967_0
The output of nvidia-smi
nvidia-smi Tue Apr 30 12:06:52 2024 +-----------------------------------------------------------------------------------------+ | NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: 12.4 | |-----------------------------------------+------------------------+----------------------+
if i try to launch python3.10 from localcolabfold and then import jax I obtain this:

CUDA backend failed to initialize: Unable to load cuDNN. Is it installed? (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

@crshin
Copy link

crshin commented May 3, 2024

I'm having the same problem.

2024-05-03 06:24:18,051 WARNING: no GPU detected, will be using CPU

I'm using the exact same version of nvcc version as you said.
To solve the problem I uninstalled and reinstalled it, but the same problem occurs.
How can I solve this problem?

this is my computer's gpu version
(Ubuntu 22.04.2 LTS and RTX 4090)
$nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Sep_21_10:33:58_PDT_2022
Cuda compilation tools, release 11.8, V11.8.89
Build cuda_11.8.r11.8/compiler.31833905_0

$gcc --version
gcc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Copyright (C) 2021 Free Software Foundation, Inc.
This is free software; see the source for copying conditions. There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

$nvidia-smi
Fri May 3 06:31:22 2024
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 520.61.05 Driver Version: 520.61.05 CUDA Version: 11.8 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA Graphics... On | 00000000:65:00.0 Off | Off |
| 0% 44C P8 23W / 450W | 1MiB / 24564MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+

python -c "import jax; print(f'Jax backend: {jax.default_backend()}')"
CUDA backend failed to initialize: Unable to load cuDNN. Is it installed? (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Jax backend: cpu

@YoshitakaMo
Copy link
Owner

@frenko @crshin Could you tell me the version of nvidia-cudnn and jaxlib? In my environment, for example,

$ /path/to/your/localcolabfold/colabfold-conda/bin/python3.10 -m pip list | grep nvidia-cudnn
nvidia-cudnn-cu11            9.1.0.70
$ $ colabfold-conda/bin/python3.10 -m pip list | grep jax
jax                          0.4.23
jaxlib                       0.4.23+cuda11.cudnn86

I suspect a mismatch of cudnn and jaxlib+cudnn versions causes the error.

@frenko
Copy link

frenko commented May 6, 2024

from my localcolabfold:
./localcolabfold/colabfold-conda/bin/python3.10 -m pip list | grep nvidia-cudnn
nvidia-cudnn-cu12 9.1.0.70
and
./localcolabfold/colabfold-conda/bin/python3.10 -m pip list | grep jax
jax 0.4.23
jaxlib 0.4.23+cuda12.cudnn89

cuda 12 is installed on my system so in the localcolabfold installation file I replaced jax[cuda11_pip]==0.4.23 with jax[cuda12_pip]==0.4.23. As suggested by some users

@YoshitakaMo
Copy link
Owner

@frenko In my system, I installed cuda-12.3, but nvidia-cudnn-cu11 9.1.0.70 and jaxlib 0.4.23+cuda11.cudnn86 work properly. But did it not work with jax[cuda11_pip] in your case?

@crshin
Copy link

crshin commented May 6, 2024

sorry for late reply.

from my localcolabfold,
./localcolabfold/colabfold-conda/bin$ python3.10 -m pip list | grep nvidia-cudnn
nvidia-cudnn-cu11 9.1.0.70
./localcolabfold/colabfold-conda/bin$ python3.10 -m pip list | grep jax
jax 0.4.23
jaxlib 0.4.23+cuda11.cudnn86

In my case, I'am using cuda 11 version.

@frenko
Copy link

frenko commented May 6, 2024

@frenko In my system, I installed cuda-12.3, but nvidia-cudnn-cu11 9.1.0.70 and jaxlib 0.4.23+cuda11.cudnn86 work properly. But did it not work with jax[cuda11_pip] in your case?

I solved by giving the following command:
./localcolabfold/colabfold-conda/bin/python3.10 -m pip install jax[cuda12_pip] -U
evidently both jax and jaxlib have updated and now localcolabfold sees the GPU and allows its use.
I am not in the lab now, I will update the comment tomorrow by including more details

@YoshitakaMo
Copy link
Owner

@crshin Could you test the solution of @frenko for your case?

@crshin
Copy link

crshin commented May 7, 2024

In my case, that doesn't work.
This is my log.

(base) /home/($user)/data/SW/localcolabfold/localcolabfold/colabfold-conda/bin$ python3.10 -m pip install jax[cuda11_pip] -U
Requirement already satisfied: jax[cuda11_pip] in ./localcolabfold/colabfold-conda/lib/python3.10/site-packages (0.4.23)
Collecting jax[cuda11_pip]
  Using cached jax-0.4.26-py3-none-any.whl.metadata (23 kB)
WARNING: jax 0.4.26 does not provide the extra 'cuda11-pip'
Requirement already satisfied: ml-dtypes>=0.2.0 in /home/($user)/data/SW/localcolabfold/localcolabfold/colabfold-conda/lib/python3.10/site-a11_pip]) (0.3.2)
Requirement already satisfied: numpy>=1.22 in /home/($user)/data/SW/localcolabfold/localcolabfold/colabfold-conda/lib/python3.10/site-packages (from jax[cuda11_pip]) (1.26.4)
Requirement already satisfied: opt-einsum in /home/($user)/.local/lib/python3.10/site-packages (from jax[cuda11_pip]) (3.3.0)
Requirement already satisfied: scipy>=1.9 in /home/($user)/.local/lib/python3.10/site-packages (from jax[cuda11_pip]) (1.10.1)
Using cached jax-0.4.26-py3-none-any.whl (1.9 MB)
Installing collected packages: jax
  Attempting uninstall: jax
    Found existing installation: jax 0.4.23
    Uninstalling jax-0.4.23:
      Successfully uninstalled jax-0.4.23
Successfully installed jax-0.4.26
$colabfold_batch --templates --amber (input) (output) 
2024-05-07 07:53:49,959 Running colabfold 1.5.5 (07644a8bbfbc00c96e7d897d96fb3b11d974b766)
2024-05-07 07:53:50,053 WARNING: no GPU detected, will be using CPU
2024-05-07 07:53:50,547 Found 9 citations for tools or databases
2024-05-07 07:53:50,547 Query 1/1: pdbA_B (length 298)
COMPLETE: 100%|█████████████████████████████████████████████████████████████████████████████████████| 300/300 [elapsed: 00:01 remaining: 00:00]
2024-05-07 07:53:58,141 Sequence 0 found templates: ['1m4u_L', '2r52_B', '6oml_Y', '5vt2_B', '2r53_A', '1lxi_A', '4n1d_A', '7zjf_B', '7zjf_A', '6z3g_A', '3qb4_C', '3qb4_A', '6z3j_A', '2h64_A', '4uhy_A', '1reu_A', '4ui0_A', '2h62_B', '4mid_A', '3bk3_B']

.
.
.
(I killed it with ctrl+z)

@frenko
Copy link

frenko commented May 7, 2024

As mentioned yesterday I update my situation. After updating the jax packages with pip:
./localcolabfold/colabfold-conda/bin/python3.10 -m pip install -U jax[cuda12]
if I run the following command

./localcolabfold/colabfold-conda/bin/python3.10 -m pip list | grep jax
jax 0.4.26
jax-cuda12-pjrt 0.4.26
jax-cuda12-plugin 0.4.26
jaxlib 0.4.26

and for cudnn:
nvidia-cudnn-cu12 8.9.7.29
and now if I import jax in python I can see GPUs:

./localcolabfold/colabfold-conda/bin/python3.10
Python 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> print(jax.local_devices()[0].platform)
gpu
>>> 

@f-meireles
Copy link

Hi everyone,

I was having the same problem of GPU not being detected, and after updating JAX as @frenko did I got it to work. My issue now is related to GPU memory:

colabfold_batch . resultsrec3/
2024-05-07 15:24:43,218 Running colabfold 1.5.5 (57b220e028610ba7331ebe1ef9c2d0419992469a)

WARNING: You are welcome to use the default MSA server, however keep in mind that it's a
limited shared resource only capable of processing a few thousand MSAs per day. Please
submit jobs only from a single IP address. We reserve the right to limit access to the
server case-by-case when usage exceeds fair use. If you require more MSAs: You can
precompute all MSAs with `colabfold_search` or host your own API and pass it to --host-url

2024-05-07 15:24:43,278 Running on GPU
2024-05-07 15:24:43,535 Found 5 citations for tools or databases
2024-05-07 15:24:43,535 Query 1/1: D-OXA-23-prot (length 273)
COMPLETE: 100%|██████████████████| 150/150 [elapsed: 00:03 remaining: 00:00]
2024-05-07 15:24:47,456 Setting max_seq=512, max_extra_seq=5120
2024-05-07 15:24:47,470 Could not predict D-OXA-23-prot. Not Enough GPU memory? FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
2024-05-07 15:24:47,470 Done

As for the versions of CUDA, JAX, and cudnn:

./localcolabfold/colabfold-conda/bin/python3.10 -m pip list | grep jax
jax                          0.4.26
jax-cuda12-pjrt              0.4.26
jax-cuda12-plugin            0.4.26
jaxlib                       0.4.26+cuda12.cudnn89
./localcolabfold/colabfold-conda/bin/python3.10 -m pip list | grep cudnn
jaxlib                       0.4.26+cuda12.cudnn89
nvidia-cudnn-cu12            8.9.2.26
nvidia-smi
Tue May  7 15:37:01 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 4070        Off |   00000000:01:00.0 Off |                  N/A |
|  0%   42C    P8             15W /  200W |       8MiB /  12282MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      2070      G   /usr/lib/xorg/Xorg                              4MiB |
+-----------------------------------------------------------------------------------------+
nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Mar_28_02:18:24_PDT_2024
Cuda compilation tools, release 12.4, V12.4.131
Build cuda_12.4.r12.4/compiler.34097967_0

Thanks a lot in advance and let me know if I should provide any more info.

@YoshitakaMo
Copy link
Owner

First, check your cuda version using /usr/local/cuda/bin/nvcc --version, not the one displayed in the first line of nvidia-smi.

@crshin I recommend jax and jaxlib==0.4.23 with gpu support, not the latest one.

./localcolabfold/colabfold-conda/bin/python3.10 -m pip install "jax[cuda11_pip]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# or use this one for cuda 12
# ./localcolabfold/colabfold-conda/bin/python3.10 -m pip install "jax[cuda12_pip]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
./localcolabfold/colabfold-conda/bin/python3.10 -m pip install jax==0.4.23

@f-meireles
It's likely due to a mismatch between your cuda version and the installed jaxlib. see jax-ml/jax#15361 (comment) or jax-ml/jax#15361 (comment) . I recommend version 0.4.23.

@crshin
Copy link

crshin commented May 8, 2024

To start anew, I uninstalled the existing localcolabfold file and reinstalled it.
And then, I noticed the same problem occured again.
So I installed "jax and jaxlib==0.4.23" as you taught me,

./localcolabfold/colabfold-conda/bin/python3.10 -m pip install "jax[cuda11_pip]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

And now, these are my nvcc version and gcc version. (also nvidia-smi)

/usr/local/cuda/bin/nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Sep_21_10:33:58_PDT_2022
Cuda compilation tools, release 11.8, V11.8.89
Build cuda_11.8.r11.8/compiler.31833905_0
$gcc --version
gcc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Copyright (C) 2021 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
$nvidia-smi
Wed May  8 02:15:50 2024
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 520.61.05    Driver Version: 520.61.05    CUDA Version: 11.8     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA Graphics...  On   | 00000000:65:00.0 Off |                  Off |
|  0%   46C    P8    23W / 450W |      1MiB / 24564MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
./localcolabfold/colabfold-conda/bin$ python3.10 -m pip list | grep jax
jax                          0.4.23
jaxlib                       0.4.23+cuda11.cudnn86
./localcolabfold/colabfold-conda/bin$ python3.10 -m pip list | grep cudnn
jaxlib                       0.4.23+cuda11.cudnn86
nvidia-cudnn-cu11            9.1.0.70

But I found some points that presumed to cause the problem.
First problem is like below:

python3.10.
Python 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> print(jax.local_devices()[0].platform)
CUDA backend failed to initialize: Unable to load cuDNN. Is it installed? (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
cpu

And during the installation of "jax[cuda11_pip]==0.4.23",

./localcolabfold/colabfold-conda/bin/python3.10 -m pip install "jax[cuda11_pip]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
"WARNING: jax 0.4.23 does not provide the extra 'cuda11-pip'"

This warning kept popping up.

I'm looking for a solution to this in my own way, but it's not easy. Do you have any guesses or solutions?

@YoshitakaMo
Copy link
Owner

I've noticed that "cuda11/12-pip" installation is not recommended now: https://jax.readthedocs.io/en/latest/installation.html
I've freshly installed my localcolabfold on Ubuntu 22.04 with the updated install_colabbatch_linux.sh. It worked properly under both cuda/11.8 and cuda/12.4 (I switched the version using sudo update-alternatives --config cuda).

$ ./localcolabfold/colabfold-conda/bin/python3.10 -m pip list | grep jax
jax                          0.4.23
jax-cuda12-pjrt              0.4.23
jax-cuda12-plugin            0.4.23
jaxlib                       0.4.23+cuda12.cudnn89
$ jaxlib                       0.4.23+cuda12.cudnn89
nvidia-cudnn-cu12            9.1.0.70

@crshin How about updating CUDA and Nvidia-driver to the latest version? It might improve that problem. I'm using CUDA 12.4 and NVIDIA-SMI 550.54.15.

@f-meireles
Copy link

@YoshitakaMo Thanks for the reply. Even after using the latest version of the install_colabbatch_linux.sh I still get the same error:

2024-05-08 10:59:07,266 Running colabfold 1.5.5 (57b220e028610ba7331ebe1ef9c2d0419992469a)

WARNING: You are welcome to use the default MSA server, however keep in mind that it's a
limited shared resource only capable of processing a few thousand MSAs per day. Please
submit jobs only from a single IP address. We reserve the right to limit access to the
server case-by-case when usage exceeds fair use. If you require more MSAs: You can 
precompute all MSAs with `colabfold_search` or host your own API and pass it to `--host-url`

2024-05-08 10:59:07,325 Running on GPU
2024-05-08 10:59:07,583 Found 5 citations for tools or databases
2024-05-08 10:59:07,583 Query 1/1: D-OXA-23-prot (length 273)
2024-05-08 10:59:08,002 Setting max_seq=512, max_extra_seq=5120
2024-05-08 10:59:08,015 Could not predict D-OXA-23-prot. Not Enough GPU memory? FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
2024-05-08 10:59:08,015 Done

As for the nvcc version:

/usr/local/cuda/bin/nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Mar_28_02:18:24_PDT_2024
Cuda compilation tools, release 12.4, V12.4.131
Build cuda_12.4.r12.4/compiler.34097967_0

Which should be compatible with my jax version:

/home/fteixeir/anaconda3_2024/envs/localcolabfold/colabfold-conda/bin/pip list |grep jax
jax                          0.4.23
jax-cuda12-pjrt              0.4.23
jax-cuda12-plugin            0.4.23
jaxlib                       0.4.23+cuda12.cudnn89

Any idea what else could be the issue? Thanks again!

@YoshitakaMo
Copy link
Owner

@f-meireles Which is your OS, Windows (WSL2) or native Ubuntu 22.04?
Also, have you tried /home/fteixeir/anaconda3_2024/envs/localcolabfold/colabfold-conda/bin/python3.10 -m pip install --upgrade nvidia-cudnn-cu11==8.5.0.96 (See: #228 or #224)

@f-meireles
Copy link

@YoshitakaMo It is native Ubuntu 22.04.

I just tried the solution you sent but also didn't have success. I'll keep trying things here.

@crshin
Copy link

crshin commented May 10, 2024

I updated CUDA version 11 .8to 12.4 as you adviced.
and re-installed $bash install_colabbatch_linux.sh
so, now my jax, cudnn and nvidia-smi versions are as follows.

$/usr/local/cuda/bin/nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Tue_Feb_27_16:19:38_PST_2024
Cuda compilation tools, release 12.4, V12.4.99
Build cuda_12.4.r12.4/compiler.33961263_0
$python3.10 -m pip list | grep jax
jax                          0.4.23
jaxlib                       0.4.23+cuda12.cudnn89
$python3.10 -m pip list | grep cudnn
jaxlib                       0.4.23+cuda12.cudnn89
nvidia-cudnn-cu11            8.5.0.96
nvidia-cudnn-cu12            9.1.0.70
$nvidia-smi
Fri May 10 07:20:32 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.14              Driver Version: 550.54.14      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+

from this state, the same problem occured as with other users,

$python3.10
Python 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> print(jax.local_devices()[0].platform)                                                   CUDA backend failed to initialize: Unable to load cuDNN. Is it installed? (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
cpu
>>> exit()

so, I installed 'jax-cuda12-pjrt 0.4.26' and 'jax-cuda12-plugin 0.4.26' like @frenko 's comment:

$python3.10 -m pip install --upgrade jax-cuda12-pjrt==0.4.23
$python3.10 -m pip install --upgrade jax-cuda12-plugin==0.4.23

after that, my localcolabfold successfully detected the GPU:

$python3.10
Python 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> print(jax.local_devices()[0].platform)
gpu

However, from now on, I'm experiencing the same problem as @f-meireles

2024-05-10 06:23:00,300 Running colabfold 1.5.5 (57b220e028610ba7331ebe1ef9c2d0419992469a)
2024-05-10 06:23:00,635 Running on GPU
2024-05-10 06:23:01,187 Found 9 citations for tools or databases
2024-05-10 06:23:01,188 Query 1/1: pdb_A (length 108)
2024-05-10 06:23:04,173 Sequence 0 found templates: ['1m4u_L', '2r52_B', '6oml_Y', '5vt2_B', '2r53_A', '1lxi_A', '4n1d_A', '7zjf_B', '7zjf_A', '6z3g_A', '3qb4_C', '3qb4_A', '6z3j_A', '2h64_A', '4uhy_A', '1reu_A', '4ui0_A', '2h62_B', '4mid_A', '3bk3_B']
2024-05-10 06:23:04,966 Setting max_seq=512, max_extra_seq=5120
2024-05-10 06:23:04,993 Could not predict pdb_A. Not Enough GPU memory? FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
2024-05-10 06:23:04,993 Done

I tried the solution $python3.10 -m pip install --upgrade nvidia-cudnn-cu11==8.5.0.9
but it didn't work my case too.

@jdmontenegro
Copy link

As mentioned yesterday I update my situation. After updating the jax packages with pip: ./localcolabfold/colabfold-conda/bin/python3.10 -m pip install -U jax[cuda12] if I run the following command

./localcolabfold/colabfold-conda/bin/python3.10 -m pip list | grep jax
jax 0.4.26
jax-cuda12-pjrt 0.4.26
jax-cuda12-plugin 0.4.26
jaxlib 0.4.26

and for cudnn: nvidia-cudnn-cu12 8.9.7.29 and now if I import jax in python I can see GPUs:

./localcolabfold/colabfold-conda/bin/python3.10
Python 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> print(jax.local_devices()[0].platform)
gpu
>>> 

I was running into the same issue and this solution worked for me, just updating to nvidia-cudnn-cu12 and jax[cuda12] as recommended in the reply above. I am in an Oracle Linux 9.4 and Tesla 4 GPUs.

@YoshitakaMo
Copy link
Owner

@crshin Your problem is more likely to be a JAX/CUDA issue rather than localcolabfold itself... It may be more helpful to ask in a JAX issue forum: jax-ml/jax#15361

@fatpmeireles
Copy link

@YoshitakaMo @crshin I was finally able to make it work! I had to get ColabFold directly from the original repository. In an new, empty conda env with Python3.10 I did:

pip install --upgrade "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install "colabfold[alphafold] @ git+https://github.com/sokrypton/ColabFold"

And now it works normally. Unfortunately, I didn't figure out what was wrong, but I hope this works for you as well.

@crshin
Copy link

crshin commented May 14, 2024

That works on my computer too!! @YoshitakaMo @fatpmeireles
Thank you so much.
I also don't know what's wrong, but anyway, it works well now.
I think we had the same problem.

@johnnytam100
Copy link

@YoshitakaMo @crshin I was finally able to make it work! I had to get ColabFold directly from the original repository. In an new, empty conda env with Python3.10 I did:

pip install --upgrade "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install "colabfold[alphafold] @ git+https://github.com/sokrypton/ColabFold"

And now it works normally. Unfortunately, I didn't figure out what was wrong, but I hope this works for you as well.

I met the same problem, and this workaround just worked! Thanks!
Specifically, I did:

conda create -n localcolabfold python=3.10
conda activate localcolabfold
pip install --upgrade "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install "colabfold[alphafold] @ git+https://github.com/sokrypton/ColabFold"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests