From 2ac49c1d4bcf4831ae62a13fffee86299250b9f8 Mon Sep 17 00:00:00 2001 From: Prasun Gera Date: Tue, 12 May 2026 01:22:43 -0700 Subject: [PATCH 1/9] Restore monolithic CuTe-DSL MLA decode alongside modular, gated by cute_dsl_impl= (AI-assisted) PR #2805 refactored the monolithic CuTe-DSL MLA decode kernel into a modular structure and removed the original implementation. The original authors want it kept available because the modular path is still maturing. Restore it under the cute-dsl backend (no new backend name) and let the user pick: flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( ..., backend="cute-dsl", cute_dsl_impl="auto" | "modular" | "monolithic") Layout: flashinfer/cute_dsl/attention/ monolithic/ - restored kernels (verbatim from before #2805, relocated to live next to the modular code). Includes the H<128 / Kimi K2.5 fix from #3235 backported (workspace pads H to max(H, 128) when split_kv != 1; can_implement no longer rejects H<128). wrappers/ - existing modular standalone + wrapper. mla_dispatch.py - new dispatcher in front of both impls. Dispatcher contract: - "auto" (default): monolithic, but auto-promotes to modular when a modular-only feature is requested (currently sinks). - "modular" : strict, always modular. - "monolithic" : strict, raises ValueError if a modular-only feature is requested rather than silently substituting. The dispatcher strips modular-only kwargs (sinks=None) before forwarding to monolithic, so callers can pass sinks= unconditionally without breaking the monolithic path. Sinks support: - trtllm_batch_decode_with_kv_cache_mla(sinks=...) on backend= "cute-dsl" now constructs an AttentionWithSink variant inside the modular standalone, instead of being rejected at the API boundary. - AttentionWithSink gains value-based __hash__/__eq__ keyed on (data_ptr, shape, dtype) so @functools.cache on _compile_mla_kernel correctly reuses compiled kernels across invocations with the same sinks tensor. Without this, a fresh variant per call hashed by object identity, JIT-recompiled the kernel on every iteration, and made cuda-graph + sinks bench measurements appear to hang. Tests: - test_cute_dsl_mla_decode.py: existing standalone and public-API tests now parametrize over modular/monolithic via a cute_dsl_impl fixture; new minimal sinks tests pin the auto/modular dispatch branches and the monolithic+sinks ValueError contract. Wrapper sinks numerics remain covered by the pre-existing test_cute_dsl_mla_decode_attention_sink. - test_trtllm_gen_mla.py: comment near the cute-dsl skip refreshed to reflect the dispatcher's cute_dsl_impl behaviour. Bench: - bench_trtllm_gen_mla.py grows a focused 6-cell with_sinks=True sub-sweep (B in {1,16,128} x S in {1024,8192} at q_len=1, page=64, bf16) on top of the existing main sweep, instead of doubling the full grid. Argument list deduplicated into a common_kwargs dict so warmup and benchmark calls cannot drift. --- benchmarks/bench_trtllm_gen_mla.py | 87 +- flashinfer/cute_dsl/attention/__init__.py | 9 +- .../cute_dsl/attention/fusion/variant.py | 22 + flashinfer/cute_dsl/attention/mla_dispatch.py | 126 + .../cute_dsl/attention/monolithic/__init__.py | 41 + .../attention/monolithic/mla_decode.py | 490 ++ .../attention/monolithic/mla_decode_fp16.py | 4259 +++++++++++++++++ .../attention/monolithic/mla_decode_fp8.py | 4230 ++++++++++++++++ .../attention/monolithic/mla_helpers.py | 304 ++ .../cute_dsl/attention/wrappers/batch_mla.py | 53 +- flashinfer/mla/_core.py | 44 +- tests/attention/test_cute_dsl_mla_decode.py | 132 +- tests/attention/test_trtllm_gen_mla.py | 5 + 13 files changed, 9770 insertions(+), 32 deletions(-) create mode 100644 flashinfer/cute_dsl/attention/mla_dispatch.py create mode 100644 flashinfer/cute_dsl/attention/monolithic/__init__.py create mode 100644 flashinfer/cute_dsl/attention/monolithic/mla_decode.py create mode 100644 flashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.py create mode 100644 flashinfer/cute_dsl/attention/monolithic/mla_decode_fp8.py create mode 100644 flashinfer/cute_dsl/attention/monolithic/mla_helpers.py diff --git a/benchmarks/bench_trtllm_gen_mla.py b/benchmarks/bench_trtllm_gen_mla.py index a739ccc21b..f846c70633 100644 --- a/benchmarks/bench_trtllm_gen_mla.py +++ b/benchmarks/bench_trtllm_gen_mla.py @@ -11,8 +11,21 @@ def bench_trtllm_mla( - batch_size, q_len_per_request, seq_len, page_size, dtype, backend="auto" + batch_size, + q_len_per_request, + seq_len, + page_size, + dtype, + backend="auto", + with_sinks=False, ): + """Benchmark a single (config, backend, sinks?) cell. + + `with_sinks=True` allocates a per-head sinks tensor and passes it via + the `sinks=` kwarg. Currently supported by trtllm-gen and cute-dsl + on MLA decode; xqa MLA rejects sinks. Output line includes the sinks + flag so two cells can be diffed cleanly. + """ torch.manual_seed(42) device = "cuda:0" @@ -69,9 +82,13 @@ def bench_trtllm_mla( # todo(Yingyi): calculate the actual size of workspace buffer workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) - # Run decode-MLA - # warmup - flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( + sinks = ( + torch.randn(num_q_heads, dtype=torch.float32, device=device) + if with_sinks + else None + ) + + common_kwargs = dict( query=query, kv_cache=kv_cache.unsqueeze(1), workspace_buffer=workspace_buffer, @@ -83,23 +100,24 @@ def bench_trtllm_mla( max_seq_len=max_seq_len, bmm1_scale=1.0 / ((128 + 64) ** 0.5), bmm2_scale=1.0, + sinks=sinks, backend=backend, ) + + # Run decode-MLA + # warmup + flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(**common_kwargs) # benchmark + # NOTE: cold_l2_cache=True is requested but silently degrades to warm-L2 + # because the inputs are captured in the lambda's closure rather than + # passed via input_kwargs= (bench_gpu_time only flushes/rotates GPU + # tensors it can find by introspecting its own args). We accept this: + # both backends are measured under identical warm-L2 conditions, so + # cross-backend comparisons remain fair, only absolute GB/s numbers + # are optimistic vs. a real cold-cache serving workload. measurements = bench_gpu_time( lambda: flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( - query=query, - kv_cache=kv_cache.unsqueeze(1), - workspace_buffer=workspace_buffer, - qk_nope_head_dim=qk_nope_head_dim, - kv_lora_rank=kv_lora_rank, - qk_rope_head_dim=qk_rope_head_dim, - block_tables=block_tables, - seq_lens=seq_lens_tensor, - max_seq_len=max_seq_len, - bmm1_scale=1.0 / ((128 + 64) ** 0.5), - bmm2_scale=1.0, - backend=backend, + **common_kwargs ), dry_run_iters=5, repeat_iters=30, @@ -130,7 +148,7 @@ def bench_trtllm_mla( * q_len_per_request ) print( - f"backend={backend}, batch_size={batch_size}, q_len_per_request={q_len_per_request}, seq_len={seq_len}, num_q_heads={num_q_heads}, qk_nope_head_dim={qk_nope_head_dim}, qk_rope_head_dim={qk_rope_head_dim}, kv_lora_rank={kv_lora_rank}, page_size={page_size}" + f"backend={backend}, sinks={with_sinks}, batch_size={batch_size}, q_len_per_request={q_len_per_request}, seq_len={seq_len}, num_q_heads={num_q_heads}, qk_nope_head_dim={qk_nope_head_dim}, qk_rope_head_dim={qk_rope_head_dim}, kv_lora_rank={kv_lora_rank}, page_size={page_size}" ) print(f"execution time: {ms:.4f} ms") print(f"memory bandwidth: {total_mem_bytes / ms / 1e6:.2f} GB/s") @@ -154,6 +172,11 @@ def bench_trtllm_mla( else: q_lens = [1, 2, 4, 8, 16] + # Main perf sweep — without sinks, same shape grid as the original + # script. Doubling every cell with a sinks pass would explode runtime + # without adding signal: sinks adds a small per-call overhead that's + # uniform across shapes, so a focused sub-sweep below is enough to + # characterise it. for dtype in [torch.bfloat16, torch.float8_e4m3fn]: for page_size in [32, 64]: for batch_size in [1, 2, 4, 16, 32, 64, 128, 256, 512, 768, 1024]: @@ -167,6 +190,7 @@ def bench_trtllm_mla( page_size, dtype, backend=args.backend, + with_sinks=False, ) except ValueError as e: print(f"SKIPPED: {e}") @@ -178,3 +202,32 @@ def bench_trtllm_mla( f"backend={args.backend}: {type(e).__name__}: {e}" ) print() + + # Focused sinks sub-sweep — small representative grid that exercises + # the sinks code path on both backends that support it (trtllm-gen and + # cute-dsl; xqa MLA rejects sinks). Pairs with the no-sinks rows above + # at the same shapes so users can read the sinks overhead off the diff. + if args.backend != "xqa": + print() + print("=" * 72) + print("Focused sinks sub-sweep (with_sinks=True)") + print("=" * 72) + for batch_size in [1, 16, 128]: + for seq_len in [1024, 8192]: + try: + bench_trtllm_mla( + batch_size, + q_len_per_request=1, + seq_len=seq_len, + page_size=64, + dtype=torch.bfloat16, + backend=args.backend, + with_sinks=True, + ) + except Exception as e: + print( + f"ERROR: backend={args.backend}, sinks=True, " + f"batch_size={batch_size}, seq_len={seq_len}: " + f"{type(e).__name__}: {e}" + ) + print() diff --git a/flashinfer/cute_dsl/attention/__init__.py b/flashinfer/cute_dsl/attention/__init__.py index 34b43c4bb9..be49852a9c 100644 --- a/flashinfer/cute_dsl/attention/__init__.py +++ b/flashinfer/cute_dsl/attention/__init__.py @@ -75,5 +75,12 @@ ) from .wrappers.batch_mla import ( BatchMLADecodeCuteDSLWrapper, - cute_dsl_mla_decode, ) + +# MLA decode is reached via a dispatcher that picks the modular or monolithic +# implementation based on the cute_dsl_impl= kwarg (default "auto" = pick +# monolithic, auto-promote to modular on modular-only features like sinks). +# See mla_dispatch.py for the full selection contract. The modular and +# monolithic standalone functions remain importable from their original +# locations for tests/benchmarks that need to bypass the dispatcher. +from .mla_dispatch import cute_dsl_mla_decode diff --git a/flashinfer/cute_dsl/attention/fusion/variant.py b/flashinfer/cute_dsl/attention/fusion/variant.py index 47873edd84..9460a8dfc3 100644 --- a/flashinfer/cute_dsl/attention/fusion/variant.py +++ b/flashinfer/cute_dsl/attention/fusion/variant.py @@ -419,6 +419,28 @@ def __init__(self, sink): def extra_params(self): return self._sink + # Value-based hash/eq so this variant can serve as a stable + # @functools.cache key for kernel compilation. The compiled kernel does + # not specialize on the sink *values* — only on the variant type and the + # params *shape* (the runtime sinks tensor is bound at launch time, not + # baked into the kernel). So two AttentionWithSink built from sinks + # tensors of the same shape and dtype are equivalent for caching, even + # if the underlying storage differs. This lets callers who reconstruct + # the variant per-call (e.g. the standalone cute_dsl_mla_decode function) + # still hit the kernel cache and avoid catastrophic per-call JIT + # recompiles. + def _cache_key(self): + return (type(self), tuple(self._sink.shape), self._sink.dtype) + + def __hash__(self): + return hash(self._cache_key()) + + def __eq__(self, other): + return ( + isinstance(other, AttentionWithSink) + and self._cache_key() == other._cache_key() + ) + @cute.jit def update_statistics(self, kv_tile_idx, qo_head_idx, m, d, scale): # Guard: on non-first tiles, return (m, d) unchanged. Computing diff --git a/flashinfer/cute_dsl/attention/mla_dispatch.py b/flashinfer/cute_dsl/attention/mla_dispatch.py new file mode 100644 index 0000000000..f41976dff9 --- /dev/null +++ b/flashinfer/cute_dsl/attention/mla_dispatch.py @@ -0,0 +1,126 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""Dispatcher between the modular and monolithic CuTe DSL MLA decode kernels. + +Both implementations share the public ``backend="cute-dsl"`` user surface in +``trtllm_batch_decode_with_kv_cache_mla``. Implementation selection is +controlled by the ``cute_dsl_impl`` kwarg, with three valid values: + +* ``"auto"`` (default) — library picks the right implementation. + Monolithic by default, automatically promoted to modular when the call + uses a feature monolithic doesn't support (currently: ``sinks``). +* ``"modular"`` — strict. Always run the modular implementation. +* ``"monolithic"`` — strict. Always run the monolithic implementation; + raise :class:`ValueError` if the call uses any modular-only feature. + No silent fallback — the contract is "you asked for monolithic, you + get monolithic, or a clear error". + +The strict modes exist so users can pin the implementation for +differential debugging or perf characterisation without worrying that +the dispatcher will silently substitute something else. +""" + +from __future__ import annotations + +import logging +from typing import Optional + +logger = logging.getLogger(__name__) + +VALID_IMPLS = ("auto", "modular", "monolithic") +_DEFAULT_IMPL = "monolithic" + +# One-shot log so users can confirm which impl they actually got, without +# spamming on every kernel call. Keyed by resolved impl so the very first +# call of each impl in a process logs once. +_logged_impls: set[str] = set() + + +# Modular-only features. When any of these is requested: +# * ``cute_dsl_impl="auto"`` → silently promote to modular +# * ``cute_dsl_impl="modular"`` → already correct, run as-is +# * ``cute_dsl_impl="monolithic"`` → raise ValueError +# +# Add new entries here as more variants are exposed through the standalone +# signature. +MODULAR_ONLY_KWARGS = ("sinks",) + + +def _has_modular_only_feature(kwargs: dict) -> Optional[str]: + """Return the name of the first modular-only kwarg present (and non-None), + or None if no modular-only feature was requested.""" + for name in MODULAR_ONLY_KWARGS: + if kwargs.get(name) is not None: + return name + return None + + +def _resolve_impl(*, requested: str, kwargs: dict) -> str: + """Map a user request and call kwargs to a concrete impl name. + + See module docstring for the contract. ``requested`` must already be + one of :data:`VALID_IMPLS`. + """ + if requested not in VALID_IMPLS: + raise ValueError( + f"Invalid cute_dsl_impl={requested!r}; expected one of {VALID_IMPLS}" + ) + + needs_modular = _has_modular_only_feature(kwargs) + + if requested == "auto": + return "modular" if needs_modular is not None else _DEFAULT_IMPL + + if requested == "monolithic" and needs_modular is not None: + raise ValueError( + f"cute_dsl_impl='monolithic' was requested but the call uses " + f"{needs_modular!r}, which is only supported by the modular " + f"implementation. Use cute_dsl_impl='auto' (default, picks the " + f"right impl based on the call) or cute_dsl_impl='modular'." + ) + + return requested # "modular" or "monolithic" + + +def cute_dsl_mla_decode(*args, cute_dsl_impl: str = "auto", **kwargs): + """Run CuTe DSL MLA decode using the resolved implementation. + + Forwards all positional and keyword arguments verbatim to the underlying + implementation. See + :func:`flashinfer.cute_dsl.attention.wrappers.batch_mla.cute_dsl_mla_decode` + (modular, supports ``sinks=``) and + :func:`flashinfer.cute_dsl.attention.monolithic.mla_decode.cute_dsl_mla_decode` + (monolithic, no variant support) — their signatures are otherwise + identical. + + Parameters + ---------- + cute_dsl_impl : str, default ``"auto"`` + ``"auto"`` (default) lets the dispatcher pick: monolithic by + default, modular when the call uses a modular-only feature + (currently ``sinks``). ``"modular"`` and ``"monolithic"`` are + strict — the dispatcher will not silently switch implementations, + and ``"monolithic"`` raises :class:`ValueError` if the call uses + a modular-only feature. + """ + impl = _resolve_impl(requested=cute_dsl_impl, kwargs=kwargs) + + if impl not in _logged_impls: + _logged_impls.add(impl) + logger.info( + "flashinfer.cute_dsl MLA decode using impl=%s", + impl, + ) + + # Imports are deferred so that selecting one impl never imports/JITs the + # other. Each impl's import path triggers heavy CuTe-DSL machinery. + if impl == "monolithic": + # Strip modular-only kwargs. The strict resolver above guarantees + # they're all None when we land here, but the monolithic standalone + # signature predates these kwargs and would TypeError on them. + kwargs = {k: v for k, v in kwargs.items() if k not in MODULAR_ONLY_KWARGS} + from .monolithic.mla_decode import cute_dsl_mla_decode as _impl + else: + from .wrappers.batch_mla import cute_dsl_mla_decode as _impl + return _impl(*args, **kwargs) diff --git a/flashinfer/cute_dsl/attention/monolithic/__init__.py b/flashinfer/cute_dsl/attention/monolithic/__init__.py new file mode 100644 index 0000000000..df53e16a6c --- /dev/null +++ b/flashinfer/cute_dsl/attention/monolithic/__init__.py @@ -0,0 +1,41 @@ +# Copyright (c) 2026 by FlashInfer team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Monolithic CuTe DSL MLA decode kernels for Blackwell SM100. + +This subpackage hosts the original single-file MLA decode kernels +(``BlackwellMultiHeadLatentAttentionForwardFP16`` / ``…FP8``) that were +introduced in #2743 and #2901. They were removed in #2805 in favor of the +modular kernel under ``flashinfer.cute_dsl.attention`` and restored here as +an alternate implementation under the same ``backend="cute-dsl"`` user +surface. + +Selection between the modular and monolithic implementations is handled by +``flashinfer.cute_dsl.attention.cute_dsl_mla_decode``; this module is not +intended to be imported directly by users. To force the monolithic path, +pass ``cute_dsl_impl="monolithic"`` to the public API call. +""" + +from flashinfer.cute_dsl.utils import is_cute_dsl_available + +if is_cute_dsl_available(): + from .mla_decode import cute_dsl_mla_decode + +__all__ = [ + "is_cute_dsl_available", +] + +if is_cute_dsl_available(): + __all__ += [ + "cute_dsl_mla_decode", + ] diff --git a/flashinfer/cute_dsl/attention/monolithic/mla_decode.py b/flashinfer/cute_dsl/attention/monolithic/mla_decode.py new file mode 100644 index 0000000000..08f4e616e8 --- /dev/null +++ b/flashinfer/cute_dsl/attention/monolithic/mla_decode.py @@ -0,0 +1,490 @@ +# Copyright (c) 2026 by FlashInfer team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +CuTe DSL MLA Decode Kernel Integration +======================================= + +Wraps NVIDIA's CuTe DSL MLA decode kernels (FP16/BF16/FP8) for Blackwell SM100 +and exposes them via a PyTorch API compatible with FlashInfer's MLA backend. +""" + +import functools +from typing import Callable, Optional, Tuple + +import cutlass +import cutlass.cute as cute +import torch +from cutlass import Float32, Int32 + +from flashinfer.utils import device_support_pdl + +from .mla_decode_fp16 import BlackwellMultiHeadLatentAttentionForwardFP16 +from .mla_decode_fp8 import BlackwellMultiHeadLatentAttentionForwardFP8 +from flashinfer.cute_dsl.utils import ( + get_max_active_clusters, + get_num_sm, + torch_to_cutlass_dtype, +) + + +@functools.cache +def _get_split_kv_and_workspace_size( + B: int, + q_len: int, + H: int, + kv_lora_rank: int, + max_active_blocks: int, +) -> Tuple[int, int]: + """Cache split_kv and workspace_size since they are deterministic for the same params.""" + split_kv = BlackwellMultiHeadLatentAttentionForwardFP16.get_split_kv_simplified( + B, q_len, max_active_blocks + ) + workspace_size = BlackwellMultiHeadLatentAttentionForwardFP16.get_workspace_size( + H, q_len, kv_lora_rank, B, split_kv, cutlass.Float32 + ) + return split_kv, workspace_size + + +@functools.cache +def _check_can_implement( + torch_dtype: torch.dtype, + torch_out_dtype: torch.dtype, + page_size: int, + num_heads: int, + seq_len_q: int, + kv_lora_rank: int, + qk_rope_head_dim: int, + is_persistent: bool, + is_var_seq: bool, + is_var_split_kv: bool, +) -> None: + """Check if the kernel supports the given configuration (cached).""" + mma_qk_tiler_mn = (128, 128) + mma_pv_tiler_mn = (128, 256) + + is_fp8 = torch_dtype == torch.float8_e4m3fn + KernelClass = ( + BlackwellMultiHeadLatentAttentionForwardFP8 + if is_fp8 + else BlackwellMultiHeadLatentAttentionForwardFP16 + ) + cutlass_in_dtype = torch_to_cutlass_dtype(torch_dtype) + cutlass_out_dtype = torch_to_cutlass_dtype(torch_out_dtype) + if not KernelClass.can_implement( + 1, # B (runtime, use placeholder) + seq_len_q, + 1, # K (runtime, use placeholder) + num_heads, + kv_lora_rank, + qk_rope_head_dim, + cutlass_in_dtype, + cutlass_out_dtype, + cutlass.Float32, + cutlass.Float32, + mma_qk_tiler_mn, + mma_pv_tiler_mn, + is_persistent, + is_var_seq, + is_var_split_kv, + page_size, + ): + raise ValueError( + f"cute_dsl_mla_decode: unsupported configuration " + f"(q_len={seq_len_q}, num_heads={num_heads}, page_size={page_size}, " + f"in_dtype={torch_dtype}, out_dtype={torch_out_dtype})" + ) + + +@functools.cache +def _get_compiled_mla_kernel( + torch_dtype: torch.dtype, + torch_out_dtype: torch.dtype, + page_size: int, + kv_lora_rank: int, + qk_rope_head_dim: int, + is_persistent: bool, + is_var_seq: bool, + is_var_split_kv: bool, + skip_correction_threshold: float = 0.0, + is_workspace_size_zero: bool = False, + enable_pdl: bool = False, +) -> Callable: + """Compile and cache an MLA decode kernel. + + Returns a callable that accepts (q_latent, q_rope, c_latent, c_rope, + page_table, o, lse, workspace, split_kv_scalar, cache_seqs, + block_split_kvs, softmax_scale_scalar, output_scale_scalar). + + All scalar arguments must be pre-wrapped as Int32/Float32. + """ + # Tile sizes for Blackwell mma. + # (128, 128) for QK and (128, 256) for PV. + mma_qk_tiler_mn = (128, 128) + mma_pv_tiler_mn = (128, 256) + # 2 CTAs along M (num_heads) + cluster_shape_mnk = (2, 1, 1) + + is_fp8 = torch_dtype == torch.float8_e4m3fn + KernelClass = ( + BlackwellMultiHeadLatentAttentionForwardFP8 + if is_fp8 + else BlackwellMultiHeadLatentAttentionForwardFP16 + ) + cutlass_dtype = torch_to_cutlass_dtype(torch_dtype) + cutlass_out_dtype = torch_to_cutlass_dtype(torch_out_dtype) + + kernel_obj = KernelClass( + acc_dtype=cutlass.Float32, + lse_dtype=cutlass.Float32, + mma_qk_tiler_mn=mma_qk_tiler_mn, + mma_pv_tiler_mn=mma_pv_tiler_mn, + max_active_clusters=get_max_active_clusters( + cluster_shape_mnk[0] * cluster_shape_mnk[1] + ), + page_size=page_size, + skip_correction_threshold=skip_correction_threshold, + is_persistent=is_persistent, + is_var_seq=is_var_seq, + is_var_split_kv=is_var_split_kv, + enable_pdl=enable_pdl, + ) + + # All dimensions as sym_int — this matches the original kernel's use of + # mark_compact_shape_dynamic, which makes ALL shapes dynamic CuTe Integers. + # Static Python ints would cause cute.assume() to fail with AttributeError + # inside initialize_workspace() since it expects DSL Integer types. + sym_heads = cute.sym_int() + sym_latent = cute.sym_int(divisibility=16) + sym_seq_q = cute.sym_int() + sym_rope = cute.sym_int(divisibility=16) + sym_batch = cute.sym_int() # query/output batch dimension + sym_kv_batch = cute.sym_int() # KV cache batch dim (flat pool, =1 in paged mode) + sym_seq_kv = cute.sym_int() + sym_page_count = cute.sym_int() + sym_workspace_size = cute.sym_int() + + # q_latent, q_rope, c_latent, c_rope are slices of contiguous tensors on + # the last dim (e.g. query[..., :kv_lora_rank]), so they are NOT contiguous: + # stride[-2] = D_qk (original full last dim), not the sliced shape. + # Use make_fake_tensor with fully dynamic strides so the compiled kernel + # reads actual strides from the runtime tensor. Last-dim stride is always 1. + + # q_latent: [batch_size, seq_len_q, num_heads, latent_dim] — non-contiguous slice + q_latent_fake = cute.runtime.make_fake_tensor( + cutlass_dtype, + (sym_batch, sym_seq_q, sym_heads, sym_latent), + stride=(cute.sym_int(), cute.sym_int(), cute.sym_int(), 1), + assumed_align=16, + ) + # q_rope: [batch_size, seq_len_q, num_heads, rope_dim] — non-contiguous slice + q_rope_fake = cute.runtime.make_fake_tensor( + cutlass_dtype, + (sym_batch, sym_seq_q, sym_heads, sym_rope), + stride=(cute.sym_int(), cute.sym_int(), cute.sym_int(), 1), + assumed_align=16, + ) + # c_latent: [kv_batch, seq_len_k, latent_dim] — non-contiguous slice + # kv_batch is a separate sym_int from query batch: paged KV cache uses a flat + # pool so kv_batch=num_pages at runtime, while query batch can be any value. + c_latent_fake = cute.runtime.make_fake_tensor( + cutlass_dtype, + (sym_kv_batch, sym_seq_kv, sym_latent), + stride=(cute.sym_int(), cute.sym_int(), 1), + assumed_align=16, + ) + # c_rope: [kv_batch, seq_len_k, rope_dim] — non-contiguous slice + c_rope_fake = cute.runtime.make_fake_tensor( + cutlass_dtype, + (sym_kv_batch, sym_seq_kv, sym_rope), + stride=(cute.sym_int(), cute.sym_int(), 1), + assumed_align=16, + ) + # page_table: [batch_size, page_count] — contiguous + page_table_fake = cute.runtime.make_fake_compact_tensor( + cutlass.Int32, + (sym_batch, sym_page_count), + stride_order=(1, 0), + assumed_align=16, + ) + # o: [batch_size, seq_len_q, num_heads, latent_dim] — contiguous + o_fake = cute.runtime.make_fake_compact_tensor( + cutlass_out_dtype, + (sym_batch, sym_seq_q, sym_heads, sym_latent), + stride_order=(3, 2, 1, 0), + assumed_align=16, + ) + # lse: [batch_size, seq_len_q, num_heads] — contiguous + lse_fake = cute.runtime.make_fake_compact_tensor( + cutlass.Float32, + (sym_batch, sym_seq_q, sym_heads), + stride_order=(2, 1, 0), + assumed_align=16, + ) + if is_workspace_size_zero: + workspace_fake = None + else: + # workspace: 1-D int8 buffer. 32-byte alignment because workspace stores + # fp32 partial sums internally, requiring stricter alignment than tensors. + workspace_fake = cute.runtime.make_fake_compact_tensor( + cutlass.Int8, + (sym_workspace_size,), + assumed_align=32, + ) + # cache_seqs: [batch_size] — int32 + cache_seqs_fake = cute.runtime.make_fake_compact_tensor( + cutlass.Int32, + (sym_batch,), + assumed_align=16, + ) + # block_split_kvs: [batch_size] — int32 (only needed for is_var_split_kv=True) + if is_var_split_kv: + block_split_kvs_fake = cute.runtime.make_fake_compact_tensor( + cutlass.Int32, + (sym_batch,), + assumed_align=16, + ) + else: + block_split_kvs_fake = None + + stream_fake = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) + + compiled_kernel = cute.compile( + kernel_obj, + q_latent_fake, + q_rope_fake, + c_latent_fake, + c_rope_fake, + page_table_fake, + o_fake, + lse_fake, + workspace_fake, + Int32(1), # split_kv placeholder + cache_seqs_fake, + block_split_kvs_fake, + Float32(1.0), # softmax_scale placeholder + Float32(1.0), # output_scale placeholder + stream_fake, + options="--enable-tvm-ffi --opt-level 2", + ) + + return compiled_kernel + + +# TODO: query[..., :kv_lora_rank], do we need to remove such kind of slice and move the logic to call routine in the kernel file. +def cute_dsl_mla_decode( + query: torch.Tensor, + kv_cache: torch.Tensor, + workspace_buffer: torch.Tensor, + kv_lora_rank: int, + qk_rope_head_dim: int, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + max_seq_len: int, + softmax_scale: float, + output_scale: float = 1.0, + out: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + is_var_seq: bool = True, + enable_pdl: Optional[bool] = None, +) -> torch.Tensor: + """CuTe DSL MLA decode kernel for Blackwell SM100. + + Parameters + ---------- + query : torch.Tensor + [B, q_len, H, D_qk] where D_qk = kv_lora_rank + qk_rope_head_dim + kv_cache : torch.Tensor + [num_pages, page_size, D_ckv + D_kpe] (3D) or [num_pages, 1, page_size, D_ckv + D_kpe] (4D) + workspace_buffer : torch.Tensor + Pre-allocated workspace buffer (uint8). Required size depends on batch size + and split_kv (auto-computed from B, q_len, and number of SMs): + + - Formula: ``B * H * q_len * split_kv * (kv_lora_rank + 1) * 4`` bytes + (0 when split_kv == 1, which happens when B >= num_SMs / 2) + - Typical max: ~18 MB on a 148-SM GPU (e.g. B=4..8, H=128, D=512) + - Safe default: 128 MB covers all realistic configurations + kv_lora_rank : int + Latent dimension (e.g. 512). + qk_rope_head_dim : int + RoPE dimension (e.g. 64). + block_tables : torch.Tensor + [B, max_pages] — page table indices. + seq_lens : torch.Tensor + [B] — per-request KV sequence lengths. + max_seq_len : int + Maximum sequence length across the batch. + softmax_scale : float + Scale factor for QK^T before softmax. + output_scale : float + Scale factor applied to the output. + out : Optional[torch.Tensor] + Pre-allocated output tensor [B, q_len, H, kv_lora_rank]. + out_dtype : Optional[torch.dtype] + Output data type. If None, defaults to torch.bfloat16 (matching trtllm-gen backend). + Supported values: torch.bfloat16, torch.float8_e4m3fn (FP8 input only), + torch.float16, torch.bfloat16 (FP16/BF16 input). + is_var_seq : bool + Whether the sequence length is variable. + If True, the sequence length is variable. + Otherwise,the sequence length is fixed for all the requests in the batch. + enable_pdl : Optional[bool], default=None + Whether to enable Programmatic Dependent Launch (PDL). + If None, auto-detects based on device capability. + + Returns + ------- + torch.Tensor + Output tensor [B, q_len, H, kv_lora_rank]. + """ + supported_dtypes = {torch.float16, torch.bfloat16, torch.float8_e4m3fn} + assert query.dtype in supported_dtypes, ( + f"cute_dsl_mla_decode only supports {supported_dtypes}, got {query.dtype}" + ) + assert kv_cache.dtype == query.dtype, ( + f"kv_cache dtype {kv_cache.dtype} must match query dtype {query.dtype}" + ) + B, q_len, H, D_qk = query.shape + assert D_qk == kv_lora_rank + qk_rope_head_dim + + q_dtype = query.dtype + # Resolve output dtype: for FP8 input, default to bfloat16 (matching trtllm-gen backend); + # for FP16/BF16 input, default to same as input. Allow override via out_dtype or out tensor. + if out is not None: + o_dtype = out.dtype + elif out_dtype is not None: + o_dtype = out_dtype + elif q_dtype == torch.float8_e4m3fn: + o_dtype = torch.bfloat16 + else: + o_dtype = q_dtype + + # Handle 3D vs 4D kv_cache: normalize to 3D [num_pages, page_size, D_total] + if kv_cache.dim() == 4: + kv_cache = kv_cache.squeeze(1) + page_size = kv_cache.shape[1] + + # Split query into latent and rope components — keep contiguous [B, q_len, H, D]. + # The kernel's __call__ reinterprets to [H, D, q_len, B] via zero-cost make_tensor. + q_latent_k = query[..., :kv_lora_rank] + q_rope_k = query[..., kv_lora_rank:] + + # KV cache slices — keep contiguous [num_pages, page_size, D]. + # The kernel reinterprets to [page_size, D, num_pages] internally. + c_latent_k = kv_cache[:, :, :kv_lora_rank] + c_rope_k = kv_cache[:, :, kv_lora_rank:] + + # Page table: [B, max_pages]: passed directly, kernel reinterprets. + page_table_k = block_tables + + if max_seq_len <= 0: + raise ValueError(f"max_seq_len must be > 0, got {max_seq_len}") + + # Cached split_kv and workspace_size computation + max_active_blocks = get_num_sm(query.device) + split_kv, workspace_size = _get_split_kv_and_workspace_size( + B, q_len, H, kv_lora_rank, max_active_blocks + ) + + # Prepare workspace: slice of contiguous 1D buffer is already contiguous + assert workspace_buffer.dtype == torch.int8, ( + f"workspace_buffer must be torch.int8, got {workspace_buffer.dtype}" + ) + assert workspace_buffer.numel() >= workspace_size, ( + f"workspace_buffer too small: {workspace_buffer.numel()} bytes, " + f"need {workspace_size} bytes" + ) + is_workspace_size_zero = workspace_size == 0 + if is_workspace_size_zero: + workspace_bytes = None + else: + workspace_bytes = workspace_buffer[:workspace_size] + # Output buffer: contiguous [B, q_len, H, D]. + # Kernel reinterprets to [H, D, q_len, B] internally via zero-cost make_tensor. + if out is not None: + o_k = out + else: + o_k = torch.empty( + (B, q_len, H, kv_lora_rank), dtype=o_dtype, device=query.device + ) + + # LSE: contiguous [B, q_len, H]. Kernel reinterprets to [H, q_len, B]. + lse_k = torch.empty((B, q_len, H), dtype=torch.float32, device=query.device) + + # cache_seqs: per-batch sequence lengths (skip .to() if already int32) + cache_seqs = seq_lens if seq_lens.dtype == torch.int32 else seq_lens.to(torch.int32) + + is_var_split_kv = False + block_split_kvs = None + skip_correction_threshold = 0.0 + + # for fix-length, set is_persistent to True; otherwise, set to False. + is_persistent = not is_var_seq + + # Validate configuration (cached, negligible overhead after first call) + _check_can_implement( + torch_dtype=q_dtype, + torch_out_dtype=o_dtype, + page_size=page_size, + num_heads=H, + seq_len_q=q_len, + kv_lora_rank=kv_lora_rank, + qk_rope_head_dim=qk_rope_head_dim, + is_persistent=is_persistent, + is_var_seq=is_var_seq, + is_var_split_kv=is_var_split_kv, + ) + + enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl + + # Get compiled kernel (cached after first compile) + # Note: when is_workspace_size_zero is True, workspace_bytes is None and it will launch one kernel without workspace. + # Otherwise, workspace_bytes is not None and it will launch two kernels. + compiled_kernel = _get_compiled_mla_kernel( + torch_dtype=q_dtype, + torch_out_dtype=o_dtype, + page_size=page_size, + kv_lora_rank=kv_lora_rank, + qk_rope_head_dim=qk_rope_head_dim, + is_persistent=is_persistent, + is_var_seq=is_var_seq, + is_var_split_kv=is_var_split_kv, + skip_correction_threshold=skip_correction_threshold, + is_workspace_size_zero=is_workspace_size_zero, + enable_pdl=enable_pdl, + ) + + # Call the kernel + compiled_kernel( + q_latent_k, + q_rope_k, + c_latent_k, + c_rope_k, + page_table_k, + o_k, + lse_k, + workspace_bytes, + Int32(split_kv), + cache_seqs, + block_split_kvs, + Float32(softmax_scale), + Float32(output_scale), + ) + + # If out was provided, kernel already wrote into it — return directly. + if out is not None: + return out + + # o_k is [B, q_len, H, D] — return as-is to match trtllm-gen output shape. + return o_k diff --git a/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.py b/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.py new file mode 100644 index 0000000000..aa3b2cd475 --- /dev/null +++ b/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.py @@ -0,0 +1,4259 @@ +# Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +import math +from typing import Type, Tuple, Optional +from types import SimpleNamespace + +import torch +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +import cutlass.cute.testing as testing +import cutlass.cute.nvgpu.tcgen05 as tcgen05 + +# TODO: Remove this hook helper function after nvidia-cutlass-dsl 4.3.x is no longer supported. +# Compat shim: setmaxregister_{decrease,increase} added in cutlass-dsl 4.4; +# older versions only have the deprecated warpgroup_reg_{dealloc,alloc}. +_setmaxregister_decrease = getattr( + cute.arch, + "setmaxregister_decrease", + getattr(cute.arch, "warpgroup_reg_dealloc", None), +) +_setmaxregister_increase = getattr( + cute.arch, + "setmaxregister_increase", + getattr(cute.arch, "warpgroup_reg_alloc", None), +) + +# Compat shim: get_max_tmem_alloc_cols added in cutlass-dsl 4.4; +# older versions don't have it, so we provide a fallback implementation. +_TMEM_MAX_ALLOC_COLUMNS_MAP = {"sm_100": 512, "sm_103": 512, "sm_120": 512} + + +# TODO: Remove this hook helper function after nvidia-cutlass-dsl 4.3.x is no longer supported. +def _get_max_tmem_alloc_cols(compute_capability: str) -> int: + if hasattr(cute.arch, "get_max_tmem_alloc_cols"): + return cute.arch.get_max_tmem_alloc_cols(compute_capability) + if compute_capability not in _TMEM_MAX_ALLOC_COLUMNS_MAP: + raise ValueError(f"Unsupported compute capability: {compute_capability}") + return _TMEM_MAX_ALLOC_COLUMNS_MAP[compute_capability] + + +from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode +import cutlass.cute.nvgpu.cpasync as cpasync +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait +import cutlass.torch as cutlass_torch +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.runtime import from_dlpack +from cutlass.base_dsl.arch import Arch +from cutlass.cutlass_dsl import BaseDSL + + +from .mla_helpers import ( + ceil_div, + MAX_SPLITS, + LOG2_E, + MLAStaticTileScheduler, + MLAStaticTileSchedulerParams, + create_mla_static_tile_scheduler, + create_mla_static_tile_scheduler_params, +) + +""" +A Multi-Head Latent Attention (MLA) example with FP16 data type for the NVIDIA Blackwell SM100 architecture using CUTE DSL + +This example demonstrates an implementation of inference of multi-head latent attention using a TMA + Blackwell +SM100 TensorCore warp-specialized persistent kernel. The implementation integrates the (Qc + Qr)*(Kc + Kr)^T +matrix multiplication, softmax normalization, and softmax((Qc + Qr)*(Kc + Kr)^T)*Vc into a single kernel. +The kernel provides support for page table storage and variable-length KV cache sequences. It implements KV splitting +functionality to minimize latency when processing long KV sequences. + +The kernel implements key optimizations including: +- Warp specialization for different computation phases (load, MMA, softmax, correction, epilogue) +- Pipeline stages between different warps for overlapping computation and memory access +- Support for different precision data types +- Two sub-kernels (split KV kernel and reduction kernel) that enable split KV processing + +To run this example: + +.. code-block:: bash + + python examples/blackwell/mla_fp16.py \ + --batch_size 4 --latent_dim 512 --rope_dim 64 \ + --num_heads 128 --seq_len_q 1 --seq_len_k 1024 \ + --in_dtype Float16 --out_dtype Float16 \ + --acc_dtype Float32 --lse_dtype Float32 \ + --is_var_seq --is_var_split_kv \ + --is_persistent + +The above example runs Multi-Head Latent Attention (MLA) with the following configuration: +- Batch size: 4 +- Sequence length of Q: 1 +- Sequence length of K: 1024 +- Latent dimension: 512 +- RoPE dimension: 64 +- Number of heads: 128 +- Data types: Float16 (input), Float16 (output), Float32 (accumulation and LSE) + +It utilizes page table storage for the KV cache and enables both variable-length KV cache sequences +and variable split KV processing with persistent scheduling. + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/mla_fp16.py \ + --batch_size 4 --latent_dim 512 --rope_dim 64 \ + --num_heads 128 --seq_len_q 1 --seq_len_k 1024 \ + --in_dtype Float16 --out_dtype Float16 \ + --acc_dtype Float32 --lse_dtype Float32 \ + --is_var_seq --is_var_split_kv \ + --is_persistent --warmup_iterations 3 \ + --iterations 10 --skip_ref_check + +Constraints for this example: +* Data type requirements: + - Input/output: Float16 + - Accumulation and LSE: Float32 +* Fixed architecture parameters: + - Number of attention heads: 128 + - Latent dimension: 512 + - RoPE dimension: 64 +* Input query modes should be (NumHeads, LatentDim/RopeDim, SeqLenQ, BatchSize) +* Input kv latent/rope modes should be (SeqLenK, LatentDim/RopeDim, BatchSize) +* Query sequence length must be 1-4 +* Only supports 2-CTA instructions +* Variable sequence length requires page table storage enabled +""" + + +class BlackwellMultiHeadLatentAttentionForwardFP16: + def __init__( + self, + acc_dtype: Type[cutlass.Numeric], + lse_dtype: Type[cutlass.Numeric], + mma_qk_tiler_mn: Tuple[int, int], + mma_pv_tiler_mn: Tuple[int, int], + max_active_clusters: int, + page_size: int, + skip_correction_threshold: float, + is_persistent: bool, + is_var_seq: bool, + is_var_split_kv: bool, + enable_pdl: bool, + ): + """Initializes the configuration for a Blackwell Multi-Head Latent Attention (MLA) kernel. + + :param acc_dtype: Data type for accumulation S and O + :type acc_dtype: Type[cutlass.Numeric] + :param lse_dtype: Data type for output LSE + :type lse_dtype: Type[cutlass.Numeric] + :param mma_s_tiler: The (H, K) tile shape of the MMA instruction for S + :type mma_s_tiler: Tuple[int, int] + :param mma_p_tiler: The (H, D) tile shape of the MMA instruction for P + :type mma_p_tiler: Tuple[int, int] + :param max_active_clusters: Maximum number of active clusters + :type max_active_clusters: int + :param page_size: The page size of the page table + :type page_size: int + :param skip_correction_threshold: Threshold to skip correction + :type skip_correction_threshold: float + :param is_persistent: Whether to use persistent kernel mode + :type is_persistent: bool + :param is_var_seq: Whether to use variable sequence length + :type is_var_seq: bool + :param is_var_split_kv: Whether to use variable split KV + :type is_var_split_kv: bool + :param enable_pdl: Whether to use PDL + :type enable_pdl: bool + """ + + self.latent_dim = 512 + self.rope_dim = 64 + self.acc_dtype = acc_dtype + self.lse_dtype = lse_dtype + self.mma_qk_tiler_mn = mma_qk_tiler_mn + self.mma_pv_tiler_mn = mma_pv_tiler_mn + self.max_active_clusters = max_active_clusters + self.skip_correction_threshold = skip_correction_threshold + self.is_persistent = is_persistent + self.page_size = page_size + self.is_var_seq = is_var_seq + self.is_var_split_kv = is_var_split_kv + self.enable_pdl = enable_pdl + self.cluster_shape_mnk = (2, 1, 1) + self.use_2cta_instrs = True + # When using 2 CTAs with m=128: warps 0-1 handle accumulation for first half [0, n/2), + # while warps 2-3 handle accumulation for second half [n/2, n) + self.warps_in_n = 2 + self.num_compute_warps = 4 + self.threads_per_warp = 32 + mma_qk_tiler_k = self.rope_dim + self.mma_qk_tiler = ( + self.mma_qk_tiler_mn[0], + self.mma_qk_tiler_mn[1], + mma_qk_tiler_k, + ) + self.mma_qk_rope_tiler = ( + self.mma_qk_tiler_mn[0], + self.mma_qk_tiler_mn[1], + self.rope_dim, + ) + self.mma_pv_tiler = ( + self.mma_pv_tiler_mn[0], + self.mma_pv_tiler_mn[1], + self.mma_qk_tiler[1] * self.mma_qk_tiler[2] // self.mma_pv_tiler_mn[1], + ) + self.iterations_qk_latent = self.latent_dim // self.mma_qk_tiler[2] + self.iterations_qk_rope = mma_qk_tiler_k // self.mma_qk_tiler[2] + self.iterations_qk = self.iterations_qk_latent + self.iterations_qk_rope + self.iterations_pv_k = self.mma_qk_tiler[1] // self.mma_pv_tiler[2] + self.iterations_pv_n = self.latent_dim // self.mma_pv_tiler[1] + + # Set specialized warp ids + self.compute_warp_ids = (0, 1, 2, 3) + self.correction_warp_ids = (4, 5, 6, 7) + self.mma_warp_id = 8 + + self.load_tma_warp_id = 9 + self.load_pt_warp_id = 10 + self.empty_warp_ids = (11,) + self.threads_per_cta = self.threads_per_warp * len( + ( + self.mma_warp_id, + self.load_tma_warp_id, + self.load_pt_warp_id, + *self.compute_warp_ids, + *self.correction_warp_ids, + *self.empty_warp_ids, + ) + ) + + # register settings + self.softmax_reg_num = 192 + self.correction_reg_num = 208 + self.other_reg_num = 96 + # Named barriers + self.tmem_ptr_sync_bar = pipeline.NamedBarrier( + barrier_id=1, + num_threads=( + self.threads_per_warp + + self.threads_per_warp * self.num_compute_warps * 2 + ), + ) + self.softmax_exchange_sync_bar = pipeline.NamedBarrier( + barrier_id=2, num_threads=(self.threads_per_warp * self.num_compute_warps) + ) + self.epilogue_exchange_sync_bar = pipeline.NamedBarrier( + barrier_id=3, num_threads=(self.threads_per_warp * self.num_compute_warps) + ) + + def _setup_attributes(self): + """Set up configurations and parameters for the MLA kernel operation. + + This method initializes and configures various attributes required for the + execution of the multi-head latent attention kernel, mainly about the pipeline stages: + + - Sets up staging parameters for Q, K, V inputs and accumulator data + - Configures pipeline stages for softmax, correction, and epilogue operations + """ + + self.load_q_stage = 1 + self.load_kv_stage = 15 + self.mma_s_stage = 2 + self.p_mma_stage = 2 + self.p_cor_stage = 2 + self.mma_o_stage = 1 + self.load_pt_stage = 4 + + self.tmem_o_offset = self.mma_s_stage * self.mma_qk_tiler[1] // self.warps_in_n + self.correction_factor_offset = ( + self.tmem_o_offset + self.latent_dim // self.warps_in_n + ) + + @cute.jit + def __call__( + self, + q_latent: cute.Tensor, + q_rope: cute.Tensor, + c_latent: cute.Tensor, + c_rope: cute.Tensor, + page_table: cute.Tensor, + o: cute.Tensor, + lse: cute.Tensor, + workspace: cute.Tensor, + split_kv: cutlass.Int32, + cache_seqs: Optional[cute.Tensor], + block_split_kvs: Optional[cute.Tensor], + softmax_scale: cutlass.Float32, + output_scale: cutlass.Float32, + stream: cuda.CUstream, + ): + """Execute the Multi-Head Latent Attention operation on the provided tensors. + + The method handles: + 1. Initialization of workspace for temporary split KV buffers + 2. Validation of tensor data types + 3. Initialization of hardware-specific parameters and memory layouts + 4. Configuration of TMA (Tensor Memory Access) operations + 5. Grid and work scheduling computation + 6. Kernel launch(split KV kernel and reduction kernel) with appropriate parameters + + :param q_latent: The query tensor with shape [batch_size, seq_len_q, num_head, latent_dim] (contiguous) + :type q_latent: cute.Tensor + :param q_rope: The query RoPE tensor with shape [batch_size, seq_len_q, num_head, rope_dim] (contiguous) + :type q_rope: cute.Tensor + :param c_latent: The key tensor with shape [num_pages, page_size, latent_dim] (contiguous) + :type c_latent: cute.Tensor + :param c_rope: The key RoPE tensor with shape [num_pages, page_size, rope_dim] (contiguous) + :type c_rope: cute.Tensor + :param page_table: The page table tensor with shape [batch_size, page_count] (contiguous) + :type page_table: cute.Tensor + :param o: The output tensor with shape [batch_size, seq_len_q, num_head, latent_dim] (contiguous) + :type o: cute.Tensor + :param lse: The LSE tensor with shape [batch_size, seq_len_q, num_head] (contiguous) + :type lse: cute.Tensor + :param workspace: The workspace tensor with 1-d shape prepared for acc_o and acc_lse + :type workspace: cute.Tensor + :param split_kv: The scalar factor for split KV + :type split_kv: cutlass.Int32 + :param cache_seqs: The cache sequences tensor with shape [batch_size] + :type cache_seqs: cute.Tensor + :param block_split_kvs: The block split KV tensor with shape [batch_size] + :type block_split_kvs: cute.Tensor + :param softmax_scale: The scale factor for softmax + :type softmax_scale: cutlass.Float32 + :param output_scale: The scale factor for the output + :type output_scale: cutlass.Float32 + :param stream: The CUDA stream to execute the kernel on + :type stream: cuda.CUstream + + :raises TypeError: If tensor data types don't match or aren't supported + """ + + # setup static attributes before smem/grid/tma computation + self.q_dtype = q_latent.element_type + self.k_dtype = c_latent.element_type + self.v_dtype = c_latent.element_type + self.o_dtype = o.element_type + + # check type consistency + if cutlass.const_expr( + self.q_dtype != self.k_dtype or self.q_dtype != self.v_dtype + ): + raise TypeError( + f"Type mismatch: {self.q_dtype} != {self.k_dtype} or {self.q_dtype} != {self.v_dtype}" + ) + + # Reinterpret contiguous [B, S_q, H, D] as [H, D, S_q, B] + # Input stride: (S_q*H*D, H*D, D, 1) → Target: (D, 1, H*D, S_q*H*D) + def _reinterpret_4d(t): + return cute.make_tensor( + t.iterator, + cute.make_layout( + (t.shape[2], t.shape[3], t.shape[1], t.shape[0]), + stride=(t.stride[2], t.stride[3], t.stride[1], t.stride[0]), + ), + ) + + q_latent = _reinterpret_4d(q_latent) + q_rope = _reinterpret_4d(q_rope) + o = _reinterpret_4d(o) + + # Reinterpret contiguous [num_pages, page_size, D] as [page_size, D, num_pages] + # Input stride: (PS*D, D, 1) → Target: (D, 1, PS*D) + def _reinterpret_3d_kv(t): + return cute.make_tensor( + t.iterator, + cute.make_layout( + (t.shape[1], t.shape[2], t.shape[0]), + stride=(t.stride[1], t.stride[2], t.stride[0]), + ), + ) + + c_latent = _reinterpret_3d_kv(c_latent) + c_rope = _reinterpret_3d_kv(c_rope) + + # Reinterpret contiguous [B, page_count] as [page_count, B] + page_table = cute.make_tensor( + page_table.iterator, + cute.make_layout( + (page_table.shape[1], page_table.shape[0]), + stride=(page_table.stride[1], page_table.stride[0]), + ), + ) + + # Reinterpret contiguous [B, S_q, H] as [H, S_q, B] + # Input stride: (S_q*H, H, 1) → Target: (1, H, S_q*H) + lse = cute.make_tensor( + lse.iterator, + cute.make_layout( + (lse.shape[2], lse.shape[1], lse.shape[0]), + stride=(lse.stride[2], lse.stride[1], lse.stride[0]), + ), + ) + + acc_o, acc_lse = self.initialize_workspace( + q_latent.shape[0], + q_latent.shape[1], + q_latent.shape[2], + q_latent.shape[3], + split_kv, + self.acc_dtype, + workspace, + ) + + c_latent_tranpose_layout = cute.select(c_latent.layout, mode=[1, 0, 2]) + c_latent_transpose = cute.make_tensor( + c_latent.iterator, c_latent_tranpose_layout + ) + + self.q_major_mode = tcgen05.OperandMajorMode.K + self.k_major_mode = tcgen05.OperandMajorMode.K + self.v_major_mode = tcgen05.OperandMajorMode.MN + + self._setup_attributes() + + cta_group = tcgen05.CtaGroup.TWO + # the intermediate tensor p is from smem & k-major + p_major_mode = tcgen05.OperandMajorMode.K + qk_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.q_dtype, + self.q_major_mode, + self.k_major_mode, + self.acc_dtype, + cta_group, + self.mma_qk_tiler[:2], + ) + pv_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.v_dtype, + p_major_mode, + self.v_major_mode, + self.acc_dtype, + cta_group, + self.mma_pv_tiler[:2], + ) + + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), + (qk_tiled_mma.thr_id.shape,), + ) + + self.epi_tile = self.mma_pv_tiler[:2] + + q_latent_smem_layout_staged = sm100_utils.make_smem_layout_a( + qk_tiled_mma, + self.mma_qk_tiler, + self.q_dtype, + (self.iterations_qk_latent * self.load_q_stage), + ) + q_latent_smem_layout_staged = cute.logical_divide( + q_latent_smem_layout_staged, (None, None, None, self.iterations_qk_latent) + ) + q_rope_smem_layout_staged = sm100_utils.make_smem_layout_a( + qk_tiled_mma, + self.mma_qk_rope_tiler, + self.q_dtype, + self.load_q_stage, + ) + + # rope reuse the same smem layout as latent + kc_smem_layout_staged = sm100_utils.make_smem_layout_b( + qk_tiled_mma, + self.mma_qk_tiler, + self.k_dtype, + self.load_kv_stage, + ) + kc_page_tile_size = min( + self.page_size, qk_tiled_mma.op.shape_mnk[0] // qk_tiled_mma.thr_id.shape + ) + + kc_smem_layout_for_tma = sm100_utils.make_smem_layout( + OperandMajorMode.K, + (self.mma_qk_tiler[0] // qk_tiled_mma.thr_id.shape, self.mma_qk_tiler[2]), + self.k_dtype, + self.load_kv_stage, + ) + kc_smem_layout_for_tma = cute.tiled_divide( + kc_smem_layout_for_tma, (kc_page_tile_size, self.mma_qk_tiler[2]) + ) + + p_smem_layout_staged = sm100_utils.make_smem_layout_a( + pv_tiled_mma, + self.mma_pv_tiler, + self.q_dtype, + (self.iterations_pv_k * self.p_mma_stage), + ) + p_smem_layout_staged = cute.logical_divide( + p_smem_layout_staged, (None, None, None, self.iterations_pv_k) + ) + + vc_smem_layout_staged = sm100_utils.make_smem_layout_b( + pv_tiled_mma, + self.mma_pv_tiler, + self.v_dtype, + self.load_kv_stage, + ) + vc_page_tile_size = min(self.page_size, self.mma_pv_tiler[2]) + vc_smem_layout_for_tma = sm100_utils.make_smem_layout( + OperandMajorMode.MN, + (self.mma_pv_tiler[1] // pv_tiled_mma.thr_id.shape, self.mma_pv_tiler[2]), + self.v_dtype, + self.load_kv_stage, + ) + vc_smem_layout_for_tma = cute.tiled_divide( + vc_smem_layout_for_tma, + ( + pv_tiled_mma.op.shape_mnk[1] // pv_tiled_mma.thr_id.shape, + vc_page_tile_size, + ), + ) + # TMA load for Q latent and rope + tma_load_op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp(cta_group) + + q_latent_smem_layout = cute.select(q_latent_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_q_latent, tma_tensor_q_latent = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q_latent, + q_latent_smem_layout, + self.mma_qk_tiler, + qk_tiled_mma, + cta_layout_vmnk.shape, + ) + q_rope_smem_layout = cute.select(q_rope_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_q_rope, tma_tensor_q_rope = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q_rope, + q_rope_smem_layout, + self.mma_qk_rope_tiler, + qk_tiled_mma, + cta_layout_vmnk.shape, + ) + # TMA load for c latent and k rope + kc_smem_layout = cute.select(kc_smem_layout_for_tma, mode=[0]) + tma_atom_c_latent, tma_tensor_c_latent = self.make_paged_tiled_tma_atom( + tma_load_op, + c_latent, + kc_smem_layout, + (self.mma_qk_tiler[1], self.mma_qk_tiler[2]), + qk_tiled_mma, + is_k_load=True, + ) + tma_atom_c_rope, tma_tensor_c_rope = self.make_paged_tiled_tma_atom( + tma_load_op, + c_rope, + kc_smem_layout, + (self.mma_qk_tiler[1], self.mma_qk_tiler[2]), + qk_tiled_mma, + is_k_load=True, + ) + # TMA load for c latent transpose + vc_smem_layout = cute.select(vc_smem_layout_for_tma, mode=[0]) + tma_atom_c_latent_transpose, tma_tensor_c_latent_transpose = ( + self.make_paged_tiled_tma_atom( + tma_load_op, + c_latent_transpose, + vc_smem_layout, + (self.mma_pv_tiler[1], self.mma_pv_tiler[2]), + pv_tiled_mma, + is_k_load=False, + ) + ) + + q_latent_copy_size = ( + cute.size_in_bytes(self.q_dtype, q_latent_smem_layout) + * cute.size(qk_tiled_mma.thr_id.shape) + * self.iterations_qk_latent + ) + q_rope_copy_size = ( + cute.size_in_bytes(self.q_dtype, q_rope_smem_layout) + * cute.size(qk_tiled_mma.thr_id.shape) + * self.iterations_qk_rope + ) + q_copy_size = q_latent_copy_size + q_rope_copy_size + kc_copy_size = cute.size_in_bytes( + self.k_dtype, cute.select(kc_smem_layout_staged, mode=[0, 1, 2]) + ) * cute.size(qk_tiled_mma.thr_id.shape) + vc_copy_size = cute.size_in_bytes( + self.v_dtype, cute.select(vc_smem_layout_staged, mode=[0, 1, 2]) + ) * cute.size(pv_tiled_mma.thr_id.shape) + assert kc_copy_size == vc_copy_size, ( + "kc_copy_size and vc_copy_size must be the same" + ) + + self.tma_copy_q_bytes = q_copy_size + self.tma_copy_kc_bytes = kc_copy_size + + tile_sched_params, grid = self._compute_grid( + o, + split_kv, + self.cluster_shape_mnk, + self.max_active_clusters, + self.is_persistent, + ) + + @cute.struct + class SplitKVKernelSharedStorage: + # Pipeline barriers + load_q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_q_stage * 2] + load_kv_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.load_kv_stage * 2 + ] + mma_s_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.mma_s_stage * 2] + p_mma_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.p_mma_stage * 2] + p_cor_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.p_cor_stage * 2] + mma_o_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.mma_o_stage * 2] + load_pt_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.load_pt_stage * 2 + ] + # Tmem dealloc cluster barrier + tmem_dealloc_mbar_ptr: cutlass.Int64 + + # Tmem holding buffer + tmem_holding_buf: cutlass.Int32 + # Smem tensors + softmax_smem_exchange: cute.struct.MemRange[ + self.acc_dtype, self.num_compute_warps * self.threads_per_warp + ] + epilogue_smem_exchange: cute.struct.MemRange[ + self.acc_dtype, self.num_compute_warps * self.threads_per_warp + ] + smem_q_latent: cute.struct.Align[ + cute.struct.MemRange[ + self.q_dtype, cute.cosize(q_latent_smem_layout_staged) + ], + 1024, + ] + smem_q_rope: cute.struct.Align[ + cute.struct.MemRange[ + self.q_dtype, cute.cosize(q_rope_smem_layout_staged) + ], + 1024, + ] + smem_kc: cute.struct.Align[ + cute.struct.MemRange[self.k_dtype, cute.cosize(kc_smem_layout_staged)], + 1024, + ] + smem_p: cute.struct.Align[ + cute.struct.MemRange[self.q_dtype, cute.cosize(p_smem_layout_staged)], + 1024, + ] + smem_page_table: cute.struct.MemRange[ + cutlass.Int32, self.load_pt_stage * self.mma_qk_tiler[1] // 2 + ] + + softmax_scale_log2 = softmax_scale * LOG2_E + self.split_kv_kernel( + qk_tiled_mma, + pv_tiled_mma, + tma_atom_q_latent, + tma_tensor_q_latent, + tma_atom_q_rope, + tma_tensor_q_rope, + tma_atom_c_latent, + tma_tensor_c_latent, + tma_atom_c_rope, + tma_tensor_c_rope, + tma_atom_c_latent_transpose, + tma_tensor_c_latent_transpose, + page_table, + o, + lse, + acc_o, + acc_lse, + split_kv, + cache_seqs, + block_split_kvs, + softmax_scale_log2, + output_scale, + q_latent_smem_layout_staged, + q_rope_smem_layout_staged, + kc_smem_layout_staged, + p_smem_layout_staged, + vc_smem_layout_staged, + kc_smem_layout_for_tma, + vc_smem_layout_for_tma, + cta_layout_vmnk, + tile_sched_params, + SplitKVKernelSharedStorage, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk, + smem=SplitKVKernelSharedStorage.size_in_bytes(), # type: ignore[attr-defined] + stream=stream, + min_blocks_per_mp=1, + use_pdl=self.enable_pdl, + ) + if cutlass.const_expr(acc_o is not None): + self.reduction_kernel( + o, + lse, + acc_o, + acc_lse, + split_kv, + cache_seqs, + block_split_kvs, + ).launch( + grid=(q_latent.shape[0], q_latent.shape[2], q_latent.shape[3]), + block=[self.threads_per_warp * self.num_compute_warps, 1, 1], + smem=MAX_SPLITS * self.acc_dtype.width // 8, + stream=stream, + min_blocks_per_mp=1, + use_pdl=self.enable_pdl, + ) + + @cute.jit + def make_paged_tiled_tma_atom( + self, + tma_load_op: cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp, + gmem: cute.Tensor, + smem_layout: cute.Layout, + mma_tiler, + tiled_mma: cute.TiledMma, + is_k_load: bool, + ): + ident = cute.make_identity_layout(gmem.shape) + g_tile = cute.composition(ident, mma_tiler) + cta_mn = mma_tiler[0] // tiled_mma.thr_id.shape + cta_v_map = cute.flat_divide(g_tile, (cta_mn,)) + cta_v_map = cute.select(cta_v_map, mode=[0, 2]) + page_tile_size = ( + min(self.page_size, cta_mn) + if is_k_load + else min(self.page_size, mma_tiler[1]) + ) + cta_v_map = cute.zipped_divide( + cta_v_map, + (page_tile_size, mma_tiler[1]) if is_k_load else (cta_mn, page_tile_size), + ) + cta_v_map = cute.select(cta_v_map, mode=[0]) + from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir + + res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load( + gmem.value, + smem_layout.value, + cta_v_map, + tma_load_op._to_ir(), + num_multicast=1, + ) + return cute.CopyAtom( + tma_load_op, cpasync.CopyBulkTensorTileG2SNonExecTrait(res[0]) + ), res[1] + + @cute.kernel + def split_kv_kernel( + self, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + tma_atom_q_latent: Optional[cute.CopyAtom], + mQL: cute.Tensor, + tma_atom_q_rope: Optional[cute.CopyAtom], + mQR: cute.Tensor, + tma_atom_c_latent: Optional[cute.CopyAtom], + mCL: cute.Tensor, + tma_atom_c_rope: Optional[cute.CopyAtom], + mKR: cute.Tensor, + tma_atom_c_latent_transpose: Optional[cute.CopyAtom], + mCLT: cute.Tensor, + mPT: cute.Tensor, + mO: Optional[cute.Tensor], + mLSE: Optional[cute.Tensor], + mAccO: Optional[cute.Tensor], + mAccLSE: Optional[cute.Tensor], + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + softmax_scale_log2: cutlass.Float32, + output_scale: cutlass.Float32, + q_latent_smem_layout_staged: cute.ComposedLayout, + q_rope_smem_layout_staged: cute.ComposedLayout, + kc_smem_layout_staged: cute.ComposedLayout, + p_smem_layout_staged: cute.ComposedLayout, + vc_smem_layout_staged: cute.ComposedLayout, + kc_smem_layout_for_tma: cute.ComposedLayout, + vc_smem_layout_for_tma: cute.ComposedLayout, + cta_layout_vmnk: cute.Layout, + tile_sched_params: MLAStaticTileSchedulerParams, + SharedStorage: cutlass.Constexpr, + ): + """The device split_kv kernel implementation of the Multi-Head Latent Attention. + + This kernel coordinates multiple specialized warps to perform different phases of the MLA computation: + 1. Load warp: Loads Q/C latent/rope data from global memory to shared memory using TMA + 2. MMA warp: Performs matrix multiplications (Q*K^T and P*V) + 3. Compute warps: Compute softmax and do rescaling on accumulators, and store the intermediate/final results + to global memory + + The kernel produces either intermediate or final results of the MLA computation based on the split_kv parameter. + When split_kv is 1, the kernel generates the final results directly. Otherwise, it produces intermediate results + that will later be combined by a reduction kernel. + + The kernel implements a complex pipeline with overlapping computation and memory operations, + using tensor memory access (TMA) for efficient data loading, warp specialization for different + computation phases. + + :param tiled_mma_qk: Tiled MMA for Q*K^T + :type tiled_mma_qk: cute.TiledMma + :param tiled_mma_pv: Tiled MMA for P*V + :type tiled_mma_pv: cute.TiledMma + :param tma_atom_q_latent: TMA copy atom for query latent tensor + :type tma_atom_q_latent: cute.CopyAtom + :param mQL: query latent tensor + :type mQL: cute.Tensor + :param tma_atom_q_rope: TMA copy atom for query rope tensor + :type tma_atom_q_rope: cute.CopyAtom + :param mKR: Compressed rope tensor + :type mKR: cute.Tensor + :param tma_atom_c_latent: TMA copy atom for c latent tensor + :type tma_atom_c_latent: cute.CopyAtom + :param mCL: Compressed latent tensor + :type mCL: cute.Tensor + :param tma_atom_c_rope: TMA copy atom for c rope tensor + :type tma_atom_c_rope: cute.CopyAtom + :param mCLT: Compressed latent transpose tensor + :type mCLT: cute.Tensor + :param mPT: Page table tensor + :type mPT: cute.Tensor + :param mO: Output tensor + :type mO: cute.Tensor + :param mLSE: Log-sum-exp tensor + :type mLSE: cute.Tensor + :param mAccO: Intermediate accumulator output tensor + :type mAccO: cute.Tensor + :param mAccLSE: Intermediate accumulator log-sum-exp tensor + :type mAccLSE: cute.Tensor + :param split_kv: The split_kv parameter + :type split_kv: cutlass.Int32 + :param cache_seqs: The variable sequence length tensor + :type cache_seqs: cute.Tensor + :param block_split_kvs: The per-block split_kv values tensor + :type block_split_kvs: cute.Tensor + :param softmax_scale_log2: The log2 scale factor for softmax + :type softmax_scale_log2: cutlass.Float32 + :param output_scale: The scale factor for the output + :type output_scale: cutlass.Float32 + :param q_latent_smem_layout_staged: Shared memory layout for query latent tensor + :type q_latent_smem_layout_staged: cute.ComposedLayout + :param q_rope_smem_layout_staged: Shared memory layout for query rope tensor + :type q_rope_smem_layout_staged: cute.ComposedLayout + :param kc_smem_layout_staged: Shared memory layout for key/value latent/rope tensor + :type kc_smem_layout_staged: cute.ComposedLayout + :param p_smem_layout_staged: Shared memory layout for probability matrix + :type p_smem_layout_staged: cute.ComposedLayout + :param vc_smem_layout_staged: Shared memory layout for value tensor + :type vc_smem_layout_staged: cute.ComposedLayout + :param kc_smem_layout_for_tma: Shared memory layout for key/value latent tensor for TMA + :type kc_smem_layout_for_tma: cute.ComposedLayout + :param vc_smem_layout_for_tma: Shared memory layout for value tensor for TMA + :type vc_smem_layout_for_tma: cute.ComposedLayout + :param cta_layout_vmnk: Layout for compute threads + :type cta_layout_vmnk: cute.Layout + :param tile_sched_params: Scheduling parameters for work distribution + :type tile_sched_params: MLAStaticTileSchedulerParams + :param SharedStorage: Shared storage for the kernel + :type SharedStorage: cutlass.Constexpr + """ + + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma_qk.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + + # Prefetch tma descriptor + if warp_idx == self.mma_warp_id: + cpasync.prefetch_descriptor(tma_atom_q_latent) + cpasync.prefetch_descriptor(tma_atom_q_rope) + cpasync.prefetch_descriptor(tma_atom_c_latent) + cpasync.prefetch_descriptor(tma_atom_c_rope) + cpasync.prefetch_descriptor(tma_atom_c_latent_transpose) + + # Alloc + smem = utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + + # Tensor memory dealloc barrier init + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=self.tmem_ptr_sync_bar, + allocator_warp_id=self.mma_warp_id, + is_two_cta=self.use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + load_q_pipeline = self.make_and_init_load_qkv_pipeline( + storage.load_q_mbar_ptr.data_ptr(), + cta_layout_vmnk, + self.load_q_stage, + self.tma_copy_q_bytes, + ) + load_kv_pipeline = self.make_and_init_load_qkv_pipeline( + storage.load_kv_mbar_ptr.data_ptr(), + cta_layout_vmnk, + self.load_kv_stage, + self.tma_copy_kc_bytes, + ) + mma_s_pipeline = self.make_and_init_mma_s_pipeline( + storage.mma_s_mbar_ptr.data_ptr(), cta_layout_vmnk + ) + p_mma_pipeline = self.make_and_init_p_mma_pipeline( + storage.p_mma_mbar_ptr.data_ptr(), cta_layout_vmnk + ) + p_cor_pipeline = self.make_and_init_p_cor_pipeline( + storage.p_cor_mbar_ptr.data_ptr() + ) + mma_o_pipeline = self.make_and_init_mma_o_pipeline( + storage.mma_o_mbar_ptr.data_ptr(), cta_layout_vmnk + ) + load_pt_pipeline = self.make_and_init_load_pt_pipeline( + storage.load_pt_mbar_ptr.data_ptr() + ) + + # Cluster arrive after barrier init + pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mnk, is_relaxed=True) + + # Generate smem tensor Q/KC/VC/exchange + # (MMA, MMA_H, MMA_R, PIPE) + sQ = storage.smem_q_latent.get_tensor( + q_latent_smem_layout_staged.outer, swizzle=q_latent_smem_layout_staged.inner + ) + sQ_rope = storage.smem_q_rope.get_tensor( + q_rope_smem_layout_staged.outer, swizzle=q_rope_smem_layout_staged.inner + ) + # (MMA, MMA_K, MMA_R, PIPE) + sKC = storage.smem_kc.get_tensor( + kc_smem_layout_staged.outer, swizzle=kc_smem_layout_staged.inner + ) + sKC_for_tma = storage.smem_kc.get_tensor( + kc_smem_layout_for_tma.outer, + swizzle=kc_smem_layout_for_tma.inner, + ) + # (MMA, MMA_D, MMA_K, PIPE) + # reuse smem + sVC_ptr = cute.recast_ptr(sKC.iterator, vc_smem_layout_staged.inner) + sVC = cute.make_tensor(sVC_ptr, vc_smem_layout_staged.outer) + sVC_for_tma = cute.make_tensor(sVC_ptr, vc_smem_layout_for_tma.outer) + # (MMA, MMA_H, MMA_K) + sP = storage.smem_p.get_tensor( + p_smem_layout_staged.outer, swizzle=p_smem_layout_staged.inner + ) + sPT = storage.smem_page_table.get_tensor( + cute.make_layout((self.mma_qk_tiler[1] // 2, self.load_pt_stage)) + ) + # (compute_threads,) + softmax_smem_exchange = storage.softmax_smem_exchange.get_tensor( + cute.make_layout(self.num_compute_warps * self.threads_per_warp) + ) + epilogue_smem_exchange = storage.epilogue_smem_exchange.get_tensor( + cute.make_layout(self.num_compute_warps * self.threads_per_warp) + ) + + # + # Cluster wait before tensor memory alloc + # + pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mnk) + + if cutlass.const_expr(self.enable_pdl): + cute.arch.griddepcontrol_wait() + + # /////////////////////////////////////////////////////////////////////////////// + # Load warps, including page table and data tensors + # /////////////////////////////////////////////////////////////////////////////// + + if warp_idx >= self.empty_warp_ids[0] and warp_idx <= self.empty_warp_ids[-1]: + _setmaxregister_decrease(self.other_reg_num) + if warp_idx == self.load_pt_warp_id: + _setmaxregister_decrease(self.other_reg_num) + load_pt_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.load_pt_stage + ) + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self.get_k_tile_count( + split_kv, + cache_seqs, + block_split_kvs, + blk_coord, + ) + if k_tile_count > 0: + load_pt_common_params = SimpleNamespace( + blk_coord=blk_coord, + load_pt_pipeline=load_pt_pipeline, + mPT=mPT, + sPT=sPT, + tidx=tidx, + page_size=mCL.shape[0], + ) + load_pt_producer_state = self.load_page_table( + load_pt_common_params, + k_index, + k_tile_count, + load_pt_producer_state, + ) + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + load_pt_pipeline.producer_tail(load_pt_producer_state) + if warp_idx == self.load_tma_warp_id: + _setmaxregister_decrease(self.other_reg_num) + load_q_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.load_q_stage + ) + load_kv_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.load_kv_stage + ) + load_pt_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.load_pt_stage + ) + load_pt_release_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.load_pt_stage + ) + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self.get_k_tile_count( + split_kv, + cache_seqs, + block_split_kvs, + blk_coord, + ) + if k_tile_count > 0: + # Construct fixed common/tma_qk/tma_pv params for load_tma + tma_common_params = SimpleNamespace( + blk_coord=blk_coord, + local_split_kv=local_split_kv, + load_q_pipeline=load_q_pipeline, + load_kv_pipeline=load_kv_pipeline, + mPT=mPT, + sPT=sPT, + load_pt_pipeline=load_pt_pipeline, + ) + tma_qk_params = SimpleNamespace( + tiled_mma_qk=tiled_mma_qk, + tma_atom_q_latent=tma_atom_q_latent, + tma_atom_q_rope=tma_atom_q_rope, + tma_atom_c_latent=tma_atom_c_latent, + tma_atom_c_rope=tma_atom_c_rope, + mQL=mQL, + mQR=mQR, + mCL=mCL, + mKR=mKR, + sQ=sQ, + sQ_rope=sQ_rope, + sKC=sKC_for_tma, + ) + tma_pv_params = SimpleNamespace( + tiled_mma_pv=tiled_mma_pv, + tma_atom_c_latent_transpose=tma_atom_c_latent_transpose, + mCL=mCL, + mKR=mKR, + mCLT=mCLT, + sVC=sVC_for_tma, + ) + # Load tma + ( + load_q_producer_state, + load_kv_producer_state, + load_pt_consumer_state, + load_pt_release_state, + ) = self.load_tma( + tma_common_params, + tma_qk_params, + tma_pv_params, + k_index, + k_tile_count, + load_q_producer_state, + load_kv_producer_state, + load_pt_consumer_state, + load_pt_release_state, + ) + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + load_q_pipeline.producer_tail(load_q_producer_state) + load_kv_pipeline.producer_tail(load_kv_producer_state) + + # /////////////////////////////////////////////////////////////////////////////// + # MMA warp + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.mma_warp_id: + _setmaxregister_decrease(self.other_reg_num) + # Alloc tensor memory buffer + tmem.allocate(_get_max_tmem_alloc_cols("sm_100")) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + + load_q_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.load_q_stage + ) + load_kv_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.load_kv_stage + ) + mma_s_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.mma_s_stage + ) + p_mma_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.p_mma_stage + ) + mma_o_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.mma_o_stage + ) + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self.get_k_tile_count( + split_kv, cache_seqs, block_split_kvs, blk_coord + ) + if k_tile_count > 0: + mma_common_params = SimpleNamespace( + blk_coord=blk_coord, + local_split_kv=local_split_kv, + load_q_pipeline=load_q_pipeline, + load_kv_pipeline=load_kv_pipeline, + tmem_ptr=tmem_ptr, + is_leader_cta=is_leader_cta, + L=mCL.shape[1], + ) + mma_qk_params = SimpleNamespace( + mma_s_pipeline=mma_s_pipeline, + sQ=sQ, + sQ_rope=sQ_rope, + sKC=sKC, + ) + mma_pv_params = SimpleNamespace( + p_mma_pipeline=p_mma_pipeline, + mma_o_pipeline=mma_o_pipeline, + sP=sP, + sVC=sVC, + ) + ( + tiled_mma_qk, + tiled_mma_pv, + load_q_consumer_state, + load_kv_consumer_state, + mma_s_producer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) = self.mma( + mma_common_params, + mma_qk_params, + mma_pv_params, + k_tile_count, + tiled_mma_qk, + tiled_mma_pv, + load_q_consumer_state, + load_kv_consumer_state, + mma_s_producer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + mma_s_pipeline.producer_tail(mma_s_producer_state) + mma_o_pipeline.producer_tail(mma_o_producer_state) + + tmem.relinquish_alloc_permit() + tmem.free(tmem_ptr) + if cutlass.const_expr(self.enable_pdl): + cute.arch.griddepcontrol_launch_dependents() + + # /////////////////////////////////////////////////////////////////////////////// + # Compute warp + # /////////////////////////////////////////////////////////////////////////////// + if ( + warp_idx >= self.compute_warp_ids[0] + and warp_idx <= self.compute_warp_ids[-1] + ): + _setmaxregister_increase(self.softmax_reg_num) + mma_s_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.mma_s_stage + ) + p_mma_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.p_mma_stage + ) + p_cor_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.p_cor_stage + ) + mma_o_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.mma_o_stage + ) + # sync with mma warp before retrieving tmem ptr + tmem.wait_for_alloc() + + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self.get_k_tile_count( + split_kv, cache_seqs, block_split_kvs, blk_coord + ) + if k_tile_count > 0: + compute_common_params = SimpleNamespace( + blk_coord=blk_coord, + split_kv=split_kv, + local_split_kv=local_split_kv, + smem_exchange=softmax_smem_exchange, + mAccO=mAccO, + mO=mO, + K=cache_seqs[blk_coord[2]], + L=mCL.shape[1], + tmem_ptr=tmem_ptr, + tidx=tidx, + p_cor_pipeline=p_cor_pipeline, + ) + compute_softmax_params = SimpleNamespace( + tiled_mma_qk=tiled_mma_qk, + sP=sP, + mma_s_pipeline=mma_s_pipeline, + p_mma_pipeline=p_mma_pipeline, + softmax_scale_log2=softmax_scale_log2, + ) + mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state = ( + self.compute( + compute_common_params, + compute_softmax_params, + k_index=k_index, + k_tile_count=k_tile_count, + mma_s_consumer_state=mma_s_consumer_state, + p_mma_producer_state=p_mma_producer_state, + p_cor_producer_state=p_cor_producer_state, + ) + ) + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + p_cor_pipeline.producer_tail(p_cor_producer_state) + + # /////////////////////////////////////////////////////////////////////////////// + # Correction warp + # /////////////////////////////////////////////////////////////////////////////// + if ( + warp_idx >= self.correction_warp_ids[0] + and warp_idx <= self.correction_warp_ids[-1] + ): + _setmaxregister_increase(self.correction_reg_num) + p_cor_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.p_cor_stage + ) + mma_o_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.mma_o_stage + ) + # sync with mma warp before retrieving tmem ptr + tmem.wait_for_alloc() + + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self.get_k_tile_count( + split_kv, cache_seqs, block_split_kvs, blk_coord + ) + if k_tile_count > 0: + compute_common_params = SimpleNamespace( + blk_coord=blk_coord, + split_kv=split_kv, + local_split_kv=local_split_kv, + smem_exchange=epilogue_smem_exchange, + mAccO=mAccO, + mO=mO, + K=cache_seqs[blk_coord[2]], + L=mCL.shape[1], + H=mQL.shape[0], + tmem_ptr=tmem_ptr, + tidx=tidx, + tiled_mma_pv=tiled_mma_pv, + p_cor_pipeline=p_cor_pipeline, + mma_o_pipeline=mma_o_pipeline, + ) + compute_epilogue_params = SimpleNamespace( + output_scale=output_scale, + softmax_scale_log2=softmax_scale_log2, + mAccLSE=mAccLSE, + mLSE=mLSE, + ) + p_cor_consumer_state, mma_o_consumer_state = self.correction( + compute_common_params, + compute_epilogue_params, + k_tile_count=k_tile_count, + p_cor_consumer_state=p_cor_consumer_state, + mma_o_consumer_state=mma_o_consumer_state, + ) + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + return + + @cute.kernel + def reduction_kernel( + self, + mO: cute.Tensor, + mLSE: cute.Tensor, + mAccO: cute.Tensor, + mAccLSE: cute.Tensor, + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + ): + """The reduction kernel for Multi-Head Latent Attention (MLA) that combines intermediate results + from multiple split_kv blocks into final outputs. + + :param mO: Output tensor for storing final results + :type mO: cute.Tensor + :param mLSE: Log-sum-exp tensor for storing final LSE values + :type mLSE: cute.Tensor + :param mAccO: Accumulated output tensor from split_kv blocks + :type mAccO: cute.Tensor + :param mAccLSE: Accumulated LSE tensor from split_kv blocks + :type mAccLSE: cute.Tensor + :param split_kv: Number of split_kv blocks + :type split_kv: cutlass.Int32 + :param cache_seqs: Cache sequence lengths tensor + :type cache_seqs: cute.Tensor + :param block_split_kvs: Per-block split_kv values tensor (for variable split_kv) + :type block_split_kvs: cute.Tensor + """ + bidx, bidy, bidz = cute.arch.block_idx() + tidx, _, _ = cute.arch.thread_idx() + blk_coord = (bidx, bidy, bidz) + local_split_kv = ( + block_split_kvs[blk_coord[2]] if self.is_var_split_kv else split_kv + ) + k_tile_total = cute.ceil_div(cache_seqs[blk_coord[2]], self.mma_qk_tiler[1]) + k_tile_per_cta = cute.ceil_div(k_tile_total, local_split_kv) + local_split_kv = cute.ceil_div(k_tile_total, k_tile_per_cta) + + # Alloc shared memory + smem = utils.SmemAllocator() + storage = smem.allocate(MAX_SPLITS * self.acc_dtype.width // 8, 16) + lse_scale_ptr = cute.recast_ptr(storage, dtype=self.acc_dtype) + smem_lse_scale = cute.make_tensor(lse_scale_ptr, cute.make_layout(MAX_SPLITS)) + + if cutlass.const_expr(self.enable_pdl): + cute.arch.griddepcontrol_wait() + gLSE = mAccLSE[blk_coord[0], None, blk_coord[1], blk_coord[2]] + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + if warp_idx == 0: + # calculate the global lse and exp ^ (local_lse - global_lse) + lse_per_thread = cute.ceil_div(MAX_SPLITS, self.threads_per_warp) + + local_lse = cute.make_rmem_tensor( + cute.make_layout(lse_per_thread), self.lse_dtype + ) + lse_max = -self.lse_dtype.inf + # find the max lse + for i in cutlass.range_constexpr(lse_per_thread): + split_kv_idx = tidx + i * self.threads_per_warp + local_lse[i] = ( + gLSE[split_kv_idx] + if cute.elem_less(split_kv_idx, local_split_kv) + else -self.lse_dtype.inf + ) + # reduce the local lse + lse_max = cute.arch.fmax(lse_max, local_lse[i]) + lse_max = cute.arch.warp_reduction_max(lse_max) + lse_max = lse_max if lse_max != -self.lse_dtype.inf else 0.0 + # calculate sum_lse + sum_lse = 0.0 + for i in cutlass.range_constexpr(lse_per_thread): + sum_lse += cute.math.exp2(local_lse[i] - lse_max, fastmath=True) + sum_lse = cute.arch.warp_reduction_sum(sum_lse) + # calculate the global_lse + global_lse = ( + lse_max + cute.math.log2(sum_lse, fastmath=True) + if not sum_lse == self.lse_dtype(0.0) or sum_lse != sum_lse # noqa: SIM201 + else self.lse_dtype.inf + ) + if tidx == 0: + mLSE[blk_coord[0], blk_coord[1], blk_coord[2]] = global_lse + # store the scale to shared memory + for i in cutlass.range_constexpr(lse_per_thread): + split_kv_idx = tidx + i * self.threads_per_warp + if cute.elem_less(split_kv_idx, local_split_kv): + smem_lse_scale[split_kv_idx] = cute.math.exp2( + local_lse[i] - global_lse, fastmath=True + ) + + pipeline.sync(barrier_id=4) + + elements_per_thread = cute.ceil_div( + self.latent_dim, self.threads_per_warp * self.num_compute_warps + ) + gAccO = mAccO[blk_coord[0], None, None, blk_coord[1], blk_coord[2]] + rAccO = cute.make_rmem_tensor( + cute.make_layout(elements_per_thread), self.acc_dtype + ) + rO = cute.make_rmem_tensor(cute.make_layout(elements_per_thread), self.o_dtype) + rAccO.fill(0.0) + for i in range(local_split_kv): + for j in cutlass.range_constexpr(elements_per_thread): + element_idx = tidx + j * self.threads_per_warp * self.num_compute_warps + rAccO[j] += gAccO[i, element_idx] * smem_lse_scale[i] + rO.store(rAccO.load().to(self.o_dtype)) + for j in cutlass.range_constexpr(elements_per_thread): + element_idx = tidx + j * self.threads_per_warp * self.num_compute_warps + mO[blk_coord[0], element_idx, blk_coord[1], blk_coord[2]] = rO[j] + if cutlass.const_expr(self.enable_pdl): + cute.arch.griddepcontrol_launch_dependents() + return + + @staticmethod + def get_split_kv( + B: int, S: int, K: int, mma_qk_tiler_mn: tuple, max_active_blocks: int + ) -> int: + """Get the proper split_kv value for the MLA kernel based on parameters. + + :param B: Batch size + :type B: int + :param S: Sequence length + :type S: int + :param K: Sequence length + :type K: int + :param mma_qk_tiler_mn: MLA tiling parameters + :type mma_qk_tiler_mn: tuple + :param max_active_blocks: Maximum number of active blocks + :type max_active_blocks: int + :return: Split_kv value + :rtype: int + """ + max_splits = ceil_div(K, mma_qk_tiler_mn[1]) + blocks_per_batch = max(1, max_active_blocks // B // (S * 2)) + split_heur = min(max_splits, blocks_per_batch) + # {$nv-internal-release begin} + # TODO: figure out the error of make_tile with dynamic int_tuple + # {$nv-internal-release end} + k_waves = ceil_div(max_splits, split_heur) + split_wave_aware = ceil_div(max_splits, k_waves) + max_split_kv = 32 + return min(split_wave_aware, max_split_kv) + + @staticmethod + def get_split_kv_simplified(B: int, S: int, max_active_blocks: int) -> int: + blocks_per_batch = max(1, max_active_blocks // B // (S * 2)) + max_split_kv = 32 + return min(blocks_per_batch, max_split_kv) + + @cute.jit + def get_k_tile_count( + self, + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + blk_coord: cute.Coord, + ) -> tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32]: + """Get the current k_index, k_tile_count, and local split_kv value for the MLA kernel. + + :param split_kv: Split_kv value + :type split_kv: cutlass.Int32 + :param cache_seqs: Cache sequence lengths tensor + :type cache_seqs: cute.Tensor + :param block_split_kvs: Per-block split_kv values tensor + :type block_split_kvs: cute.Tensor + :param blk_coord: Block coordinate + :type blk_coord: cute.Coord + :return: k_index, k_tile_count, split_kv + :rtype: tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32] + """ + K = cache_seqs[blk_coord[2]] + if cutlass.const_expr(self.is_var_split_kv): + split_kv = block_split_kvs[blk_coord[2]] + + k_tile_total = cute.ceil_div(K, self.mma_qk_tiler[1]) + # {$nv-internal-release begin} + # TODO: figure out the error of make_tile with dynamic int_tuple + # {$nv-internal-release end} + k_tile_per_cta = cute.ceil_div(k_tile_total, split_kv) + k_index = blk_coord[3] * k_tile_per_cta + k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index) + return k_index, k_tile_count, split_kv + + @cute.jit + def load_page_table( + self, + common_params: SimpleNamespace, + k_index: cutlass.Int32, + k_tile_count: cutlass.Int32, + load_pt_producer_state: pipeline.PipelineState, + ) -> pipeline.PipelineState: + """Load warp to load page table. Updates the load pt producer state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param k_index: The k index + :type k_index: cutlass.Int32 + :param k_tile_count: The k tile count + :type k_tile_count: cutlass.Int32 + :param load_pt_producer_state: The load pt producer state + :type load_pt_producer_state: pipeline.PipelineState + + :return: The load pt producer state + :rtype: pipeline.PipelineState + """ + mPT = common_params.mPT[None, common_params.blk_coord[2]] + page_per_tile = self.mma_qk_tiler[1] // self.page_size + tidx = common_params.tidx % self.threads_per_warp + + load_pt_pipeline = common_params.load_pt_pipeline + while k_tile_count > 0: + load_pt_pipeline.producer_acquire(load_pt_producer_state) + + elem_per_thread = cute.ceil_div(page_per_tile, self.threads_per_warp) + + # atom_async_copy: async copy atom for page table load + atom_async_copy = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS), + cutlass.Int32, + num_bits_per_copy=cutlass.Int32.width, + ) + mPT_for_copy = cute.flat_divide(mPT, (1,)) + sPT_for_copy = cute.flat_divide(common_params.sPT, (1,)) + # elem_per_thread is a dynamic value depends on the page_size setting. + for i in range(elem_per_thread): + idx = i * self.threads_per_warp + tidx + if cute.elem_less( + k_index * page_per_tile + idx, mPT.shape[0] + ) and cute.elem_less(idx, page_per_tile): + cute.copy( + atom_async_copy, + mPT_for_copy[None, k_index * page_per_tile + idx], + sPT_for_copy[None, idx, load_pt_producer_state.index], + ) + else: + sPT_for_copy[None, idx, load_pt_producer_state.index].fill(0) + mbar_ptr = load_pt_pipeline.producer_get_barrier(load_pt_producer_state) # noqa: F841 + load_pt_pipeline.producer_commit(load_pt_producer_state) + load_pt_producer_state.advance() + k_index += 1 + k_tile_count -= 1 + + return load_pt_producer_state + + @cute.jit + def load_tma( + self, + common_params: SimpleNamespace, + qk_params: SimpleNamespace, + v_params: SimpleNamespace, + k_index: cutlass.Int32, + k_tile_count: cutlass.Int32, + load_q_producer_state: pipeline.PipelineState, + load_kv_producer_state: pipeline.PipelineState, + load_pt_consumer_state: pipeline.PipelineState, + load_pt_release_state: pipeline.PipelineState, + ) -> tuple[ + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + ]: + """Load wrap to load Q/C latent/rope tensors. Updates the load qkv producer state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param qk_params: The qk parameters + :type qk_params: SimpleNamespace + :param v_params: The v parameters + :type v_params: SimpleNamespace + :param k_index: The k index + :type k_index: cutlass.Int32 + :param k_tile_count: The k tile count + :type k_tile_count: cutlass.Int32 + :param load_q_producer_state: The load q producer state + :type load_q_producer_state: pipeline.PipelineState + :param load_kv_producer_state: The load kv producer state + :type load_kv_producer_state: pipeline.PipelineState + :param load_pt_consumer_state: The load pt consumer state + :type load_pt_consumer_state: pipeline.PipelineState + :param load_pt_release_state: The load pt release state + :type load_pt_release_state: pipeline.PipelineState + + :return: The load q producer state, load kv producer state, load pt consumer state, and load pt release state + :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] + """ + # page table + mPT = common_params.mPT[None, common_params.blk_coord[2]] + + # Flatten divide and partition global tensors for QK TMA load + # (bM, bK, rM, rK, rL) + mma_qk_tiler_mk = cute.select(self.mma_qk_tiler, mode=[0, 2]) + gQL = cute.flat_divide(qk_params.mQL, mma_qk_tiler_mk) + mma_qk_tiler_mk_rope = cute.select(self.mma_qk_rope_tiler, mode=[0, 2]) + gQR = cute.flat_divide(qk_params.mQR, mma_qk_tiler_mk_rope) + + thr_mma_qk = qk_params.tiled_mma_qk.get_slice( + common_params.blk_coord[0] % cute.size(qk_params.tiled_mma_qk.thr_id) + ) + tSgQL = thr_mma_qk.partition_A(gQL) + tSgQR = thr_mma_qk.partition_A(gQR) + + cta_m = min( + qk_params.tiled_mma_qk.op.shape_mnk[0] + // qk_params.tiled_mma_qk.thr_id.shape, + self.page_size, + ) + page_tile_size = min(self.page_size, cta_m) + gCL = cute.tiled_divide(qk_params.mCL, (page_tile_size, self.mma_qk_tiler[2])) + tSgCL = ( + gCL[ + None, + common_params.blk_coord[0] % qk_params.tiled_mma_qk.thr_id.shape, + None, + None, + ] + if cta_m < self.page_size + else gCL[None, 0, None, None] + ) + gKR = cute.tiled_divide(qk_params.mKR, (page_tile_size, self.mma_qk_tiler[2])) + tSgKR = ( + gKR[ + None, + common_params.blk_coord[0] % qk_params.tiled_mma_qk.thr_id.shape, + None, + None, + ] + if cta_m < self.page_size + else gKR[None, 0, None, None] + ) + + # tma partition for q, k latent/rope + # smem: ((atom_v, rest_v), STAGE) + # gmem: ((atom_v, rest_v), RestM, RestK, RestL) + tQsQ, tQLgQL_mkl = cpasync.tma_partition( + qk_params.tma_atom_q_latent, + 0, + cute.make_layout(1), + cute.group_modes(qk_params.sQ, 0, 3), + cute.group_modes(tSgQL, 0, 3), + ) + + tQsQ_rope, tQRgQR_mkl = cpasync.tma_partition( + qk_params.tma_atom_q_rope, + 0, + cute.make_layout(1), + cute.group_modes(qk_params.sQ_rope, 0, 3), + cute.group_modes(tSgQR, 0, 3), + ) + + tKCsKC, tCLgCL = cpasync.tma_partition( + qk_params.tma_atom_c_latent, + 0, + cute.make_layout(1), + qk_params.sKC, + tSgCL, + ) + + _, tKRgKR = cpasync.tma_partition( + qk_params.tma_atom_c_rope, + 0, + cute.make_layout(1), + qk_params.sKC, + tSgKR, + ) + + tQLgQL = tQLgQL_mkl[ + None, None, None, common_params.blk_coord[1], common_params.blk_coord[2] + ] + tQRgQR = tQRgQR_mkl[ + None, None, None, common_params.blk_coord[1], common_params.blk_coord[2] + ] + + # Flatten divide and partition global tensors for V TMA load + page_tile_size = min(self.page_size, self.mma_pv_tiler[2]) + gCLT = cute.flat_divide(v_params.mCLT, (self.mma_pv_tiler[1], page_tile_size)) + cta_n = self.mma_pv_tiler[1] // v_params.tiled_mma_pv.thr_id.shape + gCLT = cute.logical_divide(gCLT, (cta_n,))[ + (None, common_params.blk_coord[0]), None, None, None, None + ] + tOgCLT = cute.tiled_divide(gCLT, (cta_n, page_tile_size)) + tOgCLT = tOgCLT[None, 0, 0, None, None, None] + + # tma partition for vc + # smem: ((atom_v, rest_v), STAGE) + # gmem: ((atom_v, rest_v), RestM, RestK, RestL) + tVCsVC, tCLTgCLT = cpasync.tma_partition( + v_params.tma_atom_c_latent_transpose, + 0, + cute.make_layout(1), + v_params.sVC, + tOgCLT, + ) + + # set extra params + common_params.mPT = mPT + qk_params.tQLgQL = tQLgQL + qk_params.tQRgQR = tQRgQR + qk_params.tCLgCL = tCLgCL + qk_params.tKRgKR = tKRgKR + qk_params.tQsQ = tQsQ + qk_params.tQsQ_rope = tQsQ_rope + qk_params.tKCsKC = tKCsKC + v_params.tCLTgCLT = tCLTgCLT + v_params.tVCsVC = tVCsVC + + load_q_producer_state, load_kv_producer_state, load_pt_consumer_state = ( + self.load_tma_qk_one_k_tile( + common_params, + qk_params, + k_index, + k_tile_count, + load_q_producer_state, + load_kv_producer_state, + load_pt_consumer_state, + load_q=True, + ) + ) + k_index += 1 + k_tile_count -= 1 + while k_tile_count > 0: + # {$nv-internal-release begin} + # TODO: figure out how to support SingleNamespace/struct in ast + # {$nv-internal-release end} + load_q_producer_state, load_kv_producer_state, load_pt_consumer_state = ( + self.load_tma_qk_one_k_tile( + common_params, + qk_params, + k_index, + k_tile_count, + load_q_producer_state, + load_kv_producer_state, + load_pt_consumer_state, + load_q=False, + ) + ) + load_kv_producer_state, load_pt_release_state = self.load_tma_v_one_k_tile( + common_params, + v_params, + k_index - 1, + load_kv_producer_state, + load_pt_release_state, + ) + k_index += 1 + k_tile_count -= 1 + + # load last v tile + load_kv_producer_state, load_pt_release_state = self.load_tma_v_one_k_tile( + common_params, + v_params, + k_index - 1, + load_kv_producer_state, + load_pt_release_state, + ) + return ( + load_q_producer_state, + load_kv_producer_state, + load_pt_consumer_state, + load_pt_release_state, + ) + + @cute.jit + def load_tma_qk_one_k_tile( + self, + common_params: SimpleNamespace, + qk_params: SimpleNamespace, + k_index: cutlass.Int32, + k_tile_count: cutlass.Int32, + load_q_producer_state: pipeline.PipelineState, + load_kv_producer_state: pipeline.PipelineState, + load_pt_consumer_state: pipeline.PipelineState, + load_q: bool, + ) -> tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState]: + """Load one k-tile of Q/C latent/rope tensors. Updates the load qkv producer state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param qk_params: The qk parameters + :type qk_params: SimpleNamespace + :param k_index: The k index + :type k_index: cutlass.Int32 + :param k_tile_count: The k tile count + :type k_tile_count: cutlass.Int32 + :param load_q_producer_state: The load q producer state + :type load_q_producer_state: pipeline.PipelineState + :param load_kv_producer_state: The load kv producer state + :type load_kv_producer_state: pipeline.PipelineState + :param load_pt_consumer_state: The load pt consumer state + :type load_pt_consumer_state: pipeline.PipelineState + :param load_q: Whether to load q + :type load_q: bool + + :return: The load q producer state, load kv producer state, and load pt consumer state + :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] + """ + page_per_tile = ceil_div( + self.mma_qk_tiler[1] // self.page_size, qk_params.tiled_mma_qk.thr_id.shape + ) + common_params.load_pt_pipeline.consumer_wait(load_pt_consumer_state) + page_table_stage = load_pt_consumer_state.index + load_pt_consumer_state.advance() + k_idx = cute.make_rmem_tensor(cute.make_layout(page_per_tile), cutlass.Int32) + for i in cutlass.range_constexpr(page_per_tile): + k_idx[i] = ( + common_params.sPT[0, page_table_stage] + if self.mma_qk_tiler[1] // self.page_size == 1 + else common_params.sPT[ + i + common_params.blk_coord[0] * page_per_tile, page_table_stage + ] + ) + # load q once at first iteration + if cutlass.const_expr(load_q): + common_params.load_q_pipeline.producer_acquire(load_q_producer_state) + # get the mbar ptr from pipeline. + tma_bar_ptr = common_params.load_q_pipeline.producer_get_barrier( + load_q_producer_state + ) + for i in cutlass.range(self.iterations_qk_latent): + # load q latent + cute.copy( + qk_params.tma_atom_q_latent, + qk_params.tQLgQL[None, 0, i], + qk_params.tQsQ[None, (i, 0)], + tma_bar_ptr=tma_bar_ptr, + ) + for i in cutlass.range(self.iterations_qk_rope): + # load q rope + cute.copy( + qk_params.tma_atom_q_rope, + qk_params.tQRgQR[None, 0, i], + qk_params.tQsQ_rope[None, i], + tma_bar_ptr=tma_bar_ptr, + ) + load_q_producer_state.advance() + load_kv_pipeline = common_params.load_kv_pipeline + tma_bar_ptr = load_kv_pipeline.producer_get_barrier(load_kv_producer_state) + for i in cutlass.range(self.iterations_qk_latent): + # get the mbar ptr from pipeline. + tma_bar_ptr = load_kv_pipeline.producer_get_barrier(load_kv_producer_state) + load_kv_pipeline.producer_acquire(load_kv_producer_state) + for k in cutlass.range(page_per_tile): + # load k latent + cute.copy( + qk_params.tma_atom_c_latent, + qk_params.tCLgCL[None, i, k_idx[k]], + qk_params.tKCsKC[None, k, 0, load_kv_producer_state.index], + tma_bar_ptr=tma_bar_ptr, + ) + load_kv_producer_state.advance() + + for i in cutlass.range(self.iterations_qk_rope): + # get the mbar ptr from pipeline. + tma_bar_ptr = load_kv_pipeline.producer_get_barrier(load_kv_producer_state) + load_kv_pipeline.producer_acquire(load_kv_producer_state) + for k in cutlass.range(page_per_tile): + # load k rope + cute.copy( + qk_params.tma_atom_c_rope, + qk_params.tKRgKR[None, i, k_idx[k]], + qk_params.tKCsKC[None, k, 0, load_kv_producer_state.index], + tma_bar_ptr=tma_bar_ptr, + ) + load_kv_producer_state.advance() + + return load_q_producer_state, load_kv_producer_state, load_pt_consumer_state + + @cute.jit + def load_tma_v_one_k_tile( + self, + common_params: SimpleNamespace, + v_params: SimpleNamespace, + k_index: cutlass.Int32, + load_kv_producer_state: pipeline.PipelineState, + load_pt_release_state: pipeline.PipelineState, + ) -> tuple[pipeline.PipelineState, pipeline.PipelineState]: + """Load one k-tile of compressed latent transpose tensor(v). Updates the load qkv producer state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param v_params: The load tma v parameters + :type v_params: SimpleNamespace + :param k_index: The k index + :type k_index: cutlass.Int32 + :param load_kv_producer_state: The load qkv producer state + :type load_kv_producer_state: pipeline.PipelineState + :param load_pt_release_state: The load pt release state + :type load_pt_release_state: pipeline.PipelineState + + :return: The load kv producer state and load pt release state + :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState] + """ + page_per_tile = self.mma_pv_tiler[2] * self.iterations_pv_k // self.page_size + page_per_subtile = ceil_div(page_per_tile, self.iterations_pv_k) + k_idx = cute.make_rmem_tensor(cute.make_layout(page_per_tile), cutlass.Int32) + page_table_stage = load_pt_release_state.index + for i in cutlass.range(page_per_tile): + k_idx[i] = ( + common_params.sPT[0, page_table_stage] + if page_per_tile == 1 + else common_params.sPT[i, page_table_stage] + ) + common_params.load_pt_pipeline.consumer_release(load_pt_release_state) + load_pt_release_state.advance() + load_kv_pipeline = common_params.load_kv_pipeline + tma_bar_ptr = load_kv_pipeline.producer_get_barrier(load_kv_producer_state) + for i in cutlass.range(self.iterations_pv_k): + for j in cutlass.range(self.iterations_pv_n): + # get the mbar ptr from pipeline. + tma_bar_ptr = load_kv_pipeline.producer_get_barrier( + load_kv_producer_state + ) + load_kv_pipeline.producer_acquire(load_kv_producer_state) + for k in cutlass.range(page_per_subtile): + k_idx_i = k_idx[ + k + + i + // ceil_div(self.iterations_pv_k, page_per_tile) + * page_per_subtile + ] + cute.copy( + v_params.tma_atom_c_latent_transpose, + v_params.tCLTgCLT[ + None, + j, + i % ceil_div(self.iterations_pv_k, page_per_tile), + k_idx_i, + ], + v_params.tVCsVC[None, 0, k, load_kv_producer_state.index], + tma_bar_ptr=tma_bar_ptr, + ) + + load_kv_producer_state.advance() + return load_kv_producer_state, load_pt_release_state + + @cute.jit + def mma( + self, + common_params: SimpleNamespace, + qk_params: SimpleNamespace, + pv_params: SimpleNamespace, + k_tile_count: cutlass.Int32, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + load_q_consumer_state: pipeline.PipelineState, + load_kv_consumer_state: pipeline.PipelineState, + mma_s_producer_state: pipeline.PipelineState, + p_mma_consumer_state: pipeline.PipelineState, + mma_o_producer_state: pipeline.PipelineState, + ) -> tuple[ + cute.TiledMma, + cute.TiledMma, + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + ]: + """MMA warp to compute the result of Q*K^T and P*V. Updates the tiled mma and pipeline states. + + :param common_params: The common parameters for mma qk and pv + :type common_params: SimpleNamespace + :param qk_params: The mma qk parameters + :type qk_params: SimpleNamespace + :param pv_params: The mma pv parameters + :type pv_params: SimpleNamespace + :param k_tile_count: The k tile count + :type k_tile_count: cutlass.Int32 + :param tiled_mma_qk: The tiled mma qk + :type tiled_mma_qk: cute.TiledMma + :param tiled_mma_pv: The tiled mma pv + :type tiled_mma_pv: cute.TiledMma + :param load_q_consumer_state: The load q consumer state + :type load_q_consumer_state: pipeline.PipelineState + :param load_kv_consumer_state: The load kv consumer state + :type load_kv_consumer_state: pipeline.PipelineState + :param mma_s_producer_state: The mma s producer state + :type mma_s_producer_state: pipeline.PipelineState + :param p_mma_consumer_state: The p mma consumer state + :type p_mma_consumer_state: pipeline.PipelineState + :param mma_o_producer_state: The mma o producer state + :type mma_o_producer_state: pipeline.PipelineState + + :return: The tiled mma qk, the tiled mma pv, the load q consumer state, the load kv consumer state, the mma s producer state, the p mma consumer state, and the mma o producer state + :rtype: tuple[cute.TiledMma, cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] + """ + + tSrQ = tiled_mma_qk.make_fragment_A(qk_params.sQ) + tSrQ_rope = tiled_mma_qk.make_fragment_A(qk_params.sQ_rope) + tSrKC = tiled_mma_qk.make_fragment_B(qk_params.sKC) + tOrP = tiled_mma_pv.make_fragment_A(pv_params.sP) + tOrVC = tiled_mma_pv.make_fragment_B(pv_params.sVC) + + tStS_shape = tiled_mma_qk.partition_shape_C( + cute.select(self.mma_qk_tiler, mode=[0, 1]) + ) + tStS_staged_fake = tiled_mma_qk.make_fragment_C( + cute.append(tStS_shape, self.mma_s_stage) + ) + # use real tmem ptr for tStS + tStS_staged = cute.make_tensor(common_params.tmem_ptr, tStS_staged_fake.layout) + tOtO_shape = tiled_mma_pv.partition_shape_C( + cute.select(self.mma_pv_tiler, mode=[0, 1]) + ) + # mma O has 1 stage. + tOtO = tiled_mma_pv.make_fragment_C(tOtO_shape) + tOtO_layout = cute.append( + tOtO.layout, + cute.make_layout( + common_params.L // self.mma_pv_tiler[1], + stride=self.mma_pv_tiler[1] // self.warps_in_n, + ), + ) + tOtO_staged = cute.make_tensor( + tStS_staged.iterator + self.tmem_o_offset, tOtO_layout + ) + + # set more parameters + qk_params.tSrQ = tSrQ + qk_params.tSrQ_rope = tSrQ_rope + qk_params.tSrKC = tSrKC + qk_params.tStS_staged = tStS_staged + pv_params.tOrP = tOrP + pv_params.tOrVC = tOrVC + pv_params.tOtO_staged = tOtO_staged + + # mma O accumulates on K, so the accumlate flag is set to False once before all K blocks. + tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, False) + load_q_pipeline = common_params.load_q_pipeline + if common_params.is_leader_cta: + load_q_release_state = load_q_consumer_state.clone() + + ( + tiled_mma_qk, + load_q_consumer_state, + load_kv_consumer_state, + mma_s_producer_state, + ) = self.mma_qk( + common_params, + qk_params, + tiled_mma_qk, + load_q_consumer_state, + load_kv_consumer_state, + mma_s_producer_state, + wait_q=True, + ) + k_tile_count -= 1 + while k_tile_count > 0: + ( + tiled_mma_qk, + load_q_consumer_state, + load_kv_consumer_state, + mma_s_producer_state, + ) = self.mma_qk( + common_params, + qk_params, + tiled_mma_qk, + load_q_consumer_state, + load_kv_consumer_state, + mma_s_producer_state, + wait_q=False, + ) + ( + tiled_mma_pv, + load_kv_consumer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) = self.mma_pv( + common_params, + pv_params, + tiled_mma_pv, + load_kv_consumer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) + k_tile_count -= 1 + + # release q consumer states + load_q_pipeline.consumer_release(load_q_release_state) + load_q_release_state.advance() + ( + tiled_mma_pv, + load_kv_consumer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) = self.mma_pv( + common_params, + pv_params, + tiled_mma_pv, + load_kv_consumer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) + + return ( # type: ignore[return-value] + tiled_mma_qk, + tiled_mma_pv, + load_q_consumer_state, + load_kv_consumer_state, + mma_s_producer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) + + @cute.jit + def mma_qk( + self, + common_params: SimpleNamespace, + qk_params: SimpleNamespace, + tiled_mma_qk: cute.TiledMma, + load_q_consumer_state: pipeline.PipelineState, + load_kv_consumer_state: pipeline.PipelineState, + mma_s_producer_state: pipeline.PipelineState, + wait_q: bool, + ) -> tuple[ + cute.TiledMma, + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + ]: + """Compute one k-tile of mma for Q*K^T. Updates the tiled MMA QK and pipeline states. + + :param qk_params: The qk parameters + :type qk_params: SimpleNamespace + :param tiled_mma_qk: The tiled mma qk + :type tiled_mma_qk: cute.TiledMma + :param load_q_consumer_state: The load q consumer state + :type load_q_consumer_state: pipeline.PipelineState + :param load_kv_consumer_state: The load kv consumer state + :type load_kv_consumer_state: pipeline.PipelineState + :param mma_s_producer_state: The mma s producer state + :type mma_s_producer_state: pipeline.PipelineState + + :return: The tiled mma qk, the load q consumer state, the load kv consumer state, and the mma s producer state + :rtype: tuple[cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] + """ + tStS = qk_params.tStS_staged[None, None, None, mma_s_producer_state.index] + + qk_params.mma_s_pipeline.producer_acquire(mma_s_producer_state) + tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, False) + load_q_pipeline = common_params.load_q_pipeline + load_kv_pipeline = common_params.load_kv_pipeline + if cutlass.const_expr(wait_q): + load_q_pipeline.consumer_wait(load_q_consumer_state) + load_q_consumer_state.advance() + for q_stage in range(self.iterations_qk_latent): + load_kv_pipeline.consumer_wait(load_kv_consumer_state) + kc_stage = load_kv_consumer_state.index + for k_block in cutlass.range(cute.size(qk_params.tSrQ.shape[2])): + cute.gemm( + tiled_mma_qk, + tStS, + qk_params.tSrQ[None, None, k_block, q_stage], + qk_params.tSrKC[None, None, k_block, kc_stage], + tStS, + ) + tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, True) + load_kv_pipeline.consumer_release(load_kv_consumer_state) + load_kv_consumer_state.advance() + for q_stage in range(self.iterations_qk_rope): + load_kv_pipeline.consumer_wait(load_kv_consumer_state) + kc_stage = load_kv_consumer_state.index + for k_block in cutlass.range(self.rope_dim // tiled_mma_qk.shape_mnk[2]): + cute.gemm( + tiled_mma_qk, + tStS, + qk_params.tSrQ_rope[None, None, k_block, q_stage], + qk_params.tSrKC[None, None, k_block, kc_stage], + tStS, + ) + tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, True) + load_kv_pipeline.consumer_release(load_kv_consumer_state) + load_kv_consumer_state.advance() + + qk_params.mma_s_pipeline.producer_commit(mma_s_producer_state) + mma_s_producer_state.advance() + return ( + tiled_mma_qk, + load_q_consumer_state, + load_kv_consumer_state, + mma_s_producer_state, + ) + + @cute.jit + def mma_pv( + self, + common_params: SimpleNamespace, + pv_params: SimpleNamespace, + tiled_mma_pv: cute.TiledMma, + load_kv_consumer_state: pipeline.PipelineState, + p_mma_consumer_state: pipeline.PipelineState, + mma_o_producer_state: pipeline.PipelineState, + ) -> tuple[ + cute.TiledMma, + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + ]: + """Compute one k-tile of mma for P*V. Updates the tiled mma pv and pipeline states. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param pv_params: The pv parameters + :type pv_params: SimpleNamespace + :param tiled_mma_pv: The tiled mma pv + :type tiled_mma_pv: cute.TiledMma + :param load_kv_consumer_state: The load kv consumer state + :type load_kv_consumer_state: pipeline.PipelineState + :param p_mma_consumer_state: The P MMA consumer state + :type p_mma_consumer_state: pipeline.PipelineState + :param mma_o_producer_state: The MMA o producer state + :type mma_o_producer_state: pipeline.PipelineState + + :return: The tiled mma pv, the load qkv consumer state, the P MMA consumer state, and the MMA o producer state + :rtype: tuple[cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] + """ + + pv_params.mma_o_pipeline.producer_acquire(mma_o_producer_state) + pv_params.p_mma_pipeline.consumer_wait(p_mma_consumer_state) + load_kv_pipeline = common_params.load_kv_pipeline + for p_stage in range(self.iterations_pv_k): + accumulate_flag = tiled_mma_pv.get(tcgen05.Field.ACCUMULATE) + for acc_stage in range(self.iterations_pv_n): + load_kv_pipeline.consumer_wait(load_kv_consumer_state) + tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, accumulate_flag) + vc_stage = load_kv_consumer_state.index + tOtO = pv_params.tOtO_staged[None, None, None, acc_stage] + for k_block in cutlass.range(pv_params.tOrP.shape[2]): + cute.gemm( + tiled_mma_pv, + tOtO, + pv_params.tOrP[ + None, + None, + k_block, + (p_stage, p_mma_consumer_state.index), + ], + pv_params.tOrVC[None, None, k_block, vc_stage], + tOtO, + ) + tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, True) + load_kv_pipeline.consumer_release(load_kv_consumer_state) + load_kv_consumer_state.advance() + pv_params.p_mma_pipeline.consumer_release(p_mma_consumer_state) + p_mma_consumer_state.advance() + pv_params.mma_o_pipeline.producer_commit(mma_o_producer_state) + mma_o_producer_state.advance() + + return ( + tiled_mma_pv, + load_kv_consumer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) + + @cute.jit + def compute( + self, + common_params: SimpleNamespace, + softmax_params: SimpleNamespace, + k_index: cutlass.Int32, + k_tile_count: cutlass.Int32, + mma_s_consumer_state: pipeline.PipelineState, + p_mma_producer_state: pipeline.PipelineState, + p_cor_producer_state: pipeline.PipelineState, + ) -> tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState]: + """Compute warp to compute the result of softmax, rescale, and epilogue. Updates the related pipeline states. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param softmax_params: The softmax parameters + :type softmax_params: SimpleNamespace + :param k_index: The index of the k-tile + :type k_index: cutlass.Int32 + :param k_tile_count: The number of k-tiles + :type k_tile_count: cutlass.Int32 + :param mma_s_consumer_state: The MMA s consumer state + :type mma_s_consumer_state: pipeline.PipelineState + :param p_mma_producer_state: The P MMA producer state + :type p_mma_producer_state: pipeline.PipelineState + :param p_cor_producer_state: The P correction producer state + :type p_cor_producer_state: pipeline.PipelineState + + :return: The MMA s consumer state, the P MMA producer state, and the P correction producer state + :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] + """ + + k_tile_total = cute.ceil_div(common_params.K, self.mma_qk_tiler[1]) + + row_max = -self.acc_dtype.inf + row_sum = self.acc_dtype(0) + correction_factor = self.acc_dtype(1) + common_params.p_cor_pipeline.producer_acquire(p_cor_producer_state) + + # no mask applied + while k_tile_count > 1: + ( + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + ) = self.softmax( + common_params, + softmax_params, + k_index, + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + False, + False, + ) + k_index = k_index + 1 + k_tile_count = k_tile_count - 1 + + # mask applied + if cutlass.const_expr(common_params.mAccO is not None): + ( + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + ) = self.softmax( + common_params, + softmax_params, + k_index, + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + k_index == k_tile_total - 1, + True, + ) + else: + ( + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + ) = self.softmax( + common_params, + softmax_params, + k_index, + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + True, + True, + ) + + return mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state + + @cute.jit + def correction( + self, + common_params: SimpleNamespace, + epilogue_params: SimpleNamespace, + k_tile_count: cutlass.Int32, + p_cor_consumer_state: pipeline.PipelineState, + mma_o_consumer_state: pipeline.PipelineState, + ) -> tuple[pipeline.PipelineState, pipeline.PipelineState]: + """Compute warp to compute the result of softmax, rescale, and epilogue. Updates the related pipeline states. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param epilogue_params: The epilogue parameters + :type epilogue_params: SimpleNamespace + :param k_index: The index of the k-tile + :type k_index: cutlass.Int32 + :param k_tile_count: The number of k-tiles + :type k_tile_count: cutlass.Int32 + :param p_cor_consumer_state: The P correction consumer state + :type p_cor_consumer_state: pipeline.PipelineState + :param mma_o_consumer_state: The MMA o consumer state + :type mma_o_consumer_state: pipeline.PipelineState + + :return: The P correction consumer state, and the MMA o consumer state + :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState] + """ + + k_tile_count_init = k_tile_count + while k_tile_count > 0: + p_cor_consumer_state, row_sum, row_max, correction_factor, no_correction = ( + self.get_correction_factor(common_params, p_cor_consumer_state) + ) + if k_tile_count_init != k_tile_count: + mma_o_consumer_state = self.rescale( + common_params, + mma_o_consumer_state, + correction_factor, + no_correction, + ) + k_tile_count = k_tile_count - 1 + if k_tile_count == 0: + mma_o_consumer_state = self.epilogue( + common_params, + epilogue_params, + mma_o_consumer_state, + row_sum, + row_max, + ) + + return p_cor_consumer_state, mma_o_consumer_state + + @cute.jit + def exchange_p_cor_metadata( + self, + common_params: SimpleNamespace, + softmax_params: SimpleNamespace, + correction_factor: cutlass.Float32, + row_sum: cutlass.Float32, + row_max: cutlass.Float32, + row_max_new: cutlass.Float32, + tAcc: cute.Tensor, + tidx: cutlass.Int32, + p_cor_producer_state: pipeline.PipelineState, + ) -> pipeline.PipelineState: + """Compute the correction factor for the last k tile.""" + no_correction = 0 + if ( + row_max_new - row_max + ) * softmax_params.softmax_scale_log2 <= self.skip_correction_threshold: + no_correction = 1 + row_max_new = row_max + + # pad for 4x32b + corr_layout = cute.make_layout( + (tAcc.shape[0], (4, tAcc.shape[1][1]), self.mma_s_stage), + stride=(tAcc.stride[0], (1, tAcc.stride[1][1]), 4), + ) + tCor = cute.make_tensor( + common_params.tmem_ptr + self.correction_factor_offset, + corr_layout, + ) + cCor = cute.make_identity_tensor(tCor.shape) + corr_tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(4)), self.acc_dtype + ) + corr_tmem_store_tiled_copy = tcgen05.make_tmem_copy(corr_tmem_store_atom, tCor) + corr_tmem_store_thr_copy = corr_tmem_store_tiled_copy.get_slice(tidx) + cCor_for_copy = corr_tmem_store_thr_copy.partition_S(cCor) + tCor_for_copy = corr_tmem_store_thr_copy.partition_D(tCor) + rCor = cute.make_fragment_like( + cCor_for_copy[None, None, None, 0], self.acc_dtype + ) + rCor_int = cute.make_tensor( + cute.recast_ptr(rCor.iterator, dtype=cutlass.Int32), rCor.layout + ) + rCor[0] = row_sum + rCor[1] = row_max_new + rCor[2] = correction_factor + rCor_int[3] = no_correction + + cute.copy( + corr_tmem_store_tiled_copy, + rCor, + tCor_for_copy[None, None, None, p_cor_producer_state.index], + ) + # fence between tmem store and correction warp + cute.arch.fence_view_async_tmem_store() + common_params.p_cor_pipeline.producer_commit(p_cor_producer_state) + p_cor_producer_state.advance() + return p_cor_producer_state, row_max_new + + @cute.jit + def softmax( + self, + common_params: SimpleNamespace, + softmax_params: SimpleNamespace, + k_index: cutlass.Int32, + mma_s_consumer_state: pipeline.PipelineState, + p_mma_producer_state: pipeline.PipelineState, + p_cor_producer_state: pipeline.PipelineState, + row_max: cutlass.Float32, + row_sum: cutlass.Float32, + correction_factor: cutlass.Float32, + is_last_tile: bool, + is_local_last_tile: cutlass.Boolean, + ) -> tuple[ + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + cutlass.Float32, + cutlass.Float32, + cutlass.Float32, + ]: + """Softmax for one k-tile. Updates the related pipeline states and returns the computed results. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param softmax_params: The softmax parameters + :type softmax_params: SimpleNamespace + :param k_index: The index of the k-tile + :type k_index: cutlass.Int32 + :param mma_s_consumer_state: The MMA s consumer state + :type mma_s_consumer_state: pipeline.PipelineState + :param p_mma_producer_state: The P MMA producer state + :type p_mma_producer_state: pipeline.PipelineState + :param p_cor_producer_state: The P correction producer state + :type p_cor_producer_state: pipeline.PipelineState + :param row_max: The row max + :type row_max: cutlass.Float32 + :param row_sum: The row sum + :type row_sum: cutlass.Float32 + :param correction_factor: The correction factor + :type correction_factor: cutlass.Float32 + :param is_last_tile: Whether the last tile + :type is_last_tile: bool + :param is_local_last_tile: Whether the last tile is local + :type is_local_last_tile: cutlass.Boolean + + :return: The MMA s consumer state, the P MMA producer state, the P correction producer state, the row max, the row sum, and the correction factor + :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, cutlass.Float32, cutlass.Float32, cutlass.Float32] + """ + + softmax_params.p_mma_pipeline.producer_acquire(p_mma_producer_state) + softmax_params.mma_s_pipeline.consumer_wait(mma_s_consumer_state) + + # load S from tmem + tStS_shape = softmax_params.tiled_mma_qk.partition_shape_C( + cute.select(self.mma_qk_tiler, mode=[0, 1]) + ) + tStS_staged_fake = softmax_params.tiled_mma_qk.make_fragment_C( + cute.append(tStS_shape, self.mma_s_stage) + ) + tStS_staged = cute.make_tensor(common_params.tmem_ptr, tStS_staged_fake.layout) + tStS = tStS_staged[None, None, None, mma_s_consumer_state.index] + + tAcc = tStS[(None, None), 0, 0] + cta_qk_tiler = ( + self.mma_qk_tiler[0] // self.cluster_shape_mnk[0], + self.mma_qk_tiler[1], + self.mma_qk_tiler[2], + ) + cS = cute.make_identity_tensor(cute.select(cta_qk_tiler, mode=[0, 1])) + + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype + ) + tmem_tiled_copy = tcgen05.make_tmem_copy(tmem_load_atom, tAcc) + + tidx = common_params.tidx % (self.num_compute_warps * self.threads_per_warp) + + tmem_thr_copy = tmem_tiled_copy.get_slice(tidx) + tTR_tAcc = tmem_thr_copy.partition_S(tAcc) + tTR_tS = tmem_thr_copy.partition_D(cS) + + tTR_rAcc = cute.make_fragment_like(tTR_tS, self.acc_dtype) + + row_max_new = row_max + arch = BaseDSL._get_dsl().get_arch_enum() + if cutlass.const_expr(arch >= Arch.sm_100 and arch <= Arch.sm_100f): + cute.copy(tmem_tiled_copy, tTR_tAcc, tTR_rAcc) + for i in cutlass.range_constexpr(cute.size(tTR_rAcc)): + if is_last_tile: + tTR_rAcc[i] = ( + tTR_rAcc[i] + if cute.elem_less( + tTR_tS[i][1] + self.mma_qk_tiler[1] * k_index, + common_params.K, + ) + else -self.acc_dtype.inf + ) + # reduction for row_max + row_max_new = tTR_rAcc.load().reduce(cute.ReductionOp.MAX, row_max_new, 0) + + elif cutlass.const_expr(arch >= Arch.sm_103 and arch <= Arch.sm_103f): + tmem_load_red_atom = cute.make_copy_atom( + tcgen05.copy.LdRed32x32bOp( + tcgen05.copy.Repetition(64), redOp=tcgen05.TmemLoadRedOp.MAX + ), + self.acc_dtype, + ) + tmem_red_tiled_copy = tcgen05.make_tmem_copy(tmem_load_red_atom, tAcc) + tmem_red_thr_copy = tmem_red_tiled_copy.get_slice(tidx) + tTR_tAcc_red = tmem_red_thr_copy.partition_S(tAcc) + tTR_tS_red = tmem_red_thr_copy.partition_D(cS) + tTR_rAcc_red = cute.make_fragment_like(tTR_tS_red, self.acc_dtype) + tTR_rMax = cute.make_rmem_tensor( + cute.make_layout((1, tTR_tS_red.shape[1], tTR_tS_red.shape[2])), + self.acc_dtype, + ) + cute.copy( + tmem_red_tiled_copy, + tTR_tAcc_red, + (tTR_rAcc_red, tTR_rMax), + ) + tTR_rAcc = cute.make_tensor(tTR_rAcc_red.iterator, tTR_rAcc.layout) + if is_last_tile: + for i in cutlass.range_constexpr(cute.size(tTR_rAcc)): + tTR_rAcc[i] = ( + tTR_rAcc[i] + if cute.elem_less( + tTR_tS[i][1] + self.mma_qk_tiler[1] * k_index, + common_params.K, + ) + else -self.acc_dtype.inf + ) + # reduction for row_max + row_max_new = tTR_rAcc.load().reduce( + cute.ReductionOp.MAX, row_max_new, 0 + ) + else: + row_max_new = cute.arch.fmax(row_max_new, tTR_rMax[0]) + + # if warps in N is 2, reduce row_max across warps (0, 1) and (2, 3) + if cutlass.const_expr(self.warps_in_n == 2): + common_params.smem_exchange[tidx] = row_max_new + self.softmax_exchange_sync_bar.wait() + row_max_new = cute.arch.fmax( + row_max_new, + common_params.smem_exchange[ + (tidx + 64) % (self.num_compute_warps * self.threads_per_warp) + ], + ) + + # find correction factor + correction_factor = cute.math.exp2( + (row_max - row_max_new) * softmax_params.softmax_scale_log2, fastmath=True + ) + # split kv case + if cutlass.const_expr(not is_local_last_tile): + p_cor_producer_state, row_max_new = self.exchange_p_cor_metadata( + common_params, + softmax_params, + correction_factor, + row_sum, + row_max, + row_max_new, + tAcc, + tidx, + p_cor_producer_state, + ) + + # softmax + fma_b = softmax_params.softmax_scale_log2 + fma_c = (0.0 - row_max_new) * softmax_params.softmax_scale_log2 + + for i in cutlass.range(cute.size(tTR_rAcc), vectorize=True, unroll_full=True): + tTR_rAcc[i] = tTR_rAcc[i] * fma_b + fma_c + tTR_rAcc[i] = cute.math.exp2(tTR_rAcc[i], fastmath=True) + + tTR_rS = cute.make_fragment_like(tTR_tS, self.q_dtype) + + # quantize + tTR_rS.store(tTR_rAcc.load().to(self.q_dtype)) + + # create sP + sP = softmax_params.sP[None, None, None, (None, p_mma_producer_state.index)] + sP_mk_view = cute.make_tensor( + sP.iterator, + cute.make_layout( + ( + (sP.shape[0][0], sP.shape[1]), + (sP.shape[0][1], sP.shape[2], sP.shape[3]), + ), + stride=( + (sP.stride[0][0], sP.stride[1]), + (sP.stride[0][1], sP.stride[2], sP.stride[3]), + ), + ), + ) + # {$nv-internal-release begin} + # TODO: figure out if we could use A tmem for pv. + # {$nv-internal-release end} + # change to PISL + sP_wo_swizzle_iter = cute.recast_ptr(sP.iterator, swizzle_=None) + swizzle_bits = ( + int(math.log2(self.mma_pv_tiler[2] * self.q_dtype.width // 8 // 32)) + 1 + ) + swizzle_base = 3 if self.q_dtype.width == 16 else 4 + sP_swizzle = cute.make_swizzle(swizzle_bits, swizzle_base, 3) + sP_mk_view = cute.make_tensor( + sP_wo_swizzle_iter, + cute.make_composed_layout(sP_swizzle, 0, sP_mk_view.layout), + ) + universal_copy_bits = 128 + smem_copy_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.q_dtype, + num_bits_per_copy=universal_copy_bits, + ) + smem_tiled_copy = cute.make_tiled_copy_D(smem_copy_atom, tmem_tiled_copy) + smem_thr_copy = smem_tiled_copy.get_slice(tidx) + rP_copy_view = smem_thr_copy.retile(tTR_rS) + sP_copy_view = smem_thr_copy.partition_D(sP_mk_view) + cute.copy(smem_tiled_copy, rP_copy_view, sP_copy_view) + + # fence between smem store and mma o + cute.arch.fence_view_async_shared() + softmax_params.p_mma_pipeline.producer_commit(p_mma_producer_state) + p_mma_producer_state.advance() + + # row_sum, using `add_packed_f32x2` to reduce the number of instructions + row_sum = row_sum * correction_factor + row_sum_vec = (0.0, 0.0) + for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2): + row_sum_vec = cute.arch.add_packed_f32x2( + row_sum_vec, (tTR_rAcc[i], tTR_rAcc[i + 1]) + ) + row_sum = row_sum_vec[0] + row_sum_vec[1] + row_sum + + # split kv case + if cutlass.const_expr(is_local_last_tile): + p_cor_producer_state, row_max_new = self.exchange_p_cor_metadata( + common_params, + softmax_params, + correction_factor, + row_sum, + row_max, + row_max_new, + tAcc, + tidx, + p_cor_producer_state, + ) + + # store correction factor/row_sum/row_max to tmem for correction warp + common_params.p_cor_pipeline.producer_acquire(p_cor_producer_state) + + # fence between tmem load and mma s + cute.arch.fence_view_async_tmem_load() + + softmax_params.mma_s_pipeline.consumer_release(mma_s_consumer_state) + mma_s_consumer_state.advance() + + return ( + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max_new, + row_sum, + correction_factor, + ) + + @cute.jit + def _tmem_load_partition( + self, common_params: SimpleNamespace, tiled_mma_pv: cute.TiledMma, iter_n: int + ) -> tuple[ + cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma + ]: + """Tensor memory load partition for rescale and epilogue. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param tiled_mma_pv: The tiled mma pv + :type tiled_mma_pv: cute.TiledMma + :param iter_n: The iteration number + :type iter_n: int + + :return: The tiled mma pv, the tiled mma pv, the tiled mma pv, the tiled mma pv, the tiled mma pv + :rtype: tuple[cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma] + """ + + tOtO_shape = tiled_mma_pv.partition_shape_C( + cute.select(self.mma_pv_tiler, mode=[0, 1]) + ) + tOtO = tiled_mma_pv.make_fragment_C(tOtO_shape) + tOtO_layout = cute.append( + tOtO.layout, + cute.make_layout( + common_params.L // self.mma_pv_tiler[1], + stride=self.mma_pv_tiler[1] // self.warps_in_n, + ), + ) + tOtO = cute.make_tensor( + common_params.tmem_ptr + self.tmem_o_offset, tOtO_layout + ) + tOtO = tOtO[None, None, None, iter_n] + + tAcc = tOtO[(None, None), 0, 0] + + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype + ) + tmem_load_tiled_copy = tcgen05.make_tmem_copy(tmem_load_atom, tAcc) + # {$nv-internal-release begin} + # TODO: supports size() on tiled copy. + # {$nv-internal-release end} + tmem_load_thr_copy = tmem_load_tiled_copy.get_slice( + common_params.tidx % (self.num_compute_warps * self.threads_per_warp) + ) + + cta_pv_tiler = ( + self.mma_pv_tiler[0] // self.cluster_shape_mnk[0], + self.mma_pv_tiler[1], + self.mma_pv_tiler[2], + ) + # Flatten divide and partition global tensors for O + cta_pv_tiler_mn = cute.select(cta_pv_tiler, mode=[0, 1]) + + gO = None + if cutlass.const_expr(common_params.mAccO is not None): + gO = cute.local_tile( + common_params.mAccO[None, common_params.blk_coord[3], None, None, None], + cta_pv_tiler_mn, + ( + common_params.blk_coord[0], + iter_n, + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + ) + cO = cute.local_tile( + cute.make_identity_tensor( + common_params.mAccO[ + None, common_params.blk_coord[3], None, None, None + ].shape + ), + cta_pv_tiler_mn, + ( + common_params.blk_coord[0], + iter_n, + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + ) + else: + gO = cute.local_tile( + common_params.mO, + cta_pv_tiler_mn, + ( + common_params.blk_coord[0], + iter_n, + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + ) + cO = cute.local_tile( + cute.make_identity_tensor(common_params.mO.shape), + cta_pv_tiler_mn, + ( + common_params.blk_coord[0], + iter_n, + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + ) + tTR_tAcc = tmem_load_thr_copy.partition_S(tAcc) + tTR_gO = tmem_load_thr_copy.partition_D(gO) + tTR_cO = tmem_load_thr_copy.partition_D(cO) + tTR_rAcc = cute.make_fragment_like(tTR_gO, self.acc_dtype) + return tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc # type: ignore[return-value] + + def get_correction_factor( + self, + common_params: SimpleNamespace, + p_cor_consumer_state: pipeline.PipelineState, + ) -> tuple[ + pipeline.PipelineState, + cutlass.Float32, + cutlass.Float32, + cutlass.Float32, + cutlass.Int32, + ]: + """Get the correction factor from the P correction consumer state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param p_cor_consumer_state: The P correction consumer state + :type p_cor_consumer_state: pipeline.PipelineState + + :return: The P correction consumer state, the row_sum, the row_max, and the correction factor + :rtype: tuple[pipeline.PipelineState, cutlass.Float32, cutlass.Float32, cutlass.Float32, cutlass.Int32] + """ + common_params.p_cor_pipeline.consumer_wait(p_cor_consumer_state) + tidx = common_params.tidx % (self.num_compute_warps * self.threads_per_warp) + # load correction factor + _, tAcc, _, _, _, _ = self._tmem_load_partition( + common_params, common_params.tiled_mma_pv, 0 + ) + corr_layout = cute.make_layout( + (tAcc.shape[0], (4, tAcc.shape[1][1]), self.p_cor_stage), + stride=(tAcc.stride[0], (1, tAcc.stride[1][1]), 4), + ) + tCor = cute.make_tensor( + common_params.tmem_ptr + self.correction_factor_offset, corr_layout + ) + cCor = cute.make_identity_tensor(tCor.shape) + corr_tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(4)), self.acc_dtype + ) + corr_tmem_load_tiled_copy = tcgen05.make_tmem_copy(corr_tmem_load_atom, tCor) + corr_tmem_load_thr_copy = corr_tmem_load_tiled_copy.get_slice(tidx) + tCor_for_copy = corr_tmem_load_thr_copy.partition_S(tCor) + cCor_for_copy = corr_tmem_load_thr_copy.partition_D(cCor) + rCor = cute.make_fragment_like( + cCor_for_copy[None, None, None, 0], self.acc_dtype + ) + rCor_int = cute.make_tensor( + cute.recast_ptr(rCor.iterator, dtype=cutlass.Int32), rCor.layout + ) + cute.copy( + corr_tmem_load_tiled_copy, + tCor_for_copy[None, None, None, p_cor_consumer_state.index], + rCor, + ) + row_sum = rCor[0] + row_max = rCor[1] + correction_factor = rCor[2] + no_correction = rCor_int[3] + + common_params.p_cor_pipeline.consumer_release(p_cor_consumer_state) + p_cor_consumer_state.advance() + return p_cor_consumer_state, row_sum, row_max, correction_factor, no_correction + + @cute.jit + def rescale( + self, + common_params: SimpleNamespace, + mma_o_consumer_state: pipeline.PipelineState, + correction_factor: cutlass.Float32, + no_correction: cutlass.Int32, + ) -> pipeline.PipelineState: + """Rescale for one k-tile. Updates the related pipeline state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param mma_o_consumer_state: The mma o consumer state + :type mma_o_consumer_state: pipeline.PipelineState + :param correction_factor: The correction factor + :type correction_factor: cutlass.Float32 + :param no_correction: Whether to apply correction factor + :type no_correction: cutlass.Int32 + + :return: The MMA o consumer state + :rtype: pipeline.PipelineState + """ + skip_correction = cute.arch.vote_all_sync(no_correction == 1) + common_params.mma_o_pipeline.consumer_wait(mma_o_consumer_state) + if not skip_correction: + for iter_n in cutlass.range_constexpr(self.iterations_pv_n): + # tmem load tiled copy and partition results. + tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc = ( + self._tmem_load_partition( + common_params, common_params.tiled_mma_pv, iter_n + ) + ) + + # tmem store tiled copy + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype + ) + tmem_store_tiled_copy = tcgen05.make_tmem_copy(tmem_store_atom, tAcc) + + # load o + cute.copy(tmem_load_tiled_copy, tTR_tAcc, tTR_rAcc) + # rescale, using `mul_packed_f32x2` to reduce the number of instructions + for i in cutlass.range( + cute.size(tTR_rAcc), vectorize=True, unroll_full=True + ): + tTR_rAcc[i] = tTR_rAcc[i] * correction_factor + + # store o to tensor memory for next k tile + cute.copy(tmem_store_tiled_copy, tTR_rAcc, tTR_tAcc) + + cute.arch.fence_view_async_tmem_store() + common_params.mma_o_pipeline.consumer_release(mma_o_consumer_state) + mma_o_consumer_state.advance() + + return mma_o_consumer_state + + @cute.jit + def epilogue( + self, + common_params: SimpleNamespace, + epilogue_params: SimpleNamespace, + mma_o_consumer_state: pipeline.PipelineState, + row_sum: cutlass.Float32, + row_max: cutlass.Float32, + ) -> pipeline.PipelineState: + """Epilogue for one k-tile. Updates the related pipeline state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param epilogue_params: The epilogue parameters + :type epilogue_params: SimpleNamespace + :param mma_o_consumer_state: The mma o consumer state + :type mma_o_consumer_state: pipeline.PipelineState + :param row_sum: The row sum + :type row_sum: cutlass.Float32 + :param row_max: The row max + :type row_max: cutlass.Float32 + + :return: The MMA o consumer state + :rtype: pipeline.PipelineState + """ + + tidx = common_params.tidx % (self.num_compute_warps * self.threads_per_warp) + + # exchange row_sum between warps (0, 1) and (2, 3) + if cutlass.const_expr(self.warps_in_n == 2): + common_params.smem_exchange[tidx] = row_sum + self.epilogue_exchange_sync_bar.wait() + # (64, 2) + row_sum = ( + row_sum + + common_params.smem_exchange[ + (tidx + 64) % (self.num_compute_warps * self.threads_per_warp) + ] + ) + # mma_o pipeline consumer wait + common_params.mma_o_pipeline.consumer_wait(mma_o_consumer_state) + for iter_n in cutlass.range_constexpr(self.iterations_pv_n): + # tmem load tiled copy and partition results. + tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc = ( + self._tmem_load_partition( + common_params, common_params.tiled_mma_pv, iter_n + ) + ) + + # load o + cute.copy(tmem_load_tiled_copy, tTR_tAcc, tTR_rAcc) + + # apply output scale and normalize by row_sum + for i in cutlass.range( + cute.size(tTR_rAcc), vectorize=True, unroll_full=True + ): + tTR_rAcc[i] = ( + tTR_rAcc[i] + * epilogue_params.output_scale + * cute.arch.rcp_approx(row_sum) + ) + + # store o to global memory + tR2G_rO_src = None + tR2G_rO_dst = tTR_gO + if cutlass.const_expr(common_params.mAccO is None): + tR2G_rO_src = cute.make_fragment_like(tTR_gO, self.o_dtype) + # using final output dtype for o + tR2G_rO_src.store(tTR_rAcc.load().to(self.o_dtype)) + else: + # using accumulate dtype for o + tR2G_rO_src = tTR_rAcc + + if cute.elem_less(tTR_cO[0][0], common_params.H): + cute.autovec_copy( + tR2G_rO_src, + tR2G_rO_dst, + l1c_evict_priority=cute.nvgpu.CacheEvictionPriority.NO_ALLOCATE, + ) + + # store the lse to global memory + cta_pv_tiler = ( + self.mma_pv_tiler[0] // self.cluster_shape_mnk[0], + self.mma_pv_tiler[1], + self.mma_pv_tiler[2], + ) + gLSE = None + cLSE = None + if cutlass.const_expr(epilogue_params.mAccLSE is None): + gLSE = cute.local_tile( + epilogue_params.mLSE, + (cta_pv_tiler[0], 1, 1), + ( + common_params.blk_coord[0], + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + (1, 1, 1), + ) + cLSE = cute.local_tile( + cute.make_identity_tensor(epilogue_params.mLSE.shape), + (cta_pv_tiler[0], 1, 1), + ( + common_params.blk_coord[0], + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + (1, 1, 1), + ) + + else: + gLSE = cute.local_tile( + epilogue_params.mAccLSE[ + None, common_params.blk_coord[3], None, None + ], + (cta_pv_tiler[0], 1, 1), + ( + common_params.blk_coord[0], + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + (1, 1, 1), + ) + cLSE = cute.local_tile( + cute.make_identity_tensor( + epilogue_params.mAccLSE[ + None, common_params.blk_coord[3], None, None + ].shape + ), + (cta_pv_tiler[0], 1, 1), + ( + common_params.blk_coord[0], + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + (1, 1, 1), + ) + lse = ( + cute.math.log2(row_sum, fastmath=True) + + epilogue_params.softmax_scale_log2 * row_max + ) + if cutlass.const_expr(self.warps_in_n == 2): + if cute.elem_less(cLSE[tidx][0], common_params.H): + gLSE[tidx] = lse + + cute.arch.fence_view_async_tmem_load() + common_params.mma_o_pipeline.consumer_release(mma_o_consumer_state) + mma_o_consumer_state.advance() + + return mma_o_consumer_state + + def make_and_init_load_pt_pipeline(self, load_pt_mbar_ptr): + """Create and initialize the load page table pipeline. + + :param load_pt_mbar_ptr: The load page table mbar pointer + :type load_pt_mbar_ptr: cute.Tensor + + :return: The load page table pipeline + :rtype: pipeline.PipelineAsync + """ + load_pt_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * len([self.load_pt_warp_id]), + ) + load_pt_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * len([self.load_tma_warp_id]), + ) + return pipeline.PipelineCpAsync.create( + barrier_storage=load_pt_mbar_ptr, + num_stages=self.load_pt_stage, + producer_group=load_pt_producer_group, + consumer_group=load_pt_consumer_group, + defer_sync=True, + ) + + def make_and_init_load_qkv_pipeline( + self, load_qkv_mbar_ptr, cta_layout_vmnk, load_stages, tx_count + ) -> pipeline.PipelineTmaUmma: + """Create and initialize the tma load qkv pipeline. + + :param load_qkv_mbar_ptr: The load qkv mbar pointer + :type load_qkv_mbar_ptr: cute.Tensor + :param cta_layout_vmnk: The cta layout vmnk + :type cta_layout_vmnk: tuple[int, int, int] + :param load_stages: The load stages + :type load_stages: list[int] + :param tx_count: The tx count + :type tx_count: int + + :return: The tma load qkv pipeline + :rtype: pipeline.PipelineTmaUmma + """ + load_qkv_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.load_tma_warp_id]) + ) + load_qkv_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + return pipeline.PipelineTmaUmma.create( + barrier_storage=load_qkv_mbar_ptr, + num_stages=load_stages, + producer_group=load_qkv_producer_group, + consumer_group=load_qkv_consumer_group, + tx_count=tx_count, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + + def make_and_init_mma_s_pipeline( + self, mma_s_mbar_ptr, cta_layout_vmnk + ) -> pipeline.PipelineUmmaAsync: + """Create and initialize the mma s pipeline. + + :param mma_s_mbar_ptr: The mma s mbar pointer + :type mma_s_mbar_ptr: cute.Tensor + :param cta_layout_vmnk: The cta layout vmnk + :type cta_layout_vmnk: tuple[int, int, int] + + :return: The mma s pipeline + :rtype: pipeline.PipelineUmmaAsync + """ + + mma_s_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + consumer_thread_size = ( + self.threads_per_warp + * len(self.compute_warp_ids) + * self.cluster_shape_mnk[0] + ) + mma_s_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + consumer_thread_size, + ) + return pipeline.PipelineUmmaAsync.create( + barrier_storage=mma_s_mbar_ptr, + num_stages=self.mma_s_stage, + producer_group=mma_s_producer_group, + consumer_group=mma_s_consumer_group, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + + def make_and_init_p_mma_pipeline( + self, p_mma_mbar_ptr, cta_layout_vmnk + ) -> pipeline.PipelineAsyncUmma: + """Create and initialize the p mma pipeline. + + :param p_mma_mbar_ptr: The p mma mbar pointer + :type p_mma_mbar_ptr: cute.Tensor + :param cta_layout_vmnk: The cta layout vmnk + :type cta_layout_vmnk: tuple[int, int, int] + + :return: The p mma pipeline + :rtype: pipeline.PipelineAsyncUmma + """ + + producer_thread_size = ( + self.threads_per_warp + * len(self.compute_warp_ids) + * self.cluster_shape_mnk[0] + ) + p_mma_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + producer_thread_size, + ) + p_mma_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + return pipeline.PipelineAsyncUmma.create( + barrier_storage=p_mma_mbar_ptr, + num_stages=self.p_mma_stage, + producer_group=p_mma_producer_group, + consumer_group=p_mma_consumer_group, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + + def make_and_init_p_cor_pipeline( + self, p_cor_mbar_ptr + ) -> pipeline.PipelineAsyncUmma: + """Create and initialize the p correction pipeline. + + :param p_cor_mbar_ptr: The p correction mbar pointer + :type p_cor_mbar_ptr: cute.Tensor + + :return: The p correction pipeline + :rtype: pipeline.PipelineAsyncUmma + """ + + producer_thread_size = self.threads_per_warp * len(self.compute_warp_ids) + p_cor_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + producer_thread_size, + ) + p_cor_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + producer_thread_size, + ) + return pipeline.PipelineAsync.create( + barrier_storage=p_cor_mbar_ptr, + num_stages=self.p_cor_stage, + producer_group=p_cor_producer_group, + consumer_group=p_cor_consumer_group, + defer_sync=True, + ) + + def make_and_init_mma_o_pipeline( + self, mma_o_mbar_ptr, cta_layout_vmnk + ) -> pipeline.PipelineUmmaAsync: + """Create and initialize the mma o pipeline. + + :param mma_o_mbar_ptr: The mma o mbar pointer + :type mma_o_mbar_ptr: cute.Tensor + :param cta_layout_vmnk: The cta layout vmnk + :type cta_layout_vmnk: tuple[int, int, int] + + :return: The mma o pipeline + :rtype: pipeline.PipelineUmmaAsync + """ + + mma_o_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + consumer_thread_size = ( + self.threads_per_warp + * len(self.compute_warp_ids) + * self.cluster_shape_mnk[0] + ) + mma_o_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + consumer_thread_size, + ) + return pipeline.PipelineUmmaAsync.create( + barrier_storage=mma_o_mbar_ptr, + num_stages=self.mma_o_stage, + producer_group=mma_o_producer_group, + consumer_group=mma_o_consumer_group, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + + @staticmethod + def _compute_grid( + o: cute.Tensor, + split_kv: cutlass.Int32, + cluster_shape_mnk: Tuple[int, int, int], + max_active_clusters: int, + is_persistent: bool, + ) -> Tuple[MLAStaticTileSchedulerParams, Tuple[int, int, int]]: + """Compute grid shape for the output tensor C. + + :param c: The output tensor C + :type c: cute.Tensor + :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type cta_tile_shape_mnk: tuple[int, int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + + :return: Tile scheduler parameters and grid shape. + :rtype: tuple[MLAStaticTileSchedulerParams, tuple[int, int, int]] + """ + o_shape = o.shape + tile_sched_params = create_mla_static_tile_scheduler_params( + is_persistent, + cute.size(o_shape[3]), + cute.size(o_shape[2]), + cluster_shape_mnk, + split_kv, + ) + grid = MLAStaticTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + + return tile_sched_params, grid + + @staticmethod + def get_workspace_size( + H: int, + S: int, + D: int, + B: int, + split_kv: int, + acc_dtype: Type[cutlass.Numeric], + ) -> int: + """Get the extra workspace(device memory) size for the MLA kernel when split_kv is not 1. + + :param H: The height of the output tensor C + :type H: int + :param S: The sequence length of the output tensor C + :type S: int + :param D: The depth of the output tensor C + :type D: int + :param B: The batch size of the output tensor C + :type B: int + :param split_kv: The split key-value of the output tensor C + :type split_kv: int + :param acc_dtype: The data type of the output tensor C + :type acc_dtype: Type[cutlass.Numeric] + + :return: The workspace size for the MLA kernel + :rtype: int + """ + if split_kv == 1: + return 0 + # Decode packs heads into a physical 128-wide MMA-M tile. For H < 128, + # split-KV partials can still touch the padded head lanes before + # reduction, so size the workspace for max(H, 128). Mirrors the same + # padding applied in initialize_workspace(). See #3235. + workspace_heads = max(H, 128) + return B * workspace_heads * S * split_kv * (D + 1) * acc_dtype.width // 8 + + @cute.jit + def initialize_workspace( + self, + H: cutlass.Int32, + D: cutlass.Int32, + S: cutlass.Int32, + B: cutlass.Int32, + split_kv: cutlass.Int32, + acc_dtype: Type[cutlass.Numeric], + workspace: cute.Tensor, + ) -> tuple[cute.Tensor, cute.Tensor]: + """Initialize the workspace for the MLA kernel. Construct the intermediate tensors + acc_o and acc_lse. + + :param H: The height of the output tensor C + :type H: cutlass.Int32 + :param D: The depth of the output tensor C + :type D: cutlass.Int32 + :param S: The sequence length of the output tensor C + :type S: cutlass.Int32 + :param B: The batch size of the output tensor C + :type B: cutlass.Int32 + :param split_kv: The split key-value of the output tensor C + :type split_kv: cutlass.Int32 + :param acc_dtype: The data type of the output tensor C + :type acc_dtype: Type[cutlass.Numeric] + :param workspace: The workspace tensor + :type workspace: cute.Tensor + + :return: The output tensor C and the workspace tensor + :rtype: tuple[cute.Tensor, cute.Tensor] + """ + acc_o, acc_lse = None, None + if cutlass.const_expr(workspace is not None): + # Pad head dim to the physical 128-wide MMA-M tile. Without this, + # H<128 split-KV partials write past the workspace. See #3235. + workspace_H = cutlass.max(H, cutlass.Int32(128)) + align = 256 // self.q_dtype.width + acc_o_layout = cute.make_layout( + (workspace_H, split_kv, D, S, B), + stride=( + cute.assume(split_kv * D, align), + cute.assume(D, align), + 1, + cute.assume(split_kv * workspace_H * D, align), + cute.assume(workspace_H * split_kv * S * D, align), + ), + ) + acc_o_iter = cute.recast_ptr(workspace.iterator, dtype=acc_dtype) + acc_o = cute.make_tensor(acc_o_iter, acc_o_layout) + acc_lse_layout = cute.make_layout( + (workspace_H, split_kv, S, B), + stride=( + split_kv, + 1, + workspace_H * split_kv, + workspace_H * split_kv * S, + ), + ) + acc_lse_iter = cute.recast_ptr( + workspace.iterator + cute.cosize(acc_o_layout) * acc_dtype.width // 8, + dtype=acc_dtype, + ) + acc_lse = cute.make_tensor(acc_lse_iter, acc_lse_layout) + return acc_o, acc_lse + + @staticmethod + def can_implement( + B: int, + S: int, + K: int, + H: int, + L: int, + R: int, + in_dtype: Type[cutlass.Numeric], + out_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + lse_dtype: Type[cutlass.Numeric], + mma_qk_tiler_mn: Tuple[int, int], + mma_pv_tiler_mn: Tuple[int, int], + is_persistent: bool, + is_var_seq: bool, + is_var_split_kv: bool, + page_size: int, + ) -> bool: + """Check if the MLA kernel can be implemented. + + :param B: The batch size of the output tensor C + :type B: int + :param S: The sequence length of the output tensor C + :type S: int + :param K: The width of the output tensor KV + :type K: int + :param H: The number of heads of the output tensor C + :type H: int + :param L: The number of latent dimensions of the tensor KV + :type L: int + :param R: The number of rope dimensions of the tensor C_rope + :type R: int + :param in_dtype: The data type of the input tensor + :type in_dtype: Type[cutlass.Numeric] + :param out_dtype: The data type of the output tensor + :type out_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param lse_dtype: The data type of the log-sum-exp + :type lse_dtype: Type[cutlass.Numeric] + :param mma_qk_tiler_mn: The tile shape of the query-key matrix multiplication + :type mma_qk_tiler_mn: Tuple[int, int] + :param mma_pv_tiler_mn: The tile shape of the probability-value matrix multiplication + :type mma_pv_tiler_mn: Tuple[int, int] + :param is_persistent: Whether to use persistent kernel optimization + :type is_persistent: bool + :param is_var_seq: Whether to use variable sequence length + :type is_var_seq: bool + :param is_var_split_kv: Whether to use variable split_kv + :type is_var_split_kv: bool + :param page_size: The page size of the page table + :type page_size: int + + :return: Whether the MLA kernel can be implemented + :rtype: bool + """ + if L != 512 or R != 64: + return False + if in_dtype not in [cutlass.Float16, cutlass.BFloat16]: + return False + if out_dtype not in [cutlass.Float16, cutlass.BFloat16]: + return False + if acc_dtype != cutlass.Float32 or lse_dtype != cutlass.Float32: + return False + # page size equals 1 is prohibited by tma specification, not 128B aligned. + if mma_qk_tiler_mn[1] % page_size != 0 or page_size == 1: + return False + if mma_qk_tiler_mn[0] != mma_pv_tiler_mn[0] or mma_qk_tiler_mn[0] != 128: + return False + if is_var_split_kv and not is_var_seq: + return False + if H > 128: + return False + if S < 1 or S > 4: + return False + if K <= 0: + return False + return True + + +def run( + batch_size: int, + seq_len_q: int, + seq_len_k: int, + num_heads: int, + latent_dim: int, + rope_dim: int, + in_dtype: Type[cutlass.Numeric], + out_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + lse_dtype: Type[cutlass.Numeric], + mma_qk_tiler_mn: Tuple[int, int], + mma_pv_tiler_mn: Tuple[int, int], + split_kv: int, + is_persistent: bool, + is_var_seq: bool, + is_var_split_kv: bool, + page_size: int, + softmax_scale: float, + output_scale: float, + skip_correction_threshold: float, + tolerance: float, + warmup_iterations: int, + iterations: int, + skip_ref_check: bool, + use_cold_l2: bool, + enable_pdl: bool = False, + **kwargs, +): + """Execute Multi-Head Latent Attention (MLA) on Blackwell architecture and validate results. + + This function creates random input tensors for query latent/rope, compressed latent/rope, and value, + then performs the complete MLA computation pipeline. It supports configurable data types, tiling parameters, + page table, variable sequence length, and variable split_kv. Results can be validated against a PyTorch reference + implementation or run multiple times for performance measurement. + + :param batch_size: Batch size + :type batch_size: int + :param seq_len_q: Sequence length of Q + :type seq_len_q: int + :param seq_len_k: Sequence length of K + :type seq_len_k: int + :param num_heads: Number of heads + :type num_heads: int + :param latent_dim: dimension of query/compressed latent + :type latent_dim: int + :param rope_dim: dimension of query/compressed rope + :type rope_dim: int + :param in_dtype: Input data type for query/compressed latent/rope tensors + :type in_dtype: Type[cutlass.Numeric] + :param out_dtype: Output data type for attention output + :type out_dtype: Type[cutlass.Numeric] + :param acc_dtype: Accumulator data type for query-key matrix multiplication + :type acc_dtype: Type[cutlass.Numeric] + :param lse_dtype: Accumulator data type for log-sum-exp + :type lse_dtype: Type[cutlass.Numeric] + :param mma_qk_tiler_mn: Matrix multiply accumulate tile shape (M, N) for query-key matrix multiplication + :type mma_qk_tiler_mn: Tuple[int, int] + :param mma_pv_tiler_mn: Matrix multiply accumulate tile shape (M, N) for probability-value matrix multiplication + :type mma_pv_tiler_mn: Tuple[int, int] + :param split_kv: Split key-value + :type split_kv: int + :param is_persistent: Whether to use persistent kernel optimization + :type is_persistent: bool + :param is_var_seq: Whether to use variable sequence length + :type is_var_seq: bool + :param is_var_split_kv: Whether to use variable split_kv + :type is_var_split_kv: bool + :param page_size: Page size of the page table + :type page_size: int + :param softmax_scale: Attention score scaling factor + :type softmax_scale: float + :param output_scale: Output scaling factor + :type output_scale: float + :param skip_correction_threshold: Threshold to skip correction + :type skip_correction_threshold: float + :param tolerance: Maximum acceptable error for validation + :type tolerance: float + :param warmup_iterations: Number of warmup iterations + :type warmup_iterations: int + :param iterations: Number of iterations to run for performance testing + :type iterations: int + :param skip_ref_check: Skip validation against reference implementation + :type skip_ref_check: bool + :param use_cold_l2: Whether to use cold L2 cache + :type use_cold_l2: bool + + :raises ValueError: If input shapes are incompatible or head dimension is unsupported + :raises RuntimeError: If GPU is unavailable for computation + """ + + print("Running Blackwell MLA test with:") + print(f" batch_size: {batch_size}") + print(f" seq_len_q: {seq_len_q}") + print(f" seq_len_k: {seq_len_k}") + print(f" num_heads: {num_heads}") + print(f" latent_dim: {latent_dim}") + print(f" rope_dim: {rope_dim}") + print(f" in_dtype: {in_dtype}") + print(f" out_dtype: {out_dtype}") + print(f" acc_dtype: {acc_dtype}") + print(f" mma_qk_tiler_mn: {mma_qk_tiler_mn}") + print(f" mma_pv_tiler_mn: {mma_pv_tiler_mn}") + print(f" split_kv: {split_kv}") + print(f" is_persistent: {is_persistent}") + print(f" is_var_seq: {is_var_seq}") + print(f" is_var_split_kv: {is_var_split_kv}") + print(f" page_size: {page_size}") + print(f" softmax_scale: {softmax_scale}") + print(f" output_scale: {output_scale}") + print(f" skip_correction_threshold: {skip_correction_threshold}") + print(f" tolerance: {tolerance}") + print(f" warmup_iterations: {warmup_iterations}") + print(f" iterations: {iterations}") + print(f" skip_ref_check: {skip_ref_check}") + print(f" use_cold_l2: {use_cold_l2}") + + # Prepare pytorch tensors: Q, K, V (random from 0 to 2) and O (all zero) + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + if not BlackwellMultiHeadLatentAttentionForwardFP16.can_implement( + batch_size, + seq_len_q, + seq_len_k, + num_heads, + latent_dim, + rope_dim, + in_dtype, + out_dtype, + acc_dtype, + lse_dtype, + mma_qk_tiler_mn, + mma_pv_tiler_mn, + is_persistent, + is_var_seq, + is_var_split_kv, + page_size, + ): + raise TypeError( + f"Unsupported testcase {batch_size}, {seq_len_q}, {seq_len_k}, {num_heads}, {latent_dim}, {rope_dim}, {in_dtype}, {out_dtype}, {acc_dtype}, {lse_dtype}, {mma_qk_tiler_mn}, {mma_pv_tiler_mn}, {split_kv}, {is_persistent}, {is_var_seq}, {is_var_split_kv}, {page_size}" + ) + + torch.manual_seed(1111) + + def create_data_tensor( + B, + HK, + D, + dtype, + is_dynamic_layout=True, + page_table=None, + cache_seqs=None, + is_lse=False, + seq_len_q=None, + ): + shape = (B, HK, D) + if page_table is not None: + if cache_seqs is not None: + max_seq_len = torch.max(cache_seqs) + shape = (B * ceil_div(max_seq_len, page_size), page_size, D) + else: + shape = (B * ceil_div(HK, page_size), page_size, D) + + if seq_len_q is not None: + shape = (B, seq_len_q, HK, D) + + # Contiguous row-major: last dim has stride 1 (highest stride_order value = fastest) + if is_lse: + shape = (B, seq_len_q, HK) + leading_dim = 2 + stride_order = (0, 1, 2) + elif seq_len_q is not None: + leading_dim = 3 + stride_order = (0, 1, 2, 3) + else: + leading_dim = 2 + stride_order = (0, 1, 2) + + init_config = cutlass.torch.RandomInitConfig(min_val=-2, max_val=2) + + torch_dtype = ( + cutlass_torch.dtype(dtype) if dtype != cutlass.Float8E4M3FN else torch.int8 + ) + + # Create contiguous dtype torch tensor (cpu) — no permute + torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( + shape, + torch_dtype, + init_type=cutlass.torch.TensorInitType.RANDOM, + init_config=init_config, + ) + + # Create dtype torch tensor (gpu) + torch_tensor_gpu = torch_tensor_cpu.cuda() + + # Create f32 torch tensor (cpu) + f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32) + + # Create dtype cute tensor (gpu) + cute_tensor = from_dlpack(torch_tensor_gpu, assumed_align=16) + cute_tensor.element_type = dtype + if is_dynamic_layout: + cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=leading_dim) + if not is_lse: + cute_tensor = cute_tensor.mark_compact_shape_dynamic( + mode=leading_dim, + stride_order=stride_order, + divisibility=(128 // dtype.width), + ) + + cute_tensor = cutlass_torch.convert_cute_tensor( + f32_torch_tensor, + cute_tensor, + dtype, + is_dynamic_layout=is_dynamic_layout, + ) + + return f32_torch_tensor, cute_tensor, torch_tensor_gpu + + def create_cache_seqs(batch_size, seq_len_k, is_var_seq): + cache_seqs_ref = torch.ones(batch_size, dtype=torch.int32) * seq_len_k + cache_seqs_gpu = cache_seqs_ref.cuda() + cache_seqs = from_dlpack(cache_seqs_gpu, assumed_align=16).mark_layout_dynamic() + if is_var_seq: + max_seq_len = seq_len_k + min_seq_len = int(seq_len_k * 0.8) + cache_seqs_ref = cutlass_torch.create_and_permute_torch_tensor( + (batch_size,), + torch.int32, + init_type=cutlass.torch.TensorInitType.RANDOM, + init_config=cutlass.torch.RandomInitConfig( + min_val=min_seq_len, max_val=max_seq_len + 1 + ), + ) + cache_seqs_gpu = cache_seqs_ref.cuda() + cache_seqs = from_dlpack( + cache_seqs_gpu, + assumed_align=16, + ).mark_layout_dynamic() + return cache_seqs_ref, cache_seqs, cache_seqs_gpu + + def create_page_table(batch_size, seq_len_k, is_var_seq, page_size): + max_seq_len = seq_len_k if not is_var_seq else torch.max(cache_seqs_ref) + page_count = ceil_div(max_seq_len, page_size) + page_table_ref = torch.empty([batch_size, page_count], dtype=torch.int32) + # use transposed index for page table to make sure the value is in bound of `batch_size * seq_len_block`. In practice, the value could be any positive values. This setting is only for testing purpose. + for b in range(batch_size): + for j in range(page_count): + page_table_ref[b, j] = b + j * batch_size + page_table_gpu = page_table_ref.cuda() # contiguous [B, page_count] + page_table = from_dlpack(page_table_gpu, assumed_align=16).mark_layout_dynamic( + leading_dim=1 + ) + return page_table_ref, page_table, page_table_gpu + + def create_block_split_kvs( + batch_size, + split_kv, + cache_seqs_ref, + is_var_split_kv, + mma_qk_tiler_mn, + cluster_shape_mnk, + max_active_clusters, + ): + block_split_kvs_ref, block_split_kvs, block_split_kvs_gpu = None, None, None + # check if split_kv is valid otherwise do auto setting of split_kv + if is_var_split_kv: + block_split_kvs_ref = torch.zeros([batch_size], dtype=torch.int32) + for b in range(batch_size): + block_split_kvs_ref[b] = ( + BlackwellMultiHeadLatentAttentionForwardFP16.get_split_kv( + batch_size, + seq_len_q, + cache_seqs_ref[b].item(), + mma_qk_tiler_mn, + max_active_clusters * cluster_shape_mnk[0], + ) + ) + split_kv = torch.max(block_split_kvs_ref).item() + block_split_kvs_gpu = block_split_kvs_ref.cuda() + block_split_kvs = from_dlpack( + block_split_kvs_gpu, assumed_align=16 + ).mark_layout_dynamic() + elif split_kv <= 0: + split_kv = BlackwellMultiHeadLatentAttentionForwardFP16.get_split_kv( + batch_size, + seq_len_q, + cache_seqs_ref[0].item(), + mma_qk_tiler_mn, + max_active_clusters * cluster_shape_mnk[0], + ) + return split_kv, block_split_kvs_ref, block_split_kvs, block_split_kvs_gpu + + def create_workspace( + num_heads, seq_len_q, latent_dim, batch_size, split_kv, acc_dtype + ): + workspace_size = ( + BlackwellMultiHeadLatentAttentionForwardFP16.get_workspace_size( + num_heads, + seq_len_q, + latent_dim, + batch_size, + split_kv, + acc_dtype, + ) + ) + + workspace, workspace_torch = None, None + if workspace_size > 0: + workspace_torch = torch.empty([workspace_size], dtype=torch.int8).cuda() + workspace = from_dlpack(workspace_torch, assumed_align=32) + return workspace, workspace_torch + + cache_seqs_ref, cache_seqs, cache_seqs_torch = create_cache_seqs( + batch_size, seq_len_k, is_var_seq + ) + page_table_ref, page_table, page_table_torch = create_page_table( + batch_size, seq_len_k, is_var_seq, page_size + ) + cluster_shape_mnk = (2, 1, 1) + hardware_info = utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters( + cluster_shape_mnk[0] * cluster_shape_mnk[1] + ) + split_kv, block_split_kvs_ref, block_split_kvs, block_split_kvs_torch = ( + create_block_split_kvs( + batch_size, + split_kv, + cache_seqs_ref, + is_var_split_kv, + mma_qk_tiler_mn, + cluster_shape_mnk, + max_active_clusters, + ) + ) + + q_latent_ref, q_latent, q_latent_torch = create_data_tensor( + batch_size, + num_heads, + latent_dim, + in_dtype, + is_dynamic_layout=True, + seq_len_q=seq_len_q, + ) + q_rope_ref, q_rope, q_rope_torch = create_data_tensor( + batch_size, + num_heads, + rope_dim, + in_dtype, + is_dynamic_layout=True, + seq_len_q=seq_len_q, + ) + + c_latent_ref, c_latent, c_latent_torch = create_data_tensor( + batch_size, + seq_len_k, + latent_dim, + in_dtype, + is_dynamic_layout=True, + page_table=page_table, + cache_seqs=cache_seqs_ref, + ) + c_rope_ref, c_rope, c_rope_torch = create_data_tensor( + batch_size, + seq_len_k, + rope_dim, + in_dtype, + is_dynamic_layout=True, + page_table=page_table, + cache_seqs=cache_seqs_ref, + ) + o_ref, o, o_torch = create_data_tensor( + batch_size, + num_heads, + latent_dim, + out_dtype, + is_dynamic_layout=True, + seq_len_q=seq_len_q, + ) + lse_ref, lse, lse_torch = create_data_tensor( + batch_size, + num_heads, + 1, + lse_dtype, + is_dynamic_layout=True, + is_lse=True, + seq_len_q=seq_len_q, + ) + workspace, workspace_torch = create_workspace( + num_heads, seq_len_q, latent_dim, batch_size, split_kv, acc_dtype + ) + + mla = BlackwellMultiHeadLatentAttentionForwardFP16( + acc_dtype, + lse_dtype, + mma_qk_tiler_mn, + mma_pv_tiler_mn, + max_active_clusters, + page_size, + skip_correction_threshold, + is_persistent, + is_var_seq, + is_var_split_kv, + enable_pdl, + ) + + # Get current CUDA stream from PyTorch + torch_stream = torch.cuda.current_stream() + # Get the raw stream pointer as a CUstream + stream = cuda.CUstream(torch_stream.cuda_stream) + + # compile mla kernel + compiled_mla = cute.compile( + mla, + q_latent, + q_rope, + c_latent, + c_rope, + page_table, + o, + lse, + workspace, + split_kv, + cache_seqs, + block_split_kvs, + softmax_scale, + output_scale, + stream, + options="--opt-level 2", + ) + + def torch_reference_mla( + q_latent, + q_rope, + c_latent, + c_rope, + page_table, + cache_seqs, + softmax_scale=1.0, + output_scale=1.0, + ): + # Ref tensors are now contiguous: + # q_latent/q_rope: [B, S_q, H, D] + # c_latent/c_rope: [num_pages, page_size, D] + # Concat along last dim and reshape for SDPA [B, S_q, H, D_total] + q_ref = torch.cat([q_latent, q_rope], dim=3) + # KV cache: concat along last dim, already [num_pages, page_size, D_total] + page_count = page_table_ref.shape[1] + k_ref_paged = torch.cat([c_latent, c_rope], dim=2).reshape( + batch_size * page_count, page_size, latent_dim + rope_dim + ) + v_ref_paged = c_latent.reshape(batch_size * page_count, page_size, latent_dim) + + if is_var_seq: + max_seq_len = torch.max(cache_seqs_ref) + else: + max_seq_len = seq_len_k + + k_ref = torch.zeros([batch_size, 1, max_seq_len, latent_dim + rope_dim]) + v_ref = torch.zeros([batch_size, 1, max_seq_len, latent_dim]) + k_ref = torch.index_select( + k_ref_paged, 0, torch.flatten(page_table_ref) + ).reshape(batch_size, 1, -1, latent_dim + rope_dim)[:, :, :max_seq_len, :] + v_ref = torch.index_select( + v_ref_paged, 0, torch.flatten(page_table_ref) + ).reshape(batch_size, 1, -1, latent_dim)[:, :, :max_seq_len, :] + for b in range(batch_size): + k_ref[b, :, cache_seqs_ref[b] :, :] = 0 + v_ref[b, :, cache_seqs_ref[b] :, :] = 0 + import torch.nn.functional as F + + o_ref = F.scaled_dot_product_attention( + q_ref, + k_ref, + v_ref, + attn_mask=None, + dropout_p=0.0, + scale=softmax_scale, + is_causal=False, + ) + s_ref = torch.einsum("bhld,bhsd->bhls", q_ref, k_ref) + s_ref_max, s_ref_max_pos = torch.max(s_ref, dim=-1, keepdim=True) + softmax_scale_log2 = LOG2_E * softmax_scale + s_ref_sum = torch.sum( + torch.exp2((s_ref - s_ref_max) * softmax_scale_log2), dim=-1, keepdim=True + ) + + lse_ref = s_ref_max * softmax_scale_log2 + torch.log2(s_ref_sum) + lse_ref = lse_ref.squeeze(3) # [B, S_q, H] + o_ref = o_ref * output_scale + # o_ref already [B, S_q, H, D_latent] — matches contiguous output layout + + return o_ref, lse_ref + + if skip_correction_threshold > 0.0: + print( + "Skipping correction verification since skip_correction_threshold is greater than 0.0..." + ) + skip_ref_check = True + if not skip_ref_check: + # Execute kernel once for reference checking + compiled_mla( + q_latent, + q_rope, + c_latent, + c_rope, + page_table, + o, + lse, + workspace, + split_kv, + cache_seqs, + block_split_kvs, + softmax_scale, + output_scale, + stream, + ) + torch.cuda.synchronize() + + print("Verifying results...") + if in_dtype == cutlass.Float8E4M3FN: + tolerance = 0.13 + o_ref, lse_ref = torch_reference_mla( + q_latent_ref, + q_rope_ref, + c_latent_ref, + c_rope_ref, + page_table, + cache_seqs, + softmax_scale, + output_scale, + ) + + if out_dtype in [cutlass.Float8E5M2, cutlass.Float8E4M3FN]: + # {$nv-internal-release begin} + # todo: not sure why, but the below `cute.testing.convert` will cause bus error occasionally in local and ci. + # {$nv-internal-release end} + # convert o back to f32 for comparison + o_fp32, o_fp32_torch = cutlass_torch.cute_tensor_like( + torch.empty(*o_torch.shape, dtype=torch.float32), + cutlass.Float32, + is_dynamic_layout=True, + assumed_align=16, + ) + cute.testing.convert(o, o_fp32) + o = o_fp32_torch.cpu() + ref_fp8, _ = cutlass_torch.cute_tensor_like( + torch.empty(*o_ref.shape, dtype=torch.uint8), + out_dtype, + is_dynamic_layout=True, + assumed_align=16, + ) + o_ref_gpu = o_ref.cuda() + o_ref_f32 = from_dlpack(o_ref_gpu).mark_layout_dynamic(leading_dim=3) + + # convert ref : f32 -> fp8 -> f32 + cute.testing.convert(o_ref_f32, ref_fp8) + cute.testing.convert(ref_fp8, o_ref_f32) + + o_ref = o_ref_gpu.cpu() + else: + o = o_torch.cpu().to(torch.float32) + lse = lse_torch.cpu() + lse_ref = lse_ref.to(cutlass.torch.dtype(lse_dtype)) + # Assert close results + torch.testing.assert_close(o, o_ref, atol=tolerance, rtol=1e-05) + torch.testing.assert_close(lse, lse_ref, atol=tolerance, rtol=1e-05) + print("Results verified successfully!") + + def generate_tensors(): + _, cache_seqs, _ = create_cache_seqs(batch_size, seq_len_k, is_var_seq) + _, page_table, _ = create_page_table( + batch_size, seq_len_k, is_var_seq, page_size + ) + _split_kv, _, block_split_kvs, _ = create_block_split_kvs( + batch_size, + split_kv, + cache_seqs_ref, + is_var_split_kv, + mma_qk_tiler_mn, + cluster_shape_mnk, + max_active_clusters, + ) + + _, q_latent, _ = create_data_tensor( + batch_size, + num_heads, + latent_dim, + in_dtype, + is_dynamic_layout=True, + seq_len_q=seq_len_q, + ) + _, q_rope, _ = create_data_tensor( + batch_size, + num_heads, + rope_dim, + in_dtype, + is_dynamic_layout=True, + seq_len_q=seq_len_q, + ) + + _, c_latent, _ = create_data_tensor( + batch_size, + seq_len_k, + latent_dim, + in_dtype, + is_dynamic_layout=True, + page_table=page_table, + cache_seqs=cache_seqs_ref, + ) + _, c_rope, _ = create_data_tensor( + batch_size, + seq_len_k, + rope_dim, + in_dtype, + is_dynamic_layout=True, + page_table=page_table, + cache_seqs=cache_seqs_ref, + ) + _, o, _ = create_data_tensor( + batch_size, + num_heads, + latent_dim, + out_dtype, + is_dynamic_layout=True, + seq_len_q=seq_len_q, + ) + _, lse, _ = create_data_tensor( + batch_size, + num_heads, + 1, + lse_dtype, + is_dynamic_layout=True, + is_lse=True, + seq_len_q=seq_len_q, + ) + workspace, workspace_torch = create_workspace( + num_heads, seq_len_q, latent_dim, batch_size, _split_kv, acc_dtype + ) + return testing.JitArguments( + q_latent, + q_rope, + c_latent, + c_rope, + page_table, + o, + lse, + workspace, + _split_kv, + cache_seqs, + block_split_kvs, + softmax_scale, + output_scale, + stream, + ) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + q_latent_torch.numel() * q_latent_torch.element_size() + + q_rope_torch.numel() * q_rope_torch.element_size() + + c_latent_torch.numel() * c_latent_torch.element_size() + + c_rope_torch.numel() * c_rope_torch.element_size() + + o_torch.numel() * o_torch.element_size() + + lse_torch.numel() * lse_torch.element_size() + + cache_seqs_torch.numel() * cache_seqs_torch.element_size() + ) + one_workspace_bytes += ( + page_table_torch.numel() * page_table_torch.element_size() + ) + if is_var_split_kv: + one_workspace_bytes += ( + block_split_kvs_torch.numel() * block_split_kvs_torch.element_size() + ) + if workspace_torch is not None: + one_workspace_bytes += ( + workspace_torch.numel() * workspace_torch.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + avg_time_us = testing.benchmark( + compiled_mla, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return avg_time_us # Return execution time in microseconds diff --git a/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp8.py b/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp8.py new file mode 100644 index 0000000000..a4fcd119e4 --- /dev/null +++ b/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp8.py @@ -0,0 +1,4230 @@ +# Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +import math +from typing import Type, Tuple, Optional +from types import SimpleNamespace + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +import cutlass.cute.testing as testing +from cutlass.cute.nvgpu import tcgen05 +from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode + +# Compat shim: setmaxregister_{decrease,increase} added in cutlass-dsl 4.4; +# older versions only have the deprecated warpgroup_reg_{dealloc,alloc}. +_setmaxregister_decrease = getattr( + cute.arch, + "setmaxregister_decrease", + getattr(cute.arch, "warpgroup_reg_dealloc", None), +) +_setmaxregister_increase = getattr( + cute.arch, + "setmaxregister_increase", + getattr(cute.arch, "warpgroup_reg_alloc", None), +) + +# Compat shim: get_max_tmem_alloc_cols added in cutlass-dsl 4.4; +# older versions don't have it, so we provide a fallback implementation. +_TMEM_MAX_ALLOC_COLUMNS_MAP = {"sm_100": 512, "sm_103": 512, "sm_120": 512} + + +def _get_max_tmem_alloc_cols(compute_capability: str) -> int: + if hasattr(cute.arch, "get_max_tmem_alloc_cols"): + return cute.arch.get_max_tmem_alloc_cols(compute_capability) + if compute_capability not in _TMEM_MAX_ALLOC_COLUMNS_MAP: + raise ValueError(f"Unsupported compute capability: {compute_capability}") + return _TMEM_MAX_ALLOC_COLUMNS_MAP[compute_capability] + + +import cutlass.cute.nvgpu.cpasync as cpasync +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.runtime import from_dlpack +from cutlass.cute.arch import Arch +from cutlass.cutlass_dsl import BaseDSL + + +from .mla_helpers import ( + ceil_div, + MAX_SPLITS, + LOG2_E, + MLAStaticTileScheduler, + MLAStaticTileSchedulerParams, + create_mla_static_tile_scheduler, + create_mla_static_tile_scheduler_params, +) + +""" +A Multi-Head Latent Attention (MLA) example using fp8 as input/output for the NVIDIA Blackwell SM100 architecture using CUTE DSL + +This example demonstrates an implementation of inference of multi-head latent attention using a TMA + Blackwell +SM100 TensorCore warp-specialized persistent kernel. The implementation integrates the (Qc + Qr)*(Kc + Kr)^T +matrix multiplication, softmax normalization, and softmax((Qc + Qr)*(Kc + Kr)^T)*Vc into a single kernel. +The kernel provides support for page table storage and variable-length KV cache sequences. It implements KV splitting +functionality to minimize latency when processing long KV sequences. + +The kernel implements key optimizations including: +- Warp specialization for different computation phases (load, MMA, softmax, correction, epilogue) +- Pipeline stages between different warps for overlapping computation and memory access +- Support for different precision data types +- Two sub-kernels (split KV kernel and reduction kernel) that enable split KV processing + +To run this example: + +.. code-block:: bash + + python examples/blackwell/mla_fp8.py \ + --batch_size 4 --latent_dim 512 --rope_dim 64 \ + --num_heads 128 --seq_len_q 1 --seq_len_k 1024 \ + --in_dtype Float8E4M3FN --out_dtype Float8E4M3FN \ + --acc_dtype Float32 --lse_dtype Float32 \ + --is_var_seq --is_var_split_kv \ + --is_persistent + +The above example runs Multi-Head Latent Attention (MLA) with the following configuration: +- Batch size: 4 +- Sequence length of Q: 1 +- Sequence length of K: 1024 +- Latent dimension: 512 +- RoPE dimension: 64 +- Number of heads: 128 +- Data types: Float8E4M3FN (input), Float8E4M3FN (output), Float32 (accumulation and LSE) + +It utilizes page table storage for the KV cache and enables both variable-length KV cache sequences +and variable split KV processing with persistent scheduling. + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/mla_fp8.py \ + --batch_size 4 --latent_dim 512 --rope_dim 64 \ + --num_heads 128 --seq_len_q 1 --seq_len_k 1024 \ + --in_dtype Float8E4M3FN --out_dtype Float8E4M3FN \ + --acc_dtype Float32 --lse_dtype Float32 \ + --is_var_seq --is_var_split_kv \ + --is_persistent --warmup_iterations 3 \ + --iterations 10 --skip_ref_check + +Constraints for this example: +* Data type requirements: + - Input/output: Float8E4M3FN + - Accumulation and LSE: Float32 +* Fixed architecture parameters: + - Number of attention heads: 128 + - Latent dimension: 512 + - RoPE dimension: 64 +* Input query modes should be (NumHeads, LatentDim/RopeDim, SeqLenQ, BatchSize) +* Input kv latent/rope modes should be (SeqLenK, LatentDim/RopeDim, BatchSize) +* Query sequence length must be 1-4 +* Only supports 2-CTA instructions +* Variable sequence length requires page table storage enabled +""" + + +class BlackwellMultiHeadLatentAttentionForwardFP8: + def __init__( + self, + acc_dtype: Type[cutlass.Numeric], + lse_dtype: Type[cutlass.Numeric], + mma_qk_tiler_mn: Tuple[int, int], + mma_pv_tiler_mn: Tuple[int, int], + max_active_clusters: int, + page_size: int, + skip_correction_threshold: float, + is_persistent: bool, + is_var_seq: bool, + is_var_split_kv: bool, + enable_pdl: bool, + ): + """Initializes the configuration for a Blackwell Multi-Head Latent Attention (MLA) kernel. + + :param acc_dtype: Data type for accumulation S and O + :type acc_dtype: Type[cutlass.Numeric] + :param lse_dtype: Data type for output LSE + :type lse_dtype: Type[cutlass.Numeric] + :param mma_s_tiler: The (H, K) tile shape of the MMA instruction for S + :type mma_s_tiler: Tuple[int, int] + :param mma_p_tiler: The (H, D) tile shape of the MMA instruction for P + :type mma_p_tiler: Tuple[int, int] + :param max_active_clusters: Maximum number of active clusters + :type max_active_clusters: int + :param page_size: The page size + :type page_size: int + :param skip_correction_threshold: Threshold to skip correction + :type skip_correction_threshold: float + :param is_persistent: Whether to use persistent kernel mode + :type is_persistent: bool + :param is_var_seq: Whether to use variable sequence length + :type is_var_seq: bool + :param is_var_split_kv: Whether to use variable split KV + :type is_var_split_kv: bool + :param enable_pdl: Whether to use PDL + :type enable_pdl: bool + """ + + self.latent_dim = 512 + self.rope_dim = 64 + self.acc_dtype = acc_dtype + self.lse_dtype = lse_dtype + self.mma_qk_tiler_mn = mma_qk_tiler_mn + self.mma_pv_tiler_mn = mma_pv_tiler_mn + self.max_active_clusters = max_active_clusters + self.skip_correction_threshold = skip_correction_threshold + self.is_persistent = is_persistent + self.page_size = page_size + self.is_var_seq = is_var_seq + self.is_var_split_kv = is_var_split_kv + self.enable_pdl = enable_pdl + self.cluster_shape_mnk = (2, 1, 1) + self.use_2cta_instrs = True + # When using 2 CTAs with m=128: warps 0-1 handle accumulation for first half [0, n/2), + # while warps 2-3 handle accumulation for second half [n/2, n) + self.warps_in_n = 2 + self.num_compute_warps = 4 + self.threads_per_warp = 32 + mma_qk_tiler_k = self.rope_dim * 2 + self.mma_qk_tiler = ( + self.mma_qk_tiler_mn[0], + self.mma_qk_tiler_mn[1], + mma_qk_tiler_k, + ) + self.mma_qk_rope_tiler = ( + self.mma_qk_tiler_mn[0], + self.mma_qk_tiler_mn[1], + self.rope_dim, + ) + self.mma_pv_tiler = ( + self.mma_pv_tiler_mn[0], + self.mma_pv_tiler_mn[1], + self.mma_qk_tiler[1] * self.mma_qk_tiler[2] // self.mma_pv_tiler_mn[1], + ) + self.iterations_qk_latent = self.latent_dim // self.mma_qk_tiler[2] + self.iterations_qk_rope = 1 + self.iterations_qk = self.iterations_qk_latent + self.iterations_qk_rope + self.iterations_pv_k = self.mma_qk_tiler[1] // self.mma_pv_tiler[2] + self.iterations_pv_n = self.latent_dim // self.mma_pv_tiler[1] + + # Set specialized warp ids + self.compute_warp_ids = (0, 1, 2, 3) + self.correction_warp_ids = (4, 5, 6, 7) + self.mma_warp_id = 8 + self.load_tma_k_warp_id = 9 + self.load_tma_v_warp_id = 10 + self.empty_warp_ids = (11,) + self.threads_per_cta = self.threads_per_warp * len( + ( + self.mma_warp_id, + self.load_tma_k_warp_id, + self.load_tma_v_warp_id, + *self.compute_warp_ids, + *self.correction_warp_ids, + *self.empty_warp_ids, + ) + ) + + # register settings + self.softmax_reg_num = 192 + self.correction_reg_num = 256 + self.other_reg_num = 48 + # Named barriers + self.tmem_ptr_sync_bar = pipeline.NamedBarrier( + barrier_id=1, + num_threads=( + self.threads_per_warp + + self.threads_per_warp * self.num_compute_warps * 2 + ), + ) + self.softmax_exchange_sync_bar = pipeline.NamedBarrier( + barrier_id=2, num_threads=(self.threads_per_warp * self.num_compute_warps) + ) + self.epilogue_exchange_sync_bar = pipeline.NamedBarrier( + barrier_id=3, num_threads=(self.threads_per_warp * self.num_compute_warps) + ) + + def _setup_attributes(self): + """Set up configurations and parameters for the MLA kernel operation. + + This method initializes and configures various attributes required for the + execution of the multi-head latent attention kernel, mainly about the pipeline stages: + + - Sets up staging parameters for Q, K, V inputs and accumulator data + - Configures pipeline stages for softmax, correction, and epilogue operations + """ + + self.load_q_stage = 1 + self.load_k_stage = 3 + self.load_v_stage = 2 + self.mma_s_stage = 2 + self.p_mma_stage = 2 + self.p_cor_stage = 2 + self.mma_o_stage = 2 + + self.tmem_o_offset = self.mma_s_stage * self.mma_qk_tiler[1] // self.warps_in_n + self.correction_factor_offset = ( + self.tmem_o_offset + self.latent_dim // self.warps_in_n + ) + + @cute.jit + def __call__( + self, + q_latent: cute.Tensor, + q_rope: cute.Tensor, + c_latent: cute.Tensor, + c_rope: cute.Tensor, + page_table: cute.Tensor, + o: cute.Tensor, + lse: cute.Tensor, + workspace: cute.Tensor, + split_kv: cutlass.Int32, + cache_seqs: Optional[cute.Tensor], + block_split_kvs: Optional[cute.Tensor], + softmax_scale: cutlass.Float32, + output_scale: cutlass.Float32, + stream: cuda.CUstream, + ): + """Execute the Multi-Head Latent Attention operation on the provided tensors. + + The method handles: + 1. Initialization of workspace for temporary split KV buffers + 2. Validation of tensor data types + 3. Initialization of hardware-specific parameters and memory layouts + 4. Configuration of TMA (Tensor Memory Access) operations + 5. Grid and work scheduling computation + 6. Kernel launch(split KV kernel and reduction kernel) with appropriate parameters + + :param q_latent: The query tensor with shape [batch_size, seq_len_q, num_head, latent_dim] (contiguous) + :type q_latent: cute.Tensor + :param q_rope: The query RoPE tensor with shape [batch_size, seq_len_q, num_head, rope_dim] (contiguous) + :type q_rope: cute.Tensor + :param c_latent: The key tensor with shape [num_pages, page_size, latent_dim] (contiguous) + :type c_latent: cute.Tensor + :param c_rope: The key RoPE tensor with shape [num_pages, page_size, rope_dim] (contiguous) + :type c_rope: cute.Tensor + :param page_table: The page table tensor with shape [batch_size, page_count] (contiguous) + :type page_table: cute.Tensor + :param o: The output tensor with shape [batch_size, seq_len_q, num_head, latent_dim] (contiguous) + :type o: cute.Tensor + :param lse: The LSE tensor with shape [batch_size, seq_len_q, num_head] (contiguous) + :type lse: cute.Tensor + :param workspace: The workspace tensor with 1-d shape prepared for acc_o and acc_lse + :type workspace: cute.Tensor + :param split_kv: The scalar factor for split KV + :type split_kv: cutlass.Int32 + :param cache_seqs: The cache sequences tensor with shape [batch_size] + :type cache_seqs: cute.Tensor + :param block_split_kvs: The block split KV tensor with shape [batch_size] + :type block_split_kvs: cute.Tensor + :param softmax_scale: The scale factor for softmax + :type softmax_scale: cutlass.Float32 + :param output_scale: The scale factor for the output + :type output_scale: cutlass.Float32 + :param stream: The CUDA stream to execute the kernel on + :type stream: cuda.CUstream + + :raises TypeError: If tensor data types don't match or aren't supported + """ + + # setup static attributes before smem/grid/tma computation + self.q_dtype = q_latent.element_type + self.k_dtype = c_latent.element_type + self.v_dtype = c_latent.element_type + self.o_dtype = o.element_type + + # check type consistency + if cutlass.const_expr( + self.q_dtype != self.k_dtype or self.q_dtype != self.v_dtype + ): + raise TypeError( + f"Type mismatch: {self.q_dtype} != {self.k_dtype} or {self.q_dtype} != {self.v_dtype}" + ) + + # Reinterpret contiguous [B, S_q, H, D] as [H, D, S_q, B] + # Input stride: (S_q*H*D, H*D, D, 1) → Target: (D, 1, H*D, S_q*H*D) + def _reinterpret_4d(t): + return cute.make_tensor( + t.iterator, + cute.make_layout( + (t.shape[2], t.shape[3], t.shape[1], t.shape[0]), + stride=(t.stride[2], t.stride[3], t.stride[1], t.stride[0]), + ), + ) + + q_latent = _reinterpret_4d(q_latent) + q_rope = _reinterpret_4d(q_rope) + o = _reinterpret_4d(o) + + # Reinterpret contiguous [num_pages, page_size, D] as [page_size, D, num_pages] + # Input stride: (PS*D, D, 1) → Target: (D, 1, PS*D) + def _reinterpret_3d_kv(t): + return cute.make_tensor( + t.iterator, + cute.make_layout( + (t.shape[1], t.shape[2], t.shape[0]), + stride=(t.stride[1], t.stride[2], t.stride[0]), + ), + ) + + c_latent = _reinterpret_3d_kv(c_latent) + c_rope = _reinterpret_3d_kv(c_rope) + + # Reinterpret contiguous [B, page_count] as [page_count, B] + page_table = cute.make_tensor( + page_table.iterator, + cute.make_layout( + (page_table.shape[1], page_table.shape[0]), + stride=(page_table.stride[1], page_table.stride[0]), + ), + ) + + # Reinterpret contiguous [B, S_q, H] as [H, S_q, B] + # Input stride: (S_q*H, H, 1) → Target: (1, H, S_q*H) + lse = cute.make_tensor( + lse.iterator, + cute.make_layout( + (lse.shape[2], lse.shape[1], lse.shape[0]), + stride=(lse.stride[2], lse.stride[1], lse.stride[0]), + ), + ) + + acc_o, acc_lse = self.initialize_workspace( + q_latent.shape[0], + q_latent.shape[1], + q_latent.shape[2], + q_latent.shape[3], + split_kv, + self.acc_dtype, + workspace, + ) + + c_latent_tranpose_layout = cute.select(c_latent.layout, mode=[1, 0, 2]) + c_latent_transpose = cute.make_tensor( + c_latent.iterator, c_latent_tranpose_layout + ) + + self.q_major_mode = OperandMajorMode.K + self.k_major_mode = OperandMajorMode.K + self.v_major_mode = OperandMajorMode.MN + + self._setup_attributes() + + cta_group = tcgen05.CtaGroup.TWO + # the intermediate tensor p is from smem & k-major + p_major_mode = OperandMajorMode.K + qk_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.q_dtype, + self.q_major_mode, + self.k_major_mode, + self.acc_dtype, + cta_group, + self.mma_qk_tiler[:2], + ) + pv_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.v_dtype, + p_major_mode, + self.v_major_mode, + self.acc_dtype, + cta_group, + self.mma_pv_tiler[:2], + ) + + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), + (qk_tiled_mma.thr_id.shape,), + ) + + self.epi_tile = self.mma_pv_tiler[:2] + + q_latent_smem_layout_staged = sm100_utils.make_smem_layout_a( + qk_tiled_mma, + self.mma_qk_tiler, + self.q_dtype, + (self.iterations_qk_latent * self.load_q_stage), + ) + q_latent_smem_layout_staged = cute.logical_divide( + q_latent_smem_layout_staged, (None, None, None, self.iterations_qk_latent) + ) + q_rope_smem_layout_staged = sm100_utils.make_smem_layout_a( + qk_tiled_mma, + self.mma_qk_rope_tiler, + self.q_dtype, + self.load_q_stage, + ) + + kc_latent_smem_layout_staged = sm100_utils.make_smem_layout_b( + qk_tiled_mma, + self.mma_qk_tiler, + self.k_dtype, + (self.iterations_qk_latent * self.load_k_stage), + ) + kc_page_tile_size = min( + self.page_size, qk_tiled_mma.op.shape_mnk[0] // qk_tiled_mma.thr_id.shape + ) + kc_latent_smem_layout_staged = cute.logical_divide( + kc_latent_smem_layout_staged, (None, None, None, self.iterations_qk_latent) + ) + + kc_latent_smem_layout_for_tma = sm100_utils.make_smem_layout( + OperandMajorMode.K, + (self.mma_qk_tiler[0] // qk_tiled_mma.thr_id.shape, self.mma_qk_tiler[2]), + self.k_dtype, + (self.iterations_qk_latent * self.load_k_stage), + ) + kc_latent_smem_layout_for_tma = cute.tiled_divide( + kc_latent_smem_layout_for_tma, (kc_page_tile_size, self.mma_qk_tiler[2]) + ) + kc_latent_smem_layout_for_tma = cute.logical_divide( + kc_latent_smem_layout_for_tma, (None, None, None, self.iterations_qk_latent) + ) + + kc_rope_smem_layout_staged = sm100_utils.make_smem_layout_b( + qk_tiled_mma, + self.mma_qk_rope_tiler, + self.k_dtype, + self.load_k_stage, + ) + kc_rope_smem_layout_for_tma = sm100_utils.make_smem_layout( + OperandMajorMode.K, + ( + self.mma_qk_rope_tiler[0] // qk_tiled_mma.thr_id.shape, + self.mma_qk_rope_tiler[2], + ), + self.k_dtype, + (self.iterations_qk_rope * self.load_k_stage), + ) + kc_rope_smem_layout_for_tma = cute.tiled_divide( + kc_rope_smem_layout_for_tma, (kc_page_tile_size, self.mma_qk_rope_tiler[2]) + ) + + p_smem_layout_staged = sm100_utils.make_smem_layout_a( + pv_tiled_mma, + self.mma_pv_tiler, + self.q_dtype, + (self.iterations_pv_k * self.p_mma_stage), + ) + p_smem_layout_staged = cute.logical_divide( + p_smem_layout_staged, (None, None, None, self.iterations_pv_k) + ) + + vc_smem_layout_staged = sm100_utils.make_smem_layout_b( + pv_tiled_mma, + self.mma_pv_tiler, + self.v_dtype, + (self.iterations_pv_k * self.iterations_pv_n * self.load_v_stage), + ) + vc_smem_layout_staged = cute.logical_divide( + cute.logical_divide( + vc_smem_layout_staged, + (None, None, None, self.iterations_pv_k * self.iterations_pv_n), + ), + (None, None, None, (self.iterations_pv_n, None)), + ) + vc_page_tile_size = min(self.page_size, self.mma_pv_tiler[2]) + vc_smem_layout_for_tma = sm100_utils.make_smem_layout( + OperandMajorMode.MN, + (self.mma_pv_tiler[1] // pv_tiled_mma.thr_id.shape, self.mma_pv_tiler[2]), + self.v_dtype, + (self.iterations_pv_k * self.iterations_pv_n * self.load_v_stage), + ) + vc_smem_layout_for_tma = cute.tiled_divide( + vc_smem_layout_for_tma, + ( + pv_tiled_mma.op.shape_mnk[1] // pv_tiled_mma.thr_id.shape, + vc_page_tile_size, + ), + ) + vc_smem_layout_for_tma = cute.logical_divide( + cute.logical_divide( + vc_smem_layout_for_tma, + (None, None, None, self.iterations_pv_k * self.iterations_pv_n), + ), + (None, None, None, (self.iterations_pv_n, None)), + ) + # TMA load for Q latent and rope + tma_load_op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp(cta_group) + + q_smem_layout = cute.select(q_latent_smem_layout_staged, mode=[0, 1, 2]) + + tma_atom_q_latent, tma_tensor_q_latent = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q_latent, + q_smem_layout, + self.mma_qk_tiler, + qk_tiled_mma, + cta_layout_vmnk.shape, + ) + q_rope_smem_layout = cute.select(q_rope_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_q_rope, tma_tensor_q_rope = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q_rope, + q_rope_smem_layout, + self.mma_qk_rope_tiler, + qk_tiled_mma, + cta_layout_vmnk.shape, + ) + # TMA load for c latent and k rope + kc_smem_layout = cute.select(kc_latent_smem_layout_for_tma, mode=[0]) + tma_atom_c_latent, tma_tensor_c_latent = self.make_paged_tiled_tma_atom( + tma_load_op, + c_latent, + kc_smem_layout, + (self.mma_qk_tiler[1], self.mma_qk_tiler[2]), + qk_tiled_mma, + is_k_load=True, + ) + kc_rope_smem_layout = cute.select(kc_rope_smem_layout_for_tma, mode=[0]) + tma_atom_c_rope, tma_tensor_c_rope = self.make_paged_tiled_tma_atom( + tma_load_op, + c_rope, + kc_rope_smem_layout, + (self.mma_qk_rope_tiler[1], self.mma_qk_rope_tiler[2]), + qk_tiled_mma, + is_k_load=True, + ) + + # TMA load for c latent transpose + vc_smem_layout = cute.select(vc_smem_layout_for_tma, mode=[0]) + tma_atom_c_latent_transpose, tma_tensor_c_latent_transpose = ( + self.make_paged_tiled_tma_atom( + tma_load_op, + c_latent_transpose, + vc_smem_layout, + (self.mma_pv_tiler[1], self.mma_pv_tiler[2]), + pv_tiled_mma, + is_k_load=False, + ) + ) + + q_latent_copy_size = ( + cute.size_in_bytes(self.q_dtype, q_smem_layout) + * cute.size(qk_tiled_mma.thr_id.shape) + * self.iterations_qk_latent + ) + q_rope_copy_size = ( + cute.size_in_bytes(self.q_dtype, q_rope_smem_layout) + * cute.size(qk_tiled_mma.thr_id.shape) + * self.iterations_qk_rope + ) + kc_latent_copy_size = ( + cute.size_in_bytes( + self.k_dtype, + cute.select(kc_latent_smem_layout_staged, mode=[0, 1, 2]), + ) + * cute.size(qk_tiled_mma.thr_id.shape) + * self.iterations_qk_latent + ) + kc_rope_copy_size = ( + cute.size_in_bytes( + self.k_dtype, + cute.select(kc_rope_smem_layout_staged, mode=[0, 1, 2]), + ) + * cute.size(qk_tiled_mma.thr_id.shape) + * self.iterations_qk_rope + ) + vc_copy_size = ( + cute.size_in_bytes( + self.v_dtype, cute.select(vc_smem_layout_staged, mode=[0, 1, 2]) + ) + * cute.size(pv_tiled_mma.thr_id.shape) + * self.iterations_pv_n + * self.iterations_pv_k + ) + + self.tma_copy_q_bytes = q_latent_copy_size + q_rope_copy_size + self.tma_copy_kc_bytes = kc_latent_copy_size + kc_rope_copy_size + self.tma_copy_vc_bytes = vc_copy_size + + tile_sched_params, grid = self._compute_grid( + o, + split_kv, + self.cluster_shape_mnk, + self.max_active_clusters, + self.is_persistent, + ) + + @cute.struct + class SplitKVKernelSharedStorage: + # Pipeline barriers + load_q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_q_stage * 2] + load_k_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_k_stage * 2] + load_v_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_v_stage * 2] + mma_s_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.mma_s_stage * 2] + p_mma_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.p_mma_stage * 2] + p_cor_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.p_cor_stage * 2] + mma_o_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.mma_o_stage * 2] + + # Smem tensors + smem_p: cute.struct.Align[ + cute.struct.MemRange[self.q_dtype, cute.cosize(p_smem_layout_staged)], + 1024, + ] + smem_kc_latent: cute.struct.Align[ + cute.struct.MemRange[ + self.k_dtype, cute.cosize(kc_latent_smem_layout_staged) + ], + 1024, + ] + + smem_kc_rope: cute.struct.Align[ + cute.struct.MemRange[ + self.k_dtype, cute.cosize(kc_rope_smem_layout_staged) + ], + 1024, + ] + smem_q_latent: cute.struct.Align[ + cute.struct.MemRange[ + self.q_dtype, cute.cosize(q_latent_smem_layout_staged) + ], + 1024, + ] + smem_q_rope: cute.struct.Align[ + cute.struct.MemRange[ + self.q_dtype, cute.cosize(q_rope_smem_layout_staged) + ], + 1024, + ] + smem_vc: cute.struct.Align[ + cute.struct.MemRange[self.v_dtype, cute.cosize(vc_smem_layout_staged)], + 1024, + ] + softmax_smem_exchange: cute.struct.MemRange[ + self.acc_dtype, self.num_compute_warps * self.threads_per_warp + ] + epilogue_smem_exchange: cute.struct.MemRange[ + self.acc_dtype, self.num_compute_warps * self.threads_per_warp + ] + + # Tmem dealloc cluster barrier + tmem_dealloc_mbar_ptr: cutlass.Int64 + + # Tmem holding buffer + tmem_holding_buf: cutlass.Int32 + + softmax_scale_log2 = softmax_scale * LOG2_E + + self.split_kv_kernel( + qk_tiled_mma, + pv_tiled_mma, + tma_atom_q_latent, + tma_tensor_q_latent, + tma_atom_q_rope, + tma_tensor_q_rope, + tma_atom_c_latent, + tma_tensor_c_latent, + tma_atom_c_rope, + tma_tensor_c_rope, + tma_atom_c_latent_transpose, + tma_tensor_c_latent_transpose, + page_table, + o, + lse, + acc_o, + acc_lse, + split_kv, + cache_seqs, + block_split_kvs, + softmax_scale_log2, + output_scale, + q_latent_smem_layout_staged, + q_rope_smem_layout_staged, + kc_latent_smem_layout_staged, + kc_rope_smem_layout_staged, + p_smem_layout_staged, + vc_smem_layout_staged, + kc_latent_smem_layout_for_tma, + kc_rope_smem_layout_for_tma, + vc_smem_layout_for_tma, + cta_layout_vmnk, + tile_sched_params, + SplitKVKernelSharedStorage, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk, + smem=SplitKVKernelSharedStorage.size_in_bytes(), # type: ignore[attr-defined] + stream=stream, + min_blocks_per_mp=1, + use_pdl=self.enable_pdl, + ) + if cutlass.const_expr(acc_o is not None): + self.reduction_kernel( + o, + lse, + acc_o, + acc_lse, + split_kv, + cache_seqs, + block_split_kvs, + ).launch( + grid=(q_latent.shape[0], q_latent.shape[2], q_latent.shape[3]), + block=[self.threads_per_warp * self.num_compute_warps, 1, 1], + smem=MAX_SPLITS * self.acc_dtype.width // 8, + stream=stream, + min_blocks_per_mp=1, + use_pdl=self.enable_pdl, + ) + + @cute.jit + def make_paged_tiled_tma_atom( + self, + tma_load_op: cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp, + gmem: cute.Tensor, + smem_layout: cute.Layout, + mma_tiler, + tiled_mma: cute.TiledMma, + is_k_load: bool, + ): + ident = cute.make_identity_layout(gmem.shape) + g_tile = cute.composition(ident, mma_tiler) + cta_mn = mma_tiler[0] // tiled_mma.thr_id.shape + cta_v_map = cute.flat_divide(g_tile, (cta_mn,)) + cta_v_map = cute.select(cta_v_map, mode=[0, 2]) + page_tile_size = ( + min(self.page_size, cta_mn) + if is_k_load + else min(self.page_size, mma_tiler[1]) + ) + cta_v_map = cute.zipped_divide( + cta_v_map, + (page_tile_size, mma_tiler[1]) if is_k_load else (cta_mn, page_tile_size), + ) + cta_v_map = cute.select(cta_v_map, mode=[0]) + from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir + + res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load( + gmem.value, + smem_layout.value, + cta_v_map, + tma_load_op._to_ir(), + num_multicast=1, + ) + return cute.CopyAtom( + tma_load_op, cpasync.CopyBulkTensorTileG2SNonExecTrait(res[0]) + ), res[1] + + @cute.kernel + def split_kv_kernel( + self, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + tma_atom_q_latent: Optional[cute.CopyAtom], + mQL: cute.Tensor, + tma_atom_q_rope: Optional[cute.CopyAtom], + mQR: cute.Tensor, + tma_atom_c_latent: Optional[cute.CopyAtom], + mCL: cute.Tensor, + tma_atom_c_rope: Optional[cute.CopyAtom], + mKR: cute.Tensor, + tma_atom_c_latent_transpose: Optional[cute.CopyAtom], + mCLT: cute.Tensor, + mPT: cute.Tensor, + mO: Optional[cute.Tensor], + mLSE: Optional[cute.Tensor], + mAccO: Optional[cute.Tensor], + mAccLSE: Optional[cute.Tensor], + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + softmax_scale_log2: cutlass.Float32, + output_scale: cutlass.Float32, + q_latent_smem_layout_staged: cute.ComposedLayout, + q_rope_smem_layout_staged: cute.ComposedLayout, + kc_latent_smem_layout_staged: cute.ComposedLayout, + kc_rope_smem_layout_staged: cute.ComposedLayout, + p_smem_layout_staged: cute.ComposedLayout, + vc_smem_layout_staged: cute.ComposedLayout, + kc_latent_smem_layout_for_tma: Optional[cute.ComposedLayout], + kc_rope_smem_layout_for_tma: Optional[cute.ComposedLayout], + vc_smem_layout_for_tma: Optional[cute.ComposedLayout], + cta_layout_vmnk: cute.Layout, + tile_sched_params: MLAStaticTileSchedulerParams, + SharedStorage: cutlass.Constexpr, + ): + """The device split_kv kernel implementation of the Multi-Head Latent Attention. + + This kernel coordinates multiple specialized warps to perform different phases of the MLA computation: + 1. Load warp: Loads Q/C latent/rope data from global memory to shared memory using TMA + 2. MMA warp: Performs matrix multiplications (Q*K^T and P*V) + 3. Compute warps: Compute softmax and do rescaling on accumulators, and store the intermediate/final results + to global memory + + The kernel produces either intermediate or final results of the MLA computation based on the split_kv parameter. + When split_kv is 1, the kernel generates the final results directly. Otherwise, it produces intermediate results + that will later be combined by a reduction kernel. + + The kernel implements a complex pipeline with overlapping computation and memory operations, + using tensor memory access (TMA) for efficient data loading, warp specialization for different + computation phases. + + :param tiled_mma_qk: Tiled MMA for Q*K^T + :type tiled_mma_qk: cute.TiledMma + :param tiled_mma_pv: Tiled MMA for P*V + :type tiled_mma_pv: cute.TiledMma + :param tma_atom_q_latent: TMA copy atom for query latent tensor + :type tma_atom_q_latent: cute.CopyAtom + :param mQL: query latent tensor + :type mQL: cute.Tensor + :param tma_atom_q_rope: TMA copy atom for query rope tensor + :type tma_atom_q_rope: cute.CopyAtom + :param mKR: Compressed rope tensor + :type mKR: cute.Tensor + :param tma_atom_c_latent: TMA copy atom for c latent tensor + :type tma_atom_c_latent: cute.CopyAtom + :param mCL: Compressed latent tensor + :type mCL: cute.Tensor + :param tma_atom_c_rope: TMA copy atom for c rope tensor + :type tma_atom_c_rope: cute.CopyAtom + :param mCLT: Compressed latent transpose tensor + :type mCLT: cute.Tensor + :param mPT: Page table tensor + :type mPT: cute.Tensor + :param mO: Output tensor + :type mO: cute.Tensor + :param mLSE: Log-sum-exp tensor + :type mLSE: cute.Tensor + :param mAccO: Intermediate accumulator output tensor + :type mAccO: cute.Tensor + :param mAccLSE: Intermediate accumulator log-sum-exp tensor + :type mAccLSE: cute.Tensor + :param split_kv: The split_kv parameter + :type split_kv: cutlass.Int32 + :param cache_seqs: The variable sequence length tensor + :type cache_seqs: cute.Tensor + :param block_split_kvs: The per-block split_kv values tensor + :type block_split_kvs: cute.Tensor + :param softmax_scale_log2: The log2 scale factor for softmax + :type softmax_scale_log2: cutlass.Float32 + :param output_scale: The scale factor for the output + :type output_scale: cutlass.Float32 + :param q_latent_smem_layout_staged: Shared memory layout for query tensor + :type q_latent_smem_layout_staged: cute.ComposedLayout + :param q_rope_smem_layout_staged: Shared memory layout for query rope tensor + :type q_rope_smem_layout_staged: cute.ComposedLayout + :param kc_latent_smem_layout_staged: Shared memory layout for key tensor + :type kc_latent_smem_layout_staged: cute.ComposedLayout + :param kc_rope_smem_layout_staged: Shared memory layout for key rope tensor + :type kc_rope_smem_layout_staged: cute.ComposedLayout + :param p_smem_layout_staged: Shared memory layout for probability matrix + :type p_smem_layout_staged: cute.ComposedLayout + :param vc_smem_layout_staged: Shared memory layout for value tensor + :type vc_smem_layout_staged: cute.ComposedLayout + :param cta_layout_vmnk: Layout for compute threads + :type cta_layout_vmnk: cute.Layout + :param tile_sched_params: Scheduling parameters for work distribution + :type tile_sched_params: MLAStaticTileSchedulerParams + :param SharedStorage: Shared storage for the kernel + :type SharedStorage: cutlass.Constexpr + """ + + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma_qk.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + + # Prefetch tma descriptor + if warp_idx == self.mma_warp_id: + cpasync.prefetch_descriptor(tma_atom_q_latent) + cpasync.prefetch_descriptor(tma_atom_q_rope) + cpasync.prefetch_descriptor(tma_atom_c_latent) + cpasync.prefetch_descriptor(tma_atom_c_rope) + cpasync.prefetch_descriptor(tma_atom_c_latent_transpose) + + # Alloc + smem = utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + + # Tensor memory dealloc barrier init + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=self.tmem_ptr_sync_bar, + allocator_warp_id=self.mma_warp_id, + is_two_cta=self.use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + load_q_pipeline = self.make_and_init_load_qkv_pipeline( + storage.load_q_mbar_ptr.data_ptr(), + cta_layout_vmnk, + self.load_q_stage, + self.tma_copy_q_bytes, + ) + load_k_pipeline = self.make_and_init_load_qkv_pipeline( + storage.load_k_mbar_ptr.data_ptr(), + cta_layout_vmnk, + self.load_k_stage, + self.tma_copy_kc_bytes, + ) + load_v_pipeline = self.make_and_init_load_qkv_pipeline( + storage.load_v_mbar_ptr.data_ptr(), + cta_layout_vmnk, + self.load_v_stage, + self.tma_copy_vc_bytes, + ) + mma_s_pipeline = self.make_and_init_mma_s_pipeline( + storage.mma_s_mbar_ptr.data_ptr(), cta_layout_vmnk + ) + p_mma_pipeline = self.make_and_init_p_mma_pipeline( + storage.p_mma_mbar_ptr.data_ptr(), cta_layout_vmnk + ) + p_cor_pipeline = self.make_and_init_p_cor_pipeline( + storage.p_cor_mbar_ptr.data_ptr() + ) + mma_o_pipeline = self.make_and_init_mma_o_pipeline( + storage.mma_o_mbar_ptr.data_ptr(), cta_layout_vmnk + ) + + # Cluster arrive after barrier init + pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mnk, is_relaxed=True) + + # Generate smem tensor Q/KC/VC/exchange + # (MMA, MMA_H, MMA_R, PIPE) + sQ = storage.smem_q_latent.get_tensor( + q_latent_smem_layout_staged.outer, swizzle=q_latent_smem_layout_staged.inner + ) + sQ_rope = storage.smem_q_rope.get_tensor( + q_rope_smem_layout_staged.outer, swizzle=q_rope_smem_layout_staged.inner + ) + # (MMA, MMA_K, MMA_R, PIPE) + sKC = storage.smem_kc_latent.get_tensor( + kc_latent_smem_layout_staged.outer, + swizzle=kc_latent_smem_layout_staged.inner, + ) + sKC_rope = storage.smem_kc_rope.get_tensor( + kc_rope_smem_layout_staged.outer, swizzle=kc_rope_smem_layout_staged.inner + ) + sKC_for_tma = storage.smem_kc_latent.get_tensor( + kc_latent_smem_layout_for_tma.outer, + swizzle=kc_latent_smem_layout_for_tma.inner, + ) + sKC_rope_for_tma = storage.smem_kc_rope.get_tensor( + kc_rope_smem_layout_for_tma.outer, swizzle=kc_rope_smem_layout_for_tma.inner + ) + # (MMA, MMA_D, MMA_K, PIPE) + sVC = storage.smem_vc.get_tensor( + vc_smem_layout_staged.outer, swizzle=vc_smem_layout_staged.inner + ) + sVC_for_tma = storage.smem_vc.get_tensor( + vc_smem_layout_for_tma.outer, swizzle=vc_smem_layout_for_tma.inner + ) + # (MMA, MMA_H, MMA_K) + sP = storage.smem_p.get_tensor( + p_smem_layout_staged.outer, swizzle=p_smem_layout_staged.inner + ) + # (compute_threads,) + softmax_smem_exchange = storage.softmax_smem_exchange.get_tensor( + cute.make_layout(self.num_compute_warps * self.threads_per_warp) + ) + epilogue_smem_exchange = storage.epilogue_smem_exchange.get_tensor( + cute.make_layout(self.num_compute_warps * self.threads_per_warp) + ) + + # + # Cluster wait before tensor memory alloc + # + pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mnk) + + if cutlass.const_expr(self.enable_pdl): + cute.arch.griddepcontrol_wait() + + # /////////////////////////////////////////////////////////////////////////////// + # Load warps, including page table and data tensors + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx >= self.empty_warp_ids[0] and warp_idx <= self.empty_warp_ids[-1]: + _setmaxregister_decrease(self.other_reg_num) + + if warp_idx == self.load_tma_k_warp_id: + _setmaxregister_decrease(self.other_reg_num) + load_q_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.load_q_stage + ) + load_k_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.load_k_stage + ) + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self.get_k_tile_count( + split_kv, + cache_seqs, + block_split_kvs, + blk_coord, + ) + if k_tile_count > 0: + # Construct fixed common/tma_qk/tma_pv params for load_tma + tma_common_params = SimpleNamespace( + blk_coord=blk_coord, + local_split_kv=local_split_kv, + load_q_pipeline=load_q_pipeline, + load_k_pipeline=load_k_pipeline, + load_v_pipeline=load_v_pipeline, + mPT=mPT, + ) + tma_qk_params = SimpleNamespace( + tiled_mma_qk=tiled_mma_qk, + tma_atom_q_latent=tma_atom_q_latent, + tma_atom_q_rope=tma_atom_q_rope, + tma_atom_c_latent=tma_atom_c_latent, + tma_atom_c_rope=tma_atom_c_rope, + mQL=mQL, + mQR=mQR, + mCL=mCL, + mKR=mKR, + sQ=sQ, + sQ_rope=sQ_rope, + sKC=sKC_for_tma, + sKC_rope=sKC_rope_for_tma, + ) + # Load tma + load_q_producer_state, load_k_producer_state = self.load_tma_qk( + tma_common_params, + tma_qk_params, + k_index, + k_tile_count, + load_q_producer_state, + load_k_producer_state, + ) + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + load_q_pipeline.producer_tail(load_q_producer_state) + load_k_pipeline.producer_tail(load_k_producer_state) + + if warp_idx == self.load_tma_v_warp_id: + _setmaxregister_decrease(self.other_reg_num) + load_v_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.load_v_stage + ) + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self.get_k_tile_count( + split_kv, + cache_seqs, + block_split_kvs, + blk_coord, + ) + if k_tile_count > 0: + # Construct fixed common/tma_qk/tma_pv params for load_tma + tma_common_params = SimpleNamespace( + blk_coord=blk_coord, + local_split_kv=local_split_kv, + load_v_pipeline=load_v_pipeline, + mPT=mPT, + ) + tma_pv_params = SimpleNamespace( + tiled_mma_pv=tiled_mma_pv, + tma_atom_c_latent_transpose=tma_atom_c_latent_transpose, + mCLT=mCLT, + sVC=sVC_for_tma, + ) + # Load tma + load_v_producer_state = self.load_tma_v( + tma_common_params, + tma_pv_params, + k_index, + k_tile_count, + load_v_producer_state, + ) + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + load_v_pipeline.producer_tail(load_v_producer_state) + + # /////////////////////////////////////////////////////////////////////////////// + # MMA warp + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.mma_warp_id: + _setmaxregister_decrease(self.other_reg_num) + # Alloc tensor memory buffer + tmem.allocate(_get_max_tmem_alloc_cols("sm_100")) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + + load_q_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.load_q_stage + ) + load_k_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.load_k_stage + ) + load_v_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.load_v_stage + ) + mma_s_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.mma_s_stage + ) + p_mma_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.p_mma_stage + ) + mma_o_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.mma_o_stage + ) + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self.get_k_tile_count( + split_kv, cache_seqs, block_split_kvs, blk_coord + ) + if k_tile_count > 0: + mma_common_params = SimpleNamespace( + blk_coord=blk_coord, + local_split_kv=local_split_kv, + load_q_pipeline=load_q_pipeline, + load_k_pipeline=load_k_pipeline, + load_v_pipeline=load_v_pipeline, + tmem_ptr=tmem_ptr, + is_leader_cta=is_leader_cta, + L=mCL.shape[1], + ) + mma_qk_params = SimpleNamespace( + mma_s_pipeline=mma_s_pipeline, + sQ=sQ, + sQ_rope=sQ_rope, + sKC=sKC, + sKC_rope=sKC_rope, + ) + mma_pv_params = SimpleNamespace( + p_mma_pipeline=p_mma_pipeline, + mma_o_pipeline=mma_o_pipeline, + sP=sP, + sVC=sVC, + ) + ( + tiled_mma_qk, + tiled_mma_pv, + load_q_consumer_state, + load_k_consumer_state, + load_v_consumer_state, + mma_s_producer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) = self.mma( + mma_common_params, + mma_qk_params, + mma_pv_params, + k_tile_count, + tiled_mma_qk, + tiled_mma_pv, + load_q_consumer_state, + load_k_consumer_state, + load_v_consumer_state, + mma_s_producer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + mma_s_pipeline.producer_tail(mma_s_producer_state) + mma_o_pipeline.producer_tail(mma_o_producer_state) + + tmem.relinquish_alloc_permit() + tmem.free(tmem_ptr) + if cutlass.const_expr(self.enable_pdl): + cute.arch.griddepcontrol_launch_dependents() + + # /////////////////////////////////////////////////////////////////////////////// + # Compute warp + # /////////////////////////////////////////////////////////////////////////////// + if ( + warp_idx >= self.compute_warp_ids[0] + and warp_idx <= self.compute_warp_ids[-1] + ): + _setmaxregister_increase(self.softmax_reg_num) + mma_s_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.mma_s_stage + ) + p_mma_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.p_mma_stage + ) + p_cor_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.p_cor_stage + ) + mma_o_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.mma_o_stage + ) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self.get_k_tile_count( + split_kv, cache_seqs, block_split_kvs, blk_coord + ) + if k_tile_count > 0: + compute_common_params = SimpleNamespace( + blk_coord=blk_coord, + split_kv=split_kv, + local_split_kv=local_split_kv, + smem_exchange=softmax_smem_exchange, + mAccO=mAccO, + mO=mO, + K=cache_seqs[blk_coord[2]], + L=mCL.shape[1], + tmem_ptr=tmem_ptr, + tidx=tidx, + p_cor_pipeline=p_cor_pipeline, + ) + compute_softmax_params = SimpleNamespace( + tiled_mma_qk=tiled_mma_qk, + sP=sP, + mma_s_pipeline=mma_s_pipeline, + p_mma_pipeline=p_mma_pipeline, + softmax_scale_log2=softmax_scale_log2, + ) + mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state = ( + self.compute( + compute_common_params, + compute_softmax_params, + k_index=k_index, + k_tile_count=k_tile_count, + mma_s_consumer_state=mma_s_consumer_state, + p_mma_producer_state=p_mma_producer_state, + p_cor_producer_state=p_cor_producer_state, + ) + ) + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + p_cor_pipeline.producer_tail(p_cor_producer_state) + + # /////////////////////////////////////////////////////////////////////////////// + # Correction warp + # /////////////////////////////////////////////////////////////////////////////// + if ( + warp_idx >= self.correction_warp_ids[0] + and warp_idx <= self.correction_warp_ids[-1] + ): + _setmaxregister_increase(self.correction_reg_num) + p_cor_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.p_cor_stage + ) + mma_o_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.mma_o_stage + ) + # sync with mma warp before retrieving tmem ptr + tmem.wait_for_alloc() + + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self.get_k_tile_count( + split_kv, cache_seqs, block_split_kvs, blk_coord + ) + if k_tile_count > 0: + compute_common_params = SimpleNamespace( + blk_coord=blk_coord, + split_kv=split_kv, + local_split_kv=local_split_kv, + smem_exchange=epilogue_smem_exchange, + mAccO=mAccO, + mO=mO, + K=cache_seqs[blk_coord[2]], + L=mCL.shape[1], + H=mQL.shape[0], + tmem_ptr=tmem_ptr, + tidx=tidx, + tiled_mma_pv=tiled_mma_pv, + p_cor_pipeline=p_cor_pipeline, + mma_o_pipeline=mma_o_pipeline, + ) + compute_epilogue_params = SimpleNamespace( + output_scale=output_scale, + softmax_scale_log2=softmax_scale_log2, + mAccLSE=mAccLSE, + mLSE=mLSE, + ) + p_cor_consumer_state, mma_o_consumer_state = self.correction( + compute_common_params, + compute_epilogue_params, + k_tile_count=k_tile_count, + p_cor_consumer_state=p_cor_consumer_state, + mma_o_consumer_state=mma_o_consumer_state, + ) + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + return + + @cute.kernel + def reduction_kernel( + self, + mO: cute.Tensor, + mLSE: cute.Tensor, + mAccO: cute.Tensor, + mAccLSE: cute.Tensor, + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + ): + """The reduction kernel for Multi-Head Latent Attention (MLA) that combines intermediate results + from multiple split_kv blocks into final outputs. + + :param mO: Output tensor for storing final results + :type mO: cute.Tensor + :param mLSE: Log-sum-exp tensor for storing final LSE values + :type mLSE: cute.Tensor + :param mAccO: Accumulated output tensor from split_kv blocks + :type mAccO: cute.Tensor + :param mAccLSE: Accumulated LSE tensor from split_kv blocks + :type mAccLSE: cute.Tensor + :param split_kv: Number of split_kv blocks + :type split_kv: cutlass.Int32 + :param cache_seqs: Cache sequence lengths tensor + :type cache_seqs: cute.Tensor + :param block_split_kvs: Per-block split_kv values tensor (for variable split_kv) + :type block_split_kvs: cute.Tensor + """ + bidx, bidy, bidz = cute.arch.block_idx() + tidx, _, _ = cute.arch.thread_idx() + blk_coord = (bidx, bidy, bidz) + local_split_kv = ( + block_split_kvs[blk_coord[2]] if self.is_var_split_kv else split_kv + ) + k_tile_total = cute.ceil_div(cache_seqs[blk_coord[2]], self.mma_qk_tiler[1]) + k_tile_per_cta = cute.ceil_div(k_tile_total, local_split_kv) + local_split_kv = cute.ceil_div(k_tile_total, k_tile_per_cta) + + # Alloc shared memory + smem = utils.SmemAllocator() + storage = smem.allocate(MAX_SPLITS * self.acc_dtype.width // 8, 16) + lse_scale_ptr = cute.recast_ptr(storage, dtype=self.acc_dtype) + smem_lse_scale = cute.make_tensor(lse_scale_ptr, cute.make_layout(MAX_SPLITS)) + + if cutlass.const_expr(self.enable_pdl): + cute.arch.griddepcontrol_wait() + gLSE = mAccLSE[blk_coord[0], None, blk_coord[1], blk_coord[2]] + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + if warp_idx == 0: + # calculate the global lse and exp ^ (local_lse - global_lse) + lse_per_thread = cute.ceil_div(MAX_SPLITS, self.threads_per_warp) + + local_lse = cute.make_rmem_tensor( + cute.make_layout(lse_per_thread), self.lse_dtype + ) + lse_max = -self.lse_dtype.inf + # find the max lse + for i in cutlass.range_constexpr(lse_per_thread): + split_kv_idx = tidx + i * self.threads_per_warp + local_lse[i] = ( + gLSE[split_kv_idx] + if cute.elem_less(split_kv_idx, local_split_kv) + else -self.lse_dtype.inf + ) + # reduce the local lse + lse_max = cute.arch.fmax(lse_max, local_lse[i]) + lse_max = cute.arch.warp_reduction_max(lse_max) + lse_max = lse_max if lse_max != -self.lse_dtype.inf else 0.0 + # calculate sum_lse + sum_lse = 0.0 + for i in cutlass.range_constexpr(lse_per_thread): + sum_lse += cute.math.exp2(local_lse[i] - lse_max, fastmath=True) + sum_lse = cute.arch.warp_reduction_sum(sum_lse) + # calculate the global_lse + global_lse = ( + lse_max + cute.math.log2(sum_lse, fastmath=True) + if not sum_lse == self.lse_dtype(0.0) or sum_lse != sum_lse # noqa: SIM201 + else self.lse_dtype.inf + ) + if tidx == 0: + mLSE[blk_coord[0], blk_coord[1], blk_coord[2]] = global_lse + # store the scale to shared memory + for i in cutlass.range_constexpr(lse_per_thread): + split_kv_idx = tidx + i * self.threads_per_warp + if cute.elem_less(split_kv_idx, local_split_kv): + smem_lse_scale[split_kv_idx] = cute.math.exp2( + local_lse[i] - global_lse, fastmath=True + ) + + pipeline.sync(barrier_id=4) + + elements_per_thread = cute.ceil_div( + self.latent_dim, self.threads_per_warp * self.num_compute_warps + ) + gAccO = mAccO[blk_coord[0], None, None, blk_coord[1], blk_coord[2]] + rAccO = cute.make_rmem_tensor( + cute.make_layout(elements_per_thread), self.acc_dtype + ) + rO = cute.make_rmem_tensor(cute.make_layout(elements_per_thread), self.o_dtype) + rAccO.fill(0.0) + for i in range(local_split_kv): + for j in cutlass.range_constexpr(elements_per_thread): + element_idx = tidx + j * self.threads_per_warp * self.num_compute_warps + rAccO[j] += gAccO[i, element_idx] * smem_lse_scale[i] + rO.store(rAccO.load().to(self.o_dtype)) + for j in cutlass.range_constexpr(elements_per_thread): + element_idx = tidx + j * self.threads_per_warp * self.num_compute_warps + mO[blk_coord[0], element_idx, blk_coord[1], blk_coord[2]] = rO[j] + if cutlass.const_expr(self.enable_pdl): + cute.arch.griddepcontrol_launch_dependents() + return + + @staticmethod + def get_split_kv( + B: int, S: int, K: int, mma_qk_tiler_mn: tuple, max_active_blocks: int + ) -> int: + """Get the proper split_kv value for the MLA kernel based on parameters. + + :param B: Batch size + :type B: int + :param S: Sequence length + :type S: int + :param K: Sequence length + :type K: int + :param mma_qk_tiler_mn: MLA tiling parameters + :type mma_qk_tiler_mn: tuple + :param max_active_blocks: Maximum number of active blocks + :type max_active_blocks: int + :return: Split_kv value + :rtype: int + """ + max_splits = ceil_div(K, mma_qk_tiler_mn[1]) + blocks_per_batch = max(1, max_active_blocks // B // (S * 2)) + split_heur = min(max_splits, blocks_per_batch) + # {$nv-internal-release begin} + # TODO: figure out the error of make_tile with dynamic int_tuple + # {$nv-internal-release end} + k_waves = ceil_div(max_splits, split_heur) + split_wave_aware = ceil_div(max_splits, k_waves) + max_split_kv = 32 + return min(split_wave_aware, max_split_kv) + + @staticmethod + def get_split_kv_simplified(B: int, S: int, max_active_blocks: int) -> int: + blocks_per_batch = max(1, max_active_blocks // B // (S * 2)) + max_split_kv = 32 + return min(blocks_per_batch, max_split_kv) + + @cute.jit + def get_k_tile_count( + self, + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + blk_coord: cute.Coord, + ) -> tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32]: + """Get the current k_index, k_tile_count, and local split_kv value for the MLA kernel. + + :param split_kv: Split_kv value + :type split_kv: cutlass.Int32 + :param cache_seqs: Cache sequence lengths tensor + :type cache_seqs: cute.Tensor + :param block_split_kvs: Per-block split_kv values tensor + :type block_split_kvs: cute.Tensor + :param blk_coord: Block coordinate + :type blk_coord: cute.Coord + :return: k_index, k_tile_count, split_kv + :rtype: tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32] + """ + K = cache_seqs[blk_coord[2]] + if cutlass.const_expr(self.is_var_split_kv): + split_kv = block_split_kvs[blk_coord[2]] + + k_tile_total = cute.ceil_div(K, self.mma_qk_tiler[1]) + # {$nv-internal-release begin} + # TODO: figure out the error of make_tile with dynamic int_tuple + # {$nv-internal-release end} + k_tile_per_cta = cute.ceil_div(k_tile_total, split_kv) + k_index = blk_coord[3] * k_tile_per_cta + k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index) + return k_index, k_tile_count, split_kv + + @cute.jit + def load_tma_qk( + self, + common_params: SimpleNamespace, + qk_params: SimpleNamespace, + k_index: cutlass.Int32, + k_tile_count: cutlass.Int32, + load_q_producer_state: pipeline.PipelineState | None = None, + load_k_producer_state: pipeline.PipelineState | None = None, + ) -> tuple[pipeline.PipelineState, pipeline.PipelineState]: + """Load wrap to load Q/K tensors. Updates the load qk producer state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param qk_params: The qk parameters + :type qk_params: SimpleNamespace + :param k_index: The k index + :type k_index: cutlass.Int32 + :param k_tile_count: The k tile count + :type k_tile_count: cutlass.Int32 + :param load_q_producer_state: The load q producer state + :type load_q_producer_state: pipeline.PipelineState + :param load_k_producer_state: The load k producer state + :type load_k_producer_state: pipeline.PipelineState + + :return: The load q producer state and load k producer state + :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState] + """ + # page table + mPT = common_params.mPT[None, common_params.blk_coord[2]] + + # Flatten divide and partition global tensors for QK TMA load + # (bM, bK, rM, rK, rL) + mma_qk_tiler_mk = cute.select(self.mma_qk_tiler, mode=[0, 2]) + gQL = cute.flat_divide(qk_params.mQL, mma_qk_tiler_mk) + mma_qk_tiler_mk_rope = cute.select(self.mma_qk_rope_tiler, mode=[0, 2]) + gQR = cute.flat_divide(qk_params.mQR, mma_qk_tiler_mk_rope) + + thr_mma_qk = qk_params.tiled_mma_qk.get_slice( + common_params.blk_coord[0] % cute.size(qk_params.tiled_mma_qk.thr_id) + ) + tSgQL = thr_mma_qk.partition_A(gQL) + tSgQR = thr_mma_qk.partition_A(gQR) + + cta_m = min( + qk_params.tiled_mma_qk.op.shape_mnk[0] + // qk_params.tiled_mma_qk.thr_id.shape, + self.page_size, + ) + page_tile_size = min(self.page_size, cta_m) + gCL = cute.tiled_divide(qk_params.mCL, (page_tile_size, self.mma_qk_tiler[2])) + tSgCL = ( + gCL[ + None, + common_params.blk_coord[0] % qk_params.tiled_mma_qk.thr_id.shape, + None, + None, + ] + if cta_m < self.page_size + else gCL[None, 0, None, None] + ) + gKR = cute.tiled_divide( + qk_params.mKR, (page_tile_size, self.mma_qk_rope_tiler[2]) + ) + tSgKR = ( + gKR[ + None, + common_params.blk_coord[0] % qk_params.tiled_mma_qk.thr_id.shape, + None, + None, + ] + if cta_m < self.page_size + else gKR[None, 0, None, None] + ) + # tma partition for q, k latent/rope + + # smem: ((atom_v, rest_v), STAGE) + # gmem: ((atom_v, rest_v), RestM, RestK, RestL) + tQsQ, tQLgQL_mkl = cpasync.tma_partition( + qk_params.tma_atom_q_latent, + 0, + cute.make_layout(1), + cute.group_modes(qk_params.sQ, 0, 3), + cute.group_modes(tSgQL, 0, 3), + ) + + tQsQ_rope, tQRgQR_mkl = cpasync.tma_partition( + qk_params.tma_atom_q_rope, + 0, + cute.make_layout(1), + cute.group_modes(qk_params.sQ_rope, 0, 3), + cute.group_modes(tSgQR, 0, 3), + ) + tKCsKC, tCLgCL = cpasync.tma_partition( + qk_params.tma_atom_c_latent, + 0, + cute.make_layout(1), + qk_params.sKC, + tSgCL, + ) + + tKCsKC_rope, tKRgKR = cpasync.tma_partition( + qk_params.tma_atom_c_rope, + 0, + cute.make_layout(1), + qk_params.sKC_rope, + tSgKR, + ) + + tQLgQL = tQLgQL_mkl[ + None, None, None, common_params.blk_coord[1], common_params.blk_coord[2] + ] + tQRgQR = tQRgQR_mkl[ + None, None, None, common_params.blk_coord[1], common_params.blk_coord[2] + ] + + # set extra params + common_params.mPT = mPT + qk_params.tQLgQL = tQLgQL + qk_params.tQRgQR = tQRgQR + qk_params.tCLgCL = tCLgCL + qk_params.tKRgKR = tKRgKR + qk_params.tQsQ = tQsQ + qk_params.tQsQ_rope = tQsQ_rope + qk_params.tKCsKC = tKCsKC + qk_params.tKCsKC_rope = tKCsKC_rope + + k_tile_count_init = k_tile_count + while k_tile_count > 0: + # {$nv-internal-release begin} + # TODO: figure out how to support SingleNamespace/struct in ast + # {$nv-internal-release end} + load_q_producer_state, load_k_producer_state = self.load_tma_qk_one_k_tile( + common_params, + qk_params, + k_index, + k_tile_count, + load_q_producer_state, + load_k_producer_state, + load_q=k_tile_count_init == k_tile_count, + ) + k_index += 1 + k_tile_count -= 1 + + return load_q_producer_state, load_k_producer_state + + @cute.jit + def load_tma_v( + self, + common_params: SimpleNamespace, + v_params: SimpleNamespace, + k_index: cutlass.Int32, + k_tile_count: cutlass.Int32, + load_v_producer_state: pipeline.PipelineState, + ) -> pipeline.PipelineState: + """Load wrap to load V tensors. Updates the load v producer state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param v_params: The v parameters + :type v_params: SimpleNamespace + :param k_index: The k index + :type k_index: cutlass.Int32 + :param k_tile_count: The k tile count + :type k_tile_count: cutlass.Int32 + :param load_v_producer_state: The load v producer state + :type load_v_producer_state: pipeline.PipelineState + + :return: The load v producer state + :rtype: pipeline.PipelineState + """ + # page table + mPT = common_params.mPT[None, common_params.blk_coord[2]] + + # Flatten divide and partition global tensors for V TMA load + page_tile_size = min(self.page_size, self.mma_pv_tiler[2]) + gCLT = cute.flat_divide(v_params.mCLT, (self.mma_pv_tiler[1], page_tile_size)) + cta_n = self.mma_pv_tiler[1] // v_params.tiled_mma_pv.thr_id.shape + gCLT = cute.logical_divide(gCLT, (cta_n,))[ + (None, common_params.blk_coord[0]), None, None, None, None + ] + tOgCLT = cute.tiled_divide(gCLT, (cta_n, page_tile_size)) + tOgCLT = tOgCLT[None, 0, 0, None, None, None] + # tma partition for vc + # smem: ((atom_v, rest_v), STAGE) + # gmem: ((atom_v, rest_v), RestM, RestK, RestL) + tVCsVC, tCLTgCLT = cpasync.tma_partition( + v_params.tma_atom_c_latent_transpose, + 0, + cute.make_layout(1), + v_params.sVC, + tOgCLT, + ) + + # set extra params + common_params.mPT = mPT + v_params.tCLTgCLT = tCLTgCLT + v_params.tVCsVC = tVCsVC + + while k_tile_count > 0: + # {$nv-internal-release begin} + # TODO: figure out how to support SingleNamespace/struct in ast + # {$nv-internal-release end} + load_v_producer_state = self.load_tma_v_one_k_tile( + common_params, + v_params, + k_index, + load_v_producer_state, + ) + k_index += 1 + k_tile_count -= 1 + return load_v_producer_state + + @cute.jit + def load_tma_qk_one_k_tile( + self, + common_params: SimpleNamespace, + qk_params: SimpleNamespace, + k_index: cutlass.Int32, + k_tile_count: cutlass.Int32, + load_q_producer_state: pipeline.PipelineState, + load_k_producer_state: pipeline.PipelineState, + load_q: bool, + ) -> tuple[pipeline.PipelineState, pipeline.PipelineState]: + """Load one k-tile of Q/C latent/rope tensors. Updates the load qkv producer state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param qk_params: The qk parameters + :type qk_params: SimpleNamespace + :param k_index: The k index + :type k_index: cutlass.Int32 + :param k_tile_count: The k tile count + :type k_tile_count: cutlass.Int32 + :param load_q_producer_state: The load q producer state + :type load_q_producer_state: pipeline.PipelineState + :param load_k_producer_state: The load kv producer state + :type load_k_producer_state: pipeline.PipelineState + :param load_q: Whether to load q + :type load_q: bool + + :return: The load q producer state and load kv producer state + :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState] + """ + page_per_tile = ceil_div( + self.mma_qk_tiler[1] // self.page_size, qk_params.tiled_mma_qk.thr_id.shape + ) + k_idx = cute.make_rmem_tensor(cute.make_layout(page_per_tile), cutlass.Int32) + for i in cutlass.range_constexpr(page_per_tile): + k_idx[i] = ( + common_params.mPT[k_index] + if self.mma_qk_tiler[1] // self.page_size == 1 + else common_params.mPT[ + ( + k_index * qk_params.tiled_mma_qk.thr_id.shape + + common_params.blk_coord[0] + ) + * page_per_tile + + i + ] + ) + # load q once at first iteration + load_q_pipeline = common_params.load_q_pipeline + if load_q: + # get the mbar ptr from pipeline. + tma_bar_ptr = load_q_pipeline.producer_get_barrier(load_q_producer_state) + # expect the extra bytes for q. + load_q_pipeline.producer_acquire(load_q_producer_state) + for i in cutlass.range_constexpr(self.iterations_qk_latent): + # load q latent + cute.copy( + qk_params.tma_atom_q_latent, + qk_params.tQLgQL[None, 0, i], + qk_params.tQsQ[None, (i, 0)], + tma_bar_ptr=tma_bar_ptr, + ) + for i in cutlass.range_constexpr(self.iterations_qk_rope): + # load q rope + cute.copy( + qk_params.tma_atom_q_rope, + qk_params.tQRgQR[None, 0, i], + qk_params.tQsQ_rope[None, i], + tma_bar_ptr=tma_bar_ptr, + ) + load_q_producer_state.advance() + # get the mbar ptr from pipeline. + tma_bar_ptr = common_params.load_k_pipeline.producer_get_barrier( + load_k_producer_state + ) + common_params.load_k_pipeline.producer_acquire(load_k_producer_state) + for i in range(self.iterations_qk_latent): + for k in range(page_per_tile): + # load k latent + cute.copy( + qk_params.tma_atom_c_latent, + qk_params.tCLgCL[None, i, k_idx[k]], + qk_params.tKCsKC[None, k, 0, (i, load_k_producer_state.index)], + tma_bar_ptr=tma_bar_ptr, + ) + + for i in cutlass.range_constexpr(self.iterations_qk_rope): + for k in cutlass.range_constexpr(page_per_tile): + # load k rope + cute.copy( + qk_params.tma_atom_c_rope, + qk_params.tKRgKR[None, i, k_idx[k]], + qk_params.tKCsKC_rope[None, k, 0, load_k_producer_state.index], + tma_bar_ptr=tma_bar_ptr, + ) + load_k_producer_state.advance() + + return load_q_producer_state, load_k_producer_state + + @cute.jit + def load_tma_v_one_k_tile( + self, + common_params: SimpleNamespace, + v_params: SimpleNamespace, + k_index: cutlass.Int32, + load_v_producer_state: pipeline.PipelineState, + ) -> pipeline.PipelineState: + """Load one k-tile of compressed latent transpose tensor(v). Updates the load qkv producer state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param v_params: The load tma v parameters + :type v_params: SimpleNamespace + :param k_index: The k index + :type k_index: cutlass.Int32 + :param load_v_producer_state: The load v producer state + :type load_v_producer_state: pipeline.PipelineState + + :return: The load qkv producer state + :rtype: pipeline.PipelineState + """ + page_per_tile = self.mma_pv_tiler[2] * self.iterations_pv_k // self.page_size + page_per_subtile = ceil_div(page_per_tile, self.iterations_pv_k) + k_idx = cute.make_rmem_tensor(cute.make_layout(page_per_tile), cutlass.Int32) + for i in cutlass.range_constexpr(page_per_tile): + k_idx[i] = ( + common_params.mPT[k_index] + if page_per_tile == 1 + else common_params.mPT[k_index * page_per_tile + i] + ) + # get the mbar ptr from pipeline. + tma_bar_ptr = common_params.load_v_pipeline.producer_get_barrier( + load_v_producer_state + ) + common_params.load_v_pipeline.producer_acquire(load_v_producer_state) + for j in cutlass.range_constexpr(self.iterations_pv_n): + for i in cutlass.range_constexpr(self.iterations_pv_k): + if cutlass.const_expr(page_per_tile > 1): + for k in cutlass.range_constexpr(page_per_subtile): + k_idx_i = k_idx[k + i * page_per_subtile] + cute.copy( + v_params.tma_atom_c_latent_transpose, + v_params.tCLTgCLT[None, j, 0, k_idx_i], + v_params.tVCsVC[ + None, 0, k, ((j, i), load_v_producer_state.index) + ], + tma_bar_ptr=tma_bar_ptr, + ) + else: + cute.copy( + v_params.tma_atom_c_latent_transpose, + v_params.tCLTgCLT[None, j, i, k_idx[0]], + v_params.tVCsVC[ + None, 0, 0, ((j, i), load_v_producer_state.index) + ], + tma_bar_ptr=tma_bar_ptr, + ) + load_v_producer_state.advance() + return load_v_producer_state + + @cute.jit + def mma( + self, + common_params: SimpleNamespace, + qk_params: SimpleNamespace, + pv_params: SimpleNamespace, + k_tile_count: cutlass.Int32, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + load_q_consumer_state: pipeline.PipelineState, + load_k_consumer_state: pipeline.PipelineState, + load_v_consumer_state: pipeline.PipelineState, + mma_s_producer_state: pipeline.PipelineState, + p_mma_consumer_state: pipeline.PipelineState, + mma_o_producer_state: pipeline.PipelineState, + ) -> tuple[ + cute.TiledMma, + cute.TiledMma, + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + ]: + """MMA warp to compute the result of Q*K^T and P*V. Updates the tiled mma and pipeline states. + + :param common_params: The common parameters for mma qk and pv + :type common_params: SimpleNamespace + :param qk_params: The mma qk parameters + :type qk_params: SimpleNamespace + :param pv_params: The mma pv parameters + :type pv_params: SimpleNamespace + :param k_tile_count: The k tile count + :type k_tile_count: cutlass.Int32 + :param tiled_mma_qk: The tiled mma qk + :type tiled_mma_qk: cute.TiledMma + :param tiled_mma_pv: The tiled mma pv + :type tiled_mma_pv: cute.TiledMma + :param load_q_consumer_state: The load q consumer state + :type load_q_consumer_state: pipeline.PipelineState + :param load_k_consumer_state: The load k consumer state + :type load_k_consumer_state: pipeline.PipelineState + :param load_v_consumer_state: The load v consumer state + :type load_v_consumer_state: pipeline.PipelineState + :param mma_s_producer_state: The mma s producer state + :type mma_s_producer_state: pipeline.PipelineState + :param p_mma_consumer_state: The p mma consumer state + :type p_mma_consumer_state: pipeline.PipelineState + :param mma_o_producer_state: The mma o producer state + :type mma_o_producer_state: pipeline.PipelineState + + :return: The tiled mma qk, the tiled mma pv, the load q consumer state, the load k consumer state, the load v consumer state, the mma s producer state, the p mma consumer state, and the mma o producer state + :rtype: tuple[cute.TiledMma, cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] + """ + + tSrQ = tiled_mma_qk.make_fragment_A(qk_params.sQ) + tSrQ_rope = tiled_mma_qk.make_fragment_A(qk_params.sQ_rope) + tSrKC = tiled_mma_qk.make_fragment_B(qk_params.sKC) + tSrKC_rope = tiled_mma_qk.make_fragment_B(qk_params.sKC_rope) + tOrP = tiled_mma_pv.make_fragment_A(pv_params.sP) + tOrVC = tiled_mma_pv.make_fragment_B(pv_params.sVC) + + tStS_shape = tiled_mma_qk.partition_shape_C( + cute.select(self.mma_qk_tiler, mode=[0, 1]) + ) + tStS_staged_fake = tiled_mma_qk.make_fragment_C( + cute.append(tStS_shape, self.mma_s_stage) + ) + # use real tmem ptr for tStS + tStS_staged = cute.make_tensor(common_params.tmem_ptr, tStS_staged_fake.layout) + tOtO_shape = tiled_mma_pv.partition_shape_C( + cute.select(self.mma_pv_tiler, mode=[0, 1]) + ) + # mma O has 1 stage. + tOtO = tiled_mma_pv.make_fragment_C(tOtO_shape) + tOtO_layout = cute.append( + tOtO.layout, + cute.make_layout( + common_params.L // self.mma_pv_tiler[1], + stride=self.mma_pv_tiler[1] // self.warps_in_n, + ), + ) + tOtO_staged = cute.make_tensor( + tStS_staged.iterator + self.tmem_o_offset, tOtO_layout + ) + + # set more parameters + qk_params.tSrQ = tSrQ + qk_params.tSrQ_rope = tSrQ_rope + qk_params.tSrKC = tSrKC + qk_params.tSrKC_rope = tSrKC_rope + qk_params.tStS_staged = tStS_staged + pv_params.tOrP = tOrP + pv_params.tOrVC = tOrVC + pv_params.tOtO_staged = tOtO_staged + + # mma O accumulates on K, so the accumlate flag is set to False once before all K blocks. + tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, False) + load_q_pipeline = common_params.load_q_pipeline + if common_params.is_leader_cta: + load_q_release_state = load_q_consumer_state.clone() + ( + tiled_mma_qk, + load_q_consumer_state, + load_k_consumer_state, + mma_s_producer_state, + ) = self.mma_qk( + common_params, + qk_params, + tiled_mma_qk, + load_q_consumer_state, + load_k_consumer_state, + mma_s_producer_state, + wait_q=True, + ) + k_tile_count -= 1 + + while k_tile_count > 0: + ( + tiled_mma_qk, + load_q_consumer_state, + load_k_consumer_state, + mma_s_producer_state, + ) = self.mma_qk( + common_params, + qk_params, + tiled_mma_qk, + load_q_consumer_state, + load_k_consumer_state, + mma_s_producer_state, + wait_q=False, + ) + ( + tiled_mma_pv, + load_v_consumer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) = self.mma_pv( + common_params, + pv_params, + tiled_mma_pv, + load_v_consumer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) + k_tile_count -= 1 + # release q consumer states + load_q_pipeline.consumer_release(load_q_release_state) + load_q_release_state.advance() + ( + tiled_mma_pv, + load_v_consumer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) = self.mma_pv( + common_params, + pv_params, + tiled_mma_pv, + load_v_consumer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) + + return ( # type: ignore[return-value] + tiled_mma_qk, + tiled_mma_pv, + load_q_consumer_state, + load_k_consumer_state, + load_v_consumer_state, + mma_s_producer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) + + @cute.jit + def mma_qk( + self, + common_params: SimpleNamespace, + qk_params: SimpleNamespace, + tiled_mma_qk: cute.TiledMma, + load_q_consumer_state: pipeline.PipelineState, + load_k_consumer_state: pipeline.PipelineState, + mma_s_producer_state: pipeline.PipelineState, + wait_q: bool, + ) -> tuple[ + cute.TiledMma, + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + ]: + """Compute one k-tile of mma for Q*K^T. Updates the tiled MMA QK and pipeline states. + + :param qk_params: The qk parameters + :type qk_params: SimpleNamespace + :param tiled_mma_qk: The tiled mma qk + :type tiled_mma_qk: cute.TiledMma + :param load_q_consumer_state: The load q consumer state + :type load_q_consumer_state: pipeline.PipelineState + :param load_k_consumer_state: The load k consumer state + :type load_k_consumer_state: pipeline.PipelineState + :param mma_s_producer_state: The mma s producer state + :type mma_s_producer_state: pipeline.PipelineState + + :return: The tiled mma qk, the load q consumer state, the load k consumer state, and the mma s producer state + :rtype: tuple[cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] + """ + tStS = qk_params.tStS_staged[None, None, None, mma_s_producer_state.index] + + qk_params.mma_s_pipeline.producer_acquire(mma_s_producer_state) + tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, False) + load_q_pipeline = common_params.load_q_pipeline + load_k_pipeline = common_params.load_k_pipeline + if cutlass.const_expr(wait_q): + load_q_pipeline.consumer_wait(load_q_consumer_state) + load_k_pipeline.consumer_wait(load_k_consumer_state) + for q_stage in range(self.iterations_qk_latent): + kc_stage = load_k_consumer_state.index + for k_block in cutlass.range_constexpr(cute.size(qk_params.tSrQ.shape[2])): + cute.gemm( + tiled_mma_qk, + tStS, + qk_params.tSrQ[None, None, k_block, (q_stage, 0)], + qk_params.tSrKC[None, None, k_block, (q_stage, kc_stage)], + tStS, + ) + tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, True) + + for q_stage in range(self.iterations_qk_rope): + kc_stage = load_k_consumer_state.index + for k_block in cutlass.range_constexpr( + self.rope_dim // tiled_mma_qk.shape_mnk[2] + ): + cute.gemm( + tiled_mma_qk, + tStS, + qk_params.tSrQ_rope[None, None, k_block, q_stage], + qk_params.tSrKC_rope[None, None, k_block, kc_stage], + tStS, + ) + tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, True) + load_k_pipeline.consumer_release(load_k_consumer_state) + load_k_consumer_state.advance() + if cutlass.const_expr(wait_q): + load_q_consumer_state.advance() + + qk_params.mma_s_pipeline.producer_commit(mma_s_producer_state) + mma_s_producer_state.advance() + return ( + tiled_mma_qk, + load_q_consumer_state, + load_k_consumer_state, + mma_s_producer_state, + ) + + @cute.jit + def mma_pv( + self, + common_params: SimpleNamespace, + pv_params: SimpleNamespace, + tiled_mma_pv: cute.TiledMma, + load_v_consumer_state: pipeline.PipelineState, + p_mma_consumer_state: pipeline.PipelineState, + mma_o_producer_state: pipeline.PipelineState, + ) -> tuple[ + cute.TiledMma, + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + ]: + """Compute one k-tile of mma for P*V. Updates the tiled mma pv and pipeline states. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param pv_params: The pv parameters + :type pv_params: SimpleNamespace + :param tiled_mma_pv: The tiled mma pv + :type tiled_mma_pv: cute.TiledMma + :param load_v_consumer_state: The load v consumer state + :type load_v_consumer_state: pipeline.PipelineState + :param p_mma_consumer_state: The P MMA consumer state + :type p_mma_consumer_state: pipeline.PipelineState + :param mma_o_producer_state: The MMA o producer state + :type mma_o_producer_state: pipeline.PipelineState + + :return: The tiled mma pv, the load v consumer state, the P MMA consumer state, and the MMA o producer state + :rtype: tuple[cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] + """ + + pv_params.p_mma_pipeline.consumer_wait(p_mma_consumer_state) + load_v_pipeline = common_params.load_v_pipeline + accumulate_flag = tiled_mma_pv.get(tcgen05.Field.ACCUMULATE) + mma_o_pipeline = pv_params.mma_o_pipeline + + load_v_pipeline.consumer_wait(load_v_consumer_state) + vc_stage = load_v_consumer_state.index + for acc_stage in range(self.iterations_pv_n): + mma_o_pipeline.producer_acquire(mma_o_producer_state) + tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, accumulate_flag) + for p_stage in range(self.iterations_pv_k): + tOtO = pv_params.tOtO_staged[None, None, None, acc_stage] + for k_block in cutlass.range_constexpr(pv_params.tOrP.shape[2]): + cute.gemm( + tiled_mma_pv, + tOtO, + pv_params.tOrP[ + None, + None, + k_block, + (p_stage, p_mma_consumer_state.index), + ], + pv_params.tOrVC[ + None, None, k_block, ((acc_stage, p_stage), vc_stage) + ], + tOtO, + ) + tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, True) + + mma_o_pipeline.producer_commit(mma_o_producer_state) + mma_o_producer_state.advance() + load_v_pipeline.consumer_release(load_v_consumer_state) + load_v_consumer_state.advance() + pv_params.p_mma_pipeline.consumer_release(p_mma_consumer_state) + p_mma_consumer_state.advance() + + return ( + tiled_mma_pv, + load_v_consumer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) + + @cute.jit + def compute( + self, + common_params: SimpleNamespace, + softmax_params: SimpleNamespace, + k_index: cutlass.Int32, + k_tile_count: cutlass.Int32, + mma_s_consumer_state: pipeline.PipelineState, + p_mma_producer_state: pipeline.PipelineState, + p_cor_producer_state: pipeline.PipelineState, + ) -> tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState]: + """Compute warp to compute the result of softmax, rescale, and epilogue. Updates the related pipeline states. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param softmax_params: The softmax parameters + :type softmax_params: SimpleNamespace + :param k_index: The index of the k-tile + :type k_index: cutlass.Int32 + :param k_tile_count: The number of k-tiles + :type k_tile_count: cutlass.Int32 + :param mma_s_consumer_state: The MMA s consumer state + :type mma_s_consumer_state: pipeline.PipelineState + :param p_mma_producer_state: The P MMA producer state + :type p_mma_producer_state: pipeline.PipelineState + :param p_cor_producer_state: The P correction producer state + :type p_cor_producer_state: pipeline.PipelineState + + :return: The MMA s consumer state, the P MMA producer state, and the P correction producer state + :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] + """ + + k_tile_total = cute.ceil_div(common_params.K, self.mma_qk_tiler[1]) + + row_max = -self.acc_dtype.inf + row_sum = self.acc_dtype(0) + correction_factor = self.acc_dtype(1) + common_params.p_cor_pipeline.producer_acquire(p_cor_producer_state) + + # no mask applied + while k_tile_count > 1: + ( + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + ) = self.softmax( + common_params, + softmax_params, + k_index, + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + False, + False, + ) + k_index = k_index + 1 + k_tile_count = k_tile_count - 1 + + # mask applied + if cutlass.const_expr(common_params.mAccO is not None): + ( + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + ) = self.softmax( + common_params, + softmax_params, + k_index, + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + k_index == k_tile_total - 1, + True, + ) + else: + ( + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + ) = self.softmax( + common_params, + softmax_params, + k_index, + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + True, + True, + ) + + return mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state + + @cute.jit + def correction( + self, + common_params: SimpleNamespace, + epilogue_params: SimpleNamespace, + k_tile_count: cutlass.Int32, + p_cor_consumer_state: pipeline.PipelineState, + mma_o_consumer_state: pipeline.PipelineState, + ) -> tuple[pipeline.PipelineState, pipeline.PipelineState]: + """Compute warp to compute the result of softmax, rescale, and epilogue. Updates the related pipeline states. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param epilogue_params: The epilogue parameters + :type epilogue_params: SimpleNamespace + :param k_index: The index of the k-tile + :type k_index: cutlass.Int32 + :param k_tile_count: The number of k-tiles + :type k_tile_count: cutlass.Int32 + :param p_cor_consumer_state: The P correction consumer state + :type p_cor_consumer_state: pipeline.PipelineState + :param mma_o_consumer_state: The MMA o consumer state + :type mma_o_consumer_state: pipeline.PipelineState + + :return: The P correction consumer state, and the MMA o consumer state + :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState] + """ + + k_tile_count_init = k_tile_count + while k_tile_count > 0: + p_cor_consumer_state, row_sum, row_max, correction_factor, no_correction = ( + self.get_correction_factor(common_params, p_cor_consumer_state) + ) + if k_tile_count_init != k_tile_count: + mma_o_consumer_state = self.rescale( + common_params, + mma_o_consumer_state, + correction_factor, + no_correction, + ) + k_tile_count = k_tile_count - 1 + if k_tile_count == 0: + mma_o_consumer_state = self.epilogue( + common_params, + epilogue_params, + mma_o_consumer_state, + row_sum, + row_max, + ) + return p_cor_consumer_state, mma_o_consumer_state + + @cute.jit + def exchange_p_cor_metadata( + self, + common_params: SimpleNamespace, + softmax_params: SimpleNamespace, + correction_factor: cutlass.Float32, + row_sum: cutlass.Float32, + row_max: cutlass.Float32, + row_max_new: cutlass.Float32, + tAcc: cute.Tensor, + tidx: cutlass.Int32, + p_cor_producer_state: pipeline.PipelineState, + ) -> tuple[pipeline.PipelineState, cutlass.Float32]: + """Compute the correction factor for the last k tile.""" + no_correction = 0 + if ( + row_max_new - row_max + ) * softmax_params.softmax_scale_log2 <= self.skip_correction_threshold: + no_correction = 1 + row_max_new = row_max + + # pad for 4x32b + corr_layout = cute.make_layout( + (tAcc.shape[0], (4, tAcc.shape[1][1]), self.mma_s_stage), + stride=(tAcc.stride[0], (1, tAcc.stride[1][1]), 4), + ) + tCor = cute.make_tensor( + common_params.tmem_ptr + self.correction_factor_offset, + corr_layout, + ) + cCor = cute.make_identity_tensor(tCor.shape) + corr_tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(4)), self.acc_dtype + ) + corr_tmem_store_tiled_copy = tcgen05.make_tmem_copy(corr_tmem_store_atom, tCor) + corr_tmem_store_thr_copy = corr_tmem_store_tiled_copy.get_slice(tidx) + cCor_for_copy = corr_tmem_store_thr_copy.partition_S(cCor) + tCor_for_copy = corr_tmem_store_thr_copy.partition_D(tCor) + rCor = cute.make_fragment_like( + cCor_for_copy[None, None, None, 0], self.acc_dtype + ) + rCor_int = cute.make_tensor( + cute.recast_ptr(rCor.iterator, dtype=cutlass.Int32), rCor.layout + ) + rCor[0] = row_sum + rCor[1] = row_max_new + rCor[2] = correction_factor + rCor_int[3] = no_correction + + cute.copy( + corr_tmem_store_tiled_copy, + rCor, + tCor_for_copy[None, None, None, p_cor_producer_state.index], + ) + # fence between tmem store and correction warp + cute.arch.fence_view_async_tmem_store() + common_params.p_cor_pipeline.producer_commit(p_cor_producer_state) + p_cor_producer_state.advance() + return p_cor_producer_state, row_max_new + + @cute.jit + def softmax( + self, + common_params: SimpleNamespace, + softmax_params: SimpleNamespace, + k_index: cutlass.Int32, + mma_s_consumer_state: pipeline.PipelineState, + p_mma_producer_state: pipeline.PipelineState, + p_cor_producer_state: pipeline.PipelineState, + row_max: cutlass.Float32, + row_sum: cutlass.Float32, + correction_factor: cutlass.Float32, + is_last_tile: bool, + is_local_last_tile: cutlass.Boolean, + ) -> tuple[ + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + cutlass.Float32, + cutlass.Float32, + cutlass.Float32, + ]: + """Softmax for one k-tile. Updates the related pipeline states and returns the computed results. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param softmax_params: The softmax parameters + :type softmax_params: SimpleNamespace + :param k_index: The index of the k-tile + :type k_index: cutlass.Int32 + :param mma_s_consumer_state: The MMA s consumer state + :type mma_s_consumer_state: pipeline.PipelineState + :param p_mma_producer_state: The P MMA producer state + :type p_mma_producer_state: pipeline.PipelineState + :param p_cor_producer_state: The P correction producer state + :type p_cor_producer_state: pipeline.PipelineState + :param row_max: The row max + :type row_max: cutlass.Float32 + :param row_sum: The row sum + :type row_sum: cutlass.Float32 + :param correction_factor: The correction factor + :type correction_factor: cutlass.Float32 + :param is_last_tile: Whether the last tile + :type is_last_tile: bool + :param is_local_last_tile: Whether the last tile is local + :type is_local_last_tile: cutlass.Boolean + + :return: The MMA s consumer state, the P MMA producer state, the P correction producer state, the row max, the row sum, and the correction factor + :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, cutlass.Float32, cutlass.Float32, cutlass.Float32] + """ + + softmax_params.p_mma_pipeline.producer_acquire(p_mma_producer_state) + softmax_params.mma_s_pipeline.consumer_wait(mma_s_consumer_state) + + # load S from tmem + tStS_shape = softmax_params.tiled_mma_qk.partition_shape_C( + cute.select(self.mma_qk_tiler, mode=[0, 1]) + ) + tStS_staged_fake = softmax_params.tiled_mma_qk.make_fragment_C( + cute.append(tStS_shape, self.mma_s_stage) + ) + tStS_staged = cute.make_tensor(common_params.tmem_ptr, tStS_staged_fake.layout) + tStS = tStS_staged[None, None, None, mma_s_consumer_state.index] + + tAcc = tStS[(None, None), 0, 0] + cta_qk_tiler = ( + self.mma_qk_tiler[0] // self.cluster_shape_mnk[0], + self.mma_qk_tiler[1], + self.mma_qk_tiler[2], + ) + cS = cute.make_identity_tensor(cute.select(cta_qk_tiler, mode=[0, 1])) + + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype + ) + tmem_tiled_copy = tcgen05.make_tmem_copy(tmem_load_atom, tAcc) + + tidx = common_params.tidx % (self.num_compute_warps * self.threads_per_warp) + + tmem_thr_copy = tmem_tiled_copy.get_slice(tidx) + tTR_tAcc = tmem_thr_copy.partition_S(tAcc) + tTR_tS = tmem_thr_copy.partition_D(cS) + + tTR_rAcc = cute.make_fragment_like(tTR_tS, self.acc_dtype) + + row_max_new = row_max + arch = BaseDSL._get_dsl().get_arch_enum() + if cutlass.const_expr(arch >= Arch.sm_100 and arch <= Arch.sm_100f): + cute.copy(tmem_tiled_copy, tTR_tAcc, tTR_rAcc) + for i in cutlass.range_constexpr(cute.size(tTR_rAcc)): + if is_last_tile: + tTR_rAcc[i] = ( + tTR_rAcc[i] + if cute.elem_less( + tTR_tS[i][1] + self.mma_qk_tiler[1] * k_index, + common_params.K, + ) + else -self.acc_dtype.inf + ) + # reduction for row_max + row_max_new = tTR_rAcc.load().reduce(cute.ReductionOp.MAX, row_max_new, 0) + elif cutlass.const_expr(arch >= Arch.sm_103 and arch <= Arch.sm_103f): + tmem_load_red_atom = cute.make_copy_atom( + tcgen05.copy.LdRed32x32bOp( + tcgen05.copy.Repetition(64), redOp=tcgen05.TmemLoadRedOp.MAX + ), + self.acc_dtype, + ) + tmem_red_tiled_copy = tcgen05.make_tmem_copy(tmem_load_red_atom, tAcc) + tmem_red_thr_copy = tmem_red_tiled_copy.get_slice(tidx) + tTR_tAcc_red = tmem_red_thr_copy.partition_S(tAcc) + tTR_tS_red = tmem_red_thr_copy.partition_D(cS) + tTR_rAcc_red = cute.make_fragment_like(tTR_tS_red, self.acc_dtype) + tTR_rMax = cute.make_rmem_tensor( + cute.make_layout((1, tTR_tS_red.shape[1], tTR_tS_red.shape[2])), + self.acc_dtype, + ) + cute.copy( + tmem_red_tiled_copy, + tTR_tAcc_red, + (tTR_rAcc_red, tTR_rMax), + ) + tTR_rAcc = cute.make_tensor(tTR_rAcc_red.iterator, tTR_rAcc.layout) + if is_last_tile: + for i in cutlass.range_constexpr(cute.size(tTR_rAcc)): + tTR_rAcc[i] = ( + tTR_rAcc[i] + if cute.elem_less( + tTR_tS[i][1] + self.mma_qk_tiler[1] * k_index, + common_params.K, + ) + else -self.acc_dtype.inf + ) + # reduction for row_max + row_max_new = tTR_rAcc.load().reduce( + cute.ReductionOp.MAX, row_max_new, 0 + ) + else: + row_max_new = cute.arch.fmax(row_max_new, tTR_rMax[0]) + + # if warps in N is 2, reduce row_max across warps (0, 1) and (2, 3) + if cutlass.const_expr(self.warps_in_n == 2): + common_params.smem_exchange[tidx] = row_max_new + self.softmax_exchange_sync_bar.wait() + row_max_new = cute.arch.fmax( + row_max_new, + common_params.smem_exchange[ + (tidx + 64) % (self.num_compute_warps * self.threads_per_warp) + ], + ) + + # find correction factor + correction_factor = cute.math.exp2( + (row_max - row_max_new) * softmax_params.softmax_scale_log2, fastmath=True + ) + # split kv case + if cutlass.const_expr(not is_local_last_tile): + p_cor_producer_state, row_max_new = self.exchange_p_cor_metadata( + common_params, + softmax_params, + correction_factor, + row_sum, + row_max, + row_max_new, + tAcc, + tidx, + p_cor_producer_state, + ) + + # softmax + fma_b = softmax_params.softmax_scale_log2 + fma_c = (0.0 - row_max_new) * softmax_params.softmax_scale_log2 + + for i in cutlass.range(cute.size(tTR_rAcc), vectorize=True, unroll_full=True): + tTR_rAcc[i] = tTR_rAcc[i] * fma_b + fma_c + tTR_rAcc[i] = cute.math.exp2(tTR_rAcc[i], fastmath=True) + + tTR_rS = cute.make_fragment_like(tTR_tS, self.q_dtype) + + # quantize + tTR_rS.store(tTR_rAcc.load().to(self.q_dtype)) + + # create sP + sP = softmax_params.sP[None, None, None, (None, p_mma_producer_state.index)] + sP_mk_view = cute.make_tensor( + sP.iterator, + cute.make_layout( + ( + (sP.shape[0][0], sP.shape[1]), + (sP.shape[0][1], sP.shape[2], sP.shape[3]), + ), + stride=( + (sP.stride[0][0], sP.stride[1]), + (sP.stride[0][1], sP.stride[2], sP.stride[3]), + ), + ), + ) + # {$nv-internal-release begin} + # TODO: figure out if we could use A tmem for pv. + # {$nv-internal-release end} + # change to PISL + sP_wo_swizzle_iter = cute.recast_ptr(sP.iterator, swizzle_=None) + swizzle_bits = ( + int(math.log2(self.mma_pv_tiler[2] * self.q_dtype.width // 8 // 32)) + 1 + ) + swizzle_base = 3 if self.q_dtype.width == 16 else 4 + sP_swizzle = cute.make_swizzle(swizzle_bits, swizzle_base, 3) + sP_mk_view = cute.make_tensor( + sP_wo_swizzle_iter, + cute.make_composed_layout(sP_swizzle, 0, sP_mk_view.layout), + ) + universal_copy_bits = 128 + smem_copy_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.q_dtype, + num_bits_per_copy=universal_copy_bits, + ) + smem_tiled_copy = cute.make_tiled_copy_D(smem_copy_atom, tmem_tiled_copy) + smem_thr_copy = smem_tiled_copy.get_slice(tidx) + rP_copy_view = smem_thr_copy.retile(tTR_rS) + sP_copy_view = smem_thr_copy.partition_D(sP_mk_view) + cute.copy(smem_tiled_copy, rP_copy_view, sP_copy_view) + + # fence between smem store and mma o + cute.arch.fence_view_async_shared() + softmax_params.p_mma_pipeline.producer_commit(p_mma_producer_state) + p_mma_producer_state.advance() + + # row_sum, using `add_packed_f32x2` to reduce the number of instructions + row_sum = row_sum * correction_factor + row_sum_vec = (0.0, 0.0) + for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2): + row_sum_vec = cute.arch.add_packed_f32x2( + row_sum_vec, (tTR_rAcc[i], tTR_rAcc[i + 1]) + ) + row_sum = row_sum_vec[0] + row_sum_vec[1] + row_sum + + # split kv case + if cutlass.const_expr(is_local_last_tile): + p_cor_producer_state, row_max_new = self.exchange_p_cor_metadata( + common_params, + softmax_params, + correction_factor, + row_sum, + row_max, + row_max_new, + tAcc, + tidx, + p_cor_producer_state, + ) + + # store correction factor/row_sum/row_max to tmem for correction warp + common_params.p_cor_pipeline.producer_acquire(p_cor_producer_state) + + # fence between tmem load and mma s + cute.arch.fence_view_async_tmem_load() + + softmax_params.mma_s_pipeline.consumer_release(mma_s_consumer_state) + mma_s_consumer_state.advance() + + return ( + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max_new, + row_sum, + correction_factor, + ) + + @cute.jit + def _tmem_load_partition( + self, common_params: SimpleNamespace, tiled_mma_pv: cute.TiledMma, iter_n: int + ) -> tuple[ + cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma + ]: + """Tensor memory load partition for rescale and epilogue. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param tiled_mma_pv: The tiled mma pv + :type tiled_mma_pv: cute.TiledMma + :param iter_n: The iteration number + :type iter_n: int + + :return: The tiled mma pv, the tiled mma pv, the tiled mma pv, the tiled mma pv, the tiled mma pv + :rtype: tuple[cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma] + """ + + tOtO_shape = tiled_mma_pv.partition_shape_C( + cute.select(self.mma_pv_tiler, mode=[0, 1]) + ) + tOtO = tiled_mma_pv.make_fragment_C(tOtO_shape) + tOtO_layout = cute.append( + tOtO.layout, + cute.make_layout( + common_params.L // self.mma_pv_tiler[1], + stride=self.mma_pv_tiler[1] // self.warps_in_n, + ), + ) + tOtO = cute.make_tensor( + common_params.tmem_ptr + self.tmem_o_offset, tOtO_layout + ) + tOtO = tOtO[None, None, None, iter_n] + + tAcc = tOtO[(None, None), 0, 0] + + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype + ) + tmem_load_tiled_copy = tcgen05.make_tmem_copy(tmem_load_atom, tAcc) + # {$nv-internal-release begin} + # TODO: supports size() on tiled copy. + # {$nv-internal-release end} + tmem_load_thr_copy = tmem_load_tiled_copy.get_slice( + common_params.tidx % (self.num_compute_warps * self.threads_per_warp) + ) + + cta_pv_tiler = ( + self.mma_pv_tiler[0] // self.cluster_shape_mnk[0], + self.mma_pv_tiler[1], + self.mma_pv_tiler[2], + ) + # Flatten divide and partition global tensors for O + cta_pv_tiler_mn = cute.select(cta_pv_tiler, mode=[0, 1]) + + gO = None + if cutlass.const_expr(common_params.mAccO is not None): + gO = cute.local_tile( + common_params.mAccO[None, common_params.blk_coord[3], None, None, None], + cta_pv_tiler_mn, + ( + common_params.blk_coord[0], + iter_n, + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + ) + cO = cute.local_tile( + cute.make_identity_tensor( + common_params.mAccO[ + None, common_params.blk_coord[3], None, None, None + ].shape + ), + cta_pv_tiler_mn, + ( + common_params.blk_coord[0], + iter_n, + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + ) + else: + gO = cute.local_tile( + common_params.mO, + cta_pv_tiler_mn, + ( + common_params.blk_coord[0], + iter_n, + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + ) + cO = cute.local_tile( + cute.make_identity_tensor(common_params.mO.shape), + cta_pv_tiler_mn, + ( + common_params.blk_coord[0], + iter_n, + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + ) + tTR_tAcc = tmem_load_thr_copy.partition_S(tAcc) + tTR_gO = tmem_load_thr_copy.partition_D(gO) + tTR_cO = tmem_load_thr_copy.partition_D(cO) + tTR_rAcc = cute.make_fragment_like(tTR_gO, self.acc_dtype) + return tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc # type: ignore[return-value] + + def get_correction_factor( + self, + common_params: SimpleNamespace, + p_cor_consumer_state: pipeline.PipelineState, + ) -> tuple[ + pipeline.PipelineState, + cutlass.Float32, + cutlass.Float32, + cutlass.Float32, + cutlass.Int32, + ]: + """Get the correction factor from the P correction consumer state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param p_cor_consumer_state: The P correction consumer state + :type p_cor_consumer_state: pipeline.PipelineState + + :return: The P correction consumer state, the row_sum, the row_max, and the correction factor + :rtype: tuple[pipeline.PipelineState, cutlass.Float32, cutlass.Float32, cutlass.Float32, cutlass.Int32] + """ + common_params.p_cor_pipeline.consumer_wait(p_cor_consumer_state) + tidx = common_params.tidx % (self.num_compute_warps * self.threads_per_warp) + # load correction factor + _, tAcc, _, _, _, _ = self._tmem_load_partition( + common_params, common_params.tiled_mma_pv, 0 + ) + corr_layout = cute.make_layout( + (tAcc.shape[0], (4, tAcc.shape[1][1]), self.p_cor_stage), + stride=(tAcc.stride[0], (1, tAcc.stride[1][1]), 4), + ) + tCor = cute.make_tensor( + common_params.tmem_ptr + self.correction_factor_offset, corr_layout + ) + cCor = cute.make_identity_tensor(tCor.shape) + corr_tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(4)), self.acc_dtype + ) + corr_tmem_load_tiled_copy = tcgen05.make_tmem_copy(corr_tmem_load_atom, tCor) + corr_tmem_load_thr_copy = corr_tmem_load_tiled_copy.get_slice(tidx) + tCor_for_copy = corr_tmem_load_thr_copy.partition_S(tCor) + cCor_for_copy = corr_tmem_load_thr_copy.partition_D(cCor) + rCor = cute.make_fragment_like( + cCor_for_copy[None, None, None, 0], self.acc_dtype + ) + rCor_int = cute.make_tensor( + cute.recast_ptr(rCor.iterator, dtype=cutlass.Int32), rCor.layout + ) + cute.copy( + corr_tmem_load_tiled_copy, + tCor_for_copy[None, None, None, p_cor_consumer_state.index], + rCor, + ) + row_sum = rCor[0] + row_max = rCor[1] + correction_factor = rCor[2] + no_correction = rCor_int[3] + + common_params.p_cor_pipeline.consumer_release(p_cor_consumer_state) + p_cor_consumer_state.advance() + return p_cor_consumer_state, row_sum, row_max, correction_factor, no_correction + + @cute.jit + def rescale( + self, + common_params: SimpleNamespace, + mma_o_consumer_state: pipeline.PipelineState, + correction_factor: cutlass.Float32, + no_correction: cutlass.Int32, + ) -> pipeline.PipelineState: + """Rescale for one k-tile. Updates the related pipeline state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param mma_o_consumer_state: The mma o consumer state + :type mma_o_consumer_state: pipeline.PipelineState + :param correction_factor: The correction factor + :type correction_factor: cutlass.Float32 + :param no_correction: Whether to apply correction factor + :type no_correction: cutlass.Int32 + + :return: The MMA o consumer state + :rtype: pipeline.PipelineState + """ + skip_correction = cute.arch.vote_all_sync(no_correction == 1) + for iter_n in cutlass.range_constexpr(self.iterations_pv_n): + common_params.mma_o_pipeline.consumer_wait(mma_o_consumer_state) + if not skip_correction: + # tmem load tiled copy and partition results. + tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc = ( + self._tmem_load_partition( + common_params, common_params.tiled_mma_pv, iter_n + ) + ) + + # tmem store tiled copy + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype + ) + tmem_store_tiled_copy = tcgen05.make_tmem_copy(tmem_store_atom, tAcc) + + # load o + cute.copy(tmem_load_tiled_copy, tTR_tAcc, tTR_rAcc) + # rescale, using `mul_packed_f32x2` to reduce the number of instructions + for i in cutlass.range( + cute.size(tTR_rAcc), vectorize=True, unroll_full=True + ): + tTR_rAcc[i] = tTR_rAcc[i] * correction_factor + + # store o to tensor memory for next k tile + cute.copy(tmem_store_tiled_copy, tTR_rAcc, tTR_tAcc) + + cute.arch.fence_view_async_tmem_store() + common_params.mma_o_pipeline.consumer_release(mma_o_consumer_state) + mma_o_consumer_state.advance() + + return mma_o_consumer_state + + @cute.jit + def epilogue( + self, + common_params: SimpleNamespace, + epilogue_params: SimpleNamespace, + mma_o_consumer_state: pipeline.PipelineState, + row_sum: cutlass.Float32, + row_max: cutlass.Float32, + ) -> pipeline.PipelineState: + """Epilogue for one k-tile. Updates the related pipeline state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param epilogue_params: The epilogue parameters + :type epilogue_params: SimpleNamespace + :param mma_o_consumer_state: The mma o consumer state + :type mma_o_consumer_state: pipeline.PipelineState + :param row_sum: The row sum + :type row_sum: cutlass.Float32 + :param row_max: The row max + :type row_max: cutlass.Float32 + + :return: The MMA o consumer state + :rtype: pipeline.PipelineState + """ + + tidx = common_params.tidx % (self.num_compute_warps * self.threads_per_warp) + + # exchange row_sum between warps (0, 1) and (2, 3) + if cutlass.const_expr(self.warps_in_n == 2): + common_params.smem_exchange[tidx] = row_sum + self.epilogue_exchange_sync_bar.wait() + # (64, 2) + row_sum = ( + row_sum + + common_params.smem_exchange[ + (tidx + 64) % (self.num_compute_warps * self.threads_per_warp) + ] + ) + # mma_o pipeline consumer wait + for iter_n in cutlass.range_constexpr(self.iterations_pv_n): + common_params.mma_o_pipeline.consumer_wait(mma_o_consumer_state) + # tmem load tiled copy and partition results. + tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc = ( + self._tmem_load_partition( + common_params, common_params.tiled_mma_pv, iter_n + ) + ) + + # load o + cute.copy(tmem_load_tiled_copy, tTR_tAcc, tTR_rAcc) + + # apply output scale and normalize by row_sum + for i in cutlass.range( + cute.size(tTR_rAcc), vectorize=True, unroll_full=True + ): + tTR_rAcc[i] = ( + tTR_rAcc[i] + * epilogue_params.output_scale + * cute.arch.rcp_approx(row_sum) + ) + + # store o to global memory + tR2G_rO_src = None + tR2G_rO_dst = tTR_gO + if cutlass.const_expr(common_params.mAccO is None): + tR2G_rO_src = cute.make_fragment_like(tTR_gO, self.o_dtype) + # using final output dtype for o + tR2G_rO_src.store(tTR_rAcc.load().to(self.o_dtype)) + else: + # using accumulate dtype for o + tR2G_rO_src = tTR_rAcc + + if cute.elem_less(tTR_cO[0][0], common_params.H): + cute.autovec_copy( + tR2G_rO_src, + tR2G_rO_dst, + l1c_evict_priority=cute.nvgpu.CacheEvictionPriority.NO_ALLOCATE, + ) + + # store the lse to global memory + cta_pv_tiler = ( + self.mma_pv_tiler[0] // self.cluster_shape_mnk[0], + self.mma_pv_tiler[1], + self.mma_pv_tiler[2], + ) + gLSE = None + cLSE = None + if cutlass.const_expr(epilogue_params.mAccLSE is None): + gLSE = cute.local_tile( + epilogue_params.mLSE, + (cta_pv_tiler[0], 1, 1), + ( + common_params.blk_coord[0], + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + (1, 1, 1), + ) + cLSE = cute.local_tile( + cute.make_identity_tensor(epilogue_params.mLSE.shape), + (cta_pv_tiler[0], 1, 1), + ( + common_params.blk_coord[0], + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + (1, 1, 1), + ) + + else: + gLSE = cute.local_tile( + epilogue_params.mAccLSE[ + None, common_params.blk_coord[3], None, None + ], + (cta_pv_tiler[0], 1, 1), + ( + common_params.blk_coord[0], + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + (1, 1, 1), + ) + cLSE = cute.local_tile( + cute.make_identity_tensor( + epilogue_params.mAccLSE[ + None, common_params.blk_coord[3], None, None + ].shape + ), + (cta_pv_tiler[0], 1, 1), + ( + common_params.blk_coord[0], + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + (1, 1, 1), + ) + lse = ( + cute.math.log2(row_sum, fastmath=True) + + epilogue_params.softmax_scale_log2 * row_max + ) + if cutlass.const_expr(self.warps_in_n == 2): + if cute.elem_less(cLSE[tidx][0], common_params.H): + gLSE[tidx] = lse + + cute.arch.fence_view_async_tmem_load() + common_params.mma_o_pipeline.consumer_release(mma_o_consumer_state) + mma_o_consumer_state.advance() + + return mma_o_consumer_state + + def make_and_init_load_qkv_pipeline( + self, load_qkv_mbar_ptr, cta_layout_vmnk, load_stages, tx_count + ) -> pipeline.PipelineTmaUmma: + """Create and initialize the tma load qkv pipeline. + + :param load_qkv_mbar_ptr: The load qkv mbar pointer + :type load_qkv_mbar_ptr: cute.Tensor + :param cta_layout_vmnk: The cta layout vmnk + :type cta_layout_vmnk: tuple[int, int, int] + :param load_stages: The load stages + :type load_stages: list[int] + :param tx_count: The tx count + :type tx_count: int + + :return: The tma load qkv pipeline + :rtype: pipeline.PipelineTmaUmma + """ + load_qkv_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.load_tma_k_warp_id]) + ) + load_qkv_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + return pipeline.PipelineTmaUmma.create( + barrier_storage=load_qkv_mbar_ptr, + num_stages=load_stages, + producer_group=load_qkv_producer_group, + consumer_group=load_qkv_consumer_group, + tx_count=tx_count, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + + def make_and_init_mma_s_pipeline( + self, mma_s_mbar_ptr, cta_layout_vmnk + ) -> pipeline.PipelineUmmaAsync: + """Create and initialize the mma s pipeline. + + :param mma_s_mbar_ptr: The mma s mbar pointer + :type mma_s_mbar_ptr: cute.Tensor + :param cta_layout_vmnk: The cta layout vmnk + :type cta_layout_vmnk: tuple[int, int, int] + + :return: The mma s pipeline + :rtype: pipeline.PipelineUmmaAsync + """ + + mma_s_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + consumer_thread_size = ( + self.threads_per_warp + * len(self.compute_warp_ids) + * self.cluster_shape_mnk[0] + ) + mma_s_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + consumer_thread_size, + ) + return pipeline.PipelineUmmaAsync.create( + barrier_storage=mma_s_mbar_ptr, + num_stages=self.mma_s_stage, + producer_group=mma_s_producer_group, + consumer_group=mma_s_consumer_group, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + + def make_and_init_p_mma_pipeline( + self, p_mma_mbar_ptr, cta_layout_vmnk + ) -> pipeline.PipelineAsyncUmma: + """Create and initialize the p mma pipeline. + + :param p_mma_mbar_ptr: The p mma mbar pointer + :type p_mma_mbar_ptr: cute.Tensor + :param cta_layout_vmnk: The cta layout vmnk + :type cta_layout_vmnk: tuple[int, int, int] + + :return: The p mma pipeline + :rtype: pipeline.PipelineAsyncUmma + """ + + producer_thread_size = ( + self.threads_per_warp + * len(self.compute_warp_ids) + * self.cluster_shape_mnk[0] + ) + p_mma_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + producer_thread_size, + ) + p_mma_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + return pipeline.PipelineAsyncUmma.create( + barrier_storage=p_mma_mbar_ptr, + num_stages=self.p_mma_stage, + producer_group=p_mma_producer_group, + consumer_group=p_mma_consumer_group, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + + def make_and_init_p_cor_pipeline( + self, p_cor_mbar_ptr + ) -> pipeline.PipelineAsyncUmma: + """Create and initialize the p correction pipeline. + + :param p_cor_mbar_ptr: The p correction mbar pointer + :type p_cor_mbar_ptr: cute.Tensor + + :return: The p correction pipeline + :rtype: pipeline.PipelineAsyncUmma + """ + + producer_thread_size = self.threads_per_warp * len(self.compute_warp_ids) + p_cor_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + producer_thread_size, + ) + p_cor_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + producer_thread_size, + ) + return pipeline.PipelineAsync.create( + barrier_storage=p_cor_mbar_ptr, + num_stages=self.p_cor_stage, + producer_group=p_cor_producer_group, + consumer_group=p_cor_consumer_group, + defer_sync=True, + ) + + def make_and_init_mma_o_pipeline( + self, mma_o_mbar_ptr, cta_layout_vmnk + ) -> pipeline.PipelineUmmaAsync: + """Create and initialize the mma o pipeline. + + :param mma_o_mbar_ptr: The mma o mbar pointer + :type mma_o_mbar_ptr: cute.Tensor + :param cta_layout_vmnk: The cta layout vmnk + :type cta_layout_vmnk: tuple[int, int, int] + + :return: The mma o pipeline + :rtype: pipeline.PipelineUmmaAsync + """ + + mma_o_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + consumer_thread_size = ( + self.threads_per_warp + * len(self.compute_warp_ids) + * self.cluster_shape_mnk[0] + ) + mma_o_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + consumer_thread_size, + ) + return pipeline.PipelineUmmaAsync.create( + barrier_storage=mma_o_mbar_ptr, + num_stages=self.mma_o_stage, + producer_group=mma_o_producer_group, + consumer_group=mma_o_consumer_group, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + + @staticmethod + def _compute_grid( + o: cute.Tensor, + split_kv: cutlass.Int32, + cluster_shape_mnk: Tuple[int, int, int], + max_active_clusters: int, + is_persistent: bool, + ) -> Tuple[MLAStaticTileSchedulerParams, Tuple[int, int, int]]: + """Compute grid shape for the output tensor C. + + :param c: The output tensor C + :type c: cute.Tensor + :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type cta_tile_shape_mnk: tuple[int, int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + + :return: Tile scheduler parameters and grid shape. + :rtype: tuple[MLAStaticTileSchedulerParams, tuple[int, int, int]] + """ + o_shape = o.shape + tile_sched_params = create_mla_static_tile_scheduler_params( + is_persistent, + cute.size(o_shape[3]), + cute.size(o_shape[2]), + cluster_shape_mnk, + split_kv, + ) + grid = MLAStaticTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + + return tile_sched_params, grid + + @staticmethod + def get_workspace_size( + H: int, + S: int, + D: int, + B: int, + split_kv: int, + acc_dtype: Type[cutlass.Numeric], + ) -> int: + """Get the extra workspace(device memory) size for the MLA kernel when split_kv is not 1. + + :param H: The height of the output tensor C + :type H: int + :param S: The sequence length of the output tensor C + :type S: int + :param D: The depth of the output tensor C + :type D: int + :param B: The batch size of the output tensor C + :type B: int + :param split_kv: The split key-value of the output tensor C + :type split_kv: int + :param acc_dtype: The data type of the output tensor C + :type acc_dtype: Type[cutlass.Numeric] + + :return: The workspace size for the MLA kernel + :rtype: int + """ + if split_kv == 1: + return 0 + # Decode packs heads into a physical 128-wide MMA-M tile. For H < 128, + # split-KV partials can still touch the padded head lanes before + # reduction, so size the workspace for max(H, 128). Mirrors the same + # padding applied in initialize_workspace(). See #3235. + workspace_heads = max(H, 128) + return B * workspace_heads * S * split_kv * (D + 1) * acc_dtype.width // 8 + + @cute.jit + def initialize_workspace( + self, + H: cutlass.Int32, + D: cutlass.Int32, + S: cutlass.Int32, + B: cutlass.Int32, + split_kv: cutlass.Int32, + acc_dtype: Type[cutlass.Numeric], + workspace: cute.Tensor, + ) -> tuple[cute.Tensor, cute.Tensor]: + """Initialize the workspace for the MLA kernel. Construct the intermediate tensors + acc_o and acc_lse. + + :param H: The height of the output tensor C + :type H: cutlass.Int32 + :param D: The depth of the output tensor C + :type D: cutlass.Int32 + :param S: The sequence length of the output tensor C + :type S: cutlass.Int32 + :param B: The batch size of the output tensor C + :type B: cutlass.Int32 + :param split_kv: The split key-value of the output tensor C + :type split_kv: cutlass.Int32 + :param acc_dtype: The data type of the output tensor C + :type acc_dtype: Type[cutlass.Numeric] + :param workspace: The workspace tensor + :type workspace: cute.Tensor + + :return: The output tensor C and the workspace tensor + :rtype: tuple[cute.Tensor, cute.Tensor] + """ + acc_o, acc_lse = None, None + if cutlass.const_expr(workspace is not None): + # Pad head dim to the physical 128-wide MMA-M tile. Without this, + # H<128 split-KV partials write past the workspace. See #3235. + workspace_H = cutlass.max(H, cutlass.Int32(128)) + align = 256 // self.q_dtype.width + acc_o_layout = cute.make_layout( + (workspace_H, split_kv, D, S, B), + stride=( + cute.assume(split_kv * D, align), + cute.assume(D, align), + 1, + cute.assume(split_kv * workspace_H * D, align), + cute.assume(workspace_H * split_kv * S * D, align), + ), + ) + acc_o_iter = cute.recast_ptr(workspace.iterator, dtype=acc_dtype) + acc_o = cute.make_tensor(acc_o_iter, acc_o_layout) + acc_lse_layout = cute.make_layout( + (workspace_H, split_kv, S, B), + stride=( + split_kv, + 1, + workspace_H * split_kv, + workspace_H * split_kv * S, + ), + ) + acc_lse_iter = cute.recast_ptr( + workspace.iterator + cute.cosize(acc_o_layout) * acc_dtype.width // 8, + dtype=acc_dtype, + ) + acc_lse = cute.make_tensor(acc_lse_iter, acc_lse_layout) + return acc_o, acc_lse + + @staticmethod + def can_implement( + B: int, + S: int, + K: int, + H: int, + L: int, + R: int, + in_dtype: Type[cutlass.Numeric], + out_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + lse_dtype: Type[cutlass.Numeric], + mma_qk_tiler_mn: Tuple[int, int], + mma_pv_tiler_mn: Tuple[int, int], + is_persistent: bool, + is_var_seq: bool, + is_var_split_kv: bool, + page_size: int, + ) -> bool: + """Check if the MLA kernel can be implemented. + + :param B: The batch size of the output tensor C + :type B: int + :param S: The sequence length of the output tensor C + :type S: int + :param K: The width of the output tensor KV + :type K: int + :param H: The number of heads of the output tensor C + :type H: int + :param L: The number of latent dimensions of the tensor KV + :type L: int + :param R: The number of rope dimensions of the tensor C_rope + :type R: int + :param in_dtype: The data type of the input tensor + :type in_dtype: Type[cutlass.Numeric] + :param out_dtype: The data type of the output tensor + :type out_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param lse_dtype: The data type of the log-sum-exp + :type lse_dtype: Type[cutlass.Numeric] + :param mma_qk_tiler_mn: The tile shape of the query-key matrix multiplication + :type mma_qk_tiler_mn: Tuple[int, int] + :param mma_pv_tiler_mn: The tile shape of the probability-value matrix multiplication + :type mma_pv_tiler_mn: Tuple[int, int] + :param is_persistent: Whether to use persistent kernel optimization + :type is_persistent: bool + :param is_var_seq: Whether to use variable sequence length + :type is_var_seq: bool + :param is_var_split_kv: Whether to use variable split_kv + :type is_var_split_kv: bool + :param page_size: The page size of the page table + :type page_size: int + + :return: Whether the MLA kernel can be implemented + :rtype: bool + """ + if L != 512 or R != 64: + return False + if in_dtype not in [cutlass.Float8E4M3FN]: + return False + if out_dtype not in [cutlass.Float8E4M3FN, cutlass.BFloat16]: + return False + if acc_dtype != cutlass.Float32 or lse_dtype != cutlass.Float32: + return False + # page size equals 1 is prohibited by tma specification, not 128B aligned. + if mma_qk_tiler_mn[1] % page_size != 0 or page_size == 1: + return False + if mma_qk_tiler_mn[0] != mma_pv_tiler_mn[0] or mma_qk_tiler_mn[0] != 128: + return False + if is_var_split_kv and not is_var_seq: + return False + if H > 128: + return False + if S <= 0 or S > 4: + return False + if K <= 0: + return False + return True + + +def run( + batch_size: int, + seq_len_q: int, + seq_len_k: int, + num_heads: int, + latent_dim: int, + rope_dim: int, + in_dtype: Type[cutlass.Numeric], + out_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + lse_dtype: Type[cutlass.Numeric], + mma_qk_tiler_mn: Tuple[int, int], + mma_pv_tiler_mn: Tuple[int, int], + split_kv: int, + is_persistent: bool, + is_var_seq: bool, + is_var_split_kv: bool, + page_size: int, + softmax_scale: float, + output_scale: float, + skip_correction_threshold: float, + tolerance: float, + warmup_iterations: int, + iterations: int, + skip_ref_check: bool, + use_cold_l2: bool, + enable_pdl: bool = False, + **kwargs, +): + """Execute Multi-Head Latent Attention (MLA) on Blackwell architecture and validate results. + + This function creates random input tensors for query latent/rope, compressed latent/rope, and value, + then performs the complete MLA computation pipeline. It supports configurable data types, tiling parameters, + page table, variable sequence length, and variable split_kv. Results can be validated against a PyTorch reference + implementation or run multiple times for performance measurement. + + :param batch_size: Batch size + :type batch_size: int + :param seq_len_q: Sequence length of Q + :type seq_len_q: int + :param seq_len_k: Sequence length of K + :type seq_len_k: int + :param num_heads: Number of heads + :type num_heads: int + :param latent_dim: dimension of query/compressed latent + :type latent_dim: int + :param rope_dim: dimension of query/compressed rope + :type rope_dim: int + :param in_dtype: Input data type for query/compressed latent/rope tensors + :type in_dtype: Type[cutlass.Numeric] + :param out_dtype: Output data type for attention output + :type out_dtype: Type[cutlass.Numeric] + :param acc_dtype: Accumulator data type for query-key matrix multiplication + :type acc_dtype: Type[cutlass.Numeric] + :param lse_dtype: Accumulator data type for log-sum-exp + :type lse_dtype: Type[cutlass.Numeric] + :param mma_qk_tiler_mn: Matrix multiply accumulate tile shape (M, N) for query-key matrix multiplication + :type mma_qk_tiler_mn: Tuple[int, int] + :param mma_pv_tiler_mn: Matrix multiply accumulate tile shape (M, N) for probability-value matrix multiplication + :type mma_pv_tiler_mn: Tuple[int, int] + :param split_kv: Split key-value + :type split_kv: int + :param is_persistent: Whether to use persistent kernel optimization + :type is_persistent: bool + :param is_var_seq: Whether to use variable sequence length + :type is_var_seq: bool + :param is_var_split_kv: Whether to use variable split_kv + :type is_var_split_kv: bool + :param page_size: Page size of the page table + :type page_size: int + :param softmax_scale: Attention score scaling factor + :type softmax_scale: float + :param output_scale: Output scaling factor + :type output_scale: float + :param skip_correction_threshold: Threshold to skip correction + :type skip_correction_threshold: float + :param tolerance: Maximum acceptable error for validation + :type tolerance: float + :param warmup_iterations: Number of warmup iterations + :type warmup_iterations: int + :param iterations: Number of iterations to run for performance testing + :type iterations: int + :param skip_ref_check: Skip validation against reference implementation + :type skip_ref_check: bool + :param use_cold_l2: Whether to use cold L2 cache + :type use_cold_l2: bool + + :raises ValueError: If input shapes are incompatible or head dimension is unsupported + :raises RuntimeError: If GPU is unavailable for computation + """ + + print("Running Blackwell MLA test with:") + print(f" batch_size: {batch_size}") + print(f" seq_len_q: {seq_len_q}") + print(f" seq_len_k: {seq_len_k}") + print(f" num_heads: {num_heads}") + print(f" latent_dim: {latent_dim}") + print(f" rope_dim: {rope_dim}") + print(f" in_dtype: {in_dtype}") + print(f" out_dtype: {out_dtype}") + print(f" acc_dtype: {acc_dtype}") + print(f" mma_qk_tiler_mn: {mma_qk_tiler_mn}") + print(f" mma_pv_tiler_mn: {mma_pv_tiler_mn}") + print(f" split_kv: {split_kv}") + print(f" is_persistent: {is_persistent}") + print(f" is_var_seq: {is_var_seq}") + print(f" is_var_split_kv: {is_var_split_kv}") + print(f" page_size: {page_size}") + print(f" softmax_scale: {softmax_scale}") + print(f" output_scale: {output_scale}") + print(f" skip_correction_threshold: {skip_correction_threshold}") + print(f" tolerance: {tolerance}") + print(f" warmup_iterations: {warmup_iterations}") + print(f" iterations: {iterations}") + print(f" skip_ref_check: {skip_ref_check}") + print(f" use_cold_l2: {use_cold_l2}") + + import torch + import cutlass.torch as cutlass_torch + + # Prepare pytorch tensors: Q, K, V (random from 0 to 2) and O (all zero) + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + if not BlackwellMultiHeadLatentAttentionForwardFP8.can_implement( + batch_size, + seq_len_q, + seq_len_k, + num_heads, + latent_dim, + rope_dim, + in_dtype, + out_dtype, + acc_dtype, + lse_dtype, + mma_qk_tiler_mn, + mma_pv_tiler_mn, + is_persistent, + is_var_seq, + is_var_split_kv, + page_size, + ): + raise TypeError( + f"Unsupported testcase {batch_size}, {seq_len_q}, {seq_len_k}, {num_heads}, {latent_dim}, {rope_dim}, {in_dtype}, {out_dtype}, {acc_dtype}, {lse_dtype}, {mma_qk_tiler_mn}, {mma_pv_tiler_mn}, {split_kv}, {is_persistent}, {is_var_seq}, {is_var_split_kv}, {page_size}" + ) + + torch.manual_seed(1111) + + def create_data_tensor( + B, + HK, + D, + dtype, + is_dynamic_layout=True, + page_table=None, + cache_seqs=None, + is_lse=False, + seq_len_q=None, + ): + shape = (B, HK, D) + if page_table is not None: + if cache_seqs is not None: + max_seq_len = torch.max(cache_seqs) + shape = (B * ceil_div(max_seq_len, page_size), page_size, D) + else: + shape = (B * ceil_div(HK, page_size), page_size, D) + + if seq_len_q is not None: + shape = (B, seq_len_q, HK, D) + + # Contiguous row-major: last dim has stride 1 (highest stride_order value = fastest) + if is_lse: + shape = (B, seq_len_q, HK) + leading_dim = 2 + stride_order = (0, 1, 2) + elif seq_len_q is not None: + leading_dim = 3 + stride_order = (0, 1, 2, 3) + else: + leading_dim = 2 + stride_order = (0, 1, 2) + + init_config = cutlass.torch.RandomInitConfig(min_val=-2, max_val=2) + + torch_dtype = ( + cutlass_torch.dtype(dtype) if dtype != cutlass.Float8E4M3FN else torch.int8 + ) + + # Create contiguous dtype torch tensor (cpu) — no permute + torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( + shape, + torch_dtype, + init_type=cutlass.torch.TensorInitType.RANDOM, + init_config=init_config, + ) + + # Create dtype torch tensor (gpu) + torch_tensor_gpu = torch_tensor_cpu.cuda() + + # Create f32 torch tensor (cpu) + f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32) + + # Create dtype cute tensor (gpu) + cute_tensor = from_dlpack(torch_tensor_gpu, assumed_align=16) + cute_tensor.element_type = dtype + if is_dynamic_layout: + cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=leading_dim) + if not is_lse: + cute_tensor = cute_tensor.mark_compact_shape_dynamic( + mode=leading_dim, + stride_order=stride_order, + divisibility=(128 // dtype.width), + ) + + cute_tensor = cutlass_torch.convert_cute_tensor( + f32_torch_tensor, + cute_tensor, + dtype, + is_dynamic_layout=is_dynamic_layout, + ) + + return f32_torch_tensor, cute_tensor, torch_tensor_gpu + + def create_cache_seqs(batch_size, seq_len_k, is_var_seq): + cache_seqs_ref = torch.ones(batch_size, dtype=torch.int32) * seq_len_k + cache_seqs_gpu = cache_seqs_ref.cuda() + cache_seqs = from_dlpack(cache_seqs_gpu, assumed_align=16).mark_layout_dynamic() + if is_var_seq: + max_seq_len = seq_len_k + min_seq_len = int(seq_len_k * 0.8) + cache_seqs_ref = cutlass_torch.create_and_permute_torch_tensor( + (batch_size,), + torch.int32, + init_type=cutlass.torch.TensorInitType.RANDOM, + init_config=cutlass.torch.RandomInitConfig( + min_val=min_seq_len, max_val=max_seq_len + 1 + ), + ) + cache_seqs_gpu = cache_seqs_ref.cuda() + cache_seqs = from_dlpack( + cache_seqs_gpu, + assumed_align=16, + ).mark_layout_dynamic() + return cache_seqs_ref, cache_seqs, cache_seqs_gpu + + def create_page_table(batch_size, seq_len_k, is_var_seq, page_size): + max_seq_len = seq_len_k if not is_var_seq else torch.max(cache_seqs_ref) + page_count = ceil_div(max_seq_len, page_size) + page_table_ref = torch.empty([batch_size, page_count], dtype=torch.int32) + # use transposed index for page table to make sure the value is in bound of `batch_size * seq_len_block`. In practice, the value could be any positive values. This setting is only for testing purpose. + for b in range(batch_size): + for j in range(page_count): + page_table_ref[b, j] = b + j * batch_size + page_table_gpu = page_table_ref.cuda() # contiguous [B, page_count] + page_table = from_dlpack(page_table_gpu, assumed_align=16).mark_layout_dynamic( + leading_dim=1 + ) + return page_table_ref, page_table, page_table_gpu + + def create_block_split_kvs( + batch_size, + split_kv, + cache_seqs_ref, + is_var_split_kv, + mma_qk_tiler_mn, + cluster_shape_mnk, + max_active_clusters, + ): + block_split_kvs_ref, block_split_kvs, block_split_kvs_gpu = None, None, None + # check if split_kv is valid otherwise do auto setting of split_kv + if is_var_split_kv: + block_split_kvs_ref = torch.zeros([batch_size], dtype=torch.int32) + for b in range(batch_size): + block_split_kvs_ref[b] = ( + BlackwellMultiHeadLatentAttentionForwardFP8.get_split_kv( + batch_size, + seq_len_q, + cache_seqs_ref[b].item(), + mma_qk_tiler_mn, + max_active_clusters * cluster_shape_mnk[0], + ) + ) + split_kv = torch.max(block_split_kvs_ref).item() + block_split_kvs_gpu = block_split_kvs_ref.cuda() + block_split_kvs = from_dlpack( + block_split_kvs_gpu, assumed_align=16 + ).mark_layout_dynamic() + elif split_kv <= 0: + split_kv = BlackwellMultiHeadLatentAttentionForwardFP8.get_split_kv( + batch_size, + seq_len_q, + cache_seqs_ref[0].item(), + mma_qk_tiler_mn, + max_active_clusters * cluster_shape_mnk[0], + ) + return split_kv, block_split_kvs_ref, block_split_kvs, block_split_kvs_gpu + + def create_workspace( + num_heads, seq_len_q, latent_dim, batch_size, split_kv, acc_dtype + ): + workspace_size = BlackwellMultiHeadLatentAttentionForwardFP8.get_workspace_size( + num_heads, + seq_len_q, + latent_dim, + batch_size, + split_kv, + acc_dtype, + ) + + workspace, workspace_torch = None, None + if workspace_size > 0: + workspace_torch = torch.empty([workspace_size], dtype=torch.int8).cuda() + workspace = from_dlpack(workspace_torch, assumed_align=32) + return workspace, workspace_torch + + cache_seqs_ref, cache_seqs, cache_seqs_torch = create_cache_seqs( + batch_size, seq_len_k, is_var_seq + ) + page_table_ref, page_table, page_table_torch = create_page_table( + batch_size, seq_len_k, is_var_seq, page_size + ) + cluster_shape_mnk = (2, 1, 1) + hardware_info = utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters( + cluster_shape_mnk[0] * cluster_shape_mnk[1] + ) + split_kv, block_split_kvs_ref, block_split_kvs, block_split_kvs_torch = ( + create_block_split_kvs( + batch_size, + split_kv, + cache_seqs_ref, + is_var_split_kv, + mma_qk_tiler_mn, + cluster_shape_mnk, + max_active_clusters, + ) + ) + + q_latent_ref, q_latent, q_latent_torch = create_data_tensor( + batch_size, + num_heads, + latent_dim, + in_dtype, + is_dynamic_layout=True, + seq_len_q=seq_len_q, + ) + q_rope_ref, q_rope, q_rope_torch = create_data_tensor( + batch_size, + num_heads, + rope_dim, + in_dtype, + is_dynamic_layout=True, + seq_len_q=seq_len_q, + ) + + c_latent_ref, c_latent, c_latent_torch = create_data_tensor( + batch_size, + seq_len_k, + latent_dim, + in_dtype, + is_dynamic_layout=True, + page_table=page_table, + cache_seqs=cache_seqs_ref, + ) + c_rope_ref, c_rope, c_rope_torch = create_data_tensor( + batch_size, + seq_len_k, + rope_dim, + in_dtype, + is_dynamic_layout=True, + page_table=page_table, + cache_seqs=cache_seqs_ref, + ) + o_ref, o, o_torch = create_data_tensor( + batch_size, + num_heads, + latent_dim, + out_dtype, + is_dynamic_layout=True, + seq_len_q=seq_len_q, + ) + lse_ref, lse, lse_torch = create_data_tensor( + batch_size, + num_heads, + 1, + lse_dtype, + is_dynamic_layout=True, + is_lse=True, + seq_len_q=seq_len_q, + ) + workspace, workspace_torch = create_workspace( + num_heads, seq_len_q, latent_dim, batch_size, split_kv, acc_dtype + ) + + mla = BlackwellMultiHeadLatentAttentionForwardFP8( + acc_dtype, + lse_dtype, + mma_qk_tiler_mn, + mma_pv_tiler_mn, + max_active_clusters, + page_size, + skip_correction_threshold, + is_persistent, + is_var_seq, + is_var_split_kv, + enable_pdl, + ) + + # Get current CUDA stream from PyTorch + torch_stream = torch.cuda.current_stream() + # Get the raw stream pointer as a CUstream + stream = cuda.CUstream(torch_stream.cuda_stream) + + # compile mla kernel + compiled_mla = cute.compile( + mla, + q_latent, + q_rope, + c_latent, + c_rope, + page_table, + o, + lse, + workspace, + split_kv, + cache_seqs, + block_split_kvs, + softmax_scale, + output_scale, + stream, + options="--opt-level 2", + ) + + def torch_reference_mla( + q_latent, + q_rope, + c_latent, + c_rope, + page_table, + cache_seqs, + softmax_scale=1.0, + output_scale=1.0, + ): + # Ref tensors are now contiguous: + # q_latent/q_rope: [B, S_q, H, D] + # c_latent/c_rope: [num_pages, page_size, D] + # Concat along last dim and reshape for SDPA [B, S_q, H, D_total] + q_ref = torch.cat([q_latent, q_rope], dim=3) + # KV cache: concat along last dim, already [num_pages, page_size, D_total] + page_count = page_table_ref.shape[1] + k_ref_paged = torch.cat([c_latent, c_rope], dim=2).reshape( + batch_size * page_count, page_size, latent_dim + rope_dim + ) + v_ref_paged = c_latent.reshape(batch_size * page_count, page_size, latent_dim) + + if is_var_seq: + max_seq_len = torch.max(cache_seqs_ref) + else: + max_seq_len = seq_len_k + + k_ref = torch.zeros([batch_size, 1, max_seq_len, latent_dim + rope_dim]) + v_ref = torch.zeros([batch_size, 1, max_seq_len, latent_dim]) + k_ref = torch.index_select( + k_ref_paged, 0, torch.flatten(page_table_ref) + ).reshape(batch_size, 1, -1, latent_dim + rope_dim)[:, :, :max_seq_len, :] + v_ref = torch.index_select( + v_ref_paged, 0, torch.flatten(page_table_ref) + ).reshape(batch_size, 1, -1, latent_dim)[:, :, :max_seq_len, :] + for b in range(batch_size): + k_ref[b, :, cache_seqs_ref[b] :, :] = 0 + v_ref[b, :, cache_seqs_ref[b] :, :] = 0 + import torch.nn.functional as F + + o_ref = F.scaled_dot_product_attention( + q_ref, + k_ref, + v_ref, + attn_mask=None, + dropout_p=0.0, + scale=softmax_scale, + is_causal=False, + ) + s_ref = torch.einsum("bhld,bhsd->bhls", q_ref, k_ref) + s_ref_max, s_ref_max_pos = torch.max(s_ref, dim=-1, keepdim=True) + softmax_scale_log2 = LOG2_E * softmax_scale + s_ref_sum = torch.sum( + torch.exp2((s_ref - s_ref_max) * softmax_scale_log2), dim=-1, keepdim=True + ) + + lse_ref = s_ref_max * softmax_scale_log2 + torch.log2(s_ref_sum) + lse_ref = lse_ref.squeeze(3) # [B, S_q, H] + o_ref = o_ref * output_scale + # o_ref already [B, S_q, H, D_latent] — matches contiguous output layout + + return o_ref, lse_ref + + if skip_correction_threshold > 0.0: + print( + "Skipping correction verification since skip_correction_threshold is greater than 0.0..." + ) + skip_ref_check = True + if not skip_ref_check: + # Execute kernel once for reference checking + compiled_mla( + q_latent, + q_rope, + c_latent, + c_rope, + page_table, + o, + lse, + workspace, + split_kv, + cache_seqs, + block_split_kvs, + softmax_scale, + output_scale, + stream, + ) + torch.cuda.synchronize() + + print("Verifying results...") + if in_dtype == cutlass.Float8E4M3FN: + tolerance = 0.13 + o_ref, lse_ref = torch_reference_mla( + q_latent_ref, + q_rope_ref, + c_latent_ref, + c_rope_ref, + page_table, + cache_seqs, + softmax_scale, + output_scale, + ) + + if out_dtype in [cutlass.Float8E5M2, cutlass.Float8E4M3FN]: + # {$nv-internal-release begin} + # todo: not sure why, but the below `cute.testing.convert` will cause bus error occasionally in local and ci. + # {$nv-internal-release end} + # convert o back to f32 for comparison + o_fp32, o_fp32_torch = cutlass_torch.cute_tensor_like( + torch.empty(*o_torch.shape, dtype=torch.float32), + cutlass.Float32, + is_dynamic_layout=True, + assumed_align=16, + ) + cute.testing.convert(o, o_fp32) + o = o_fp32_torch.cpu() + ref_fp8, _ = cutlass_torch.cute_tensor_like( + torch.empty(*o_ref.shape, dtype=torch.uint8), + out_dtype, + is_dynamic_layout=True, + assumed_align=16, + ) + o_ref_gpu = o_ref.cuda() + o_ref_f32 = from_dlpack(o_ref_gpu).mark_layout_dynamic(leading_dim=3) + + # convert ref : f32 -> fp8 -> f32 + cute.testing.convert(o_ref_f32, ref_fp8) + cute.testing.convert(ref_fp8, o_ref_f32) + + o_ref = o_ref_gpu.cpu() + else: + o = o_torch.cpu().to(torch.float32) + lse = lse_torch.cpu() + lse_ref = lse_ref.to(cutlass.torch.dtype(lse_dtype)) + # Assert close results + torch.testing.assert_close(o, o_ref, atol=tolerance, rtol=1e-05) + torch.testing.assert_close(lse, lse_ref, atol=tolerance, rtol=1e-05) + print("Results verified successfully!") + + def generate_tensors(): + _, cache_seqs, _ = create_cache_seqs(batch_size, seq_len_k, is_var_seq) + _, page_table, _ = create_page_table( + batch_size, seq_len_k, is_var_seq, page_size + ) + _split_kv, _, block_split_kvs, _ = create_block_split_kvs( + batch_size, + split_kv, + cache_seqs_ref, + is_var_split_kv, + mma_qk_tiler_mn, + cluster_shape_mnk, + max_active_clusters, + ) + + _, q_latent, _ = create_data_tensor( + batch_size, + num_heads, + latent_dim, + in_dtype, + is_dynamic_layout=True, + seq_len_q=seq_len_q, + ) + _, q_rope, _ = create_data_tensor( + batch_size, + num_heads, + rope_dim, + in_dtype, + is_dynamic_layout=True, + seq_len_q=seq_len_q, + ) + + _, c_latent, _ = create_data_tensor( + batch_size, + seq_len_k, + latent_dim, + in_dtype, + is_dynamic_layout=True, + page_table=page_table, + cache_seqs=cache_seqs_ref, + ) + _, c_rope, _ = create_data_tensor( + batch_size, + seq_len_k, + rope_dim, + in_dtype, + is_dynamic_layout=True, + page_table=page_table, + cache_seqs=cache_seqs_ref, + ) + _, o, _ = create_data_tensor( + batch_size, + num_heads, + latent_dim, + out_dtype, + is_dynamic_layout=True, + seq_len_q=seq_len_q, + ) + _, lse, _ = create_data_tensor( + batch_size, + num_heads, + 1, + lse_dtype, + is_dynamic_layout=True, + is_lse=True, + seq_len_q=seq_len_q, + ) + workspace, workspace_torch = create_workspace( + num_heads, seq_len_q, latent_dim, batch_size, _split_kv, acc_dtype + ) + return testing.JitArguments( + q_latent, + q_rope, + c_latent, + c_rope, + page_table, + o, + lse, + workspace, + _split_kv, + cache_seqs, + block_split_kvs, + softmax_scale, + output_scale, + stream, + ) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + q_latent_torch.numel() * q_latent_torch.element_size() + + q_rope_torch.numel() * q_rope_torch.element_size() + + c_latent_torch.numel() * c_latent_torch.element_size() + + c_rope_torch.numel() * c_rope_torch.element_size() + + o_torch.numel() * o_torch.element_size() + + lse_torch.numel() * lse_torch.element_size() + + cache_seqs_torch.numel() * cache_seqs_torch.element_size() + ) + one_workspace_bytes += ( + page_table_torch.numel() * page_table_torch.element_size() + ) + if is_var_split_kv: + one_workspace_bytes += ( + block_split_kvs_torch.numel() * block_split_kvs_torch.element_size() + ) + if workspace_torch is not None: + one_workspace_bytes += ( + workspace_torch.numel() * workspace_torch.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + avg_time_us = testing.benchmark( + compiled_mla, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return avg_time_us # Return execution time in microseconds diff --git a/flashinfer/cute_dsl/attention/monolithic/mla_helpers.py b/flashinfer/cute_dsl/attention/monolithic/mla_helpers.py new file mode 100644 index 0000000000..ac2bee49df --- /dev/null +++ b/flashinfer/cute_dsl/attention/monolithic/mla_helpers.py @@ -0,0 +1,304 @@ +# Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +import cutlass +import cutlass.cute as cute + + +class MLAStaticTileSchedulerParams: + def __init__( + self, + is_persistent: bool, + problem_shape_b: cute.Int32, + problem_shape_s: cute.Int32, + cluster_shape_mnk: cute.Shape, + split_kv: cutlass.Int32, + *, + problem_shape_b_fdd: cute.FastDivmodDivisor = None, + problem_shape_s_fdd: cute.FastDivmodDivisor = None, + split_kv_fdd: cute.FastDivmodDivisor = None, + loc=None, + ip=None, + ): + """The static tile scheduler parameters prepared for MLA static tile scheduler. + + :param is_persistent: Whether to use persistent kernel mode + :type is_persistent: bool + :param problem_shape_b: The shape of the problem + :type problem_shape_b: cute.Int32 + :param problem_shape_s: The shape of the problem in sequence length Q dimension + :type problem_shape_s: cute.Int32 + :param cluster_shape_mnk: The shape of the cluster + :type cluster_shape_mnk: cute.Shape + :param split_kv: The scalar factor for split KV + """ + self.is_persistent = is_persistent + self.problem_shape_b = problem_shape_b + self.problem_shape_s = problem_shape_s + self.problem_shape_b_fdd = problem_shape_b_fdd + self.problem_shape_s_fdd = problem_shape_s_fdd + self.cluster_shape_mnk = cluster_shape_mnk + self.split_kv = split_kv + self.split_kv_fdd = split_kv_fdd + if cutlass.const_expr(problem_shape_b_fdd is None): + self.problem_shape_b_fdd = cute.fast_divmod_create_divisor( + problem_shape_b, loc=loc, ip=ip + ) + if cutlass.const_expr(problem_shape_s_fdd is None): + self.problem_shape_s_fdd = cute.fast_divmod_create_divisor( + problem_shape_s, loc=loc, ip=ip + ) + if cutlass.const_expr(split_kv_fdd is None): + self.split_kv_fdd = cute.fast_divmod_create_divisor( + split_kv, loc=loc, ip=ip + ) + self.loc = loc + self.ip = ip + + def __extract_mlir_values__(self): + values = cutlass.extract_mlir_values(self.problem_shape_b) + values += cutlass.extract_mlir_values(self.problem_shape_s) + values += cutlass.extract_mlir_values(self.split_kv) + values += cutlass.extract_mlir_values(self.problem_shape_b_fdd) + values += cutlass.extract_mlir_values(self.problem_shape_s_fdd) + values += cutlass.extract_mlir_values(self.split_kv_fdd) + return values + + def __new_from_mlir_values__(self, values): + problem_shape_b = cutlass.new_from_mlir_values( + self.problem_shape_b, (values[0],) + ) + problem_shape_s = cutlass.new_from_mlir_values( + self.problem_shape_s, (values[1],) + ) + split_kv = cutlass.new_from_mlir_values(self.split_kv, (values[2],)) + problem_shape_b_fdd = cutlass.new_from_mlir_values( + self.problem_shape_b_fdd, (values[3],) + ) + problem_shape_s_fdd = cutlass.new_from_mlir_values( + self.problem_shape_s_fdd, (values[4],) + ) + split_kv_fdd = cutlass.new_from_mlir_values(self.split_kv_fdd, (values[5],)) + return MLAStaticTileSchedulerParams( + self.is_persistent, + problem_shape_b, + problem_shape_s, + self.cluster_shape_mnk, + split_kv, + problem_shape_b_fdd=problem_shape_b_fdd, + problem_shape_s_fdd=problem_shape_s_fdd, + split_kv_fdd=split_kv_fdd, + loc=self.loc, + ) + + +def create_mla_static_tile_scheduler_params( + is_persistent: bool, + problem_shape_b: cute.Int32, + problem_shape_s: cute.Int32, + cluster_shape_mnk: cute.Shape, + split_kv: cutlass.Int32, +) -> MLAStaticTileSchedulerParams: + return MLAStaticTileSchedulerParams( + is_persistent, problem_shape_b, problem_shape_s, cluster_shape_mnk, split_kv + ) + + +class WorkTileInfo: + def __init__(self, blk_coord: cute.Coord, is_valid: bool): + self.blk_coord = blk_coord + self.is_valid = cutlass.Boolean(is_valid) + + def __extract_mlir_values__(self): + values = cutlass.extract_mlir_values(self.blk_coord) + values += cutlass.extract_mlir_values(self.is_valid) + return values + + def __new_from_mlir_values__(self, values): + new_tile_idx = cutlass.new_from_mlir_values(self.blk_coord, values[:-1]) + new_is_valid_tile = cutlass.new_from_mlir_values(self.is_valid, [values[-1]]) + return WorkTileInfo(new_tile_idx, new_is_valid_tile) + + @property + def is_valid_tile(self) -> cutlass.Boolean: + return self.is_valid + + @property + def tile_idx(self) -> cute.Coord: + return self.blk_coord + + +class MLAStaticTileScheduler: + def __init__( + self, + params: MLAStaticTileSchedulerParams, + current_work_linear_idx: cutlass.Int32, + blk_coord: cute.Coord, + grid_shape: cute.Shape, + *, + is_valid: bool = True, + loc=None, + ip=None, + ): + """The static tile scheduler for MLA split kv kernel. + Based on `is_persistent`, it provides 2 modes for use: + - Persistent mode: Launch fixed blocks and reschedule the data blocks. + - Non-persistent mode: Launch dynamic blocks and exit when the current work is done. + + :param params: The static tile scheduler parameters + :type params: MLAStaticTileSchedulerParams + :param current_work_linear_idx: The linear index of the current work + :type current_work_linear_idx: cutlass.Int32 + :param blk_coord: The coordinate of the current work + :type blk_coord: cute.Coord + :param grid_shape: The shape of the grid + :type grid_shape: cute.Shape + :param is_valid: Whether the current work is valid + :type is_valid: bool + """ + self.params = params + self.blk_coord = blk_coord + self.grid_shape = grid_shape + self.current_work_linear_idx = current_work_linear_idx + if params.is_persistent: + self.persistent_blk_layout = cute.make_layout( + ( + params.cluster_shape_mnk[0], + params.problem_shape_s, + params.problem_shape_b, + params.split_kv, + ), + loc=loc, + ip=ip, + ) + self.num_blocks = cute.size(self.persistent_blk_layout, loc=loc, ip=ip) + # Used for persistent scheduling + self.num_persistent_sm = cute.size(grid_shape, loc=loc, ip=ip) + else: + self.is_valid = is_valid + self.loc = loc + self.ip = ip + + @staticmethod + def get_grid_shape( + params: MLAStaticTileSchedulerParams, + max_active_clusters: int, + *, + loc=None, + ip=None, + ) -> cute.Shape: + # called by host + grid_shape = ( + params.cluster_shape_mnk[0], + params.problem_shape_b * params.problem_shape_s, + params.split_kv, + ) + if params.is_persistent: + return ( + cutlass.min( + max_active_clusters * cute.size(params.cluster_shape_mnk), + cute.size(grid_shape, loc=loc, ip=ip), + ), + 1, + 1, + ) + else: + return grid_shape + + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + is_valid = ( + self.current_work_linear_idx < self.num_blocks + if self.params.is_persistent + else self.is_valid + ) + + if self.params.is_persistent: + current_work_cluster_batch, cluster_idx = ( + self.current_work_linear_idx // self.params.cluster_shape_mnk[0], + self.current_work_linear_idx % self.params.cluster_shape_mnk[0], + ) + current_work_s_batch, s_idx = divmod( + current_work_cluster_batch, self.params.problem_shape_s_fdd + ) + current_work_b_batch, b_idx = divmod( + current_work_s_batch, self.params.problem_shape_b_fdd + ) + _, split_kv_idx = divmod(current_work_b_batch, self.params.split_kv_fdd) + + blk_coord = (cluster_idx, s_idx, b_idx, split_kv_idx) + else: + s_idx, b_idx = divmod(self.blk_coord[1], self.params.problem_shape_b_fdd) + blk_coord = (self.blk_coord[0], s_idx, b_idx, self.blk_coord[2]) + + return WorkTileInfo(blk_coord, is_valid) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def advance_to_next_work(self, *, advance_count=1, loc=None, ip=None): + if self.params.is_persistent: + self.current_work_linear_idx += advance_count * self.num_persistent_sm + else: + self.is_valid = False + + def __extract_mlir_values__(self): + values = cutlass.extract_mlir_values(self.params) + values.extend(cutlass.extract_mlir_values(self.current_work_linear_idx)) + values.extend(cutlass.extract_mlir_values(self.blk_coord)) + values.extend(cutlass.extract_mlir_values(self.grid_shape)) + return values + + def __new_from_mlir_values__(self, values): + assert len(values) == 13 + new_params = cutlass.new_from_mlir_values(self.params, values[0:6]) + new_current_work_linear_idx = cutlass.new_from_mlir_values( + self.current_work_linear_idx, [values[6]] + ) + new_blk_coord = cutlass.new_from_mlir_values(self.blk_coord, values[7:10]) + new_grid_shape = cutlass.new_from_mlir_values(self.grid_shape, values[10:]) + return MLAStaticTileScheduler( + new_params, new_current_work_linear_idx, new_blk_coord, new_grid_shape + ) + + +def create_mla_static_tile_scheduler( + params: MLAStaticTileSchedulerParams, + blk_coord: cute.Coord, + grid_shape: cute.Shape, +) -> MLAStaticTileScheduler: + return MLAStaticTileScheduler(params, blk_coord[0], blk_coord, grid_shape) + + +LOG2_E = 1.4426950408889634074 +# avoid register indexing on array. +MAX_SPLITS = 256 + + +def ceil_div(a: int, b: int) -> int: + return (a + b - 1) // b diff --git a/flashinfer/cute_dsl/attention/wrappers/batch_mla.py b/flashinfer/cute_dsl/attention/wrappers/batch_mla.py index e3d203db22..b2a917c23e 100644 --- a/flashinfer/cute_dsl/attention/wrappers/batch_mla.py +++ b/flashinfer/cute_dsl/attention/wrappers/batch_mla.py @@ -29,7 +29,7 @@ ) from ..config import AttentionFusion -from ..fusion.variant import AttentionVariant, StandardAttention +from ..fusion.variant import AttentionVariant, AttentionWithSink, StandardAttention from ..mla_decode import BlackwellMultiLatentAttentionForward from ..mla_decode_fp8 import BlackwellMultiLatentAttentionForwardFP8 from ..mla_config import MLAConfig @@ -258,8 +258,15 @@ def _compile_mla_kernel( Uses ``@functools.cache`` so repeated calls with the same arguments return the previously compiled kernel in microseconds rather than recompiling (~3 s). For standard attention pass ``variant=None`` - (the default); for custom variants pass the variant instance (hashable - by identity). + (the default); for custom variants pass the variant instance. + + Variants must define value-based ``__hash__``/``__eq__`` for the cache + to work correctly across freshly-constructed instances — see + ``AttentionWithSink`` for an example. Variants that don't override + these (the base ``AttentionVariant`` class doesn't) hash by Python + identity, which is fine for the wrapper pattern (``plan()`` stores + the variant on ``self`` and reuses it across ``run()`` calls) but + breaks any caller that reconstructs the variant per-call. ``AttentionFusion`` is constructed *inside* this function so it never appears in the cache key (it is unhashable). @@ -681,6 +688,7 @@ def cute_dsl_mla_decode( out_dtype: Optional[torch.dtype] = None, is_var_seq: bool = True, enable_pdl: Optional[bool] = None, + sinks: Optional[torch.Tensor] = None, ) -> torch.Tensor: """CuTe DSL MLA decode kernel for Blackwell SM100 (modular variant). @@ -721,6 +729,12 @@ def cute_dsl_mla_decode( enable_pdl : Optional[bool], default=None Whether to enable Programmatic Dependent Launch (PDL). If None, auto-detects based on device capability. + sinks : Optional[torch.Tensor], default=None + Per-head sink values added to the softmax denominator on the first + KV tile (modular-only feature, implemented via the + ``AttentionWithSink`` variant). Shape ``(num_qo_heads,)``; will be + cast to float32 internally. When ``None`` (default), runs standard + softmax attention. Returns ------- @@ -812,6 +826,32 @@ def cute_dsl_mla_decode( is_persistent = not is_var_seq + # Optional variant (currently only AttentionWithSink, exposed via the + # `sinks=` kwarg). Building the variant + extracting params here mirrors + # what BatchMLADecodeCuteDSLWrapper.plan() does. AttentionWithSink + # defines value-based __hash__/__eq__ keyed on the sinks tensor shape and + # dtype, so re-creating the variant per call still hits + # _compile_mla_kernel's @functools.cache as long as those don't change. + variant: Optional[AttentionVariant] = None + params_torch: Optional[torch.Tensor] = None + params_shape: Optional[tuple] = None + if sinks is not None: + # Validate on the *input* tensor: post-conversion .to() returns a + # fresh contiguous tensor, so checking after would silently mask a + # caller's mistake (and never fire). + if not sinks.is_contiguous(): + raise ValueError( + f"sinks tensor must be contiguous, got strides {sinks.stride()} " + f"for shape {sinks.shape}" + ) + variant = AttentionWithSink(sinks) + # NOTE: .to(dtype).to(device) is a no-op (returns same tensor) when + # already fp32 + on query.device — the common case. When sinks is + # supplied in a different dtype/device, this allocates per call; + # callers in tight loops should pre-cast. + params_torch = variant.extra_params.to(torch.float32).to(query.device) + params_shape = tuple(params_torch.shape) + # Validate configuration (cached, negligible overhead after first call) _check_can_implement( torch_dtype=q_dtype, @@ -828,7 +868,8 @@ def cute_dsl_mla_decode( enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl - # Get compiled kernel (cached after first compile) + # Get compiled kernel (cached after first compile). Pass the variant + # only when non-standard so the StandardAttention cache key stays stable. compiled_kernel = _compile_mla_kernel( torch_dtype=q_dtype, torch_out_dtype=o_dtype, @@ -841,6 +882,8 @@ def cute_dsl_mla_decode( skip_correction_threshold=skip_correction_threshold, is_workspace_size_zero=is_workspace_size_zero, enable_pdl=enable_pdl, + variant=variant, + params_shape=params_shape, ) # Call the kernel @@ -858,7 +901,7 @@ def cute_dsl_mla_decode( block_split_kvs, Float32(softmax_scale), Float32(output_scale), - None, # params_in (no variant in standalone function) + params_torch, # variant params tensor (None when no variant) ) if out is not None: diff --git a/flashinfer/mla/_core.py b/flashinfer/mla/_core.py index 894be4962c..1aa6d72793 100644 --- a/flashinfer/mla/_core.py +++ b/flashinfer/mla/_core.py @@ -648,6 +648,7 @@ def trtllm_batch_decode_with_kv_cache_mla( backend: str = "auto", is_var_seq: bool = True, uses_shared_paged_kv_idx: bool = True, + cute_dsl_impl: str = "auto", ) -> torch.Tensor: """ Parameters @@ -674,6 +675,10 @@ def trtllm_batch_decode_with_kv_cache_mla( When using ``trtllm-gen`` backend, it can be a ``torch.Tensor`` with dtype ``torch.float32``. When using ``cute-dsl`` backend, only ``float`` values are supported. sinks: additional value per head in the denominator of the softmax. + Supported by all three backends. On ``cute-dsl`` this requires + the modular implementation; ``cute_dsl_impl="auto"`` (the default) + promotes to modular automatically, and ``cute_dsl_impl="monolithic"`` + with sinks set raises :class:`ValueError`. skip_softmax_threshold_scale_factor: threshold scale factor for skipping softmax operations. Providing a value for this parameter enables skip-softmax sparsity as described in: https://arxiv.org/abs/2512.12087 If no value is provided, then standard attention is used. @@ -684,6 +689,9 @@ def trtllm_batch_decode_with_kv_cache_mla( When set to ``auto``, the backend will be chosen based on the device architecture and kernel availability. For sm_100 and sm_103 (blackwell architecture), ``auto`` will choose ``trtllm-gen`` backend. For sm_120 (blackwell architecture), ``auto`` will choose ``xqa`` backend. + The ``cute-dsl`` backend has two interchangeable implementations + (``monolithic`` and ``modular``) on the same shape/dtype envelope; + which one runs is controlled by the ``cute_dsl_impl`` kwarg below. is_var_seq : bool Whether the sequence length is variable. If True, the sequence length is variable. @@ -693,6 +701,17 @@ def trtllm_batch_decode_with_kv_cache_mla( True (default) uses vLLM/FlashInfer layout with a 2D page table. False uses TRT-LLM layout with a 3D page table ``[batch_size, 2, max_num_pages_per_seq]``. False is only supported for trtllm-gen backend. + cute_dsl_impl : str = "auto" + Which cute-dsl implementation to use. Honored only when + ``backend="cute-dsl"``; ignored for other backends. + + * ``"auto"`` (default) — picks monolithic by default, automatically + promoted to modular when the call uses a feature monolithic + doesn't support (currently ``sinks``). + * ``"modular"`` — strict. Always run the modular kernels. + * ``"monolithic"`` — strict. Always run the monolithic kernels; + raise :class:`ValueError` if the call uses any modular-only + feature (e.g. ``sinks``). Note ---- @@ -862,10 +881,27 @@ def trtllm_batch_decode_with_kv_cache_mla( "cute-dsl backend (MLA decode kernel) does not support tensor bmm2_scale, " "please pass a float value" ) + # `sinks` is supported via the modular AttentionWithSink variant; the + # dispatcher in flashinfer.cute_dsl.attention.mla_dispatch will force + # impl="modular" when sinks is set (monolithic has no variant path). + # The public sinks signature is Optional[List[torch.Tensor]] for + # legacy reasons, but every backend (xqa, trtllm-gen, cute-dsl) + # treats it as a single per-head tensor. Normalise here so the + # downstream cute-dsl path sees a tensor or None; reject the + # ambiguous len>1 case loudly rather than silently dropping tail + # entries. + cute_dsl_sinks: Optional[torch.Tensor] = None if sinks is not None: - raise ValueError( - "cute-dsl backend (MLA decode kernel) does not support sinks" - ) + if isinstance(sinks, (list, tuple)): + if len(sinks) != 1: + raise ValueError( + f"cute-dsl backend (MLA decode kernel) expects sinks " + f"to be a single tensor or a length-1 list/tuple; got " + f"len={len(sinks)}." + ) + cute_dsl_sinks = sinks[0] + else: + cute_dsl_sinks = sinks if sparse_mla_top_k > 0: raise ValueError( "cute-dsl backend (MLA decode kernel) does not support sparse_mla_top_k" @@ -894,6 +930,8 @@ def trtllm_batch_decode_with_kv_cache_mla( out=out, is_var_seq=is_var_seq, enable_pdl=enable_pdl, + sinks=cute_dsl_sinks, + cute_dsl_impl=cute_dsl_impl, ) else: raise ValueError(f"Backend {backend} not supported") diff --git a/tests/attention/test_cute_dsl_mla_decode.py b/tests/attention/test_cute_dsl_mla_decode.py index 6147d10c86..4938f922ad 100644 --- a/tests/attention/test_cute_dsl_mla_decode.py +++ b/tests/attention/test_cute_dsl_mla_decode.py @@ -30,6 +30,16 @@ def skip_if_unsupported(): pytest.skip("CuTe DSL not available") +# Tests that exercise the standalone cute_dsl_mla_decode function or the +# public trtllm_batch_decode_with_kv_cache_mla(backend="cute-dsl") path +# pass this fixture's value as the cute_dsl_impl= kwarg, exercising both +# implementations explicitly. Variant tests use BatchMLADecodeCuteDSLWrapper +# directly (which is modular-only) and are not parametrized here. +@pytest.fixture(params=["modular", "monolithic"], ids=["modular", "monolithic"]) +def cute_dsl_impl(request): + return request.param + + def torch_reference_mla( q_nope, q_rope, @@ -103,7 +113,7 @@ def torch_reference_mla( @pytest.mark.parametrize("q_len", [1, 2]) @pytest.mark.parametrize("enable_pdl", [True, False]) def test_cute_dsl_mla_decode_fp16( - batch_size, seq_len_k, page_size, dtype, q_len, enable_pdl + batch_size, seq_len_k, page_size, dtype, q_len, enable_pdl, cute_dsl_impl ): """Test FP16/BF16 MLA decode kernel.""" skip_if_unsupported() @@ -162,6 +172,7 @@ def test_cute_dsl_mla_decode_fp16( output_scale=output_scale, is_var_seq=False, enable_pdl=enable_pdl, + cute_dsl_impl=cute_dsl_impl, ) # Reference @@ -193,7 +204,9 @@ def test_cute_dsl_mla_decode_fp16( @pytest.mark.parametrize("seq_len_k", [128, 512, 2048]) @pytest.mark.parametrize("page_size", [32, 128]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_cute_dsl_mla_decode_variable_seq_len(batch_size, seq_len_k, page_size, dtype): +def test_cute_dsl_mla_decode_variable_seq_len( + batch_size, seq_len_k, page_size, dtype, cute_dsl_impl +): """Test MLA decode with variable sequence lengths across the batch.""" skip_if_unsupported() @@ -242,6 +255,7 @@ def test_cute_dsl_mla_decode_variable_seq_len(batch_size, seq_len_k, page_size, softmax_scale=softmax_scale, output_scale=output_scale, is_var_seq=True, + cute_dsl_impl=cute_dsl_impl, ) kv_flat = kv_cache.reshape(-1, D_qk) @@ -270,7 +284,7 @@ def test_cute_dsl_mla_decode_variable_seq_len(batch_size, seq_len_k, page_size, @pytest.mark.parametrize("seq_len_k", [128, 512]) @pytest.mark.parametrize("num_heads", [128, 64]) def test_cute_dsl_mla_decode_via_api( - batch_size, seq_len_k, num_heads, page_size=128, enable_pdl=False + batch_size, seq_len_k, num_heads, cute_dsl_impl, page_size=128, enable_pdl=False ): """Test MLA decode via the trtllm_batch_decode_with_kv_cache_mla API with cute-dsl backend.""" skip_if_unsupported() @@ -321,6 +335,7 @@ def test_cute_dsl_mla_decode_via_api( backend="cute-dsl", is_var_seq=False, enable_pdl=enable_pdl, + cute_dsl_impl=cute_dsl_impl, ) assert out.shape == (batch_size, q_len, num_heads, latent_dim) @@ -330,7 +345,9 @@ def test_cute_dsl_mla_decode_via_api( @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("seq_len_k", [128, 512]) @pytest.mark.parametrize("enable_pdl", [True, False]) -def test_cute_dsl_vs_trtllm_gen(batch_size, seq_len_k, enable_pdl, page_size=64): +def test_cute_dsl_vs_trtllm_gen( + batch_size, seq_len_k, enable_pdl, cute_dsl_impl, page_size=64 +): """Test cute-dsl backend output matches trtllm-gen backend output.""" skip_if_unsupported() @@ -385,7 +402,10 @@ def test_cute_dsl_vs_trtllm_gen(batch_size, seq_len_k, enable_pdl, page_size=64) **common_args, backend="trtllm-gen", is_var_seq=False ) out_cute_dsl = trtllm_batch_decode_with_kv_cache_mla( - **common_args, backend="cute-dsl", is_var_seq=False + **common_args, + backend="cute-dsl", + is_var_seq=False, + cute_dsl_impl=cute_dsl_impl, ) torch.testing.assert_close( @@ -402,7 +422,7 @@ def test_cute_dsl_vs_trtllm_gen(batch_size, seq_len_k, enable_pdl, page_size=64) @pytest.mark.parametrize("num_heads", [128, 64]) @pytest.mark.parametrize("enable_pdl", [False]) def test_cute_dsl_mla_decode_fp8( - batch_size, seq_len_k, page_size, num_heads, enable_pdl + batch_size, seq_len_k, page_size, num_heads, enable_pdl, cute_dsl_impl ): """Test FP8 MLA decode kernel against FP32 reference.""" skip_if_unsupported() @@ -456,6 +476,7 @@ def test_cute_dsl_mla_decode_fp8( softmax_scale=softmax_scale, output_scale=output_scale, enable_pdl=enable_pdl, + cute_dsl_impl=cute_dsl_impl, ) assert out.dtype == torch.bfloat16 @@ -852,6 +873,105 @@ def test_cute_dsl_mla_decode_attention_sink(batch_size, seq_len_k, page_size): torch.testing.assert_close(out, ref_out_cast, atol=1e-2, rtol=1e-2) +@pytest.mark.parametrize("cute_dsl_impl_arg", ["auto", "modular"]) +def test_cute_dsl_mla_decode_via_api_with_sinks(cute_dsl_impl_arg): + """Public trtllm_batch_decode_with_kv_cache_mla(backend='cute-dsl', sinks=...) + works end-to-end on both ``cute_dsl_impl="auto"`` (which auto-promotes + to modular due to sinks) and ``cute_dsl_impl="modular"`` (explicit). + The ``cute_dsl_impl="monolithic"`` case is the strict-error contract + covered separately by test_via_api_monolithic_with_sinks_raises below. + + Single shape is sufficient — sinks correctness across shapes is + already covered by test_cute_dsl_mla_decode_attention_sink; this + test pins the dispatcher's auto/modular branches at the public API. + """ + skip_if_unsupported() + batch_size, seq_len_k, page_size = 4, 2048, 64 + + from flashinfer.mla import trtllm_batch_decode_with_kv_cache_mla + + torch.manual_seed(42) + dtype = torch.bfloat16 + + ( + query, + kv_cache, + block_tables, + seq_lens, + workspace_buffer, + num_heads, + latent_dim, + rope_dim, + ) = _make_mla_test_data(batch_size, seq_len_k, page_size, dtype) + sink = torch.randn((num_heads,), dtype=dtype, device="cuda") + + # The public API takes a 4D KV cache: [num_pages, 1, page_size, D] + out = trtllm_batch_decode_with_kv_cache_mla( + query=query, + kv_cache=kv_cache.unsqueeze(1), + workspace_buffer=workspace_buffer, + qk_nope_head_dim=latent_dim, + kv_lora_rank=latent_dim, + qk_rope_head_dim=rope_dim, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=seq_len_k, + bmm1_scale=1.0 / (latent_dim**0.5), + bmm2_scale=1.0, + sinks=sink, + backend="cute-dsl", + is_var_seq=False, + cute_dsl_impl=cute_dsl_impl_arg, + ) + assert out.shape == (batch_size, query.shape[1], num_heads, latent_dim) + assert torch.isfinite(out).all(), ( + "public-API cute-dsl with sinks produced non-finite values" + ) + + +def test_via_api_monolithic_with_sinks_raises(): + """Strict-mode contract: cute_dsl_impl='monolithic' + sinks must raise + ValueError, never silently substitute modular. Inputs are minimal — + we just need to reach the dispatcher's resolver, not actually run the + kernel.""" + skip_if_unsupported() + + from flashinfer.mla import trtllm_batch_decode_with_kv_cache_mla + + torch.manual_seed(42) + dtype = torch.bfloat16 + ( + query, + kv_cache, + block_tables, + seq_lens, + workspace_buffer, + num_heads, + latent_dim, + rope_dim, + ) = _make_mla_test_data(batch_size=1, seq_len_k=128, page_size=64, dtype=dtype) + sink = torch.randn((num_heads,), dtype=dtype, device="cuda") + + with pytest.raises(ValueError, match="monolithic.*sinks.*modular"): + trtllm_batch_decode_with_kv_cache_mla( + query=query, + kv_cache=kv_cache.unsqueeze(1), + workspace_buffer=workspace_buffer, + qk_nope_head_dim=latent_dim, + kv_lora_rank=latent_dim, + qk_rope_head_dim=rope_dim, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=128, + bmm1_scale=1.0 / (latent_dim**0.5), + bmm2_scale=1.0, + sinks=sink, + backend="cute-dsl", + is_var_seq=False, + cute_dsl_impl="monolithic", + ) + + @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("seq_len_k", [256, 2048]) @pytest.mark.parametrize("page_size", [64, 128]) diff --git a/tests/attention/test_trtllm_gen_mla.py b/tests/attention/test_trtllm_gen_mla.py index bac935ea62..835fdb6b65 100755 --- a/tests/attention/test_trtllm_gen_mla.py +++ b/tests/attention/test_trtllm_gen_mla.py @@ -823,6 +823,11 @@ def test_trtllm_batch_decode_mla( and layer_dimensions.head_dimensions == smaller_mla_dimensions ): pytest.skip("cute-dsl MLA requires 512 latent dim and 64 rope dim") + # Note: the cute-dsl branch runs the default implementation (auto → + # monolithic, since this test passes no sinks). We don't parametrize + # cute_dsl_impl here because this test's purpose is public-API smoke + # across all three backends; the modular vs monolithic matrix lives + # in test_cute_dsl_mla_decode.py. trtllm_batch_decode_mla( layer_dimensions, batch_size, From ebcec6c6ad3c69613bd7173be6469c2c0683887a Mon Sep 17 00:00:00 2001 From: jingzec Date: Wed, 13 May 2026 02:04:44 -0700 Subject: [PATCH 2/9] add fold_sq for optimization of num heads < 128 --- benchmarks/bench_trtllm_gen_mla.py | 5 +- .../attention/monolithic/mla_decode.py | 27 +- .../attention/monolithic/mla_decode_fp16.py | 242 ++++++++++++++++-- .../attention/monolithic/mla_decode_fp8.py | 242 ++++++++++++++++-- tests/attention/test_cute_dsl_mla_decode.py | 175 +++++++++++++ 5 files changed, 648 insertions(+), 43 deletions(-) diff --git a/benchmarks/bench_trtllm_gen_mla.py b/benchmarks/bench_trtllm_gen_mla.py index f846c70633..16482f219a 100644 --- a/benchmarks/bench_trtllm_gen_mla.py +++ b/benchmarks/bench_trtllm_gen_mla.py @@ -167,10 +167,7 @@ def bench_trtllm_mla( ) args = parser.parse_args() - if args.backend == "cute-dsl": - q_lens = [1, 2, 4] - else: - q_lens = [1, 2, 4, 8, 16] + q_lens = [1, 2, 4, 8, 16] # Main perf sweep — without sinks, same shape grid as the original # script. Doubling every cell with a sinks pass would explode runtime diff --git a/flashinfer/cute_dsl/attention/monolithic/mla_decode.py b/flashinfer/cute_dsl/attention/monolithic/mla_decode.py index 08f4e616e8..67c8a9a831 100644 --- a/flashinfer/cute_dsl/attention/monolithic/mla_decode.py +++ b/flashinfer/cute_dsl/attention/monolithic/mla_decode.py @@ -51,8 +51,17 @@ def _get_split_kv_and_workspace_size( split_kv = BlackwellMultiHeadLatentAttentionForwardFP16.get_split_kv_simplified( B, q_len, max_active_blocks ) + # When folding S_q into heads, the workspace dims are the effective dims + # (num_heads * F, q_len // F). get_workspace_size already pads H<128 to + # 128, so passing num_heads_eff and seq_len_q_eff yields the right size. + mma_qk_tile_m = 128 + fold_sq_ratio = BlackwellMultiHeadLatentAttentionForwardFP16.compute_fold_sq_ratio( + H, q_len, mma_qk_tile_m + ) + num_heads_eff = H * fold_sq_ratio + seq_len_q_eff = q_len // fold_sq_ratio workspace_size = BlackwellMultiHeadLatentAttentionForwardFP16.get_workspace_size( - H, q_len, kv_lora_rank, B, split_kv, cutlass.Float32 + num_heads_eff, seq_len_q_eff, kv_lora_rank, B, split_kv, cutlass.Float32 ) return split_kv, workspace_size @@ -114,6 +123,8 @@ def _get_compiled_mla_kernel( page_size: int, kv_lora_rank: int, qk_rope_head_dim: int, + num_heads: int, + seq_len_q: int, is_persistent: bool, is_var_seq: bool, is_var_split_kv: bool, @@ -145,6 +156,14 @@ def _get_compiled_mla_kernel( cutlass_dtype = torch_to_cutlass_dtype(torch_dtype) cutlass_out_dtype = torch_to_cutlass_dtype(torch_out_dtype) + # Derive the seq_len_q-into-heads fold factor. F > 1 means the kernel + # repacks the [H, S_q] tile to [H*F, S_q/F] internally so MTP / spec-decoding + # with H < 128 fully populates the 128-wide MMA-M tile. + fold_sq_ratio = KernelClass.compute_fold_sq_ratio( + num_heads, seq_len_q, mma_qk_tiler_mn[0] + ) + fold_sq = fold_sq_ratio > 1 + kernel_obj = KernelClass( acc_dtype=cutlass.Float32, lse_dtype=cutlass.Float32, @@ -159,6 +178,9 @@ def _get_compiled_mla_kernel( is_var_seq=is_var_seq, is_var_split_kv=is_var_split_kv, enable_pdl=enable_pdl, + num_heads=num_heads, + seq_len_q=seq_len_q, + fold_sq=fold_sq, ) # All dimensions as sym_int — this matches the original kernel's use of @@ -431,6 +453,7 @@ def cute_dsl_mla_decode( # for fix-length, set is_persistent to True; otherwise, set to False. is_persistent = not is_var_seq + print(f"is_persistent: {is_persistent}") # Validate configuration (cached, negligible overhead after first call) _check_can_implement( @@ -457,6 +480,8 @@ def cute_dsl_mla_decode( page_size=page_size, kv_lora_rank=kv_lora_rank, qk_rope_head_dim=qk_rope_head_dim, + num_heads=H, + seq_len_q=q_len, is_persistent=is_persistent, is_var_seq=is_var_seq, is_var_split_kv=is_var_split_kv, diff --git a/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.py b/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.py index aa3b2cd475..5ee65218b5 100644 --- a/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.py +++ b/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.py @@ -171,6 +171,9 @@ def __init__( is_var_seq: bool, is_var_split_kv: bool, enable_pdl: bool, + num_heads: int = 128, + seq_len_q: int = 1, + fold_sq: bool = False, ): """Initializes the configuration for a Blackwell Multi-Head Latent Attention (MLA) kernel. @@ -196,6 +199,17 @@ def __init__( :type is_var_split_kv: bool :param enable_pdl: Whether to use PDL :type enable_pdl: bool + :param num_heads: Number of attention heads (pre-fold). Used for the + per-row spec-decoding (MTP) causal mask q_token_index computation. + :type num_heads: int + :param seq_len_q: Query sequence length (pre-fold). Combined with + ``num_heads`` to derive the per-row q_token used by the causal mask. + :type seq_len_q: int + :param fold_sq: Whether to fold tokens of ``seq_len_q`` into the head + dimension so the M tile becomes [F sub_q_tok][num_heads heads]. + Required when ``num_heads < mma_qk_tiler_mn[0]`` and ``seq_len_q > 1`` + so the M tile is fully populated. + :type fold_sq: bool """ self.latent_dim = 512 @@ -211,6 +225,24 @@ def __init__( self.is_var_seq = is_var_seq self.is_var_split_kv = is_var_split_kv self.enable_pdl = enable_pdl + # Original (pre-fold) num_heads and seq_len_q used for per-row + # spec-decoding (MTP) causal q_token_index computation. When fold_sq is + # True the M tile is laid out as [F sub_q_tok][num_heads heads]; the + # full q_tok for row r is blk_coord[1] * F + (r // num_heads). + self.num_heads = num_heads + self.seq_len_q = seq_len_q + # fold_sq (caller-controlled): whether the folding code path is enabled. + # fold_sq_ratio (derived): fold factor F ≥ 1; the largest divisor of + # seq_len_q with num_heads * F ≤ M_tile and F ≤ seq_len_q. When the + # caller passes fold_sq=False, the kernel ignores the ratio. + # When fold_sq=True but the derived ratio is 1, the folding branch + # is taken with F=1 (a no-op transform). + self.fold_sq = fold_sq + self.fold_sq_ratio = ( + BlackwellMultiHeadLatentAttentionForwardFP16.compute_fold_sq_ratio( + num_heads, seq_len_q, mma_qk_tiler_mn[0] + ) + ) self.cluster_shape_mnk = (2, 1, 1) self.use_2cta_instrs = True # When using 2 CTAs with m=128: warps 0-1 handle accumulation for first half [0, n/2), @@ -423,6 +455,48 @@ def _reinterpret_3d_kv(t): ), ) + # When num_heads < M tile, fold up to F = fold_sq_ratio tokens of + # seq_len_q into the head dimension so M_eff = num_heads * F (≤ M_tile). + # E.g., H=32, S_q=4 → F=4, M_eff=128, S_q_eff=1 + # E.g., H=32, S_q=8 → F=4, M_eff=128, S_q_eff=2 + # This works because MLA shares KV across all heads/queries independently. + # Tensor layout: [H, D, S_q, B] → [H*F, D, S_q/F, B]; relies on + # stride_S == stride_H * H (always true for contiguous [B, S_q, H, D] + # tensors after _reinterpret_4d). + if cutlass.const_expr(self.fold_sq): + F = self.fold_sq_ratio + + def _fold_sq_4d(t): + return cute.make_tensor( + t.iterator, + cute.make_layout( + ( + t.shape[0] * F, + t.shape[1], + t.shape[2] // F, + t.shape[3], + ), + stride=( + t.stride[0], + t.stride[1], + t.stride[2] * F, + t.stride[3], + ), + ), + ) + + q_latent = _fold_sq_4d(q_latent) + q_rope = _fold_sq_4d(q_rope) + o = _fold_sq_4d(o) + # LSE: [H, S_q, B] → [H*F, S_q/F, B] + lse = cute.make_tensor( + lse.iterator, + cute.make_layout( + (lse.shape[0] * F, lse.shape[1] // F, lse.shape[2]), + stride=(lse.stride[0], lse.stride[1] * F, lse.stride[2]), + ), + ) + acc_o, acc_lse = self.initialize_workspace( q_latent.shape[0], q_latent.shape[1], @@ -2324,8 +2398,20 @@ def compute( correction_factor = self.acc_dtype(1) common_params.p_cor_pipeline.producer_acquire(p_cor_producer_state) - # no mask applied - while k_tile_count > 1: + # Number of tiles from the global-K end that may contain causal-masked + # positions. Min k_bound = K - (S_q-1), which can span up to + # ceil((fold_sq_ratio-2)/tile_N)+1 tiles (tile-boundary-crossing case). For + # S_q=1 this reduces to 1 tile — identical to a plain K-bound check. + tile_n = self.mma_qk_tiler[1] + mask_tile_count = (self.fold_sq_ratio - 2 + tile_n - 1) // tile_n + 1 + + # first_mask_tile_idx is the global index of the first tile that may + # need masking. Runtime because it depends on K (per-batch in + # var-seq / split-KV). + first_mask_tile_idx = k_tile_total - mask_tile_count + + # Phase 1: pure unmasked bulk tiles (all columns strictly < min k_bound). + while k_tile_count > 1 and k_index < first_mask_tile_idx: ( mma_s_consumer_state, p_mma_producer_state, @@ -2349,8 +2435,37 @@ def compute( k_index = k_index + 1 k_tile_count = k_tile_count - 1 - # mask applied + # Phase 2: intermediate tiles that overlap the causal/K-bound region + # but are not this work-split's final tile. + while k_tile_count > 1: + ( + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + ) = self.softmax( + common_params, + softmax_params, + k_index, + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + True, + False, + ) + k_index = k_index + 1 + k_tile_count = k_tile_count - 1 + + # Phase 3: this work-split's final tile. if cutlass.const_expr(common_params.mAccO is not None): + # Split-KV: only apply mask when this final tile is globally in + # the mask region (covers both last-split last-tile and straddling + # splits). Runtime comparison. ( mma_s_consumer_state, p_mma_producer_state, @@ -2368,7 +2483,7 @@ def compute( row_max, row_sum, correction_factor, - k_index == k_tile_total - 1, + k_index >= first_mask_tile_idx, True, ) else: @@ -2519,7 +2634,7 @@ def softmax( row_max: cutlass.Float32, row_sum: cutlass.Float32, correction_factor: cutlass.Float32, - is_last_tile: bool, + apply_mask: bool, is_local_last_tile: cutlass.Boolean, ) -> tuple[ pipeline.PipelineState, @@ -2549,8 +2664,10 @@ def softmax( :type row_sum: cutlass.Float32 :param correction_factor: The correction factor :type correction_factor: cutlass.Float32 - :param is_last_tile: Whether the last tile - :type is_last_tile: bool + :param apply_mask: Whether the tile needs K-bound / causal masking (Python bool + for the unmasked/masked bulk loops; runtime cutlass.Boolean for the + split-KV final iter where mask only applies on the global last tile). + :type apply_mask: bool | cutlass.Boolean :param is_local_last_tile: Whether the last tile is local :type is_local_last_tile: cutlass.Boolean @@ -2593,18 +2710,39 @@ def softmax( tTR_rAcc = cute.make_fragment_like(tTR_tS, self.acc_dtype) row_max_new = row_max + # Spec-decoding (MTP) causal mask: each row represents one (q_token, head) + # pair; row r's effective K bound is K - (S_q - 1 - q_tok(r)). + # With fold factor F = self.fold_sq_ratio (fold_sq=True), the M tile is + # laid out as [F sub_q_tok][num_heads heads] and there are S_q/F outer + # chunks indexed by blk_coord[1]: + # q_tok(r) = blk_coord[1] * F + (r_global // num_heads) + # r_global = row_in_cta + cluster_idx * (M_tile / cluster_m) + # When fold_sq=False this reduces to q_tok = blk_coord[1]. For S_q=1 + # this further reduces to k_bound = K (plain K-bound check). + # Masked positions are filled with a large negative sentinel (not -inf) + # to avoid NaN propagation when an entire row becomes masked. + cta_m_rows = self.mma_qk_tiler[0] // self.cluster_shape_mnk[0] arch = BaseDSL._get_dsl().get_arch_enum() if cutlass.const_expr(arch >= Arch.sm_100 and arch <= Arch.sm_100f): cute.copy(tmem_tiled_copy, tTR_tAcc, tTR_rAcc) for i in cutlass.range_constexpr(cute.size(tTR_rAcc)): - if is_last_tile: + if apply_mask: + if cutlass.const_expr(self.fold_sq): + q_tok = ( + common_params.blk_coord[1] * self.fold_sq_ratio + + (tTR_tS[i][0] + common_params.blk_coord[0] * cta_m_rows) + // self.num_heads + ) + else: + q_tok = common_params.blk_coord[1] + k_bound = common_params.K - (self.seq_len_q - 1) + q_tok tTR_rAcc[i] = ( tTR_rAcc[i] if cute.elem_less( tTR_tS[i][1] + self.mma_qk_tiler[1] * k_index, - common_params.K, + k_bound, ) - else -self.acc_dtype.inf + else self.acc_dtype(-1.0e6) ) # reduction for row_max row_max_new = tTR_rAcc.load().reduce(cute.ReductionOp.MAX, row_max_new, 0) @@ -2631,21 +2769,32 @@ def softmax( (tTR_rAcc_red, tTR_rMax), ) tTR_rAcc = cute.make_tensor(tTR_rAcc_red.iterator, tTR_rAcc.layout) - if is_last_tile: + if apply_mask: for i in cutlass.range_constexpr(cute.size(tTR_rAcc)): + if cutlass.const_expr(self.fold_sq): + q_tok = ( + common_params.blk_coord[1] * self.fold_sq_ratio + + (tTR_tS[i][0] + common_params.blk_coord[0] * cta_m_rows) + // self.num_heads + ) + else: + q_tok = common_params.blk_coord[1] + k_bound = common_params.K - (self.seq_len_q - 1) + q_tok tTR_rAcc[i] = ( tTR_rAcc[i] if cute.elem_less( tTR_tS[i][1] + self.mma_qk_tiler[1] * k_index, - common_params.K, + k_bound, ) - else -self.acc_dtype.inf + else self.acc_dtype(-1.0e6) ) - # reduction for row_max + # reduction for row_max after manual masking row_max_new = tTR_rAcc.load().reduce( cute.ReductionOp.MAX, row_max_new, 0 ) else: + # sm_103 pre-computed max via reduction is valid here because + # tTR_rAcc is unmodified (no mask applied to this tile). row_max_new = cute.arch.fmax(row_max_new, tTR_rMax[0]) # if warps in N is 2, reduce row_max across warps (0, 1) and (2, 3) @@ -3381,6 +3530,25 @@ def _compute_grid( return tile_sched_params, grid + @staticmethod + def compute_fold_sq_ratio(num_heads: int, seq_len_q: int, m_tile: int) -> int: + """Derive the seq_len_q-into-heads fold factor F. + + Returns the largest integer F such that: + - F divides seq_len_q evenly + - num_heads * F ≤ m_tile + - 1 ≤ F ≤ seq_len_q + + F=1 means no folding (i.e. ``fold_sq`` should be False at the caller). + """ + if num_heads >= m_tile: + return 1 + max_fold = min(seq_len_q, m_tile // num_heads) + for f in range(max_fold, 0, -1): + if seq_len_q % f == 0: + return f + return 1 + @staticmethod def get_workspace_size( H: int, @@ -3555,9 +3723,12 @@ def can_implement( return False if is_var_split_kv and not is_var_seq: return False - if H > 128: + if mma_qk_tiler_mn[0] < H: return False - if S < 1 or S > 4: + # When H < M tile, fold up to F tokens of S into H (M_eff = H*F ≤ M_tile). + # F is auto-picked as the largest divisor of S with H*F ≤ M_tile. + # F=1 always works, so any (H ≤ M_tile, S ≥ 1) is implementable. + if S < 1: return False if K <= 0: return False @@ -3889,6 +4060,14 @@ def create_workspace( max_active_clusters = hardware_info.get_max_active_clusters( cluster_shape_mnk[0] * cluster_shape_mnk[1] ) + # When num_heads < M tile, fold up to F tokens of seq_len_q into heads, + # capped so num_heads * F ≤ M_tile. F must divide seq_len_q evenly. + # The class derives F internally; we mirror it here for split/workspace. + # H*F may be < M tile; TMA zero-fills OOB rows and epilogue guards skip padded output. + fold_sq_ratio = BlackwellMultiHeadLatentAttentionForwardFP16.compute_fold_sq_ratio( + num_heads, seq_len_q, mma_qk_tiler_mn[0] + ) + fold_sq = fold_sq_ratio > 1 split_kv, block_split_kvs_ref, block_split_kvs, block_split_kvs_torch = ( create_block_split_kvs( batch_size, @@ -3953,8 +4132,11 @@ def create_workspace( is_lse=True, seq_len_q=seq_len_q, ) + # Use effective dimensions for workspace when folding S_q into heads + num_heads_eff = num_heads * fold_sq_ratio + seq_len_q_eff = seq_len_q // fold_sq_ratio workspace, workspace_torch = create_workspace( - num_heads, seq_len_q, latent_dim, batch_size, split_kv, acc_dtype + num_heads_eff, seq_len_q_eff, latent_dim, batch_size, split_kv, acc_dtype ) mla = BlackwellMultiHeadLatentAttentionForwardFP16( @@ -3969,6 +4151,9 @@ def create_workspace( is_var_seq, is_var_split_kv, enable_pdl, + num_heads=num_heads, + seq_len_q=seq_len_q, + fold_sq=fold_sq, ) # Get current CUDA stream from PyTorch @@ -4036,16 +4221,35 @@ def torch_reference_mla( v_ref[b, :, cache_seqs_ref[b] :, :] = 0 import torch.nn.functional as F + # Always-on spec-decoding (MTP) causal mask: for Q token qi ∈ [0, S_q) + # and batch b, valid KV positions are [0, cache_seqs_ref[b] - S_q + 1 + qi). + # For S_q=1 this reduces to the plain K-bound check. SDPA treats + # q_ref=[B, S_q, H, D_total] as batch=B, group=S_q, query-seq=H, + # dim=D_total, so the mask is indexed by the group (S_q) dim and + # broadcasts over the query-seq (H) dim. + S_q_actual = q_ref.shape[1] + max_K_len = k_ref.shape[2] + attn_mask = torch.zeros(batch_size, S_q_actual, 1, max_K_len, dtype=torch.bool) + for b in range(batch_size): + Kb = int(cache_seqs_ref[b]) + for qi in range(S_q_actual): + upper = max(0, Kb - S_q_actual + 1 + qi) + attn_mask[b, qi, 0, :upper] = True + attn_mask_sdpa = attn_mask.to(q_ref.device) if q_ref.is_cuda else attn_mask + o_ref = F.scaled_dot_product_attention( q_ref, k_ref, v_ref, - attn_mask=None, + attn_mask=attn_mask_sdpa, dropout_p=0.0, scale=softmax_scale, is_causal=False, ) s_ref = torch.einsum("bhld,bhsd->bhls", q_ref, k_ref) + s_ref = s_ref.masked_fill( + ~attn_mask.to(s_ref.device).expand_as(s_ref), float("-inf") + ) s_ref_max, s_ref_max_pos = torch.max(s_ref, dim=-1, keepdim=True) softmax_scale_log2 = LOG2_E * softmax_scale s_ref_sum = torch.sum( @@ -4202,7 +4406,7 @@ def generate_tensors(): seq_len_q=seq_len_q, ) workspace, workspace_torch = create_workspace( - num_heads, seq_len_q, latent_dim, batch_size, _split_kv, acc_dtype + num_heads_eff, seq_len_q_eff, latent_dim, batch_size, _split_kv, acc_dtype ) return testing.JitArguments( q_latent, diff --git a/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp8.py b/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp8.py index a4fcd119e4..b6648ff7a8 100644 --- a/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp8.py +++ b/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp8.py @@ -167,6 +167,9 @@ def __init__( is_var_seq: bool, is_var_split_kv: bool, enable_pdl: bool, + num_heads: int = 128, + seq_len_q: int = 1, + fold_sq: bool = False, ): """Initializes the configuration for a Blackwell Multi-Head Latent Attention (MLA) kernel. @@ -192,6 +195,17 @@ def __init__( :type is_var_split_kv: bool :param enable_pdl: Whether to use PDL :type enable_pdl: bool + :param num_heads: Number of attention heads (pre-fold). Used for the + per-row spec-decoding (MTP) causal mask q_token_index computation. + :type num_heads: int + :param seq_len_q: Query sequence length (pre-fold). Combined with + ``num_heads`` to derive the per-row q_token used by the causal mask. + :type seq_len_q: int + :param fold_sq: Whether to fold tokens of ``seq_len_q`` into the head + dimension so the M tile becomes [F sub_q_tok][num_heads heads]. + Required when ``num_heads < mma_qk_tiler_mn[0]`` and ``seq_len_q > 1`` + so the M tile is fully populated. + :type fold_sq: bool """ self.latent_dim = 512 @@ -207,6 +221,24 @@ def __init__( self.is_var_seq = is_var_seq self.is_var_split_kv = is_var_split_kv self.enable_pdl = enable_pdl + # Original (pre-fold) num_heads and seq_len_q used for per-row + # spec-decoding (MTP) causal q_token_index computation. When fold_sq is + # True the M tile is laid out as [F sub_q_tok][num_heads heads]; the + # full q_tok for row r is blk_coord[1] * F + (r // num_heads). + self.num_heads = num_heads + self.seq_len_q = seq_len_q + # fold_sq (caller-controlled): whether the folding code path is enabled. + # fold_sq_ratio (derived): fold factor F ≥ 1; the largest divisor of + # seq_len_q with num_heads * F ≤ M_tile and F ≤ seq_len_q. When the + # caller passes fold_sq=False, the kernel ignores the ratio. + # When fold_sq=True but the derived ratio is 1, the folding branch + # is taken with F=1 (a no-op transform). + self.fold_sq = fold_sq + self.fold_sq_ratio = ( + BlackwellMultiHeadLatentAttentionForwardFP8.compute_fold_sq_ratio( + num_heads, seq_len_q, mma_qk_tiler_mn[0] + ) + ) self.cluster_shape_mnk = (2, 1, 1) self.use_2cta_instrs = True # When using 2 CTAs with m=128: warps 0-1 handle accumulation for first half [0, n/2), @@ -418,6 +450,48 @@ def _reinterpret_3d_kv(t): ), ) + # When num_heads < M tile, fold up to F = fold_sq_ratio tokens of + # seq_len_q into the head dimension so M_eff = num_heads * F (≤ M_tile). + # E.g., H=32, S_q=4 → F=4, M_eff=128, S_q_eff=1 + # E.g., H=32, S_q=8 → F=4, M_eff=128, S_q_eff=2 + # This works because MLA shares KV across all heads/queries independently. + # Tensor layout: [H, D, S_q, B] → [H*F, D, S_q/F, B]; relies on + # stride_S == stride_H * H (always true for contiguous [B, S_q, H, D] + # tensors after _reinterpret_4d). + if cutlass.const_expr(self.fold_sq): + F = self.fold_sq_ratio + + def _fold_sq_4d(t): + return cute.make_tensor( + t.iterator, + cute.make_layout( + ( + t.shape[0] * F, + t.shape[1], + t.shape[2] // F, + t.shape[3], + ), + stride=( + t.stride[0], + t.stride[1], + t.stride[2] * F, + t.stride[3], + ), + ), + ) + + q_latent = _fold_sq_4d(q_latent) + q_rope = _fold_sq_4d(q_rope) + o = _fold_sq_4d(o) + # LSE: [H, S_q, B] → [H*F, S_q/F, B] + lse = cute.make_tensor( + lse.iterator, + cute.make_layout( + (lse.shape[0] * F, lse.shape[1] // F, lse.shape[2]), + stride=(lse.stride[0], lse.stride[1] * F, lse.stride[2]), + ), + ) + acc_o, acc_lse = self.initialize_workspace( q_latent.shape[0], q_latent.shape[1], @@ -2321,8 +2395,20 @@ def compute( correction_factor = self.acc_dtype(1) common_params.p_cor_pipeline.producer_acquire(p_cor_producer_state) - # no mask applied - while k_tile_count > 1: + # Number of tiles from the global-K end that may contain causal-masked + # positions. Min k_bound = K - (S_q-1), which can span up to + # ceil((fold_sq_ratio-2)/tile_N)+1 tiles (tile-boundary-crossing case). For + # S_q=1 this reduces to 1 tile — identical to a plain K-bound check. + tile_n = self.mma_qk_tiler[1] + mask_tile_count = (self.fold_sq_ratio - 2 + tile_n - 1) // tile_n + 1 + + # first_mask_tile_idx is the global index of the first tile that may + # need masking. Runtime because it depends on K (per-batch in + # var-seq / split-KV). + first_mask_tile_idx = k_tile_total - mask_tile_count + + # Phase 1: pure unmasked bulk tiles (all columns strictly < min k_bound). + while k_tile_count > 1 and k_index < first_mask_tile_idx: ( mma_s_consumer_state, p_mma_producer_state, @@ -2346,8 +2432,37 @@ def compute( k_index = k_index + 1 k_tile_count = k_tile_count - 1 - # mask applied + # Phase 2: intermediate tiles that overlap the causal/K-bound region + # but are not this work-split's final tile. + while k_tile_count > 1: + ( + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + ) = self.softmax( + common_params, + softmax_params, + k_index, + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + True, + False, + ) + k_index = k_index + 1 + k_tile_count = k_tile_count - 1 + + # Phase 3: this work-split's final tile. if cutlass.const_expr(common_params.mAccO is not None): + # Split-KV: only apply mask when this final tile is globally in + # the mask region (covers both last-split last-tile and straddling + # splits). Runtime comparison. ( mma_s_consumer_state, p_mma_producer_state, @@ -2365,7 +2480,7 @@ def compute( row_max, row_sum, correction_factor, - k_index == k_tile_total - 1, + k_index >= first_mask_tile_idx, True, ) else: @@ -2515,7 +2630,7 @@ def softmax( row_max: cutlass.Float32, row_sum: cutlass.Float32, correction_factor: cutlass.Float32, - is_last_tile: bool, + apply_mask: bool, is_local_last_tile: cutlass.Boolean, ) -> tuple[ pipeline.PipelineState, @@ -2545,8 +2660,10 @@ def softmax( :type row_sum: cutlass.Float32 :param correction_factor: The correction factor :type correction_factor: cutlass.Float32 - :param is_last_tile: Whether the last tile - :type is_last_tile: bool + :param apply_mask: Whether the tile needs K-bound / causal masking (Python bool + for the unmasked/masked bulk loops; runtime cutlass.Boolean for the + split-KV final iter where mask only applies on the global last tile). + :type apply_mask: bool | cutlass.Boolean :param is_local_last_tile: Whether the last tile is local :type is_local_last_tile: cutlass.Boolean @@ -2589,18 +2706,39 @@ def softmax( tTR_rAcc = cute.make_fragment_like(tTR_tS, self.acc_dtype) row_max_new = row_max + # Spec-decoding (MTP) causal mask: each row represents one (q_token, head) + # pair; row r's effective K bound is K - (S_q - 1 - q_tok(r)). + # With fold factor F = self.fold_sq_ratio (fold_sq=True), the M tile is + # laid out as [F sub_q_tok][num_heads heads] and there are S_q/F outer + # chunks indexed by blk_coord[1]: + # q_tok(r) = blk_coord[1] * F + (r_global // num_heads) + # r_global = row_in_cta + cluster_idx * (M_tile / cluster_m) + # When fold_sq=False this reduces to q_tok = blk_coord[1]. For S_q=1 + # this further reduces to k_bound = K (plain K-bound check). + # Masked positions are filled with a large negative sentinel (not -inf) + # to avoid NaN propagation when an entire row becomes masked. + cta_m_rows = self.mma_qk_tiler[0] // self.cluster_shape_mnk[0] arch = BaseDSL._get_dsl().get_arch_enum() if cutlass.const_expr(arch >= Arch.sm_100 and arch <= Arch.sm_100f): cute.copy(tmem_tiled_copy, tTR_tAcc, tTR_rAcc) for i in cutlass.range_constexpr(cute.size(tTR_rAcc)): - if is_last_tile: + if apply_mask: + if cutlass.const_expr(self.fold_sq): + q_tok = ( + common_params.blk_coord[1] * self.fold_sq_ratio + + (tTR_tS[i][0] + common_params.blk_coord[0] * cta_m_rows) + // self.num_heads + ) + else: + q_tok = common_params.blk_coord[1] + k_bound = common_params.K - (self.seq_len_q - 1) + q_tok tTR_rAcc[i] = ( tTR_rAcc[i] if cute.elem_less( tTR_tS[i][1] + self.mma_qk_tiler[1] * k_index, - common_params.K, + k_bound, ) - else -self.acc_dtype.inf + else self.acc_dtype(-1.0e6) ) # reduction for row_max row_max_new = tTR_rAcc.load().reduce(cute.ReductionOp.MAX, row_max_new, 0) @@ -2626,21 +2764,32 @@ def softmax( (tTR_rAcc_red, tTR_rMax), ) tTR_rAcc = cute.make_tensor(tTR_rAcc_red.iterator, tTR_rAcc.layout) - if is_last_tile: + if apply_mask: for i in cutlass.range_constexpr(cute.size(tTR_rAcc)): + if cutlass.const_expr(self.fold_sq): + q_tok = ( + common_params.blk_coord[1] * self.fold_sq_ratio + + (tTR_tS[i][0] + common_params.blk_coord[0] * cta_m_rows) + // self.num_heads + ) + else: + q_tok = common_params.blk_coord[1] + k_bound = common_params.K - (self.seq_len_q - 1) + q_tok tTR_rAcc[i] = ( tTR_rAcc[i] if cute.elem_less( tTR_tS[i][1] + self.mma_qk_tiler[1] * k_index, - common_params.K, + k_bound, ) - else -self.acc_dtype.inf + else self.acc_dtype(-1.0e6) ) - # reduction for row_max + # reduction for row_max after manual masking row_max_new = tTR_rAcc.load().reduce( cute.ReductionOp.MAX, row_max_new, 0 ) else: + # sm_103 pre-computed max via reduction is valid here because + # tTR_rAcc is unmodified (no mask applied to this tile). row_max_new = cute.arch.fmax(row_max_new, tTR_rMax[0]) # if warps in N is 2, reduce row_max across warps (0, 1) and (2, 3) @@ -3351,6 +3500,25 @@ def _compute_grid( return tile_sched_params, grid + @staticmethod + def compute_fold_sq_ratio(num_heads: int, seq_len_q: int, m_tile: int) -> int: + """Derive the seq_len_q-into-heads fold factor F. + + Returns the largest integer F such that: + - F divides seq_len_q evenly + - num_heads * F ≤ m_tile + - 1 ≤ F ≤ seq_len_q + + F=1 means no folding (i.e. ``fold_sq`` should be False at the caller). + """ + if num_heads >= m_tile: + return 1 + max_fold = min(seq_len_q, m_tile // num_heads) + for f in range(max_fold, 0, -1): + if seq_len_q % f == 0: + return f + return 1 + @staticmethod def get_workspace_size( H: int, @@ -3525,9 +3693,12 @@ def can_implement( return False if is_var_split_kv and not is_var_seq: return False - if H > 128: + if mma_qk_tiler_mn[0] < H: return False - if S <= 0 or S > 4: + # When H < M tile, fold up to F tokens of S into H (M_eff = H*F ≤ M_tile). + # F is auto-picked as the largest divisor of S with H*F ≤ M_tile. + # F=1 always works, so any (H ≤ M_tile, S ≥ 1) is implementable. + if S < 1: return False if K <= 0: return False @@ -3860,6 +4031,14 @@ def create_workspace( max_active_clusters = hardware_info.get_max_active_clusters( cluster_shape_mnk[0] * cluster_shape_mnk[1] ) + # When num_heads < M tile, fold up to F tokens of seq_len_q into heads, + # capped so num_heads * F ≤ M_tile. F must divide seq_len_q evenly. + # The class derives F internally; we mirror it here for split/workspace. + # H*F may be < M tile; TMA zero-fills OOB rows and epilogue guards skip padded output. + fold_sq_ratio = BlackwellMultiHeadLatentAttentionForwardFP8.compute_fold_sq_ratio( + num_heads, seq_len_q, mma_qk_tiler_mn[0] + ) + fold_sq = fold_sq_ratio > 1 split_kv, block_split_kvs_ref, block_split_kvs, block_split_kvs_torch = ( create_block_split_kvs( batch_size, @@ -3924,8 +4103,11 @@ def create_workspace( is_lse=True, seq_len_q=seq_len_q, ) + # Use effective dimensions for workspace when folding S_q into heads + num_heads_eff = num_heads * fold_sq_ratio + seq_len_q_eff = seq_len_q // fold_sq_ratio workspace, workspace_torch = create_workspace( - num_heads, seq_len_q, latent_dim, batch_size, split_kv, acc_dtype + num_heads_eff, seq_len_q_eff, latent_dim, batch_size, split_kv, acc_dtype ) mla = BlackwellMultiHeadLatentAttentionForwardFP8( @@ -3940,6 +4122,9 @@ def create_workspace( is_var_seq, is_var_split_kv, enable_pdl, + num_heads=num_heads, + seq_len_q=seq_len_q, + fold_sq=fold_sq, ) # Get current CUDA stream from PyTorch @@ -4007,16 +4192,35 @@ def torch_reference_mla( v_ref[b, :, cache_seqs_ref[b] :, :] = 0 import torch.nn.functional as F + # Always-on spec-decoding (MTP) causal mask: for Q token qi ∈ [0, S_q) + # and batch b, valid KV positions are [0, cache_seqs_ref[b] - S_q + 1 + qi). + # For S_q=1 this reduces to the plain K-bound check. SDPA treats + # q_ref=[B, S_q, H, D_total] as batch=B, group=S_q, query-seq=H, + # dim=D_total, so the mask is indexed by the group (S_q) dim and + # broadcasts over the query-seq (H) dim. + S_q_actual = q_ref.shape[1] + max_K_len = k_ref.shape[2] + attn_mask = torch.zeros(batch_size, S_q_actual, 1, max_K_len, dtype=torch.bool) + for b in range(batch_size): + Kb = int(cache_seqs_ref[b]) + for qi in range(S_q_actual): + upper = max(0, Kb - S_q_actual + 1 + qi) + attn_mask[b, qi, 0, :upper] = True + attn_mask_sdpa = attn_mask.to(q_ref.device) if q_ref.is_cuda else attn_mask + o_ref = F.scaled_dot_product_attention( q_ref, k_ref, v_ref, - attn_mask=None, + attn_mask=attn_mask_sdpa, dropout_p=0.0, scale=softmax_scale, is_causal=False, ) s_ref = torch.einsum("bhld,bhsd->bhls", q_ref, k_ref) + s_ref = s_ref.masked_fill( + ~attn_mask.to(s_ref.device).expand_as(s_ref), float("-inf") + ) s_ref_max, s_ref_max_pos = torch.max(s_ref, dim=-1, keepdim=True) softmax_scale_log2 = LOG2_E * softmax_scale s_ref_sum = torch.sum( @@ -4173,7 +4377,7 @@ def generate_tensors(): seq_len_q=seq_len_q, ) workspace, workspace_torch = create_workspace( - num_heads, seq_len_q, latent_dim, batch_size, _split_kv, acc_dtype + num_heads_eff, seq_len_q_eff, latent_dim, batch_size, _split_kv, acc_dtype ) return testing.JitArguments( q_latent, diff --git a/tests/attention/test_cute_dsl_mla_decode.py b/tests/attention/test_cute_dsl_mla_decode.py index 4938f922ad..6e60c52911 100644 --- a/tests/attention/test_cute_dsl_mla_decode.py +++ b/tests/attention/test_cute_dsl_mla_decode.py @@ -50,9 +50,17 @@ def torch_reference_mla( softmax_scale, output_scale, page_size, + apply_mtp_mask=False, ): """PyTorch reference implementation for MLA decode. + When ``apply_mtp_mask`` is True, applies the spec-decoding (MTP) causal + mask the monolithic kernel uses: for q_token qi ∈ [0, q_len), valid KV + positions are [0, seq_len - q_len + 1 + qi). For q_len=1 this reduces + to the plain K-bound check (no-op). The modular implementation does + not apply this mask, so callers exercising the modular path should + leave ``apply_mtp_mask=False``. + Args: q_nope: [B, q_len, H, latent_dim] q_rope: [B, q_len, H, rope_dim] @@ -63,6 +71,7 @@ def torch_reference_mla( softmax_scale: float output_scale: float page_size: int + apply_mtp_mask: bool — whether to apply the MTP causal mask. """ B, q_len, H, latent_dim = q_nope.shape @@ -94,6 +103,14 @@ def torch_reference_mla( attn_rope = torch.einsum("qhd,kd->qhk", q_rope_b.float(), k_rope.float()) attn = (attn_latent + attn_rope) * softmax_scale + # Spec-decoding (MTP) causal mask: row qi's k_bound is seq_len-(q_len-1)+qi. + if apply_mtp_mask and q_len > 1: + mask = torch.zeros(q_len, seq_len, dtype=torch.bool, device=attn.device) + for qi in range(q_len): + upper = max(0, seq_len - q_len + 1 + qi) + mask[qi, :upper] = True + attn = attn.masked_fill(~mask.unsqueeze(1), float("-inf")) + # Softmax attn = F.softmax(attn, dim=-1) @@ -182,6 +199,7 @@ def test_cute_dsl_mla_decode_fp16( q_nope = query[..., :latent_dim] q_rope = query[..., latent_dim:] + # Monolithic applies the MTP causal mask for q_len > 1; modular does not. ref_out = torch_reference_mla( q_nope, q_rope, @@ -192,6 +210,7 @@ def test_cute_dsl_mla_decode_fp16( softmax_scale, output_scale, page_size, + apply_mtp_mask=(cute_dsl_impl == "monolithic"), ) ref_out_cast = ref_out.to(dtype) @@ -200,6 +219,162 @@ def test_cute_dsl_mla_decode_fp16( torch.testing.assert_close(out, ref_out_cast, atol=1e-2, rtol=1e-2) +# Exercises the spec-decoding (MTP) causal mask + fold_sq path: num_heads < 128 +# forces the kernel to pack F = compute_fold_sq_ratio(H, q_len, 128) tokens of +# q_len into the head dim so the 128-wide MMA-M tile is fully populated. +# (H=128, q_len=any) → F=1 (no fold), (H=64, q_len=2) → F=2, (H=64, q_len=4) → F=2, +# (H=32, q_len=4) → F=4, (H=32, q_len=2) → F=2. All paths share the same +# kernel; the MTP causal mask is applied uniformly for q_len > 1. +# Monolithic-only: the modular path doesn't implement fold_sq or the MTP mask. +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("seq_len_k", [128, 1024]) +@pytest.mark.parametrize("num_heads", [16, 32, 64]) +@pytest.mark.parametrize("q_len", [2, 4]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float8_e4m3fn]) +def test_cute_dsl_mla_decode_fold_sq( + batch_size, seq_len_k, num_heads, q_len, dtype, cute_dsl_impl +): + """Verify the MTP causal mask + fold_sq packing for H ≤ 128 and q_len > 1.""" + if cute_dsl_impl != "monolithic": + pytest.skip("fold_sq / MTP causal mask are monolithic-only features") + skip_if_unsupported() + + from flashinfer.cute_dsl.attention import cute_dsl_mla_decode + + torch.manual_seed(42) + device = torch.device("cuda") + + page_size = 64 + latent_dim = 512 + rope_dim = 64 + softmax_scale = 1.0 / (latent_dim**0.5) + output_scale = 1.0 + D_qk = latent_dim + rope_dim + + # torch.randn doesn't support fp8; for FP8 inputs create as fp16 then convert. + is_fp8 = dtype == torch.float8_e4m3fn + if is_fp8: + query = ( + torch.randn( + batch_size, q_len, num_heads, D_qk, dtype=torch.float16, device=device + ) + * 0.1 + ).to(torch.float8_e4m3fn) + else: + query = torch.randn( + batch_size, q_len, num_heads, D_qk, dtype=dtype, device=device + ) + + num_pages_per_batch = (seq_len_k + page_size - 1) // page_size + total_pages = num_pages_per_batch * batch_size + 10 + if is_fp8: + kv_cache = ( + torch.randn( + total_pages, page_size, D_qk, dtype=torch.float16, device=device + ) + * 0.1 + ).to(torch.float8_e4m3fn) + else: + kv_cache = torch.randn(total_pages, page_size, D_qk, dtype=dtype, device=device) + + block_tables = torch.zeros( + batch_size, num_pages_per_batch, dtype=torch.int32, device=device + ) + for b in range(batch_size): + for p in range(num_pages_per_batch): + block_tables[b, p] = b * num_pages_per_batch + p + + seq_lens = torch.full((batch_size,), seq_len_k, dtype=torch.int32, device=device) + + workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device=device) + + out = cute_dsl_mla_decode( + query=query, + kv_cache=kv_cache, + workspace_buffer=workspace_buffer, + kv_lora_rank=latent_dim, + qk_rope_head_dim=rope_dim, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=seq_len_k, + softmax_scale=softmax_scale, + output_scale=output_scale, + is_var_seq=False, + cute_dsl_impl=cute_dsl_impl, + ) + + # FP8 input → BF16 output (default), so do the reference in FP32 with wider tolerance. + if is_fp8: + kv_flat = kv_cache.reshape(-1, D_qk).to(torch.float32) + q_nope = query[..., :latent_dim].to(torch.float32) + q_rope = query[..., latent_dim:].to(torch.float32) + else: + kv_flat = kv_cache.reshape(-1, D_qk) + q_nope = query[..., :latent_dim] + q_rope = query[..., latent_dim:] + c_latent_ref = kv_flat[:, :latent_dim] + c_rope_ref = kv_flat[:, latent_dim:] + + # Monolithic-only test — kernel always applies the MTP causal mask here. + ref_out = torch_reference_mla( + q_nope, + q_rope, + c_latent_ref, + c_rope_ref, + block_tables, + seq_lens, + softmax_scale, + output_scale, + page_size, + apply_mtp_mask=True, + ) + + if is_fp8: + # FP8 has limited precision; compare in FP32 with wider tolerance. + torch.testing.assert_close( + out.to(torch.float32), ref_out.to(torch.float32), atol=0.1, rtol=0.1 + ) + else: + ref_out_cast = ref_out.to(dtype) + torch.testing.assert_close(out, ref_out_cast, atol=1e-2, rtol=1e-2) + + +def test_compute_fold_sq_ratio(): + """Unit test the static helper used by both run() and the wrapper.""" + if not is_cute_dsl_available(): + pytest.skip("CuTe DSL not available") + from flashinfer.cute_dsl.attention.monolithic.mla_decode_fp16 import ( + BlackwellMultiHeadLatentAttentionForwardFP16 as FP16, + ) + from flashinfer.cute_dsl.attention.monolithic.mla_decode_fp8 import ( + BlackwellMultiHeadLatentAttentionForwardFP8 as FP8, + ) + + cases = [ + # (num_heads, seq_len_q, m_tile, expected) + (128, 1, 128, 1), # H == M_tile → no fold + (128, 4, 128, 1), # H == M_tile → no fold + (64, 1, 128, 1), # seq_len_q=1 → F=1 + (64, 2, 128, 2), # exact divisor, H*F=128 ≤ M_tile + (64, 4, 128, 2), # H*F ≤ 128 caps F at 2; 4 % 2 == 0 + (64, 3, 128, 1), # 3's only divisors are 1 and 3; H*3=192 > M_tile → F=1 + (32, 4, 128, 4), # tighter pack: F=4, H*F=128 + (32, 8, 128, 4), # capped by M_tile/H = 4 + (32, 3, 128, 3), # max_fold=min(3, 4)=3; 3 % 3 == 0 → F=3 + (16, 8, 128, 8), # max_fold=min(8, 8)=8; 8 % 8 == 0 → F=8 + (16, 6, 128, 6), # max_fold=min(6, 8)=6; 6 % 6 == 0 → F=6 + ] + for H, S_q, m_tile, expected in cases: + assert FP16.compute_fold_sq_ratio(H, S_q, m_tile) == expected, ( + f"FP16.compute_fold_sq_ratio({H}, {S_q}, {m_tile}) " + f"= {FP16.compute_fold_sq_ratio(H, S_q, m_tile)}, expected {expected}" + ) + assert FP8.compute_fold_sq_ratio(H, S_q, m_tile) == expected, ( + f"FP8.compute_fold_sq_ratio({H}, {S_q}, {m_tile}) " + f"= {FP8.compute_fold_sq_ratio(H, S_q, m_tile)}, expected {expected}" + ) + + @pytest.mark.parametrize("batch_size", [1, 4, 16]) @pytest.mark.parametrize("seq_len_k", [128, 512, 2048]) @pytest.mark.parametrize("page_size", [32, 128]) From f285855fc7782e565b83f06378dacb93378589f7 Mon Sep 17 00:00:00 2001 From: jingzec Date: Wed, 13 May 2026 04:40:12 -0700 Subject: [PATCH 3/9] fix split_kv calculation --- flashinfer/cute_dsl/attention/monolithic/mla_decode.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/flashinfer/cute_dsl/attention/monolithic/mla_decode.py b/flashinfer/cute_dsl/attention/monolithic/mla_decode.py index 67c8a9a831..722a690a06 100644 --- a/flashinfer/cute_dsl/attention/monolithic/mla_decode.py +++ b/flashinfer/cute_dsl/attention/monolithic/mla_decode.py @@ -48,9 +48,6 @@ def _get_split_kv_and_workspace_size( max_active_blocks: int, ) -> Tuple[int, int]: """Cache split_kv and workspace_size since they are deterministic for the same params.""" - split_kv = BlackwellMultiHeadLatentAttentionForwardFP16.get_split_kv_simplified( - B, q_len, max_active_blocks - ) # When folding S_q into heads, the workspace dims are the effective dims # (num_heads * F, q_len // F). get_workspace_size already pads H<128 to # 128, so passing num_heads_eff and seq_len_q_eff yields the right size. @@ -60,6 +57,9 @@ def _get_split_kv_and_workspace_size( ) num_heads_eff = H * fold_sq_ratio seq_len_q_eff = q_len // fold_sq_ratio + split_kv = BlackwellMultiHeadLatentAttentionForwardFP16.get_split_kv_simplified( + B, seq_len_q_eff, max_active_blocks + ) workspace_size = BlackwellMultiHeadLatentAttentionForwardFP16.get_workspace_size( num_heads_eff, seq_len_q_eff, kv_lora_rank, B, split_kv, cutlass.Float32 ) @@ -453,7 +453,6 @@ def cute_dsl_mla_decode( # for fix-length, set is_persistent to True; otherwise, set to False. is_persistent = not is_var_seq - print(f"is_persistent: {is_persistent}") # Validate configuration (cached, negligible overhead after first call) _check_can_implement( From b908aa407fa3d1451b19bd10cd7d36582a6935c2 Mon Sep 17 00:00:00 2001 From: jingzec Date: Sun, 24 May 2026 20:23:01 -0700 Subject: [PATCH 4/9] minor fix --- flashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.py | 4 ++-- flashinfer/cute_dsl/attention/monolithic/mla_helpers.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.py b/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.py index 5ee65218b5..77ba90699a 100644 --- a/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.py +++ b/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.py @@ -250,7 +250,7 @@ def __init__( self.warps_in_n = 2 self.num_compute_warps = 4 self.threads_per_warp = 32 - mma_qk_tiler_k = self.rope_dim + mma_qk_tiler_k = self.rope_dim if self.seq_len_q == 1 else self.rope_dim * 2 self.mma_qk_tiler = ( self.mma_qk_tiler_mn[0], self.mma_qk_tiler_mn[1], @@ -321,7 +321,7 @@ def _setup_attributes(self): """ self.load_q_stage = 1 - self.load_kv_stage = 15 + self.load_kv_stage = 15 if self.seq_len_q == 1 else 7 self.mma_s_stage = 2 self.p_mma_stage = 2 self.p_cor_stage = 2 diff --git a/flashinfer/cute_dsl/attention/monolithic/mla_helpers.py b/flashinfer/cute_dsl/attention/monolithic/mla_helpers.py index ac2bee49df..722bcfdf93 100644 --- a/flashinfer/cute_dsl/attention/monolithic/mla_helpers.py +++ b/flashinfer/cute_dsl/attention/monolithic/mla_helpers.py @@ -253,7 +253,7 @@ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: blk_coord = (cluster_idx, s_idx, b_idx, split_kv_idx) else: - s_idx, b_idx = divmod(self.blk_coord[1], self.params.problem_shape_b_fdd) + b_idx, s_idx = divmod(self.blk_coord[1], self.params.problem_shape_s_fdd) blk_coord = (self.blk_coord[0], s_idx, b_idx, self.blk_coord[2]) return WorkTileInfo(blk_coord, is_valid) From d924e6bd67ff0b3188b125f4d2aab09330ab45d4 Mon Sep 17 00:00:00 2001 From: jingzec Date: Mon, 25 May 2026 02:51:07 -0700 Subject: [PATCH 5/9] improve fp8 perf --- .../attention/monolithic/mla_decode_fp8.py | 881 +++++++++++++----- 1 file changed, 654 insertions(+), 227 deletions(-) diff --git a/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp8.py b/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp8.py index b6648ff7a8..f152ee7b25 100644 --- a/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp8.py +++ b/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp8.py @@ -268,42 +268,87 @@ def __init__( self.iterations_pv_k = self.mma_qk_tiler[1] // self.mma_pv_tiler[2] self.iterations_pv_n = self.latent_dim // self.mma_pv_tiler[1] - # Set specialized warp ids + # Set specialized warp ids. + # Compute (softmax) warp groups: g0 = warps 0-3 (even k-tiles), g1 = warps + # 12-15 (odd k-tiles). Correction warps 4-7. MMA split: W8 issues mma_qk + # only, W11 issues mma_pv only (was empty filler). self.compute_warp_ids = (0, 1, 2, 3) self.correction_warp_ids = (4, 5, 6, 7) - self.mma_warp_id = 8 + self.mma_qk_warp_id = 8 self.load_tma_k_warp_id = 9 self.load_tma_v_warp_id = 10 - self.empty_warp_ids = (11,) + self.mma_pv_warp_id = 11 # was empty_warp_ids[0] + # Second softmax warp group (g1): odd k-tiles. Mirrors compute_warp_ids. + self.second_compute_warp_ids = (12, 13, 14, 15) + self.num_total_compute_warps = self.num_compute_warps + len( + self.second_compute_warp_ids + ) self.threads_per_cta = self.threads_per_warp * len( ( - self.mma_warp_id, + self.mma_qk_warp_id, self.load_tma_k_warp_id, self.load_tma_v_warp_id, *self.compute_warp_ids, + *self.second_compute_warp_ids, *self.correction_warp_ids, - *self.empty_warp_ids, + self.mma_pv_warp_id, ) ) - # register settings - self.softmax_reg_num = 192 - self.correction_reg_num = 256 - self.other_reg_num = 48 + # register settings. Sized to fit 16-warp launch (mma_qk + mma_pv + + # 8 compute + 4 correction + 2 load): + # 8 softmax * 160 + 4 correction * 160 + 4 other * 32 + # = 1280 + 640 + 128 = 2048 regs/thread × 32 = 65536 (fits 64K). + self.softmax_reg_num = 160 + self.correction_reg_num = 160 + self.other_reg_num = 32 # Named barriers + # tmem_ptr_sync_bar: mma_qk(1) + mma_pv(1) + 8 compute + 4 correction = 14 warps × 32 self.tmem_ptr_sync_bar = pipeline.NamedBarrier( barrier_id=1, num_threads=( - self.threads_per_warp - + self.threads_per_warp * self.num_compute_warps * 2 + self.threads_per_warp * 2 # mma_qk + mma_pv + + self.threads_per_warp * self.num_total_compute_warps # 8 compute + + self.threads_per_warp * self.num_compute_warps # 4 correction ), ) - self.softmax_exchange_sync_bar = pipeline.NamedBarrier( - barrier_id=2, num_threads=(self.threads_per_warp * self.num_compute_warps) + # softmax_exchange covers BOTH compute groups (8 warps) for cross-group + # row_max / row_sum merge. + self.softmax_exchange_sync_bar_0 = pipeline.NamedBarrier( + barrier_id=2, + num_threads=(self.threads_per_warp * self.num_compute_warps), ) - self.epilogue_exchange_sync_bar = pipeline.NamedBarrier( - barrier_id=3, num_threads=(self.threads_per_warp * self.num_compute_warps) + self.softmax_exchange_sync_bar_1 = pipeline.NamedBarrier( + barrier_id=3, + num_threads=(self.threads_per_warp * self.num_compute_warps), ) + self.epilogue_exchange_sync_bar = pipeline.NamedBarrier( + barrier_id=4, + num_threads=(self.threads_per_warp * self.num_compute_warps), + ) + # Pingpong order barriers (OrderedSequenceBarrier<1,2> pattern). Each + # group waits on its own bar via arrive_and_wait (= bar.sync); the OTHER + # group signals via split-phase .arrive() (non-blocking). num_threads + # MUST cover both groups (256 = wait-side 128 + signal-side 128) so the + # bar releases only after BOTH have arrived — that's the cross-group + # serialization that gives the TMEM peer-read its happens-before. + # Init-phase trick: g1 pre-arrives bar_0 once at warp setup so g0's + # first arrive_and_wait completes without waiting for g1's loop arrive. + self.softmax_order_bar_0 = pipeline.NamedBarrier( + barrier_id=5, + num_threads=(self.threads_per_warp * self.num_total_compute_warps), + ) + self.softmax_order_bar_1 = pipeline.NamedBarrier( + barrier_id=6, + num_threads=(self.threads_per_warp * self.num_total_compute_warps), + ) + # Init seed for TMEM corr (row_max, row_sum). load_other_group_metadata + # reads peer's prev-tile metadata; the first tile of each group needs a + # valid "no prev yet" seed: row_max = -inf, row_sum = 0 form the online + # softmax identity (fmax(x, -inf) = x; running_sum * 1 + s = s). + self.init_row_max = -float("inf") + # TMEM corr stage: 4 32-bit cols (row_sum, row_max, corr, no_corr). + self.tmem_corr_stage_cols = 4 def _setup_attributes(self): """Set up configurations and parameters for the MLA kernel operation. @@ -792,8 +837,10 @@ class SplitKVKernelSharedStorage: cute.struct.MemRange[self.v_dtype, cute.cosize(vc_smem_layout_staged)], 1024, ] + # 2softmax: doubled so both compute groups can write simultaneously. + # g0 slots [0, 128); g1 slots [128, 256). softmax_smem_exchange: cute.struct.MemRange[ - self.acc_dtype, self.num_compute_warps * self.threads_per_warp + self.acc_dtype, 2 * self.num_compute_warps * self.threads_per_warp ] epilogue_smem_exchange: cute.struct.MemRange[ self.acc_dtype, self.num_compute_warps * self.threads_per_warp @@ -1029,7 +1076,7 @@ def split_kv_kernel( is_leader_cta = mma_tile_coord_v == 0 # Prefetch tma descriptor - if warp_idx == self.mma_warp_id: + if warp_idx == self.mma_qk_warp_id: cpasync.prefetch_descriptor(tma_atom_q_latent) cpasync.prefetch_descriptor(tma_atom_q_rope) cpasync.prefetch_descriptor(tma_atom_c_latent) @@ -1040,11 +1087,16 @@ def split_kv_kernel( smem = utils.SmemAllocator() storage = smem.allocate(SharedStorage) - # Tensor memory dealloc barrier init + # Tensor memory dealloc barrier init. + # TMEM lifetime is owned by mma_pv warp (W11) — the LAST TMEM user. + # W8 (mma_qk) finishes earlier via mma_s pipeline back-pressure, but W11 + # keeps reading P / writing O until its mma_o.producer_tail. Putting + # allocate + free on W11 avoids the race where W8 frees TMEM while + # W11 still has in-flight mma_pv. See OPT#14. tmem = utils.TmemAllocator( storage.tmem_holding_buf, barrier_for_retrieve=self.tmem_ptr_sync_bar, - allocator_warp_id=self.mma_warp_id, + allocator_warp_id=self.mma_pv_warp_id, is_two_cta=self.use_2cta_instrs, two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, ) @@ -1117,9 +1169,9 @@ def split_kv_kernel( sP = storage.smem_p.get_tensor( p_smem_layout_staged.outer, swizzle=p_smem_layout_staged.inner ) - # (compute_threads,) + # (compute_threads,) — doubled for 2softmax (both groups exchange concurrently). softmax_smem_exchange = storage.softmax_smem_exchange.get_tensor( - cute.make_layout(self.num_compute_warps * self.threads_per_warp) + cute.make_layout(2 * self.num_compute_warps * self.threads_per_warp) ) epilogue_smem_exchange = storage.epilogue_smem_exchange.get_tensor( cute.make_layout(self.num_compute_warps * self.threads_per_warp) @@ -1136,8 +1188,7 @@ def split_kv_kernel( # /////////////////////////////////////////////////////////////////////////////// # Load warps, including page table and data tensors # /////////////////////////////////////////////////////////////////////////////// - if warp_idx >= self.empty_warp_ids[0] and warp_idx <= self.empty_warp_ids[-1]: - _setmaxregister_decrease(self.other_reg_num) + # Note: warp 11 (formerly empty filler) is now mma_pv_warp_id — handled below. if warp_idx == self.load_tma_k_warp_id: _setmaxregister_decrease(self.other_reg_num) @@ -1243,12 +1294,12 @@ def split_kv_kernel( load_v_pipeline.producer_tail(load_v_producer_state) # /////////////////////////////////////////////////////////////////////////////// - # MMA warp + # MMA-QK warp (W8): issues all mma_qk, produces S via mma_s pipeline. + # Does NOT allocate or free TMEM (W11 owns TMEM lifetime — OPT#14). # /////////////////////////////////////////////////////////////////////////////// - if warp_idx == self.mma_warp_id: + if warp_idx == self.mma_qk_warp_id: _setmaxregister_decrease(self.other_reg_num) - # Alloc tensor memory buffer - tmem.allocate(_get_max_tmem_alloc_cols("sm_100")) + # TMEM allocation done by W11; W8 just waits and retrieves the ptr. tmem.wait_for_alloc() tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) @@ -1258,18 +1309,9 @@ def split_kv_kernel( load_k_consumer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.load_k_stage ) - load_v_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.load_v_stage - ) mma_s_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.mma_s_stage ) - p_mma_consumer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.p_mma_stage - ) - mma_o_producer_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.mma_o_stage - ) tile_sched = create_mla_static_tile_scheduler( tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() ) @@ -1285,7 +1327,6 @@ def split_kv_kernel( local_split_kv=local_split_kv, load_q_pipeline=load_q_pipeline, load_k_pipeline=load_k_pipeline, - load_v_pipeline=load_v_pipeline, tmem_ptr=tmem_ptr, is_leader_cta=is_leader_cta, L=mCL.shape[1], @@ -1297,41 +1338,88 @@ def split_kv_kernel( sKC=sKC, sKC_rope=sKC_rope, ) - mma_pv_params = SimpleNamespace( - p_mma_pipeline=p_mma_pipeline, - mma_o_pipeline=mma_o_pipeline, - sP=sP, - sVC=sVC, - ) ( tiled_mma_qk, - tiled_mma_pv, load_q_consumer_state, load_k_consumer_state, - load_v_consumer_state, mma_s_producer_state, - p_mma_consumer_state, - mma_o_producer_state, ) = self.mma( mma_common_params, mma_qk_params, - mma_pv_params, k_tile_count, tiled_mma_qk, - tiled_mma_pv, load_q_consumer_state, load_k_consumer_state, - load_v_consumer_state, mma_s_producer_state, + ) + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + mma_s_pipeline.producer_tail(mma_s_producer_state) + # TMEM relinquish/free done by W11 (mma_pv warp, allocator). + + # /////////////////////////////////////////////////////////////////////////////// + # MMA-PV warp (W11): owns TMEM lifetime. Issues all mma_pv. + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.mma_pv_warp_id: + _setmaxregister_decrease(self.other_reg_num) + # W11 (mma_pv) owns TMEM lifetime: allocate here, free after the loop. + tmem.allocate(_get_max_tmem_alloc_cols("sm_100")) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + + load_v_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.load_v_stage + ) + p_mma_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.p_mma_stage + ) + mma_o_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.mma_o_stage + ) + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self.get_k_tile_count( + split_kv, cache_seqs, block_split_kvs, blk_coord + ) + if k_tile_count > 0: + mma_pv_common_params = SimpleNamespace( + blk_coord=blk_coord, + local_split_kv=local_split_kv, + load_v_pipeline=load_v_pipeline, + tmem_ptr=tmem_ptr, + is_leader_cta=is_leader_cta, + L=mCL.shape[1], + ) + mma_pv_only_params = SimpleNamespace( + p_mma_pipeline=p_mma_pipeline, + mma_o_pipeline=mma_o_pipeline, + sP=sP, + sVC=sVC, + ) + ( + tiled_mma_pv, + load_v_consumer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) = self.mma_pv_warp_body( + mma_pv_common_params, + mma_pv_only_params, + k_tile_count, + tiled_mma_pv, + load_v_consumer_state, p_mma_consumer_state, mma_o_producer_state, ) tile_sched.advance_to_next_work() work_tile = tile_sched.get_current_work() - mma_s_pipeline.producer_tail(mma_s_producer_state) mma_o_pipeline.producer_tail(mma_o_producer_state) - + # W11 is the allocator; safe to free now that all mma_pv has retired. tmem.relinquish_alloc_permit() tmem.free(tmem_ptr) if cutlass.const_expr(self.enable_pdl): @@ -1399,11 +1487,90 @@ def split_kv_kernel( mma_s_consumer_state=mma_s_consumer_state, p_mma_producer_state=p_mma_producer_state, p_cor_producer_state=p_cor_producer_state, + is_second_compute_warp=False, + ) + ) + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # /////////////////////////////////////////////////////////////////////////////// + # Compute warp — second group (g1, warps 12-15, odd k-tiles). + # 2softmax Strategy A: alternates k-tiles with g0; cross-group merges + # inside softmax() via softmax_exchange_sync_bar. + # /////////////////////////////////////////////////////////////////////////////// + if ( + warp_idx >= self.second_compute_warp_ids[0] + and warp_idx <= self.second_compute_warp_ids[-1] + ): + _setmaxregister_increase(self.softmax_reg_num) + mma_s_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.mma_s_stage + ) + p_mma_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.p_mma_stage + ) + p_cor_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.p_cor_stage + ) + mma_o_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.mma_o_stage + ) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + mma_s_consumer_state.advance() + p_mma_producer_state.advance() + p_cor_producer_state.advance() + # Pingpong init-phase trick: g1 pre-arrives softmax_order_bar_0 once + # so g0's FIRST softmax_order_bar_0.arrive_and_wait() completes + # without waiting for g1's first loop arrive. Replaces a per-iter + # is_first_tile guard. Bar is num_total_compute_warps (256 threads) + # so wait-side 128 + signal-side 128 = 256 → release. + while work_tile.is_valid_tile: + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self.get_k_tile_count( + split_kv, cache_seqs, block_split_kvs, blk_coord + ) + if k_tile_count > 0: + compute_common_params = SimpleNamespace( + blk_coord=blk_coord, + split_kv=split_kv, + local_split_kv=local_split_kv, + smem_exchange=softmax_smem_exchange, + mAccO=mAccO, + mO=mO, + K=cache_seqs[blk_coord[2]], + L=mCL.shape[1], + tmem_ptr=tmem_ptr, + tidx=tidx, + p_cor_pipeline=p_cor_pipeline, + ) + compute_softmax_params = SimpleNamespace( + tiled_mma_qk=tiled_mma_qk, + sP=sP, + mma_s_pipeline=mma_s_pipeline, + p_mma_pipeline=p_mma_pipeline, + softmax_scale_log2=softmax_scale_log2, + ) + mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state = ( + self.compute( + compute_common_params, + compute_softmax_params, + k_index=k_index, + k_tile_count=k_tile_count, + mma_s_consumer_state=mma_s_consumer_state, + p_mma_producer_state=p_mma_producer_state, + p_cor_producer_state=p_cor_producer_state, + is_second_compute_warp=True, ) ) tile_sched.advance_to_next_work() work_tile = tile_sched.get_current_work() - p_cor_pipeline.producer_tail(p_cor_producer_state) + # NOTE: g1 skips p_cor_pipeline.producer_tail — g0 already does it. # /////////////////////////////////////////////////////////////////////////////// # Correction warp @@ -1465,7 +1632,6 @@ def split_kv_kernel( ) tile_sched.advance_to_next_work() work_tile = tile_sched.get_current_work() - return @cute.kernel @@ -2030,62 +2196,27 @@ def mma( self, common_params: SimpleNamespace, qk_params: SimpleNamespace, - pv_params: SimpleNamespace, k_tile_count: cutlass.Int32, tiled_mma_qk: cute.TiledMma, - tiled_mma_pv: cute.TiledMma, load_q_consumer_state: pipeline.PipelineState, load_k_consumer_state: pipeline.PipelineState, - load_v_consumer_state: pipeline.PipelineState, mma_s_producer_state: pipeline.PipelineState, - p_mma_consumer_state: pipeline.PipelineState, - mma_o_producer_state: pipeline.PipelineState, ) -> tuple[ - cute.TiledMma, cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, - pipeline.PipelineState, - pipeline.PipelineState, ]: - """MMA warp to compute the result of Q*K^T and P*V. Updates the tiled mma and pipeline states. - - :param common_params: The common parameters for mma qk and pv - :type common_params: SimpleNamespace - :param qk_params: The mma qk parameters - :type qk_params: SimpleNamespace - :param pv_params: The mma pv parameters - :type pv_params: SimpleNamespace - :param k_tile_count: The k tile count - :type k_tile_count: cutlass.Int32 - :param tiled_mma_qk: The tiled mma qk - :type tiled_mma_qk: cute.TiledMma - :param tiled_mma_pv: The tiled mma pv - :type tiled_mma_pv: cute.TiledMma - :param load_q_consumer_state: The load q consumer state - :type load_q_consumer_state: pipeline.PipelineState - :param load_k_consumer_state: The load k consumer state - :type load_k_consumer_state: pipeline.PipelineState - :param load_v_consumer_state: The load v consumer state - :type load_v_consumer_state: pipeline.PipelineState - :param mma_s_producer_state: The mma s producer state - :type mma_s_producer_state: pipeline.PipelineState - :param p_mma_consumer_state: The p mma consumer state - :type p_mma_consumer_state: pipeline.PipelineState - :param mma_o_producer_state: The mma o producer state - :type mma_o_producer_state: pipeline.PipelineState + """MMA-QK warp body (W8). Issues K mma_qk calls producing S to TMEM. - :return: The tiled mma qk, the tiled mma pv, the load q consumer state, the load k consumer state, the load v consumer state, the mma s producer state, the p mma consumer state, and the mma o producer state - :rtype: tuple[cute.TiledMma, cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] + Split from the previous fused mma_qk_pv: PV is now handled by + mma_pv_warp_body on W11. See OPT#16 / OPT#14 (TMEM allocator on W11). """ tSrQ = tiled_mma_qk.make_fragment_A(qk_params.sQ) tSrQ_rope = tiled_mma_qk.make_fragment_A(qk_params.sQ_rope) tSrKC = tiled_mma_qk.make_fragment_B(qk_params.sKC) tSrKC_rope = tiled_mma_qk.make_fragment_B(qk_params.sKC_rope) - tOrP = tiled_mma_pv.make_fragment_A(pv_params.sP) - tOrVC = tiled_mma_pv.make_fragment_B(pv_params.sVC) tStS_shape = tiled_mma_qk.partition_shape_C( cute.select(self.mma_qk_tiler, mode=[0, 1]) @@ -2093,55 +2224,20 @@ def mma( tStS_staged_fake = tiled_mma_qk.make_fragment_C( cute.append(tStS_shape, self.mma_s_stage) ) - # use real tmem ptr for tStS tStS_staged = cute.make_tensor(common_params.tmem_ptr, tStS_staged_fake.layout) - tOtO_shape = tiled_mma_pv.partition_shape_C( - cute.select(self.mma_pv_tiler, mode=[0, 1]) - ) - # mma O has 1 stage. - tOtO = tiled_mma_pv.make_fragment_C(tOtO_shape) - tOtO_layout = cute.append( - tOtO.layout, - cute.make_layout( - common_params.L // self.mma_pv_tiler[1], - stride=self.mma_pv_tiler[1] // self.warps_in_n, - ), - ) - tOtO_staged = cute.make_tensor( - tStS_staged.iterator + self.tmem_o_offset, tOtO_layout - ) - # set more parameters qk_params.tSrQ = tSrQ qk_params.tSrQ_rope = tSrQ_rope qk_params.tSrKC = tSrKC qk_params.tSrKC_rope = tSrKC_rope qk_params.tStS_staged = tStS_staged - pv_params.tOrP = tOrP - pv_params.tOrVC = tOrVC - pv_params.tOtO_staged = tOtO_staged - # mma O accumulates on K, so the accumlate flag is set to False once before all K blocks. - tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, False) load_q_pipeline = common_params.load_q_pipeline if common_params.is_leader_cta: load_q_release_state = load_q_consumer_state.clone() - ( - tiled_mma_qk, - load_q_consumer_state, - load_k_consumer_state, - mma_s_producer_state, - ) = self.mma_qk( - common_params, - qk_params, - tiled_mma_qk, - load_q_consumer_state, - load_k_consumer_state, - mma_s_producer_state, - wait_q=True, - ) - k_tile_count -= 1 + load_q_pipeline.consumer_wait(load_q_consumer_state) + load_q_consumer_state.advance() while k_tile_count > 0: ( tiled_mma_qk, @@ -2157,6 +2253,68 @@ def mma( mma_s_producer_state, wait_q=False, ) + k_tile_count -= 1 + # release q consumer states + load_q_pipeline.consumer_release(load_q_release_state) + load_q_release_state.advance() + # NOTE: mma_pv (mainloop + epilog) is now handled by warp 11 (mma_pv_warp_body). + + return ( + tiled_mma_qk, + load_q_consumer_state, + load_k_consumer_state, + mma_s_producer_state, + ) + + @cute.jit + def mma_pv_warp_body( + self, + common_params: SimpleNamespace, + pv_params: SimpleNamespace, + k_tile_count: cutlass.Int32, + tiled_mma_pv: cute.TiledMma, + load_v_consumer_state: pipeline.PipelineState, + p_mma_consumer_state: pipeline.PipelineState, + mma_o_producer_state: pipeline.PipelineState, + ) -> tuple[ + cute.TiledMma, + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + ]: + """PV-only MMA warp body (W11). Runs K mma_pv calls — one per k-tile. + + W11 owns TMEM lifetime; allocate/free happen in the caller __call__ + branch around this loop. ACCUMULATE=False is set once before the loop; + each mma_pv accumulates into the same TMEM O region across k-tiles. + """ + tOrP = tiled_mma_pv.make_fragment_A(pv_params.sP) + tOrVC = tiled_mma_pv.make_fragment_B(pv_params.sVC) + + tOtO_shape = tiled_mma_pv.partition_shape_C( + cute.select(self.mma_pv_tiler, mode=[0, 1]) + ) + tOtO = tiled_mma_pv.make_fragment_C(tOtO_shape) + tOtO_layout = cute.append( + tOtO.layout, + cute.make_layout( + common_params.L // self.mma_pv_tiler[1], + stride=self.mma_pv_tiler[1] // self.warps_in_n, + ), + ) + tOtO_staged = cute.make_tensor( + common_params.tmem_ptr + self.tmem_o_offset, tOtO_layout + ) + + pv_params.tOrP = tOrP + pv_params.tOrVC = tOrVC + pv_params.tOtO_staged = tOtO_staged + + # mma O accumulates across k-tiles, set ACCUMULATE=False once before loop. + tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, False) + + if common_params.is_leader_cta: + while k_tile_count > 0: ( tiled_mma_pv, load_v_consumer_state, @@ -2171,30 +2329,10 @@ def mma( mma_o_producer_state, ) k_tile_count -= 1 - # release q consumer states - load_q_pipeline.consumer_release(load_q_release_state) - load_q_release_state.advance() - ( - tiled_mma_pv, - load_v_consumer_state, - p_mma_consumer_state, - mma_o_producer_state, - ) = self.mma_pv( - common_params, - pv_params, - tiled_mma_pv, - load_v_consumer_state, - p_mma_consumer_state, - mma_o_producer_state, - ) return ( # type: ignore[return-value] - tiled_mma_qk, tiled_mma_pv, - load_q_consumer_state, - load_k_consumer_state, load_v_consumer_state, - mma_s_producer_state, p_mma_consumer_state, mma_o_producer_state, ) @@ -2356,6 +2494,21 @@ def mma_pv( mma_o_producer_state, ) + @cute.jit + def softmax_advance_to_next_group( + self, + common_params: SimpleNamespace, + p_mma_producer_state: pipeline.PipelineState, + mma_s_consumer_state: pipeline.PipelineState, + p_cor_producer_state: pipeline.PipelineState, + ) -> tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState]: + """Advance the P MMA producer state, the MMA s consumer state, and the P correction producer state to the next group.""" + p_mma_producer_state.advance() + mma_s_consumer_state.advance() + p_cor_producer_state.advance() + common_params.p_cor_pipeline.producer_acquire(p_cor_producer_state) + return p_mma_producer_state, mma_s_consumer_state, p_cor_producer_state + @cute.jit def compute( self, @@ -2366,6 +2519,7 @@ def compute( mma_s_consumer_state: pipeline.PipelineState, p_mma_producer_state: pipeline.PipelineState, p_cor_producer_state: pipeline.PipelineState, + is_second_compute_warp: bool, ) -> tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState]: """Compute warp to compute the result of softmax, rescale, and epilogue. Updates the related pipeline states. @@ -2383,6 +2537,8 @@ def compute( :type p_mma_producer_state: pipeline.PipelineState :param p_cor_producer_state: The P correction producer state :type p_cor_producer_state: pipeline.PipelineState + :param is_second_compute_warp: True for g1 (odd k-tiles), False for g0 + :type is_second_compute_warp: bool :return: The MMA s consumer state, the P MMA producer state, and the P correction producer state :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] @@ -2390,11 +2546,32 @@ def compute( k_tile_total = cute.ceil_div(common_params.K, self.mma_qk_tiler[1]) - row_max = -self.acc_dtype.inf + # 2softmax: row_max initialised from the init seed (-inf) — same value + # that init_p_cor_metadata writes to peer-readable TMEM below, so the + # first tile's load_other_group_metadata + fmax(row_max_new, peer) gives + # the correct identity for online softmax. + row_max = self.acc_dtype(self.init_row_max) row_sum = self.acc_dtype(0) correction_factor = self.acc_dtype(1) + odd_k_tile = k_tile_count % 2 == 1 + # 2softmax: g0 takes even k-tiles, g1 takes odd k-tiles. g1 advances all + # pipeline states once at entry to start on stage 1 (one stage per group). + if cutlass.const_expr(is_second_compute_warp): + k_index = k_index + 1 + k_tile_count = k_tile_count // 2 + else: + k_tile_count = (k_tile_count + 1) // 2 # g0 takes the extra tile if odd common_params.p_cor_pipeline.producer_acquire(p_cor_producer_state) + # Seed this group's home p_cor TMEM stage so the FIRST tile's peer-read + # in load_other_group_metadata returns (row_max=-inf, row_sum=0) instead + # of uninitialised TMEM. Must run after producer_acquire so this group + # owns its stage write. + if cutlass.const_expr(is_second_compute_warp): + self.init_p_cor_metadata( + common_params, softmax_params, p_cor_producer_state + ) + self.softmax_order_bar_0.arrive() # Number of tiles from the global-K end that may contain causal-masked # positions. Min k_bound = K - (S_q-1), which can span up to # ceil((fold_sq_ratio-2)/tile_N)+1 tiles (tile-boundary-crossing case). For @@ -2408,33 +2585,13 @@ def compute( first_mask_tile_idx = k_tile_total - mask_tile_count # Phase 1: pure unmasked bulk tiles (all columns strictly < min k_bound). - while k_tile_count > 1 and k_index < first_mask_tile_idx: - ( - mma_s_consumer_state, - p_mma_producer_state, - p_cor_producer_state, - row_max, - row_sum, - correction_factor, - ) = self.softmax( - common_params, - softmax_params, - k_index, - mma_s_consumer_state, - p_mma_producer_state, - p_cor_producer_state, - row_max, - row_sum, - correction_factor, - False, - False, + # 2softmax: each group steps by 2 k-tiles; phase boundaries respect this. + while k_tile_count > 0 and k_index < first_mask_tile_idx: + is_local_last_tile = ( + False + if cutlass.const_expr(common_params.mAccO is None) + else k_tile_count == 1 ) - k_index = k_index + 1 - k_tile_count = k_tile_count - 1 - - # Phase 2: intermediate tiles that overlap the causal/K-bound region - # but are not this work-split's final tile. - while k_tile_count > 1: ( mma_s_consumer_state, p_mma_producer_state, @@ -2452,38 +2609,25 @@ def compute( row_max, row_sum, correction_factor, - True, + is_second_compute_warp, False, + is_local_last_tile, ) - k_index = k_index + 1 + k_index = k_index + 2 k_tile_count = k_tile_count - 1 + if k_tile_count > 0: + p_mma_producer_state, mma_s_consumer_state, p_cor_producer_state = ( + self.softmax_advance_to_next_group( + common_params, + p_mma_producer_state, + mma_s_consumer_state, + p_cor_producer_state, + ) + ) - # Phase 3: this work-split's final tile. - if cutlass.const_expr(common_params.mAccO is not None): - # Split-KV: only apply mask when this final tile is globally in - # the mask region (covers both last-split last-tile and straddling - # splits). Runtime comparison. - ( - mma_s_consumer_state, - p_mma_producer_state, - p_cor_producer_state, - row_max, - row_sum, - correction_factor, - ) = self.softmax( - common_params, - softmax_params, - k_index, - mma_s_consumer_state, - p_mma_producer_state, - p_cor_producer_state, - row_max, - row_sum, - correction_factor, - k_index >= first_mask_tile_idx, - True, - ) - else: + # Phase 2: remaining tiles that overlap the causal / K-bound region, + # including this work-split's final tile. + while k_tile_count > 0: ( mma_s_consumer_state, p_mma_producer_state, @@ -2501,10 +2645,42 @@ def compute( row_max, row_sum, correction_factor, + is_second_compute_warp, True, True, ) + k_index = k_index + 2 + k_tile_count = k_tile_count - 1 + if k_tile_count > 0: + p_mma_producer_state, mma_s_consumer_state, p_cor_producer_state = ( + self.softmax_advance_to_next_group( + common_params, + p_mma_producer_state, + mma_s_consumer_state, + p_cor_producer_state, + ) + ) + if odd_k_tile: + if cutlass.const_expr(is_second_compute_warp): + # next first compute warp in this wave + p_mma_producer_state.advance() + mma_s_consumer_state.advance() + p_cor_producer_state.advance() + # first compute warp in next wave + p_mma_producer_state.advance() + mma_s_consumer_state.advance() + p_cor_producer_state.advance() + else: + p_mma_producer_state.advance() + mma_s_consumer_state.advance() + p_cor_producer_state.advance() + if cutlass.const_expr(is_second_compute_warp): + if odd_k_tile: + self.softmax_order_bar_1.arrive_and_wait() + else: + if not odd_k_tile: + self.softmax_order_bar_0.arrive_and_wait() return mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state @cute.jit @@ -2618,6 +2794,174 @@ def exchange_p_cor_metadata( p_cor_producer_state.advance() return p_cor_producer_state, row_max_new + @cute.jit + def init_p_cor_metadata( + self, + common_params: SimpleNamespace, + softmax_params: SimpleNamespace, + p_cor_producer_state: pipeline.PipelineState, + ) -> None: + """Seed the TMEM correction-factor stage owned by this softmax warp. + + Slot order matches exchange_p_cor_metadata (4 × 32-bit per stage): + rCor[0] = row_sum → 0.0 + rCor[1] = row_max → -inf (the load-bearing seed) + rCor[2] = correction_factor → 1.0 + rCor[3] = no_correction (Int32 via recast) → 1 + + Called by each softmax warpgroup once after producer_acquire on its + starting stage (g0 → stage 0, g1 → stage 1), before entering the K-tile + mainloop. Without this seed, the FIRST tile's load_other_group_metadata + would read uninitialised TMEM and corrupt the running (row_max, row_sum). + """ + init_tidx = common_params.tidx % ( + self.num_compute_warps * self.threads_per_warp + ) + init_tStS_shape = softmax_params.tiled_mma_qk.partition_shape_C( + cute.select(self.mma_qk_tiler, mode=[0, 1]) + ) + init_tStS_layout = softmax_params.tiled_mma_qk.make_fragment_C( + cute.append(init_tStS_shape, self.mma_s_stage) + ).layout + init_tStS = cute.make_tensor(common_params.tmem_ptr, init_tStS_layout) + init_tAcc = init_tStS[(None, None), 0, 0, 0] + + init_corr_layout = cute.make_layout( + (init_tAcc.shape[0], 4, self.mma_s_stage), + stride=(init_tAcc.stride[0], 1, self.tmem_corr_stage_cols), + ) + init_tCor = cute.make_tensor( + common_params.tmem_ptr + self.correction_factor_offset, + init_corr_layout, + ) + init_cCor = cute.make_identity_tensor(init_tCor.shape) + init_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(4)), self.acc_dtype + ) + init_tiled_copy = tcgen05.make_tmem_copy(init_store_atom, init_tCor) + init_thr_copy = init_tiled_copy.get_slice(init_tidx) + init_cCor_part = init_thr_copy.partition_S(init_cCor) + init_tCor_part = init_thr_copy.partition_D(init_tCor) + init_rCor = cute.make_fragment_like( + init_cCor_part[None, None, None, 0], self.acc_dtype + ) + init_rCor_int = cute.make_tensor( + cute.recast_ptr(init_rCor.iterator, dtype=cutlass.Int32), + init_rCor.layout, + ) + init_rCor[0] = self.acc_dtype(0.0) + init_rCor[1] = self.acc_dtype(self.init_row_max) + init_rCor[2] = self.acc_dtype(1.0) + init_rCor_int[3] = cutlass.Int32(1) + cute.copy( + init_tiled_copy, + init_rCor, + init_tCor_part[None, None, None, p_cor_producer_state.index], + ) + cute.arch.fence_view_async_tmem_store() + + @cute.jit + def load_other_group_metadata( + self, + common_params: SimpleNamespace, + softmax_params: SimpleNamespace, + p_cor_producer_state: pipeline.PipelineState, + ) -> tuple[cutlass.Float32, cutlass.Float32]: + """Load (row_max, row_sum) from the OTHER softmax warpgroup's home + p_cor TMEM stage. + + Each group's exchange_p_cor_metadata writes to its OWN home stage + (g0 → stage 0, g1 → stage 1, due to the `advance×2` skip in softmax). + The OPPOSITE stage thus holds the peer's most recent metadata (or the + init_p_cor_metadata seed before any cross-group write). + + Safety: NO pipeline acquire/release — this is a pure cross-group + READ. The caller MUST have done a softmax_order_bar.arrive_and_wait() + with the peer right before this load so the peer's exchange_p_cor TMEM + store is happens-before. + + Returns: (row_max, row_sum) from slot 1 and slot 0 respectively. + """ + other_stage = (p_cor_producer_state.index + 1) % self.mma_s_stage + + tidx = common_params.tidx % (self.num_compute_warps * self.threads_per_warp) + tStS_shape = softmax_params.tiled_mma_qk.partition_shape_C( + cute.select(self.mma_qk_tiler, mode=[0, 1]) + ) + tStS_layout = softmax_params.tiled_mma_qk.make_fragment_C( + cute.append(tStS_shape, self.mma_s_stage) + ).layout + tStS = cute.make_tensor(common_params.tmem_ptr, tStS_layout) + tAcc = tStS[(None, None), 0, 0, 0] + + corr_layout = cute.make_layout( + (tAcc.shape[0], 4, self.mma_s_stage), + stride=(tAcc.stride[0], 1, self.tmem_corr_stage_cols), + ) + tCor = cute.make_tensor( + common_params.tmem_ptr + self.correction_factor_offset, corr_layout + ) + cCor = cute.make_identity_tensor(tCor.shape) + load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(4)), self.acc_dtype + ) + load_tiled_copy = tcgen05.make_tmem_copy(load_atom, tCor) + load_thr_copy = load_tiled_copy.get_slice(tidx) + tCor_part = load_thr_copy.partition_S(tCor) + cCor_part = load_thr_copy.partition_D(cCor) + rCor = cute.make_fragment_like(cCor_part[None, None, None, 0], self.acc_dtype) + cute.copy( + load_tiled_copy, + tCor_part[None, None, None, other_stage], + rCor, + ) + return rCor[1], rCor[0] # (row_max, row_sum) + + @cute.jit + def store_p_cor_row_sum( + self, + common_params: SimpleNamespace, + row_sum: cutlass.Float32, + saved_stage_idx: cutlass.Int32, + tAcc: cute.Tensor, + tidx: cutlass.Int32, + ) -> None: + """Late store of row_sum ONLY (slot 0 of corr region) at the previously + committed stage. No commit / advance — caller already committed via + exchange_p_cor_metadata's full store. Used in pingpong's split-write + scheme: early store commits (row_max + corr + no_corr + placeholder + row_sum) so the correction warp can rescale; this late store patches + in the real row_sum once it has been reduced. Correction warp uses + row_sum only at the LAST tile (separate is_local_last_tile path), so + race-on-row_sum is benign. + """ + corr_layout_1 = cute.make_layout( + (tAcc.shape[0], 1, self.mma_s_stage), + stride=(tAcc.stride[0], 1, self.tmem_corr_stage_cols), + ) + tCor = cute.make_tensor( + common_params.tmem_ptr + self.correction_factor_offset, + corr_layout_1, + ) + cCor = cute.make_identity_tensor(tCor.shape) + store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(1)), self.acc_dtype + ) + tiled_copy = tcgen05.make_tmem_copy(store_atom, tCor) + thr_copy = tiled_copy.get_slice(tidx) + cCor_for_copy = thr_copy.partition_S(cCor) + tCor_for_copy = thr_copy.partition_D(tCor) + rCor = cute.make_fragment_like( + cCor_for_copy[None, None, None, 0], self.acc_dtype + ) + rCor[0] = row_sum + cute.copy( + tiled_copy, + rCor, + tCor_for_copy[None, None, None, saved_stage_idx], + ) + cute.arch.fence_view_async_tmem_store() + @cute.jit def softmax( self, @@ -2630,6 +2974,7 @@ def softmax( row_max: cutlass.Float32, row_sum: cutlass.Float32, correction_factor: cutlass.Float32, + is_second_compute_warp: bool, apply_mask: bool, is_local_last_tile: cutlass.Boolean, ) -> tuple[ @@ -2671,7 +3016,12 @@ def softmax( :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, cutlass.Float32, cutlass.Float32, cutlass.Float32] """ - softmax_params.p_mma_pipeline.producer_acquire(p_mma_producer_state) + softmax_exchange_sync_bar = ( + self.softmax_exchange_sync_bar_1 + if is_second_compute_warp + else self.softmax_exchange_sync_bar_0 + ) + softmax_params.mma_s_pipeline.consumer_wait(mma_s_consumer_state) # load S from tmem @@ -2742,7 +3092,11 @@ def softmax( ) # reduction for row_max row_max_new = tTR_rAcc.load().reduce(cute.ReductionOp.MAX, row_max_new, 0) - elif cutlass.const_expr(arch >= Arch.sm_103 and arch <= Arch.sm_103f): + elif cutlass.const_expr( + (arch >= Arch.sm_101 and arch <= Arch.sm_101f) + or (arch >= Arch.sm_103 and arch <= Arch.sm_103f) + or (arch >= Arch.sm_110 and arch <= Arch.sm_110f) + ): tmem_load_red_atom = cute.make_copy_atom( tcgen05.copy.LdRed32x32bOp( tcgen05.copy.Repetition(64), redOp=tcgen05.TmemLoadRedOp.MAX @@ -2788,27 +3142,75 @@ def softmax( cute.ReductionOp.MAX, row_max_new, 0 ) else: - # sm_103 pre-computed max via reduction is valid here because + # sm_101+ pre-computed max via reduction is valid here because # tTR_rAcc is unmodified (no mask applied to this tile). row_max_new = cute.arch.fmax(row_max_new, tTR_rMax[0]) + # fence between tmem load and mma s + cute.arch.fence_view_async_tmem_load() - # if warps in N is 2, reduce row_max across warps (0, 1) and (2, 3) + softmax_params.mma_s_pipeline.consumer_release(mma_s_consumer_state) + + # Intra-group warps_in_n=2 exchange across warps (0,1)↔(2,3) within + # the group. Each group writes into its own half of softmax_smem_exchange: + # g0 → slots [0, 128), g1 → slots [128, 256). The named barrier covers + # both groups (one arrive_and_wait serves both the intra-group reduce + # AND the cross-group Sync #1 below — see the second smem_exchange read + # which gives this thread the PEER GROUP's row_max). + _group_offset = self.num_compute_warps * self.threads_per_warp + if cutlass.const_expr(is_second_compute_warp): + _my_base = _group_offset + _peer_base = 0 + else: + _my_base = 0 + _peer_base = _group_offset if cutlass.const_expr(self.warps_in_n == 2): - common_params.smem_exchange[tidx] = row_max_new - self.softmax_exchange_sync_bar.wait() + common_params.smem_exchange[_my_base + tidx] = row_max_new + softmax_exchange_sync_bar.arrive_and_wait() row_max_new = cute.arch.fmax( row_max_new, common_params.smem_exchange[ - (tidx + 64) % (self.num_compute_warps * self.threads_per_warp) + _my_base + + (tidx + 64) % (self.num_compute_warps * self.threads_per_warp) ], ) - # find correction factor + # === 2softmax cross-group merge via TMEM peer-read (tunePerf pattern) === + # Pingpong A: wait for the OTHER group to release us. The peer's + # exchange_p_cor_metadata TMEM store on its prev tile is happens-before + # the .arrive() it issues at the bottom of its prev iter, so this + # arrive_and_wait gives us the acquire memory ordering needed for the + # load_other_group_metadata read below. g0's first wait is satisfied + # by g1's pre-arrive of bar_0 at warp setup (init-phase trick). + if cutlass.const_expr(is_second_compute_warp): + self.softmax_order_bar_1.arrive_and_wait() + else: + self.softmax_order_bar_0.arrive_and_wait() + # cute.nvgpu.cfence() + + # Serial inheritance: peer's prev-tile metadata IS the GLOBAL running + # state right before THIS tile in serial tile order (pingpong serializes + # tiles 0,1,2,3,... across g0/g1). Override (row_max, row_sum) with + # peer's prev values so the subsequent correction = exp2(prev_global - + # this_global) and the row_sum update gives running_sum_after_this_tile + # = GLOBAL state. + other_row_max, other_row_sum = self.load_other_group_metadata( + common_params, softmax_params, p_cor_producer_state + ) + row_max_new = cute.arch.fmax(row_max_new, other_row_max) + row_max = other_row_max + row_sum = other_row_sum + + # find correction factor (uses inherited row_max from peer = prev global max) correction_factor = cute.math.exp2( (row_max - row_max_new) * softmax_params.softmax_scale_log2, fastmath=True ) - # split kv case - if cutlass.const_expr(not is_local_last_tile): + saved_p_cor_idx = p_cor_producer_state.index + # Early store of (row_max_new, correction_factor, no_correction) — row_sum + # field carries the inherited peer-prev value (the global sum BEFORE this + # tile), which is a placeholder; the real updated row_sum is patched in + # via store_p_cor_row_sum below at saved_p_cor_idx. Safe because the + # correction warp uses row_sum only at the LAST tile path (separate is_local_last_tile branch). + if not is_local_last_tile: p_cor_producer_state, row_max_new = self.exchange_p_cor_metadata( common_params, softmax_params, @@ -2873,6 +3275,8 @@ def softmax( smem_thr_copy = smem_tiled_copy.get_slice(tidx) rP_copy_view = smem_thr_copy.retile(tTR_rS) sP_copy_view = smem_thr_copy.partition_D(sP_mk_view) + + softmax_params.p_mma_pipeline.producer_acquire(p_mma_producer_state) cute.copy(smem_tiled_copy, rP_copy_view, sP_copy_view) # fence between smem store and mma o @@ -2880,7 +3284,13 @@ def softmax( softmax_params.p_mma_pipeline.producer_commit(p_mma_producer_state) p_mma_producer_state.advance() - # row_sum, using `add_packed_f32x2` to reduce the number of instructions + # row_sum update: tunePerf peer-inheritance pattern. + # row_sum entering this block = peer's prev-tile row_sum (the global + # running sum BEFORE this tile, inherited via load_other_group_metadata). + # Apply correction (rescale prev to new row_max base) and add this tile's + # locally reduced exp sum → running_sum_after_this_tile = GLOBAL state. + # No cross-group SMEM exchange needed; the next iter's peer-read picks + # up this group's freshly written row_sum from TMEM corr. row_sum = row_sum * correction_factor row_sum_vec = (0.0, 0.0) for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2): @@ -2889,8 +3299,26 @@ def softmax( ) row_sum = row_sum_vec[0] + row_sum_vec[1] + row_sum + # Late store of row_sum ONLY — patches in the real row_sum at the slot + # committed by the early store above so peer's next-iter + # load_other_group_metadata sees the updated global row_sum. No commit + # / advance (already done by early store). + if not is_local_last_tile: + self.store_p_cor_row_sum( + common_params, + row_sum, + saved_p_cor_idx, + tAcc, + tidx, + ) + + # Pingpong B — signal the OTHER group to start its critical section. + # Split-phase .arrive() contributes this group's threads to the other + # group's bar and falls through immediately. cfence: control-flow fence + # to prevent ptxas from hoisting subsequent work above bar.arrive. + # split kv case - if cutlass.const_expr(is_local_last_tile): + if is_local_last_tile: p_cor_producer_state, row_max_new = self.exchange_p_cor_metadata( common_params, softmax_params, @@ -2902,14 +3330,13 @@ def softmax( tidx, p_cor_producer_state, ) + # cute.nvgpu.cfence() + if cutlass.const_expr(is_second_compute_warp): + self.softmax_order_bar_0.arrive() # g1 → g0 + else: + self.softmax_order_bar_1.arrive() # g0 → g1 + # cute.nvgpu.cfence() - # store correction factor/row_sum/row_max to tmem for correction warp - common_params.p_cor_pipeline.producer_acquire(p_cor_producer_state) - - # fence between tmem load and mma s - cute.arch.fence_view_async_tmem_load() - - softmax_params.mma_s_pipeline.consumer_release(mma_s_consumer_state) mma_s_consumer_state.advance() return ( @@ -3320,7 +3747,7 @@ def make_and_init_load_qkv_pipeline( pipeline.Agent.Thread, len([self.load_tma_k_warp_id]) ) load_qkv_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, len([self.mma_warp_id]) + pipeline.Agent.Thread, len([self.mma_qk_warp_id]) ) return pipeline.PipelineTmaUmma.create( barrier_storage=load_qkv_mbar_ptr, @@ -3347,7 +3774,7 @@ def make_and_init_mma_s_pipeline( """ mma_s_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, len([self.mma_warp_id]) + pipeline.Agent.Thread, len([self.mma_qk_warp_id]) ) consumer_thread_size = ( self.threads_per_warp @@ -3391,7 +3818,7 @@ def make_and_init_p_mma_pipeline( producer_thread_size, ) p_mma_consumer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, len([self.mma_warp_id]) + pipeline.Agent.Thread, len([self.mma_pv_warp_id]) ) return pipeline.PipelineAsyncUmma.create( barrier_storage=p_mma_mbar_ptr, @@ -3446,7 +3873,7 @@ def make_and_init_mma_o_pipeline( """ mma_o_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, len([self.mma_warp_id]) + pipeline.Agent.Thread, len([self.mma_pv_warp_id]) ) consumer_thread_size = ( self.threads_per_warp From 39070d7bb172cd360a9338dc82a47f9b54f08c01 Mon Sep 17 00:00:00 2001 From: jingzec Date: Mon, 25 May 2026 06:44:36 -0700 Subject: [PATCH 6/9] add lse support --- .../attention/monolithic/mla_decode.py | 61 +++++++++++--- .../attention/monolithic/mla_decode_fp16.py | 13 ++- .../attention/monolithic/mla_decode_fp8.py | 13 ++- .../cute_dsl/attention/wrappers/batch_mla.py | 33 +++++++- flashinfer/mla/_core.py | 27 +++--- tests/attention/test_cute_dsl_mla_decode.py | 82 +++++++++++++++++-- 6 files changed, 194 insertions(+), 35 deletions(-) diff --git a/flashinfer/cute_dsl/attention/monolithic/mla_decode.py b/flashinfer/cute_dsl/attention/monolithic/mla_decode.py index 722a690a06..abcc10ddda 100644 --- a/flashinfer/cute_dsl/attention/monolithic/mla_decode.py +++ b/flashinfer/cute_dsl/attention/monolithic/mla_decode.py @@ -21,7 +21,7 @@ """ import functools -from typing import Callable, Optional, Tuple +from typing import Callable, Optional, Tuple, Union import cutlass import cutlass.cute as cute @@ -320,7 +320,9 @@ def cute_dsl_mla_decode( out_dtype: Optional[torch.dtype] = None, is_var_seq: bool = True, enable_pdl: Optional[bool] = None, -) -> torch.Tensor: + lse: Optional[torch.Tensor] = None, + return_lse: bool = False, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """CuTe DSL MLA decode kernel for Blackwell SM100. Parameters @@ -364,11 +366,26 @@ def cute_dsl_mla_decode( enable_pdl : Optional[bool], default=None Whether to enable Programmatic Dependent Launch (PDL). If None, auto-detects based on device capability. + lse : Optional[torch.Tensor] + Pre-allocated Log-Sum-Exp buffer. Accepted shapes (dtype must be + ``torch.float32``): + + * ``[B, q_len, H]`` (native kernel layout, no reshape), or + * ``[B * q_len, H]`` (matches ``trtllm-gen`` shape; the wrapper + reshapes via ``.view`` to the native layout). + + If ``return_lse`` is True and this is None, a buffer of the native + ``[B, q_len, H]`` shape is allocated internally. + return_lse : bool + Whether to return LSE values. When True, the function returns + ``(out, lse)`` (the ``lse`` tensor returned is in whatever shape + the caller supplied; if no ``lse`` was supplied, ``[B, q_len, H]``). Returns ------- - torch.Tensor - Output tensor [B, q_len, H, kv_lora_rank]. + torch.Tensor or Tuple[torch.Tensor, torch.Tensor] + Output tensor [B, q_len, H, kv_lora_rank] when ``return_lse=False``; + otherwise ``(output, lse)``. """ supported_dtypes = {torch.float16, torch.bfloat16, torch.float8_e4m3fn} assert query.dtype in supported_dtypes, ( @@ -442,7 +459,24 @@ def cute_dsl_mla_decode( ) # LSE: contiguous [B, q_len, H]. Kernel reinterprets to [H, q_len, B]. - lse_k = torch.empty((B, q_len, H), dtype=torch.float32, device=query.device) + # If caller supplied an `lse` buffer in either the native 3D shape or the + # 2D trtllm-gen shape [B*q_len, H], reshape to the 3D native layout for + # the kernel call. + if lse is not None: + if lse.dtype != torch.float32: + raise ValueError(f"lse must be torch.float32, got {lse.dtype}") + if lse.shape == (B, q_len, H): + lse_k = lse + elif lse.shape == (B * q_len, H): + # Native kernel layout is 3D; .view is zero-cost when contiguous. + lse_k = lse.view(B, q_len, H) + else: + raise ValueError( + f"lse must have shape (B, q_len, H)=({B}, {q_len}, {H}) " + f"or (B*q_len, H)=({B * q_len}, {H}); got {tuple(lse.shape)}" + ) + else: + lse_k = torch.empty((B, q_len, H), dtype=torch.float32, device=query.device) # cache_seqs: per-batch sequence lengths (skip .to() if already int32) cache_seqs = seq_lens if seq_lens.dtype == torch.int32 else seq_lens.to(torch.int32) @@ -506,9 +540,16 @@ def cute_dsl_mla_decode( Float32(output_scale), ) - # If out was provided, kernel already wrote into it — return directly. - if out is not None: - return out + # Pick the output to return: caller-provided buffer (already written + # in-place) or the freshly allocated o_k. o_k is [B, q_len, H, D], + # matching trtllm-gen output shape. + out_tensor = out if out is not None else o_k + + if return_lse: + # Return the lse tensor in the shape the caller supplied (or 3D when + # we allocated it). When caller passed 2D, lse_k is a .view into + # that same memory, so returning the original `lse` keeps the + # caller's expected shape. + return out_tensor, (lse if lse is not None else lse_k) - # o_k is [B, q_len, H, D] — return as-is to match trtllm-gen output shape. - return o_k + return out_tensor diff --git a/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.py b/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.py index b871f7aa6b..f16695bda7 100644 --- a/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.py +++ b/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.py @@ -1452,7 +1452,12 @@ def reduction_kernel( else self.lse_dtype.inf ) if tidx == 0: - mLSE[blk_coord[0], blk_coord[1], blk_coord[2]] = global_lse + # Convert from kernel-internal log2 base to the natural-log + # convention exposed to callers (matches trtllm-gen / flash-attn). + # `1.0 / LOG2_E == ln(2)`. + mLSE[blk_coord[0], blk_coord[1], blk_coord[2]] = global_lse * ( + 1.0 / LOG2_E + ) # store the scale to shared memory for i in cutlass.range_constexpr(lse_per_thread): split_kv_idx = tidx + i * self.threads_per_warp @@ -3257,6 +3262,12 @@ def epilogue( cute.math.log2(row_sum, fastmath=True) + epilogue_params.softmax_scale_log2 * row_max ) + # When writing directly to the user-facing mLSE (single-tile, + # no split-KV merge), convert from log2 base to natural log. + # When writing the per-split intermediate (mAccLSE branch), keep + # log2 base so the merge code above can use exp2 / log2 ops. + if cutlass.const_expr(epilogue_params.mAccLSE is None): + lse = lse * (1.0 / LOG2_E) if cutlass.const_expr(self.warps_in_n == 2): if cute.elem_less(cLSE[tidx][0], common_params.H): gLSE[tidx] = lse diff --git a/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp8.py b/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp8.py index 3bc35451e4..4388d3dbfc 100644 --- a/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp8.py +++ b/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp8.py @@ -1682,7 +1682,12 @@ def reduction_kernel( else self.lse_dtype.inf ) if tidx == 0: - mLSE[blk_coord[0], blk_coord[1], blk_coord[2]] = global_lse + # Convert from kernel-internal log2 base to the natural-log + # convention exposed to callers (matches trtllm-gen / flash-attn). + # `1.0 / LOG2_E == ln(2)`. + mLSE[blk_coord[0], blk_coord[1], blk_coord[2]] = global_lse * ( + 1.0 / LOG2_E + ) # store the scale to shared memory for i in cutlass.range_constexpr(lse_per_thread): split_kv_idx = tidx + i * self.threads_per_warp @@ -3683,6 +3688,12 @@ def epilogue( cute.math.log2(row_sum, fastmath=True) + epilogue_params.softmax_scale_log2 * row_max ) + # When writing directly to the user-facing mLSE (single-tile, + # no split-KV merge), convert from log2 base to natural log. + # When writing the per-split intermediate (mAccLSE branch), keep + # log2 base so the merge code above can use exp2 / log2 ops. + if cutlass.const_expr(epilogue_params.mAccLSE is None): + lse = lse * (1.0 / LOG2_E) if cutlass.const_expr(self.warps_in_n == 2): if cute.elem_less(cLSE[tidx][0], common_params.H): gLSE[tidx] = lse diff --git a/flashinfer/cute_dsl/attention/wrappers/batch_mla.py b/flashinfer/cute_dsl/attention/wrappers/batch_mla.py index 76d5975531..2dec49b283 100644 --- a/flashinfer/cute_dsl/attention/wrappers/batch_mla.py +++ b/flashinfer/cute_dsl/attention/wrappers/batch_mla.py @@ -12,7 +12,7 @@ """ import functools -from typing import Callable, Optional, Tuple +from typing import Callable, Optional, Tuple, Union import cutlass import cutlass.cute as cute @@ -689,7 +689,9 @@ def cute_dsl_mla_decode( is_var_seq: bool = True, enable_pdl: Optional[bool] = None, sinks: Optional[torch.Tensor] = None, -) -> torch.Tensor: + lse: Optional[torch.Tensor] = None, + return_lse: bool = False, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """CuTe DSL MLA decode kernel for Blackwell SM100 (modular variant). Parameters @@ -735,6 +737,16 @@ def cute_dsl_mla_decode( ``AttentionWithSink`` variant). Shape ``(num_qo_heads,)``; will be cast to float32 internally. When ``None`` (default), runs standard softmax attention. + lse : Optional[torch.Tensor] + **Not supported on the modular path yet** — raises + :class:`NotImplementedError` when non-None. Use the monolithic + path (``cute_dsl_impl='monolithic'`` or the default + ``cute_dsl_impl='auto'`` when no modular-only feature is + requested) for LSE output. + return_lse : bool + **Not supported on the modular path yet** — raises + :class:`NotImplementedError` when True. Same workaround as + ``lse=``. Returns ------- @@ -814,7 +826,19 @@ def cute_dsl_mla_decode( (B, q_len, H, kv_lora_rank), dtype=o_dtype, device=query.device ) - # LSE buffer + # LSE: the modular path writes LSE in log2 base directly to its internal + # buffer and does not convert. Exposing that as a user-facing tensor + # would silently disagree with the trtllm-gen + monolithic convention + # (natural log), so explicitly refuse the request here until the + # modular kernel is updated to convert at the final store site. + if return_lse or lse is not None: + raise NotImplementedError( + "cute_dsl_mla_decode modular path does not support return_lse / " + "lse output yet — use cute_dsl_impl='monolithic' (default 'auto' " + "also picks monolithic when no modular-only feature is requested) " + "for LSE support." + ) + # Internal buffer for the kernel call; never returned to the user. lse_k = torch.empty((B, q_len, H), dtype=torch.float32, device=query.device) # cache_seqs: per-batch sequence lengths @@ -914,7 +938,8 @@ def cute_dsl_mla_decode( params_torch, # variant params tensor (None when no variant) ) + # `return_lse=True` is guarded above for the modular path, so we only + # return the output tensor here. if out is not None: return out - return o_k diff --git a/flashinfer/mla/_core.py b/flashinfer/mla/_core.py index 628d67658b..7810f110a0 100644 --- a/flashinfer/mla/_core.py +++ b/flashinfer/mla/_core.py @@ -1127,14 +1127,20 @@ def trtllm_batch_decode_with_kv_cache_mla( False uses TRT-LLM layout with a 3D page table ``[batch_size, 2, max_num_pages_per_seq]``. False is only supported for trtllm-gen backend. lse : Optional[torch.Tensor] = None - Optional pre-allocated buffer for Log-Sum-Exp values. Only supported by - ``trtllm-gen`` backend. Must have shape - ``[batch_size * q_len_per_request, num_qo_heads]`` with dtype - ``torch.float32``. If ``return_lse`` is True and this is None, a buffer - will be allocated. + Optional pre-allocated buffer for Log-Sum-Exp values. Supported by + ``trtllm-gen`` and ``cute-dsl`` backends. Must have dtype + ``torch.float32``. Accepted shapes: + + * ``[batch_size * q_len_per_request, num_qo_heads]`` (trtllm-gen + native; accepted by both backends), or + * ``[batch_size, q_len_per_request, num_qo_heads]`` (cute-dsl native; + also accepted by cute-dsl). + + If ``return_lse`` is True and this is None, a buffer will be + allocated by the backend. return_lse : bool = False - Whether to return LSE values. Only supported by ``trtllm-gen`` backend. - When True, the function returns ``(out, lse)``. + Whether to return LSE values. Supported by ``trtllm-gen`` and + ``cute-dsl`` backends. When True, the function returns ``(out, lse)``. cute_dsl_impl : str = "auto" Which cute-dsl implementation to use. Honored only when ``backend="cute-dsl"``; ignored for other backends. @@ -1376,11 +1382,6 @@ def trtllm_batch_decode_with_kv_cache_mla( "cute-dsl backend (MLA decode kernel) does not support separate KV page indices " "(uses_shared_paged_kv_idx=False)" ) - if return_lse or lse is not None: - raise NotImplementedError( - "cute-dsl backend (MLA decode kernel) does not support return_lse/lse output" - ) - return cute_dsl_mla_decode( query=query, kv_cache=kv_cache, @@ -1397,6 +1398,8 @@ def trtllm_batch_decode_with_kv_cache_mla( enable_pdl=enable_pdl, sinks=cute_dsl_sinks, cute_dsl_impl=cute_dsl_impl, + lse=lse, + return_lse=return_lse, ) else: raise ValueError(f"Backend {backend} not supported") diff --git a/tests/attention/test_cute_dsl_mla_decode.py b/tests/attention/test_cute_dsl_mla_decode.py index dd70943ead..9e5f292004 100644 --- a/tests/attention/test_cute_dsl_mla_decode.py +++ b/tests/attention/test_cute_dsl_mla_decode.py @@ -51,6 +51,7 @@ def torch_reference_mla( output_scale, page_size, apply_mtp_mask=False, + return_lse=False, ): """PyTorch reference implementation for MLA decode. @@ -61,6 +62,10 @@ def torch_reference_mla( not apply this mask, so callers exercising the modular path should leave ``apply_mtp_mask=False``. + When ``return_lse=True``, also returns the Log-Sum-Exp of the + pre-softmax scores: ``LSE = log(sum(exp(QK^T * softmax_scale)))`` + in natural log, matching the cute_dsl kernel's LSE convention. + Args: q_nope: [B, q_len, H, latent_dim] q_rope: [B, q_len, H, rope_dim] @@ -72,10 +77,12 @@ def torch_reference_mla( output_scale: float page_size: int apply_mtp_mask: bool — whether to apply the MTP causal mask. + return_lse: bool — also return LSE [B, q_len, H] (float32). """ B, q_len, H, latent_dim = q_nope.shape outputs = [] + lses = [] for b in range(B): seq_len = cache_seqs[b].item() num_pages_needed = (seq_len + page_size - 1) // page_size @@ -111,6 +118,10 @@ def torch_reference_mla( mask[qi, :upper] = True attn = attn.masked_fill(~mask.unsqueeze(1), float("-inf")) + if return_lse: + # LSE = logsumexp over the KV dimension (natural log). + lses.append(torch.logsumexp(attn, dim=-1)) # [q_len, H] + # Softmax attn = F.softmax(attn, dim=-1) @@ -120,7 +131,10 @@ def torch_reference_mla( out_b = out_b * output_scale outputs.append(out_b) - return torch.stack(outputs, dim=0) # [B, q_len, H, latent_dim] + out_stack = torch.stack(outputs, dim=0) # [B, q_len, H, latent_dim] + if return_lse: + return out_stack, torch.stack(lses, dim=0) # ([B,q_len,H,D], [B,q_len,H]) + return out_stack @pytest.mark.parametrize("batch_size", [1, 4, 32]) @@ -175,8 +189,12 @@ def test_cute_dsl_mla_decode_fp16( # Workspace workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device=device) - # Run kernel - out = cute_dsl_mla_decode( + # Run kernel. Request LSE in the native 3D [B, q_len, H] shape; the + # wrapper also accepts [B*q_len, H] (trtllm-gen shape) which gets + # reshaped internally. + # LSE output is currently monolithic-only; the modular path raises + # NotImplementedError, so only request it on the monolithic path. + result = cute_dsl_mla_decode( query=query, kv_cache=kv_cache, workspace_buffer=workspace_buffer, @@ -190,7 +208,15 @@ def test_cute_dsl_mla_decode_fp16( is_var_seq=False, enable_pdl=enable_pdl, cute_dsl_impl=cute_dsl_impl, + return_lse=(cute_dsl_impl == "monolithic"), ) + if cute_dsl_impl == "monolithic": + out, lse = result + assert lse.dtype == torch.float32 + assert lse.shape == (batch_size, q_len, num_heads) + else: + out = result + lse = None # Reference kv_flat = kv_cache.reshape(-1, latent_dim + rope_dim) @@ -200,7 +226,7 @@ def test_cute_dsl_mla_decode_fp16( q_rope = query[..., latent_dim:] # Monolithic applies the MTP causal mask for q_len > 1; modular does not. - ref_out = torch_reference_mla( + ref = torch_reference_mla( q_nope, q_rope, c_latent_ref, @@ -211,12 +237,21 @@ def test_cute_dsl_mla_decode_fp16( output_scale, page_size, apply_mtp_mask=(cute_dsl_impl == "monolithic"), + return_lse=(cute_dsl_impl == "monolithic"), ) + if cute_dsl_impl == "monolithic": + ref_out, ref_lse = ref + else: + ref_out = ref + ref_lse = None ref_out_cast = ref_out.to(dtype) # Check with tolerance appropriate for FP16/BF16 torch.testing.assert_close(out, ref_out_cast, atol=1e-2, rtol=1e-2) + if cute_dsl_impl == "monolithic": + # LSE is float32 — tighter tolerance. + torch.testing.assert_close(lse, ref_lse, atol=1e-2, rtol=1e-2) # Exercises the spec-decoding (MTP) causal mask + fold_sq path: num_heads < 128 @@ -639,7 +674,16 @@ def test_cute_dsl_mla_decode_fp8( seq_lens = torch.full((batch_size,), seq_len_k, dtype=torch.int32, device=device) workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device=device) - out = cute_dsl_mla_decode( + # Exercise the 2D trtllm-gen-style lse buffer here for coverage when + # available (monolithic only — the modular path raises NotImplementedError + # for LSE output). The wrapper reshapes via .view to the kernel's native + # [B, q_len, H] layout. + lse_buf = ( + torch.empty((batch_size * q_len, num_heads), dtype=torch.float32, device=device) + if cute_dsl_impl == "monolithic" + else None + ) + result = cute_dsl_mla_decode( query=query, kv_cache=kv_cache, workspace_buffer=workspace_buffer, @@ -652,7 +696,19 @@ def test_cute_dsl_mla_decode_fp8( output_scale=output_scale, enable_pdl=enable_pdl, cute_dsl_impl=cute_dsl_impl, - ) + lse=lse_buf, + return_lse=(cute_dsl_impl == "monolithic"), + ) + if cute_dsl_impl == "monolithic": + out, lse = result + # Caller-supplied buffer must be returned (identity), not a copy. + assert lse.data_ptr() == lse_buf.data_ptr() + assert lse.shape == (batch_size * q_len, num_heads) + assert lse.dtype == torch.float32 + assert torch.isfinite(lse).all(), "FP8 cute-dsl MLA LSE produced non-finite" + else: + out = result + lse = None assert out.dtype == torch.bfloat16 assert out.shape == (batch_size, q_len, num_heads, latent_dim) @@ -665,7 +721,7 @@ def test_cute_dsl_mla_decode_fp8( q_nope = query[..., :latent_dim].to(torch.float32) q_rope_tensor = query[..., latent_dim:].to(torch.float32) - ref_out = torch_reference_mla( + ref = torch_reference_mla( q_nope, q_rope_tensor, c_latent_ref, @@ -675,11 +731,23 @@ def test_cute_dsl_mla_decode_fp8( softmax_scale, output_scale, page_size, + return_lse=(cute_dsl_impl == "monolithic"), ) + if cute_dsl_impl == "monolithic": + ref_out, ref_lse = ref + else: + ref_out = ref + ref_lse = None # Compare outputs in FP32; FP8 has limited precision so use wider tolerance torch.testing.assert_close( out.to(torch.float32), ref_out.to(torch.float32), atol=0.1, rtol=0.1 ) + if cute_dsl_impl == "monolithic": + # LSE reshaped back to native shape for comparison. FP8 quantization + # noise propagates into LSE so use the same wide tolerance as `out`. + torch.testing.assert_close( + lse.view(batch_size, q_len, num_heads), ref_lse, atol=0.1, rtol=0.1 + ) # --------------------------------------------------------------------------- From 9f0d7c1914cbd29c0ea24e7a9c161928e3e96ebe Mon Sep 17 00:00:00 2001 From: jingzec Date: Mon, 25 May 2026 09:00:43 -0700 Subject: [PATCH 7/9] minor fix --- flashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.py | 4 ++-- flashinfer/cute_dsl/attention/monolithic/mla_decode_fp8.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.py b/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.py index f16695bda7..26285fdbc2 100644 --- a/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.py +++ b/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp16.py @@ -2370,10 +2370,10 @@ def compute( # Number of tiles from the global-K end that may contain causal-masked # positions. Min k_bound = K - (S_q-1), which can span up to - # ceil((fold_sq_ratio-2)/tile_N)+1 tiles (tile-boundary-crossing case). For + # ceil((seq_len_q-2)/tile_N)+1 tiles (tile-boundary-crossing case). For # S_q=1 this reduces to 1 tile — identical to a plain K-bound check. tile_n = self.mma_qk_tiler[1] - mask_tile_count = (self.fold_sq_ratio - 2 + tile_n - 1) // tile_n + 1 + mask_tile_count = (self.seq_len_q - 2 + tile_n - 1) // tile_n + 1 # first_mask_tile_idx is the global index of the first tile that may # need masking. Runtime because it depends on K (per-batch in diff --git a/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp8.py b/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp8.py index 4388d3dbfc..eebb6ecd71 100644 --- a/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp8.py +++ b/flashinfer/cute_dsl/attention/monolithic/mla_decode_fp8.py @@ -2546,10 +2546,10 @@ def compute( self.softmax_order_bar_0.arrive() # Number of tiles from the global-K end that may contain causal-masked # positions. Min k_bound = K - (S_q-1), which can span up to - # ceil((fold_sq_ratio-2)/tile_N)+1 tiles (tile-boundary-crossing case). For + # ceil((seq_len_q-2)/tile_N)+1 tiles (tile-boundary-crossing case). For # S_q=1 this reduces to 1 tile — identical to a plain K-bound check. tile_n = self.mma_qk_tiler[1] - mask_tile_count = (self.fold_sq_ratio - 2 + tile_n - 1) // tile_n + 1 + mask_tile_count = (self.seq_len_q - 2 + tile_n - 1) // tile_n + 1 # first_mask_tile_idx is the global index of the first tile that may # need masking. Runtime because it depends on K (per-batch in From 131dab485b97cfd3fe80da15e5a67340f032dfd5 Mon Sep 17 00:00:00 2001 From: jingzec Date: Mon, 25 May 2026 18:36:46 -0700 Subject: [PATCH 8/9] fix mypy --- .../cute_dsl/attention/wrappers/batch_mla.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/flashinfer/cute_dsl/attention/wrappers/batch_mla.py b/flashinfer/cute_dsl/attention/wrappers/batch_mla.py index 2dec49b283..49518c9847 100644 --- a/flashinfer/cute_dsl/attention/wrappers/batch_mla.py +++ b/flashinfer/cute_dsl/attention/wrappers/batch_mla.py @@ -688,9 +688,9 @@ def cute_dsl_mla_decode( out_dtype: Optional[torch.dtype] = None, is_var_seq: bool = True, enable_pdl: Optional[bool] = None, - sinks: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, return_lse: bool = False, + sinks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """CuTe DSL MLA decode kernel for Blackwell SM100 (modular variant). @@ -731,12 +731,6 @@ def cute_dsl_mla_decode( enable_pdl : Optional[bool], default=None Whether to enable Programmatic Dependent Launch (PDL). If None, auto-detects based on device capability. - sinks : Optional[torch.Tensor], default=None - Per-head sink values added to the softmax denominator on the first - KV tile (modular-only feature, implemented via the - ``AttentionWithSink`` variant). Shape ``(num_qo_heads,)``; will be - cast to float32 internally. When ``None`` (default), runs standard - softmax attention. lse : Optional[torch.Tensor] **Not supported on the modular path yet** — raises :class:`NotImplementedError` when non-None. Use the monolithic @@ -747,6 +741,14 @@ def cute_dsl_mla_decode( **Not supported on the modular path yet** — raises :class:`NotImplementedError` when True. Same workaround as ``lse=``. + sinks : Optional[torch.Tensor], default=None + Per-head sink values added to the softmax denominator on the first + KV tile (modular-only feature, implemented via the + ``AttentionWithSink`` variant). Shape ``(num_qo_heads,)``; will be + cast to float32 internally. When ``None`` (default), runs standard + softmax attention. Kept as the last parameter so the modular + signature is a strict prefix-extension of the monolithic one (lets + ``mla_dispatch._impl`` assignment type-check across both branches). Returns ------- From 375dc752ff1417d35309f36e9dfdeebaf000e967 Mon Sep 17 00:00:00 2001 From: jingzec Date: Wed, 27 May 2026 23:54:47 -0700 Subject: [PATCH 9/9] fix comment --- flashinfer/mla/_core.py | 33 +++++++++++++++++---- tests/attention/test_cute_dsl_mla_decode.py | 2 +- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/flashinfer/mla/_core.py b/flashinfer/mla/_core.py index 24b8001d8f..80b4b83197 100644 --- a/flashinfer/mla/_core.py +++ b/flashinfer/mla/_core.py @@ -1825,8 +1825,6 @@ def trtllm_batch_decode_with_kv_cache_mla( sm_count = get_device_sm_count(query.device) block_size = kv_cache.size(-2) - if block_size != 32 and block_size != 64: - raise ValueError(f"Supported block_size are 32 and 64, got {block_size}") if skip_softmax_threshold_scale_factor is not None and sparse_mla_top_k != 0: raise ValueError("skip_softmax is not supported for sparse MLA") @@ -1858,12 +1856,33 @@ def trtllm_batch_decode_with_kv_cache_mla( "out", ) + # Remember the caller-supplied lse so we can return it in its original + # shape: 2D ``(B*q_len, H)`` stays 2D, 3D ``(B, q_len, H)`` stays 3D, and + # an allocated default stays 2D. Internally we normalize to 2D for the + # backend dispatch (matches trtllm-gen's native layout). + user_lse = lse if return_lse: - lse_shape = (query.size(0) * query.size(1), query.size(2)) + flat_lse_shape = (query.size(0) * query.size(1), query.size(2)) + nested_lse_shape = (query.size(0), query.size(1), query.size(2)) if lse is None: - lse = torch.empty(lse_shape, dtype=torch.float32, device=query.device) + lse = torch.empty(flat_lse_shape, dtype=torch.float32, device=query.device) + user_lse = lse + elif tuple(lse.shape) == flat_lse_shape: + check_shape_dtype_device( + lse, flat_lse_shape, torch.float32, query.device, "lse" + ) + elif tuple(lse.shape) == nested_lse_shape: + check_shape_dtype_device( + lse, nested_lse_shape, torch.float32, query.device, "lse" + ) + # Normalize to 2D for the backend; .view shares storage so the + # kernel writes propagate back to user_lse automatically. + lse = lse.view(flat_lse_shape) else: - check_shape_dtype_device(lse, lse_shape, torch.float32, query.device, "lse") + raise ValueError( + f"lse must have shape {flat_lse_shape} or {nested_lse_shape}; " + f"got {tuple(lse.shape)}" + ) page_size = kv_cache.shape[-2] cute_dsl_reason = _cute_dsl_incompatibility_reason( @@ -1966,7 +1985,9 @@ def trtllm_batch_decode_with_kv_cache_mla( ) runner(inputs=inputs, tactic=tactic) if return_lse: - return out, lse + # Return the lse in the same shape the caller supplied (2D or 3D), + # or 2D ``(B*q_len, H)`` when we allocated the default. + return out, user_lse return out diff --git a/tests/attention/test_cute_dsl_mla_decode.py b/tests/attention/test_cute_dsl_mla_decode.py index d0d0a5225a..9e5f292004 100644 --- a/tests/attention/test_cute_dsl_mla_decode.py +++ b/tests/attention/test_cute_dsl_mla_decode.py @@ -494,7 +494,7 @@ def test_cute_dsl_mla_decode_variable_seq_len( @pytest.mark.parametrize("seq_len_k", [128, 512]) @pytest.mark.parametrize("num_heads", [128, 64]) def test_cute_dsl_mla_decode_via_api( - batch_size, seq_len_k, num_heads, cute_dsl_impl, page_size=32, enable_pdl=False + batch_size, seq_len_k, num_heads, cute_dsl_impl, page_size=128, enable_pdl=False ): """Test MLA decode via the trtllm_batch_decode_with_kv_cache_mla API with cute-dsl backend.""" skip_if_unsupported()