feat: Add support for bmm mxfp8#2256
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. 📝 WalkthroughWalkthroughAdds MXFP8 batched matrix-multiplication (bmm_mxfp8): new public API, cuDNN MXFP8 graph creation/execution and autotune runner, benchmark/test integration, and dtype/backend support mappings. Changes
Sequence Diagram(s)sequenceDiagram
participant Test as Test/Benchmark
participant Quant as mxfp8_quantize
participant API as bmm_mxfp8 (Public API)
participant Planner as cuDNN Graph / Planner
participant Runner as Tunable Runner / Autotune
participant CuDNN as cuDNN Executor
participant GPU as GPU Compute
rect rgb(230,240,255)
Note over Test,Quant: Input preparation & quantization
Test->>Quant: raw tensors -> quantized tensors + scales
Quant-->>Test: quantized A,B and scales
end
rect rgb(240,255,230)
Note over Test,API: Execution request
Test->>API: call bmm_mxfp8(A_q, B_q, scales, params)
API->>API: validate dtypes / CC / problem size
API->>Planner: build cuDNN graph & descriptors (block scales, dequant steps)
Planner-->>API: execution plan / graph
API->>Runner: request autotuned runner (tactic selection)
Runner->>CuDNN: invoke selected tactic on graph
end
rect rgb(255,240,230)
Note over CuDNN,GPU: Compute
CuDNN->>GPU: execute graph (MXFP8 matmul + dequant)
GPU-->>CuDNN: result tensor
CuDNN-->>API: output tensor (cast to requested dtype)
end
rect rgb(245,230,255)
Note over API,Test: Validation
API-->>Test: result tensor
Test->>Test: verify shape, dtype, NaNs, cosine-similarity vs reference
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @danisereb, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates MXFP8 (Mixed Precision FP8) quantization into the FlashInfer library for Batch Matrix Multiplication (BMM) operations. The primary goal is to enhance the performance and memory footprint of GEMM computations by leveraging the FP8 format for data and FP8_E8M0 for scales, specifically utilizing the cuDNN backend. This addition expands FlashInfer's capabilities in efficient low-precision arithmetic for deep learning workloads. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for batched matrix multiplication with MXFP8 data types (bmm_mxfp8), currently leveraging the cuDNN backend. The changes are well-structured, adding the new routine to the core library, including it in the benchmark suite, and providing a dedicated test file.
My review focuses on correctness and potential improvements. I've identified a few critical issues in the benchmark implementation and the core cuDNN graph execution logic that could lead to incorrect results or metrics. I've also noted opportunities to improve test coverage and address some TODO items related to autotuning configuration. Overall, this is a solid addition, and addressing these points will enhance its robustness.
benchmarks/routines/gemm.py
Outdated
| problem_bytes = ( | ||
| m * k * torch.float8_e4m3fn.itemsize | ||
| + n * k * torch.float8_e4m3fn.itemsize | ||
| + m * n * res_dtype.itemsize | ||
| ) |
There was a problem hiding this comment.
The calculation for problem_bytes appears to be missing the batch_size. The problem_flops calculation on line 939 correctly includes it. This will lead to incorrect bandwidth reporting in the benchmark.
| problem_bytes = ( | |
| m * k * torch.float8_e4m3fn.itemsize | |
| + n * k * torch.float8_e4m3fn.itemsize | |
| + m * n * res_dtype.itemsize | |
| ) | |
| problem_bytes = ( | |
| m * k * torch.float8_e4m3fn.itemsize | |
| + n * k * torch.float8_e4m3fn.itemsize | |
| + m * n * res_dtype.itemsize | |
| ) * batch_size |
There was a problem hiding this comment.
Fixed, benchmark results after fix:
python benchmarks/flashinfer_benchmark.py \
--routine bmm_mxfp8 -vv \
--num_iters 30 \
--batch_size 128 \
--m 512 --n 512 --k 4096 \
--out_dtype bfloat16 \
--backends cudnn \
--refcheck
[PERF] cudnn :: median time 0.117 ms; std 0.001 ms; achieved tflops 2344.958 TFLOPs/sec; achieved tb_per_sec 5.152 TB/sec
No major change in tflops, but tb_per_sec increased to ~5 TB/s.
The HBM bandwidth is still under the max memory 8 TB/s of a single B200 (based on this spec https://www.nvidia.com/en-eu/data-center/dgx-b200/).
This fix is possibly also required in testBmmFp8.
| "mxfp8_gemm", # TODO: check if this is correct | ||
| runners, | ||
| _FP8_GEMM_SM100_TUNING_CONFIG, # TODO: check if this is correct |
There was a problem hiding this comment.
The autotuner is configured with mxfp8_gemm as the key and reuses _FP8_GEMM_SM100_TUNING_CONFIG. While this might work, it's worth considering if a dedicated tuning configuration for mxfp8 would be more optimal, as the performance characteristics might differ from standard FP8 GEMM. The TODO comments also suggest this might be a temporary solution.
| @pytest.mark.parametrize("input_dtype", [torch.bfloat16]) | ||
| @pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) | ||
| @pytest.mark.parametrize("res_dtype", [torch.bfloat16]) |
There was a problem hiding this comment.
The test parameterization for input_dtype and res_dtype is limited to torch.bfloat16.
- The
bmm_mxfp8function also supportstorch.float16as an output dtype, which should be added to the tests for better coverage. - The docstring for
bmm_mxfp8mentions support forfp8_e5m2input, but the currentmxfp8_quantizefunction only producesfp8_e4m3fn. This creates a discrepancy between the documented API and the testable functionality. It would be beneficial to either update the quantization function to supporte5m2and add it to the tests, or update the docstring to reflect the current limitation.
| @pytest.mark.parametrize("input_dtype", [torch.bfloat16]) | |
| @pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) | |
| @pytest.mark.parametrize("res_dtype", [torch.bfloat16]) | |
| @pytest.mark.parametrize("input_dtype", [torch.bfloat16]) | |
| @pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) | |
| @pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16]) |
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
c3cb9ee to
1c4f05d
Compare
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
| return res | ||
|
|
||
|
|
||
| def testBmmMxfp8(args): |
There was a problem hiding this comment.
If you think it's better, I can merge this with the existing testBmmFp8.
aleozlx
left a comment
There was a problem hiding this comment.
left minor comments. looks good so far
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (1)
tests/gemm/test_bmm_mxfp8.py (1)
10-76: Solid end-to-end MXFP8 BMM test; consider a couple of small adjustmentsStrengths:
- Capability gating: Skips SM < 100 and SM 11/12 explicitly, which matches current support for
bmm_mxfp8(SM100/103 only).- Quantization/use of mxfp8_quantize: Both
AandBare quantized viamxfp8_quantize, and you assertnumel(input_mxfp8) == numel(input_scale) * 32, which encodes the 32-element block-size invariant.- Autotune path: The test exercises both autotuned and non-autotuned paths via
with autotune(auto_tuning): ... bmm_mxfp8(...).- Validation: Checks shape, dtype, NaNs, and cosine similarity vs
torch.bmmwith a conservativemin_cos_sim = 0.9.Two minor suggestions:
Output dtype coverage ():
bmm_mxfp8also supportstorch.float16outputs. Addingtorch.float16to theres_dtypeparametrization would improve coverage and catch any FP16-specific issues.Scale-layout parameter: You parametrize
is_sf_swizzled_layoutbut the GEMM path always treats scales withreordering_type=F8_128x4. Please double-check that the MXFP8 quantizer and cuDNN graph agree on semantics for the non-swizzled case; if only swizzled scales are supported today, it might be better to restrict this test toTrue(and document that limitation) until the other mode is fully wired.Overall, the test is well-structured and provides good functional coverage for the new path.
🧹 Nitpick comments (7)
flashinfer/gemm/gemm_base.py (4)
1339-1375: Workspace handling and tactic-specific sizes look correct, but consider cache integrationThe MXFP8
execute_cudnn_gemm_mxfp8_graphmirrors the FP4 path and correctly usesgraph.get_workspace_size(tactic)when a specific plan index is provided, falling back tograph.get_workspace_size()for the heuristic case. The local reallocation ofworkspace_bufferwhen too small is fine functionally, but it bypasses_get_cache_buf’s caching, so large workspaces will be reallocated on subsequent calls instead of being reused.If workspace sizes for MXFP8 end up consistently larger than
DEFAULT_WORKSPACE_SIZE, consider either:
- increasing the default for the MXFP8 cache key, or
- pushing the resized buffer back into the cache when you grow it.
This is non-blocking and mainly a perf/fragmentation tweak.
3799-3821: Block-scale dimension helper is clear; maybe document assumptions
_calculate_block_scale_dimsembeds the “indestructible 128x4 block” logic into a single helper and usesdiv_uptwice on K to align to(block_size, 4)groups, which matches the intended 128×4 granularity.If this formula is tied directly to the cuDNN MXFP8 blockscale layout (e.g., FP8_128x4 requirements), a short docstring note about the relationship between
(m, n, k, block_size)and the expected scale tensor shapes would make future maintenance safer, especially if cuDNN layouts evolve.
3952-4047: MXFP8 cuDNN runner wiring is consistent with the autotuner, with minor nits
_get_cudnn_mxfp8_gemm_graph’soutparameter is unused; it can be removed from the signature and call sites unless you expect to use it for layout decisions later. This would also quiet the static analyzer._cudnn_gemm_mxfp8_runner.get_valid_tacticscurrently returns[0]and theforwardpath passestacticthrough to_cudnn_gemm_mxfp8, which mapstactic == -1to the generic “build all plans and let cuDNN choose” path andtactic >= 0to a specific plan index. This is compatible withAutoTuner(fallback will still usetactic=-1), but means tuning will only ever consider plan index 0. If cuDNN reports multiple valid plans, you may eventually want to enumerate them viagraph.get_execution_plan_count()as in the FP4 runner and profile more than one.Both points are non-blocking; behavior is correct as-is.
4049-4072: MXFP8 GEMM autotuning reuses FP8 tuning config; acceptable but potentially suboptimal
mxfp8_gemm_sm100wires MXFP8 GEMM into the autotuner under the"mxfp8_gemm"key and reuses_FP8_GEMM_SM100_TUNING_CONFIG. This matches thefp8_gemm_sm100pattern and should work functionally, since the inputs list layout is identical ([a, b, scale_a, scale_b, out, workspace_buffer]) and the constraints only depend onaandoutshapes.If MXFP8 kernels end up with different sweet spots (e.g., K or batch-size sensitivities) from standard FP8 kernels, consider introducing a dedicated tuning config later; no change is required now.
benchmarks/routines/gemm.py (3)
26-47: run_gemm_test wiring for bmm_mxfp8 is straightforwardAdding the
elif args.routine == "bmm_mxfp8": return testBmmMxfp8(args)branch integrates MXFP8 cleanly into the existing dispatch function without affecting other routines.One minor nit: the
--autotunehelp text inparse_gemm_argsstill only mentionsmm_fp4andbmm_fp8, but you now support autotuning forbmm_mxfp8as well. Consider updating the help string to avoid confusion.
150-156: Autotune help text is slightly stale wrt bmm_mxfp8The
--autotuneargument’s help string currently says:Enable autotuner warmup for supported routines (mm_fp4 and bmm_fp8).
Since
testBmmMxfp8also honors--autotuneviaautotune_supported_backends = ["cudnn"]and a warmup loop, it would be more accurate to mentionbmm_mxfp8here as well.
764-966: bmm_mxfp8 benchmark implementation looks correct; a few small suggestionsPositives:
- Argument parsing & backend filtering: You reuse
dtype_str_to_torch_dtypeandfilter_backends_by_compute_capability, withroutine_cc_to_supported_backends["bmm_mxfp8"]limiting to"cudnn"on SM100/103. Autotune backend filtering is consistent withbmm_fp8/mm_fp4.- Input and quantization: Inputs are
[batch_size, m, k]and[batch_size, n, k]ᵀ(so[b, k, n]) in BF16, then quantized viamxfp8_quantize. Shapes and layouts match thebmm_mxfp8docstring expectations (A: [b, m, k],B: [b, k, n]).- Reference & validation: Reference uses
torch.bmm(input, mat2). Validation uses cosine similarity withmin_cos_sim = 0.9, same metric style astestBmmFp8but with a looser threshold; that’s reasonable for MXFP8.- Autotune integration: Warmup under
with autotune(True)is correctly wrapped aroundbmm_mxfp8calls for backends that support autotuning ("cudnn").Minor nits:
Bandwidth accounting comment: The
problem_bytesformula intentionally ignores the scale tensors and only accounts for FP8 inputs and BF16/FP16 outputs, but the comment suggests “approximate as 1 byte per element for simplicity.” Either include approximate scale traffic or rephrase the comment to “ignore scale tensors as their traffic is comparatively small” to avoid confusion.Autotune help alignment: As mentioned earlier, you may want to update the global
--autotunehelp text to mentionbmm_mxfp8.Functionally this benchmark is sound.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
benchmarks/routines/flashinfer_benchmark_utils.pybenchmarks/routines/gemm.pyflashinfer/__init__.pyflashinfer/gemm/__init__.pyflashinfer/gemm/gemm_base.pytests/gemm/test_bmm_mxfp8.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
flashinfer/gemm/gemm_base.py
🧬 Code graph analysis (4)
flashinfer/__init__.py (1)
flashinfer/gemm/gemm_base.py (1)
bmm_mxfp8(4144-4202)
benchmarks/routines/gemm.py (3)
flashinfer/fp8_quantization.py (1)
mxfp8_quantize(147-180)flashinfer/gemm/gemm_base.py (1)
bmm_mxfp8(4144-4202)flashinfer/testing/utils.py (1)
bench_gpu_time(1484-1631)
tests/gemm/test_bmm_mxfp8.py (3)
flashinfer/autotuner.py (1)
autotune(251-262)flashinfer/gemm/gemm_base.py (1)
bmm_mxfp8(4144-4202)flashinfer/utils.py (1)
get_compute_capability(258-261)
flashinfer/gemm/gemm_base.py (2)
flashinfer/autotuner.py (7)
TunableRunner(194-247)get_valid_tactics(196-214)OptimizationProfile(168-183)forward(220-244)AutoTuner(335-791)get(362-365)choose_one(400-534)flashinfer/utils.py (2)
supported_compute_capability(819-899)backend_requirement(902-1184)
🪛 Ruff (0.14.10)
benchmarks/routines/gemm.py
812-814: Avoid specifying long messages outside the exception class
(TRY003)
862-862: Avoid specifying long messages outside the exception class
(TRY003)
923-925: Avoid specifying long messages outside the exception class
(TRY003)
flashinfer/gemm/gemm_base.py
3836-3836: Avoid specifying long messages outside the exception class
(TRY003)
3838-3838: Avoid specifying long messages outside the exception class
(TRY003)
3841-3841: Avoid specifying long messages outside the exception class
(TRY003)
3843-3843: Avoid specifying long messages outside the exception class
(TRY003)
3845-3845: Avoid specifying long messages outside the exception class
(TRY003)
3956-3956: Unused function argument: out
(ARG001)
4019-4019: Unused method argument: inputs
(ARG002)
4020-4020: Unused method argument: profile
(ARG002)
4030-4030: Unused method argument: do_preparation
(ARG002)
4031-4031: Unused method argument: kwargs
(ARG002)
4077-4077: Unused function argument: A
(ARG001)
4078-4078: Unused function argument: B
(ARG001)
4079-4079: Unused function argument: A_scale
(ARG001)
4080-4080: Unused function argument: B_scale
(ARG001)
4081-4081: Unused function argument: dtype
(ARG001)
4082-4082: Unused function argument: out
(ARG001)
4083-4083: Unused function argument: backend
(ARG001)
4092-4095: Avoid specifying long messages outside the exception class
(TRY003)
4101-4101: Unused function argument: A_scale
(ARG001)
4102-4102: Unused function argument: B_scale
(ARG001)
4104-4104: Unused function argument: out
(ARG001)
4105-4105: Unused function argument: backend
(ARG001)
4110-4110: Avoid specifying long messages outside the exception class
(TRY003)
4112-4114: Avoid specifying long messages outside the exception class
(TRY003)
4122-4122: Unused function argument: A
(ARG001)
4123-4123: Unused function argument: B
(ARG001)
4124-4124: Unused function argument: A_scale
(ARG001)
4125-4125: Unused function argument: B_scale
(ARG001)
4126-4126: Unused function argument: dtype
(ARG001)
4127-4127: Unused function argument: out
(ARG001)
4128-4128: Unused function argument: backend
(ARG001)
4185-4185: Avoid specifying long messages outside the exception class
(TRY003)
4188-4188: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (5)
flashinfer/gemm/gemm_base.py (1)
4136-4202: bmm_mxfp8 public API and backend_requirement integration look consistentThe
bmm_mxfp8entry point:
- Is properly guarded by
@backend_requirementusing_cudnn_bmm_mxfp8_requirementand_check_bmm_mxfp8_problem_size, with compute capability restricted to SM100/103 via@supported_compute_capability.- Validates
dtypevia_validate_mxfp8_output_dtype, ensuring BF16/FP16 only.- Enforces
backend == "cudnn"at the Python API level and errors out clearly if cuDNN is unavailable.- Allocates
outlazily with shape(b, m, n)and uses_get_cache_buffor a workspace buffer, then dispatches intomxfp8_gemm_sm100with["cudnn"]runner selection.Given the current implementation only supports the cuDNN backend, having
backend: Literal["cudnn"]is appropriate; theheuristic_funcis effectively dormant but harmless. Overall this wiring matches the existing FP8/FP4 patterns.flashinfer/gemm/__init__.py (1)
1-27: Public export of bmm_mxfp8 is correct and consistentThe new import/export of
bmm_mxfp8from.gemm_baseand its inclusion in__all__align with other GEMM APIs (e.g.,bmm_fp8,mm_fp4). This makes the new MXFP8 BMM routine available asflashinfer.gemm.bmm_mxfp8in a consistent way.flashinfer/__init__.py (1)
87-91: Top-level bmm_mxfp8 export matches existing GEMM API surfaceImporting
bmm_mxfp8from.gemmand re-exporting it at the package root is consistent with howbmm_fp8,mm_fp4, andmm_fp8are exposed. This is sufficient to surface MXFP8 BMM asflashinfer.bmm_mxfp8.benchmarks/routines/flashinfer_benchmark_utils.py (2)
92-112: Including bmm_mxfp8 in gemm benchmark_apis is correctAdding
"bmm_mxfp8"to the"gemm"list ensures the new routine is discoverable by the benchmark harness and aligns with howbmm_fp8and other GEMM routines are registered.
240-249: Backend mapping for bmm_mxfp8 matches capability checksThe
routine_cc_to_supported_backends["bmm_mxfp8"]entry only enables"cudnn"on compute capabilities"10.0"and"10.3", matching the@supported_compute_capability([100, 103])on_cudnn_bmm_mxfp8_requirement. This keeps the benchmark frontend in sync with the backend_requirement logic and avoids running MXFP8 BMM on unsupported architectures.No changes needed here.
| def create_cudnn_execution_plans_mxfp8_gemm( | ||
| a_shape, | ||
| a_stride, | ||
| a_type, # cudnn.data_type, FP8_E4M3 or FP8_E5M2 | ||
| b_shape, | ||
| b_stride, | ||
| b_type, # cudnn.data_type, FP8_E4M3 or FP8_E5M2 | ||
| block_size, | ||
| o_type, # cudnn.data_type, BF16 or FP16 | ||
| device, | ||
| ): | ||
| if len(a_shape) != 3: | ||
| raise ValueError(f"A shape must be 3D, got {a_shape}") | ||
| if len(b_shape) != 3: | ||
| raise ValueError(f"B shape must be 3D, got {b_shape}") | ||
|
|
||
| if a_type not in [cudnn.data_type.FP8_E4M3, cudnn.data_type.FP8_E5M2]: | ||
| raise ValueError(f"A type must be FP8_E4M3 or FP8_E5M2, got {a_type}") | ||
| if b_type not in [cudnn.data_type.FP8_E4M3, cudnn.data_type.FP8_E5M2]: | ||
| raise ValueError(f"B type must be FP8_E4M3 or FP8_E5M2, got {b_type}") | ||
| if o_type not in [cudnn.data_type.BFLOAT16, cudnn.data_type.HALF]: | ||
| raise ValueError(f"Output type must be BF16 or FP16, got {o_type}") | ||
|
|
||
| # Extract batch, m, n, k dimensions | ||
| b_dim = a_shape[0] | ||
| m = a_shape[1] | ||
| k = a_shape[2] | ||
| n = b_shape[2] | ||
|
|
||
| # Calculate block scale dimensions using indestructible block formula | ||
| block_scale_dim_m, block_scale_dim_n, block_scale_dim_k = ( | ||
| _calculate_block_scale_dims(m, n, k, block_size) | ||
| ) | ||
|
|
||
| # For mxfp8, scale tensors need to be reshaped to 3D with correct strides | ||
| # cuDNN expects K-major layout: stride for K dimension should be 1 | ||
| # For block_descale_a: shape [b, block_scale_dim_m, block_scale_dim_k], stride [block_scale_dim_m * block_scale_dim_k, block_scale_dim_k, 1] | ||
| # For block_descale_b: shape [b, block_scale_dim_k, block_scale_dim_n], stride [block_scale_dim_n * block_scale_dim_k, 1, block_scale_dim_k] | ||
|
|
||
| a_descale_shape = (b_dim, block_scale_dim_m, block_scale_dim_k) | ||
| a_descale_stride = ( | ||
| block_scale_dim_m * block_scale_dim_k, | ||
| block_scale_dim_k, | ||
| 1, | ||
| ) | ||
|
|
||
| b_descale_shape = (b_dim, block_scale_dim_k, block_scale_dim_n) | ||
| b_descale_stride = ( | ||
| block_scale_dim_n * block_scale_dim_k, | ||
| 1, | ||
| block_scale_dim_k, | ||
| ) | ||
|
|
||
| # MXFP8 uses FP8_E4M3/FP8_E5M2 for quantized data | ||
| # MXFP8 uses FP8_E8M0 for scale data | ||
| scale_type = cudnn.data_type.FP8_E8M0 | ||
|
|
||
| stream = torch.cuda.current_stream(device) | ||
| with cudnn.graph(_get_cudnn_handle(stream)) as (graph, _): | ||
| a_cudnn_tensor = graph.tensor( | ||
| name="a", | ||
| dim=tuple(a_shape), # [b, m, k] | ||
| stride=tuple(a_stride), # [m * k, k, 1] | ||
| data_type=a_type, | ||
| ) | ||
| b_cudnn_tensor = graph.tensor( | ||
| name="b", | ||
| dim=tuple(b_shape), # [b, k, n] | ||
| stride=tuple(b_stride), # [k * n, 1, k] | ||
| data_type=b_type, | ||
| ) | ||
| block_descale_a_cudnn_tensor = graph.tensor( | ||
| name="block_descale_a", | ||
| dim=a_descale_shape, | ||
| stride=a_descale_stride, | ||
| data_type=scale_type, | ||
| reordering_type=cudnn.tensor_reordering.F8_128x4, | ||
| ) | ||
| block_descale_b_cudnn_tensor = graph.tensor( | ||
| name="block_descale_b", | ||
| dim=b_descale_shape, | ||
| stride=b_descale_stride, | ||
| data_type=scale_type, | ||
| reordering_type=cudnn.tensor_reordering.F8_128x4, | ||
| ) | ||
|
|
||
| # Dequantize the input tensors | ||
| dequant_a_tensor = graph.block_scale_dequantize( | ||
| a_cudnn_tensor, | ||
| block_descale_a_cudnn_tensor, | ||
| block_size=[1, block_size], | ||
| name="dequant_a", | ||
| ) | ||
| dequant_a_tensor.set_data_type(cudnn.data_type.FLOAT) | ||
| dequant_b_tensor = graph.block_scale_dequantize( | ||
| b_cudnn_tensor, | ||
| block_descale_b_cudnn_tensor, | ||
| block_size=[block_size, 1], | ||
| name="dequant_b", | ||
| ) | ||
| dequant_b_tensor.set_data_type(cudnn.data_type.FLOAT) | ||
|
|
||
| # The actual matmul operation | ||
| c_tensor = graph.matmul( | ||
| dequant_a_tensor, | ||
| dequant_b_tensor, | ||
| compute_data_type=cudnn.data_type.FLOAT, | ||
| name="gemm", | ||
| ) | ||
| c_tensor.set_data_type(cudnn.data_type.FLOAT) | ||
|
|
||
| # Output the dequantized result with the specified output dtype | ||
| c_tensor.set_output(True).set_data_type(o_type) | ||
| c_final_cudnn_tensor = c_tensor | ||
|
|
||
| a_cudnn_tensor.set_uid(UIDs.A_UID.value) | ||
| b_cudnn_tensor.set_uid(UIDs.B_UID.value) | ||
| block_descale_a_cudnn_tensor.set_uid(UIDs.BLOCK_DESCALE_A_UID.value) | ||
| block_descale_b_cudnn_tensor.set_uid(UIDs.BLOCK_DESCALE_B_UID.value) | ||
| c_final_cudnn_tensor.set_uid(UIDs.O_UID.value) | ||
|
|
||
| graph.validate() | ||
| graph.build_operation_graph() | ||
| graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.B]) | ||
|
|
||
| return graph | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's verify the file and function exist, and examine the exact code
cat -n flashinfer/gemm/gemm_base.py | sed -n '3824,3950p'Repository: flashinfer-ai/flashinfer
Length of output: 5674
🏁 Script executed:
# Let's also check for the _check_bmm_mxfp8_problem_size function
rg "_check_bmm_mxfp8_problem_size" -nRepository: flashinfer-ai/flashinfer
Length of output: 218
🏁 Script executed:
# And look for mxfp8_quantize function to understand scale tensor creation
rg "def mxfp8_quantize" -n -A 20Repository: flashinfer-ai/flashinfer
Length of output: 5447
🏁 Script executed:
# Check for any existing stride validation in the codebase
rg "stride\(-1\)" -n flashinfer/gemm/gemm_base.pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
cat -n flashinfer/gemm/gemm_base.py | sed -n '4098,4130p'Repository: flashinfer-ai/flashinfer
Length of output: 1234
🏁 Script executed:
# Look at the mxfp8_quantize function more completely
cat -n flashinfer/fp8_quantization.py | sed -n '147,250p'Repository: flashinfer-ai/flashinfer
Length of output: 2940
🏁 Script executed:
# Search for any contiguity or stride checks in the codebase related to mxfp8
rg "contiguous|stride" flashinfer/fp8_quantization.py -n -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check if there are any scale tensor validation or size checks anywhere
rg "_scale.*numel\|assert.*stride" flashinfer/gemm/gemm_base.py -n -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check the _calculate_block_scale_dims function
rg "def _calculate_block_scale_dims" -A 20Repository: flashinfer-ai/flashinfer
Length of output: 1363
🏁 Script executed:
# Look for where create_cudnn_execution_plans_mxfp8_gemm is called to see what inputs are passed
rg "create_cudnn_execution_plans_mxfp8_gemm" -n -B 3 -A 3Repository: flashinfer-ai/flashinfer
Length of output: 899
🏁 Script executed:
# Check how A_scale and B_scale are used in the bmm_mxfp8 function
rg "def bmm_mxfp8" flashinfer/gemm/gemm_base.py -A 50 | head -80Repository: flashinfer-ai/flashinfer
Length of output: 1384
🏁 Script executed:
# Look at the full bmm_mxfp8 function to see the complete context
cat -n flashinfer/gemm/gemm_base.py | sed -n '3952,4000p'Repository: flashinfer-ai/flashinfer
Length of output: 1781
🏁 Script executed:
# Check if there are any tests that verify the behavior with different stride/layout configurations
rg "bmm_mxfp8" flashinfer/tests -lRepository: flashinfer-ai/flashinfer
Length of output: 124
🏁 Script executed:
# Look at what happens when calling create_cudnn_execution_plans_mxfp8_gemm - is a/b stride checked beforehand?
cat -n flashinfer/gemm/gemm_base.py | sed -n '3950,3990p'Repository: flashinfer-ai/flashinfer
Length of output: 1453
🏁 Script executed:
# Find test files
find . -path "*/test*" -name "*.py" -type f | grep -i mxfp8 | head -10Repository: flashinfer-ai/flashinfer
Length of output: 98
🏁 Script executed:
# Check the full bmm_mxfp8 function including requirement checking
cat -n flashinfer/gemm/gemm_base.py | sed -n '4120,4180p'Repository: flashinfer-ai/flashinfer
Length of output: 2059
🏁 Script executed:
# Look for where A_scale and B_scale are reshaped or prepared
rg "A_scale|B_scale" flashinfer/gemm/gemm_base.py -n | head -30Repository: flashinfer-ai/flashinfer
Length of output: 1070
🏁 Script executed:
# Check the cudnn_bmm_mxfp8_requirement to see what validation happens
rg "cudnn_bmm_mxfp8_requirement" flashinfer/gemm/gemm_base.py -A 30Repository: flashinfer-ai/flashinfer
Length of output: 2654
🏁 Script executed:
# Let's look at the test to see what kinds of inputs are being tested
cat tests/gemm/test_bmm_mxfp8.pyRepository: flashinfer-ai/flashinfer
Length of output: 2851
🏁 Script executed:
# Check what mxfp8_quantize actually returns for scale tensor shape
rg "mxfp8_quantize" tests/gemm/test_bmm_mxfp8.py -B 5 -A 10Repository: flashinfer-ai/flashinfer
Length of output: 1501
🏁 Script executed:
# Let's check if there are any contiguity checks anywhere in the mxfp8 code path
rg "contiguous\(\)" flashinfer/gemm/gemm_base.py -nRepository: flashinfer-ai/flashinfer
Length of output: 151
🏁 Script executed:
# Check if anywhere in the code path validates that tensors have the expected layout
rg "is_contiguous\|stride\(-1\)" flashinfer/gemm/gemm_base.py -nRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Let's look at what happens with the scale tensors - are they reshaped or validated?
cat -n flashinfer/gemm/gemm_base.py | sed -n '4180,4220p'Repository: flashinfer-ai/flashinfer
Length of output: 840
Add stride and scale tensor validation to prevent layout mismatches
create_cudnn_execution_plans_mxfp8_gemm hard-codes block-descale shapes/strides assuming K-major (stride[-1] == 1) and contiguous inputs, with no runtime validation. Add two defensive checks:
-
In
_check_bmm_mxfp8_problem_size, validate thatA.stride(-1) == 1andB.stride(-1) == 1to enforce the K-major layout requirement and fail fast on unexpected tensor layouts. -
In the same function, assert that
A_scale.numel() == b_dim * block_scale_dim_m * block_scale_dim_kandB_scale.numel() == b_dim * block_scale_dim_k * block_scale_dim_nto catch mismatches between quantization layout and cuDNN's graph expectations.
These checks are defensive and won't affect correct usage with mxfp8_quantize, but will make debugging easier if scale tensors are manually created or inputs have unexpected layouts.
🧰 Tools
🪛 Ruff (0.14.10)
3836-3836: Avoid specifying long messages outside the exception class
(TRY003)
3838-3838: Avoid specifying long messages outside the exception class
(TRY003)
3841-3841: Avoid specifying long messages outside the exception class
(TRY003)
3843-3843: Avoid specifying long messages outside the exception class
(TRY003)
3845-3845: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In flashinfer/gemm/gemm_base.py around lines 3824 to 3950, add defensive
validation in the helper _check_bmm_mxfp8_problem_size to ensure inputs follow
the K-major layout and scale-tensor shapes match the computed block-scale dims:
check that A.stride(-1) == 1 and B.stride(-1) == 1 and raise a ValueError with a
clear message if not, and check that A_scale.numel() == b_dim *
block_scale_dim_m * block_scale_dim_k and B_scale.numel() == b_dim *
block_scale_dim_k * block_scale_dim_n and raise ValueError if either count
mismatches, so the function fails fast on unexpected tensor layouts or scale
size mismatches.
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
a7c6acd to
8fef4bf
Compare
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
flashinfer/gemm/gemm_base.py (1)
4184-4188: Remove redundant backend validationSince
backendhas typeLiteral["cudnn"], the checks at lines 4184-4188 are redundant — the type system already ensures only"cudnn"is accepted. Thebackend_requirementdecorator also handles backend validation.🔎 Simplification
- if backend != "cudnn": - raise ValueError(f"Invalid backend: {backend}") - - if not CUDNN_AVAILABLE: - raise ValueError("cudnn is not available") - if out is None:
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer/gemm/gemm_base.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
flashinfer/gemm/gemm_base.py
🪛 Ruff (0.14.10)
flashinfer/gemm/gemm_base.py
3836-3836: Avoid specifying long messages outside the exception class
(TRY003)
3838-3838: Avoid specifying long messages outside the exception class
(TRY003)
3841-3841: Avoid specifying long messages outside the exception class
(TRY003)
3843-3843: Avoid specifying long messages outside the exception class
(TRY003)
3845-3845: Avoid specifying long messages outside the exception class
(TRY003)
3956-3956: Unused function argument: out
(ARG001)
4019-4019: Unused method argument: inputs
(ARG002)
4020-4020: Unused method argument: profile
(ARG002)
4030-4030: Unused method argument: do_preparation
(ARG002)
4031-4031: Unused method argument: kwargs
(ARG002)
4077-4077: Unused function argument: A
(ARG001)
4078-4078: Unused function argument: B
(ARG001)
4079-4079: Unused function argument: A_scale
(ARG001)
4080-4080: Unused function argument: B_scale
(ARG001)
4081-4081: Unused function argument: dtype
(ARG001)
4082-4082: Unused function argument: out
(ARG001)
4083-4083: Unused function argument: backend
(ARG001)
4092-4095: Avoid specifying long messages outside the exception class
(TRY003)
4101-4101: Unused function argument: A_scale
(ARG001)
4102-4102: Unused function argument: B_scale
(ARG001)
4104-4104: Unused function argument: out
(ARG001)
4105-4105: Unused function argument: backend
(ARG001)
4110-4110: Avoid specifying long messages outside the exception class
(TRY003)
4112-4114: Avoid specifying long messages outside the exception class
(TRY003)
4122-4122: Unused function argument: A
(ARG001)
4123-4123: Unused function argument: B
(ARG001)
4124-4124: Unused function argument: A_scale
(ARG001)
4125-4125: Unused function argument: B_scale
(ARG001)
4126-4126: Unused function argument: dtype
(ARG001)
4127-4127: Unused function argument: out
(ARG001)
4128-4128: Unused function argument: backend
(ARG001)
4185-4185: Avoid specifying long messages outside the exception class
(TRY003)
4188-4188: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (10)
flashinfer/gemm/gemm_base.py (10)
1339-1376: LGTM: cuDNN graph execution correctly handles workspace and tacticsThe execution logic properly manages workspace allocation, tactic-based plan selection, and variant pack construction for MXFP8 operations.
3799-3821: LGTM: Block scale dimension calculation implements the indestructible 128×4 layoutThe formula correctly pads M and N to multiples of 128, and pads K (divided by block_size) to multiples of 4, matching cuDNN's expected layout requirements for MXFP8.
3952-3978: LGTM: Graph construction wrapper correctly delegates to cached builderThe function appropriately extracts tensor metadata and handles tactic-based plan selection. The unused
outparameter maintains API consistency with similar functions.
3980-4013: LGTM: MXFP8 execution correctly uses block_size=32 and integrates graph operationsThe function properly sets the MXFP8 block size and coordinates graph retrieval with execution.
4015-4047: LGTM: Runner implementation follows established pattern with appropriate tactic handlingThe single-tactic approach (
[0]) is consistent with other cuDNN runners that rely on cuDNN's internal heuristics. The TODO is noted but not blocking.
4049-4073: Consider dedicated tuning configuration for MXFP8 in future optimizationThe function reuses
_FP8_GEMM_SM100_TUNING_CONFIGfrom standard FP8 GEMM. While this works, MXFP8's different block structure (32-element blocks vs standard FP8) may benefit from specialized tuning parameters in future performance work. The TODOs at lines 4066 and 4068 correctly flag this for follow-up.As noted in past reviews, this is acceptable for initial implementation.
4075-4087: LGTM: Backend availability check correctly validates cuDNN presenceThis requirement function appropriately verifies cuDNN availability for SM100/103. The unused parameters are part of the
backend_requirementdecorator interface.
4089-4096: LGTM: Output dtype validation correctly restricts to bf16/fp16The validation logic is clear and provides helpful error messages.
4098-4118: LGTM: Problem size validation correctly checks 3D shapes and K-dimension matchingThe validation logic properly ensures tensors are 3D and that the K dimensions align (A.shape[2] == B.shape[1]). The error messages provide full shape information for debugging.
4120-4134: LGTM: Heuristic function correctly filters for cuDNN backendThe backend selection logic appropriately checks cuDNN availability. Unused parameters are part of the heuristic function interface.
| @functools.cache | ||
| def create_cudnn_execution_plans_mxfp8_gemm( | ||
| a_shape, | ||
| a_stride, | ||
| a_type, # cudnn.data_type, FP8_E4M3 or FP8_E5M2 | ||
| b_shape, | ||
| b_stride, | ||
| b_type, # cudnn.data_type, FP8_E4M3 or FP8_E5M2 | ||
| block_size, | ||
| o_type, # cudnn.data_type, BF16 or FP16 | ||
| device, | ||
| ): | ||
| if len(a_shape) != 3: | ||
| raise ValueError(f"A shape must be 3D, got {a_shape}") | ||
| if len(b_shape) != 3: | ||
| raise ValueError(f"B shape must be 3D, got {b_shape}") | ||
|
|
||
| if a_type not in [cudnn.data_type.FP8_E4M3, cudnn.data_type.FP8_E5M2]: | ||
| raise ValueError(f"A type must be FP8_E4M3 or FP8_E5M2, got {a_type}") | ||
| if b_type not in [cudnn.data_type.FP8_E4M3, cudnn.data_type.FP8_E5M2]: | ||
| raise ValueError(f"B type must be FP8_E4M3 or FP8_E5M2, got {b_type}") | ||
| if o_type not in [cudnn.data_type.BFLOAT16, cudnn.data_type.HALF]: | ||
| raise ValueError(f"Output type must be BF16 or FP16, got {o_type}") | ||
|
|
||
| # Extract batch, m, n, k dimensions | ||
| b_dim = a_shape[0] | ||
| m = a_shape[1] | ||
| k = a_shape[2] | ||
| n = b_shape[2] | ||
|
|
||
| # Calculate block scale dimensions using indestructible block formula | ||
| block_scale_dim_m, block_scale_dim_n, block_scale_dim_k = ( | ||
| _calculate_block_scale_dims(m, n, k, block_size) | ||
| ) | ||
|
|
||
| # For mxfp8, scale tensors need to be reshaped to 3D with correct strides | ||
| # cuDNN expects K-major layout: stride for K dimension should be 1 | ||
| # For block_descale_a: shape [b, block_scale_dim_m, block_scale_dim_k], stride [block_scale_dim_m * block_scale_dim_k, block_scale_dim_k, 1] | ||
| # For block_descale_b: shape [b, block_scale_dim_k, block_scale_dim_n], stride [block_scale_dim_n * block_scale_dim_k, 1, block_scale_dim_k] | ||
|
|
||
| a_descale_shape = (b_dim, block_scale_dim_m, block_scale_dim_k) | ||
| a_descale_stride = ( | ||
| block_scale_dim_m * block_scale_dim_k, | ||
| block_scale_dim_k, | ||
| 1, | ||
| ) | ||
|
|
||
| b_descale_shape = (b_dim, block_scale_dim_k, block_scale_dim_n) | ||
| b_descale_stride = ( | ||
| block_scale_dim_n * block_scale_dim_k, | ||
| 1, | ||
| block_scale_dim_k, | ||
| ) | ||
|
|
||
| # MXFP8 uses FP8_E4M3/FP8_E5M2 for quantized data | ||
| # MXFP8 uses FP8_E8M0 for scale data | ||
| scale_type = cudnn.data_type.FP8_E8M0 | ||
|
|
||
| stream = torch.cuda.current_stream(device) | ||
| with cudnn.graph(_get_cudnn_handle(stream)) as (graph, _): | ||
| a_cudnn_tensor = graph.tensor( | ||
| name="a", | ||
| dim=tuple(a_shape), # [b, m, k] | ||
| stride=tuple(a_stride), # [m * k, k, 1] | ||
| data_type=a_type, | ||
| ) | ||
| b_cudnn_tensor = graph.tensor( | ||
| name="b", | ||
| dim=tuple(b_shape), # [b, k, n] | ||
| stride=tuple(b_stride), # [k * n, 1, k] | ||
| data_type=b_type, | ||
| ) | ||
| block_descale_a_cudnn_tensor = graph.tensor( | ||
| name="block_descale_a", | ||
| dim=a_descale_shape, | ||
| stride=a_descale_stride, | ||
| data_type=scale_type, | ||
| reordering_type=cudnn.tensor_reordering.F8_128x4, | ||
| ) | ||
| block_descale_b_cudnn_tensor = graph.tensor( | ||
| name="block_descale_b", | ||
| dim=b_descale_shape, | ||
| stride=b_descale_stride, | ||
| data_type=scale_type, | ||
| reordering_type=cudnn.tensor_reordering.F8_128x4, | ||
| ) | ||
|
|
||
| # Dequantize the input tensors | ||
| dequant_a_tensor = graph.block_scale_dequantize( | ||
| a_cudnn_tensor, | ||
| block_descale_a_cudnn_tensor, | ||
| block_size=[1, block_size], | ||
| name="dequant_a", | ||
| ) | ||
| dequant_a_tensor.set_data_type(cudnn.data_type.FLOAT) | ||
| dequant_b_tensor = graph.block_scale_dequantize( | ||
| b_cudnn_tensor, | ||
| block_descale_b_cudnn_tensor, | ||
| block_size=[block_size, 1], | ||
| name="dequant_b", | ||
| ) | ||
| dequant_b_tensor.set_data_type(cudnn.data_type.FLOAT) | ||
|
|
||
| # The actual matmul operation | ||
| c_tensor = graph.matmul( | ||
| dequant_a_tensor, | ||
| dequant_b_tensor, | ||
| compute_data_type=cudnn.data_type.FLOAT, | ||
| name="gemm", | ||
| ) | ||
| c_tensor.set_data_type(cudnn.data_type.FLOAT) | ||
|
|
||
| # Output the dequantized result with the specified output dtype | ||
| c_tensor.set_output(True).set_data_type(o_type) | ||
| c_final_cudnn_tensor = c_tensor | ||
|
|
||
| a_cudnn_tensor.set_uid(UIDs.A_UID.value) | ||
| b_cudnn_tensor.set_uid(UIDs.B_UID.value) | ||
| block_descale_a_cudnn_tensor.set_uid(UIDs.BLOCK_DESCALE_A_UID.value) | ||
| block_descale_b_cudnn_tensor.set_uid(UIDs.BLOCK_DESCALE_B_UID.value) | ||
| c_final_cudnn_tensor.set_uid(UIDs.O_UID.value) | ||
|
|
||
| graph.validate() | ||
| graph.build_operation_graph() | ||
| graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.B]) | ||
|
|
||
| return graph | ||
|
|
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
Add input tensor and scale tensor validation to prevent silent failures
This function hard-codes K-major strides (stride[-1] == 1) and specific block-scale shapes when constructing the cuDNN graph, but performs no runtime validation that the actual input tensors a and b have the expected layout, or that the scale tensors match the computed dimensions. Add defensive checks:
- Before graph construction, validate that the input strides match K-major expectations (e.g.,
a_stride[-1] == 1andb_stride[-1] == 1). - In the calling code (or here if scale tensors are accessible), assert that scale tensor element counts match
block_scale_dim_m * block_scale_dim_kfor A andblock_scale_dim_k * block_scale_dim_nfor B.
These checks ensure early, clear failures if quantization or layout assumptions are violated, rather than silent cuDNN errors or incorrect results.
Based on past review comments indicating this validation gap.
🧰 Tools
🪛 Ruff (0.14.10)
3836-3836: Avoid specifying long messages outside the exception class
(TRY003)
3838-3838: Avoid specifying long messages outside the exception class
(TRY003)
3841-3841: Avoid specifying long messages outside the exception class
(TRY003)
3843-3843: Avoid specifying long messages outside the exception class
(TRY003)
3845-3845: Avoid specifying long messages outside the exception class
(TRY003)
| def bmm_mxfp8( | ||
| A: torch.Tensor, | ||
| B: torch.Tensor, | ||
| A_scale: torch.Tensor, | ||
| B_scale: torch.Tensor, | ||
| dtype: torch.dtype, | ||
| out: Optional[torch.Tensor] = None, | ||
| backend: Literal["cudnn"] = "cudnn", | ||
| ) -> torch.Tensor: | ||
| r"""BMM MXFP8 | ||
|
|
||
| Parameters | ||
| ---------- | ||
| A: torch.Tensor | ||
| Input tensor, shape (b, m, k), fp8 e4m3 or fp8 e5m2. | ||
|
|
||
| B: torch.Tensor | ||
| Mat2 tensor, shape (b, k, n), should be column major, fp8 e4m3 or fp8 e5m2. | ||
|
|
||
| A_scale: torch.Tensor | ||
| Scale tensor for A, uint8 (fp8 e8m0 format). | ||
|
|
||
| B_scale: torch.Tensor | ||
| Scale tensor for B, uint8 (fp8 e8m0 format). | ||
|
|
||
| dtype: torch.dtype | ||
| out dtype, bf16 or fp16. | ||
|
|
||
| out: Optional[torch.Tensor] | ||
| Out tensor, shape (b, m, n), bf16 or fp16, defaults to ``None``. | ||
|
|
||
| backend: Literal["cudnn"] | ||
| The backend to use for the operation. Defaults to ``"cudnn"``. | ||
|
|
||
| Returns | ||
| ------- | ||
| out: torch.Tensor | ||
| Out tensor, shape (b, m, n), bf16 or fp16. | ||
| """ | ||
|
|
||
| if backend != "cudnn": | ||
| raise ValueError(f"Invalid backend: {backend}") | ||
|
|
||
| if not CUDNN_AVAILABLE: | ||
| raise ValueError("cudnn is not available") | ||
|
|
||
| if out is None: | ||
| out = torch.empty( | ||
| (A.shape[0], A.shape[1], B.shape[2]), | ||
| device=A.device, | ||
| dtype=dtype, | ||
| ) | ||
|
|
||
| workspace_buffer = _get_cache_buf( | ||
| "bmm_mxfp8_workspace", DEFAULT_WORKSPACE_SIZE, A.device | ||
| ) | ||
|
|
||
| mxfp8_gemm_sm100(A, B, A_scale, B_scale, out, workspace_buffer, ["cudnn"]) | ||
| return out |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
Add input tensor validation to ensure correct dtypes and scale tensor shapes
The function should validate that:
AandBhave FP8 dtypes (torch.float8_e4m3fnortorch.float8_e5m2)A_scaleandB_scalehave dtypetorch.uint8(FP8_E8M0 format)- Scale tensor shapes match the expected block-scale dimensions computed from A and B shapes
These checks would fail fast on incorrect inputs rather than producing cuDNN errors or incorrect results downstream.
🔎 Example validation to add after line 4195
+ # Validate input dtypes
+ if A.dtype not in (torch.float8_e4m3fn, torch.float8_e5m2):
+ raise ValueError(
+ f"A must have FP8 dtype (torch.float8_e4m3fn or torch.float8_e5m2), got {A.dtype}"
+ )
+ if B.dtype not in (torch.float8_e4m3fn, torch.float8_e5m2):
+ raise ValueError(
+ f"B must have FP8 dtype (torch.float8_e4m3fn or torch.float8_e5m2), got {B.dtype}"
+ )
+
+ # Validate scale tensor dtypes
+ if A_scale.dtype != torch.uint8:
+ raise ValueError(f"A_scale must be uint8 (FP8_E8M0), got {A_scale.dtype}")
+ if B_scale.dtype != torch.uint8:
+ raise ValueError(f"B_scale must be uint8 (FP8_E8M0), got {B_scale.dtype}")
+
+ # Validate scale tensor shapes
+ block_size = 32 # MXFP8 block size
+ b_dim, m, k = A.shape
+ n = B.shape[2]
+ block_scale_dim_m, block_scale_dim_n, block_scale_dim_k = _calculate_block_scale_dims(m, n, k, block_size)
+ expected_a_scale_shape = (b_dim, block_scale_dim_m, block_scale_dim_k)
+ expected_b_scale_shape = (b_dim, block_scale_dim_k, block_scale_dim_n)
+ if A_scale.shape != expected_a_scale_shape:
+ raise ValueError(
+ f"A_scale shape mismatch. Expected {expected_a_scale_shape}, got {A_scale.shape}"
+ )
+ if B_scale.shape != expected_b_scale_shape:
+ raise ValueError(
+ f"B_scale shape mismatch. Expected {expected_b_scale_shape}, got {B_scale.shape}"
+ )🧰 Tools
🪛 Ruff (0.14.10)
4185-4185: Avoid specifying long messages outside the exception class
(TRY003)
4188-4188: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In flashinfer/gemm/gemm_base.py around lines 4144 to 4202, add input validation
after the docstring (around line 4195): verify A.dtype and B.dtype are one of
torch.float8_e4m3fn or torch.float8_e5m2 and raise TypeError otherwise; verify
A_scale.dtype and B_scale.dtype are torch.uint8 and raise TypeError otherwise;
validate A_scale and B_scale shapes match the expected block-scale dimensions
derived from A and B (compute expected scale shapes from A.shape and B.shape
using the same FP8 blocking logic used by mxfp8_gemm_sm100 or a small helper:
blocks_k = ceil_div(A.shape[2], FP8_SCALE_BLOCK), expected_A_scale_shape =
(A.shape[0], blocks_k) and expected_B_scale_shape = (B.shape[0],
ceil_div(B.shape[2], FP8_SCALE_BLOCK)) or use the existing helper if present),
and raise ValueError with clear messages if shapes mismatch.
Invalid use of get_workspace_size. Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
flashinfer/gemm/gemm_base.py (1)
3819-3946: Consider adding stride validation for K-major layout requirementThe function hard-codes K-major strides in the cuDNN graph (lines 3856-3871 show stride configuration with
stride[-1] == 1expectation). While this will work correctly when used withmxfp8_quantize, adding defensive checks would help catch misuse:# At the start of the function, after shape validation: if a_stride[-1] != 1: raise ValueError(f"A must have K-major layout (stride[-1] == 1), got stride {a_stride}") if b_stride[-1] != 1: raise ValueError(f"B must have N-major layout for column-major (stride[-1] != 1 expected), got stride {b_stride}")This is a defensive measure—correct usage won't trigger these checks, but they would provide clear error messages for debugging if tensors with unexpected layouts are passed.
Based on past review comments.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer/gemm/gemm_base.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
flashinfer/gemm/gemm_base.py
🪛 Ruff (0.14.10)
flashinfer/gemm/gemm_base.py
3832-3832: Avoid specifying long messages outside the exception class
(TRY003)
3834-3834: Avoid specifying long messages outside the exception class
(TRY003)
3837-3837: Avoid specifying long messages outside the exception class
(TRY003)
3839-3839: Avoid specifying long messages outside the exception class
(TRY003)
3841-3841: Avoid specifying long messages outside the exception class
(TRY003)
3952-3952: Unused function argument: out
(ARG001)
4015-4015: Unused method argument: inputs
(ARG002)
4016-4016: Unused method argument: profile
(ARG002)
4026-4026: Unused method argument: do_preparation
(ARG002)
4027-4027: Unused method argument: kwargs
(ARG002)
4073-4073: Unused function argument: A
(ARG001)
4074-4074: Unused function argument: B
(ARG001)
4075-4075: Unused function argument: A_scale
(ARG001)
4076-4076: Unused function argument: B_scale
(ARG001)
4077-4077: Unused function argument: dtype
(ARG001)
4078-4078: Unused function argument: out
(ARG001)
4079-4079: Unused function argument: backend
(ARG001)
4088-4091: Avoid specifying long messages outside the exception class
(TRY003)
4097-4097: Unused function argument: A_scale
(ARG001)
4098-4098: Unused function argument: B_scale
(ARG001)
4100-4100: Unused function argument: out
(ARG001)
4101-4101: Unused function argument: backend
(ARG001)
4106-4106: Avoid specifying long messages outside the exception class
(TRY003)
4108-4110: Avoid specifying long messages outside the exception class
(TRY003)
4118-4118: Unused function argument: A
(ARG001)
4119-4119: Unused function argument: B
(ARG001)
4120-4120: Unused function argument: A_scale
(ARG001)
4121-4121: Unused function argument: B_scale
(ARG001)
4122-4122: Unused function argument: dtype
(ARG001)
4123-4123: Unused function argument: out
(ARG001)
4124-4124: Unused function argument: backend
(ARG001)
4181-4181: Avoid specifying long messages outside the exception class
(TRY003)
4184-4184: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (7)
flashinfer/gemm/gemm_base.py (7)
1339-1372: LGTM!The execution function correctly handles the variant pack setup and workspace buffer sizing. The removal of the
.view(torch.float8_e4m3fn)calls (flagged in previous review) is correct since MXFP8 tensors should already have the appropriate dtype.
3795-3817: LGTM!The block scale dimension calculation correctly applies the indestructible block formula with proper ceiling division.
4045-4068: Verify tuning configuration is appropriate for MXFP8The function reuses
_FP8_GEMM_SM100_TUNING_CONFIGwhich was designed for FP8 GEMM. While the tensor layout is similar (3D with dynamic M dimension), MXFP8's block-scaled nature might benefit from different bucketing strategies in the future.The current implementation is functional, but consider adding a dedicated
_MXFP8_GEMM_SM100_TUNING_CONFIGif profiling shows different optimal configurations.
4011-4043: LGTM!The runner follows the same pattern as
_cudnn_gemm_fp8_runner, using[0]as the default tactic which delegates to cuDNN's internal heuristics. The TODO comment can be removed since this is consistent with the existing FP8 implementation.
4140-4198: LGTM!The public API is well-documented and follows the established patterns from
bmm_fp8. The explicit backend and cuDNN availability checks provide clear error messages despite some redundancy with the decorator validation.
3948-4008: LGTM!The graph creation and execution functions follow the established FP4 GEMM patterns. The unused
outparameter (flagged by static analysis) is retained for API consistency, which is the same pattern used in_get_cudnn_fp4_gemm_graph.
4071-4091: LGTM!The requirement checker and dtype validator follow the established patterns from FP8/FP4 implementations. The unused function arguments (flagged by static analysis) are required by the
@backend_requirementdecorator's expected function signature.
| def _check_bmm_mxfp8_problem_size( | ||
| A: torch.Tensor, | ||
| B: torch.Tensor, | ||
| A_scale: torch.Tensor, | ||
| B_scale: torch.Tensor, | ||
| dtype: torch.dtype, | ||
| out: Optional[torch.Tensor] = None, | ||
| backend: Literal["cudnn"] = "cudnn", | ||
| ): | ||
| # Check input tensors | ||
| if A.ndim != 3 or B.ndim != 3: | ||
| # A is [b, m, k], B is [b, k, n] | ||
| raise ValueError(f"bmm_mxfp8 accepts 3d tensors, got {A.shape=} and {B.shape=}") | ||
| if A.shape[2] != B.shape[1]: | ||
| raise ValueError( | ||
| f"K dimension (last dim of A) mismatch in bmm_mxfp8. got {A.shape=}, {B.shape=}" | ||
| ) | ||
|
|
||
| _validate_mxfp8_output_dtype(dtype) | ||
| return True |
There was a problem hiding this comment.
Add batch dimension and input dtype validation
The function validates K dimensions but misses batch dimension matching and input dtype validation. Consider adding:
🔎 Proposed validation additions
def _check_bmm_mxfp8_problem_size(
A: torch.Tensor,
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
dtype: torch.dtype,
out: Optional[torch.Tensor] = None,
backend: Literal["cudnn"] = "cudnn",
):
# Check input tensors
if A.ndim != 3 or B.ndim != 3:
# A is [b, m, k], B is [b, k, n]
raise ValueError(f"bmm_mxfp8 accepts 3d tensors, got {A.shape=} and {B.shape=}")
+ if A.shape[0] != B.shape[0]:
+ raise ValueError(
+ f"Batch dimension mismatch in bmm_mxfp8. got {A.shape[0]=}, {B.shape[0]=}"
+ )
if A.shape[2] != B.shape[1]:
raise ValueError(
f"K dimension (last dim of A) mismatch in bmm_mxfp8. got {A.shape=}, {B.shape=}"
)
+ # Validate input dtypes
+ if A.dtype not in (torch.float8_e4m3fn, torch.float8_e5m2):
+ raise ValueError(f"A must have FP8 dtype, got {A.dtype}")
+ if B.dtype not in (torch.float8_e4m3fn, torch.float8_e5m2):
+ raise ValueError(f"B must have FP8 dtype, got {B.dtype}")
+ if A_scale.dtype != torch.uint8:
+ raise ValueError(f"A_scale must be uint8 (FP8_E8M0), got {A_scale.dtype}")
+ if B_scale.dtype != torch.uint8:
+ raise ValueError(f"B_scale must be uint8 (FP8_E8M0), got {B_scale.dtype}")
_validate_mxfp8_output_dtype(dtype)
return TrueThese checks help fail fast with clear error messages rather than producing cuDNN errors or incorrect results downstream. Based on past review comments.
🧰 Tools
🪛 Ruff (0.14.10)
4097-4097: Unused function argument: A_scale
(ARG001)
4098-4098: Unused function argument: B_scale
(ARG001)
4100-4100: Unused function argument: out
(ARG001)
4101-4101: Unused function argument: backend
(ARG001)
4106-4106: Avoid specifying long messages outside the exception class
(TRY003)
4108-4110: Avoid specifying long messages outside the exception class
(TRY003)
|
[FAILED] Pipeline #40747548: 1/20 passed |
|
the 33 failed on test_bmm_fp8 in the pipeline were expected from main. |
|
/bot run |
📌 Description
Add support for GEMM with MXFP8 (
bmm_mxfp8).At this time only cuDNN is supported.
Added test
tests/gemm/test_bmm_mxfp8.pyAdded routine
bmm_mxfp8toflashinfer_benchmark.Benchmark results for
bmm_mxfp8(on B200 GPU):And
bmm_fp8for comparison:When running
ncuthe kernelnvjet_sm100_qqtst_128x256_128x6_2x1_2cta_v_bz_Avec32UE8M0_Bvec32UE8M0_NNTseems to trigger.🔍 Related Issues
#2209
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Tests
Chores
✏️ Tip: You can customize this high-level summary in your review settings.