Replies: 1 comment 5 replies
-
This looks like an issue with your JAX installation. Can you try installing the most recent version: pip install -U "jax[cuda12]" |
Beta Was this translation helpful? Give feedback.
5 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
jnp.linalg.det(array) method raises error when run on a cuda device. When run using a cpu instead the determinant is calculated without any error.
The following code describes the reproduction of the issue:
Jax and jaxlib version: 0.4.28
cuda driver 12.4
Beta Was this translation helpful? Give feedback.
All reactions