-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jax and jaxlib versions #20
Comments
Hi, we require version of 0.1.69 jaxlib to be able to use CUDA unified memory for running long sequences. If you don't need this you can probably run with 0.1.68, but that might be related to the illegal address error that you see. How long was the sequence you were trying to run? Some of the other open issues about CUDA versions might also be of help. |
I was benchmarking with the CASP14 targets. T1026 (172 residues) raised the issue. For my system information,
|
That's a very small protein, so I'm surprised it's an issue. What GPU are you using? Is it possible to try using the Dockerfile? You could try disabling unified memory by commenting out these two lines in your script, if you have them: |
I tested with Quadro RTX 6000 and RTX 2080Ti. There was no issue with the (4)... So, there may be some differences between my non-Docker version and the original Docker version... (I thought I implemented my custom non-Docker version with the exact same version of libraries...) I will try it again. |
@huhlim Did you solve your CUDA_ERROR_ILLEGAL_ADDRESS problems? I just ran ~100 proteins from an internal sample, and this cropped up for me in some cases. As I investigate, it would be helpful if you follow-up here with anything you learned and/or how you resolved your problem. (I am using Docker at an A100) |
@christroat I could not fully resolve the issue. When I turned off the jax.jit compilation of models (initialization of the RunModel class in alphafold/model/model.py), it reduced the chance of the error but did not resolve the issue. I have not had the issue with my Docker system, so I guess my problem is related to our cluster setup... Unfortunately, I gave up to tackle the issue. |
TL;DR is it okay to use the jaxlib version of 0.1.68+cuda110 instead of 0.1.69+cuda110?
I have tried to write a script and construct a conda environment that does not use Docker. When I used the same versions of jax and jaxlib defined in the docker/Dockerfile, I had some issues during the inference time. Scripts were working fine for model_{1,3,4} but raised CUDA_ERROR_ILLEGAL_ADDRESS errors for model_{2,5}. I have no idea why it happened...
So, I tested many variants of the environment and found out that jax=0.2.17 (probably, it is the same version of the original) and jaxlib=0.1.68+cuda110 (it is the version for installing jax with a command
pip3 install jax[cuda110] -f https://storage.googleapis.com/jax-releases/jax_releases.html
) are okay to run smoothly without Docker, but with my custom conda environment.The text was updated successfully, but these errors were encountered: