Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 60 additions & 2 deletions .github/workflows/build_linux_jax_wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@ on:
tar_url:
description: URL to TheRock tarball to build against
type: string
cloudfront_url:
description: CloudFront URL pointing to Python index
required: true
type: string
cloudfront_staging_url:
description: CloudFront base URL pointing to staging Python index
required: true
type: string
ref:
description: "Branch, tag or SHA to checkout. Defaults to the reference or SHA that triggered the workflow."
type: string
workflow_dispatch:
inputs:
amdgpu_family:
Expand Down Expand Up @@ -57,6 +68,21 @@ on:
tar_url:
description: URL to TheRock tarball to build against
type: string
cloudfront_url:
description: CloudFront URL pointing to Python index
required: true
type: string
cloudfront_staging_url:
description: CloudFront base URL pointing to staging Python index
required: true
type: string
ref:
description: "Branch, tag or SHA to checkout. Defaults to the reference or SHA that triggered the workflow."
type: string
jax_ref:
description: Branch/tag/sha of rocm/rocm-jax to build
type: string
default: "rocm-jaxlib-v0.7.1"

permissions:
id-token: write
Expand All @@ -72,16 +98,21 @@ jobs:
env:
PACKAGE_DIST_DIR: ${{ github.workspace }}/jax/jax_rocm_plugin/wheelhouse
S3_BUCKET_PY: "therock-${{ inputs.release_type }}-python"
outputs:
jax_whl_list: ${{ steps.jax_wheel_list.outputs.jax_whl_list }}
steps:
- name: Checkout TheRock
- name: Checkout
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
repository: ${{ inputs.repository || github.repository }}
ref: ${{ inputs.ref || '' }}

- name: Checkout JAX
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
path: jax
repository: rocm/rocm-jax
ref: ${{ matrix.jax_ref }}
ref: ${{ inputs.jax_ref || matrix.jax_ref }}

- name: Configure Git Identity
run: |
Expand Down Expand Up @@ -117,6 +148,13 @@ jobs:
aws-region: us-east-2
role-to-assume: arn:aws:iam::692859939525:role/therock-${{ inputs.release_type }}-releases

- name: jax wheel list
id: jax_wheel_list
run: |
export jax_whl_list="$(python3 ./build_tools/mapping_built_jax_wheels.py --dir "${{ env.PACKAGE_DIST_DIR }}")"
echo "jax_whl_list=${jax_whl_list}" >> "$GITHUB_OUTPUT"
echo "${jax_whl_list}" >> "$GITHUB_STEP_SUMMARY"
Comment on lines +151 to +156
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please follow the style of https://github.com/ROCm/TheRock/blob/main/build_tools/github_actions/write_torch_versions.py here with a build_tools/github_actions/write_jax_versions.py

Usage:

  • - name: Build PyTorch Wheels
    id: build-pytorch-wheels
    run: |
    echo "Building PyTorch wheels for ${{ inputs.amdgpu_family }}"
    ./external-builds/pytorch/build_prod_wheels.py \
    build \
    --install-rocm \
    --pip-cache-dir /tmp/pipcache \
    --index-url "${{ inputs.cloudfront_url }}/${{ inputs.amdgpu_family }}/" \
    --clean \
    --output-dir ${{ env.PACKAGE_DIST_DIR }} ${{ env.optional_build_prod_arguments }}
    python ./build_tools/github_actions/write_torch_versions.py --dist-dir ${{ env.PACKAGE_DIST_DIR }}
  • outputs:
    cp_version: ${{ env.cp_version }}
    torch_version: ${{ steps.build-pytorch-wheels.outputs.torch_version }}
    torchaudio_version: ${{ steps.build-pytorch-wheels.outputs.torchaudio_version }}
    torchvision_version: ${{ steps.build-pytorch-wheels.outputs.torchvision_version }}
    triton_version: ${{ steps.build-pytorch-wheels.outputs.triton_version }}
  • (and other parts of that file)
  • env:
    VENV_DIR: ${{ github.workspace }}/.venv
    TORCH_VERSION: ${{ inputs.torch_version }}
  • - name: Set up virtual environment
    run: |
    python build_tools/setup_venv.py ${VENV_DIR} \
    --packages torch==${TORCH_VERSION} \
    --index-url ${{ inputs.package_index_url }} \
    --index-subdir ${{ inputs.amdgpu_family }} \
    --activate-in-future-github-actions-steps

Note that the versions of the files are what is computed and passed through the workflows, not their filenames. Then pip commands use the versions to download or install wheels with those versions (as users would do outside of github actions workflows)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will think of adding a logging or summary later as part of enhancements as it already has the logs.


- name: Upload wheels to S3
if: ${{ github.repository_owner == 'ROCm' }}
run: |
Expand All @@ -130,3 +168,23 @@ jobs:
source .venv/bin/activate
pip3 install boto3 packaging
python3 ./build_tools/third_party/s3_management/manage.py ${{ inputs.s3_subdir }}/${{ inputs.amdgpu_family }}

test_jax_wheels:
name: Test JAX wheels | ${{ inputs.amdgpu_family }}
needs: [build_jax_wheels]
if: ${{ github.repository_owner == 'ROCm' }}
permissions:
id-token: write
contents: read
packages: write
uses: ./.github/workflows/test_linux_jax_wheels.yml
with:
amdgpu_family: ${{ inputs.amdgpu_family }}
release_type: ${{ inputs.release_type }}
package_index_url: ${{ inputs.cloudfront_staging_url }}
rocm_version: ${{ inputs.rocm_version }}
tar_url: ${{ inputs.tar_url }}
python_versions: ${{ inputs.python_versions }}
jax_ref: rocm-jaxlib-v0.7.1
jax_whl_list: ${{ needs.build_jax_wheels.outputs.jax_whl_list }}
test_runs_on: "linux-mi325-1gpu-ossci-rocm"
6 changes: 5 additions & 1 deletion .github/workflows/release_portable_linux_packages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,12 @@ jobs:
"python_versions": "3.12",
"release_type": "${{ env.RELEASE_TYPE }}",
"s3_subdir": "${{ env.S3_STAGING_SUBDIR }}",
"s3_staging_subdir": "${{ env.S3_STAGING_SUBDIR }}",
"rocm_version": "${{ needs.setup_metadata.outputs.version }}",
"tar_url": "${{ steps.url-encode-tar.outputs.tar_url }}"
"tar_url": "${{ steps.url-encode-tar.outputs.tar_url }}",
"cloudfront_url": "${{ needs.setup_metadata.outputs.cloudfront_url }}",
"cloudfront_staging_url": "${{ needs.setup_metadata.outputs.cloudfront_staging_url }}",
"ref": "${{ inputs.ref || '' }}"
}

- name: Save cache
Expand Down
167 changes: 167 additions & 0 deletions .github/workflows/test_linux_jax_wheels.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
name: Test JAX Wheels
on:
workflow_call:
inputs:
amdgpu_family:
description: GPU family to test
required: true
type: string
release_type:
description: The type of release
required: true
type: string
rocm_version:
description: ROCm version to use in Docker
required: true
type: string
tar_url:
description: URL to TheRock tarball to build against
required: true
type: string
test_runs_on:
required: true
type: string
default: "linux-mi325-1gpu-ossci-rocm"
python_versions:
description: Python version string used to compute CP tag
required: true
type: string
jax_ref:
description: Branch/tag/sha of rocm/rocm-jax to build
required: false
type: string
default: "rocm-jaxlib-v0.7.1"
package_index_url:
required: true
type: string
jax_whl_list:
description: list of built jax wheel filenames to fetch & test
required: true
type: string

workflow_dispatch:
inputs:
amdgpu_family:
type: choice
options:
- gfx110X-dgpu
- gfx1151
- gfx120X-all
- gfx94X-dcgpu
- gfx950-dcgpu
default: gfx94X-dcgpu
release_type:
type: choice
options: [dev, nightly, prerelease]
default: dev
rocm_version:
type: string
required: true
tar_url:
type: string
required: true
test_runs_on:
description: Runner label to use. The selected runner should have a GPU supported by amdgpu_family
required: true
type: string
default: "linux-mi325-1gpu-ossci-rocm"
python_versions:
type: string
required: true
default: "3.12"
jax_ref:
type: string
default: "rocm-jaxlib-v0.7.1"
package_index_url:
description: Base Python package index URL to test, typically nightly/dev URL with a "v2" or "v2-staging" subdir (without a GPU family subdir)
required: true
type: string
default: "https://rocm.nightlies.amd.com/v2"
jax_whl_list:
description: list of built jax wheel filenames to fetch & test
required: true
type: string

permissions:
id-token: write
contents: read
packages: none

jobs:
test_jax_wheels:
name: Test | ${{ inputs.amdgpu_family }}
runs-on: ${{ inputs.test_runs_on }}
env:
WHEELHOUSE_DIR: ${{ github.workspace }}/jax/wheelhouse
steps:
- name: Checkout TheRock
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
repository: ${{ inputs.repository || github.repository }}
ref: ${{ inputs.ref || '' }}

- name: Checkout rocm-jax (plugin + build scripts)
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
path: jax
repository: rocm/rocm-jax
ref: ${{ inputs.jax_ref }}

- name: Setup Python
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
with:
python-version: "3.12"

- name: Prepare wheelhouse and download JAX wheels
shell: bash
working-directory: jax
run: |
python3 -m pip install --upgrade pip
python3 -m pip install requests
python3 ../build_tools/fetch_wheels.py \
--cloudfront_url ${{ inputs.package_index_url }} \
--amdgpu_family ${{ inputs.amdgpu_family }} \
--dir "wheelhouse" \
--list_whls '${{ inputs.jax_whl_list }}'
Comment on lines +115 to +125
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this actually doing with the downloaded wheels? Is the python3 build/ci_build script using this "wheelhouse" directory?

Can this use setup_venv.py as we do already in the pytorch test workflow?

- name: Set up virtual environment
run: |
python build_tools/setup_venv.py ${VENV_DIR} \
--packages torch==${TORCH_VERSION} \
--index-url ${{ inputs.package_index_url }} \
--index-subdir ${{ inputs.amdgpu_family }} \
--activate-in-future-github-actions-steps

Here are some logs of that script running in the workflow: https://github.com/ROCm/TheRock/actions/runs/19228713043/job/54971284725#step:7:22

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we cannot use a setup_venv.py. As we have to change the docker build steps all that if we have to use a venv. This is the same way just a wheelhouse directory.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is Docker a requirement? I still haven't seen a clear answer to that. What are the user install instructions like? Are they more than pip install jax --index-url=...? Tests should match what users will do, and we can't expect users to download files manually then build a dockerfile.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it does has other requriements to be configured and not the whole code is in theROCk as we are using this workflow as a wrapper to the rocm-jax and upstream code. I have update the script as per suggestion to use requests instead of wget.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ScottTodd It's not a requirement for running tests. This is just the way that we currently do it in our CI because some Ubuntu docker images are part of our deliverables. We do everything through the build/ci_build script, and the script assumes that you want to do the build in three stages:

  1. Build the wheels
  2. Build the docker images
  3. Test using the docker images

There isn't a step in our CI that installs the wheels (and requirements like ROCm or pytest) with pip and then runs the tests. You can do this, of course, but we just don't have an easy command for it in build/ci_build because we're taking care of that in step 3 of our CI. We could probably add a command that will do that to build/ci_build though, or you could try and do that right in the workflow.


- name: Compute ROCM_VERSION_SHORT
working-directory: jax
env:
ROCM_VERSION: ${{ inputs.rocm_version }}
run: |
# Extract major.minor.patch from ROCM_VERSION
# sed command captures first three numeric components separated by dots and reconstructs them as major.minor.patch
ROCM_VERSION_SHORT="$(echo "${ROCM_VERSION}" | sed -E 's/^([0-9]+)\.([0-9]+)\.([0-9]+).*/\1.\2.\3/')"
echo "ROCM_VERSION_SHORT=${ROCM_VERSION_SHORT}" >> "$GITHUB_ENV"
echo "Using semantic ROCm version: ${ROCM_VERSION_SHORT}"
Comment on lines +127 to +136
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is suspicious and I don't think it should be needed. At a minimum there should be a comment here explaining what problem this code is solving.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm passing the full rocmversion docker fails to add that as a tag. Added more comments.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no need to compute a tag. Rather spin up a docker, pip install the jax wheels, run tests and that's it. No need to build a docker container.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either way its the same. Goal is to test the wheels. The tests currently which jax team is supporting are these tests.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like JAX has these docs, which do suggest running in a container: https://github.com/ROCm/rocm-jax/blob/master/BUILDING.md#3-running-tests, but I'm not seeing a technical explanation there for why tests are difficult to run outside of a container. We are strongly discouraging Docker and container usage in new packaging in TheRock and this is the time to investigate those details.

Again, and I'm getting frustrated repeating myself - this type of work needs to start with a design discussion, before we get into details during code review. We've wasted multiple weeks now going back and forth without seeing a plan written down. At a minimum provide links to pages like these so we can discuss the design points - that is your job as the PR author and not something reviewers should need to research on their own:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're not terribly difficult, it just requires the right dependencies. Building the image and then running with that is just what we do in our current CI setup. See my other comment. I wholeheartedly agree that doing some design work and getting clear on how TheRock expects frameworks to build and test is much needed, and I'd be happy to discuss it somewhere that's not a PR comment thread.


- name: Build JAX Docker image
working-directory: jax
env:
ROCM_VERSION: ${{ inputs.rocm_version }}
run: |
python3 build/ci_build \
--rocm-version="${ROCM_VERSION_SHORT}" \
--therock-path="${{ inputs.tar_url }}" \
build_dockers \
-f ubu24
# Assign JAX_IMAGE to the expected image name produced by the build script
JAX_IMAGE="jax-ubu24.rocm$(echo "${ROCM_VERSION_SHORT}" | tr -d '.'):latest"
echo "JAX_IMAGE=${JAX_IMAGE}" >> "$GITHUB_ENV"
echo "Built JAX Docker image: ${JAX_IMAGE}"
Comment on lines +138 to +151
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is a Docker image needed here?

If keeping the docker image, a tag other than :latest is probably called for (I assume this doesn't get uploaded anywhere but still safer to not have a test workflow directly set the tag to "latest")

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The core_tests and other GPU tests are ran using the docker image. Even if we test the docker image to latest or not the latest, it should be any difference as its running inside a pod and will be decommissioned after this workflow is completed running.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this need to use a custom Docker image? I strongly suggest to follow the style used in testing torch wheels in TheRock. There a Docker image is only used to guarantee having a ROCm-free environment. Packages to test should than be installed via pip. See https://github.com/ROCm/TheRock/blob/main/.github/workflows/test_pytorch_wheels.yml

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a custom docker image. Its the same docker file which jax upstream supports and the tests which we are running. The tests currently are running on a docker image based. I can have a new issue which is to enhance it the same way as torch. Most of the tests support only docker image based even the performance and accuracy tests. for JAX-MAX training. Even for the pytorch we have to build the docker image to support performance training pytorch which is PRIMUS.


- name: Checkout JAX tests repo (for extended tests)
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
repository: rocm/jax
ref: ${{ inputs.jax_ref }}
path: jax/jax_tests

- name: Run unit tests
working-directory: jax
env:
ROCM_VERSION: ${{ inputs.rocm_version }}
run: |
python3 build/ci_build test \
"${JAX_IMAGE}" \
--test-cmd "pytest jax_tests/tests/core_test.py"
109 changes: 109 additions & 0 deletions build_tools/fetch_wheels.py
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this script should be needed, see my other comment.

If we do keep this script, some high level comments:

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will need this script and will update as per your comments.

Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#!/usr/bin/env python3
"""
Download .whl files from a CloudFront URL amdgpu_family into a specified directory.

USAGE:
python download_from_cloudfront.py \
--cloudfront_url ${PACKAGE_INDEX_URL} \
--amdgpu_family ${amdgpu_family} \
--dir ${WHEELHOUSE_DIR} \
--list_whls '["jax_pjrt.whl","jax_plugin.whl","jaxlib.whl"]'

EXAMPLE:
python download_from_cloudfront.py \
--cloudfront_url https://rocm.nightlies.amd.com/v2-staging \
--amdgpu_family gfx94X-dcgpu \
--dir wheelhouse \
--list_whls $jax_whl_list

This will:
- Create wheelhouse if it does not exist.
- Download:
https://rocm.nightlies.amd.com/v2-staging/gfx94X-dcgpu/jax_pjrt.whl
https://rocm.nightlies.amd.com/v2-staging/gfx94X-dcgpu/jax_plugin.whl
https://rocm.nightlies.amd.com/v2-staging/gfx94X-dcgpu/jaxlib.whl
- Save all files into wheelhouse
"""

import argparse
import ast
import logging
from pathlib import Path
from urllib.parse import quote, urljoin
import shutil
import requests
from github_actions.github_actions_utils import gha_append_step_summary

LOG = logging.getLogger(__name__)


def download_one(session, url: str, dest: Path, bufsize: int = 64 * 1024, timeout: int = 30):

# Stream-download url to dest using a temporary .part file.
dest.parent.mkdir(parents=True, exist_ok=True)
tmp = dest.with_suffix(dest.suffix + ".part")

with session.get(url, stream=True, timeout=timeout) as r:
r.raise_for_status()
r.raw.decode_content = True
with tmp.open("wb") as fh:
shutil.copyfileobj(r.raw, fh, length=bufsize)

tmp.replace(dest)


def main():
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")

p = argparse.ArgumentParser("Simple wheel downloader")
p.add_argument("--cloudfront_url", required=True)
p.add_argument("--amdgpu_family", required=True)
p.add_argument("--dir", required=True, dest="wheelhouse_dir")
p.add_argument(
"--list_whls",
required=True,
help='Python list literal of filenames, e.g. \'["jax_pjrt.whl","jax_plugin.whl","jaxlib.whl"]\'',
)
args = p.parse_args()

try:
whl_list = ast.literal_eval(args.list_whls)
if not isinstance(whl_list, list):
raise ValueError
except Exception:
LOG.error("Invalid --list_whls; provide a Python list literal like '[\"jax_pjrt.whl\",\"jax_plugin.whl\",\"jaxlib.whl\"]'")
raise SystemExit(1)

base = args.cloudfront_url.rstrip("/") + "/"
family = args.amdgpu_family.strip("/")
out_dir = Path(args.wheelhouse_dir)

session = requests.Session()

failures = []
for name in whl_list:
encoded = quote(name, safe="")
url = urljoin(base, f"{family}/{encoded}")
dest = out_dir / name

LOG.info("Downloading %s -> %s", url, dest)
try:
download_one(session, url, dest)
msg = f"Downloaded {name} -> {out_dir}"
LOG.info(msg)
gha_append_step_summary(msg)
except Exception as e:
msg = f"Failed {name}: {e}"
LOG.error(msg)
gha_append_step_summary(msg)
failures.append(name)

if failures:
LOG.error("Failed downloads: %s", ", ".join(failures))
raise SystemExit(1)

LOG.info("All downloads complete.")


if __name__ == "__main__":
main()
Loading
Loading