-
Notifications
You must be signed in to change notification settings - Fork 245
add jax core test #2035
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add jax core test #2035
Changes from all commits
2bcfbe1
8ea5d5d
c133ed4
f610229
1d62f84
d453f80
cb198c5
9122dd8
39b53d1
377fccf
ffcd958
eea77fc
c66d9d4
9996cab
9cc5d1a
cb02054
f97d78b
cc2c42a
652556c
deeeb4f
96d09f1
8da0444
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,167 @@ | ||||||||||||||||
| name: Test JAX Wheels | ||||||||||||||||
| on: | ||||||||||||||||
| workflow_call: | ||||||||||||||||
| inputs: | ||||||||||||||||
| amdgpu_family: | ||||||||||||||||
| description: GPU family to test | ||||||||||||||||
| required: true | ||||||||||||||||
| type: string | ||||||||||||||||
| release_type: | ||||||||||||||||
| description: The type of release | ||||||||||||||||
| required: true | ||||||||||||||||
| type: string | ||||||||||||||||
| rocm_version: | ||||||||||||||||
| description: ROCm version to use in Docker | ||||||||||||||||
| required: true | ||||||||||||||||
| type: string | ||||||||||||||||
| tar_url: | ||||||||||||||||
| description: URL to TheRock tarball to build against | ||||||||||||||||
| required: true | ||||||||||||||||
| type: string | ||||||||||||||||
| test_runs_on: | ||||||||||||||||
| required: true | ||||||||||||||||
| type: string | ||||||||||||||||
| default: "linux-mi325-1gpu-ossci-rocm" | ||||||||||||||||
| python_versions: | ||||||||||||||||
| description: Python version string used to compute CP tag | ||||||||||||||||
| required: true | ||||||||||||||||
| type: string | ||||||||||||||||
| jax_ref: | ||||||||||||||||
| description: Branch/tag/sha of rocm/rocm-jax to build | ||||||||||||||||
| required: false | ||||||||||||||||
| type: string | ||||||||||||||||
| default: "rocm-jaxlib-v0.7.1" | ||||||||||||||||
| package_index_url: | ||||||||||||||||
| required: true | ||||||||||||||||
| type: string | ||||||||||||||||
| jax_whl_list: | ||||||||||||||||
| description: list of built jax wheel filenames to fetch & test | ||||||||||||||||
| required: true | ||||||||||||||||
| type: string | ||||||||||||||||
|
|
||||||||||||||||
| workflow_dispatch: | ||||||||||||||||
| inputs: | ||||||||||||||||
| amdgpu_family: | ||||||||||||||||
| type: choice | ||||||||||||||||
| options: | ||||||||||||||||
| - gfx110X-dgpu | ||||||||||||||||
| - gfx1151 | ||||||||||||||||
| - gfx120X-all | ||||||||||||||||
| - gfx94X-dcgpu | ||||||||||||||||
| - gfx950-dcgpu | ||||||||||||||||
| default: gfx94X-dcgpu | ||||||||||||||||
| release_type: | ||||||||||||||||
| type: choice | ||||||||||||||||
| options: [dev, nightly, prerelease] | ||||||||||||||||
| default: dev | ||||||||||||||||
| rocm_version: | ||||||||||||||||
| type: string | ||||||||||||||||
| required: true | ||||||||||||||||
| tar_url: | ||||||||||||||||
| type: string | ||||||||||||||||
| required: true | ||||||||||||||||
| test_runs_on: | ||||||||||||||||
| description: Runner label to use. The selected runner should have a GPU supported by amdgpu_family | ||||||||||||||||
| required: true | ||||||||||||||||
| type: string | ||||||||||||||||
| default: "linux-mi325-1gpu-ossci-rocm" | ||||||||||||||||
| python_versions: | ||||||||||||||||
| type: string | ||||||||||||||||
| required: true | ||||||||||||||||
| default: "3.12" | ||||||||||||||||
| jax_ref: | ||||||||||||||||
| type: string | ||||||||||||||||
| default: "rocm-jaxlib-v0.7.1" | ||||||||||||||||
| package_index_url: | ||||||||||||||||
| description: Base Python package index URL to test, typically nightly/dev URL with a "v2" or "v2-staging" subdir (without a GPU family subdir) | ||||||||||||||||
| required: true | ||||||||||||||||
| type: string | ||||||||||||||||
| default: "https://rocm.nightlies.amd.com/v2" | ||||||||||||||||
| jax_whl_list: | ||||||||||||||||
| description: list of built jax wheel filenames to fetch & test | ||||||||||||||||
| required: true | ||||||||||||||||
| type: string | ||||||||||||||||
|
|
||||||||||||||||
| permissions: | ||||||||||||||||
| id-token: write | ||||||||||||||||
| contents: read | ||||||||||||||||
| packages: none | ||||||||||||||||
|
|
||||||||||||||||
| jobs: | ||||||||||||||||
| test_jax_wheels: | ||||||||||||||||
| name: Test | ${{ inputs.amdgpu_family }} | ||||||||||||||||
| runs-on: ${{ inputs.test_runs_on }} | ||||||||||||||||
| env: | ||||||||||||||||
| WHEELHOUSE_DIR: ${{ github.workspace }}/jax/wheelhouse | ||||||||||||||||
| steps: | ||||||||||||||||
| - name: Checkout TheRock | ||||||||||||||||
| uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 | ||||||||||||||||
| with: | ||||||||||||||||
| repository: ${{ inputs.repository || github.repository }} | ||||||||||||||||
| ref: ${{ inputs.ref || '' }} | ||||||||||||||||
|
|
||||||||||||||||
| - name: Checkout rocm-jax (plugin + build scripts) | ||||||||||||||||
| uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 | ||||||||||||||||
| with: | ||||||||||||||||
| path: jax | ||||||||||||||||
| repository: rocm/rocm-jax | ||||||||||||||||
| ref: ${{ inputs.jax_ref }} | ||||||||||||||||
|
|
||||||||||||||||
| - name: Setup Python | ||||||||||||||||
| uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 | ||||||||||||||||
| with: | ||||||||||||||||
| python-version: "3.12" | ||||||||||||||||
|
|
||||||||||||||||
| - name: Prepare wheelhouse and download JAX wheels | ||||||||||||||||
| shell: bash | ||||||||||||||||
| working-directory: jax | ||||||||||||||||
| run: | | ||||||||||||||||
| python3 -m pip install --upgrade pip | ||||||||||||||||
| python3 -m pip install requests | ||||||||||||||||
| python3 ../build_tools/fetch_wheels.py \ | ||||||||||||||||
| --cloudfront_url ${{ inputs.package_index_url }} \ | ||||||||||||||||
| --amdgpu_family ${{ inputs.amdgpu_family }} \ | ||||||||||||||||
| --dir "wheelhouse" \ | ||||||||||||||||
| --list_whls '${{ inputs.jax_whl_list }}' | ||||||||||||||||
|
Comment on lines
+115
to
+125
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is this actually doing with the downloaded wheels? Is the Can this use TheRock/.github/workflows/test_pytorch_wheels.yml Lines 112 to 118 in 2e9e168
Here are some logs of that script running in the workflow: https://github.com/ROCm/TheRock/actions/runs/19228713043/job/54971284725#step:7:22
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, we cannot use a setup_venv.py. As we have to change the docker build steps all that if we have to use a venv. This is the same way just a wheelhouse directory.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is Docker a requirement? I still haven't seen a clear answer to that. What are the user install instructions like? Are they more than
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it does has other requriements to be configured and not the whole code is in theROCk as we are using this workflow as a wrapper to the rocm-jax and upstream code. I have update the script as per suggestion to use requests instead of wget.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ScottTodd It's not a requirement for running tests. This is just the way that we currently do it in our CI because some Ubuntu docker images are part of our deliverables. We do everything through the
There isn't a step in our CI that installs the wheels (and requirements like ROCm or pytest) with pip and then runs the tests. You can do this, of course, but we just don't have an easy command for it in |
||||||||||||||||
|
|
||||||||||||||||
| - name: Compute ROCM_VERSION_SHORT | ||||||||||||||||
| working-directory: jax | ||||||||||||||||
| env: | ||||||||||||||||
| ROCM_VERSION: ${{ inputs.rocm_version }} | ||||||||||||||||
| run: | | ||||||||||||||||
| # Extract major.minor.patch from ROCM_VERSION | ||||||||||||||||
| # sed command captures first three numeric components separated by dots and reconstructs them as major.minor.patch | ||||||||||||||||
| ROCM_VERSION_SHORT="$(echo "${ROCM_VERSION}" | sed -E 's/^([0-9]+)\.([0-9]+)\.([0-9]+).*/\1.\2.\3/')" | ||||||||||||||||
| echo "ROCM_VERSION_SHORT=${ROCM_VERSION_SHORT}" >> "$GITHUB_ENV" | ||||||||||||||||
| echo "Using semantic ROCm version: ${ROCM_VERSION_SHORT}" | ||||||||||||||||
|
Comment on lines
+127
to
+136
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This code is suspicious and I don't think it should be needed. At a minimum there should be a comment here explaining what problem this code is solving.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I'm passing the full rocmversion docker fails to add that as a tag. Added more comments.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is no need to compute a tag. Rather spin up a docker, pip install the jax wheels, run tests and that's it. No need to build a docker container.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Either way its the same. Goal is to test the wheels. The tests currently which jax team is supporting are these tests.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like JAX has these docs, which do suggest running in a container: https://github.com/ROCm/rocm-jax/blob/master/BUILDING.md#3-running-tests, but I'm not seeing a technical explanation there for why tests are difficult to run outside of a container. We are strongly discouraging Docker and container usage in new packaging in TheRock and this is the time to investigate those details. Again, and I'm getting frustrated repeating myself - this type of work needs to start with a design discussion, before we get into details during code review. We've wasted multiple weeks now going back and forth without seeing a plan written down. At a minimum provide links to pages like these so we can discuss the design points - that is your job as the PR author and not something reviewers should need to research on their own:
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They're not terribly difficult, it just requires the right dependencies. Building the image and then running with that is just what we do in our current CI setup. See my other comment. I wholeheartedly agree that doing some design work and getting clear on how TheRock expects frameworks to build and test is much needed, and I'd be happy to discuss it somewhere that's not a PR comment thread. |
||||||||||||||||
|
|
||||||||||||||||
| - name: Build JAX Docker image | ||||||||||||||||
| working-directory: jax | ||||||||||||||||
| env: | ||||||||||||||||
| ROCM_VERSION: ${{ inputs.rocm_version }} | ||||||||||||||||
| run: | | ||||||||||||||||
| python3 build/ci_build \ | ||||||||||||||||
| --rocm-version="${ROCM_VERSION_SHORT}" \ | ||||||||||||||||
| --therock-path="${{ inputs.tar_url }}" \ | ||||||||||||||||
| build_dockers \ | ||||||||||||||||
| -f ubu24 | ||||||||||||||||
| # Assign JAX_IMAGE to the expected image name produced by the build script | ||||||||||||||||
| JAX_IMAGE="jax-ubu24.rocm$(echo "${ROCM_VERSION_SHORT}" | tr -d '.'):latest" | ||||||||||||||||
| echo "JAX_IMAGE=${JAX_IMAGE}" >> "$GITHUB_ENV" | ||||||||||||||||
| echo "Built JAX Docker image: ${JAX_IMAGE}" | ||||||||||||||||
|
Comment on lines
+138
to
+151
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is a Docker image needed here? If keeping the docker image, a tag other than
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The core_tests and other GPU tests are ran using the docker image. Even if we test the docker image to latest or not the latest, it should be any difference as its running inside a pod and will be decommissioned after this workflow is completed running.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does this need to use a custom Docker image? I strongly suggest to follow the style used in testing torch wheels in TheRock. There a Docker image is only used to guarantee having a ROCm-free environment. Packages to test should than be installed via pip. See https://github.com/ROCm/TheRock/blob/main/.github/workflows/test_pytorch_wheels.yml
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not a custom docker image. Its the same docker file which jax upstream supports and the tests which we are running. The tests currently are running on a docker image based. I can have a new issue which is to enhance it the same way as torch. Most of the tests support only docker image based even the performance and accuracy tests. for JAX-MAX training. Even for the pytorch we have to build the docker image to support performance training pytorch which is PRIMUS. |
||||||||||||||||
|
|
||||||||||||||||
| - name: Checkout JAX tests repo (for extended tests) | ||||||||||||||||
| uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 | ||||||||||||||||
| with: | ||||||||||||||||
| repository: rocm/jax | ||||||||||||||||
| ref: ${{ inputs.jax_ref }} | ||||||||||||||||
| path: jax/jax_tests | ||||||||||||||||
|
|
||||||||||||||||
| - name: Run unit tests | ||||||||||||||||
| working-directory: jax | ||||||||||||||||
| env: | ||||||||||||||||
| ROCM_VERSION: ${{ inputs.rocm_version }} | ||||||||||||||||
| run: | | ||||||||||||||||
| python3 build/ci_build test \ | ||||||||||||||||
| "${JAX_IMAGE}" \ | ||||||||||||||||
| --test-cmd "pytest jax_tests/tests/core_test.py" | ||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this script should be needed, see my other comment. If we do keep this script, some high level comments:
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We will need this script and will update as per your comments. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,109 @@ | ||
| #!/usr/bin/env python3 | ||
| """ | ||
| Download .whl files from a CloudFront URL amdgpu_family into a specified directory. | ||
|
|
||
| USAGE: | ||
| python download_from_cloudfront.py \ | ||
| --cloudfront_url ${PACKAGE_INDEX_URL} \ | ||
| --amdgpu_family ${amdgpu_family} \ | ||
| --dir ${WHEELHOUSE_DIR} \ | ||
| --list_whls '["jax_pjrt.whl","jax_plugin.whl","jaxlib.whl"]' | ||
|
|
||
| EXAMPLE: | ||
| python download_from_cloudfront.py \ | ||
| --cloudfront_url https://rocm.nightlies.amd.com/v2-staging \ | ||
| --amdgpu_family gfx94X-dcgpu \ | ||
| --dir wheelhouse \ | ||
| --list_whls $jax_whl_list | ||
|
|
||
| This will: | ||
| - Create wheelhouse if it does not exist. | ||
| - Download: | ||
| https://rocm.nightlies.amd.com/v2-staging/gfx94X-dcgpu/jax_pjrt.whl | ||
| https://rocm.nightlies.amd.com/v2-staging/gfx94X-dcgpu/jax_plugin.whl | ||
| https://rocm.nightlies.amd.com/v2-staging/gfx94X-dcgpu/jaxlib.whl | ||
| - Save all files into wheelhouse | ||
| """ | ||
|
|
||
| import argparse | ||
| import ast | ||
| import logging | ||
| from pathlib import Path | ||
| from urllib.parse import quote, urljoin | ||
| import shutil | ||
| import requests | ||
| from github_actions.github_actions_utils import gha_append_step_summary | ||
|
|
||
| LOG = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def download_one(session, url: str, dest: Path, bufsize: int = 64 * 1024, timeout: int = 30): | ||
|
|
||
| # Stream-download url to dest using a temporary .part file. | ||
| dest.parent.mkdir(parents=True, exist_ok=True) | ||
| tmp = dest.with_suffix(dest.suffix + ".part") | ||
|
|
||
| with session.get(url, stream=True, timeout=timeout) as r: | ||
| r.raise_for_status() | ||
| r.raw.decode_content = True | ||
| with tmp.open("wb") as fh: | ||
| shutil.copyfileobj(r.raw, fh, length=bufsize) | ||
|
|
||
| tmp.replace(dest) | ||
|
|
||
|
|
||
| def main(): | ||
| logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") | ||
|
|
||
| p = argparse.ArgumentParser("Simple wheel downloader") | ||
| p.add_argument("--cloudfront_url", required=True) | ||
| p.add_argument("--amdgpu_family", required=True) | ||
| p.add_argument("--dir", required=True, dest="wheelhouse_dir") | ||
| p.add_argument( | ||
| "--list_whls", | ||
| required=True, | ||
| help='Python list literal of filenames, e.g. \'["jax_pjrt.whl","jax_plugin.whl","jaxlib.whl"]\'', | ||
| ) | ||
| args = p.parse_args() | ||
|
|
||
| try: | ||
| whl_list = ast.literal_eval(args.list_whls) | ||
| if not isinstance(whl_list, list): | ||
| raise ValueError | ||
| except Exception: | ||
| LOG.error("Invalid --list_whls; provide a Python list literal like '[\"jax_pjrt.whl\",\"jax_plugin.whl\",\"jaxlib.whl\"]'") | ||
| raise SystemExit(1) | ||
|
|
||
| base = args.cloudfront_url.rstrip("/") + "/" | ||
| family = args.amdgpu_family.strip("/") | ||
| out_dir = Path(args.wheelhouse_dir) | ||
|
|
||
| session = requests.Session() | ||
|
|
||
| failures = [] | ||
| for name in whl_list: | ||
| encoded = quote(name, safe="") | ||
| url = urljoin(base, f"{family}/{encoded}") | ||
| dest = out_dir / name | ||
|
|
||
| LOG.info("Downloading %s -> %s", url, dest) | ||
| try: | ||
| download_one(session, url, dest) | ||
| msg = f"Downloaded {name} -> {out_dir}" | ||
| LOG.info(msg) | ||
| gha_append_step_summary(msg) | ||
| except Exception as e: | ||
| msg = f"Failed {name}: {e}" | ||
| LOG.error(msg) | ||
| gha_append_step_summary(msg) | ||
| failures.append(name) | ||
|
|
||
| if failures: | ||
| LOG.error("Failed downloads: %s", ", ".join(failures)) | ||
| raise SystemExit(1) | ||
|
|
||
| LOG.info("All downloads complete.") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please follow the style of https://github.com/ROCm/TheRock/blob/main/build_tools/github_actions/write_torch_versions.py here with a
build_tools/github_actions/write_jax_versions.pyUsage:
TheRock/.github/workflows/build_portable_linux_pytorch_wheels.yml
Lines 177 to 188 in 2e9e168
TheRock/.github/workflows/build_portable_linux_pytorch_wheels.yml
Lines 117 to 122 in 2e9e168
TheRock/.github/workflows/test_pytorch_wheels.yml
Lines 80 to 82 in 2e9e168
TheRock/.github/workflows/test_pytorch_wheels.yml
Lines 112 to 118 in 2e9e168
Note that the versions of the files are what is computed and passed through the workflows, not their filenames. Then
pipcommands use the versions to download or install wheels with those versions (as users would do outside of github actions workflows)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will think of adding a logging or summary later as part of enhancements as it already has the logs.