diff --git a/.github/workflows/build_linux_jax_wheels.yml b/.github/workflows/build_linux_jax_wheels.yml index e3642232e..77b9e80fa 100644 --- a/.github/workflows/build_linux_jax_wheels.yml +++ b/.github/workflows/build_linux_jax_wheels.yml @@ -23,6 +23,11 @@ on: tar_url: description: URL to TheRock tarball to build against type: string + jax_ref: + description: JAX ref to checkout + required: false + type: string + default: rocm-jaxlib-v0.7.1 workflow_dispatch: inputs: amdgpu_family: @@ -57,6 +62,10 @@ on: tar_url: description: URL to TheRock tarball to build against type: string + jax_ref: + description: JAX ref to checkout + type: string + default: "rocm-jaxlib-v0.7.1" permissions: id-token: write @@ -64,9 +73,6 @@ permissions: jobs: build_jax_wheels: - strategy: - matrix: - jax_ref: [rocm-jaxlib-v0.7.1] name: Build Linux JAX Wheels | ${{ inputs.amdgpu_family }} | Python ${{ inputs.python_version }} runs-on: ${{ github.repository_owner == 'ROCm' && 'azure-linux-scale-rocm' || 'ubuntu-24.04' }} env: @@ -81,7 +87,7 @@ jobs: with: path: jax repository: rocm/rocm-jax - ref: ${{ matrix.jax_ref }} + ref: ${{ inputs.jax_ref }} - name: Configure Git Identity run: | @@ -102,7 +108,7 @@ jobs: python3 build/ci_build \ --compiler=clang \ --python-versions="${{ inputs.python_versions }}" \ - --rocm-version="${ROCM_VERSION:0:5}" \ + --rocm-version="${ROCM_VERSION}" \ --therock-path="${{ inputs.tar_url }}" \ dist_wheels