Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax and jaxlib versions #20

Closed
huhlim opened this issue Jul 17, 2021 · 6 comments
Closed

jax and jaxlib versions #20

huhlim opened this issue Jul 17, 2021 · 6 comments

Comments

@huhlim
Copy link

huhlim commented Jul 17, 2021

TL;DR is it okay to use the jaxlib version of 0.1.68+cuda110 instead of 0.1.69+cuda110?

I have tried to write a script and construct a conda environment that does not use Docker. When I used the same versions of jax and jaxlib defined in the docker/Dockerfile, I had some issues during the inference time. Scripts were working fine for model_{1,3,4} but raised CUDA_ERROR_ILLEGAL_ADDRESS errors for model_{2,5}. I have no idea why it happened...
So, I tested many variants of the environment and found out that jax=0.2.17 (probably, it is the same version of the original) and jaxlib=0.1.68+cuda110 (it is the version for installing jax with a command
pip3 install jax[cuda110] -f https://storage.googleapis.com/jax-releases/jax_releases.html ) are okay to run smoothly without Docker, but with my custom conda environment.

@tfgg
Copy link
Collaborator

tfgg commented Jul 19, 2021

Hi, we require version of 0.1.69 jaxlib to be able to use CUDA unified memory for running long sequences. If you don't need this you can probably run with 0.1.68, but that might be related to the illegal address error that you see. How long was the sequence you were trying to run?

Some of the other open issues about CUDA versions might also be of help.

@huhlim
Copy link
Author

huhlim commented Jul 19, 2021

I was benchmarking with the CASP14 targets. T1026 (172 residues) raised the issue.
I realized that some of the targets still have issues of the CUDA_ERROR_ILLEGAL_ADDRESS, even though I used jax==0.2.17 and jaxlib==0.1.68+cuda110. Those targets were running okay on CPUs.

For my system information,

  • NVIDIA driver: 450.36.06
  • CUDA version: 11.0
  • jax: 0.1.68
  • jaxlib: 0.1.68+cuda110
  • tensorflow: 2.5.0

@tfgg
Copy link
Collaborator

tfgg commented Jul 19, 2021

That's a very small protein, so I'm surprised it's an issue. What GPU are you using? Is it possible to try using the Dockerfile?

You could try disabling unified memory by commenting out these two lines in your script, if you have them:
https://github.com/deepmind/alphafold/blob/main/docker/run_docker.py#L171-L172

@huhlim
Copy link
Author

huhlim commented Jul 19, 2021

I tested with Quadro RTX 6000 and RTX 2080Ti.
I have tested with
(1) jaxlib==0.1.68+cuda110, jax==0.2.17, cudatoolkit=11.0.3 for my custom non-Docker version
(2) jaxlib==0.1.69+cuda110, jax==0.2.17, cudatoolkit=11.0.3 for my custom non-Docker version
(3) (1) or (2) + commenting out the two lines for the unified memory
(4) the same as (2), but with a docker container (the original one)

There was no issue with the (4)... So, there may be some differences between my non-Docker version and the original Docker version... (I thought I implemented my custom non-Docker version with the exact same version of libraries...) I will try it again.

@chrisroat
Copy link

@huhlim Did you solve your CUDA_ERROR_ILLEGAL_ADDRESS problems? I just ran ~100 proteins from an internal sample, and this cropped up for me in some cases. As I investigate, it would be helpful if you follow-up here with anything you learned and/or how you resolved your problem. (I am using Docker at an A100)

@huhlim
Copy link
Author

huhlim commented Aug 4, 2021

@christroat I could not fully resolve the issue. When I turned off the jax.jit compilation of models (initialization of the RunModel class in alphafold/model/model.py), it reduced the chance of the error but did not resolve the issue. I have not had the issue with my Docker system, so I guess my problem is related to our cluster setup... Unfortunately, I gave up to tackle the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants