Lowering: update fp64 dimension checks to match Triton support#35654
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances JAX's Pallas Triton lowering for Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request correctly addresses issues with fp64 dot products in Triton lowering. The changes include fixing a dtype degradation bug in pallas.primitives.dot, correcting a shape validation check, and replacing a restrictive dimension check with a more precise, hardware-aware one for float64 operations. The code is clear, well-commented, and includes a regression test for the valid dimension case. I've added one suggestion to also include a test case for invalid dimensions to ensure the new error-raising logic is covered.
Note: Security Review did not run due to the size of the PR.
…enforce minimum dimensions for float64 dot operation. test: add validation for float64 dot dimensions.
e735695 to
35424bd
Compare
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request addresses issues with fp64 dot product lowering in Triton. It correctly sets the accumulator dtype to float64 when appropriate, preventing dtype degradation. It also replaces a conservative dimension check with a more precise, hardware-informed check for fp64 operations, while also fixing a separate bug in a 2D shape validation check. The changes are accompanied by thorough regression tests. My review includes a suggestion to improve the maintainability of the new tests by making the hardware capability checks more explicit and less dependent on test method names.
Note: Security Review did not run due to the size of the PR.
| is_sm80_test = any( | ||
| getattr(self, "_testMethodName", "").startswith(prefix) | ||
| for prefix in ( | ||
| "test_dot_f32_small_dimensions", | ||
| "test_dot_fp64_valid_dimensions", | ||
| "test_dot_fp64_invalid_dimensions", | ||
| ) | ||
| ) | ||
| min_compute = "8.0" if is_sm80_test else "9.0" |
There was a problem hiding this comment.
This method of checking test names in setUp to determine compute capability requirements is a bit fragile and can make test maintenance harder. If test methods are renamed, this logic will break silently.
A more robust approach would be to move the capability check into the test methods themselves, or use a decorator. For example, you could define a helper method:
def _require_compute_capability(self, min_version_str):
if jtu.test_device_matches(['cuda']) and not jtu.is_cuda_compute_capability_at_least(min_version_str):
self.skipTest(f"Requires CUDA compute capability >= {min_version_str}")And then call it at the beginning of each relevant test:
def test_dot_f32_small_dimensions(self):
self._require_compute_capability("8.0")
# ...This would make the requirements for each test more explicit and avoid coupling the test logic to test names.
There was a problem hiding this comment.
@gemini-code-assist While I agree that this would be a better way, to make it properly all the tests in that file should be restructured in this way. This change would go way beyond the scope of this PR. I would rather do it in a separate PR.
|
The PR is ready for (human) reviews. I would appreciate running a full test suite, I could only check if it works for A100. |
rdyro
left a comment
There was a problem hiding this comment.
This looks good to me, let's run the tests, thanks!
|
The limits in fp64 MMA in this PR could be lower, but there is a bug in Triton that causes I don't get why ROC tests failed. This does not look like it's related to the changes in this PR. |
Indirectly fixes #35529
The problem reported there was occurring during a small tensor contraction.
This PR updates the Triton lowering for
dot_generalso that MMA instructions are used. The previous assertions were too conservative -- Triton supports smaller MMAs. I have also found a bug injax/_src/pallas/primitives.py, which is fixed now.triton-lang/triton#7310
Details
jax/_src/pallas/primitives.pywherepl.dotwould naively default the output accumulator tofloat32for non-integer types, even when inputs werefloat64. This led to compilation errors in Triton as the hardware layouts for mixed-precisionf64xf64->f32were not correctly mapped.dim >= 16check injax/_src/pallas/triton/lowering.pyand replaced it with a hardware-informed check forfloat64. Whilefloat32/float16can handle smaller blocks via FMA fallback,float64on NVIDIA MMAv2 requires a minimum ofM>=16, N>=8, K>=16per warp tile to avoid compiler segfaults.Changes:
jax/_src/pallas/primitives.py: Updateddotto preservefloat64as the preferred element type when inputs arefloat64.jax/_src/pallas/triton/lowering.py:min(shape) < 16restriction.float64to match NVIDIAMMAv2layout requirements (M>=16, N>=8, K>=16).tests/pallas/triton_pallas_test.py: Added a regression testtest_dot_fp64_valid_dimensionsto verify correctfloat64behavior on supported shapes.Summary:
JAX_ENABLE_X64=1.float64shapes now raise a descriptiveValueErrorinstead of segfaulting the compiler.float32andfloat16can now use shapes smaller than 16 (e.g., 8x8x8).