Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions benchmarks/bench_trtllm_gen_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,7 @@ def bench_trtllm_mla(
)
args = parser.parse_args()

if args.backend == "cute-dsl":
q_lens = [1, 2, 4]
else:
q_lens = [1, 2, 4, 8, 16]
q_lens = [1, 2, 4, 8, 16]

# Main perf sweep β€” without sinks, same shape grid as the original
# script. Doubling every cell with a sinks pass would explode runtime
Expand Down
89 changes: 77 additions & 12 deletions flashinfer/cute_dsl/attention/monolithic/mla_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"""

import functools
from typing import Callable, Optional, Tuple
from typing import Callable, Optional, Tuple, Union

import cutlass
import cutlass.cute as cute
Expand All @@ -48,11 +48,20 @@ def _get_split_kv_and_workspace_size(
max_active_blocks: int,
) -> Tuple[int, int]:
"""Cache split_kv and workspace_size since they are deterministic for the same params."""
# When folding S_q into heads, the workspace dims are the effective dims
# (num_heads * F, q_len // F). get_workspace_size already pads H<128 to
# 128, so passing num_heads_eff and seq_len_q_eff yields the right size.
mma_qk_tile_m = 128
fold_sq_ratio = BlackwellMultiHeadLatentAttentionForwardFP16.compute_fold_sq_ratio(
H, q_len, mma_qk_tile_m
)
num_heads_eff = H * fold_sq_ratio
seq_len_q_eff = q_len // fold_sq_ratio
split_kv = BlackwellMultiHeadLatentAttentionForwardFP16.get_split_kv_simplified(
B, q_len, max_active_blocks
B, seq_len_q_eff, max_active_blocks
)
workspace_size = BlackwellMultiHeadLatentAttentionForwardFP16.get_workspace_size(
H, q_len, kv_lora_rank, B, split_kv, cutlass.Float32
num_heads_eff, seq_len_q_eff, kv_lora_rank, B, split_kv, cutlass.Float32
)
return split_kv, workspace_size

Expand Down Expand Up @@ -114,6 +123,8 @@ def _get_compiled_mla_kernel(
page_size: int,
kv_lora_rank: int,
qk_rope_head_dim: int,
num_heads: int,
seq_len_q: int,
is_persistent: bool,
is_var_seq: bool,
is_var_split_kv: bool,
Expand Down Expand Up @@ -145,6 +156,14 @@ def _get_compiled_mla_kernel(
cutlass_dtype = torch_to_cutlass_dtype(torch_dtype)
cutlass_out_dtype = torch_to_cutlass_dtype(torch_out_dtype)

# Derive the seq_len_q-into-heads fold factor. F > 1 means the kernel
# repacks the [H, S_q] tile to [H*F, S_q/F] internally so MTP / spec-decoding
# with H < 128 fully populates the 128-wide MMA-M tile.
fold_sq_ratio = KernelClass.compute_fold_sq_ratio(
num_heads, seq_len_q, mma_qk_tiler_mn[0]
)
fold_sq = fold_sq_ratio > 1

kernel_obj = KernelClass(
acc_dtype=cutlass.Float32,
lse_dtype=cutlass.Float32,
Expand All @@ -159,6 +178,9 @@ def _get_compiled_mla_kernel(
is_var_seq=is_var_seq,
is_var_split_kv=is_var_split_kv,
enable_pdl=enable_pdl,
num_heads=num_heads,
seq_len_q=seq_len_q,
fold_sq=fold_sq,
)

# All dimensions as sym_int β€” this matches the original kernel's use of
Expand Down Expand Up @@ -298,7 +320,9 @@ def cute_dsl_mla_decode(
out_dtype: Optional[torch.dtype] = None,
is_var_seq: bool = True,
enable_pdl: Optional[bool] = None,
) -> torch.Tensor:
lse: Optional[torch.Tensor] = None,
return_lse: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""CuTe DSL MLA decode kernel for Blackwell SM100.

Parameters
Expand Down Expand Up @@ -342,11 +366,26 @@ def cute_dsl_mla_decode(
enable_pdl : Optional[bool], default=None
Whether to enable Programmatic Dependent Launch (PDL).
If None, auto-detects based on device capability.
lse : Optional[torch.Tensor]
Pre-allocated Log-Sum-Exp buffer. Accepted shapes (dtype must be
``torch.float32``):

* ``[B, q_len, H]`` (native kernel layout, no reshape), or
* ``[B * q_len, H]`` (matches ``trtllm-gen`` shape; the wrapper
reshapes via ``.view`` to the native layout).

If ``return_lse`` is True and this is None, a buffer of the native
``[B, q_len, H]`` shape is allocated internally.
return_lse : bool
Whether to return LSE values. When True, the function returns
``(out, lse)`` (the ``lse`` tensor returned is in whatever shape
the caller supplied; if no ``lse`` was supplied, ``[B, q_len, H]``).

Returns
-------
torch.Tensor
Output tensor [B, q_len, H, kv_lora_rank].
torch.Tensor or Tuple[torch.Tensor, torch.Tensor]
Output tensor [B, q_len, H, kv_lora_rank] when ``return_lse=False``;
otherwise ``(output, lse)``.
"""
supported_dtypes = {torch.float16, torch.bfloat16, torch.float8_e4m3fn}
assert query.dtype in supported_dtypes, (
Expand Down Expand Up @@ -420,7 +459,24 @@ def cute_dsl_mla_decode(
)

# LSE: contiguous [B, q_len, H]. Kernel reinterprets to [H, q_len, B].
lse_k = torch.empty((B, q_len, H), dtype=torch.float32, device=query.device)
# If caller supplied an `lse` buffer in either the native 3D shape or the
# 2D trtllm-gen shape [B*q_len, H], reshape to the 3D native layout for
# the kernel call.
if lse is not None:
if lse.dtype != torch.float32:
raise ValueError(f"lse must be torch.float32, got {lse.dtype}")
if lse.shape == (B, q_len, H):
lse_k = lse
elif lse.shape == (B * q_len, H):
# Native kernel layout is 3D; .view is zero-cost when contiguous.
lse_k = lse.view(B, q_len, H)
else:
raise ValueError(
f"lse must have shape (B, q_len, H)=({B}, {q_len}, {H}) "
f"or (B*q_len, H)=({B * q_len}, {H}); got {tuple(lse.shape)}"
)
else:
lse_k = torch.empty((B, q_len, H), dtype=torch.float32, device=query.device)
Comment thread
Observer007 marked this conversation as resolved.

# cache_seqs: per-batch sequence lengths (skip .to() if already int32)
cache_seqs = seq_lens if seq_lens.dtype == torch.int32 else seq_lens.to(torch.int32)
Expand Down Expand Up @@ -457,6 +513,8 @@ def cute_dsl_mla_decode(
page_size=page_size,
kv_lora_rank=kv_lora_rank,
qk_rope_head_dim=qk_rope_head_dim,
num_heads=H,
seq_len_q=q_len,
is_persistent=is_persistent,
is_var_seq=is_var_seq,
is_var_split_kv=is_var_split_kv,
Expand All @@ -482,9 +540,16 @@ def cute_dsl_mla_decode(
Float32(output_scale),
)

# If out was provided, kernel already wrote into it β€” return directly.
if out is not None:
return out
# Pick the output to return: caller-provided buffer (already written
# in-place) or the freshly allocated o_k. o_k is [B, q_len, H, D],
# matching trtllm-gen output shape.
out_tensor = out if out is not None else o_k

if return_lse:
# Return the lse tensor in the shape the caller supplied (or 3D when
# we allocated it). When caller passed 2D, lse_k is a .view into
# that same memory, so returning the original `lse` keeps the
# caller's expected shape.
return out_tensor, (lse if lse is not None else lse_k)

# o_k is [B, q_len, H, D] β€” return as-is to match trtllm-gen output shape.
return o_k
return out_tensor
Loading
Loading