Skip to content

feat: add MXFP8 GEMM support for SM120#2902

Open
samuellees wants to merge 1 commit intomainfrom
sam-mxfp8-sm120-clean
Open

feat: add MXFP8 GEMM support for SM120#2902
samuellees wants to merge 1 commit intomainfrom
sam-mxfp8-sm120-clean

Conversation

@samuellees
Copy link
Copy Markdown
Collaborator

@samuellees samuellees commented Mar 27, 2026

Resolves #2728

Key changes:

  • New CUTLASS kernel template: mxfp8_gemm_cutlass_template_sm120.h using Sm1xxBlkScaledConfig with 3 CTA tile configs (128×128×128, 256×128×128, 128×256×128)
  • New C++ launcher: csrc/mxfp8_gemm_cutlass_sm120.cu with full input validation (K%32==0, N%32==0, M unconstrained, swizzled-only scale format)
  • SM120 kernel hardcodes the hardware-native swizzled scale layout (Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA/SFB); linear (2D) scale is explicitly rejected with a clear error message
  • JIT module generator gen_gemm_sm120_module_cutlass_mxfp8() in flashinfer/jit/gemm/core.py
  • AOT registration under has_sm120 or has_sm121 in flashinfer/aot.py
  • Updated mm_mxfp8 docstring with SM12x swizzled-only note
  • Added mm_mxfp8 to docs/api/gemm.rst
  • Test suite: tests/gemm/test_mm_mxfp8_sm120.py covering arbitrary M (including non-multiples of 128), N/K in {128,256,512,1024}, both bf16/fp16 output, cos_sim > 0.99 accuracy threshold

AI-assisted

📌 Description

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added MXFP8 GEMM support for SM120 (Hopper) GPUs with optimized kernels for BF16 and FP16 output formats.
  • Documentation

    • Added API documentation for MXFP8 matrix multiplication operations.
  • Tests

    • Added comprehensive tests validating SM120 MXFP8 GEMM across multiple matrix sizes and configurations.

)

Implement `mm_mxfp8` / `bmm_mxfp8` support for SM12x GPUs (RTX PRO 6000,
RTX 5090, RTX 5080) via CUTLASS 4.x blockscaled GEMM.

Key changes:
- New CUTLASS kernel template: `mxfp8_gemm_cutlass_template_sm120.h` using
  `Sm1xxBlkScaledConfig` with 3 CTA tile configs (128×128×128, 256×128×128,
  128×256×128)
- New C++ launcher: `csrc/mxfp8_gemm_cutlass_sm120.cu` with full input
  validation (K%32==0, N%32==0, M unconstrained, swizzled-only scale format)
- SM120 kernel hardcodes the hardware-native swizzled scale layout
  (Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA/SFB); linear (2D) scale is
  explicitly rejected with a clear error message
- JIT module generator `gen_gemm_sm120_module_cutlass_mxfp8()` in
  `flashinfer/jit/gemm/core.py`
- AOT registration under `has_sm120 or has_sm121` in `flashinfer/aot.py`
- Updated `mm_mxfp8` docstring with SM12x swizzled-only note
- Added `mm_mxfp8` to `docs/api/gemm.rst`
- Test suite: `tests/gemm/test_mm_mxfp8_sm120.py` covering arbitrary M
  (including non-multiples of 128), N/K in {128,256,512,1024}, both bf16/fp16
  output, cos_sim > 0.99 accuracy threshold

AI-assisted

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 27, 2026

📝 Walkthrough

Walkthrough

This PR adds SM120 MXFP8 GEMM support by implementing SM120-specific kernel templates, dispatchers, JIT generators, and Python integration. It includes new CUDA kernels, template instantiations, runtime dispatchers with configuration selection, documentation, and comprehensive tests for the SM120 architecture.

Changes

Cohort / File(s) Summary
Core SM120 MXFP8 GEMM Implementation
csrc/mxfp8_gemm_cutlass_sm120.cu, csrc/mxfp8_gemm_cutlass_sm120.jinja
Added SM120 FFI entry points (mxfp8_gemm, mxfp8_gemm_tactic_num), workspace allocation logic, input validation for tensor ranks/dtypes/scale layouts, and Jinja template for kernel instantiations across multiple CTA tile shapes and data types.
Header Template & Dispatch Layer
include/flashinfer/gemm/mxfp8_gemm_template_sm120.h, include/flashinfer/gemm/mxfp8_gemm_cutlass_template_sm120.h
Introduced SM120-specific kernel launcher templates with full CUTLASS integration, dispatcher logic for selecting CTA shapes, workspace memoization via static hash map, and error handling for unsupported configurations.
Configuration & Tile Support
include/flashinfer/gemm/cutlass_gemm_configs.h
Added CtaShape128x256x128B tile configuration to CutlassTileConfigSM120 enumeration.
Python JIT & AOT Integration
flashinfer/jit/gemm/core.py, flashinfer/jit/gemm/__init__.py, flashinfer/aot.py
Created gen_gemm_sm120_module_cutlass_mxfp8() JIT generator with SM120-specific NVCC flags, exported new generator symbol, and wired SM120 MXFP8 GEMM into AOT module generation pipeline.
Module Routing & API
flashinfer/gemm/gemm_base.py
Added cached SM120 MXFP8 CUTLASS module getter, extended get_cutlass_mxfp8_gemm_module() to route SM major 12, expanded supported compute capabilities to include SM120/SM121, and updated API documentation with SM12x-specific swizzled-scale constraints.
Documentation
docs/api/gemm.rst
Added MXFP8 GEMM API documentation section with autosummary entry for mm_mxfp8.
Tests
tests/gemm/test_mm_mxfp8_sm120.py
Added comprehensive test suite validating SM120 MXFP8 GEMM functional correctness across multiple M/N/K dimensions, output dtypes, tactic enumeration, and code generation for swizzled-scale layouts.

Sequence Diagram(s)

sequenceDiagram
    participant PyAPI as Python API<br/>(mm_mxfp8)
    participant PyGemm as gemm_base.py<br/>(module routing)
    participant Dispatch as SM120 Dispatcher<br/>(mxfp8_gemm_cutlass_sm120.cu)
    participant Runner as CutlassMxfp8GemmRunnerSm120<br/>(template_sm120.h)
    participant Kernel as CUTLASS Kernel<br/>(SM120 arch)
    participant CUDA as CUDA Device

    PyAPI->>PyGemm: get_cutlass_mxfp8_gemm_module(sm_major=12)
    PyGemm->>Dispatch: Load SM120 MXFP8 module
    PyAPI->>Dispatch: mxfp8_gemm(mat1, mat2, scales, workspace, tactic)
    Dispatch->>Dispatch: Validate inputs (dtypes, ranks, scale layout)
    Dispatch->>Dispatch: Select CutlassGemmConfig from tactic
    Dispatch->>Runner: gemm(A, B, scales, workspace)
    Runner->>Runner: dispatchToArch (select CTA shape)
    Runner->>Kernel: Launch kernel with CTA dimensions
    Kernel->>CUDA: Execute MXFP8 block-scaled GEMM
    CUDA-->>Kernel: Return results
    Kernel-->>Runner: Workspace size / Status
    Runner-->>Dispatch: Completion
    Dispatch-->>PyAPI: Output tensor
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related issues

Possibly related PRs

Suggested labels

run-ci, op: gemm

Suggested reviewers

  • nvmbreughe
  • cyx-6
  • jimmyzho
  • jiahanc
  • bkryu
  • yzh119
  • aleozlx
  • djmmoss

Poem

🐰 A brand new SM120 GEMM kernel is here,
With MXFP8 magic and swizzled scales clear,
Templates dispatch through dimensions so fine,
Workspace and tactics in perfect design,
One-twenty computing with grace and with cheer!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 28.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title 'feat: add MXFP8 GEMM support for SM120' clearly and concisely summarizes the main change—adding MXFP8 GEMM support for SM120 GPUs—and is directly related to the primary objective of the changeset.
Description check ✅ Passed The PR description provides comprehensive details about key changes, links to resolved issue #2728, documents new templates/launchers/generators, specifies alignment constraints and test coverage, and includes a completed checklist meeting template requirements.
Linked Issues check ✅ Passed The code changes fulfill issue #2728 requirements: the PR integrates SM120 MXFP8 BlockScaled GEMM using CUTLASS with three tile configs (128×128×128, 256×128×128, 128×256×128), enforces swizzled-only scale layout, includes proper input validation (K%32==0, N%32==0), and provides comprehensive test coverage.
Out of Scope Changes check ✅ Passed All changes are tightly scoped to SM120 MXFP8 GEMM support: new kernel templates, launchers, JIT generators, tests, and documentation updates directly address the stated objective of migrating SM120 MXFP8 BlockScaled GEMM from TRTLLM with no unrelated modifications.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch sam-mxfp8-sm120-clean

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements MXFP8 GEMM support for SM120 (Blackwell) GPUs using CUTLASS. The changes include the core CUDA implementation, JIT and AOT build system integration, documentation updates, and new test cases. Feedback identifies an improvement opportunity to refactor duplicated logic for retrieving GEMM configurations into a shared helper function to enhance maintainability.

Comment on lines +198 to +205
int64_t mxfp8_gemm_tactic_num_sm120() {
auto getCutlassConfigs = []() {
CutlassMxfp8GemmRunnerSm120<__nv_bfloat16> gemmRunner;
return gemmRunner.getConfigs();
};
static int64_t totalTactics = getCutlassConfigs().size();
return totalTactics;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for retrieving CUTLASS GEMM configurations is duplicated from getMxfp8GemmConfigSm120. To improve maintainability and avoid code duplication, you could introduce a shared helper function.

For example, you could create a helper function within a detail namespace inside torch_ext:

namespace torch_ext {
namespace detail {
inline const std::vector<CutlassGemmConfig>& GetMxfp8GemmConfigsSm120() {
  static const std::vector<CutlassGemmConfig> kGlobalConfigs = []() {
    CutlassMxfp8GemmRunnerSm120<__nv_bfloat16> gemmRunner;
    return gemmRunner.getConfigs();
  }();
  return kGlobalConfigs;
}
}  // namespace detail

Then, getMxfp8GemmConfigSm120 and mxfp8_gemm_tactic_num_sm120 can both use this helper:

// In getMxfp8GemmConfigSm120
const auto& globalConfigs = detail::GetMxfp8GemmConfigsSm120();

// In mxfp8_gemm_tactic_num_sm120
int64_t mxfp8_gemm_tactic_num_sm120() {
  return detail::GetMxfp8GemmConfigsSm120().size();
}

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/gemm/gemm_base.py (1)

3409-3420: ⚠️ Potential issue | 🟠 Major

Reject SM12x-incompatible CUTLASS inputs during backend selection.

The docstring now says SM12x CUTLASS only supports 1D layout_128x4 scales, but _cutlass_gemm_mxfp8_requirement() still returns True unconditionally. That means backend="auto" can still pick CUTLASS for 2D scales, use_8x4_sf_layout=True, or N/K values that the SM120 launcher rejects, and the failure happens late instead of at API validation time.

♻️ Suggested guard
 `@supported_compute_capability`([100, 103, 110, 120, 121])
 def _cutlass_gemm_mxfp8_requirement(
     a: torch.Tensor,
     b: torch.Tensor,
@@
     use_8x4_sf_layout: bool = True,
     backend: Literal["cutlass", "cute-dsl", "trtllm", "auto"] = "auto",
 ):
+    if is_sm12x_supported(a.device):
+        if a.shape[1] % 32 != 0 or b.shape[1] % 32 != 0:
+            raise ValueError(
+                "SM120/SM121 CUTLASS MXFP8 requires K and N to be multiples of 32."
+            )
+        if use_8x4_sf_layout:
+            raise ValueError(
+                "SM120/SM121 CUTLASS MXFP8 only supports SfLayout.layout_128x4."
+            )
+        if a_descale.ndim != 1 or b_descale.ndim != 1:
+            raise ValueError(
+                "SM120/SM121 CUTLASS MXFP8 only supports 1D swizzled scale tensors."
+            )
     return True

Also applies to: 3933-3935

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 3409 - 3420, Update
_cutlass_gemm_mxfp8_requirement to proactively reject inputs that SM12x CUTLASS
cannot support: detect when backend selection might pick CUTLASS (backend in
{"auto","cutlass"}) and return False for SM12x GPUs if any of the SM12x-only
constraints are violated — e.g., a_descale or b_descale are not 1D
layout_128x4-style scales, use_8x4_sf_layout is True when SM12x doesn't support
that layout, or N/K tensor dimensions/sizes fall into ranges the SM120 launcher
rejects; apply the same validation logic to the analogous check around lines
3933-3935 so backend="auto" won't choose CUTLASS for incompatible 2D scales,
layouts, or N/K values. Ensure you reference and use the function name
_cutlass_gemm_mxfp8_requirement (and the similar guard at 3933-3935) to locate
and implement these condition checks and return False when constraints are not
met.
🧹 Nitpick comments (3)
include/flashinfer/gemm/mxfp8_gemm_cutlass_template_sm120.h (1)

107-113: Document why the workspace probe swallows std::runtime_error.

This catch is part of config probing, but without a short rationale it reads like accidental error suppression.

Suggested change
-        } catch (std::runtime_error&) {
+        } catch (std::runtime_error&) {
+          // Swallow errors when a candidate config is ineligible or exceeds SMEM.
           continue;
         }
Based on learnings, swallowed `std::runtime_error` probes in `getWorkspaceSizeImpl` are acceptable here, but they should carry a brief rationale comment.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/gemm/mxfp8_gemm_cutlass_template_sm120.h` around lines 107
- 113, The workspace-size probe in getWorkspaceSizeImpl currently catches and
swallows std::runtime_error thrown by dispatchToArch (used to probe candidate
gemmConfig) without explanation; add a brief comment above the try/catch
explaining that probing may throw for unsupported/config-invalid configs and
that these exceptions are intentionally ignored to allow trying other configs
(i.e., this is a non-fatal probe, not a logic bug), referencing dispatchToArch
and workspace_size so readers know why the catch exists and that it only applies
to probing, not to actual execution.
tests/gemm/test_mm_mxfp8_sm120.py (2)

28-45: The SM120-specific paths still need direct coverage.

This file only exercises 2D swizzled cases with N/K values that are multiples of 128. Please add at least one 3D batched case, one N/K case that's only a multiple of 32, and one pytest.raises case for linear scales so the new bmm_mxfp8, alignment, and rejection paths are actually tested.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gemm/test_mm_mxfp8_sm120.py` around lines 28 - 45, Extend the test
matrix in tests/gemm/test_mm_mxfp8_sm120.py to cover SM120-specific code paths
by (1) adding a 3D batched input case that calls _prepare_mxfp8 and exercises
bmm_mxfp8 (e.g., shapes with batch>1) so the batched branch is executed; (2)
adding at least one case where N and K are multiples of 32 but not 128 (e.g., 96
or 160) to hit the alignment path that handles 32-granularity; and (3) adding a
pytest.raises test that calls mxfp8_quantize/_prepare_mxfp8 (using
SfLayout.layout_linear) with an input shape that should trigger the linear-scale
rejection so the code path that raises for invalid linear scales is covered;
ensure the new tests reference _prepare_mxfp8, mxfp8_quantize, bmm_mxfp8 and use
SfLayout.layout_linear / SfLayout.layout_128x4 to select swizzled vs linear
flows.

81-109: all_tactics only exercises the default path once.

This proves the module reports a tactic count, but it never dispatches tactics 1 and 2. A broken secondary CTA instantiation would still pass here, so either iterate each tactic explicitly or rename the test to match its real coverage.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gemm/test_mm_mxfp8_sm120.py` around lines 81 - 109, The test
test_mm_mxfp8_sm120_all_tactics only exercises the default tactic; query
module.mxfp8_gemm_tactic_num() and loop i from 0 to num_tactics-1, invoking the
same execution path for each tactic (use the module or mm_mxfp8 API that accepts
a tactic/index parameter or call the module’s tactic-dispatch method) and
validate shape, finiteness and cosine similarity per-tactic; alternatively if
you intend to only check tactic count, rename the test to reflect that (e.g.,
test_mm_mxfp8_sm120_reports_tactics) instead of claiming it exercises all
tactics.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@docs/api/gemm.rst`:
- Around line 30-33: The autosummary list for GEMM currently only documents
mm_mxfp8, omitting the batched API bmm_mxfp8; update the autosummary block to
include bmm_mxfp8 alongside mm_mxfp8 so the batched MXFP8 function is included
in the generated reference docs (i.e., add the symbol name bmm_mxfp8 into the
same autosummary list that contains mm_mxfp8).

In `@flashinfer/gemm/gemm_base.py`:
- Around line 3271-3286: The batched MXFP8 path isn't using the SM120/121
CUTLASS module: update the bmm path to support SM120/121 by wiring the CUTLASS
runner there and relaxing the cuDNN requirement checks. Specifically, modify
bmm_mxfp8 to prefer the CUTLASS module for sm_major 12 (use
get_cutlass_mxfp8_gemm_module or call the equivalent cutlass runner instead of
hard-coding mxfp8_gemm_sm100(..., ["cudnn"])), and update
_cudnn_bmm_mxfp8_requirement to allow SM120/121 (or to defer to the CUTLASS
option) so batched MXFP8 can run on SM120/121 with the CUTLASS implementation.

In `@tests/gemm/test_mm_mxfp8_sm120.py`:
- Around line 15-20: The helper _is_sm120_available should avoid catching all
exceptions; update it to first check torch.cuda.is_available() and then call
get_compute_capability(torch.device("cuda")), or alternatively catch only
RuntimeError from get_compute_capability; specifically modify the
_is_sm120_available function to return False immediately if
torch.cuda.is_available() is False, and when calling get_compute_capability (the
symbol in question) catch RuntimeError only and use that to return False instead
of a broad except Exception.

---

Outside diff comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 3409-3420: Update _cutlass_gemm_mxfp8_requirement to proactively
reject inputs that SM12x CUTLASS cannot support: detect when backend selection
might pick CUTLASS (backend in {"auto","cutlass"}) and return False for SM12x
GPUs if any of the SM12x-only constraints are violated — e.g., a_descale or
b_descale are not 1D layout_128x4-style scales, use_8x4_sf_layout is True when
SM12x doesn't support that layout, or N/K tensor dimensions/sizes fall into
ranges the SM120 launcher rejects; apply the same validation logic to the
analogous check around lines 3933-3935 so backend="auto" won't choose CUTLASS
for incompatible 2D scales, layouts, or N/K values. Ensure you reference and use
the function name _cutlass_gemm_mxfp8_requirement (and the similar guard at
3933-3935) to locate and implement these condition checks and return False when
constraints are not met.

---

Nitpick comments:
In `@include/flashinfer/gemm/mxfp8_gemm_cutlass_template_sm120.h`:
- Around line 107-113: The workspace-size probe in getWorkspaceSizeImpl
currently catches and swallows std::runtime_error thrown by dispatchToArch (used
to probe candidate gemmConfig) without explanation; add a brief comment above
the try/catch explaining that probing may throw for unsupported/config-invalid
configs and that these exceptions are intentionally ignored to allow trying
other configs (i.e., this is a non-fatal probe, not a logic bug), referencing
dispatchToArch and workspace_size so readers know why the catch exists and that
it only applies to probing, not to actual execution.

In `@tests/gemm/test_mm_mxfp8_sm120.py`:
- Around line 28-45: Extend the test matrix in tests/gemm/test_mm_mxfp8_sm120.py
to cover SM120-specific code paths by (1) adding a 3D batched input case that
calls _prepare_mxfp8 and exercises bmm_mxfp8 (e.g., shapes with batch>1) so the
batched branch is executed; (2) adding at least one case where N and K are
multiples of 32 but not 128 (e.g., 96 or 160) to hit the alignment path that
handles 32-granularity; and (3) adding a pytest.raises test that calls
mxfp8_quantize/_prepare_mxfp8 (using SfLayout.layout_linear) with an input shape
that should trigger the linear-scale rejection so the code path that raises for
invalid linear scales is covered; ensure the new tests reference _prepare_mxfp8,
mxfp8_quantize, bmm_mxfp8 and use SfLayout.layout_linear / SfLayout.layout_128x4
to select swizzled vs linear flows.
- Around line 81-109: The test test_mm_mxfp8_sm120_all_tactics only exercises
the default tactic; query module.mxfp8_gemm_tactic_num() and loop i from 0 to
num_tactics-1, invoking the same execution path for each tactic (use the module
or mm_mxfp8 API that accepts a tactic/index parameter or call the module’s
tactic-dispatch method) and validate shape, finiteness and cosine similarity
per-tactic; alternatively if you intend to only check tactic count, rename the
test to reflect that (e.g., test_mm_mxfp8_sm120_reports_tactics) instead of
claiming it exercises all tactics.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: d7b77a63-ee11-4a91-8929-f087c895b0d9

📥 Commits

Reviewing files that changed from the base of the PR and between 31b63bc and 77156f7.

📒 Files selected for processing (11)
  • csrc/mxfp8_gemm_cutlass_sm120.cu
  • csrc/mxfp8_gemm_cutlass_sm120.jinja
  • docs/api/gemm.rst
  • flashinfer/aot.py
  • flashinfer/gemm/gemm_base.py
  • flashinfer/jit/gemm/__init__.py
  • flashinfer/jit/gemm/core.py
  • include/flashinfer/gemm/cutlass_gemm_configs.h
  • include/flashinfer/gemm/mxfp8_gemm_cutlass_template_sm120.h
  • include/flashinfer/gemm/mxfp8_gemm_template_sm120.h
  • tests/gemm/test_mm_mxfp8_sm120.py

Comment on lines +30 to +33
.. autosummary::
:toctree: ../generated

mm_mxfp8
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add bmm_mxfp8 to the generated GEMM docs.

The new section only exposes mm_mxfp8, so the batched MXFP8 API stays undiscoverable in the generated reference docs.

📝 Suggested update
 .. autosummary::
     :toctree: ../generated
 
     mm_mxfp8
+    bmm_mxfp8
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
.. autosummary::
:toctree: ../generated
mm_mxfp8
.. autosummary::
:toctree: ../generated
mm_mxfp8
bmm_mxfp8
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/api/gemm.rst` around lines 30 - 33, The autosummary list for GEMM
currently only documents mm_mxfp8, omitting the batched API bmm_mxfp8; update
the autosummary block to include bmm_mxfp8 alongside mm_mxfp8 so the batched
MXFP8 function is included in the generated reference docs (i.e., add the symbol
name bmm_mxfp8 into the same autosummary list that contains mm_mxfp8).

Comment on lines +3271 to +3286
@functools.cache
def get_gemm_sm120_module_cutlass_mxfp8():
"""Get the SM120/121 MXFP8 GEMM module."""
module = gen_gemm_sm120_module_cutlass_mxfp8().build_and_load()
return _create_cutlass_mxfp8_gemm_module(
module, "flashinfer::cutlass_mxfp8_gemm", "cutlass_mxfp8_gemm"
)


def get_cutlass_mxfp8_gemm_module(
sm_major: int,
):
if sm_major in [10, 11]:
return get_gemm_sm100_module_cutlass_mxfp8()
elif sm_major in [12]:
return get_gemm_sm120_module_cutlass_mxfp8()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

SM120 CUTLASS is wired only into mm_mxfp8().

get_gemm_sm120_module_cutlass_mxfp8() is only consumed by the 2D path. bmm_mxfp8() below still hard-codes the cuDNN runner via mxfp8_gemm_sm100(..., ["cudnn"]), and _cudnn_bmm_mxfp8_requirement still only allows SM100/103, so batched MXFP8 remains unavailable on SM120/121.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 3271 - 3286, The batched MXFP8
path isn't using the SM120/121 CUTLASS module: update the bmm path to support
SM120/121 by wiring the CUTLASS runner there and relaxing the cuDNN requirement
checks. Specifically, modify bmm_mxfp8 to prefer the CUTLASS module for sm_major
12 (use get_cutlass_mxfp8_gemm_module or call the equivalent cutlass runner
instead of hard-coding mxfp8_gemm_sm100(..., ["cudnn"])), and update
_cudnn_bmm_mxfp8_requirement to allow SM120/121 (or to defer to the CUTLASS
option) so batched MXFP8 can run on SM120/121 with the CUTLASS implementation.

Comment on lines +15 to +20
def _is_sm120_available() -> bool:
try:
cc = get_compute_capability(torch.device("cuda"))
return cc[0] == 12
except Exception:
return False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# Find the implementation of get_compute_capability
fd -t f "*.py" | xargs grep -l "def get_compute_capability" | head -5

Repository: flashinfer-ai/flashinfer

Length of output: 351


🏁 Script executed:

# Look at the test file being reviewed
cat -n tests/gemm/test_mm_mxfp8_sm120.py | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 1105


🏁 Script executed:

# Search for other uses of get_compute_capability in tests to see patterns
rg "get_compute_capability" tests/ -A 3 -B 1 | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 2999


🏁 Script executed:

# Find where get_compute_capability is defined
rg "def get_compute_capability" -A 10

Repository: flashinfer-ai/flashinfer

Length of output: 1406


🏁 Script executed:

# Check how it's used in similar helper functions across the codebase
rg "_is_sm|_require_sm|skip.*sm" tests/ --type py -A 2 -B 1 | head -80

Repository: flashinfer-ai/flashinfer

Length of output: 4915


🏁 Script executed:

# Look at any docstring or error handling in flashinfer.utils
fd flashinfer utils.py -type f | head -5

Repository: flashinfer-ai/flashinfer

Length of output: 238


🏁 Script executed:

# Check if CUDA availability is checked explicitly anywhere
rg "torch.cuda.is_available" tests/ --type py -B 2 | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 1785


🏁 Script executed:

# Look at the actual implementation details of get_device_capability
python3 << 'EOF'
import torch
try:
    # Check what exceptions can be raised
    help(torch.cuda.get_device_capability)
except Exception as e:
    print(f"Error: {e}")
EOF

Repository: flashinfer-ai/flashinfer

Length of output: 185


🏁 Script executed:

# Find similar patterns where exception handling is used for device checks
rg "except.*RuntimeError|except.*ValueError" tests/ --type py -B 3 | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 1732


🏁 Script executed:

# Check if there's a pattern of checking torch.cuda.is_available before device checks
rg "torch.cuda.is_available" tests/gemm/ -A 2 -B 1

Repository: flashinfer-ai/flashinfer

Length of output: 382


🏁 Script executed:

# Look at the specific exceptions that torch.cuda.get_device_capability can raise
rg "get_device_capability" --type py -B 2 -A 5 | head -60

Repository: flashinfer-ai/flashinfer

Length of output: 3667


🏁 Script executed:

# Check what ValueError message would be if device type is wrong
cat -n flashinfer/utils.py | grep -A 5 "get_compute_capability"

Repository: flashinfer-ai/flashinfer

Length of output: 3205


🏁 Script executed:

# See if the test file initializes CUDA before calling _is_sm120_available
cat -n tests/gemm/test_mm_mxfp8_sm120.py | grep -A 15 "_is_sm120_available\|_skip_if_not_sm120\|def test_"

Repository: flashinfer-ai/flashinfer

Length of output: 3107


Replace except Exception with specific exception handling or add availability check.

The current code catches all exceptions, which can silently mask unexpected failures. Either:

  1. Catch only RuntimeError (which torch.cuda.get_device_capability() raises when CUDA is unavailable)
  2. Or add torch.cuda.is_available() check first, matching the pattern used elsewhere in the codebase (e.g., tests/gdn/test_prefill_delta_rule.py, tests/model_optimizations/test_tinygemm2.py)
🧰 Tools
🪛 Ruff (0.15.7)

[warning] 19-19: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gemm/test_mm_mxfp8_sm120.py` around lines 15 - 20, The helper
_is_sm120_available should avoid catching all exceptions; update it to first
check torch.cuda.is_available() and then call
get_compute_capability(torch.device("cuda")), or alternatively catch only
RuntimeError from get_compute_capability; specifically modify the
_is_sm120_available function to return False immediately if
torch.cuda.is_available() is False, and when calling get_compute_capability (the
symbol in question) catch RuntimeError only and use that to return False instead
of a broad except Exception.

@johnnynunez
Copy link
Copy Markdown
Contributor

is this compatible with DGX Spark? @samuellees

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature] migrate SM120 MXFP8 BlockScaled GEMM from TRTLLM

2 participants