diff --git a/.github/workflows/build_linux_jax_wheels.yml b/.github/workflows/build_linux_jax_wheels.yml index fbf483f276c..782e3acfa93 100644 --- a/.github/workflows/build_linux_jax_wheels.yml +++ b/.github/workflows/build_linux_jax_wheels.yml @@ -173,7 +173,7 @@ jobs: run: | ls -lah pushd rocm-jax - [[ "${{ inputs.jax_ref }}" == *"0.8.2"* ]] && SOURCE_ARG="--jax-source-dir=${{ github.workspace }}/jax" || SOURCE_ARG="" + [[ "${{ inputs.jax_ref }}" != *"0.8.0"* ]] && SOURCE_ARG="--jax-source-dir=${{ github.workspace }}/jax" || SOURCE_ARG="" python3 build/ci_build \ --compiler=clang \ --python-versions="${{ inputs.python_version }}" \ @@ -256,6 +256,8 @@ jobs: jax_ref: ${{ needs.build_jax_wheels.outputs.jax_ref }} jax_version: ${{ needs.build_jax_wheels.outputs.jax_version }} jaxlib_version: ${{ needs.build_jax_wheels.outputs.jaxlib_version }} + jax_plugin_version: ${{ needs.build_jax_wheels.outputs.jax_plugin_version }} + jax_pjrt_version: ${{ needs.build_jax_wheels.outputs.jax_pjrt_version }} test_runs_on: ${{ needs.generate_target_to_run.outputs.test_runs_on }} upload_jax_wheels: diff --git a/.github/workflows/release_portable_linux_jax_wheels.yml b/.github/workflows/release_portable_linux_jax_wheels.yml index 00000076b24..0a54aafc69b 100644 --- a/.github/workflows/release_portable_linux_jax_wheels.yml +++ b/.github/workflows/release_portable_linux_jax_wheels.yml @@ -102,7 +102,7 @@ jobs: fail-fast: false matrix: python_version: ["3.11", "3.12", "3.13", "3.14"] - jax_ref: ["rocm-jaxlib-v0.8.0", "rocm-jaxlib-v0.8.2"] + jax_ref: ["rocm-jaxlib-v0.8.0", "rocm-jaxlib-v0.8.2", "rocm-jaxlib-v0.9.0"] uses: ./.github/workflows/build_linux_jax_wheels.yml with: diff --git a/.github/workflows/test_linux_jax_wheels.yml b/.github/workflows/test_linux_jax_wheels.yml index f235b322260..2db5a5002bf 100644 --- a/.github/workflows/test_linux_jax_wheels.yml +++ b/.github/workflows/test_linux_jax_wheels.yml @@ -41,7 +41,15 @@ on: required: false type: string jaxlib_version: - description: "Full jaxlib version including rocm suffix (e.g. 0.8.0+rocm7.12.0). Extracted from built wheels by write_jax_versions.py." + description: "jaxlib wheel version (e.g. 0.9.0+rocm7 or 0.8.0+rocm7.12.0). Extracted from built wheels by write_jax_versions.py." + required: false + type: string + jax_plugin_version: + description: "jax_rocm7_plugin wheel version. Extracted from built wheels by write_jax_versions.py." + required: false + type: string + jax_pjrt_version: + description: "jax_rocm7_pjrt wheel version. Extracted from built wheels by write_jax_versions.py." required: false type: string ref: @@ -103,7 +111,15 @@ on: required: false type: string jaxlib_version: - description: "Full jaxlib version with rocm suffix (e.g. 0.8.0+rocm7.12.0). Leave empty to auto-compute from rocm_version." + description: "jaxlib wheel version (e.g. 0.9.0+rocm7 or 0.8.0+rocm7.12.0). Leave empty to auto-compute from rocm_version." + required: false + type: string + jax_plugin_version: + description: "jax_rocm7_plugin wheel version. Leave empty to use jaxlib_version." + required: false + type: string + jax_pjrt_version: + description: "jax_rocm7_pjrt wheel version. Leave empty to use jaxlib_version." required: false type: string test_runs_on: @@ -206,6 +222,8 @@ jobs: echo "Using versions passed from build workflow" echo "JAXLIB_VERSION=${{ inputs.jaxlib_version }}" >> "$GITHUB_ENV" echo "JAX_VERSION=${{ inputs.jax_version }}" >> "$GITHUB_ENV" + echo "JAX_PLUGIN_VERSION=${{ inputs.jax_plugin_version }}" >> "$GITHUB_ENV" + echo "JAX_PJRT_VERSION=${{ inputs.jax_pjrt_version }}" >> "$GITHUB_ENV" else echo "Computing versions from rocm_version and rocm-jax requirements" pip install packaging @@ -218,12 +236,14 @@ jobs: run: | echo "Installing JAX packages:" echo " JAXLIB_VERSION=${JAXLIB_VERSION}" + echo " JAX_PLUGIN_VERSION=${JAX_PLUGIN_VERSION}" + echo " JAX_PJRT_VERSION=${JAX_PJRT_VERSION}" echo " JAX_VERSION=${JAX_VERSION}" # Install jaxlib/plugin/pjrt from the GPU-family index; install jax from PyPI to match the version pip install --index-url "${{ env.WHEEL_INDEX_URL }}" \ "jaxlib==${JAXLIB_VERSION}" \ - "jax_rocm7_plugin==${JAXLIB_VERSION}" \ - "jax_rocm7_pjrt==${JAXLIB_VERSION}" + "jax_rocm7_plugin==${JAX_PLUGIN_VERSION}" \ + "jax_rocm7_pjrt==${JAX_PJRT_VERSION}" pip install jax==${JAX_VERSION} - name: Run JAX tests