From e4f3f8f064228c0739b3333a9e78ac715cb34273 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 22 Oct 2024 11:45:58 -0700 Subject: [PATCH] Use libtpu releases rather than libtpu-nightly for jax[tpu]. PiperOrigin-RevId: 688632409 --- .github/workflows/cloud-tpu-ci-nightly.yml | 6 +++--- CHANGELOG.md | 4 ++++ docs/installation.md | 2 +- setup.py | 8 ++++++-- 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index d617178254a4..4bff1e87e7f3 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -50,23 +50,23 @@ jobs: pip install -U -r build/collect-profile-requirements.txt - name: Install JAX run: | - pip uninstall -y jax jaxlib libtpu-nightly + pip uninstall -y jax jaxlib libtpu if [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then pip install .[tpu] \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html - pip install --pre libtpu-nightly \ + pip install --pre libtpu \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html pip install requests elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then pip install --pre . -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html + # TODO(phawkins): switch to libtpu, when the oldest release we support is a libtpu release. pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html pip install requests - else echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}" exit 1 diff --git a/CHANGELOG.md b/CHANGELOG.md index 7fb6c1bd0fde..d9da9a2bdc71 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * `jax.lax.FftType` was introduced as a public name for the enum of FFT operations. The semi-public API `jax.lib.xla_client.FftType` has been deprecated. + * TPU: JAX now installs TPU support from the `libtpu` package rather than + `libtpu-nightly`. For the next few releases JAX will pin an empty version of + `libtpu-nightly` as well as `libtpu` to ease the transition; that dependency + will be removed in Q1 2025. * Deprecations: * The semi-public API `jax.lib.xla_client.PaddingType` has been deprecated. diff --git a/docs/installation.md b/docs/installation.md index 0c40e3dfc881..5b8893628d85 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -282,7 +282,7 @@ pip install -U --pre jax jaxlib -f https://storage.googleapis.com/jax-releases/j - Google Cloud TPU: ```bash -pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +pip install -U --pre jax jaxlib libtpu requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html ``` - NVIDIA GPU (CUDA 12): diff --git a/setup.py b/setup.py index 9cbcd0d950e8..9ce3626fec75 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,9 @@ _current_jaxlib_version = '0.4.34' # The following should be updated after each new jaxlib release. _latest_jaxlib_version_on_pypi = '0.4.34' -_libtpu_version = '0.1.dev20241002' + +_libtpu_version = '0.0.2' +_libtpu_nightly_terminal_version = '0.1.dev20241010+nightly.cleanup' def load_version_module(pkg_path): spec = importlib.util.spec_from_file_location( @@ -76,7 +78,9 @@ def load_version_module(pkg_path): # $ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 'tpu': [ f'jaxlib>={_current_jaxlib_version},<={_jax_version}', - f'libtpu-nightly=={_libtpu_version}', + # TODO(phawkins): remove the libtpu-nightly dependency in Q1 2025. + f'libtpu-nightly=={_libtpu_nightly_terminal_version}', + f'libtpu=={_libtpu_version}', 'requests', # necessary for jax.distributed.initialize ],