Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Installing JAX across MacOS and Linux. #6956

Open
tbenthompson opened this issue Nov 3, 2022 · 14 comments
Open

Installing JAX across MacOS and Linux. #6956

tbenthompson opened this issue Nov 3, 2022 · 14 comments
Labels
kind/bug Something isn't working as expected status/triage This issue needs to be triaged

Comments

@tbenthompson
Copy link

The markers to specify separate versions/sources for different platforms are not working as I would've expected.

jaxlib

With jaxlib, I was able to resolve the problem via a hacky workaround.

The simple version did not work:

jaxlib = [
    {version = "^0.3.22", platform = "darwin"},
    {version = "^0.3.22+cuda11.cudnn82", platform = "linux"}
]

With this dependency spec, poetry v1.2.2 will install the +cuda version of the package regardless of the platform.

This seems to be an example of the issue discussed in #6710 and fixed in python-poetry/poetry-core#497

I worked around the issue by specifying like:

jaxlib = [
    {version = "0.3.20", platform = "darwin", source="pypi"},
    {version = "^0.3.22+cuda11.cudnn82", platform = "linux"}
]

Since the version numbers are now different, the darwin install works correctly. The source="pypi" was also necessary.

jax

With jax, I'm unable to get anything to work in a cross-platform way. As one particular example, I pinned versions to try to get something like the jaxlib hack to work:

jax = [
    {extras = ["cpu"], version = "0.3.21", platform="darwin", source="pypi"},
    {extras = ["cuda"], version = "0.3.22", platform="linux"}
]

In this case, the resolver seems to decide that I want to depend on both version:

Because jax (0.3.22) depends on jaxlib (0.3.22+cuda11.cudnn82)
 and jax (0.3.21) depends on jaxlib (0.3.20), jax (0.3.22) is incompatible with jax (0.3.21).
So, because confirm depends on both jax (0.3.21) and jax (0.3.22), version solving failed.

details

I previously posted: #6955 and in the comment there, I thought I had fixed this issue, but that was just from looking at the poetry.lock file. Actual testing demonstrated I had been wrong.

Versions:

  • Poetry version: 1.2.2
  • Python version: 3.10.6
  • OS version and name: Ubuntu 22.04
  • pyproject.toml:
[tool.isort]
profile = "black"

[tool.poetry]
name = "confirm"
version = "0.1.0"
description = ""
authors = ["Your Name <[email protected]>"]

[tool.poetry.dependencies]
python = "~3.10"
numpy = "^1.23.4"
scipy = "^1.9.3"
typer = "^0.6.1"
jax = [
    {extras = ["cpu"], version = "0.3.21", platform="darwin", source="pypi"},
    {extras = ["cuda"], version = "0.3.22", platform="linux"}
]

[tool.poetry.group.test.dependencies]
pytest = "^7.2.0"
pre-commit = "^2.20.0"

[[tool.poetry.source]]
name = "jax"
url = "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"
default = false
secondary = false

[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
@dimbleby
Copy link
Contributor

dimbleby commented Nov 3, 2022

seems like another duplicate of #5506

@tbenthompson
Copy link
Author

I'm not the expert here but I think there are issues here not captured by #5506. I'm sure fixing #5506 would improve the situation here. But, the example pyproject.toml I shared at the end of the issue causes failures but only mentions the jax package and has no explicit reference to jaxlib.

@dimbleby
Copy link
Contributor

dimbleby commented Nov 3, 2022

I'd rate it highly likely that these are in fact duplicates

However so far as I know no-one has really dug into #5506: until they do and the issues are more clearly understood, perhaps we won't be absolutely certain.

@jorenham
Copy link

I have a workaround for exactly this :)
https://github.com/jorenham/jax_pep503

@tbenthompson
Copy link
Author

tbenthompson commented Mar 1, 2023

I have a workaround for exactly this :)
https://github.com/jorenham/jax_pep503

Thanks!

For other folks coming by, what jorenham has done is super useful for having a PyPI-compatible JAX package repository, but it doesn't solve the issue here because there aren't non-CUDA/Mac builds of jaxlib available in his package repo. I'm glad to see progress here!

UPDATE: This should be fixed as part of the bug discussed here: jorenham/jax_pep503#3

@tbenthompson
Copy link
Author

tbenthompson commented Mar 2, 2023

Oh well, I was hopeful for @jorenham 's solution, but it still turns out to be thwarted by problems with poetry 1.3.2 dependency handling. Perhaps the simplest reproducing example is the following:

[tool.poetry]
name = "example"
version = "0.1.0"
description = ""
authors = ['tbenthompson']

[tool.poetry.dependencies]
python = "^3.10"
jaxlib = [
    { version = "0.4.4+cuda11.cudnn86", platform = "linux", source="jorenham/jax_pep503" },
    { version = "0.4.4", platform = "darwin", source = "jorenham/jax_pep503" }
]

[[tool.poetry.source]]
name = "jorenham/jax_pep503"
url = "https://jorenham.github.io/jax_pep503/"
secondary = true

which fails when run on Mac with:

❯ poetry --version
Poetry (version 1.3.2)

❯ poetry lock
Updating dependencies
Resolving dependencies... (0.5s)

Writing lock file

❯ poetry install
Installing dependencies from lock file

Package operations: 2 installs, 0 updates, 0 removals

  • Installing scipy (1.9.3)
  • Installing jaxlib (0.4.4+cuda11.cudnn86): Failed

  RuntimeError

  Unable to find installation candidates for jaxlib (0.4.4+cuda11.cudnn86)

In case having a poetry.lock file is also useful, here is the relevant subset (I left out the numpy/scipy portion):

[[package]]
name = "jaxlib"
version = "0.4.4+cuda11.cudnn86"
description = "XLA library for JAX"
category = "main"
optional = false
python-versions = ">=3.8"
files = [
    {file = "jaxlib-0.4.4+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:2248ce9c99bdc58a7ca634047bc90bd
dff227a58224f0a9f7fddc736f9dcec52"},
    {file = "jaxlib-0.4.4+cuda11.cudnn86-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:2801b02db6d801b1c1abaebb9d5dbc1
0d483b6834dbf3270f126efdbc5e834fd"},
    {file = "jaxlib-0.4.4+cuda11.cudnn86-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:b2294328da65de22ba4228f735b387151
e4d4bfb840214a37eec5a203a9ce5c1"},
    {file = "jaxlib-0.4.4+cuda11.cudnn86-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:7cc659c7fe8110917715ae0458d3f7939
080d21ac94e696ac5d5015a65cfb7be"},
]

[package.dependencies]
numpy = ">=1.20"
scipy = ">=1.5"

[package.source]
type = "legacy"
url = "https://jorenham.github.io/jax_pep503"
reference = "jorenham/jax_pep503"

As you can see above, only wheels with the 0.4.4+cuda11.cudnn86 version are being included in poetry.lock.

@radoering
Copy link
Member

radoering commented Mar 2, 2023

What if you explicitly exclude the versions you don't want, i.e. replacing version = "0.4.4" with version = "==0.4.4,!=0.4.4+cuda11.cudnn86"?

@tbenthompson
Copy link
Author

What if you explicitly exclude the versions you don't want, i.e. replacing version = "0.4.4" with version = "==0.4.4,!=0.4.4+cuda11.cudnn86"?

Nice idea, thanks! The resulting poetry lock is unchanged. But, it does make the poetry install error on Mac more sensible:

❯ poetry install 
Installing dependencies from lock file

Because example depends on jaxlib (0.4.4,!=0.4.4+cuda11.cudnn86) which doesn't match any versions, version solving failed.

@jorenham
Copy link

jorenham commented Mar 2, 2023

I'm guessing here, but could it be that the version resolving code only considers the repository in which the first version requirement is found? I can image that there might be an implicit "one package, one repo" assumption.

E.g. with "==0.4.4,!=0.4.4+cuda11.cudnn86" it first resolves 0.4.4 to be in the main pypi repo, then looks for 0.4.4+cuda11.cudnn86 in the main repo, fails to find it there, but it doesn't consider the secondary repo.

@dimbleby
Copy link
Contributor

comments from #6956 (comment) onwards are quite different from the original report

(and much simpler! for those the issue is that so far as satisfying requirements goes, 0.4.4+cuda11.cudnn86 is a perfectly good solution for both version = "0.4.4+cuda11.cudnn86" and also version = "0.4.4". So poetry is within its rights to choose only that. The underlying problem is that local versions just aren't a good way to express this sort of thing: jax are burdening the version with more meaning than it is capable of carrying, and the cracks are showing)

@tbenthompson
Copy link
Author

@dimbleby Out of curiosity, how should JAX handle this CUDA/no-CUDA situation in its releases? I am happy to file an issue with JAX.

@tbenthompson
Copy link
Author

It also seems feasible for Poetry to add something like a “precise” flag that ensures that version = "0.4.4+cuda11.cudnn86" is not used to satisfy a version = "0.4.4", precise = True dependency. Perhaps something like that already exists?

@dimbleby
Copy link
Contributor

#7256 asking for "arbitrary equality" is close. However this too would be a stretch of the intended use: per the specs, arbitrary equality is supposed to be for handling projects with non-compliant versions - not for handling projects with versions that are compliant, but which they are overloading with extra semantics.

the CUDA / non-CUDA thing seems to be difficult for several projects, see never-ending #6409 for another example. It may be that jax / torch etc could do better if they thought hard and came up with some suitable scheme using extras. Or the best thing might be to cut the gordian knot: publish separate packages jax-cpu and jax-cuda or whatever. Or conceivably a new PEP is needed: if they really have a situation that doesn't fit well with current standards: engage with the python packaging authority to improve the standards.

@NeilGirdhar
Copy link
Contributor

@tbenthompson If you do file an issue, could you link it here so that I can watch it? Also, it may be worth linking jax-ml/jax#5410 so that both issues can be handled at the same time?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
kind/bug Something isn't working as expected status/triage This issue needs to be triaged
Projects
None yet
Development

No branches or pull requests

5 participants