Support for MXFP4 and NVFP4 group GEMMs on GeForce and Spark#2738
Support for MXFP4 and NVFP4 group GEMMs on GeForce and Spark#2738aleozlx merged 13 commits intoflashinfer-ai:mainfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds SM12x (SM120/121) group-wise GEMM support (NVFP4 and MXFP4): new CUTLASS kernels, headers, Jinja instantiations, C++ FFI bindings, Python API + validation/dispatch, tests, and benchmarks; updates JIT build flags and runtime capability–dependent tile/dtype selection. Changes
Sequence Diagram(s)sequenceDiagram
participant Py as Python Caller
participant API as gemm_base.py
participant FFI as group_gemm_sm120_binding.cu
participant Dispatch as DLPack & tile dispatch
participant CUTLASS as CUTLASS templated kernel
participant GPU as CUDA SM120
Py->>API: group_gemm_nvfp4_nt_groupwise(a,b,a_scale,b_scale,...)
API->>API: validate shapes/dtypes, check is_sm12x_supported
alt SM12x supported
API->>FFI: call CutlassGroupGemm...SM120(...)
else
API->>FFI: call fallback SM100 entry
end
FFI->>Dispatch: dispatch by DLPack dtypes and tile sizes
Dispatch->>Dispatch: is_valid_config checks (dtype/tile)
alt valid config
Dispatch->>CUTLASS: invoke templated Cutlass...GroupGEMM<...>
CUTLASS->>GPU: launch grouped GEMM kernels
GPU-->>CUTLASS: complete
else invalid
Dispatch-->>FFI: return error
end
FFI-->>API: return result tensor
API-->>Py: deliver output
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly expands FlashInfer's capabilities by introducing support for MXFP4 and NVFP4 group GEMMs on NVIDIA's latest Blackwell GeForce and DGX Spark architectures. It provides new, optimized kernels for these floating-point formats, ensuring efficient computation on modern hardware. Additionally, the change addresses functional correctness by adjusting GDC settings within CUTLASS kernels, enhancing the robustness of the library. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request adds support for MXFP4 and NVFP4 group GEMMs on new NVIDIA architectures (Blackwell GeForce, DGX Spark). However, a security audit identified two high-severity integer overflow vulnerabilities in the CUDA kernels responsible for computing group-wise scaling arguments. These overflows occur during the calculation of scale factor offsets, leading to out-of-bounds memory accesses on the GPU. It is crucial to address these by using 64-bit integers for these calculations to ensure memory safety. Additionally, there are inconsistencies in the Python API docstrings, an opportunity to reduce code duplication in the CUDA headers, and a critical bug in a new test file that needs to be fixed.
There was a problem hiding this comment.
Actionable comments posted: 8
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/group_gemm_mxfp4_groupwise_sm120.cu`:
- Around line 42-50: The DISPATCH_TILE_K macro currently only handles tile_k ==
128 causing a failure for tile_k == 256; add a branch for tile_k == 256 that
defines constexpr int TILE_K = 256 and invokes the same lambda path (i.e.,
mirror the 128 case), and also add the corresponding kernel instantiation for
TILE_K=256 in the kernel template
csrc/group_gemm_mxfp4_groupwise_sm120_kernel_inst.jinja so the launcher and
kernel template both support tile_k=256 (refer to DISPATCH_TILE_K and the kernel
instantiation entries in group_gemm_mxfp4_groupwise_sm120_kernel_inst.jinja).
In `@csrc/group_gemm_nvfp4_groupwise_sm120.cu`:
- Around line 56-72: The fallback macro name is misspelled: when
FLASHINFER_ENABLE_FP8_E4M3 is off you define _DISPATCH_SF_CASE_FP8_E4M3 but the
switch uses _DISPATCH_SF_CASE_FP8_UE4M3, causing an undefined token; fix by
renaming the fallback definition to _DISPATCH_SF_CASE_FP8_UE4M3 with the same
parameter list (c_type, ...) and body (empty) so the
DISPATCH_DLPACK_DTYPE_TO_CTYPE_SF_UE4M3 switch (which calls encode_dlpack_dtype)
compiles correctly in the no-FP8-E4M3 build.
In `@flashinfer/gemm/gemm_base.py`:
- Around line 5122-5194: In _check_group_gemm_nvfp4_nt_groupwise_problem_size
validate alpha before FFI use: if alpha is not None and alpha.numel() > 0,
ensure alpha.dtype == torch.float32, alpha.is_contiguous() is True, alpha.device
is cpu (or explicitly state expected device if FFI requires GPU), and
alpha.numel() is either 1 or equals num_groups (computed from
m_indptr.shape[0]-1); raise descriptive ValueError mentioning alpha, num_groups,
and the expected dtype/shape/device when any check fails so the SM120 launcher
won't receive an invalid float* pointer.
- Around line 5077-5113: The current code can return an untouched output buffer
when neither is_sm12x_supported(a.device) nor is_sm100a_supported(a.device)
matches; update the control flow after those conditionals to handle the
unsupported case by raising a clear exception (e.g., RuntimeError) instead of
returning out silently. Specifically, in the enclosing function in gemm_base.py,
after the two if/elif blocks for is_sm12x_supported and is_sm100a_supported, add
an else branch that raises an error describing the device and that
group_gemm_mxfp4_nt_groupwise was not launched (reference is_sm12x_supported,
is_sm100a_supported, get_gemm_sm120_module().group_gemm_mxfp4_nt_groupwise and
get_gemm_sm100_module().group_gemm_mxfp4_nt_groupwise to locate the logic).
Ensure the message includes the device identifier (a.device) and any key
parameters (n, k, tile sizes) to aid debugging.
- Around line 5202-5247: Update the docstring for group_gemm_nvfp4_nt_groupwise
to match the actual implementation: change the parameter "a" to indicate it is
packed uint8 (torch.uint8) with shape (cum_m, k // 2) instead of float8
(torch.float8_...), remove any mention of float8 for "a"; update the "tile_n"
description to state only tile_n=128 is supported (remove 64,192,256 options);
keep/confirm other shape/type descriptions (b: torch.uint8, a_scale, b_scale)
but adjust any dependent shape text if it assumed unpacked k; ensure the types
for "a" and "tile_n" in the parameter list match the implementation in
group_gemm_nvfp4_nt_groupwise.
In `@tests/gemm/test_group_gemm_fp4.py`:
- Line 25: The test currently imports get_compute_capability but must be gated
with the runtime check is_sm12x_supported(); import is_sm12x_supported from
flashinfer.utils and skip the test when it returns False (e.g., with pytest.skip
or a pytest.mark.skipif using is_sm12x_supported()), ensuring the test module or
specific test functions for group_gemm_fp4 are not run on SM12x machines lacking
the required runtime/toolchain; update any other similar checks in this file
(references to get_compute_capability) to use is_sm12x_supported() as well.
- Around line 45-49: The test fails because _quantize_nvfp4_group_inputs
declares a third parameter m_indptr that is unused while callers (e.g., the call
on Line 105) pass only two args; remove the unused m_indptr parameter from the
function signature of _quantize_nvfp4_group_inputs (and update its type
annotation to accept only a_float: torch.Tensor and b_float: torch.Tensor),
delete any references to m_indptr inside the function body, and ensure the
returned tuple typing remains correct (tuple[torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor]) so existing callers continue to work.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f6a05ef6-35ce-4ed2-8b92-c2cd91358b8b
📒 Files selected for processing (14)
benchmarks/bench_groupwise_grouped_gemm_mxfp4_blackwell.pybenchmarks/bench_groupwise_grouped_gemm_nvfp4_blackwell_geforce.pycsrc/group_gemm_mxfp4_groupwise_sm120.cucsrc/group_gemm_mxfp4_groupwise_sm120_kernel_inst.jinjacsrc/group_gemm_nvfp4_groupwise_sm120.cucsrc/group_gemm_nvfp4_groupwise_sm120_kernel_inst.jinjacsrc/group_gemm_sm120_binding.cuflashinfer/gemm/__init__.pyflashinfer/gemm/gemm_base.pyflashinfer/jit/gemm/core.pyinclude/flashinfer/gemm/group_gemm_mxfp4_groupwise_sm120.cuhinclude/flashinfer/gemm/group_gemm_nvfp4_groupwise_sm120.cuhtests/gemm/test_group_gemm_fp4.pytests/gemm/test_groupwise_scaled_gemm_mxfp4.py
There was a problem hiding this comment.
♻️ Duplicate comments (1)
flashinfer/gemm/gemm_base.py (1)
5159-5188:⚠️ Potential issue | 🟠 MajorTighten
alphavalidation before forwarding its raw pointer.This only checks
dtypeandshape[0]. A 0-D/2-D tensor, a non-contiguous view, or a tensor on the wrong device still reachesgroup_gemm_nvfp4_nt_groupwise()and is passed straight through tocsrc/group_gemm_nvfp4_groupwise_sm120.cu, Lines 104-131, asfloat*, which will either misread the scales or hand the kernel a host pointer.💡 Suggested guard
if alpha is not None and alpha.dtype != torch.float32: raise ValueError( f"alpha must be a float32 tensor or None, but got {alpha.dtype}" ) + if alpha is not None and alpha.device != a.device: + raise ValueError(f"alpha must be on {a.device}, but got {alpha.device}") + if alpha is not None and not alpha.is_contiguous(): + raise ValueError("alpha must be contiguous") + if alpha is not None and alpha.ndim != 1: + raise ValueError(f"alpha must be 1D, but got shape {tuple(alpha.shape)}") ... - if alpha is not None and alpha.shape[0] != num_groups: + if alpha is not None and alpha.numel() not in (0, num_groups): raise ValueError( - f"alpha.shape[0] must equal num_groups, but got alpha.shape[0]={alpha.shape[0]}, num_groups={num_groups}" + f"alpha must be empty or have shape ({num_groups},), but got {tuple(alpha.shape)}" )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 5159 - 5188, The alpha validation is incomplete: before forwarding alpha to group_gemm_nvfp4_nt_groupwise() (and ultimately the CUDA kernel), ensure alpha is a 1-D float32 tensor located on the correct device and contiguous (or explicitly make it so). Concretely, in the block that already checks dtype and shape[0] (referencing alpha and num_groups), add checks that alpha.dim() == 1 and alpha.is_contiguous() and alpha.device == m_indptr.device (or, if you prefer to accept non-contiguous/wrong-device tensors, convert them: alpha = alpha.to(m_indptr.device).contiguous().to(torch.float32)); raise a ValueError with a clear message if dim or device are wrong, or perform the conversion so the raw float* passed into group_gemm_nvfp4_nt_groupwise() is always a contiguous device tensor.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 5159-5188: The alpha validation is incomplete: before forwarding
alpha to group_gemm_nvfp4_nt_groupwise() (and ultimately the CUDA kernel),
ensure alpha is a 1-D float32 tensor located on the correct device and
contiguous (or explicitly make it so). Concretely, in the block that already
checks dtype and shape[0] (referencing alpha and num_groups), add checks that
alpha.dim() == 1 and alpha.is_contiguous() and alpha.device == m_indptr.device
(or, if you prefer to accept non-contiguous/wrong-device tensors, convert them:
alpha = alpha.to(m_indptr.device).contiguous().to(torch.float32)); raise a
ValueError with a clear message if dim or device are wrong, or perform the
conversion so the raw float* passed into group_gemm_nvfp4_nt_groupwise() is
always a contiguous device tensor.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 6d38bade-b5ba-497a-b4fa-9132f0e7df6f
📒 Files selected for processing (4)
csrc/group_gemm_nvfp4_groupwise_sm120.cucsrc/tvm_ffi_utils.hflashinfer/gemm/gemm_base.pytests/gemm/test_group_gemm_fp4.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/gemm/test_group_gemm_fp4.py
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/group_gemm_nvfp4_groupwise_sm120.cu`:
- Around line 101-102: The device guard and stream are sourced from different
tensors (device_guard uses float_workspace_buffer.device() while stream is
obtained via get_stream(A)), which can activate the wrong device; make them
consistent by using the same tensor as the source—e.g., initialize
ffi::CUDADeviceGuard with A.device().device_id (or alternatively get the stream
from float_workspace_buffer) so that device_guard and stream use the same
device.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 17dbe637-5a81-41cc-9ab6-75820d421630
📒 Files selected for processing (2)
csrc/group_gemm_nvfp4_groupwise_sm120.cucsrc/tvm_ffi_utils.h
|
thank you @depaulmillz |
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
f3ac8d3 to
7182f84
Compare
There was a problem hiding this comment.
Actionable comments posted: 4
♻️ Duplicate comments (2)
csrc/group_gemm_nvfp4_groupwise_sm120.cu (1)
101-102:⚠️ Potential issue | 🟡 MinorDevice guard and stream may use different tensor sources.
Line 101 creates a device guard from
float_workspace_buffer.device(), but line 102 retrieves the stream fromA.device(). If these tensors reside on different devices (e.g., during multi-GPU operations), this could cause incorrect execution context.Consider using the same tensor for both:
Suggested fix
- ffi::CUDADeviceGuard device_guard(float_workspace_buffer.device().device_id); - auto stream = get_stream(A.device()); + ffi::CUDADeviceGuard device_guard(A.device().device_id); + auto stream = get_stream(A.device());🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/group_gemm_nvfp4_groupwise_sm120.cu` around lines 101 - 102, The code uses ffi::CUDADeviceGuard constructed from float_workspace_buffer.device() but calls get_stream(A.device()), which can mismatch devices; change to use the same tensor/device for both operations (e.g., construct ffi::CUDADeviceGuard with A.device() and call get_stream(A.device()), or vice versa) so the device guard and stream source (float_workspace_buffer or A) are consistent; update the usage of ffi::CUDADeviceGuard and get_stream to reference the same tensor (A or float_workspace_buffer) throughout.flashinfer/gemm/gemm_base.py (1)
5159-5217:⚠️ Potential issue | 🟠 MajorReject non-flat or cross-device buffers before handing them to FFI.
A CPU/non-contiguous/wrong-shaped
alphastill passes here, andout.deviceis never checked. The SM120 NVFP4 path later treats both tensors as raw device buffers, so a bad user tensor becomes an invalid pointer or scrambled per-group scales instead of a clean Python error.🛠️ Proposed fix
if alpha is not None and alpha.dtype != torch.float32: raise ValueError( f"alpha must be a float32 tensor or None, but got {alpha.dtype}" ) + if alpha is not None and alpha.device != a.device: + raise ValueError(f"alpha must be on {a.device}, but got {alpha.device}") + if alpha is not None and not alpha.is_contiguous(): + raise ValueError("alpha must be contiguous") @@ num_groups = m_indptr.shape[0] - 1 - if alpha is not None and alpha.shape[0] != num_groups: + if alpha is not None and alpha.shape != (num_groups,): raise ValueError( - f"alpha.shape[0] must equal num_groups, but got alpha.shape[0]={alpha.shape[0]}, num_groups={num_groups}" + f"alpha must have shape ({num_groups},), but got {tuple(alpha.shape)}" ) @@ out_shape = (a.shape[0], n) if out is not None: if out.shape != out_shape: raise ValueError(f"out.shape must be {out_shape}, but got {out.shape}") + if out.device != a.device: + raise ValueError(f"out must be on {a.device}, but got {out.device}") if out.dtype != out_dtype: raise ValueError(f"out.dtype must be {out_dtype}, but got {out.dtype}")Run this to confirm the current path still forwards
alphaas a raw pointer without extra normalization:#!/bin/bash set -euo pipefail echo "=== Python-side NVFP4 validation ===" sed -n '5158,5217p' flashinfer/gemm/gemm_base.py echo echo "=== SM120 binding signature ===" sed -n '26,38p' csrc/group_gemm_sm120_binding.cu echo echo "=== CUTLASS epilogue alpha pointer wiring ===" sed -n '247,252p' include/flashinfer/gemm/group_gemm_nvfp4_groupwise_sm120.cuh🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 5159 - 5217, Reject non-flat or cross-device buffers before FFI: validate that alpha (if not None) is a 1D, contiguous torch.Tensor with dtype torch.float32, alpha.shape[0] == num_groups and alpha.device matches the device used for computation (same device as b/a); validate out (if provided) is on the same device as a/b, is contiguous, has shape (a.shape[0], n) and dtype out_dtype; raise clear ValueErrors for non-tensor, non-contiguous, wrong-dtype, wrong-dim, or cross-device cases so raw pointers passed to the SM120 NVFP4 path are always flat, correctly-typed, and device-local.
🧹 Nitpick comments (3)
csrc/group_gemm_mxfp4_groupwise_sm120_kernel_inst.jinja (1)
53-54: Remove extraneous semicolons after namespace closing braces.The semicolons after the closing braces are unnecessary and unconventional in C++.
Suggested fix
-}; // namespace group_gemm -}; // namespace flashinfer +} // namespace group_gemm +} // namespace flashinfer🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/group_gemm_mxfp4_groupwise_sm120_kernel_inst.jinja` around lines 53 - 54, Remove the extraneous semicolons following the closing namespace braces for the namespaces group_gemm and flashinfer: locate the closing braces for namespace group_gemm and namespace flashinfer and delete the trailing ';' characters so the namespace endings read simply "}" without semicolons.csrc/group_gemm_nvfp4_groupwise_sm120_kernel_inst.jinja (1)
53-54: Remove extraneous semicolons after namespace closing braces.Same as the MXFP4 template - the semicolons after closing braces are unnecessary.
Suggested fix
-}; // namespace group_gemm -}; // namespace flashinfer +} // namespace group_gemm +} // namespace flashinfer🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/group_gemm_nvfp4_groupwise_sm120_kernel_inst.jinja` around lines 53 - 54, The file ends namespace blocks with extraneous semicolons; remove the trailing semicolons after the closing braces for the namespaces 'group_gemm' and 'flashinfer' so the two lines "}; // namespace group_gemm" and "}; // namespace flashinfer" become "}" comments preserved — update the lines that close the namespaces group_gemm and flashinfer to drop the unnecessary semicolons.benchmarks/bench_groupwise_grouped_gemm_nvfp4_blackwell_geforce.py (1)
74-89: Lambda captures loop variables by reference - safe but fragile.The lambda passed to
bench_gpu_timecapturestile_m,tile_n, andtile_kby reference. While this works correctly because the lambda is executed immediately within the same loop iteration, it's a pattern that can cause subtle bugs if the code is refactored.Consider using default argument binding to capture by value:
Suggested fix
measurements = bench_gpu_time( - lambda: flashinfer.gemm.group_gemm_nvfp4_nt_groupwise( + lambda tile_m=tile_m, tile_n=tile_n, tile_k=tile_k: flashinfer.gemm.group_gemm_nvfp4_nt_groupwise( a, b, a_scale, b_scale, segment_offsets, out=out, tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, ),🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_groupwise_grouped_gemm_nvfp4_blackwell_geforce.py` around lines 74 - 89, The lambda passed into bench_gpu_time closes over loop variables tile_m, tile_n, tile_k by reference which is fragile; update the call so the lambda captures these values by value (e.g., use default-argument binding: lambda tile_m=tile_m, tile_n=tile_n, tile_k=tile_k: flashinfer.gemm.group_gemm_nvfp4_nt_groupwise(a, b, a_scale, b_scale, segment_offsets, out=out, tile_m=tile_m, tile_n=tile_n, tile_k=tile_k)) or use functools.partial to bind the parameters before passing to bench_gpu_time to ensure stable behavior.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 4928-4948: The public docs for
group_gemm_mxfp8_mxfp4_nt_groupwise() are out of sync with runtime checks
(is_sm12x_supported) which now restrict mma_sm, tile_n, and tile_k ranges;
update the documentation for group_gemm_mxfp8_mxfp4_nt_groupwise to list the
exact allowed values used in the code (for SM12x: mma_sm == 1, tile_m == 128,
tile_n == 128, tile_k == 128; otherwise: mma_sm in {1,2}, tile_m == 128, tile_n
in {64,128,192,256}, tile_k in {128,256}) so users aren’t misled and will avoid
ValueError at runtime.
In `@include/flashinfer/gemm/group_gemm_mxfp4_groupwise_sm120.cuh`:
- Around line 221-223: The KernelHardwareInfo instance sets hw_info.device_id =
0 which breaks multi-GPU setups; update the code that constructs
cutlass::KernelHardwareInfo (hw_info) to use the actual CUDA device instead of
hardcoding 0—either query the current device with cudaGetDevice() (or equivalent
helper) and assign that to hw_info.device_id, or change the calling signature to
accept and forward a device_id parameter so sm_count lookup uses the correct
device; ensure hw_info.sm_count still uses sm_count but comes from the chosen
device context.
In `@include/flashinfer/gemm/group_gemm_nvfp4_groupwise_sm120.cuh`:
- Around line 175-203: The code fails when num_groups == 0 because num_threads =
std::min(num_groups, 1024) becomes 0 and num_blocks divides by zero; add an
early-return guard that checks num_groups == 0 before computing
num_threads/num_blocks (and ideally before heavy allocations) to short-circuit
the grouped GEMM path. Locate the allocation/launch setup in
group_gemm_nvfp4_groupwise_sm120 (or the surrounding helper that uses
AlignedAllocator and calls allocator.aligned_alloc) and insert a simple check
for num_groups == 0 that returns success/does nothing so num_threads and
num_blocks are never computed or used. Ensure the guard references num_groups,
num_threads, and num_blocks to prevent the divide-by-zero.
- Around line 221-225: Remove the thread_local caching of sm_count and the
hardcoded hw_info.device_id = 0; instead, at launch time obtain the current
device and fresh SM count: call cudaGetDevice(&hw_info.device_id) and then set
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id).
Update the code around the sm_count variable and hw_info initialization
(symbols: sm_count, hw_info,
cutlass::KernelHardwareInfo::query_device_multiprocessor_count) in the affected
group_gemm_*_sm120.cuh and the other listed group_gemm files so each launch
queries the current device rather than using a thread-local cached value.
---
Duplicate comments:
In `@csrc/group_gemm_nvfp4_groupwise_sm120.cu`:
- Around line 101-102: The code uses ffi::CUDADeviceGuard constructed from
float_workspace_buffer.device() but calls get_stream(A.device()), which can
mismatch devices; change to use the same tensor/device for both operations
(e.g., construct ffi::CUDADeviceGuard with A.device() and call
get_stream(A.device()), or vice versa) so the device guard and stream source
(float_workspace_buffer or A) are consistent; update the usage of
ffi::CUDADeviceGuard and get_stream to reference the same tensor (A or
float_workspace_buffer) throughout.
In `@flashinfer/gemm/gemm_base.py`:
- Around line 5159-5217: Reject non-flat or cross-device buffers before FFI:
validate that alpha (if not None) is a 1D, contiguous torch.Tensor with dtype
torch.float32, alpha.shape[0] == num_groups and alpha.device matches the device
used for computation (same device as b/a); validate out (if provided) is on the
same device as a/b, is contiguous, has shape (a.shape[0], n) and dtype
out_dtype; raise clear ValueErrors for non-tensor, non-contiguous, wrong-dtype,
wrong-dim, or cross-device cases so raw pointers passed to the SM120 NVFP4 path
are always flat, correctly-typed, and device-local.
---
Nitpick comments:
In `@benchmarks/bench_groupwise_grouped_gemm_nvfp4_blackwell_geforce.py`:
- Around line 74-89: The lambda passed into bench_gpu_time closes over loop
variables tile_m, tile_n, tile_k by reference which is fragile; update the call
so the lambda captures these values by value (e.g., use default-argument
binding: lambda tile_m=tile_m, tile_n=tile_n, tile_k=tile_k:
flashinfer.gemm.group_gemm_nvfp4_nt_groupwise(a, b, a_scale, b_scale,
segment_offsets, out=out, tile_m=tile_m, tile_n=tile_n, tile_k=tile_k)) or use
functools.partial to bind the parameters before passing to bench_gpu_time to
ensure stable behavior.
In `@csrc/group_gemm_mxfp4_groupwise_sm120_kernel_inst.jinja`:
- Around line 53-54: Remove the extraneous semicolons following the closing
namespace braces for the namespaces group_gemm and flashinfer: locate the
closing braces for namespace group_gemm and namespace flashinfer and delete the
trailing ';' characters so the namespace endings read simply "}" without
semicolons.
In `@csrc/group_gemm_nvfp4_groupwise_sm120_kernel_inst.jinja`:
- Around line 53-54: The file ends namespace blocks with extraneous semicolons;
remove the trailing semicolons after the closing braces for the namespaces
'group_gemm' and 'flashinfer' so the two lines "}; // namespace group_gemm" and
"}; // namespace flashinfer" become "}" comments preserved — update the lines
that close the namespaces group_gemm and flashinfer to drop the unnecessary
semicolons.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 5536fc04-9019-4d3c-ab8b-a3a8e0318072
📒 Files selected for processing (15)
benchmarks/bench_groupwise_grouped_gemm_mxfp4_blackwell.pybenchmarks/bench_groupwise_grouped_gemm_nvfp4_blackwell_geforce.pycsrc/group_gemm_mxfp4_groupwise_sm120.cucsrc/group_gemm_mxfp4_groupwise_sm120_kernel_inst.jinjacsrc/group_gemm_nvfp4_groupwise_sm120.cucsrc/group_gemm_nvfp4_groupwise_sm120_kernel_inst.jinjacsrc/group_gemm_sm120_binding.cucsrc/tvm_ffi_utils.hflashinfer/gemm/__init__.pyflashinfer/gemm/gemm_base.pyflashinfer/jit/gemm/core.pyinclude/flashinfer/gemm/group_gemm_mxfp4_groupwise_sm120.cuhinclude/flashinfer/gemm/group_gemm_nvfp4_groupwise_sm120.cuhtests/gemm/test_group_gemm_fp4.pytests/gemm/test_groupwise_scaled_gemm_mxfp4.py
🚧 Files skipped from review as they are similar to previous changes (4)
- flashinfer/gemm/init.py
- tests/gemm/test_group_gemm_fp4.py
- benchmarks/bench_groupwise_grouped_gemm_mxfp4_blackwell.py
- csrc/tvm_ffi_utils.h
|
/bot run |
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
♻️ Duplicate comments (1)
tests/gemm/test_group_gemm_fp4.py (1)
45-49:⚠️ Potential issue | 🟡 MinorFix return type annotation to match actual return value.
The function returns 5 tensors (
a_fp4,b_fp4,a_scale_padded,b_scale,alpha) but the annotation specifies only 4.💡 Suggested fix
def _quantize_nvfp4_group_inputs( a_float: torch.Tensor, b_float: torch.Tensor, m_indptr: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gemm/test_group_gemm_fp4.py` around lines 45 - 49, The return type annotation for _quantize_nvfp4_group_inputs is incorrect: the function actually returns five tensors (a_fp4, b_fp4, a_scale_padded, b_scale, alpha) but the signature declares only four; update the function's return annotation to tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] (or an equivalent 5-tuple type) so it matches the actual returned values and helps type-checkers and readers find the correct symbols (a_fp4, b_fp4, a_scale_padded, b_scale, alpha).
🧹 Nitpick comments (3)
csrc/group_gemm_mxfp4_groupwise_sm120_kernel_inst.jinja (1)
53-54: Remove trailing semicolons after namespace closing braces.The semicolons after the closing braces are valid C++ (empty statements) but unconventional. Standard style omits them.
💡 Suggested fix
-}; // namespace group_gemm -}; // namespace flashinfer +} // namespace group_gemm +} // namespace flashinfer🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/group_gemm_mxfp4_groupwise_sm120_kernel_inst.jinja` around lines 53 - 54, Remove the unnecessary trailing semicolons after the namespace closing braces: locate the closing braces for namespace group_gemm and namespace flashinfer in the template (symbols "namespace group_gemm" and "namespace flashinfer") and delete the semicolons that follow the closing '}' characters so the file ends with plain closing braces rather than '};'.benchmarks/bench_groupwise_grouped_gemm_nvfp4_blackwell_geforce.py (2)
74-86: Lambda captures loop variables by reference.The lambda on lines 76-86 captures
tile_m,tile_n, andtile_kby reference. While this works correctly here becausebench_gpu_timeexecutes the lambda immediately, consider using default arguments to bind the values explicitly for robustness.💡 Suggested fix
measurements = bench_gpu_time( - lambda: flashinfer.gemm.group_gemm_nvfp4_nt_groupwise( + lambda tm=tile_m, tn=tile_n, tk=tile_k: flashinfer.gemm.group_gemm_nvfp4_nt_groupwise( a, b, a_scale, b_scale, segment_offsets, out=out, - tile_m=tile_m, - tile_n=tile_n, - tile_k=tile_k, + tile_m=tm, + tile_n=tn, + tile_k=tk, ),🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_groupwise_grouped_gemm_nvfp4_blackwell_geforce.py` around lines 74 - 86, The lambda passed to bench_gpu_time captures loop variables tile_m, tile_n, tile_k by reference which can lead to late-binding bugs; change the call so the lambda binds current loop values as defaults (e.g., lambda tile_m=tile_m, tile_n=tile_n, tile_k=tile_k: flashinfer.gemm.group_gemm_nvfp4_nt_groupwise(...)) when invoking flashinfer.gemm.group_gemm_nvfp4_nt_groupwise inside bench_gpu_time to ensure the correct tile parameters are used.
26-103: Consider adding a GPU capability check before benchmarking.The MXFP4 benchmark (
bench_groupwise_grouped_gemm_mxfp4_blackwell.py) includes a runtime capability check. Adding a similar check here would prevent confusing errors when running on unsupported GPUs.💡 Suggested addition
def bench_groupwise_grouped_gemm_nvfp4_blackwell(group_size, m, n, k, out_dtype): + from flashinfer.utils import get_compute_capability + compute_capability = get_compute_capability(torch.device("cuda")) + if compute_capability[0] not in [12]: + print("group_gemm_nvfp4_nt_groupwise is only supported on SM120/SM121 GPUs.") + return torch.random.manual_seed(0)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_groupwise_grouped_gemm_nvfp4_blackwell_geforce.py` around lines 26 - 103, The benchmark function bench_groupwise_grouped_gemm_nvfp4_blackwell should guard against running on unsupported GPUs: before seeding/random tensors or calling flashinfer.gemm.group_gemm_nvfp4_nt_groupwise, add a runtime capability check (same style as in bench_groupwise_grouped_gemm_mxfp4_blackwell.py) that queries CUDA device properties (compute capability or a provided flashinfer capability check) and early-returns or prints a skip message if NVFP4/Blackwell features are not available; place this check at the top of bench_groupwise_grouped_gemm_nvfp4_blackwell so the rest of the setup (tensor allocation, a_scale/b_scale, and the benchmarking loop) is skipped on incompatible hardware.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@tests/gemm/test_group_gemm_fp4.py`:
- Around line 45-49: The return type annotation for _quantize_nvfp4_group_inputs
is incorrect: the function actually returns five tensors (a_fp4, b_fp4,
a_scale_padded, b_scale, alpha) but the signature declares only four; update the
function's return annotation to tuple[torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor] (or an equivalent 5-tuple type) so it matches the
actual returned values and helps type-checkers and readers find the correct
symbols (a_fp4, b_fp4, a_scale_padded, b_scale, alpha).
---
Nitpick comments:
In `@benchmarks/bench_groupwise_grouped_gemm_nvfp4_blackwell_geforce.py`:
- Around line 74-86: The lambda passed to bench_gpu_time captures loop variables
tile_m, tile_n, tile_k by reference which can lead to late-binding bugs; change
the call so the lambda binds current loop values as defaults (e.g., lambda
tile_m=tile_m, tile_n=tile_n, tile_k=tile_k:
flashinfer.gemm.group_gemm_nvfp4_nt_groupwise(...)) when invoking
flashinfer.gemm.group_gemm_nvfp4_nt_groupwise inside bench_gpu_time to ensure
the correct tile parameters are used.
- Around line 26-103: The benchmark function
bench_groupwise_grouped_gemm_nvfp4_blackwell should guard against running on
unsupported GPUs: before seeding/random tensors or calling
flashinfer.gemm.group_gemm_nvfp4_nt_groupwise, add a runtime capability check
(same style as in bench_groupwise_grouped_gemm_mxfp4_blackwell.py) that queries
CUDA device properties (compute capability or a provided flashinfer capability
check) and early-returns or prints a skip message if NVFP4/Blackwell features
are not available; place this check at the top of
bench_groupwise_grouped_gemm_nvfp4_blackwell so the rest of the setup (tensor
allocation, a_scale/b_scale, and the benchmarking loop) is skipped on
incompatible hardware.
In `@csrc/group_gemm_mxfp4_groupwise_sm120_kernel_inst.jinja`:
- Around line 53-54: Remove the unnecessary trailing semicolons after the
namespace closing braces: locate the closing braces for namespace group_gemm and
namespace flashinfer in the template (symbols "namespace group_gemm" and
"namespace flashinfer") and delete the semicolons that follow the closing '}'
characters so the file ends with plain closing braces rather than '};'.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 8b2f1c76-621c-4242-96f4-4c93b3ec55c3
📒 Files selected for processing (17)
benchmarks/bench_groupwise_grouped_gemm_mxfp4_blackwell.pybenchmarks/bench_groupwise_grouped_gemm_nvfp4_blackwell_geforce.pycsrc/group_gemm_mxfp4_groupwise_sm120.cucsrc/group_gemm_mxfp4_groupwise_sm120_kernel_inst.jinjacsrc/group_gemm_nvfp4_groupwise_sm120.cucsrc/group_gemm_nvfp4_groupwise_sm120_kernel_inst.jinjacsrc/group_gemm_sm120_binding.cucsrc/tvm_ffi_utils.hflashinfer/gemm/__init__.pyflashinfer/gemm/gemm_base.pyflashinfer/jit/gemm/core.pyinclude/flashinfer/gemm/fp4_gemm_template_sm120.hinclude/flashinfer/gemm/group_gemm_fp8_groupwise_sm120.cuhinclude/flashinfer/gemm/group_gemm_mxfp4_groupwise_sm120.cuhinclude/flashinfer/gemm/group_gemm_nvfp4_groupwise_sm120.cuhtests/gemm/test_group_gemm_fp4.pytests/gemm/test_groupwise_scaled_gemm_mxfp4.py
🚧 Files skipped from review as they are similar to previous changes (2)
- csrc/group_gemm_nvfp4_groupwise_sm120_kernel_inst.jinja
- csrc/group_gemm_sm120_binding.cu
|
@coderabbitai resume |
|
/bot run |
|
[FAILED] Pipeline #46400558: 0/20 passed |
|
/bot run |
|
[SUCCESS] Pipeline #46568840: 14/20 passed |
|
@aleozlx i ping again that this solves a lot of pains from users related with nvfp4 |
|
@depaulmillz - Looks like it needs some changes to be merged cleanly. EDIT: disregard, tried to apply as a patch. Merges without any issues, now building on DGX Spark... |
|
thanks for the ping. restarting CI |
📌 Description
This MR adds functional support for CUTLASS MXFP4 group GEMMs and NVFP4 group GEMMs on Blackwell GeForce and DGX Spark. It implements MXFP4 group GEMMs to match the existing interface and adds a new interface for NVFP4. The NVFP4 interface aims to match the baseline GEMM interface including support for alpha scaling.
The MR also unguards GDC on CUTLASS kernels for functional correctness.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit