Skip to content

feat: Add support for bmm mxfp8#2256

Merged
yzh119 merged 4 commits intoflashinfer-ai:mainfrom
danisereb:support_mxfp8
Dec 25, 2025
Merged

feat: Add support for bmm mxfp8#2256
yzh119 merged 4 commits intoflashinfer-ai:mainfrom
danisereb:support_mxfp8

Conversation

@danisereb
Copy link
Copy Markdown
Contributor

@danisereb danisereb commented Dec 22, 2025

📌 Description

Add support for GEMM with MXFP8 (bmm_mxfp8).

At this time only cuDNN is supported.

Added test tests/gemm/test_bmm_mxfp8.py

Added routine bmm_mxfp8 to flashinfer_benchmark.

Benchmark results for bmm_mxfp8 (on B200 GPU):

python benchmarks/flashinfer_benchmark.py \
--routine bmm_mxfp8 -vv \
--num_iters 30 \
--batch_size 128 \
--m 512 --n 512 --k 4096 \
--out_dtype bfloat16 \
--backends cudnn \
--refcheck

[PERF] cudnn          :: median time 0.117 ms; std 0.001 ms; achieved tflops 2347.650 TFLOPs/sec; achieved tb_per_sec 0.040 TB/sec

And bmm_fp8 for comparison:

python benchmarks/flashinfer_benchmark.py \
--routine bmm_fp8 -vv \
--num_iters 30 \
--batch_size 128 \
--m 512 --n 512 --k 4096 \
--input_dtype fp8_e4m3 \
--mat2_dtype fp8_e4m3 \
--out_dtype bfloat16 \
--backends cudnn \
--refcheck

[PERF] cudnn          :: median time 0.116 ms; std 0.001 ms; achieved tflops 2369.049 TFLOPs/sec; achieved tb_per_sec 0.041 TB/sec

When running ncu the kernel nvjet_sm100_qqtst_128x256_128x6_2x1_2cta_v_bz_Avec32UE8M0_Bvec32UE8M0_NNT seems to trigger.

🔍 Related Issues

#2209

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added MXFP8 (mixed 8-bit float) batched matrix multiplication with cuDNN acceleration and package-level export.
  • Tests

    • Added parameterized tests validating MXFP8 BMM against reference results across shapes, dtypes, layouts, backends, and autotune modes.
  • Chores

    • Updated benchmark catalog and backend-support mappings to include MXFP8 BMM.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Dec 22, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

📝 Walkthrough

Walkthrough

Adds MXFP8 batched matrix-multiplication (bmm_mxfp8): new public API, cuDNN MXFP8 graph creation/execution and autotune runner, benchmark/test integration, and dtype/backend support mappings.

Changes

Cohort / File(s) Summary
Benchmarks
benchmarks/routines/flashinfer_benchmark_utils.py, benchmarks/routines/gemm.py
Added bmm_mxfp8 to gemm routine list and dtype→backend mappings (cudnn listed for CC 10.0/10.3); added testBmmMxfp8, imported mxfp8_quantize, and wired dispatch in run_gemm_test.
Top-level exports
flashinfer/__init__.py
Exported bmm_mxfp8 at package level.
GEMM package exports
flashinfer/gemm/__init__.py
Imported and exported bmm_mxfp8; updated __all__.
Core MXFP8 GEMM implementation
flashinfer/gemm/gemm_base.py
New MXFP8 cuDNN graph & execution: graph creation, execution-plan helpers, block-scale dimension calc, runners (_cudnn_gemm_mxfp8_runner, mxfp8_gemm_sm100), validation/requirement checks, heuristic, and public bmm_mxfp8 API.
Tests
tests/gemm/test_bmm_mxfp8.py
New parameterized pytest validating MXFP8 BMM (quantization via mxfp8_quantize, autotune path, CC gating, cosine-similarity vs float reference).
Manifests / packaging
requirements.txt, pyproject.toml, setup.py
Listed in manifest; no content shown in diff but referenced.

Sequence Diagram(s)

sequenceDiagram
    participant Test as Test/Benchmark
    participant Quant as mxfp8_quantize
    participant API as bmm_mxfp8 (Public API)
    participant Planner as cuDNN Graph / Planner
    participant Runner as Tunable Runner / Autotune
    participant CuDNN as cuDNN Executor
    participant GPU as GPU Compute

    rect rgb(230,240,255)
    Note over Test,Quant: Input preparation & quantization
    Test->>Quant: raw tensors -> quantized tensors + scales
    Quant-->>Test: quantized A,B and scales
    end

    rect rgb(240,255,230)
    Note over Test,API: Execution request
    Test->>API: call bmm_mxfp8(A_q, B_q, scales, params)
    API->>API: validate dtypes / CC / problem size
    API->>Planner: build cuDNN graph & descriptors (block scales, dequant steps)
    Planner-->>API: execution plan / graph
    API->>Runner: request autotuned runner (tactic selection)
    Runner->>CuDNN: invoke selected tactic on graph
    end

    rect rgb(255,240,230)
    Note over CuDNN,GPU: Compute
    CuDNN->>GPU: execute graph (MXFP8 matmul + dequant)
    GPU-->>CuDNN: result tensor
    CuDNN-->>API: output tensor (cast to requested dtype)
    end

    rect rgb(245,230,255)
    Note over API,Test: Validation
    API-->>Test: result tensor
    Test->>Test: verify shape, dtype, NaNs, cosine-similarity vs reference
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • cyx-6
  • aleozlx
  • bkryu
  • djmmoss
  • yzh119
  • nvmbreughe

Poem

🐰
I hop through tensors, scales in tow,
MXFP8 graphs begin to glow,
Blocks and plans align just right,
Batched mats leap into GPU flight,
A tiny nibble — math goes bright!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 40.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely summarizes the main change: adding support for MXFP8 batch matrix multiplication.
Description check ✅ Passed The description follows the template structure with populated Description and Related Issues sections, but the Tests checklist items are not marked as complete despite tests being added.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @danisereb, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates MXFP8 (Mixed Precision FP8) quantization into the FlashInfer library for Batch Matrix Multiplication (BMM) operations. The primary goal is to enhance the performance and memory footprint of GEMM computations by leveraging the FP8 format for data and FP8_E8M0 for scales, specifically utilizing the cuDNN backend. This addition expands FlashInfer's capabilities in efficient low-precision arithmetic for deep learning workloads.

Highlights

  • New Feature: bmm_mxfp8 Support: This pull request introduces support for Batch Matrix Multiplication with MXFP8 (Mixed Precision FP8) quantization, enabling more efficient GEMM operations.
  • cuDNN Backend Only: Initially, the bmm_mxfp8 operation is exclusively supported through the cuDNN backend, targeting specific NVIDIA GPU architectures (SM100 and above).
  • Benchmarking and Testing: A new test file tests/gemm/test_bmm_mxfp8.py has been added, along with integration into the flashinfer_benchmark utility, including performance benchmarks that show comparable TFLOPs/sec to bmm_fp8.
  • Core Implementation: The core implementation in flashinfer/gemm/gemm_base.py includes functions for building and executing cuDNN graphs for MXFP8 GEMM, handling block scale dimensions, and integrating with the autotuner.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for batched matrix multiplication with MXFP8 data types (bmm_mxfp8), currently leveraging the cuDNN backend. The changes are well-structured, adding the new routine to the core library, including it in the benchmark suite, and providing a dedicated test file.

My review focuses on correctness and potential improvements. I've identified a few critical issues in the benchmark implementation and the core cuDNN graph execution logic that could lead to incorrect results or metrics. I've also noted opportunities to improve test coverage and address some TODO items related to autotuning configuration. Overall, this is a solid addition, and addressing these points will enhance its robustness.

Comment on lines +942 to +946
problem_bytes = (
m * k * torch.float8_e4m3fn.itemsize
+ n * k * torch.float8_e4m3fn.itemsize
+ m * n * res_dtype.itemsize
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The calculation for problem_bytes appears to be missing the batch_size. The problem_flops calculation on line 939 correctly includes it. This will lead to incorrect bandwidth reporting in the benchmark.

Suggested change
problem_bytes = (
m * k * torch.float8_e4m3fn.itemsize
+ n * k * torch.float8_e4m3fn.itemsize
+ m * n * res_dtype.itemsize
)
problem_bytes = (
m * k * torch.float8_e4m3fn.itemsize
+ n * k * torch.float8_e4m3fn.itemsize
+ m * n * res_dtype.itemsize
) * batch_size

Copy link
Copy Markdown
Contributor Author

@danisereb danisereb Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, benchmark results after fix:

python benchmarks/flashinfer_benchmark.py \
--routine bmm_mxfp8 -vv \
--num_iters 30 \
--batch_size 128 \
--m 512 --n 512 --k 4096 \
--out_dtype bfloat16 \
--backends cudnn \
--refcheck

[PERF] cudnn          :: median time 0.117 ms; std 0.001 ms; achieved tflops 2344.958 TFLOPs/sec; achieved tb_per_sec 5.152 TB/sec

No major change in tflops, but tb_per_sec increased to ~5 TB/s.
The HBM bandwidth is still under the max memory 8 TB/s of a single B200 (based on this spec https://www.nvidia.com/en-eu/data-center/dgx-b200/).

This fix is possibly also required in testBmmFp8.

Comment on lines +4066 to +4068
"mxfp8_gemm", # TODO: check if this is correct
runners,
_FP8_GEMM_SM100_TUNING_CONFIG, # TODO: check if this is correct
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The autotuner is configured with mxfp8_gemm as the key and reuses _FP8_GEMM_SM100_TUNING_CONFIG. While this might work, it's worth considering if a dedicated tuning configuration for mxfp8 would be more optimal, as the performance characteristics might differ from standard FP8 GEMM. The TODO comments also suggest this might be a temporary solution.

Comment on lines +14 to +16
@pytest.mark.parametrize("input_dtype", [torch.bfloat16])
@pytest.mark.parametrize("is_sf_swizzled_layout", [True, False])
@pytest.mark.parametrize("res_dtype", [torch.bfloat16])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The test parameterization for input_dtype and res_dtype is limited to torch.bfloat16.

  1. The bmm_mxfp8 function also supports torch.float16 as an output dtype, which should be added to the tests for better coverage.
  2. The docstring for bmm_mxfp8 mentions support for fp8_e5m2 input, but the current mxfp8_quantize function only produces fp8_e4m3fn. This creates a discrepancy between the documented API and the testable functionality. It would be beneficial to either update the quantization function to support e5m2 and add it to the tests, or update the docstring to reflect the current limitation.
Suggested change
@pytest.mark.parametrize("input_dtype", [torch.bfloat16])
@pytest.mark.parametrize("is_sf_swizzled_layout", [True, False])
@pytest.mark.parametrize("res_dtype", [torch.bfloat16])
@pytest.mark.parametrize("input_dtype", [torch.bfloat16])
@pytest.mark.parametrize("is_sf_swizzled_layout", [True, False])
@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16])

Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
return res


def testBmmMxfp8(args):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you think it's better, I can merge this with the existing testBmmFp8.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @bkryu for opinion

Copy link
Copy Markdown
Collaborator

@aleozlx aleozlx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left minor comments. looks good so far

@danisereb danisereb marked this pull request as ready for review December 24, 2025 09:55
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
tests/gemm/test_bmm_mxfp8.py (1)

10-76: Solid end-to-end MXFP8 BMM test; consider a couple of small adjustments

Strengths:

  • Capability gating: Skips SM < 100 and SM 11/12 explicitly, which matches current support for bmm_mxfp8 (SM100/103 only).
  • Quantization/use of mxfp8_quantize: Both A and B are quantized via mxfp8_quantize, and you assert numel(input_mxfp8) == numel(input_scale) * 32, which encodes the 32-element block-size invariant.
  • Autotune path: The test exercises both autotuned and non-autotuned paths via with autotune(auto_tuning): ... bmm_mxfp8(...).
  • Validation: Checks shape, dtype, NaNs, and cosine similarity vs torch.bmm with a conservative min_cos_sim = 0.9.

Two minor suggestions:

  1. Output dtype coverage (): bmm_mxfp8 also supports torch.float16 outputs. Adding torch.float16 to the res_dtype parametrization would improve coverage and catch any FP16-specific issues.

  2. Scale-layout parameter: You parametrize is_sf_swizzled_layout but the GEMM path always treats scales with reordering_type=F8_128x4. Please double-check that the MXFP8 quantizer and cuDNN graph agree on semantics for the non-swizzled case; if only swizzled scales are supported today, it might be better to restrict this test to True (and document that limitation) until the other mode is fully wired.

Overall, the test is well-structured and provides good functional coverage for the new path.

🧹 Nitpick comments (7)
flashinfer/gemm/gemm_base.py (4)

1339-1375: Workspace handling and tactic-specific sizes look correct, but consider cache integration

The MXFP8 execute_cudnn_gemm_mxfp8_graph mirrors the FP4 path and correctly uses graph.get_workspace_size(tactic) when a specific plan index is provided, falling back to graph.get_workspace_size() for the heuristic case. The local reallocation of workspace_buffer when too small is fine functionally, but it bypasses _get_cache_buf’s caching, so large workspaces will be reallocated on subsequent calls instead of being reused.

If workspace sizes for MXFP8 end up consistently larger than DEFAULT_WORKSPACE_SIZE, consider either:

  • increasing the default for the MXFP8 cache key, or
  • pushing the resized buffer back into the cache when you grow it.

This is non-blocking and mainly a perf/fragmentation tweak.


3799-3821: Block-scale dimension helper is clear; maybe document assumptions

_calculate_block_scale_dims embeds the “indestructible 128x4 block” logic into a single helper and uses div_up twice on K to align to (block_size, 4) groups, which matches the intended 128×4 granularity.

If this formula is tied directly to the cuDNN MXFP8 blockscale layout (e.g., FP8_128x4 requirements), a short docstring note about the relationship between (m, n, k, block_size) and the expected scale tensor shapes would make future maintenance safer, especially if cuDNN layouts evolve.


3952-4047: MXFP8 cuDNN runner wiring is consistent with the autotuner, with minor nits

  • _get_cudnn_mxfp8_gemm_graph’s out parameter is unused; it can be removed from the signature and call sites unless you expect to use it for layout decisions later. This would also quiet the static analyzer.
  • _cudnn_gemm_mxfp8_runner.get_valid_tactics currently returns [0] and the forward path passes tactic through to _cudnn_gemm_mxfp8, which maps tactic == -1 to the generic “build all plans and let cuDNN choose” path and tactic >= 0 to a specific plan index. This is compatible with AutoTuner (fallback will still use tactic=-1), but means tuning will only ever consider plan index 0. If cuDNN reports multiple valid plans, you may eventually want to enumerate them via graph.get_execution_plan_count() as in the FP4 runner and profile more than one.

Both points are non-blocking; behavior is correct as-is.


4049-4072: MXFP8 GEMM autotuning reuses FP8 tuning config; acceptable but potentially suboptimal

mxfp8_gemm_sm100 wires MXFP8 GEMM into the autotuner under the "mxfp8_gemm" key and reuses _FP8_GEMM_SM100_TUNING_CONFIG. This matches the fp8_gemm_sm100 pattern and should work functionally, since the inputs list layout is identical ([a, b, scale_a, scale_b, out, workspace_buffer]) and the constraints only depend on a and out shapes.

If MXFP8 kernels end up with different sweet spots (e.g., K or batch-size sensitivities) from standard FP8 kernels, consider introducing a dedicated tuning config later; no change is required now.

benchmarks/routines/gemm.py (3)

26-47: run_gemm_test wiring for bmm_mxfp8 is straightforward

Adding the elif args.routine == "bmm_mxfp8": return testBmmMxfp8(args) branch integrates MXFP8 cleanly into the existing dispatch function without affecting other routines.

One minor nit: the --autotune help text in parse_gemm_args still only mentions mm_fp4 and bmm_fp8, but you now support autotuning for bmm_mxfp8 as well. Consider updating the help string to avoid confusion.


150-156: Autotune help text is slightly stale wrt bmm_mxfp8

The --autotune argument’s help string currently says:

Enable autotuner warmup for supported routines (mm_fp4 and bmm_fp8).

Since testBmmMxfp8 also honors --autotune via autotune_supported_backends = ["cudnn"] and a warmup loop, it would be more accurate to mention bmm_mxfp8 here as well.


764-966: bmm_mxfp8 benchmark implementation looks correct; a few small suggestions

Positives:

  • Argument parsing & backend filtering: You reuse dtype_str_to_torch_dtype and filter_backends_by_compute_capability, with routine_cc_to_supported_backends["bmm_mxfp8"] limiting to "cudnn" on SM100/103. Autotune backend filtering is consistent with bmm_fp8/mm_fp4.
  • Input and quantization: Inputs are [batch_size, m, k] and [batch_size, n, k]ᵀ (so [b, k, n]) in BF16, then quantized via mxfp8_quantize. Shapes and layouts match the bmm_mxfp8 docstring expectations (A: [b, m, k], B: [b, k, n]).
  • Reference & validation: Reference uses torch.bmm(input, mat2). Validation uses cosine similarity with min_cos_sim = 0.9, same metric style as testBmmFp8 but with a looser threshold; that’s reasonable for MXFP8.
  • Autotune integration: Warmup under with autotune(True) is correctly wrapped around bmm_mxfp8 calls for backends that support autotuning ("cudnn").

Minor nits:

  1. Bandwidth accounting comment: The problem_bytes formula intentionally ignores the scale tensors and only accounts for FP8 inputs and BF16/FP16 outputs, but the comment suggests “approximate as 1 byte per element for simplicity.” Either include approximate scale traffic or rephrase the comment to “ignore scale tensors as their traffic is comparatively small” to avoid confusion.

  2. Autotune help alignment: As mentioned earlier, you may want to update the global --autotune help text to mention bmm_mxfp8.

Functionally this benchmark is sound.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between df82616 and a7c6acd.

📒 Files selected for processing (6)
  • benchmarks/routines/flashinfer_benchmark_utils.py
  • benchmarks/routines/gemm.py
  • flashinfer/__init__.py
  • flashinfer/gemm/__init__.py
  • flashinfer/gemm/gemm_base.py
  • tests/gemm/test_bmm_mxfp8.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • flashinfer/gemm/gemm_base.py
🧬 Code graph analysis (4)
flashinfer/__init__.py (1)
flashinfer/gemm/gemm_base.py (1)
  • bmm_mxfp8 (4144-4202)
benchmarks/routines/gemm.py (3)
flashinfer/fp8_quantization.py (1)
  • mxfp8_quantize (147-180)
flashinfer/gemm/gemm_base.py (1)
  • bmm_mxfp8 (4144-4202)
flashinfer/testing/utils.py (1)
  • bench_gpu_time (1484-1631)
tests/gemm/test_bmm_mxfp8.py (3)
flashinfer/autotuner.py (1)
  • autotune (251-262)
flashinfer/gemm/gemm_base.py (1)
  • bmm_mxfp8 (4144-4202)
flashinfer/utils.py (1)
  • get_compute_capability (258-261)
flashinfer/gemm/gemm_base.py (2)
flashinfer/autotuner.py (7)
  • TunableRunner (194-247)
  • get_valid_tactics (196-214)
  • OptimizationProfile (168-183)
  • forward (220-244)
  • AutoTuner (335-791)
  • get (362-365)
  • choose_one (400-534)
flashinfer/utils.py (2)
  • supported_compute_capability (819-899)
  • backend_requirement (902-1184)
🪛 Ruff (0.14.10)
benchmarks/routines/gemm.py

812-814: Avoid specifying long messages outside the exception class

(TRY003)


862-862: Avoid specifying long messages outside the exception class

(TRY003)


923-925: Avoid specifying long messages outside the exception class

(TRY003)

flashinfer/gemm/gemm_base.py

3836-3836: Avoid specifying long messages outside the exception class

(TRY003)


3838-3838: Avoid specifying long messages outside the exception class

(TRY003)


3841-3841: Avoid specifying long messages outside the exception class

(TRY003)


3843-3843: Avoid specifying long messages outside the exception class

(TRY003)


3845-3845: Avoid specifying long messages outside the exception class

(TRY003)


3956-3956: Unused function argument: out

(ARG001)


4019-4019: Unused method argument: inputs

(ARG002)


4020-4020: Unused method argument: profile

(ARG002)


4030-4030: Unused method argument: do_preparation

(ARG002)


4031-4031: Unused method argument: kwargs

(ARG002)


4077-4077: Unused function argument: A

(ARG001)


4078-4078: Unused function argument: B

(ARG001)


4079-4079: Unused function argument: A_scale

(ARG001)


4080-4080: Unused function argument: B_scale

(ARG001)


4081-4081: Unused function argument: dtype

(ARG001)


4082-4082: Unused function argument: out

(ARG001)


4083-4083: Unused function argument: backend

(ARG001)


4092-4095: Avoid specifying long messages outside the exception class

(TRY003)


4101-4101: Unused function argument: A_scale

(ARG001)


4102-4102: Unused function argument: B_scale

(ARG001)


4104-4104: Unused function argument: out

(ARG001)


4105-4105: Unused function argument: backend

(ARG001)


4110-4110: Avoid specifying long messages outside the exception class

(TRY003)


4112-4114: Avoid specifying long messages outside the exception class

(TRY003)


4122-4122: Unused function argument: A

(ARG001)


4123-4123: Unused function argument: B

(ARG001)


4124-4124: Unused function argument: A_scale

(ARG001)


4125-4125: Unused function argument: B_scale

(ARG001)


4126-4126: Unused function argument: dtype

(ARG001)


4127-4127: Unused function argument: out

(ARG001)


4128-4128: Unused function argument: backend

(ARG001)


4185-4185: Avoid specifying long messages outside the exception class

(TRY003)


4188-4188: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (5)
flashinfer/gemm/gemm_base.py (1)

4136-4202: bmm_mxfp8 public API and backend_requirement integration look consistent

The bmm_mxfp8 entry point:

  • Is properly guarded by @backend_requirement using _cudnn_bmm_mxfp8_requirement and _check_bmm_mxfp8_problem_size, with compute capability restricted to SM100/103 via @supported_compute_capability.
  • Validates dtype via _validate_mxfp8_output_dtype, ensuring BF16/FP16 only.
  • Enforces backend == "cudnn" at the Python API level and errors out clearly if cuDNN is unavailable.
  • Allocates out lazily with shape (b, m, n) and uses _get_cache_buf for a workspace buffer, then dispatches into mxfp8_gemm_sm100 with ["cudnn"] runner selection.

Given the current implementation only supports the cuDNN backend, having backend: Literal["cudnn"] is appropriate; the heuristic_func is effectively dormant but harmless. Overall this wiring matches the existing FP8/FP4 patterns.

flashinfer/gemm/__init__.py (1)

1-27: Public export of bmm_mxfp8 is correct and consistent

The new import/export of bmm_mxfp8 from .gemm_base and its inclusion in __all__ align with other GEMM APIs (e.g., bmm_fp8, mm_fp4). This makes the new MXFP8 BMM routine available as flashinfer.gemm.bmm_mxfp8 in a consistent way.

flashinfer/__init__.py (1)

87-91: Top-level bmm_mxfp8 export matches existing GEMM API surface

Importing bmm_mxfp8 from .gemm and re-exporting it at the package root is consistent with how bmm_fp8, mm_fp4, and mm_fp8 are exposed. This is sufficient to surface MXFP8 BMM as flashinfer.bmm_mxfp8.

benchmarks/routines/flashinfer_benchmark_utils.py (2)

92-112: Including bmm_mxfp8 in gemm benchmark_apis is correct

Adding "bmm_mxfp8" to the "gemm" list ensures the new routine is discoverable by the benchmark harness and aligns with how bmm_fp8 and other GEMM routines are registered.


240-249: Backend mapping for bmm_mxfp8 matches capability checks

The routine_cc_to_supported_backends["bmm_mxfp8"] entry only enables "cudnn" on compute capabilities "10.0" and "10.3", matching the @supported_compute_capability([100, 103]) on _cudnn_bmm_mxfp8_requirement. This keeps the benchmark frontend in sync with the backend_requirement logic and avoids running MXFP8 BMM on unsupported architectures.

No changes needed here.

Comment on lines +3824 to +3950
def create_cudnn_execution_plans_mxfp8_gemm(
a_shape,
a_stride,
a_type, # cudnn.data_type, FP8_E4M3 or FP8_E5M2
b_shape,
b_stride,
b_type, # cudnn.data_type, FP8_E4M3 or FP8_E5M2
block_size,
o_type, # cudnn.data_type, BF16 or FP16
device,
):
if len(a_shape) != 3:
raise ValueError(f"A shape must be 3D, got {a_shape}")
if len(b_shape) != 3:
raise ValueError(f"B shape must be 3D, got {b_shape}")

if a_type not in [cudnn.data_type.FP8_E4M3, cudnn.data_type.FP8_E5M2]:
raise ValueError(f"A type must be FP8_E4M3 or FP8_E5M2, got {a_type}")
if b_type not in [cudnn.data_type.FP8_E4M3, cudnn.data_type.FP8_E5M2]:
raise ValueError(f"B type must be FP8_E4M3 or FP8_E5M2, got {b_type}")
if o_type not in [cudnn.data_type.BFLOAT16, cudnn.data_type.HALF]:
raise ValueError(f"Output type must be BF16 or FP16, got {o_type}")

# Extract batch, m, n, k dimensions
b_dim = a_shape[0]
m = a_shape[1]
k = a_shape[2]
n = b_shape[2]

# Calculate block scale dimensions using indestructible block formula
block_scale_dim_m, block_scale_dim_n, block_scale_dim_k = (
_calculate_block_scale_dims(m, n, k, block_size)
)

# For mxfp8, scale tensors need to be reshaped to 3D with correct strides
# cuDNN expects K-major layout: stride for K dimension should be 1
# For block_descale_a: shape [b, block_scale_dim_m, block_scale_dim_k], stride [block_scale_dim_m * block_scale_dim_k, block_scale_dim_k, 1]
# For block_descale_b: shape [b, block_scale_dim_k, block_scale_dim_n], stride [block_scale_dim_n * block_scale_dim_k, 1, block_scale_dim_k]

a_descale_shape = (b_dim, block_scale_dim_m, block_scale_dim_k)
a_descale_stride = (
block_scale_dim_m * block_scale_dim_k,
block_scale_dim_k,
1,
)

b_descale_shape = (b_dim, block_scale_dim_k, block_scale_dim_n)
b_descale_stride = (
block_scale_dim_n * block_scale_dim_k,
1,
block_scale_dim_k,
)

# MXFP8 uses FP8_E4M3/FP8_E5M2 for quantized data
# MXFP8 uses FP8_E8M0 for scale data
scale_type = cudnn.data_type.FP8_E8M0

stream = torch.cuda.current_stream(device)
with cudnn.graph(_get_cudnn_handle(stream)) as (graph, _):
a_cudnn_tensor = graph.tensor(
name="a",
dim=tuple(a_shape), # [b, m, k]
stride=tuple(a_stride), # [m * k, k, 1]
data_type=a_type,
)
b_cudnn_tensor = graph.tensor(
name="b",
dim=tuple(b_shape), # [b, k, n]
stride=tuple(b_stride), # [k * n, 1, k]
data_type=b_type,
)
block_descale_a_cudnn_tensor = graph.tensor(
name="block_descale_a",
dim=a_descale_shape,
stride=a_descale_stride,
data_type=scale_type,
reordering_type=cudnn.tensor_reordering.F8_128x4,
)
block_descale_b_cudnn_tensor = graph.tensor(
name="block_descale_b",
dim=b_descale_shape,
stride=b_descale_stride,
data_type=scale_type,
reordering_type=cudnn.tensor_reordering.F8_128x4,
)

# Dequantize the input tensors
dequant_a_tensor = graph.block_scale_dequantize(
a_cudnn_tensor,
block_descale_a_cudnn_tensor,
block_size=[1, block_size],
name="dequant_a",
)
dequant_a_tensor.set_data_type(cudnn.data_type.FLOAT)
dequant_b_tensor = graph.block_scale_dequantize(
b_cudnn_tensor,
block_descale_b_cudnn_tensor,
block_size=[block_size, 1],
name="dequant_b",
)
dequant_b_tensor.set_data_type(cudnn.data_type.FLOAT)

# The actual matmul operation
c_tensor = graph.matmul(
dequant_a_tensor,
dequant_b_tensor,
compute_data_type=cudnn.data_type.FLOAT,
name="gemm",
)
c_tensor.set_data_type(cudnn.data_type.FLOAT)

# Output the dequantized result with the specified output dtype
c_tensor.set_output(True).set_data_type(o_type)
c_final_cudnn_tensor = c_tensor

a_cudnn_tensor.set_uid(UIDs.A_UID.value)
b_cudnn_tensor.set_uid(UIDs.B_UID.value)
block_descale_a_cudnn_tensor.set_uid(UIDs.BLOCK_DESCALE_A_UID.value)
block_descale_b_cudnn_tensor.set_uid(UIDs.BLOCK_DESCALE_B_UID.value)
c_final_cudnn_tensor.set_uid(UIDs.O_UID.value)

graph.validate()
graph.build_operation_graph()
graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.B])

return graph

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's verify the file and function exist, and examine the exact code
cat -n flashinfer/gemm/gemm_base.py | sed -n '3824,3950p'

Repository: flashinfer-ai/flashinfer

Length of output: 5674


🏁 Script executed:

# Let's also check for the _check_bmm_mxfp8_problem_size function
rg "_check_bmm_mxfp8_problem_size" -n

Repository: flashinfer-ai/flashinfer

Length of output: 218


🏁 Script executed:

# And look for mxfp8_quantize function to understand scale tensor creation
rg "def mxfp8_quantize" -n -A 20

Repository: flashinfer-ai/flashinfer

Length of output: 5447


🏁 Script executed:

# Check for any existing stride validation in the codebase
rg "stride\(-1\)" -n flashinfer/gemm/gemm_base.py

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

cat -n flashinfer/gemm/gemm_base.py | sed -n '4098,4130p'

Repository: flashinfer-ai/flashinfer

Length of output: 1234


🏁 Script executed:

# Look at the mxfp8_quantize function more completely
cat -n flashinfer/fp8_quantization.py | sed -n '147,250p'

Repository: flashinfer-ai/flashinfer

Length of output: 2940


🏁 Script executed:

# Search for any contiguity or stride checks in the codebase related to mxfp8
rg "contiguous|stride" flashinfer/fp8_quantization.py -n -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check if there are any scale tensor validation or size checks anywhere
rg "_scale.*numel\|assert.*stride" flashinfer/gemm/gemm_base.py -n -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check the _calculate_block_scale_dims function
rg "def _calculate_block_scale_dims" -A 20

Repository: flashinfer-ai/flashinfer

Length of output: 1363


🏁 Script executed:

# Look for where create_cudnn_execution_plans_mxfp8_gemm is called to see what inputs are passed
rg "create_cudnn_execution_plans_mxfp8_gemm" -n -B 3 -A 3

Repository: flashinfer-ai/flashinfer

Length of output: 899


🏁 Script executed:

# Check how A_scale and B_scale are used in the bmm_mxfp8 function
rg "def bmm_mxfp8" flashinfer/gemm/gemm_base.py -A 50 | head -80

Repository: flashinfer-ai/flashinfer

Length of output: 1384


🏁 Script executed:

# Look at the full bmm_mxfp8 function to see the complete context
cat -n flashinfer/gemm/gemm_base.py | sed -n '3952,4000p'

Repository: flashinfer-ai/flashinfer

Length of output: 1781


🏁 Script executed:

# Check if there are any tests that verify the behavior with different stride/layout configurations
rg "bmm_mxfp8" flashinfer/tests -l

Repository: flashinfer-ai/flashinfer

Length of output: 124


🏁 Script executed:

# Look at what happens when calling create_cudnn_execution_plans_mxfp8_gemm - is a/b stride checked beforehand?
cat -n flashinfer/gemm/gemm_base.py | sed -n '3950,3990p'

Repository: flashinfer-ai/flashinfer

Length of output: 1453


🏁 Script executed:

# Find test files
find . -path "*/test*" -name "*.py" -type f | grep -i mxfp8 | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 98


🏁 Script executed:

# Check the full bmm_mxfp8 function including requirement checking
cat -n flashinfer/gemm/gemm_base.py | sed -n '4120,4180p'

Repository: flashinfer-ai/flashinfer

Length of output: 2059


🏁 Script executed:

# Look for where A_scale and B_scale are reshaped or prepared
rg "A_scale|B_scale" flashinfer/gemm/gemm_base.py -n | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 1070


🏁 Script executed:

# Check the cudnn_bmm_mxfp8_requirement to see what validation happens
rg "cudnn_bmm_mxfp8_requirement" flashinfer/gemm/gemm_base.py -A 30

Repository: flashinfer-ai/flashinfer

Length of output: 2654


🏁 Script executed:

# Let's look at the test to see what kinds of inputs are being tested
cat tests/gemm/test_bmm_mxfp8.py

Repository: flashinfer-ai/flashinfer

Length of output: 2851


🏁 Script executed:

# Check what mxfp8_quantize actually returns for scale tensor shape
rg "mxfp8_quantize" tests/gemm/test_bmm_mxfp8.py -B 5 -A 10

Repository: flashinfer-ai/flashinfer

Length of output: 1501


🏁 Script executed:

# Let's check if there are any contiguity checks anywhere in the mxfp8 code path
rg "contiguous\(\)" flashinfer/gemm/gemm_base.py -n

Repository: flashinfer-ai/flashinfer

Length of output: 151


🏁 Script executed:

# Check if anywhere in the code path validates that tensors have the expected layout
rg "is_contiguous\|stride\(-1\)" flashinfer/gemm/gemm_base.py -n

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Let's look at what happens with the scale tensors - are they reshaped or validated?
cat -n flashinfer/gemm/gemm_base.py | sed -n '4180,4220p'

Repository: flashinfer-ai/flashinfer

Length of output: 840


Add stride and scale tensor validation to prevent layout mismatches

create_cudnn_execution_plans_mxfp8_gemm hard-codes block-descale shapes/strides assuming K-major (stride[-1] == 1) and contiguous inputs, with no runtime validation. Add two defensive checks:

  1. In _check_bmm_mxfp8_problem_size, validate that A.stride(-1) == 1 and B.stride(-1) == 1 to enforce the K-major layout requirement and fail fast on unexpected tensor layouts.

  2. In the same function, assert that A_scale.numel() == b_dim * block_scale_dim_m * block_scale_dim_k and B_scale.numel() == b_dim * block_scale_dim_k * block_scale_dim_n to catch mismatches between quantization layout and cuDNN's graph expectations.

These checks are defensive and won't affect correct usage with mxfp8_quantize, but will make debugging easier if scale tensors are manually created or inputs have unexpected layouts.

🧰 Tools
🪛 Ruff (0.14.10)

3836-3836: Avoid specifying long messages outside the exception class

(TRY003)


3838-3838: Avoid specifying long messages outside the exception class

(TRY003)


3841-3841: Avoid specifying long messages outside the exception class

(TRY003)


3843-3843: Avoid specifying long messages outside the exception class

(TRY003)


3845-3845: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In flashinfer/gemm/gemm_base.py around lines 3824 to 3950, add defensive
validation in the helper _check_bmm_mxfp8_problem_size to ensure inputs follow
the K-major layout and scale-tensor shapes match the computed block-scale dims:
check that A.stride(-1) == 1 and B.stride(-1) == 1 and raise a ValueError with a
clear message if not, and check that A_scale.numel() == b_dim *
block_scale_dim_m * block_scale_dim_k and B_scale.numel() == b_dim *
block_scale_dim_k * block_scale_dim_n and raise ValueError if either count
mismatches, so the function fails fast on unexpected tensor layouts or scale
size mismatches.

Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Dec 24, 2025

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !216 has been created, and the CI pipeline #40747548 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
flashinfer/gemm/gemm_base.py (1)

4184-4188: Remove redundant backend validation

Since backend has type Literal["cudnn"], the checks at lines 4184-4188 are redundant — the type system already ensures only "cudnn" is accepted. The backend_requirement decorator also handles backend validation.

🔎 Simplification
-    if backend != "cudnn":
-        raise ValueError(f"Invalid backend: {backend}")
-
-    if not CUDNN_AVAILABLE:
-        raise ValueError("cudnn is not available")
-
     if out is None:
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a7c6acd and 8fef4bf.

📒 Files selected for processing (1)
  • flashinfer/gemm/gemm_base.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • flashinfer/gemm/gemm_base.py
🪛 Ruff (0.14.10)
flashinfer/gemm/gemm_base.py

3836-3836: Avoid specifying long messages outside the exception class

(TRY003)


3838-3838: Avoid specifying long messages outside the exception class

(TRY003)


3841-3841: Avoid specifying long messages outside the exception class

(TRY003)


3843-3843: Avoid specifying long messages outside the exception class

(TRY003)


3845-3845: Avoid specifying long messages outside the exception class

(TRY003)


3956-3956: Unused function argument: out

(ARG001)


4019-4019: Unused method argument: inputs

(ARG002)


4020-4020: Unused method argument: profile

(ARG002)


4030-4030: Unused method argument: do_preparation

(ARG002)


4031-4031: Unused method argument: kwargs

(ARG002)


4077-4077: Unused function argument: A

(ARG001)


4078-4078: Unused function argument: B

(ARG001)


4079-4079: Unused function argument: A_scale

(ARG001)


4080-4080: Unused function argument: B_scale

(ARG001)


4081-4081: Unused function argument: dtype

(ARG001)


4082-4082: Unused function argument: out

(ARG001)


4083-4083: Unused function argument: backend

(ARG001)


4092-4095: Avoid specifying long messages outside the exception class

(TRY003)


4101-4101: Unused function argument: A_scale

(ARG001)


4102-4102: Unused function argument: B_scale

(ARG001)


4104-4104: Unused function argument: out

(ARG001)


4105-4105: Unused function argument: backend

(ARG001)


4110-4110: Avoid specifying long messages outside the exception class

(TRY003)


4112-4114: Avoid specifying long messages outside the exception class

(TRY003)


4122-4122: Unused function argument: A

(ARG001)


4123-4123: Unused function argument: B

(ARG001)


4124-4124: Unused function argument: A_scale

(ARG001)


4125-4125: Unused function argument: B_scale

(ARG001)


4126-4126: Unused function argument: dtype

(ARG001)


4127-4127: Unused function argument: out

(ARG001)


4128-4128: Unused function argument: backend

(ARG001)


4185-4185: Avoid specifying long messages outside the exception class

(TRY003)


4188-4188: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (10)
flashinfer/gemm/gemm_base.py (10)

1339-1376: LGTM: cuDNN graph execution correctly handles workspace and tactics

The execution logic properly manages workspace allocation, tactic-based plan selection, and variant pack construction for MXFP8 operations.


3799-3821: LGTM: Block scale dimension calculation implements the indestructible 128×4 layout

The formula correctly pads M and N to multiples of 128, and pads K (divided by block_size) to multiples of 4, matching cuDNN's expected layout requirements for MXFP8.


3952-3978: LGTM: Graph construction wrapper correctly delegates to cached builder

The function appropriately extracts tensor metadata and handles tactic-based plan selection. The unused out parameter maintains API consistency with similar functions.


3980-4013: LGTM: MXFP8 execution correctly uses block_size=32 and integrates graph operations

The function properly sets the MXFP8 block size and coordinates graph retrieval with execution.


4015-4047: LGTM: Runner implementation follows established pattern with appropriate tactic handling

The single-tactic approach ([0]) is consistent with other cuDNN runners that rely on cuDNN's internal heuristics. The TODO is noted but not blocking.


4049-4073: Consider dedicated tuning configuration for MXFP8 in future optimization

The function reuses _FP8_GEMM_SM100_TUNING_CONFIG from standard FP8 GEMM. While this works, MXFP8's different block structure (32-element blocks vs standard FP8) may benefit from specialized tuning parameters in future performance work. The TODOs at lines 4066 and 4068 correctly flag this for follow-up.

As noted in past reviews, this is acceptable for initial implementation.


4075-4087: LGTM: Backend availability check correctly validates cuDNN presence

This requirement function appropriately verifies cuDNN availability for SM100/103. The unused parameters are part of the backend_requirement decorator interface.


4089-4096: LGTM: Output dtype validation correctly restricts to bf16/fp16

The validation logic is clear and provides helpful error messages.


4098-4118: LGTM: Problem size validation correctly checks 3D shapes and K-dimension matching

The validation logic properly ensures tensors are 3D and that the K dimensions align (A.shape[2] == B.shape[1]). The error messages provide full shape information for debugging.


4120-4134: LGTM: Heuristic function correctly filters for cuDNN backend

The backend selection logic appropriately checks cuDNN availability. Unused parameters are part of the heuristic function interface.

Comment on lines +3823 to +3950
@functools.cache
def create_cudnn_execution_plans_mxfp8_gemm(
a_shape,
a_stride,
a_type, # cudnn.data_type, FP8_E4M3 or FP8_E5M2
b_shape,
b_stride,
b_type, # cudnn.data_type, FP8_E4M3 or FP8_E5M2
block_size,
o_type, # cudnn.data_type, BF16 or FP16
device,
):
if len(a_shape) != 3:
raise ValueError(f"A shape must be 3D, got {a_shape}")
if len(b_shape) != 3:
raise ValueError(f"B shape must be 3D, got {b_shape}")

if a_type not in [cudnn.data_type.FP8_E4M3, cudnn.data_type.FP8_E5M2]:
raise ValueError(f"A type must be FP8_E4M3 or FP8_E5M2, got {a_type}")
if b_type not in [cudnn.data_type.FP8_E4M3, cudnn.data_type.FP8_E5M2]:
raise ValueError(f"B type must be FP8_E4M3 or FP8_E5M2, got {b_type}")
if o_type not in [cudnn.data_type.BFLOAT16, cudnn.data_type.HALF]:
raise ValueError(f"Output type must be BF16 or FP16, got {o_type}")

# Extract batch, m, n, k dimensions
b_dim = a_shape[0]
m = a_shape[1]
k = a_shape[2]
n = b_shape[2]

# Calculate block scale dimensions using indestructible block formula
block_scale_dim_m, block_scale_dim_n, block_scale_dim_k = (
_calculate_block_scale_dims(m, n, k, block_size)
)

# For mxfp8, scale tensors need to be reshaped to 3D with correct strides
# cuDNN expects K-major layout: stride for K dimension should be 1
# For block_descale_a: shape [b, block_scale_dim_m, block_scale_dim_k], stride [block_scale_dim_m * block_scale_dim_k, block_scale_dim_k, 1]
# For block_descale_b: shape [b, block_scale_dim_k, block_scale_dim_n], stride [block_scale_dim_n * block_scale_dim_k, 1, block_scale_dim_k]

a_descale_shape = (b_dim, block_scale_dim_m, block_scale_dim_k)
a_descale_stride = (
block_scale_dim_m * block_scale_dim_k,
block_scale_dim_k,
1,
)

b_descale_shape = (b_dim, block_scale_dim_k, block_scale_dim_n)
b_descale_stride = (
block_scale_dim_n * block_scale_dim_k,
1,
block_scale_dim_k,
)

# MXFP8 uses FP8_E4M3/FP8_E5M2 for quantized data
# MXFP8 uses FP8_E8M0 for scale data
scale_type = cudnn.data_type.FP8_E8M0

stream = torch.cuda.current_stream(device)
with cudnn.graph(_get_cudnn_handle(stream)) as (graph, _):
a_cudnn_tensor = graph.tensor(
name="a",
dim=tuple(a_shape), # [b, m, k]
stride=tuple(a_stride), # [m * k, k, 1]
data_type=a_type,
)
b_cudnn_tensor = graph.tensor(
name="b",
dim=tuple(b_shape), # [b, k, n]
stride=tuple(b_stride), # [k * n, 1, k]
data_type=b_type,
)
block_descale_a_cudnn_tensor = graph.tensor(
name="block_descale_a",
dim=a_descale_shape,
stride=a_descale_stride,
data_type=scale_type,
reordering_type=cudnn.tensor_reordering.F8_128x4,
)
block_descale_b_cudnn_tensor = graph.tensor(
name="block_descale_b",
dim=b_descale_shape,
stride=b_descale_stride,
data_type=scale_type,
reordering_type=cudnn.tensor_reordering.F8_128x4,
)

# Dequantize the input tensors
dequant_a_tensor = graph.block_scale_dequantize(
a_cudnn_tensor,
block_descale_a_cudnn_tensor,
block_size=[1, block_size],
name="dequant_a",
)
dequant_a_tensor.set_data_type(cudnn.data_type.FLOAT)
dequant_b_tensor = graph.block_scale_dequantize(
b_cudnn_tensor,
block_descale_b_cudnn_tensor,
block_size=[block_size, 1],
name="dequant_b",
)
dequant_b_tensor.set_data_type(cudnn.data_type.FLOAT)

# The actual matmul operation
c_tensor = graph.matmul(
dequant_a_tensor,
dequant_b_tensor,
compute_data_type=cudnn.data_type.FLOAT,
name="gemm",
)
c_tensor.set_data_type(cudnn.data_type.FLOAT)

# Output the dequantized result with the specified output dtype
c_tensor.set_output(True).set_data_type(o_type)
c_final_cudnn_tensor = c_tensor

a_cudnn_tensor.set_uid(UIDs.A_UID.value)
b_cudnn_tensor.set_uid(UIDs.B_UID.value)
block_descale_a_cudnn_tensor.set_uid(UIDs.BLOCK_DESCALE_A_UID.value)
block_descale_b_cudnn_tensor.set_uid(UIDs.BLOCK_DESCALE_B_UID.value)
c_final_cudnn_tensor.set_uid(UIDs.O_UID.value)

graph.validate()
graph.build_operation_graph()
graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.B])

return graph

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Add input tensor and scale tensor validation to prevent silent failures

This function hard-codes K-major strides (stride[-1] == 1) and specific block-scale shapes when constructing the cuDNN graph, but performs no runtime validation that the actual input tensors a and b have the expected layout, or that the scale tensors match the computed dimensions. Add defensive checks:

  1. Before graph construction, validate that the input strides match K-major expectations (e.g., a_stride[-1] == 1 and b_stride[-1] == 1).
  2. In the calling code (or here if scale tensors are accessible), assert that scale tensor element counts match block_scale_dim_m * block_scale_dim_k for A and block_scale_dim_k * block_scale_dim_n for B.

These checks ensure early, clear failures if quantization or layout assumptions are violated, rather than silent cuDNN errors or incorrect results.

Based on past review comments indicating this validation gap.

🧰 Tools
🪛 Ruff (0.14.10)

3836-3836: Avoid specifying long messages outside the exception class

(TRY003)


3838-3838: Avoid specifying long messages outside the exception class

(TRY003)


3841-3841: Avoid specifying long messages outside the exception class

(TRY003)


3843-3843: Avoid specifying long messages outside the exception class

(TRY003)


3845-3845: Avoid specifying long messages outside the exception class

(TRY003)

Comment on lines +4144 to +4202
def bmm_mxfp8(
A: torch.Tensor,
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
dtype: torch.dtype,
out: Optional[torch.Tensor] = None,
backend: Literal["cudnn"] = "cudnn",
) -> torch.Tensor:
r"""BMM MXFP8

Parameters
----------
A: torch.Tensor
Input tensor, shape (b, m, k), fp8 e4m3 or fp8 e5m2.

B: torch.Tensor
Mat2 tensor, shape (b, k, n), should be column major, fp8 e4m3 or fp8 e5m2.

A_scale: torch.Tensor
Scale tensor for A, uint8 (fp8 e8m0 format).

B_scale: torch.Tensor
Scale tensor for B, uint8 (fp8 e8m0 format).

dtype: torch.dtype
out dtype, bf16 or fp16.

out: Optional[torch.Tensor]
Out tensor, shape (b, m, n), bf16 or fp16, defaults to ``None``.

backend: Literal["cudnn"]
The backend to use for the operation. Defaults to ``"cudnn"``.

Returns
-------
out: torch.Tensor
Out tensor, shape (b, m, n), bf16 or fp16.
"""

if backend != "cudnn":
raise ValueError(f"Invalid backend: {backend}")

if not CUDNN_AVAILABLE:
raise ValueError("cudnn is not available")

if out is None:
out = torch.empty(
(A.shape[0], A.shape[1], B.shape[2]),
device=A.device,
dtype=dtype,
)

workspace_buffer = _get_cache_buf(
"bmm_mxfp8_workspace", DEFAULT_WORKSPACE_SIZE, A.device
)

mxfp8_gemm_sm100(A, B, A_scale, B_scale, out, workspace_buffer, ["cudnn"])
return out
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Add input tensor validation to ensure correct dtypes and scale tensor shapes

The function should validate that:

  1. A and B have FP8 dtypes (torch.float8_e4m3fn or torch.float8_e5m2)
  2. A_scale and B_scale have dtype torch.uint8 (FP8_E8M0 format)
  3. Scale tensor shapes match the expected block-scale dimensions computed from A and B shapes

These checks would fail fast on incorrect inputs rather than producing cuDNN errors or incorrect results downstream.

🔎 Example validation to add after line 4195
+    # Validate input dtypes
+    if A.dtype not in (torch.float8_e4m3fn, torch.float8_e5m2):
+        raise ValueError(
+            f"A must have FP8 dtype (torch.float8_e4m3fn or torch.float8_e5m2), got {A.dtype}"
+        )
+    if B.dtype not in (torch.float8_e4m3fn, torch.float8_e5m2):
+        raise ValueError(
+            f"B must have FP8 dtype (torch.float8_e4m3fn or torch.float8_e5m2), got {B.dtype}"
+        )
+    
+    # Validate scale tensor dtypes
+    if A_scale.dtype != torch.uint8:
+        raise ValueError(f"A_scale must be uint8 (FP8_E8M0), got {A_scale.dtype}")
+    if B_scale.dtype != torch.uint8:
+        raise ValueError(f"B_scale must be uint8 (FP8_E8M0), got {B_scale.dtype}")
+    
+    # Validate scale tensor shapes
+    block_size = 32  # MXFP8 block size
+    b_dim, m, k = A.shape
+    n = B.shape[2]
+    block_scale_dim_m, block_scale_dim_n, block_scale_dim_k = _calculate_block_scale_dims(m, n, k, block_size)
+    expected_a_scale_shape = (b_dim, block_scale_dim_m, block_scale_dim_k)
+    expected_b_scale_shape = (b_dim, block_scale_dim_k, block_scale_dim_n)
+    if A_scale.shape != expected_a_scale_shape:
+        raise ValueError(
+            f"A_scale shape mismatch. Expected {expected_a_scale_shape}, got {A_scale.shape}"
+        )
+    if B_scale.shape != expected_b_scale_shape:
+        raise ValueError(
+            f"B_scale shape mismatch. Expected {expected_b_scale_shape}, got {B_scale.shape}"
+        )
🧰 Tools
🪛 Ruff (0.14.10)

4185-4185: Avoid specifying long messages outside the exception class

(TRY003)


4188-4188: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In flashinfer/gemm/gemm_base.py around lines 4144 to 4202, add input validation
after the docstring (around line 4195): verify A.dtype and B.dtype are one of
torch.float8_e4m3fn or torch.float8_e5m2 and raise TypeError otherwise; verify
A_scale.dtype and B_scale.dtype are torch.uint8 and raise TypeError otherwise;
validate A_scale and B_scale shapes match the expected block-scale dimensions
derived from A and B (compute expected scale shapes from A.shape and B.shape
using the same FP8 blocking logic used by mxfp8_gemm_sm100 or a small helper:
blocks_k = ceil_div(A.shape[2], FP8_SCALE_BLOCK), expected_A_scale_shape =
(A.shape[0], blocks_k) and expected_B_scale_shape = (B.shape[0],
ceil_div(B.shape[2], FP8_SCALE_BLOCK)) or use the existing helper if present),
and raise ValueError with clear messages if shapes mismatch.

Invalid use of get_workspace_size.

Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
flashinfer/gemm/gemm_base.py (1)

3819-3946: Consider adding stride validation for K-major layout requirement

The function hard-codes K-major strides in the cuDNN graph (lines 3856-3871 show stride configuration with stride[-1] == 1 expectation). While this will work correctly when used with mxfp8_quantize, adding defensive checks would help catch misuse:

# At the start of the function, after shape validation:
if a_stride[-1] != 1:
    raise ValueError(f"A must have K-major layout (stride[-1] == 1), got stride {a_stride}")
if b_stride[-1] != 1:
    raise ValueError(f"B must have N-major layout for column-major (stride[-1] != 1 expected), got stride {b_stride}")

This is a defensive measure—correct usage won't trigger these checks, but they would provide clear error messages for debugging if tensors with unexpected layouts are passed.

Based on past review comments.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8fef4bf and 3447cbe.

📒 Files selected for processing (1)
  • flashinfer/gemm/gemm_base.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • flashinfer/gemm/gemm_base.py
🪛 Ruff (0.14.10)
flashinfer/gemm/gemm_base.py

3832-3832: Avoid specifying long messages outside the exception class

(TRY003)


3834-3834: Avoid specifying long messages outside the exception class

(TRY003)


3837-3837: Avoid specifying long messages outside the exception class

(TRY003)


3839-3839: Avoid specifying long messages outside the exception class

(TRY003)


3841-3841: Avoid specifying long messages outside the exception class

(TRY003)


3952-3952: Unused function argument: out

(ARG001)


4015-4015: Unused method argument: inputs

(ARG002)


4016-4016: Unused method argument: profile

(ARG002)


4026-4026: Unused method argument: do_preparation

(ARG002)


4027-4027: Unused method argument: kwargs

(ARG002)


4073-4073: Unused function argument: A

(ARG001)


4074-4074: Unused function argument: B

(ARG001)


4075-4075: Unused function argument: A_scale

(ARG001)


4076-4076: Unused function argument: B_scale

(ARG001)


4077-4077: Unused function argument: dtype

(ARG001)


4078-4078: Unused function argument: out

(ARG001)


4079-4079: Unused function argument: backend

(ARG001)


4088-4091: Avoid specifying long messages outside the exception class

(TRY003)


4097-4097: Unused function argument: A_scale

(ARG001)


4098-4098: Unused function argument: B_scale

(ARG001)


4100-4100: Unused function argument: out

(ARG001)


4101-4101: Unused function argument: backend

(ARG001)


4106-4106: Avoid specifying long messages outside the exception class

(TRY003)


4108-4110: Avoid specifying long messages outside the exception class

(TRY003)


4118-4118: Unused function argument: A

(ARG001)


4119-4119: Unused function argument: B

(ARG001)


4120-4120: Unused function argument: A_scale

(ARG001)


4121-4121: Unused function argument: B_scale

(ARG001)


4122-4122: Unused function argument: dtype

(ARG001)


4123-4123: Unused function argument: out

(ARG001)


4124-4124: Unused function argument: backend

(ARG001)


4181-4181: Avoid specifying long messages outside the exception class

(TRY003)


4184-4184: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (7)
flashinfer/gemm/gemm_base.py (7)

1339-1372: LGTM!

The execution function correctly handles the variant pack setup and workspace buffer sizing. The removal of the .view(torch.float8_e4m3fn) calls (flagged in previous review) is correct since MXFP8 tensors should already have the appropriate dtype.


3795-3817: LGTM!

The block scale dimension calculation correctly applies the indestructible block formula with proper ceiling division.


4045-4068: Verify tuning configuration is appropriate for MXFP8

The function reuses _FP8_GEMM_SM100_TUNING_CONFIG which was designed for FP8 GEMM. While the tensor layout is similar (3D with dynamic M dimension), MXFP8's block-scaled nature might benefit from different bucketing strategies in the future.

The current implementation is functional, but consider adding a dedicated _MXFP8_GEMM_SM100_TUNING_CONFIG if profiling shows different optimal configurations.


4011-4043: LGTM!

The runner follows the same pattern as _cudnn_gemm_fp8_runner, using [0] as the default tactic which delegates to cuDNN's internal heuristics. The TODO comment can be removed since this is consistent with the existing FP8 implementation.


4140-4198: LGTM!

The public API is well-documented and follows the established patterns from bmm_fp8. The explicit backend and cuDNN availability checks provide clear error messages despite some redundancy with the decorator validation.


3948-4008: LGTM!

The graph creation and execution functions follow the established FP4 GEMM patterns. The unused out parameter (flagged by static analysis) is retained for API consistency, which is the same pattern used in _get_cudnn_fp4_gemm_graph.


4071-4091: LGTM!

The requirement checker and dtype validator follow the established patterns from FP8/FP4 implementations. The unused function arguments (flagged by static analysis) are required by the @backend_requirement decorator's expected function signature.

Comment on lines +4094 to +4113
def _check_bmm_mxfp8_problem_size(
A: torch.Tensor,
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
dtype: torch.dtype,
out: Optional[torch.Tensor] = None,
backend: Literal["cudnn"] = "cudnn",
):
# Check input tensors
if A.ndim != 3 or B.ndim != 3:
# A is [b, m, k], B is [b, k, n]
raise ValueError(f"bmm_mxfp8 accepts 3d tensors, got {A.shape=} and {B.shape=}")
if A.shape[2] != B.shape[1]:
raise ValueError(
f"K dimension (last dim of A) mismatch in bmm_mxfp8. got {A.shape=}, {B.shape=}"
)

_validate_mxfp8_output_dtype(dtype)
return True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add batch dimension and input dtype validation

The function validates K dimensions but misses batch dimension matching and input dtype validation. Consider adding:

🔎 Proposed validation additions
 def _check_bmm_mxfp8_problem_size(
     A: torch.Tensor,
     B: torch.Tensor,
     A_scale: torch.Tensor,
     B_scale: torch.Tensor,
     dtype: torch.dtype,
     out: Optional[torch.Tensor] = None,
     backend: Literal["cudnn"] = "cudnn",
 ):
     # Check input tensors
     if A.ndim != 3 or B.ndim != 3:
         # A is [b, m, k], B is [b, k, n]
         raise ValueError(f"bmm_mxfp8 accepts 3d tensors, got {A.shape=} and {B.shape=}")
+    if A.shape[0] != B.shape[0]:
+        raise ValueError(
+            f"Batch dimension mismatch in bmm_mxfp8. got {A.shape[0]=}, {B.shape[0]=}"
+        )
     if A.shape[2] != B.shape[1]:
         raise ValueError(
             f"K dimension (last dim of A) mismatch in bmm_mxfp8. got {A.shape=}, {B.shape=}"
         )
+    # Validate input dtypes
+    if A.dtype not in (torch.float8_e4m3fn, torch.float8_e5m2):
+        raise ValueError(f"A must have FP8 dtype, got {A.dtype}")
+    if B.dtype not in (torch.float8_e4m3fn, torch.float8_e5m2):
+        raise ValueError(f"B must have FP8 dtype, got {B.dtype}")
+    if A_scale.dtype != torch.uint8:
+        raise ValueError(f"A_scale must be uint8 (FP8_E8M0), got {A_scale.dtype}")
+    if B_scale.dtype != torch.uint8:
+        raise ValueError(f"B_scale must be uint8 (FP8_E8M0), got {B_scale.dtype}")
 
     _validate_mxfp8_output_dtype(dtype)
     return True

These checks help fail fast with clear error messages rather than producing cuDNN errors or incorrect results downstream. Based on past review comments.

🧰 Tools
🪛 Ruff (0.14.10)

4097-4097: Unused function argument: A_scale

(ARG001)


4098-4098: Unused function argument: B_scale

(ARG001)


4100-4100: Unused function argument: out

(ARG001)


4101-4101: Unused function argument: backend

(ARG001)


4106-4106: Avoid specifying long messages outside the exception class

(TRY003)


4108-4110: Avoid specifying long messages outside the exception class

(TRY003)

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #40747548: 1/20 passed

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Dec 25, 2025

the 33 failed on test_bmm_fp8 in the pipeline were expected from main.

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Dec 25, 2025

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !216 has been updated with latest changes, and the CI pipeline #40779177 is currently running. I'll report back once the pipeline job completes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants