diff --git a/.github/workflows/build_jax_dockerfile.yml b/.github/workflows/build_jax_dockerfile.yml deleted file mode 100644 index fd9e078431f..00000000000 --- a/.github/workflows/build_jax_dockerfile.yml +++ /dev/null @@ -1,79 +0,0 @@ -name: Build Linux JAX Docker Images - -on: - workflow_call: - inputs: - amdgpu_family: - required: true - type: string - python_version: - required: true - type: string - release_type: - description: The type of release to build ("nightly", "prerelease", or "dev") - required: true - type: string - s3_subdir: - description: S3 subdirectory, not including the GPU-family - required: true - type: string - rocm_version: - description: ROCm version to install - type: string - tar_url: - description: URL to TheRock tarball to build against - type: string - workflow_dispatch: - inputs: - amdgpu_family: - required: true - type: string - python_version: - required: true - type: string - release_type: - description: The type of release to build ("nightly", "prerelease", or "dev") - required: true - type: string - default: "dev" - s3_subdir: - description: S3 subdirectory, not including the GPU-family - type: string - default: "v2" - rocm_version: - description: ROCm version to install - type: string - tar_url: - description: URL to TheRock tarball to build against - type: string - -permissions: - id-token: write - contents: read - -jobs: - build_jax_wheels: - strategy: - matrix: - jax_ref: [master] - name: Build | ${{ inputs.amdgpu_family }} | Python ${{ inputs.python_version }} - runs-on: ${{ github.repository_owner == 'ROCm' && 'azure-linux-scale-rocm' || 'ubuntu-24.04' }} - env: - PACKAGE_DIST_DIR: ${{ github.workspace }}/jax_rocm_plugin/wheelhouse - S3_BUCKET_PY: "therock-${{ inputs.release_type }}-python" - steps: - - name: Checkout TheRock - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 - - - name: Checkout JAX - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 - with: - path: jax - repository: rocm/rocm-jax - ref: ${{ matrix.jax_ref }} - - - name: Configure Git Identity - run: | - git config --global user.name "therockbot" - git config --global user.email "therockbot@amd.com" - # TODO: Pull down JAX plugin wheels into the wheelhouse directory and run the image build script from rocm-jax diff --git a/.github/workflows/build_linux_jax_wheels.yml b/.github/workflows/build_linux_jax_wheels.yml index 6d044a2a578..ee9442df32c 100644 --- a/.github/workflows/build_linux_jax_wheels.yml +++ b/.github/workflows/build_linux_jax_wheels.yml @@ -23,6 +23,21 @@ on: tar_url: description: URL to TheRock tarball to build against type: string + cloudfront_url: + description: CloudFront URL pointing to Python index + required: true + type: string + cloudfront_staging_url: + description: CloudFront base URL pointing to staging Python index + required: true + type: string + repository: + description: "Repository to checkout. Defaults to `ROCm/TheRock`." + type: string + default: "ROCm/TheRock" + ref: + description: "Branch, tag or SHA to checkout. Defaults to the reference or SHA that triggered the workflow." + type: string workflow_dispatch: inputs: amdgpu_family: @@ -57,6 +72,18 @@ on: tar_url: description: URL to TheRock tarball to build against type: string + cloudfront_url: + description: CloudFront base URL pointing to Python index + type: string + default: "https://d25kgig7rdsyks.cloudfront.net/v2" + cloudfront_staging_url: + description: CloudFront base URL pointing to staging Python index + type: string + default: "https://d25kgig7rdsyks.cloudfront.net/v2-staging" + jax_ref: + description: rocm-jax repository ref/branch to check out + type: string + default: rocm-jaxlib-v0.8.0 permissions: id-token: write @@ -130,3 +157,43 @@ jobs: source .venv/bin/activate pip3 install boto3 packaging python3 ./build_tools/third_party/s3_management/manage.py ${{ inputs.s3_subdir }}/${{ inputs.amdgpu_family }} + + generate_target_to_run: + name: Generate target_to_run + runs-on: ubuntu-24.04 + outputs: + test_runs_on: ${{ steps.configure.outputs.test-runs-on }} + bypass_tests_for_releases: ${{ steps.configure.outputs.bypass_tests_for_releases }} + steps: + - name: Checking out repository + uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 + with: + repository: ${{ inputs.repository || github.repository }} + ref: ${{ inputs.ref || '' }} + + - name: Generating target to run + id: configure + env: + TARGET: ${{ inputs.amdgpu_family }} + PLATFORM: "linux" + run: python ./build_tools/github_actions/configure_target_run.py + + test_jax_wheels: + name: Test JAX wheels | ${{ inputs.amdgpu_family }} | ${{ needs.generate_target_to_run.outputs.test_runs_on }} + needs: [build_jax_wheels, generate_target_to_run] + permissions: + contents: read + packages: read + uses: ./.github/workflows/test_linux_jax_wheels.yml + with: + amdgpu_family: ${{ inputs.amdgpu_family }} + release_type: ${{ inputs.release_type }} + s3_subdir: ${{ inputs.s3_subdir }} + package_index_url: ${{ inputs.cloudfront_url }} + rocm_version: ${{ inputs.rocm_version }} + tar_url: ${{ inputs.tar_url }} + python_version: ${{ inputs.python_version }} + repository: ${{ inputs.repository || github.repository }} + ref: ${{ inputs.ref || '' }} + jax_ref: ${{ inputs.jax_ref }} + test_runs_on: ${{ needs.generate_target_to_run.outputs.test_runs_on }} diff --git a/.github/workflows/test_jax_dockerfile.yml b/.github/workflows/test_jax_dockerfile.yml index c63b82ca5e5..d459224343f 100644 --- a/.github/workflows/test_jax_dockerfile.yml +++ b/.github/workflows/test_jax_dockerfile.yml @@ -16,6 +16,7 @@ on: required: false type: string jax_plugin_branch: + required: true description: JAX plugin branch to checkout type: string default: "rocm-jaxlib-v0.6.0" @@ -36,7 +37,7 @@ on: jax_plugin_branch: description: JAX plugin branch to checkout to use for test scripts type: string - default: "rocm-jaxlib-v0.6.0" + default: "rocm-jaxlib-v0.8.0" permissions: contents: read diff --git a/.github/workflows/test_linux_jax_wheels.yml b/.github/workflows/test_linux_jax_wheels.yml new file mode 100644 index 00000000000..35fd1a517bd --- /dev/null +++ b/.github/workflows/test_linux_jax_wheels.yml @@ -0,0 +1,203 @@ +name: Test Linux JAX Wheels + +on: + workflow_call: + inputs: + amdgpu_family: + required: true + type: string + release_type: + required: true + type: string + s3_subdir: + required: true + type: string + package_index_url: + description: Base CloudFront URL for the Python package index + required: true + type: string + rocm_version: + description: ROCm version (optional, informational) + required: false + type: string + tar_url: + description: URL to TheRock tarball to configure ROCm + required: true + type: string + python_version: + description: Python version(s) to test (e.g., "3.12") + required: true + type: string + repository: + description: "Repository to checkout. Otherwise, defaults to `github.repository`." + type: string + jax_ref: + description: rocm-jax repository ref/branch to check out + required: false + type: string + ref: + description: "Branch, tag or SHA to checkout. Defaults to the reference or SHA that triggered the workflow." + type: string + test_runs_on: + required: true + type: string + + workflow_dispatch: + inputs: + amdgpu_family: + type: choice + options: + - gfx101X-dgpu + - gfx103X-dgpu + - gfx110X-all + - gfx1150 + - gfx1151 + - gfx120X-all + - gfx90X-dcgpu + - gfx94X-dcgpu + - gfx950-dcgpu + default: gfx94X-dcgpu + release_type: + description: The type of release ("nightly" or "dev") + required: true + type: string + default: dev + s3_subdir: + description: S3 subdirectory, not including the GPU-family + required: true + type: string + default: v2 + package_index_url: + description: Base CloudFront URL for the Python package index + required: true + type: string + default: https://rocm.nightlies.amd.com/v2-staging/ + rocm_version: + description: ROCm version + required: false + type: string + tar_url: + description: URL to TheRock tarball to configure ROCm + required: true + type: string + python_version: + description: Python version(s) to test (e.g., "3.12") + required: true + type: string + default: "3.12" + jax_ref: + description: rocm-jax repository ref/branch to check out + required: false + type: string + test_runs_on: + description: Runner label to use. The selected runner should have a GPU supported by amdgpu_family + required: true + type: string + default: "linux-mi325-1gpu-ossci-rocm-frac" + ref: + description: "Branch, tag or SHA to checkout. Defaults to the reference or SHA that triggered the workflow." + type: string + +permissions: + contents: read + packages: read + +jobs: + test_jax_wheels: + name: Test JAX Wheels | ${{ inputs.amdgpu_family }} + runs-on: ${{ inputs.test_runs_on }} + container: + image: ghcr.io/rocm/no_rocm_image_ubuntu24_04@sha256:405945a40deaff9db90b9839c0f41d4cba4a383c1a7459b28627047bf6302a26 + options: >- + --device /dev/kfd + --device /dev/dri + --group-add render + --group-add video + --user root + --env-file /etc/podinfo/gha-gpu-isolation-settings + defaults: + run: + shell: bash + env: + VIRTUAL_ENV: ${{ github.workspace }}/.venv + AMDGPU_FAMILY: ${{ inputs.amdgpu_family }} + THEROCK_TAR_URL: ${{ inputs.tar_url }} + PYTHON_VERSION: ${{ inputs.python_version }} + WHEEL_INDEX_URL: ${{ inputs.package_index_url }}/${{ inputs.amdgpu_family }} + + steps: + - name: Checkout + uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 + with: + repository: ${{ inputs.repository || github.repository }} + ref: ${{ inputs.ref || '' }} + + - name: Checkout rocm-jax (plugin + build scripts) + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + with: + path: jax + repository: rocm/rocm-jax + ref: ${{ inputs.jax_ref }} + + - name: Checkout JAX extended tests repo + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + with: + repository: rocm/jax + ref: ${{ inputs.jax_ref }} + path: jax/jax_tests + + - name: Set up Python + uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 + with: + python-version: ${{ inputs.python_version }} + check-latest: true + + - name: System deps, venv configure + run: | + python3 -m venv "${VIRTUAL_ENV}" + echo "PATH=${VIRTUAL_ENV}/bin:${PATH}" >> "$GITHUB_ENV" + python3 build_tools/setup_venv.py "${VIRTUAL_ENV}" --activate-in-future-github-actions-steps + + - name: Install base JAX test requirements + run: | + # This script sets up the venv and activates it across steps; keep it consistent + pip install -r external-builds/jax/requirements-jax.txt + + - name: Configure ROCm from TheRock tarball + env: + ROCM_VERSION: ${{ inputs.rocm_version }} + AMDGPU_FAMILY: ${{ inputs.amdgpu_family }} + run: | + DEST="/opt/rocm-${{ inputs.rocm_version }}" + # Install directly from TheRock release buckets (nightly/dev) using the provided version + python build_tools/install_rocm_from_artifacts.py \ + --release "${{ inputs.rocm_version }}" \ + --artifact-group "${{ inputs.amdgpu_family }}" \ + --output-dir "${DEST}" + + - name: Extract JAX version and set to GITHUB_ENV + run: | + # Extract JAX version from requirements.txt (e.g., "jax==0.8.0") + # Remove all whitespace from requirements.txt to simplify parsing + # Search for lines starting with "jax==" or "jaxlib==" followed by version (excluding comments) + # Extract the version number by splitting on '=' and taking the 3rd field + # [^#]+ matches one or more characters that are NOT '#', ensuring we stop before any inline comments + JAX_VERSION=$(tr -d ' ' < jax/build/requirements.txt \ + | grep -E '^(jax|jaxlib)==[^#]+' | head -n1 | cut -d'=' -f3) + echo "JAX_VERSION=$JAX_VERSION" >> "$GITHUB_ENV" + + - name: Install JAX wheels from package index + run: | + # Install jaxlib/plugin/pjrt from the GPU-family index; install jax from PyPI to match the version + pip install --index-url "${{ env.WHEEL_INDEX_URL }}" \ + "jaxlib==${JAX_VERSION}+rocm${{ inputs.rocm_version }}" \ + "jax-rocm7-plugin==${JAX_VERSION}+rocm${{ inputs.rocm_version }}" \ + "jax-rocm7-pjrt==${JAX_VERSION}+rocm${{ inputs.rocm_version }}" + pip install --extra-index-url https://pypi.org/simple "jax==${JAX_VERSION}" + + - name: Run JAX tests + run: | + pytest jax/jax_tests/tests/multi_device_test.py -q --log-cli-level=INFO + pytest jax/jax_tests/tests/core_test.py -q --log-cli-level=INFO + pytest jax/jax_tests/tests/util_test.py -q --log-cli-level=INFO + pytest jax/jax_tests/tests/scipy_stats_test.py -q --log-cli-level=INFO diff --git a/external-builds/jax/requirements-jax.txt b/external-builds/jax/requirements-jax.txt new file mode 100644 index 00000000000..59774301fca --- /dev/null +++ b/external-builds/jax/requirements-jax.txt @@ -0,0 +1,20 @@ +boto3==1.41.4 +numpy==2.3.5 +build==1.3.0 +wheel==0.45.1 +six==1.17.0 +auditwheel==6.5.0 +scipy==1.16.3 +pytest==9.0.1 +pytest-html==4.1.1 +pytest_html_merger==0.1.0 +pytest-reportlog==1.0.0 +pytest-rerunfailures==16.1 +pytest-json-report==1.5.0 +cloudpickle==3.1.2 +portpicker==1.6.0 +matplotlib==3.10.7 +absl-py==2.3.1 +flatbuffers==25.9.23 +hypothesis==6.148.2 +ml_dtypes==0.5.4