Skip to content

misc: support checks for gemm#2214

Merged
yzh119 merged 3 commits intoflashinfer-ai:mainfrom
jimmyzho:decorate-gemm
Dec 17, 2025
Merged

misc: support checks for gemm#2214
yzh119 merged 3 commits intoflashinfer-ai:mainfrom
jimmyzho:decorate-gemm

Conversation

@jimmyzho
Copy link
Copy Markdown
Contributor

@jimmyzho jimmyzho commented Dec 13, 2025

📌 Description

Continuation of !2000

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added FP8 groupwise and block‑scaled GEMM operations, MXFP8/MXF4 mixed‑precision groupwise paths, and DeepGEMM groupwise/batch variants with backend routing.
  • Improvements

    • Stronger runtime pre-checks and capability gating before execution.
    • Replaced implicit asserts with explicit input validation and clearer error messages for more reliable usage and easier debugging.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Dec 13, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds dedicated pre-check validator functions and backend_requirement gating for FP8 GEMM groupwise/blockscaled/DeepGEMM paths; replaces inline asserts with explicit ValueError raises in DeepGEMM contiguous path; introduces multiple new public GEMM entry points and backend-specific requirement checks.

Changes

Cohort / File(s) Summary
DeepGEMM pre-checks & entry-point updates
flashinfer/deep_gemm.py
Added _check_group_deepgemm_fp8_nt_contiguous_problem_size and _check_m_grouped_fp8_gemm_nt_masked_problem_size; wired them as common_check via @backend_requirement on m_grouped_fp8_gemm_nt_contiguous and m_grouped_fp8_gemm_nt_masked; replaced in-function assert checks with explicit ValueError raises; extended imports (supported_compute_capability, backend_requirement).
GEMM base: FP8 groupwise / blockscaled APIs & validators
flashinfer/gemm/gemm_base.py
Added backend requirement functions (_cutlass_gemm_fp8_nt_groupwise_requirement, _trtllm_gemm_fp8_nt_groupwise_requirement) and multiple problem-size pre-checks (_check_gemm_fp8_nt_groupwise_problem_size, _check_gemm_fp8_nt_blockscaled_problem_size, _check_group_gemm_mxfp8_mxfp4_nt_groupwise_problem_size, _check_group_deepgemm_fp8_nt_groupwise_problem_size, _check_batch_deepgemm_fp8_nt_groupwise); introduced public APIs (gemm_fp8_nt_groupwise, gemm_fp8_nt_blockscaled, group_gemm_fp8_nt_groupwise, group_gemm_mxfp8_mxfp4_nt_groupwise, group_deepgemm_fp8_nt_groupwise, batch_deepgemm_fp8_nt_groupwise) wired via @backend_requirement; removed many inline shape/dtype/SM assertions in favor of pre-check delegation.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Pay extra attention to correctness and completeness of newly added validator logic.
  • Verify backend_requirement decorator wiring and that common_check functions cover previously inlined asserts.
  • Review dtype/shape handling and error messages across new public entry points.

Suggested reviewers

  • aleozlx
  • cyx-6
  • nvmbreughe

Poem

🐰 I nibbled through shapes and backend gates,
Replaced loud asserts with tidy checks and plates.
Grouped GEMMs hum in ordered rows,
Validators guard where the data flows.
A hop, a tweak, and off it goes—🥕

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 2 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 31.82% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Title check ❓ Inconclusive The title 'misc: support checks for gemm' is vague and generic, using non-descriptive terms that don't clearly convey the specific changes made beyond a general notion of 'misc' support. Consider a more specific title that describes the main change, such as 'Add pre-check validation decorators for FP8 GEMM operations' or 'Implement backend requirement checks for GEMM functions'.
Description check ❓ Inconclusive The description is mostly the PR template with minimal content; only the Related Issues section mentions continuation of PR #2000, but lacks substantive explanation of what this PR accomplishes. Expand the Description section to explain the changes, such as which GEMM operations now have checks, what these checks validate, and why they are needed. The Description section should provide more detail beyond just the template.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @jimmyzho, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refactors the input validation and compute capability checks for several General Matrix Multiply (GEMM) operations within the flashinfer library. By introducing a decorator-driven approach with supported_compute_capability and backend_requirement, along with specialized _check_problem_size helper functions, the changes centralize and standardize how inputs are validated and how hardware compatibility is enforced. This leads to a cleaner codebase, more robust error handling, and a more modular design for future extensions.

Highlights

  • Centralized Validation Logic: Introduced a new decorator-based system (@supported_compute_capability, @backend_requirement) for input validation and compute capability checks across various General Matrix Multiply (GEMM) functions.
  • Improved Error Handling: Replaced direct assert statements with ValueError exceptions within dedicated _check_..._problem_size helper functions, providing more informative error messages and better control flow.
  • Code Refactoring: Extracted validation logic from the main GEMM function bodies into separate, reusable check functions, enhancing code readability and maintainability by separating concerns.
  • Backend-Specific Requirements: Added specific requirement checks for different backends (e.g., Cutlass, TRTLLM) within the new validation framework, allowing for tailored validation based on the chosen backend.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@jimmyzho
Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !195 has been created, and the CI pipeline #40127035 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors several GEMM functions to extract validation logic into separate check functions, which are then used with a new @backend_requirement decorator. This is a good pattern for improving code structure and reusability. My review includes one critical issue related to incorrect shape validation and a missing check for positive dimensions. I've also pointed out a couple of instances of duplicated logic for out_dtype resolution that could be refactored to improve maintainability.

assert b.dtype == torch.float8_e4m3fn
assert d.dtype == torch.bfloat16
assert m_indices.dtype == torch.int32
if m != m_ or k != k_ or n != n_ or m__ != m_ or num_groups != m__:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There appears to be a bug in the shape validation logic. The condition num_groups != m__ is incorrect, as num_groups (from b.shape[0]) is not necessarily equal to m__ (from m_indices.numel(), which is m). This check was not present in the original assert statement and seems to have been added by mistake.

Additionally, the check for positive dimensions (n > 0, k > 0, num_groups > 0) from the original assert has been removed and should be restored.

I suggest correcting the shape check and reintroducing the positive dimension checks. For example:

    if m != m_ or k != k_ or n != n_ or m__ != m_:
        raise ValueError(
            f"Shape mismatch. m = {m}, m_ = {m_}, k = {k}, k_ = {k_}, n = {n}, n_ = {n_}, m__ = {m__}"
        )
    if n <= 0 or k <= 0 or num_groups <= 0:
        raise ValueError(
            f"n, k, and num_groups must be positive, but got n={n}, k={k}, num_groups={num_groups}"
        )
Suggested change
if m != m_ or k != k_ or n != n_ or m__ != m_ or num_groups != m__:
if m != m_ or k != k_ or n != n_ or m__ != m_:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The num_groups != m__ looks confusing to me as well, @jimmyzho would you mind double checking? I don't see it in existing codebase.

Comment on lines +2817 to +2822
if out is None:
if out_dtype is None:
out_dtype = torch.bfloat16
else:
if out_dtype is None:
out_dtype = out.dtype
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for determining out_dtype is duplicated in this check function and in the main group_gemm_fp8_nt_groupwise function (lines 2928-2933). This duplication can lead to maintenance issues.

To improve maintainability, consider refactoring this logic into a private helper function that can be called from both places. This would ensure a single source of truth for out_dtype resolution.

Comment on lines +3016 to +3021
if out is None:
if out_dtype is None:
out_dtype = torch.bfloat16
else:
if out_dtype is None:
out_dtype = out.dtype
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to a previous comment, the logic for determining out_dtype is duplicated here and in the main group_gemm_mxfp8_mxfp4_nt_groupwise function (lines 3142-3147).

Refactoring this into a shared private helper function would reduce code duplication and improve maintainability.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (3)
flashinfer/deep_gemm.py (3)

1416-1437: LGTM - decorator wiring is correct.

The @backend_requirement decorator with empty backend_checks and common_check is correctly wired to enforce validation before execution.

Minor: The unpacked k_ at line 1437 is unused. Consider using _ as a placeholder:

-    num_groups, n, k_ = b.shape
+    num_groups, n, _ = b.shape

1489-1490: Consider using underscore for unused unpacked variables.

The sfa and sfb tensors are unpacked but not used in the validation. Using underscore makes the intent clearer:

-    a, sfa = a_fp8
-    b, sfb = b_fp8
+    a, _ = a_fp8
+    b, _ = b_fp8

1548-1549: Redundant assertions after pre-check validation.

These assertions duplicate the validation in _check_m_grouped_fp8_gemm_nt_masked_problem_size. They are technically redundant when the decorator runs, but serve as defense-in-depth when skip_check=True is passed.

Consider removing them if redundancy is undesirable, or adding a comment explaining they guard against skip_check=True usage:

+    # Guard for skip_check=True scenarios
     assert major_a == major_b == MajorTypeAB.KMajor
     assert masked_m.is_contiguous()
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1ac4e1d and 00992c9.

📒 Files selected for processing (2)
  • flashinfer/deep_gemm.py (4 hunks)
  • flashinfer/gemm/gemm_base.py (7 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/deep_gemm.py (1)
flashinfer/utils.py (4)
  • ceil_div (622-633)
  • round_up (636-638)
  • supported_compute_capability (819-899)
  • backend_requirement (902-1184)
flashinfer/gemm/gemm_base.py (2)
flashinfer/utils.py (4)
  • supported_compute_capability (819-899)
  • backend_requirement (902-1184)
  • is_sm120a_supported (551-553)
  • is_sm121a_supported (556-558)
flashinfer/deep_gemm.py (2)
  • _check_group_deepgemm_fp8_nt_contiguous_problem_size (1367-1413)
  • _check_m_grouped_fp8_gemm_nt_masked_problem_size (1468-1526)
🪛 Ruff (0.14.8)
flashinfer/deep_gemm.py

1372-1372: Unused function argument: recipe

(ARG001)


1373-1373: Unused function argument: compiled_dims

(ARG001)


1379-1379: Avoid specifying long messages outside the exception class

(TRY003)


1381-1381: Avoid specifying long messages outside the exception class

(TRY003)


1384-1386: Avoid specifying long messages outside the exception class

(TRY003)


1397-1399: Avoid specifying long messages outside the exception class

(TRY003)


1401-1401: Avoid specifying long messages outside the exception class

(TRY003)


1403-1403: Avoid specifying long messages outside the exception class

(TRY003)


1405-1405: Avoid specifying long messages outside the exception class

(TRY003)


1407-1407: Avoid specifying long messages outside the exception class

(TRY003)


1411-1411: Avoid specifying long messages outside the exception class

(TRY003)


1437-1437: Unpacked variable k_ is never used

(RUF059)


1474-1474: Unused function argument: recipe

(ARG001)


1475-1475: Unused function argument: compiled_dims

(ARG001)


1480-1480: Avoid specifying long messages outside the exception class

(TRY003)


1482-1482: Avoid specifying long messages outside the exception class

(TRY003)


1485-1487: Avoid specifying long messages outside the exception class

(TRY003)


1489-1489: Unpacked variable sfa is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1490-1490: Unpacked variable sfb is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1502-1504: Avoid specifying long messages outside the exception class

(TRY003)


1506-1508: Avoid specifying long messages outside the exception class

(TRY003)


1510-1512: Avoid specifying long messages outside the exception class

(TRY003)


1514-1514: Avoid specifying long messages outside the exception class

(TRY003)


1516-1516: Avoid specifying long messages outside the exception class

(TRY003)


1518-1518: Avoid specifying long messages outside the exception class

(TRY003)


1520-1520: Avoid specifying long messages outside the exception class

(TRY003)


1524-1524: Avoid specifying long messages outside the exception class

(TRY003)

flashinfer/gemm/gemm_base.py

2415-2415: Unused function argument: a

(ARG001)


2416-2416: Unused function argument: b

(ARG001)


2417-2417: Unused function argument: a_scale

(ARG001)


2418-2418: Unused function argument: b_scale

(ARG001)


2420-2420: Unused function argument: mma_sm

(ARG001)


2421-2421: Unused function argument: scale_granularity_mnk

(ARG001)


2422-2422: Unused function argument: out

(ARG001)


2423-2423: Unused function argument: out_dtype

(ARG001)


2424-2424: Unused function argument: backend

(ARG001)


2427-2427: Avoid specifying long messages outside the exception class

(TRY003)


2435-2435: Unused function argument: b

(ARG001)


2436-2436: Unused function argument: a_scale

(ARG001)


2437-2437: Unused function argument: b_scale

(ARG001)


2438-2438: Unused function argument: scale_major_mode

(ARG001)


2439-2439: Unused function argument: mma_sm

(ARG001)


2441-2441: Unused function argument: out

(ARG001)


2442-2442: Unused function argument: out_dtype

(ARG001)


2443-2443: Unused function argument: backend

(ARG001)


2446-2446: Avoid specifying long messages outside the exception class

(TRY003)


2448-2448: Avoid specifying long messages outside the exception class

(TRY003)


2456-2456: Unused function argument: a_scale

(ARG001)


2457-2457: Unused function argument: b_scale

(ARG001)


2458-2458: Unused function argument: scale_major_mode

(ARG001)


2459-2459: Unused function argument: mma_sm

(ARG001)


2460-2460: Unused function argument: scale_granularity_mnk

(ARG001)


2461-2461: Unused function argument: out

(ARG001)


2463-2463: Unused function argument: backend

(ARG001)


2466-2466: Avoid specifying long messages outside the exception class

(TRY003)


2469-2471: Avoid specifying long messages outside the exception class

(TRY003)


2790-2790: Unused function argument: scale_granularity_mnk

(ARG001)


2797-2797: Avoid specifying long messages outside the exception class

(TRY003)


2799-2799: Avoid specifying long messages outside the exception class

(TRY003)


2801-2801: Avoid specifying long messages outside the exception class

(TRY003)


2803-2803: Avoid specifying long messages outside the exception class

(TRY003)


2805-2805: Avoid specifying long messages outside the exception class

(TRY003)


2807-2809: Avoid specifying long messages outside the exception class

(TRY003)


2811-2811: Avoid specifying long messages outside the exception class

(TRY003)


2824-2826: Avoid specifying long messages outside the exception class

(TRY003)


2828-2830: Avoid specifying long messages outside the exception class

(TRY003)


2835-2835: Avoid specifying long messages outside the exception class

(TRY003)


2837-2837: Avoid specifying long messages outside the exception class

(TRY003)


2839-2839: Avoid specifying long messages outside the exception class

(TRY003)


2845-2847: Avoid specifying long messages outside the exception class

(TRY003)


2993-2995: Avoid specifying long messages outside the exception class

(TRY003)


2997-2997: Avoid specifying long messages outside the exception class

(TRY003)


2999-2999: Avoid specifying long messages outside the exception class

(TRY003)


3001-3001: Avoid specifying long messages outside the exception class

(TRY003)


3003-3003: Avoid specifying long messages outside the exception class

(TRY003)


3005-3005: Avoid specifying long messages outside the exception class

(TRY003)


3007-3007: Avoid specifying long messages outside the exception class

(TRY003)


3009-3009: Avoid specifying long messages outside the exception class

(TRY003)


3011-3011: Avoid specifying long messages outside the exception class

(TRY003)


3013-3013: Avoid specifying long messages outside the exception class

(TRY003)


3024-3026: Avoid specifying long messages outside the exception class

(TRY003)


3030-3032: Avoid specifying long messages outside the exception class

(TRY003)


3039-3041: Avoid specifying long messages outside the exception class

(TRY003)


3046-3046: Avoid specifying long messages outside the exception class

(TRY003)


3048-3048: Avoid specifying long messages outside the exception class

(TRY003)


3053-3053: Avoid specifying long messages outside the exception class

(TRY003)


3055-3055: Avoid specifying long messages outside the exception class

(TRY003)


3218-3218: Unused function argument: out_dtype

(ARG001)


3369-3369: Unused function argument: out_dtype

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (16)
flashinfer/deep_gemm.py (1)

48-53: LGTM!

The new imports for supported_compute_capability and backend_requirement are correctly added and align with the decorator usage in the file.

flashinfer/gemm/gemm_base.py (15)

2413-2429: LGTM!

The CUTLASS backend requirement correctly validates that scale_major_mode is provided, which is required for CUTLASS FP8 groupwise GEMM.


2432-2450: LGTM!

The TRT-LLM backend requirements correctly enforce the (1, 128, 128) scale granularity and minimum k-dimension of 256.


2453-2475: LGTM!

The common check correctly validates tensor dimensions and output dtype for FP8 NT groupwise GEMM.


2478-2484: LGTM!

The @backend_requirement decorator is correctly wired with both backend-specific checks and a common check.


2750-2753: Decorator wiring is correct.

The @backend_requirement with empty backend_checks and common_check correctly delegates to the pre-check function. Note: Fix the function call issues in _check_gemm_fp8_nt_blockscaled_problem_size as noted above.


2783-2849: LGTM - comprehensive validation with clear constraints.

The pre-check function provides thorough validation for group GEMM FP8 operations. The SM120/121 restriction for num_groups > 1 is properly documented with a runtime error.

Note: The commented assertion at line 2813 (a.shape[0] == m_indptr[-1].item()) trades correctness checking for performance. Consider documenting this tradeoff in the function's docstring.


2852-2855: LGTM!

The decorator is correctly wired with the pre-check function.


2974-2974: LGTM!

The explicit return out statement ensures the function properly returns the output tensor.


2977-3057: LGTM - thorough MXFP8/MXFP4 validation.

The pre-check function correctly validates:

  • Input dtypes (float8 for a, uint8 for packed FP4 b and scales)
  • Tile size constraints for the kernel
  • Shape alignment requirements

The validation logic properly handles the FP4 packing (k dimension is doubled from uint8 shape).


3060-3063: LGTM!

The decorator is correctly wired.


3141-3154: LGTM!

The out_dtype determination logic is correctly maintained in the function body for tensor allocation, while the pre-check handles validation.


3209-3226: LGTM - correctly delegates to deep_gemm validation.

The function correctly packs the individual tensors into tuples (a, a_scale) and (b, b_scale) to match the a_fp8 and b_fp8 format expected by _check_group_deepgemm_fp8_nt_contiguous_problem_size.


3229-3232: LGTM!

The decorator is correctly wired.


3359-3375: LGTM - correctly delegates to deep_gemm validation.

The function properly delegates to _check_m_grouped_fp8_gemm_nt_masked_problem_size with correctly packed tensor tuples.


3378-3381: LGTM!

The decorator is correctly wired.

Comment on lines +1396 to +1399
if m != m_ or k != k_ or n != n_ or m__ != m_ or num_groups != m__:
raise ValueError(
f"Shape mismatch. m = {m}, m_ = {m_}, k = {k}, k_ = {k_}, n = {n}, n_ = {n_}, m__ = {m__}"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Find usages of m_grouped_fp8_gemm_nt_contiguous
rg -n "m_grouped_fp8_gemm_nt_contiguous\(" flashinfer/deep_gemm.py --type py -A 5

Repository: flashinfer-ai/flashinfer

Length of output: 328


🏁 Script executed:

#!/bin/bash
# Find the function definition and its docstring
rg -n -B 5 "def m_grouped_fp8_gemm_nt_contiguous" flashinfer/deep_gemm.py --type py -A 20

Repository: flashinfer-ai/flashinfer

Length of output: 3255


🏁 Script executed:

#!/bin/bash
# Search for test files that might show expected usage of m_indices
fd -e py "test" | head -20 | xargs rg -l "m_grouped_fp8_gemm_nt_contiguous" 2>/dev/null

Repository: flashinfer-ai/flashinfer

Length of output: 52


🏁 Script executed:

#!/bin/bash
# Look for examples or documentation about m_indices
rg -n "m_indices" flashinfer/deep_gemm.py --type py -B 2 -A 2 | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 1992


🏁 Script executed:

#!/bin/bash
# Look for docstring of m_grouped_fp8_gemm_nt_contiguous
sed -n '1420,1450p' flashinfer/deep_gemm.py

Repository: flashinfer-ai/flashinfer

Length of output: 1038


🏁 Script executed:

#!/bin/bash
# Search for test files that use this function
find . -name "*.py" -path "*/test*" -o -name "*test*.py" | xargs rg -l "m_grouped_fp8_gemm_nt_contiguous" 2>/dev/null | head -5

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

#!/bin/bash
# Look for any comments or docstrings about m_indices semantics in the check function
sed -n '1366,1413p' flashinfer/deep_gemm.py

Repository: flashinfer-ai/flashinfer

Length of output: 1928


🏁 Script executed:

#!/bin/bash
# Search for any tests or examples of m_grouped_fp8_gemm_nt_contiguous
rg -n "m_grouped_fp8_gemm_nt_contiguous" --type py -B 3 -A 3 | grep -A 10 -B 10 "m_indices"

Repository: flashinfer-ai/flashinfer

Length of output: 2163


🏁 Script executed:

#!/bin/bash
# Look for kernel generation code that might explain m_indices
sed -n '1030,1139p' flashinfer/deep_gemm.py | grep -A 5 -B 5 "GROUPED_LAYOUT"

Repository: flashinfer-ai/flashinfer

Length of output: 422


🏁 Script executed:

#!/bin/bash
# Check if there are any comments in the kernel generator about group layout
rg -n "GROUPED_LAYOUT" flashinfer/ --type py -B 5 -A 5

Repository: flashinfer-ai/flashinfer

Length of output: 2058


🏁 Script executed:

#!/bin/bash
# Look at the full public API function that calls m_grouped_fp8_gemm_nt_contiguous
sed -n '3320,3360p' flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 1769


🏁 Script executed:

#!/bin/bash
# Search for documentation or examples of how m_indices should be constructed
rg -n "m_indices" flashinfer/gemm/gemm_base.py -B 5 -A 5

Repository: flashinfer-ai/flashinfer

Length of output: 3943


🏁 Script executed:

#!/bin/bash
# Check if there's a "masked" variant that might clarify the semantics
rg -n "m_grouped_fp8_gemm_nt_masked" flashinfer/deep_gemm.py --type py -A 20 | head -60

Repository: flashinfer-ai/flashinfer

Length of output: 1728


The shape check condition num_groups != m__ is incorrect and contradicts the documented API.

The condition enforces num_groups == m_indices.numel() == m, which requires the number of groups to equal the total number of rows. However, the public API documentation (gemm_base.py lines 3279-3283) and example (lines 3326-3329) clearly define m_indices as a per-row group assignment tensor where each element specifies which group (0 to num_groups-1) that row belongs to. The example shows num_groups=8 with m=256, so m_indices has 256 elements with values in [0, 8).

The check should be removed or replaced with validation that m_indices values are within the valid range [0, num_groups).

🧰 Tools
🪛 Ruff (0.14.8)

1397-1399: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In flashinfer/deep_gemm.py around lines 1396 to 1399, remove the incorrect check
comparing num_groups to m__ (which enforces num_groups == number of rows) and
instead validate the m_indices tensor values: ensure m_indices has integer
dtype, has length equal to m (if not already checked earlier), and that every
element satisfies 0 <= m_indices[i] < num_groups; replace the rejected equality
test with this range-check and raise a ValueError with a clear message if any
index is out of bounds.

yzh119
yzh119 approved these changes Dec 15, 2025
@jimmyzho
Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !195 has been updated with latest changes, and the CI pipeline #40251567 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
flashinfer/gemm/gemm_base.py (1)

2726-2748: Fix critical parameter mismatch in function calls.

Both calls to _check_gemm_fp8_nt_groupwise_problem_size (line 2726) and _cutlass_gemm_fp8_nt_groupwise_requirement (line 2738) are missing the scale_granularity_mnk parameter. This causes positional arguments to be incorrectly bound:

  • outscale_granularity_mnk
  • out_dtypeout
  • backend="cutlass"out_dtype

This will result in a TypeError or incorrect validation at runtime.

Apply this diff to fix both calls:

     _check_gemm_fp8_nt_groupwise_problem_size(
         a,
         b,
         a_scale,
         b_scale,
         scale_major_mode,
         mma_sm,
+        scale_granularity_mnk=(128, 128, 128),
         out,
         out_dtype,
         backend="cutlass",
     )

     _cutlass_gemm_fp8_nt_groupwise_requirement(
         a,
         b,
         a_scale,
         b_scale,
         scale_major_mode,
         mma_sm,
+        scale_granularity_mnk=(128, 128, 128),
         out,
         out_dtype,
         backend="cutlass",
     )
🧹 Nitpick comments (1)
flashinfer/deep_gemm.py (1)

1548-1549: Remove redundant assertions.

These assertions duplicate checks already performed by the _check_m_grouped_fp8_gemm_nt_masked_problem_size pre-check function (lines 1479-1482 and 1484-1487). Since the @backend_requirement decorator ensures the pre-check runs before this function, these assertions are unnecessary.

Apply this diff to remove the redundant checks:

-    assert major_a == major_b == MajorTypeAB.KMajor
-    assert masked_m.is_contiguous()
-
     a, sfa = a_fp8
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 00992c9 and 5de3a5c.

📒 Files selected for processing (2)
  • flashinfer/deep_gemm.py (4 hunks)
  • flashinfer/gemm/gemm_base.py (7 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • flashinfer/deep_gemm.py
  • flashinfer/gemm/gemm_base.py
🧬 Code graph analysis (1)
flashinfer/deep_gemm.py (1)
flashinfer/utils.py (4)
  • ceil_div (622-633)
  • round_up (636-638)
  • supported_compute_capability (819-899)
  • backend_requirement (902-1184)
🪛 Ruff (0.14.8)
flashinfer/deep_gemm.py

1372-1372: Unused function argument: recipe

(ARG001)


1373-1373: Unused function argument: compiled_dims

(ARG001)


1379-1379: Avoid specifying long messages outside the exception class

(TRY003)


1381-1381: Avoid specifying long messages outside the exception class

(TRY003)


1384-1386: Avoid specifying long messages outside the exception class

(TRY003)


1397-1399: Avoid specifying long messages outside the exception class

(TRY003)


1401-1401: Avoid specifying long messages outside the exception class

(TRY003)


1403-1403: Avoid specifying long messages outside the exception class

(TRY003)


1405-1405: Avoid specifying long messages outside the exception class

(TRY003)


1407-1407: Avoid specifying long messages outside the exception class

(TRY003)


1411-1411: Avoid specifying long messages outside the exception class

(TRY003)


1437-1437: Unpacked variable k_ is never used

(RUF059)


1474-1474: Unused function argument: recipe

(ARG001)


1475-1475: Unused function argument: compiled_dims

(ARG001)


1480-1480: Avoid specifying long messages outside the exception class

(TRY003)


1482-1482: Avoid specifying long messages outside the exception class

(TRY003)


1485-1487: Avoid specifying long messages outside the exception class

(TRY003)


1489-1489: Unpacked variable sfa is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1490-1490: Unpacked variable sfb is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1502-1504: Avoid specifying long messages outside the exception class

(TRY003)


1506-1508: Avoid specifying long messages outside the exception class

(TRY003)


1510-1512: Avoid specifying long messages outside the exception class

(TRY003)


1514-1514: Avoid specifying long messages outside the exception class

(TRY003)


1516-1516: Avoid specifying long messages outside the exception class

(TRY003)


1518-1518: Avoid specifying long messages outside the exception class

(TRY003)


1520-1520: Avoid specifying long messages outside the exception class

(TRY003)


1524-1524: Avoid specifying long messages outside the exception class

(TRY003)

flashinfer/gemm/gemm_base.py

2415-2415: Unused function argument: a

(ARG001)


2416-2416: Unused function argument: b

(ARG001)


2417-2417: Unused function argument: a_scale

(ARG001)


2418-2418: Unused function argument: b_scale

(ARG001)


2420-2420: Unused function argument: mma_sm

(ARG001)


2421-2421: Unused function argument: scale_granularity_mnk

(ARG001)


2422-2422: Unused function argument: out

(ARG001)


2423-2423: Unused function argument: out_dtype

(ARG001)


2424-2424: Unused function argument: backend

(ARG001)


2427-2427: Avoid specifying long messages outside the exception class

(TRY003)


2435-2435: Unused function argument: b

(ARG001)


2436-2436: Unused function argument: a_scale

(ARG001)


2437-2437: Unused function argument: b_scale

(ARG001)


2438-2438: Unused function argument: scale_major_mode

(ARG001)


2439-2439: Unused function argument: mma_sm

(ARG001)


2441-2441: Unused function argument: out

(ARG001)


2442-2442: Unused function argument: out_dtype

(ARG001)


2443-2443: Unused function argument: backend

(ARG001)


2446-2446: Avoid specifying long messages outside the exception class

(TRY003)


2448-2448: Avoid specifying long messages outside the exception class

(TRY003)


2456-2456: Unused function argument: a_scale

(ARG001)


2457-2457: Unused function argument: b_scale

(ARG001)


2458-2458: Unused function argument: scale_major_mode

(ARG001)


2459-2459: Unused function argument: mma_sm

(ARG001)


2460-2460: Unused function argument: scale_granularity_mnk

(ARG001)


2463-2463: Unused function argument: backend

(ARG001)


2466-2466: Avoid specifying long messages outside the exception class

(TRY003)


2469-2471: Avoid specifying long messages outside the exception class

(TRY003)


2793-2793: Unused function argument: scale_granularity_mnk

(ARG001)


2800-2800: Avoid specifying long messages outside the exception class

(TRY003)


2802-2802: Avoid specifying long messages outside the exception class

(TRY003)


2804-2804: Avoid specifying long messages outside the exception class

(TRY003)


2806-2806: Avoid specifying long messages outside the exception class

(TRY003)


2808-2808: Avoid specifying long messages outside the exception class

(TRY003)


2810-2812: Avoid specifying long messages outside the exception class

(TRY003)


2814-2814: Avoid specifying long messages outside the exception class

(TRY003)


2827-2829: Avoid specifying long messages outside the exception class

(TRY003)


2831-2833: Avoid specifying long messages outside the exception class

(TRY003)


2838-2838: Avoid specifying long messages outside the exception class

(TRY003)


2840-2840: Avoid specifying long messages outside the exception class

(TRY003)


2842-2842: Avoid specifying long messages outside the exception class

(TRY003)


2848-2850: Avoid specifying long messages outside the exception class

(TRY003)


2996-2998: Avoid specifying long messages outside the exception class

(TRY003)


3000-3000: Avoid specifying long messages outside the exception class

(TRY003)


3002-3002: Avoid specifying long messages outside the exception class

(TRY003)


3004-3004: Avoid specifying long messages outside the exception class

(TRY003)


3006-3006: Avoid specifying long messages outside the exception class

(TRY003)


3008-3008: Avoid specifying long messages outside the exception class

(TRY003)


3010-3010: Avoid specifying long messages outside the exception class

(TRY003)


3012-3012: Avoid specifying long messages outside the exception class

(TRY003)


3014-3014: Avoid specifying long messages outside the exception class

(TRY003)


3016-3016: Avoid specifying long messages outside the exception class

(TRY003)


3027-3029: Avoid specifying long messages outside the exception class

(TRY003)


3033-3035: Avoid specifying long messages outside the exception class

(TRY003)


3042-3044: Avoid specifying long messages outside the exception class

(TRY003)


3049-3049: Avoid specifying long messages outside the exception class

(TRY003)


3051-3051: Avoid specifying long messages outside the exception class

(TRY003)


3056-3056: Avoid specifying long messages outside the exception class

(TRY003)


3058-3058: Avoid specifying long messages outside the exception class

(TRY003)


3221-3221: Unused function argument: out_dtype

(ARG001)


3372-3372: Unused function argument: out_dtype

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (11)
flashinfer/deep_gemm.py (3)

48-53: LGTM!

The added imports are correctly used in the new pre-check functions and backend requirement decorators below.


1366-1413: LGTM!

The shape validation logic is correct. The check at line 1396 properly verifies that m_indices.numel() equals m, ensuring the index tensor has the right number of elements. The unused parameters (recipe, compiled_dims) are part of the function signature to maintain API consistency with the backend_requirement decorator.


1416-1464: LGTM!

The backend_requirement decorator is correctly configured with an empty backend_checks dict (single implicit backend) and the common_check properly validates inputs before execution.

flashinfer/gemm/gemm_base.py (8)

2413-2450: LGTM!

The requirement functions correctly validate backend-specific constraints:

  • CUTLASS requires scale_major_mode to be specified
  • TRTLLM requires specific scale granularity and minimum inner dimension

The unused parameters are necessary to match the function signature expected by the backend_requirement decorator.


2453-2478: LGTM!

The common check function correctly validates tensor shapes and output dtype. The unused parameters are required to match the function signature for the backend_requirement decorator.


2481-2627: LGTM!

The function correctly implements FP8 NT groupwise GEMM with proper backend routing between CUTLASS and TRTLLM. The backend_requirement decorator ensures all validations are performed before execution.


2753-2783: LGTM!

The blockscaled variant correctly wraps gemm_fp8_nt_groupwise with a fixed scale granularity of (128, 128, 128). The implementation is clean and maintains API consistency.


2787-2977: LGTM!

The group GEMM FP8 NT groupwise implementation is well-structured with proper validation and backend routing. The function correctly handles output tensor allocation and delegates to the appropriate SM-specific implementation.

Note: There's minor code duplication in the out_dtype determination logic between the check function (lines 2820-2825) and the main function (lines 2931-2936), but this is acceptable for clarity in the "Chill" review mode.


2980-3176: LGTM!

The MXFP8/MXFP4 groupwise GEMM implementation is comprehensive with thorough validation of tensor shapes, dtypes, and tile parameters. The function properly handles the packed FP4 format and delegates to the SM100 backend.

Note: Similar to the FP8 variant, there's minor code duplication in out_dtype determination (lines 3019-3024 vs. 3145-3150), but this is acceptable.


3212-3359: LGTM!

The grouped DeepGEMM function correctly delegates validation to the deep_gemm module and provides a clean API for grouped matrix multiplication with FP8 data types. The documentation is comprehensive and includes helpful examples.


3362-3513: LGTM!

The batch DeepGEMM function properly handles batch-wise FP8 matrix multiplication with masking. The validation delegation pattern is consistent with the grouped variant, and the documentation clearly explains the masking behavior.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[CANCELING] Pipeline #40251567: canceled

@jimmyzho
Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !195 has been updated with latest changes, and the CI pipeline #40334478 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
flashinfer/gemm/gemm_base.py (2)

2822-2835: Refactor duplicated out_dtype resolution logic.

The logic for determining out_dtype (lines 2822-2835) is duplicated in the main group_gemm_fp8_nt_groupwise function (lines 2933-2938). This was also flagged in a previous review.

Consider extracting this into a helper function to maintain a single source of truth.

Apply this refactor:

+def _resolve_out_dtype(out: Optional[torch.Tensor], out_dtype: Optional[torch.dtype], default_dtype: torch.dtype = torch.bfloat16) -> torch.dtype:
+    """Resolve output dtype from out tensor or explicit dtype."""
+    if out is None:
+        return out_dtype if out_dtype is not None else default_dtype
+    else:
+        return out_dtype if out_dtype is not None else out.dtype
+
 @supported_compute_capability([100, 103, 120, 121])
 def _check_group_gemm_fp8_nt_groupwise_problem_size(
     ...
 ):
     ...
-    if out is None:
-        if out_dtype is None:
-            out_dtype = torch.bfloat16
-    else:
-        if out_dtype is None:
-            out_dtype = out.dtype
+    out_dtype = _resolve_out_dtype(out, out_dtype, torch.bfloat16)

3021-3026: Refactor duplicated out_dtype resolution logic.

Similar to the FP8 groupwise checker, this function duplicates the out_dtype resolution logic found in the main group_gemm_mxfp8_mxfp4_nt_groupwise function (lines 3147-3152). This was flagged in a previous review.

Consider using the same helper function suggested in the earlier comment to eliminate duplication across all checkers.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5de3a5c and 2e7fda1.

📒 Files selected for processing (1)
  • flashinfer/gemm/gemm_base.py (7 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • flashinfer/gemm/gemm_base.py
🧬 Code graph analysis (1)
flashinfer/gemm/gemm_base.py (4)
flashinfer/utils.py (2)
  • supported_compute_capability (819-899)
  • backend_requirement (902-1184)
csrc/tvm_ffi_utils.h (1)
  • Tensor (304-306)
include/flashinfer/attention/blackwell/common/pow_2.hpp (1)
  • b (52-52)
flashinfer/deep_gemm.py (3)
  • _check_group_deepgemm_fp8_nt_contiguous_problem_size (1367-1413)
  • m_grouped_fp8_gemm_nt_contiguous (1420-1464)
  • _check_m_grouped_fp8_gemm_nt_masked_problem_size (1468-1526)
🪛 Ruff (0.14.8)
flashinfer/gemm/gemm_base.py

2415-2415: Unused function argument: a

(ARG001)


2416-2416: Unused function argument: b

(ARG001)


2417-2417: Unused function argument: a_scale

(ARG001)


2418-2418: Unused function argument: b_scale

(ARG001)


2420-2420: Unused function argument: mma_sm

(ARG001)


2421-2421: Unused function argument: scale_granularity_mnk

(ARG001)


2422-2422: Unused function argument: out

(ARG001)


2423-2423: Unused function argument: out_dtype

(ARG001)


2424-2424: Unused function argument: backend

(ARG001)


2427-2427: Avoid specifying long messages outside the exception class

(TRY003)


2435-2435: Unused function argument: b

(ARG001)


2436-2436: Unused function argument: a_scale

(ARG001)


2437-2437: Unused function argument: b_scale

(ARG001)


2438-2438: Unused function argument: scale_major_mode

(ARG001)


2439-2439: Unused function argument: mma_sm

(ARG001)


2441-2441: Unused function argument: out

(ARG001)


2442-2442: Unused function argument: out_dtype

(ARG001)


2443-2443: Unused function argument: backend

(ARG001)


2446-2446: Avoid specifying long messages outside the exception class

(TRY003)


2448-2448: Avoid specifying long messages outside the exception class

(TRY003)


2456-2456: Unused function argument: a_scale

(ARG001)


2457-2457: Unused function argument: b_scale

(ARG001)


2458-2458: Unused function argument: scale_major_mode

(ARG001)


2459-2459: Unused function argument: mma_sm

(ARG001)


2460-2460: Unused function argument: scale_granularity_mnk

(ARG001)


2463-2463: Unused function argument: backend

(ARG001)


2466-2466: Avoid specifying long messages outside the exception class

(TRY003)


2469-2471: Avoid specifying long messages outside the exception class

(TRY003)


2795-2795: Unused function argument: scale_granularity_mnk

(ARG001)


2802-2802: Avoid specifying long messages outside the exception class

(TRY003)


2804-2804: Avoid specifying long messages outside the exception class

(TRY003)


2806-2806: Avoid specifying long messages outside the exception class

(TRY003)


2808-2808: Avoid specifying long messages outside the exception class

(TRY003)


2810-2810: Avoid specifying long messages outside the exception class

(TRY003)


2812-2814: Avoid specifying long messages outside the exception class

(TRY003)


2816-2816: Avoid specifying long messages outside the exception class

(TRY003)


2829-2831: Avoid specifying long messages outside the exception class

(TRY003)


2833-2835: Avoid specifying long messages outside the exception class

(TRY003)


2840-2840: Avoid specifying long messages outside the exception class

(TRY003)


2842-2842: Avoid specifying long messages outside the exception class

(TRY003)


2844-2844: Avoid specifying long messages outside the exception class

(TRY003)


2850-2852: Avoid specifying long messages outside the exception class

(TRY003)


2998-3000: Avoid specifying long messages outside the exception class

(TRY003)


3002-3002: Avoid specifying long messages outside the exception class

(TRY003)


3004-3004: Avoid specifying long messages outside the exception class

(TRY003)


3006-3006: Avoid specifying long messages outside the exception class

(TRY003)


3008-3008: Avoid specifying long messages outside the exception class

(TRY003)


3010-3010: Avoid specifying long messages outside the exception class

(TRY003)


3012-3012: Avoid specifying long messages outside the exception class

(TRY003)


3014-3014: Avoid specifying long messages outside the exception class

(TRY003)


3016-3016: Avoid specifying long messages outside the exception class

(TRY003)


3018-3018: Avoid specifying long messages outside the exception class

(TRY003)


3029-3031: Avoid specifying long messages outside the exception class

(TRY003)


3035-3037: Avoid specifying long messages outside the exception class

(TRY003)


3044-3046: Avoid specifying long messages outside the exception class

(TRY003)


3051-3051: Avoid specifying long messages outside the exception class

(TRY003)


3053-3053: Avoid specifying long messages outside the exception class

(TRY003)


3058-3058: Avoid specifying long messages outside the exception class

(TRY003)


3060-3060: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (11)
flashinfer/gemm/gemm_base.py (11)

2413-2429: LGTM!

The CUTLASS backend requirement check is minimal and focused, validating only that scale_major_mode is provided. The unused parameter warnings from static analysis are expected for requirement checker functions.


2432-2450: LGTM!

The TRTLLM backend requirement check appropriately enforces stricter constraints: fixed scale granularity and minimum K dimension of 256.


2453-2478: LGTM!

The common problem size check validates essential shape constraints and output dtype requirements for FP8 GEMM operations.


2715-2752: LGTM!

The blockscaled problem size check correctly delegates to the groupwise checker with the fixed scale granularity of (128, 128, 128). The missing scale_granularity_mnk argument issue from past reviews has been resolved.


2848-2852: Good safeguard for known SM120/121 limitation.

The explicit check prevents using group_gemm_fp8_nt_groupwise with multiple groups on SM120/121 where correctness issues exist.


2857-2860: LGTM!

The @backend_requirement decorator with empty backend_checks and only a common_check is the correct pattern for validation-only usage without backend selection.


3065-3068: LGTM!

Consistent use of the validation-only pattern with @backend_requirement.


3214-3235: LGTM!

The checker appropriately delegates to the DeepGEMM module's validation function while adapting the API format.


3238-3241: LGTM!

Consistent validation-only pattern for DeepGEMM operations.


3368-3390: LGTM!

The batch DeepGEMM checker follows the same delegation pattern, appropriately adapting the API to the DeepGEMM module's validation function.


3393-3396: LGTM!

Final consistent use of the validation-only pattern with @backend_requirement.

out_dtype = out_dtype or torch.bfloat16
out = torch.empty(a.shape[0], b.shape[1], dtype=out_dtype, device=a.device)

print("GOT HERE")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Remove debug print statement.

This debug print statement should not be committed to the codebase as it will pollute production output.

Apply this diff:

-    print("GOT HERE")
     m_grouped_fp8_gemm_nt_contiguous(
         (a, a_scale), (b, b_scale), out, m_indices, scale_granularity_mnk
     )
🤖 Prompt for AI Agents
In flashinfer/gemm/gemm_base.py around line 3360 there is a leftover debug print
statement ("GOT HERE") that should be removed; delete that print call so no
debug output is emitted from the production code and ensure no other stray debug
prints remain nearby.

@yzh119 yzh119 merged commit c7752ab into flashinfer-ai:main Dec 17, 2025
4 checks passed
@coderabbitai coderabbitai bot mentioned this pull request Dec 24, 2025
5 tasks
BingooYang pushed a commit to BingooYang/flashinfer that referenced this pull request Mar 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants