feat: Add cuBLASLt backend for mm_bf16 and enable multi-tactic autotuning for FP8/MXFP8 runners#2914
feat: Add cuBLASLt backend for mm_bf16 and enable multi-tactic autotuning for FP8/MXFP8 runners#2914vadiklyutiy wants to merge 15 commits intoflashinfer-ai:mainfrom
mm_bf16 and enable multi-tactic autotuning for FP8/MXFP8 runners#2914Conversation
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
…8 GEMM runners Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughAdds cuBLASLt-backed BF16 GEMM and FP8 BMM algorithm enumeration with serialized algo buffers and tactic-based execution, exposes new FFI bindings and JIT/AOT module for cublasLt BF16 GEMM, extends CLI/tests with Changes
Sequence DiagramsequenceDiagram
participant App as Application / Test
participant Runner as GEMM/BMM Runner
participant Cache as Algo Cache (CPU)
participant Enumerator as Algo Enumerator (FFI)
participant cuBLASLt as cuBLASLt (Heuristics & Executor)
App->>Runner: forward(inputs, tactic? / auto_tune?)
Runner->>Cache: lookup(shape, dtype)
alt cache hit
Cache-->>Runner: algo_buffer, count
else cache miss
Runner->>Enumerator: get_algorithms(A,B,workspace)
Enumerator->>cuBLASLt: query heuristics (workspace limit)
cuBLASLt-->>Enumerator: [algo_t ...]
Enumerator->>Enumerator: serialize algos -> algo_buffer (CPU)
Enumerator-->>Runner: algo_count, algo_buffer
Runner->>Cache: store(shape,dtype)->algo_buffer
end
alt tactic >= 0 or autotune loop
loop try tactics 0..N-1
Runner->>cuBLASLt: run_with_algo(algo_buf, idx, workspace, stream)
cuBLASLt-->>Runner: status/result
end
Runner->>App: output
else default
Runner->>cuBLASLt: run_with_algo(algo_buf, idx=0)
cuBLASLt-->>Runner: result
Runner-->>App: output
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a new cublaslt backend for BF16 GEMM operations, enabling heuristic algorithm selection and caching to minimize runtime overhead. It also extends the FP8 BMM implementation with similar algorithm selection capabilities and updates the autotuner to handle non-hashable values. Review feedback correctly identified a cache key collision risk in the algorithm caching logic and a shape mismatch in the new test cases for non-square matrices.
|
/gemini review |
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/aot.py (1)
490-498:⚠️ Potential issue | 🟠 MajorRegister the new AOT module for SM103 too.
Line 497 is currently nested under
if has_sm100:, so an AOT build targeting onlycompute_103never packagesmm_bf16_cublaslteven though this PR adds that backend for SM103 as well.has_sm103is already computed in this function, so the new append needs its ownif has_sm100 or has_sm103:guard.♻️ Proposed change
if has_sm100: jit_specs.append(gen_fp4_quantization_sm100_module()) jit_specs.append(gen_cutlass_fused_moe_sm100_module()) jit_specs.append(gen_gemm_sm100_module()) jit_specs.append(gen_gemm_sm100_module_cutlass_fp4()) jit_specs.append(gen_gemm_sm100_module_cutlass_fp8()) jit_specs.append(gen_gemm_sm100_module_cutlass_mxfp8()) - jit_specs.append(gen_mm_bf16_cublaslt_module()) + if has_sm100 or has_sm103: + jit_specs.append(gen_mm_bf16_cublaslt_module()) + if has_sm100: # Add TGV GEMM modules for both bf16 and fp16 jit_specs.append( gen_tgv_gemm_sm10x_module(torch.bfloat16, use_sm_100f=False) )As per coding guidelines,
flashinfer/aot.pyshould "Register new operations inflashinfer/aot.pyfor AOT (Ahead-of-Time) compilation into pre-compiled packages".🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/aot.py` around lines 490 - 498, The mm_bf16 cublaslt AOT module registration is incorrectly placed inside the has_sm100-only block so builds targeting compute_103 skip it; change the logic around jit_specs.append(gen_mm_bf16_cublaslt_module()) so it is executed when either has_sm100 or has_sm103 is true (i.e., wrap or move that append under an if has_sm100 or has_sm103 guard), keeping the other SM100-only appends (gen_fp4_quantization_sm100_module, gen_cutlass_fused_moe_sm100_module, gen_gemm_sm100_module*, etc.) unchanged.
🧹 Nitpick comments (1)
flashinfer/jit/gemm/core.py (1)
53-60: Scope this JIT spec to SM10x builds.The new backend is SM100/SM103-gated, but Lines 54-59 don't pass any arch-scoped NVCC flags. Mixed-arch builds will compile/package
mm_bf16_cublasltfor every target inFLASHINFER_CUDA_ARCH_LIST, unlike the other GEMM generators in this file.♻️ Proposed change
def gen_mm_bf16_cublaslt_module() -> JitSpec: + nvcc_flags = current_compilation_context.get_nvcc_flags_list( + supported_major_versions=[10] + ) return gen_jit_spec( "mm_bf16_cublaslt", [ jit_env.FLASHINFER_CSRC_DIR / "mm_bf16_cublaslt.cu", ], + extra_cuda_cflags=nvcc_flags, extra_ldflags=["-lcublas", "-lcublasLt"], )As per coding guidelines,
flashinfer/jit/**/*.pyshould "Specify supported NVIDIA SM major versions in JIT modules usingsupported_major_versionsparameter to limit compilation to specific GPU architectures".🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/jit/gemm/core.py` around lines 53 - 60, The JIT spec gen_mm_bf16_cublaslt_module currently calls gen_jit_spec without arch scoping, causing mixed-arch builds; update the gen_jit_spec invocation in gen_mm_bf16_cublaslt_module to pass supported_major_versions=[10] (or the list containing SM10x major version) so the module is only compiled for SM100/SM103-class GPUs (refer to gen_mm_bf16_cublaslt_module, gen_jit_spec, and JitSpec to locate the change).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/mm_bf16_cublaslt.cu`:
- Around line 56-76: Reject invalid tensor residency and dtype before crossing
host/device: ensure algo_buffer is host (CPU) memory, contiguous, and uint8
(check algo_buffer.device().is_cpu() and algo_buffer.dtype()==torch::kUInt8 and
keep CHECK_CONTIGUOUS(algo_buffer)), and ensure workspace_buffer is CUDA device
memory (check workspace_buffer.device().is_cuda()) before passing
algo_buffer.data_ptr() to host-side memcpy helpers or
workspace_buffer.data_ptr() to cublasLt calls; use the same TVM_FFI_ICHECK (or
TVM_FFI_ICHECK_EQ) style used for other checks to return clear errors, and keep
get_algorithms / cublasLtMatmul calls unchanged otherwise so pointers are safe.
In `@flashinfer/gemm/gemm_base.py`:
- Around line 980-983: The BF16 cuBLASLt algorithm cache key in _get_algos is
using b.shape[0] (K) twice, so N is never included and different N values
collide; change the key construction in _get_algos to use b.shape[1] for N
(since mm_bf16 receives b in (K, N) layout) and make the same correction in the
other equivalent cache-key site (the second occurrence around the
mm_bf16-related logic) so the tuple becomes (M, N, K, compute_dt) (or equivalent
ordering used elsewhere) to uniquely key by N.
- Around line 165-179: The current path assumes at least one algorithm exists
and forces tactic=0 when tactic >= count, but if self._get_algos(inputs) returns
count==0 you must avoid calling module.bmm_fp8_run_with_algo with an
uninitialized algo buffer; update the block around self._get_algos(inputs) to
check if count == 0 and handle it (e.g., raise a clear RuntimeError or fall back
to the non-algo execution path) before computing/adjusting tactic and before
calling module.bmm_fp8_run_with_algo; refer to _get_algos and
module.bmm_fp8_run_with_algo (and the local variable tactic) to locate where to
add the guard.
---
Outside diff comments:
In `@flashinfer/aot.py`:
- Around line 490-498: The mm_bf16 cublaslt AOT module registration is
incorrectly placed inside the has_sm100-only block so builds targeting
compute_103 skip it; change the logic around
jit_specs.append(gen_mm_bf16_cublaslt_module()) so it is executed when either
has_sm100 or has_sm103 is true (i.e., wrap or move that append under an if
has_sm100 or has_sm103 guard), keeping the other SM100-only appends
(gen_fp4_quantization_sm100_module, gen_cutlass_fused_moe_sm100_module,
gen_gemm_sm100_module*, etc.) unchanged.
---
Nitpick comments:
In `@flashinfer/jit/gemm/core.py`:
- Around line 53-60: The JIT spec gen_mm_bf16_cublaslt_module currently calls
gen_jit_spec without arch scoping, causing mixed-arch builds; update the
gen_jit_spec invocation in gen_mm_bf16_cublaslt_module to pass
supported_major_versions=[10] (or the list containing SM10x major version) so
the module is only compiled for SM100/SM103-class GPUs (refer to
gen_mm_bf16_cublaslt_module, gen_jit_spec, and JitSpec to locate the change).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: cbf109a0-b993-446a-8aa7-c78df2ecc3d2
📒 Files selected for processing (13)
benchmarks/routines/gemm.pycsrc/bmm_fp8.cucsrc/flashinfer_gemm_binding.cucsrc/mm_bf16_cublaslt.cuflashinfer/aot.pyflashinfer/autotuner.pyflashinfer/gemm/gemm_base.pyflashinfer/jit/gemm/__init__.pyflashinfer/jit/gemm/core.pyinclude/flashinfer/gemm/bmm_fp8.cuhinclude/flashinfer/gemm/mm_bf16_cublaslt.cuhtests/gemm/test_bmm_bf16.pytests/gemm/test_mm_bf16.py
|
Warning Gemini is experiencing higher than usual traffic and was unable to create the review. Please try again in a few hours by commenting |
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (2)
flashinfer/gemm/gemm_base.py (2)
1037-1040:⚠️ Potential issue | 🟡 MinorReport the BF16 zero-algo
Ndimension fromb.shape[1].Line 1039 still prints
N={b.shape[0]}, which isKfor the(K, N)layout accepted bymm_bf16(...). That makes the new diagnostic point people at the wrong dimension.Suggested fix
- f"M={a.shape[0]}, N={b.shape[0]}, K={a.shape[1]}, " + f"M={a.shape[0]}, N={b.shape[1]}, K={a.shape[1]}, "🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 1037 - 1040, The error message in the cuBLASLt heuristic failure prints the wrong N dimension (uses b.shape[0], which is K for the (K, N) layout); update the RuntimeError message in gemm_base.py where the heuristic failure is raised to report N using b.shape[1] instead of b.shape[0], and ensure the message still includes M (a.shape[0]), K (a.shape[1]), and dtype (compute_out.dtype) so callers of mm_bf16(...) see the correct diagnostic.
165-179:⚠️ Potential issue | 🟠 MajorGuard
count == 0beforebmm_fp8_run_with_algo().If
bmm_fp8_get_algos()returns0, Line 167 still falls through totactic = 0and Line 169 dispatches with an uninitialized algo record. The BF16 cuBLASLt runner below already fails fast on this edge case; this path needs the same check.Suggested fix
if tactic >= 0: algo_buf, count = self._get_algos(inputs) + if count == 0: + raise RuntimeError( + "cuBLASLt heuristic returned zero FP8 algorithms for " + f"A={tuple(a.shape)}, B={tuple(b.shape)}, out={tuple(out.shape)}." + ) if tactic >= count: tactic = 0 module.bmm_fp8_run_with_algo(🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 165 - 179, The code calls module.bmm_fp8_run_with_algo(...) using algo_buf from self._get_algos(inputs) without handling the case when count == 0; update the block in the tactic >= 0 branch (where tactic, _get_algos, algo_buf, count are used) to check if count == 0 and fail fast or return/raise (matching the BF16 cuBLASLt runner behavior) instead of proceeding—i.e., after calling self._get_algos(inputs) if count == 0 log/raise an error or return early so bmm_fp8_run_with_algo is never invoked with an uninitialized algo_buf.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 128-143: The code fetches a cuBLAS handle with
torch.cuda.current_blas_handle() without ensuring the current CUDA device
matches the input tensor's device, which can dispatch work to the wrong GPU;
update all call sites (e.g., inside CublasFp8GemmRunner and
CublasltBf16GemmRunner where module.bmm_fp8_get_algos and similar are invoked)
to switch the CUDA context to the tensor device before calling
torch.cuda.current_blas_handle() (use torch.cuda.device(a.device) or
torch.cuda.set_device(a.device) as a guard) and then restore/exit the context so
the retrieved handle is bound to a.device; apply the same pattern at every
flagged location (the calls at the sites mentioned in the review).
---
Duplicate comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 1037-1040: The error message in the cuBLASLt heuristic failure
prints the wrong N dimension (uses b.shape[0], which is K for the (K, N)
layout); update the RuntimeError message in gemm_base.py where the heuristic
failure is raised to report N using b.shape[1] instead of b.shape[0], and ensure
the message still includes M (a.shape[0]), K (a.shape[1]), and dtype
(compute_out.dtype) so callers of mm_bf16(...) see the correct diagnostic.
- Around line 165-179: The code calls module.bmm_fp8_run_with_algo(...) using
algo_buf from self._get_algos(inputs) without handling the case when count == 0;
update the block in the tactic >= 0 branch (where tactic, _get_algos, algo_buf,
count are used) to check if count == 0 and fail fast or return/raise (matching
the BF16 cuBLASLt runner behavior) instead of proceeding—i.e., after calling
self._get_algos(inputs) if count == 0 log/raise an error or return early so
bmm_fp8_run_with_algo is never invoked with an uninitialized algo_buf.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 714631ec-98a6-48ee-a06a-5ffc06099110
📒 Files selected for processing (2)
flashinfer/gemm/gemm_base.pytests/gemm/test_mm_bf16.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/gemm/test_mm_bf16.py
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/gemm/gemm_base.py (1)
118-189:⚠️ Potential issue | 🟠 MajorInclude dtype-dependent tactic space in the autotuner key.
These runners now expose tactic lists that depend on
a.dtype,b.dtype, and/or the compute/output dtype, but they don’t overrideget_cache_key_extras(). That lets autotuned choices bleed across FP8 format or BF16/FP32 output variants, so a cached tactic index can point at a different algo/plan list on the next call.💡 Suggested fix
class CublasFp8GemmRunner(TunableRunner): + def get_cache_key_extras(self, inputs: List[torch.Tensor]) -> tuple: + a, b, _, _, out, _ = inputs + return (a.dtype, b.dtype, out.dtype) + def __init__(self): self._algo_cache: dict = {} ... class CublasltBf16GemmRunner(TunableRunner): + def get_cache_key_extras(self, inputs: List[torch.Tensor]) -> tuple: + _, _, _, _, out, _ = inputs + return (self._compute_dtype(out.dtype),) + def __init__(self): self._algo_cache: dict = {} ... class CudnnFp8GemmRunner(TunableRunner): + def get_cache_key_extras(self, inputs: List[torch.Tensor]) -> tuple: + a, b, _, _, out, _ = inputs + return (a.dtype, b.dtype, out.dtype) + def get_valid_tactics( self, inputs: List[torch.Tensor], profile: OptimizationProfile, ) -> List[int]: ... class CudnnMxfp8GemmRunner(TunableRunner): + def get_cache_key_extras(self, inputs: List[torch.Tensor]) -> tuple: + a, b, _, _, out, _ = inputs + return (a.dtype, b.dtype, out.dtype) + def get_valid_tactics( self, inputs: List[torch.Tensor], profile: OptimizationProfile, ) -> List[int]:Also applies to: 973-1061, 2969-3006, 7039-7073
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 118 - 189, The autotuner key for CublasFp8GemmRunner must include tensor dtypes so tactic indices don't get reused across different FP8/BF16/FP32 variants; add/override get_cache_key_extras(self, inputs) in the CublasFp8GemmRunner class to return a tuple of the dtype identifiers (e.g., a.dtype, b.dtype, out.dtype and any scale tensor dtypes if relevant) derived from the inputs used in get_valid_tactics/forward so the tuner cache separates entries by compute/output formats; update the same pattern for the other runner classes mentioned (the ones at the other ranges) that expose dtype-dependent tactic lists.
♻️ Duplicate comments (1)
flashinfer/gemm/gemm_base.py (1)
122-145:⚠️ Potential issue | 🟠 MajorFetch the BLAS handle under
a.device.
torch.cuda.current_blas_handle()returns the current cuBLAS handle, and cuBLAS handles are device-scoped. These four lookups happen without first switching to the input tensor’s device, so multi-GPU calls can pick up a handle for the wrong GPU. Wrap each lookup inwith torch.cuda.device(a.device): .... (docs.pytorch.org)Also applies to: 163-188, 985-1013, 1030-1057
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 122 - 145, The code in _get_algos obtains a cuBLAS handle via torch.cuda.current_blas_handle() without ensuring the correct device context, which can cause wrong-GPU handles in multi-GPU runs; wrap the handle lookup (and any other CUDA device-specific lookups) in a device context using with torch.cuda.device(a.device): and call torch.cuda.current_blas_handle() inside that block, then use that handle for module.bmm_fp8_get_algos; apply the same pattern to the other similar lookup sites flagged (the other cuBLAS handle/get calls in this file) so each lookup is performed under the input tensor’s device context.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 441-443: The docstring for the out parameter incorrectly omits the
TGV path; update the out parameter description in gemm_base.py so it states that
preallocated outputs are supported by the TGV backend as well (since
_tgv_gemm_requirement() only restricts non-BF16 out_dtype and
TGVGemmRunner.forward() writes into the provided out), e.g. change the "Enabled
for CUTLASS, cuDNN, and cuBLASLt backends" phrase to include TGV or use a more
general phrasing that lists TGVGemmRunner/TGV as supported.
---
Outside diff comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 118-189: The autotuner key for CublasFp8GemmRunner must include
tensor dtypes so tactic indices don't get reused across different FP8/BF16/FP32
variants; add/override get_cache_key_extras(self, inputs) in the
CublasFp8GemmRunner class to return a tuple of the dtype identifiers (e.g.,
a.dtype, b.dtype, out.dtype and any scale tensor dtypes if relevant) derived
from the inputs used in get_valid_tactics/forward so the tuner cache separates
entries by compute/output formats; update the same pattern for the other runner
classes mentioned (the ones at the other ranges) that expose dtype-dependent
tactic lists.
---
Duplicate comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 122-145: The code in _get_algos obtains a cuBLAS handle via
torch.cuda.current_blas_handle() without ensuring the correct device context,
which can cause wrong-GPU handles in multi-GPU runs; wrap the handle lookup (and
any other CUDA device-specific lookups) in a device context using with
torch.cuda.device(a.device): and call torch.cuda.current_blas_handle() inside
that block, then use that handle for module.bmm_fp8_get_algos; apply the same
pattern to the other similar lookup sites flagged (the other cuBLAS handle/get
calls in this file) so each lookup is performed under the input tensor’s device
context.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f5e57a92-8f6e-492d-b499-e7227d7c8ce2
📒 Files selected for processing (5)
csrc/mm_bf16_cublaslt.cuflashinfer/aot.pyflashinfer/autotuner.pyflashinfer/gemm/gemm_base.pyflashinfer/jit/gemm/core.py
🚧 Files skipped from review as they are similar to previous changes (3)
- flashinfer/aot.py
- flashinfer/autotuner.py
- csrc/mm_bf16_cublaslt.cu
There was a problem hiding this comment.
♻️ Duplicate comments (1)
flashinfer/gemm/gemm_base.py (1)
444-445:⚠️ Potential issue | 🟡 Minor
outdocs still omit TGV preallocated output support.This wording is still misleading: TGV accepts preallocated
out(with bf16 constraints), so the docs should not implyoutis unsupported on TGV.Suggested doc wording
- out: Optional[torch.Tensor] - Out tensor, shape (m, n), bf16, fp16, or fp32. Enabled for CUTLASS, cuDNN, and cuBLASLt - backends. Defaults to ``None``. + out: Optional[torch.Tensor] + Out tensor, shape (m, n). Preallocated output is supported by all backends. + TGV requires ``torch.bfloat16`` output; CUTLASS, cuDNN, and cuBLASLt also support fp16/fp32. + Defaults to ``None``.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 444 - 445, Update the docstring for the parameter "out" in gemm_base.py to explicitly state that TGV backend accepts a preallocated out tensor (subject to TGV's bf16 dtype constraints) rather than implying out is unsupported on TGV; keep existing notes that out can be bf16/fp16/fp32 for CUTLASS, cuDNN, and cuBLASLt, and add a sentence like "TGV also accepts a preallocated out tensor but currently requires bf16 dtype" referencing the "out" parameter and the TGV backend to make the support and dtype constraint clear.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 444-445: Update the docstring for the parameter "out" in
gemm_base.py to explicitly state that TGV backend accepts a preallocated out
tensor (subject to TGV's bf16 dtype constraints) rather than implying out is
unsupported on TGV; keep existing notes that out can be bf16/fp16/fp32 for
CUTLASS, cuDNN, and cuBLASLt, and add a sentence like "TGV also accepts a
preallocated out tensor but currently requires bf16 dtype" referencing the "out"
parameter and the TGV backend to make the support and dtype constraint clear.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 77499ae3-7c49-4096-b218-91c4e9ae5172
📒 Files selected for processing (1)
flashinfer/gemm/gemm_base.py
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Summary
mm_bf16: newbackend="cublaslt"option (gated to SM100/SM103). Autotuning across all available cuBLASLt algorithms viaget_valid_tactics().CublasFp8GemmRunner,CudnnFp8GemmRunner, andCudnnMxfp8GemmRunnerpreviously hardcoded a single tactic (return [0]/return [-1]), preventing the autotuner from exploring better algorithms (the same are done for FP4 and FP16). Now all three enumerate available algorithms/plans viaget_valid_tactics()and pass the selected tactic through to execution.cublasltandautobackends +auto_tuningparameter totest_mm_bf16.py,autobackend totest_bmm_bf16.py, and a dedicated edge-case test for zero-algorithm handling.Test Results
All GEMM/BMM test suites pass with 0 failures:
test_mm_bf16.pytest_mm_fp8.pytest_mm_fp4.pytest_mm_mxfp8.pytest_bmm_bf16.pytest_bmm_fp8.pytest_bmm_mxfp8.pyTest additions:
test_mm_bf16.py: addedcublasltandautoto backend parametrize, addedauto_tuningparameter, addedtest_cublaslt_bf16_runner_zero_algosedge-case test.test_bmm_bf16.py: addedautobackend.Summary by CodeRabbit
New Features
Bug Fixes
Tests