Skip to content

feat: Add cuBLASLt backend for mm_bf16 and enable multi-tactic autotuning for FP8/MXFP8 runners#2914

Open
vadiklyutiy wants to merge 15 commits intoflashinfer-ai:mainfrom
vadiklyutiy:mm-bf16-cublaslt
Open

feat: Add cuBLASLt backend for mm_bf16 and enable multi-tactic autotuning for FP8/MXFP8 runners#2914
vadiklyutiy wants to merge 15 commits intoflashinfer-ai:mainfrom
vadiklyutiy:mm-bf16-cublaslt

Conversation

@vadiklyutiy
Copy link
Copy Markdown

@vadiklyutiy vadiklyutiy commented Mar 30, 2026

Summary

  • Add cuBLASLt backend for mm_bf16: new backend="cublaslt" option (gated to SM100/SM103). Autotuning across all available cuBLASLt algorithms via get_valid_tactics().
  • Enable multi-tactic autotuning for single-tactic GEMM runners: CublasFp8GemmRunner, CudnnFp8GemmRunner, and CudnnMxfp8GemmRunner previously hardcoded a single tactic (return [0] / return [-1]), preventing the autotuner from exploring better algorithms (the same are done for FP4 and FP16). Now all three enumerate available algorithms/plans via get_valid_tactics() and pass the selected tactic through to execution.
  • Improve test coverage: added cublaslt and auto backends + auto_tuning parameter to test_mm_bf16.py, auto backend to test_bmm_bf16.py, and a dedicated edge-case test for zero-algorithm handling.

Test Results

All GEMM/BMM test suites pass with 0 failures:

Test File Passed Failed
test_mm_bf16.py 1441 0
test_mm_fp8.py 30 0
test_mm_fp4.py 1440 0
test_mm_mxfp8.py 1843 0
test_bmm_bf16.py 144 0
test_bmm_fp8.py 1188 0
test_bmm_mxfp8.py 288 0
Total 6374 0

Test additions:

  • test_mm_bf16.py: added cublaslt and auto to backend parametrize, added auto_tuning parameter, added test_cublaslt_bf16_runner_zero_algos edge-case test.
  • test_bmm_bf16.py: added auto backend.

Summary by CodeRabbit

  • New Features

    • Added a cuBLASLt BF16 GEMM backend with selectable algorithm tactics; module generated for supported hardware.
    • Exposed APIs to enumerate and run FP8/BF16 GEMM algorithms for explicit tactic selection.
    • Expanded backend choices with "cublaslt" and broader "auto" routing.
  • Bug Fixes

    • Robustified autotuner hashing to avoid failures with unhashable attributes.
  • Tests

    • Extended tests for new backends, autotuning options, auto routing, and zero-algorithm failure handling.

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
…8 GEMM runners

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 30, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 7003c174-3bc7-431c-8ffc-9d591c38fb5b

📥 Commits

Reviewing files that changed from the base of the PR and between 9e82e66 and b5f957d.

📒 Files selected for processing (1)
  • flashinfer/gemm/gemm_base.py

📝 Walkthrough

Walkthrough

Adds cuBLASLt-backed BF16 GEMM and FP8 BMM algorithm enumeration with serialized algo buffers and tactic-based execution, exposes new FFI bindings and JIT/AOT module for cublasLt BF16 GEMM, extends CLI/tests with cublaslt/auto, and hardens autotuner hashing.

Changes

Cohort / File(s) Summary
CLI & Tests
benchmarks/routines/gemm.py, tests/gemm/test_bmm_bf16.py, tests/gemm/test_mm_bf16.py
Added cublaslt and auto backend choices, added auto_tuning test dimension, adjusted backend skip logic, and added a zero-algorithms test for the cublaslt BF16 runner.
FP8 BMM runtime & bindings
csrc/bmm_fp8.cu, csrc/flashinfer_gemm_binding.cu, include/flashinfer/gemm/bmm_fp8.cuh
Changed workspace length units to bytes for runtime calls; added algorithm enumeration API (bmm_fp8_get_algos) and run-with-algo API (bmm_fp8_run_with_algo); added cuBLASLt FP8 descriptor helpers and serialized algo storage.
BF16 cuBLASLt implementation
csrc/mm_bf16_cublaslt.cu, include/flashinfer/gemm/mm_bf16_cublaslt.cuh
New cuBLASLt-backed BF16 GEMM: descriptor factories, algorithm enumeration (get_algorithms), serialized algo buffer format, run-with-algo execution, and FFI-exported getter/runner functions.
GEMM routing & runners
flashinfer/gemm/gemm_base.py
Added cublaslt backend and SM100/SM103 gating; implemented cublasLt BF16 runner with algo enumeration/cache and tactic-based execution; updated FP8 and cuDNN runners to enumerate and accept explicit tactics.
JIT / AOT module generation
flashinfer/jit/gemm/core.py, flashinfer/jit/gemm/__init__.py, flashinfer/aot.py
Added gen_mm_bf16_cublaslt_module, exported it, and wired it into AOT/JIT generation for SM100/SM103 with -lcublas/-lcublasLt linking.
Autotuner robustness
flashinfer/autotuner.py
Made TunableRunner.__hash__ resilient to unhashable attributes by skipping *_cache fields and falling back to id(v) for unhashable values.
FFI exports
csrc/flashinfer_gemm_binding.cu
Registered new FFI exports: bmm_fp8_get_algos, bmm_fp8_run_with_algo, mm_bf16_cublaslt_get_algos, and mm_bf16_cublaslt_run_with_algo.

Sequence Diagram

sequenceDiagram
    participant App as Application / Test
    participant Runner as GEMM/BMM Runner
    participant Cache as Algo Cache (CPU)
    participant Enumerator as Algo Enumerator (FFI)
    participant cuBLASLt as cuBLASLt (Heuristics & Executor)

    App->>Runner: forward(inputs, tactic? / auto_tune?)
    Runner->>Cache: lookup(shape, dtype)
    alt cache hit
        Cache-->>Runner: algo_buffer, count
    else cache miss
        Runner->>Enumerator: get_algorithms(A,B,workspace) 
        Enumerator->>cuBLASLt: query heuristics (workspace limit)
        cuBLASLt-->>Enumerator: [algo_t ...]
        Enumerator->>Enumerator: serialize algos -> algo_buffer (CPU)
        Enumerator-->>Runner: algo_count, algo_buffer
        Runner->>Cache: store(shape,dtype)->algo_buffer
    end
    alt tactic >= 0 or autotune loop
        loop try tactics 0..N-1
            Runner->>cuBLASLt: run_with_algo(algo_buf, idx, workspace, stream)
            cuBLASLt-->>Runner: status/result
        end
        Runner->>App: output
    else default
        Runner->>cuBLASLt: run_with_algo(algo_buf, idx=0)
        cuBLASLt-->>Runner: result
        Runner-->>App: output
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

op: gemm, run-ci

Suggested reviewers

  • bkryu
  • nvmbreughe
  • jimmyzho
  • yzh119
  • jiahanc
  • yongwww
  • cyx-6

Poem

🐇 I bounded through code with a curious twitch,
I cached all the algos inside my little niche.
cuBLASLt hummed, I picked which to try,
I hop, I run — the fastest one flies high. ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 16.95% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main changes: adding cuBLASLt backend for mm_bf16 and enabling multi-tactic autotuning for FP8/MXFP8 runners, matching the core objectives.
Description check ✅ Passed The PR description provides a clear summary section covering all major changes, includes comprehensive test results showing 6374 tests passed with 0 failures, and documents specific test additions; however, it lacks the required pre-commit checks, related issues link, and reviewer notes sections from the template.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new cublaslt backend for BF16 GEMM operations, enabling heuristic algorithm selection and caching to minimize runtime overhead. It also extends the FP8 BMM implementation with similar algorithm selection capabilities and updates the autotuner to handle non-hashable values. Review feedback correctly identified a cache key collision risk in the algorithm caching logic and a shape mismatch in the new test cases for non-square matrices.

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
@vadiklyutiy
Copy link
Copy Markdown
Author

/gemini review

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/aot.py (1)

490-498: ⚠️ Potential issue | 🟠 Major

Register the new AOT module for SM103 too.

Line 497 is currently nested under if has_sm100:, so an AOT build targeting only compute_103 never packages mm_bf16_cublaslt even though this PR adds that backend for SM103 as well. has_sm103 is already computed in this function, so the new append needs its own if has_sm100 or has_sm103: guard.

♻️ Proposed change
         if has_sm100:
             jit_specs.append(gen_fp4_quantization_sm100_module())
             jit_specs.append(gen_cutlass_fused_moe_sm100_module())
             jit_specs.append(gen_gemm_sm100_module())
             jit_specs.append(gen_gemm_sm100_module_cutlass_fp4())
             jit_specs.append(gen_gemm_sm100_module_cutlass_fp8())
             jit_specs.append(gen_gemm_sm100_module_cutlass_mxfp8())
-            jit_specs.append(gen_mm_bf16_cublaslt_module())
+        if has_sm100 or has_sm103:
+            jit_specs.append(gen_mm_bf16_cublaslt_module())
+        if has_sm100:
             # Add TGV GEMM modules for both bf16 and fp16
             jit_specs.append(
                 gen_tgv_gemm_sm10x_module(torch.bfloat16, use_sm_100f=False)
             )

As per coding guidelines, flashinfer/aot.py should "Register new operations in flashinfer/aot.py for AOT (Ahead-of-Time) compilation into pre-compiled packages".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/aot.py` around lines 490 - 498, The mm_bf16 cublaslt AOT module
registration is incorrectly placed inside the has_sm100-only block so builds
targeting compute_103 skip it; change the logic around
jit_specs.append(gen_mm_bf16_cublaslt_module()) so it is executed when either
has_sm100 or has_sm103 is true (i.e., wrap or move that append under an if
has_sm100 or has_sm103 guard), keeping the other SM100-only appends
(gen_fp4_quantization_sm100_module, gen_cutlass_fused_moe_sm100_module,
gen_gemm_sm100_module*, etc.) unchanged.
🧹 Nitpick comments (1)
flashinfer/jit/gemm/core.py (1)

53-60: Scope this JIT spec to SM10x builds.

The new backend is SM100/SM103-gated, but Lines 54-59 don't pass any arch-scoped NVCC flags. Mixed-arch builds will compile/package mm_bf16_cublaslt for every target in FLASHINFER_CUDA_ARCH_LIST, unlike the other GEMM generators in this file.

♻️ Proposed change
 def gen_mm_bf16_cublaslt_module() -> JitSpec:
+    nvcc_flags = current_compilation_context.get_nvcc_flags_list(
+        supported_major_versions=[10]
+    )
     return gen_jit_spec(
         "mm_bf16_cublaslt",
         [
             jit_env.FLASHINFER_CSRC_DIR / "mm_bf16_cublaslt.cu",
         ],
+        extra_cuda_cflags=nvcc_flags,
         extra_ldflags=["-lcublas", "-lcublasLt"],
     )

As per coding guidelines, flashinfer/jit/**/*.py should "Specify supported NVIDIA SM major versions in JIT modules using supported_major_versions parameter to limit compilation to specific GPU architectures".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/jit/gemm/core.py` around lines 53 - 60, The JIT spec
gen_mm_bf16_cublaslt_module currently calls gen_jit_spec without arch scoping,
causing mixed-arch builds; update the gen_jit_spec invocation in
gen_mm_bf16_cublaslt_module to pass supported_major_versions=[10] (or the list
containing SM10x major version) so the module is only compiled for
SM100/SM103-class GPUs (refer to gen_mm_bf16_cublaslt_module, gen_jit_spec, and
JitSpec to locate the change).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/mm_bf16_cublaslt.cu`:
- Around line 56-76: Reject invalid tensor residency and dtype before crossing
host/device: ensure algo_buffer is host (CPU) memory, contiguous, and uint8
(check algo_buffer.device().is_cpu() and algo_buffer.dtype()==torch::kUInt8 and
keep CHECK_CONTIGUOUS(algo_buffer)), and ensure workspace_buffer is CUDA device
memory (check workspace_buffer.device().is_cuda()) before passing
algo_buffer.data_ptr() to host-side memcpy helpers or
workspace_buffer.data_ptr() to cublasLt calls; use the same TVM_FFI_ICHECK (or
TVM_FFI_ICHECK_EQ) style used for other checks to return clear errors, and keep
get_algorithms / cublasLtMatmul calls unchanged otherwise so pointers are safe.

In `@flashinfer/gemm/gemm_base.py`:
- Around line 980-983: The BF16 cuBLASLt algorithm cache key in _get_algos is
using b.shape[0] (K) twice, so N is never included and different N values
collide; change the key construction in _get_algos to use b.shape[1] for N
(since mm_bf16 receives b in (K, N) layout) and make the same correction in the
other equivalent cache-key site (the second occurrence around the
mm_bf16-related logic) so the tuple becomes (M, N, K, compute_dt) (or equivalent
ordering used elsewhere) to uniquely key by N.
- Around line 165-179: The current path assumes at least one algorithm exists
and forces tactic=0 when tactic >= count, but if self._get_algos(inputs) returns
count==0 you must avoid calling module.bmm_fp8_run_with_algo with an
uninitialized algo buffer; update the block around self._get_algos(inputs) to
check if count == 0 and handle it (e.g., raise a clear RuntimeError or fall back
to the non-algo execution path) before computing/adjusting tactic and before
calling module.bmm_fp8_run_with_algo; refer to _get_algos and
module.bmm_fp8_run_with_algo (and the local variable tactic) to locate where to
add the guard.

---

Outside diff comments:
In `@flashinfer/aot.py`:
- Around line 490-498: The mm_bf16 cublaslt AOT module registration is
incorrectly placed inside the has_sm100-only block so builds targeting
compute_103 skip it; change the logic around
jit_specs.append(gen_mm_bf16_cublaslt_module()) so it is executed when either
has_sm100 or has_sm103 is true (i.e., wrap or move that append under an if
has_sm100 or has_sm103 guard), keeping the other SM100-only appends
(gen_fp4_quantization_sm100_module, gen_cutlass_fused_moe_sm100_module,
gen_gemm_sm100_module*, etc.) unchanged.

---

Nitpick comments:
In `@flashinfer/jit/gemm/core.py`:
- Around line 53-60: The JIT spec gen_mm_bf16_cublaslt_module currently calls
gen_jit_spec without arch scoping, causing mixed-arch builds; update the
gen_jit_spec invocation in gen_mm_bf16_cublaslt_module to pass
supported_major_versions=[10] (or the list containing SM10x major version) so
the module is only compiled for SM100/SM103-class GPUs (refer to
gen_mm_bf16_cublaslt_module, gen_jit_spec, and JitSpec to locate the change).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: cbf109a0-b993-446a-8aa7-c78df2ecc3d2

📥 Commits

Reviewing files that changed from the base of the PR and between 4941606 and 4616733.

📒 Files selected for processing (13)
  • benchmarks/routines/gemm.py
  • csrc/bmm_fp8.cu
  • csrc/flashinfer_gemm_binding.cu
  • csrc/mm_bf16_cublaslt.cu
  • flashinfer/aot.py
  • flashinfer/autotuner.py
  • flashinfer/gemm/gemm_base.py
  • flashinfer/jit/gemm/__init__.py
  • flashinfer/jit/gemm/core.py
  • include/flashinfer/gemm/bmm_fp8.cuh
  • include/flashinfer/gemm/mm_bf16_cublaslt.cuh
  • tests/gemm/test_bmm_bf16.py
  • tests/gemm/test_mm_bf16.py

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

Gemini is experiencing higher than usual traffic and was unable to create the review. Please try again in a few hours by commenting /gemini review.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
flashinfer/gemm/gemm_base.py (2)

1037-1040: ⚠️ Potential issue | 🟡 Minor

Report the BF16 zero-algo N dimension from b.shape[1].

Line 1039 still prints N={b.shape[0]}, which is K for the (K, N) layout accepted by mm_bf16(...). That makes the new diagnostic point people at the wrong dimension.

Suggested fix
-                        f"M={a.shape[0]}, N={b.shape[0]}, K={a.shape[1]}, "
+                        f"M={a.shape[0]}, N={b.shape[1]}, K={a.shape[1]}, "
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 1037 - 1040, The error message in
the cuBLASLt heuristic failure prints the wrong N dimension (uses b.shape[0],
which is K for the (K, N) layout); update the RuntimeError message in
gemm_base.py where the heuristic failure is raised to report N using b.shape[1]
instead of b.shape[0], and ensure the message still includes M (a.shape[0]), K
(a.shape[1]), and dtype (compute_out.dtype) so callers of mm_bf16(...) see the
correct diagnostic.

165-179: ⚠️ Potential issue | 🟠 Major

Guard count == 0 before bmm_fp8_run_with_algo().

If bmm_fp8_get_algos() returns 0, Line 167 still falls through to tactic = 0 and Line 169 dispatches with an uninitialized algo record. The BF16 cuBLASLt runner below already fails fast on this edge case; this path needs the same check.

Suggested fix
                 if tactic >= 0:
                     algo_buf, count = self._get_algos(inputs)
+                    if count == 0:
+                        raise RuntimeError(
+                            "cuBLASLt heuristic returned zero FP8 algorithms for "
+                            f"A={tuple(a.shape)}, B={tuple(b.shape)}, out={tuple(out.shape)}."
+                        )
                     if tactic >= count:
                         tactic = 0
                     module.bmm_fp8_run_with_algo(
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 165 - 179, The code calls
module.bmm_fp8_run_with_algo(...) using algo_buf from self._get_algos(inputs)
without handling the case when count == 0; update the block in the tactic >= 0
branch (where tactic, _get_algos, algo_buf, count are used) to check if count ==
0 and fail fast or return/raise (matching the BF16 cuBLASLt runner behavior)
instead of proceeding—i.e., after calling self._get_algos(inputs) if count == 0
log/raise an error or return early so bmm_fp8_run_with_algo is never invoked
with an uninitialized algo_buf.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 128-143: The code fetches a cuBLAS handle with
torch.cuda.current_blas_handle() without ensuring the current CUDA device
matches the input tensor's device, which can dispatch work to the wrong GPU;
update all call sites (e.g., inside CublasFp8GemmRunner and
CublasltBf16GemmRunner where module.bmm_fp8_get_algos and similar are invoked)
to switch the CUDA context to the tensor device before calling
torch.cuda.current_blas_handle() (use torch.cuda.device(a.device) or
torch.cuda.set_device(a.device) as a guard) and then restore/exit the context so
the retrieved handle is bound to a.device; apply the same pattern at every
flagged location (the calls at the sites mentioned in the review).

---

Duplicate comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 1037-1040: The error message in the cuBLASLt heuristic failure
prints the wrong N dimension (uses b.shape[0], which is K for the (K, N)
layout); update the RuntimeError message in gemm_base.py where the heuristic
failure is raised to report N using b.shape[1] instead of b.shape[0], and ensure
the message still includes M (a.shape[0]), K (a.shape[1]), and dtype
(compute_out.dtype) so callers of mm_bf16(...) see the correct diagnostic.
- Around line 165-179: The code calls module.bmm_fp8_run_with_algo(...) using
algo_buf from self._get_algos(inputs) without handling the case when count == 0;
update the block in the tactic >= 0 branch (where tactic, _get_algos, algo_buf,
count are used) to check if count == 0 and fail fast or return/raise (matching
the BF16 cuBLASLt runner behavior) instead of proceeding—i.e., after calling
self._get_algos(inputs) if count == 0 log/raise an error or return early so
bmm_fp8_run_with_algo is never invoked with an uninitialized algo_buf.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 714631ec-98a6-48ee-a06a-5ffc06099110

📥 Commits

Reviewing files that changed from the base of the PR and between 4616733 and 37e2abf.

📒 Files selected for processing (2)
  • flashinfer/gemm/gemm_base.py
  • tests/gemm/test_mm_bf16.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/gemm/test_mm_bf16.py

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/gemm/gemm_base.py (1)

118-189: ⚠️ Potential issue | 🟠 Major

Include dtype-dependent tactic space in the autotuner key.

These runners now expose tactic lists that depend on a.dtype, b.dtype, and/or the compute/output dtype, but they don’t override get_cache_key_extras(). That lets autotuned choices bleed across FP8 format or BF16/FP32 output variants, so a cached tactic index can point at a different algo/plan list on the next call.

💡 Suggested fix
 class CublasFp8GemmRunner(TunableRunner):
+    def get_cache_key_extras(self, inputs: List[torch.Tensor]) -> tuple:
+        a, b, _, _, out, _ = inputs
+        return (a.dtype, b.dtype, out.dtype)
+
     def __init__(self):
         self._algo_cache: dict = {}
 ...
 class CublasltBf16GemmRunner(TunableRunner):
+    def get_cache_key_extras(self, inputs: List[torch.Tensor]) -> tuple:
+        _, _, _, _, out, _ = inputs
+        return (self._compute_dtype(out.dtype),)
+
     def __init__(self):
         self._algo_cache: dict = {}
 ...
 class CudnnFp8GemmRunner(TunableRunner):
+    def get_cache_key_extras(self, inputs: List[torch.Tensor]) -> tuple:
+        a, b, _, _, out, _ = inputs
+        return (a.dtype, b.dtype, out.dtype)
+
     def get_valid_tactics(
         self,
         inputs: List[torch.Tensor],
         profile: OptimizationProfile,
     ) -> List[int]:
 ...
 class CudnnMxfp8GemmRunner(TunableRunner):
+    def get_cache_key_extras(self, inputs: List[torch.Tensor]) -> tuple:
+        a, b, _, _, out, _ = inputs
+        return (a.dtype, b.dtype, out.dtype)
+
     def get_valid_tactics(
         self,
         inputs: List[torch.Tensor],
         profile: OptimizationProfile,
     ) -> List[int]:

Also applies to: 973-1061, 2969-3006, 7039-7073

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 118 - 189, The autotuner key for
CublasFp8GemmRunner must include tensor dtypes so tactic indices don't get
reused across different FP8/BF16/FP32 variants; add/override
get_cache_key_extras(self, inputs) in the CublasFp8GemmRunner class to return a
tuple of the dtype identifiers (e.g., a.dtype, b.dtype, out.dtype and any scale
tensor dtypes if relevant) derived from the inputs used in
get_valid_tactics/forward so the tuner cache separates entries by compute/output
formats; update the same pattern for the other runner classes mentioned (the
ones at the other ranges) that expose dtype-dependent tactic lists.
♻️ Duplicate comments (1)
flashinfer/gemm/gemm_base.py (1)

122-145: ⚠️ Potential issue | 🟠 Major

Fetch the BLAS handle under a.device.

torch.cuda.current_blas_handle() returns the current cuBLAS handle, and cuBLAS handles are device-scoped. These four lookups happen without first switching to the input tensor’s device, so multi-GPU calls can pick up a handle for the wrong GPU. Wrap each lookup in with torch.cuda.device(a.device): .... (docs.pytorch.org)

Also applies to: 163-188, 985-1013, 1030-1057

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 122 - 145, The code in _get_algos
obtains a cuBLAS handle via torch.cuda.current_blas_handle() without ensuring
the correct device context, which can cause wrong-GPU handles in multi-GPU runs;
wrap the handle lookup (and any other CUDA device-specific lookups) in a device
context using with torch.cuda.device(a.device): and call
torch.cuda.current_blas_handle() inside that block, then use that handle for
module.bmm_fp8_get_algos; apply the same pattern to the other similar lookup
sites flagged (the other cuBLAS handle/get calls in this file) so each lookup is
performed under the input tensor’s device context.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 441-443: The docstring for the out parameter incorrectly omits the
TGV path; update the out parameter description in gemm_base.py so it states that
preallocated outputs are supported by the TGV backend as well (since
_tgv_gemm_requirement() only restricts non-BF16 out_dtype and
TGVGemmRunner.forward() writes into the provided out), e.g. change the "Enabled
for CUTLASS, cuDNN, and cuBLASLt backends" phrase to include TGV or use a more
general phrasing that lists TGVGemmRunner/TGV as supported.

---

Outside diff comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 118-189: The autotuner key for CublasFp8GemmRunner must include
tensor dtypes so tactic indices don't get reused across different FP8/BF16/FP32
variants; add/override get_cache_key_extras(self, inputs) in the
CublasFp8GemmRunner class to return a tuple of the dtype identifiers (e.g.,
a.dtype, b.dtype, out.dtype and any scale tensor dtypes if relevant) derived
from the inputs used in get_valid_tactics/forward so the tuner cache separates
entries by compute/output formats; update the same pattern for the other runner
classes mentioned (the ones at the other ranges) that expose dtype-dependent
tactic lists.

---

Duplicate comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 122-145: The code in _get_algos obtains a cuBLAS handle via
torch.cuda.current_blas_handle() without ensuring the correct device context,
which can cause wrong-GPU handles in multi-GPU runs; wrap the handle lookup (and
any other CUDA device-specific lookups) in a device context using with
torch.cuda.device(a.device): and call torch.cuda.current_blas_handle() inside
that block, then use that handle for module.bmm_fp8_get_algos; apply the same
pattern to the other similar lookup sites flagged (the other cuBLAS handle/get
calls in this file) so each lookup is performed under the input tensor’s device
context.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f5e57a92-8f6e-492d-b499-e7227d7c8ce2

📥 Commits

Reviewing files that changed from the base of the PR and between 37e2abf and cac20a4.

📒 Files selected for processing (5)
  • csrc/mm_bf16_cublaslt.cu
  • flashinfer/aot.py
  • flashinfer/autotuner.py
  • flashinfer/gemm/gemm_base.py
  • flashinfer/jit/gemm/core.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • flashinfer/aot.py
  • flashinfer/autotuner.py
  • csrc/mm_bf16_cublaslt.cu

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
flashinfer/gemm/gemm_base.py (1)

444-445: ⚠️ Potential issue | 🟡 Minor

out docs still omit TGV preallocated output support.

This wording is still misleading: TGV accepts preallocated out (with bf16 constraints), so the docs should not imply out is unsupported on TGV.

Suggested doc wording
-    out: Optional[torch.Tensor]
-        Out tensor, shape (m, n), bf16, fp16, or fp32. Enabled for CUTLASS, cuDNN, and cuBLASLt
-        backends. Defaults to ``None``.
+    out: Optional[torch.Tensor]
+        Out tensor, shape (m, n). Preallocated output is supported by all backends.
+        TGV requires ``torch.bfloat16`` output; CUTLASS, cuDNN, and cuBLASLt also support fp16/fp32.
+        Defaults to ``None``.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 444 - 445, Update the docstring
for the parameter "out" in gemm_base.py to explicitly state that TGV backend
accepts a preallocated out tensor (subject to TGV's bf16 dtype constraints)
rather than implying out is unsupported on TGV; keep existing notes that out can
be bf16/fp16/fp32 for CUTLASS, cuDNN, and cuBLASLt, and add a sentence like "TGV
also accepts a preallocated out tensor but currently requires bf16 dtype"
referencing the "out" parameter and the TGV backend to make the support and
dtype constraint clear.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 444-445: Update the docstring for the parameter "out" in
gemm_base.py to explicitly state that TGV backend accepts a preallocated out
tensor (subject to TGV's bf16 dtype constraints) rather than implying out is
unsupported on TGV; keep existing notes that out can be bf16/fp16/fp32 for
CUTLASS, cuDNN, and cuBLASLt, and add a sentence like "TGV also accepts a
preallocated out tensor but currently requires bf16 dtype" referencing the "out"
parameter and the TGV backend to make the support and dtype constraint clear.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 77499ae3-7c49-4096-b218-91c4e9ae5172

📥 Commits

Reviewing files that changed from the base of the PR and between cac20a4 and 9e82e66.

📒 Files selected for processing (1)
  • flashinfer/gemm/gemm_base.py

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant