-
-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
figure out a better run constraint for jax #197
Comments
This happened to me when I inadvertently was using an old Python, 3.5 I think. The jax version was way ahead of the jaxlib version but I couldn't update to a newer jaxlib version. The install completed; the problem showed up at runtime. When I set up another environment with Python 3.10 and started over the problem went away. |
We're having this problem in the CI pipelines of pymc-devs/pymc. Very annoying, because we don't really want to pin exact jax/jaxlib versions |
I think we should have jax pin to jaxlib at the same version or maybe be at most 2 minor versions ahead. Any thoughts @conda-forge/jax @conda-forge/jaxlib? |
Upstream JAX maintainer: there's probably no harm in pinning the most recent jaxlib as a dependency of each jax release. It's possible we'll do that for the pip packages also soon. The main reason we haven't so far is because there are multiple kinds of However since you already have a method for handling CUDA/non-CUDA variants in the conda-forge build, I would imagine you can just add a hard constraint right now. |
Thank you! |
We usually track the pin in the Jax feedstock. A PR was just merged that will hopefully catch problems going forward. See conda-forge/jax-feedstock#130. @beckermr let me know if that PR addresses your concerns or if we should come up with a better solution |
That's great. It requires manual updates which sucks. |
But it should (in theory) fail if not updated, so easier to spot... maybe? |
I think we could put jaxlib in host and then pin greater than that version. That would work like a compiler where you always have to be at or after the version used for the build. If we see errors, maybe we revisit. |
Right now the jax run constraint is simply
>={{ version }}
. I've gotten reports of installation issues when jax gets too far ahead of jaxlib. Maybe we should pin tox.x
or somethibg?The text was updated successfully, but these errors were encountered: