Skip to content

Fix CUTLASS FP8 gemm correctness issue on SM120/SM121 for shapes where N is not divisible by ScaleGranularityN.#2261

Merged
yzh119 merged 4 commits intoflashinfer-ai:mainfrom
yongwww:sm120_f8_gemm_fix
Dec 24, 2025
Merged

Fix CUTLASS FP8 gemm correctness issue on SM120/SM121 for shapes where N is not divisible by ScaleGranularityN.#2261
yzh119 merged 4 commits intoflashinfer-ai:mainfrom
yongwww:sm120_f8_gemm_fix

Conversation

@yongwww
Copy link
Copy Markdown
Member

@yongwww yongwww commented Dec 23, 2025

📌 Description

The SM120 CUTLASS blockwise gemm kernel requires dimensions like N to be multiples of 128 due to hardware constraints (https://github.com/NVIDIA/cutlass/blob/3f4c086d09bd1dc55defb955862f333893bbb28b/include/cutlass/gemm/collective/sm120_mma_tma_blockwise_scaling.hpp#L345C5-L346).

We met the shape a: torch.Size([1, 1, 2688]), b: torch.Size([1, 2688, 10304]), scale_a: torch.Size([]), scale_b: torch.Size([]), out: torch.Size([1, 1, 10304]), workspace_buffer: torch.Size([33554432]) from Nemotron-Nano-v3, where 10304 is not a multiple of 128, the cutlass gemm does not work for it properly. In this PR, we add a pad and slice to get it work.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

Release Notes

  • Bug Fixes

    • FP8 matrix operations on SM120/SM121 GPUs now support arbitrary input dimensions, removing the previous K dimension minimum requirement and enabling broader use cases.
  • Tests

    • Expanded test coverage for FP8 matrix operations with additional parameter combinations and improved hardware compatibility validation.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Dec 23, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds padding logic for SM120/SM121 CUTLASS groupwise FP8 NT GEMM operations, padding N and K dimensions to 128-element boundaries to support arbitrary dimensions. Removes prior k_dim >= 128 guard by unconditionally enabling the padded kernel path. Test coverage expanded with additional parameter combinations and SM120/SM121 conditional xfail removed.

Changes

Cohort / File(s) Change Summary
Core FP8 GEMM padding implementation
flashinfer/gemm/gemm_base.py
Introduces _pad_to_multiple helper to pad inputs A and B to align N and K to 128 boundaries. Updates gemm_fp8_nt_groupwise forward pass to use padded tensors and slice output back to original dimensions. Removes conditional k_dim >= 128 guard in SM120/SM121 kernel selection. Adjusts scale expansion shape calculations to use padded dimensions in 2D and BMM cases. Modifies _heuristic_func_bmm_fp8 to unconditionally add cutlass_sm12x for SM120/SM121 detection.
Test parameter expansion and guard removal
tests/gemm/test_bmm_fp8.py
Expands test parameter grids for m, n, k dimensions from two to three values per dimension, increasing test combinations. Removes prior conditional xfail block that gated SM120/SM121 CUTLASS path, allowing tests to run for these hardware configurations.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~22 minutes

Possibly related PRs

Suggested reviewers

  • bkryu
  • nvmbreughe
  • djmmoss
  • cyx-6
  • wenscarl
  • yzh119

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 16.67% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: fixing a CUTLASS FP8 GEMM correctness issue on SM120/SM121 for shapes where N is not divisible by a multiple of 128.
Description check ✅ Passed The PR description provides clear context for the fix: SM120 CUTLASS kernel constraints, concrete example shape from Nemotron-Nano-v3, and the padding/slicing solution approach.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @yongwww, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical correctness issue within the CUTLASS FP8 GEMM kernel, specifically impacting SM120/SM121 GPUs when processing input shapes where the N dimension is not a multiple of 128. The core of the solution involves implementing a robust padding and slicing strategy for input and output tensors. This ensures that the underlying hardware requirements for blockwise GEMM operations are met, thereby extending the applicability and reliability of FP8 GEMM to a wider range of tensor dimensions without compromising accuracy.

Highlights

  • Correctness Fix for CUTLASS FP8 GEMM: Resolved a correctness issue in the CUTLASS FP8 GEMM kernel on SM120/SM121 architectures where the N dimension of input tensors was not divisible by ScaleGranularityN (128), leading to incorrect results.
  • Padding and Slicing Implementation: Introduced a mechanism to pad the input tensors (A and B) and the output tensor to ensure their N and K dimensions are multiples of 128, satisfying hardware constraints. After the GEMM operation, the result is sliced back to the original N dimension.
  • Relaxed K-dimension Constraint: Removed the k_dim >= 128 constraint for the cutlass_sm12x backend selection in _heuristic_func_bmm_fp8, as the new padding logic now handles smaller K values gracefully.
  • Expanded Test Coverage: Updated test_bmm_fp8 to include problematic N and K dimensions (e.g., N=10304, K=2688) and removed the pytest.xfail marker for SM120/121 CUTLASS bmm_fp8, confirming the fix.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@yongwww
Copy link
Copy Markdown
Member Author

yongwww commented Dec 23, 2025

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !214 has been created, and the CI pipeline #40704245 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses a correctness issue with FP8 GEMM on SM120/SM121 architectures for specific input shapes by introducing padding for the 'N' and 'K' dimensions. The implementation is sound, handling both 2D and 3D tensors correctly, and includes necessary updates to tests, which now cover the previously failing cases. I have one suggestion to refactor a small piece of duplicated code to improve maintainability, but overall, the changes are solid and well-executed.

Comment on lines +267 to +292
if a.dim() == 2:
a_padded = a
if needs_k_padding:
a_padded = torch.nn.functional.pad(
a_padded.contiguous(), (0, k_padded - k_dim)
)
b_col_major_padded = torch.zeros(
(n_padded, k_padded),
dtype=b_col_major.dtype,
device=b_col_major.device,
)
b_col_major_padded[:n_dim, :k_dim].copy_(b_col_major)
else:
a_padded = a
if needs_k_padding:
a_padded = torch.nn.functional.pad(
a_padded.contiguous(), (0, k_padded - k_dim)
)

b_underlying_padded = torch.zeros(
(batch_size, n_padded, k_padded),
dtype=b_col_major.dtype,
device=b_col_major.device,
)
b_col_major_padded = b_underlying_padded.transpose(-2, -1)
b_col_major_padded[:, :k_dim, :n_dim].copy_(b_col_major)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's some code duplication in how a_padded is handled for 2D and 3D cases. You can hoist the padding logic for a out of the if a.dim() == 2: block to avoid repetition and improve maintainability.

                    a_padded = a
                    if needs_k_padding:
                        a_padded = torch.nn.functional.pad(
                            a_padded.contiguous(), (0, k_padded - k_dim)
                        )

                    if a.dim() == 2:
                        b_col_major_padded = torch.zeros(
                            (n_padded, k_padded),
                            dtype=b_col_major.dtype,
                            device=b_col_major.device,
                        )
                        b_col_major_padded[:n_dim, :k_dim].copy_(b_col_major)
                    else:
                        b_underlying_padded = torch.zeros(
                            (batch_size, n_padded, k_padded),
                            dtype=b_col_major.dtype,
                            device=b_col_major.device,
                        )
                        b_col_major_padded = b_underlying_padded.transpose(-2, -1)
                        b_col_major_padded[:, :k_dim, :n_dim].copy_(b_col_major)

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/gemm/gemm_base.py (1)

3415-3415: Remove debug print statement.

A debug print statement (print("GOT HERE")) has been left in the production code. This should be removed before merging as it will pollute logs in production environments.

🔎 Proposed fix
-    print("GOT HERE")
     m_grouped_fp8_gemm_nt_contiguous(
         (a, a_scale), (b, b_scale), out, m_indices, scale_granularity_mnk
     )
🧹 Nitpick comments (1)
flashinfer/gemm/gemm_base.py (1)

266-293: Consider optimizing the contiguous call and documenting padding overhead.

The padding logic is correct, but there are a few considerations:

  1. Line 271 & 282: The .contiguous() calls may be redundant if the tensors are already contiguous. Consider checking a.is_contiguous() first to avoid unnecessary copies.

  2. Zero tensor allocation: Creating zero-padded tensors (lines 273-278, 286-292) can be memory-intensive when padding adds significant dimensions. For example, if n=10304 is padded to n_padded=10368, this adds ~0.6% overhead, but if n=129 is padded to n_padded=256, this nearly doubles the memory and compute.

While this approach is correct and necessary for hardware constraints, consider adding a debug log or warning when padding overhead exceeds a threshold (e.g., >20% increase in dimensions).

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 25de38e and d81cc5d.

📒 Files selected for processing (2)
  • flashinfer/gemm/gemm_base.py
  • tests/gemm/test_bmm_fp8.py
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • flashinfer/gemm/gemm_base.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (7)
flashinfer/gemm/gemm_base.py (6)

248-260: LGTM! Clear padding logic for SM120 hardware constraints.

The helper function and padding computation correctly align N and K to 128-element boundaries as required by the SM120 CUTLASS blockwise scaling kernel.


294-308: LGTM! Efficient output padding strategy.

The code correctly creates a padded output tensor only when N-dimension padding is needed, avoiding unnecessary allocations for the common case.


309-334: LGTM! Scale tensor expansion correctly uses padded dimensions.

The scale computations properly use k_padded and n_padded to match the dimensions that the kernel will operate on, ensuring correct scaling behavior for the padded problem.


336-349: LGTM! Kernel invocation uses padded tensors correctly.

The call to gemm_fp8_nt_groupwise appropriately passes the padded tensors (a_padded, b_col_major_padded, out_padded) and scale parameters, ensuring the kernel operates on properly aligned dimensions.


350-356: Result slicing is correct; be aware of copy overhead.

The slicing logic correctly restores the original output dimensions when padding was applied. The copy_ operation is necessary to write results back to the user-provided output tensor, though it does add some overhead. This is an acceptable trade-off for correctness and API compatibility.


2337-2364: LGTM! Heuristic correctly reflects padding support.

Removing the k_dim >= 128 guard for SM120/SM121 is correct given the padding implementation. The comment clearly documents that padding now enables support for all K values. This aligns with the PR objective to handle shapes like k=2688 (not divisible by 128).

tests/gemm/test_bmm_fp8.py (1)

11-13: Excellent test coverage expansion for padding validation.

The expanded parameter ranges effectively test the padding implementation:

  • n=80: Tests N-dimension padding (80 → 128, ~60% overhead)
  • n=10304: Tests the specific Nemotron-Nano-v3 case mentioned in the PR (10304 → 10368, minimal overhead)
  • k=64: Tests K-dimension padding for small K (64 → 128, 100% overhead)
  • k=2688: Tests an already-aligned K dimension (2688 % 128 = 0)

This combination will exercise both the padding path and the fast path when no padding is needed, providing good validation of the fix.

Copy link
Copy Markdown
Collaborator

@bkryu bkryu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @yongwww the PR makes sense to me.

Let's wait for the unit test results to come back.

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I'm concerned about the performance of padding, at least it fixes the functionality issue.

Thanks for working on this PR.

@yzh119 yzh119 merged commit f3e036d into flashinfer-ai:main Dec 24, 2025
4 checks passed
@yongwww yongwww deleted the sm120_f8_gemm_fix branch December 24, 2025 21:28
bkryu pushed a commit that referenced this pull request Dec 25, 2025
<!-- .github/pull_request_template.md -->

## 📌 Description

saw some [test
failures](https://gitlab-master.nvidia.com/dl/flashinfer/flashinfer-ci/-/jobs/247866505)
on Blackwell boards after #2261, all the failed assertions are related
to the large value 10304.

Use `.float()` to help reduce precision loss during `cosine_similarity`
(`dot(x, y) / (||x|| * ||y||)`) check.

```
FAILED tests/gemm/test_bmm_fp8.py::test_bmm_fp8[True-cutlass-res_dtype1-mat2_dtype0-input_dtype0-256-10304-128-16] - AssertionError: assert tensor(0., device='cuda:0') > 0.99
2025-12-24T07:00:08.299846Z 01O FAILED tests/gemm/test_bmm_fp8.py::test_bmm_fp8[False-cudnn-res_dtype1-mat2_dtype0-input_dtype1-256-10304-128-16] - AssertionError: assert tensor(0., device='cuda:0') > 0.99
... # the failure occurs for all backend (cutlass, cudnn, etc)
```

cc: @zihaoye @bkryu 


## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Tests**
* Improved test accuracy by ensuring tensor comparisons use
floating-point precision for cosine similarity calculations.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants