Skip to content

Commit

Permalink
Use libtpu releases rather than libtpu-nightly for jax[tpu].
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 688632409
  • Loading branch information
hawkinsp authored and Google-ML-Automation committed Oct 22, 2024
1 parent 1c6b0a9 commit e4f3f8f
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 6 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/cloud-tpu-ci-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,23 +50,23 @@ jobs:
pip install -U -r build/collect-profile-requirements.txt
- name: Install JAX
run: |
pip uninstall -y jax jaxlib libtpu-nightly
pip uninstall -y jax jaxlib libtpu
if [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then
pip install .[tpu] \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then
pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
pip install --pre libtpu-nightly \
pip install --pre libtpu \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install requests
elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then
pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
# TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release.
pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install requests
else
echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}"
exit 1
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* `jax.lax.FftType` was introduced as a public name for the enum of FFT
operations. The semi-public API `jax.lib.xla_client.FftType` has been
deprecated.
* TPU: JAX now installs TPU support from the `libtpu` package rather than
`libtpu-nightly`. For the next few releases JAX will pin an empty version of
`libtpu-nightly` as well as `libtpu` to ease the transition; that dependency
will be removed in Q1 2025.

* Deprecations:
* The semi-public API `jax.lib.xla_client.PaddingType` has been deprecated.
Expand Down
2 changes: 1 addition & 1 deletion docs/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ pip install -U --pre jax jaxlib -f https://storage.googleapis.com/jax-releases/j
- Google Cloud TPU:

```bash
pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install -U --pre jax jaxlib libtpu requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
```

- NVIDIA GPU (CUDA 12):
Expand Down
8 changes: 6 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
_current_jaxlib_version = '0.4.34'
# The following should be updated after each new jaxlib release.
_latest_jaxlib_version_on_pypi = '0.4.34'
_libtpu_version = '0.1.dev20241002'

_libtpu_version = '0.0.2'
_libtpu_nightly_terminal_version = '0.1.dev20241010+nightly.cleanup'

def load_version_module(pkg_path):
spec = importlib.util.spec_from_file_location(
Expand Down Expand Up @@ -76,7 +78,9 @@ def load_version_module(pkg_path):
# $ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
'tpu': [
f'jaxlib>={_current_jaxlib_version},<={_jax_version}',
f'libtpu-nightly=={_libtpu_version}',
# TODO(phawkins): remove the libtpu-nightly dependency in Q1 2025.
f'libtpu-nightly=={_libtpu_nightly_terminal_version}',
f'libtpu=={_libtpu_version}',
'requests', # necessary for jax.distributed.initialize
],

Expand Down

0 comments on commit e4f3f8f

Please sign in to comment.