Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
0129aeb
feat: Add CuTe DSL MLA decode kernel for Blackwell SM100
limin2021 Mar 10, 2026
d95da37
style: Fix pre-commit lint/format/type errors in MLA decode kernel files
limin2021 Mar 10, 2026
5cc2f74
chore: Update copyright year to 2026
limin2021 Mar 10, 2026
3c38f20
feat: Add dtype assertions and FP8 tests for cute_dsl_mla_decode
limin2021 Mar 10, 2026
7b12c25
perf: Reduce host overhead in cute_dsl_mla_decode
limin2021 Mar 10, 2026
a8c3b5a
perf: Move permute logic into kernel __call__ to eliminate Python-sid…
limin2021 Mar 10, 2026
7420b6b
feat: Add is_var_split_kv parameter and workspace size check to cute_…
limin2021 Mar 10, 2026
c480145
style: Fix trailing whitespace and ruff formatting
limin2021 Mar 10, 2026
b3b0f8b
perf: Simplify split_kv computation and remove is_var_split_kv parameter
limin2021 Mar 11, 2026
020fea5
feat: Add BFloat16 support to CuTe DSL MLA decode kernel
limin2021 Mar 11, 2026
842b624
minor.
limin2021 Mar 11, 2026
deee819
format
limin2021 Mar 11, 2026
5cef493
refactor: Replace hardcoded MLA config constants with function parame…
limin2021 Mar 11, 2026
2ece5a7
refactor: Split can_implement check from kernel compilation to avoid …
limin2021 Mar 11, 2026
98eae77
fix: Align cute-dsl output shape with trtllm-gen and fix tensor scale…
limin2021 Mar 11, 2026
a4a8723
fix workspace None issue.
limin2021 Mar 11, 2026
c2e769e
fix: align assumed_align values with kernel's from_dlpack settings
limin2021 Mar 11, 2026
b66eb4e
perf: add divisibility hints and opt-level 2 for CuTe DSL MLA compila…
limin2021 Mar 12, 2026
114460f
format.
limin2021 Mar 12, 2026
104c9fe
feat: add is_var_seq parameter for auto persistent/non-persistent str…
limin2021 Mar 12, 2026
0999000
doc update
limin2021 Mar 12, 2026
80a93fd
fix: address review feedback for CuTe DSL MLA decode
limin2021 Mar 12, 2026
9dddc5b
fix: resolve merge conflict and validate unsupported args in cute-dsl…
limin2021 Mar 13, 2026
72bec07
fix: add compat shim for cutlass-dsl setmaxregister API
limin2021 Mar 16, 2026
a913f90
refactor: move MLA CuTe DSL kernels to flashinfer/mla/cute_dsl/
limin2021 Mar 16, 2026
82cc2c3
fix: update copyright year to 2026 in flashinfer/mla/
limin2021 Mar 16, 2026
2423c57
fix: update copyright years to 2026
limin2021 Mar 16, 2026
037eab6
fix: add compat shim for cutlass-dsl get_max_tmem_alloc_cols API
limin2021 Mar 17, 2026
c09a2cc
feat: add cute-dsl backend to test_trtllm_gen_mla uniform testing
limin2021 Mar 17, 2026
96fb34e
feat: add cute-dsl backend support for MLA microbenchmark
limin2021 Mar 17, 2026
0455e5c
Merge remote-tracking branch 'origin/main' into add_cute_dsl_mla_new
limin2021 Mar 17, 2026
fcc1b78
Merge origin/main into add_cute_dsl_mla_new
limin2021 Mar 25, 2026
81ec3aa
feat: support flexible output dtype for CuTe DSL MLA FP8 decode kernel
limin2021 Mar 25, 2026
4d221fd
fix: skip CuTe DSL MLA tests on unsupported archs (SM120+)
limin2021 Mar 26, 2026
adafa35
Merge remote-tracking branch 'origin/main' into add_cute_dsl_mla_new
limin2021 Mar 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Currently supports testing attention, gemm, fused MOE, normalization, quantizati
- `BatchPrefillWithRaggedKVCacheWrapper` - Prefill attention with ragged KV cache.
- Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` (cudnn-native) and `trtllm_ragged_attention_deepseek`.
- `BatchMLAPagedAttentionWrapper` - MLA attention proposed in DeepSeek series of models.
- Also supports computationally similar `trtllm_batch_decode_with_kv_cache_mla`.
- Also supports computationally similar `trtllm_batch_decode_with_kv_cache_mla` (trtllm-native) and CuTe DSL MLA decode kernel (cute-dsl, SM100+).
- GEMM:
- `gemm_fp8_nt_groupwise` - GEMM with FP8 data types using groupwise scaling.
- `group_gemm_fp8_nt_groupwise` - Group GEMM with FP8 data types using groupwise scaling.
Expand Down Expand Up @@ -191,7 +191,7 @@ The output CSV will contain detailed metrics including:
| `--verbose`, `-v` | Print additional information (can be used multiple times for more verbosity, e.g. `-vv`) |
| `--case_tag` | Optional tag for the test case, useful for annotating or filtering results in the output CSV. |
| `--generate_repro_command`| If set, prints a reproducer command for the test case and stores it in the output CSV. |
| `--backends` | Space-separated list of backends to test, e.g. fa2, fa2_tc, fa3, auto, cudnn, cudnn-native, cutlass, trtllm, trtllm-gen, trtllm-native, cublas. (`auto` currently supported for `BatchDecodeWithPagedKVCacheWrapper` and `BatchPrefillWithPagedKVCacheWrapper`.)|
| `--backends` | Space-separated list of backends to test, e.g. fa2, fa2_tc, fa3, auto, cudnn, cudnn-native, cutlass, trtllm, trtllm-gen, trtllm-native, cute-dsl, cublas. (`auto` currently supported for `BatchDecodeWithPagedKVCacheWrapper` and `BatchPrefillWithPagedKVCacheWrapper`.)|

### Attention Flags
| Flag | Description |
Expand Down Expand Up @@ -464,7 +464,7 @@ Legend:
| **BatchDecodeWithPagedKVCacheWrapper** | fa2 | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn, trtllm-gen, trtllm-native | fa2, fa2_tc, cudnn, trtllm-gen, trtllm-native | fa2, fa2_tc, cudnn |
| **BatchPrefillWithPagedKVCacheWrapper** | | fa2, cudnn, cudnn-native | fa2, cudnn, cudnn-native | fa2, cudnn, cudnn-native | fa2, fa3, cudnn, cudnn-native | fa2, cudnn, cudnn-native, trtllm-gen, trtllm-native | fa2, cudnn, cudnn-native, trtllm-gen, trtllm-native | fa2, cudnn, cudnn-native |
| **BatchPrefillWithRaggedKVCacheWrapper** | | fa2, cudnn, cudnn-native | fa2, cudnn, cudnn-native | fa2, cudnn, cudnn-native | fa2, fa3, cudnn, cudnn-native | fa2, cudnn, cudnn-native, cutlass, trtllm-native | fa2, cudnn, cudnn-native, cutlass, trtllm-native | fa2, cudnn, cudnn-native |
| **BatchMLAPagedAttentionWrapper** | | fa2 | fa2 | fa2 | fa2, fa3 | fa2, cutlass, trtllm-native | fa2, cutlass, trtllm-native | fa2 |
| **BatchMLAPagedAttentionWrapper** | | fa2 | fa2 | fa2 | fa2, fa3 | fa2, cutlass, trtllm-native, cute-dsl | fa2, cutlass, trtllm-native | fa2 |
| **gemm_fp8_nt_groupwise** | | | | | | cutlass | cutlass | |
| **group_gemm_fp8_nt_groupwise** | | | | | | cutlass | cutlass | |
| **bmm_fp8** | | | | cudnn, cublas | cudnn, cublas | cudnn, cublas, cutlass | cudnn, cublas, cutlass | cudnn, cublas |
Expand Down
48 changes: 42 additions & 6 deletions benchmarks/bench_trtllm_gen_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
kv_lora_rank = 512


def bench_trtllm_mla(batch_size, q_len_per_request, seq_len, page_size, dtype):
def bench_trtllm_mla(
batch_size, q_len_per_request, seq_len, page_size, dtype, backend="auto"
):
torch.manual_seed(42)
device = "cuda:0"

Expand Down Expand Up @@ -81,6 +83,7 @@ def bench_trtllm_mla(batch_size, q_len_per_request, seq_len, page_size, dtype):
max_seq_len=max_seq_len,
bmm1_scale=1.0 / ((128 + 64) ** 0.5),
bmm2_scale=1.0,
backend=backend,
)
# benchmark
measurements = bench_gpu_time(
Expand All @@ -96,6 +99,7 @@ def bench_trtllm_mla(batch_size, q_len_per_request, seq_len, page_size, dtype):
max_seq_len=max_seq_len,
bmm1_scale=1.0 / ((128 + 64) ** 0.5),
bmm2_scale=1.0,
backend=backend,
),
dry_run_iters=5,
repeat_iters=30,
Expand Down Expand Up @@ -126,19 +130,51 @@ def bench_trtllm_mla(batch_size, q_len_per_request, seq_len, page_size, dtype):
* q_len_per_request
)
print(
f"batch_size={batch_size}, q_len_per_request={q_len_per_request}, seq_len={seq_len}, num_q_heads={num_q_heads}, qk_nope_head_dim={qk_nope_head_dim}, qk_rope_head_dim={qk_rope_head_dim}, kv_lora_rank={kv_lora_rank}, page_size={page_size}"
f"backend={backend}, batch_size={batch_size}, q_len_per_request={q_len_per_request}, seq_len={seq_len}, num_q_heads={num_q_heads}, qk_nope_head_dim={qk_nope_head_dim}, qk_rope_head_dim={qk_rope_head_dim}, kv_lora_rank={kv_lora_rank}, page_size={page_size}"
)
print(f"execution time: {ms:.4f} ms")
print(f"memory bandwidth: {total_mem_bytes / ms / 1e6:.2f} GB/s")
print(f"FLOPs: {flops / ms / 1e9:.2f} TFLOPs/s")


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description="Benchmark trtllm MLA decode")
parser.add_argument(
"--backend",
type=str,
default="auto",
help="Backend to use (auto, trtllm-gen, cute-dsl)",
)
args = parser.parse_args()

if args.backend == "cute-dsl":
q_lens = [1, 2, 4]
else:
q_lens = [1, 2, 4, 8, 16]

for dtype in [torch.bfloat16, torch.float8_e4m3fn]:
for page_size in [32, 64]:
for batch_size in [1, 2, 4, 16, 32, 64, 128, 256, 512, 768, 1024]:
for seq_len in [1024, 4096, 8192]:
for q_len_per_request in [1, 2, 4, 8, 16]:
bench_trtllm_mla(
batch_size, q_len_per_request, seq_len, page_size, dtype
)
for q_len_per_request in q_lens:
try:
bench_trtllm_mla(
batch_size,
q_len_per_request,
seq_len,
page_size,
dtype,
backend=args.backend,
)
except ValueError as e:
print(f"SKIPPED: {e}")
print()
except Exception as e:
print(
f"ERROR: batch_size={batch_size}, q_len={q_len_per_request}, "
f"seq_len={seq_len}, page_size={page_size}, dtype={dtype}, "
f"backend={args.backend}: {type(e).__name__}: {e}"
)
print()
23 changes: 23 additions & 0 deletions benchmarks/routines/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def parse_attention_args(line, parser):
"trtllm-gen",
"trtllm-native",
"trtllm-gen-native", # Deprecated, will be removed in future
"cute-dsl",
],
help="Kernel backends to test. Default: fa2. backend=auto is only supported for BatchDecodeWithPagedKVCacheWrapper and BatchPrefillWithPagedKVCacheWrapper.",
)
Expand Down Expand Up @@ -2152,6 +2153,13 @@ def testBatchMLAPagedAttentionWrapper(args):
remove_trtllm_native = True
if remove_trtllm_native:
backends.remove("trtllm-native")
if "cute-dsl" in backends:
remove_cute_dsl = False
if num_qo_heads < 128:
print("[INFO] cute-dsl MLA backend requires num_heads >= 128. Skipping.")
remove_cute_dsl = True
if remove_cute_dsl:
backends.remove("cute-dsl")
if len(backends) == 0:
print("[ERROR] No backends to test. Exiting.")
return res
Expand Down Expand Up @@ -2337,6 +2345,21 @@ def run_backend_wrapper(
bmm1_scale=sm_scale,
bmm2_scale=1.0,
).squeeze(1)
elif backend == "cute-dsl":
return flashinfer.mla.trtllm_batch_decode_with_kv_cache_mla(
query=q.unsqueeze(1),
kv_cache=kv_cache.unsqueeze(1),
workspace_buffer=workspace_buffer,
qk_nope_head_dim=128,
kv_lora_rank=head_dim_ckv,
qk_rope_head_dim=head_dim_kpe,
block_tables=block_tables,
seq_lens=actual_seq_lens_kv.flatten(),
max_seq_len=s_kv,
bmm1_scale=sm_scale,
bmm2_scale=1.0,
backend="cute-dsl",
).squeeze(1)
else:
print(f"[ERROR] Unsupported backend: {backend}")
return None
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/routines/flashinfer_benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,12 +325,13 @@ def dtype_str_to_torch_dtype(dtype_str):
},
"BatchMLAPagedAttentionWrapper": {
# NOTE: trtllm-native calls trtllm_batch_decode_with_kv_cache_mla
# NOTE: cute-dsl calls trtllm_batch_decode_with_kv_cache_mla(backend="cute-dsl")
"7.5": [],
"8.0": ["fa2"],
"8.6": ["fa2"],
"8.9": ["fa2"],
"9.0": ["fa2", "fa3"],
"10.0": ["fa2", "cutlass", "trtllm-native"],
"10.0": ["fa2", "cutlass", "trtllm-native", "cute-dsl"],
"10.3": ["fa2", "cutlass", "trtllm-native"],
"12.0": ["fa2"],
},
Expand Down
2 changes: 2 additions & 0 deletions flashinfer/cute_dsl/__init__.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are moving kernels out of this directory, maybe create a module

flashinfer/mla/ and move these kernels under flashinfer/mla/cute_dsl?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
is_cute_dsl_available,
make_ptr,
get_cutlass_dtype,
torch_to_cutlass_dtype,
get_num_sm,
convert_sf_to_mma_layout,
convert_sf_from_mma_layout,
Expand Down Expand Up @@ -77,6 +78,7 @@
"is_cute_dsl_available",
"make_ptr",
"get_cutlass_dtype",
"torch_to_cutlass_dtype",
"get_num_sm",
# Scale factor layout conversion utilities
"convert_sf_to_mma_layout",
Expand Down
14 changes: 14 additions & 0 deletions flashinfer/cute_dsl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,20 @@ def get_cutlass_dtype(dtype: str) -> cutlass.dtype:
return dtype_map[dtype]


def torch_to_cutlass_dtype(dtype: torch.dtype) -> cutlass.dtype:
"""Return the corresponding cutlass dtype for the given torch.dtype."""
dtype_map = {
torch.float16: cutlass.Float16,
torch.bfloat16: cutlass.BFloat16,
torch.float32: cutlass.Float32,
torch.float8_e5m2: cutlass.Float8E5M2,
torch.float8_e4m3fn: cutlass.Float8E4M3FN,
}
if dtype not in dtype_map:
raise TypeError(f"{dtype} is not supported by cutlass")
return dtype_map[dtype]


def cutlass_to_torch_dtype(cutlass_dtype):
"""
Return the corresponding torch.dtype per the given DSL type
Expand Down
15 changes: 15 additions & 0 deletions flashinfer/mla/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2026 by FlashInfer team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ._core import * # noqa: F401,F403
79 changes: 70 additions & 9 deletions flashinfer/mla.py → flashinfer/mla/_core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright (c) 2023 by FlashInfer team.
Copyright (c) 2026 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -20,10 +20,10 @@

import torch

from .api_logging import flashinfer_api
from .jit import gen_batch_mla_module, gen_trtllm_gen_fmha_module, setup_cubin_loader
from .jit.mla import gen_mla_module
from .utils import (
from ..api_logging import flashinfer_api
from ..jit import gen_batch_mla_module, gen_trtllm_gen_fmha_module, setup_cubin_loader
from ..jit.mla import gen_mla_module
from ..utils import (
MaskMode,
_check_block_tables_shape,
check_shape_dtype_device,
Expand All @@ -33,7 +33,7 @@
get_device_sm_count,
log2e,
)
from .xqa import xqa_mla
from ..xqa import xqa_mla


def _check_cutlass_shape(q_nope_pe, ckv_kpe_cache, kv_len, page_table):
Expand Down Expand Up @@ -607,6 +607,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
skip_softmax_threshold_scale_factor: Optional[float] = None,
enable_pdl: bool | None = None,
backend: str = "auto",
is_var_seq: bool = True,
uses_shared_paged_kv_idx: bool = True,
) -> torch.Tensor:
"""
Expand All @@ -628,20 +629,26 @@ def trtllm_batch_decode_with_kv_cache_mla(
max_seq_len: max sequence length for kv_cache
out: output tensor, if not provided, will be allocated internally
bmm1_scale: fused scale for mla bmm1 input.
when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32.
When using ``trtllm-gen`` backend, it can be a ``torch.Tensor`` with dtype ``torch.float32``.
When using ``cute-dsl`` backend, only ``float`` values are supported.
bmm2_scale: fused scale for mla bmm2 input.
when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32.
When using ``trtllm-gen`` backend, it can be a ``torch.Tensor`` with dtype ``torch.float32``.
When using ``cute-dsl`` backend, only ``float`` values are supported.
sinks: additional value per head in the denominator of the softmax.
skip_softmax_threshold_scale_factor: threshold scale factor for skipping softmax operations.
Providing a value for this parameter enables skip-softmax sparsity as described in: https://arxiv.org/abs/2512.12087
If no value is provided, then standard attention is used.
Setting the threshold to a higher value generally increases kernel performance at the cost of accuracy degradation.
The actual threshold value equals the provided threshold_scale_factor divided by the context length.
backend : str = "auto"
The implementation backend, could be ``auto``/``xqa`` or ``trtllm-gen``. Defaults to ``auto``.
The implementation backend, could be ``auto``/``xqa``, ``trtllm-gen``, or ``cute-dsl``. Defaults to ``auto``.
When set to ``auto``, the backend will be chosen based on the device architecture and kernel availability.
For sm_100 and sm_103 (blackwell architecture), ``auto`` will choose ``trtllm-gen`` backend.
For sm_120 (blackwell architecture), ``auto`` will choose ``xqa`` backend.
is_var_seq : bool
Whether the sequence length is variable.
If True, the sequence length is variable.
Otherwise,the sequence length is fixed for all the requests in the batch.
uses_shared_paged_kv_idx : bool = True
Whether the K and V page indices are shared as a unified index.
True (default) uses vLLM/FlashInfer layout with a 2D page table.
Expand Down Expand Up @@ -790,6 +797,60 @@ def trtllm_batch_decode_with_kv_cache_mla(
)

return out
elif backend == "cute-dsl":
cc = get_compute_capability(query.device)
if cc[0] < 10:
raise RuntimeError(
f"cute-dsl backend (MLA decode kernel) requires SM100+, got SM{cc[0]}{cc[1]}"
)
from .cute_dsl import cute_dsl_mla_decode

if isinstance(bmm1_scale, torch.Tensor):
raise ValueError(
"cute-dsl backend (MLA decode kernel) does not support tensor bmm1_scale, "
"please pass a float value"
)
if isinstance(bmm2_scale, torch.Tensor):
raise ValueError(
"cute-dsl backend (MLA decode kernel) does not support tensor bmm2_scale, "
"please pass a float value"
)
if sinks is not None:
raise ValueError(
"cute-dsl backend (MLA decode kernel) does not support sinks"
)
if sparse_mla_top_k > 0:
raise ValueError(
"cute-dsl backend (MLA decode kernel) does not support sparse_mla_top_k"
)
if enable_pdl is not None:
raise ValueError(
"cute-dsl backend (MLA decode kernel) does not support enable_pdl"
)
if skip_softmax_threshold_scale_factor is not None:
raise ValueError(
"cute-dsl backend (MLA decode kernel) does not support skip_softmax_threshold_scale_factor"
)
if not uses_shared_paged_kv_idx:
raise ValueError(
"cute-dsl backend (MLA decode kernel) does not support separate KV page indices "
"(uses_shared_paged_kv_idx=False)"
)

return cute_dsl_mla_decode(
query=query,
kv_cache=kv_cache,
workspace_buffer=workspace_buffer,
kv_lora_rank=kv_lora_rank,
qk_rope_head_dim=qk_rope_head_dim,
block_tables=block_tables,
seq_lens=seq_lens,
max_seq_len=max_seq_len,
softmax_scale=bmm1_scale,
output_scale=bmm2_scale,
out=out,
is_var_seq=is_var_seq,
)
else:
raise ValueError(f"Backend {backend} not supported")

Expand Down
30 changes: 30 additions & 0 deletions flashinfer/mla/cute_dsl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) 2026 by FlashInfer team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
CuTe DSL MLA Decode Kernels for Blackwell SM100.
"""

from flashinfer.cute_dsl.utils import is_cute_dsl_available

if is_cute_dsl_available():
from .mla_decode import cute_dsl_mla_decode

__all__ = [
"is_cute_dsl_available",
]

if is_cute_dsl_available():
__all__ += [
"cute_dsl_mla_decode",
]
Loading
Loading