Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ModuleNotFoundError: No module named 'jax.extend' related to #209, #210 #224

Closed
ZubuNoShoshinsha opened this issue Apr 1, 2024 · 9 comments

Comments

@ZubuNoShoshinsha
Copy link

Hello,

My question is related to #209 and #210
My environment is...

Wsl2
OS: Ubuntu 22.04.4
GCC: 11.4.0
CUDA: 12.1
GPU: RTX 4090
LocalColabFold Ver. 1.5.5

As instructed in #209 , I checked if GPU was recognized and it was not.
So, I dongraded jax and jaxlib to
jax 0.4.7
jaxlib0.4.7+cuda11.cudnn82
as instructed in #209 .

And then I checked again using
$ /path/to/your/localcolabfold/colabfold-conda/bin/python3.10

import jax
print(jax.local_devices()[0].platform)

and "gpu" was returned.

Then, I run the localcolabfold. But, this error message popped up and stopped like below

2024-04-01 15:14:35,452 Running colabfold 1.5.5 (61df3b853140ca79dbdf64349824beb14364ebfd)
2024-04-01 15:14:36,006 Running on GPU
Traceback (most recent call last):
File "/mnt/d/Alphafold/localcolabfold/colabfold-conda/bin/colabfold_batch", line 8, in sys.exit(main())
File "/mnt/d/AlphaFold/localcolabfold/colabfold-conda/lib/python3.10/site-packages/colabfold/batch.py", line 2037, in main run(
File "/mnt/d/AlphaFold/localcolabfold/colabfold-conda/lib/python3.10/site-packages/colabfold/batch.py", line 1292, in run from colabfold.alphafold.models import load_models_and_params
File "/mnt/d/AlphaFold/localcolabfold/colabfold-conda/lib/python3.10/site-packages/colabfold/alphafold/models.py", line 4, in import haiku
File "/mnt/d/AlphaFold/localcolabfold/colabfold-conda/lib/python3.10/site-packages/haiku/init.py", line 20, in from haiku import experimental
File "/mnt/d/AlphaFold/localcolabfold/colabfold-conda/lib/python3.10/site-packages/haiku/experimental/init.py", line 34, in from haiku._src.dot import abstract_to_dot
File "/mnt/d/AlphaFold/localcolabfold/colabfold-conda/lib/python3.10/site-packages/haiku/_src/dot.py", line 29, in from jax.extend import linear_util as lu
ModuleNotFoundError: No module named 'jax.extend'

It would be helpful if there would be any instruction for solving this issue.

@YoshitakaMo
Copy link
Owner

I suspect that the issue lies in the version of the dm-haiku module being 0.0.11 or later. In my environment:

$ localcolabfold/colabfold-conda/bin/python3.10 -m pip list

jax                          0.4.23
jaxlib                       0.4.23+cuda11.cudnn86
chex                         0.1.85
dm-haiku                     0.0.10

If CUDA 12.1 is installed, these versions should be fine.

Please set your dm-haiku to version 0.0.10. Otherwise, you may encounter the error ModuleNotFoundError: No module named 'jax.extend'

@ZubuNoShoshinsha
Copy link
Author

Thank you for your suggestion. (I just noticed your response)
Actually my dm-haiku was 0.0.12, so I down graded to 0.0.10 as you suggested.

And I ran localcolabfold 1.5.5.
So, my environment is now
jax 0.4.7
jaxlib 0.4.7+cuda11.cudnn82
chex 0.1.82
dm-haiku 0.0.10

No more " ModuleNotFoundError: No module named 'jax.extend' ", but now new message showed up and the program stopped.

" Could not predict ProteinA. Not Enough GPU memory? FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details. "

Is there any more suggestion to solve this??

@ZubuNoShoshinsha
Copy link
Author

my nvidia-smi result
nvidia-smi

@YoshitakaMo
Copy link
Owner

If you are using WSL2, did you turn on the settings shown in https://github.com/YoshitakaMo/localcolabfold?tab=readme-ov-file#for-wsl2-in-windows ?
Unfortunately, I can't figure out the cause because I don't have a WSL2 environment.

@ZubuNoShoshinsha
Copy link
Author

Yes, I did.
I restart wsl2 and tried another shot, but it didn't work well.

Thank you though.

@ZubuNoShoshinsha
Copy link
Author

I wonder...
when I downgraded dm-haiku, the message said
" colabfold 1.5.5 requires dm-haiku<0.013, >=0.0.12, but you have dm-haiku 0.0.10 which is incompatible. "
Is it fine to run colabfold appropriately?

@ZubuNoShoshinsha
Copy link
Author

Finally,

I might have found the solution.

I downgraded " nvidia-cudnn-cu11 " by doing this command from 9.0.0.312 to 8.5.0.96 .

pip install --upgrade nvidia-cudnn-cu11==8.5.0.96

I ran the localcolabfold and it processed very smoothly on GPU.

I was astonished.

Thank you.

@Vinaysukhesh98
Copy link

jax                          0.4.23
jaxlib                       0.4.23+cuda11.cudnn86
chex                         0.1.85
dm-haiku                     0.0.10

Requirement already satisfied: torch==1.13.1 in /usr/local/lib/python3.10/dist-packages (1.13.1)
Requirement already satisfied: transformers==4.24.0 in /usr/local/lib/python3.10/dist-packages (4.24.0)
Collecting diffusers==0.3.0
Using cached diffusers-0.3.0-py3-none-any.whl (153 kB)
Collecting jax==0.4.23
Downloading jax-0.4.23-py3-none-any.whl (1.7 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.7/1.7 MB 6.3 MB/s eta 0:00:00
ERROR: Could not find a version that satisfies the requirement jaxlib==0.4.23+cuda11.cudnn86 (from versions: 0.4.6, 0.4.7, 0.4.9, 0.4.10, 0.4.11, 0.4.12, 0.4.13, 0.4.14, 0.4.16, 0.4.17, 0.4.18, 0.4.19, 0.4.20, 0.4.21, 0.4.22, 0.4.23, 0.4.24, 0.4.25, 0.4.26, 0.4.27, 0.4.28)
ERROR: No matching distribution found for jaxlib==0.4.23+cuda11.cudnn86

@YoshitakaMo
Copy link
Owner

I updated the installer and updater script for Linux two days ago as Jax 0.4.23 no longer seems suitable for cuda 12 and cudnn 9. Please update your cuda to 12.4, cudnn to 9, and use the latest updater script.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants