Skip to content

Conversation

@raayandhar
Copy link
Contributor

@raayandhar raayandhar commented Nov 10, 2025

📌 Description

This issue was opened a little while ago (#1974) and I finally got a chance to tackle it. Feature request for BF16 GEMM. I decided to try and implement using CUTLASS backend. The issue poster was using B200 so I implemented for B200 (SM100) as well.

🔍 Related Issues

Feature request: #1974

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added high-performance mm_bf16() and bmm_bf16() APIs backed by a Cutlass SM100 BF16 runner with autotuning, workspace management, and runtime tactic selection (including a query for available tactics).
    • Added a JIT generator to emit multiple SM100 BF16 kernel variants for runtime selection.
  • Documentation

    • Added BF16 GEMM API docs and autosummary entries.
  • Tests

    • Added unit tests validating mm_bf16() and bmm_bf16() correctness on supported GPUs.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 10, 2025

Walkthrough

Adds SM100 Cutlass BF16 GEMM support: new CUDA runner and FFI entry points, SM100-specific kernel templates and Jinja instantiations, public Python mm_bf16/bmm_bf16 APIs with JIT generator, headers for runner/configs, tests, and documentation. Includes workspace/config probing and tactic selection.

Changes

Cohort / File(s) Summary
CUDA FFI & Runner Implementation
csrc/bf16_gemm_cutlass.cu
New CUDA source exposing TVM FFI entry points (bf16_gemm, bf16_gemm_tactic_num), runtime config selection (getBf16GemmConfig), generic runGemm<T>, workspace handling, input/output validation, and dispatch between half and __nv_bfloat16. Explicit template instantiations for runners added.
Jinja Instantiations
csrc/bf16_gemm_cutlass.jinja
Jinja template that instantiates multiple BF16/FP16 kernel variants for several CTA/cluster parameter combinations (SM100 variants).
Public Headers — Runner Interface
include/flashinfer/gemm/bf16_gemm_cutlass.h
Adds CutlassBf16GemmRunnerInterface and templated CutlassBf16GemmRunner<T> declarations (gemm, getWorkspaceSize, getConfigs).
Kernel Templates & Launchers
include/flashinfer/gemm/bf16_gemm_cutlass_template.h, include/flashinfer/gemm/bf16_gemm_template_sm100.h
Implements SM100-specific BF16 kernel launchers, arch/cluster dispatch, workspace sizing/probing, type adapters (_1SM, _2SM), tile/cluster enumerations, and macro INSTANCE_BF16_GEMM_TEMPLATE_SM100.
Python API & Exports
flashinfer/gemm/gemm_base.py, flashinfer/gemm/__init__.py, flashinfer/__init__.py
Adds public APIs mm_bf16 and bmm_bf16, module loader get_gemm_sm100_module_cutlass_bf16, internal bf16_gemm_sm100 path invoking the Cutlass BF16 runner, and re-exports mm_bf16/bmm_bf16.
JIT Generation
flashinfer/jit/gemm/core.py, flashinfer/jit/gemm/__init__.py
Adds gen_gemm_sm100_module_cutlass_bf16() JIT generator that renders bf16_gemm_cutlass.jinja variants, assembles source paths (including base .cu), computes nvcc flags, and returns a JitSpec; exports the generator.
Tests & Documentation
tests/gemm/test_mm_bf16.py, tests/gemm/test_bmm_bf16.py, docs/api/gemm.rst
Adds parameterized tests for mm_bf16 and bmm_bf16 (cosine similarity validation) and docs entry for BF16 GEMM in API docs.

Sequence Diagram(s)

sequenceDiagram
    participant Py as Python Test / User
    participant API as mm_bf16 / bmm_bf16
    participant BF16SM as bf16_gemm_sm100
    participant JIT as gen_gemm_sm100_module_cutlass_bf16
    participant FFI as CUDA FFI (bf16_gemm)
    participant Runner as CutlassBf16GemmRunner
    participant Kernel as Cutlass kernel

    Py->>API: call mm_bf16 / bmm_bf16(a,b,out_dtype)
    API->>API: validate shapes & dtype\nallocate/validate out & workspace
    API->>BF16SM: bf16_gemm_sm100(a,b,out,workspace)
    BF16SM->>JIT: load/get SM100 BF16 module (JIT)
    JIT-->>BF16SM: module with bf16_gemm_runner
    BF16SM->>FFI: call bf16_gemm(...) via FFI (tactic)
    FFI->>Runner: getBf16GemmConfig -> choose tactic/config
    FFI->>Runner: runGemm<T>() -> compute workspace & launch
    Runner->>Kernel: launch Cutlass kernel on stream
    Kernel-->>Runner: kernel complete
    Runner-->>FFI: return status
    FFI-->>BF16SM: completed
    BF16SM-->>API: done
    API-->>Py: return output
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~45 minutes

  • Files needing extra attention:
    • SM100-specific kernel dispatch and tile/cluster mapping (include/flashinfer/gemm/*.h)
    • Workspace sizing, memoization and runtime probing (bf16_gemm_cutlass_template.h, bf16_gemm_cutlass.cu)
    • FFI boundary correctness and dtype handling (csrc/bf16_gemm_cutlass.cu)
    • JIT generator source assembly and nvcc flag choices (flashinfer/jit/gemm/core.py)

Suggested reviewers

  • aleozlx
  • ttyio
  • nvmbreughe
  • djmmoss
  • yzh119
  • cyx-6
  • bkryu
  • wenscarl

Poem

🐰 I hopped through code and wove a thread,
BF16 kernels hum where tactics led,
Tests chirp loud and JITs compile,
Workspace snug — kernels run in style,
A tiny rabbit cheers: performance ahead!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 11.43% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'feat: BF16 GEMM using CUTLASS backend for SM100' clearly and specifically describes the main change: adding BF16 GEMM support using CUTLASS for SM100 architecture.
Description check ✅ Passed The PR description follows the template structure with Description, Related Issues, and Pre-commit/Test checklists mostly complete. All required sections are present and filled with substantive information including feature rationale, related issue link, and test results.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@raayandhar
Copy link
Contributor Author

raayandhar commented Nov 10, 2025

Currently there is an error about the second matrix being non-contiguous:
RuntimeError: Check failed: (mat2.IsContiguous()) is false: mat2 must be contiguous
I am trying to work on it. However, I have limited access to B200s so it may be a bit difficult. I am also a newbie when it comes to CUTLASS, so if any experts could provide any feedback here, I would really appreciate. Especially concerning tile sizes, etc. Not sure what the best choices are (some seem to run into an error about SMEM space, which seems surprising to me?)

@raayandhar raayandhar force-pushed the user/rdhar/cutlass_bf16_gemm_sm100 branch from dd6216f to aaaee56 Compare November 10, 2025 04:09
@raayandhar raayandhar changed the title [FEAT] BF16 GEMM using CUTLASS backend for SM100 feat: BF16 GEMM using CUTLASS backend for SM100 Nov 10, 2025
@raayandhar raayandhar marked this pull request as ready for review November 10, 2025 04:11
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f5a06a4 and 8ce4cb4.

📒 Files selected for processing (13)
  • csrc/bf16_gemm_cutlass.cu (1 hunks)
  • csrc/bf16_gemm_cutlass.jinja (1 hunks)
  • docs/api/gemm.rst (1 hunks)
  • flashinfer/__init__.py (1 hunks)
  • flashinfer/gemm/gemm_base.py (4 hunks)
  • flashinfer/jit/gemm/__init__.py (2 hunks)
  • flashinfer/jit/gemm/core.py (1 hunks)
  • include/flashinfer/gemm/bf16_gemm_cutlass.h (1 hunks)
  • include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1 hunks)
  • include/flashinfer/gemm/bf16_gemm_template_sm100.h (1 hunks)
  • include/flashinfer/gemm/fp8_gemm_cutlass_template.h (0 hunks)
  • tests/gemm/test_bmm_bf16.py (1 hunks)
  • tests/gemm/test_mm_bf16.py (1 hunks)
💤 Files with no reviewable changes (1)
  • include/flashinfer/gemm/fp8_gemm_cutlass_template.h
🧰 Additional context used
🧬 Code graph analysis (10)
flashinfer/jit/gemm/__init__.py (1)
flashinfer/jit/gemm/core.py (1)
  • gen_gemm_sm100_module_cutlass_bf16 (193-237)
include/flashinfer/gemm/bf16_gemm_cutlass.h (3)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
  • flashinfer (41-134)
  • gemm (42-91)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
  • gemm (44-176)
flashinfer/gemm/gemm_base.py (1)
  • CutlassBf16GemmRunner (497-520)
flashinfer/gemm/gemm_base.py (5)
flashinfer/jit/gemm/core.py (2)
  • gen_gemm_sm100_module (240-316)
  • gen_gemm_sm100_module_cutlass_bf16 (193-237)
flashinfer/utils.py (3)
  • supported_compute_capability (773-853)
  • _get_cache_buf (205-211)
  • get_compute_capability (252-255)
flashinfer/autotuner.py (7)
  • TunableRunner (194-247)
  • OptimizationProfile (168-183)
  • AutoTuner (335-784)
  • TuningConfig (101-141)
  • DynamicTensorSpec (41-82)
  • ConstraintSpec (86-97)
  • choose_one (400-529)
csrc/bf16_gemm_cutlass.cu (4)
  • bf16_gemm_tactic_num (149-156)
  • bf16_gemm_tactic_num (149-149)
  • bf16_gemm (144-147)
  • bf16_gemm (144-145)
flashinfer/fused_moe/utils.py (2)
  • get_last_power_of_2_num_tokens_buckets (206-215)
  • last_positive_power_of_2 (183-188)
tests/gemm/test_bmm_bf16.py (3)
flashinfer/autotuner.py (1)
  • autotune (251-262)
flashinfer/gemm/gemm_base.py (1)
  • bmm_bf16 (250-313)
flashinfer/utils.py (2)
  • get_compute_capability (252-255)
  • is_compute_capability_supported (979-994)
flashinfer/__init__.py (1)
flashinfer/gemm/gemm_base.py (2)
  • bmm_bf16 (250-313)
  • mm_bf16 (183-246)
tests/gemm/test_mm_bf16.py (3)
flashinfer/autotuner.py (1)
  • autotune (251-262)
flashinfer/gemm/gemm_base.py (1)
  • mm_bf16 (183-246)
flashinfer/utils.py (2)
  • get_compute_capability (252-255)
  • is_compute_capability_supported (979-994)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (4)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
  • flashinfer (26-60)
  • gemm (27-59)
include/flashinfer/gemm/fp8_gemm_cutlass_template.h (4)
  • flashinfer (41-145)
  • gemm (42-95)
  • std (184-184)
  • std (185-185)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (6)
  • gemm (44-176)
  • _1SM (53-57)
  • cutlass (135-135)
  • cutlass (136-136)
  • cutlass (137-137)
  • cutlass (138-138)
include/flashinfer/gemm/cutlass_gemm_configs.h (1)
  • CutlassTileConfigSM100 (106-425)
flashinfer/jit/gemm/core.py (2)
flashinfer/jit/core.py (2)
  • JitSpec (213-312)
  • gen_jit_spec (315-381)
flashinfer/compilation_context.py (1)
  • get_nvcc_flags_list (50-68)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (2)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
  • flashinfer (26-60)
  • gemm (27-59)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
  • flashinfer (41-134)
  • gemm (42-91)
csrc/bf16_gemm_cutlass.cu (4)
flashinfer/gemm/gemm_base.py (1)
  • CutlassBf16GemmRunner (497-520)
include/flashinfer/gemm/bf16_gemm_cutlass.h (1)
  • CutlassBf16GemmRunnerInterface (29-41)
include/flashinfer/gemm/cutlass_gemm_configs.h (1)
  • CutlassTileConfigSM100 (106-425)
csrc/tvm_ffi_utils.h (2)
  • get_stream (272-274)
  • encode_dlpack_dtype (29-31)
🪛 Clang (14.0.6)
include/flashinfer/gemm/bf16_gemm_cutlass.h

[error] 20-20: 'cuda_runtime_api.h' file not found

(clang-diagnostic-error)

include/flashinfer/gemm/bf16_gemm_cutlass_template.h

[error] 23-23: 'cutlass/arch/arch.h' file not found

(clang-diagnostic-error)

include/flashinfer/gemm/bf16_gemm_template_sm100.h

[error] 23-23: 'cutlass/arch/arch.h' file not found

(clang-diagnostic-error)

🪛 GitHub Actions: pre-commit
flashinfer/__init__.py

[error] 88-88: mypy: Module "flashinfer.gemm" has no attribute "bmm_bf16".


[error] 90-90: mypy: Module "flashinfer.gemm" has no attribute "mm_bf16".

🪛 Ruff (0.14.3)
flashinfer/gemm/gemm_base.py

218-218: Avoid specifying long messages outside the exception class

(TRY003)


220-220: Avoid specifying long messages outside the exception class

(TRY003)


230-232: Avoid specifying long messages outside the exception class

(TRY003)


234-236: Avoid specifying long messages outside the exception class

(TRY003)


238-240: Avoid specifying long messages outside the exception class

(TRY003)


284-284: Avoid specifying long messages outside the exception class

(TRY003)


286-286: Avoid specifying long messages outside the exception class

(TRY003)


297-299: Avoid specifying long messages outside the exception class

(TRY003)


301-303: Avoid specifying long messages outside the exception class

(TRY003)


305-307: Avoid specifying long messages outside the exception class

(TRY003)


500-500: Unused method argument: inputs

(ARG002)


501-501: Unused method argument: profile

(ARG002)


509-509: Unused method argument: do_preparation

(ARG002)


510-510: Unused method argument: kwargs

(ARG002)


592-592: Avoid specifying long messages outside the exception class

(TRY003)

flashinfer/jit/gemm/core.py

233-233: Consider [*nvcc_flags, "-DENABLE_BF16"] instead of concatenation

Replace with [*nvcc_flags, "-DENABLE_BF16"]

(RUF005)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (2)
docs/api/gemm.rst (2)

10-17: Documentation formatting is consistent and well-structured.

The BF16 GEMM subsection follows the established pattern of other GEMM sections in the file (consistent indentation, autosummary directive, toctree configuration). Placement at the beginning of the GEMM API documentation is logical and appropriate.


10-17: Documentation is complete and accurate.

The BF16 GEMM subsection correctly documents mm_bf16 and bmm_bf16—these are the only public-facing BF16 GEMM functions (verified by top-level exports in flashinfer/__init__.py). The bf16_gemm mentioned in the PR summary is an internal C++ binding and tuning identifier, not a public Python API.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (1)
flashinfer/gemm/gemm_base.py (1)

512-520: Materialize the transposed tensor before passing to the CUTLASS runner.

This is the root cause of the runtime error reported in the PR: b.transpose(-2, -1) returns a non-contiguous view, but the C++ binding requires contiguous input. The fix is to call .contiguous() on the transposed tensor.

Apply this fix:

                 a, b, out, workspace_buffer = inputs
                 module.bf16_gemm(
                     a,
-                    b.transpose(-2, -1),
+                    b.transpose(-2, -1).contiguous(),
                     out,
                     workspace_buffer,
                     tactic,
                 )

This issue was already identified in the previous review.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8ce4cb4 and 511d8e0.

📒 Files selected for processing (2)
  • flashinfer/gemm/__init__.py (2 hunks)
  • flashinfer/gemm/gemm_base.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/gemm/gemm_base.py (7)
include/flashinfer/gemm/bf16_gemm_cutlass.h (1)
  • gemm (27-59)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
  • gemm (44-176)
flashinfer/jit/gemm/core.py (1)
  • gen_gemm_sm100_module_cutlass_bf16 (193-237)
flashinfer/utils.py (3)
  • supported_compute_capability (773-853)
  • _get_cache_buf (205-211)
  • get_compute_capability (252-255)
flashinfer/autotuner.py (7)
  • TunableRunner (194-247)
  • OptimizationProfile (168-183)
  • AutoTuner (335-784)
  • TuningConfig (101-141)
  • DynamicTensorSpec (41-82)
  • ConstraintSpec (86-97)
  • choose_one (400-529)
csrc/bf16_gemm_cutlass.cu (4)
  • bf16_gemm_tactic_num (149-156)
  • bf16_gemm_tactic_num (149-149)
  • bf16_gemm (144-147)
  • bf16_gemm (144-145)
flashinfer/fused_moe/utils.py (2)
  • get_last_power_of_2_num_tokens_buckets (206-215)
  • last_positive_power_of_2 (183-188)
🪛 Ruff (0.14.3)
flashinfer/gemm/gemm_base.py

218-218: Avoid specifying long messages outside the exception class

(TRY003)


220-220: Avoid specifying long messages outside the exception class

(TRY003)


230-232: Avoid specifying long messages outside the exception class

(TRY003)


234-236: Avoid specifying long messages outside the exception class

(TRY003)


238-240: Avoid specifying long messages outside the exception class

(TRY003)


284-284: Avoid specifying long messages outside the exception class

(TRY003)


286-286: Avoid specifying long messages outside the exception class

(TRY003)


297-299: Avoid specifying long messages outside the exception class

(TRY003)


301-303: Avoid specifying long messages outside the exception class

(TRY003)


305-307: Avoid specifying long messages outside the exception class

(TRY003)


500-500: Unused method argument: inputs

(ARG002)


501-501: Unused method argument: profile

(ARG002)


509-509: Unused method argument: do_preparation

(ARG002)


510-510: Unused method argument: kwargs

(ARG002)


592-592: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (1)
flashinfer/gemm/__init__.py (1)

1-38: LGTM! Public API exports are correctly wired.

The new BF16 GEMM functions (bmm_bf16 and mm_bf16) are properly imported from gemm_base and exposed through the module's __all__ list, making them available as part of the public API.

@raayandhar raayandhar changed the title feat: BF16 GEMM using CUTLASS backend for SM100 feat: (wip) BF16 GEMM using CUTLASS backend for SM100 Nov 10, 2025
@raayandhar raayandhar force-pushed the user/rdhar/cutlass_bf16_gemm_sm100 branch from 511d8e0 to fbe5723 Compare November 12, 2025 01:54
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (5)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)

156-173: Restore workspace probe handling before launching kernels

CutlassBf16GemmRunner::getWorkspaceSize() calls this launcher with workspacePtr == nullptr to learn how big the buffer must be. Today we immediately throw because workspaceBytes == 0, so the probe reports 0 and the next real launch still fails with “insufficient workspace”. Please short‑circuit the probe and return the computed size instead of throwing.

   size_t workspace_size = gemm.get_workspace_size(arguments);
+  if (workspacePtr == nullptr) {
+    return workspace_size;
+  }
   if (workspace_size > workspaceBytes) {
flashinfer/gemm/gemm_base.py (4)

182-246: Validate BF16 MM inputs before dispatch

The CUTLASS path assumes 2‑D bf16 matrices on the same CUDA device with contiguous row‑major layout. Without the early guards we can accept the wrong dtype, mismatched shapes/devices, or a strided view and only fail deep inside the kernel (or produce garbage). Please restore the validation/contiguity fixes before touching the workspace.

-    if backend != "cutlass":
+    if backend != "cutlass":
         raise ValueError(f"Unsupported backend: {backend}. Only cutlass is available.")
     if out_dtype not in (torch.bfloat16, torch.float16):
         raise ValueError("Only bf16 and fp16 outputs are supported.")
+
+    if a.ndim != 2 or b.ndim != 2:
+        raise ValueError(f"mm_bf16 expects 2D tensors. Got a.ndim={a.ndim}, b.ndim={b.ndim}.")
+    if a.dtype != torch.bfloat16 or b.dtype != torch.bfloat16:
+        raise ValueError(f"Inputs must be bfloat16. Got a.dtype={a.dtype}, b.dtype={b.dtype}.")
+    if a.shape[1] != b.shape[0]:
+        raise ValueError(
+            f"Shape mismatch for matrix multiplication. a.shape[1]={a.shape[1]} must equal b.shape[0]={b.shape[0]}."
+        )
+    if a.device != b.device:
+        raise ValueError(f"Device mismatch. a.device={a.device}, b.device={b.device}.")
+    if not a.is_contiguous():
+        a = a.contiguous()
+    if not b.is_contiguous():
+        b = b.contiguous()
+    if out is not None and not out.is_contiguous():
+        raise ValueError("Output tensor must be contiguous for the CUTLASS backend.")

249-313: Do the same validation for BMM

The batched entry point has the same holes: wrong dtype, rank, device, or non‑contiguous slices go straight into CUTLASS and fail later (or worse, corrupt results). Please add the missing checks for 3‑D tensors, matching batch/K dims, same device, and enforce contiguity before launching.

-    if backend != "cutlass":
+    if backend != "cutlass":
         raise ValueError(f"Unsupported backend: {backend}. Only cutlass is available.")
     if out_dtype not in (torch.bfloat16, torch.float16):
         raise ValueError("Only bf16 and fp16 outputs are supported.")
+
+    if a.ndim != 3 or b.ndim != 3:
+        raise ValueError(f"bmm_bf16 expects 3D tensors. Got a.ndim={a.ndim}, b.ndim={b.ndim}.")
+    if a.dtype != torch.bfloat16 or b.dtype != torch.bfloat16:
+        raise ValueError(f"Inputs must be bfloat16. Got a.dtype={a.dtype}, b.dtype={b.dtype}.")
+    if a.shape[0] != b.shape[0]:
+        raise ValueError(
+            f"Batch size mismatch. a.shape[0]={a.shape[0]} must equal b.shape[0]={b.shape[0]}."
+        )
+    if a.shape[2] != b.shape[1]:
+        raise ValueError(
+            f"K dimension mismatch. a.shape[2]={a.shape[2]} must equal b.shape[1]={b.shape[1]}."
+        )
+    if a.device != b.device:
+        raise ValueError(f"Device mismatch. a.device={a.device}, b.device={b.device}.")
+    if not a.is_contiguous():
+        a = a.contiguous()
+    if not b.is_contiguous():
+        b = b.contiguous()
+    if out is not None and not out.is_contiguous():
+        raise ValueError("Output tensor must be contiguous for the CUTLASS backend.")

512-520: Materialize column‑major B before calling the kernel

bf16_gemm still receives b.transpose(-2, -1) directly, which is a non‑contiguous view and reproduces the runtime error (“mat2 must be contiguous”). Please allocate the column‑major buffer before dispatching to CUTLASS.

-                module.bf16_gemm(
-                    a,
-                    b.transpose(-2, -1),
-                    out,
-                    workspace_buffer,
-                    tactic,
-                )
+                b_col_major = b.transpose(-2, -1).contiguous()
+                module.bf16_gemm(
+                    a,
+                    b_col_major,
+                    out,
+                    workspace_buffer,
+                    tactic,
+                )

590-592: Report the actual device when no runner is found

When a lives on a non‑default GPU, torch.device("cuda") queries device 0 and we raise “sm100” even if the tensor was on sm90. Use the tensor’s device so the error reflects reality.

-        major, minor = get_compute_capability(torch.device("cuda"))
+        major, minor = get_compute_capability(a.device)
🧹 Nitpick comments (1)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)

54-91: Cluster shape dispatch with limited configuration support.

The function correctly dispatches based on cluster shape, with appropriate error handling for unsupported configurations. The limitation to only ClusterShape_1x1x1 aligns with the PR author's note about tile size and SMEM constraints during initial development.

Note: Line 66 has a break statement after return, which is unreachable but harmless.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 511d8e0 and fbe5723.

📒 Files selected for processing (14)
  • csrc/bf16_gemm_cutlass.cu (1 hunks)
  • csrc/bf16_gemm_cutlass.jinja (1 hunks)
  • docs/api/gemm.rst (1 hunks)
  • flashinfer/__init__.py (1 hunks)
  • flashinfer/gemm/__init__.py (2 hunks)
  • flashinfer/gemm/gemm_base.py (4 hunks)
  • flashinfer/jit/gemm/__init__.py (2 hunks)
  • flashinfer/jit/gemm/core.py (1 hunks)
  • include/flashinfer/gemm/bf16_gemm_cutlass.h (1 hunks)
  • include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1 hunks)
  • include/flashinfer/gemm/bf16_gemm_template_sm100.h (1 hunks)
  • include/flashinfer/gemm/fp8_gemm_cutlass_template.h (0 hunks)
  • tests/gemm/test_bmm_bf16.py (1 hunks)
  • tests/gemm/test_mm_bf16.py (1 hunks)
💤 Files with no reviewable changes (1)
  • include/flashinfer/gemm/fp8_gemm_cutlass_template.h
🚧 Files skipped from review as they are similar to previous changes (6)
  • tests/gemm/test_mm_bf16.py
  • flashinfer/gemm/init.py
  • csrc/bf16_gemm_cutlass.jinja
  • tests/gemm/test_bmm_bf16.py
  • flashinfer/jit/gemm/init.py
  • csrc/bf16_gemm_cutlass.cu
🧰 Additional context used
🧬 Code graph analysis (6)
flashinfer/__init__.py (1)
flashinfer/gemm/gemm_base.py (2)
  • bmm_bf16 (250-313)
  • mm_bf16 (183-246)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (3)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
  • flashinfer (26-60)
  • gemm (27-59)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (6)
  • gemm (44-176)
  • _1SM (53-57)
  • cutlass (135-135)
  • cutlass (136-136)
  • cutlass (137-137)
  • cutlass (138-138)
include/flashinfer/gemm/cutlass_gemm_configs.h (1)
  • CutlassTileConfigSM100 (106-425)
flashinfer/jit/gemm/core.py (2)
flashinfer/jit/core.py (2)
  • JitSpec (213-312)
  • gen_jit_spec (315-381)
flashinfer/compilation_context.py (1)
  • get_nvcc_flags_list (50-68)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
  • flashinfer (41-134)
  • gemm (42-91)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
  • gemm (44-176)
flashinfer/gemm/gemm_base.py (5)
flashinfer/jit/gemm/core.py (1)
  • gen_gemm_sm100_module_cutlass_bf16 (193-237)
flashinfer/utils.py (3)
  • supported_compute_capability (773-853)
  • _get_cache_buf (205-211)
  • get_compute_capability (252-255)
flashinfer/autotuner.py (5)
  • TunableRunner (194-247)
  • OptimizationProfile (168-183)
  • AutoTuner (335-784)
  • TuningConfig (101-141)
  • choose_one (400-529)
csrc/bf16_gemm_cutlass.cu (4)
  • bf16_gemm_tactic_num (149-156)
  • bf16_gemm_tactic_num (149-149)
  • bf16_gemm (144-147)
  • bf16_gemm (144-145)
flashinfer/fused_moe/utils.py (2)
  • get_last_power_of_2_num_tokens_buckets (206-215)
  • last_positive_power_of_2 (183-188)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (2)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
  • flashinfer (26-60)
  • gemm (27-59)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
  • flashinfer (41-134)
  • gemm (42-91)
🪛 Clang (14.0.6)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h

[error] 23-23: 'cutlass/arch/arch.h' file not found

(clang-diagnostic-error)

include/flashinfer/gemm/bf16_gemm_cutlass.h

[error] 20-20: 'cuda_runtime_api.h' file not found

(clang-diagnostic-error)

include/flashinfer/gemm/bf16_gemm_template_sm100.h

[error] 23-23: 'cutlass/arch/arch.h' file not found

(clang-diagnostic-error)

🪛 Ruff (0.14.4)
flashinfer/jit/gemm/core.py

233-233: Consider [*nvcc_flags, "-DENABLE_BF16"] instead of concatenation

Replace with [*nvcc_flags, "-DENABLE_BF16"]

(RUF005)

flashinfer/gemm/gemm_base.py

218-218: Avoid specifying long messages outside the exception class

(TRY003)


220-220: Avoid specifying long messages outside the exception class

(TRY003)


230-232: Avoid specifying long messages outside the exception class

(TRY003)


234-236: Avoid specifying long messages outside the exception class

(TRY003)


238-240: Avoid specifying long messages outside the exception class

(TRY003)


284-284: Avoid specifying long messages outside the exception class

(TRY003)


286-286: Avoid specifying long messages outside the exception class

(TRY003)


297-299: Avoid specifying long messages outside the exception class

(TRY003)


301-303: Avoid specifying long messages outside the exception class

(TRY003)


305-307: Avoid specifying long messages outside the exception class

(TRY003)


500-500: Unused method argument: inputs

(ARG002)


501-501: Unused method argument: profile

(ARG002)


509-509: Unused method argument: do_preparation

(ARG002)


510-510: Unused method argument: kwargs

(ARG002)


592-592: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (6)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)

29-41: Well-designed interface for BF16 GEMM runner.

The abstract interface provides a clean contract with appropriate virtual methods for GEMM operations, workspace management, and configuration enumeration. The virtual destructor is correctly included for safe polymorphic deletion.


43-57: Template class declaration follows proper separation pattern.

The template class declaration correctly inherits from the interface and overrides all pure virtual methods. The separation of declaration (here) and definition (in the template header) is appropriate for template code.

include/flashinfer/gemm/bf16_gemm_cutlass_template.h (4)

19-34: Appropriate diagnostic pragmas for CUTLASS integration.

The GCC diagnostic pragmas correctly suppress strict-aliasing warnings around CUTLASS headers, which is necessary since CUTLASS may use type punning internally.


136-143: GEMM implementation correctly delegates to dispatch logic.

The implementation properly forwards all parameters to dispatchToArch with appropriate type casting.


186-210: Configuration enumeration with limited initial support.

The function correctly enumerates candidate configurations by combining tile configs and cluster shapes. The current limitation to a single tile configuration (CtaShape64x64x128B) and cluster shape (ClusterShape_1x1x1) aligns with the PR objectives and the author's noted constraints regarding SMEM space and limited B200 hardware access for testing.

As additional tile sizes and cluster shapes are validated on SM100 hardware, uncomment the relevant lines to expand the configuration space.


99-103: Verify the intentional A↔B and m↔n parameter swap is correct for the kernel expectations.

The parameter swap pattern dispatchGemmClusterShapeSm100<T, arch, 64, 64, 128>(B, A, static_cast<T*>(D), n, m, k, ...) is applied consistently across all tile configurations in both bf16 and fp8 GEMM implementations. This is paired with explicit layout declarations: LayoutA = RowMajor and LayoutB = ColumnMajor.

While the consistency of this pattern across multiple files strongly suggests it is intentional for layout conversion, please confirm that this parameter reordering matches the actual kernel signature and expectations for dispatchGemmClusterShapeSm100.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (7)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)

166-173: Improve MNK hash mixing to avoid collisions.

XORing h1 ^ h2 ^ h3 collapses different (m,n,k) permutations to the same bucket, so cached workspace sizes can be reused for incompatible shapes. Combine the hashes with a proper mixer instead.

   struct MNKHash {
     size_t operator()(const MNK& mnk) const {
       auto h1 = std::hash<int>{}(std::get<0>(mnk));
       auto h2 = std::hash<int>{}(std::get<1>(mnk));
       auto h3 = std::hash<int>{}(std::get<2>(mnk));
-      return h1 ^ h2 ^ h3;
+      size_t seed = h1;
+      seed ^= h2 + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2);
+      seed ^= h3 + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2);
+      return seed;
     }
   };

175-183: Guard the static workspace cache with a mutex.

workspace_hashmap is mutated without synchronization; concurrent calls to getWorkspaceSize will race on find()/operator[]. Protect the cache with a lock.

-  static std::unordered_map<MNK, size_t, MNKHash> workspace_hashmap;
+  static std::unordered_map<MNK, size_t, MNKHash> workspace_hashmap;
+  static std::mutex workspace_mutex;
 
   size_t workspace_size = 0;
-  if (workspace_hashmap.find(std::make_tuple(m, n, k)) == workspace_hashmap.end()) {
-    workspace_size = CutlassBf16GemmRunner<T>::getWorkspaceSizeImpl(m, n, k);
-    workspace_hashmap[std::make_tuple(m, n, k)] = workspace_size;
-  } else {
-    workspace_size = workspace_hashmap[std::make_tuple(m, n, k)];
-  }
+  const MNK key = std::make_tuple(m, n, k);
+  {
+    std::lock_guard<std::mutex> lock(workspace_mutex);
+    auto it = workspace_hashmap.find(key);
+    if (it != workspace_hashmap.end()) {
+      return it->second;
+    }
+    workspace_size = CutlassBf16GemmRunner<T>::getWorkspaceSizeImpl(m, n, k);
+    workspace_hashmap.emplace(key, workspace_size);
+  }
   return workspace_size;
tests/gemm/test_mm_bf16.py (1)

14-21: Skip on CPU-only test environments.

get_compute_capability(torch.device("cuda")) raises when CUDA isn’t available, causing the entire suite to error out instead of skipping. Guard this with if not torch.cuda.is_available(): pytest.skip(...) before the capability query.

tests/gemm/test_bmm_bf16.py (1)

15-22: Gracefully skip when CUDA is unavailable.

Like the MM test, calling get_compute_capability(torch.device("cuda")) without checking torch.cuda.is_available() hard-fails on CPU-only setups. Add a skip guard before querying the device.

flashinfer/gemm/gemm_base.py (3)

217-240: Validate inputs before firing the kernel.

mm_bf16 still accepts tensors with wrong dtype, shape, or device, which the CUTLASS runner interprets incorrectly (e.g., passing fp16 data corrupts results). Add explicit checks for bf16 dtype, matching inner dimensions, and matching devices at the top of the function so misuse fails fast.

+    if a.dtype != torch.bfloat16 or b.dtype != torch.bfloat16:
+        raise ValueError(
+            f"Inputs must be bfloat16. Got a.dtype={a.dtype}, b.dtype={b.dtype}."
+        )
+    if a.ndim != 2 or b.ndim != 2:
+        raise ValueError(
+            f"Inputs must be 2D matrices. Got a.ndim={a.ndim}, b.ndim={b.ndim}."
+        )
+    if a.shape[1] != b.shape[0]:
+        raise ValueError(
+            f"Shape mismatch: a.shape[1]={a.shape[1]} must equal b.shape[0]={b.shape[0]}."
+        )
+    if a.device != b.device:
+        raise ValueError(
+            f"Device mismatch: a.device={a.device}, b.device={b.device}."
+        )

288-307: Add basic sanity checks for batched inputs.

bmm_bf16 also needs dtype/shape/device validation; otherwise mismatched batch sizes or wrong K dimensions surface as low-level CUTLASS failures. Please mirror the checks from mm_bf16 for 3D tensors (batch, m, k) and (batch, k, n), ensuring matching batch/K dimensions and bf16 dtype.

+    if A.dtype != torch.bfloat16 or B.dtype != torch.bfloat16:
+        raise ValueError(
+            f"Inputs must be bfloat16. Got A.dtype={A.dtype}, B.dtype={B.dtype}."
+        )
+    if A.ndim != 3 or B.ndim != 3:
+        raise ValueError(
+            f"Inputs must be 3D tensors. Got A.ndim={A.ndim}, B.ndim={B.ndim}."
+        )
+    if A.shape[0] != B.shape[0]:
+        raise ValueError(
+            f"Batch mismatch: A.shape[0]={A.shape[0]} != B.shape[0]={B.shape[0]}."
+        )
+    if A.shape[2] != B.shape[1]:
+        raise ValueError(
+            f"K mismatch: A.shape[2]={A.shape[2]} must equal B.shape[1]={B.shape[1]}."
+        )
+    if A.device != B.device:
+        raise ValueError(
+            f"Device mismatch: A.device={A.device}, B.device={B.device}."
+        )

512-519: Make the transposed B operand contiguous.

The CUTLASS binding now enforces mat2.is_contiguous(). Transposing on the fly hands it a strided view and triggers the runtime error you reported. Materialize the column-major buffer before launching.

-                module.bf16_gemm(
-                    a,
-                    b.transpose(-2, -1),
-                    out,
-                    workspace_buffer,
-                    tactic,
-                )
+                b_col_major = b.transpose(-2, -1).contiguous()
+                module.bf16_gemm(
+                    a,
+                    b_col_major,
+                    out,
+                    workspace_buffer,
+                    tactic,
+                )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between fbe5723 and 7f62bb0.

📒 Files selected for processing (4)
  • flashinfer/gemm/gemm_base.py (4 hunks)
  • include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1 hunks)
  • tests/gemm/test_bmm_bf16.py (1 hunks)
  • tests/gemm/test_mm_bf16.py (1 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.563Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
📚 Learning: 2025-11-12T03:35:17.563Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.563Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/gemm/bf16_gemm_cutlass_template.h
  • flashinfer/gemm/gemm_base.py
🧬 Code graph analysis (4)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (4)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
  • flashinfer (26-60)
  • gemm (27-59)
include/flashinfer/gemm/fp8_gemm_cutlass_template.h (4)
  • flashinfer (41-145)
  • gemm (42-95)
  • std (184-184)
  • std (185-185)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (6)
  • gemm (44-176)
  • _1SM (53-57)
  • cutlass (135-135)
  • cutlass (136-136)
  • cutlass (137-137)
  • cutlass (138-138)
include/flashinfer/gemm/cutlass_gemm_configs.h (1)
  • CutlassTileConfigSM100 (106-425)
tests/gemm/test_mm_bf16.py (2)
flashinfer/gemm/gemm_base.py (1)
  • mm_bf16 (183-246)
flashinfer/utils.py (2)
  • get_compute_capability (252-255)
  • is_compute_capability_supported (979-994)
tests/gemm/test_bmm_bf16.py (3)
flashinfer/autotuner.py (1)
  • autotune (251-262)
flashinfer/gemm/gemm_base.py (1)
  • bmm_bf16 (250-313)
flashinfer/utils.py (2)
  • get_compute_capability (252-255)
  • is_compute_capability_supported (979-994)
flashinfer/gemm/gemm_base.py (8)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)
  • gemm (42-91)
include/flashinfer/gemm/bf16_gemm_cutlass.h (1)
  • gemm (27-59)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
  • gemm (44-176)
flashinfer/jit/gemm/core.py (1)
  • gen_gemm_sm100_module_cutlass_bf16 (193-237)
flashinfer/utils.py (2)
  • supported_compute_capability (773-853)
  • _get_cache_buf (205-211)
flashinfer/autotuner.py (7)
  • TunableRunner (194-247)
  • OptimizationProfile (168-183)
  • AutoTuner (335-784)
  • TuningConfig (101-141)
  • DynamicTensorSpec (41-82)
  • ConstraintSpec (86-97)
  • choose_one (400-529)
csrc/bf16_gemm_cutlass.cu (4)
  • bf16_gemm_tactic_num (149-156)
  • bf16_gemm_tactic_num (149-149)
  • bf16_gemm (144-147)
  • bf16_gemm (144-145)
flashinfer/fused_moe/utils.py (2)
  • get_last_power_of_2_num_tokens_buckets (206-215)
  • last_positive_power_of_2 (183-188)
🪛 Clang (14.0.6)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h

[error] 23-23: 'cutlass/arch/arch.h' file not found

(clang-diagnostic-error)

🪛 Ruff (0.14.4)
flashinfer/gemm/gemm_base.py

218-218: Avoid specifying long messages outside the exception class

(TRY003)


220-220: Avoid specifying long messages outside the exception class

(TRY003)


230-232: Avoid specifying long messages outside the exception class

(TRY003)


234-236: Avoid specifying long messages outside the exception class

(TRY003)


238-240: Avoid specifying long messages outside the exception class

(TRY003)


284-284: Avoid specifying long messages outside the exception class

(TRY003)


286-286: Avoid specifying long messages outside the exception class

(TRY003)


297-299: Avoid specifying long messages outside the exception class

(TRY003)


301-303: Avoid specifying long messages outside the exception class

(TRY003)


305-307: Avoid specifying long messages outside the exception class

(TRY003)


500-500: Unused method argument: inputs

(ARG002)


501-501: Unused method argument: profile

(ARG002)


509-509: Unused method argument: do_preparation

(ARG002)


510-510: Unused method argument: kwargs

(ARG002)

@raayandhar raayandhar force-pushed the user/rdhar/cutlass_bf16_gemm_sm100 branch from d2c8547 to 8a58e45 Compare November 16, 2025 22:36
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (7)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)

156-170: Allow workspace probes to succeed when no buffer is provided.

CutlassBf16GemmRunner::getWorkspaceSizeImpl invokes this launcher with workspacePtr == nullptr and workspaceBytes == 0 to query the required size. The current code throws before returning the computed workspace_size, breaking workspace queries. Short-circuit when workspacePtr is nullptr to return the size without running the kernel.

   size_t workspace_size = gemm.get_workspace_size(arguments);
+  if (workspacePtr == nullptr) {
+    return workspace_size;
+  }
   if (workspace_size > workspaceBytes) {
     throw std::runtime_error("[Bf16 Gemm Runner] insufficient workspace");
   }
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)

166-173: Hash function prone to collisions.

The MNKHash function uses XOR to combine hash values (h1 ^ h2 ^ h3), which produces collisions for permutations of the same values. For example, (1, 2, 3) and (3, 2, 1) hash identically, potentially returning incorrect workspace sizes.

Use a proper hash combining algorithm:

   struct MNKHash {
     size_t operator()(const MNK& mnk) const {
       auto h1 = std::hash<int>{}(std::get<0>(mnk));
       auto h2 = std::hash<int>{}(std::get<1>(mnk));
       auto h3 = std::hash<int>{}(std::get<2>(mnk));
-      return h1 ^ h2 ^ h3;
+      // Combine hashes properly to avoid collisions
+      size_t seed = h1;
+      seed ^= h2 + 0x9e3779b9 + (seed << 6) + (seed >> 2);
+      seed ^= h3 + 0x9e3779b9 + (seed << 6) + (seed >> 2);
+      return seed;
     }
   };

175-184: Critical: Data race on static workspace cache.

The static workspace_hashmap at Line 175 is accessed concurrently without synchronization. While C++11+ guarantees thread-safe initialization of function-local statics, concurrent access via find() (Line 178) and operator[] (Lines 180, 182) creates data races if getWorkspaceSize is called from multiple threads.

Protect the map with a mutex:

+  static std::mutex workspace_mutex;
   static std::unordered_map<MNK, size_t, MNKHash> workspace_hashmap;

   size_t workspace_size = 0;
+  std::lock_guard<std::mutex> lock(workspace_mutex);
   if (workspace_hashmap.find(std::make_tuple(m, n, k)) == workspace_hashmap.end()) {

Alternatively, use std::shared_mutex with shared (read) and exclusive (write) locking for better concurrent read performance.

tests/gemm/test_bmm_bf16.py (1)

15-22: Guard test behind CUDA availability.

Calling get_compute_capability(torch.device("cuda")) without first checking torch.cuda.is_available() will raise an exception on non-CUDA systems instead of skipping gracefully.

Add an early CUDA check:

 def test_bmm_bf16(b, m, n, k, res_dtype):
+    if not torch.cuda.is_available():
+        pytest.skip("CUDA is not available")
     compute_capability = get_compute_capability(torch.device(device="cuda"))
flashinfer/gemm/gemm_base.py (3)

182-245: Add input validation for dtype, shape, and device consistency.

The function is missing essential input validation that could lead to cryptic errors downstream. Per previous review feedback, please add checks at the beginning of the function:

+    if a.dtype != torch.bfloat16 or b.dtype != torch.bfloat16:
+        raise ValueError(f"Inputs must be bfloat16. Got a.dtype={a.dtype}, b.dtype={b.dtype}.")
+    if a.shape[1] != b.shape[0]:
+        raise ValueError(
+            f"Shape mismatch for matrix multiplication. "
+            f"a.shape[1]={a.shape[1]} must equal b.shape[0]={b.shape[0]}."
+        )
+    if a.device != b.device:
+        raise ValueError(f"Device mismatch. a.device={a.device}, b.device={b.device}.")
+
     if backend != "cutlass":

248-312: Add input validation for dtype, shape, and device consistency.

Similar to mm_bf16, this function lacks essential input validation. Per previous review feedback, please add checks at the beginning:

+    if A.dtype != torch.bfloat16 or B.dtype != torch.bfloat16:
+        raise ValueError(f"Inputs must be bfloat16. Got A.dtype={A.dtype}, B.dtype={B.dtype}.")
+    if A.ndim != 3 or B.ndim != 3:
+        raise ValueError(f"Expected 3D tensors. Got A.ndim={A.ndim}, B.ndim={B.ndim}.")
+    if A.shape[0] != B.shape[0]:
+        raise ValueError(
+            f"Batch size mismatch. A.shape[0]={A.shape[0]} != B.shape[0]={B.shape[0]}."
+        )
+    if A.shape[2] != B.shape[1]:
+        raise ValueError(
+            f"Shape mismatch for batched matrix multiplication. "
+            f"A.shape[2]={A.shape[2]} must equal B.shape[1]={B.shape[1]}."
+        )
+    if A.device != B.device:
+        raise ValueError(f"Device mismatch. A.device={A.device}, B.device={B.device}.")
+
     if backend != "cutlass":

511-519: Make the B operand contiguous before invoking the CUTLASS runner.

transpose(-2, -1) returns a non-contiguous view, which causes the runtime error you reported: "RuntimeError: Check failed: (mat2.IsContiguous()) is false: mat2 must be contiguous". Per previous review feedback, materialize the column-major buffer before launching the kernel:

+                b_col_major = b.transpose(-2, -1).contiguous()
                 module.bf16_gemm(
                     a,
-                    b.transpose(-2, -1),
+                    b_col_major,
                     out,
                     workspace_buffer,
                     tactic,
                 )
🧹 Nitpick comments (2)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)

187-211: Config enumeration is clean but limited to one configuration.

The getConfigs implementation only enumerates CtaShape64x64x128B and ClusterShape_1x1x1, reflecting the WIP status and SMEM constraints. The nested loop pattern is extensible for adding more configs once SMEM issues are resolved.

Would you like help generating a script to analyze SMEM usage across different tile configurations to understand which sizes are viable for SM100?

flashinfer/jit/gemm/core.py (1)

193-237: WIP tile configurations are appropriate for initial testing.

The implementation correctly follows the established pattern from FP8/FP4 modules. The single active tile configuration (64, 64, 128) is a reasonable conservative choice while debugging SMEM constraints on SM100 hardware, especially given your limited B200 access.

Optional style improvement (flagged by static analysis):

-    return gen_jit_spec(
-        "bf16_gemm_cutlass",
-        source_paths,
-        extra_cuda_cflags=nvcc_flags + ["-DENABLE_BF16"],
-        extra_cflags=[
-            "-DFAST_BUILD",
-        ],
-    )
+    return gen_jit_spec(
+        "bf16_gemm_cutlass",
+        source_paths,
+        extra_cuda_cflags=[*nvcc_flags, "-DENABLE_BF16"],
+        extra_cflags=[
+            "-DFAST_BUILD",
+        ],
+    )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d2c8547 and 8a58e45.

📒 Files selected for processing (14)
  • csrc/bf16_gemm_cutlass.cu (1 hunks)
  • csrc/bf16_gemm_cutlass.jinja (1 hunks)
  • docs/api/gemm.rst (1 hunks)
  • flashinfer/__init__.py (1 hunks)
  • flashinfer/gemm/__init__.py (2 hunks)
  • flashinfer/gemm/gemm_base.py (4 hunks)
  • flashinfer/jit/gemm/__init__.py (2 hunks)
  • flashinfer/jit/gemm/core.py (1 hunks)
  • include/flashinfer/gemm/bf16_gemm_cutlass.h (1 hunks)
  • include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1 hunks)
  • include/flashinfer/gemm/bf16_gemm_template_sm100.h (1 hunks)
  • include/flashinfer/gemm/fp8_gemm_cutlass_template.h (0 hunks)
  • tests/gemm/test_bmm_bf16.py (1 hunks)
  • tests/gemm/test_mm_bf16.py (1 hunks)
💤 Files with no reviewable changes (1)
  • include/flashinfer/gemm/fp8_gemm_cutlass_template.h
🚧 Files skipped from review as they are similar to previous changes (2)
  • csrc/bf16_gemm_cutlass.cu
  • tests/gemm/test_mm_bf16.py
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • csrc/bf16_gemm_cutlass.jinja
  • include/flashinfer/gemm/bf16_gemm_template_sm100.h
  • include/flashinfer/gemm/bf16_gemm_cutlass_template.h
  • flashinfer/gemm/gemm_base.py
  • include/flashinfer/gemm/bf16_gemm_cutlass.h
  • flashinfer/__init__.py
🧬 Code graph analysis (9)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (2)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
  • flashinfer (41-134)
  • gemm (42-91)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
  • flashinfer (26-60)
  • gemm (27-59)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (3)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
  • flashinfer (26-60)
  • gemm (27-59)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (6)
  • gemm (44-176)
  • _1SM (53-57)
  • cutlass (135-135)
  • cutlass (136-136)
  • cutlass (137-137)
  • cutlass (138-138)
include/flashinfer/gemm/cutlass_gemm_configs.h (1)
  • CutlassTileConfigSM100 (106-425)
flashinfer/gemm/gemm_base.py (8)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)
  • gemm (42-91)
include/flashinfer/gemm/bf16_gemm_cutlass.h (1)
  • gemm (27-59)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
  • gemm (44-176)
flashinfer/jit/gemm/core.py (1)
  • gen_gemm_sm100_module_cutlass_bf16 (193-237)
flashinfer/utils.py (2)
  • supported_compute_capability (773-853)
  • _get_cache_buf (205-211)
flashinfer/autotuner.py (7)
  • TunableRunner (194-247)
  • OptimizationProfile (168-183)
  • AutoTuner (335-786)
  • TuningConfig (101-141)
  • DynamicTensorSpec (41-82)
  • ConstraintSpec (86-97)
  • choose_one (400-529)
csrc/bf16_gemm_cutlass.cu (4)
  • bf16_gemm_tactic_num (149-156)
  • bf16_gemm_tactic_num (149-149)
  • bf16_gemm (144-147)
  • bf16_gemm (144-145)
flashinfer/fused_moe/utils.py (2)
  • get_last_power_of_2_num_tokens_buckets (206-215)
  • last_positive_power_of_2 (183-188)
include/flashinfer/gemm/bf16_gemm_cutlass.h (3)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
  • flashinfer (41-134)
  • gemm (42-91)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
  • gemm (44-176)
flashinfer/gemm/gemm_base.py (1)
  • CutlassBf16GemmRunner (496-519)
flashinfer/gemm/__init__.py (1)
flashinfer/gemm/gemm_base.py (2)
  • bmm_bf16 (249-312)
  • mm_bf16 (183-245)
flashinfer/__init__.py (1)
flashinfer/gemm/gemm_base.py (2)
  • bmm_bf16 (249-312)
  • mm_bf16 (183-245)
flashinfer/jit/gemm/__init__.py (1)
flashinfer/jit/gemm/core.py (1)
  • gen_gemm_sm100_module_cutlass_bf16 (193-237)
tests/gemm/test_bmm_bf16.py (2)
flashinfer/gemm/gemm_base.py (1)
  • bmm_bf16 (249-312)
flashinfer/utils.py (2)
  • get_compute_capability (252-255)
  • is_compute_capability_supported (979-994)
flashinfer/jit/gemm/core.py (2)
flashinfer/jit/core.py (2)
  • JitSpec (213-312)
  • gen_jit_spec (315-381)
flashinfer/compilation_context.py (1)
  • get_nvcc_flags_list (50-68)
🪛 Clang (14.0.6)
include/flashinfer/gemm/bf16_gemm_template_sm100.h

[error] 23-23: 'cutlass/arch/arch.h' file not found

(clang-diagnostic-error)

include/flashinfer/gemm/bf16_gemm_cutlass_template.h

[error] 23-23: 'cutlass/arch/arch.h' file not found

(clang-diagnostic-error)

include/flashinfer/gemm/bf16_gemm_cutlass.h

[error] 20-20: 'cuda_runtime_api.h' file not found

(clang-diagnostic-error)

🪛 Ruff (0.14.4)
flashinfer/gemm/gemm_base.py

217-217: Avoid specifying long messages outside the exception class

(TRY003)


219-219: Avoid specifying long messages outside the exception class

(TRY003)


229-231: Avoid specifying long messages outside the exception class

(TRY003)


233-235: Avoid specifying long messages outside the exception class

(TRY003)


237-239: Avoid specifying long messages outside the exception class

(TRY003)


283-283: Avoid specifying long messages outside the exception class

(TRY003)


285-285: Avoid specifying long messages outside the exception class

(TRY003)


296-298: Avoid specifying long messages outside the exception class

(TRY003)


300-302: Avoid specifying long messages outside the exception class

(TRY003)


304-306: Avoid specifying long messages outside the exception class

(TRY003)


499-499: Unused method argument: inputs

(ARG002)


500-500: Unused method argument: profile

(ARG002)


508-508: Unused method argument: do_preparation

(ARG002)


509-509: Unused method argument: kwargs

(ARG002)

flashinfer/jit/gemm/core.py

233-233: Consider [*nvcc_flags, "-DENABLE_BF16"] instead of concatenation

Replace with [*nvcc_flags, "-DENABLE_BF16"]

(RUF005)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (12)
include/flashinfer/gemm/bf16_gemm_cutlass.h (1)

1-62: LGTM! Clean interface/implementation pattern.

The abstract interface and templated concrete class follow best practices for extensibility. The separation of public getWorkspaceSize and private getWorkspaceSizeImpl suggests proper encapsulation of workspace size computation logic.

include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)

46-151: LGTM! Standard CUTLASS GEMM setup.

The SMTypeAdapter specializations and launcher configuration follow CUTLASS patterns correctly. Regarding the comment on Line 147: setting fusion_args.alpha = 1.0f and fusion_args.beta = 0.0f is the standard way to configure a GEMM epilogue for D = A*B (no accumulation). This is the right approach.

include/flashinfer/gemm/bf16_gemm_cutlass_template.h (3)

136-143: LGTM! Clean forwarding to dispatcher.


145-160: Exception handling is appropriate for config probing.

The pattern of catching and ignoring std::runtime_error when probing workspace sizes is acceptable, as some configurations may legitimately fail due to SMEM constraints. The comment on Line 155 documents the rationale clearly.

Based on learnings


44-134: No changes needed—review comment is accurate.

Verification confirms the bf16 dispatcher is intentionally limited to CtaShape64x64x128B and ClusterShape_1x1x1 (lines 100–103 and 65–67), while other tile configs and cluster shapes remain commented out. This differs from the fp8 implementation, which enables multiple configurations, confirming the bf16 limitation is deliberate due to SMEM constraints as noted. The transpose pattern (swapping B, A and n, m at line 101–102) is correct for layout handling.

docs/api/gemm.rst (1)

10-18: LGTM! Documentation follows existing patterns.

The new BF16 GEMM section properly documents the mm_bf16 and bmm_bf16 entry points, following the same autosummary format as other GEMM types in this file.

flashinfer/__init__.py (1)

88-90: LGTM! BF16 GEMM exports are now available.

The imports of bmm_bf16 and mm_bf16 from the gemm module expose the new BF16 GEMM functionality at the top level. Past review comments indicate the necessary exports were added to flashinfer/gemm/__init__.py.

flashinfer/jit/gemm/__init__.py (1)

22-22: LGTM! JIT generator export follows existing patterns.

The gen_gemm_sm100_module_cutlass_bf16 import and export are consistent with other GEMM generators in this module.

Also applies to: 37-37

flashinfer/gemm/__init__.py (1)

2-2: LGTM! GEMM module exports are properly configured.

The bmm_bf16 and mm_bf16 imports from gemm_base and their inclusion in __all__ enable the top-level imports in flashinfer/__init__.py to work correctly.

Also applies to: 4-4, 25-25, 27-27

tests/gemm/test_bmm_bf16.py (1)

23-34: LGTM! Test logic is sound.

The test correctly creates BF16 inputs, computes a reference with torch.bmm, and validates the bmm_bf16 output using cosine similarity. The threshold of 0.99 is reasonable for BF16 precision.

csrc/bf16_gemm_cutlass.jinja (1)

1-27: LGTM! Clean template structure with conservative defaults.

The template correctly instantiates the SM100 BF16 GEMM kernel with a single-SM cluster configuration (1,1,1), which is appropriate for initial testing. The commented-out multi-SM cluster configurations provide clear guidance for future performance tuning once the basic implementation is validated.

flashinfer/gemm/gemm_base.py (1)

577-616: LGTM! AutoTuner integration follows established patterns.

The function correctly uses a.device for SM version checking and properly integrates with the AutoTuner for dynamic tactic selection. The tuning configuration appropriately profiles on the M dimension using power-of-2 bucketing, matching the pattern used in fp8_gemm_sm100.

@raayandhar raayandhar force-pushed the user/rdhar/cutlass_bf16_gemm_sm100 branch from dcbc17a to 28baee5 Compare November 17, 2025 05:10
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (7)
tests/gemm/test_mm_bf16.py (2)

25-31: Use row‑major (k, n) weight and avoid .T to keep inputs contiguous.

Generate mat2 as (k, n), use it directly in both the reference and API call. This prevents passing a non‑contiguous transpose and matches the documented contract.

-    mat2 = torch.randn([n, k], device="cuda", dtype=torch.bfloat16)
-
-    reference = torch.mm(input, mat2.T)
+    mat2 = torch.randn([k, n], device="cuda", dtype=torch.bfloat16)
+    reference = torch.mm(input, mat2)
 ...
-        mm_bf16(input, mat2.T, out=out, out_dtype=res_dtype)
+        mm_bf16(input, mat2, out=out, out_dtype=res_dtype)

14-16: Skip on CPU-only to avoid hard failure.

Add CUDA-availability guard before calling get_compute_capability.

 def test_mm_bf16(m: int, n: int, k: int, res_dtype: torch.dtype):
-    compute_capability = get_compute_capability(torch.device(device="cuda"))
+    if not torch.cuda.is_available():
+        pytest.skip("CUDA is not available")
+    compute_capability = get_compute_capability(torch.device(device="cuda"))
tests/gemm/test_bmm_bf16.py (1)

14-16: Skip on CPU-only to avoid hard failure.

Guard get_compute_capability(torch.device("cuda")) with a CUDA-availability check so the test skips instead of crashing on CPU-only runners.

 def test_bmm_bf16(b, m, n, k, res_dtype):
-    compute_capability = get_compute_capability(torch.device(device="cuda"))
+    if not torch.cuda.is_available():
+        pytest.skip("CUDA is not available")
+    compute_capability = get_compute_capability(torch.device(device="cuda"))
flashinfer/gemm/gemm_base.py (3)

182-205: Add essential input validation (dtype, shape, device).

Fail fast with clear errors to avoid cryptic backend failures.

 def mm_bf16(
     a: torch.Tensor,
     b: torch.Tensor,
@@
 ) -> torch.Tensor:
@@
-    if backend != "cutlass":
+    # Basic validations
+    if a.dtype != torch.bfloat16 or b.dtype != torch.bfloat16:
+        raise ValueError(f"Inputs must be bfloat16. Got a.dtype={a.dtype}, b.dtype={b.dtype}.")
+    if a.ndim != 2 or b.ndim != 2:
+        raise ValueError(f"mm_bf16 expects 2D tensors. Got a.ndim={a.ndim}, b.ndim={b.ndim}.")
+    if a.shape[1] != b.shape[0]:
+        raise ValueError(
+            f"Shape mismatch: a.shape[1]={a.shape[1]} must equal b.shape[0]={b.shape[0]}."
+        )
+    if a.device != b.device:
+        raise ValueError(f"Device mismatch. a.device={a.device}, b.device={b.device}.")
+
+    if backend != "cutlass":
         raise ValueError(f"Unsupported backend: {backend}. Only cutlass is available.")

259-283: Add essential input validation (batched dtype/shape/device).

Validate 3D inputs, batch, and K dims before launching the kernel.

 def bmm_bf16(
     A: torch.Tensor,
     B: torch.Tensor,
@@
 ) -> torch.Tensor:
@@
-    if backend != "cutlass":
+    # Basic validations
+    if A.dtype != torch.bfloat16 or B.dtype != torch.bfloat16:
+        raise ValueError(f"Inputs must be bfloat16. Got A.dtype={A.dtype}, B.dtype={B.dtype}.")
+    if A.ndim != 3 or B.ndim != 3:
+        raise ValueError(f"bmm_bf16 expects 3D tensors. Got A.ndim={A.ndim}, B.ndim={B.ndim}.")
+    if A.shape[0] != B.shape[0]:
+        raise ValueError(f"Batch size mismatch: A.shape[0]={A.shape[0]} != B.shape[0]={B.shape[0]}.")
+    if A.shape[2] != B.shape[1]:
+        raise ValueError(
+            f"K mismatch: A.shape[2]={A.shape[2]} must equal B.shape[1]={B.shape[1]}."
+        )
+    if A.device != B.device:
+        raise ValueError(f"Device mismatch. A.device={A.device}, B.device={B.device}.")
+
+    if backend != "cutlass":
         raise ValueError(f"Unsupported backend: {backend}. Only cutlass is available.")

533-541: Fix runtime error: make B contiguous before calling the CUTLASS binding.

b.transpose(-2, -1) is a non‑contiguous view; the C++ binding asserts contiguity (“mat2 must be contiguous”). Materialize column‑major B.

-                module.bf16_gemm(
-                    a,
-                    b.transpose(-2, -1),
-                    out,
-                    workspace_buffer,
-                    tactic,
-                )
+                b_col_major = b.transpose(-2, -1).contiguous()
+                module.bf16_gemm(
+                    a,
+                    b_col_major,
+                    out,
+                    workspace_buffer,
+                    tactic,
+                )
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)

157-175: Thread-safety and hash quality in workspace cache.

  • XOR-combining hashes collides easily.
  • The static workspace_hashmap is accessed unsafely across threads.

Harden both.

@@
-  struct MNKHash {
+  struct MNKHash {
     size_t operator()(const MNK& mnk) const {
       auto h1 = std::hash<int>{}(std::get<0>(mnk));
       auto h2 = std::hash<int>{}(std::get<1>(mnk));
       auto h3 = std::hash<int>{}(std::get<2>(mnk));
-      return h1 ^ h2 ^ h3;
+      // Robust hash combine to reduce collisions
+      size_t seed = h1;
+      seed ^= h2 + 0x9e3779b9 + (seed << 6) + (seed >> 2);
+      seed ^= h3 + 0x9e3779b9 + (seed << 6) + (seed >> 2);
+      return seed;
     }
   };
@@
-  static std::unordered_map<MNK, size_t, MNKHash> workspace_hashmap;
+  static std::mutex workspace_mutex;
+  static std::unordered_map<MNK, size_t, MNKHash> workspace_hashmap;
@@
-  size_t workspace_size = 0;
-  if (workspace_hashmap.find(std::make_tuple(m, n, k)) == workspace_hashmap.end()) {
-    workspace_size = CutlassBf16GemmRunner<T>::getWorkspaceSizeImpl(m, n, k);
-    workspace_hashmap[std::make_tuple(m, n, k)] = workspace_size;
-  } else {
-    workspace_size = workspace_hashmap[std::make_tuple(m, n, k)];
-  }
-  return workspace_size;
+  const MNK key = std::make_tuple(m, n, k);
+  {
+    std::lock_guard<std::mutex> lock(workspace_mutex);
+    auto it = workspace_hashmap.find(key);
+    if (it != workspace_hashmap.end()) {
+      return it->second;
+    }
+  }
+  // Compute outside lock to avoid blocking others; insert with lock.
+  size_t computed = CutlassBf16GemmRunner<T>::getWorkspaceSizeImpl(m, n, k);
+  {
+    std::lock_guard<std::mutex> lock(workspace_mutex);
+    auto it = workspace_hashmap.find(key);
+    if (it == workspace_hashmap.end()) {
+      workspace_hashmap.emplace(key, computed);
+      return computed;
+    }
+    return it->second;
+  }

Also add the include near the top:

-#include <stdexcept>
+#include <stdexcept>
+#include <mutex>
🧹 Nitpick comments (2)
flashinfer/jit/gemm/core.py (1)

229-236: Minor: prefer list splat over concatenation (RUF005).

Use list unpacking for readability in extra_cuda_cflags.

-    return gen_jit_spec(
+    return gen_jit_spec(
         "bf16_gemm_cutlass",
         source_paths,
-        extra_cuda_cflags=nvcc_flags + ["-DENABLE_BF16"],
+        extra_cuda_cflags=[*nvcc_flags, "-DENABLE_BF16"],
         extra_cflags=[
             "-DFAST_BUILD",
         ],
     )
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)

101-104: Remove unused template parameter or use it.

genericBf16GemmKernelLauncherSm100 has template param arch but hardcodes ArchTag = cutlass::arch::Sm100. Either use arch or drop the parameter.

-  using ArchTag = cutlass::arch::Sm100;
+  using ArchTag = arch;
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8a58e45 and 28baee5.

📒 Files selected for processing (14)
  • csrc/bf16_gemm_cutlass.cu (1 hunks)
  • csrc/bf16_gemm_cutlass.jinja (1 hunks)
  • docs/api/gemm.rst (1 hunks)
  • flashinfer/__init__.py (1 hunks)
  • flashinfer/gemm/__init__.py (2 hunks)
  • flashinfer/gemm/gemm_base.py (4 hunks)
  • flashinfer/jit/gemm/__init__.py (2 hunks)
  • flashinfer/jit/gemm/core.py (1 hunks)
  • include/flashinfer/gemm/bf16_gemm_cutlass.h (1 hunks)
  • include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1 hunks)
  • include/flashinfer/gemm/bf16_gemm_template_sm100.h (1 hunks)
  • include/flashinfer/gemm/fp8_gemm_cutlass_template.h (0 hunks)
  • tests/gemm/test_bmm_bf16.py (1 hunks)
  • tests/gemm/test_mm_bf16.py (1 hunks)
💤 Files with no reviewable changes (1)
  • include/flashinfer/gemm/fp8_gemm_cutlass_template.h
🚧 Files skipped from review as they are similar to previous changes (5)
  • docs/api/gemm.rst
  • csrc/bf16_gemm_cutlass.cu
  • flashinfer/init.py
  • flashinfer/gemm/init.py
  • flashinfer/jit/gemm/init.py
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • csrc/bf16_gemm_cutlass.jinja
  • include/flashinfer/gemm/bf16_gemm_template_sm100.h
  • include/flashinfer/gemm/bf16_gemm_cutlass.h
  • flashinfer/gemm/gemm_base.py
  • include/flashinfer/gemm/bf16_gemm_cutlass_template.h
🧬 Code graph analysis (7)
tests/gemm/test_mm_bf16.py (3)
flashinfer/autotuner.py (1)
  • autotune (251-262)
flashinfer/gemm/gemm_base.py (1)
  • mm_bf16 (183-256)
flashinfer/utils.py (2)
  • get_compute_capability (252-255)
  • is_compute_capability_supported (979-994)
tests/gemm/test_bmm_bf16.py (3)
flashinfer/autotuner.py (1)
  • autotune (251-262)
flashinfer/gemm/gemm_base.py (1)
  • bmm_bf16 (260-334)
flashinfer/utils.py (2)
  • get_compute_capability (252-255)
  • is_compute_capability_supported (979-994)
flashinfer/jit/gemm/core.py (2)
flashinfer/jit/core.py (2)
  • JitSpec (213-312)
  • gen_jit_spec (315-381)
flashinfer/compilation_context.py (1)
  • get_nvcc_flags_list (50-68)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (2)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
  • flashinfer (41-125)
  • gemm (42-91)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
  • flashinfer (26-60)
  • gemm (27-59)
include/flashinfer/gemm/bf16_gemm_cutlass.h (3)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
  • flashinfer (41-125)
  • gemm (42-91)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
  • gemm (44-182)
flashinfer/gemm/gemm_base.py (1)
  • CutlassBf16GemmRunner (518-541)
flashinfer/gemm/gemm_base.py (8)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)
  • gemm (42-91)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
  • gemm (44-182)
include/flashinfer/gemm/bf16_gemm_cutlass.h (1)
  • gemm (27-59)
flashinfer/jit/gemm/core.py (1)
  • gen_gemm_sm100_module_cutlass_bf16 (193-236)
flashinfer/utils.py (2)
  • supported_compute_capability (773-853)
  • _get_cache_buf (205-211)
flashinfer/autotuner.py (5)
  • TunableRunner (194-247)
  • OptimizationProfile (168-183)
  • TuningConfig (101-141)
  • DynamicTensorSpec (41-82)
  • ConstraintSpec (86-97)
csrc/bf16_gemm_cutlass.cu (4)
  • bf16_gemm_tactic_num (149-156)
  • bf16_gemm_tactic_num (149-149)
  • bf16_gemm (144-147)
  • bf16_gemm (144-145)
flashinfer/fused_moe/utils.py (2)
  • get_last_power_of_2_num_tokens_buckets (206-215)
  • last_positive_power_of_2 (183-188)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (7)
  • gemm (44-182)
  • _1SM (53-57)
  • _2SM (60-64)
  • cutlass (135-135)
  • cutlass (136-136)
  • cutlass (137-137)
  • cutlass (138-138)
include/flashinfer/gemm/cutlass_gemm_configs.h (1)
  • CutlassTileConfigSM100 (106-425)
🪛 Clang (14.0.6)
include/flashinfer/gemm/bf16_gemm_template_sm100.h

[error] 23-23: 'cutlass/arch/arch.h' file not found

(clang-diagnostic-error)

include/flashinfer/gemm/bf16_gemm_cutlass.h

[error] 20-20: 'cuda_runtime_api.h' file not found

(clang-diagnostic-error)

include/flashinfer/gemm/bf16_gemm_cutlass_template.h

[error] 23-23: 'cutlass/arch/arch.h' file not found

(clang-diagnostic-error)

🪛 Ruff (0.14.4)
flashinfer/jit/gemm/core.py

232-232: Consider [*nvcc_flags, "-DENABLE_BF16"] instead of concatenation

Replace with [*nvcc_flags, "-DENABLE_BF16"]

(RUF005)

flashinfer/gemm/gemm_base.py

228-228: Avoid specifying long messages outside the exception class

(TRY003)


230-230: Avoid specifying long messages outside the exception class

(TRY003)


240-242: Avoid specifying long messages outside the exception class

(TRY003)


244-246: Avoid specifying long messages outside the exception class

(TRY003)


248-250: Avoid specifying long messages outside the exception class

(TRY003)


305-305: Avoid specifying long messages outside the exception class

(TRY003)


307-307: Avoid specifying long messages outside the exception class

(TRY003)


318-320: Avoid specifying long messages outside the exception class

(TRY003)


322-324: Avoid specifying long messages outside the exception class

(TRY003)


326-328: Avoid specifying long messages outside the exception class

(TRY003)


521-521: Unused method argument: inputs

(ARG002)


522-522: Unused method argument: profile

(ARG002)


530-530: Unused method argument: do_preparation

(ARG002)


531-531: Unused method argument: kwargs

(ARG002)

🔇 Additional comments (2)
csrc/bf16_gemm_cutlass.jinja (1)

17-26: Instantiation set looks good.

Coverage of cluster shapes for 1SM/2SM variants matches the SM100 launcher; no issues spotted.

include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)

152-157: Good: explicit workspace probe path.

Early-return on null A/B/D ensures getWorkspaceSizeImpl can probe without needing a buffer. This unblocks tactic sizing. Based on learnings.

@raayandhar raayandhar changed the title feat: (wip) BF16 GEMM using CUTLASS backend for SM100 feat: BF16 GEMM using CUTLASS backend for SM100 Nov 17, 2025
@raayandhar
Copy link
Contributor Author

raayandhar commented Nov 17, 2025

Hi experts, I think this is now ready for review!
I had more trouble than I expected, even though there were already FP8 and FP4 implementations of CUTLASS GEMMs, and I learned a lot working on this, especially since it was my first time using/working with CUTLASS.

Right now we are passing all the tests that I wrote for this feature:

Test Results (click to expand)
(flashinfer) root@spry-shaggy-smilodon:~/flashinfer# pytest tests/gemm/test_bmm_bf16.py
======================================================== test session starts =========================================================
platform linux -- Python 3.12.12, pytest-9.0.1, pluggy-1.6.0
rootdir: /root/flashinfer
configfile: pytest.ini
collected 32 items                                                                                                                   

tests/gemm/test_bmm_bf16.py ................................                                                                   [100%]

========================================================= 32 passed in 2.37s =========================================================
(flashinfer) root@spry-shaggy-smilodon:~/flashinfer# pytest tests/gemm/test_mm_bf16.py
======================================================== test session starts =========================================================
platform linux -- Python 3.12.12, pytest-9.0.1, pluggy-1.6.0
rootdir: /root/flashinfer
configfile: pytest.ini
collected 90 items                                                                                                                   

tests/gemm/test_mm_bf16.py ..........................................................................................          [100%]

========================================================= 90 passed in 3.45s =========================================================

The original issue (#1974) was to see if CUTLASS backend GEMM for BF16 could do better at smaller batch sizes. Now, using linear_mm with the autotuning gives the following results (updated script here):

Benchmark Results (click to expand)
(flashinfer) root@spry-shaggy-smilodon:~/flashinfer# python benchmark_linear.py 
2025-11-17 05:53:37,584 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ...
2025-11-17 05:53:37,859 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends
batch=1
==============================================================================================================
SUMMARY COMPARISON
==============================================================================================================
Variant                             Median (us)     Std (us)     GFLOPS       BW (GB/s)    Speedup
--------------------------------------------------------------------------------------------------------------
1. Original                         12.354560       0.010645     1357.98      1358.97      1.00x
2. torch.compile()                  12.349120       0.016171     1358.58      1359.57      1.00x
3. max-autotune ncg                 5.017280        0.021958     3343.89      3346.34      2.46x
4. TGV GEMM pdl=False               6.540480        0.010113     2565.14      2567.01      1.89x
5. TGV GEMM pdl=True                6.099520        0.015821     2750.58      2752.59      2.03x
6. MM BF16                          9.650880        0.010973     1738.41      1739.69      1.28x

batch=2
==============================================================================================================
SUMMARY COMPARISON
==============================================================================================================
Variant                             Median (us)     Std (us)     GFLOPS       BW (GB/s)    Speedup
--------------------------------------------------------------------------------------------------------------
1. Original                         11.057280       0.015711     3034.60      1519.52      1.00x
2. torch.compile()                  11.058559       0.009181     3034.25      1519.35      1.00x
3. max-autotune ncg                 11.062080       0.015267     3033.28      1518.86      1.00x
4. TGV GEMM pdl=False               6.560000        0.019785     5115.00      2561.25      1.69x
5. TGV GEMM pdl=True                6.123200        0.017838     5479.88      2743.96      1.81x
6. MM BF16                          7.129600        0.013278     4706.36      2356.62      1.55x

batch=4
==============================================================================================================
SUMMARY COMPARISON
==============================================================================================================
Variant                             Median (us)     Std (us)     GFLOPS       BW (GB/s)    Speedup
--------------------------------------------------------------------------------------------------------------
1. Original                         11.060160       0.012899     6067.62      1521.35      1.00x
2. torch.compile()                  11.062720       0.015620     6066.22      1521.00      1.00x
3. max-autotune ncg                 11.064000       0.015849     6065.52      1520.82      1.00x
4. TGV GEMM pdl=False               6.556480        0.015751     10235.50     2566.37      1.69x
5. TGV GEMM pdl=True                6.122880        0.021713     10960.34     2748.11      1.81x
6. MM BF16                          7.208640        0.012594     9309.50      2334.19      1.53x

batch=8
==============================================================================================================
SUMMARY COMPARISON
==============================================================================================================
Variant                             Median (us)     Std (us)     GFLOPS       BW (GB/s)    Speedup
--------------------------------------------------------------------------------------------------------------
1. Original                         12.005440       0.011114     11179.74     1405.66      1.00x
2. torch.compile()                  11.999360       0.015240     11185.41     1406.37      1.00x
3. max-autotune ncg                 11.997440       0.015384     11187.20     1406.59      1.00x
4. TGV GEMM pdl=False               6.547520        0.011645     20499.02     2577.39      1.83x
5. TGV GEMM pdl=True                6.127040        0.017312     21905.80     2754.27      1.96x
6. MM BF16                          7.579840        0.028135     17707.20     2226.37      1.58x

batch=16
==============================================================================================================
SUMMARY COMPARISON
==============================================================================================================
Variant                             Median (us)     Std (us)     GFLOPS       BW (GB/s)    Speedup
--------------------------------------------------------------------------------------------------------------
1. Original                         8.046400        0.018652     33360.94     2109.49      1.00x
2. torch.compile()                  8.064640        0.016166     33285.48     2104.72      1.00x
3. max-autotune ncg                 8.056320        0.021365     33319.86     2106.90      1.00x
4. TGV GEMM pdl=False               6.656000        0.014604     40329.85     2550.15      1.21x
5. TGV GEMM pdl=True                6.190080        0.013137     43365.43     2742.10      1.30x
6. MM BF16                          7.120960        0.015677     37696.53     2383.64      1.13x

batch=32
==============================================================================================================
SUMMARY COMPARISON
==============================================================================================================
Variant                             Median (us)     Std (us)     GFLOPS       BW (GB/s)    Speedup
--------------------------------------------------------------------------------------------------------------
1. Original                         8.078080        0.018555     66460.21     2125.56      1.00x
2. torch.compile()                  8.084480        0.010366     66407.60     2123.88      1.00x
3. max-autotune ncg                 8.063040        0.016551     66584.18     2129.52      1.00x
4. TGV GEMM pdl=False               6.957120        0.017935     77168.56     2468.04      1.16x
5. TGV GEMM pdl=True                6.484800        0.011767     82789.12     2647.80      1.25x
6. MM BF16                          7.397760        0.004454     72572.09     2321.03      1.09x

batch=64
==============================================================================================================
SUMMARY COMPARISON
==============================================================================================================
Variant                             Median (us)     Std (us)     GFLOPS       BW (GB/s)    Speedup
--------------------------------------------------------------------------------------------------------------
1. Original                         8.182080        0.045039     131230.92    2146.60      1.00x
2. torch.compile()                  8.167680        0.016335     131462.28    2150.38      1.00x
3. max-autotune ncg                 8.189760        0.012934     131107.85    2144.59      1.00x
4. TGV GEMM pdl=False               12.971520       0.028602     82776.87     1354.02      0.63x
5. TGV GEMM pdl=True                12.534400       0.012996     85663.60     1401.24      0.65x
6. MM BF16                          7.150080        0.020649     150172.00    2456.43      1.14x

(flashinfer) root@spry-shaggy-smilodon:~/flashinfer# 

but the highlight is that at a larger batch size like batch=64, we are at ~7 microseconds while TGV is at ~12.5-13 microseconds, and original/torch.compile() is at ~8 microseconds. I'm a CUTLASS newbie, so maybe adding more tile sizes and cluster shapes / autotuning wider can get even better performance, since we are slightly worse elsewhere...

For reviewers:

  • I just picked tile sizes/cluster shapes that compiled from the FP8 implementations. Maybe we can get better performance with more tile sizes/cluster shapes? Especially targeting smaller batch sizes.
  • Let me know if different test coverage is needed in terms of sizes.
  • This is just SM100. I think this can largely be re-used for SM120. If you are happy with the changes here, I'd be happy to tackle SM120 as well. Also on this note, it may be worth to add a cuDNN backend option, if the performance might be better?

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 28baee5 and d3a53cd.

📒 Files selected for processing (1)
  • include/flashinfer/gemm/bf16_gemm_template_sm100.h (1 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/gemm/bf16_gemm_template_sm100.h
🧬 Code graph analysis (1)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (2)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
  • flashinfer (41-125)
  • gemm (42-91)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
  • flashinfer (26-60)
  • gemm (27-59)
🪛 Clang (14.0.6)
include/flashinfer/gemm/bf16_gemm_template_sm100.h

[error] 23-23: 'cutlass/arch/arch.h' file not found

(clang-diagnostic-error)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (2)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (2)

183-190: LGTM! Macro enables flexible template instantiation.

The INSTANCE_BF16_GEMM_TEMPLATE_SM100 macro correctly provides explicit template instantiation control for different tile sizes, cluster shapes, and SM types. The parameter list matches the launcher's template signature, and the macro will be used by the JIT generator to instantiate specific configurations.


49-50: ****

The forward declarations in this file are not problematic duplicates requiring consolidation. The actual struct definitions of _1SM and _2SM are in include/flashinfer/gemm/bf16_gemm_cutlass_template.h (lines 44-45), while the SM100 template files provide independent forward declarations. This is the correct C++ pattern: the base template defines the types, and SM100 template files forward-declare them to specialize SMTypeAdapter<_1SM> and SMTypeAdapter<_2SM> without incurring unnecessary includes. This separation of concerns is appropriate and consistent across all GEMM implementations (bf16, fp8, fp4).

Likely an incorrect or invalid review comment.

throw std::runtime_error("[Bf16 Gemm Runner] insufficient workspace");
}

auto can_implement = gemm.can_implement(arguments);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there any advantage to doing these safety checks this way instead of just using the CUTLASS_CHECK macro? I saw it done this way for FP8 and FP4, so I kept it this way. But just wondering because it seems the same?

@raayandhar raayandhar force-pushed the user/rdhar/cutlass_bf16_gemm_sm100 branch from 1387bed to a56d74b Compare November 19, 2025 00:23
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)

136-176: Guard static workspace cache against concurrent access

CutlassBf16GemmRunner<T>::getWorkspaceSize uses a function‑local static std::unordered_map to memoize workspace sizes, but accesses it without synchronization:

  • workspace_hashmap.find(...) and workspace_hashmap[...] = ... can race when getWorkspaceSize is called concurrently from multiple threads, leading to undefined behavior.

Given this is a header template used across translation units, it’s easy for callers to hit it from multiple threads (e.g., multi-stream inference).

Consider protecting the map with a mutex:

@@
   struct MNKHash {
     size_t operator()(const MNK& mnk) const {
       auto h1 = std::hash<int>{}(std::get<0>(mnk));
       auto h2 = std::hash<int>{}(std::get<1>(mnk));
       auto h3 = std::hash<int>{}(std::get<2>(mnk));
       return h1 ^ h2 ^ h3;
     }
   };
 
-  static std::unordered_map<MNK, size_t, MNKHash> workspace_hashmap;
+  static std::unordered_map<MNK, size_t, MNKHash> workspace_hashmap;
+  static std::mutex workspace_mutex;
@@
-  size_t workspace_size = 0;
-  if (workspace_hashmap.find(std::make_tuple(m, n, k)) == workspace_hashmap.end()) {
-    workspace_size = CutlassBf16GemmRunner<T>::getWorkspaceSizeImpl(m, n, k);
-    workspace_hashmap[std::make_tuple(m, n, k)] = workspace_size;
-  } else {
-    workspace_size = workspace_hashmap[std::make_tuple(m, n, k)];
-  }
-  return workspace_size;
+  const MNK key = std::make_tuple(m, n, k);
+  size_t workspace_size;
+  std::lock_guard<std::mutex> lock(workspace_mutex);
+  auto it = workspace_hashmap.find(key);
+  if (it == workspace_hashmap.end()) {
+    workspace_size = CutlassBf16GemmRunner<T>::getWorkspaceSizeImpl(m, n, k);
+    workspace_hashmap.emplace(key, workspace_size);
+  } else {
+    workspace_size = it->second;
+  }
+  return workspace_size;

You may also want to upgrade MNKHash to a stronger combiner (e.g., boost::hash_combine style) to reduce collision risk, though that’s a secondary concern.

tests/gemm/test_mm_bf16.py (1)

13-21: Skip BF16 MM test cleanly when CUDA is unavailable

This test assumes CUDA is present:

compute_capability = get_compute_capability(torch.device(device="cuda"))

On CPU-only machines this will raise instead of skipping. Guard before any CUDA calls:

 def test_mm_bf16(m: int, n: int, k: int, res_dtype: torch.dtype):
-    compute_capability = get_compute_capability(torch.device(device="cuda"))
+    if not torch.cuda.is_available():
+        pytest.skip("CUDA is not available; skipping mm_bf16 tests.")
+
+    compute_capability = get_compute_capability(torch.device(device="cuda"))

The rest of the test (layouts and reference computation) looks consistent with the mm_bf16 API.

🧹 Nitpick comments (3)
flashinfer/jit/gemm/core.py (1)

193-236: SM100 BF16 JIT generator looks consistent; only a minor style nit on flag concatenation

Generation directory, sources, tile/dtype loops, and NVCC flags match the existing SM100 FP8/FP4 patterns and look correct. You could adopt the small Ruff suggestion for cleanliness:

-    return gen_jit_spec(
-        "bf16_gemm_cutlass",
-        source_paths,
-        extra_cuda_cflags=nvcc_flags + ["-DENABLE_BF16"],
-        extra_cflags=[
-            "-DFAST_BUILD",
-        ],
-    )
+    return gen_jit_spec(
+        "bf16_gemm_cutlass",
+        source_paths,
+        extra_cuda_cflags=[*nvcc_flags, "-DENABLE_BF16"],
+        extra_cflags=["-DFAST_BUILD"],
+    )
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)

66-178: SM100 BF16 kernel launcher matches existing CUTLASS patterns; only minor nits

The launcher’s layout/stride setup, workspace query (!A && !B && !D), can_implement/initialize/run flow, and error handling look sound and consistent with the FP8 templates. The unused arch template parameter and CutlassGemmConfig config argument are harmless but could be dropped or used later if you add per-config tuning; likewise, caching gemm.get_workspace_size(arguments) into a local would avoid recomputing it three times. No functional issues spotted here.

csrc/bf16_gemm_cutlass.cu (1)

49-139: FFI BF16 GEMM wiring, shapes, and workspace handling look correct

The FFI layer’s behavior is consistent end‑to‑end:

  • getBf16GemmConfig pulls configs once and indexes them by tactic, matching the Python autotuner’s “tactic == index” contract.
  • runGemm uses CutlassBf16GemmRunner::getWorkspaceSize to size the workspace, allocates a temporary buffer when the cached one is too small, and always passes a sufficient workspaceBytes to the runner.
  • bf16_bmm_impl’s 2D and 3D shape logic (mat1: (m,k)/(b,m,k), mat2: (n,k)/(b,n,k)) matches the Python side’s extra transpose, and the out‑shape checks are sound.

Only minor nit: the m, n, k parameters of getBf16GemmConfig are currently unused; you could drop them or use them later for heuristic selection.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d3a53cd and a56d74b.

📒 Files selected for processing (13)
  • csrc/bf16_gemm_cutlass.cu (1 hunks)
  • csrc/bf16_gemm_cutlass.jinja (1 hunks)
  • docs/api/gemm.rst (1 hunks)
  • flashinfer/__init__.py (1 hunks)
  • flashinfer/gemm/__init__.py (2 hunks)
  • flashinfer/gemm/gemm_base.py (4 hunks)
  • flashinfer/jit/gemm/__init__.py (2 hunks)
  • flashinfer/jit/gemm/core.py (1 hunks)
  • include/flashinfer/gemm/bf16_gemm_cutlass.h (1 hunks)
  • include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1 hunks)
  • include/flashinfer/gemm/bf16_gemm_template_sm100.h (1 hunks)
  • tests/gemm/test_bmm_bf16.py (1 hunks)
  • tests/gemm/test_mm_bf16.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (5)
  • tests/gemm/test_bmm_bf16.py
  • docs/api/gemm.rst
  • csrc/bf16_gemm_cutlass.jinja
  • flashinfer/init.py
  • flashinfer/gemm/init.py
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/gemm/bf16_gemm_template_sm100.h
  • include/flashinfer/gemm/bf16_gemm_cutlass.h
  • csrc/bf16_gemm_cutlass.cu
  • flashinfer/gemm/gemm_base.py
  • include/flashinfer/gemm/bf16_gemm_cutlass_template.h
🧬 Code graph analysis (8)
flashinfer/jit/gemm/core.py (2)
flashinfer/jit/core.py (2)
  • JitSpec (213-312)
  • gen_jit_spec (315-381)
flashinfer/compilation_context.py (1)
  • get_nvcc_flags_list (50-68)
tests/gemm/test_mm_bf16.py (3)
flashinfer/autotuner.py (1)
  • autotune (251-262)
flashinfer/gemm/gemm_base.py (1)
  • mm_bf16 (183-256)
flashinfer/utils.py (2)
  • get_compute_capability (253-256)
  • is_compute_capability_supported (1020-1035)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (2)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
  • flashinfer (41-125)
  • gemm (42-91)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
  • flashinfer (26-60)
  • gemm (27-59)
include/flashinfer/gemm/bf16_gemm_cutlass.h (3)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (2)
  • flashinfer (41-125)
  • gemm (42-91)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
  • gemm (44-181)
flashinfer/gemm/gemm_base.py (1)
  • CutlassBf16GemmRunner (518-541)
flashinfer/jit/gemm/__init__.py (1)
flashinfer/jit/gemm/core.py (1)
  • gen_gemm_sm100_module_cutlass_bf16 (193-236)
csrc/bf16_gemm_cutlass.cu (4)
flashinfer/gemm/gemm_base.py (1)
  • CutlassBf16GemmRunner (518-541)
include/flashinfer/gemm/bf16_gemm_cutlass.h (1)
  • CutlassBf16GemmRunnerInterface (29-41)
include/flashinfer/gemm/cutlass_gemm_configs.h (1)
  • CutlassTileConfigSM100 (106-425)
csrc/tvm_ffi_utils.h (2)
  • get_stream (272-274)
  • encode_dlpack_dtype (29-31)
flashinfer/gemm/gemm_base.py (8)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
  • gemm (44-181)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)
  • gemm (42-91)
include/flashinfer/gemm/bf16_gemm_cutlass.h (1)
  • gemm (27-59)
flashinfer/jit/gemm/core.py (1)
  • gen_gemm_sm100_module_cutlass_bf16 (193-236)
flashinfer/utils.py (2)
  • supported_compute_capability (814-894)
  • _get_cache_buf (206-212)
flashinfer/autotuner.py (6)
  • TunableRunner (194-247)
  • OptimizationProfile (168-183)
  • AutoTuner (335-786)
  • TuningConfig (101-141)
  • DynamicTensorSpec (41-82)
  • choose_one (400-529)
csrc/bf16_gemm_cutlass.cu (4)
  • bf16_gemm_tactic_num (149-156)
  • bf16_gemm_tactic_num (149-149)
  • bf16_gemm (144-147)
  • bf16_gemm (144-145)
flashinfer/fused_moe/utils.py (2)
  • get_last_power_of_2_num_tokens_buckets (206-215)
  • last_positive_power_of_2 (183-188)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (3)
include/flashinfer/gemm/bf16_gemm_cutlass.h (2)
  • flashinfer (26-60)
  • gemm (27-59)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (7)
  • gemm (44-181)
  • _1SM (53-57)
  • _2SM (60-64)
  • cutlass (135-135)
  • cutlass (136-136)
  • cutlass (137-137)
  • cutlass (138-138)
include/flashinfer/gemm/cutlass_gemm_configs.h (1)
  • CutlassTileConfigSM100 (106-425)
🪛 Clang (14.0.6)
include/flashinfer/gemm/bf16_gemm_template_sm100.h

[error] 23-23: 'cutlass/arch/arch.h' file not found

(clang-diagnostic-error)

include/flashinfer/gemm/bf16_gemm_cutlass.h

[error] 20-20: 'cuda_runtime_api.h' file not found

(clang-diagnostic-error)

include/flashinfer/gemm/bf16_gemm_cutlass_template.h

[error] 23-23: 'cutlass/arch/arch.h' file not found

(clang-diagnostic-error)

🪛 Ruff (0.14.5)
flashinfer/jit/gemm/core.py

232-232: Consider [*nvcc_flags, "-DENABLE_BF16"] instead of concatenation

Replace with [*nvcc_flags, "-DENABLE_BF16"]

(RUF005)

flashinfer/gemm/gemm_base.py

228-228: Avoid specifying long messages outside the exception class

(TRY003)


230-230: Avoid specifying long messages outside the exception class

(TRY003)


240-242: Avoid specifying long messages outside the exception class

(TRY003)


244-246: Avoid specifying long messages outside the exception class

(TRY003)


248-250: Avoid specifying long messages outside the exception class

(TRY003)


305-305: Avoid specifying long messages outside the exception class

(TRY003)


307-307: Avoid specifying long messages outside the exception class

(TRY003)


318-320: Avoid specifying long messages outside the exception class

(TRY003)


322-324: Avoid specifying long messages outside the exception class

(TRY003)


326-328: Avoid specifying long messages outside the exception class

(TRY003)


521-521: Unused method argument: inputs

(ARG002)


522-522: Unused method argument: profile

(ARG002)


530-530: Unused method argument: do_preparation

(ARG002)


531-531: Unused method argument: kwargs

(ARG002)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (3)
flashinfer/jit/gemm/__init__.py (1)

17-38: BF16 SM100 JIT generator correctly exported

Importing gen_gemm_sm100_module_cutlass_bf16 and adding it to __all__ is consistent with the existing FP4/FP8 exports and makes the BF16 generator discoverable where expected.

include/flashinfer/gemm/bf16_gemm_cutlass.h (1)

29-57: BF16 GEMM runner interface is well-shaped and matches usage

Interface methods and the templated CutlassBf16GemmRunner align with the CUDA implementation (gemm, workspace sizing, config enumeration). Header-only surface looks good; no changes needed here.

flashinfer/gemm/gemm_base.py (1)

182-257: BF16 SM100 Python path is coherent; layout and autotuning wiring look good

  • mm_bf16 / bmm_bf16:

    • API contracts (A: (m, k)/(b, m, k), B: (k, n)/(b, k, n), out: (m, n)/(b, m, n), bf16/fp16) are clear and align with the docstring examples.
    • Output allocation and validation mirror existing GEMM helpers and look correct.
    • Relying on the C++ layer for dtype/shape checks is consistent with your earlier stance and avoids duplication.
  • get_gemm_sm100_module_cutlass_bf16 and runner:

    • JIT loading via gen_gemm_sm100_module_cutlass_bf16().build_and_load() matches the established FP8/FP4 pattern.
    • The runner’s get_valid_tactics returns range(module.bf16_gemm_tactic_num()), which fits the FFI config index logic.
    • Transposing b inside the runner before calling module.bf16_gemm correctly converts user‑visible (k, n)/(b, k, n) row‑major weights into the (n, k)/(b, n, k) layout expected by the CUTLASS kernels, and also resolves the earlier “mat2 must be contiguous” issue when callers follow the documented transpose pattern.
  • bf16_gemm_sm100:

    • SM gating via _match_sm_version(a.device, ["100"]) and the supported_compute_capability([100]) decorator on the public APIs tie this path cleanly to B200‑class hardware.
    • AutoTuner integration (dynamic spec on the a tensor’s -2 dim and a matching constraint on out) is parallel to the FP8 GEMM tuning setup and should behave well with varying M.

Net: the BF16 GEMM SM100 path is wired consistently from public Python API through autotuning into the CUTLASS backend. Any rough edges left (e.g., generic “No suitable runners found” assertion on non‑SM100 devices) are minor and acceptable for an SM100‑only feature.

Also applies to: 259-334, 513-547, 599-637

@raayandhar raayandhar requested a review from bkryu as a code owner November 27, 2025 04:23
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
flashinfer/gemm/gemm_base.py (1)

603-641: Consider adding support for SM103/SM110 and improving error handling.

Two suggestions for improving robustness:

  1. Limited SM version support: The function only supports SM100 (line 610), while the comparable fp8_gemm_sm100 supports multiple SM versions via the runner_names parameter. Consider whether BF16 GEMM should also support SM103 and SM110 for consistency.

  2. Improve error handling: Line 612 uses assert runners which produces a cryptic AssertionError if no suitable runners are found. Replace with a user-friendly error message similar to the pattern used elsewhere in the codebase.

Apply this diff to improve the error message:

     if _match_sm_version(a.device, ["100"]):
         runners.append(get_gemm_sm100_module_cutlass_bf16().cutlass_bf16_gemm_runner())
-    assert runners, "No suitable runners found"
+    if len(runners) == 0:
+        major, minor = get_compute_capability(a.device)
+        raise ValueError(f"No valid runner found for current device sm{major}{minor}")
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a56d74b and db00e51.

📒 Files selected for processing (1)
  • flashinfer/gemm/gemm_base.py (4 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • flashinfer/gemm/gemm_base.py
🧬 Code graph analysis (1)
flashinfer/gemm/gemm_base.py (5)
include/flashinfer/gemm/bf16_gemm_template_sm100.h (1)
  • gemm (44-181)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)
  • gemm (42-91)
include/flashinfer/gemm/bf16_gemm_cutlass.h (1)
  • gemm (27-59)
flashinfer/jit/gemm/core.py (1)
  • gen_gemm_sm100_module_cutlass_bf16 (193-236)
flashinfer/utils.py (1)
  • supported_compute_capability (814-894)
🪛 Ruff (0.14.6)
flashinfer/gemm/gemm_base.py

230-230: Avoid specifying long messages outside the exception class

(TRY003)


232-232: Avoid specifying long messages outside the exception class

(TRY003)


242-244: Avoid specifying long messages outside the exception class

(TRY003)


246-248: Avoid specifying long messages outside the exception class

(TRY003)


250-252: Avoid specifying long messages outside the exception class

(TRY003)


307-307: Avoid specifying long messages outside the exception class

(TRY003)


309-309: Avoid specifying long messages outside the exception class

(TRY003)


320-322: Avoid specifying long messages outside the exception class

(TRY003)


324-326: Avoid specifying long messages outside the exception class

(TRY003)


328-330: Avoid specifying long messages outside the exception class

(TRY003)


523-523: Unused method argument: inputs

(ARG002)


524-524: Unused method argument: profile

(ARG002)


532-532: Unused method argument: do_preparation

(ARG002)


533-533: Unused method argument: kwargs

(ARG002)

🔇 Additional comments (2)
flashinfer/gemm/gemm_base.py (2)

55-55: LGTM!

The import follows the existing naming convention and is correctly placed with other JIT GEMM module imports.


515-549: Unused parameter warnings are false positives.

The static analysis tool flags inputs, profile, do_preparation, and kwargs as unused in the CutlassBf16GemmRunner class (lines 523-524, 532-533). These parameters are required by the TunableRunner interface and cannot be removed.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between db00e51 and 323e6fa.

📒 Files selected for processing (1)
  • flashinfer/gemm/gemm_base.py (4 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • flashinfer/gemm/gemm_base.py
🧬 Code graph analysis (1)
flashinfer/gemm/gemm_base.py (3)
include/flashinfer/gemm/bf16_gemm_cutlass_template.h (1)
  • gemm (42-91)
flashinfer/utils.py (1)
  • _get_cache_buf (206-212)
csrc/bf16_gemm_cutlass.cu (4)
  • bf16_gemm_tactic_num (149-156)
  • bf16_gemm_tactic_num (149-149)
  • bf16_gemm (144-147)
  • bf16_gemm (144-145)
🪛 Ruff (0.14.6)
flashinfer/gemm/gemm_base.py

230-230: Avoid specifying long messages outside the exception class

(TRY003)


232-232: Avoid specifying long messages outside the exception class

(TRY003)


242-244: Avoid specifying long messages outside the exception class

(TRY003)


246-248: Avoid specifying long messages outside the exception class

(TRY003)


250-252: Avoid specifying long messages outside the exception class

(TRY003)


307-307: Avoid specifying long messages outside the exception class

(TRY003)


309-309: Avoid specifying long messages outside the exception class

(TRY003)


320-322: Avoid specifying long messages outside the exception class

(TRY003)


324-326: Avoid specifying long messages outside the exception class

(TRY003)


328-330: Avoid specifying long messages outside the exception class

(TRY003)


523-523: Unused method argument: inputs

(ARG002)


524-524: Unused method argument: profile

(ARG002)


532-532: Unused method argument: do_preparation

(ARG002)


533-533: Unused method argument: kwargs

(ARG002)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (2)
flashinfer/gemm/gemm_base.py (2)

515-549: Clarify resolution of transpose contiguity concern.

Earlier reviews flagged that b.transpose(-2, -1) at line 538 creates a non-contiguous view, which reportedly triggered the runtime error "mat2 must be contiguous" you mentioned. You noted that calling .contiguous() caused a SEGFAULT.

The code still passes the non-contiguous transposed view directly, yet you report that local tests pass (32 bmm_bf16 + 90 mm_bf16). This suggests either:

  1. The C++ binding was updated to handle non-contiguous tensors, or
  2. The kernel correctly interprets the strided layout from .transpose(-2, -1), or
  3. The test cases don't trigger the failure path.

Since the FP8 implementation (line 498) uses an identical pattern, this may be the intended approach. However, for future maintainability, could you briefly confirm:

  • Was the C++ side modified to accept non-contiguous mat2?
  • Or does the CUTLASS kernel natively handle the strided transpose layout?

This will help document the resolution and prevent confusion in future reviews.

Note: The unused arguments do_preparation and kwargs (lines 532-533) are part of the TunableRunner interface, so the static analysis warnings can be ignored.


603-641: LGTM! Follows established GEMM dispatcher patterns.

The implementation correctly:

  • Validates SM100 support via _match_sm_version
  • Initializes the BF16 runner
  • Configures autotuning with dynamic tensor specs for the M dimension
  • Uses constraint specs to keep output shape consistent with input
  • Integrates with the existing AutoTuner infrastructure

The structure mirrors fp8_gemm_sm100 (lines 571-600), ensuring consistency across GEMM implementations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant