Skip to content

Lowering: update fp64 dimension checks to match Triton support#35654

Merged
copybara-service[bot] merged 1 commit into
jax-ml:mainfrom
mwichro:extend_triton_dot_PR
Mar 9, 2026
Merged

Lowering: update fp64 dimension checks to match Triton support#35654
copybara-service[bot] merged 1 commit into
jax-ml:mainfrom
mwichro:extend_triton_dot_PR

Conversation

@mwichro
Copy link
Copy Markdown
Contributor

@mwichro mwichro commented Mar 5, 2026

Indirectly fixes #35529
The problem reported there was occurring during a small tensor contraction.
This PR updates the Triton lowering for dot_general so that MMA instructions are used. The previous assertions were too conservative -- Triton supports smaller MMAs. I have also found a bug in jax/_src/pallas/primitives.py, which is fixed now.

triton-lang/triton#7310

Details

  1. DType Degradation: Fixed a bug in jax/_src/pallas/primitives.py where pl.dot would naively default the output accumulator to float32 for non-integer types, even when inputs were float64. This led to compilation errors in Triton as the hardware layouts for mixed-precision f64xf64->f32 were not correctly mapped.
  2. Hardware Layout Constraints: Relaxed the artificial dim >= 16 check in jax/_src/pallas/triton/lowering.py and replaced it with a hardware-informed check for float64. While float32/float16 can handle smaller blocks via FMA fallback, float64 on NVIDIA MMAv2 requires a minimum of M>=16, N>=8, K>=16 per warp tile to avoid compiler segfaults.

Changes:

  • jax/_src/pallas/primitives.py: Updated dot to preserve float64 as the preferred element type when inputs are float64.
  • jax/_src/pallas/triton/lowering.py:
    • Removed the blanket min(shape) < 16 restriction.
    • Added specific guards for float64 to match NVIDIA MMAv2 layout requirements (M>=16, N>=8, K>=16).
    • Added detailed comments explaining the relationship between PTX instructions and JAX/Triton register counts.
  • tests/pallas/triton_pallas_test.py: Added a regression test test_dot_fp64_valid_dimensions to verify correct float64 behavior on supported shapes.

Summary:

  • All tests passed on A100 GPU with JAX_ENABLE_X64=1.
  • Verified that invalid float64 shapes now raise a descriptive ValueError instead of segfaulting the compiler.
  • Verified that float32 and float16 can now use shapes smaller than 16 (e.g., 8x8x8).

@google-cla
Copy link
Copy Markdown

google-cla Bot commented Mar 5, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances JAX's Pallas Triton lowering for dot_general operations, specifically improving float64 support and flexibility for matrix multiplications. It addresses a critical bug where float64 inputs were incorrectly cast to float32 in pl.dot and refines dimension checks to better match Triton's hardware capabilities, allowing smaller block sizes for float32/float16 while enforcing necessary constraints for float64 to ensure stability and prevent compiler failures.

Highlights

  • DType Degradation Fix: Corrected jax/_src/pallas/primitives.py to prevent float64 inputs from degrading to float32 in pl.dot output accumulators, resolving Triton compilation errors.
  • Relaxed Dimension Constraints: Updated jax/_src/pallas/triton/lowering.py by removing an overly conservative min(shape) < 16 restriction and implementing hardware-specific float64 dimension checks (M>=16, N>=8, K>=16) to align with NVIDIA MMAv2 requirements and prevent compiler segfaults.
  • New Regression Test: Added test_dot_fp64_valid_dimensions in tests/pallas/triton_pallas_test.py to validate correct float64 behavior for supported shapes.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • jax/_src/pallas/primitives.py
    • Modified dot function to correctly infer float64 as the output accumulator type when inputs are float64, preventing unintended float32 degradation.
  • jax/_src/pallas/triton/lowering.py
    • Removed the general min(shape) < 16 restriction for dot operations.
    • Introduced specific dimension checks for float64 inputs (M>=16, N>=8, K>=16) to align with NVIDIA MMAv2 requirements.
    • Added detailed comments explaining the rationale behind the float64 dimension constraints, linking them to PTX instructions and Triton's register usage.
    • Updated acc_dtype inference to include float64 as a valid accumulator type.
  • tests/pallas/triton_pallas_test.py
    • Added test_dot_fp64_valid_dimensions to verify float64 dot operations with valid dimensions, ensuring correct behavior and preventing regressions.
Activity
  • All tests passed on A100 GPU with JAX_ENABLE_X64=1.
  • Verified that invalid float64 shapes now raise a descriptive ValueError instead of causing a compiler segfault.
  • Confirmed that float32 and float16 can now use shapes smaller than 16 (e.g., 8x8x8).
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses issues with fp64 dot products in Triton lowering. The changes include fixing a dtype degradation bug in pallas.primitives.dot, correcting a shape validation check, and replacing a restrictive dimension check with a more precise, hardware-aware one for float64 operations. The code is clear, well-commented, and includes a regression test for the valid dimension case. I've added one suggestion to also include a test case for invalid dimensions to ensure the new error-raising logic is covered.

Note: Security Review did not run due to the size of the PR.

Comment thread tests/pallas/triton_pallas_test.py
…enforce minimum dimensions for float64 dot operation. test: add validation for float64 dot dimensions.
@mwichro mwichro force-pushed the extend_triton_dot_PR branch from e735695 to 35424bd Compare March 5, 2026 23:26
@mwichro
Copy link
Copy Markdown
Contributor Author

mwichro commented Mar 5, 2026

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses issues with fp64 dot product lowering in Triton. It correctly sets the accumulator dtype to float64 when appropriate, preventing dtype degradation. It also replaces a conservative dimension check with a more precise, hardware-informed check for fp64 operations, while also fixing a separate bug in a 2D shape validation check. The changes are accompanied by thorough regression tests. My review includes a suggestion to improve the maintainability of the new tests by making the hardware capability checks more explicit and less dependent on test method names.

Note: Security Review did not run due to the size of the PR.

Comment on lines +50 to +58
is_sm80_test = any(
getattr(self, "_testMethodName", "").startswith(prefix)
for prefix in (
"test_dot_f32_small_dimensions",
"test_dot_fp64_valid_dimensions",
"test_dot_fp64_invalid_dimensions",
)
)
min_compute = "8.0" if is_sm80_test else "9.0"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This method of checking test names in setUp to determine compute capability requirements is a bit fragile and can make test maintenance harder. If test methods are renamed, this logic will break silently.

A more robust approach would be to move the capability check into the test methods themselves, or use a decorator. For example, you could define a helper method:

def _require_compute_capability(self, min_version_str):
  if jtu.test_device_matches(['cuda']) and not jtu.is_cuda_compute_capability_at_least(min_version_str):
    self.skipTest(f"Requires CUDA compute capability >= {min_version_str}")

And then call it at the beginning of each relevant test:

def test_dot_f32_small_dimensions(self):
  self._require_compute_capability("8.0")
  # ...

This would make the requirements for each test more explicit and avoid coupling the test logic to test names.

Copy link
Copy Markdown
Contributor Author

@mwichro mwichro Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gemini-code-assist While I agree that this would be a better way, to make it properly all the tests in that file should be restructured in this way. This change would go way beyond the scope of this PR. I would rather do it in a separate PR.

@mwichro
Copy link
Copy Markdown
Contributor Author

mwichro commented Mar 6, 2026

The PR is ready for (human) reviews. I would appreciate running a full test suite, I could only check if it works for A100.

@superbobry superbobry requested review from chr1sj0nes and rdyro March 8, 2026 15:46
Copy link
Copy Markdown
Collaborator

@rdyro rdyro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me, let's run the tests, thanks!

@google-ml-butler google-ml-butler Bot added kokoro:force-run pull ready Ready for copybara import and testing labels Mar 8, 2026
@mwichro
Copy link
Copy Markdown
Contributor Author

mwichro commented Mar 8, 2026

The limits in fp64 MMA in this PR could be lower, but there is a bug in Triton that causes m8n8k4.row.col.f64.f64.f64.f64 to throw a segmentation fault. I will reproduce and submit the issue to Triton, maybe even try to fix it.

I don't get why ROC tests failed. This does not look like it's related to the changes in this PR.

@copybara-service copybara-service Bot merged commit fb6abfe into jax-ml:main Mar 9, 2026
29 of 30 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kokoro:force-run pull ready Ready for copybara import and testing

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Severe (5-10x) performance regression in JAX/Pallas Triton kernel (JAX 0.6.2 vs 0.8.0+)

2 participants