jax wheel tests#2247
Conversation
| # Extract JAX version from requirements.txt (e.g., "jax==0.8.0") | ||
| JAX_VERSION=$(tr -d ' ' < jax/build/requirements.txt \ | ||
| | grep -E '^(jax|jaxlib)==[^#]+' | head -n1 | cut -d'=' -f3) | ||
| echo "JAX_VERSION=$JAX_VERSION" >> "$GITHUB_ENV" |
There was a problem hiding this comment.
could we instead pull it from inputs.jax_ref ?
There was a problem hiding this comment.
No, The exact version of each can only be retrieved from the requirements.txt. As the plan to support both master and a release branch. We will need the jax version from requirements.txt
There was a problem hiding this comment.
could you please add a bit more context to this?
you build jax with version x.y. but inputs.jax_ref is only used to get the jax checkout to fetch the correct tests?
how would the version look differently between master and test?
how do you then test a new jax version?
it would be also good not to use bash but some python code to exract the version. does something like
import jax
print(jax.__version__)
work?
There was a problem hiding this comment.
agree with Laura, please add docs for this or a comment
as i was reading this, i am confused on why (also, there is complex bash too)
There was a problem hiding this comment.
Sure, I will add comments. And wont work as we need the value prior to installing jax.
import jax
print(jax.__version__)
There was a problem hiding this comment.
Yes let's rework this step (can be in a follow-up). There is too much bash magic here independent of the workflow inputs. We probably want to support building multiple jax versions (at least "stable" and "latest"), so we can't just use the version in a pinned requirements file.
HereThereBeDragons
left a comment
There was a problem hiding this comment.
thanks so far for removing the docker image!
see below some questions and comments i have
| # Extract JAX version from requirements.txt (e.g., "jax==0.8.0") | ||
| JAX_VERSION=$(tr -d ' ' < jax/build/requirements.txt \ | ||
| | grep -E '^(jax|jaxlib)==[^#]+' | head -n1 | cut -d'=' -f3) | ||
| echo "JAX_VERSION=$JAX_VERSION" >> "$GITHUB_ENV" |
There was a problem hiding this comment.
could you please add a bit more context to this?
you build jax with version x.y. but inputs.jax_ref is only used to get the jax checkout to fetch the correct tests?
how would the version look differently between master and test?
how do you then test a new jax version?
it would be also good not to use bash but some python code to exract the version. does something like
import jax
print(jax.__version__)
work?
| # Extract JAX version from requirements.txt (e.g., "jax==0.8.0") | ||
| JAX_VERSION=$(tr -d ' ' < jax/build/requirements.txt \ | ||
| | grep -E '^(jax|jaxlib)==[^#]+' | head -n1 | cut -d'=' -f3) | ||
| echo "JAX_VERSION=$JAX_VERSION" >> "$GITHUB_ENV" |
There was a problem hiding this comment.
agree with Laura, please add docs for this or a comment
as i was reading this, i am confused on why (also, there is complex bash too)
|
Please add a title for this PR, not "Users/kithumma/jax test". Also tag relevant issues tracking the work, like #1985. |
geomin12
left a comment
There was a problem hiding this comment.
looking good! a few clarifications are needed for my understanding and some housekeeping items, but aside from that, looks good
geomin12
left a comment
There was a problem hiding this comment.
looking good! just some concerns about the inputs.ref as well as missing inputs in the test file
|
@geomin12 Created an issue 2383 which is to refactor all the workflows with |
that's fine to do cleanup in another PR, but my concern is this code won't work based on what I am seeing
points 1 and 2 also contradict with point 3, which is why I'm curious on why we add those two inputs to |
geomin12
left a comment
There was a problem hiding this comment.
looks good, i just see one bug then should be good to go
| - name: Run JAX tests | ||
| run: | | ||
| pytest jax/jax_tests/tests/multi_device_test.py -q --log-cli-level=INFO | ||
| pytest jax/jax_tests/tests/core_test.py -q --log-cli-level=INFO | ||
| pytest jax/jax_tests/tests/util_test.py -q --log-cli-level=INFO | ||
| pytest jax/jax_tests/tests/scipy_stats_test.py -q --log-cli-level=INFO |
There was a problem hiding this comment.
This will exit on the first test failure, when we probably want to keep going and report all failures together. Fine to improve in a follow-up.
| # Extract JAX version from requirements.txt (e.g., "jax==0.8.0") | ||
| JAX_VERSION=$(tr -d ' ' < jax/build/requirements.txt \ | ||
| | grep -E '^(jax|jaxlib)==[^#]+' | head -n1 | cut -d'=' -f3) | ||
| echo "JAX_VERSION=$JAX_VERSION" >> "$GITHUB_ENV" |
There was a problem hiding this comment.
Yes let's rework this step (can be in a follow-up). There is too much bash magic here independent of the workflow inputs. We probably want to support building multiple jax versions (at least "stable" and "latest"), so we can't just use the version in a pinned requirements file.
8e96f55 to
537526c
Compare
geomin12
left a comment
There was a problem hiding this comment.
sweet, thanks for the updates! you can also close that other issue as the repository and ref is needed since this will be a part of TheRock releases
ScottTodd
left a comment
There was a problem hiding this comment.
A few small comments but otherwise LGTM
Updated all as suggested. |
Got approval from 2/3 reviewers
## Motivation - Add an automated GPU validation test that verifies built JAX wheel (.whl) files are functional. - We currently lack automated wheel-level GPU validation. This test will catch regressions in packaging, compatibility issues, missing binaries, or other wheel-level failures before wheels are promoted. - Faster feedback, reduces risk of shipping broken GPU wheels. <!-- Explain the purpose of this PR and the goals it aims to achieve. --> ## Technical Details - New workflow: test_linux_jax_wheels.yml - Test type: Core , utils, scipy stats and multi device tests - Purpose: Install built JAX wheels and run tests. - Diagram: <img width="170" height="359" alt="Screenshot 2025-11-23 140535" src="https://github.com/user-attachments/assets/f7b757f3-7013-4253-ae5b-a3a12934e6b6" /> <!-- Explain the changes along with any relevant GitHub links. --> ## Test Plan - Local Verification - CI Verification <!-- Explain any relevant testing done to verify this PR. --> ## Test Result - Workflow run: Workflow run: https://github.com/ROCm/TheRock/actions/runs/19603110752/job/56139806408 <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [x] All tests pass on at least one GPU configuration in CI - [x] Workflow has a manual dispatch input for testing arbitrary wheel URLs - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist
All tests pass on at least one GPU configuration in CI
Workflow has a manual dispatch input for testing arbitrary wheel URLs
Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.