Skip to content

Support for MXFP4 and NVFP4 group GEMMs on GeForce and Spark#2738

Merged
aleozlx merged 13 commits intoflashinfer-ai:mainfrom
depaulmillz:geforce_and_spark
Mar 28, 2026
Merged

Support for MXFP4 and NVFP4 group GEMMs on GeForce and Spark#2738
aleozlx merged 13 commits intoflashinfer-ai:mainfrom
depaulmillz:geforce_and_spark

Conversation

@depaulmillz
Copy link
Copy Markdown
Contributor

@depaulmillz depaulmillz commented Mar 10, 2026

📌 Description

This MR adds functional support for CUTLASS MXFP4 group GEMMs and NVFP4 group GEMMs on Blackwell GeForce and DGX Spark. It implements MXFP4 group GEMMs to match the existing interface and adds a new interface for NVFP4. The NVFP4 interface aims to match the baseline GEMM interface including support for alpha scaling.

The MR also unguards GDC on CUTLASS kernels for functional correctness.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features
    • Added NVFP4 and MXFP4 group-wise scaled GEMM support for Blackwell (SM12x) with runtime capability routing, new public NVFP4 group-GEMM API, and FFI exports for SM12x kernels.
  • Benchmarks
    • Added NVFP4 groupwise benchmark that searches tile configs and reports best TFLOPs; MXFP4 benchmark now adapts its search space based on runtime SM12x support.
  • Tests
    • Added FP4 groupwise tests and expanded MXFP4 test gating/parameterization to include SM12x.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 10, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds SM12x (SM120/121) group-wise GEMM support (NVFP4 and MXFP4): new CUTLASS kernels, headers, Jinja instantiations, C++ FFI bindings, Python API + validation/dispatch, tests, and benchmarks; updates JIT build flags and runtime capability–dependent tile/dtype selection.

Changes

Cohort / File(s) Summary
Benchmarks
benchmarks/bench_groupwise_grouped_gemm_mxfp4_blackwell.py, benchmarks/bench_groupwise_grouped_gemm_nvfp4_blackwell_geforce.py
MXFP4 benchmark now queries SM12x support to narrow tile/MMA grids; added NVFP4 benchmark script and selection/TFLOP reporting.
Python GEMM API & Exports
flashinfer/gemm/gemm_base.py, flashinfer/gemm/__init__.py
Added group_gemm_nvfp4_nt_groupwise API, problem-size validators, SM12x-aware tiling/MMA routing, and exported NVFP4 symbol; MXFP4 aliasing adjusted for compatibility.
C++ Bindings / FFI
csrc/group_gemm_sm120_binding.cu
Added SM120 entry declarations and TVM FFI exports for NVFP4 and MXFP4 groupwise entry points.
NVFP4 Kernel Implementation
csrc/group_gemm_nvfp4_groupwise_sm120.cu, include/flashinfer/gemm/group_gemm_nvfp4_groupwise_sm120.cuh, csrc/group_gemm_nvfp4_groupwise_sm120_kernel_inst.jinja
New NVFP4 SM120 CUTLASS path: dtype dispatch, tile dispatch, per-group arg builders, alpha handling, workspace usage, and Jinja instantiations.
MXFP4 Kernel Implementation
csrc/group_gemm_mxfp4_groupwise_sm120.cu, include/flashinfer/gemm/group_gemm_mxfp4_groupwise_sm120.cuh, csrc/group_gemm_mxfp4_groupwise_sm120_kernel_inst.jinja
New MXFP4 SM120 CUTLASS path: per-group argument kernel, dtype dispatch, instantiation macro and wrapper entrypoints.
Headers / Templates / Macros
include/flashinfer/gemm/*.cuh, include/flashinfer/gemm/fp4_gemm_template_sm120.h, include/flashinfer/gemm/group_gemm_fp8_groupwise_sm120.cuh
Added per-group arg preparation kernels, instantiation macros, and adjusted CUTLASS launch flag (PDL disabled) in SM120 paths.
JIT / Build Flags
flashinfer/jit/gemm/core.py
Enabled GDC/PDL-related nvcc flags in generators and added SM120 groupwise kernel source generation to build lists.
DLPack / TVM utils
csrc/tvm_ffi_utils.h
Added guarded FP8 E4M3 scaling-factor dispatch macros (new SF dispatch path gated by flags/CUDA version).
Tests
tests/gemm/test_group_gemm_fp4.py, tests/gemm/test_groupwise_scaled_gemm_mxfp4.py
Added NVFP4 FP4 groupwise tests and adjusted MXFP4 tests to gate SM12x and restrict parameter grids when SM12x is present.
Kernel Instantiations (Jinja)
csrc/*_kernel_inst.jinja
Added Jinja templates instantiating SM120 NVFP4/MXFP4 groupwise kernel variants.

Sequence Diagram(s)

sequenceDiagram
    participant Py as Python Caller
    participant API as gemm_base.py
    participant FFI as group_gemm_sm120_binding.cu
    participant Dispatch as DLPack & tile dispatch
    participant CUTLASS as CUTLASS templated kernel
    participant GPU as CUDA SM120

    Py->>API: group_gemm_nvfp4_nt_groupwise(a,b,a_scale,b_scale,...)
    API->>API: validate shapes/dtypes, check is_sm12x_supported
    alt SM12x supported
        API->>FFI: call CutlassGroupGemm...SM120(...)
    else
        API->>FFI: call fallback SM100 entry
    end
    FFI->>Dispatch: dispatch by DLPack dtypes and tile sizes
    Dispatch->>Dispatch: is_valid_config checks (dtype/tile)
    alt valid config
        Dispatch->>CUTLASS: invoke templated Cutlass...GroupGEMM<...>
        CUTLASS->>GPU: launch grouped GEMM kernels
        GPU-->>CUTLASS: complete
    else invalid
        Dispatch-->>FFI: return error
    end
    FFI-->>API: return result tensor
    API-->>Py: deliver output
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested labels

ready, op: gemm

Suggested reviewers

  • aleozlx
  • yongwww
  • yzh119
  • cyx-6
  • nvmbreughe
  • jiahanc
  • jimmyzho
  • bkryu

Poem

🐰 Hop hop, new kernels spring and play,
SM120 wakes to speed the day.
FP4 sparks and MX rows gleam,
Grouped GEMM hops into the stream.
✨🥕

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 23.81% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main changes: adding support for MXFP4 and NVFP4 group GEMMs on GeForce and Spark platforms, which aligns with the primary objective.
Description check ✅ Passed The description addresses the template requirements by explaining what the PR does (adds functional support for CUTLASS MXFP4 and NVFP4 group GEMMs) and confirms pre-commit checks and tests were completed, though missing explicit test result details and specific test names.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly expands FlashInfer's capabilities by introducing support for MXFP4 and NVFP4 group GEMMs on NVIDIA's latest Blackwell GeForce and DGX Spark architectures. It provides new, optimized kernels for these floating-point formats, ensuring efficient computation on modern hardware. Additionally, the change addresses functional correctness by adjusting GDC settings within CUTLASS kernels, enhancing the robustness of the library.

Highlights

  • MXFP4 Group GEMM Support for SM12x: Implemented functional support for CUTLASS MXFP4 group GEMMs on NVIDIA Blackwell GeForce and DGX Spark architectures (SM12x), matching the existing interface.
  • NVFP4 Group GEMM Introduction for SM12x: Added a new interface and functional support for NVFP4 group GEMMs on Blackwell GeForce and DGX Spark, including support for alpha scaling to align with baseline GEMM interfaces.
  • CUTLASS GDC Unguarding: Unguarded Grid-Dependent Control (GDC) on CUTLASS kernels for improved functional correctness across various architectures (SM90, SM100, SM120).
  • New Benchmarks and Tests: Introduced new benchmark scripts for NVFP4 group GEMMs and updated existing MXFP4 benchmarks to reflect SM12x support. Comprehensive unit tests were added for the new NVFP4 functionality.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • benchmarks/bench_groupwise_grouped_gemm_mxfp4_blackwell.py
    • Updated copyright year to 2026.
    • Modified benchmark configuration to conditionally set mma_sm_list, tile_m_list, tile_n_list, tile_k_list, and swap_ab_list based on SM12x support.
  • benchmarks/bench_groupwise_grouped_gemm_nvfp4_blackwell_geforce.py
    • Added new benchmark script for group_gemm_nvfp4_nt_groupwise on Blackwell architectures, testing various group sizes, M, N, and K dimensions.
  • csrc/group_gemm_mxfp4_groupwise_sm120.cu
    • Added new CUDA source file implementing MXFP4 group GEMM for SM120 architecture, including dispatch macros for tile sizes and data types.
  • csrc/group_gemm_mxfp4_groupwise_sm120_kernel_inst.jinja
    • Added Jinja template for instantiating MXFP4 group GEMM kernels for SM120.
  • csrc/group_gemm_nvfp4_groupwise_sm120.cu
    • Added new CUDA source file implementing NVFP4 group GEMM for SM120 architecture, supporting alpha scaling and dispatch macros for tile sizes and data types.
  • csrc/group_gemm_nvfp4_groupwise_sm120_kernel_inst.jinja
    • Added Jinja template for instantiating NVFP4 group GEMM kernels for SM120.
  • csrc/group_gemm_sm120_binding.cu
    • Added FFI declarations and exports for CutlassGroupGemmNVFP4GroupwiseScaledSM120 and CutlassGroupGemmMXFP4GroupwiseScaledSM120.
  • flashinfer/gemm/init.py
    • Imported group_gemm_nvfp4_nt_groupwise from gemm_base.
    • Added group_gemm_nvfp4_nt_groupwise to the __all__ export list.
  • flashinfer/gemm/gemm_base.py
    • Updated docstring for group_gemm_mxfp8_mxfp4_nt_groupwise to include Blackwell Geforce and DGX Spark support.
    • Modified group_gemm_mxfp8_mxfp4_nt_groupwise to conditionally dispatch to SM120 module if is_sm12x_supported is true.
    • Added new Python API group_gemm_nvfp4_nt_groupwise with problem size checks and calls to the SM120 module, supporting alpha scaling.
  • flashinfer/jit/gemm/core.py
    • Added CUTLASS_ENABLE_GDC_FOR_SM100=1 and CUTLASS_ENABLE_GDC_FOR_SM90=1 flags to various gen_gemm_smXXX_module functions for CUTLASS kernels.
    • Added logic to generate kernel instantiations for MXFP4 and NVFP4 group GEMMs for SM120 architecture.
  • include/flashinfer/gemm/group_gemm_mxfp4_groupwise_sm120.cuh
    • Added new CUDA header file defining the MXFP4 group GEMM kernel for SM120, including argument computation and instantiation macros.
  • include/flashinfer/gemm/group_gemm_nvfp4_groupwise_sm120.cuh
    • Added new CUDA header file defining the NVFP4 group GEMM kernel for SM120, including argument computation, alpha scaling, and instantiation macros.
  • tests/gemm/test_group_gemm_fp4.py
    • Added new unit tests for group_gemm_nvfp4_nt_groupwise, including reference implementation and quantization helper.
  • tests/gemm/test_groupwise_scaled_gemm_mxfp4.py
    • Updated test_mxfp8_mxfp4_groupwise_group_gemm to include SM12x compute capability and adjusted test parameters accordingly.
Activity
  • The pull request introduces new CUDA kernels and Python APIs for MXFP4 and NVFP4 group GEMMs on SM12x architectures.
  • New benchmark and test files have been added to validate the correctness and performance of the new functionalities.
  • Existing benchmark and test files have been updated to incorporate the new SM12x support for MXFP4 GEMMs.
  • Build configurations for JIT compilation have been modified to enable GDC for CUTLASS kernels across relevant SM versions.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for MXFP4 and NVFP4 group GEMMs on new NVIDIA architectures (Blackwell GeForce, DGX Spark). However, a security audit identified two high-severity integer overflow vulnerabilities in the CUDA kernels responsible for computing group-wise scaling arguments. These overflows occur during the calculation of scale factor offsets, leading to out-of-bounds memory accesses on the GPU. It is crucial to address these by using 64-bit integers for these calculations to ensure memory safety. Additionally, there are inconsistencies in the Python API docstrings, an opportunity to reduce code duplication in the CUDA headers, and a critical bug in a new test file that needs to be fixed.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/group_gemm_mxfp4_groupwise_sm120.cu`:
- Around line 42-50: The DISPATCH_TILE_K macro currently only handles tile_k ==
128 causing a failure for tile_k == 256; add a branch for tile_k == 256 that
defines constexpr int TILE_K = 256 and invokes the same lambda path (i.e.,
mirror the 128 case), and also add the corresponding kernel instantiation for
TILE_K=256 in the kernel template
csrc/group_gemm_mxfp4_groupwise_sm120_kernel_inst.jinja so the launcher and
kernel template both support tile_k=256 (refer to DISPATCH_TILE_K and the kernel
instantiation entries in group_gemm_mxfp4_groupwise_sm120_kernel_inst.jinja).

In `@csrc/group_gemm_nvfp4_groupwise_sm120.cu`:
- Around line 56-72: The fallback macro name is misspelled: when
FLASHINFER_ENABLE_FP8_E4M3 is off you define _DISPATCH_SF_CASE_FP8_E4M3 but the
switch uses _DISPATCH_SF_CASE_FP8_UE4M3, causing an undefined token; fix by
renaming the fallback definition to _DISPATCH_SF_CASE_FP8_UE4M3 with the same
parameter list (c_type, ...) and body (empty) so the
DISPATCH_DLPACK_DTYPE_TO_CTYPE_SF_UE4M3 switch (which calls encode_dlpack_dtype)
compiles correctly in the no-FP8-E4M3 build.

In `@flashinfer/gemm/gemm_base.py`:
- Around line 5122-5194: In _check_group_gemm_nvfp4_nt_groupwise_problem_size
validate alpha before FFI use: if alpha is not None and alpha.numel() > 0,
ensure alpha.dtype == torch.float32, alpha.is_contiguous() is True, alpha.device
is cpu (or explicitly state expected device if FFI requires GPU), and
alpha.numel() is either 1 or equals num_groups (computed from
m_indptr.shape[0]-1); raise descriptive ValueError mentioning alpha, num_groups,
and the expected dtype/shape/device when any check fails so the SM120 launcher
won't receive an invalid float* pointer.
- Around line 5077-5113: The current code can return an untouched output buffer
when neither is_sm12x_supported(a.device) nor is_sm100a_supported(a.device)
matches; update the control flow after those conditionals to handle the
unsupported case by raising a clear exception (e.g., RuntimeError) instead of
returning out silently. Specifically, in the enclosing function in gemm_base.py,
after the two if/elif blocks for is_sm12x_supported and is_sm100a_supported, add
an else branch that raises an error describing the device and that
group_gemm_mxfp4_nt_groupwise was not launched (reference is_sm12x_supported,
is_sm100a_supported, get_gemm_sm120_module().group_gemm_mxfp4_nt_groupwise and
get_gemm_sm100_module().group_gemm_mxfp4_nt_groupwise to locate the logic).
Ensure the message includes the device identifier (a.device) and any key
parameters (n, k, tile sizes) to aid debugging.
- Around line 5202-5247: Update the docstring for group_gemm_nvfp4_nt_groupwise
to match the actual implementation: change the parameter "a" to indicate it is
packed uint8 (torch.uint8) with shape (cum_m, k // 2) instead of float8
(torch.float8_...), remove any mention of float8 for "a"; update the "tile_n"
description to state only tile_n=128 is supported (remove 64,192,256 options);
keep/confirm other shape/type descriptions (b: torch.uint8, a_scale, b_scale)
but adjust any dependent shape text if it assumed unpacked k; ensure the types
for "a" and "tile_n" in the parameter list match the implementation in
group_gemm_nvfp4_nt_groupwise.

In `@tests/gemm/test_group_gemm_fp4.py`:
- Line 25: The test currently imports get_compute_capability but must be gated
with the runtime check is_sm12x_supported(); import is_sm12x_supported from
flashinfer.utils and skip the test when it returns False (e.g., with pytest.skip
or a pytest.mark.skipif using is_sm12x_supported()), ensuring the test module or
specific test functions for group_gemm_fp4 are not run on SM12x machines lacking
the required runtime/toolchain; update any other similar checks in this file
(references to get_compute_capability) to use is_sm12x_supported() as well.
- Around line 45-49: The test fails because _quantize_nvfp4_group_inputs
declares a third parameter m_indptr that is unused while callers (e.g., the call
on Line 105) pass only two args; remove the unused m_indptr parameter from the
function signature of _quantize_nvfp4_group_inputs (and update its type
annotation to accept only a_float: torch.Tensor and b_float: torch.Tensor),
delete any references to m_indptr inside the function body, and ensure the
returned tuple typing remains correct (tuple[torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor]) so existing callers continue to work.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f6a05ef6-35ce-4ed2-8b92-c2cd91358b8b

📥 Commits

Reviewing files that changed from the base of the PR and between fe06b91 and 508682e.

📒 Files selected for processing (14)
  • benchmarks/bench_groupwise_grouped_gemm_mxfp4_blackwell.py
  • benchmarks/bench_groupwise_grouped_gemm_nvfp4_blackwell_geforce.py
  • csrc/group_gemm_mxfp4_groupwise_sm120.cu
  • csrc/group_gemm_mxfp4_groupwise_sm120_kernel_inst.jinja
  • csrc/group_gemm_nvfp4_groupwise_sm120.cu
  • csrc/group_gemm_nvfp4_groupwise_sm120_kernel_inst.jinja
  • csrc/group_gemm_sm120_binding.cu
  • flashinfer/gemm/__init__.py
  • flashinfer/gemm/gemm_base.py
  • flashinfer/jit/gemm/core.py
  • include/flashinfer/gemm/group_gemm_mxfp4_groupwise_sm120.cuh
  • include/flashinfer/gemm/group_gemm_nvfp4_groupwise_sm120.cuh
  • tests/gemm/test_group_gemm_fp4.py
  • tests/gemm/test_groupwise_scaled_gemm_mxfp4.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
flashinfer/gemm/gemm_base.py (1)

5159-5188: ⚠️ Potential issue | 🟠 Major

Tighten alpha validation before forwarding its raw pointer.

This only checks dtype and shape[0]. A 0-D/2-D tensor, a non-contiguous view, or a tensor on the wrong device still reaches group_gemm_nvfp4_nt_groupwise() and is passed straight through to csrc/group_gemm_nvfp4_groupwise_sm120.cu, Lines 104-131, as float*, which will either misread the scales or hand the kernel a host pointer.

💡 Suggested guard
     if alpha is not None and alpha.dtype != torch.float32:
         raise ValueError(
             f"alpha must be a float32 tensor or None, but got {alpha.dtype}"
         )
+    if alpha is not None and alpha.device != a.device:
+        raise ValueError(f"alpha must be on {a.device}, but got {alpha.device}")
+    if alpha is not None and not alpha.is_contiguous():
+        raise ValueError("alpha must be contiguous")
+    if alpha is not None and alpha.ndim != 1:
+        raise ValueError(f"alpha must be 1D, but got shape {tuple(alpha.shape)}")
 ...
-    if alpha is not None and alpha.shape[0] != num_groups:
+    if alpha is not None and alpha.numel() not in (0, num_groups):
         raise ValueError(
-            f"alpha.shape[0] must equal num_groups, but got alpha.shape[0]={alpha.shape[0]}, num_groups={num_groups}"
+            f"alpha must be empty or have shape ({num_groups},), but got {tuple(alpha.shape)}"
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 5159 - 5188, The alpha validation
is incomplete: before forwarding alpha to group_gemm_nvfp4_nt_groupwise() (and
ultimately the CUDA kernel), ensure alpha is a 1-D float32 tensor located on the
correct device and contiguous (or explicitly make it so). Concretely, in the
block that already checks dtype and shape[0] (referencing alpha and num_groups),
add checks that alpha.dim() == 1 and alpha.is_contiguous() and alpha.device ==
m_indptr.device (or, if you prefer to accept non-contiguous/wrong-device
tensors, convert them: alpha =
alpha.to(m_indptr.device).contiguous().to(torch.float32)); raise a ValueError
with a clear message if dim or device are wrong, or perform the conversion so
the raw float* passed into group_gemm_nvfp4_nt_groupwise() is always a
contiguous device tensor.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 5159-5188: The alpha validation is incomplete: before forwarding
alpha to group_gemm_nvfp4_nt_groupwise() (and ultimately the CUDA kernel),
ensure alpha is a 1-D float32 tensor located on the correct device and
contiguous (or explicitly make it so). Concretely, in the block that already
checks dtype and shape[0] (referencing alpha and num_groups), add checks that
alpha.dim() == 1 and alpha.is_contiguous() and alpha.device == m_indptr.device
(or, if you prefer to accept non-contiguous/wrong-device tensors, convert them:
alpha = alpha.to(m_indptr.device).contiguous().to(torch.float32)); raise a
ValueError with a clear message if dim or device are wrong, or perform the
conversion so the raw float* passed into group_gemm_nvfp4_nt_groupwise() is
always a contiguous device tensor.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 6d38bade-b5ba-497a-b4fa-9132f0e7df6f

📥 Commits

Reviewing files that changed from the base of the PR and between 8f5fb16 and 9774f3a.

📒 Files selected for processing (4)
  • csrc/group_gemm_nvfp4_groupwise_sm120.cu
  • csrc/tvm_ffi_utils.h
  • flashinfer/gemm/gemm_base.py
  • tests/gemm/test_group_gemm_fp4.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/gemm/test_group_gemm_fp4.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/group_gemm_nvfp4_groupwise_sm120.cu`:
- Around line 101-102: The device guard and stream are sourced from different
tensors (device_guard uses float_workspace_buffer.device() while stream is
obtained via get_stream(A)), which can activate the wrong device; make them
consistent by using the same tensor as the source—e.g., initialize
ffi::CUDADeviceGuard with A.device().device_id (or alternatively get the stream
from float_workspace_buffer) so that device_guard and stream use the same
device.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 17dbe637-5a81-41cc-9ab6-75820d421630

📥 Commits

Reviewing files that changed from the base of the PR and between 9774f3a and f3ac8d3.

📒 Files selected for processing (2)
  • csrc/group_gemm_nvfp4_groupwise_sm120.cu
  • csrc/tvm_ffi_utils.h

@johnnynunez
Copy link
Copy Markdown
Contributor

thank you @depaulmillz

dePaul Miller and others added 6 commits March 10, 2026 09:26
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

♻️ Duplicate comments (2)
csrc/group_gemm_nvfp4_groupwise_sm120.cu (1)

101-102: ⚠️ Potential issue | 🟡 Minor

Device guard and stream may use different tensor sources.

Line 101 creates a device guard from float_workspace_buffer.device(), but line 102 retrieves the stream from A.device(). If these tensors reside on different devices (e.g., during multi-GPU operations), this could cause incorrect execution context.

Consider using the same tensor for both:

Suggested fix
-  ffi::CUDADeviceGuard device_guard(float_workspace_buffer.device().device_id);
-  auto stream = get_stream(A.device());
+  ffi::CUDADeviceGuard device_guard(A.device().device_id);
+  auto stream = get_stream(A.device());
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/group_gemm_nvfp4_groupwise_sm120.cu` around lines 101 - 102, The code
uses ffi::CUDADeviceGuard constructed from float_workspace_buffer.device() but
calls get_stream(A.device()), which can mismatch devices; change to use the same
tensor/device for both operations (e.g., construct ffi::CUDADeviceGuard with
A.device() and call get_stream(A.device()), or vice versa) so the device guard
and stream source (float_workspace_buffer or A) are consistent; update the usage
of ffi::CUDADeviceGuard and get_stream to reference the same tensor (A or
float_workspace_buffer) throughout.
flashinfer/gemm/gemm_base.py (1)

5159-5217: ⚠️ Potential issue | 🟠 Major

Reject non-flat or cross-device buffers before handing them to FFI.

A CPU/non-contiguous/wrong-shaped alpha still passes here, and out.device is never checked. The SM120 NVFP4 path later treats both tensors as raw device buffers, so a bad user tensor becomes an invalid pointer or scrambled per-group scales instead of a clean Python error.

🛠️ Proposed fix
     if alpha is not None and alpha.dtype != torch.float32:
         raise ValueError(
             f"alpha must be a float32 tensor or None, but got {alpha.dtype}"
         )
+    if alpha is not None and alpha.device != a.device:
+        raise ValueError(f"alpha must be on {a.device}, but got {alpha.device}")
+    if alpha is not None and not alpha.is_contiguous():
+        raise ValueError("alpha must be contiguous")
@@
     num_groups = m_indptr.shape[0] - 1

-    if alpha is not None and alpha.shape[0] != num_groups:
+    if alpha is not None and alpha.shape != (num_groups,):
         raise ValueError(
-            f"alpha.shape[0] must equal num_groups, but got alpha.shape[0]={alpha.shape[0]}, num_groups={num_groups}"
+            f"alpha must have shape ({num_groups},), but got {tuple(alpha.shape)}"
         )
@@
     out_shape = (a.shape[0], n)
     if out is not None:
         if out.shape != out_shape:
             raise ValueError(f"out.shape must be {out_shape}, but got {out.shape}")
+        if out.device != a.device:
+            raise ValueError(f"out must be on {a.device}, but got {out.device}")
         if out.dtype != out_dtype:
             raise ValueError(f"out.dtype must be {out_dtype}, but got {out.dtype}")

Run this to confirm the current path still forwards alpha as a raw pointer without extra normalization:

#!/bin/bash
set -euo pipefail

echo "=== Python-side NVFP4 validation ==="
sed -n '5158,5217p' flashinfer/gemm/gemm_base.py

echo
echo "=== SM120 binding signature ==="
sed -n '26,38p' csrc/group_gemm_sm120_binding.cu

echo
echo "=== CUTLASS epilogue alpha pointer wiring ==="
sed -n '247,252p' include/flashinfer/gemm/group_gemm_nvfp4_groupwise_sm120.cuh
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 5159 - 5217, Reject non-flat or
cross-device buffers before FFI: validate that alpha (if not None) is a 1D,
contiguous torch.Tensor with dtype torch.float32, alpha.shape[0] == num_groups
and alpha.device matches the device used for computation (same device as b/a);
validate out (if provided) is on the same device as a/b, is contiguous, has
shape (a.shape[0], n) and dtype out_dtype; raise clear ValueErrors for
non-tensor, non-contiguous, wrong-dtype, wrong-dim, or cross-device cases so raw
pointers passed to the SM120 NVFP4 path are always flat, correctly-typed, and
device-local.
🧹 Nitpick comments (3)
csrc/group_gemm_mxfp4_groupwise_sm120_kernel_inst.jinja (1)

53-54: Remove extraneous semicolons after namespace closing braces.

The semicolons after the closing braces are unnecessary and unconventional in C++.

Suggested fix
-};  // namespace group_gemm
-};  // namespace flashinfer
+}  // namespace group_gemm
+}  // namespace flashinfer
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/group_gemm_mxfp4_groupwise_sm120_kernel_inst.jinja` around lines 53 -
54, Remove the extraneous semicolons following the closing namespace braces for
the namespaces group_gemm and flashinfer: locate the closing braces for
namespace group_gemm and namespace flashinfer and delete the trailing ';'
characters so the namespace endings read simply "}" without semicolons.
csrc/group_gemm_nvfp4_groupwise_sm120_kernel_inst.jinja (1)

53-54: Remove extraneous semicolons after namespace closing braces.

Same as the MXFP4 template - the semicolons after closing braces are unnecessary.

Suggested fix
-};  // namespace group_gemm
-};  // namespace flashinfer
+}  // namespace group_gemm
+}  // namespace flashinfer
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/group_gemm_nvfp4_groupwise_sm120_kernel_inst.jinja` around lines 53 -
54, The file ends namespace blocks with extraneous semicolons; remove the
trailing semicolons after the closing braces for the namespaces 'group_gemm' and
'flashinfer' so the two lines "};  // namespace group_gemm" and "};  //
namespace flashinfer" become "}" comments preserved — update the lines that
close the namespaces group_gemm and flashinfer to drop the unnecessary
semicolons.
benchmarks/bench_groupwise_grouped_gemm_nvfp4_blackwell_geforce.py (1)

74-89: Lambda captures loop variables by reference - safe but fragile.

The lambda passed to bench_gpu_time captures tile_m, tile_n, and tile_k by reference. While this works correctly because the lambda is executed immediately within the same loop iteration, it's a pattern that can cause subtle bugs if the code is refactored.

Consider using default argument binding to capture by value:

Suggested fix
         measurements = bench_gpu_time(
-            lambda: flashinfer.gemm.group_gemm_nvfp4_nt_groupwise(
+            lambda tile_m=tile_m, tile_n=tile_n, tile_k=tile_k: flashinfer.gemm.group_gemm_nvfp4_nt_groupwise(
                 a,
                 b,
                 a_scale,
                 b_scale,
                 segment_offsets,
                 out=out,
                 tile_m=tile_m,
                 tile_n=tile_n,
                 tile_k=tile_k,
             ),
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_groupwise_grouped_gemm_nvfp4_blackwell_geforce.py` around
lines 74 - 89, The lambda passed into bench_gpu_time closes over loop variables
tile_m, tile_n, tile_k by reference which is fragile; update the call so the
lambda captures these values by value (e.g., use default-argument binding:
lambda tile_m=tile_m, tile_n=tile_n, tile_k=tile_k:
flashinfer.gemm.group_gemm_nvfp4_nt_groupwise(a, b, a_scale, b_scale,
segment_offsets, out=out, tile_m=tile_m, tile_n=tile_n, tile_k=tile_k)) or use
functools.partial to bind the parameters before passing to bench_gpu_time to
ensure stable behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 4928-4948: The public docs for
group_gemm_mxfp8_mxfp4_nt_groupwise() are out of sync with runtime checks
(is_sm12x_supported) which now restrict mma_sm, tile_n, and tile_k ranges;
update the documentation for group_gemm_mxfp8_mxfp4_nt_groupwise to list the
exact allowed values used in the code (for SM12x: mma_sm == 1, tile_m == 128,
tile_n == 128, tile_k == 128; otherwise: mma_sm in {1,2}, tile_m == 128, tile_n
in {64,128,192,256}, tile_k in {128,256}) so users aren’t misled and will avoid
ValueError at runtime.

In `@include/flashinfer/gemm/group_gemm_mxfp4_groupwise_sm120.cuh`:
- Around line 221-223: The KernelHardwareInfo instance sets hw_info.device_id =
0 which breaks multi-GPU setups; update the code that constructs
cutlass::KernelHardwareInfo (hw_info) to use the actual CUDA device instead of
hardcoding 0—either query the current device with cudaGetDevice() (or equivalent
helper) and assign that to hw_info.device_id, or change the calling signature to
accept and forward a device_id parameter so sm_count lookup uses the correct
device; ensure hw_info.sm_count still uses sm_count but comes from the chosen
device context.

In `@include/flashinfer/gemm/group_gemm_nvfp4_groupwise_sm120.cuh`:
- Around line 175-203: The code fails when num_groups == 0 because num_threads =
std::min(num_groups, 1024) becomes 0 and num_blocks divides by zero; add an
early-return guard that checks num_groups == 0 before computing
num_threads/num_blocks (and ideally before heavy allocations) to short-circuit
the grouped GEMM path. Locate the allocation/launch setup in
group_gemm_nvfp4_groupwise_sm120 (or the surrounding helper that uses
AlignedAllocator and calls allocator.aligned_alloc) and insert a simple check
for num_groups == 0 that returns success/does nothing so num_threads and
num_blocks are never computed or used. Ensure the guard references num_groups,
num_threads, and num_blocks to prevent the divide-by-zero.
- Around line 221-225: Remove the thread_local caching of sm_count and the
hardcoded hw_info.device_id = 0; instead, at launch time obtain the current
device and fresh SM count: call cudaGetDevice(&hw_info.device_id) and then set
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id).
Update the code around the sm_count variable and hw_info initialization
(symbols: sm_count, hw_info,
cutlass::KernelHardwareInfo::query_device_multiprocessor_count) in the affected
group_gemm_*_sm120.cuh and the other listed group_gemm files so each launch
queries the current device rather than using a thread-local cached value.

---

Duplicate comments:
In `@csrc/group_gemm_nvfp4_groupwise_sm120.cu`:
- Around line 101-102: The code uses ffi::CUDADeviceGuard constructed from
float_workspace_buffer.device() but calls get_stream(A.device()), which can
mismatch devices; change to use the same tensor/device for both operations
(e.g., construct ffi::CUDADeviceGuard with A.device() and call
get_stream(A.device()), or vice versa) so the device guard and stream source
(float_workspace_buffer or A) are consistent; update the usage of
ffi::CUDADeviceGuard and get_stream to reference the same tensor (A or
float_workspace_buffer) throughout.

In `@flashinfer/gemm/gemm_base.py`:
- Around line 5159-5217: Reject non-flat or cross-device buffers before FFI:
validate that alpha (if not None) is a 1D, contiguous torch.Tensor with dtype
torch.float32, alpha.shape[0] == num_groups and alpha.device matches the device
used for computation (same device as b/a); validate out (if provided) is on the
same device as a/b, is contiguous, has shape (a.shape[0], n) and dtype
out_dtype; raise clear ValueErrors for non-tensor, non-contiguous, wrong-dtype,
wrong-dim, or cross-device cases so raw pointers passed to the SM120 NVFP4 path
are always flat, correctly-typed, and device-local.

---

Nitpick comments:
In `@benchmarks/bench_groupwise_grouped_gemm_nvfp4_blackwell_geforce.py`:
- Around line 74-89: The lambda passed into bench_gpu_time closes over loop
variables tile_m, tile_n, tile_k by reference which is fragile; update the call
so the lambda captures these values by value (e.g., use default-argument
binding: lambda tile_m=tile_m, tile_n=tile_n, tile_k=tile_k:
flashinfer.gemm.group_gemm_nvfp4_nt_groupwise(a, b, a_scale, b_scale,
segment_offsets, out=out, tile_m=tile_m, tile_n=tile_n, tile_k=tile_k)) or use
functools.partial to bind the parameters before passing to bench_gpu_time to
ensure stable behavior.

In `@csrc/group_gemm_mxfp4_groupwise_sm120_kernel_inst.jinja`:
- Around line 53-54: Remove the extraneous semicolons following the closing
namespace braces for the namespaces group_gemm and flashinfer: locate the
closing braces for namespace group_gemm and namespace flashinfer and delete the
trailing ';' characters so the namespace endings read simply "}" without
semicolons.

In `@csrc/group_gemm_nvfp4_groupwise_sm120_kernel_inst.jinja`:
- Around line 53-54: The file ends namespace blocks with extraneous semicolons;
remove the trailing semicolons after the closing braces for the namespaces
'group_gemm' and 'flashinfer' so the two lines "};  // namespace group_gemm" and
"};  // namespace flashinfer" become "}" comments preserved — update the lines
that close the namespaces group_gemm and flashinfer to drop the unnecessary
semicolons.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 5536fc04-9019-4d3c-ab8b-a3a8e0318072

📥 Commits

Reviewing files that changed from the base of the PR and between f3ac8d3 and 7182f84.

📒 Files selected for processing (15)
  • benchmarks/bench_groupwise_grouped_gemm_mxfp4_blackwell.py
  • benchmarks/bench_groupwise_grouped_gemm_nvfp4_blackwell_geforce.py
  • csrc/group_gemm_mxfp4_groupwise_sm120.cu
  • csrc/group_gemm_mxfp4_groupwise_sm120_kernel_inst.jinja
  • csrc/group_gemm_nvfp4_groupwise_sm120.cu
  • csrc/group_gemm_nvfp4_groupwise_sm120_kernel_inst.jinja
  • csrc/group_gemm_sm120_binding.cu
  • csrc/tvm_ffi_utils.h
  • flashinfer/gemm/__init__.py
  • flashinfer/gemm/gemm_base.py
  • flashinfer/jit/gemm/core.py
  • include/flashinfer/gemm/group_gemm_mxfp4_groupwise_sm120.cuh
  • include/flashinfer/gemm/group_gemm_nvfp4_groupwise_sm120.cuh
  • tests/gemm/test_group_gemm_fp4.py
  • tests/gemm/test_groupwise_scaled_gemm_mxfp4.py
🚧 Files skipped from review as they are similar to previous changes (4)
  • flashinfer/gemm/init.py
  • tests/gemm/test_group_gemm_fp4.py
  • benchmarks/bench_groupwise_grouped_gemm_mxfp4_blackwell.py
  • csrc/tvm_ffi_utils.h

@yongwww yongwww added the run-ci label Mar 11, 2026
@yongwww
Copy link
Copy Markdown
Member

yongwww commented Mar 11, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !401 has been created, and the CI pipeline #45897092 is currently running. I'll report back once the pipeline job completes.

@depaulmillz
Copy link
Copy Markdown
Contributor Author

@coderabbitai review

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 12, 2026

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
tests/gemm/test_group_gemm_fp4.py (1)

45-49: ⚠️ Potential issue | 🟡 Minor

Fix return type annotation to match actual return value.

The function returns 5 tensors (a_fp4, b_fp4, a_scale_padded, b_scale, alpha) but the annotation specifies only 4.

💡 Suggested fix
 def _quantize_nvfp4_group_inputs(
     a_float: torch.Tensor,
     b_float: torch.Tensor,
     m_indptr: torch.Tensor,
-) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gemm/test_group_gemm_fp4.py` around lines 45 - 49, The return type
annotation for _quantize_nvfp4_group_inputs is incorrect: the function actually
returns five tensors (a_fp4, b_fp4, a_scale_padded, b_scale, alpha) but the
signature declares only four; update the function's return annotation to
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] (or
an equivalent 5-tuple type) so it matches the actual returned values and helps
type-checkers and readers find the correct symbols (a_fp4, b_fp4,
a_scale_padded, b_scale, alpha).
🧹 Nitpick comments (3)
csrc/group_gemm_mxfp4_groupwise_sm120_kernel_inst.jinja (1)

53-54: Remove trailing semicolons after namespace closing braces.

The semicolons after the closing braces are valid C++ (empty statements) but unconventional. Standard style omits them.

💡 Suggested fix
-};  // namespace group_gemm
-};  // namespace flashinfer
+}  // namespace group_gemm
+}  // namespace flashinfer
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/group_gemm_mxfp4_groupwise_sm120_kernel_inst.jinja` around lines 53 -
54, Remove the unnecessary trailing semicolons after the namespace closing
braces: locate the closing braces for namespace group_gemm and namespace
flashinfer in the template (symbols "namespace group_gemm" and "namespace
flashinfer") and delete the semicolons that follow the closing '}' characters so
the file ends with plain closing braces rather than '};'.
benchmarks/bench_groupwise_grouped_gemm_nvfp4_blackwell_geforce.py (2)

74-86: Lambda captures loop variables by reference.

The lambda on lines 76-86 captures tile_m, tile_n, and tile_k by reference. While this works correctly here because bench_gpu_time executes the lambda immediately, consider using default arguments to bind the values explicitly for robustness.

💡 Suggested fix
         measurements = bench_gpu_time(
-            lambda: flashinfer.gemm.group_gemm_nvfp4_nt_groupwise(
+            lambda tm=tile_m, tn=tile_n, tk=tile_k: flashinfer.gemm.group_gemm_nvfp4_nt_groupwise(
                 a,
                 b,
                 a_scale,
                 b_scale,
                 segment_offsets,
                 out=out,
-                tile_m=tile_m,
-                tile_n=tile_n,
-                tile_k=tile_k,
+                tile_m=tm,
+                tile_n=tn,
+                tile_k=tk,
             ),
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_groupwise_grouped_gemm_nvfp4_blackwell_geforce.py` around
lines 74 - 86, The lambda passed to bench_gpu_time captures loop variables
tile_m, tile_n, tile_k by reference which can lead to late-binding bugs; change
the call so the lambda binds current loop values as defaults (e.g., lambda
tile_m=tile_m, tile_n=tile_n, tile_k=tile_k:
flashinfer.gemm.group_gemm_nvfp4_nt_groupwise(...)) when invoking
flashinfer.gemm.group_gemm_nvfp4_nt_groupwise inside bench_gpu_time to ensure
the correct tile parameters are used.

26-103: Consider adding a GPU capability check before benchmarking.

The MXFP4 benchmark (bench_groupwise_grouped_gemm_mxfp4_blackwell.py) includes a runtime capability check. Adding a similar check here would prevent confusing errors when running on unsupported GPUs.

💡 Suggested addition
 def bench_groupwise_grouped_gemm_nvfp4_blackwell(group_size, m, n, k, out_dtype):
+    from flashinfer.utils import get_compute_capability
+    compute_capability = get_compute_capability(torch.device("cuda"))
+    if compute_capability[0] not in [12]:
+        print("group_gemm_nvfp4_nt_groupwise is only supported on SM120/SM121 GPUs.")
+        return
     torch.random.manual_seed(0)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_groupwise_grouped_gemm_nvfp4_blackwell_geforce.py` around
lines 26 - 103, The benchmark function
bench_groupwise_grouped_gemm_nvfp4_blackwell should guard against running on
unsupported GPUs: before seeding/random tensors or calling
flashinfer.gemm.group_gemm_nvfp4_nt_groupwise, add a runtime capability check
(same style as in bench_groupwise_grouped_gemm_mxfp4_blackwell.py) that queries
CUDA device properties (compute capability or a provided flashinfer capability
check) and early-returns or prints a skip message if NVFP4/Blackwell features
are not available; place this check at the top of
bench_groupwise_grouped_gemm_nvfp4_blackwell so the rest of the setup (tensor
allocation, a_scale/b_scale, and the benchmarking loop) is skipped on
incompatible hardware.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@tests/gemm/test_group_gemm_fp4.py`:
- Around line 45-49: The return type annotation for _quantize_nvfp4_group_inputs
is incorrect: the function actually returns five tensors (a_fp4, b_fp4,
a_scale_padded, b_scale, alpha) but the signature declares only four; update the
function's return annotation to tuple[torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor] (or an equivalent 5-tuple type) so it matches the
actual returned values and helps type-checkers and readers find the correct
symbols (a_fp4, b_fp4, a_scale_padded, b_scale, alpha).

---

Nitpick comments:
In `@benchmarks/bench_groupwise_grouped_gemm_nvfp4_blackwell_geforce.py`:
- Around line 74-86: The lambda passed to bench_gpu_time captures loop variables
tile_m, tile_n, tile_k by reference which can lead to late-binding bugs; change
the call so the lambda binds current loop values as defaults (e.g., lambda
tile_m=tile_m, tile_n=tile_n, tile_k=tile_k:
flashinfer.gemm.group_gemm_nvfp4_nt_groupwise(...)) when invoking
flashinfer.gemm.group_gemm_nvfp4_nt_groupwise inside bench_gpu_time to ensure
the correct tile parameters are used.
- Around line 26-103: The benchmark function
bench_groupwise_grouped_gemm_nvfp4_blackwell should guard against running on
unsupported GPUs: before seeding/random tensors or calling
flashinfer.gemm.group_gemm_nvfp4_nt_groupwise, add a runtime capability check
(same style as in bench_groupwise_grouped_gemm_mxfp4_blackwell.py) that queries
CUDA device properties (compute capability or a provided flashinfer capability
check) and early-returns or prints a skip message if NVFP4/Blackwell features
are not available; place this check at the top of
bench_groupwise_grouped_gemm_nvfp4_blackwell so the rest of the setup (tensor
allocation, a_scale/b_scale, and the benchmarking loop) is skipped on
incompatible hardware.

In `@csrc/group_gemm_mxfp4_groupwise_sm120_kernel_inst.jinja`:
- Around line 53-54: Remove the unnecessary trailing semicolons after the
namespace closing braces: locate the closing braces for namespace group_gemm and
namespace flashinfer in the template (symbols "namespace group_gemm" and
"namespace flashinfer") and delete the semicolons that follow the closing '}'
characters so the file ends with plain closing braces rather than '};'.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 8b2f1c76-621c-4242-96f4-4c93b3ec55c3

📥 Commits

Reviewing files that changed from the base of the PR and between f3ac8d3 and 6659ad0.

📒 Files selected for processing (17)
  • benchmarks/bench_groupwise_grouped_gemm_mxfp4_blackwell.py
  • benchmarks/bench_groupwise_grouped_gemm_nvfp4_blackwell_geforce.py
  • csrc/group_gemm_mxfp4_groupwise_sm120.cu
  • csrc/group_gemm_mxfp4_groupwise_sm120_kernel_inst.jinja
  • csrc/group_gemm_nvfp4_groupwise_sm120.cu
  • csrc/group_gemm_nvfp4_groupwise_sm120_kernel_inst.jinja
  • csrc/group_gemm_sm120_binding.cu
  • csrc/tvm_ffi_utils.h
  • flashinfer/gemm/__init__.py
  • flashinfer/gemm/gemm_base.py
  • flashinfer/jit/gemm/core.py
  • include/flashinfer/gemm/fp4_gemm_template_sm120.h
  • include/flashinfer/gemm/group_gemm_fp8_groupwise_sm120.cuh
  • include/flashinfer/gemm/group_gemm_mxfp4_groupwise_sm120.cuh
  • include/flashinfer/gemm/group_gemm_nvfp4_groupwise_sm120.cuh
  • tests/gemm/test_group_gemm_fp4.py
  • tests/gemm/test_groupwise_scaled_gemm_mxfp4.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • csrc/group_gemm_nvfp4_groupwise_sm120_kernel_inst.jinja
  • csrc/group_gemm_sm120_binding.cu

@depaulmillz
Copy link
Copy Markdown
Contributor Author

@coderabbitai resume

@aleozlx aleozlx self-assigned this Mar 17, 2026
@kahyunnam
Copy link
Copy Markdown
Collaborator

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !401 has been updated with latest changes, and the CI pipeline #46400558 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #46400558: 0/20 passed

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Mar 20, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !401 has been created, and the CI pipeline #46568840 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[SUCCESS] Pipeline #46568840: 14/20 passed

@aleozlx aleozlx enabled auto-merge (squash) March 25, 2026 16:32
@aleozlx aleozlx added run-ci and removed run-ci labels Mar 25, 2026
@johnnynunez
Copy link
Copy Markdown
Contributor

@aleozlx i ping again that this solves a lot of pains from users related with nvfp4

@eugr
Copy link
Copy Markdown

eugr commented Mar 27, 2026

@depaulmillz - Looks like it needs some changes to be merged cleanly. EDIT: disregard, tried to apply as a patch. Merges without any issues, now building on DGX Spark...

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Mar 28, 2026

thanks for the ping. restarting CI

@aleozlx aleozlx merged commit 904fa8c into flashinfer-ai:main Mar 28, 2026
28 of 29 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants