Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AttributeError: module 'jaxlib.xla_extension' has no attribute 'PmapFunction' #10717

Closed
PRISHIta123 opened this issue May 15, 2022 · 8 comments · Fixed by #10718
Closed

AttributeError: module 'jaxlib.xla_extension' has no attribute 'PmapFunction' #10717

PRISHIta123 opened this issue May 15, 2022 · 8 comments · Fixed by #10718
Labels
bug Something isn't working

Comments

@PRISHIta123
Copy link

PRISHIta123 commented May 15, 2022

I am getting the above error after installing the latest version of jax and trying to import it for some computations:

!pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
/usr/local/lib/python3.7/dist-packages/jax/_src/lax/linalg.py in <module>()
    977       lu_pivots_to_permutation_p,
    978       partial(_lu_pivots_to_permutation_gpu_lowering,
--> 979               gpu_linalg.cuda_lu_pivots_to_permutation),
    980       platform='cuda')
    981   mlir.register_lowering(

AttributeError: module 'jaxlib.cuda_linalg' has no attribute 'cuda_lu_pivots_to_permutation'

It would be helpful if someone could please look into this ASAP. This is a new error I received when running my code within the past hour and had no problem with this earlier.

@PRISHIta123 PRISHIta123 added the bug Something isn't working label May 15, 2022
@PRISHIta123
Copy link
Author

@yashk2810 please check your latest commit

@yashk2810
Copy link
Collaborator

Which version were you running with previously?

Any latest commit would not affect you since you are installing an already built version.

@PRISHIta123
Copy link
Author

PRISHIta123 commented May 15, 2022

Which version were you running with previously?

Any latest commit would not affect you since you are installing an already built version.

This fetches the latest version right?

!pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

How do I specify the previous version while installing?

@PRISHIta123
Copy link
Author

PRISHIta123 commented May 15, 2022

I downgraded using this which I found on StackOverflow:

!pip install --upgrade jax==0.3.10 jaxlib==0.3.10+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_releases.html

It would be helpful if this could be mentioned on the README page as it becomes difficult to revert back to older versions with only the general installation command that is provided, when newer versions are released and some functionalities suddenly do not work.

@yashk2810
Copy link
Collaborator

We'll release with the fix again! Sorry for the breakage.

@PRISHIta123
Copy link
Author

No worries!

@yashk2810
Copy link
Collaborator

We released jax 0.3.12: https://pypi.org/project/jax/0.3.12/ with the fix.

@PRISHIta123
Copy link
Author

Ok, thanks for addressing this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants