-
Notifications
You must be signed in to change notification settings - Fork 584
feat: BF16 GEMM using CUTLASS backend for SM100 #2070
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat: BF16 GEMM using CUTLASS backend for SM100 #2070
Conversation
WalkthroughAdds SM100 Cutlass BF16 GEMM support: new CUDA runner and FFI entry points, SM100-specific kernel templates and Jinja instantiations, public Python mm_bf16/bmm_bf16 APIs with JIT generator, headers for runner/configs, tests, and documentation. Includes workspace/config probing and tactic selection. Changes
Sequence Diagram(s)sequenceDiagram
participant Py as Python Test / User
participant API as mm_bf16 / bmm_bf16
participant BF16SM as bf16_gemm_sm100
participant JIT as gen_gemm_sm100_module_cutlass_bf16
participant FFI as CUDA FFI (bf16_gemm)
participant Runner as CutlassBf16GemmRunner
participant Kernel as Cutlass kernel
Py->>API: call mm_bf16 / bmm_bf16(a,b,out_dtype)
API->>API: validate shapes & dtype\nallocate/validate out & workspace
API->>BF16SM: bf16_gemm_sm100(a,b,out,workspace)
BF16SM->>JIT: load/get SM100 BF16 module (JIT)
JIT-->>BF16SM: module with bf16_gemm_runner
BF16SM->>FFI: call bf16_gemm(...) via FFI (tactic)
FFI->>Runner: getBf16GemmConfig -> choose tactic/config
FFI->>Runner: runGemm<T>() -> compute workspace & launch
Runner->>Kernel: launch Cutlass kernel on stream
Kernel-->>Runner: kernel complete
Runner-->>FFI: return status
FFI-->>BF16SM: completed
BF16SM-->>API: done
API-->>Py: return output
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~45 minutes
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
Currently there is an error about the second matrix being non-contiguous: |
dd6216f to
aaaee56
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (13)
csrc/bf16_gemm_cutlass.cu(1 hunks)csrc/bf16_gemm_cutlass.jinja(1 hunks)docs/api/gemm.rst(1 hunks)flashinfer/__init__.py(1 hunks)flashinfer/gemm/gemm_base.py(4 hunks)flashinfer/jit/gemm/__init__.py(2 hunks)flashinfer/jit/gemm/core.py(1 hunks)include/flashinfer/gemm/bf16_gemm_cutlass.h(1 hunks)include/flashinfer/gemm/bf16_gemm_cutlass_template.h(1 hunks)include/flashinfer/gemm/bf16_gemm_template_sm100.h(1 hunks)include/flashinfer/gemm/fp8_gemm_cutlass_template.h(0 hunks)tests/gemm/test_bmm_bf16.py(1 hunks)tests/gemm/test_mm_bf16.py(1 hunks)
💤 Files with no reviewable changes (1)
- include/flashinfer/gemm/fp8_gemm_cutlass_template.h
🧰 Additional context used
🧬 Code graph analysis (10)
flashinfer/jit/gemm/__init__.py (1)
flashinfer/jit/gemm/core.py (1)
gen_gemm_sm100_module_cutlass_bf16(193-237)
include/flashinfer/gemm/bf16_gemm_cutlass.h (3)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
flashinfer(41-134)gemm(42-91)include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
gemm(44-176)flashinfer/gemm/gemm_base.py (1)
CutlassBf16GemmRunner(497-520)
flashinfer/gemm/gemm_base.py (5)
flashinfer/jit/gemm/core.py (2)
gen_gemm_sm100_module(240-316)gen_gemm_sm100_module_cutlass_bf16(193-237)flashinfer/utils.py (3)
supported_compute_capability(773-853)_get_cache_buf(205-211)get_compute_capability(252-255)flashinfer/autotuner.py (7)
TunableRunner(194-247)OptimizationProfile(168-183)AutoTuner(335-784)TuningConfig(101-141)DynamicTensorSpec(41-82)ConstraintSpec(86-97)choose_one(400-529)csrc/bf16_gemm_cutlass.cu (4)
bf16_gemm_tactic_num(149-156)bf16_gemm_tactic_num(149-149)bf16_gemm(144-147)bf16_gemm(144-145)flashinfer/fused_moe/utils.py (2)
get_last_power_of_2_num_tokens_buckets(206-215)last_positive_power_of_2(183-188)
tests/gemm/test_bmm_bf16.py (3)
flashinfer/autotuner.py (1)
autotune(251-262)flashinfer/gemm/gemm_base.py (1)
bmm_bf16(250-313)flashinfer/utils.py (2)
get_compute_capability(252-255)is_compute_capability_supported(979-994)
flashinfer/__init__.py (1)
flashinfer/gemm/gemm_base.py (2)
bmm_bf16(250-313)mm_bf16(183-246)
tests/gemm/test_mm_bf16.py (3)
flashinfer/autotuner.py (1)
autotune(251-262)flashinfer/gemm/gemm_base.py (1)
mm_bf16(183-246)flashinfer/utils.py (2)
get_compute_capability(252-255)is_compute_capability_supported(979-994)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (4)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
flashinfer(26-60)gemm(27-59)include/flashinfer/gemm/fp8_gemm_cutlass_template.h (4)
flashinfer(41-145)gemm(42-95)std(184-184)std(185-185)include/flashinfer/gemm/bf16_gemm_template_sm100.h (6)
gemm(44-176)_1SM(53-57)cutlass(135-135)cutlass(136-136)cutlass(137-137)cutlass(138-138)include/flashinfer/gemm/cutlass_gemm_configs.h (1)
CutlassTileConfigSM100(106-425)
flashinfer/jit/gemm/core.py (2)
flashinfer/jit/core.py (2)
JitSpec(213-312)gen_jit_spec(315-381)flashinfer/compilation_context.py (1)
get_nvcc_flags_list(50-68)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (2)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
flashinfer(26-60)gemm(27-59)include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
flashinfer(41-134)gemm(42-91)
csrc/bf16_gemm_cutlass.cu (4)
flashinfer/gemm/gemm_base.py (1)
CutlassBf16GemmRunner(497-520)include/flashinfer/gemm/bf16_gemm_cutlass.h (1)
CutlassBf16GemmRunnerInterface(29-41)include/flashinfer/gemm/cutlass_gemm_configs.h (1)
CutlassTileConfigSM100(106-425)csrc/tvm_ffi_utils.h (2)
get_stream(272-274)encode_dlpack_dtype(29-31)
🪛 Clang (14.0.6)
include/flashinfer/gemm/bf16_gemm_cutlass.h
[error] 20-20: 'cuda_runtime_api.h' file not found
(clang-diagnostic-error)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h
[error] 23-23: 'cutlass/arch/arch.h' file not found
(clang-diagnostic-error)
include/flashinfer/gemm/bf16_gemm_template_sm100.h
[error] 23-23: 'cutlass/arch/arch.h' file not found
(clang-diagnostic-error)
🪛 GitHub Actions: pre-commit
flashinfer/__init__.py
[error] 88-88: mypy: Module "flashinfer.gemm" has no attribute "bmm_bf16".
[error] 90-90: mypy: Module "flashinfer.gemm" has no attribute "mm_bf16".
🪛 Ruff (0.14.3)
flashinfer/gemm/gemm_base.py
218-218: Avoid specifying long messages outside the exception class
(TRY003)
220-220: Avoid specifying long messages outside the exception class
(TRY003)
230-232: Avoid specifying long messages outside the exception class
(TRY003)
234-236: Avoid specifying long messages outside the exception class
(TRY003)
238-240: Avoid specifying long messages outside the exception class
(TRY003)
284-284: Avoid specifying long messages outside the exception class
(TRY003)
286-286: Avoid specifying long messages outside the exception class
(TRY003)
297-299: Avoid specifying long messages outside the exception class
(TRY003)
301-303: Avoid specifying long messages outside the exception class
(TRY003)
305-307: Avoid specifying long messages outside the exception class
(TRY003)
500-500: Unused method argument: inputs
(ARG002)
501-501: Unused method argument: profile
(ARG002)
509-509: Unused method argument: do_preparation
(ARG002)
510-510: Unused method argument: kwargs
(ARG002)
592-592: Avoid specifying long messages outside the exception class
(TRY003)
flashinfer/jit/gemm/core.py
233-233: Consider [*nvcc_flags, "-DENABLE_BF16"] instead of concatenation
Replace with [*nvcc_flags, "-DENABLE_BF16"]
(RUF005)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (2)
docs/api/gemm.rst (2)
10-17: Documentation formatting is consistent and well-structured.The BF16 GEMM subsection follows the established pattern of other GEMM sections in the file (consistent indentation, autosummary directive, toctree configuration). Placement at the beginning of the GEMM API documentation is logical and appropriate.
10-17: Documentation is complete and accurate.The BF16 GEMM subsection correctly documents
mm_bf16andbmm_bf16—these are the only public-facing BF16 GEMM functions (verified by top-level exports inflashinfer/__init__.py). Thebf16_gemmmentioned in the PR summary is an internal C++ binding and tuning identifier, not a public Python API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (1)
flashinfer/gemm/gemm_base.py (1)
512-520: Materialize the transposed tensor before passing to the CUTLASS runner.This is the root cause of the runtime error reported in the PR:
b.transpose(-2, -1)returns a non-contiguous view, but the C++ binding requires contiguous input. The fix is to call.contiguous()on the transposed tensor.Apply this fix:
a, b, out, workspace_buffer = inputs module.bf16_gemm( a, - b.transpose(-2, -1), + b.transpose(-2, -1).contiguous(), out, workspace_buffer, tactic, )This issue was already identified in the previous review.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
flashinfer/gemm/__init__.py(2 hunks)flashinfer/gemm/gemm_base.py(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/gemm/gemm_base.py (7)
include/flashinfer/gemm/bf16_gemm_cutlass.h (1)
gemm(27-59)include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
gemm(44-176)flashinfer/jit/gemm/core.py (1)
gen_gemm_sm100_module_cutlass_bf16(193-237)flashinfer/utils.py (3)
supported_compute_capability(773-853)_get_cache_buf(205-211)get_compute_capability(252-255)flashinfer/autotuner.py (7)
TunableRunner(194-247)OptimizationProfile(168-183)AutoTuner(335-784)TuningConfig(101-141)DynamicTensorSpec(41-82)ConstraintSpec(86-97)choose_one(400-529)csrc/bf16_gemm_cutlass.cu (4)
bf16_gemm_tactic_num(149-156)bf16_gemm_tactic_num(149-149)bf16_gemm(144-147)bf16_gemm(144-145)flashinfer/fused_moe/utils.py (2)
get_last_power_of_2_num_tokens_buckets(206-215)last_positive_power_of_2(183-188)
🪛 Ruff (0.14.3)
flashinfer/gemm/gemm_base.py
218-218: Avoid specifying long messages outside the exception class
(TRY003)
220-220: Avoid specifying long messages outside the exception class
(TRY003)
230-232: Avoid specifying long messages outside the exception class
(TRY003)
234-236: Avoid specifying long messages outside the exception class
(TRY003)
238-240: Avoid specifying long messages outside the exception class
(TRY003)
284-284: Avoid specifying long messages outside the exception class
(TRY003)
286-286: Avoid specifying long messages outside the exception class
(TRY003)
297-299: Avoid specifying long messages outside the exception class
(TRY003)
301-303: Avoid specifying long messages outside the exception class
(TRY003)
305-307: Avoid specifying long messages outside the exception class
(TRY003)
500-500: Unused method argument: inputs
(ARG002)
501-501: Unused method argument: profile
(ARG002)
509-509: Unused method argument: do_preparation
(ARG002)
510-510: Unused method argument: kwargs
(ARG002)
592-592: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (1)
flashinfer/gemm/__init__.py (1)
1-38: LGTM! Public API exports are correctly wired.The new BF16 GEMM functions (
bmm_bf16andmm_bf16) are properly imported fromgemm_baseand exposed through the module's__all__list, making them available as part of the public API.
511d8e0 to
fbe5723
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (5)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
156-173: Restore workspace probe handling before launching kernels
CutlassBf16GemmRunner::getWorkspaceSize()calls this launcher withworkspacePtr == nullptrto learn how big the buffer must be. Today we immediately throw becauseworkspaceBytes == 0, so the probe reports 0 and the next real launch still fails with “insufficient workspace”. Please short‑circuit the probe and return the computed size instead of throwing.size_t workspace_size = gemm.get_workspace_size(arguments); + if (workspacePtr == nullptr) { + return workspace_size; + } if (workspace_size > workspaceBytes) {flashinfer/gemm/gemm_base.py (4)
182-246: Validate BF16 MM inputs before dispatchThe CUTLASS path assumes 2‑D bf16 matrices on the same CUDA device with contiguous row‑major layout. Without the early guards we can accept the wrong dtype, mismatched shapes/devices, or a strided view and only fail deep inside the kernel (or produce garbage). Please restore the validation/contiguity fixes before touching the workspace.
- if backend != "cutlass": + if backend != "cutlass": raise ValueError(f"Unsupported backend: {backend}. Only cutlass is available.") if out_dtype not in (torch.bfloat16, torch.float16): raise ValueError("Only bf16 and fp16 outputs are supported.") + + if a.ndim != 2 or b.ndim != 2: + raise ValueError(f"mm_bf16 expects 2D tensors. Got a.ndim={a.ndim}, b.ndim={b.ndim}.") + if a.dtype != torch.bfloat16 or b.dtype != torch.bfloat16: + raise ValueError(f"Inputs must be bfloat16. Got a.dtype={a.dtype}, b.dtype={b.dtype}.") + if a.shape[1] != b.shape[0]: + raise ValueError( + f"Shape mismatch for matrix multiplication. a.shape[1]={a.shape[1]} must equal b.shape[0]={b.shape[0]}." + ) + if a.device != b.device: + raise ValueError(f"Device mismatch. a.device={a.device}, b.device={b.device}.") + if not a.is_contiguous(): + a = a.contiguous() + if not b.is_contiguous(): + b = b.contiguous() + if out is not None and not out.is_contiguous(): + raise ValueError("Output tensor must be contiguous for the CUTLASS backend.")
249-313: Do the same validation for BMMThe batched entry point has the same holes: wrong dtype, rank, device, or non‑contiguous slices go straight into CUTLASS and fail later (or worse, corrupt results). Please add the missing checks for 3‑D tensors, matching batch/K dims, same device, and enforce contiguity before launching.
- if backend != "cutlass": + if backend != "cutlass": raise ValueError(f"Unsupported backend: {backend}. Only cutlass is available.") if out_dtype not in (torch.bfloat16, torch.float16): raise ValueError("Only bf16 and fp16 outputs are supported.") + + if a.ndim != 3 or b.ndim != 3: + raise ValueError(f"bmm_bf16 expects 3D tensors. Got a.ndim={a.ndim}, b.ndim={b.ndim}.") + if a.dtype != torch.bfloat16 or b.dtype != torch.bfloat16: + raise ValueError(f"Inputs must be bfloat16. Got a.dtype={a.dtype}, b.dtype={b.dtype}.") + if a.shape[0] != b.shape[0]: + raise ValueError( + f"Batch size mismatch. a.shape[0]={a.shape[0]} must equal b.shape[0]={b.shape[0]}." + ) + if a.shape[2] != b.shape[1]: + raise ValueError( + f"K dimension mismatch. a.shape[2]={a.shape[2]} must equal b.shape[1]={b.shape[1]}." + ) + if a.device != b.device: + raise ValueError(f"Device mismatch. a.device={a.device}, b.device={b.device}.") + if not a.is_contiguous(): + a = a.contiguous() + if not b.is_contiguous(): + b = b.contiguous() + if out is not None and not out.is_contiguous(): + raise ValueError("Output tensor must be contiguous for the CUTLASS backend.")
512-520: Materialize column‑major B before calling the kernel
bf16_gemmstill receivesb.transpose(-2, -1)directly, which is a non‑contiguous view and reproduces the runtime error (“mat2 must be contiguous”). Please allocate the column‑major buffer before dispatching to CUTLASS.- module.bf16_gemm( - a, - b.transpose(-2, -1), - out, - workspace_buffer, - tactic, - ) + b_col_major = b.transpose(-2, -1).contiguous() + module.bf16_gemm( + a, + b_col_major, + out, + workspace_buffer, + tactic, + )
590-592: Report the actual device when no runner is foundWhen
alives on a non‑default GPU,torch.device("cuda")queries device 0 and we raise “sm100” even if the tensor was on sm90. Use the tensor’s device so the error reflects reality.- major, minor = get_compute_capability(torch.device("cuda")) + major, minor = get_compute_capability(a.device)
🧹 Nitpick comments (1)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)
54-91: Cluster shape dispatch with limited configuration support.The function correctly dispatches based on cluster shape, with appropriate error handling for unsupported configurations. The limitation to only
ClusterShape_1x1x1aligns with the PR author's note about tile size and SMEM constraints during initial development.Note: Line 66 has a
breakstatement afterreturn, which is unreachable but harmless.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (14)
csrc/bf16_gemm_cutlass.cu(1 hunks)csrc/bf16_gemm_cutlass.jinja(1 hunks)docs/api/gemm.rst(1 hunks)flashinfer/__init__.py(1 hunks)flashinfer/gemm/__init__.py(2 hunks)flashinfer/gemm/gemm_base.py(4 hunks)flashinfer/jit/gemm/__init__.py(2 hunks)flashinfer/jit/gemm/core.py(1 hunks)include/flashinfer/gemm/bf16_gemm_cutlass.h(1 hunks)include/flashinfer/gemm/bf16_gemm_cutlass_template.h(1 hunks)include/flashinfer/gemm/bf16_gemm_template_sm100.h(1 hunks)include/flashinfer/gemm/fp8_gemm_cutlass_template.h(0 hunks)tests/gemm/test_bmm_bf16.py(1 hunks)tests/gemm/test_mm_bf16.py(1 hunks)
💤 Files with no reviewable changes (1)
- include/flashinfer/gemm/fp8_gemm_cutlass_template.h
🚧 Files skipped from review as they are similar to previous changes (6)
- tests/gemm/test_mm_bf16.py
- flashinfer/gemm/init.py
- csrc/bf16_gemm_cutlass.jinja
- tests/gemm/test_bmm_bf16.py
- flashinfer/jit/gemm/init.py
- csrc/bf16_gemm_cutlass.cu
🧰 Additional context used
🧬 Code graph analysis (6)
flashinfer/__init__.py (1)
flashinfer/gemm/gemm_base.py (2)
bmm_bf16(250-313)mm_bf16(183-246)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (3)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
flashinfer(26-60)gemm(27-59)include/flashinfer/gemm/bf16_gemm_template_sm100.h (6)
gemm(44-176)_1SM(53-57)cutlass(135-135)cutlass(136-136)cutlass(137-137)cutlass(138-138)include/flashinfer/gemm/cutlass_gemm_configs.h (1)
CutlassTileConfigSM100(106-425)
flashinfer/jit/gemm/core.py (2)
flashinfer/jit/core.py (2)
JitSpec(213-312)gen_jit_spec(315-381)flashinfer/compilation_context.py (1)
get_nvcc_flags_list(50-68)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
flashinfer(41-134)gemm(42-91)include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
gemm(44-176)
flashinfer/gemm/gemm_base.py (5)
flashinfer/jit/gemm/core.py (1)
gen_gemm_sm100_module_cutlass_bf16(193-237)flashinfer/utils.py (3)
supported_compute_capability(773-853)_get_cache_buf(205-211)get_compute_capability(252-255)flashinfer/autotuner.py (5)
TunableRunner(194-247)OptimizationProfile(168-183)AutoTuner(335-784)TuningConfig(101-141)choose_one(400-529)csrc/bf16_gemm_cutlass.cu (4)
bf16_gemm_tactic_num(149-156)bf16_gemm_tactic_num(149-149)bf16_gemm(144-147)bf16_gemm(144-145)flashinfer/fused_moe/utils.py (2)
get_last_power_of_2_num_tokens_buckets(206-215)last_positive_power_of_2(183-188)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (2)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
flashinfer(26-60)gemm(27-59)include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
flashinfer(41-134)gemm(42-91)
🪛 Clang (14.0.6)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h
[error] 23-23: 'cutlass/arch/arch.h' file not found
(clang-diagnostic-error)
include/flashinfer/gemm/bf16_gemm_cutlass.h
[error] 20-20: 'cuda_runtime_api.h' file not found
(clang-diagnostic-error)
include/flashinfer/gemm/bf16_gemm_template_sm100.h
[error] 23-23: 'cutlass/arch/arch.h' file not found
(clang-diagnostic-error)
🪛 Ruff (0.14.4)
flashinfer/jit/gemm/core.py
233-233: Consider [*nvcc_flags, "-DENABLE_BF16"] instead of concatenation
Replace with [*nvcc_flags, "-DENABLE_BF16"]
(RUF005)
flashinfer/gemm/gemm_base.py
218-218: Avoid specifying long messages outside the exception class
(TRY003)
220-220: Avoid specifying long messages outside the exception class
(TRY003)
230-232: Avoid specifying long messages outside the exception class
(TRY003)
234-236: Avoid specifying long messages outside the exception class
(TRY003)
238-240: Avoid specifying long messages outside the exception class
(TRY003)
284-284: Avoid specifying long messages outside the exception class
(TRY003)
286-286: Avoid specifying long messages outside the exception class
(TRY003)
297-299: Avoid specifying long messages outside the exception class
(TRY003)
301-303: Avoid specifying long messages outside the exception class
(TRY003)
305-307: Avoid specifying long messages outside the exception class
(TRY003)
500-500: Unused method argument: inputs
(ARG002)
501-501: Unused method argument: profile
(ARG002)
509-509: Unused method argument: do_preparation
(ARG002)
510-510: Unused method argument: kwargs
(ARG002)
592-592: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (6)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
29-41: Well-designed interface for BF16 GEMM runner.The abstract interface provides a clean contract with appropriate virtual methods for GEMM operations, workspace management, and configuration enumeration. The virtual destructor is correctly included for safe polymorphic deletion.
43-57: Template class declaration follows proper separation pattern.The template class declaration correctly inherits from the interface and overrides all pure virtual methods. The separation of declaration (here) and definition (in the template header) is appropriate for template code.
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (4)
19-34: Appropriate diagnostic pragmas for CUTLASS integration.The GCC diagnostic pragmas correctly suppress strict-aliasing warnings around CUTLASS headers, which is necessary since CUTLASS may use type punning internally.
136-143: GEMM implementation correctly delegates to dispatch logic.The implementation properly forwards all parameters to
dispatchToArchwith appropriate type casting.
186-210: Configuration enumeration with limited initial support.The function correctly enumerates candidate configurations by combining tile configs and cluster shapes. The current limitation to a single tile configuration (
CtaShape64x64x128B) and cluster shape (ClusterShape_1x1x1) aligns with the PR objectives and the author's noted constraints regarding SMEM space and limited B200 hardware access for testing.As additional tile sizes and cluster shapes are validated on SM100 hardware, uncomment the relevant lines to expand the configuration space.
99-103: Verify the intentional A↔B and m↔n parameter swap is correct for the kernel expectations.The parameter swap pattern
dispatchGemmClusterShapeSm100<T, arch, 64, 64, 128>(B, A, static_cast<T*>(D), n, m, k, ...)is applied consistently across all tile configurations in both bf16 and fp8 GEMM implementations. This is paired with explicit layout declarations: LayoutA = RowMajor and LayoutB = ColumnMajor.While the consistency of this pattern across multiple files strongly suggests it is intentional for layout conversion, please confirm that this parameter reordering matches the actual kernel signature and expectations for dispatchGemmClusterShapeSm100.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (7)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
166-173: Improve MNK hash mixing to avoid collisions.XORing
h1 ^ h2 ^ h3collapses different(m,n,k)permutations to the same bucket, so cached workspace sizes can be reused for incompatible shapes. Combine the hashes with a proper mixer instead.struct MNKHash { size_t operator()(const MNK& mnk) const { auto h1 = std::hash<int>{}(std::get<0>(mnk)); auto h2 = std::hash<int>{}(std::get<1>(mnk)); auto h3 = std::hash<int>{}(std::get<2>(mnk)); - return h1 ^ h2 ^ h3; + size_t seed = h1; + seed ^= h2 + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2); + seed ^= h3 + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2); + return seed; } };
175-183: Guard the static workspace cache with a mutex.
workspace_hashmapis mutated without synchronization; concurrent calls togetWorkspaceSizewill race onfind()/operator[]. Protect the cache with a lock.- static std::unordered_map<MNK, size_t, MNKHash> workspace_hashmap; + static std::unordered_map<MNK, size_t, MNKHash> workspace_hashmap; + static std::mutex workspace_mutex; size_t workspace_size = 0; - if (workspace_hashmap.find(std::make_tuple(m, n, k)) == workspace_hashmap.end()) { - workspace_size = CutlassBf16GemmRunner<T>::getWorkspaceSizeImpl(m, n, k); - workspace_hashmap[std::make_tuple(m, n, k)] = workspace_size; - } else { - workspace_size = workspace_hashmap[std::make_tuple(m, n, k)]; - } + const MNK key = std::make_tuple(m, n, k); + { + std::lock_guard<std::mutex> lock(workspace_mutex); + auto it = workspace_hashmap.find(key); + if (it != workspace_hashmap.end()) { + return it->second; + } + workspace_size = CutlassBf16GemmRunner<T>::getWorkspaceSizeImpl(m, n, k); + workspace_hashmap.emplace(key, workspace_size); + } return workspace_size;tests/gemm/test_mm_bf16.py (1)
14-21: Skip on CPU-only test environments.
get_compute_capability(torch.device("cuda"))raises when CUDA isn’t available, causing the entire suite to error out instead of skipping. Guard this withif not torch.cuda.is_available(): pytest.skip(...)before the capability query.tests/gemm/test_bmm_bf16.py (1)
15-22: Gracefully skip when CUDA is unavailable.Like the MM test, calling
get_compute_capability(torch.device("cuda"))without checkingtorch.cuda.is_available()hard-fails on CPU-only setups. Add a skip guard before querying the device.flashinfer/gemm/gemm_base.py (3)
217-240: Validate inputs before firing the kernel.
mm_bf16still accepts tensors with wrong dtype, shape, or device, which the CUTLASS runner interprets incorrectly (e.g., passing fp16 data corrupts results). Add explicit checks for bf16 dtype, matching inner dimensions, and matching devices at the top of the function so misuse fails fast.+ if a.dtype != torch.bfloat16 or b.dtype != torch.bfloat16: + raise ValueError( + f"Inputs must be bfloat16. Got a.dtype={a.dtype}, b.dtype={b.dtype}." + ) + if a.ndim != 2 or b.ndim != 2: + raise ValueError( + f"Inputs must be 2D matrices. Got a.ndim={a.ndim}, b.ndim={b.ndim}." + ) + if a.shape[1] != b.shape[0]: + raise ValueError( + f"Shape mismatch: a.shape[1]={a.shape[1]} must equal b.shape[0]={b.shape[0]}." + ) + if a.device != b.device: + raise ValueError( + f"Device mismatch: a.device={a.device}, b.device={b.device}." + )
288-307: Add basic sanity checks for batched inputs.
bmm_bf16also needs dtype/shape/device validation; otherwise mismatched batch sizes or wrong K dimensions surface as low-level CUTLASS failures. Please mirror the checks frommm_bf16for 3D tensors (batch, m, k) and (batch, k, n), ensuring matching batch/K dimensions and bf16 dtype.+ if A.dtype != torch.bfloat16 or B.dtype != torch.bfloat16: + raise ValueError( + f"Inputs must be bfloat16. Got A.dtype={A.dtype}, B.dtype={B.dtype}." + ) + if A.ndim != 3 or B.ndim != 3: + raise ValueError( + f"Inputs must be 3D tensors. Got A.ndim={A.ndim}, B.ndim={B.ndim}." + ) + if A.shape[0] != B.shape[0]: + raise ValueError( + f"Batch mismatch: A.shape[0]={A.shape[0]} != B.shape[0]={B.shape[0]}." + ) + if A.shape[2] != B.shape[1]: + raise ValueError( + f"K mismatch: A.shape[2]={A.shape[2]} must equal B.shape[1]={B.shape[1]}." + ) + if A.device != B.device: + raise ValueError( + f"Device mismatch: A.device={A.device}, B.device={B.device}." + )
512-519: Make the transposed B operand contiguous.The CUTLASS binding now enforces
mat2.is_contiguous(). Transposing on the fly hands it a strided view and triggers the runtime error you reported. Materialize the column-major buffer before launching.- module.bf16_gemm( - a, - b.transpose(-2, -1), - out, - workspace_buffer, - tactic, - ) + b_col_major = b.transpose(-2, -1).contiguous() + module.bf16_gemm( + a, + b_col_major, + out, + workspace_buffer, + tactic, + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
flashinfer/gemm/gemm_base.py(4 hunks)include/flashinfer/gemm/bf16_gemm_cutlass_template.h(1 hunks)tests/gemm/test_bmm_bf16.py(1 hunks)tests/gemm/test_mm_bf16.py(1 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.563Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
📚 Learning: 2025-11-12T03:35:17.563Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.563Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
include/flashinfer/gemm/bf16_gemm_cutlass_template.hflashinfer/gemm/gemm_base.py
🧬 Code graph analysis (4)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (4)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
flashinfer(26-60)gemm(27-59)include/flashinfer/gemm/fp8_gemm_cutlass_template.h (4)
flashinfer(41-145)gemm(42-95)std(184-184)std(185-185)include/flashinfer/gemm/bf16_gemm_template_sm100.h (6)
gemm(44-176)_1SM(53-57)cutlass(135-135)cutlass(136-136)cutlass(137-137)cutlass(138-138)include/flashinfer/gemm/cutlass_gemm_configs.h (1)
CutlassTileConfigSM100(106-425)
tests/gemm/test_mm_bf16.py (2)
flashinfer/gemm/gemm_base.py (1)
mm_bf16(183-246)flashinfer/utils.py (2)
get_compute_capability(252-255)is_compute_capability_supported(979-994)
tests/gemm/test_bmm_bf16.py (3)
flashinfer/autotuner.py (1)
autotune(251-262)flashinfer/gemm/gemm_base.py (1)
bmm_bf16(250-313)flashinfer/utils.py (2)
get_compute_capability(252-255)is_compute_capability_supported(979-994)
flashinfer/gemm/gemm_base.py (8)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)
gemm(42-91)include/flashinfer/gemm/bf16_gemm_cutlass.h (1)
gemm(27-59)include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
gemm(44-176)flashinfer/jit/gemm/core.py (1)
gen_gemm_sm100_module_cutlass_bf16(193-237)flashinfer/utils.py (2)
supported_compute_capability(773-853)_get_cache_buf(205-211)flashinfer/autotuner.py (7)
TunableRunner(194-247)OptimizationProfile(168-183)AutoTuner(335-784)TuningConfig(101-141)DynamicTensorSpec(41-82)ConstraintSpec(86-97)choose_one(400-529)csrc/bf16_gemm_cutlass.cu (4)
bf16_gemm_tactic_num(149-156)bf16_gemm_tactic_num(149-149)bf16_gemm(144-147)bf16_gemm(144-145)flashinfer/fused_moe/utils.py (2)
get_last_power_of_2_num_tokens_buckets(206-215)last_positive_power_of_2(183-188)
🪛 Clang (14.0.6)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h
[error] 23-23: 'cutlass/arch/arch.h' file not found
(clang-diagnostic-error)
🪛 Ruff (0.14.4)
flashinfer/gemm/gemm_base.py
218-218: Avoid specifying long messages outside the exception class
(TRY003)
220-220: Avoid specifying long messages outside the exception class
(TRY003)
230-232: Avoid specifying long messages outside the exception class
(TRY003)
234-236: Avoid specifying long messages outside the exception class
(TRY003)
238-240: Avoid specifying long messages outside the exception class
(TRY003)
284-284: Avoid specifying long messages outside the exception class
(TRY003)
286-286: Avoid specifying long messages outside the exception class
(TRY003)
297-299: Avoid specifying long messages outside the exception class
(TRY003)
301-303: Avoid specifying long messages outside the exception class
(TRY003)
305-307: Avoid specifying long messages outside the exception class
(TRY003)
500-500: Unused method argument: inputs
(ARG002)
501-501: Unused method argument: profile
(ARG002)
509-509: Unused method argument: do_preparation
(ARG002)
510-510: Unused method argument: kwargs
(ARG002)
d2c8547 to
8a58e45
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (7)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
156-170: Allow workspace probes to succeed when no buffer is provided.
CutlassBf16GemmRunner::getWorkspaceSizeImplinvokes this launcher withworkspacePtr == nullptrandworkspaceBytes == 0to query the required size. The current code throws before returning the computedworkspace_size, breaking workspace queries. Short-circuit whenworkspacePtrisnullptrto return the size without running the kernel.size_t workspace_size = gemm.get_workspace_size(arguments); + if (workspacePtr == nullptr) { + return workspace_size; + } if (workspace_size > workspaceBytes) { throw std::runtime_error("[Bf16 Gemm Runner] insufficient workspace"); }include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
166-173: Hash function prone to collisions.The
MNKHashfunction uses XOR to combine hash values (h1 ^ h2 ^ h3), which produces collisions for permutations of the same values. For example,(1, 2, 3)and(3, 2, 1)hash identically, potentially returning incorrect workspace sizes.Use a proper hash combining algorithm:
struct MNKHash { size_t operator()(const MNK& mnk) const { auto h1 = std::hash<int>{}(std::get<0>(mnk)); auto h2 = std::hash<int>{}(std::get<1>(mnk)); auto h3 = std::hash<int>{}(std::get<2>(mnk)); - return h1 ^ h2 ^ h3; + // Combine hashes properly to avoid collisions + size_t seed = h1; + seed ^= h2 + 0x9e3779b9 + (seed << 6) + (seed >> 2); + seed ^= h3 + 0x9e3779b9 + (seed << 6) + (seed >> 2); + return seed; } };
175-184: Critical: Data race on static workspace cache.The static
workspace_hashmapat Line 175 is accessed concurrently without synchronization. While C++11+ guarantees thread-safe initialization of function-local statics, concurrent access viafind()(Line 178) andoperator[](Lines 180, 182) creates data races ifgetWorkspaceSizeis called from multiple threads.Protect the map with a mutex:
+ static std::mutex workspace_mutex; static std::unordered_map<MNK, size_t, MNKHash> workspace_hashmap; size_t workspace_size = 0; + std::lock_guard<std::mutex> lock(workspace_mutex); if (workspace_hashmap.find(std::make_tuple(m, n, k)) == workspace_hashmap.end()) {Alternatively, use
std::shared_mutexwith shared (read) and exclusive (write) locking for better concurrent read performance.tests/gemm/test_bmm_bf16.py (1)
15-22: Guard test behind CUDA availability.Calling
get_compute_capability(torch.device("cuda"))without first checkingtorch.cuda.is_available()will raise an exception on non-CUDA systems instead of skipping gracefully.Add an early CUDA check:
def test_bmm_bf16(b, m, n, k, res_dtype): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") compute_capability = get_compute_capability(torch.device(device="cuda"))flashinfer/gemm/gemm_base.py (3)
182-245: Add input validation for dtype, shape, and device consistency.The function is missing essential input validation that could lead to cryptic errors downstream. Per previous review feedback, please add checks at the beginning of the function:
+ if a.dtype != torch.bfloat16 or b.dtype != torch.bfloat16: + raise ValueError(f"Inputs must be bfloat16. Got a.dtype={a.dtype}, b.dtype={b.dtype}.") + if a.shape[1] != b.shape[0]: + raise ValueError( + f"Shape mismatch for matrix multiplication. " + f"a.shape[1]={a.shape[1]} must equal b.shape[0]={b.shape[0]}." + ) + if a.device != b.device: + raise ValueError(f"Device mismatch. a.device={a.device}, b.device={b.device}.") + if backend != "cutlass":
248-312: Add input validation for dtype, shape, and device consistency.Similar to
mm_bf16, this function lacks essential input validation. Per previous review feedback, please add checks at the beginning:+ if A.dtype != torch.bfloat16 or B.dtype != torch.bfloat16: + raise ValueError(f"Inputs must be bfloat16. Got A.dtype={A.dtype}, B.dtype={B.dtype}.") + if A.ndim != 3 or B.ndim != 3: + raise ValueError(f"Expected 3D tensors. Got A.ndim={A.ndim}, B.ndim={B.ndim}.") + if A.shape[0] != B.shape[0]: + raise ValueError( + f"Batch size mismatch. A.shape[0]={A.shape[0]} != B.shape[0]={B.shape[0]}." + ) + if A.shape[2] != B.shape[1]: + raise ValueError( + f"Shape mismatch for batched matrix multiplication. " + f"A.shape[2]={A.shape[2]} must equal B.shape[1]={B.shape[1]}." + ) + if A.device != B.device: + raise ValueError(f"Device mismatch. A.device={A.device}, B.device={B.device}.") + if backend != "cutlass":
511-519: Make the B operand contiguous before invoking the CUTLASS runner.
transpose(-2, -1)returns a non-contiguous view, which causes the runtime error you reported: "RuntimeError: Check failed: (mat2.IsContiguous()) is false: mat2 must be contiguous". Per previous review feedback, materialize the column-major buffer before launching the kernel:+ b_col_major = b.transpose(-2, -1).contiguous() module.bf16_gemm( a, - b.transpose(-2, -1), + b_col_major, out, workspace_buffer, tactic, )
🧹 Nitpick comments (2)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)
187-211: Config enumeration is clean but limited to one configuration.The
getConfigsimplementation only enumeratesCtaShape64x64x128BandClusterShape_1x1x1, reflecting the WIP status and SMEM constraints. The nested loop pattern is extensible for adding more configs once SMEM issues are resolved.Would you like help generating a script to analyze SMEM usage across different tile configurations to understand which sizes are viable for SM100?
flashinfer/jit/gemm/core.py (1)
193-237: WIP tile configurations are appropriate for initial testing.The implementation correctly follows the established pattern from FP8/FP4 modules. The single active tile configuration (64, 64, 128) is a reasonable conservative choice while debugging SMEM constraints on SM100 hardware, especially given your limited B200 access.
Optional style improvement (flagged by static analysis):
- return gen_jit_spec( - "bf16_gemm_cutlass", - source_paths, - extra_cuda_cflags=nvcc_flags + ["-DENABLE_BF16"], - extra_cflags=[ - "-DFAST_BUILD", - ], - ) + return gen_jit_spec( + "bf16_gemm_cutlass", + source_paths, + extra_cuda_cflags=[*nvcc_flags, "-DENABLE_BF16"], + extra_cflags=[ + "-DFAST_BUILD", + ], + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (14)
csrc/bf16_gemm_cutlass.cu(1 hunks)csrc/bf16_gemm_cutlass.jinja(1 hunks)docs/api/gemm.rst(1 hunks)flashinfer/__init__.py(1 hunks)flashinfer/gemm/__init__.py(2 hunks)flashinfer/gemm/gemm_base.py(4 hunks)flashinfer/jit/gemm/__init__.py(2 hunks)flashinfer/jit/gemm/core.py(1 hunks)include/flashinfer/gemm/bf16_gemm_cutlass.h(1 hunks)include/flashinfer/gemm/bf16_gemm_cutlass_template.h(1 hunks)include/flashinfer/gemm/bf16_gemm_template_sm100.h(1 hunks)include/flashinfer/gemm/fp8_gemm_cutlass_template.h(0 hunks)tests/gemm/test_bmm_bf16.py(1 hunks)tests/gemm/test_mm_bf16.py(1 hunks)
💤 Files with no reviewable changes (1)
- include/flashinfer/gemm/fp8_gemm_cutlass_template.h
🚧 Files skipped from review as they are similar to previous changes (2)
- csrc/bf16_gemm_cutlass.cu
- tests/gemm/test_mm_bf16.py
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
csrc/bf16_gemm_cutlass.jinjainclude/flashinfer/gemm/bf16_gemm_template_sm100.hinclude/flashinfer/gemm/bf16_gemm_cutlass_template.hflashinfer/gemm/gemm_base.pyinclude/flashinfer/gemm/bf16_gemm_cutlass.hflashinfer/__init__.py
🧬 Code graph analysis (9)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (2)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
flashinfer(41-134)gemm(42-91)include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
flashinfer(26-60)gemm(27-59)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (3)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
flashinfer(26-60)gemm(27-59)include/flashinfer/gemm/bf16_gemm_template_sm100.h (6)
gemm(44-176)_1SM(53-57)cutlass(135-135)cutlass(136-136)cutlass(137-137)cutlass(138-138)include/flashinfer/gemm/cutlass_gemm_configs.h (1)
CutlassTileConfigSM100(106-425)
flashinfer/gemm/gemm_base.py (8)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)
gemm(42-91)include/flashinfer/gemm/bf16_gemm_cutlass.h (1)
gemm(27-59)include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
gemm(44-176)flashinfer/jit/gemm/core.py (1)
gen_gemm_sm100_module_cutlass_bf16(193-237)flashinfer/utils.py (2)
supported_compute_capability(773-853)_get_cache_buf(205-211)flashinfer/autotuner.py (7)
TunableRunner(194-247)OptimizationProfile(168-183)AutoTuner(335-786)TuningConfig(101-141)DynamicTensorSpec(41-82)ConstraintSpec(86-97)choose_one(400-529)csrc/bf16_gemm_cutlass.cu (4)
bf16_gemm_tactic_num(149-156)bf16_gemm_tactic_num(149-149)bf16_gemm(144-147)bf16_gemm(144-145)flashinfer/fused_moe/utils.py (2)
get_last_power_of_2_num_tokens_buckets(206-215)last_positive_power_of_2(183-188)
include/flashinfer/gemm/bf16_gemm_cutlass.h (3)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
flashinfer(41-134)gemm(42-91)include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
gemm(44-176)flashinfer/gemm/gemm_base.py (1)
CutlassBf16GemmRunner(496-519)
flashinfer/gemm/__init__.py (1)
flashinfer/gemm/gemm_base.py (2)
bmm_bf16(249-312)mm_bf16(183-245)
flashinfer/__init__.py (1)
flashinfer/gemm/gemm_base.py (2)
bmm_bf16(249-312)mm_bf16(183-245)
flashinfer/jit/gemm/__init__.py (1)
flashinfer/jit/gemm/core.py (1)
gen_gemm_sm100_module_cutlass_bf16(193-237)
tests/gemm/test_bmm_bf16.py (2)
flashinfer/gemm/gemm_base.py (1)
bmm_bf16(249-312)flashinfer/utils.py (2)
get_compute_capability(252-255)is_compute_capability_supported(979-994)
flashinfer/jit/gemm/core.py (2)
flashinfer/jit/core.py (2)
JitSpec(213-312)gen_jit_spec(315-381)flashinfer/compilation_context.py (1)
get_nvcc_flags_list(50-68)
🪛 Clang (14.0.6)
include/flashinfer/gemm/bf16_gemm_template_sm100.h
[error] 23-23: 'cutlass/arch/arch.h' file not found
(clang-diagnostic-error)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h
[error] 23-23: 'cutlass/arch/arch.h' file not found
(clang-diagnostic-error)
include/flashinfer/gemm/bf16_gemm_cutlass.h
[error] 20-20: 'cuda_runtime_api.h' file not found
(clang-diagnostic-error)
🪛 Ruff (0.14.4)
flashinfer/gemm/gemm_base.py
217-217: Avoid specifying long messages outside the exception class
(TRY003)
219-219: Avoid specifying long messages outside the exception class
(TRY003)
229-231: Avoid specifying long messages outside the exception class
(TRY003)
233-235: Avoid specifying long messages outside the exception class
(TRY003)
237-239: Avoid specifying long messages outside the exception class
(TRY003)
283-283: Avoid specifying long messages outside the exception class
(TRY003)
285-285: Avoid specifying long messages outside the exception class
(TRY003)
296-298: Avoid specifying long messages outside the exception class
(TRY003)
300-302: Avoid specifying long messages outside the exception class
(TRY003)
304-306: Avoid specifying long messages outside the exception class
(TRY003)
499-499: Unused method argument: inputs
(ARG002)
500-500: Unused method argument: profile
(ARG002)
508-508: Unused method argument: do_preparation
(ARG002)
509-509: Unused method argument: kwargs
(ARG002)
flashinfer/jit/gemm/core.py
233-233: Consider [*nvcc_flags, "-DENABLE_BF16"] instead of concatenation
Replace with [*nvcc_flags, "-DENABLE_BF16"]
(RUF005)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (12)
include/flashinfer/gemm/bf16_gemm_cutlass.h (1)
1-62: LGTM! Clean interface/implementation pattern.The abstract interface and templated concrete class follow best practices for extensibility. The separation of public getWorkspaceSize and private getWorkspaceSizeImpl suggests proper encapsulation of workspace size computation logic.
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
46-151: LGTM! Standard CUTLASS GEMM setup.The SMTypeAdapter specializations and launcher configuration follow CUTLASS patterns correctly. Regarding the comment on Line 147: setting
fusion_args.alpha = 1.0fandfusion_args.beta = 0.0fis the standard way to configure a GEMM epilogue forD = A*B(no accumulation). This is the right approach.include/flashinfer/gemm/bf16_gemm_cutlass_template.h (3)
136-143: LGTM! Clean forwarding to dispatcher.
145-160: Exception handling is appropriate for config probing.The pattern of catching and ignoring
std::runtime_errorwhen probing workspace sizes is acceptable, as some configurations may legitimately fail due to SMEM constraints. The comment on Line 155 documents the rationale clearly.Based on learnings
44-134: No changes needed—review comment is accurate.Verification confirms the bf16 dispatcher is intentionally limited to
CtaShape64x64x128BandClusterShape_1x1x1(lines 100–103 and 65–67), while other tile configs and cluster shapes remain commented out. This differs from the fp8 implementation, which enables multiple configurations, confirming the bf16 limitation is deliberate due to SMEM constraints as noted. The transpose pattern (swapping B, A and n, m at line 101–102) is correct for layout handling.docs/api/gemm.rst (1)
10-18: LGTM! Documentation follows existing patterns.The new BF16 GEMM section properly documents the
mm_bf16andbmm_bf16entry points, following the same autosummary format as other GEMM types in this file.flashinfer/__init__.py (1)
88-90: LGTM! BF16 GEMM exports are now available.The imports of
bmm_bf16andmm_bf16from thegemmmodule expose the new BF16 GEMM functionality at the top level. Past review comments indicate the necessary exports were added toflashinfer/gemm/__init__.py.flashinfer/jit/gemm/__init__.py (1)
22-22: LGTM! JIT generator export follows existing patterns.The
gen_gemm_sm100_module_cutlass_bf16import and export are consistent with other GEMM generators in this module.Also applies to: 37-37
flashinfer/gemm/__init__.py (1)
2-2: LGTM! GEMM module exports are properly configured.The
bmm_bf16andmm_bf16imports fromgemm_baseand their inclusion in__all__enable the top-level imports inflashinfer/__init__.pyto work correctly.Also applies to: 4-4, 25-25, 27-27
tests/gemm/test_bmm_bf16.py (1)
23-34: LGTM! Test logic is sound.The test correctly creates BF16 inputs, computes a reference with
torch.bmm, and validates thebmm_bf16output using cosine similarity. The threshold of 0.99 is reasonable for BF16 precision.csrc/bf16_gemm_cutlass.jinja (1)
1-27: LGTM! Clean template structure with conservative defaults.The template correctly instantiates the SM100 BF16 GEMM kernel with a single-SM cluster configuration (1,1,1), which is appropriate for initial testing. The commented-out multi-SM cluster configurations provide clear guidance for future performance tuning once the basic implementation is validated.
flashinfer/gemm/gemm_base.py (1)
577-616: LGTM! AutoTuner integration follows established patterns.The function correctly uses
a.devicefor SM version checking and properly integrates with the AutoTuner for dynamic tactic selection. The tuning configuration appropriately profiles on the M dimension using power-of-2 bucketing, matching the pattern used infp8_gemm_sm100.
dcbc17a to
28baee5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (7)
tests/gemm/test_mm_bf16.py (2)
25-31: Use row‑major (k, n) weight and avoid.Tto keep inputs contiguous.Generate
mat2as(k, n), use it directly in both the reference and API call. This prevents passing a non‑contiguous transpose and matches the documented contract.- mat2 = torch.randn([n, k], device="cuda", dtype=torch.bfloat16) - - reference = torch.mm(input, mat2.T) + mat2 = torch.randn([k, n], device="cuda", dtype=torch.bfloat16) + reference = torch.mm(input, mat2) ... - mm_bf16(input, mat2.T, out=out, out_dtype=res_dtype) + mm_bf16(input, mat2, out=out, out_dtype=res_dtype)
14-16: Skip on CPU-only to avoid hard failure.Add CUDA-availability guard before calling
get_compute_capability.def test_mm_bf16(m: int, n: int, k: int, res_dtype: torch.dtype): - compute_capability = get_compute_capability(torch.device(device="cuda")) + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + compute_capability = get_compute_capability(torch.device(device="cuda"))tests/gemm/test_bmm_bf16.py (1)
14-16: Skip on CPU-only to avoid hard failure.Guard
get_compute_capability(torch.device("cuda"))with a CUDA-availability check so the test skips instead of crashing on CPU-only runners.def test_bmm_bf16(b, m, n, k, res_dtype): - compute_capability = get_compute_capability(torch.device(device="cuda")) + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + compute_capability = get_compute_capability(torch.device(device="cuda"))flashinfer/gemm/gemm_base.py (3)
182-205: Add essential input validation (dtype, shape, device).Fail fast with clear errors to avoid cryptic backend failures.
def mm_bf16( a: torch.Tensor, b: torch.Tensor, @@ ) -> torch.Tensor: @@ - if backend != "cutlass": + # Basic validations + if a.dtype != torch.bfloat16 or b.dtype != torch.bfloat16: + raise ValueError(f"Inputs must be bfloat16. Got a.dtype={a.dtype}, b.dtype={b.dtype}.") + if a.ndim != 2 or b.ndim != 2: + raise ValueError(f"mm_bf16 expects 2D tensors. Got a.ndim={a.ndim}, b.ndim={b.ndim}.") + if a.shape[1] != b.shape[0]: + raise ValueError( + f"Shape mismatch: a.shape[1]={a.shape[1]} must equal b.shape[0]={b.shape[0]}." + ) + if a.device != b.device: + raise ValueError(f"Device mismatch. a.device={a.device}, b.device={b.device}.") + + if backend != "cutlass": raise ValueError(f"Unsupported backend: {backend}. Only cutlass is available.")
259-283: Add essential input validation (batched dtype/shape/device).Validate 3D inputs, batch, and K dims before launching the kernel.
def bmm_bf16( A: torch.Tensor, B: torch.Tensor, @@ ) -> torch.Tensor: @@ - if backend != "cutlass": + # Basic validations + if A.dtype != torch.bfloat16 or B.dtype != torch.bfloat16: + raise ValueError(f"Inputs must be bfloat16. Got A.dtype={A.dtype}, B.dtype={B.dtype}.") + if A.ndim != 3 or B.ndim != 3: + raise ValueError(f"bmm_bf16 expects 3D tensors. Got A.ndim={A.ndim}, B.ndim={B.ndim}.") + if A.shape[0] != B.shape[0]: + raise ValueError(f"Batch size mismatch: A.shape[0]={A.shape[0]} != B.shape[0]={B.shape[0]}.") + if A.shape[2] != B.shape[1]: + raise ValueError( + f"K mismatch: A.shape[2]={A.shape[2]} must equal B.shape[1]={B.shape[1]}." + ) + if A.device != B.device: + raise ValueError(f"Device mismatch. A.device={A.device}, B.device={B.device}.") + + if backend != "cutlass": raise ValueError(f"Unsupported backend: {backend}. Only cutlass is available.")
533-541: Fix runtime error: make B contiguous before calling the CUTLASS binding.
b.transpose(-2, -1)is a non‑contiguous view; the C++ binding asserts contiguity (“mat2 must be contiguous”). Materialize column‑major B.- module.bf16_gemm( - a, - b.transpose(-2, -1), - out, - workspace_buffer, - tactic, - ) + b_col_major = b.transpose(-2, -1).contiguous() + module.bf16_gemm( + a, + b_col_major, + out, + workspace_buffer, + tactic, + )include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)
157-175: Thread-safety and hash quality in workspace cache.
- XOR-combining hashes collides easily.
- The static
workspace_hashmapis accessed unsafely across threads.Harden both.
@@ - struct MNKHash { + struct MNKHash { size_t operator()(const MNK& mnk) const { auto h1 = std::hash<int>{}(std::get<0>(mnk)); auto h2 = std::hash<int>{}(std::get<1>(mnk)); auto h3 = std::hash<int>{}(std::get<2>(mnk)); - return h1 ^ h2 ^ h3; + // Robust hash combine to reduce collisions + size_t seed = h1; + seed ^= h2 + 0x9e3779b9 + (seed << 6) + (seed >> 2); + seed ^= h3 + 0x9e3779b9 + (seed << 6) + (seed >> 2); + return seed; } }; @@ - static std::unordered_map<MNK, size_t, MNKHash> workspace_hashmap; + static std::mutex workspace_mutex; + static std::unordered_map<MNK, size_t, MNKHash> workspace_hashmap; @@ - size_t workspace_size = 0; - if (workspace_hashmap.find(std::make_tuple(m, n, k)) == workspace_hashmap.end()) { - workspace_size = CutlassBf16GemmRunner<T>::getWorkspaceSizeImpl(m, n, k); - workspace_hashmap[std::make_tuple(m, n, k)] = workspace_size; - } else { - workspace_size = workspace_hashmap[std::make_tuple(m, n, k)]; - } - return workspace_size; + const MNK key = std::make_tuple(m, n, k); + { + std::lock_guard<std::mutex> lock(workspace_mutex); + auto it = workspace_hashmap.find(key); + if (it != workspace_hashmap.end()) { + return it->second; + } + } + // Compute outside lock to avoid blocking others; insert with lock. + size_t computed = CutlassBf16GemmRunner<T>::getWorkspaceSizeImpl(m, n, k); + { + std::lock_guard<std::mutex> lock(workspace_mutex); + auto it = workspace_hashmap.find(key); + if (it == workspace_hashmap.end()) { + workspace_hashmap.emplace(key, computed); + return computed; + } + return it->second; + }Also add the include near the top:
-#include <stdexcept> +#include <stdexcept> +#include <mutex>
🧹 Nitpick comments (2)
flashinfer/jit/gemm/core.py (1)
229-236: Minor: prefer list splat over concatenation (RUF005).Use list unpacking for readability in
extra_cuda_cflags.- return gen_jit_spec( + return gen_jit_spec( "bf16_gemm_cutlass", source_paths, - extra_cuda_cflags=nvcc_flags + ["-DENABLE_BF16"], + extra_cuda_cflags=[*nvcc_flags, "-DENABLE_BF16"], extra_cflags=[ "-DFAST_BUILD", ], )include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
101-104: Remove unused template parameter or use it.
genericBf16GemmKernelLauncherSm100has template paramarchbut hardcodesArchTag = cutlass::arch::Sm100. Either usearchor drop the parameter.- using ArchTag = cutlass::arch::Sm100; + using ArchTag = arch;
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (14)
csrc/bf16_gemm_cutlass.cu(1 hunks)csrc/bf16_gemm_cutlass.jinja(1 hunks)docs/api/gemm.rst(1 hunks)flashinfer/__init__.py(1 hunks)flashinfer/gemm/__init__.py(2 hunks)flashinfer/gemm/gemm_base.py(4 hunks)flashinfer/jit/gemm/__init__.py(2 hunks)flashinfer/jit/gemm/core.py(1 hunks)include/flashinfer/gemm/bf16_gemm_cutlass.h(1 hunks)include/flashinfer/gemm/bf16_gemm_cutlass_template.h(1 hunks)include/flashinfer/gemm/bf16_gemm_template_sm100.h(1 hunks)include/flashinfer/gemm/fp8_gemm_cutlass_template.h(0 hunks)tests/gemm/test_bmm_bf16.py(1 hunks)tests/gemm/test_mm_bf16.py(1 hunks)
💤 Files with no reviewable changes (1)
- include/flashinfer/gemm/fp8_gemm_cutlass_template.h
🚧 Files skipped from review as they are similar to previous changes (5)
- docs/api/gemm.rst
- csrc/bf16_gemm_cutlass.cu
- flashinfer/init.py
- flashinfer/gemm/init.py
- flashinfer/jit/gemm/init.py
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
csrc/bf16_gemm_cutlass.jinjainclude/flashinfer/gemm/bf16_gemm_template_sm100.hinclude/flashinfer/gemm/bf16_gemm_cutlass.hflashinfer/gemm/gemm_base.pyinclude/flashinfer/gemm/bf16_gemm_cutlass_template.h
🧬 Code graph analysis (7)
tests/gemm/test_mm_bf16.py (3)
flashinfer/autotuner.py (1)
autotune(251-262)flashinfer/gemm/gemm_base.py (1)
mm_bf16(183-256)flashinfer/utils.py (2)
get_compute_capability(252-255)is_compute_capability_supported(979-994)
tests/gemm/test_bmm_bf16.py (3)
flashinfer/autotuner.py (1)
autotune(251-262)flashinfer/gemm/gemm_base.py (1)
bmm_bf16(260-334)flashinfer/utils.py (2)
get_compute_capability(252-255)is_compute_capability_supported(979-994)
flashinfer/jit/gemm/core.py (2)
flashinfer/jit/core.py (2)
JitSpec(213-312)gen_jit_spec(315-381)flashinfer/compilation_context.py (1)
get_nvcc_flags_list(50-68)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (2)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
flashinfer(41-125)gemm(42-91)include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
flashinfer(26-60)gemm(27-59)
include/flashinfer/gemm/bf16_gemm_cutlass.h (3)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
flashinfer(41-125)gemm(42-91)include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
gemm(44-182)flashinfer/gemm/gemm_base.py (1)
CutlassBf16GemmRunner(518-541)
flashinfer/gemm/gemm_base.py (8)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)
gemm(42-91)include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
gemm(44-182)include/flashinfer/gemm/bf16_gemm_cutlass.h (1)
gemm(27-59)flashinfer/jit/gemm/core.py (1)
gen_gemm_sm100_module_cutlass_bf16(193-236)flashinfer/utils.py (2)
supported_compute_capability(773-853)_get_cache_buf(205-211)flashinfer/autotuner.py (5)
TunableRunner(194-247)OptimizationProfile(168-183)TuningConfig(101-141)DynamicTensorSpec(41-82)ConstraintSpec(86-97)csrc/bf16_gemm_cutlass.cu (4)
bf16_gemm_tactic_num(149-156)bf16_gemm_tactic_num(149-149)bf16_gemm(144-147)bf16_gemm(144-145)flashinfer/fused_moe/utils.py (2)
get_last_power_of_2_num_tokens_buckets(206-215)last_positive_power_of_2(183-188)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (7)
gemm(44-182)_1SM(53-57)_2SM(60-64)cutlass(135-135)cutlass(136-136)cutlass(137-137)cutlass(138-138)include/flashinfer/gemm/cutlass_gemm_configs.h (1)
CutlassTileConfigSM100(106-425)
🪛 Clang (14.0.6)
include/flashinfer/gemm/bf16_gemm_template_sm100.h
[error] 23-23: 'cutlass/arch/arch.h' file not found
(clang-diagnostic-error)
include/flashinfer/gemm/bf16_gemm_cutlass.h
[error] 20-20: 'cuda_runtime_api.h' file not found
(clang-diagnostic-error)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h
[error] 23-23: 'cutlass/arch/arch.h' file not found
(clang-diagnostic-error)
🪛 Ruff (0.14.4)
flashinfer/jit/gemm/core.py
232-232: Consider [*nvcc_flags, "-DENABLE_BF16"] instead of concatenation
Replace with [*nvcc_flags, "-DENABLE_BF16"]
(RUF005)
flashinfer/gemm/gemm_base.py
228-228: Avoid specifying long messages outside the exception class
(TRY003)
230-230: Avoid specifying long messages outside the exception class
(TRY003)
240-242: Avoid specifying long messages outside the exception class
(TRY003)
244-246: Avoid specifying long messages outside the exception class
(TRY003)
248-250: Avoid specifying long messages outside the exception class
(TRY003)
305-305: Avoid specifying long messages outside the exception class
(TRY003)
307-307: Avoid specifying long messages outside the exception class
(TRY003)
318-320: Avoid specifying long messages outside the exception class
(TRY003)
322-324: Avoid specifying long messages outside the exception class
(TRY003)
326-328: Avoid specifying long messages outside the exception class
(TRY003)
521-521: Unused method argument: inputs
(ARG002)
522-522: Unused method argument: profile
(ARG002)
530-530: Unused method argument: do_preparation
(ARG002)
531-531: Unused method argument: kwargs
(ARG002)
🔇 Additional comments (2)
csrc/bf16_gemm_cutlass.jinja (1)
17-26: Instantiation set looks good.Coverage of cluster shapes for 1SM/2SM variants matches the SM100 launcher; no issues spotted.
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
152-157: Good: explicit workspace probe path.Early-return on null A/B/D ensures
getWorkspaceSizeImplcan probe without needing a buffer. This unblocks tactic sizing. Based on learnings.
|
Hi experts, I think this is now ready for review! Right now we are passing all the tests that I wrote for this feature: Test Results (click to expand)The original issue (#1974) was to see if CUTLASS backend GEMM for BF16 could do better at smaller batch sizes. Now, using Benchmark Results (click to expand)but the highlight is that at a larger batch size like batch=64, we are at ~7 microseconds while TGV is at ~12.5-13 microseconds, and original/torch.compile() is at ~8 microseconds. I'm a CUTLASS newbie, so maybe adding more tile sizes and cluster shapes / autotuning wider can get even better performance, since we are slightly worse elsewhere... For reviewers:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
include/flashinfer/gemm/bf16_gemm_template_sm100.h(1 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
include/flashinfer/gemm/bf16_gemm_template_sm100.h
🧬 Code graph analysis (1)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (2)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
flashinfer(41-125)gemm(42-91)include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
flashinfer(26-60)gemm(27-59)
🪛 Clang (14.0.6)
include/flashinfer/gemm/bf16_gemm_template_sm100.h
[error] 23-23: 'cutlass/arch/arch.h' file not found
(clang-diagnostic-error)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (2)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (2)
183-190: LGTM! Macro enables flexible template instantiation.The
INSTANCE_BF16_GEMM_TEMPLATE_SM100macro correctly provides explicit template instantiation control for different tile sizes, cluster shapes, and SM types. The parameter list matches the launcher's template signature, and the macro will be used by the JIT generator to instantiate specific configurations.
49-50: ****The forward declarations in this file are not problematic duplicates requiring consolidation. The actual struct definitions of
_1SMand_2SMare ininclude/flashinfer/gemm/bf16_gemm_cutlass_template.h(lines 44-45), while the SM100 template files provide independent forward declarations. This is the correct C++ pattern: the base template defines the types, and SM100 template files forward-declare them to specializeSMTypeAdapter<_1SM>andSMTypeAdapter<_2SM>without incurring unnecessary includes. This separation of concerns is appropriate and consistent across all GEMM implementations (bf16, fp8, fp4).Likely an incorrect or invalid review comment.
| throw std::runtime_error("[Bf16 Gemm Runner] insufficient workspace"); | ||
| } | ||
|
|
||
| auto can_implement = gemm.can_implement(arguments); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there any advantage to doing these safety checks this way instead of just using the CUTLASS_CHECK macro? I saw it done this way for FP8 and FP4, so I kept it this way. But just wondering because it seems the same?
Signed-off-by: raayandhar <[email protected]>
Signed-off-by: raayandhar <[email protected]>
Signed-off-by: raayandhar <[email protected]>
Signed-off-by: raayandhar <[email protected]>
Signed-off-by: Raayan Dhar [email protected] <[email protected]>
Signed-off-by: Raayan Dhar [email protected] <[email protected]>
Signed-off-by: Raayan Dhar [email protected] <[email protected]>
Signed-off-by: raayandhar <[email protected]>
Signed-off-by: raayandhar <[email protected]>
1387bed to
a56d74b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)
136-176: Guard static workspace cache against concurrent access
CutlassBf16GemmRunner<T>::getWorkspaceSizeuses a function‑local staticstd::unordered_mapto memoize workspace sizes, but accesses it without synchronization:
workspace_hashmap.find(...)andworkspace_hashmap[...] = ...can race whengetWorkspaceSizeis called concurrently from multiple threads, leading to undefined behavior.Given this is a header template used across translation units, it’s easy for callers to hit it from multiple threads (e.g., multi-stream inference).
Consider protecting the map with a mutex:
@@ struct MNKHash { size_t operator()(const MNK& mnk) const { auto h1 = std::hash<int>{}(std::get<0>(mnk)); auto h2 = std::hash<int>{}(std::get<1>(mnk)); auto h3 = std::hash<int>{}(std::get<2>(mnk)); return h1 ^ h2 ^ h3; } }; - static std::unordered_map<MNK, size_t, MNKHash> workspace_hashmap; + static std::unordered_map<MNK, size_t, MNKHash> workspace_hashmap; + static std::mutex workspace_mutex; @@ - size_t workspace_size = 0; - if (workspace_hashmap.find(std::make_tuple(m, n, k)) == workspace_hashmap.end()) { - workspace_size = CutlassBf16GemmRunner<T>::getWorkspaceSizeImpl(m, n, k); - workspace_hashmap[std::make_tuple(m, n, k)] = workspace_size; - } else { - workspace_size = workspace_hashmap[std::make_tuple(m, n, k)]; - } - return workspace_size; + const MNK key = std::make_tuple(m, n, k); + size_t workspace_size; + std::lock_guard<std::mutex> lock(workspace_mutex); + auto it = workspace_hashmap.find(key); + if (it == workspace_hashmap.end()) { + workspace_size = CutlassBf16GemmRunner<T>::getWorkspaceSizeImpl(m, n, k); + workspace_hashmap.emplace(key, workspace_size); + } else { + workspace_size = it->second; + } + return workspace_size;You may also want to upgrade
MNKHashto a stronger combiner (e.g.,boost::hash_combinestyle) to reduce collision risk, though that’s a secondary concern.tests/gemm/test_mm_bf16.py (1)
13-21: Skip BF16 MM test cleanly when CUDA is unavailableThis test assumes CUDA is present:
compute_capability = get_compute_capability(torch.device(device="cuda"))On CPU-only machines this will raise instead of skipping. Guard before any CUDA calls:
def test_mm_bf16(m: int, n: int, k: int, res_dtype: torch.dtype): - compute_capability = get_compute_capability(torch.device(device="cuda")) + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available; skipping mm_bf16 tests.") + + compute_capability = get_compute_capability(torch.device(device="cuda"))The rest of the test (layouts and reference computation) looks consistent with the mm_bf16 API.
🧹 Nitpick comments (3)
flashinfer/jit/gemm/core.py (1)
193-236: SM100 BF16 JIT generator looks consistent; only a minor style nit on flag concatenationGeneration directory, sources, tile/dtype loops, and NVCC flags match the existing SM100 FP8/FP4 patterns and look correct. You could adopt the small Ruff suggestion for cleanliness:
- return gen_jit_spec( - "bf16_gemm_cutlass", - source_paths, - extra_cuda_cflags=nvcc_flags + ["-DENABLE_BF16"], - extra_cflags=[ - "-DFAST_BUILD", - ], - ) + return gen_jit_spec( + "bf16_gemm_cutlass", + source_paths, + extra_cuda_cflags=[*nvcc_flags, "-DENABLE_BF16"], + extra_cflags=["-DFAST_BUILD"], + )include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
66-178: SM100 BF16 kernel launcher matches existing CUTLASS patterns; only minor nitsThe launcher’s layout/stride setup, workspace query (
!A && !B && !D),can_implement/initialize/runflow, and error handling look sound and consistent with the FP8 templates. The unusedarchtemplate parameter andCutlassGemmConfig configargument are harmless but could be dropped or used later if you add per-config tuning; likewise, cachinggemm.get_workspace_size(arguments)into a local would avoid recomputing it three times. No functional issues spotted here.csrc/bf16_gemm_cutlass.cu (1)
49-139: FFI BF16 GEMM wiring, shapes, and workspace handling look correctThe FFI layer’s behavior is consistent end‑to‑end:
getBf16GemmConfigpulls configs once and indexes them by tactic, matching the Python autotuner’s “tactic == index” contract.runGemmusesCutlassBf16GemmRunner::getWorkspaceSizeto size the workspace, allocates a temporary buffer when the cached one is too small, and always passes a sufficientworkspaceBytesto the runner.bf16_bmm_impl’s 2D and 3D shape logic (mat1: (m,k)/(b,m,k), mat2: (n,k)/(b,n,k)) matches the Python side’s extra transpose, and the out‑shape checks are sound.Only minor nit: the
m, n, kparameters ofgetBf16GemmConfigare currently unused; you could drop them or use them later for heuristic selection.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (13)
csrc/bf16_gemm_cutlass.cu(1 hunks)csrc/bf16_gemm_cutlass.jinja(1 hunks)docs/api/gemm.rst(1 hunks)flashinfer/__init__.py(1 hunks)flashinfer/gemm/__init__.py(2 hunks)flashinfer/gemm/gemm_base.py(4 hunks)flashinfer/jit/gemm/__init__.py(2 hunks)flashinfer/jit/gemm/core.py(1 hunks)include/flashinfer/gemm/bf16_gemm_cutlass.h(1 hunks)include/flashinfer/gemm/bf16_gemm_cutlass_template.h(1 hunks)include/flashinfer/gemm/bf16_gemm_template_sm100.h(1 hunks)tests/gemm/test_bmm_bf16.py(1 hunks)tests/gemm/test_mm_bf16.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (5)
- tests/gemm/test_bmm_bf16.py
- docs/api/gemm.rst
- csrc/bf16_gemm_cutlass.jinja
- flashinfer/init.py
- flashinfer/gemm/init.py
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
include/flashinfer/gemm/bf16_gemm_template_sm100.hinclude/flashinfer/gemm/bf16_gemm_cutlass.hcsrc/bf16_gemm_cutlass.cuflashinfer/gemm/gemm_base.pyinclude/flashinfer/gemm/bf16_gemm_cutlass_template.h
🧬 Code graph analysis (8)
flashinfer/jit/gemm/core.py (2)
flashinfer/jit/core.py (2)
JitSpec(213-312)gen_jit_spec(315-381)flashinfer/compilation_context.py (1)
get_nvcc_flags_list(50-68)
tests/gemm/test_mm_bf16.py (3)
flashinfer/autotuner.py (1)
autotune(251-262)flashinfer/gemm/gemm_base.py (1)
mm_bf16(183-256)flashinfer/utils.py (2)
get_compute_capability(253-256)is_compute_capability_supported(1020-1035)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (2)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
flashinfer(41-125)gemm(42-91)include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
flashinfer(26-60)gemm(27-59)
include/flashinfer/gemm/bf16_gemm_cutlass.h (3)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
flashinfer(41-125)gemm(42-91)include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
gemm(44-181)flashinfer/gemm/gemm_base.py (1)
CutlassBf16GemmRunner(518-541)
flashinfer/jit/gemm/__init__.py (1)
flashinfer/jit/gemm/core.py (1)
gen_gemm_sm100_module_cutlass_bf16(193-236)
csrc/bf16_gemm_cutlass.cu (4)
flashinfer/gemm/gemm_base.py (1)
CutlassBf16GemmRunner(518-541)include/flashinfer/gemm/bf16_gemm_cutlass.h (1)
CutlassBf16GemmRunnerInterface(29-41)include/flashinfer/gemm/cutlass_gemm_configs.h (1)
CutlassTileConfigSM100(106-425)csrc/tvm_ffi_utils.h (2)
get_stream(272-274)encode_dlpack_dtype(29-31)
flashinfer/gemm/gemm_base.py (8)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
gemm(44-181)include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)
gemm(42-91)include/flashinfer/gemm/bf16_gemm_cutlass.h (1)
gemm(27-59)flashinfer/jit/gemm/core.py (1)
gen_gemm_sm100_module_cutlass_bf16(193-236)flashinfer/utils.py (2)
supported_compute_capability(814-894)_get_cache_buf(206-212)flashinfer/autotuner.py (6)
TunableRunner(194-247)OptimizationProfile(168-183)AutoTuner(335-786)TuningConfig(101-141)DynamicTensorSpec(41-82)choose_one(400-529)csrc/bf16_gemm_cutlass.cu (4)
bf16_gemm_tactic_num(149-156)bf16_gemm_tactic_num(149-149)bf16_gemm(144-147)bf16_gemm(144-145)flashinfer/fused_moe/utils.py (2)
get_last_power_of_2_num_tokens_buckets(206-215)last_positive_power_of_2(183-188)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (3)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
flashinfer(26-60)gemm(27-59)include/flashinfer/gemm/bf16_gemm_template_sm100.h (7)
gemm(44-181)_1SM(53-57)_2SM(60-64)cutlass(135-135)cutlass(136-136)cutlass(137-137)cutlass(138-138)include/flashinfer/gemm/cutlass_gemm_configs.h (1)
CutlassTileConfigSM100(106-425)
🪛 Clang (14.0.6)
include/flashinfer/gemm/bf16_gemm_template_sm100.h
[error] 23-23: 'cutlass/arch/arch.h' file not found
(clang-diagnostic-error)
include/flashinfer/gemm/bf16_gemm_cutlass.h
[error] 20-20: 'cuda_runtime_api.h' file not found
(clang-diagnostic-error)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h
[error] 23-23: 'cutlass/arch/arch.h' file not found
(clang-diagnostic-error)
🪛 Ruff (0.14.5)
flashinfer/jit/gemm/core.py
232-232: Consider [*nvcc_flags, "-DENABLE_BF16"] instead of concatenation
Replace with [*nvcc_flags, "-DENABLE_BF16"]
(RUF005)
flashinfer/gemm/gemm_base.py
228-228: Avoid specifying long messages outside the exception class
(TRY003)
230-230: Avoid specifying long messages outside the exception class
(TRY003)
240-242: Avoid specifying long messages outside the exception class
(TRY003)
244-246: Avoid specifying long messages outside the exception class
(TRY003)
248-250: Avoid specifying long messages outside the exception class
(TRY003)
305-305: Avoid specifying long messages outside the exception class
(TRY003)
307-307: Avoid specifying long messages outside the exception class
(TRY003)
318-320: Avoid specifying long messages outside the exception class
(TRY003)
322-324: Avoid specifying long messages outside the exception class
(TRY003)
326-328: Avoid specifying long messages outside the exception class
(TRY003)
521-521: Unused method argument: inputs
(ARG002)
522-522: Unused method argument: profile
(ARG002)
530-530: Unused method argument: do_preparation
(ARG002)
531-531: Unused method argument: kwargs
(ARG002)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (3)
flashinfer/jit/gemm/__init__.py (1)
17-38: BF16 SM100 JIT generator correctly exportedImporting
gen_gemm_sm100_module_cutlass_bf16and adding it to__all__is consistent with the existing FP4/FP8 exports and makes the BF16 generator discoverable where expected.include/flashinfer/gemm/bf16_gemm_cutlass.h (1)
29-57: BF16 GEMM runner interface is well-shaped and matches usageInterface methods and the templated
CutlassBf16GemmRunneralign with the CUDA implementation (gemm, workspace sizing, config enumeration). Header-only surface looks good; no changes needed here.flashinfer/gemm/gemm_base.py (1)
182-257: BF16 SM100 Python path is coherent; layout and autotuning wiring look good
mm_bf16/bmm_bf16:
- API contracts (A: (m, k)/(b, m, k), B: (k, n)/(b, k, n), out: (m, n)/(b, m, n), bf16/fp16) are clear and align with the docstring examples.
- Output allocation and validation mirror existing GEMM helpers and look correct.
- Relying on the C++ layer for dtype/shape checks is consistent with your earlier stance and avoids duplication.
get_gemm_sm100_module_cutlass_bf16and runner:
- JIT loading via
gen_gemm_sm100_module_cutlass_bf16().build_and_load()matches the established FP8/FP4 pattern.- The runner’s
get_valid_tacticsreturnsrange(module.bf16_gemm_tactic_num()), which fits the FFI config index logic.- Transposing
binside the runner before callingmodule.bf16_gemmcorrectly converts user‑visible (k, n)/(b, k, n) row‑major weights into the (n, k)/(b, n, k) layout expected by the CUTLASS kernels, and also resolves the earlier “mat2 must be contiguous” issue when callers follow the documented transpose pattern.
bf16_gemm_sm100:
- SM gating via
_match_sm_version(a.device, ["100"])and thesupported_compute_capability([100])decorator on the public APIs tie this path cleanly to B200‑class hardware.- AutoTuner integration (dynamic spec on the
atensor’s-2dim and a matching constraint onout) is parallel to the FP8 GEMM tuning setup and should behave well with varying M.Net: the BF16 GEMM SM100 path is wired consistently from public Python API through autotuning into the CUTLASS backend. Any rough edges left (e.g., generic “No suitable runners found” assertion on non‑SM100 devices) are minor and acceptable for an SM100‑only feature.
Also applies to: 259-334, 513-547, 599-637
Signed-off-by: Raayan Dhar [email protected] <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (1)
flashinfer/gemm/gemm_base.py (1)
603-641: Consider adding support for SM103/SM110 and improving error handling.Two suggestions for improving robustness:
Limited SM version support: The function only supports SM100 (line 610), while the comparable
fp8_gemm_sm100supports multiple SM versions via therunner_namesparameter. Consider whether BF16 GEMM should also support SM103 and SM110 for consistency.Improve error handling: Line 612 uses
assert runnerswhich produces a crypticAssertionErrorif no suitable runners are found. Replace with a user-friendly error message similar to the pattern used elsewhere in the codebase.Apply this diff to improve the error message:
if _match_sm_version(a.device, ["100"]): runners.append(get_gemm_sm100_module_cutlass_bf16().cutlass_bf16_gemm_runner()) - assert runners, "No suitable runners found" + if len(runners) == 0: + major, minor = get_compute_capability(a.device) + raise ValueError(f"No valid runner found for current device sm{major}{minor}")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer/gemm/gemm_base.py(4 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
flashinfer/gemm/gemm_base.py
🧬 Code graph analysis (1)
flashinfer/gemm/gemm_base.py (5)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
gemm(44-181)include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)
gemm(42-91)include/flashinfer/gemm/bf16_gemm_cutlass.h (1)
gemm(27-59)flashinfer/jit/gemm/core.py (1)
gen_gemm_sm100_module_cutlass_bf16(193-236)flashinfer/utils.py (1)
supported_compute_capability(814-894)
🪛 Ruff (0.14.6)
flashinfer/gemm/gemm_base.py
230-230: Avoid specifying long messages outside the exception class
(TRY003)
232-232: Avoid specifying long messages outside the exception class
(TRY003)
242-244: Avoid specifying long messages outside the exception class
(TRY003)
246-248: Avoid specifying long messages outside the exception class
(TRY003)
250-252: Avoid specifying long messages outside the exception class
(TRY003)
307-307: Avoid specifying long messages outside the exception class
(TRY003)
309-309: Avoid specifying long messages outside the exception class
(TRY003)
320-322: Avoid specifying long messages outside the exception class
(TRY003)
324-326: Avoid specifying long messages outside the exception class
(TRY003)
328-330: Avoid specifying long messages outside the exception class
(TRY003)
523-523: Unused method argument: inputs
(ARG002)
524-524: Unused method argument: profile
(ARG002)
532-532: Unused method argument: do_preparation
(ARG002)
533-533: Unused method argument: kwargs
(ARG002)
🔇 Additional comments (2)
flashinfer/gemm/gemm_base.py (2)
55-55: LGTM!The import follows the existing naming convention and is correctly placed with other JIT GEMM module imports.
515-549: Unused parameter warnings are false positives.The static analysis tool flags
inputs,profile,do_preparation, andkwargsas unused in theCutlassBf16GemmRunnerclass (lines 523-524, 532-533). These parameters are required by theTunableRunnerinterface and cannot be removed.
Signed-off-by: Raayan Dhar [email protected] <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer/gemm/gemm_base.py(4 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
flashinfer/gemm/gemm_base.py
🧬 Code graph analysis (1)
flashinfer/gemm/gemm_base.py (3)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)
gemm(42-91)flashinfer/utils.py (1)
_get_cache_buf(206-212)csrc/bf16_gemm_cutlass.cu (4)
bf16_gemm_tactic_num(149-156)bf16_gemm_tactic_num(149-149)bf16_gemm(144-147)bf16_gemm(144-145)
🪛 Ruff (0.14.6)
flashinfer/gemm/gemm_base.py
230-230: Avoid specifying long messages outside the exception class
(TRY003)
232-232: Avoid specifying long messages outside the exception class
(TRY003)
242-244: Avoid specifying long messages outside the exception class
(TRY003)
246-248: Avoid specifying long messages outside the exception class
(TRY003)
250-252: Avoid specifying long messages outside the exception class
(TRY003)
307-307: Avoid specifying long messages outside the exception class
(TRY003)
309-309: Avoid specifying long messages outside the exception class
(TRY003)
320-322: Avoid specifying long messages outside the exception class
(TRY003)
324-326: Avoid specifying long messages outside the exception class
(TRY003)
328-330: Avoid specifying long messages outside the exception class
(TRY003)
523-523: Unused method argument: inputs
(ARG002)
524-524: Unused method argument: profile
(ARG002)
532-532: Unused method argument: do_preparation
(ARG002)
533-533: Unused method argument: kwargs
(ARG002)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (2)
flashinfer/gemm/gemm_base.py (2)
515-549: Clarify resolution of transpose contiguity concern.Earlier reviews flagged that
b.transpose(-2, -1)at line 538 creates a non-contiguous view, which reportedly triggered the runtime error"mat2 must be contiguous"you mentioned. You noted that calling.contiguous()caused a SEGFAULT.The code still passes the non-contiguous transposed view directly, yet you report that local tests pass (32 bmm_bf16 + 90 mm_bf16). This suggests either:
- The C++ binding was updated to handle non-contiguous tensors, or
- The kernel correctly interprets the strided layout from
.transpose(-2, -1), or- The test cases don't trigger the failure path.
Since the FP8 implementation (line 498) uses an identical pattern, this may be the intended approach. However, for future maintainability, could you briefly confirm:
- Was the C++ side modified to accept non-contiguous
mat2?- Or does the CUTLASS kernel natively handle the strided transpose layout?
This will help document the resolution and prevent confusion in future reviews.
Note: The unused arguments
do_preparationandkwargs(lines 532-533) are part of theTunableRunnerinterface, so the static analysis warnings can be ignored.
603-641: LGTM! Follows established GEMM dispatcher patterns.The implementation correctly:
- Validates SM100 support via
_match_sm_version- Initializes the BF16 runner
- Configures autotuning with dynamic tensor specs for the M dimension
- Uses constraint specs to keep output shape consistent with input
- Integrates with the existing
AutoTunerinfrastructureThe structure mirrors
fp8_gemm_sm100(lines 571-600), ensuring consistency across GEMM implementations.
📌 Description
This issue was opened a little while ago (#1974) and I finally got a chance to tackle it. Feature request for BF16 GEMM. I decided to try and implement using CUTLASS backend. The issue poster was using B200 so I implemented for B200 (SM100) as well.
🔍 Related Issues
Feature request: #1974
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Documentation
Tests
✏️ Tip: You can customize this high-level summary in your review settings.