Skip to content

Add cute dsl mla decode op#2743

Merged
bkryu merged 35 commits intoflashinfer-ai:mainfrom
limin2021:add_cute_dsl_mla_new
Mar 26, 2026
Merged

Add cute dsl mla decode op#2743
bkryu merged 35 commits intoflashinfer-ai:mainfrom
limin2021:add_cute_dsl_mla_new

Conversation

@limin2021
Copy link
Contributor

@limin2021 limin2021 commented Mar 10, 2026

📌 Description

Integrate NVIDIA's CuTe DSL MLA (Multi-Head Latent Attention) decode kernels for Blackwell SM100 into FlashInfer, supporting both BF16/FP16 and FP8 dtypes.

  • Add CuTe DSL MLA decode kernel files (mla_helpers.py, mla_decode_fp16.py, mla_decode_fp8.py) and compilation wrapper (mla_decode.py)
  • Accept tensors from PyTorch and use zero-cost cute.make_tensor layout reinterpretation inside kernel call, eliminating ~10 us of Python-side .permute() overhead per call
  • Compile with --enable-tvm-ffi for AOT caching via compile_and_cache_cute_dsl_kernel

Test plan

  • python -m pytest tests/attention/test_cute_dsl_mla_decode.py — 18 tests passing (FP16 + FP8, various batch/head/seq_len configs)
  • Standalone run functions in mla_decode_fp16.py and mla_decode_fp8.py pass
  • pre-commit run --all-files passes

Funtionality:

All 396 configs PASSED with 0 failures.

Test matrix:

  • dtype: bfloat16, float8_e4m3fn
  • page_size: 32, 64
  • batch_size: 1, 2, 4, 16, 32, 64, 128, 256, 512, 768, 1024
  • seq_len: 1024, 4096, 8192
  • q_len_per_request: 1, 2, 4
Status Count
PASSED 396
FAILED 0
TOTAL 396

Performance:

  • GPU: NVIDIA Blackwell (SM100a)
  • Model config: DeepSeek-V3 MLA (128 heads, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64)
  • KV seq_len: 8192, page_size: 32
  • Timing: CUPTI, CUDA graph enabled, cold L2 cache, 30 iterations median
  • Date: 2026-03-11

1. FP8 fixed-len (is_var_seq=False → persistent)

Config                                             trtllm-gen (ms)   cute-dsl (ms)    Speedup
---------------------------------------------------------------------------------------------
B=1, q=1, s=8192, ps=32, fp8                                0.0159          0.0166      0.96x
B=32, q=1, s=8192, ps=32, fp8                               0.0522          0.0513      1.02x
B=64, q=1, s=8192, ps=32, fp8                               0.0771          0.0704      1.10x
B=128, q=1, s=8192, ps=32, fp8                              0.1430          0.1336      1.07x
B=256, q=1, s=8192, ps=32, fp8                              0.2825          0.2681      1.05x
B=1, q=4, s=8192, ps=32, fp8                                0.0192          0.0185      1.04x
B=32, q=4, s=8192, ps=32, fp8                               0.1307          0.1214      1.08x
B=64, q=4, s=8192, ps=32, fp8                               0.2612          0.2441      1.07x
B=128, q=4, s=8192, ps=32, fp8                              0.4840          0.4533      1.07x
B=256, q=4, s=8192, ps=32, fp8                              0.9927          0.9359      1.06x

2. FP8 var-seqlen (is_var_seq=True → non-persistent)

Config                                             trtllm-gen (ms)   cute-dsl (ms)    Speedup
---------------------------------------------------------------------------------------------
B=1, q=1, s=8192, ps=32, fp8                                0.0159          0.0164      0.97x
B=32, q=1, s=8192, ps=32, fp8                               0.0463          0.0468      0.99x
B=64, q=1, s=8192, ps=32, fp8                               0.0704          0.0640      1.10x
B=128, q=1, s=8192, ps=32, fp8                              0.1264          0.1020      1.24x
B=256, q=1, s=8192, ps=32, fp8                              0.1873          0.1698      1.10x
B=1, q=4, s=8192, ps=32, fp8                                0.0192          0.0184      1.05x
B=32, q=4, s=8192, ps=32, fp8                               0.1181          0.1037      1.14x
B=64, q=4, s=8192, ps=32, fp8                               0.1851          0.1637      1.13x
B=128, q=4, s=8192, ps=32, fp8                              0.3040          0.2930      1.04x
B=256, q=4, s=8192, ps=32, fp8                              0.5964          0.6038      0.99x

3. BF16 fixed-len (is_var_seq=False → persistent)

Config                                             trtllm-gen (ms)   cute-dsl (ms)    Speedup
---------------------------------------------------------------------------------------------
B=1, q=1, s=8192, ps=32, bf16                               0.0241          0.0185      1.30x
B=32, q=1, s=8192, ps=32, bf16                              0.0824          0.0844      0.98x
B=64, q=1, s=8192, ps=32, bf16                              0.1351          0.1283      1.05x
B=128, q=1, s=8192, ps=32, bf16                             0.2566          0.2441      1.05x
B=256, q=1, s=8192, ps=32, bf16                             0.5106          0.4971      1.03x
B=1, q=4, s=8192, ps=32, bf16                               0.0227          0.0224      1.02x
B=32, q=4, s=8192, ps=32, bf16                              0.2136          0.2096      1.02x
B=64, q=4, s=8192, ps=32, bf16                              0.4284          0.4347      0.99x
B=128, q=4, s=8192, ps=32, bf16                             0.7891          0.8124      0.97x
B=256, q=4, s=8192, ps=32, bf16                             1.6007          1.7218      0.93x

4. BF16 var-seqlen (is_var_seq=True → non-persistent)

Config                                             trtllm-gen (ms)   cute-dsl (ms)    Speedup
---------------------------------------------------------------------------------------------
B=1, q=1, s=8192, ps=32, bf16                               0.0241          0.0184      1.31x
B=32, q=1, s=8192, ps=32, bf16                              0.0746          0.0764      0.98x
B=64, q=1, s=8192, ps=32, bf16                              0.1210          0.1126      1.07x
B=128, q=1, s=8192, ps=32, bf16                             0.2196          0.1805      1.22x
B=256, q=1, s=8192, ps=32, bf16                             0.3392          0.3090      1.10x
B=1, q=4, s=8192, ps=32, bf16                               0.0220          0.0214      1.03x
B=32, q=4, s=8192, ps=32, bf16                              0.1841          0.1743      1.06x
B=64, q=4, s=8192, ps=32, bf16                              0.2929          0.2961      0.99x
B=128, q=4, s=8192, ps=32, bf16                             0.5050          0.5384      0.94x
B=256, q=4, s=8192, ps=32, bf16                             1.0073          1.0903      0.92x

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added "cute-dsl" backend for MLA decoding with FP16/BF16/FP8 support and a PyTorch-compatible MLA decode API.
    • Added MLA static-tile scheduling helpers and deterministic kernel/workspace handling.
  • Public API

    • Exposed the new decode API and a torch→Cutlass dtype mapping in the public interface.
  • Tests

    • End-to-end tests for FP16/BF16/FP8, variable sequence lengths, and API path.
  • Benchmarking

    • Benchmarks updated to select and exercise the new backend.

limin2021 and others added 7 commits March 10, 2026 00:35
Integrates NVIDIA's CuTe DSL MLA decode kernels (FP16/FP8) for
Blackwell SM100 as a new "cute-dsl" backend in
trtllm_batch_decode_with_kv_cache_mla().

Key tensor layout insights documented in mla_decode.py:
- c_latent/c_rope kernel layout is [page_size, D, total_pages], not
  [total_tokens, D, 1] — the kernel indexes KV intra-page per physical page
- All fake tensor dimensions must be cute.sym_int() (not static Python ints)
  so cute.assume() receives CuTe Integer types in initialize_workspace()
- lse fake tensor needs stride_order=(0,1,2) for stride[0]=1 compile-time constant
- Do NOT call .contiguous() after .permute() on q/lse/o tensors — it
  collapses to row-major, destroying required non-standard strides
- Separate sym_kv_batch for KV cache (=1, flat pool) vs query batch (=B)

New files:
- flashinfer/cute_dsl/mla_helpers.py
- flashinfer/cute_dsl/mla_decode_fp16.py
- flashinfer/cute_dsl/mla_decode_fp8.py
- flashinfer/cute_dsl/mla_decode.py (compilation wrapper + public API)
- tests/attention/test_cute_dsl_mla_decode.py (14 tests, all passing)
- Remove unnecessary .contiguous() on page_table transpose by changing
  fake tensor stride_order from (1,0) to (0,1), matching the original
  kernel's convention of non-contiguous permute(1,0)
- Use torch.full instead of torch.ones * val for block_split_kvs
- Remove redundant .contiguous() on workspace buffer slice
- Remove redundant .to(int32).contiguous() when seq_lens is already int32
- Eliminate output copy_ by writing kernel output directly into caller's
  out tensor via permute view (works for both q_len=1 and q_len>1)
- Fix output allocation order from (B,H,q_len,D) to (B,q_len,H,D) so
  permute back to user layout is naturally contiguous, removing .contiguous()
- Cache split_kv and workspace_size computation via functools.cache
- Remove tensor_api closure wrapper, call compiled_kernel directly
- Add host overhead benchmark script

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…e permutes

Accept contiguous row-major tensors and reinterpret layouts inside the
kernel's __call__ via zero-cost cute.make_tensor + cute.make_layout,
removing ~10 us of Python-side .permute() overhead per call.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…dsl_mla_decode

- Expose is_var_split_kv as a public parameter (default False) to control
  whether to use per-batch variable split_kv or uniform scalar split_kv,
  avoiding a torch.full GPU kernel (~5 us) when not needed.
- Add workspace_buffer size assertion to catch undersized buffers early.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 10, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a CuTe-DSL MLA decode implementation, static tile-scheduler helpers, dtype utilities, tests, and benchmarking/back-end wiring; exposes cute_dsl_mla_decode and torch_to_cutlass_dtype in the CuTe-DSL package and registers a "cute-dsl" MLA backend. (50 words)

Changes

Cohort / File(s) Summary
Core MLA Implementation
flashinfer/cute_dsl/mla_decode.py, flashinfer/cute_dsl/mla_helpers.py
New CuTe-DSL MLA decode kernel wrapper and static tile-scheduler: input validation, 3D/4D kv_cache normalization, per-batch split_kv/workspace computation and caching, deterministic kernel compilation/caching, workspace management, persistent/non-persistent tile scheduling, and PyTorch-compatible API cute_dsl_mla_decode.
Public API Exports
flashinfer/cute_dsl/__init__.py
Expose cute_dsl_mla_decode (conditional on CuTe-DSL availability) and export torch_to_cutlass_dtype in __all__.
Utilities
flashinfer/cute_dsl/utils.py
Add torch_to_cutlass_dtype(dtype: torch.dtype) -> cutlass.dtype mapping with validation and TypeError on unsupported dtypes.
Integration & Backends
flashinfer/mla.py
Add "cute-dsl" backend to trtllm_batch_decode_with_kv_cache_mla(), delegate to cute_dsl_mla_decode(), and handle tensor/float bmm scale inputs.
Tests
tests/attention/test_cute_dsl_mla_decode.py
New end-to-end tests and a PyTorch reference implementation covering FP16/BF16/FP8, variable seq lengths, and API integration (environment-gated); includes deterministic cases and tolerance-aware comparisons.
Benchmarking
benchmarks/bench_trtllm_gen_mla.py
Add optional backend arg and CLI flag, backend-aware q_len choices, pass backend through to decode calls, and per-iteration error handling/logging.

Sequence Diagram(s)

sequenceDiagram
    participant Client as PyTorch Caller
    participant API as cute_dsl_mla_decode
    participant Cache as Param/Kernel Cache
    participant Compiler as CUTLASS Compiler
    participant WS as Workspace Manager
    participant Kernel as CUDA/CuTe Kernel

    Client->>API: call(query, kv_cache, params...)
    API->>API: validate & normalize inputs
    API->>Cache: lookup split_kv & workspace_size for batch config
    alt cached
        Cache-->>API: return params
    else
        API->>Compiler: compute split_kv & workspace_size
        Compiler-->>Cache: store params
        Cache-->>API: return params
    end
    API->>Cache: lookup compiled kernel for config
    alt cached
        Cache-->>API: return kernel
    else
        API->>Compiler: compile kernel (symbolic shapes)
        Compiler-->>Cache: store compiled kernel
        Cache-->>API: return kernel
    end
    API->>WS: allocate workspace
    WS-->>API: workspace ptr
    API->>Kernel: launch with tensors & workspace
    Kernel-->>API: produce output
    API-->>Client: return formatted output
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

cute-dsl, run-ci

Suggested reviewers

  • bkryu
  • aleozlx
  • nvmbreughe
  • jimmyzho
  • nv-yunzheq
  • yzh119
  • Anerudhan

Poem

🐰 I hopped through tiles both big and small,
kernels stitched to heed the call,
caches hum and workspaces bloom,
CuTe decoding clears the room.
Hop—compile—launch—replay, I prance and stall.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 30.43% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Add cute dsl mla decode op' accurately summarizes the main change: integrating CuTe DSL MLA decode kernels into FlashInfer.
Description check ✅ Passed The PR description is comprehensive and well-structured, covering all key aspects: purpose, implementation details, test results, and performance benchmarks across multiple configurations.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@limin2021
Copy link
Contributor Author

@flashinfer-bot run

@limin2021
Copy link
Contributor Author

/bot run

@flashinfer-bot
Copy link
Collaborator

@limin2021 is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (1)
flashinfer/cute_dsl/mla_decode.py (1)

45-54: Don’t key the compile cache on unused dynamic dims.

num_heads and seq_len_q never feed KernelClass(...), and the fake tensors already model both dimensions with cute.sym_int(). Keeping them in a @functools.cache key recompiles the same kernel for every (H, q_len) pair.

🔧 Proposed cleanup
 def _get_compiled_mla_kernel(
     is_fp8: bool,
     page_size: int,
-    num_heads: int,
-    seq_len_q: int,
     is_persistent: bool,
     is_var_seq: bool,
     is_var_split_kv: bool,
 ) -> Tuple[Callable, object]:
@@
     tensor_api, kernel_cls = _get_compiled_mla_kernel(
         is_fp8=is_fp8,
         page_size=page_size,
-        num_heads=H,
-        seq_len_q=q_len,
         is_persistent=is_persistent,
         is_var_seq=is_var_seq,
         is_var_split_kv=is_var_split_kv,
     )

Also applies to: 85-93, 360-368

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/cute_dsl/mla_decode.py`:
- Around line 325-333: The slice can produce an undersized tensor silently;
before slicing workspace_buffer validate it is on the correct device, has dtype
torch.uint8, and has at least max(workspace_size, 1) bytes; if not, raise a
clear error (or allocate/resize) so the kernel never receives undersized scratch
memory. Locate calls around
BlackwellMultiHeadLatentAttentionForwardFP16.get_workspace_size(...) and the
variables workspace_buffer and workspace_bytes, check device equality
(workspace_buffer.device vs expected device from inputs), verify
workspace_buffer.dtype is torch.uint8, and assert workspace_buffer.numel() >=
max(workspace_size, 1) before performing the workspace_buffer[:workspace_size]
slice. Ensure error messages reference workspace_size and actual buffer size for
easier debugging.

In `@flashinfer/cute_dsl/mla_helpers.py`:
- Around line 264-287: The MLIR serialization round-trip omits the scheduler's
is_valid state so non-persistent schedulers become valid after deserialization;
update __extract_mlir_values__ to include self.is_valid (e.g., append a boolean
representation) and update __new_from_mlir_values__ to consume and restore that
boolean into the new object's is_valid, adjusting the assert(len(values)) and
the slicing offsets accordingly; ensure advance_to_next_work (which sets
is_valid=False for non-persistent schedulers) continues to work with the new
serialized field so deserialized instances preserve exhausted/non-exhausted
state.

In `@flashinfer/mla.py`:
- Around line 771-790: In the backend == "cute-dsl" branch in mla.py you must
reject unsupported knobs instead of silently dropping them: before calling
cute_dsl_mla_decode, validate the parameters sparse_mla_top_k, sinks, and
skip_softmax_threshold_scale_factor and raise a clear error (e.g., ValueError or
NotImplementedError) if any are set to non-default values; keep the existing
call to cute_dsl_mla_decode (passing query, kv_cache, workspace_buffer,
kv_lora_rank, qk_rope_head_dim, block_tables, seq_lens, max_seq_len,
softmax_scale, output_scale, out) but ensure unsupported options are checked and
rejected with a descriptive message referencing those option names.
- Around line 774-788: The cute-dsl path is receiving a wrongly
scaled/statically-cast value because bmm1_scale and bmm2_scale are converted to
Python floats (via .item()) after bmm1_scale was multiplied by log2e; fix by
passing tensor values (not .item()) to cute_dsl_mla_decode so
CUDA-graph/dynamic-tensor semantics are preserved and the backend sees the
correct numeric scale: for softmax_scale pass the original unmultiplied tensor
(or divide the current bmm1_scale by log2e) as a tensor rather than float, and
likewise stop calling .item()/float() for output_scale (bmm2_scale) — update the
cute_dsl_mla_decode call to use tensor softmax_scale=bmm1_scale_tensor and
output_scale=bmm2_scale_tensor (referencing cute_dsl_mla_decode, bmm1_scale,
bmm2_scale).

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 9fd8c3ec-8099-47fb-891c-f262ddfb6d3c

📥 Commits

Reviewing files that changed from the base of the PR and between fe06b91 and 3c38f20.

📒 Files selected for processing (7)
  • flashinfer/cute_dsl/__init__.py
  • flashinfer/cute_dsl/mla_decode.py
  • flashinfer/cute_dsl/mla_decode_fp16.py
  • flashinfer/cute_dsl/mla_decode_fp8.py
  • flashinfer/cute_dsl/mla_helpers.py
  • flashinfer/mla.py
  • tests/attention/test_cute_dsl_mla_decode.py

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances FlashInfer's capabilities by incorporating highly optimized CuTe DSL MLA decode kernels for Blackwell SM100 architectures. The integration focuses on improving efficiency and performance for multi-head latent attention operations, particularly by streamlining data handling and supporting mixed-precision computations. This update provides a specialized backend that leverages advanced GPU features, leading to faster and more efficient inference for large language models.

Highlights

  • CuTe DSL MLA Decode Kernel Integration: Integrated NVIDIA's CuTe DSL Multi-Head Latent Attention (MLA) decode kernels, specifically optimized for Blackwell SM100 GPUs, into FlashInfer.
  • Data Type Support: The new kernels support both FP16 and FP8 data types, enhancing flexibility and performance for different precision requirements.
  • Performance Optimization: Implemented zero-cost layout reinterpretation for contiguous PyTorch tensors within the kernel, eliminating approximately 10 microseconds of Python-side .permute() overhead per call.
  • AOT Caching: Enabled Ahead-Of-Time (AOT) caching for kernel compilation using --enable-tvm-ffi via compile_and_cache_cute_dsl_kernel, improving kernel load times.
  • New Backend Option: Introduced a new cute-dsl backend option for MLA decoding within the trtllm_batch_decode_with_kv_cache_mla API.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • flashinfer/cute_dsl/init.py
    • Imported and exposed the new cute_dsl_mla_decode function.
  • flashinfer/cute_dsl/mla_decode.py
    • Added a new module for CuTe DSL MLA decode kernel integration, including PyTorch API and kernel compilation logic.
  • flashinfer/cute_dsl/mla_helpers.py
    • Added helper classes and functions for the MLA static tile scheduler used by CuTe DSL kernels.
  • flashinfer/mla.py
    • Updated the trtllm_batch_decode_with_kv_cache_mla function to support a new "cute-dsl" backend.
  • tests/attention/test_cute_dsl_mla_decode.py
    • Added comprehensive unit tests for the CuTe DSL MLA decode kernel, covering FP16, FP8, and variable sequence lengths.
Activity
  • The pull request includes new kernel files for CuTe DSL MLA decode (mla_helpers.py, mla_decode_fp16.py, mla_decode_fp8.py) and a compilation wrapper (mla_decode.py).
  • Extensive testing has been performed, with 18 tests passing for FP16 and FP8 across various batch, head, and sequence length configurations.
  • Standalone run functions in mla_decode_fp16.py and mla_decode_fp8.py have passed.
  • Pre-commit hooks were run and passed on all files, ensuring code quality and style compliance.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request integrates NVIDIA's CuTe DSL MLA decode kernels for Blackwell SM100 GPUs into FlashInfer, adding support for both FP16 and FP8 data types. The changes are well-structured, including new kernel wrappers, helper files, and comprehensive tests. I've identified a minor maintainability issue where the kernel class for calculating split_kv and workspace_size is hardcoded, and I've provided a suggestion to make this dynamic based on the data type. Overall, this is a solid contribution that extends FlashInfer's capabilities to new hardware.

Note: Security Review did not run due to the size of the PR.

- Add get_split_kv_simplified() that computes split_kv without max_seq_len
- Remove is_var_split_kv from public API (hardcode False), eliminating
  torch.full GPU kernel overhead per call
- Remove unused bench_cute_dsl_mla_host_overhead.py

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

♻️ Duplicate comments (1)
flashinfer/cute_dsl/mla_decode.py (1)

304-309: ⚠️ Potential issue | 🟠 Major

Validate required_workspace before slicing.

When workspace_size == 0, an empty buffer satisfies the current assert, but workspace_buffer[: max(workspace_size, 1)] still produces a 0-length tensor. This path also still assumes same-device torch.uint8 storage even though the check is phrased in bytes.

🧯 Suggested fix
-    assert workspace_buffer.numel() >= workspace_size, (
-        f"workspace_buffer too small: {workspace_buffer.numel()} bytes, "
-        f"need {workspace_size} bytes"
-    )
-    workspace_bytes = workspace_buffer[: max(workspace_size, 1)]
+    if workspace_buffer.device != query.device:
+        raise ValueError("workspace_buffer must be on the same device as query")
+    if workspace_buffer.dtype != torch.uint8 or workspace_buffer.dim() != 1:
+        raise ValueError("workspace_buffer must be a 1-D torch.uint8 tensor")
+    required_workspace = max(workspace_size, 1)
+    if workspace_buffer.numel() < required_workspace:
+        raise ValueError(
+            f"workspace_buffer too small: need {required_workspace} bytes, got {workspace_buffer.numel()}"
+        )
+    workspace_bytes = workspace_buffer[:required_workspace]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/cute_dsl/mla_decode.py` around lines 304 - 309, Validate the
requested workspace size and buffer properties before slicing: assert that
workspace_size (the required_workspace) is non-negative and that
workspace_buffer.numel() (taking element size into account) is >=
workspace_size, and also assert workspace_buffer.dtype is torch.uint8 and on the
expected device; handle workspace_size == 0 explicitly (e.g., set
workspace_bytes = workspace_buffer[:0] or return an empty view) instead of using
max(workspace_size, 1) so you don't assume a 1-byte element, and replace the
existing slice workspace_buffer[: max(workspace_size, 1)] with a slice that
respects the validated workspace_size and the buffer's dtype/device in
functions/variables such as workspace_buffer and workspace_size.
🧹 Nitpick comments (1)
tests/attention/bench_cute_dsl_mla_host_overhead.py (1)

114-133: Profile the current decode path, not the removed permute-based path.

Most of these sections still time the old Python-side permutes/transposes and the is_var_split_kv=True launch shape, while cute_dsl_mla_decode() now uses row-major inputs and is_var_split_kv=False. The section breakdown will be misleading until this helper mirrors the current wrapper or is renamed as a legacy-path profiler.

Also applies to: 149-215

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/bench_cute_dsl_mla_host_overhead.py` around lines 114 - 133,
The profiling helpers (query_reshape, kv_reshape, page_table_transpose) are
still timing the old permute/transpose path and the is_var_split_kv=True layout,
but cute_dsl_mla_decode() now expects row-major inputs with
is_var_split_kv=False; update these helpers (and the analogous blocks at
149-215) to produce and measure the same input layout and shapes used by
cute_dsl_mla_decode(), or rename them to indicate they profile the legacy
permute-based path; specifically modify query_reshape, kv_reshape, and
page_table_transpose to return row-major tensors matching the decode wrapper
(remove permute/.t() usage and use the same slicing/shape conventions as
cute_dsl_mla_decode), and ensure any measure calls use is_var_split_kv=False so
the profiler reflects the current decode path.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/cute_dsl/mla_decode.py`:
- Around line 295-296: Normalize the auxiliary index tensors before dispatch:
convert block_tables (used as page_table_k) to the expected integer dtype and
ensure it's contiguous, and likewise make page_table_fake, cache_seqs_fake and
the reused seq_lens contiguous (and cast to the wrapper-expected dtype, e.g.,
torch.int64) before passing them into the TVM/FFI call; update assignments
around page_table_k, page_table_fake, cache_seqs_fake and the seq_lens usage
(also the similar site at the block around lines 327-328) to use .to(dtype=...,
copy=False) and .contiguous() so mismatched dtypes or strided tensors are
rejected at the wrapper boundary instead of failing deep in the FFI.
- Around line 63-71: The cache for _get_compiled_mla_kernel is fragmenting on
num_heads and seq_len_q even though those axes are compiled as cute.sym_int()
and do not affect cute.compile(); remove num_heads and seq_len_q from the
`@functools.cache` function signature (and any cache keys) so the cache is keyed
only by true compile-time flags (e.g., is_fp8, page_size, is_persistent,
is_var_seq, is_var_split_kv), keep using cute.sym_int() inside the compiled
kernel to represent H and q_len as symbolic ints, and ensure callers still pass
num_heads/seq_len_q only when invoking the returned Callable rather than as
cache inputs.

In `@tests/attention/bench_cute_dsl_mla_host_overhead.py`:
- Around line 84-86: The profiler function profile_sections currently passes an
outdated fifth positional argument max_seq_len into the helper
_get_split_kv_and_workspace_size, causing a TypeError; remove max_seq_len from
the helper call(s) inside profile_sections (and the same redundant argument
usage later around the other call at lines ~137-140) so that
_get_split_kv_and_workspace_size is called with the current four parameters
only, leaving the rest of profile_sections (kv_cache, workspace_buffer,
kv_lora_rank, qk_rope_head_dim, block_tables, seq_lens, softmax_scale,
output_scale, num_iters) unchanged.
- Around line 88-96: Remove the stale unused imports causing Ruff F401 by
deleting the unused symbols from the import block: specifically drop Float32,
Int32 and cutlass (and also remove get_num_sm if it is not referenced
elsewhere). Keep only the actually used symbols such as
_get_compiled_mla_kernel, _get_split_kv_and_workspace_size, _LATENT_DIM,
_ROPE_DIM, _MMA_QK_TILER_MN, _MAX_ACTIVE_CLUSTERS, and
BlackwellMultiHeadLatentAttentionForwardFP16 so the import statement no longer
triggers unused-import lint errors.

---

Duplicate comments:
In `@flashinfer/cute_dsl/mla_decode.py`:
- Around line 304-309: Validate the requested workspace size and buffer
properties before slicing: assert that workspace_size (the required_workspace)
is non-negative and that workspace_buffer.numel() (taking element size into
account) is >= workspace_size, and also assert workspace_buffer.dtype is
torch.uint8 and on the expected device; handle workspace_size == 0 explicitly
(e.g., set workspace_bytes = workspace_buffer[:0] or return an empty view)
instead of using max(workspace_size, 1) so you don't assume a 1-byte element,
and replace the existing slice workspace_buffer[: max(workspace_size, 1)] with a
slice that respects the validated workspace_size and the buffer's dtype/device
in functions/variables such as workspace_buffer and workspace_size.

---

Nitpick comments:
In `@tests/attention/bench_cute_dsl_mla_host_overhead.py`:
- Around line 114-133: The profiling helpers (query_reshape, kv_reshape,
page_table_transpose) are still timing the old permute/transpose path and the
is_var_split_kv=True layout, but cute_dsl_mla_decode() now expects row-major
inputs with is_var_split_kv=False; update these helpers (and the analogous
blocks at 149-215) to produce and measure the same input layout and shapes used
by cute_dsl_mla_decode(), or rename them to indicate they profile the legacy
permute-based path; specifically modify query_reshape, kv_reshape, and
page_table_transpose to return row-major tensors matching the decode wrapper
(remove permute/.t() usage and use the same slicing/shape conventions as
cute_dsl_mla_decode), and ensure any measure calls use is_var_split_kv=False so
the profiler reflects the current decode path.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 100622e9-26bc-47bd-b2ed-b5366850c5d2

📥 Commits

Reviewing files that changed from the base of the PR and between 3c38f20 and b3b0f8b.

📒 Files selected for processing (4)
  • flashinfer/cute_dsl/mla_decode.py
  • flashinfer/cute_dsl/mla_decode_fp16.py
  • flashinfer/cute_dsl/mla_decode_fp8.py
  • tests/attention/bench_cute_dsl_mla_host_overhead.py

- Add torch_to_cutlass_dtype() in utils.py for torch.dtype -> cutlass dtype conversion
- Extend mla_decode_fp16.py can_implement() to accept BFloat16
- Refactor mla_decode.py to support float16/bfloat16/float8_e4m3fn via dtype-aware dispatch
- Add BFloat16 parametrization to test_cute_dsl_mla_decode_fp16 test
- Add backend parameter to bench_trtllm_gen_mla.py benchmark
- Remove unused bench_cute_dsl_mla_host_overhead.py

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (1)
tests/attention/test_cute_dsl_mla_decode.py (1)

311-326: Strengthen the public API test beyond a shape check.

This only proves that the new flashinfer.mla backend branch returns a tensor of the expected size. Since that path is part of the new surface in this PR, please compare out against torch_reference_mla (or the direct cute_dsl_mla_decode result) so scale/dtype wiring bugs do not slip through.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_cute_dsl_mla_decode.py` around lines 311 - 326, The test
only asserts shape for the trtllm_batch_decode_with_kv_cache_mla call; instead
compute a reference result (using torch_reference_mla or calling
cute_dsl_mla_decode with the same inputs) and assert numeric equivalence:
compare out to the reference tensor with torch.allclose (or
torch.testing.assert_allclose) using sensible rtol/atol for the dtype to catch
scale/dtype wiring bugs. Ensure you use the same inputs (query, kv_cache,
workspace_buffer, block_tables, seq_lens, max_seq_len, bmm1_scale, bmm2_scale)
and keep the existing shape check if desired.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/bench_trtllm_gen_mla.py`:
- Around line 150-154: The benchmark currently excludes torch.bfloat16 for the
"cute-dsl" backend (setting dtypes in the args.backend conditional), but
cute_dsl/mla_decode.py now supports bfloat16; update the dtypes assignment so
that torch.bfloat16 is included in the dtypes list for the "cute-dsl" branch
(and revise the inline comment to remove the incorrect "only supports float16"
note). Locate the args.backend conditional that sets dtypes in
bench_trtllm_gen_mla.py and add torch.bfloat16 alongside the existing dtypes for
the "cute-dsl" case.

In `@flashinfer/cute_dsl/mla_decode.py`:
- Around line 307-308: Remove the unused local is_fp8 in mla_decode.py: delete
the line that sets is_fp8 = q_dtype == torch.float8_e4m3fn since q_dtype already
controls the dtype dispatch and is_fp8 is never referenced; leave q_dtype =
query.dtype intact and run pre-commit to ensure Ruff F841 no longer flags the
dead variable.
- Around line 296-305: Replace the runtime-conditional asserts in the checks for
input dtypes and shapes with explicit exceptions so they cannot be removed with
python -O: validate that query.dtype is in _SUPPORTED_DTYPES and raise a
TypeError with a clear message if not; check kv_cache.dtype equals query.dtype
and raise a TypeError if it does not; validate shapes B, q_len, H, D_qk and that
D_qk == kv_lora_rank + qk_rope_head_dim and raise a ValueError with a
descriptive message on mismatch; also enforce kv_lora_rank == _LATENT_DIM and
qk_rope_head_dim == _ROPE_DIM with ValueError if violated. Apply the same
replacements to the analogous checks around the other block referenced (the
assertions at lines 349-353) so all guards use stable exceptions instead of
assert.
- Around line 310-323: Before reinterpreting memory for the kernel, validate
that tensors use a dense row-major layout: check that query, normalized kv_cache
(after handling 4D squeeze) and out are contiguous and have expected
strides/dimensions (for a 4D kv_cache ensure the second dimension == 1 before
squeeze); if any check fails, raise a clear error explaining the required
compact layout. Add these checks immediately before the blocks that split query
into q_latent_k/q_rope_k and kv_cache into c_latent_k/c_rope_k (and the
analogous checks in the later block around lines 358-366), and include the
variable names query, kv_cache, out and a reference to the kernel's __call__
reinterpretation in the error message so callers know why the tensor must be
compact.

---

Nitpick comments:
In `@tests/attention/test_cute_dsl_mla_decode.py`:
- Around line 311-326: The test only asserts shape for the
trtllm_batch_decode_with_kv_cache_mla call; instead compute a reference result
(using torch_reference_mla or calling cute_dsl_mla_decode with the same inputs)
and assert numeric equivalence: compare out to the reference tensor with
torch.allclose (or torch.testing.assert_allclose) using sensible rtol/atol for
the dtype to catch scale/dtype wiring bugs. Ensure you use the same inputs
(query, kv_cache, workspace_buffer, block_tables, seq_lens, max_seq_len,
bmm1_scale, bmm2_scale) and keep the existing shape check if desired.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ffbe60f4-3f20-4127-accb-5a06707bee0b

📥 Commits

Reviewing files that changed from the base of the PR and between b3b0f8b and 020fea5.

📒 Files selected for processing (6)
  • benchmarks/bench_trtllm_gen_mla.py
  • flashinfer/cute_dsl/__init__.py
  • flashinfer/cute_dsl/mla_decode.py
  • flashinfer/cute_dsl/mla_decode_fp16.py
  • flashinfer/cute_dsl/utils.py
  • tests/attention/test_cute_dsl_mla_decode.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • flashinfer/cute_dsl/init.py

# Conflicts:
#	tests/attention/test_trtllm_gen_mla.py
@bkryu
Copy link
Collaborator

bkryu commented Mar 17, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !402 has been updated with latest changes, and the CI pipeline #46310895 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #46310895: 7/20 passed

@bkryu bkryu dismissed their stale review March 18, 2026 16:28

Requested changes have been made. Dismissing "requested changes"

Resolve conflict in flashinfer/mla/_core.py by keeping both
is_var_seq and uses_shared_paged_kv_idx parameters.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@limin2021
Copy link
Contributor Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !402 has been updated with latest changes, and the CI pipeline #46946647 is currently running. I'll report back once the pipeline job completes.

- Allow BFloat16 output for FP8 input (matching trtllm-gen backend default)
- FP16/BF16 input defaults to same dtype output; FP8 input defaults to BF16 output
- Add out_dtype parameter to cute_dsl_mla_decode for explicit override
- Add uses_shared_paged_kv_idx=False validation for cute-dsl backend
- Skip unsupported 3D page table tests for cute-dsl

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@limin2021
Copy link
Contributor Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !402 has been updated with latest changes, and the CI pipeline #46949378 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #46949378: 6/20 passed

bkryu
bkryu previously requested changes Mar 25, 2026
Copy link
Collaborator

@bkryu bkryu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @limin2021 -- Requesting one change in the unit test file to skip outside of SM100f

Comment on lines +289 to +291
if backend == "cute-dsl":
if compute_capability[0] < 10:
pytest.skip("cute-dsl MLA requires SM100+")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@limin2021 now most tests seem to be passing with the latest updates to CuTe DSL.

Only failure I am noticing is that we are failing this test on SM120 because the CuTe DSL kernel you added is for SM100f. Can you add a skip here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

The tcgen05 MMA operations only support SM100-SM110. Tighten arch
checks so SM120a (and above) are correctly skipped, and SM110 is
correctly allowed.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@limin2021
Copy link
Contributor Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !402 has been updated with latest changes, and the CI pipeline #47014511 is currently running. I'll report back once the pipeline job completes.

@bkryu bkryu dismissed their stale review March 26, 2026 00:12

Fix has been delivered. Dismissing review request

@limin2021
Copy link
Contributor Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !402 has been updated with latest changes, and the CI pipeline #47024987 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #47024987: 13/20 passed

@bkryu bkryu merged commit 31b63bc into flashinfer-ai:main Mar 26, 2026
29 of 30 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants