Skip to content

jax wheel tests#2247

Merged
kiran-thumma merged 40 commits into
mainfrom
users/kithumma/jax-test
Dec 4, 2025
Merged

jax wheel tests#2247
kiran-thumma merged 40 commits into
mainfrom
users/kithumma/jax-test

Conversation

@kiran-thumma
Copy link
Copy Markdown
Contributor

@kiran-thumma kiran-thumma commented Nov 21, 2025

Motivation

  • Add an automated GPU validation test that verifies built JAX wheel (.whl) files are functional.
  • We currently lack automated wheel-level GPU validation. This test will catch regressions in packaging, compatibility issues, missing binaries, or other wheel-level failures before wheels are promoted.
  • Faster feedback, reduces risk of shipping broken GPU wheels.

Technical Details

  • New workflow: test_linux_jax_wheels.yml
  • Test type: Core , utils, scipy stats and multi device tests
  • Purpose: Install built JAX wheels and run tests.
  • Diagram:
Screenshot 2025-11-23 140535

Test Plan

  • Local Verification
  • CI Verification

Test Result

Submission Checklist

@kiran-thumma
Copy link
Copy Markdown
Contributor Author

@kiran-thumma kiran-thumma marked this pull request as ready for review November 25, 2025 17:28
Comment thread external-builds/jax/requirements-jax.txt
Comment thread .github/workflows/test_jax_dockerfile.yml
Comment thread .github/workflows/test_linux_jax_wheels.yml Outdated
Comment thread .github/workflows/test_linux_jax_wheels.yml Outdated
Comment thread .github/workflows/test_linux_jax_wheels.yml Outdated
Comment thread .github/workflows/test_linux_jax_wheels.yml Outdated
Comment thread .github/workflows/test_linux_jax_wheels.yml Outdated
Comment thread .github/workflows/test_linux_jax_wheels.yml Outdated
Comment thread .github/workflows/test_linux_jax_wheels.yml Outdated
Comment on lines +165 to +168
# Extract JAX version from requirements.txt (e.g., "jax==0.8.0")
JAX_VERSION=$(tr -d ' ' < jax/build/requirements.txt \
| grep -E '^(jax|jaxlib)==[^#]+' | head -n1 | cut -d'=' -f3)
echo "JAX_VERSION=$JAX_VERSION" >> "$GITHUB_ENV"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we instead pull it from inputs.jax_ref ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, The exact version of each can only be retrieved from the requirements.txt. As the plan to support both master and a release branch. We will need the jax version from requirements.txt

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please add a bit more context to this?

you build jax with version x.y. but inputs.jax_ref is only used to get the jax checkout to fetch the correct tests?
how would the version look differently between master and test?
how do you then test a new jax version?

it would be also good not to use bash but some python code to exract the version. does something like

import jax
print(jax.__version__) 

work?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree with Laura, please add docs for this or a comment

as i was reading this, i am confused on why (also, there is complex bash too)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I will add comments. And wont work as we need the value prior to installing jax.
import jax
print(jax.__version__)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes let's rework this step (can be in a follow-up). There is too much bash magic here independent of the workflow inputs. We probably want to support building multiple jax versions (at least "stable" and "latest"), so we can't just use the version in a pinned requirements file.

This was referenced Nov 26, 2025
@kiran-thumma
Copy link
Copy Markdown
Contributor Author

Copy link
Copy Markdown
Contributor

@HereThereBeDragons HereThereBeDragons left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks so far for removing the docker image!

see below some questions and comments i have

Comment thread .github/workflows/build_linux_jax_wheels.yml
Comment thread .github/workflows/build_linux_jax_wheels.yml
Comment thread .github/workflows/build_linux_jax_wheels.yml Outdated
Comment thread .github/workflows/build_linux_jax_wheels.yml
Comment thread .github/workflows/test_linux_jax_wheels.yml
Comment thread .github/workflows/test_linux_jax_wheels.yml Outdated
Comment thread .github/workflows/test_linux_jax_wheels.yml Outdated
Comment on lines +165 to +168
# Extract JAX version from requirements.txt (e.g., "jax==0.8.0")
JAX_VERSION=$(tr -d ' ' < jax/build/requirements.txt \
| grep -E '^(jax|jaxlib)==[^#]+' | head -n1 | cut -d'=' -f3)
echo "JAX_VERSION=$JAX_VERSION" >> "$GITHUB_ENV"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please add a bit more context to this?

you build jax with version x.y. but inputs.jax_ref is only used to get the jax checkout to fetch the correct tests?
how would the version look differently between master and test?
how do you then test a new jax version?

it would be also good not to use bash but some python code to exract the version. does something like

import jax
print(jax.__version__) 

work?

Comment thread .github/workflows/test_linux_jax_wheels.yml
Comment thread .github/workflows/test_linux_jax_wheels.yml Outdated
Comment thread .github/workflows/test_jax_dockerfile.yml
Comment thread .github/workflows/build_linux_jax_wheels.yml
Comment thread .github/workflows/build_linux_jax_wheels.yml
Comment thread .github/workflows/test_linux_jax_wheels.yml
Comment thread .github/workflows/test_linux_jax_wheels.yml Outdated
Comment thread .github/workflows/test_linux_jax_wheels.yml
Comment thread .github/workflows/test_linux_jax_wheels.yml
Comment thread .github/workflows/test_linux_jax_wheels.yml Outdated
Comment on lines +165 to +168
# Extract JAX version from requirements.txt (e.g., "jax==0.8.0")
JAX_VERSION=$(tr -d ' ' < jax/build/requirements.txt \
| grep -E '^(jax|jaxlib)==[^#]+' | head -n1 | cut -d'=' -f3)
echo "JAX_VERSION=$JAX_VERSION" >> "$GITHUB_ENV"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree with Laura, please add docs for this or a comment

as i was reading this, i am confused on why (also, there is complex bash too)

Comment thread external-builds/jax/requirements-jax.txt Outdated
@ScottTodd
Copy link
Copy Markdown
Member

Please add a title for this PR, not "Users/kithumma/jax test". Also tag relevant issues tracking the work, like #1985.

@kiran-thumma kiran-thumma changed the title Users/kithumma/jax test jax wheel tests Nov 28, 2025
Copy link
Copy Markdown
Contributor

@geomin12 geomin12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looking good! a few clarifications are needed for my understanding and some housekeeping items, but aside from that, looks good

Comment thread .github/workflows/build_linux_jax_wheels.yml
Comment thread .github/workflows/build_linux_jax_wheels.yml
Comment thread .github/workflows/test_linux_jax_wheels.yml
Comment thread .github/workflows/test_linux_jax_wheels.yml
Comment thread .github/workflows/test_linux_jax_wheels.yml Outdated
Comment thread external-builds/jax/requirements-jax.txt
@kiran-thumma kiran-thumma requested a review from geomin12 December 1, 2025 22:23
Copy link
Copy Markdown
Contributor

@geomin12 geomin12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looking good! just some concerns about the inputs.ref as well as missing inputs in the test file

Comment thread .github/workflows/build_linux_jax_wheels.yml
Comment thread .github/workflows/build_linux_jax_wheels.yml
Comment thread .github/workflows/build_linux_jax_wheels.yml
Comment thread .github/workflows/test_linux_jax_wheels.yml
Comment thread .github/workflows/test_linux_jax_wheels.yml
Comment thread .github/workflows/test_linux_jax_wheels.yml
@kiran-thumma
Copy link
Copy Markdown
Contributor Author

@geomin12 Created an issue 2383 which is to refactor all the workflows with inputs.repository, inputs.ref and github.repository. Please add your suggestions in that issue what do you need instead of this logic.

@kiran-thumma kiran-thumma requested a review from geomin12 December 2, 2025 16:50
@geomin12
Copy link
Copy Markdown
Contributor

geomin12 commented Dec 2, 2025

@geomin12 Created an issue 2383 which is to refactor all the workflows with inputs.repository, inputs.ref and github.repository. Please add your suggestions in that issue what do you need instead of this logic.

that's fine to do cleanup in another PR, but my concern is this code won't work based on what I am seeing

  1. My first concern is the workflow_call repository and ref inputs. in the "test" workflow (that depends on build), ref is passed but repository is not passed. this will cause an error if a ref is passed but doesn't known which repository it belongs too
  2. {{ inputs.ref || '' }} will also cause an error. what if inputs ref is empty? an empty string ref is invalid and will also error
  3. repository and ref will default to ROCm/TheRock and whichever branch it is on during a workflow_call, which makes those two inputs unnecessary (i guess why add unnecessary inputs)

points 1 and 2 also contradict with point 3, which is why I'm curious on why we add those two inputs to workflow_call when they will always default (no point in extra unnecessary logic)

Copy link
Copy Markdown
Contributor

@geomin12 geomin12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good, i just see one bug then should be good to go

Comment thread .github/workflows/build_linux_jax_wheels.yml
Comment thread .github/workflows/test_linux_jax_wheels.yml
Comment thread .github/workflows/test_linux_jax_wheels.yml
Comment thread external-builds/jax/requirements-jax.txt Outdated
Comment on lines +190 to +195
- name: Run JAX tests
run: |
pytest jax/jax_tests/tests/multi_device_test.py -q --log-cli-level=INFO
pytest jax/jax_tests/tests/core_test.py -q --log-cli-level=INFO
pytest jax/jax_tests/tests/util_test.py -q --log-cli-level=INFO
pytest jax/jax_tests/tests/scipy_stats_test.py -q --log-cli-level=INFO
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will exit on the first test failure, when we probably want to keep going and report all failures together. Fine to improve in a follow-up.

Comment on lines +165 to +168
# Extract JAX version from requirements.txt (e.g., "jax==0.8.0")
JAX_VERSION=$(tr -d ' ' < jax/build/requirements.txt \
| grep -E '^(jax|jaxlib)==[^#]+' | head -n1 | cut -d'=' -f3)
echo "JAX_VERSION=$JAX_VERSION" >> "$GITHUB_ENV"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes let's rework this step (can be in a follow-up). There is too much bash magic here independent of the workflow inputs. We probably want to support building multiple jax versions (at least "stable" and "latest"), so we can't just use the version in a pinned requirements file.

Comment thread .github/workflows/test_linux_jax_wheels.yml Outdated
@kiran-thumma kiran-thumma force-pushed the users/kithumma/jax-test branch from 8e96f55 to 537526c Compare December 3, 2025 00:50
Copy link
Copy Markdown
Contributor

@geomin12 geomin12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good, just one fix! this will cause errors in test jax file

Comment thread .github/workflows/test_jax_dockerfile.yml
@kiran-thumma kiran-thumma requested a review from geomin12 December 3, 2025 20:22
Copy link
Copy Markdown
Contributor

@geomin12 geomin12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sweet, thanks for the updates! you can also close that other issue as the repository and ref is needed since this will be a part of TheRock releases

Copy link
Copy Markdown
Member

@ScottTodd ScottTodd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few small comments but otherwise LGTM

Comment thread .github/workflows/build_linux_jax_wheels.yml Outdated
Comment thread .github/workflows/test_jax_dockerfile.yml Outdated
Comment thread .github/workflows/test_linux_jax_wheels.yml
@kiran-thumma
Copy link
Copy Markdown
Contributor Author

A few small comments but otherwise LGTM

Updated all as suggested.

@kiran-thumma kiran-thumma dismissed HereThereBeDragons’s stale review December 4, 2025 17:33

Got approval from 2/3 reviewers

@kiran-thumma kiran-thumma merged commit d53002e into main Dec 4, 2025
9 checks passed
@github-project-automation github-project-automation Bot moved this from TODO to Done in TheRock Triage Dec 4, 2025
@kiran-thumma kiran-thumma deleted the users/kithumma/jax-test branch December 4, 2025 17:33
rponnuru5 pushed a commit that referenced this pull request Dec 9, 2025
## Motivation
- Add an automated GPU validation test that verifies built JAX wheel
(.whl) files are functional.
- We currently lack automated wheel-level GPU validation. This test will
catch regressions in packaging, compatibility issues, missing binaries,
or other wheel-level failures before wheels are promoted.
- Faster feedback, reduces risk of shipping broken GPU wheels.
<!-- Explain the purpose of this PR and the goals it aims to achieve.
-->

## Technical Details
- New workflow: test_linux_jax_wheels.yml
- Test type: Core , utils, scipy stats and multi device tests
- Purpose: Install built JAX wheels and run tests.
- Diagram:
<img width="170" height="359" alt="Screenshot 2025-11-23 140535"
src="https://github.com/user-attachments/assets/f7b757f3-7013-4253-ae5b-a3a12934e6b6"
/>

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

- Local Verification
- CI Verification

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result
- Workflow run: Workflow run:
https://github.com/ROCm/TheRock/actions/runs/19603110752/job/56139806408
<!-- Briefly summarize test outcomes. -->

## Submission Checklist
- [x] All tests pass on at least one GPU configuration in CI
- [x] Workflow has a manual dispatch input for testing arbitrary wheel
URLs

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants