Fix CUTLASS FP8 gemm correctness issue on SM120/SM121 for shapes where N is not divisible by ScaleGranularityN.#2261
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds padding logic for SM120/SM121 CUTLASS groupwise FP8 NT GEMM operations, padding N and K dimensions to 128-element boundaries to support arbitrary dimensions. Removes prior Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~22 minutes Possibly related PRs
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @yongwww, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical correctness issue within the CUTLASS FP8 GEMM kernel, specifically impacting SM120/SM121 GPUs when processing input shapes where the N dimension is not a multiple of 128. The core of the solution involves implementing a robust padding and slicing strategy for input and output tensors. This ensures that the underlying hardware requirements for blockwise GEMM operations are met, thereby extending the applicability and reliability of FP8 GEMM to a wider range of tensor dimensions without compromising accuracy. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
/bot run |
There was a problem hiding this comment.
Code Review
This pull request effectively addresses a correctness issue with FP8 GEMM on SM120/SM121 architectures for specific input shapes by introducing padding for the 'N' and 'K' dimensions. The implementation is sound, handling both 2D and 3D tensors correctly, and includes necessary updates to tests, which now cover the previously failing cases. I have one suggestion to refactor a small piece of duplicated code to improve maintainability, but overall, the changes are solid and well-executed.
| if a.dim() == 2: | ||
| a_padded = a | ||
| if needs_k_padding: | ||
| a_padded = torch.nn.functional.pad( | ||
| a_padded.contiguous(), (0, k_padded - k_dim) | ||
| ) | ||
| b_col_major_padded = torch.zeros( | ||
| (n_padded, k_padded), | ||
| dtype=b_col_major.dtype, | ||
| device=b_col_major.device, | ||
| ) | ||
| b_col_major_padded[:n_dim, :k_dim].copy_(b_col_major) | ||
| else: | ||
| a_padded = a | ||
| if needs_k_padding: | ||
| a_padded = torch.nn.functional.pad( | ||
| a_padded.contiguous(), (0, k_padded - k_dim) | ||
| ) | ||
|
|
||
| b_underlying_padded = torch.zeros( | ||
| (batch_size, n_padded, k_padded), | ||
| dtype=b_col_major.dtype, | ||
| device=b_col_major.device, | ||
| ) | ||
| b_col_major_padded = b_underlying_padded.transpose(-2, -1) | ||
| b_col_major_padded[:, :k_dim, :n_dim].copy_(b_col_major) |
There was a problem hiding this comment.
There's some code duplication in how a_padded is handled for 2D and 3D cases. You can hoist the padding logic for a out of the if a.dim() == 2: block to avoid repetition and improve maintainability.
a_padded = a
if needs_k_padding:
a_padded = torch.nn.functional.pad(
a_padded.contiguous(), (0, k_padded - k_dim)
)
if a.dim() == 2:
b_col_major_padded = torch.zeros(
(n_padded, k_padded),
dtype=b_col_major.dtype,
device=b_col_major.device,
)
b_col_major_padded[:n_dim, :k_dim].copy_(b_col_major)
else:
b_underlying_padded = torch.zeros(
(batch_size, n_padded, k_padded),
dtype=b_col_major.dtype,
device=b_col_major.device,
)
b_col_major_padded = b_underlying_padded.transpose(-2, -1)
b_col_major_padded[:, :k_dim, :n_dim].copy_(b_col_major)There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/gemm/gemm_base.py (1)
3415-3415: Remove debug print statement.A debug print statement (
print("GOT HERE")) has been left in the production code. This should be removed before merging as it will pollute logs in production environments.🔎 Proposed fix
- print("GOT HERE") m_grouped_fp8_gemm_nt_contiguous( (a, a_scale), (b, b_scale), out, m_indices, scale_granularity_mnk )
🧹 Nitpick comments (1)
flashinfer/gemm/gemm_base.py (1)
266-293: Consider optimizing the contiguous call and documenting padding overhead.The padding logic is correct, but there are a few considerations:
Line 271 & 282: The
.contiguous()calls may be redundant if the tensors are already contiguous. Consider checkinga.is_contiguous()first to avoid unnecessary copies.Zero tensor allocation: Creating zero-padded tensors (lines 273-278, 286-292) can be memory-intensive when padding adds significant dimensions. For example, if
n=10304is padded ton_padded=10368, this adds ~0.6% overhead, but ifn=129is padded ton_padded=256, this nearly doubles the memory and compute.While this approach is correct and necessary for hardware constraints, consider adding a debug log or warning when padding overhead exceeds a threshold (e.g., >20% increase in dimensions).
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
flashinfer/gemm/gemm_base.pytests/gemm/test_bmm_fp8.py
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
flashinfer/gemm/gemm_base.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (7)
flashinfer/gemm/gemm_base.py (6)
248-260: LGTM! Clear padding logic for SM120 hardware constraints.The helper function and padding computation correctly align N and K to 128-element boundaries as required by the SM120 CUTLASS blockwise scaling kernel.
294-308: LGTM! Efficient output padding strategy.The code correctly creates a padded output tensor only when N-dimension padding is needed, avoiding unnecessary allocations for the common case.
309-334: LGTM! Scale tensor expansion correctly uses padded dimensions.The scale computations properly use
k_paddedandn_paddedto match the dimensions that the kernel will operate on, ensuring correct scaling behavior for the padded problem.
336-349: LGTM! Kernel invocation uses padded tensors correctly.The call to
gemm_fp8_nt_groupwiseappropriately passes the padded tensors (a_padded,b_col_major_padded,out_padded) and scale parameters, ensuring the kernel operates on properly aligned dimensions.
350-356: Result slicing is correct; be aware of copy overhead.The slicing logic correctly restores the original output dimensions when padding was applied. The
copy_operation is necessary to write results back to the user-provided output tensor, though it does add some overhead. This is an acceptable trade-off for correctness and API compatibility.
2337-2364: LGTM! Heuristic correctly reflects padding support.Removing the
k_dim >= 128guard for SM120/SM121 is correct given the padding implementation. The comment clearly documents that padding now enables support for all K values. This aligns with the PR objective to handle shapes likek=2688(not divisible by 128).tests/gemm/test_bmm_fp8.py (1)
11-13: Excellent test coverage expansion for padding validation.The expanded parameter ranges effectively test the padding implementation:
- n=80: Tests N-dimension padding (80 → 128, ~60% overhead)
- n=10304: Tests the specific Nemotron-Nano-v3 case mentioned in the PR (10304 → 10368, minimal overhead)
- k=64: Tests K-dimension padding for small K (64 → 128, 100% overhead)
- k=2688: Tests an already-aligned K dimension (2688 % 128 = 0)
This combination will exercise both the padding path and the fast path when no padding is needed, providing good validation of the fix.
yzh119
left a comment
There was a problem hiding this comment.
While I'm concerned about the performance of padding, at least it fixes the functionality issue.
Thanks for working on this PR.
<!-- .github/pull_request_template.md --> ## 📌 Description saw some [test failures](https://gitlab-master.nvidia.com/dl/flashinfer/flashinfer-ci/-/jobs/247866505) on Blackwell boards after #2261, all the failed assertions are related to the large value 10304. Use `.float()` to help reduce precision loss during `cosine_similarity` (`dot(x, y) / (||x|| * ||y||)`) check. ``` FAILED tests/gemm/test_bmm_fp8.py::test_bmm_fp8[True-cutlass-res_dtype1-mat2_dtype0-input_dtype0-256-10304-128-16] - AssertionError: assert tensor(0., device='cuda:0') > 0.99 2025-12-24T07:00:08.299846Z 01O FAILED tests/gemm/test_bmm_fp8.py::test_bmm_fp8[False-cudnn-res_dtype1-mat2_dtype0-input_dtype1-256-10304-128-16] - AssertionError: assert tensor(0., device='cuda:0') > 0.99 ... # the failure occurs for all backend (cutlass, cudnn, etc) ``` cc: @zihaoye @bkryu ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Tests** * Improved test accuracy by ensuring tensor comparisons use floating-point precision for cosine similarity calculations. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
📌 Description
The SM120 CUTLASS blockwise gemm kernel requires dimensions like N to be multiples of 128 due to hardware constraints (https://github.com/NVIDIA/cutlass/blob/3f4c086d09bd1dc55defb955862f333893bbb28b/include/cutlass/gemm/collective/sm120_mma_tma_blockwise_scaling.hpp#L345C5-L346).
We met the shape
a: torch.Size([1, 1, 2688]), b: torch.Size([1, 2688, 10304]), scale_a: torch.Size([]), scale_b: torch.Size([]), out: torch.Size([1, 1, 10304]), workspace_buffer: torch.Size([33554432])from Nemotron-Nano-v3, where 10304 is not a multiple of 128, the cutlass gemm does not work for it properly. In this PR, we add a pad and slice to get it work.🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Release Notes
Bug Fixes
Tests
✏️ Tip: You can customize this high-level summary in your review settings.