misc: support checks for gemm#2214
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds dedicated pre-check validator functions and backend_requirement gating for FP8 GEMM groupwise/blockscaled/DeepGEMM paths; replaces inline asserts with explicit ValueError raises in DeepGEMM contiguous path; introduces multiple new public GEMM entry points and backend-specific requirement checks. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 2 inconclusive)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @jimmyzho, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly refactors the input validation and compute capability checks for several General Matrix Multiply (GEMM) operations within the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
/bot run |
There was a problem hiding this comment.
Code Review
This pull request refactors several GEMM functions to extract validation logic into separate check functions, which are then used with a new @backend_requirement decorator. This is a good pattern for improving code structure and reusability. My review includes one critical issue related to incorrect shape validation and a missing check for positive dimensions. I've also pointed out a couple of instances of duplicated logic for out_dtype resolution that could be refactored to improve maintainability.
flashinfer/deep_gemm.py
Outdated
| assert b.dtype == torch.float8_e4m3fn | ||
| assert d.dtype == torch.bfloat16 | ||
| assert m_indices.dtype == torch.int32 | ||
| if m != m_ or k != k_ or n != n_ or m__ != m_ or num_groups != m__: |
There was a problem hiding this comment.
There appears to be a bug in the shape validation logic. The condition num_groups != m__ is incorrect, as num_groups (from b.shape[0]) is not necessarily equal to m__ (from m_indices.numel(), which is m). This check was not present in the original assert statement and seems to have been added by mistake.
Additionally, the check for positive dimensions (n > 0, k > 0, num_groups > 0) from the original assert has been removed and should be restored.
I suggest correcting the shape check and reintroducing the positive dimension checks. For example:
if m != m_ or k != k_ or n != n_ or m__ != m_:
raise ValueError(
f"Shape mismatch. m = {m}, m_ = {m_}, k = {k}, k_ = {k_}, n = {n}, n_ = {n_}, m__ = {m__}"
)
if n <= 0 or k <= 0 or num_groups <= 0:
raise ValueError(
f"n, k, and num_groups must be positive, but got n={n}, k={k}, num_groups={num_groups}"
)| if m != m_ or k != k_ or n != n_ or m__ != m_ or num_groups != m__: | |
| if m != m_ or k != k_ or n != n_ or m__ != m_: |
There was a problem hiding this comment.
The num_groups != m__ looks confusing to me as well, @jimmyzho would you mind double checking? I don't see it in existing codebase.
| if out is None: | ||
| if out_dtype is None: | ||
| out_dtype = torch.bfloat16 | ||
| else: | ||
| if out_dtype is None: | ||
| out_dtype = out.dtype |
There was a problem hiding this comment.
The logic for determining out_dtype is duplicated in this check function and in the main group_gemm_fp8_nt_groupwise function (lines 2928-2933). This duplication can lead to maintenance issues.
To improve maintainability, consider refactoring this logic into a private helper function that can be called from both places. This would ensure a single source of truth for out_dtype resolution.
| if out is None: | ||
| if out_dtype is None: | ||
| out_dtype = torch.bfloat16 | ||
| else: | ||
| if out_dtype is None: | ||
| out_dtype = out.dtype |
There was a problem hiding this comment.
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (3)
flashinfer/deep_gemm.py (3)
1416-1437: LGTM - decorator wiring is correct.The
@backend_requirementdecorator with emptybackend_checksandcommon_checkis correctly wired to enforce validation before execution.Minor: The unpacked
k_at line 1437 is unused. Consider using_as a placeholder:- num_groups, n, k_ = b.shape + num_groups, n, _ = b.shape
1489-1490: Consider using underscore for unused unpacked variables.The
sfaandsfbtensors are unpacked but not used in the validation. Using underscore makes the intent clearer:- a, sfa = a_fp8 - b, sfb = b_fp8 + a, _ = a_fp8 + b, _ = b_fp8
1548-1549: Redundant assertions after pre-check validation.These assertions duplicate the validation in
_check_m_grouped_fp8_gemm_nt_masked_problem_size. They are technically redundant when the decorator runs, but serve as defense-in-depth whenskip_check=Trueis passed.Consider removing them if redundancy is undesirable, or adding a comment explaining they guard against
skip_check=Trueusage:+ # Guard for skip_check=True scenarios assert major_a == major_b == MajorTypeAB.KMajor assert masked_m.is_contiguous()
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
flashinfer/deep_gemm.py(4 hunks)flashinfer/gemm/gemm_base.py(7 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/deep_gemm.py (1)
flashinfer/utils.py (4)
ceil_div(622-633)round_up(636-638)supported_compute_capability(819-899)backend_requirement(902-1184)
flashinfer/gemm/gemm_base.py (2)
flashinfer/utils.py (4)
supported_compute_capability(819-899)backend_requirement(902-1184)is_sm120a_supported(551-553)is_sm121a_supported(556-558)flashinfer/deep_gemm.py (2)
_check_group_deepgemm_fp8_nt_contiguous_problem_size(1367-1413)_check_m_grouped_fp8_gemm_nt_masked_problem_size(1468-1526)
🪛 Ruff (0.14.8)
flashinfer/deep_gemm.py
1372-1372: Unused function argument: recipe
(ARG001)
1373-1373: Unused function argument: compiled_dims
(ARG001)
1379-1379: Avoid specifying long messages outside the exception class
(TRY003)
1381-1381: Avoid specifying long messages outside the exception class
(TRY003)
1384-1386: Avoid specifying long messages outside the exception class
(TRY003)
1397-1399: Avoid specifying long messages outside the exception class
(TRY003)
1401-1401: Avoid specifying long messages outside the exception class
(TRY003)
1403-1403: Avoid specifying long messages outside the exception class
(TRY003)
1405-1405: Avoid specifying long messages outside the exception class
(TRY003)
1407-1407: Avoid specifying long messages outside the exception class
(TRY003)
1411-1411: Avoid specifying long messages outside the exception class
(TRY003)
1437-1437: Unpacked variable k_ is never used
(RUF059)
1474-1474: Unused function argument: recipe
(ARG001)
1475-1475: Unused function argument: compiled_dims
(ARG001)
1480-1480: Avoid specifying long messages outside the exception class
(TRY003)
1482-1482: Avoid specifying long messages outside the exception class
(TRY003)
1485-1487: Avoid specifying long messages outside the exception class
(TRY003)
1489-1489: Unpacked variable sfa is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1490-1490: Unpacked variable sfb is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1502-1504: Avoid specifying long messages outside the exception class
(TRY003)
1506-1508: Avoid specifying long messages outside the exception class
(TRY003)
1510-1512: Avoid specifying long messages outside the exception class
(TRY003)
1514-1514: Avoid specifying long messages outside the exception class
(TRY003)
1516-1516: Avoid specifying long messages outside the exception class
(TRY003)
1518-1518: Avoid specifying long messages outside the exception class
(TRY003)
1520-1520: Avoid specifying long messages outside the exception class
(TRY003)
1524-1524: Avoid specifying long messages outside the exception class
(TRY003)
flashinfer/gemm/gemm_base.py
2415-2415: Unused function argument: a
(ARG001)
2416-2416: Unused function argument: b
(ARG001)
2417-2417: Unused function argument: a_scale
(ARG001)
2418-2418: Unused function argument: b_scale
(ARG001)
2420-2420: Unused function argument: mma_sm
(ARG001)
2421-2421: Unused function argument: scale_granularity_mnk
(ARG001)
2422-2422: Unused function argument: out
(ARG001)
2423-2423: Unused function argument: out_dtype
(ARG001)
2424-2424: Unused function argument: backend
(ARG001)
2427-2427: Avoid specifying long messages outside the exception class
(TRY003)
2435-2435: Unused function argument: b
(ARG001)
2436-2436: Unused function argument: a_scale
(ARG001)
2437-2437: Unused function argument: b_scale
(ARG001)
2438-2438: Unused function argument: scale_major_mode
(ARG001)
2439-2439: Unused function argument: mma_sm
(ARG001)
2441-2441: Unused function argument: out
(ARG001)
2442-2442: Unused function argument: out_dtype
(ARG001)
2443-2443: Unused function argument: backend
(ARG001)
2446-2446: Avoid specifying long messages outside the exception class
(TRY003)
2448-2448: Avoid specifying long messages outside the exception class
(TRY003)
2456-2456: Unused function argument: a_scale
(ARG001)
2457-2457: Unused function argument: b_scale
(ARG001)
2458-2458: Unused function argument: scale_major_mode
(ARG001)
2459-2459: Unused function argument: mma_sm
(ARG001)
2460-2460: Unused function argument: scale_granularity_mnk
(ARG001)
2461-2461: Unused function argument: out
(ARG001)
2463-2463: Unused function argument: backend
(ARG001)
2466-2466: Avoid specifying long messages outside the exception class
(TRY003)
2469-2471: Avoid specifying long messages outside the exception class
(TRY003)
2790-2790: Unused function argument: scale_granularity_mnk
(ARG001)
2797-2797: Avoid specifying long messages outside the exception class
(TRY003)
2799-2799: Avoid specifying long messages outside the exception class
(TRY003)
2801-2801: Avoid specifying long messages outside the exception class
(TRY003)
2803-2803: Avoid specifying long messages outside the exception class
(TRY003)
2805-2805: Avoid specifying long messages outside the exception class
(TRY003)
2807-2809: Avoid specifying long messages outside the exception class
(TRY003)
2811-2811: Avoid specifying long messages outside the exception class
(TRY003)
2824-2826: Avoid specifying long messages outside the exception class
(TRY003)
2828-2830: Avoid specifying long messages outside the exception class
(TRY003)
2835-2835: Avoid specifying long messages outside the exception class
(TRY003)
2837-2837: Avoid specifying long messages outside the exception class
(TRY003)
2839-2839: Avoid specifying long messages outside the exception class
(TRY003)
2845-2847: Avoid specifying long messages outside the exception class
(TRY003)
2993-2995: Avoid specifying long messages outside the exception class
(TRY003)
2997-2997: Avoid specifying long messages outside the exception class
(TRY003)
2999-2999: Avoid specifying long messages outside the exception class
(TRY003)
3001-3001: Avoid specifying long messages outside the exception class
(TRY003)
3003-3003: Avoid specifying long messages outside the exception class
(TRY003)
3005-3005: Avoid specifying long messages outside the exception class
(TRY003)
3007-3007: Avoid specifying long messages outside the exception class
(TRY003)
3009-3009: Avoid specifying long messages outside the exception class
(TRY003)
3011-3011: Avoid specifying long messages outside the exception class
(TRY003)
3013-3013: Avoid specifying long messages outside the exception class
(TRY003)
3024-3026: Avoid specifying long messages outside the exception class
(TRY003)
3030-3032: Avoid specifying long messages outside the exception class
(TRY003)
3039-3041: Avoid specifying long messages outside the exception class
(TRY003)
3046-3046: Avoid specifying long messages outside the exception class
(TRY003)
3048-3048: Avoid specifying long messages outside the exception class
(TRY003)
3053-3053: Avoid specifying long messages outside the exception class
(TRY003)
3055-3055: Avoid specifying long messages outside the exception class
(TRY003)
3218-3218: Unused function argument: out_dtype
(ARG001)
3369-3369: Unused function argument: out_dtype
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (16)
flashinfer/deep_gemm.py (1)
48-53: LGTM!The new imports for
supported_compute_capabilityandbackend_requirementare correctly added and align with the decorator usage in the file.flashinfer/gemm/gemm_base.py (15)
2413-2429: LGTM!The CUTLASS backend requirement correctly validates that
scale_major_modeis provided, which is required for CUTLASS FP8 groupwise GEMM.
2432-2450: LGTM!The TRT-LLM backend requirements correctly enforce the (1, 128, 128) scale granularity and minimum k-dimension of 256.
2453-2475: LGTM!The common check correctly validates tensor dimensions and output dtype for FP8 NT groupwise GEMM.
2478-2484: LGTM!The
@backend_requirementdecorator is correctly wired with both backend-specific checks and a common check.
2750-2753: Decorator wiring is correct.The
@backend_requirementwith emptybackend_checksandcommon_checkcorrectly delegates to the pre-check function. Note: Fix the function call issues in_check_gemm_fp8_nt_blockscaled_problem_sizeas noted above.
2783-2849: LGTM - comprehensive validation with clear constraints.The pre-check function provides thorough validation for group GEMM FP8 operations. The SM120/121 restriction for
num_groups > 1is properly documented with a runtime error.Note: The commented assertion at line 2813 (
a.shape[0] == m_indptr[-1].item()) trades correctness checking for performance. Consider documenting this tradeoff in the function's docstring.
2852-2855: LGTM!The decorator is correctly wired with the pre-check function.
2974-2974: LGTM!The explicit
return outstatement ensures the function properly returns the output tensor.
2977-3057: LGTM - thorough MXFP8/MXFP4 validation.The pre-check function correctly validates:
- Input dtypes (float8 for
a, uint8 for packed FP4band scales)- Tile size constraints for the kernel
- Shape alignment requirements
The validation logic properly handles the FP4 packing (k dimension is doubled from uint8 shape).
3060-3063: LGTM!The decorator is correctly wired.
3141-3154: LGTM!The
out_dtypedetermination logic is correctly maintained in the function body for tensor allocation, while the pre-check handles validation.
3209-3226: LGTM - correctly delegates to deep_gemm validation.The function correctly packs the individual tensors into tuples
(a, a_scale)and(b, b_scale)to match thea_fp8andb_fp8format expected by_check_group_deepgemm_fp8_nt_contiguous_problem_size.
3229-3232: LGTM!The decorator is correctly wired.
3359-3375: LGTM - correctly delegates to deep_gemm validation.The function properly delegates to
_check_m_grouped_fp8_gemm_nt_masked_problem_sizewith correctly packed tensor tuples.
3378-3381: LGTM!The decorator is correctly wired.
flashinfer/deep_gemm.py
Outdated
| if m != m_ or k != k_ or n != n_ or m__ != m_ or num_groups != m__: | ||
| raise ValueError( | ||
| f"Shape mismatch. m = {m}, m_ = {m_}, k = {k}, k_ = {k_}, n = {n}, n_ = {n_}, m__ = {m__}" | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Find usages of m_grouped_fp8_gemm_nt_contiguous
rg -n "m_grouped_fp8_gemm_nt_contiguous\(" flashinfer/deep_gemm.py --type py -A 5Repository: flashinfer-ai/flashinfer
Length of output: 328
🏁 Script executed:
#!/bin/bash
# Find the function definition and its docstring
rg -n -B 5 "def m_grouped_fp8_gemm_nt_contiguous" flashinfer/deep_gemm.py --type py -A 20Repository: flashinfer-ai/flashinfer
Length of output: 3255
🏁 Script executed:
#!/bin/bash
# Search for test files that might show expected usage of m_indices
fd -e py "test" | head -20 | xargs rg -l "m_grouped_fp8_gemm_nt_contiguous" 2>/dev/nullRepository: flashinfer-ai/flashinfer
Length of output: 52
🏁 Script executed:
#!/bin/bash
# Look for examples or documentation about m_indices
rg -n "m_indices" flashinfer/deep_gemm.py --type py -B 2 -A 2 | head -100Repository: flashinfer-ai/flashinfer
Length of output: 1992
🏁 Script executed:
#!/bin/bash
# Look for docstring of m_grouped_fp8_gemm_nt_contiguous
sed -n '1420,1450p' flashinfer/deep_gemm.pyRepository: flashinfer-ai/flashinfer
Length of output: 1038
🏁 Script executed:
#!/bin/bash
# Search for test files that use this function
find . -name "*.py" -path "*/test*" -o -name "*test*.py" | xargs rg -l "m_grouped_fp8_gemm_nt_contiguous" 2>/dev/null | head -5Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Look for any comments or docstrings about m_indices semantics in the check function
sed -n '1366,1413p' flashinfer/deep_gemm.pyRepository: flashinfer-ai/flashinfer
Length of output: 1928
🏁 Script executed:
#!/bin/bash
# Search for any tests or examples of m_grouped_fp8_gemm_nt_contiguous
rg -n "m_grouped_fp8_gemm_nt_contiguous" --type py -B 3 -A 3 | grep -A 10 -B 10 "m_indices"Repository: flashinfer-ai/flashinfer
Length of output: 2163
🏁 Script executed:
#!/bin/bash
# Look for kernel generation code that might explain m_indices
sed -n '1030,1139p' flashinfer/deep_gemm.py | grep -A 5 -B 5 "GROUPED_LAYOUT"Repository: flashinfer-ai/flashinfer
Length of output: 422
🏁 Script executed:
#!/bin/bash
# Check if there are any comments in the kernel generator about group layout
rg -n "GROUPED_LAYOUT" flashinfer/ --type py -B 5 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 2058
🏁 Script executed:
#!/bin/bash
# Look at the full public API function that calls m_grouped_fp8_gemm_nt_contiguous
sed -n '3320,3360p' flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 1769
🏁 Script executed:
#!/bin/bash
# Search for documentation or examples of how m_indices should be constructed
rg -n "m_indices" flashinfer/gemm/gemm_base.py -B 5 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 3943
🏁 Script executed:
#!/bin/bash
# Check if there's a "masked" variant that might clarify the semantics
rg -n "m_grouped_fp8_gemm_nt_masked" flashinfer/deep_gemm.py --type py -A 20 | head -60Repository: flashinfer-ai/flashinfer
Length of output: 1728
The shape check condition num_groups != m__ is incorrect and contradicts the documented API.
The condition enforces num_groups == m_indices.numel() == m, which requires the number of groups to equal the total number of rows. However, the public API documentation (gemm_base.py lines 3279-3283) and example (lines 3326-3329) clearly define m_indices as a per-row group assignment tensor where each element specifies which group (0 to num_groups-1) that row belongs to. The example shows num_groups=8 with m=256, so m_indices has 256 elements with values in [0, 8).
The check should be removed or replaced with validation that m_indices values are within the valid range [0, num_groups).
🧰 Tools
🪛 Ruff (0.14.8)
1397-1399: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In flashinfer/deep_gemm.py around lines 1396 to 1399, remove the incorrect check
comparing num_groups to m__ (which enforces num_groups == number of rows) and
instead validate the m_indices tensor values: ensure m_indices has integer
dtype, has length equal to m (if not already checked earlier), and that every
element satisfies 0 <= m_indices[i] < num_groups; replace the rejected equality
test with this range-check and raise a ValueError with a clear message if any
index is out of bounds.
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (1)
flashinfer/gemm/gemm_base.py (1)
2726-2748: Fix critical parameter mismatch in function calls.Both calls to
_check_gemm_fp8_nt_groupwise_problem_size(line 2726) and_cutlass_gemm_fp8_nt_groupwise_requirement(line 2738) are missing thescale_granularity_mnkparameter. This causes positional arguments to be incorrectly bound:
out→scale_granularity_mnkout_dtype→outbackend="cutlass"→out_dtypeThis will result in a
TypeErroror incorrect validation at runtime.Apply this diff to fix both calls:
_check_gemm_fp8_nt_groupwise_problem_size( a, b, a_scale, b_scale, scale_major_mode, mma_sm, + scale_granularity_mnk=(128, 128, 128), out, out_dtype, backend="cutlass", ) _cutlass_gemm_fp8_nt_groupwise_requirement( a, b, a_scale, b_scale, scale_major_mode, mma_sm, + scale_granularity_mnk=(128, 128, 128), out, out_dtype, backend="cutlass", )
🧹 Nitpick comments (1)
flashinfer/deep_gemm.py (1)
1548-1549: Remove redundant assertions.These assertions duplicate checks already performed by the
_check_m_grouped_fp8_gemm_nt_masked_problem_sizepre-check function (lines 1479-1482 and 1484-1487). Since the@backend_requirementdecorator ensures the pre-check runs before this function, these assertions are unnecessary.Apply this diff to remove the redundant checks:
- assert major_a == major_b == MajorTypeAB.KMajor - assert masked_m.is_contiguous() - a, sfa = a_fp8
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
flashinfer/deep_gemm.py(4 hunks)flashinfer/gemm/gemm_base.py(7 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
flashinfer/deep_gemm.pyflashinfer/gemm/gemm_base.py
🧬 Code graph analysis (1)
flashinfer/deep_gemm.py (1)
flashinfer/utils.py (4)
ceil_div(622-633)round_up(636-638)supported_compute_capability(819-899)backend_requirement(902-1184)
🪛 Ruff (0.14.8)
flashinfer/deep_gemm.py
1372-1372: Unused function argument: recipe
(ARG001)
1373-1373: Unused function argument: compiled_dims
(ARG001)
1379-1379: Avoid specifying long messages outside the exception class
(TRY003)
1381-1381: Avoid specifying long messages outside the exception class
(TRY003)
1384-1386: Avoid specifying long messages outside the exception class
(TRY003)
1397-1399: Avoid specifying long messages outside the exception class
(TRY003)
1401-1401: Avoid specifying long messages outside the exception class
(TRY003)
1403-1403: Avoid specifying long messages outside the exception class
(TRY003)
1405-1405: Avoid specifying long messages outside the exception class
(TRY003)
1407-1407: Avoid specifying long messages outside the exception class
(TRY003)
1411-1411: Avoid specifying long messages outside the exception class
(TRY003)
1437-1437: Unpacked variable k_ is never used
(RUF059)
1474-1474: Unused function argument: recipe
(ARG001)
1475-1475: Unused function argument: compiled_dims
(ARG001)
1480-1480: Avoid specifying long messages outside the exception class
(TRY003)
1482-1482: Avoid specifying long messages outside the exception class
(TRY003)
1485-1487: Avoid specifying long messages outside the exception class
(TRY003)
1489-1489: Unpacked variable sfa is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1490-1490: Unpacked variable sfb is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1502-1504: Avoid specifying long messages outside the exception class
(TRY003)
1506-1508: Avoid specifying long messages outside the exception class
(TRY003)
1510-1512: Avoid specifying long messages outside the exception class
(TRY003)
1514-1514: Avoid specifying long messages outside the exception class
(TRY003)
1516-1516: Avoid specifying long messages outside the exception class
(TRY003)
1518-1518: Avoid specifying long messages outside the exception class
(TRY003)
1520-1520: Avoid specifying long messages outside the exception class
(TRY003)
1524-1524: Avoid specifying long messages outside the exception class
(TRY003)
flashinfer/gemm/gemm_base.py
2415-2415: Unused function argument: a
(ARG001)
2416-2416: Unused function argument: b
(ARG001)
2417-2417: Unused function argument: a_scale
(ARG001)
2418-2418: Unused function argument: b_scale
(ARG001)
2420-2420: Unused function argument: mma_sm
(ARG001)
2421-2421: Unused function argument: scale_granularity_mnk
(ARG001)
2422-2422: Unused function argument: out
(ARG001)
2423-2423: Unused function argument: out_dtype
(ARG001)
2424-2424: Unused function argument: backend
(ARG001)
2427-2427: Avoid specifying long messages outside the exception class
(TRY003)
2435-2435: Unused function argument: b
(ARG001)
2436-2436: Unused function argument: a_scale
(ARG001)
2437-2437: Unused function argument: b_scale
(ARG001)
2438-2438: Unused function argument: scale_major_mode
(ARG001)
2439-2439: Unused function argument: mma_sm
(ARG001)
2441-2441: Unused function argument: out
(ARG001)
2442-2442: Unused function argument: out_dtype
(ARG001)
2443-2443: Unused function argument: backend
(ARG001)
2446-2446: Avoid specifying long messages outside the exception class
(TRY003)
2448-2448: Avoid specifying long messages outside the exception class
(TRY003)
2456-2456: Unused function argument: a_scale
(ARG001)
2457-2457: Unused function argument: b_scale
(ARG001)
2458-2458: Unused function argument: scale_major_mode
(ARG001)
2459-2459: Unused function argument: mma_sm
(ARG001)
2460-2460: Unused function argument: scale_granularity_mnk
(ARG001)
2463-2463: Unused function argument: backend
(ARG001)
2466-2466: Avoid specifying long messages outside the exception class
(TRY003)
2469-2471: Avoid specifying long messages outside the exception class
(TRY003)
2793-2793: Unused function argument: scale_granularity_mnk
(ARG001)
2800-2800: Avoid specifying long messages outside the exception class
(TRY003)
2802-2802: Avoid specifying long messages outside the exception class
(TRY003)
2804-2804: Avoid specifying long messages outside the exception class
(TRY003)
2806-2806: Avoid specifying long messages outside the exception class
(TRY003)
2808-2808: Avoid specifying long messages outside the exception class
(TRY003)
2810-2812: Avoid specifying long messages outside the exception class
(TRY003)
2814-2814: Avoid specifying long messages outside the exception class
(TRY003)
2827-2829: Avoid specifying long messages outside the exception class
(TRY003)
2831-2833: Avoid specifying long messages outside the exception class
(TRY003)
2838-2838: Avoid specifying long messages outside the exception class
(TRY003)
2840-2840: Avoid specifying long messages outside the exception class
(TRY003)
2842-2842: Avoid specifying long messages outside the exception class
(TRY003)
2848-2850: Avoid specifying long messages outside the exception class
(TRY003)
2996-2998: Avoid specifying long messages outside the exception class
(TRY003)
3000-3000: Avoid specifying long messages outside the exception class
(TRY003)
3002-3002: Avoid specifying long messages outside the exception class
(TRY003)
3004-3004: Avoid specifying long messages outside the exception class
(TRY003)
3006-3006: Avoid specifying long messages outside the exception class
(TRY003)
3008-3008: Avoid specifying long messages outside the exception class
(TRY003)
3010-3010: Avoid specifying long messages outside the exception class
(TRY003)
3012-3012: Avoid specifying long messages outside the exception class
(TRY003)
3014-3014: Avoid specifying long messages outside the exception class
(TRY003)
3016-3016: Avoid specifying long messages outside the exception class
(TRY003)
3027-3029: Avoid specifying long messages outside the exception class
(TRY003)
3033-3035: Avoid specifying long messages outside the exception class
(TRY003)
3042-3044: Avoid specifying long messages outside the exception class
(TRY003)
3049-3049: Avoid specifying long messages outside the exception class
(TRY003)
3051-3051: Avoid specifying long messages outside the exception class
(TRY003)
3056-3056: Avoid specifying long messages outside the exception class
(TRY003)
3058-3058: Avoid specifying long messages outside the exception class
(TRY003)
3221-3221: Unused function argument: out_dtype
(ARG001)
3372-3372: Unused function argument: out_dtype
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (11)
flashinfer/deep_gemm.py (3)
48-53: LGTM!The added imports are correctly used in the new pre-check functions and backend requirement decorators below.
1366-1413: LGTM!The shape validation logic is correct. The check at line 1396 properly verifies that
m_indices.numel()equalsm, ensuring the index tensor has the right number of elements. The unused parameters (recipe,compiled_dims) are part of the function signature to maintain API consistency with the backend_requirement decorator.
1416-1464: LGTM!The backend_requirement decorator is correctly configured with an empty backend_checks dict (single implicit backend) and the common_check properly validates inputs before execution.
flashinfer/gemm/gemm_base.py (8)
2413-2450: LGTM!The requirement functions correctly validate backend-specific constraints:
- CUTLASS requires
scale_major_modeto be specified- TRTLLM requires specific scale granularity and minimum inner dimension
The unused parameters are necessary to match the function signature expected by the
backend_requirementdecorator.
2453-2478: LGTM!The common check function correctly validates tensor shapes and output dtype. The unused parameters are required to match the function signature for the
backend_requirementdecorator.
2481-2627: LGTM!The function correctly implements FP8 NT groupwise GEMM with proper backend routing between CUTLASS and TRTLLM. The backend_requirement decorator ensures all validations are performed before execution.
2753-2783: LGTM!The blockscaled variant correctly wraps
gemm_fp8_nt_groupwisewith a fixed scale granularity of (128, 128, 128). The implementation is clean and maintains API consistency.
2787-2977: LGTM!The group GEMM FP8 NT groupwise implementation is well-structured with proper validation and backend routing. The function correctly handles output tensor allocation and delegates to the appropriate SM-specific implementation.
Note: There's minor code duplication in the out_dtype determination logic between the check function (lines 2820-2825) and the main function (lines 2931-2936), but this is acceptable for clarity in the "Chill" review mode.
2980-3176: LGTM!The MXFP8/MXFP4 groupwise GEMM implementation is comprehensive with thorough validation of tensor shapes, dtypes, and tile parameters. The function properly handles the packed FP4 format and delegates to the SM100 backend.
Note: Similar to the FP8 variant, there's minor code duplication in out_dtype determination (lines 3019-3024 vs. 3145-3150), but this is acceptable.
3212-3359: LGTM!The grouped DeepGEMM function correctly delegates validation to the deep_gemm module and provides a clean API for grouped matrix multiplication with FP8 data types. The documentation is comprehensive and includes helpful examples.
3362-3513: LGTM!The batch DeepGEMM function properly handles batch-wise FP8 matrix multiplication with masking. The validation delegation pattern is consistent with the grouped variant, and the documentation clearly explains the masking behavior.
|
[CANCELING] Pipeline #40251567: canceled |
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (2)
flashinfer/gemm/gemm_base.py (2)
2822-2835: Refactor duplicated out_dtype resolution logic.The logic for determining
out_dtype(lines 2822-2835) is duplicated in the maingroup_gemm_fp8_nt_groupwisefunction (lines 2933-2938). This was also flagged in a previous review.Consider extracting this into a helper function to maintain a single source of truth.
Apply this refactor:
+def _resolve_out_dtype(out: Optional[torch.Tensor], out_dtype: Optional[torch.dtype], default_dtype: torch.dtype = torch.bfloat16) -> torch.dtype: + """Resolve output dtype from out tensor or explicit dtype.""" + if out is None: + return out_dtype if out_dtype is not None else default_dtype + else: + return out_dtype if out_dtype is not None else out.dtype + @supported_compute_capability([100, 103, 120, 121]) def _check_group_gemm_fp8_nt_groupwise_problem_size( ... ): ... - if out is None: - if out_dtype is None: - out_dtype = torch.bfloat16 - else: - if out_dtype is None: - out_dtype = out.dtype + out_dtype = _resolve_out_dtype(out, out_dtype, torch.bfloat16)
3021-3026: Refactor duplicated out_dtype resolution logic.Similar to the FP8 groupwise checker, this function duplicates the out_dtype resolution logic found in the main
group_gemm_mxfp8_mxfp4_nt_groupwisefunction (lines 3147-3152). This was flagged in a previous review.Consider using the same helper function suggested in the earlier comment to eliminate duplication across all checkers.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer/gemm/gemm_base.py(7 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
flashinfer/gemm/gemm_base.py
🧬 Code graph analysis (1)
flashinfer/gemm/gemm_base.py (4)
flashinfer/utils.py (2)
supported_compute_capability(819-899)backend_requirement(902-1184)csrc/tvm_ffi_utils.h (1)
Tensor(304-306)include/flashinfer/attention/blackwell/common/pow_2.hpp (1)
b(52-52)flashinfer/deep_gemm.py (3)
_check_group_deepgemm_fp8_nt_contiguous_problem_size(1367-1413)m_grouped_fp8_gemm_nt_contiguous(1420-1464)_check_m_grouped_fp8_gemm_nt_masked_problem_size(1468-1526)
🪛 Ruff (0.14.8)
flashinfer/gemm/gemm_base.py
2415-2415: Unused function argument: a
(ARG001)
2416-2416: Unused function argument: b
(ARG001)
2417-2417: Unused function argument: a_scale
(ARG001)
2418-2418: Unused function argument: b_scale
(ARG001)
2420-2420: Unused function argument: mma_sm
(ARG001)
2421-2421: Unused function argument: scale_granularity_mnk
(ARG001)
2422-2422: Unused function argument: out
(ARG001)
2423-2423: Unused function argument: out_dtype
(ARG001)
2424-2424: Unused function argument: backend
(ARG001)
2427-2427: Avoid specifying long messages outside the exception class
(TRY003)
2435-2435: Unused function argument: b
(ARG001)
2436-2436: Unused function argument: a_scale
(ARG001)
2437-2437: Unused function argument: b_scale
(ARG001)
2438-2438: Unused function argument: scale_major_mode
(ARG001)
2439-2439: Unused function argument: mma_sm
(ARG001)
2441-2441: Unused function argument: out
(ARG001)
2442-2442: Unused function argument: out_dtype
(ARG001)
2443-2443: Unused function argument: backend
(ARG001)
2446-2446: Avoid specifying long messages outside the exception class
(TRY003)
2448-2448: Avoid specifying long messages outside the exception class
(TRY003)
2456-2456: Unused function argument: a_scale
(ARG001)
2457-2457: Unused function argument: b_scale
(ARG001)
2458-2458: Unused function argument: scale_major_mode
(ARG001)
2459-2459: Unused function argument: mma_sm
(ARG001)
2460-2460: Unused function argument: scale_granularity_mnk
(ARG001)
2463-2463: Unused function argument: backend
(ARG001)
2466-2466: Avoid specifying long messages outside the exception class
(TRY003)
2469-2471: Avoid specifying long messages outside the exception class
(TRY003)
2795-2795: Unused function argument: scale_granularity_mnk
(ARG001)
2802-2802: Avoid specifying long messages outside the exception class
(TRY003)
2804-2804: Avoid specifying long messages outside the exception class
(TRY003)
2806-2806: Avoid specifying long messages outside the exception class
(TRY003)
2808-2808: Avoid specifying long messages outside the exception class
(TRY003)
2810-2810: Avoid specifying long messages outside the exception class
(TRY003)
2812-2814: Avoid specifying long messages outside the exception class
(TRY003)
2816-2816: Avoid specifying long messages outside the exception class
(TRY003)
2829-2831: Avoid specifying long messages outside the exception class
(TRY003)
2833-2835: Avoid specifying long messages outside the exception class
(TRY003)
2840-2840: Avoid specifying long messages outside the exception class
(TRY003)
2842-2842: Avoid specifying long messages outside the exception class
(TRY003)
2844-2844: Avoid specifying long messages outside the exception class
(TRY003)
2850-2852: Avoid specifying long messages outside the exception class
(TRY003)
2998-3000: Avoid specifying long messages outside the exception class
(TRY003)
3002-3002: Avoid specifying long messages outside the exception class
(TRY003)
3004-3004: Avoid specifying long messages outside the exception class
(TRY003)
3006-3006: Avoid specifying long messages outside the exception class
(TRY003)
3008-3008: Avoid specifying long messages outside the exception class
(TRY003)
3010-3010: Avoid specifying long messages outside the exception class
(TRY003)
3012-3012: Avoid specifying long messages outside the exception class
(TRY003)
3014-3014: Avoid specifying long messages outside the exception class
(TRY003)
3016-3016: Avoid specifying long messages outside the exception class
(TRY003)
3018-3018: Avoid specifying long messages outside the exception class
(TRY003)
3029-3031: Avoid specifying long messages outside the exception class
(TRY003)
3035-3037: Avoid specifying long messages outside the exception class
(TRY003)
3044-3046: Avoid specifying long messages outside the exception class
(TRY003)
3051-3051: Avoid specifying long messages outside the exception class
(TRY003)
3053-3053: Avoid specifying long messages outside the exception class
(TRY003)
3058-3058: Avoid specifying long messages outside the exception class
(TRY003)
3060-3060: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (11)
flashinfer/gemm/gemm_base.py (11)
2413-2429: LGTM!The CUTLASS backend requirement check is minimal and focused, validating only that
scale_major_modeis provided. The unused parameter warnings from static analysis are expected for requirement checker functions.
2432-2450: LGTM!The TRTLLM backend requirement check appropriately enforces stricter constraints: fixed scale granularity and minimum K dimension of 256.
2453-2478: LGTM!The common problem size check validates essential shape constraints and output dtype requirements for FP8 GEMM operations.
2715-2752: LGTM!The blockscaled problem size check correctly delegates to the groupwise checker with the fixed scale granularity of (128, 128, 128). The missing
scale_granularity_mnkargument issue from past reviews has been resolved.
2848-2852: Good safeguard for known SM120/121 limitation.The explicit check prevents using
group_gemm_fp8_nt_groupwisewith multiple groups on SM120/121 where correctness issues exist.
2857-2860: LGTM!The
@backend_requirementdecorator with emptybackend_checksand only acommon_checkis the correct pattern for validation-only usage without backend selection.
3065-3068: LGTM!Consistent use of the validation-only pattern with
@backend_requirement.
3214-3235: LGTM!The checker appropriately delegates to the DeepGEMM module's validation function while adapting the API format.
3238-3241: LGTM!Consistent validation-only pattern for DeepGEMM operations.
3368-3390: LGTM!The batch DeepGEMM checker follows the same delegation pattern, appropriately adapting the API to the DeepGEMM module's validation function.
3393-3396: LGTM!Final consistent use of the validation-only pattern with
@backend_requirement.
| out_dtype = out_dtype or torch.bfloat16 | ||
| out = torch.empty(a.shape[0], b.shape[1], dtype=out_dtype, device=a.device) | ||
|
|
||
| print("GOT HERE") |
There was a problem hiding this comment.
Remove debug print statement.
This debug print statement should not be committed to the codebase as it will pollute production output.
Apply this diff:
- print("GOT HERE")
m_grouped_fp8_gemm_nt_contiguous(
(a, a_scale), (b, b_scale), out, m_indices, scale_granularity_mnk
)🤖 Prompt for AI Agents
In flashinfer/gemm/gemm_base.py around line 3360 there is a leftover debug print
statement ("GOT HERE") that should be removed; delete that print call so no
debug output is emitted from the production code and ensure no other stray debug
prints remain nearby.
📌 Description
Continuation of !2000
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Improvements
✏️ Tip: You can customize this high-level summary in your review settings.