From 797fb614401c0052cba2ff8659388a530392d7ec Mon Sep 17 00:00:00 2001 From: kithumma Date: Fri, 21 Nov 2025 06:57:38 +0000 Subject: [PATCH 01/40] jax test --- .github/workflows/build_linux_jax_wheels.yml | 19 ++ .github/workflows/test_jax_wheels.yml | 194 +++++++++++++++++++ build_tools/install_rocm_tar.py | 57 ++++++ external-builds/jax/requirements-jax.txt | 19 ++ 4 files changed, 289 insertions(+) create mode 100644 .github/workflows/test_jax_wheels.yml create mode 100644 build_tools/install_rocm_tar.py create mode 100644 external-builds/jax/requirements-jax.txt diff --git a/.github/workflows/build_linux_jax_wheels.yml b/.github/workflows/build_linux_jax_wheels.yml index 6d044a2a578..a6c7756331b 100644 --- a/.github/workflows/build_linux_jax_wheels.yml +++ b/.github/workflows/build_linux_jax_wheels.yml @@ -130,3 +130,22 @@ jobs: source .venv/bin/activate pip3 install boto3 packaging python3 ./build_tools/third_party/s3_management/manage.py ${{ inputs.s3_subdir }}/${{ inputs.amdgpu_family }} + + test_jax_wheels: + name: Test JAX wheels | ${{ inputs.amdgpu_family }} + needs: [build_jax_wheels] + if: ${{ github.repository_owner == 'ROCm' }} + permissions: + id-token: write + contents: read + packages: write + uses: ./.github/workflows/test_linux_jax_wheels.yml + with: + amdgpu_family: ${{ inputs.amdgpu_family }} + release_type: ${{ inputs.release_type }} + s3_subdir: ${{ inputs.s3_subdir }} + package_index_url: ${{ inputs.cloudfront_staging_url }} + rocm_version: ${{ inputs.rocm_version }} + tar_url: ${{ inputs.tar_url }} + python_versions: ${{ inputs.python_versions }} + jax_ref: master diff --git a/.github/workflows/test_jax_wheels.yml b/.github/workflows/test_jax_wheels.yml new file mode 100644 index 00000000000..b8d16fe5846 --- /dev/null +++ b/.github/workflows/test_jax_wheels.yml @@ -0,0 +1,194 @@ +name: Test Linux JAX Wheels + +on: + workflow_call: + inputs: + amdgpu_family: + required: true + type: string + release_type: + required: true + type: string + s3_subdir: + required: true + type: string + package_index_url: + description: Base CloudFront URL for the Python package index + required: true + type: string + rocm_version: + description: ROCm version (optional, informational) + required: false + type: string + tar_url: + description: URL to TheRock tarball to configure ROCm + required: true + type: string + python_versions: + description: Python version(s) to test (e.g., "3.12") + required: true + type: string + jax_ref: + description: rocm-jax repository ref/branch to check out + required: false + type: string + default: master + jax_test_branch: + description: rocm/jax tests ref/branch to run + required: false + type: string + default: main + test_runs_on: + required: true + type: string + ref: + description: "Branch, tag or SHA to checkout. Defaults to the reference or SHA that triggered the workflow." + type: string + workflow_dispatch: + inputs: + amdgpu_family: + type: choice + options: + - gfx110X-dgpu + - gfx1151 + - gfx120X-all + - gfx94X-dcgpu + - gfx950-dcgpu + default: gfx94X-dcgpu + release_type: + description: The type of release ("nightly" or "dev") + required: true + type: string + default: dev + s3_subdir: + description: S3 subdirectory, not including the GPU-family + required: true + type: string + default: v2 + package_index_url: + description: Base CloudFront URL for the Python package index + required: true + type: string + default: https://d25kgig7rdsyks.cloudfront.net/v2-staging + rocm_version: + description: ROCm version (optional, informational) + required: false + type: string + tar_url: + description: URL to TheRock tarball to configure ROCm + required: true + type: string + python_versions: + description: Python version(s) to test (e.g., "3.12") + required: true + type: string + default: "3.12" + jax_ref: + description: rocm-jax repository ref/branch to check out + required: false + type: string + default: master + jax_test_branch: + description: google/jax tests ref/branch to run + required: false + type: string + default: main + test_runs_on: + description: Runner label to use. The selected runner should have a GPU supported by amdgpu_family + required: true + type: string + default: "linux-mi325-1gpu-ossci-rocm" + +permissions: + contents: read + id-token: write + packages: write + +jobs: + test_jax_wheels: + name: Test JAX Wheels | ${{ inputs.amdgpu_family }} | Python ${{ inputs.python_versions }} + runs-on: ${{ inputs.test_runs_on }} + container: + image: ${{ contains(inputs.test_runs_on, 'linux') && 'ghcr.io/rocm/no_rocm_image_ubuntu24_04@sha256:405945a40deaff9db90b9839c0f41d4cba4a383c1a7459b28627047bf6302a26' || null }} + options: >- + --device /dev/kfd + --device /dev/dri + --group-add render + --group-add video + + env: + VIRTUAL_ENV: /home/tester/.venv + PIP_PROGRESS_BAR: off + PIP_DISABLE_PIP_VERSION_CHECK: 1 + THEROCK_TAR_URL: ${{ inputs.tar_url }} + PYTHON_VERSION: ${{ inputs.python_versions }} + WHEEL_INDEX_URL: ${{ inputs.package_index_url }}/${{ inputs.s3_subdir }}/${{ inputs.amdgpu_family }} + + steps: + - name: Checkout + uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 + with: + repository: ${{ inputs.repository || github.repository }} + ref: ${{ inputs.ref || '' }} + + - name: Checkout rocm-jax (plugin + build scripts) + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + with: + path: jax + repository: rocm/rocm-jax + ref: ${{ inputs.jax_ref }} + + - name: Checkout JAX tests repo (for extended tests) + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + with: + repository: rocm/jax + ref: ${{ inputs.jax_test_branch }} + path: jax/jax_tests + + - name: Set up Python + uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 + with: + python-version: ${{ inputs.python_versions }} + + - name: Ensure PATH includes venv bin + run: | + echo "PATH=${{ env.VIRTUAL_ENV }}/bin:${PATH}" >> "$GITHUB_ENV" + echo "PIP_PROGRESS_BAR=off" >> "$GITHUB_ENV" + echo "PIP_DISABLE_PIP_VERSION_CHECK=1" >> "$GITHUB_ENV" + + - name: System deps, venv, and base jax requirements install + run: | + python3 setup_venv.py /home/tester/.venv --activate-in-future-github-actions-steps + pip install -r external-builds/jax/requirements-jax.txt + + - name: Configure ROCm from TheRock tarball + env: + THEROCK_TAR_URL: ${{ env.THEROCK_TAR_URL }} + run: | + python3 build_tools/install_rocm_tar.py + + - name: Extract JAX version and set to GITHUB_ENV + run: | + JAX_VERSION=$(tr -d ' ' < rocm-jax/build/requirements.txt \ + | grep -E '^(jax|jaxlib)==[^#]+' | head -n1 | cut -d'=' -f3) + echo "JAX_VERSION=$JAX_VERSION" >> "$GITHUB_ENV" + + - name: Install JAX wheels from package index + run: | + # Install jaxlib/plugin/pjrt from the GPU-family index; install jax from PyPI to match the version + pip install --index-url "${WHEEL_INDEX_URL}" \ + "jaxlib==${JAX_VERSION}+rocm${{ inputs.rocm_version }}" \ + "jax-rocm7-plugin==${JAX_VERSION}+rocm${{ inputs.rocm_version }}" \ + "jax-rocm7-pjrt==${JAX_VERSION}+rocm${{ inputs.rocm_version }}" + pip install --extra-index-url https://pypi.org/simple "jax==${JAX_VERSION}" + + python -c "import jax; print('JAX version:', jax.__version__)" + python -c "import jaxlib; print('jaxlib version:', jaxlib.__version__)" + + - name: Run JAX tests + working-directory: jax-test + run: | + pytest jax/jax_tests/tests/multi_device_test.py -q + pytest jax/jax_tests/tests/core_test.py -q + pytest jax/jax_tests/tests/util_test.py -q + pytest jax/jax_tests/tests/scipy_stats_test.py -q diff --git a/build_tools/install_rocm_tar.py b/build_tools/install_rocm_tar.py new file mode 100644 index 00000000000..e9d7b5c7226 --- /dev/null +++ b/build_tools/install_rocm_tar.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +import os +import re +import shutil +import subprocess +import sys +from pathlib import Path + +def run(cmd, cwd=None): + print(f"+ {cmd}") + subprocess.check_call(cmd, shell=True, cwd=cwd) + +def main(): + therock_tar_url = os.environ.get("THEROCK_TAR_URL", "").strip() + if not therock_tar_url: + print("THEROCK_TAR_URL not provided") + sys.exit(1) + + workdir = Path.cwd() / "therock-tarball" + install_dir = workdir / "install" + workdir.mkdir(exist_ok=True) + install_dir.mkdir(exist_ok=True) + print(f"Working in {workdir}") + + # Download tarball + run(f'wget -q "{therock_tar_url}"', cwd=str(workdir)) + + # Find tarball + tars = list(workdir.glob("*.tar.gz")) + if not tars: + print("No .tar.gz downloaded") + sys.exit(1) + tarball = tars[0].name + print(f"Found tarball: {tarball}") + + # Extract version from filename + m = re.search(r'(\d+\.\d+\.\w+\d+)', tarball) + if not m: + print("Could not extract ROCm version from tarball name") + sys.exit(1) + version = m.group(1) + print(f"Parsed ROCm version: {version}") + + # Extract tarball + run(f'tar -xf "{tarball}" -C install', cwd=str(workdir)) + + # Move into /opt/rocm- and create symlinks + dest = Path(f"/opt/rocm-{version}") + run(f'sudo mkdir -p "{dest}"') + run(f'sudo mv "{install_dir}"/* "{dest}"') + run(f'sudo ln -sfn "{dest}" /opt/rocm') + run(f'sudo ln -sfn /opt/rocm /etc/alternatives/rocm') + + print("ROCm installation configured at /opt/rocm with alternatives link") + +if __name__ == "__main__": + sys.exit(main()) diff --git a/external-builds/jax/requirements-jax.txt b/external-builds/jax/requirements-jax.txt new file mode 100644 index 00000000000..0a52666f3b2 --- /dev/null +++ b/external-builds/jax/requirements-jax.txt @@ -0,0 +1,19 @@ +numpy<2 +build +wheel +six +auditwheel +scipy +pytest +pytest-html +pytest_html_merger +pytest-reportlog +pytest-rerunfailures +pytest-json-report +cloudpickle +portpicker +matplotlib +absl-py +flatbuffers +hypothesis +ml_dtypes>=0.5.0 From e958aeb67765bf36900040ff4e40c110e02c9bd3 Mon Sep 17 00:00:00 2001 From: kithumma Date: Fri, 21 Nov 2025 06:58:03 +0000 Subject: [PATCH 02/40] jax test --- .github/workflows/test_jax_wheels.yml | 6 +++--- build_tools/install_rocm_tar.py | 7 +++++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/.github/workflows/test_jax_wheels.yml b/.github/workflows/test_jax_wheels.yml index b8d16fe5846..98c528a6b0a 100644 --- a/.github/workflows/test_jax_wheels.yml +++ b/.github/workflows/test_jax_wheels.yml @@ -112,8 +112,8 @@ jobs: image: ${{ contains(inputs.test_runs_on, 'linux') && 'ghcr.io/rocm/no_rocm_image_ubuntu24_04@sha256:405945a40deaff9db90b9839c0f41d4cba4a383c1a7459b28627047bf6302a26' || null }} options: >- --device /dev/kfd - --device /dev/dri - --group-add render + --device /dev/dri + --group-add render --group-add video env: @@ -144,7 +144,7 @@ jobs: repository: rocm/jax ref: ${{ inputs.jax_test_branch }} path: jax/jax_tests - + - name: Set up Python uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 with: diff --git a/build_tools/install_rocm_tar.py b/build_tools/install_rocm_tar.py index e9d7b5c7226..97f684f5c0f 100644 --- a/build_tools/install_rocm_tar.py +++ b/build_tools/install_rocm_tar.py @@ -6,10 +6,12 @@ import sys from pathlib import Path + def run(cmd, cwd=None): print(f"+ {cmd}") subprocess.check_call(cmd, shell=True, cwd=cwd) + def main(): therock_tar_url = os.environ.get("THEROCK_TAR_URL", "").strip() if not therock_tar_url: @@ -34,7 +36,7 @@ def main(): print(f"Found tarball: {tarball}") # Extract version from filename - m = re.search(r'(\d+\.\d+\.\w+\d+)', tarball) + m = re.search(r"(\d+\.\d+\.\w+\d+)", tarball) if not m: print("Could not extract ROCm version from tarball name") sys.exit(1) @@ -49,9 +51,10 @@ def main(): run(f'sudo mkdir -p "{dest}"') run(f'sudo mv "{install_dir}"/* "{dest}"') run(f'sudo ln -sfn "{dest}" /opt/rocm') - run(f'sudo ln -sfn /opt/rocm /etc/alternatives/rocm') + run(f"sudo ln -sfn /opt/rocm /etc/alternatives/rocm") print("ROCm installation configured at /opt/rocm with alternatives link") + if __name__ == "__main__": sys.exit(main()) From ca43e0721c8717cd05b6ce51e270658bd8a13097 Mon Sep 17 00:00:00 2001 From: kithumma Date: Fri, 21 Nov 2025 07:01:30 +0000 Subject: [PATCH 03/40] file name update --- .../{test_jax_wheels.yml => test_linux_jax_wheels.yml} | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) rename .github/workflows/{test_jax_wheels.yml => test_linux_jax_wheels.yml} (98%) diff --git a/.github/workflows/test_jax_wheels.yml b/.github/workflows/test_linux_jax_wheels.yml similarity index 98% rename from .github/workflows/test_jax_wheels.yml rename to .github/workflows/test_linux_jax_wheels.yml index 98c528a6b0a..4eb3a27376b 100644 --- a/.github/workflows/test_jax_wheels.yml +++ b/.github/workflows/test_linux_jax_wheels.yml @@ -69,9 +69,9 @@ on: description: Base CloudFront URL for the Python package index required: true type: string - default: https://d25kgig7rdsyks.cloudfront.net/v2-staging + default: https://rocm.nightlies.amd.com/v2-staging/ rocm_version: - description: ROCm version (optional, informational) + description: ROCm version required: false type: string tar_url: From 869903956aeb7e9fadafa1fdbd485ce210c6ad70 Mon Sep 17 00:00:00 2001 From: kithumma Date: Fri, 21 Nov 2025 07:14:11 +0000 Subject: [PATCH 04/40] update workflow dispatch --- .github/workflows/build_linux_jax_wheels.yml | 28 +++++++++++++++++++- .github/workflows/test_jax_dockerfile.yml | 2 +- .github/workflows/test_linux_jax_wheels.yml | 10 +++---- 3 files changed, 31 insertions(+), 9 deletions(-) diff --git a/.github/workflows/build_linux_jax_wheels.yml b/.github/workflows/build_linux_jax_wheels.yml index a6c7756331b..3f242af9137 100644 --- a/.github/workflows/build_linux_jax_wheels.yml +++ b/.github/workflows/build_linux_jax_wheels.yml @@ -23,6 +23,20 @@ on: tar_url: description: URL to TheRock tarball to build against type: string + cloudfront_url: + description: CloudFront URL pointing to Python index + required: true + type: string + cloudfront_staging_url: + description: CloudFront base URL pointing to staging Python index + required: true + type: string + repository: + description: "Repository to checkout. Otherwise, defaults to `github.repository`." + type: string + ref: + description: "Branch, tag or SHA to checkout. Defaults to the reference or SHA that triggered the workflow." + type: string workflow_dispatch: inputs: amdgpu_family: @@ -57,6 +71,18 @@ on: tar_url: description: URL to TheRock tarball to build against type: string + cloudfront_url: + description: CloudFront base URL pointing to Python index + type: string + default: "https://d25kgig7rdsyks.cloudfront.net/v2" + cloudfront_staging_url: + description: CloudFront base URL pointing to staging Python index + type: string + default: "https://d25kgig7rdsyks.cloudfront.net/v2-staging" + jax_ref: + description: rocm-jax repository ref/branch to check out + type: string + default: rocm-jaxlib-v0.8.0 permissions: id-token: write @@ -81,7 +107,7 @@ jobs: with: path: jax repository: rocm/rocm-jax - ref: ${{ matrix.jax_ref }} + ref: ${{ inputs.jax_ref || matrix.jax_ref }} - name: Configure Git Identity run: | diff --git a/.github/workflows/test_jax_dockerfile.yml b/.github/workflows/test_jax_dockerfile.yml index c63b82ca5e5..866f9119d5c 100644 --- a/.github/workflows/test_jax_dockerfile.yml +++ b/.github/workflows/test_jax_dockerfile.yml @@ -36,7 +36,7 @@ on: jax_plugin_branch: description: JAX plugin branch to checkout to use for test scripts type: string - default: "rocm-jaxlib-v0.6.0" + default: "rocm-jaxlib-v0.8.0" permissions: contents: read diff --git a/.github/workflows/test_linux_jax_wheels.yml b/.github/workflows/test_linux_jax_wheels.yml index 4eb3a27376b..f4c1bdea139 100644 --- a/.github/workflows/test_linux_jax_wheels.yml +++ b/.github/workflows/test_linux_jax_wheels.yml @@ -32,18 +32,14 @@ on: description: rocm-jax repository ref/branch to check out required: false type: string - default: master - jax_test_branch: - description: rocm/jax tests ref/branch to run - required: false - type: string - default: main + default: rocm-jaxlib-v0.8.0 test_runs_on: required: true type: string ref: description: "Branch, tag or SHA to checkout. Defaults to the reference or SHA that triggered the workflow." type: string + workflow_dispatch: inputs: amdgpu_family: @@ -142,7 +138,7 @@ jobs: uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: repository: rocm/jax - ref: ${{ inputs.jax_test_branch }} + ref: ${{ inputs.jax_ref }} path: jax/jax_tests - name: Set up Python From 846b75a86831b52f3c15bbae29b666e9484aa4af Mon Sep 17 00:00:00 2001 From: kithumma Date: Fri, 21 Nov 2025 07:20:11 +0000 Subject: [PATCH 05/40] update workflow dispatch --- .github/workflows/build_linux_jax_wheels.yml | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build_linux_jax_wheels.yml b/.github/workflows/build_linux_jax_wheels.yml index 3f242af9137..0b1f9d77f1d 100644 --- a/.github/workflows/build_linux_jax_wheels.yml +++ b/.github/workflows/build_linux_jax_wheels.yml @@ -156,11 +156,18 @@ jobs: source .venv/bin/activate pip3 install boto3 packaging python3 ./build_tools/third_party/s3_management/manage.py ${{ inputs.s3_subdir }}/${{ inputs.amdgpu_family }} + + - name: Generating target to run + id: configure + env: + TARGET: ${{ inputs.amdgpu_family }} + PLATFORM: "linux" + run: python ./build_tools/github_actions/configure_target_run.py test_jax_wheels: - name: Test JAX wheels | ${{ inputs.amdgpu_family }} + name: Test JAX wheels | ${{ inputs.amdgpu_family }} | ${{ needs.generate_target_to_run.outputs.test_runs_on }} needs: [build_jax_wheels] - if: ${{ github.repository_owner == 'ROCm' }} + if: ${{ needs.generate_target_to_run.outputs.test_runs_on != '' }} permissions: id-token: write contents: read @@ -168,10 +175,12 @@ jobs: uses: ./.github/workflows/test_linux_jax_wheels.yml with: amdgpu_family: ${{ inputs.amdgpu_family }} + test_runs_on: ${{ needs.generate_target_to_run.outputs.test_runs_on }} release_type: ${{ inputs.release_type }} s3_subdir: ${{ inputs.s3_subdir }} package_index_url: ${{ inputs.cloudfront_staging_url }} rocm_version: ${{ inputs.rocm_version }} tar_url: ${{ inputs.tar_url }} python_versions: ${{ inputs.python_versions }} - jax_ref: master + ref: ${{ inputs.ref || '' }} + jax_ref: ${{ inputs.jax_ref }} From cb7e340fb9a762dba0ac327587b1f8b61044f4a7 Mon Sep 17 00:00:00 2001 From: kithumma Date: Wed, 3 Dec 2025 00:43:47 +0000 Subject: [PATCH 06/40] fixing merge conflicts --- .github/workflows/build_linux_jax_wheels.yml | 2 +- .github/workflows/test_linux_jax_wheels.yml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build_linux_jax_wheels.yml b/.github/workflows/build_linux_jax_wheels.yml index 0b1f9d77f1d..e2c398947ef 100644 --- a/.github/workflows/build_linux_jax_wheels.yml +++ b/.github/workflows/build_linux_jax_wheels.yml @@ -181,6 +181,6 @@ jobs: package_index_url: ${{ inputs.cloudfront_staging_url }} rocm_version: ${{ inputs.rocm_version }} tar_url: ${{ inputs.tar_url }} - python_versions: ${{ inputs.python_versions }} + python_version: ${{ inputs.python_version }} ref: ${{ inputs.ref || '' }} jax_ref: ${{ inputs.jax_ref }} diff --git a/.github/workflows/test_linux_jax_wheels.yml b/.github/workflows/test_linux_jax_wheels.yml index f4c1bdea139..8777f4a24ea 100644 --- a/.github/workflows/test_linux_jax_wheels.yml +++ b/.github/workflows/test_linux_jax_wheels.yml @@ -132,13 +132,13 @@ jobs: with: path: jax repository: rocm/rocm-jax - ref: ${{ inputs.jax_ref }} + ref: rocm-jaxlib-v0.8.0 - name: Checkout JAX tests repo (for extended tests) uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: repository: rocm/jax - ref: ${{ inputs.jax_ref }} + ref: rocm-jaxlib-v0.8.0 path: jax/jax_tests - name: Set up Python From 78e751f233f4dcf9503596319f9c17901eca1e1d Mon Sep 17 00:00:00 2001 From: kithumma Date: Fri, 21 Nov 2025 10:21:03 +0000 Subject: [PATCH 07/40] update workflow paths --- .../test_linux_jax_wheels-latest.yml | 202 ++++++++++++++++++ .github/workflows/test_linux_jax_wheels.yml | 18 +- 2 files changed, 212 insertions(+), 8 deletions(-) create mode 100644 .github/workflows/test_linux_jax_wheels-latest.yml diff --git a/.github/workflows/test_linux_jax_wheels-latest.yml b/.github/workflows/test_linux_jax_wheels-latest.yml new file mode 100644 index 00000000000..429179bd3f5 --- /dev/null +++ b/.github/workflows/test_linux_jax_wheels-latest.yml @@ -0,0 +1,202 @@ +name: Test Linux JAX Wheels + +on: + workflow_call: + inputs: + amdgpu_family: + required: true + type: string + release_type: + required: true + type: string + s3_subdir: + required: true + type: string + package_index_url: + description: Base CloudFront URL for the Python package index + required: true + type: string + rocm_version: + description: ROCm version (optional, informational) + required: false + type: string + tar_url: + description: URL to TheRock tarball to configure ROCm + required: true + type: string + python_versions: + description: Python version(s) to test (e.g., "3.12") + required: true + type: string + jax_ref: + description: rocm-jax repository ref/branch to check out + required: false + type: string + default: rocm-jaxlib-v0.8.0 + jax_test_branch: + description: google/rocm jax tests ref/branch to run + required: false + type: string + default: main + test_runs_on: + required: true + type: string + ref: + description: Branch, tag or SHA to checkout. Defaults to the reference or SHA that triggered the workflow. + type: string + + workflow_dispatch: + inputs: + amdgpu_family: + type: choice + options: + - gfx110X-dgpu + - gfx1151 + - gfx120X-all + - gfx94X-dcgpu + - gfx950-dcgpu + default: gfx94X-dcgpu + release_type: + description: The type of release ("nightly" or "dev") + required: true + type: string + default: dev + s3_subdir: + description: S3 subdirectory, not including the GPU-family + required: true + type: string + default: v2 + package_index_url: + description: Base CloudFront URL for the Python package index + required: true + type: string + default: https://rocm.nightlies.amd.com/v2-staging/ + rocm_version: + description: ROCm version + required: false + type: string + tar_url: + description: URL to TheRock tarball to configure ROCm + required: true + type: string + python_versions: + description: Python version(s) to test (e.g., "3.12") + required: true + type: string + default: "3.12" + jax_ref: + description: rocm-jax repository ref/branch to check out + required: false + type: string + default: master + jax_test_branch: + description: rocm/jax tests ref/branch to run + required: false + type: string + default: main + test_runs_on: + description: Runner label to use. The selected runner should have a GPU supported by amdgpu_family + required: true + type: string + default: "linux-mi325-1gpu-ossci-rocm" + +permissions: + contents: read + id-token: write + packages: write + +jobs: + test_jax_wheels: + name: Test JAX Wheels | ${{ inputs.amdgpu_family }} | Python ${{ inputs.python_versions }} + runs-on: ${{ inputs.test_runs_on }} + + container: + image: ${{ contains(inputs.test_runs_on, 'linux') && 'ghcr.io/rocm/no_rocm_image_ubuntu24_04@sha256:405945a40deaff9db90b9839c0f41d4cba4a383c1a7459b28627047bf6302a26' || null }} + options: >- + --user root + --device /dev/kfd + --device /dev/dri + --group-add render + --group-add video + + env: + VIRTUAL_ENV: /home/tester/.venv + PIP_PROGRESS_BAR: off + PIP_DISABLE_PIP_VERSION_CHECK: 1 + THEROCK_TAR_URL: ${{ inputs.tar_url }} + PYTHON_VERSION: ${{ inputs.python_versions }} + WHEEL_INDEX_URL: ${{ inputs.package_index_url }}/${{ inputs.s3_subdir }}/${{ inputs.amdgpu_family }} + + steps: + - name: Main Checkout (GitHub Actions path) + uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd + with: + repository: ${{ inputs.repository || github.repository }} + ref: ${{ inputs.ref || '' }} + + - name: Checkout rocm-jax (GitHub Actions path) + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 + with: + path: jax + repository: rocm/rocm-jax + ref: ${{ inputs.jax_ref }} + + # JAX tests checkout (rocm/jax) + - name: Checkout JAX tests + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 + with: + repository: rocm/jax + ref: ${{ inputs.jax_test_branch || 'rocm-jaxlib-v0.8.0' }} + path: jax/jax_tests + + - name: Set up Python + uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c + with: + python-version: ${{ inputs.python_versions }} + + - name: Ensure PATH includes venv bin and pip env flags + run: | + echo "PATH=${VIRTUAL_ENV}/bin:${PATH}" >> "$GITHUB_ENV" + echo "PIP_PROGRESS_BAR=off" >> "$GITHUB_ENV" + echo "PIP_DISABLE_PIP_VERSION_CHECK=1" >> "$GITHUB_ENV" + + - name: System deps, venv, and base jax requirements install + shell: bash + run: | + python3 build_tools/setup_venv.py "${VIRTUAL_ENV}" --activate-in-future-github-actions-steps + pip install -r external-builds/jax/requirements-jax.txt + + - name: Configure ROCm from TheRock tarball + env: + THEROCK_TAR_URL: ${{ env.THEROCK_TAR_URL }} + run: | + python3 build_tools/install_rocm_tar.py + + - name: Extract JAX version and export + run: | + set -euxo pipefail + JAX_VERSION=$(tr -d ' ' < jax/build/requirements.txt \ + | grep -E '^(jax|jaxlib)==[^#]+' | head -n1 | cut -d'=' -f3) + echo "JAX_VERSION=$JAX_VERSION" >> "$GITHUB_ENV" + + - name: Install JAX wheels from package index + env: + WHEEL_INDEX_URL: ${{ env.WHEEL_INDEX_URL }} + run: | + set -euxo pipefail + pip install --index-url "${WHEEL_INDEX_URL}" \ + "jaxlib==${JAX_VERSION}+rocm${{ inputs.rocm_version }}" \ + "jax-rocm7-plugin==${JAX_VERSION}+rocm${{ inputs.rocm_version }}" \ + "jax-rocm7-pjrt==${JAX_VERSION}+rocm${{ inputs.rocm_version }}" + pip install --extra-index-url https://pypi.org/simple "jax==${JAX_VERSION}" + python -c "import jax; print('JAX version:', jax.__version__)" + python -c "import jaxlib; print('jaxlib version:', jaxlib.__version__)" + + - name: Run JAX tests + working-directory: . + run: | + set -euxo pipefail + pytest jax/jax_tests/tests/multi_device_test.py -q + pytest jax/jax_tests/tests/core_test.py -q + pytest jax/jax_tests/tests/util_test.py -q + pytest jax/jax_tests/tests/scipy_stats_test.py -q diff --git a/.github/workflows/test_linux_jax_wheels.yml b/.github/workflows/test_linux_jax_wheels.yml index 8777f4a24ea..b7e69887134 100644 --- a/.github/workflows/test_linux_jax_wheels.yml +++ b/.github/workflows/test_linux_jax_wheels.yml @@ -85,7 +85,7 @@ on: type: string default: master jax_test_branch: - description: google/jax tests ref/branch to run + description: rocm/jax tests ref/branch to run required: false type: string default: main @@ -107,6 +107,7 @@ jobs: container: image: ${{ contains(inputs.test_runs_on, 'linux') && 'ghcr.io/rocm/no_rocm_image_ubuntu24_04@sha256:405945a40deaff9db90b9839c0f41d4cba4a383c1a7459b28627047bf6302a26' || null }} options: >- + --user root --device /dev/kfd --device /dev/dri --group-add render @@ -118,7 +119,7 @@ jobs: PIP_DISABLE_PIP_VERSION_CHECK: 1 THEROCK_TAR_URL: ${{ inputs.tar_url }} PYTHON_VERSION: ${{ inputs.python_versions }} - WHEEL_INDEX_URL: ${{ inputs.package_index_url }}/${{ inputs.s3_subdir }}/${{ inputs.amdgpu_family }} + WHEEL_INDEX_URL: ${{ inputs.package_index_url }}/${{ inputs.amdgpu_family }} steps: - name: Checkout @@ -153,15 +154,16 @@ jobs: echo "PIP_DISABLE_PIP_VERSION_CHECK=1" >> "$GITHUB_ENV" - name: System deps, venv, and base jax requirements install + shell: bash run: | - python3 setup_venv.py /home/tester/.venv --activate-in-future-github-actions-steps - pip install -r external-builds/jax/requirements-jax.txt + python3 build_tools/setup_venv.py "${{ env.VIRTUAL_ENV }}" --activate-in-future-github-actions-steps + pip install -r external-builds/jax/requirements-jax.txt - name: Configure ROCm from TheRock tarball env: - THEROCK_TAR_URL: ${{ env.THEROCK_TAR_URL }} + THEROCK_TAR_URL: ${{ inputs.tar_url }} run: | - python3 build_tools/install_rocm_tar.py + python3 build_tools/install_rocm_tar.py "${{ inputs.tar_url }}" - name: Extract JAX version and set to GITHUB_ENV run: | @@ -172,7 +174,7 @@ jobs: - name: Install JAX wheels from package index run: | # Install jaxlib/plugin/pjrt from the GPU-family index; install jax from PyPI to match the version - pip install --index-url "${WHEEL_INDEX_URL}" \ + pip install --index-url "${{ env.WHEEL_INDEX_URL }}" \ "jaxlib==${JAX_VERSION}+rocm${{ inputs.rocm_version }}" \ "jax-rocm7-plugin==${JAX_VERSION}+rocm${{ inputs.rocm_version }}" \ "jax-rocm7-pjrt==${JAX_VERSION}+rocm${{ inputs.rocm_version }}" @@ -182,7 +184,7 @@ jobs: python -c "import jaxlib; print('jaxlib version:', jaxlib.__version__)" - name: Run JAX tests - working-directory: jax-test + working-directory: . run: | pytest jax/jax_tests/tests/multi_device_test.py -q pytest jax/jax_tests/tests/core_test.py -q From 78c09818f15c35ac08f52bf2cdc07fc2973db5c8 Mon Sep 17 00:00:00 2001 From: kithumma Date: Fri, 21 Nov 2025 10:21:26 +0000 Subject: [PATCH 08/40] update workflow paths --- .../test_linux_jax_wheels-latest.yml | 202 ------------------ 1 file changed, 202 deletions(-) delete mode 100644 .github/workflows/test_linux_jax_wheels-latest.yml diff --git a/.github/workflows/test_linux_jax_wheels-latest.yml b/.github/workflows/test_linux_jax_wheels-latest.yml deleted file mode 100644 index 429179bd3f5..00000000000 --- a/.github/workflows/test_linux_jax_wheels-latest.yml +++ /dev/null @@ -1,202 +0,0 @@ -name: Test Linux JAX Wheels - -on: - workflow_call: - inputs: - amdgpu_family: - required: true - type: string - release_type: - required: true - type: string - s3_subdir: - required: true - type: string - package_index_url: - description: Base CloudFront URL for the Python package index - required: true - type: string - rocm_version: - description: ROCm version (optional, informational) - required: false - type: string - tar_url: - description: URL to TheRock tarball to configure ROCm - required: true - type: string - python_versions: - description: Python version(s) to test (e.g., "3.12") - required: true - type: string - jax_ref: - description: rocm-jax repository ref/branch to check out - required: false - type: string - default: rocm-jaxlib-v0.8.0 - jax_test_branch: - description: google/rocm jax tests ref/branch to run - required: false - type: string - default: main - test_runs_on: - required: true - type: string - ref: - description: Branch, tag or SHA to checkout. Defaults to the reference or SHA that triggered the workflow. - type: string - - workflow_dispatch: - inputs: - amdgpu_family: - type: choice - options: - - gfx110X-dgpu - - gfx1151 - - gfx120X-all - - gfx94X-dcgpu - - gfx950-dcgpu - default: gfx94X-dcgpu - release_type: - description: The type of release ("nightly" or "dev") - required: true - type: string - default: dev - s3_subdir: - description: S3 subdirectory, not including the GPU-family - required: true - type: string - default: v2 - package_index_url: - description: Base CloudFront URL for the Python package index - required: true - type: string - default: https://rocm.nightlies.amd.com/v2-staging/ - rocm_version: - description: ROCm version - required: false - type: string - tar_url: - description: URL to TheRock tarball to configure ROCm - required: true - type: string - python_versions: - description: Python version(s) to test (e.g., "3.12") - required: true - type: string - default: "3.12" - jax_ref: - description: rocm-jax repository ref/branch to check out - required: false - type: string - default: master - jax_test_branch: - description: rocm/jax tests ref/branch to run - required: false - type: string - default: main - test_runs_on: - description: Runner label to use. The selected runner should have a GPU supported by amdgpu_family - required: true - type: string - default: "linux-mi325-1gpu-ossci-rocm" - -permissions: - contents: read - id-token: write - packages: write - -jobs: - test_jax_wheels: - name: Test JAX Wheels | ${{ inputs.amdgpu_family }} | Python ${{ inputs.python_versions }} - runs-on: ${{ inputs.test_runs_on }} - - container: - image: ${{ contains(inputs.test_runs_on, 'linux') && 'ghcr.io/rocm/no_rocm_image_ubuntu24_04@sha256:405945a40deaff9db90b9839c0f41d4cba4a383c1a7459b28627047bf6302a26' || null }} - options: >- - --user root - --device /dev/kfd - --device /dev/dri - --group-add render - --group-add video - - env: - VIRTUAL_ENV: /home/tester/.venv - PIP_PROGRESS_BAR: off - PIP_DISABLE_PIP_VERSION_CHECK: 1 - THEROCK_TAR_URL: ${{ inputs.tar_url }} - PYTHON_VERSION: ${{ inputs.python_versions }} - WHEEL_INDEX_URL: ${{ inputs.package_index_url }}/${{ inputs.s3_subdir }}/${{ inputs.amdgpu_family }} - - steps: - - name: Main Checkout (GitHub Actions path) - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd - with: - repository: ${{ inputs.repository || github.repository }} - ref: ${{ inputs.ref || '' }} - - - name: Checkout rocm-jax (GitHub Actions path) - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 - with: - path: jax - repository: rocm/rocm-jax - ref: ${{ inputs.jax_ref }} - - # JAX tests checkout (rocm/jax) - - name: Checkout JAX tests - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 - with: - repository: rocm/jax - ref: ${{ inputs.jax_test_branch || 'rocm-jaxlib-v0.8.0' }} - path: jax/jax_tests - - - name: Set up Python - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c - with: - python-version: ${{ inputs.python_versions }} - - - name: Ensure PATH includes venv bin and pip env flags - run: | - echo "PATH=${VIRTUAL_ENV}/bin:${PATH}" >> "$GITHUB_ENV" - echo "PIP_PROGRESS_BAR=off" >> "$GITHUB_ENV" - echo "PIP_DISABLE_PIP_VERSION_CHECK=1" >> "$GITHUB_ENV" - - - name: System deps, venv, and base jax requirements install - shell: bash - run: | - python3 build_tools/setup_venv.py "${VIRTUAL_ENV}" --activate-in-future-github-actions-steps - pip install -r external-builds/jax/requirements-jax.txt - - - name: Configure ROCm from TheRock tarball - env: - THEROCK_TAR_URL: ${{ env.THEROCK_TAR_URL }} - run: | - python3 build_tools/install_rocm_tar.py - - - name: Extract JAX version and export - run: | - set -euxo pipefail - JAX_VERSION=$(tr -d ' ' < jax/build/requirements.txt \ - | grep -E '^(jax|jaxlib)==[^#]+' | head -n1 | cut -d'=' -f3) - echo "JAX_VERSION=$JAX_VERSION" >> "$GITHUB_ENV" - - - name: Install JAX wheels from package index - env: - WHEEL_INDEX_URL: ${{ env.WHEEL_INDEX_URL }} - run: | - set -euxo pipefail - pip install --index-url "${WHEEL_INDEX_URL}" \ - "jaxlib==${JAX_VERSION}+rocm${{ inputs.rocm_version }}" \ - "jax-rocm7-plugin==${JAX_VERSION}+rocm${{ inputs.rocm_version }}" \ - "jax-rocm7-pjrt==${JAX_VERSION}+rocm${{ inputs.rocm_version }}" - pip install --extra-index-url https://pypi.org/simple "jax==${JAX_VERSION}" - python -c "import jax; print('JAX version:', jax.__version__)" - python -c "import jaxlib; print('jaxlib version:', jaxlib.__version__)" - - - name: Run JAX tests - working-directory: . - run: | - set -euxo pipefail - pytest jax/jax_tests/tests/multi_device_test.py -q - pytest jax/jax_tests/tests/core_test.py -q - pytest jax/jax_tests/tests/util_test.py -q - pytest jax/jax_tests/tests/scipy_stats_test.py -q From 20ef184a456d2e3b6af8dea5269f43c486585602 Mon Sep 17 00:00:00 2001 From: kithumma Date: Fri, 21 Nov 2025 10:26:56 +0000 Subject: [PATCH 09/40] update workflow paths --- .github/workflows/build_linux_jax_wheels.yml | 13 ++----------- .github/workflows/test_linux_jax_wheels.yml | 12 ++---------- 2 files changed, 4 insertions(+), 21 deletions(-) diff --git a/.github/workflows/build_linux_jax_wheels.yml b/.github/workflows/build_linux_jax_wheels.yml index e2c398947ef..f15a2fa13c4 100644 --- a/.github/workflows/build_linux_jax_wheels.yml +++ b/.github/workflows/build_linux_jax_wheels.yml @@ -156,26 +156,17 @@ jobs: source .venv/bin/activate pip3 install boto3 packaging python3 ./build_tools/third_party/s3_management/manage.py ${{ inputs.s3_subdir }}/${{ inputs.amdgpu_family }} - - - name: Generating target to run - id: configure - env: - TARGET: ${{ inputs.amdgpu_family }} - PLATFORM: "linux" - run: python ./build_tools/github_actions/configure_target_run.py test_jax_wheels: - name: Test JAX wheels | ${{ inputs.amdgpu_family }} | ${{ needs.generate_target_to_run.outputs.test_runs_on }} + name: Test JAX wheels | ${{ inputs.amdgpu_family }} | linux-mi325-1gpu-ossci-rocm needs: [build_jax_wheels] - if: ${{ needs.generate_target_to_run.outputs.test_runs_on != '' }} permissions: id-token: write contents: read - packages: write + packages: read uses: ./.github/workflows/test_linux_jax_wheels.yml with: amdgpu_family: ${{ inputs.amdgpu_family }} - test_runs_on: ${{ needs.generate_target_to_run.outputs.test_runs_on }} release_type: ${{ inputs.release_type }} s3_subdir: ${{ inputs.s3_subdir }} package_index_url: ${{ inputs.cloudfront_staging_url }} diff --git a/.github/workflows/test_linux_jax_wheels.yml b/.github/workflows/test_linux_jax_wheels.yml index b7e69887134..5ce689ba7aa 100644 --- a/.github/workflows/test_linux_jax_wheels.yml +++ b/.github/workflows/test_linux_jax_wheels.yml @@ -33,9 +33,6 @@ on: required: false type: string default: rocm-jaxlib-v0.8.0 - test_runs_on: - required: true - type: string ref: description: "Branch, tag or SHA to checkout. Defaults to the reference or SHA that triggered the workflow." type: string @@ -89,11 +86,6 @@ on: required: false type: string default: main - test_runs_on: - description: Runner label to use. The selected runner should have a GPU supported by amdgpu_family - required: true - type: string - default: "linux-mi325-1gpu-ossci-rocm" permissions: contents: read @@ -103,9 +95,9 @@ permissions: jobs: test_jax_wheels: name: Test JAX Wheels | ${{ inputs.amdgpu_family }} | Python ${{ inputs.python_versions }} - runs-on: ${{ inputs.test_runs_on }} + runs-on: linux-mi325-1gpu-ossci-rocm container: - image: ${{ contains(inputs.test_runs_on, 'linux') && 'ghcr.io/rocm/no_rocm_image_ubuntu24_04@sha256:405945a40deaff9db90b9839c0f41d4cba4a383c1a7459b28627047bf6302a26' || null }} + image: ghcr.io/rocm/no_rocm_image_ubuntu24_04@sha256:405945a40deaff9db90b9839c0f41d4cba4a383c1a7459b28627047bf6302a26 options: >- --user root --device /dev/kfd From 055a75ae4d6d4da88513e4fa1084b8e5c8da1ae8 Mon Sep 17 00:00:00 2001 From: kithumma Date: Fri, 21 Nov 2025 10:28:13 +0000 Subject: [PATCH 10/40] update workflow paths --- .github/workflows/build_linux_jax_wheels.yml | 1 - .github/workflows/test_linux_jax_wheels.yml | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/.github/workflows/build_linux_jax_wheels.yml b/.github/workflows/build_linux_jax_wheels.yml index f15a2fa13c4..e54d46682b8 100644 --- a/.github/workflows/build_linux_jax_wheels.yml +++ b/.github/workflows/build_linux_jax_wheels.yml @@ -161,7 +161,6 @@ jobs: name: Test JAX wheels | ${{ inputs.amdgpu_family }} | linux-mi325-1gpu-ossci-rocm needs: [build_jax_wheels] permissions: - id-token: write contents: read packages: read uses: ./.github/workflows/test_linux_jax_wheels.yml diff --git a/.github/workflows/test_linux_jax_wheels.yml b/.github/workflows/test_linux_jax_wheels.yml index 5ce689ba7aa..71cf05d17db 100644 --- a/.github/workflows/test_linux_jax_wheels.yml +++ b/.github/workflows/test_linux_jax_wheels.yml @@ -89,8 +89,7 @@ on: permissions: contents: read - id-token: write - packages: write + packages: read jobs: test_jax_wheels: From a1fcaa67e1c818769c656c255e298d38dc1992f5 Mon Sep 17 00:00:00 2001 From: kithumma Date: Fri, 21 Nov 2025 16:25:03 +0000 Subject: [PATCH 11/40] update workflow paths --- .github/workflows/test_linux_jax_wheels.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_linux_jax_wheels.yml b/.github/workflows/test_linux_jax_wheels.yml index 71cf05d17db..76590f4a2ad 100644 --- a/.github/workflows/test_linux_jax_wheels.yml +++ b/.github/workflows/test_linux_jax_wheels.yml @@ -94,7 +94,7 @@ permissions: jobs: test_jax_wheels: name: Test JAX Wheels | ${{ inputs.amdgpu_family }} | Python ${{ inputs.python_versions }} - runs-on: linux-mi325-1gpu-ossci-rocm + runs-on: linux-mi325-1gpu-ossci-rocm-test container: image: ghcr.io/rocm/no_rocm_image_ubuntu24_04@sha256:405945a40deaff9db90b9839c0f41d4cba4a383c1a7459b28627047bf6302a26 options: >- From 7fc47fbd5a636b595a1b97827e8703981d98e678 Mon Sep 17 00:00:00 2001 From: kithumma Date: Fri, 21 Nov 2025 18:15:43 +0000 Subject: [PATCH 12/40] update JAX_VERSION path --- .github/workflows/test_linux_jax_wheels.yml | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test_linux_jax_wheels.yml b/.github/workflows/test_linux_jax_wheels.yml index 76590f4a2ad..8b964b7ef05 100644 --- a/.github/workflows/test_linux_jax_wheels.yml +++ b/.github/workflows/test_linux_jax_wheels.yml @@ -97,12 +97,11 @@ jobs: runs-on: linux-mi325-1gpu-ossci-rocm-test container: image: ghcr.io/rocm/no_rocm_image_ubuntu24_04@sha256:405945a40deaff9db90b9839c0f41d4cba4a383c1a7459b28627047bf6302a26 - options: >- - --user root - --device /dev/kfd + options: --device /dev/kfd --device /dev/dri --group-add render --group-add video + --user root # Running as root, by recommendation of GitHub: https://docs.github.com/en/actions/reference/workflows-and-actions/dockerfile-support#user env: VIRTUAL_ENV: /home/tester/.venv @@ -158,7 +157,7 @@ jobs: - name: Extract JAX version and set to GITHUB_ENV run: | - JAX_VERSION=$(tr -d ' ' < rocm-jax/build/requirements.txt \ + JAX_VERSION=$(tr -d ' ' < jax/build/requirements.txt \ | grep -E '^(jax|jaxlib)==[^#]+' | head -n1 | cut -d'=' -f3) echo "JAX_VERSION=$JAX_VERSION" >> "$GITHUB_ENV" From ce1525c07642453c0536c3b6efa9b7fffbfd28c9 Mon Sep 17 00:00:00 2001 From: Kiran Thumma <167153338+kiran-thumma@users.noreply.github.com> Date: Fri, 21 Nov 2025 15:25:26 -0600 Subject: [PATCH 13/40] Update build_linux_jax_wheels.yml Update package url --- .github/workflows/build_linux_jax_wheels.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_linux_jax_wheels.yml b/.github/workflows/build_linux_jax_wheels.yml index e54d46682b8..ebd3ae61b52 100644 --- a/.github/workflows/build_linux_jax_wheels.yml +++ b/.github/workflows/build_linux_jax_wheels.yml @@ -168,7 +168,7 @@ jobs: amdgpu_family: ${{ inputs.amdgpu_family }} release_type: ${{ inputs.release_type }} s3_subdir: ${{ inputs.s3_subdir }} - package_index_url: ${{ inputs.cloudfront_staging_url }} + package_index_url: ${{ inputs.cloudfront_url }} rocm_version: ${{ inputs.rocm_version }} tar_url: ${{ inputs.tar_url }} python_version: ${{ inputs.python_version }} From 20d905815d6669b21759c66cae3b2c41da2a531f Mon Sep 17 00:00:00 2001 From: kithumma Date: Fri, 21 Nov 2025 23:27:21 +0000 Subject: [PATCH 14/40] update scipy>=1.13 version --- .github/workflows/test_linux_jax_wheels.yml | 19 ++++--------------- external-builds/jax/requirements-jax.txt | 2 +- 2 files changed, 5 insertions(+), 16 deletions(-) diff --git a/.github/workflows/test_linux_jax_wheels.yml b/.github/workflows/test_linux_jax_wheels.yml index 8b964b7ef05..18b401f2a62 100644 --- a/.github/workflows/test_linux_jax_wheels.yml +++ b/.github/workflows/test_linux_jax_wheels.yml @@ -80,12 +80,7 @@ on: description: rocm-jax repository ref/branch to check out required: false type: string - default: master - jax_test_branch: - description: rocm/jax tests ref/branch to run - required: false - type: string - default: main + default: rocm-jaxlib-v0.8.0 permissions: contents: read @@ -104,7 +99,7 @@ jobs: --user root # Running as root, by recommendation of GitHub: https://docs.github.com/en/actions/reference/workflows-and-actions/dockerfile-support#user env: - VIRTUAL_ENV: /home/tester/.venv + VIRTUAL_ENV: ${{ github.workspace }}/.venv PIP_PROGRESS_BAR: off PIP_DISABLE_PIP_VERSION_CHECK: 1 THEROCK_TAR_URL: ${{ inputs.tar_url }} @@ -123,13 +118,13 @@ jobs: with: path: jax repository: rocm/rocm-jax - ref: rocm-jaxlib-v0.8.0 + ref: ${{ inputs.jax_ref }} - name: Checkout JAX tests repo (for extended tests) uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: repository: rocm/jax - ref: rocm-jaxlib-v0.8.0 + ref: ${{ inputs.jax_ref }} path: jax/jax_tests - name: Set up Python @@ -137,12 +132,6 @@ jobs: with: python-version: ${{ inputs.python_versions }} - - name: Ensure PATH includes venv bin - run: | - echo "PATH=${{ env.VIRTUAL_ENV }}/bin:${PATH}" >> "$GITHUB_ENV" - echo "PIP_PROGRESS_BAR=off" >> "$GITHUB_ENV" - echo "PIP_DISABLE_PIP_VERSION_CHECK=1" >> "$GITHUB_ENV" - - name: System deps, venv, and base jax requirements install shell: bash run: | diff --git a/external-builds/jax/requirements-jax.txt b/external-builds/jax/requirements-jax.txt index 0a52666f3b2..ceac7dd8ddc 100644 --- a/external-builds/jax/requirements-jax.txt +++ b/external-builds/jax/requirements-jax.txt @@ -3,7 +3,7 @@ build wheel six auditwheel -scipy +scipy>=1.13 pytest pytest-html pytest_html_merger From dfb2a8a7d322feffd18f3168b154bf554762a8a6 Mon Sep 17 00:00:00 2001 From: kithumma Date: Sat, 22 Nov 2025 06:19:41 +0000 Subject: [PATCH 15/40] update scipy version --- .github/workflows/test_linux_jax_wheels.yml | 4 +--- external-builds/jax/constraints.txt | 2 ++ 2 files changed, 3 insertions(+), 3 deletions(-) create mode 100644 external-builds/jax/constraints.txt diff --git a/.github/workflows/test_linux_jax_wheels.yml b/.github/workflows/test_linux_jax_wheels.yml index 18b401f2a62..963fa35e673 100644 --- a/.github/workflows/test_linux_jax_wheels.yml +++ b/.github/workflows/test_linux_jax_wheels.yml @@ -154,14 +154,12 @@ jobs: run: | # Install jaxlib/plugin/pjrt from the GPU-family index; install jax from PyPI to match the version pip install --index-url "${{ env.WHEEL_INDEX_URL }}" \ + -c external-builds/jax/constraints.txt \ "jaxlib==${JAX_VERSION}+rocm${{ inputs.rocm_version }}" \ "jax-rocm7-plugin==${JAX_VERSION}+rocm${{ inputs.rocm_version }}" \ "jax-rocm7-pjrt==${JAX_VERSION}+rocm${{ inputs.rocm_version }}" pip install --extra-index-url https://pypi.org/simple "jax==${JAX_VERSION}" - python -c "import jax; print('JAX version:', jax.__version__)" - python -c "import jaxlib; print('jaxlib version:', jaxlib.__version__)" - - name: Run JAX tests working-directory: . run: | diff --git a/external-builds/jax/constraints.txt b/external-builds/jax/constraints.txt new file mode 100644 index 00000000000..cbae5e62b2b --- /dev/null +++ b/external-builds/jax/constraints.txt @@ -0,0 +1,2 @@ +numpy==1.26.4 +scipy==1.16.3 From 81db5b9682efdc661e83655dbd1ca1ae92ad31ff Mon Sep 17 00:00:00 2001 From: kithumma Date: Sat, 22 Nov 2025 23:29:59 +0000 Subject: [PATCH 16/40] adjust python installations --- .github/workflows/test_linux_jax_wheels.yml | 29 +++++++++++---------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/.github/workflows/test_linux_jax_wheels.yml b/.github/workflows/test_linux_jax_wheels.yml index 963fa35e673..7b7f26be503 100644 --- a/.github/workflows/test_linux_jax_wheels.yml +++ b/.github/workflows/test_linux_jax_wheels.yml @@ -92,11 +92,12 @@ jobs: runs-on: linux-mi325-1gpu-ossci-rocm-test container: image: ghcr.io/rocm/no_rocm_image_ubuntu24_04@sha256:405945a40deaff9db90b9839c0f41d4cba4a383c1a7459b28627047bf6302a26 - options: --device /dev/kfd + options: >- + --device /dev/kfd --device /dev/dri --group-add render --group-add video - --user root # Running as root, by recommendation of GitHub: https://docs.github.com/en/actions/reference/workflows-and-actions/dockerfile-support#user + --user root env: VIRTUAL_ENV: ${{ github.workspace }}/.venv @@ -127,22 +128,24 @@ jobs: ref: ${{ inputs.jax_ref }} path: jax/jax_tests - - name: Set up Python - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 - with: - python-version: ${{ inputs.python_versions }} - - name: System deps, venv, and base jax requirements install shell: bash run: | - python3 build_tools/setup_venv.py "${{ env.VIRTUAL_ENV }}" --activate-in-future-github-actions-steps - pip install -r external-builds/jax/requirements-jax.txt + apt-get update && apt-get install -y --no-install-recommends python3.12-venv python3-pip ca-certificates git + python3 -m venv "${VIRTUAL_ENV}" + echo "PATH=${VIRTUAL_ENV}/bin:${PATH}" >> "$GITHUB_ENV" + "${VIRTUAL_ENV}/bin/python" -m pip install --upgrade pip setuptools wheel + python3 build_tools/setup_venv.py "${VIRTUAL_ENV}" --activate-in-future-github-actions-steps + + - name: Install base JAX test requirements + shell: bash + run: | + # This script sets up the venv and activates it across steps; keep it consistent + pip install -r external-builds/jax/requirements-jax.txt - name: Configure ROCm from TheRock tarball - env: - THEROCK_TAR_URL: ${{ inputs.tar_url }} run: | - python3 build_tools/install_rocm_tar.py "${{ inputs.tar_url }}" + python build_tools/install_rocm_tar.py "${{ inputs.tar_url }}" - name: Extract JAX version and set to GITHUB_ENV run: | @@ -154,14 +157,12 @@ jobs: run: | # Install jaxlib/plugin/pjrt from the GPU-family index; install jax from PyPI to match the version pip install --index-url "${{ env.WHEEL_INDEX_URL }}" \ - -c external-builds/jax/constraints.txt \ "jaxlib==${JAX_VERSION}+rocm${{ inputs.rocm_version }}" \ "jax-rocm7-plugin==${JAX_VERSION}+rocm${{ inputs.rocm_version }}" \ "jax-rocm7-pjrt==${JAX_VERSION}+rocm${{ inputs.rocm_version }}" pip install --extra-index-url https://pypi.org/simple "jax==${JAX_VERSION}" - name: Run JAX tests - working-directory: . run: | pytest jax/jax_tests/tests/multi_device_test.py -q pytest jax/jax_tests/tests/core_test.py -q From 9025b44f56846646a3b13d2140fbd43b9daa883e Mon Sep 17 00:00:00 2001 From: kithumma Date: Mon, 24 Nov 2025 06:03:01 +0000 Subject: [PATCH 17/40] update install scripts --- .github/workflows/test_linux_jax_wheels.yml | 9 ++- build_tools/install_rocm_tar.py | 89 +++++++++++++-------- 2 files changed, 59 insertions(+), 39 deletions(-) diff --git a/.github/workflows/test_linux_jax_wheels.yml b/.github/workflows/test_linux_jax_wheels.yml index 7b7f26be503..eb08d602d27 100644 --- a/.github/workflows/test_linux_jax_wheels.yml +++ b/.github/workflows/test_linux_jax_wheels.yml @@ -128,10 +128,10 @@ jobs: ref: ${{ inputs.jax_ref }} path: jax/jax_tests - - name: System deps, venv, and base jax requirements install + - name: System deps, venv configure shell: bash run: | - apt-get update && apt-get install -y --no-install-recommends python3.12-venv python3-pip ca-certificates git + apt-get update && apt-get install -y --no-install-recommends python${{ inputs.python_versions }}-venv python3-pip ca-certificates git python3 -m venv "${VIRTUAL_ENV}" echo "PATH=${VIRTUAL_ENV}/bin:${PATH}" >> "$GITHUB_ENV" "${VIRTUAL_ENV}/bin/python" -m pip install --upgrade pip setuptools wheel @@ -145,12 +145,13 @@ jobs: - name: Configure ROCm from TheRock tarball run: | - python build_tools/install_rocm_tar.py "${{ inputs.tar_url }}" + python build_tools/install_rocm_tar.py "${{ inputs.tar_url }}" "${{ inputs.rocm_version }}" - name: Extract JAX version and set to GITHUB_ENV run: | + # Extract JAX version from requirements.txt (e.g., "jax==0.8.0") JAX_VERSION=$(tr -d ' ' < jax/build/requirements.txt \ - | grep -E '^(jax|jaxlib)==[^#]+' | head -n1 | cut -d'=' -f3) + | grep -E '^(jax|jaxlib)==[^#]+' | head -n1 | cut -d'=' -f3) echo "JAX_VERSION=$JAX_VERSION" >> "$GITHUB_ENV" - name: Install JAX wheels from package index diff --git a/build_tools/install_rocm_tar.py b/build_tools/install_rocm_tar.py index 97f684f5c0f..742a1107556 100644 --- a/build_tools/install_rocm_tar.py +++ b/build_tools/install_rocm_tar.py @@ -1,60 +1,79 @@ #!/usr/bin/env python3 +""" +Minimal ROCm tarball installer using requests with simple logging (message only) + +Usage: + python3 install_rocm_tar.py + +Example: + python3 install_rocm_tar.py \ + "https://therock-nightly-tarball.s3.amazonaws.com/therock-dist-linux-gfx94X-dcgpu-7.10.0a20251109.tar.gz" \ + "7.10.0a20251109" +""" + +import sys import os -import re -import shutil +import logging import subprocess -import sys from pathlib import Path +from urllib.parse import urlparse +import requests + +# Log message +logging.basicConfig(level=logging.INFO, format="%(message)s") +logger = logging.getLogger(__name__) def run(cmd, cwd=None): - print(f"+ {cmd}") + logger.info("+ %s", cmd) subprocess.check_call(cmd, shell=True, cwd=cwd) +def download_tarball(url, dest_dir) -> Path: + logger.info("Downloading: %s", url) + dest_dir = Path(dest_dir) + dest_dir.mkdir(parents=True, exist_ok=True) + + filename = os.path.basename(urlparse(url).path) or "rocm.tar.gz" + outfile = dest_dir / filename + tmpfile = outfile.with_suffix(outfile.suffix + ".part") + + with requests.get(url, stream=True) as r: + r.raise_for_status() + with open(tmpfile, "wb") as f: + for chunk in r.iter_content(chunk_size=1024 * 1024): + if chunk: + f.write(chunk) + + tmpfile.rename(outfile) + logger.info("Downloaded to %s", outfile) + return outfile + + def main(): - therock_tar_url = os.environ.get("THEROCK_TAR_URL", "").strip() - if not therock_tar_url: - print("THEROCK_TAR_URL not provided") - sys.exit(1) + # Expect exactly two args: tar_url and rocm_version + tar_url = sys.argv[1] + rocm_version = sys.argv[2] workdir = Path.cwd() / "therock-tarball" install_dir = workdir / "install" workdir.mkdir(exist_ok=True) install_dir.mkdir(exist_ok=True) - print(f"Working in {workdir}") - - # Download tarball - run(f'wget -q "{therock_tar_url}"', cwd=str(workdir)) - - # Find tarball - tars = list(workdir.glob("*.tar.gz")) - if not tars: - print("No .tar.gz downloaded") - sys.exit(1) - tarball = tars[0].name - print(f"Found tarball: {tarball}") - - # Extract version from filename - m = re.search(r"(\d+\.\d+\.\w+\d+)", tarball) - if not m: - print("Could not extract ROCm version from tarball name") - sys.exit(1) - version = m.group(1) - print(f"Parsed ROCm version: {version}") + logger.info("Working in %s", workdir) + + tar_path = download_tarball(tar_url, workdir) # Extract tarball - run(f'tar -xf "{tarball}" -C install', cwd=str(workdir)) + run(f'tar -xf "{tar_path.name}" -C install', cwd=str(workdir)) - # Move into /opt/rocm- and create symlinks - dest = Path(f"/opt/rocm-{version}") + # Install to /opt/rocm- and create symlinks + dest = Path(f"/opt/rocm-{rocm_version}") run(f'sudo mkdir -p "{dest}"') run(f'sudo mv "{install_dir}"/* "{dest}"') run(f'sudo ln -sfn "{dest}" /opt/rocm') - run(f"sudo ln -sfn /opt/rocm /etc/alternatives/rocm") - - print("ROCm installation configured at /opt/rocm with alternatives link") + run(f'sudo ln -sfn /opt/rocm /etc/alternatives/rocm') + logger.info("ROCm installation configured at /opt/rocm-%s", rocm_version) if __name__ == "__main__": - sys.exit(main()) + main() From 7f34313c0d884b15c92d293a781769fad81ded61 Mon Sep 17 00:00:00 2001 From: kithumma Date: Mon, 24 Nov 2025 21:43:57 +0000 Subject: [PATCH 18/40] update install rocm --- .github/workflows/test_linux_jax_wheels.yml | 15 +++- build_tools/install_rocm_tar.py | 79 --------------------- 2 files changed, 14 insertions(+), 80 deletions(-) delete mode 100644 build_tools/install_rocm_tar.py diff --git a/.github/workflows/test_linux_jax_wheels.yml b/.github/workflows/test_linux_jax_wheels.yml index eb08d602d27..0824620ec3b 100644 --- a/.github/workflows/test_linux_jax_wheels.yml +++ b/.github/workflows/test_linux_jax_wheels.yml @@ -144,8 +144,21 @@ jobs: pip install -r external-builds/jax/requirements-jax.txt - name: Configure ROCm from TheRock tarball + shell: bash + env: + ROCM_VERSION: ${{ inputs.rocm_version }} + AMDGPU_FAMILY: ${{ inputs.amdgpu_family }} run: | - python build_tools/install_rocm_tar.py "${{ inputs.tar_url }}" "${{ inputs.rocm_version }}" + python -m pip install -q --upgrade pip boto3 + DEST="/opt/rocm-${{ inputs.rocm_version }}" + # Install directly from TheRock release buckets (nightly/dev) using the provided version + python build_tools/install_rocm_from_artifacts.py \ + --release "${{ inputs.rocm_version }}" \ + --artifact-group "${{ inputs.amdgpu_family }}" \ + --output-dir "${DEST}" + # Create standard symlinks + ln -sfn "${DEST}" /opt/rocm + ln -sfn /opt/rocm /etc/alternatives/rocm - name: Extract JAX version and set to GITHUB_ENV run: | diff --git a/build_tools/install_rocm_tar.py b/build_tools/install_rocm_tar.py deleted file mode 100644 index 742a1107556..00000000000 --- a/build_tools/install_rocm_tar.py +++ /dev/null @@ -1,79 +0,0 @@ -#!/usr/bin/env python3 -""" -Minimal ROCm tarball installer using requests with simple logging (message only) - -Usage: - python3 install_rocm_tar.py - -Example: - python3 install_rocm_tar.py \ - "https://therock-nightly-tarball.s3.amazonaws.com/therock-dist-linux-gfx94X-dcgpu-7.10.0a20251109.tar.gz" \ - "7.10.0a20251109" -""" - -import sys -import os -import logging -import subprocess -from pathlib import Path -from urllib.parse import urlparse -import requests - -# Log message -logging.basicConfig(level=logging.INFO, format="%(message)s") -logger = logging.getLogger(__name__) - - -def run(cmd, cwd=None): - logger.info("+ %s", cmd) - subprocess.check_call(cmd, shell=True, cwd=cwd) - - -def download_tarball(url, dest_dir) -> Path: - logger.info("Downloading: %s", url) - dest_dir = Path(dest_dir) - dest_dir.mkdir(parents=True, exist_ok=True) - - filename = os.path.basename(urlparse(url).path) or "rocm.tar.gz" - outfile = dest_dir / filename - tmpfile = outfile.with_suffix(outfile.suffix + ".part") - - with requests.get(url, stream=True) as r: - r.raise_for_status() - with open(tmpfile, "wb") as f: - for chunk in r.iter_content(chunk_size=1024 * 1024): - if chunk: - f.write(chunk) - - tmpfile.rename(outfile) - logger.info("Downloaded to %s", outfile) - return outfile - - -def main(): - # Expect exactly two args: tar_url and rocm_version - tar_url = sys.argv[1] - rocm_version = sys.argv[2] - - workdir = Path.cwd() / "therock-tarball" - install_dir = workdir / "install" - workdir.mkdir(exist_ok=True) - install_dir.mkdir(exist_ok=True) - logger.info("Working in %s", workdir) - - tar_path = download_tarball(tar_url, workdir) - - # Extract tarball - run(f'tar -xf "{tar_path.name}" -C install', cwd=str(workdir)) - - # Install to /opt/rocm- and create symlinks - dest = Path(f"/opt/rocm-{rocm_version}") - run(f'sudo mkdir -p "{dest}"') - run(f'sudo mv "{install_dir}"/* "{dest}"') - run(f'sudo ln -sfn "{dest}" /opt/rocm') - run(f'sudo ln -sfn /opt/rocm /etc/alternatives/rocm') - - logger.info("ROCm installation configured at /opt/rocm-%s", rocm_version) - -if __name__ == "__main__": - main() From bd3e346bbba3fe6eccd08357117a1dd0324a011f Mon Sep 17 00:00:00 2001 From: kithumma Date: Mon, 24 Nov 2025 22:21:08 +0000 Subject: [PATCH 19/40] cleanup PR --- external-builds/jax/constraints.txt | 2 -- 1 file changed, 2 deletions(-) delete mode 100644 external-builds/jax/constraints.txt diff --git a/external-builds/jax/constraints.txt b/external-builds/jax/constraints.txt deleted file mode 100644 index cbae5e62b2b..00000000000 --- a/external-builds/jax/constraints.txt +++ /dev/null @@ -1,2 +0,0 @@ -numpy==1.26.4 -scipy==1.16.3 From 8e9f9d7b1bda5b303b16abd4bea89812aac66f85 Mon Sep 17 00:00:00 2001 From: kithumma Date: Tue, 25 Nov 2025 16:15:33 +0000 Subject: [PATCH 20/40] update runs-on --- .github/workflows/test_linux_jax_wheels.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test_linux_jax_wheels.yml b/.github/workflows/test_linux_jax_wheels.yml index 0824620ec3b..df5642411ef 100644 --- a/.github/workflows/test_linux_jax_wheels.yml +++ b/.github/workflows/test_linux_jax_wheels.yml @@ -89,7 +89,7 @@ permissions: jobs: test_jax_wheels: name: Test JAX Wheels | ${{ inputs.amdgpu_family }} | Python ${{ inputs.python_versions }} - runs-on: linux-mi325-1gpu-ossci-rocm-test + runs-on: linux-mi325-1gpu-ossci-rocm container: image: ghcr.io/rocm/no_rocm_image_ubuntu24_04@sha256:405945a40deaff9db90b9839c0f41d4cba4a383c1a7459b28627047bf6302a26 options: >- From e84b5296f8c60d3d1e1547568456777a26d7a6ff Mon Sep 17 00:00:00 2001 From: kithumma Date: Wed, 26 Nov 2025 01:40:30 +0000 Subject: [PATCH 21/40] update setup-python and fixes as per comments --- .github/workflows/build_linux_jax_wheels.yml | 23 +++++++++++++- .github/workflows/test_linux_jax_wheels.yml | 28 +++++++++++------ external-builds/jax/requirements-jax.txt | 33 ++++++++++---------- 3 files changed, 58 insertions(+), 26 deletions(-) diff --git a/.github/workflows/build_linux_jax_wheels.yml b/.github/workflows/build_linux_jax_wheels.yml index ebd3ae61b52..293b2ebb0d8 100644 --- a/.github/workflows/build_linux_jax_wheels.yml +++ b/.github/workflows/build_linux_jax_wheels.yml @@ -156,9 +156,29 @@ jobs: source .venv/bin/activate pip3 install boto3 packaging python3 ./build_tools/third_party/s3_management/manage.py ${{ inputs.s3_subdir }}/${{ inputs.amdgpu_family }} + + generate_target_to_run: + name: Generate target_to_run + runs-on: ubuntu-24.04 + outputs: + test_runs_on: ${{ steps.configure.outputs.test-runs-on }} + bypass_tests_for_releases: ${{ steps.configure.outputs.bypass_tests_for_releases }} + steps: + - name: Checking out repository + uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 + with: + repository: ${{ inputs.repository || github.repository }} + ref: ${{ inputs.ref || '' }} + + - name: Generating target to run + id: configure + env: + TARGET: ${{ inputs.amdgpu_family }} + PLATFORM: "linux" + run: python ./build_tools/github_actions/configure_target_run.py test_jax_wheels: - name: Test JAX wheels | ${{ inputs.amdgpu_family }} | linux-mi325-1gpu-ossci-rocm + name: Test JAX wheels | ${{ inputs.amdgpu_family }} | ${{ needs.generate_target_to_run.outputs.test_runs_on }} needs: [build_jax_wheels] permissions: contents: read @@ -174,3 +194,4 @@ jobs: python_version: ${{ inputs.python_version }} ref: ${{ inputs.ref || '' }} jax_ref: ${{ inputs.jax_ref }} + test_runs_on: ${{ needs.generate_target_to_run.outputs.test_runs_on }} diff --git a/.github/workflows/test_linux_jax_wheels.yml b/.github/workflows/test_linux_jax_wheels.yml index df5642411ef..263082de9a2 100644 --- a/.github/workflows/test_linux_jax_wheels.yml +++ b/.github/workflows/test_linux_jax_wheels.yml @@ -36,6 +36,9 @@ on: ref: description: "Branch, tag or SHA to checkout. Defaults to the reference or SHA that triggered the workflow." type: string + test_runs_on: + required: true + type: string workflow_dispatch: inputs: @@ -81,6 +84,11 @@ on: required: false type: string default: rocm-jaxlib-v0.8.0 + test_runs_on: + description: Runner label to use. The selected runner should have a GPU supported by amdgpu_family + required: true + type: string + default: "linux-mi325-1gpu-ossci-rocm-frac" permissions: contents: read @@ -89,7 +97,7 @@ permissions: jobs: test_jax_wheels: name: Test JAX Wheels | ${{ inputs.amdgpu_family }} | Python ${{ inputs.python_versions }} - runs-on: linux-mi325-1gpu-ossci-rocm + runs-on: ${{ inputs.test_runs_on }} container: image: ghcr.io/rocm/no_rocm_image_ubuntu24_04@sha256:405945a40deaff9db90b9839c0f41d4cba4a383c1a7459b28627047bf6302a26 options: >- @@ -98,11 +106,11 @@ jobs: --group-add render --group-add video --user root - + defaults: + run: + shell: bash env: VIRTUAL_ENV: ${{ github.workspace }}/.venv - PIP_PROGRESS_BAR: off - PIP_DISABLE_PIP_VERSION_CHECK: 1 THEROCK_TAR_URL: ${{ inputs.tar_url }} PYTHON_VERSION: ${{ inputs.python_versions }} WHEEL_INDEX_URL: ${{ inputs.package_index_url }}/${{ inputs.amdgpu_family }} @@ -127,29 +135,31 @@ jobs: repository: rocm/jax ref: ${{ inputs.jax_ref }} path: jax/jax_tests + + - name: Set up Python + uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 + with: + python-version: ${{ inputs.python_versions }} + check-latest: true - name: System deps, venv configure - shell: bash run: | - apt-get update && apt-get install -y --no-install-recommends python${{ inputs.python_versions }}-venv python3-pip ca-certificates git python3 -m venv "${VIRTUAL_ENV}" echo "PATH=${VIRTUAL_ENV}/bin:${PATH}" >> "$GITHUB_ENV" "${VIRTUAL_ENV}/bin/python" -m pip install --upgrade pip setuptools wheel python3 build_tools/setup_venv.py "${VIRTUAL_ENV}" --activate-in-future-github-actions-steps - name: Install base JAX test requirements - shell: bash run: | # This script sets up the venv and activates it across steps; keep it consistent pip install -r external-builds/jax/requirements-jax.txt - name: Configure ROCm from TheRock tarball - shell: bash env: ROCM_VERSION: ${{ inputs.rocm_version }} AMDGPU_FAMILY: ${{ inputs.amdgpu_family }} run: | - python -m pip install -q --upgrade pip boto3 + python -m pip install -q --upgrade pip DEST="/opt/rocm-${{ inputs.rocm_version }}" # Install directly from TheRock release buckets (nightly/dev) using the provided version python build_tools/install_rocm_from_artifacts.py \ diff --git a/external-builds/jax/requirements-jax.txt b/external-builds/jax/requirements-jax.txt index ceac7dd8ddc..065666382a7 100644 --- a/external-builds/jax/requirements-jax.txt +++ b/external-builds/jax/requirements-jax.txt @@ -1,19 +1,20 @@ +boto3>=1.41.4 numpy<2 -build -wheel -six -auditwheel +build>=1.3.0 +wheel>=0.45.1 +six>=1.17.0 +auditwheel>=6.5.0 scipy>=1.13 -pytest -pytest-html -pytest_html_merger -pytest-reportlog -pytest-rerunfailures -pytest-json-report -cloudpickle -portpicker -matplotlib -absl-py -flatbuffers -hypothesis +pytest>=9.0.1 +pytest-html>=4.1.1 +pytest_html_merger>=0.1.0 +pytest-reportlog>=1.0.0 +pytest-rerunfailures>=16.1 +pytest-json-report>=1.5.0 +cloudpickle>=3.1.2 +portpicker>=1.6.0 +matplotlib>=3.10.7 +absl-py>=2.3.1 +flatbuffers>=25.9.23 +hypothesis>=6.148.2 ml_dtypes>=0.5.0 From 7b5f44e569e462d1816cb844b7cbb65b3c584135 Mon Sep 17 00:00:00 2001 From: kithumma Date: Wed, 3 Dec 2025 00:45:42 +0000 Subject: [PATCH 22/40] fixing merge conflicts --- .github/workflows/build_jax_dockerfile.yml | 79 ---------------------- 1 file changed, 79 deletions(-) delete mode 100644 .github/workflows/build_jax_dockerfile.yml diff --git a/.github/workflows/build_jax_dockerfile.yml b/.github/workflows/build_jax_dockerfile.yml deleted file mode 100644 index fd9e078431f..00000000000 --- a/.github/workflows/build_jax_dockerfile.yml +++ /dev/null @@ -1,79 +0,0 @@ -name: Build Linux JAX Docker Images - -on: - workflow_call: - inputs: - amdgpu_family: - required: true - type: string - python_version: - required: true - type: string - release_type: - description: The type of release to build ("nightly", "prerelease", or "dev") - required: true - type: string - s3_subdir: - description: S3 subdirectory, not including the GPU-family - required: true - type: string - rocm_version: - description: ROCm version to install - type: string - tar_url: - description: URL to TheRock tarball to build against - type: string - workflow_dispatch: - inputs: - amdgpu_family: - required: true - type: string - python_version: - required: true - type: string - release_type: - description: The type of release to build ("nightly", "prerelease", or "dev") - required: true - type: string - default: "dev" - s3_subdir: - description: S3 subdirectory, not including the GPU-family - type: string - default: "v2" - rocm_version: - description: ROCm version to install - type: string - tar_url: - description: URL to TheRock tarball to build against - type: string - -permissions: - id-token: write - contents: read - -jobs: - build_jax_wheels: - strategy: - matrix: - jax_ref: [master] - name: Build | ${{ inputs.amdgpu_family }} | Python ${{ inputs.python_version }} - runs-on: ${{ github.repository_owner == 'ROCm' && 'azure-linux-scale-rocm' || 'ubuntu-24.04' }} - env: - PACKAGE_DIST_DIR: ${{ github.workspace }}/jax_rocm_plugin/wheelhouse - S3_BUCKET_PY: "therock-${{ inputs.release_type }}-python" - steps: - - name: Checkout TheRock - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 - - - name: Checkout JAX - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 - with: - path: jax - repository: rocm/rocm-jax - ref: ${{ matrix.jax_ref }} - - - name: Configure Git Identity - run: | - git config --global user.name "therockbot" - git config --global user.email "therockbot@amd.com" - # TODO: Pull down JAX plugin wheels into the wheelhouse directory and run the image build script from rocm-jax From 4d0f1e305e41256aefc404d14ee3b32b5a34212b Mon Sep 17 00:00:00 2001 From: kithumma Date: Wed, 3 Dec 2025 00:46:57 +0000 Subject: [PATCH 23/40] fixing merge conflicts --- .github/workflows/test_linux_jax_wheels.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test_linux_jax_wheels.yml b/.github/workflows/test_linux_jax_wheels.yml index 263082de9a2..2ec47c27347 100644 --- a/.github/workflows/test_linux_jax_wheels.yml +++ b/.github/workflows/test_linux_jax_wheels.yml @@ -96,10 +96,10 @@ permissions: jobs: test_jax_wheels: - name: Test JAX Wheels | ${{ inputs.amdgpu_family }} | Python ${{ inputs.python_versions }} + name: Test JAX Wheels | ${{ inputs.amdgpu_family }} runs-on: ${{ inputs.test_runs_on }} container: - image: ghcr.io/rocm/no_rocm_image_ubuntu24_04@sha256:405945a40deaff9db90b9839c0f41d4cba4a383c1a7459b28627047bf6302a26 + image: ${{ contains(inputs.test_runs_on, 'linux') && 'ghcr.io/rocm/no_rocm_image_ubuntu24_04@sha256:405945a40deaff9db90b9839c0f41d4cba4a383c1a7459b28627047bf6302a26' || null }} options: >- --device /dev/kfd --device /dev/dri @@ -111,6 +111,7 @@ jobs: shell: bash env: VIRTUAL_ENV: ${{ github.workspace }}/.venv + AMDGPU_FAMILY: ${{ inputs.amdgpu_family }} THEROCK_TAR_URL: ${{ inputs.tar_url }} PYTHON_VERSION: ${{ inputs.python_versions }} WHEEL_INDEX_URL: ${{ inputs.package_index_url }}/${{ inputs.amdgpu_family }} From 22077f5fe3901c753b7791887e0cfecc3204e839 Mon Sep 17 00:00:00 2001 From: kithumma Date: Wed, 26 Nov 2025 07:15:50 +0000 Subject: [PATCH 24/40] test_runs_on update --- .github/workflows/build_linux_jax_wheels.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build_linux_jax_wheels.yml b/.github/workflows/build_linux_jax_wheels.yml index 293b2ebb0d8..9bcd852cdb1 100644 --- a/.github/workflows/build_linux_jax_wheels.yml +++ b/.github/workflows/build_linux_jax_wheels.yml @@ -93,7 +93,7 @@ jobs: strategy: matrix: jax_ref: [rocm-jaxlib-v0.8.0] - name: Build Linux JAX Wheels | ${{ inputs.amdgpu_family }} | Python ${{ inputs.python_version }} + name: Build Linux JAX Wheels | ${{ inputs.amdgpu_family }} | Python ${{ inputs.python_versions }} runs-on: ${{ github.repository_owner == 'ROCm' && 'azure-linux-scale-rocm' || 'ubuntu-24.04' }} env: PACKAGE_DIST_DIR: ${{ github.workspace }}/jax/jax_rocm_plugin/wheelhouse @@ -179,7 +179,7 @@ jobs: test_jax_wheels: name: Test JAX wheels | ${{ inputs.amdgpu_family }} | ${{ needs.generate_target_to_run.outputs.test_runs_on }} - needs: [build_jax_wheels] + needs: [build_jax_wheels, generate_target_to_run] permissions: contents: read packages: read From f341cb1e02222ce9427aa8c280161e8e8002e9f8 Mon Sep 17 00:00:00 2001 From: kithumma Date: Wed, 26 Nov 2025 09:57:19 +0000 Subject: [PATCH 25/40] update --- .github/workflows/build_linux_jax_wheels.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_linux_jax_wheels.yml b/.github/workflows/build_linux_jax_wheels.yml index 9bcd852cdb1..890ccfa0ddd 100644 --- a/.github/workflows/build_linux_jax_wheels.yml +++ b/.github/workflows/build_linux_jax_wheels.yml @@ -93,7 +93,7 @@ jobs: strategy: matrix: jax_ref: [rocm-jaxlib-v0.8.0] - name: Build Linux JAX Wheels | ${{ inputs.amdgpu_family }} | Python ${{ inputs.python_versions }} + name: Build Linux JAX Wheels | ${{ inputs.amdgpu_family }} runs-on: ${{ github.repository_owner == 'ROCm' && 'azure-linux-scale-rocm' || 'ubuntu-24.04' }} env: PACKAGE_DIST_DIR: ${{ github.workspace }}/jax/jax_rocm_plugin/wheelhouse From e1d7d04d2ae729b0f0a553d90c4b82bedf3a2f63 Mon Sep 17 00:00:00 2001 From: kithumma Date: Wed, 26 Nov 2025 09:58:24 +0000 Subject: [PATCH 26/40] update --- .github/workflows/build_linux_jax_wheels.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build_linux_jax_wheels.yml b/.github/workflows/build_linux_jax_wheels.yml index 890ccfa0ddd..3581ab98310 100644 --- a/.github/workflows/build_linux_jax_wheels.yml +++ b/.github/workflows/build_linux_jax_wheels.yml @@ -93,7 +93,7 @@ jobs: strategy: matrix: jax_ref: [rocm-jaxlib-v0.8.0] - name: Build Linux JAX Wheels | ${{ inputs.amdgpu_family }} + name: Build Linux JAX Wheels | ${{ inputs.amdgpu_family }} | Python ${{ inputs.python_version }} runs-on: ${{ github.repository_owner == 'ROCm' && 'azure-linux-scale-rocm' || 'ubuntu-24.04' }} env: PACKAGE_DIST_DIR: ${{ github.workspace }}/jax/jax_rocm_plugin/wheelhouse From 3a4f914da27fe21227e3fd0b0efb13e01fb155cb Mon Sep 17 00:00:00 2001 From: kithumma Date: Wed, 3 Dec 2025 00:48:12 +0000 Subject: [PATCH 27/40] fixing merge conflicts --- .github/workflows/test_jax_dockerfile.yml | 23 +++++++++++++++++++++ .github/workflows/test_linux_jax_wheels.yml | 10 ++++----- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/.github/workflows/test_jax_dockerfile.yml b/.github/workflows/test_jax_dockerfile.yml index 866f9119d5c..edbaa84381f 100644 --- a/.github/workflows/test_jax_dockerfile.yml +++ b/.github/workflows/test_jax_dockerfile.yml @@ -6,7 +6,11 @@ on: test_runs_on: required: true type: string +<<<<<<< HEAD default: "linux-mi325-1gpu-ossci-rocm-frac" +======= + default: "linux-mi325-1gpu-ossci-rocm" +>>>>>>> 541e0fc1 (update) image_name: required: true description: JAX docker image to run tests with @@ -16,6 +20,10 @@ on: required: false type: string jax_plugin_branch: +<<<<<<< HEAD +======= + required: true +>>>>>>> 541e0fc1 (update) description: JAX plugin branch to checkout type: string default: "rocm-jaxlib-v0.6.0" @@ -34,20 +42,35 @@ on: required: false type: string jax_plugin_branch: +<<<<<<< HEAD description: JAX plugin branch to checkout to use for test scripts type: string default: "rocm-jaxlib-v0.8.0" +======= + required: true + description: JAX plugin branch to checkout to use for test scripts + type: string + default: "rocm-jaxlib-v0.6.0" +>>>>>>> 541e0fc1 (update) permissions: contents: read jobs: test_wheels: +<<<<<<< HEAD name: Test runs-on: ${{ inputs.test_runs_on }} steps: - name: Checkout uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 +======= + name: Test | ${{ inputs.amdgpu_family }} + runs-on: ${{ inputs.test_runs_on }} + steps: + - name: Checkout + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 +>>>>>>> 541e0fc1 (update) with: repo: rocm/rocm-jax # TODO: Add steps for creating the JAX docker image with an install of TheRock and then running JAX tests on the container diff --git a/.github/workflows/test_linux_jax_wheels.yml b/.github/workflows/test_linux_jax_wheels.yml index 2ec47c27347..1645fff689b 100644 --- a/.github/workflows/test_linux_jax_wheels.yml +++ b/.github/workflows/test_linux_jax_wheels.yml @@ -83,7 +83,6 @@ on: description: rocm-jax repository ref/branch to check out required: false type: string - default: rocm-jaxlib-v0.8.0 test_runs_on: description: Runner label to use. The selected runner should have a GPU supported by amdgpu_family required: true @@ -106,6 +105,7 @@ jobs: --group-add render --group-add video --user root + --env-file /etc/podinfo/gha-gpu-isolation-settings defaults: run: shell: bash @@ -189,7 +189,7 @@ jobs: - name: Run JAX tests run: | - pytest jax/jax_tests/tests/multi_device_test.py -q - pytest jax/jax_tests/tests/core_test.py -q - pytest jax/jax_tests/tests/util_test.py -q - pytest jax/jax_tests/tests/scipy_stats_test.py -q + pytest jax/jax_tests/tests/multi_device_test.py -q --log-cli-level=INFO + pytest jax/jax_tests/tests/core_test.py -q --log-cli-level=INFO + pytest jax/jax_tests/tests/util_test.py -q --log-cli-level=INFO + pytest jax/jax_tests/tests/scipy_stats_test.py -q --log-cli-level=INFO From 744f1e3af88c81c12db98057a948132774c4e909 Mon Sep 17 00:00:00 2001 From: kithumma Date: Wed, 26 Nov 2025 10:39:22 +0000 Subject: [PATCH 28/40] update --- .github/workflows/test_linux_jax_wheels.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/test_linux_jax_wheels.yml b/.github/workflows/test_linux_jax_wheels.yml index 1645fff689b..8b50bd2e72e 100644 --- a/.github/workflows/test_linux_jax_wheels.yml +++ b/.github/workflows/test_linux_jax_wheels.yml @@ -32,7 +32,6 @@ on: description: rocm-jax repository ref/branch to check out required: false type: string - default: rocm-jaxlib-v0.8.0 ref: description: "Branch, tag or SHA to checkout. Defaults to the reference or SHA that triggered the workflow." type: string From 50cf5b9ced32f546d8d67551ce8b73e3a4263dcf Mon Sep 17 00:00:00 2001 From: kithumma Date: Wed, 26 Nov 2025 10:48:48 +0000 Subject: [PATCH 29/40] remove pip upgrade and add in python setup --- .github/workflows/test_linux_jax_wheels.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/test_linux_jax_wheels.yml b/.github/workflows/test_linux_jax_wheels.yml index 8b50bd2e72e..ef36fc0fc25 100644 --- a/.github/workflows/test_linux_jax_wheels.yml +++ b/.github/workflows/test_linux_jax_wheels.yml @@ -159,7 +159,6 @@ jobs: ROCM_VERSION: ${{ inputs.rocm_version }} AMDGPU_FAMILY: ${{ inputs.amdgpu_family }} run: | - python -m pip install -q --upgrade pip DEST="/opt/rocm-${{ inputs.rocm_version }}" # Install directly from TheRock release buckets (nightly/dev) using the provided version python build_tools/install_rocm_from_artifacts.py \ From 5e820c97c62b18ed145f9e15699e9032d619140d Mon Sep 17 00:00:00 2001 From: kithumma Date: Wed, 26 Nov 2025 10:50:27 +0000 Subject: [PATCH 30/40] remove symlinks --- .github/workflows/test_linux_jax_wheels.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/workflows/test_linux_jax_wheels.yml b/.github/workflows/test_linux_jax_wheels.yml index ef36fc0fc25..3b421bfc5c1 100644 --- a/.github/workflows/test_linux_jax_wheels.yml +++ b/.github/workflows/test_linux_jax_wheels.yml @@ -165,9 +165,6 @@ jobs: --release "${{ inputs.rocm_version }}" \ --artifact-group "${{ inputs.amdgpu_family }}" \ --output-dir "${DEST}" - # Create standard symlinks - ln -sfn "${DEST}" /opt/rocm - ln -sfn /opt/rocm /etc/alternatives/rocm - name: Extract JAX version and set to GITHUB_ENV run: | From 9987b725eafa6e16e17830f55787648b9a1c5778 Mon Sep 17 00:00:00 2001 From: kithumma Date: Fri, 28 Nov 2025 19:04:13 +0000 Subject: [PATCH 31/40] udpates --- .github/workflows/build_linux_jax_wheels.yml | 4 ++-- .github/workflows/test_linux_jax_wheels.yml | 16 +++++++++------- external-builds/jax/requirements-jax.txt | 2 +- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/.github/workflows/build_linux_jax_wheels.yml b/.github/workflows/build_linux_jax_wheels.yml index 3581ab98310..c358b60341e 100644 --- a/.github/workflows/build_linux_jax_wheels.yml +++ b/.github/workflows/build_linux_jax_wheels.yml @@ -32,7 +32,7 @@ on: required: true type: string repository: - description: "Repository to checkout. Otherwise, defaults to `github.repository`." + description: "Repository to checkout. Defaults to `ROCm/TheRock`." type: string ref: description: "Branch, tag or SHA to checkout. Defaults to the reference or SHA that triggered the workflow." @@ -156,7 +156,7 @@ jobs: source .venv/bin/activate pip3 install boto3 packaging python3 ./build_tools/third_party/s3_management/manage.py ${{ inputs.s3_subdir }}/${{ inputs.amdgpu_family }} - + generate_target_to_run: name: Generate target_to_run runs-on: ubuntu-24.04 diff --git a/.github/workflows/test_linux_jax_wheels.yml b/.github/workflows/test_linux_jax_wheels.yml index 3b421bfc5c1..edf5b437e9b 100644 --- a/.github/workflows/test_linux_jax_wheels.yml +++ b/.github/workflows/test_linux_jax_wheels.yml @@ -86,7 +86,6 @@ on: description: Runner label to use. The selected runner should have a GPU supported by amdgpu_family required: true type: string - default: "linux-mi325-1gpu-ossci-rocm-frac" permissions: contents: read @@ -97,7 +96,7 @@ jobs: name: Test JAX Wheels | ${{ inputs.amdgpu_family }} runs-on: ${{ inputs.test_runs_on }} container: - image: ${{ contains(inputs.test_runs_on, 'linux') && 'ghcr.io/rocm/no_rocm_image_ubuntu24_04@sha256:405945a40deaff9db90b9839c0f41d4cba4a383c1a7459b28627047bf6302a26' || null }} + image: ghcr.io/rocm/no_rocm_image_ubuntu24_04@sha256:405945a40deaff9db90b9839c0f41d4cba4a383c1a7459b28627047bf6302a26 options: >- --device /dev/kfd --device /dev/dri @@ -135,7 +134,7 @@ jobs: repository: rocm/jax ref: ${{ inputs.jax_ref }} path: jax/jax_tests - + - name: Set up Python uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 with: @@ -144,10 +143,9 @@ jobs: - name: System deps, venv configure run: | - python3 -m venv "${VIRTUAL_ENV}" - echo "PATH=${VIRTUAL_ENV}/bin:${PATH}" >> "$GITHUB_ENV" - "${VIRTUAL_ENV}/bin/python" -m pip install --upgrade pip setuptools wheel - python3 build_tools/setup_venv.py "${VIRTUAL_ENV}" --activate-in-future-github-actions-steps + python3 -m venv "${VIRTUAL_ENV}" + echo "PATH=${VIRTUAL_ENV}/bin:${PATH}" >> "$GITHUB_ENV" + python3 build_tools/setup_venv.py "${VIRTUAL_ENV}" --activate-in-future-github-actions-steps - name: Install base JAX test requirements run: | @@ -169,6 +167,10 @@ jobs: - name: Extract JAX version and set to GITHUB_ENV run: | # Extract JAX version from requirements.txt (e.g., "jax==0.8.0") + # Remove all whitespace from requirements.txt to simplify parsing + # Search for lines starting with "jax==" or "jaxlib==" followed by version (excluding comments) + # Extract the version number by splitting on '=' and taking the 3rd field + # [^#]+ matches one or more characters that are NOT '#', ensuring we stop before any inline comments JAX_VERSION=$(tr -d ' ' < jax/build/requirements.txt \ | grep -E '^(jax|jaxlib)==[^#]+' | head -n1 | cut -d'=' -f3) echo "JAX_VERSION=$JAX_VERSION" >> "$GITHUB_ENV" diff --git a/external-builds/jax/requirements-jax.txt b/external-builds/jax/requirements-jax.txt index 065666382a7..955265473e8 100644 --- a/external-builds/jax/requirements-jax.txt +++ b/external-builds/jax/requirements-jax.txt @@ -1,4 +1,4 @@ -boto3>=1.41.4 +boto3==1.41.4 numpy<2 build>=1.3.0 wheel>=0.45.1 From 9c15b8c8068390e7cffcbf1d1f185ef7bcfc6b03 Mon Sep 17 00:00:00 2001 From: kithumma Date: Mon, 1 Dec 2025 20:28:31 +0000 Subject: [PATCH 32/40] updated version to exat version --- .github/workflows/test_linux_jax_wheels.yml | 5 ++- external-builds/jax/requirements-jax.txt | 36 ++++++++++----------- 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/.github/workflows/test_linux_jax_wheels.yml b/.github/workflows/test_linux_jax_wheels.yml index edf5b437e9b..67208afa99f 100644 --- a/.github/workflows/test_linux_jax_wheels.yml +++ b/.github/workflows/test_linux_jax_wheels.yml @@ -86,6 +86,9 @@ on: description: Runner label to use. The selected runner should have a GPU supported by amdgpu_family required: true type: string + ref: + description: "Branch, tag or SHA to checkout. Defaults to the reference or SHA that triggered the workflow." + type: string permissions: contents: read @@ -128,7 +131,7 @@ jobs: repository: rocm/rocm-jax ref: ${{ inputs.jax_ref }} - - name: Checkout JAX tests repo (for extended tests) + - name: Checkout JAX extended tests repo uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: repository: rocm/jax diff --git a/external-builds/jax/requirements-jax.txt b/external-builds/jax/requirements-jax.txt index 955265473e8..9697d9d9237 100644 --- a/external-builds/jax/requirements-jax.txt +++ b/external-builds/jax/requirements-jax.txt @@ -1,20 +1,20 @@ boto3==1.41.4 numpy<2 -build>=1.3.0 -wheel>=0.45.1 -six>=1.17.0 -auditwheel>=6.5.0 -scipy>=1.13 -pytest>=9.0.1 -pytest-html>=4.1.1 -pytest_html_merger>=0.1.0 -pytest-reportlog>=1.0.0 -pytest-rerunfailures>=16.1 -pytest-json-report>=1.5.0 -cloudpickle>=3.1.2 -portpicker>=1.6.0 -matplotlib>=3.10.7 -absl-py>=2.3.1 -flatbuffers>=25.9.23 -hypothesis>=6.148.2 -ml_dtypes>=0.5.0 +build==1.3.0 +wheel==0.45.1 +six==1.17.0 +auditwheel==6.5.0 +scipy==1.13 +pytest==9.0.1 +pytest-html==4.1.1 +pytest_html_merger==0.1.0 +pytest-reportlog==1.0.0 +pytest-rerunfailures==16.1 +pytest-json-report==1.5.0 +cloudpickle==3.1.2 +portpicker==1.6.0 +matplotlib==3.10.7 +absl-py==2.3.1 +flatbuffers==25.9.23 +hypothesis==6.148.2 +ml_dtypes==0.5.0 From 41c4d9b05d5d429c1522819397e1f28b9825aeb2 Mon Sep 17 00:00:00 2001 From: kithumma Date: Mon, 1 Dec 2025 22:22:32 +0000 Subject: [PATCH 33/40] updated numpy version --- external-builds/jax/requirements-jax.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/external-builds/jax/requirements-jax.txt b/external-builds/jax/requirements-jax.txt index 9697d9d9237..6313249d8cb 100644 --- a/external-builds/jax/requirements-jax.txt +++ b/external-builds/jax/requirements-jax.txt @@ -1,5 +1,5 @@ boto3==1.41.4 -numpy<2 +numpy==1.26.4 build==1.3.0 wheel==0.45.1 six==1.17.0 From 537526c409adc2fb2e822afe525fba51d0d76eee Mon Sep 17 00:00:00 2001 From: kithumma Date: Wed, 3 Dec 2025 00:49:23 +0000 Subject: [PATCH 34/40] fixing merge conflicts --- .github/workflows/build_linux_jax_wheels.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/build_linux_jax_wheels.yml b/.github/workflows/build_linux_jax_wheels.yml index c358b60341e..cc8f20609f0 100644 --- a/.github/workflows/build_linux_jax_wheels.yml +++ b/.github/workflows/build_linux_jax_wheels.yml @@ -34,6 +34,7 @@ on: repository: description: "Repository to checkout. Defaults to `ROCm/TheRock`." type: string + default: "ROCm/TheRock" ref: description: "Branch, tag or SHA to checkout. Defaults to the reference or SHA that triggered the workflow." type: string @@ -192,6 +193,7 @@ jobs: rocm_version: ${{ inputs.rocm_version }} tar_url: ${{ inputs.tar_url }} python_version: ${{ inputs.python_version }} + repository: ${{ inputs.repository || github.repository }} ref: ${{ inputs.ref || '' }} jax_ref: ${{ inputs.jax_ref }} test_runs_on: ${{ needs.generate_target_to_run.outputs.test_runs_on }} From 0dc456642c6d75657767503bd233cd9aaef1b73e Mon Sep 17 00:00:00 2001 From: kithumma Date: Wed, 3 Dec 2025 01:24:24 +0000 Subject: [PATCH 35/40] update python_version parameter --- .github/workflows/test_linux_jax_wheels.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test_linux_jax_wheels.yml b/.github/workflows/test_linux_jax_wheels.yml index 67208afa99f..9ea8f820613 100644 --- a/.github/workflows/test_linux_jax_wheels.yml +++ b/.github/workflows/test_linux_jax_wheels.yml @@ -24,7 +24,7 @@ on: description: URL to TheRock tarball to configure ROCm required: true type: string - python_versions: + python_version: description: Python version(s) to test (e.g., "3.12") required: true type: string @@ -73,7 +73,7 @@ on: description: URL to TheRock tarball to configure ROCm required: true type: string - python_versions: + python_version: description: Python version(s) to test (e.g., "3.12") required: true type: string @@ -114,7 +114,7 @@ jobs: VIRTUAL_ENV: ${{ github.workspace }}/.venv AMDGPU_FAMILY: ${{ inputs.amdgpu_family }} THEROCK_TAR_URL: ${{ inputs.tar_url }} - PYTHON_VERSION: ${{ inputs.python_versions }} + PYTHON_VERSION: ${{ inputs.python_version }} WHEEL_INDEX_URL: ${{ inputs.package_index_url }}/${{ inputs.amdgpu_family }} steps: @@ -141,7 +141,7 @@ jobs: - name: Set up Python uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 with: - python-version: ${{ inputs.python_versions }} + python-version: ${{ inputs.python_version }} check-latest: true - name: System deps, venv configure From 5d5f40546e8594c4d4ff68eecc2c8bc8206d3e07 Mon Sep 17 00:00:00 2001 From: kithumma Date: Wed, 3 Dec 2025 06:38:46 +0000 Subject: [PATCH 36/40] update repository workflow_call --- .github/workflows/test_linux_jax_wheels.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/test_linux_jax_wheels.yml b/.github/workflows/test_linux_jax_wheels.yml index 9ea8f820613..fa9518a2227 100644 --- a/.github/workflows/test_linux_jax_wheels.yml +++ b/.github/workflows/test_linux_jax_wheels.yml @@ -28,6 +28,9 @@ on: description: Python version(s) to test (e.g., "3.12") required: true type: string + repository: + description: "Repository to checkout. Otherwise, defaults to `github.repository`." + type: string jax_ref: description: rocm-jax repository ref/branch to check out required: false From ce3e0e9519382f22411710706f4c104ec3d1aa82 Mon Sep 17 00:00:00 2001 From: kithumma Date: Wed, 3 Dec 2025 07:53:00 +0000 Subject: [PATCH 37/40] udpate version scipy, ml_dtypes, numpy --- external-builds/jax/requirements-jax.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/external-builds/jax/requirements-jax.txt b/external-builds/jax/requirements-jax.txt index 6313249d8cb..711dbd82757 100644 --- a/external-builds/jax/requirements-jax.txt +++ b/external-builds/jax/requirements-jax.txt @@ -1,10 +1,10 @@ boto3==1.41.4 -numpy==1.26.4 +numpy>=2.0.0 build==1.3.0 wheel==0.45.1 six==1.17.0 auditwheel==6.5.0 -scipy==1.13 +scipy>=1.13 pytest==9.0.1 pytest-html==4.1.1 pytest_html_merger==0.1.0 @@ -17,4 +17,4 @@ matplotlib==3.10.7 absl-py==2.3.1 flatbuffers==25.9.23 hypothesis==6.148.2 -ml_dtypes==0.5.0 +ml_dtypes>=0.5.0 From 8fa37f856cc8c40cfe8b457d77485f1eb274ba8a Mon Sep 17 00:00:00 2001 From: kithumma Date: Wed, 3 Dec 2025 08:10:56 +0000 Subject: [PATCH 38/40] update amdgpu_family --- .github/workflows/test_linux_jax_wheels.yml | 6 +++++- external-builds/jax/requirements-jax.txt | 6 +++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test_linux_jax_wheels.yml b/.github/workflows/test_linux_jax_wheels.yml index fa9518a2227..549580f8c59 100644 --- a/.github/workflows/test_linux_jax_wheels.yml +++ b/.github/workflows/test_linux_jax_wheels.yml @@ -47,9 +47,13 @@ on: amdgpu_family: type: choice options: - - gfx110X-dgpu + - gfx101X-dgpu + - gfx103X-dgpu + - gfx110X-all + - gfx1150 - gfx1151 - gfx120X-all + - gfx90X-dcgpu - gfx94X-dcgpu - gfx950-dcgpu default: gfx94X-dcgpu diff --git a/external-builds/jax/requirements-jax.txt b/external-builds/jax/requirements-jax.txt index 711dbd82757..59774301fca 100644 --- a/external-builds/jax/requirements-jax.txt +++ b/external-builds/jax/requirements-jax.txt @@ -1,10 +1,10 @@ boto3==1.41.4 -numpy>=2.0.0 +numpy==2.3.5 build==1.3.0 wheel==0.45.1 six==1.17.0 auditwheel==6.5.0 -scipy>=1.13 +scipy==1.16.3 pytest==9.0.1 pytest-html==4.1.1 pytest_html_merger==0.1.0 @@ -17,4 +17,4 @@ matplotlib==3.10.7 absl-py==2.3.1 flatbuffers==25.9.23 hypothesis==6.148.2 -ml_dtypes>=0.5.0 +ml_dtypes==0.5.4 From df44c8c77e1dcb7f2bd2984b856c24b04fc2b155 Mon Sep 17 00:00:00 2001 From: kithumma Date: Wed, 3 Dec 2025 20:21:59 +0000 Subject: [PATCH 39/40] update merge conflicts --- .github/workflows/test_jax_dockerfile.yml | 22 --------------------- .github/workflows/test_linux_jax_wheels.yml | 2 +- 2 files changed, 1 insertion(+), 23 deletions(-) diff --git a/.github/workflows/test_jax_dockerfile.yml b/.github/workflows/test_jax_dockerfile.yml index edbaa84381f..716e47293bb 100644 --- a/.github/workflows/test_jax_dockerfile.yml +++ b/.github/workflows/test_jax_dockerfile.yml @@ -6,11 +6,7 @@ on: test_runs_on: required: true type: string -<<<<<<< HEAD - default: "linux-mi325-1gpu-ossci-rocm-frac" -======= default: "linux-mi325-1gpu-ossci-rocm" ->>>>>>> 541e0fc1 (update) image_name: required: true description: JAX docker image to run tests with @@ -20,10 +16,7 @@ on: required: false type: string jax_plugin_branch: -<<<<<<< HEAD -======= required: true ->>>>>>> 541e0fc1 (update) description: JAX plugin branch to checkout type: string default: "rocm-jaxlib-v0.6.0" @@ -42,35 +35,20 @@ on: required: false type: string jax_plugin_branch: -<<<<<<< HEAD description: JAX plugin branch to checkout to use for test scripts type: string default: "rocm-jaxlib-v0.8.0" -======= - required: true - description: JAX plugin branch to checkout to use for test scripts - type: string - default: "rocm-jaxlib-v0.6.0" ->>>>>>> 541e0fc1 (update) permissions: contents: read jobs: test_wheels: -<<<<<<< HEAD name: Test runs-on: ${{ inputs.test_runs_on }} steps: - name: Checkout uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 -======= - name: Test | ${{ inputs.amdgpu_family }} - runs-on: ${{ inputs.test_runs_on }} - steps: - - name: Checkout - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 ->>>>>>> 541e0fc1 (update) with: repo: rocm/rocm-jax # TODO: Add steps for creating the JAX docker image with an install of TheRock and then running JAX tests on the container diff --git a/.github/workflows/test_linux_jax_wheels.yml b/.github/workflows/test_linux_jax_wheels.yml index 549580f8c59..36564d7c0dc 100644 --- a/.github/workflows/test_linux_jax_wheels.yml +++ b/.github/workflows/test_linux_jax_wheels.yml @@ -138,7 +138,7 @@ jobs: repository: rocm/rocm-jax ref: ${{ inputs.jax_ref }} - - name: Checkout JAX extended tests repo + - name: Checkout JAX extended tests repo uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: repository: rocm/jax From 814582bbcbc806ae2d60777951693f676006df1b Mon Sep 17 00:00:00 2001 From: kithumma Date: Thu, 4 Dec 2025 17:16:46 +0000 Subject: [PATCH 40/40] updates as per comments --- .github/workflows/build_linux_jax_wheels.yml | 2 +- .github/workflows/test_jax_dockerfile.yml | 2 +- .github/workflows/test_linux_jax_wheels.yml | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build_linux_jax_wheels.yml b/.github/workflows/build_linux_jax_wheels.yml index cc8f20609f0..ee9442df32c 100644 --- a/.github/workflows/build_linux_jax_wheels.yml +++ b/.github/workflows/build_linux_jax_wheels.yml @@ -108,7 +108,7 @@ jobs: with: path: jax repository: rocm/rocm-jax - ref: ${{ inputs.jax_ref || matrix.jax_ref }} + ref: ${{ matrix.jax_ref }} - name: Configure Git Identity run: | diff --git a/.github/workflows/test_jax_dockerfile.yml b/.github/workflows/test_jax_dockerfile.yml index 716e47293bb..d459224343f 100644 --- a/.github/workflows/test_jax_dockerfile.yml +++ b/.github/workflows/test_jax_dockerfile.yml @@ -6,7 +6,7 @@ on: test_runs_on: required: true type: string - default: "linux-mi325-1gpu-ossci-rocm" + default: "linux-mi325-1gpu-ossci-rocm-frac" image_name: required: true description: JAX docker image to run tests with diff --git a/.github/workflows/test_linux_jax_wheels.yml b/.github/workflows/test_linux_jax_wheels.yml index 36564d7c0dc..35fd1a517bd 100644 --- a/.github/workflows/test_linux_jax_wheels.yml +++ b/.github/workflows/test_linux_jax_wheels.yml @@ -93,6 +93,7 @@ on: description: Runner label to use. The selected runner should have a GPU supported by amdgpu_family required: true type: string + default: "linux-mi325-1gpu-ossci-rocm-frac" ref: description: "Branch, tag or SHA to checkout. Defaults to the reference or SHA that triggered the workflow." type: string