From 0129aebbc5916b9e14689264d665eeef58608c66 Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Tue, 10 Mar 2026 00:35:14 -0700 Subject: [PATCH 01/31] feat: Add CuTe DSL MLA decode kernel for Blackwell SM100 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Integrates NVIDIA's CuTe DSL MLA decode kernels (FP16/FP8) for Blackwell SM100 as a new "cute-dsl" backend in trtllm_batch_decode_with_kv_cache_mla(). Key tensor layout insights documented in mla_decode.py: - c_latent/c_rope kernel layout is [page_size, D, total_pages], not [total_tokens, D, 1] — the kernel indexes KV intra-page per physical page - All fake tensor dimensions must be cute.sym_int() (not static Python ints) so cute.assume() receives CuTe Integer types in initialize_workspace() - lse fake tensor needs stride_order=(0,1,2) for stride[0]=1 compile-time constant - Do NOT call .contiguous() after .permute() on q/lse/o tensors — it collapses to row-major, destroying required non-standard strides - Separate sym_kv_batch for KV cache (=1, flat pool) vs query batch (=B) New files: - flashinfer/cute_dsl/mla_helpers.py - flashinfer/cute_dsl/mla_decode_fp16.py - flashinfer/cute_dsl/mla_decode_fp8.py - flashinfer/cute_dsl/mla_decode.py (compilation wrapper + public API) - tests/attention/test_cute_dsl_mla_decode.py (14 tests, all passing) --- flashinfer/cute_dsl/__init__.py | 3 + flashinfer/cute_dsl/mla_decode.py | 390 ++ flashinfer/cute_dsl/mla_decode_fp16.py | 4385 +++++++++++++++++++ flashinfer/cute_dsl/mla_decode_fp8.py | 4356 ++++++++++++++++++ flashinfer/cute_dsl/mla_helpers.py | 304 ++ flashinfer/mla.py | 16 + tests/attention/test_cute_dsl_mla_decode.py | 290 ++ 7 files changed, 9744 insertions(+) create mode 100644 flashinfer/cute_dsl/mla_decode.py create mode 100644 flashinfer/cute_dsl/mla_decode_fp16.py create mode 100644 flashinfer/cute_dsl/mla_decode_fp8.py create mode 100644 flashinfer/cute_dsl/mla_helpers.py create mode 100644 tests/attention/test_cute_dsl_mla_decode.py diff --git a/flashinfer/cute_dsl/__init__.py b/flashinfer/cute_dsl/__init__.py index 940031453d..6b9b31a2af 100644 --- a/flashinfer/cute_dsl/__init__.py +++ b/flashinfer/cute_dsl/__init__.py @@ -53,6 +53,7 @@ add_rmsnorm_fp4quant, AddRMSNormFP4QuantKernel, ) + from .mla_decode import cute_dsl_mla_decode __all__ = [ # Utils (always available) @@ -79,4 +80,6 @@ # Add + RMSNorm + FP4 Quantization "add_rmsnorm_fp4quant", "AddRMSNormFP4QuantKernel", + # MLA Decode + "cute_dsl_mla_decode", ] diff --git a/flashinfer/cute_dsl/mla_decode.py b/flashinfer/cute_dsl/mla_decode.py new file mode 100644 index 0000000000..02834e3d45 --- /dev/null +++ b/flashinfer/cute_dsl/mla_decode.py @@ -0,0 +1,390 @@ +# Copyright (c) 2025 by FlashInfer team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +CuTe DSL MLA Decode Kernel Integration +======================================= + +Wraps NVIDIA's CuTe DSL MLA decode kernels (FP16/FP8) for Blackwell SM100 +and exposes them via a PyTorch API compatible with FlashInfer's MLA backend. +""" + +import functools +from typing import Callable, Optional, Tuple + +import cutlass +import cutlass.cute as cute +import torch +from cutlass import Float32, Int32 + +from .mla_decode_fp16 import BlackwellMultiHeadLatentAttentionForwardFP16 +from .mla_decode_fp8 import BlackwellMultiHeadLatentAttentionForwardFP8 +from .utils import get_num_sm + + +# Default kernel configuration — matches DeepSeek-V2/V3 MLA dimensions +_LATENT_DIM = 512 +_ROPE_DIM = 64 +_MMA_QK_TILER_MN = (128, 128) +_MMA_PV_TILER_MN = (128, 256) +_MAX_ACTIVE_CLUSTERS = 2 +_SKIP_CORRECTION_THRESHOLD = 0.0 + + +@functools.cache +def _get_compiled_mla_kernel( + is_fp8: bool, + page_size: int, + num_heads: int, + seq_len_q: int, + is_persistent: bool, + is_var_seq: bool, + is_var_split_kv: bool, +) -> Tuple[Callable, object]: + """Compile and cache an MLA decode kernel. + + Returns (compiled_kernel_closure, kernel_class_instance). + The kernel_class_instance is needed for get_split_kv() and get_workspace_size(). + """ + KernelClass = ( + BlackwellMultiHeadLatentAttentionForwardFP8 + if is_fp8 + else BlackwellMultiHeadLatentAttentionForwardFP16 + ) + + kernel_obj = KernelClass( + acc_dtype=cutlass.Float32, + lse_dtype=cutlass.Float32, + mma_qk_tiler_mn=_MMA_QK_TILER_MN, + mma_pv_tiler_mn=_MMA_PV_TILER_MN, + max_active_clusters=_MAX_ACTIVE_CLUSTERS, + page_size=page_size, + skip_correction_threshold=_SKIP_CORRECTION_THRESHOLD, + is_persistent=is_persistent, + is_var_seq=is_var_seq, + is_var_split_kv=is_var_split_kv, + ) + + cutlass_dtype = cutlass.Float8E4M3FN if is_fp8 else cutlass.Float16 + + # All dimensions as sym_int — this matches the original kernel's use of + # mark_compact_shape_dynamic, which makes ALL shapes dynamic CuTe Integers. + # Static Python ints would cause cute.assume() to fail with AttributeError + # inside initialize_workspace() since it expects DSL Integer types. + sym_heads = cute.sym_int() + sym_latent = cute.sym_int() + sym_seq_q = cute.sym_int() + sym_rope = cute.sym_int() + sym_batch = cute.sym_int() # query/output batch dimension + sym_kv_batch = cute.sym_int() # KV cache batch dim (flat pool, =1 in paged mode) + sym_seq_kv = cute.sym_int() + sym_page_count = cute.sym_int() + sym_workspace_size = cute.sym_int() + + # q_latent: [num_heads, latent_dim, seq_len_q, batch_size] — stride[1]==1 + q_latent_fake = cute.runtime.make_fake_compact_tensor( + cutlass_dtype, + (sym_heads, sym_latent, sym_seq_q, sym_batch), + stride_order=(3, 0, 2, 1), + assumed_align=128, + ) + # q_rope: [num_heads, rope_dim, seq_len_q, batch_size] — stride[1]==1 + q_rope_fake = cute.runtime.make_fake_compact_tensor( + cutlass_dtype, + (sym_heads, sym_rope, sym_seq_q, sym_batch), + stride_order=(3, 0, 2, 1), + assumed_align=128, + ) + # c_latent: [seq_len_k, latent_dim, kv_batch] — stride[1]==1 + # kv_batch is a separate sym_int from query batch: paged KV cache uses a flat + # pool so kv_batch=1 at runtime, while query batch can be any value. + c_latent_fake = cute.runtime.make_fake_compact_tensor( + cutlass_dtype, + (sym_seq_kv, sym_latent, sym_kv_batch), + stride_order=(2, 0, 1), + assumed_align=128, + ) + # c_rope: [seq_len_k, rope_dim, kv_batch] — stride[1]==1 + c_rope_fake = cute.runtime.make_fake_compact_tensor( + cutlass_dtype, + (sym_seq_kv, sym_rope, sym_kv_batch), + stride_order=(2, 0, 1), + assumed_align=128, + ) + # page_table: [page_count, batch_size] + page_table_fake = cute.runtime.make_fake_compact_tensor( + cutlass.Int32, + (sym_page_count, sym_batch), + stride_order=(1, 0), + assumed_align=128, + ) + # o: [num_heads, latent_dim, seq_len_q, batch_size] — stride[1]==1 + o_fake = cute.runtime.make_fake_compact_tensor( + cutlass_dtype, + (sym_heads, sym_latent, sym_seq_q, sym_batch), + stride_order=(3, 0, 2, 1), + assumed_align=128, + ) + # lse: [num_heads, seq_len_q, batch_size] — stride[0]==1 (num_heads dim is contiguous) + # stride_order[d]=rank: dim0 rank=0 means dim0 is fastest → stride[0]=1 compile-time constant + lse_fake = cute.runtime.make_fake_compact_tensor( + cutlass.Float32, + (sym_heads, sym_seq_q, sym_batch), + stride_order=(0, 1, 2), + assumed_align=128, + ) + # workspace: 1-D + workspace_fake = cute.runtime.make_fake_compact_tensor( + cutlass.Uint8, + (sym_workspace_size,), + assumed_align=128, + ) + # cache_seqs: [batch_size] — int32 + cache_seqs_fake = cute.runtime.make_fake_compact_tensor( + cutlass.Int32, + (sym_batch,), + assumed_align=128, + ) + # block_split_kvs: [batch_size] — int32 + block_split_kvs_fake = cute.runtime.make_fake_compact_tensor( + cutlass.Int32, + (sym_batch,), + assumed_align=128, + ) + + stream_fake = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) + + compiled_kernel = cute.compile( + kernel_obj, + q_latent_fake, + q_rope_fake, + c_latent_fake, + c_rope_fake, + page_table_fake, + o_fake, + lse_fake, + workspace_fake, + Int32(1), # split_kv placeholder + cache_seqs_fake, + block_split_kvs_fake, + Float32(1.0), # softmax_scale placeholder + Float32(1.0), # output_scale placeholder + stream_fake, + options="--enable-tvm-ffi", + ) + + def tensor_api( + q_latent: torch.Tensor, + q_rope: torch.Tensor, + c_latent: torch.Tensor, + c_rope: torch.Tensor, + page_table: torch.Tensor, + o: torch.Tensor, + lse: torch.Tensor, + workspace: torch.Tensor, + split_kv: int, + cache_seqs: torch.Tensor, + block_split_kvs: torch.Tensor, + softmax_scale: float, + output_scale: float, + ) -> None: + nonlocal compiled_kernel + compiled_kernel( + q_latent, + q_rope, + c_latent, + c_rope, + page_table, + o, + lse, + workspace, + Int32(split_kv), + cache_seqs, + block_split_kvs, + Float32(softmax_scale), + Float32(output_scale), + ) + + return tensor_api, kernel_obj + + +def cute_dsl_mla_decode( + query: torch.Tensor, + kv_cache: torch.Tensor, + workspace_buffer: torch.Tensor, + kv_lora_rank: int, + qk_rope_head_dim: int, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + max_seq_len: int, + softmax_scale: float, + output_scale: float = 1.0, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """CuTe DSL MLA decode kernel for Blackwell SM100. + + Parameters + ---------- + query : torch.Tensor + [B, q_len, H, D_qk] where D_qk = kv_lora_rank + qk_rope_head_dim + kv_cache : torch.Tensor + [num_pages, page_size, D_ckv + D_kpe] (3D) or [num_pages, 1, page_size, D_ckv + D_kpe] (4D) + workspace_buffer : torch.Tensor + Pre-allocated workspace buffer. + kv_lora_rank : int + Latent dimension (e.g. 512). + qk_rope_head_dim : int + RoPE dimension (e.g. 64). + block_tables : torch.Tensor + [B, max_pages] — page table indices. + seq_lens : torch.Tensor + [B] — per-request KV sequence lengths. + max_seq_len : int + Maximum sequence length across the batch. + softmax_scale : float + Scale factor for QK^T before softmax. + output_scale : float + Scale factor applied to the output. + out : Optional[torch.Tensor] + Pre-allocated output tensor [B, H, kv_lora_rank]. + + Returns + ------- + torch.Tensor + Output tensor [B, H, kv_lora_rank]. + """ + B, q_len, H, D_qk = query.shape + assert D_qk == kv_lora_rank + qk_rope_head_dim + assert kv_lora_rank == _LATENT_DIM + assert qk_rope_head_dim == _ROPE_DIM + + is_fp8 = query.dtype == torch.float8_e4m3fn + + # Handle 3D vs 4D kv_cache: normalize to 3D [num_pages, page_size, D_total] + if kv_cache.dim() == 4: + # [num_pages, 1, page_size, D_total] -> [num_pages, page_size, D_total] + kv_cache = kv_cache.squeeze(1) + page_size = kv_cache.shape[1] + D_total = kv_cache.shape[2] + assert D_total == kv_lora_rank + qk_rope_head_dim + + # Split query into latent and rope components + q_nope = query[..., :kv_lora_rank] # [B, q_len, H, latent_dim] + q_rope = query[..., kv_lora_rank:] # [B, q_len, H, rope_dim] + + # Reshape to kernel layout: [B, q_len, H, D] -> [H, D, q_len, B] + # Do NOT call .contiguous() — permute gives stride[1]=1 which the kernel requires. + # .contiguous() would rearrange to row-major making stride[3]=1 instead. + q_latent_k = q_nope.permute(2, 3, 1, 0) # [H, latent_dim, q_len, B], stride[1]=1 + q_rope_k = q_rope.permute(2, 3, 1, 0) # [H, rope_dim, q_len, B], stride[1]=1 + + # Total number of physical pages in the KV cache pool + num_pages = kv_cache.shape[0] + + # Reshape KV cache to kernel layout [page_size, D, num_pages]. + # The kernel indexes via page_table: for batch b, page p, offset t: + # c_latent[t, d, page_table[p, b]] = token (page_table[p,b]*page_size + t)'s latent[d] + # kv_cache: [num_pages, page_size, D_total] with strides (page_size*D_total, D_total, 1) + # After permute(1, 2, 0) on latent slice: [page_size, latent_dim, num_pages] + # strides = (D_total, 1, page_size*D_total) → stride[1]=1 ✓ + c_latent_k = kv_cache[:, :, :kv_lora_rank].permute(1, 2, 0) # [page_size, latent_dim, num_pages] + c_rope_k = kv_cache[:, :, kv_lora_rank:].permute(1, 2, 0) # [page_size, rope_dim, num_pages] + + # Page table: [B, max_pages] -> [max_pages, B] + page_table_k = block_tables.t().contiguous().to(torch.int32) + + # Determine split_kv and workspace + is_persistent = True + is_var_seq = True + is_var_split_kv = True + max_active_blocks = get_num_sm(query.device) + + split_kv = BlackwellMultiHeadLatentAttentionForwardFP16.get_split_kv( + B, q_len, max_seq_len, _MMA_QK_TILER_MN, max_active_blocks + ) + + workspace_size = BlackwellMultiHeadLatentAttentionForwardFP16.get_workspace_size( + H, q_len, _LATENT_DIM, B, split_kv, cutlass.Float32 + ) + + # Prepare workspace tensor + if workspace_size > 0: + workspace_bytes = workspace_buffer[: workspace_size].contiguous() + else: + workspace_bytes = workspace_buffer[:1].contiguous() + + # Allocate output: [H, latent_dim, q_len, B] with stride[1]==1 + # torch.empty(B, H, q_len, D) has row-major strides (H*q_len*D, q_len*D, D, 1). + # After permute(1, 3, 2, 0) → shape [H, D, q_len, B] with strides (q_len*D, 1, D, H*q_len*D). + # Do NOT call .contiguous() — that would collapse to row-major making stride[3]=1. + out_dtype = torch.float8_e4m3fn if is_fp8 else torch.float16 + o_k = torch.empty( + (B, H, q_len, _LATENT_DIM), dtype=out_dtype, device=query.device + ).permute(1, 3, 2, 0) # [H, latent_dim, q_len, B], stride[1]=1 + + # LSE: [H, q_len, B] with stride[0]==1 (H dim is contiguous). + # torch.empty(B, q_len, H) has row-major strides (q_len*H, H, 1). + # After permute(2, 1, 0) → shape [H, q_len, B] with strides (1, H, q_len*H). + # Do NOT call .contiguous() — that would make stride[2]=1 instead of stride[0]=1. + lse_k = torch.empty( + (B, q_len, H), dtype=torch.float32, device=query.device + ).permute(2, 1, 0) # [H, q_len, B], stride[0]=1 + + # cache_seqs: per-batch sequence lengths + cache_seqs = seq_lens.to(torch.int32).contiguous() + + # block_split_kvs: per-batch split_kv values + # Compute per-batch split_kv based on actual sequence lengths + block_split_kvs = torch.ones(B, dtype=torch.int32, device=query.device) * split_kv + + # Get compiled kernel + tensor_api, kernel_cls = _get_compiled_mla_kernel( + is_fp8=is_fp8, + page_size=page_size, + num_heads=H, + seq_len_q=q_len, + is_persistent=is_persistent, + is_var_seq=is_var_seq, + is_var_split_kv=is_var_split_kv, + ) + + # Call the kernel + tensor_api( + q_latent_k, + q_rope_k, + c_latent_k, + c_rope_k, + page_table_k, + o_k, + lse_k, + workspace_bytes.view(torch.uint8), + split_kv, + cache_seqs, + block_split_kvs, + softmax_scale, + output_scale, + ) + + # Reshape output: [H, latent_dim, q_len, B] -> [B, q_len, H, latent_dim] + result = o_k.permute(3, 2, 0, 1).contiguous() + + # Squeeze q_len dimension if it's 1: [B, 1, H, D] -> [B, H, D] + if q_len == 1: + result = result.squeeze(1) + + if out is not None: + out.copy_(result) + return out + + return result diff --git a/flashinfer/cute_dsl/mla_decode_fp16.py b/flashinfer/cute_dsl/mla_decode_fp16.py new file mode 100644 index 0000000000..baae56b65b --- /dev/null +++ b/flashinfer/cute_dsl/mla_decode_fp16.py @@ -0,0 +1,4385 @@ +# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +import sys +import argparse +import math +from typing import Type, Tuple, Optional +from types import SimpleNamespace + +import torch +import torch.nn.functional as F +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +import cutlass.cute.testing as testing +import cutlass.cute.nvgpu.tcgen05 as tcgen05 +from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode +import cutlass.cute.nvgpu.cpasync as cpasync +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait +import cutlass.torch as cutlass_torch +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.runtime import from_dlpack +from cutlass.base_dsl.arch import Arch +from cutlass.cutlass_dsl import BaseDSL + + +from .mla_helpers import ( + ceil_div, + MAX_SPLITS, + LOG2_E, + MLAStaticTileScheduler, + MLAStaticTileSchedulerParams, + create_mla_static_tile_scheduler, + create_mla_static_tile_scheduler_params, +) + +""" +A Multi-Head Latent Attention (MLA) example with FP16 data type for the NVIDIA Blackwell SM100 architecture using CUTE DSL + +This example demonstrates an implementation of inference of multi-head latent attention using a TMA + Blackwell +SM100 TensorCore warp-specialized persistent kernel. The implementation integrates the (Qc + Qr)*(Kc + Kr)^T +matrix multiplication, softmax normalization, and softmax((Qc + Qr)*(Kc + Kr)^T)*Vc into a single kernel. +The kernel provides support for page table storage and variable-length KV cache sequences. It implements KV splitting +functionality to minimize latency when processing long KV sequences. + +The kernel implements key optimizations including: +- Warp specialization for different computation phases (load, MMA, softmax, correction, epilogue) +- Pipeline stages between different warps for overlapping computation and memory access +- Support for different precision data types +- Two sub-kernels (split KV kernel and reduction kernel) that enable split KV processing + +To run this example: + +.. code-block:: bash + + python examples/blackwell/mla_fp16.py \ + --batch_size 4 --latent_dim 512 --rope_dim 64 \ + --num_heads 128 --seq_len_q 1 --seq_len_k 1024 \ + --in_dtype Float16 --out_dtype Float16 \ + --acc_dtype Float32 --lse_dtype Float32 \ + --is_var_seq --is_var_split_kv \ + --is_persistent + +The above example runs Multi-Head Latent Attention (MLA) with the following configuration: +- Batch size: 4 +- Sequence length of Q: 1 +- Sequence length of K: 1024 +- Latent dimension: 512 +- RoPE dimension: 64 +- Number of heads: 128 +- Data types: Float16 (input), Float16 (output), Float32 (accumulation and LSE) + +It utilizes page table storage for the KV cache and enables both variable-length KV cache sequences +and variable split KV processing with persistent scheduling. + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/mla_fp16.py \ + --batch_size 4 --latent_dim 512 --rope_dim 64 \ + --num_heads 128 --seq_len_q 1 --seq_len_k 1024 \ + --in_dtype Float16 --out_dtype Float16 \ + --acc_dtype Float32 --lse_dtype Float32 \ + --is_var_seq --is_var_split_kv \ + --is_persistent --warmup_iterations 3 \ + --iterations 10 --skip_ref_check + +Constraints for this example: +* Data type requirements: + - Input/output: Float16 + - Accumulation and LSE: Float32 +* Fixed architecture parameters: + - Number of attention heads: 128 + - Latent dimension: 512 + - RoPE dimension: 64 +* Input query modes should be (NumHeads, LatentDim/RopeDim, SeqLenQ, BatchSize) +* Input kv latent/rope modes should be (SeqLenK, LatentDim/RopeDim, BatchSize) +* Query sequence length must be 1-4 +* Only supports 2-CTA instructions +* Variable sequence length requires page table storage enabled +""" + + +class BlackwellMultiHeadLatentAttentionForwardFP16: + def __init__( + self, + acc_dtype: Type[cutlass.Numeric], + lse_dtype: Type[cutlass.Numeric], + mma_qk_tiler_mn: Tuple[int, int], + mma_pv_tiler_mn: Tuple[int, int], + max_active_clusters: int, + page_size: int, + skip_correction_threshold: float, + is_persistent: bool, + is_var_seq: bool, + is_var_split_kv: bool, + ): + """Initializes the configuration for a Blackwell Multi-Head Latent Attention (MLA) kernel. + + :param acc_dtype: Data type for accumulation S and O + :type acc_dtype: Type[cutlass.Numeric] + :param lse_dtype: Data type for output LSE + :type lse_dtype: Type[cutlass.Numeric] + :param mma_s_tiler: The (H, K) tile shape of the MMA instruction for S + :type mma_s_tiler: Tuple[int, int] + :param mma_p_tiler: The (H, D) tile shape of the MMA instruction for P + :type mma_p_tiler: Tuple[int, int] + :param max_active_clusters: Maximum number of active clusters + :type max_active_clusters: int + :param page_size: The page size of the page table + :type page_size: int + :param skip_correction_threshold: Threshold to skip correction + :type skip_correction_threshold: float + :param is_persistent: Whether to use persistent kernel mode + :type is_persistent: bool + :param is_var_seq: Whether to use variable sequence length + :type is_var_seq: bool + :param is_var_split_kv: Whether to use variable split KV + :type is_var_split_kv: bool + """ + + self.latent_dim = 512 + self.rope_dim = 64 + self.acc_dtype = acc_dtype + self.lse_dtype = lse_dtype + self.mma_qk_tiler_mn = mma_qk_tiler_mn + self.mma_pv_tiler_mn = mma_pv_tiler_mn + self.max_active_clusters = max_active_clusters + self.skip_correction_threshold = skip_correction_threshold + self.is_persistent = is_persistent + self.page_size = page_size + self.is_var_seq = is_var_seq + self.is_var_split_kv = is_var_split_kv + self.cluster_shape_mnk = (2, 1, 1) + self.use_2cta_instrs = True + # When using 2 CTAs with m=128: warps 0-1 handle accumulation for first half [0, n/2), + # while warps 2-3 handle accumulation for second half [n/2, n) + self.warps_in_n = 2 + self.num_compute_warps = 4 + self.threads_per_warp = 32 + mma_qk_tiler_k = self.rope_dim + self.mma_qk_tiler = ( + self.mma_qk_tiler_mn[0], + self.mma_qk_tiler_mn[1], + mma_qk_tiler_k, + ) + self.mma_qk_rope_tiler = ( + self.mma_qk_tiler_mn[0], + self.mma_qk_tiler_mn[1], + self.rope_dim, + ) + self.mma_pv_tiler = ( + self.mma_pv_tiler_mn[0], + self.mma_pv_tiler_mn[1], + self.mma_qk_tiler[1] * self.mma_qk_tiler[2] // self.mma_pv_tiler_mn[1], + ) + self.iterations_qk_latent = self.latent_dim // self.mma_qk_tiler[2] + self.iterations_qk_rope = mma_qk_tiler_k // self.mma_qk_tiler[2] + self.iterations_qk = self.iterations_qk_latent + self.iterations_qk_rope + self.iterations_pv_k = self.mma_qk_tiler[1] // self.mma_pv_tiler[2] + self.iterations_pv_n = self.latent_dim // self.mma_pv_tiler[1] + + # Set specialized warp ids + self.compute_warp_ids = (0, 1, 2, 3) + self.correction_warp_ids = (4, 5, 6, 7) + self.mma_warp_id = 8 + + self.load_tma_warp_id = 9 + self.load_pt_warp_id = 10 + self.empty_warp_ids = (11,) + self.threads_per_cta = self.threads_per_warp * len( + ( + self.mma_warp_id, + self.load_tma_warp_id, + self.load_pt_warp_id, + *self.compute_warp_ids, + *self.correction_warp_ids, + *self.empty_warp_ids, + ) + ) + + # register settings + self.softmax_reg_num = 192 + self.correction_reg_num = 208 + self.other_reg_num = 96 + # Named barriers + self.tmem_ptr_sync_bar = pipeline.NamedBarrier( + barrier_id=1, + num_threads=( + self.threads_per_warp + + self.threads_per_warp * self.num_compute_warps * 2 + ), + ) + self.softmax_exchange_sync_bar = pipeline.NamedBarrier( + barrier_id=2, num_threads=(self.threads_per_warp * self.num_compute_warps) + ) + self.epilogue_exchange_sync_bar = pipeline.NamedBarrier( + barrier_id=3, num_threads=(self.threads_per_warp * self.num_compute_warps) + ) + + def _setup_attributes(self): + """Set up configurations and parameters for the MLA kernel operation. + + This method initializes and configures various attributes required for the + execution of the multi-head latent attention kernel, mainly about the pipeline stages: + + - Sets up staging parameters for Q, K, V inputs and accumulator data + - Configures pipeline stages for softmax, correction, and epilogue operations + """ + + self.load_q_stage = 1 + self.load_kv_stage = 15 + self.mma_s_stage = 2 + self.p_mma_stage = 2 + self.p_cor_stage = 2 + self.mma_o_stage = 1 + self.load_pt_stage = 4 + + self.tmem_o_offset = self.mma_s_stage * self.mma_qk_tiler[1] // self.warps_in_n + self.correction_factor_offset = ( + self.tmem_o_offset + self.latent_dim // self.warps_in_n + ) + + @cute.jit + def __call__( + self, + q_latent: cute.Tensor, + q_rope: cute.Tensor, + c_latent: cute.Tensor, + c_rope: cute.Tensor, + page_table: cute.Tensor, + o: cute.Tensor, + lse: cute.Tensor, + workspace: cute.Tensor, + split_kv: cutlass.Int32, + cache_seqs: Optional[cute.Tensor], + block_split_kvs: Optional[cute.Tensor], + softmax_scale: cutlass.Float32, + output_scale: cutlass.Float32, + stream: cuda.CUstream, + ): + """Execute the Multi-Head Latent Attention operation on the provided tensors. + + The method handles: + 1. Initialization of workspace for temporary split KV buffers + 2. Validation of tensor data types + 3. Initialization of hardware-specific parameters and memory layouts + 4. Configuration of TMA (Tensor Memory Access) operations + 5. Grid and work scheduling computation + 6. Kernel launch(split KV kernel and reduction kernel) with appropriate parameters + + :param q_latent: The query tensor with shape [num_head, latent_dim, seq_len_q, batch_size] + :type q_latent: cute.Tensor + :param q_rope: The query RoPE tensor with shape [num_head, rope_dim, seq_len_q, batch_size] + :type q_rope: cute.Tensor + :param c_latent: The key tensor with shape [seq_len_k, latent_dim, batch_size] + :type c_latent: cute.Tensor + :param c_rope: The key RoPE tensor with shape [seq_len_k, rope_dim, batch_size] + :type c_rope: cute.Tensor + :param page_table: The page table tensor with shape [page_count, batch_size] + :type page_table: cute.Tensor + :param o: The output tensor with shape [num_head, latent_dim, seq_len_q, batch_size] + :type o: cute.Tensor + :param lse: The LSE tensor with shape [num_head, seq_len_q, batch_size] + :type lse: cute.Tensor + :param workspace: The workspace tensor with 1-d shape prepared for acc_o and acc_lse + :type workspace: cute.Tensor + :param split_kv: The scalar factor for split KV + :type split_kv: cutlass.Int32 + :param cache_seqs: The cache sequences tensor with shape [batch_size] + :type cache_seqs: cute.Tensor + :param block_split_kvs: The block split KV tensor with shape [batch_size] + :type block_split_kvs: cute.Tensor + :param softmax_scale: The scale factor for softmax + :type softmax_scale: cutlass.Float32 + :param output_scale: The scale factor for the output + :type output_scale: cutlass.Float32 + :param stream: The CUDA stream to execute the kernel on + :type stream: cuda.CUstream + + :raises TypeError: If tensor data types don't match or aren't supported + """ + + # setup static attributes before smem/grid/tma computation + self.q_dtype = q_latent.element_type + self.k_dtype = c_latent.element_type + self.v_dtype = c_latent.element_type + self.o_dtype = o.element_type + + # check type consistency + if cutlass.const_expr( + self.q_dtype != self.k_dtype or self.q_dtype != self.v_dtype + ): + raise TypeError( + f"Type mismatch: {self.q_dtype} != {self.k_dtype} or {self.q_dtype} != {self.v_dtype}" + ) + # check leading dimensions of input/output + if cutlass.const_expr(q_latent.stride[1] != 1 or q_rope.stride[1] != 1): + raise ValueError("q_latent and q_rope must have leading dimension 1") + if cutlass.const_expr(c_latent.stride[1] != 1 or c_rope.stride[1] != 1): + raise ValueError("c_latent and c_rope must have leading dimension 1") + if cutlass.const_expr(o.stride[1] != 1): + raise ValueError("o must have leading dimension 1") + if cutlass.const_expr(lse.stride[0] != 1): + raise ValueError("lse must have leading dimension 0") + + acc_o, acc_lse = self.initialize_workspace( + q_latent.shape[0], + q_latent.shape[1], + q_latent.shape[2], + q_latent.shape[3], + split_kv, + self.acc_dtype, + workspace, + ) + + c_latent_tranpose_layout = cute.select(c_latent.layout, mode=[1, 0, 2]) + c_latent_transpose = cute.make_tensor( + c_latent.iterator, c_latent_tranpose_layout + ) + + self.q_major_mode = tcgen05.OperandMajorMode.K + self.k_major_mode = tcgen05.OperandMajorMode.K + self.v_major_mode = tcgen05.OperandMajorMode.MN + + self._setup_attributes() + + cta_group = tcgen05.CtaGroup.TWO + # the intermediate tensor p is from smem & k-major + p_major_mode = tcgen05.OperandMajorMode.K + qk_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.q_dtype, + self.q_major_mode, + self.k_major_mode, + self.acc_dtype, + cta_group, + self.mma_qk_tiler[:2], + ) + pv_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.v_dtype, + p_major_mode, + self.v_major_mode, + self.acc_dtype, + cta_group, + self.mma_pv_tiler[:2], + ) + + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), + (qk_tiled_mma.thr_id.shape,), + ) + + self.epi_tile = self.mma_pv_tiler[:2] + + q_latent_smem_layout_staged = sm100_utils.make_smem_layout_a( + qk_tiled_mma, + self.mma_qk_tiler, + self.q_dtype, + (self.iterations_qk_latent * self.load_q_stage), + ) + q_latent_smem_layout_staged = cute.logical_divide( + q_latent_smem_layout_staged, (None, None, None, self.iterations_qk_latent) + ) + q_rope_smem_layout_staged = sm100_utils.make_smem_layout_a( + qk_tiled_mma, + self.mma_qk_rope_tiler, + self.q_dtype, + self.load_q_stage, + ) + + # rope reuse the same smem layout as latent + kc_smem_layout_staged = sm100_utils.make_smem_layout_b( + qk_tiled_mma, + self.mma_qk_tiler, + self.k_dtype, + self.load_kv_stage, + ) + kc_page_tile_size = min( + self.page_size, qk_tiled_mma.op.shape_mnk[0] // qk_tiled_mma.thr_id.shape + ) + + kc_smem_layout_for_tma = sm100_utils.make_smem_layout( + OperandMajorMode.K, + (self.mma_qk_tiler[0] // qk_tiled_mma.thr_id.shape, self.mma_qk_tiler[2]), + self.k_dtype, + self.load_kv_stage, + ) + kc_smem_layout_for_tma = cute.tiled_divide( + kc_smem_layout_for_tma, (kc_page_tile_size, self.mma_qk_tiler[2]) + ) + + p_smem_layout_staged = sm100_utils.make_smem_layout_a( + pv_tiled_mma, + self.mma_pv_tiler, + self.q_dtype, + (self.iterations_pv_k * self.p_mma_stage), + ) + p_smem_layout_staged = cute.logical_divide( + p_smem_layout_staged, (None, None, None, self.iterations_pv_k) + ) + + vc_smem_layout_staged = sm100_utils.make_smem_layout_b( + pv_tiled_mma, + self.mma_pv_tiler, + self.v_dtype, + self.load_kv_stage, + ) + vc_page_tile_size = min(self.page_size, self.mma_pv_tiler[2]) + vc_smem_layout_for_tma = sm100_utils.make_smem_layout( + OperandMajorMode.MN, + (self.mma_pv_tiler[1] // pv_tiled_mma.thr_id.shape, self.mma_pv_tiler[2]), + self.v_dtype, + self.load_kv_stage, + ) + vc_smem_layout_for_tma = cute.tiled_divide( + vc_smem_layout_for_tma, + ( + pv_tiled_mma.op.shape_mnk[1] // pv_tiled_mma.thr_id.shape, + vc_page_tile_size, + ), + ) + # TMA load for Q latent and rope + tma_load_op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp(cta_group) + + q_latent_smem_layout = cute.select(q_latent_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_q_latent, tma_tensor_q_latent = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q_latent, + q_latent_smem_layout, + self.mma_qk_tiler, + qk_tiled_mma, + cta_layout_vmnk.shape, + ) + q_rope_smem_layout = cute.select(q_rope_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_q_rope, tma_tensor_q_rope = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q_rope, + q_rope_smem_layout, + self.mma_qk_rope_tiler, + qk_tiled_mma, + cta_layout_vmnk.shape, + ) + # TMA load for c latent and k rope + kc_smem_layout = cute.select(kc_smem_layout_for_tma, mode=[0]) + tma_atom_c_latent, tma_tensor_c_latent = self.make_paged_tiled_tma_atom( + tma_load_op, + c_latent, + kc_smem_layout, + (self.mma_qk_tiler[1], self.mma_qk_tiler[2]), + qk_tiled_mma, + is_k_load=True, + ) + tma_atom_c_rope, tma_tensor_c_rope = self.make_paged_tiled_tma_atom( + tma_load_op, + c_rope, + kc_smem_layout, + (self.mma_qk_tiler[1], self.mma_qk_tiler[2]), + qk_tiled_mma, + is_k_load=True, + ) + # TMA load for c latent transpose + vc_smem_layout = cute.select(vc_smem_layout_for_tma, mode=[0]) + tma_atom_c_latent_transpose, tma_tensor_c_latent_transpose = ( + self.make_paged_tiled_tma_atom( + tma_load_op, + c_latent_transpose, + vc_smem_layout, + (self.mma_pv_tiler[1], self.mma_pv_tiler[2]), + pv_tiled_mma, + is_k_load=False, + ) + ) + + q_latent_copy_size = ( + cute.size_in_bytes(self.q_dtype, q_latent_smem_layout) + * cute.size(qk_tiled_mma.thr_id.shape) + * self.iterations_qk_latent + ) + q_rope_copy_size = ( + cute.size_in_bytes(self.q_dtype, q_rope_smem_layout) + * cute.size(qk_tiled_mma.thr_id.shape) + * self.iterations_qk_rope + ) + q_copy_size = q_latent_copy_size + q_rope_copy_size + kc_copy_size = cute.size_in_bytes( + self.k_dtype, cute.select(kc_smem_layout_staged, mode=[0, 1, 2]) + ) * cute.size(qk_tiled_mma.thr_id.shape) + vc_copy_size = cute.size_in_bytes( + self.v_dtype, cute.select(vc_smem_layout_staged, mode=[0, 1, 2]) + ) * cute.size(pv_tiled_mma.thr_id.shape) + assert kc_copy_size == vc_copy_size, ( + "kc_copy_size and vc_copy_size must be the same" + ) + + self.tma_copy_q_bytes = q_copy_size + self.tma_copy_kc_bytes = kc_copy_size + + tile_sched_params, grid = self._compute_grid( + o, + split_kv, + self.cluster_shape_mnk, + self.max_active_clusters, + self.is_persistent, + ) + + @cute.struct + class SplitKVKernelSharedStorage: + # Pipeline barriers + load_q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_q_stage * 2] + load_kv_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.load_kv_stage * 2 + ] + mma_s_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.mma_s_stage * 2] + p_mma_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.p_mma_stage * 2] + p_cor_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.p_cor_stage * 2] + mma_o_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.mma_o_stage * 2] + load_pt_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.load_pt_stage * 2 + ] + # Tmem dealloc cluster barrier + tmem_dealloc_mbar_ptr: cutlass.Int64 + + # Tmem holding buffer + tmem_holding_buf: cutlass.Int32 + # Smem tensors + softmax_smem_exchange: cute.struct.MemRange[ + self.acc_dtype, self.num_compute_warps * self.threads_per_warp + ] + epilogue_smem_exchange: cute.struct.MemRange[ + self.acc_dtype, self.num_compute_warps * self.threads_per_warp + ] + smem_q_latent: cute.struct.Align[ + cute.struct.MemRange[ + self.q_dtype, cute.cosize(q_latent_smem_layout_staged) + ], + 1024, + ] + smem_q_rope: cute.struct.Align[ + cute.struct.MemRange[ + self.q_dtype, cute.cosize(q_rope_smem_layout_staged) + ], + 1024, + ] + smem_kc: cute.struct.Align[ + cute.struct.MemRange[self.k_dtype, cute.cosize(kc_smem_layout_staged)], + 1024, + ] + smem_p: cute.struct.Align[ + cute.struct.MemRange[self.q_dtype, cute.cosize(p_smem_layout_staged)], + 1024, + ] + smem_page_table: cute.struct.MemRange[ + cutlass.Int32, self.load_pt_stage * self.mma_qk_tiler[1] // 2 + ] + + softmax_scale_log2 = softmax_scale * LOG2_E + self.split_kv_kernel( + qk_tiled_mma, + pv_tiled_mma, + tma_atom_q_latent, + tma_tensor_q_latent, + tma_atom_q_rope, + tma_tensor_q_rope, + tma_atom_c_latent, + tma_tensor_c_latent, + tma_atom_c_rope, + tma_tensor_c_rope, + tma_atom_c_latent_transpose, + tma_tensor_c_latent_transpose, + page_table, + o, + lse, + acc_o, + acc_lse, + split_kv, + cache_seqs, + block_split_kvs, + softmax_scale_log2, + output_scale, + q_latent_smem_layout_staged, + q_rope_smem_layout_staged, + kc_smem_layout_staged, + p_smem_layout_staged, + vc_smem_layout_staged, + kc_smem_layout_for_tma, + vc_smem_layout_for_tma, + cta_layout_vmnk, + tile_sched_params, + SplitKVKernelSharedStorage, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk, + smem=SplitKVKernelSharedStorage.size_in_bytes(), + stream=stream, + min_blocks_per_mp=1, + ) + if cutlass.const_expr(acc_o is not None): + self.reduction_kernel( + o, + lse, + acc_o, + acc_lse, + split_kv, + cache_seqs, + block_split_kvs, + ).launch( + grid=(q_latent.shape[0], q_latent.shape[2], q_latent.shape[3]), + block=[self.threads_per_warp * self.num_compute_warps, 1, 1], + smem=MAX_SPLITS * self.acc_dtype.width // 8, + stream=stream, + min_blocks_per_mp=1, + ) + + @cute.jit + def make_paged_tiled_tma_atom( + self, + tma_load_op: cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp, + gmem: cute.Tensor, + smem_layout: cute.Layout, + mma_tiler, + tiled_mma: cute.TiledMma, + is_k_load: bool, + ): + ident = cute.make_identity_layout(gmem.shape) + g_tile = cute.composition(ident, mma_tiler) + cta_mn = mma_tiler[0] // tiled_mma.thr_id.shape + cta_v_map = cute.flat_divide(g_tile, (cta_mn,)) + cta_v_map = cute.select(cta_v_map, mode=[0, 2]) + page_tile_size = ( + min(self.page_size, cta_mn) + if is_k_load + else min(self.page_size, mma_tiler[1]) + ) + cta_v_map = cute.zipped_divide( + cta_v_map, + (page_tile_size, mma_tiler[1]) if is_k_load else (cta_mn, page_tile_size), + ) + cta_v_map = cute.select(cta_v_map, mode=[0]) + from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir + + res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load( + gmem.value, + smem_layout.value, + cta_v_map, + tma_load_op._to_ir(), + num_multicast=1, + ) + return cute.CopyAtom( + tma_load_op, cpasync.CopyBulkTensorTileG2SNonExecTrait(res[0]) + ), res[1] + + @cute.kernel + def split_kv_kernel( + self, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + tma_atom_q_latent: Optional[cute.CopyAtom], + mQL: cute.Tensor, + tma_atom_q_rope: Optional[cute.CopyAtom], + mQR: cute.Tensor, + tma_atom_c_latent: Optional[cute.CopyAtom], + mCL: cute.Tensor, + tma_atom_c_rope: Optional[cute.CopyAtom], + mKR: cute.Tensor, + tma_atom_c_latent_transpose: Optional[cute.CopyAtom], + mCLT: cute.Tensor, + mPT: cute.Tensor, + mO: Optional[cute.Tensor], + mLSE: Optional[cute.Tensor], + mAccO: Optional[cute.Tensor], + mAccLSE: Optional[cute.Tensor], + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + softmax_scale_log2: cutlass.Float32, + output_scale: cutlass.Float32, + q_latent_smem_layout_staged: cute.ComposedLayout, + q_rope_smem_layout_staged: cute.ComposedLayout, + kc_smem_layout_staged: cute.ComposedLayout, + p_smem_layout_staged: cute.ComposedLayout, + vc_smem_layout_staged: cute.ComposedLayout, + kc_smem_layout_for_tma: cute.ComposedLayout, + vc_smem_layout_for_tma: cute.ComposedLayout, + cta_layout_vmnk: cute.Layout, + tile_sched_params: MLAStaticTileSchedulerParams, + SharedStorage: cutlass.Constexpr, + ): + """The device split_kv kernel implementation of the Multi-Head Latent Attention. + + This kernel coordinates multiple specialized warps to perform different phases of the MLA computation: + 1. Load warp: Loads Q/C latent/rope data from global memory to shared memory using TMA + 2. MMA warp: Performs matrix multiplications (Q*K^T and P*V) + 3. Compute warps: Compute softmax and do rescaling on accumulators, and store the intermediate/final results + to global memory + + The kernel produces either intermediate or final results of the MLA computation based on the split_kv parameter. + When split_kv is 1, the kernel generates the final results directly. Otherwise, it produces intermediate results + that will later be combined by a reduction kernel. + + The kernel implements a complex pipeline with overlapping computation and memory operations, + using tensor memory access (TMA) for efficient data loading, warp specialization for different + computation phases. + + :param tiled_mma_qk: Tiled MMA for Q*K^T + :type tiled_mma_qk: cute.TiledMma + :param tiled_mma_pv: Tiled MMA for P*V + :type tiled_mma_pv: cute.TiledMma + :param tma_atom_q_latent: TMA copy atom for query latent tensor + :type tma_atom_q_latent: cute.CopyAtom + :param mQL: query latent tensor + :type mQL: cute.Tensor + :param tma_atom_q_rope: TMA copy atom for query rope tensor + :type tma_atom_q_rope: cute.CopyAtom + :param mKR: Compressed rope tensor + :type mKR: cute.Tensor + :param tma_atom_c_latent: TMA copy atom for c latent tensor + :type tma_atom_c_latent: cute.CopyAtom + :param mCL: Compressed latent tensor + :type mCL: cute.Tensor + :param tma_atom_c_rope: TMA copy atom for c rope tensor + :type tma_atom_c_rope: cute.CopyAtom + :param mCLT: Compressed latent transpose tensor + :type mCLT: cute.Tensor + :param mPT: Page table tensor + :type mPT: cute.Tensor + :param mO: Output tensor + :type mO: cute.Tensor + :param mLSE: Log-sum-exp tensor + :type mLSE: cute.Tensor + :param mAccO: Intermediate accumulator output tensor + :type mAccO: cute.Tensor + :param mAccLSE: Intermediate accumulator log-sum-exp tensor + :type mAccLSE: cute.Tensor + :param split_kv: The split_kv parameter + :type split_kv: cutlass.Int32 + :param cache_seqs: The variable sequence length tensor + :type cache_seqs: cute.Tensor + :param block_split_kvs: The per-block split_kv values tensor + :type block_split_kvs: cute.Tensor + :param softmax_scale_log2: The log2 scale factor for softmax + :type softmax_scale_log2: cutlass.Float32 + :param output_scale: The scale factor for the output + :type output_scale: cutlass.Float32 + :param q_latent_smem_layout_staged: Shared memory layout for query latent tensor + :type q_latent_smem_layout_staged: cute.ComposedLayout + :param q_rope_smem_layout_staged: Shared memory layout for query rope tensor + :type q_rope_smem_layout_staged: cute.ComposedLayout + :param kc_smem_layout_staged: Shared memory layout for key/value latent/rope tensor + :type kc_smem_layout_staged: cute.ComposedLayout + :param p_smem_layout_staged: Shared memory layout for probability matrix + :type p_smem_layout_staged: cute.ComposedLayout + :param vc_smem_layout_staged: Shared memory layout for value tensor + :type vc_smem_layout_staged: cute.ComposedLayout + :param kc_smem_layout_for_tma: Shared memory layout for key/value latent tensor for TMA + :type kc_smem_layout_for_tma: cute.ComposedLayout + :param vc_smem_layout_for_tma: Shared memory layout for value tensor for TMA + :type vc_smem_layout_for_tma: cute.ComposedLayout + :param cta_layout_vmnk: Layout for compute threads + :type cta_layout_vmnk: cute.Layout + :param tile_sched_params: Scheduling parameters for work distribution + :type tile_sched_params: MLAStaticTileSchedulerParams + :param SharedStorage: Shared storage for the kernel + :type SharedStorage: cutlass.Constexpr + """ + + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma_qk.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + + # Prefetch tma descriptor + if warp_idx == self.mma_warp_id: + cpasync.prefetch_descriptor(tma_atom_q_latent) + cpasync.prefetch_descriptor(tma_atom_q_rope) + cpasync.prefetch_descriptor(tma_atom_c_latent) + cpasync.prefetch_descriptor(tma_atom_c_rope) + cpasync.prefetch_descriptor(tma_atom_c_latent_transpose) + + # Alloc + smem = utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + + # Tensor memory dealloc barrier init + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=self.tmem_ptr_sync_bar, + allocator_warp_id=self.mma_warp_id, + is_two_cta=self.use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + load_q_pipeline = self.make_and_init_load_qkv_pipeline( + storage.load_q_mbar_ptr.data_ptr(), + cta_layout_vmnk, + self.load_q_stage, + self.tma_copy_q_bytes, + ) + load_kv_pipeline = self.make_and_init_load_qkv_pipeline( + storage.load_kv_mbar_ptr.data_ptr(), + cta_layout_vmnk, + self.load_kv_stage, + self.tma_copy_kc_bytes, + ) + mma_s_pipeline = self.make_and_init_mma_s_pipeline( + storage.mma_s_mbar_ptr.data_ptr(), cta_layout_vmnk + ) + p_mma_pipeline = self.make_and_init_p_mma_pipeline( + storage.p_mma_mbar_ptr.data_ptr(), cta_layout_vmnk + ) + p_cor_pipeline = self.make_and_init_p_cor_pipeline( + storage.p_cor_mbar_ptr.data_ptr() + ) + mma_o_pipeline = self.make_and_init_mma_o_pipeline( + storage.mma_o_mbar_ptr.data_ptr(), cta_layout_vmnk + ) + load_pt_pipeline = self.make_and_init_load_pt_pipeline( + storage.load_pt_mbar_ptr.data_ptr() + ) + + # Cluster arrive after barrier init + pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mnk, is_relaxed=True) + + # Generate smem tensor Q/KC/VC/exchange + # (MMA, MMA_H, MMA_R, PIPE) + sQ = storage.smem_q_latent.get_tensor( + q_latent_smem_layout_staged.outer, swizzle=q_latent_smem_layout_staged.inner + ) + sQ_rope = storage.smem_q_rope.get_tensor( + q_rope_smem_layout_staged.outer, swizzle=q_rope_smem_layout_staged.inner + ) + # (MMA, MMA_K, MMA_R, PIPE) + sKC = storage.smem_kc.get_tensor( + kc_smem_layout_staged.outer, swizzle=kc_smem_layout_staged.inner + ) + sKC_for_tma = storage.smem_kc.get_tensor( + kc_smem_layout_for_tma.outer, + swizzle=kc_smem_layout_for_tma.inner, + ) + # (MMA, MMA_D, MMA_K, PIPE) + # reuse smem + sVC_ptr = cute.recast_ptr(sKC.iterator, vc_smem_layout_staged.inner) + sVC = cute.make_tensor(sVC_ptr, vc_smem_layout_staged.outer) + sVC_for_tma = cute.make_tensor(sVC_ptr, vc_smem_layout_for_tma.outer) + # (MMA, MMA_H, MMA_K) + sP = storage.smem_p.get_tensor( + p_smem_layout_staged.outer, swizzle=p_smem_layout_staged.inner + ) + sPT = storage.smem_page_table.get_tensor( + cute.make_layout((self.mma_qk_tiler[1] // 2, self.load_pt_stage)) + ) + # (compute_threads,) + softmax_smem_exchange = storage.softmax_smem_exchange.get_tensor( + cute.make_layout(self.num_compute_warps * self.threads_per_warp) + ) + epilogue_smem_exchange = storage.epilogue_smem_exchange.get_tensor( + cute.make_layout(self.num_compute_warps * self.threads_per_warp) + ) + + # + # Cluster wait before tensor memory alloc + # + pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mnk) + + # /////////////////////////////////////////////////////////////////////////////// + # Load warps, including page table and data tensors + # /////////////////////////////////////////////////////////////////////////////// + + if warp_idx >= self.empty_warp_ids[0] and warp_idx <= self.empty_warp_ids[-1]: + cute.arch.setmaxregister_decrease(self.other_reg_num) + if warp_idx == self.load_pt_warp_id: + cute.arch.setmaxregister_decrease(self.other_reg_num) + load_pt_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.load_pt_stage + ) + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self.get_k_tile_count( + split_kv, + cache_seqs, + block_split_kvs, + blk_coord, + ) + if k_tile_count > 0: + load_pt_common_params = SimpleNamespace( + blk_coord=blk_coord, + load_pt_pipeline=load_pt_pipeline, + mPT=mPT, + sPT=sPT, + tidx=tidx, + page_size=mCL.shape[0], + ) + load_pt_producer_state = self.load_page_table( + load_pt_common_params, + k_index, + k_tile_count, + load_pt_producer_state, + ) + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + load_pt_pipeline.producer_tail(load_pt_producer_state) + if warp_idx == self.load_tma_warp_id: + cute.arch.setmaxregister_decrease(self.other_reg_num) + load_q_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.load_q_stage + ) + load_kv_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.load_kv_stage + ) + load_pt_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.load_pt_stage + ) + load_pt_release_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.load_pt_stage + ) + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self.get_k_tile_count( + split_kv, + cache_seqs, + block_split_kvs, + blk_coord, + ) + if k_tile_count > 0: + # Construct fixed common/tma_qk/tma_pv params for load_tma + tma_common_params = SimpleNamespace( + blk_coord=blk_coord, + local_split_kv=local_split_kv, + load_q_pipeline=load_q_pipeline, + load_kv_pipeline=load_kv_pipeline, + mPT=mPT, + sPT=sPT, + load_pt_pipeline=load_pt_pipeline, + ) + tma_qk_params = SimpleNamespace( + tiled_mma_qk=tiled_mma_qk, + tma_atom_q_latent=tma_atom_q_latent, + tma_atom_q_rope=tma_atom_q_rope, + tma_atom_c_latent=tma_atom_c_latent, + tma_atom_c_rope=tma_atom_c_rope, + mQL=mQL, + mQR=mQR, + mCL=mCL, + mKR=mKR, + sQ=sQ, + sQ_rope=sQ_rope, + sKC=sKC_for_tma, + ) + tma_pv_params = SimpleNamespace( + tiled_mma_pv=tiled_mma_pv, + tma_atom_c_latent_transpose=tma_atom_c_latent_transpose, + mCL=mCL, + mKR=mKR, + mCLT=mCLT, + sVC=sVC_for_tma, + ) + # Load tma + ( + load_q_producer_state, + load_kv_producer_state, + load_pt_consumer_state, + load_pt_release_state, + ) = self.load_tma( + tma_common_params, + tma_qk_params, + tma_pv_params, + k_index, + k_tile_count, + load_q_producer_state, + load_kv_producer_state, + load_pt_consumer_state, + load_pt_release_state, + ) + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + load_q_pipeline.producer_tail(load_q_producer_state) + load_kv_pipeline.producer_tail(load_kv_producer_state) + + # /////////////////////////////////////////////////////////////////////////////// + # MMA warp + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.mma_warp_id: + cute.arch.setmaxregister_decrease(self.other_reg_num) + # Alloc tensor memory buffer + tmem.allocate(cute.arch.get_max_tmem_alloc_cols("sm_100")) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + + load_q_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.load_q_stage + ) + load_kv_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.load_kv_stage + ) + mma_s_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.mma_s_stage + ) + p_mma_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.p_mma_stage + ) + mma_o_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.mma_o_stage + ) + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self.get_k_tile_count( + split_kv, cache_seqs, block_split_kvs, blk_coord + ) + if k_tile_count > 0: + mma_common_params = SimpleNamespace( + blk_coord=blk_coord, + local_split_kv=local_split_kv, + load_q_pipeline=load_q_pipeline, + load_kv_pipeline=load_kv_pipeline, + tmem_ptr=tmem_ptr, + is_leader_cta=is_leader_cta, + L=mCL.shape[1], + ) + mma_qk_params = SimpleNamespace( + mma_s_pipeline=mma_s_pipeline, + sQ=sQ, + sQ_rope=sQ_rope, + sKC=sKC, + ) + mma_pv_params = SimpleNamespace( + p_mma_pipeline=p_mma_pipeline, + mma_o_pipeline=mma_o_pipeline, + sP=sP, + sVC=sVC, + ) + ( + tiled_mma_qk, + tiled_mma_pv, + load_q_consumer_state, + load_kv_consumer_state, + mma_s_producer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) = self.mma( + mma_common_params, + mma_qk_params, + mma_pv_params, + k_tile_count, + tiled_mma_qk, + tiled_mma_pv, + load_q_consumer_state, + load_kv_consumer_state, + mma_s_producer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + mma_s_pipeline.producer_tail(mma_s_producer_state) + mma_o_pipeline.producer_tail(mma_o_producer_state) + + tmem.relinquish_alloc_permit() + tmem.free(tmem_ptr) + + # /////////////////////////////////////////////////////////////////////////////// + # Compute warp + # /////////////////////////////////////////////////////////////////////////////// + if ( + warp_idx >= self.compute_warp_ids[0] + and warp_idx <= self.compute_warp_ids[-1] + ): + cute.arch.setmaxregister_increase(self.softmax_reg_num) + mma_s_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.mma_s_stage + ) + p_mma_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.p_mma_stage + ) + p_cor_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.p_cor_stage + ) + mma_o_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.mma_o_stage + ) + # sync with mma warp before retrieving tmem ptr + tmem.wait_for_alloc() + + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self.get_k_tile_count( + split_kv, cache_seqs, block_split_kvs, blk_coord + ) + if k_tile_count > 0: + compute_common_params = SimpleNamespace( + blk_coord=blk_coord, + split_kv=split_kv, + local_split_kv=local_split_kv, + smem_exchange=softmax_smem_exchange, + mAccO=mAccO, + mO=mO, + K=cache_seqs[blk_coord[2]], + L=mCL.shape[1], + tmem_ptr=tmem_ptr, + tidx=tidx, + p_cor_pipeline=p_cor_pipeline, + ) + compute_softmax_params = SimpleNamespace( + tiled_mma_qk=tiled_mma_qk, + sP=sP, + mma_s_pipeline=mma_s_pipeline, + p_mma_pipeline=p_mma_pipeline, + softmax_scale_log2=softmax_scale_log2, + ) + mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state = ( + self.compute( + compute_common_params, + compute_softmax_params, + k_index=k_index, + k_tile_count=k_tile_count, + mma_s_consumer_state=mma_s_consumer_state, + p_mma_producer_state=p_mma_producer_state, + p_cor_producer_state=p_cor_producer_state, + ) + ) + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + p_cor_pipeline.producer_tail(p_cor_producer_state) + + # /////////////////////////////////////////////////////////////////////////////// + # Correction warp + # /////////////////////////////////////////////////////////////////////////////// + if ( + warp_idx >= self.correction_warp_ids[0] + and warp_idx <= self.correction_warp_ids[-1] + ): + cute.arch.setmaxregister_increase(self.correction_reg_num) + p_cor_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.p_cor_stage + ) + mma_o_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.mma_o_stage + ) + # sync with mma warp before retrieving tmem ptr + tmem.wait_for_alloc() + + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self.get_k_tile_count( + split_kv, cache_seqs, block_split_kvs, blk_coord + ) + if k_tile_count > 0: + compute_common_params = SimpleNamespace( + blk_coord=blk_coord, + split_kv=split_kv, + local_split_kv=local_split_kv, + smem_exchange=epilogue_smem_exchange, + mAccO=mAccO, + mO=mO, + K=cache_seqs[blk_coord[2]], + L=mCL.shape[1], + H=mQL.shape[0], + tmem_ptr=tmem_ptr, + tidx=tidx, + tiled_mma_pv=tiled_mma_pv, + p_cor_pipeline=p_cor_pipeline, + mma_o_pipeline=mma_o_pipeline, + ) + compute_epilogue_params = SimpleNamespace( + output_scale=output_scale, + softmax_scale_log2=softmax_scale_log2, + mAccLSE=mAccLSE, + mLSE=mLSE, + ) + p_cor_consumer_state, mma_o_consumer_state = self.correction( + compute_common_params, + compute_epilogue_params, + k_tile_count=k_tile_count, + p_cor_consumer_state=p_cor_consumer_state, + mma_o_consumer_state=mma_o_consumer_state, + ) + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + return + + @cute.kernel + def reduction_kernel( + self, + mO: cute.Tensor, + mLSE: cute.Tensor, + mAccO: cute.Tensor, + mAccLSE: cute.Tensor, + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + ): + """The reduction kernel for Multi-Head Latent Attention (MLA) that combines intermediate results + from multiple split_kv blocks into final outputs. + + :param mO: Output tensor for storing final results + :type mO: cute.Tensor + :param mLSE: Log-sum-exp tensor for storing final LSE values + :type mLSE: cute.Tensor + :param mAccO: Accumulated output tensor from split_kv blocks + :type mAccO: cute.Tensor + :param mAccLSE: Accumulated LSE tensor from split_kv blocks + :type mAccLSE: cute.Tensor + :param split_kv: Number of split_kv blocks + :type split_kv: cutlass.Int32 + :param cache_seqs: Cache sequence lengths tensor + :type cache_seqs: cute.Tensor + :param block_split_kvs: Per-block split_kv values tensor (for variable split_kv) + :type block_split_kvs: cute.Tensor + """ + bidx, bidy, bidz = cute.arch.block_idx() + tidx, _, _ = cute.arch.thread_idx() + blk_coord = (bidx, bidy, bidz) + local_split_kv = ( + block_split_kvs[blk_coord[2]] if self.is_var_split_kv else split_kv + ) + k_tile_total = cute.ceil_div(cache_seqs[blk_coord[2]], self.mma_qk_tiler[1]) + k_tile_per_cta = cute.ceil_div(k_tile_total, local_split_kv) + local_split_kv = cute.ceil_div(k_tile_total, k_tile_per_cta) + + # Alloc shared memory + smem = utils.SmemAllocator() + storage = smem.allocate(MAX_SPLITS * self.acc_dtype.width // 8, 16) + lse_scale_ptr = cute.recast_ptr(storage, dtype=self.acc_dtype) + smem_lse_scale = cute.make_tensor(lse_scale_ptr, cute.make_layout(MAX_SPLITS)) + + gLSE = mAccLSE[blk_coord[0], None, blk_coord[1], blk_coord[2]] + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + if warp_idx == 0: + # calculate the global lse and exp ^ (local_lse - global_lse) + lse_per_thread = cute.ceil_div(MAX_SPLITS, self.threads_per_warp) + + local_lse = cute.make_rmem_tensor( + cute.make_layout(lse_per_thread), self.lse_dtype + ) + lse_max = -self.lse_dtype.inf + # find the max lse + for i in cutlass.range_constexpr(lse_per_thread): + split_kv_idx = tidx + i * self.threads_per_warp + local_lse[i] = ( + gLSE[split_kv_idx] + if cute.elem_less(split_kv_idx, local_split_kv) + else -self.lse_dtype.inf + ) + # reduce the local lse + lse_max = cute.arch.fmax(lse_max, local_lse[i]) + lse_max = cute.arch.warp_reduction_max(lse_max) + lse_max = lse_max if lse_max != -self.lse_dtype.inf else 0.0 + # calculate sum_lse + sum_lse = 0.0 + for i in cutlass.range_constexpr(lse_per_thread): + sum_lse += cute.math.exp2(local_lse[i] - lse_max, fastmath=True) + sum_lse = cute.arch.warp_reduction_sum(sum_lse) + # calculate the global_lse + global_lse = ( + lse_max + cute.math.log2(sum_lse, fastmath=True) + if not sum_lse == self.lse_dtype(0.0) or sum_lse != sum_lse + else self.lse_dtype.inf + ) + if tidx == 0: + mLSE[blk_coord[0], blk_coord[1], blk_coord[2]] = global_lse + # store the scale to shared memory + for i in cutlass.range_constexpr(lse_per_thread): + split_kv_idx = tidx + i * self.threads_per_warp + if cute.elem_less(split_kv_idx, local_split_kv): + smem_lse_scale[split_kv_idx] = cute.math.exp2( + local_lse[i] - global_lse, fastmath=True + ) + + pipeline.sync(barrier_id=4) + + elements_per_thread = cute.ceil_div( + self.latent_dim, self.threads_per_warp * self.num_compute_warps + ) + gAccO = mAccO[blk_coord[0], None, None, blk_coord[1], blk_coord[2]] + rAccO = cute.make_rmem_tensor( + cute.make_layout(elements_per_thread), self.acc_dtype + ) + rO = cute.make_rmem_tensor(cute.make_layout(elements_per_thread), self.o_dtype) + rAccO.fill(0.0) + for i in range(local_split_kv): + for j in cutlass.range_constexpr(elements_per_thread): + element_idx = tidx + j * self.threads_per_warp * self.num_compute_warps + rAccO[j] += gAccO[i, element_idx] * smem_lse_scale[i] + rO.store(rAccO.load().to(self.o_dtype)) + for j in cutlass.range_constexpr(elements_per_thread): + element_idx = tidx + j * self.threads_per_warp * self.num_compute_warps + mO[blk_coord[0], element_idx, blk_coord[1], blk_coord[2]] = rO[j] + return + + @staticmethod + def get_split_kv( + B: int, S: int, K: int, mma_qk_tiler_mn: tuple, max_active_blocks: int + ) -> int: + """Get the proper split_kv value for the MLA kernel based on parameters. + + :param B: Batch size + :type B: int + :param S: Sequence length + :type S: int + :param K: Sequence length + :type K: int + :param mma_qk_tiler_mn: MLA tiling parameters + :type mma_qk_tiler_mn: tuple + :param max_active_blocks: Maximum number of active blocks + :type max_active_blocks: int + :return: Split_kv value + :rtype: int + """ + max_splits = ceil_div(K, mma_qk_tiler_mn[1]) + blocks_per_batch = max(1, max_active_blocks // B // (S * 2)) + split_heur = min(max_splits, blocks_per_batch) + # {$nv-internal-release begin} + # TODO: figure out the error of make_tile with dynamic int_tuple + # {$nv-internal-release end} + k_waves = ceil_div(max_splits, split_heur) + split_wave_aware = ceil_div(max_splits, k_waves) + max_split_kv = 32 + return min(split_wave_aware, max_split_kv) + + @cute.jit + def get_k_tile_count( + self, + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + blk_coord: cute.Coord, + ) -> tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32]: + """Get the current k_index, k_tile_count, and local split_kv value for the MLA kernel. + + :param split_kv: Split_kv value + :type split_kv: cutlass.Int32 + :param cache_seqs: Cache sequence lengths tensor + :type cache_seqs: cute.Tensor + :param block_split_kvs: Per-block split_kv values tensor + :type block_split_kvs: cute.Tensor + :param blk_coord: Block coordinate + :type blk_coord: cute.Coord + :return: k_index, k_tile_count, split_kv + :rtype: tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32] + """ + K = cache_seqs[blk_coord[2]] + if cutlass.const_expr(self.is_var_split_kv): + split_kv = block_split_kvs[blk_coord[2]] + + k_tile_total = cute.ceil_div(K, self.mma_qk_tiler[1]) + # {$nv-internal-release begin} + # TODO: figure out the error of make_tile with dynamic int_tuple + # {$nv-internal-release end} + k_tile_per_cta = cute.ceil_div(k_tile_total, split_kv) + k_index = blk_coord[3] * k_tile_per_cta + k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index) + return k_index, k_tile_count, split_kv + + @cute.jit + def load_page_table( + self, + common_params: SimpleNamespace, + k_index: cutlass.Int32, + k_tile_count: cutlass.Int32, + load_pt_producer_state: pipeline.PipelineState, + ) -> pipeline.PipelineState: + """Load warp to load page table. Updates the load pt producer state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param k_index: The k index + :type k_index: cutlass.Int32 + :param k_tile_count: The k tile count + :type k_tile_count: cutlass.Int32 + :param load_pt_producer_state: The load pt producer state + :type load_pt_producer_state: pipeline.PipelineState + + :return: The load pt producer state + :rtype: pipeline.PipelineState + """ + mPT = common_params.mPT[None, common_params.blk_coord[2]] + page_per_tile = self.mma_qk_tiler[1] // self.page_size + tidx = common_params.tidx % self.threads_per_warp + + load_pt_pipeline = common_params.load_pt_pipeline + while k_tile_count > 0: + load_pt_pipeline.producer_acquire(load_pt_producer_state) + + elem_per_thread = cute.ceil_div(page_per_tile, self.threads_per_warp) + + # atom_async_copy: async copy atom for page table load + atom_async_copy = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.ALWAYS), + cutlass.Int32, + num_bits_per_copy=cutlass.Int32.width, + ) + mPT_for_copy = cute.flat_divide(mPT, (1,)) + sPT_for_copy = cute.flat_divide(common_params.sPT, (1,)) + # elem_per_thread is a dynamic value depends on the page_size setting. + for i in range(elem_per_thread): + idx = i * self.threads_per_warp + tidx + if cute.elem_less( + k_index * page_per_tile + idx, mPT.shape[0] + ) and cute.elem_less(idx, page_per_tile): + cute.copy( + atom_async_copy, + mPT_for_copy[None, k_index * page_per_tile + idx], + sPT_for_copy[None, idx, load_pt_producer_state.index], + ) + else: + sPT_for_copy[None, idx, load_pt_producer_state.index].fill(0) + mbar_ptr = load_pt_pipeline.producer_get_barrier(load_pt_producer_state) + load_pt_pipeline.producer_commit(load_pt_producer_state) + load_pt_producer_state.advance() + k_index += 1 + k_tile_count -= 1 + + return load_pt_producer_state + + @cute.jit + def load_tma( + self, + common_params: SimpleNamespace, + qk_params: SimpleNamespace, + v_params: SimpleNamespace, + k_index: cutlass.Int32, + k_tile_count: cutlass.Int32, + load_q_producer_state: pipeline.PipelineState, + load_kv_producer_state: pipeline.PipelineState, + load_pt_consumer_state: pipeline.PipelineState, + load_pt_release_state: pipeline.PipelineState, + ) -> tuple[ + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + ]: + """Load wrap to load Q/C latent/rope tensors. Updates the load qkv producer state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param qk_params: The qk parameters + :type qk_params: SimpleNamespace + :param v_params: The v parameters + :type v_params: SimpleNamespace + :param k_index: The k index + :type k_index: cutlass.Int32 + :param k_tile_count: The k tile count + :type k_tile_count: cutlass.Int32 + :param load_q_producer_state: The load q producer state + :type load_q_producer_state: pipeline.PipelineState + :param load_kv_producer_state: The load kv producer state + :type load_kv_producer_state: pipeline.PipelineState + :param load_pt_consumer_state: The load pt consumer state + :type load_pt_consumer_state: pipeline.PipelineState + :param load_pt_release_state: The load pt release state + :type load_pt_release_state: pipeline.PipelineState + + :return: The load q producer state, load kv producer state, load pt consumer state, and load pt release state + :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] + """ + # page table + mPT = common_params.mPT[None, common_params.blk_coord[2]] + + # Flatten divide and partition global tensors for QK TMA load + # (bM, bK, rM, rK, rL) + mma_qk_tiler_mk = cute.select(self.mma_qk_tiler, mode=[0, 2]) + gQL = cute.flat_divide(qk_params.mQL, mma_qk_tiler_mk) + mma_qk_tiler_mk_rope = cute.select(self.mma_qk_rope_tiler, mode=[0, 2]) + gQR = cute.flat_divide(qk_params.mQR, mma_qk_tiler_mk_rope) + + thr_mma_qk = qk_params.tiled_mma_qk.get_slice( + common_params.blk_coord[0] % cute.size(qk_params.tiled_mma_qk.thr_id) + ) + tSgQL = thr_mma_qk.partition_A(gQL) + tSgQR = thr_mma_qk.partition_A(gQR) + + cta_m = min( + qk_params.tiled_mma_qk.op.shape_mnk[0] + // qk_params.tiled_mma_qk.thr_id.shape, + self.page_size, + ) + page_tile_size = min(self.page_size, cta_m) + gCL = cute.tiled_divide(qk_params.mCL, (page_tile_size, self.mma_qk_tiler[2])) + tSgCL = ( + gCL[ + None, + common_params.blk_coord[0] % qk_params.tiled_mma_qk.thr_id.shape, + None, + None, + ] + if cta_m < self.page_size + else gCL[None, 0, None, None] + ) + gKR = cute.tiled_divide(qk_params.mKR, (page_tile_size, self.mma_qk_tiler[2])) + tSgKR = ( + gKR[ + None, + common_params.blk_coord[0] % qk_params.tiled_mma_qk.thr_id.shape, + None, + None, + ] + if cta_m < self.page_size + else gKR[None, 0, None, None] + ) + + # tma partition for q, k latent/rope + # smem: ((atom_v, rest_v), STAGE) + # gmem: ((atom_v, rest_v), RestM, RestK, RestL) + tQsQ, tQLgQL_mkl = cpasync.tma_partition( + qk_params.tma_atom_q_latent, + 0, + cute.make_layout(1), + cute.group_modes(qk_params.sQ, 0, 3), + cute.group_modes(tSgQL, 0, 3), + ) + + tQsQ_rope, tQRgQR_mkl = cpasync.tma_partition( + qk_params.tma_atom_q_rope, + 0, + cute.make_layout(1), + cute.group_modes(qk_params.sQ_rope, 0, 3), + cute.group_modes(tSgQR, 0, 3), + ) + + tKCsKC, tCLgCL = cpasync.tma_partition( + qk_params.tma_atom_c_latent, + 0, + cute.make_layout(1), + qk_params.sKC, + tSgCL, + ) + + _, tKRgKR = cpasync.tma_partition( + qk_params.tma_atom_c_rope, + 0, + cute.make_layout(1), + qk_params.sKC, + tSgKR, + ) + + tQLgQL = tQLgQL_mkl[ + None, None, None, common_params.blk_coord[1], common_params.blk_coord[2] + ] + tQRgQR = tQRgQR_mkl[ + None, None, None, common_params.blk_coord[1], common_params.blk_coord[2] + ] + + # Flatten divide and partition global tensors for V TMA load + page_tile_size = min(self.page_size, self.mma_pv_tiler[2]) + gCLT = cute.flat_divide(v_params.mCLT, (self.mma_pv_tiler[1], page_tile_size)) + cta_n = self.mma_pv_tiler[1] // v_params.tiled_mma_pv.thr_id.shape + gCLT = cute.logical_divide(gCLT, (cta_n,))[ + (None, common_params.blk_coord[0]), None, None, None, None + ] + tOgCLT = cute.tiled_divide(gCLT, (cta_n, page_tile_size)) + tOgCLT = tOgCLT[None, 0, 0, None, None, None] + + # tma partition for vc + # smem: ((atom_v, rest_v), STAGE) + # gmem: ((atom_v, rest_v), RestM, RestK, RestL) + tVCsVC, tCLTgCLT = cpasync.tma_partition( + v_params.tma_atom_c_latent_transpose, + 0, + cute.make_layout(1), + v_params.sVC, + tOgCLT, + ) + + # set extra params + common_params.mPT = mPT + qk_params.tQLgQL = tQLgQL + qk_params.tQRgQR = tQRgQR + qk_params.tCLgCL = tCLgCL + qk_params.tKRgKR = tKRgKR + qk_params.tQsQ = tQsQ + qk_params.tQsQ_rope = tQsQ_rope + qk_params.tKCsKC = tKCsKC + v_params.tCLTgCLT = tCLTgCLT + v_params.tVCsVC = tVCsVC + + load_q_producer_state, load_kv_producer_state, load_pt_consumer_state = ( + self.load_tma_qk_one_k_tile( + common_params, + qk_params, + k_index, + k_tile_count, + load_q_producer_state, + load_kv_producer_state, + load_pt_consumer_state, + load_q=True, + ) + ) + k_index += 1 + k_tile_count -= 1 + while k_tile_count > 0: + # {$nv-internal-release begin} + # TODO: figure out how to support SingleNamespace/struct in ast + # {$nv-internal-release end} + load_q_producer_state, load_kv_producer_state, load_pt_consumer_state = ( + self.load_tma_qk_one_k_tile( + common_params, + qk_params, + k_index, + k_tile_count, + load_q_producer_state, + load_kv_producer_state, + load_pt_consumer_state, + load_q=False, + ) + ) + load_kv_producer_state, load_pt_release_state = self.load_tma_v_one_k_tile( + common_params, + v_params, + k_index - 1, + load_kv_producer_state, + load_pt_release_state, + ) + k_index += 1 + k_tile_count -= 1 + + # load last v tile + load_kv_producer_state, load_pt_release_state = self.load_tma_v_one_k_tile( + common_params, + v_params, + k_index - 1, + load_kv_producer_state, + load_pt_release_state, + ) + return ( + load_q_producer_state, + load_kv_producer_state, + load_pt_consumer_state, + load_pt_release_state, + ) + + @cute.jit + def load_tma_qk_one_k_tile( + self, + common_params: SimpleNamespace, + qk_params: SimpleNamespace, + k_index: cutlass.Int32, + k_tile_count: cutlass.Int32, + load_q_producer_state: pipeline.PipelineState, + load_kv_producer_state: pipeline.PipelineState, + load_pt_consumer_state: pipeline.PipelineState, + load_q: bool, + ) -> tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState]: + """Load one k-tile of Q/C latent/rope tensors. Updates the load qkv producer state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param qk_params: The qk parameters + :type qk_params: SimpleNamespace + :param k_index: The k index + :type k_index: cutlass.Int32 + :param k_tile_count: The k tile count + :type k_tile_count: cutlass.Int32 + :param load_q_producer_state: The load q producer state + :type load_q_producer_state: pipeline.PipelineState + :param load_kv_producer_state: The load kv producer state + :type load_kv_producer_state: pipeline.PipelineState + :param load_pt_consumer_state: The load pt consumer state + :type load_pt_consumer_state: pipeline.PipelineState + :param load_q: Whether to load q + :type load_q: bool + + :return: The load q producer state, load kv producer state, and load pt consumer state + :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] + """ + page_per_tile = ceil_div( + self.mma_qk_tiler[1] // self.page_size, qk_params.tiled_mma_qk.thr_id.shape + ) + common_params.load_pt_pipeline.consumer_wait(load_pt_consumer_state) + page_table_stage = load_pt_consumer_state.index + load_pt_consumer_state.advance() + k_idx = cute.make_rmem_tensor(cute.make_layout(page_per_tile), cutlass.Int32) + for i in cutlass.range_constexpr(page_per_tile): + k_idx[i] = ( + common_params.sPT[0, page_table_stage] + if self.mma_qk_tiler[1] // self.page_size == 1 + else common_params.sPT[ + i + common_params.blk_coord[0] * page_per_tile, page_table_stage + ] + ) + # load q once at first iteration + if cutlass.const_expr(load_q): + common_params.load_q_pipeline.producer_acquire(load_q_producer_state) + # get the mbar ptr from pipeline. + tma_bar_ptr = common_params.load_q_pipeline.producer_get_barrier( + load_q_producer_state + ) + for i in cutlass.range(self.iterations_qk_latent): + # load q latent + cute.copy( + qk_params.tma_atom_q_latent, + qk_params.tQLgQL[None, 0, i], + qk_params.tQsQ[None, (i, 0)], + tma_bar_ptr=tma_bar_ptr, + ) + for i in cutlass.range(self.iterations_qk_rope): + # load q rope + cute.copy( + qk_params.tma_atom_q_rope, + qk_params.tQRgQR[None, 0, i], + qk_params.tQsQ_rope[None, i], + tma_bar_ptr=tma_bar_ptr, + ) + load_q_producer_state.advance() + load_kv_pipeline = common_params.load_kv_pipeline + tma_bar_ptr = load_kv_pipeline.producer_get_barrier(load_kv_producer_state) + for i in cutlass.range(self.iterations_qk_latent): + # get the mbar ptr from pipeline. + tma_bar_ptr = load_kv_pipeline.producer_get_barrier(load_kv_producer_state) + load_kv_pipeline.producer_acquire(load_kv_producer_state) + for k in cutlass.range(page_per_tile): + # load k latent + cute.copy( + qk_params.tma_atom_c_latent, + qk_params.tCLgCL[None, i, k_idx[k]], + qk_params.tKCsKC[None, k, 0, load_kv_producer_state.index], + tma_bar_ptr=tma_bar_ptr, + ) + load_kv_producer_state.advance() + + for i in cutlass.range(self.iterations_qk_rope): + # get the mbar ptr from pipeline. + tma_bar_ptr = load_kv_pipeline.producer_get_barrier(load_kv_producer_state) + load_kv_pipeline.producer_acquire(load_kv_producer_state) + for k in cutlass.range(page_per_tile): + # load k rope + cute.copy( + qk_params.tma_atom_c_rope, + qk_params.tKRgKR[None, i, k_idx[k]], + qk_params.tKCsKC[None, k, 0, load_kv_producer_state.index], + tma_bar_ptr=tma_bar_ptr, + ) + load_kv_producer_state.advance() + + return load_q_producer_state, load_kv_producer_state, load_pt_consumer_state + + @cute.jit + def load_tma_v_one_k_tile( + self, + common_params: SimpleNamespace, + v_params: SimpleNamespace, + k_index: cutlass.Int32, + load_kv_producer_state: pipeline.PipelineState, + load_pt_release_state: pipeline.PipelineState, + ) -> tuple[pipeline.PipelineState, pipeline.PipelineState]: + """Load one k-tile of compressed latent transpose tensor(v). Updates the load qkv producer state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param v_params: The load tma v parameters + :type v_params: SimpleNamespace + :param k_index: The k index + :type k_index: cutlass.Int32 + :param load_kv_producer_state: The load qkv producer state + :type load_kv_producer_state: pipeline.PipelineState + :param load_pt_release_state: The load pt release state + :type load_pt_release_state: pipeline.PipelineState + + :return: The load kv producer state and load pt release state + :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState] + """ + page_per_tile = self.mma_pv_tiler[2] * self.iterations_pv_k // self.page_size + page_per_subtile = ceil_div(page_per_tile, self.iterations_pv_k) + k_idx = cute.make_rmem_tensor(cute.make_layout(page_per_tile), cutlass.Int32) + page_table_stage = load_pt_release_state.index + for i in cutlass.range(page_per_tile): + k_idx[i] = ( + common_params.sPT[0, page_table_stage] + if page_per_tile == 1 + else common_params.sPT[i, page_table_stage] + ) + common_params.load_pt_pipeline.consumer_release(load_pt_release_state) + load_pt_release_state.advance() + load_kv_pipeline = common_params.load_kv_pipeline + tma_bar_ptr = load_kv_pipeline.producer_get_barrier(load_kv_producer_state) + for i in cutlass.range(self.iterations_pv_k): + for j in cutlass.range(self.iterations_pv_n): + # get the mbar ptr from pipeline. + tma_bar_ptr = load_kv_pipeline.producer_get_barrier( + load_kv_producer_state + ) + load_kv_pipeline.producer_acquire(load_kv_producer_state) + for k in cutlass.range(page_per_subtile): + k_idx_i = k_idx[ + k + + i + // ceil_div(self.iterations_pv_k, page_per_tile) + * page_per_subtile + ] + cute.copy( + v_params.tma_atom_c_latent_transpose, + v_params.tCLTgCLT[ + None, + j, + i % ceil_div(self.iterations_pv_k, page_per_tile), + k_idx_i, + ], + v_params.tVCsVC[None, 0, k, load_kv_producer_state.index], + tma_bar_ptr=tma_bar_ptr, + ) + + load_kv_producer_state.advance() + return load_kv_producer_state, load_pt_release_state + + @cute.jit + def mma( + self, + common_params: SimpleNamespace, + qk_params: SimpleNamespace, + pv_params: SimpleNamespace, + k_tile_count: cutlass.Int32, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + load_q_consumer_state: pipeline.PipelineState, + load_kv_consumer_state: pipeline.PipelineState, + mma_s_producer_state: pipeline.PipelineState, + p_mma_consumer_state: pipeline.PipelineState, + mma_o_producer_state: pipeline.PipelineState, + ) -> tuple[ + cute.TiledMma, + cute.TiledMma, + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + ]: + """MMA warp to compute the result of Q*K^T and P*V. Updates the tiled mma and pipeline states. + + :param common_params: The common parameters for mma qk and pv + :type common_params: SimpleNamespace + :param qk_params: The mma qk parameters + :type qk_params: SimpleNamespace + :param pv_params: The mma pv parameters + :type pv_params: SimpleNamespace + :param k_tile_count: The k tile count + :type k_tile_count: cutlass.Int32 + :param tiled_mma_qk: The tiled mma qk + :type tiled_mma_qk: cute.TiledMma + :param tiled_mma_pv: The tiled mma pv + :type tiled_mma_pv: cute.TiledMma + :param load_q_consumer_state: The load q consumer state + :type load_q_consumer_state: pipeline.PipelineState + :param load_kv_consumer_state: The load kv consumer state + :type load_kv_consumer_state: pipeline.PipelineState + :param mma_s_producer_state: The mma s producer state + :type mma_s_producer_state: pipeline.PipelineState + :param p_mma_consumer_state: The p mma consumer state + :type p_mma_consumer_state: pipeline.PipelineState + :param mma_o_producer_state: The mma o producer state + :type mma_o_producer_state: pipeline.PipelineState + + :return: The tiled mma qk, the tiled mma pv, the load q consumer state, the load kv consumer state, the mma s producer state, the p mma consumer state, and the mma o producer state + :rtype: tuple[cute.TiledMma, cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] + """ + + tSrQ = tiled_mma_qk.make_fragment_A(qk_params.sQ) + tSrQ_rope = tiled_mma_qk.make_fragment_A(qk_params.sQ_rope) + tSrKC = tiled_mma_qk.make_fragment_B(qk_params.sKC) + tOrP = tiled_mma_pv.make_fragment_A(pv_params.sP) + tOrVC = tiled_mma_pv.make_fragment_B(pv_params.sVC) + + tStS_shape = tiled_mma_qk.partition_shape_C( + cute.select(self.mma_qk_tiler, mode=[0, 1]) + ) + tStS_staged_fake = tiled_mma_qk.make_fragment_C( + cute.append(tStS_shape, self.mma_s_stage) + ) + # use real tmem ptr for tStS + tStS_staged = cute.make_tensor(common_params.tmem_ptr, tStS_staged_fake.layout) + tOtO_shape = tiled_mma_pv.partition_shape_C( + cute.select(self.mma_pv_tiler, mode=[0, 1]) + ) + # mma O has 1 stage. + tOtO = tiled_mma_pv.make_fragment_C(tOtO_shape) + tOtO_layout = cute.append( + tOtO.layout, + cute.make_layout( + common_params.L // self.mma_pv_tiler[1], + stride=self.mma_pv_tiler[1] // self.warps_in_n, + ), + ) + tOtO_staged = cute.make_tensor( + tStS_staged.iterator + self.tmem_o_offset, tOtO_layout + ) + + # set more parameters + qk_params.tSrQ = tSrQ + qk_params.tSrQ_rope = tSrQ_rope + qk_params.tSrKC = tSrKC + qk_params.tStS_staged = tStS_staged + pv_params.tOrP = tOrP + pv_params.tOrVC = tOrVC + pv_params.tOtO_staged = tOtO_staged + + # mma O accumulates on K, so the accumlate flag is set to False once before all K blocks. + tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, False) + load_q_pipeline = common_params.load_q_pipeline + if common_params.is_leader_cta: + load_q_release_state = load_q_consumer_state.clone() + + ( + tiled_mma_qk, + load_q_consumer_state, + load_kv_consumer_state, + mma_s_producer_state, + ) = self.mma_qk( + common_params, + qk_params, + tiled_mma_qk, + load_q_consumer_state, + load_kv_consumer_state, + mma_s_producer_state, + wait_q=True, + ) + k_tile_count -= 1 + while k_tile_count > 0: + ( + tiled_mma_qk, + load_q_consumer_state, + load_kv_consumer_state, + mma_s_producer_state, + ) = self.mma_qk( + common_params, + qk_params, + tiled_mma_qk, + load_q_consumer_state, + load_kv_consumer_state, + mma_s_producer_state, + wait_q=False, + ) + ( + tiled_mma_pv, + load_kv_consumer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) = self.mma_pv( + common_params, + pv_params, + tiled_mma_pv, + load_kv_consumer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) + k_tile_count -= 1 + + # release q consumer states + load_q_pipeline.consumer_release(load_q_release_state) + load_q_release_state.advance() + ( + tiled_mma_pv, + load_kv_consumer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) = self.mma_pv( + common_params, + pv_params, + tiled_mma_pv, + load_kv_consumer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) + + return ( + tiled_mma_qk, + tiled_mma_pv, + load_q_consumer_state, + load_kv_consumer_state, + mma_s_producer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) + + @cute.jit + def mma_qk( + self, + common_params: SimpleNamespace, + qk_params: SimpleNamespace, + tiled_mma_qk: cute.TiledMma, + load_q_consumer_state: pipeline.PipelineState, + load_kv_consumer_state: pipeline.PipelineState, + mma_s_producer_state: pipeline.PipelineState, + wait_q: bool, + ) -> tuple[ + cute.TiledMma, + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + ]: + """Compute one k-tile of mma for Q*K^T. Updates the tiled MMA QK and pipeline states. + + :param qk_params: The qk parameters + :type qk_params: SimpleNamespace + :param tiled_mma_qk: The tiled mma qk + :type tiled_mma_qk: cute.TiledMma + :param load_q_consumer_state: The load q consumer state + :type load_q_consumer_state: pipeline.PipelineState + :param load_kv_consumer_state: The load kv consumer state + :type load_kv_consumer_state: pipeline.PipelineState + :param mma_s_producer_state: The mma s producer state + :type mma_s_producer_state: pipeline.PipelineState + + :return: The tiled mma qk, the load q consumer state, the load kv consumer state, and the mma s producer state + :rtype: tuple[cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] + """ + tStS = qk_params.tStS_staged[None, None, None, mma_s_producer_state.index] + + qk_params.mma_s_pipeline.producer_acquire(mma_s_producer_state) + tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, False) + load_q_pipeline = common_params.load_q_pipeline + load_kv_pipeline = common_params.load_kv_pipeline + if cutlass.const_expr(wait_q): + load_q_pipeline.consumer_wait(load_q_consumer_state) + load_q_consumer_state.advance() + for q_stage in range(self.iterations_qk_latent): + load_kv_pipeline.consumer_wait(load_kv_consumer_state) + kc_stage = load_kv_consumer_state.index + for k_block in cutlass.range(cute.size(qk_params.tSrQ.shape[2])): + cute.gemm( + tiled_mma_qk, + tStS, + qk_params.tSrQ[None, None, k_block, q_stage], + qk_params.tSrKC[None, None, k_block, kc_stage], + tStS, + ) + tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, True) + load_kv_pipeline.consumer_release(load_kv_consumer_state) + load_kv_consumer_state.advance() + for q_stage in range(self.iterations_qk_rope): + load_kv_pipeline.consumer_wait(load_kv_consumer_state) + kc_stage = load_kv_consumer_state.index + for k_block in cutlass.range(self.rope_dim // tiled_mma_qk.shape_mnk[2]): + cute.gemm( + tiled_mma_qk, + tStS, + qk_params.tSrQ_rope[None, None, k_block, q_stage], + qk_params.tSrKC[None, None, k_block, kc_stage], + tStS, + ) + tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, True) + load_kv_pipeline.consumer_release(load_kv_consumer_state) + load_kv_consumer_state.advance() + + qk_params.mma_s_pipeline.producer_commit(mma_s_producer_state) + mma_s_producer_state.advance() + return ( + tiled_mma_qk, + load_q_consumer_state, + load_kv_consumer_state, + mma_s_producer_state, + ) + + @cute.jit + def mma_pv( + self, + common_params: SimpleNamespace, + pv_params: SimpleNamespace, + tiled_mma_pv: cute.TiledMma, + load_kv_consumer_state: pipeline.PipelineState, + p_mma_consumer_state: pipeline.PipelineState, + mma_o_producer_state: pipeline.PipelineState, + ) -> tuple[ + cute.TiledMma, + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + ]: + """Compute one k-tile of mma for P*V. Updates the tiled mma pv and pipeline states. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param pv_params: The pv parameters + :type pv_params: SimpleNamespace + :param tiled_mma_pv: The tiled mma pv + :type tiled_mma_pv: cute.TiledMma + :param load_kv_consumer_state: The load kv consumer state + :type load_kv_consumer_state: pipeline.PipelineState + :param p_mma_consumer_state: The P MMA consumer state + :type p_mma_consumer_state: pipeline.PipelineState + :param mma_o_producer_state: The MMA o producer state + :type mma_o_producer_state: pipeline.PipelineState + + :return: The tiled mma pv, the load qkv consumer state, the P MMA consumer state, and the MMA o producer state + :rtype: tuple[cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] + """ + + pv_params.mma_o_pipeline.producer_acquire(mma_o_producer_state) + pv_params.p_mma_pipeline.consumer_wait(p_mma_consumer_state) + load_kv_pipeline = common_params.load_kv_pipeline + for p_stage in range(self.iterations_pv_k): + accumulate_flag = tiled_mma_pv.get(tcgen05.Field.ACCUMULATE) + for acc_stage in range(self.iterations_pv_n): + load_kv_pipeline.consumer_wait(load_kv_consumer_state) + tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, accumulate_flag) + vc_stage = load_kv_consumer_state.index + tOtO = pv_params.tOtO_staged[None, None, None, acc_stage] + for k_block in cutlass.range(pv_params.tOrP.shape[2]): + cute.gemm( + tiled_mma_pv, + tOtO, + pv_params.tOrP[ + None, + None, + k_block, + (p_stage, p_mma_consumer_state.index), + ], + pv_params.tOrVC[None, None, k_block, vc_stage], + tOtO, + ) + tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, True) + load_kv_pipeline.consumer_release(load_kv_consumer_state) + load_kv_consumer_state.advance() + pv_params.p_mma_pipeline.consumer_release(p_mma_consumer_state) + p_mma_consumer_state.advance() + pv_params.mma_o_pipeline.producer_commit(mma_o_producer_state) + mma_o_producer_state.advance() + + return ( + tiled_mma_pv, + load_kv_consumer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) + + @cute.jit + def compute( + self, + common_params: SimpleNamespace, + softmax_params: SimpleNamespace, + k_index: cutlass.Int32, + k_tile_count: cutlass.Int32, + mma_s_consumer_state: pipeline.PipelineState, + p_mma_producer_state: pipeline.PipelineState, + p_cor_producer_state: pipeline.PipelineState, + ) -> tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState]: + """Compute warp to compute the result of softmax, rescale, and epilogue. Updates the related pipeline states. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param softmax_params: The softmax parameters + :type softmax_params: SimpleNamespace + :param k_index: The index of the k-tile + :type k_index: cutlass.Int32 + :param k_tile_count: The number of k-tiles + :type k_tile_count: cutlass.Int32 + :param mma_s_consumer_state: The MMA s consumer state + :type mma_s_consumer_state: pipeline.PipelineState + :param p_mma_producer_state: The P MMA producer state + :type p_mma_producer_state: pipeline.PipelineState + :param p_cor_producer_state: The P correction producer state + :type p_cor_producer_state: pipeline.PipelineState + + :return: The MMA s consumer state, the P MMA producer state, and the P correction producer state + :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] + """ + + k_tile_total = cute.ceil_div(common_params.K, self.mma_qk_tiler[1]) + + row_max = -self.acc_dtype.inf + row_sum = self.acc_dtype(0) + correction_factor = self.acc_dtype(1) + common_params.p_cor_pipeline.producer_acquire(p_cor_producer_state) + + # no mask applied + while k_tile_count > 1: + ( + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + ) = self.softmax( + common_params, + softmax_params, + k_index, + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + False, + False, + ) + k_index = k_index + 1 + k_tile_count = k_tile_count - 1 + + # mask applied + if cutlass.const_expr(common_params.mAccO is not None): + ( + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + ) = self.softmax( + common_params, + softmax_params, + k_index, + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + k_index == k_tile_total - 1, + True, + ) + else: + ( + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + ) = self.softmax( + common_params, + softmax_params, + k_index, + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + True, + True, + ) + + return mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state + + @cute.jit + def correction( + self, + common_params: SimpleNamespace, + epilogue_params: SimpleNamespace, + k_tile_count: cutlass.Int32, + p_cor_consumer_state: pipeline.PipelineState, + mma_o_consumer_state: pipeline.PipelineState, + ) -> tuple[pipeline.PipelineState, pipeline.PipelineState]: + """Compute warp to compute the result of softmax, rescale, and epilogue. Updates the related pipeline states. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param epilogue_params: The epilogue parameters + :type epilogue_params: SimpleNamespace + :param k_index: The index of the k-tile + :type k_index: cutlass.Int32 + :param k_tile_count: The number of k-tiles + :type k_tile_count: cutlass.Int32 + :param p_cor_consumer_state: The P correction consumer state + :type p_cor_consumer_state: pipeline.PipelineState + :param mma_o_consumer_state: The MMA o consumer state + :type mma_o_consumer_state: pipeline.PipelineState + + :return: The P correction consumer state, and the MMA o consumer state + :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState] + """ + + k_tile_count_init = k_tile_count + while k_tile_count > 0: + p_cor_consumer_state, row_sum, row_max, correction_factor, no_correction = ( + self.get_correction_factor(common_params, p_cor_consumer_state) + ) + if k_tile_count_init != k_tile_count: + mma_o_consumer_state = self.rescale( + common_params, + mma_o_consumer_state, + correction_factor, + no_correction, + ) + k_tile_count = k_tile_count - 1 + if k_tile_count == 0: + mma_o_consumer_state = self.epilogue( + common_params, + epilogue_params, + mma_o_consumer_state, + row_sum, + row_max, + ) + + return p_cor_consumer_state, mma_o_consumer_state + + @cute.jit + def exchange_p_cor_metadata( + self, + common_params: SimpleNamespace, + softmax_params: SimpleNamespace, + correction_factor: cutlass.Float32, + row_sum: cutlass.Float32, + row_max: cutlass.Float32, + row_max_new: cutlass.Float32, + tAcc: cute.Tensor, + tidx: cutlass.Int32, + p_cor_producer_state: pipeline.PipelineState, + ) -> pipeline.PipelineState: + """Compute the correction factor for the last k tile.""" + no_correction = 0 + if ( + row_max_new - row_max + ) * softmax_params.softmax_scale_log2 <= self.skip_correction_threshold: + no_correction = 1 + row_max_new = row_max + + # pad for 4x32b + corr_layout = cute.make_layout( + (tAcc.shape[0], (4, tAcc.shape[1][1]), self.mma_s_stage), + stride=(tAcc.stride[0], (1, tAcc.stride[1][1]), 4), + ) + tCor = cute.make_tensor( + common_params.tmem_ptr + self.correction_factor_offset, + corr_layout, + ) + cCor = cute.make_identity_tensor(tCor.shape) + corr_tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(4)), self.acc_dtype + ) + corr_tmem_store_tiled_copy = tcgen05.make_tmem_copy(corr_tmem_store_atom, tCor) + corr_tmem_store_thr_copy = corr_tmem_store_tiled_copy.get_slice(tidx) + cCor_for_copy = corr_tmem_store_thr_copy.partition_S(cCor) + tCor_for_copy = corr_tmem_store_thr_copy.partition_D(tCor) + rCor = cute.make_fragment_like( + cCor_for_copy[None, None, None, 0], self.acc_dtype + ) + rCor_int = cute.make_tensor( + cute.recast_ptr(rCor.iterator, dtype=cutlass.Int32), rCor.layout + ) + rCor[0] = row_sum + rCor[1] = row_max_new + rCor[2] = correction_factor + rCor_int[3] = no_correction + + cute.copy( + corr_tmem_store_tiled_copy, + rCor, + tCor_for_copy[None, None, None, p_cor_producer_state.index], + ) + # fence between tmem store and correction warp + cute.arch.fence_view_async_tmem_store() + common_params.p_cor_pipeline.producer_commit(p_cor_producer_state) + p_cor_producer_state.advance() + return p_cor_producer_state, row_max_new + + @cute.jit + def softmax( + self, + common_params: SimpleNamespace, + softmax_params: SimpleNamespace, + k_index: cutlass.Int32, + mma_s_consumer_state: pipeline.PipelineState, + p_mma_producer_state: pipeline.PipelineState, + p_cor_producer_state: pipeline.PipelineState, + row_max: cutlass.Float32, + row_sum: cutlass.Float32, + correction_factor: cutlass.Float32, + is_last_tile: bool, + is_local_last_tile: cutlass.Boolean, + ) -> tuple[ + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + cutlass.Float32, + cutlass.Float32, + cutlass.Float32, + ]: + """Softmax for one k-tile. Updates the related pipeline states and returns the computed results. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param softmax_params: The softmax parameters + :type softmax_params: SimpleNamespace + :param k_index: The index of the k-tile + :type k_index: cutlass.Int32 + :param mma_s_consumer_state: The MMA s consumer state + :type mma_s_consumer_state: pipeline.PipelineState + :param p_mma_producer_state: The P MMA producer state + :type p_mma_producer_state: pipeline.PipelineState + :param p_cor_producer_state: The P correction producer state + :type p_cor_producer_state: pipeline.PipelineState + :param row_max: The row max + :type row_max: cutlass.Float32 + :param row_sum: The row sum + :type row_sum: cutlass.Float32 + :param correction_factor: The correction factor + :type correction_factor: cutlass.Float32 + :param is_last_tile: Whether the last tile + :type is_last_tile: bool + :param is_local_last_tile: Whether the last tile is local + :type is_local_last_tile: cutlass.Boolean + + :return: The MMA s consumer state, the P MMA producer state, the P correction producer state, the row max, the row sum, and the correction factor + :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, cutlass.Float32, cutlass.Float32, cutlass.Float32] + """ + + softmax_params.p_mma_pipeline.producer_acquire(p_mma_producer_state) + softmax_params.mma_s_pipeline.consumer_wait(mma_s_consumer_state) + + # load S from tmem + tStS_shape = softmax_params.tiled_mma_qk.partition_shape_C( + cute.select(self.mma_qk_tiler, mode=[0, 1]) + ) + tStS_staged_fake = softmax_params.tiled_mma_qk.make_fragment_C( + cute.append(tStS_shape, self.mma_s_stage) + ) + tStS_staged = cute.make_tensor(common_params.tmem_ptr, tStS_staged_fake.layout) + tStS = tStS_staged[None, None, None, mma_s_consumer_state.index] + + tAcc = tStS[(None, None), 0, 0] + cta_qk_tiler = ( + self.mma_qk_tiler[0] // self.cluster_shape_mnk[0], + self.mma_qk_tiler[1], + self.mma_qk_tiler[2], + ) + cS = cute.make_identity_tensor(cute.select(cta_qk_tiler, mode=[0, 1])) + + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype + ) + tmem_tiled_copy = tcgen05.make_tmem_copy(tmem_load_atom, tAcc) + + tidx = common_params.tidx % (self.num_compute_warps * self.threads_per_warp) + + tmem_thr_copy = tmem_tiled_copy.get_slice(tidx) + tTR_tAcc = tmem_thr_copy.partition_S(tAcc) + tTR_tS = tmem_thr_copy.partition_D(cS) + + tTR_rAcc = cute.make_fragment_like(tTR_tS, self.acc_dtype) + + row_max_new = row_max + arch = BaseDSL._get_dsl().get_arch_enum() + if cutlass.const_expr(arch >= Arch.sm_100 and arch <= Arch.sm_100f): + cute.copy(tmem_tiled_copy, tTR_tAcc, tTR_rAcc) + for i in cutlass.range_constexpr(cute.size(tTR_rAcc)): + if is_last_tile: + tTR_rAcc[i] = ( + tTR_rAcc[i] + if cute.elem_less( + tTR_tS[i][1] + self.mma_qk_tiler[1] * k_index, + common_params.K, + ) + else -self.acc_dtype.inf + ) + # reduction for row_max + row_max_new = tTR_rAcc.load().reduce(cute.ReductionOp.MAX, row_max_new, 0) + + elif cutlass.const_expr(arch >= Arch.sm_103 and arch <= Arch.sm_103f): + tmem_load_red_atom = cute.make_copy_atom( + tcgen05.copy.LdRed32x32bOp( + tcgen05.copy.Repetition(64), redOp=tcgen05.TmemLoadRedOp.MAX + ), + self.acc_dtype, + ) + tmem_red_tiled_copy = tcgen05.make_tmem_copy(tmem_load_red_atom, tAcc) + tmem_red_thr_copy = tmem_red_tiled_copy.get_slice(tidx) + tTR_tAcc_red = tmem_red_thr_copy.partition_S(tAcc) + tTR_tS_red = tmem_red_thr_copy.partition_D(cS) + tTR_rAcc_red = cute.make_fragment_like(tTR_tS_red, self.acc_dtype) + tTR_rMax = cute.make_rmem_tensor( + cute.make_layout((1, tTR_tS_red.shape[1], tTR_tS_red.shape[2])), + self.acc_dtype, + ) + cute.copy( + tmem_red_tiled_copy, + tTR_tAcc_red, + (tTR_rAcc_red, tTR_rMax), + ) + tTR_rAcc = cute.make_tensor(tTR_rAcc_red.iterator, tTR_rAcc.layout) + if is_last_tile: + for i in cutlass.range_constexpr(cute.size(tTR_rAcc)): + tTR_rAcc[i] = ( + tTR_rAcc[i] + if cute.elem_less( + tTR_tS[i][1] + self.mma_qk_tiler[1] * k_index, + common_params.K, + ) + else -self.acc_dtype.inf + ) + # reduction for row_max + row_max_new = tTR_rAcc.load().reduce( + cute.ReductionOp.MAX, row_max_new, 0 + ) + else: + row_max_new = cute.arch.fmax(row_max_new, tTR_rMax[0]) + + # if warps in N is 2, reduce row_max across warps (0, 1) and (2, 3) + if cutlass.const_expr(self.warps_in_n == 2): + common_params.smem_exchange[tidx] = row_max_new + self.softmax_exchange_sync_bar.wait() + row_max_new = cute.arch.fmax( + row_max_new, + common_params.smem_exchange[ + (tidx + 64) % (self.num_compute_warps * self.threads_per_warp) + ], + ) + + # find correction factor + correction_factor = cute.math.exp2( + (row_max - row_max_new) * softmax_params.softmax_scale_log2, fastmath=True + ) + # split kv case + if cutlass.const_expr(not is_local_last_tile): + p_cor_producer_state, row_max_new = self.exchange_p_cor_metadata( + common_params, + softmax_params, + correction_factor, + row_sum, + row_max, + row_max_new, + tAcc, + tidx, + p_cor_producer_state, + ) + + # softmax + fma_b = softmax_params.softmax_scale_log2 + fma_c = (0.0 - row_max_new) * softmax_params.softmax_scale_log2 + + for i in cutlass.range(cute.size(tTR_rAcc), vectorize=True, unroll_full=True): + tTR_rAcc[i] = tTR_rAcc[i] * fma_b + fma_c + tTR_rAcc[i] = cute.math.exp2(tTR_rAcc[i], fastmath=True) + + tTR_rS = cute.make_fragment_like(tTR_tS, self.q_dtype) + + # quantize + tTR_rS.store(tTR_rAcc.load().to(self.q_dtype)) + + # create sP + sP = softmax_params.sP[None, None, None, (None, p_mma_producer_state.index)] + sP_mk_view = cute.make_tensor( + sP.iterator, + cute.make_layout( + ( + (sP.shape[0][0], sP.shape[1]), + (sP.shape[0][1], sP.shape[2], sP.shape[3]), + ), + stride=( + (sP.stride[0][0], sP.stride[1]), + (sP.stride[0][1], sP.stride[2], sP.stride[3]), + ), + ), + ) + # {$nv-internal-release begin} + # TODO: figure out if we could use A tmem for pv. + # {$nv-internal-release end} + # change to PISL + sP_wo_swizzle_iter = cute.recast_ptr(sP.iterator, swizzle_=None) + swizzle_bits = ( + int(math.log2(self.mma_pv_tiler[2] * self.q_dtype.width // 8 // 32)) + 1 + ) + swizzle_base = 3 if self.q_dtype.width == 16 else 4 + sP_swizzle = cute.make_swizzle(swizzle_bits, swizzle_base, 3) + sP_mk_view = cute.make_tensor( + sP_wo_swizzle_iter, + cute.make_composed_layout(sP_swizzle, 0, sP_mk_view.layout), + ) + universal_copy_bits = 128 + smem_copy_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.q_dtype, + num_bits_per_copy=universal_copy_bits, + ) + smem_tiled_copy = cute.make_tiled_copy_D(smem_copy_atom, tmem_tiled_copy) + smem_thr_copy = smem_tiled_copy.get_slice(tidx) + rP_copy_view = smem_thr_copy.retile(tTR_rS) + sP_copy_view = smem_thr_copy.partition_D(sP_mk_view) + cute.copy(smem_tiled_copy, rP_copy_view, sP_copy_view) + + # fence between smem store and mma o + cute.arch.fence_view_async_shared() + softmax_params.p_mma_pipeline.producer_commit(p_mma_producer_state) + p_mma_producer_state.advance() + + # row_sum, using `add_packed_f32x2` to reduce the number of instructions + row_sum = row_sum * correction_factor + row_sum_vec = (0.0, 0.0) + for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2): + row_sum_vec = cute.arch.add_packed_f32x2( + row_sum_vec, (tTR_rAcc[i], tTR_rAcc[i + 1]) + ) + row_sum = row_sum_vec[0] + row_sum_vec[1] + row_sum + + # split kv case + if cutlass.const_expr(is_local_last_tile): + p_cor_producer_state, row_max_new = self.exchange_p_cor_metadata( + common_params, + softmax_params, + correction_factor, + row_sum, + row_max, + row_max_new, + tAcc, + tidx, + p_cor_producer_state, + ) + + # store correction factor/row_sum/row_max to tmem for correction warp + common_params.p_cor_pipeline.producer_acquire(p_cor_producer_state) + + # fence between tmem load and mma s + cute.arch.fence_view_async_tmem_load() + + softmax_params.mma_s_pipeline.consumer_release(mma_s_consumer_state) + mma_s_consumer_state.advance() + + return ( + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max_new, + row_sum, + correction_factor, + ) + + @cute.jit + def _tmem_load_partition( + self, common_params: SimpleNamespace, tiled_mma_pv: cute.TiledMma, iter_n: int + ) -> tuple[ + cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma + ]: + """Tensor memory load partition for rescale and epilogue. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param tiled_mma_pv: The tiled mma pv + :type tiled_mma_pv: cute.TiledMma + :param iter_n: The iteration number + :type iter_n: int + + :return: The tiled mma pv, the tiled mma pv, the tiled mma pv, the tiled mma pv, the tiled mma pv + :rtype: tuple[cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma] + """ + + tOtO_shape = tiled_mma_pv.partition_shape_C( + cute.select(self.mma_pv_tiler, mode=[0, 1]) + ) + tOtO = tiled_mma_pv.make_fragment_C(tOtO_shape) + tOtO_layout = cute.append( + tOtO.layout, + cute.make_layout( + common_params.L // self.mma_pv_tiler[1], + stride=self.mma_pv_tiler[1] // self.warps_in_n, + ), + ) + tOtO = cute.make_tensor( + common_params.tmem_ptr + self.tmem_o_offset, tOtO_layout + ) + tOtO = tOtO[None, None, None, iter_n] + + tAcc = tOtO[(None, None), 0, 0] + + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype + ) + tmem_load_tiled_copy = tcgen05.make_tmem_copy(tmem_load_atom, tAcc) + # {$nv-internal-release begin} + # TODO: supports size() on tiled copy. + # {$nv-internal-release end} + tmem_load_thr_copy = tmem_load_tiled_copy.get_slice( + common_params.tidx % (self.num_compute_warps * self.threads_per_warp) + ) + + cta_pv_tiler = ( + self.mma_pv_tiler[0] // self.cluster_shape_mnk[0], + self.mma_pv_tiler[1], + self.mma_pv_tiler[2], + ) + # Flatten divide and partition global tensors for O + cta_pv_tiler_mn = cute.select(cta_pv_tiler, mode=[0, 1]) + + gO = None + if cutlass.const_expr(common_params.mAccO is not None): + gO = cute.local_tile( + common_params.mAccO[None, common_params.blk_coord[3], None, None, None], + cta_pv_tiler_mn, + ( + common_params.blk_coord[0], + iter_n, + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + ) + cO = cute.local_tile( + cute.make_identity_tensor( + common_params.mAccO[ + None, common_params.blk_coord[3], None, None, None + ].shape + ), + cta_pv_tiler_mn, + ( + common_params.blk_coord[0], + iter_n, + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + ) + else: + gO = cute.local_tile( + common_params.mO, + cta_pv_tiler_mn, + ( + common_params.blk_coord[0], + iter_n, + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + ) + cO = cute.local_tile( + cute.make_identity_tensor(common_params.mO.shape), + cta_pv_tiler_mn, + ( + common_params.blk_coord[0], + iter_n, + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + ) + tTR_tAcc = tmem_load_thr_copy.partition_S(tAcc) + tTR_gO = tmem_load_thr_copy.partition_D(gO) + tTR_cO = tmem_load_thr_copy.partition_D(cO) + tTR_rAcc = cute.make_fragment_like(tTR_gO, self.acc_dtype) + return tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc + + def get_correction_factor( + self, + common_params: SimpleNamespace, + p_cor_consumer_state: pipeline.PipelineState, + ) -> tuple[ + pipeline.PipelineState, + cutlass.Float32, + cutlass.Float32, + cutlass.Float32, + cutlass.Int32, + ]: + """Get the correction factor from the P correction consumer state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param p_cor_consumer_state: The P correction consumer state + :type p_cor_consumer_state: pipeline.PipelineState + + :return: The P correction consumer state, the row_sum, the row_max, and the correction factor + :rtype: tuple[pipeline.PipelineState, cutlass.Float32, cutlass.Float32, cutlass.Float32, cutlass.Int32] + """ + common_params.p_cor_pipeline.consumer_wait(p_cor_consumer_state) + tidx = common_params.tidx % (self.num_compute_warps * self.threads_per_warp) + # load correction factor + _, tAcc, _, _, _, _ = self._tmem_load_partition( + common_params, common_params.tiled_mma_pv, 0 + ) + corr_layout = cute.make_layout( + (tAcc.shape[0], (4, tAcc.shape[1][1]), self.p_cor_stage), + stride=(tAcc.stride[0], (1, tAcc.stride[1][1]), 4), + ) + tCor = cute.make_tensor( + common_params.tmem_ptr + self.correction_factor_offset, corr_layout + ) + cCor = cute.make_identity_tensor(tCor.shape) + corr_tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(4)), self.acc_dtype + ) + corr_tmem_load_tiled_copy = tcgen05.make_tmem_copy(corr_tmem_load_atom, tCor) + corr_tmem_load_thr_copy = corr_tmem_load_tiled_copy.get_slice(tidx) + tCor_for_copy = corr_tmem_load_thr_copy.partition_S(tCor) + cCor_for_copy = corr_tmem_load_thr_copy.partition_D(cCor) + rCor = cute.make_fragment_like( + cCor_for_copy[None, None, None, 0], self.acc_dtype + ) + rCor_int = cute.make_tensor( + cute.recast_ptr(rCor.iterator, dtype=cutlass.Int32), rCor.layout + ) + cute.copy( + corr_tmem_load_tiled_copy, + tCor_for_copy[None, None, None, p_cor_consumer_state.index], + rCor, + ) + row_sum = rCor[0] + row_max = rCor[1] + correction_factor = rCor[2] + no_correction = rCor_int[3] + + common_params.p_cor_pipeline.consumer_release(p_cor_consumer_state) + p_cor_consumer_state.advance() + return p_cor_consumer_state, row_sum, row_max, correction_factor, no_correction + + @cute.jit + def rescale( + self, + common_params: SimpleNamespace, + mma_o_consumer_state: pipeline.PipelineState, + correction_factor: cutlass.Float32, + no_correction: cutlass.Int32, + ) -> pipeline.PipelineState: + """Rescale for one k-tile. Updates the related pipeline state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param mma_o_consumer_state: The mma o consumer state + :type mma_o_consumer_state: pipeline.PipelineState + :param correction_factor: The correction factor + :type correction_factor: cutlass.Float32 + :param no_correction: Whether to apply correction factor + :type no_correction: cutlass.Int32 + + :return: The MMA o consumer state + :rtype: pipeline.PipelineState + """ + skip_correction = cute.arch.vote_all_sync(no_correction == 1) + common_params.mma_o_pipeline.consumer_wait(mma_o_consumer_state) + if not skip_correction: + for iter_n in cutlass.range_constexpr(self.iterations_pv_n): + # tmem load tiled copy and partition results. + tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc = ( + self._tmem_load_partition( + common_params, common_params.tiled_mma_pv, iter_n + ) + ) + + # tmem store tiled copy + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype + ) + tmem_store_tiled_copy = tcgen05.make_tmem_copy(tmem_store_atom, tAcc) + + # load o + cute.copy(tmem_load_tiled_copy, tTR_tAcc, tTR_rAcc) + # rescale, using `mul_packed_f32x2` to reduce the number of instructions + for i in cutlass.range( + cute.size(tTR_rAcc), vectorize=True, unroll_full=True + ): + tTR_rAcc[i] = tTR_rAcc[i] * correction_factor + + # store o to tensor memory for next k tile + cute.copy(tmem_store_tiled_copy, tTR_rAcc, tTR_tAcc) + + cute.arch.fence_view_async_tmem_store() + common_params.mma_o_pipeline.consumer_release(mma_o_consumer_state) + mma_o_consumer_state.advance() + + return mma_o_consumer_state + + @cute.jit + def epilogue( + self, + common_params: SimpleNamespace, + epilogue_params: SimpleNamespace, + mma_o_consumer_state: pipeline.PipelineState, + row_sum: cutlass.Float32, + row_max: cutlass.Float32, + ) -> pipeline.PipelineState: + """Epilogue for one k-tile. Updates the related pipeline state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param epilogue_params: The epilogue parameters + :type epilogue_params: SimpleNamespace + :param mma_o_consumer_state: The mma o consumer state + :type mma_o_consumer_state: pipeline.PipelineState + :param row_sum: The row sum + :type row_sum: cutlass.Float32 + :param row_max: The row max + :type row_max: cutlass.Float32 + + :return: The MMA o consumer state + :rtype: pipeline.PipelineState + """ + + tidx = common_params.tidx % (self.num_compute_warps * self.threads_per_warp) + + # exchange row_sum between warps (0, 1) and (2, 3) + if cutlass.const_expr(self.warps_in_n == 2): + common_params.smem_exchange[tidx] = row_sum + self.epilogue_exchange_sync_bar.wait() + # (64, 2) + row_sum = ( + row_sum + + common_params.smem_exchange[ + (tidx + 64) % (self.num_compute_warps * self.threads_per_warp) + ] + ) + # mma_o pipeline consumer wait + common_params.mma_o_pipeline.consumer_wait(mma_o_consumer_state) + for iter_n in cutlass.range_constexpr(self.iterations_pv_n): + # tmem load tiled copy and partition results. + tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc = ( + self._tmem_load_partition( + common_params, common_params.tiled_mma_pv, iter_n + ) + ) + + # load o + cute.copy(tmem_load_tiled_copy, tTR_tAcc, tTR_rAcc) + + # apply output scale and normalize by row_sum + for i in cutlass.range( + cute.size(tTR_rAcc), vectorize=True, unroll_full=True + ): + tTR_rAcc[i] = ( + tTR_rAcc[i] + * epilogue_params.output_scale + * cute.arch.rcp_approx(row_sum) + ) + + # store o to global memory + tR2G_rO_src = None + tR2G_rO_dst = tTR_gO + if cutlass.const_expr(common_params.mAccO is None): + tR2G_rO_src = cute.make_fragment_like(tTR_gO, self.o_dtype) + # using final output dtype for o + tR2G_rO_src.store(tTR_rAcc.load().to(self.o_dtype)) + else: + # using accumulate dtype for o + tR2G_rO_src = tTR_rAcc + + if cute.elem_less(tTR_cO[0][0], common_params.H): + cute.autovec_copy( + tR2G_rO_src, + tR2G_rO_dst, + l1c_evict_priority=cute.nvgpu.CacheEvictionPriority.NO_ALLOCATE, + ) + + # store the lse to global memory + cta_pv_tiler = ( + self.mma_pv_tiler[0] // self.cluster_shape_mnk[0], + self.mma_pv_tiler[1], + self.mma_pv_tiler[2], + ) + gLSE = None + cLSE = None + if cutlass.const_expr(epilogue_params.mAccLSE is None): + gLSE = cute.local_tile( + epilogue_params.mLSE, + (cta_pv_tiler[0], 1, 1), + ( + common_params.blk_coord[0], + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + (1, 1, 1), + ) + cLSE = cute.local_tile( + cute.make_identity_tensor(epilogue_params.mLSE.shape), + (cta_pv_tiler[0], 1, 1), + ( + common_params.blk_coord[0], + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + (1, 1, 1), + ) + + else: + gLSE = cute.local_tile( + epilogue_params.mAccLSE[ + None, common_params.blk_coord[3], None, None + ], + (cta_pv_tiler[0], 1, 1), + ( + common_params.blk_coord[0], + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + (1, 1, 1), + ) + cLSE = cute.local_tile( + cute.make_identity_tensor( + epilogue_params.mAccLSE[ + None, common_params.blk_coord[3], None, None + ].shape + ), + (cta_pv_tiler[0], 1, 1), + ( + common_params.blk_coord[0], + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + (1, 1, 1), + ) + lse = ( + cute.math.log2(row_sum, fastmath=True) + + epilogue_params.softmax_scale_log2 * row_max + ) + if cutlass.const_expr(self.warps_in_n == 2): + if cute.elem_less(cLSE[tidx][0], common_params.H): + gLSE[tidx] = lse + + cute.arch.fence_view_async_tmem_load() + common_params.mma_o_pipeline.consumer_release(mma_o_consumer_state) + mma_o_consumer_state.advance() + + return mma_o_consumer_state + + def make_and_init_load_pt_pipeline(self, load_pt_mbar_ptr): + """Create and initialize the load page table pipeline. + + :param load_pt_mbar_ptr: The load page table mbar pointer + :type load_pt_mbar_ptr: cute.Tensor + + :return: The load page table pipeline + :rtype: pipeline.PipelineAsync + """ + load_pt_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * len([self.load_pt_warp_id]), + ) + load_pt_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * len([self.load_tma_warp_id]), + ) + return pipeline.PipelineCpAsync.create( + barrier_storage=load_pt_mbar_ptr, + num_stages=self.load_pt_stage, + producer_group=load_pt_producer_group, + consumer_group=load_pt_consumer_group, + defer_sync=True, + ) + + def make_and_init_load_qkv_pipeline( + self, load_qkv_mbar_ptr, cta_layout_vmnk, load_stages, tx_count + ) -> pipeline.PipelineTmaUmma: + """Create and initialize the tma load qkv pipeline. + + :param load_qkv_mbar_ptr: The load qkv mbar pointer + :type load_qkv_mbar_ptr: cute.Tensor + :param cta_layout_vmnk: The cta layout vmnk + :type cta_layout_vmnk: tuple[int, int, int] + :param load_stages: The load stages + :type load_stages: list[int] + :param tx_count: The tx count + :type tx_count: int + + :return: The tma load qkv pipeline + :rtype: pipeline.PipelineTmaUmma + """ + load_qkv_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.load_tma_warp_id]) + ) + load_qkv_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + return pipeline.PipelineTmaUmma.create( + barrier_storage=load_qkv_mbar_ptr, + num_stages=load_stages, + producer_group=load_qkv_producer_group, + consumer_group=load_qkv_consumer_group, + tx_count=tx_count, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + + def make_and_init_mma_s_pipeline( + self, mma_s_mbar_ptr, cta_layout_vmnk + ) -> pipeline.PipelineUmmaAsync: + """Create and initialize the mma s pipeline. + + :param mma_s_mbar_ptr: The mma s mbar pointer + :type mma_s_mbar_ptr: cute.Tensor + :param cta_layout_vmnk: The cta layout vmnk + :type cta_layout_vmnk: tuple[int, int, int] + + :return: The mma s pipeline + :rtype: pipeline.PipelineUmmaAsync + """ + + mma_s_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + consumer_thread_size = ( + self.threads_per_warp + * len(self.compute_warp_ids) + * self.cluster_shape_mnk[0] + ) + mma_s_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + consumer_thread_size, + ) + return pipeline.PipelineUmmaAsync.create( + barrier_storage=mma_s_mbar_ptr, + num_stages=self.mma_s_stage, + producer_group=mma_s_producer_group, + consumer_group=mma_s_consumer_group, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + + def make_and_init_p_mma_pipeline( + self, p_mma_mbar_ptr, cta_layout_vmnk + ) -> pipeline.PipelineAsyncUmma: + """Create and initialize the p mma pipeline. + + :param p_mma_mbar_ptr: The p mma mbar pointer + :type p_mma_mbar_ptr: cute.Tensor + :param cta_layout_vmnk: The cta layout vmnk + :type cta_layout_vmnk: tuple[int, int, int] + + :return: The p mma pipeline + :rtype: pipeline.PipelineAsyncUmma + """ + + producer_thread_size = ( + self.threads_per_warp + * len(self.compute_warp_ids) + * self.cluster_shape_mnk[0] + ) + p_mma_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + producer_thread_size, + ) + p_mma_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + return pipeline.PipelineAsyncUmma.create( + barrier_storage=p_mma_mbar_ptr, + num_stages=self.p_mma_stage, + producer_group=p_mma_producer_group, + consumer_group=p_mma_consumer_group, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + + def make_and_init_p_cor_pipeline( + self, p_cor_mbar_ptr + ) -> pipeline.PipelineAsyncUmma: + """Create and initialize the p correction pipeline. + + :param p_cor_mbar_ptr: The p correction mbar pointer + :type p_cor_mbar_ptr: cute.Tensor + + :return: The p correction pipeline + :rtype: pipeline.PipelineAsyncUmma + """ + + producer_thread_size = self.threads_per_warp * len(self.compute_warp_ids) + p_cor_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + producer_thread_size, + ) + p_cor_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + producer_thread_size, + ) + return pipeline.PipelineAsync.create( + barrier_storage=p_cor_mbar_ptr, + num_stages=self.p_cor_stage, + producer_group=p_cor_producer_group, + consumer_group=p_cor_consumer_group, + defer_sync=True, + ) + + def make_and_init_mma_o_pipeline( + self, mma_o_mbar_ptr, cta_layout_vmnk + ) -> pipeline.PipelineUmmaAsync: + """Create and initialize the mma o pipeline. + + :param mma_o_mbar_ptr: The mma o mbar pointer + :type mma_o_mbar_ptr: cute.Tensor + :param cta_layout_vmnk: The cta layout vmnk + :type cta_layout_vmnk: tuple[int, int, int] + + :return: The mma o pipeline + :rtype: pipeline.PipelineUmmaAsync + """ + + mma_o_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + consumer_thread_size = ( + self.threads_per_warp + * len(self.compute_warp_ids) + * self.cluster_shape_mnk[0] + ) + mma_o_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + consumer_thread_size, + ) + return pipeline.PipelineUmmaAsync.create( + barrier_storage=mma_o_mbar_ptr, + num_stages=self.mma_o_stage, + producer_group=mma_o_producer_group, + consumer_group=mma_o_consumer_group, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + + @staticmethod + def _compute_grid( + o: cute.Tensor, + split_kv: cutlass.Int32, + cluster_shape_mnk: Tuple[int, int, int], + max_active_clusters: int, + is_persistent: bool, + ) -> Tuple[MLAStaticTileSchedulerParams, Tuple[int, int, int]]: + """Compute grid shape for the output tensor C. + + :param c: The output tensor C + :type c: cute.Tensor + :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type cta_tile_shape_mnk: tuple[int, int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + + :return: Tile scheduler parameters and grid shape. + :rtype: tuple[MLAStaticTileSchedulerParams, tuple[int, int, int]] + """ + o_shape = o.shape + tile_sched_params = create_mla_static_tile_scheduler_params( + is_persistent, + cute.size(o_shape[3]), + cute.size(o_shape[2]), + cluster_shape_mnk, + split_kv, + ) + grid = MLAStaticTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + + return tile_sched_params, grid + + @staticmethod + def get_workspace_size( + H: int, + S: int, + D: int, + B: int, + split_kv: int, + acc_dtype: Type[cutlass.Numeric], + ) -> int: + """Get the extra workspace(device memory) size for the MLA kernel when split_kv is not 1. + + :param H: The height of the output tensor C + :type H: int + :param S: The sequence length of the output tensor C + :type S: int + :param D: The depth of the output tensor C + :type D: int + :param B: The batch size of the output tensor C + :type B: int + :param split_kv: The split key-value of the output tensor C + :type split_kv: int + :param acc_dtype: The data type of the output tensor C + :type acc_dtype: Type[cutlass.Numeric] + + :return: The workspace size for the MLA kernel + :rtype: int + """ + if split_kv == 1: + return 0 + return B * H * S * split_kv * (D + 1) * acc_dtype.width // 8 + + @cute.jit + def initialize_workspace( + self, + H: cutlass.Int32, + D: cutlass.Int32, + S: cutlass.Int32, + B: cutlass.Int32, + split_kv: cutlass.Int32, + acc_dtype: Type[cutlass.Numeric], + workspace: cute.Tensor, + ) -> tuple[cute.Tensor, cute.Tensor]: + """Initialize the workspace for the MLA kernel. Construct the intermediate tensors + acc_o and acc_lse. + + :param H: The height of the output tensor C + :type H: cutlass.Int32 + :param D: The depth of the output tensor C + :type D: cutlass.Int32 + :param S: The sequence length of the output tensor C + :type S: cutlass.Int32 + :param B: The batch size of the output tensor C + :type B: cutlass.Int32 + :param split_kv: The split key-value of the output tensor C + :type split_kv: cutlass.Int32 + :param acc_dtype: The data type of the output tensor C + :type acc_dtype: Type[cutlass.Numeric] + :param workspace: The workspace tensor + :type workspace: cute.Tensor + + :return: The output tensor C and the workspace tensor + :rtype: tuple[cute.Tensor, cute.Tensor] + """ + acc_o, acc_lse = None, None + if cutlass.const_expr(workspace is not None): + align = 256 // self.q_dtype.width + acc_o_layout = cute.make_layout( + (H, split_kv, D, S, B), + stride=( + cute.assume(split_kv * D, align), + cute.assume(D, align), + 1, + cute.assume(split_kv * H * D, align), + cute.assume(H * split_kv * S * D, align), + ), + ) + acc_o_iter = cute.recast_ptr(workspace.iterator, dtype=acc_dtype) + acc_o = cute.make_tensor(acc_o_iter, acc_o_layout) + acc_lse_layout = cute.make_layout( + (H, split_kv, S, B), + stride=(split_kv, 1, H * split_kv, H * split_kv * S), + ) + acc_lse_iter = cute.recast_ptr( + workspace.iterator + cute.cosize(acc_o_layout) * acc_dtype.width // 8, + dtype=acc_dtype, + ) + acc_lse = cute.make_tensor(acc_lse_iter, acc_lse_layout) + return acc_o, acc_lse + + @staticmethod + def can_implement( + B: int, + S: int, + K: int, + H: int, + L: int, + R: int, + in_dtype: Type[cutlass.Numeric], + out_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + lse_dtype: Type[cutlass.Numeric], + mma_qk_tiler_mn: Tuple[int, int], + mma_pv_tiler_mn: Tuple[int, int], + split_kv: int, + is_persistent: bool, + is_var_seq: bool, + is_var_split_kv: bool, + page_size: int, + ) -> bool: + """Check if the MLA kernel can be implemented. + + :param B: The batch size of the output tensor C + :type B: int + :param S: The sequence length of the output tensor C + :type S: int + :param K: The width of the output tensor KV + :type K: int + :param H: The number of heads of the output tensor C + :type H: int + :param L: The number of latent dimensions of the tensor KV + :type L: int + :param R: The number of rope dimensions of the tensor C_rope + :type R: int + :param in_dtype: The data type of the input tensor + :type in_dtype: Type[cutlass.Numeric] + :param out_dtype: The data type of the output tensor + :type out_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param lse_dtype: The data type of the log-sum-exp + :type lse_dtype: Type[cutlass.Numeric] + :param mma_qk_tiler_mn: The tile shape of the query-key matrix multiplication + :type mma_qk_tiler_mn: Tuple[int, int] + :param mma_pv_tiler_mn: The tile shape of the probability-value matrix multiplication + :type mma_pv_tiler_mn: Tuple[int, int] + :param split_kv: The split key-value of the output tensor C + :type split_kv: int + :param is_persistent: Whether to use persistent kernel optimization + :type is_persistent: bool + :param is_var_seq: Whether to use variable sequence length + :type is_var_seq: bool + :param is_var_split_kv: Whether to use variable split_kv + :type is_var_split_kv: bool + :param page_size: The page size of the page table + :type page_size: int + + :return: Whether the MLA kernel can be implemented + :rtype: bool + """ + if L != 512 or R != 64: + return False + if in_dtype not in [cutlass.Float16]: + return False + if out_dtype not in [cutlass.Float16]: + return False + if acc_dtype != cutlass.Float32 or lse_dtype != cutlass.Float32: + return False + # page size equals 1 is prohibited by tma specification, not 128B aligned. + if mma_qk_tiler_mn[1] % page_size != 0 or page_size == 1: + return False + if mma_qk_tiler_mn[0] != mma_pv_tiler_mn[0] or mma_qk_tiler_mn[0] != 128: + return False + if is_var_split_kv and not is_var_seq: + return False + if H > 128 or (H < 128 and split_kv != 1): + return False + if S < 1 or S > 4: + return False + if K <= 0: + return False + return True + + +def run( + batch_size: int, + seq_len_q: int, + seq_len_k: int, + num_heads: int, + latent_dim: int, + rope_dim: int, + in_dtype: Type[cutlass.Numeric], + out_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + lse_dtype: Type[cutlass.Numeric], + mma_qk_tiler_mn: Tuple[int, int], + mma_pv_tiler_mn: Tuple[int, int], + split_kv: int, + is_persistent: bool, + is_var_seq: bool, + is_var_split_kv: bool, + page_size: int, + softmax_scale: float, + output_scale: float, + skip_correction_threshold: float, + tolerance: float, + warmup_iterations: int, + iterations: int, + skip_ref_check: bool, + use_cold_l2: bool, + **kwargs, +): + """Execute Multi-Head Latent Attention (MLA) on Blackwell architecture and validate results. + + This function creates random input tensors for query latent/rope, compressed latent/rope, and value, + then performs the complete MLA computation pipeline. It supports configurable data types, tiling parameters, + page table, variable sequence length, and variable split_kv. Results can be validated against a PyTorch reference + implementation or run multiple times for performance measurement. + + :param batch_size: Batch size + :type batch_size: int + :param seq_len_q: Sequence length of Q + :type seq_len_q: int + :param seq_len_k: Sequence length of K + :type seq_len_k: int + :param num_heads: Number of heads + :type num_heads: int + :param latent_dim: dimension of query/compressed latent + :type latent_dim: int + :param rope_dim: dimension of query/compressed rope + :type rope_dim: int + :param in_dtype: Input data type for query/compressed latent/rope tensors + :type in_dtype: Type[cutlass.Numeric] + :param out_dtype: Output data type for attention output + :type out_dtype: Type[cutlass.Numeric] + :param acc_dtype: Accumulator data type for query-key matrix multiplication + :type acc_dtype: Type[cutlass.Numeric] + :param lse_dtype: Accumulator data type for log-sum-exp + :type lse_dtype: Type[cutlass.Numeric] + :param mma_qk_tiler_mn: Matrix multiply accumulate tile shape (M, N) for query-key matrix multiplication + :type mma_qk_tiler_mn: Tuple[int, int] + :param mma_pv_tiler_mn: Matrix multiply accumulate tile shape (M, N) for probability-value matrix multiplication + :type mma_pv_tiler_mn: Tuple[int, int] + :param split_kv: Split key-value + :type split_kv: int + :param is_persistent: Whether to use persistent kernel optimization + :type is_persistent: bool + :param is_var_seq: Whether to use variable sequence length + :type is_var_seq: bool + :param is_var_split_kv: Whether to use variable split_kv + :type is_var_split_kv: bool + :param page_size: Page size of the page table + :type page_size: int + :param softmax_scale: Attention score scaling factor + :type softmax_scale: float + :param output_scale: Output scaling factor + :type output_scale: float + :param skip_correction_threshold: Threshold to skip correction + :type skip_correction_threshold: float + :param tolerance: Maximum acceptable error for validation + :type tolerance: float + :param warmup_iterations: Number of warmup iterations + :type warmup_iterations: int + :param iterations: Number of iterations to run for performance testing + :type iterations: int + :param skip_ref_check: Skip validation against reference implementation + :type skip_ref_check: bool + :param use_cold_l2: Whether to use cold L2 cache + :type use_cold_l2: bool + + :raises ValueError: If input shapes are incompatible or head dimension is unsupported + :raises RuntimeError: If GPU is unavailable for computation + """ + + print("Running Blackwell MLA test with:") + print(f" batch_size: {batch_size}") + print(f" seq_len_q: {seq_len_q}") + print(f" seq_len_k: {seq_len_k}") + print(f" num_heads: {num_heads}") + print(f" latent_dim: {latent_dim}") + print(f" rope_dim: {rope_dim}") + print(f" in_dtype: {in_dtype}") + print(f" out_dtype: {out_dtype}") + print(f" acc_dtype: {acc_dtype}") + print(f" mma_qk_tiler_mn: {mma_qk_tiler_mn}") + print(f" mma_pv_tiler_mn: {mma_pv_tiler_mn}") + print(f" split_kv: {split_kv}") + print(f" is_persistent: {is_persistent}") + print(f" is_var_seq: {is_var_seq}") + print(f" is_var_split_kv: {is_var_split_kv}") + print(f" page_size: {page_size}") + print(f" softmax_scale: {softmax_scale}") + print(f" output_scale: {output_scale}") + print(f" skip_correction_threshold: {skip_correction_threshold}") + print(f" tolerance: {tolerance}") + print(f" warmup_iterations: {warmup_iterations}") + print(f" iterations: {iterations}") + print(f" skip_ref_check: {skip_ref_check}") + print(f" use_cold_l2: {use_cold_l2}") + + # Prepare pytorch tensors: Q, K, V (random from 0 to 2) and O (all zero) + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + if not BlackwellMultiHeadLatentAttentionForwardFP16.can_implement( + batch_size, + seq_len_q, + seq_len_k, + num_heads, + latent_dim, + rope_dim, + in_dtype, + out_dtype, + acc_dtype, + lse_dtype, + mma_qk_tiler_mn, + mma_pv_tiler_mn, + split_kv, + is_persistent, + is_var_seq, + is_var_split_kv, + page_size, + ): + raise TypeError( + f"Unsupported testcase {batch_size}, {seq_len_q}, {seq_len_k}, {num_heads}, {latent_dim}, {rope_dim}, {in_dtype}, {out_dtype}, {acc_dtype}, {lse_dtype}, {mma_qk_tiler_mn}, {mma_pv_tiler_mn}, {split_kv}, {is_persistent}, {is_var_seq}, {is_var_split_kv}, {page_size}" + ) + + torch.manual_seed(1111) + + def create_data_tensor( + B, + HK, + D, + dtype, + is_dynamic_layout=True, + page_table=None, + cache_seqs=None, + is_lse=False, + seq_len_q=None, + ): + shape = (B, HK, D) + if page_table is not None: + if cache_seqs is not None: + max_seq_len = torch.max(cache_seqs) + shape = (B * ceil_div(max_seq_len, page_size), page_size, D) + else: + shape = (B * ceil_div(HK, page_size), page_size, D) + + if seq_len_q is not None: + shape = (B, seq_len_q, HK, D) + + permute_order = (1, 2, 0) + stride_order = (2, 0, 1) + leading_dim = 1 + if is_lse: + shape = (B, seq_len_q, HK) + permute_order = (2, 1, 0) + stride_order = (2, 1, 0) + leading_dim = 0 + elif seq_len_q is not None: + permute_order = (2, 3, 1, 0) + stride_order = (3, 2, 0, 1) + leading_dim = 1 + + init_config = cutlass.torch.RandomInitConfig(min_val=-2, max_val=2) + + torch_dtype = ( + cutlass_torch.dtype(dtype) if dtype != cutlass.Float8E4M3FN else torch.int8 + ) + + # Create dtype torch tensor (cpu) + torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( + shape, + torch_dtype, + permute_order=permute_order, + init_type=cutlass.torch.TensorInitType.RANDOM, + init_config=init_config, + ) + + # Create dtype torch tensor (gpu) + torch_tensor_gpu = torch_tensor_cpu.cuda() + + # Create f32 torch tensor (cpu) + f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32) + + # Create dtype cute tensor (gpu) + cute_tensor = from_dlpack(torch_tensor_gpu, assumed_align=16) + cute_tensor.element_type = dtype + if is_dynamic_layout: + cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=leading_dim) + if not is_lse: + cute_tensor = cute_tensor.mark_compact_shape_dynamic( + mode=leading_dim, + stride_order=stride_order, + divisibility=(128 // dtype.width), + ) + + cute_tensor = cutlass_torch.convert_cute_tensor( + f32_torch_tensor, + cute_tensor, + dtype, + is_dynamic_layout=is_dynamic_layout, + ) + + return f32_torch_tensor, cute_tensor, torch_tensor_gpu + + def create_cache_seqs(batch_size, seq_len_k, is_var_seq): + cache_seqs_ref = torch.ones(batch_size, dtype=torch.int32) * seq_len_k + cache_seqs_gpu = cache_seqs_ref.cuda() + cache_seqs = from_dlpack(cache_seqs_gpu, assumed_align=16).mark_layout_dynamic() + if is_var_seq: + max_seq_len = seq_len_k + min_seq_len = int(seq_len_k * 0.8) + cache_seqs_ref = cutlass_torch.create_and_permute_torch_tensor( + (batch_size,), + torch.int32, + init_type=cutlass.torch.TensorInitType.RANDOM, + init_config=cutlass.torch.RandomInitConfig( + min_val=min_seq_len, max_val=max_seq_len + 1 + ), + ) + cache_seqs_gpu = cache_seqs_ref.cuda() + cache_seqs = from_dlpack( + cache_seqs_gpu, + assumed_align=16, + ).mark_layout_dynamic() + return cache_seqs_ref, cache_seqs, cache_seqs_gpu + + def create_page_table(batch_size, seq_len_k, is_var_seq, page_size): + max_seq_len = seq_len_k if not is_var_seq else torch.max(cache_seqs_ref) + page_count = ceil_div(max_seq_len, page_size) + page_table_ref = torch.empty([batch_size, page_count], dtype=torch.int32) + # use transposed index for page table to make sure the value is in bound of `batch_size * seq_len_block`. In practice, the value could be any positive values. This setting is only for testing purpose. + for b in range(batch_size): + for j in range(page_count): + page_table_ref[b, j] = b + j * batch_size + page_table_gpu = page_table_ref.permute(1, 0).cuda() + page_table = from_dlpack(page_table_gpu, assumed_align=16).mark_layout_dynamic( + leading_dim=0 + ) + return page_table_ref, page_table, page_table_gpu + + def create_block_split_kvs( + batch_size, + split_kv, + cache_seqs_ref, + is_var_split_kv, + mma_qk_tiler_mn, + cluster_shape_mnk, + max_active_clusters, + ): + block_split_kvs_ref, block_split_kvs, block_split_kvs_gpu = None, None, None + # check if split_kv is valid otherwise do auto setting of split_kv + if is_var_split_kv: + block_split_kvs_ref = torch.zeros([batch_size], dtype=torch.int32) + for b in range(batch_size): + block_split_kvs_ref[b] = ( + BlackwellMultiHeadLatentAttentionForwardFP16.get_split_kv( + batch_size, + seq_len_q, + cache_seqs_ref[b].item(), + mma_qk_tiler_mn, + max_active_clusters * cluster_shape_mnk[0], + ) + ) + split_kv = torch.max(block_split_kvs_ref).item() + block_split_kvs_gpu = block_split_kvs_ref.cuda() + block_split_kvs = from_dlpack( + block_split_kvs_gpu, assumed_align=16 + ).mark_layout_dynamic() + elif split_kv <= 0: + split_kv = BlackwellMultiHeadLatentAttentionForwardFP16.get_split_kv( + batch_size, + seq_len_q, + cache_seqs_ref[0].item(), + mma_qk_tiler_mn, + max_active_clusters * cluster_shape_mnk[0], + ) + return split_kv, block_split_kvs_ref, block_split_kvs, block_split_kvs_gpu + + def create_workspace( + num_heads, seq_len_q, latent_dim, batch_size, split_kv, acc_dtype + ): + workspace_size = ( + BlackwellMultiHeadLatentAttentionForwardFP16.get_workspace_size( + num_heads, + seq_len_q, + latent_dim, + batch_size, + split_kv, + acc_dtype, + ) + ) + + workspace, workspace_torch = None, None + if workspace_size > 0: + workspace_torch = torch.empty([workspace_size], dtype=torch.int8).cuda() + workspace = from_dlpack(workspace_torch, assumed_align=32) + return workspace, workspace_torch + + cache_seqs_ref, cache_seqs, cache_seqs_torch = create_cache_seqs( + batch_size, seq_len_k, is_var_seq + ) + page_table_ref, page_table, page_table_torch = create_page_table( + batch_size, seq_len_k, is_var_seq, page_size + ) + cluster_shape_mnk = (2, 1, 1) + hardware_info = utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters( + cluster_shape_mnk[0] * cluster_shape_mnk[1] + ) + split_kv, block_split_kvs_ref, block_split_kvs, block_split_kvs_torch = ( + create_block_split_kvs( + batch_size, + split_kv, + cache_seqs_ref, + is_var_split_kv, + mma_qk_tiler_mn, + cluster_shape_mnk, + max_active_clusters, + ) + ) + + q_latent_ref, q_latent, q_latent_torch = create_data_tensor( + batch_size, + num_heads, + latent_dim, + in_dtype, + is_dynamic_layout=True, + seq_len_q=seq_len_q, + ) + q_rope_ref, q_rope, q_rope_torch = create_data_tensor( + batch_size, + num_heads, + rope_dim, + in_dtype, + is_dynamic_layout=True, + seq_len_q=seq_len_q, + ) + + c_latent_ref, c_latent, c_latent_torch = create_data_tensor( + batch_size, + seq_len_k, + latent_dim, + in_dtype, + is_dynamic_layout=True, + page_table=page_table, + cache_seqs=cache_seqs_ref, + ) + c_rope_ref, c_rope, c_rope_torch = create_data_tensor( + batch_size, + seq_len_k, + rope_dim, + in_dtype, + is_dynamic_layout=True, + page_table=page_table, + cache_seqs=cache_seqs_ref, + ) + o_ref, o, o_torch = create_data_tensor( + batch_size, + num_heads, + latent_dim, + out_dtype, + is_dynamic_layout=True, + seq_len_q=seq_len_q, + ) + lse_ref, lse, lse_torch = create_data_tensor( + batch_size, + num_heads, + 1, + lse_dtype, + is_dynamic_layout=True, + is_lse=True, + seq_len_q=seq_len_q, + ) + workspace, workspace_torch = create_workspace( + num_heads, seq_len_q, latent_dim, batch_size, split_kv, acc_dtype + ) + + mla = BlackwellMultiHeadLatentAttentionForwardFP16( + acc_dtype, + lse_dtype, + mma_qk_tiler_mn, + mma_pv_tiler_mn, + max_active_clusters, + page_size, + skip_correction_threshold, + is_persistent, + is_var_seq, + is_var_split_kv, + ) + + # Get current CUDA stream from PyTorch + torch_stream = torch.cuda.current_stream() + # Get the raw stream pointer as a CUstream + stream = cuda.CUstream(torch_stream.cuda_stream) + + # compile mla kernel + compiled_mla = cute.compile( + mla, + q_latent, + q_rope, + c_latent, + c_rope, + page_table, + o, + lse, + workspace, + split_kv, + cache_seqs, + block_split_kvs, + softmax_scale, + output_scale, + stream, + options="--opt-level 2", + ) + + def torch_reference_mla( + q_latent, + q_rope, + c_latent, + c_rope, + page_table, + cache_seqs, + softmax_scale=1.0, + output_scale=1.0, + ): + # expand and concat q_latent and q_rope to have the dimension of sequence length for q + q_ref = torch.cat([q_latent, q_rope], dim=1).permute(3, 2, 0, 1) + # expand and concat c_latent and c_rope to have the dimension of num_heads for k and v + page_count = page_table_ref.shape[1] + k_ref_paged = ( + torch.cat([c_latent, c_rope], dim=1) + .permute(2, 0, 1) + .reshape(batch_size * page_count, page_size, latent_dim + rope_dim) + ) + v_ref_paged = c_latent.permute(2, 0, 1).reshape( + batch_size * page_count, page_size, latent_dim + ) + + if is_var_seq: + max_seq_len = torch.max(cache_seqs_ref) + else: + max_seq_len = seq_len_k + + k_ref = torch.zeros([batch_size, 1, max_seq_len, latent_dim + rope_dim]) + v_ref = torch.zeros([batch_size, 1, max_seq_len, latent_dim]) + k_ref = torch.index_select( + k_ref_paged, 0, torch.flatten(page_table_ref) + ).reshape(batch_size, 1, -1, latent_dim + rope_dim)[:, :, :max_seq_len, :] + v_ref = torch.index_select( + v_ref_paged, 0, torch.flatten(page_table_ref) + ).reshape(batch_size, 1, -1, latent_dim)[:, :, :max_seq_len, :] + for b in range(batch_size): + k_ref[b, :, cache_seqs_ref[b] :, :] = 0 + v_ref[b, :, cache_seqs_ref[b] :, :] = 0 + import torch.nn.functional as F + + o_ref = F.scaled_dot_product_attention( + q_ref, + k_ref, + v_ref, + attn_mask=None, + dropout_p=0.0, + scale=softmax_scale, + is_causal=False, + ) + s_ref = torch.einsum("bhld,bhsd->bhls", q_ref, k_ref) + s_ref_max, s_ref_max_pos = torch.max(s_ref, dim=-1, keepdim=True) + softmax_scale_log2 = LOG2_E * softmax_scale + s_ref_sum = torch.sum( + torch.exp2((s_ref - s_ref_max) * softmax_scale_log2), dim=-1, keepdim=True + ) + + lse_ref = s_ref_max * softmax_scale_log2 + torch.log2(s_ref_sum) + lse_ref = lse_ref.squeeze(3).permute(2, 1, 0) + o_ref = o_ref * output_scale + o_ref = o_ref.permute(2, 3, 1, 0) + + return o_ref, lse_ref + + if skip_correction_threshold > 0.0: + print( + "Skipping correction verification since skip_correction_threshold is greater than 0.0..." + ) + skip_ref_check = True + if not skip_ref_check: + # Execute kernel once for reference checking + compiled_mla( + q_latent, + q_rope, + c_latent, + c_rope, + page_table, + o, + lse, + workspace, + split_kv, + cache_seqs, + block_split_kvs, + softmax_scale, + output_scale, + stream, + ) + torch.cuda.synchronize() + + print("Verifying results...") + if in_dtype == cutlass.Float8E4M3FN: + tolerance = 0.13 + o_ref, lse_ref = torch_reference_mla( + q_latent_ref, + q_rope_ref, + c_latent_ref, + c_rope_ref, + page_table, + cache_seqs, + softmax_scale, + output_scale, + ) + + if out_dtype in [cutlass.Float8E5M2, cutlass.Float8E4M3FN]: + # {$nv-internal-release begin} + # todo: not sure why, but the below `cute.testing.convert` will cause bus error occasionally in local and ci. + # {$nv-internal-release end} + # convert o back to f32 for comparison + o_fp32, o_fp32_torch = cutlass_torch.cute_tensor_like( + torch.empty(*o_torch.shape, dtype=torch.float32), + cutlass.Float32, + is_dynamic_layout=True, + assumed_align=16, + ) + cute.testing.convert(o, o_fp32) + o = o_fp32_torch.cpu() + ref_fp8, _ = cutlass_torch.cute_tensor_like( + torch.empty( + *o_ref.permute(3, 2, 0, 1).shape, dtype=torch.uint8 + ).permute(2, 3, 1, 0), + out_dtype, + is_dynamic_layout=True, + assumed_align=16, + ) + o_ref_gpu = o_ref.cuda() + o_ref_f32 = from_dlpack(o_ref_gpu).mark_layout_dynamic(leading_dim=1) + + # convert ref : f32 -> fp8 -> f32 + cute.testing.convert(o_ref_f32, ref_fp8) + cute.testing.convert(ref_fp8, o_ref_f32) + + o_ref = o_ref_gpu.cpu() + else: + o = o_torch.cpu().to(torch.float32) + lse = lse_torch.cpu() + lse_ref = lse_ref.to(cutlass.torch.dtype(lse_dtype)) + # Assert close results + torch.testing.assert_close(o, o_ref, atol=tolerance, rtol=1e-05) + torch.testing.assert_close(lse, lse_ref, atol=tolerance, rtol=1e-05) + print("Results verified successfully!") + + def generate_tensors(): + _, cache_seqs, _ = create_cache_seqs(batch_size, seq_len_k, is_var_seq) + _, page_table, _ = create_page_table( + batch_size, seq_len_k, is_var_seq, page_size + ) + _split_kv, _, block_split_kvs, _ = create_block_split_kvs( + batch_size, + split_kv, + cache_seqs_ref, + is_var_split_kv, + mma_qk_tiler_mn, + cluster_shape_mnk, + max_active_clusters, + ) + + _, q_latent, _ = create_data_tensor( + batch_size, + num_heads, + latent_dim, + in_dtype, + is_dynamic_layout=True, + seq_len_q=seq_len_q, + ) + _, q_rope, _ = create_data_tensor( + batch_size, + num_heads, + rope_dim, + in_dtype, + is_dynamic_layout=True, + seq_len_q=seq_len_q, + ) + + _, c_latent, _ = create_data_tensor( + batch_size, + seq_len_k, + latent_dim, + in_dtype, + is_dynamic_layout=True, + page_table=page_table, + cache_seqs=cache_seqs_ref, + ) + _, c_rope, _ = create_data_tensor( + batch_size, + seq_len_k, + rope_dim, + in_dtype, + is_dynamic_layout=True, + page_table=page_table, + cache_seqs=cache_seqs_ref, + ) + _, o, _ = create_data_tensor( + batch_size, + num_heads, + latent_dim, + out_dtype, + is_dynamic_layout=True, + seq_len_q=seq_len_q, + ) + _, lse, _ = create_data_tensor( + batch_size, + num_heads, + 1, + lse_dtype, + is_dynamic_layout=True, + is_lse=True, + seq_len_q=seq_len_q, + ) + workspace, workspace_torch = create_workspace( + num_heads, seq_len_q, latent_dim, batch_size, _split_kv, acc_dtype + ) + return testing.JitArguments( + q_latent, + q_rope, + c_latent, + c_rope, + page_table, + o, + lse, + workspace, + _split_kv, + cache_seqs, + block_split_kvs, + softmax_scale, + output_scale, + stream, + ) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + q_latent_torch.numel() * q_latent_torch.element_size() + + q_rope_torch.numel() * q_rope_torch.element_size() + + c_latent_torch.numel() * c_latent_torch.element_size() + + c_rope_torch.numel() * c_rope_torch.element_size() + + o_torch.numel() * o_torch.element_size() + + lse_torch.numel() * lse_torch.element_size() + + cache_seqs_torch.numel() * cache_seqs_torch.element_size() + ) + one_workspace_bytes += ( + page_table_torch.numel() * page_table_torch.element_size() + ) + if is_var_split_kv: + one_workspace_bytes += ( + block_split_kvs_torch.numel() * block_split_kvs_torch.element_size() + ) + if workspace_torch is not None: + one_workspace_bytes += ( + workspace_torch.numel() * workspace_torch.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + avg_time_us = testing.benchmark( + compiled_mla, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return avg_time_us # Return execution time in microseconds + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: + try: + return tuple(int(x.strip()) for x in s.split(",")) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + def parse_mma_tiler(s: str) -> Tuple[int, int, Tuple[int, int]]: + ret = parse_comma_separated_ints(s) + if len(ret) != 2: + raise argparse.ArgumentTypeError( + "Invalid format. Expected 2 comma-separated integers." + ) + return (ret[0], ret[1]) + + parser = argparse.ArgumentParser(description="Example of MLA on Blackwell.") + + parser.add_argument( + "--in_dtype", + type=cutlass.dtype, + default=cutlass.Float16, + help="Input data type", + ) + + parser.add_argument( + "--out_dtype", + type=cutlass.dtype, + default=cutlass.Float16, + help="Output data type", + ) + + parser.add_argument( + "--acc_dtype", + type=cutlass.dtype, + default=cutlass.Float32, + help="Accumulator data type", + ) + + parser.add_argument( + "--lse_dtype", + type=cutlass.dtype, + default=cutlass.Float32, + help="LSE data type", + ) + parser.add_argument( + "--mma_qk_tiler_mn", + type=parse_mma_tiler, + default=(128, 128), + help="MMA tile shape (H, K)", + ) + parser.add_argument( + "--mma_pv_tiler_mn", + type=parse_mma_tiler, + default=(128, 256), + help="MMA tile shape (H, D)", + ) + + parser.add_argument( + "--is_persistent", + action="store_true", + help="Is persistent", + ) + + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="Batch size", + ) + + parser.add_argument( + "--seq_len_q", + type=int, + default=1, + help="Sequence length of Q", + ) + + parser.add_argument( + "--seq_len_k", + type=int, + default=128, + help="Sequence length of K/V", + ) + + parser.add_argument( + "--num_heads", + type=int, + default=128, + help="Number of heads of Q", + ) + + parser.add_argument( + "--latent_dim", + type=int, + default=512, + help="Latent dimension of Q/C", + ) + + parser.add_argument( + "--rope_dim", + type=int, + default=64, + help="Rope dimension of Q/C", + ) + + parser.add_argument( + "--is_var_seq", + action="store_true", + help="Use variable length of sequence length or not", + ) + + parser.add_argument( + "--is_var_split_kv", + action="store_true", + help="Use variable length of split kv or not", + ) + + parser.add_argument( + "--page_size", + type=int, + default=128, + help="Page size of page table", + ) + + parser.add_argument( + "--split_kv", + type=int, + default=-1, + help="Split KV setting", + ) + + parser.add_argument( + "--softmax_scale", + type=float, + default=0.0416, + help="Scaling factor to scale softmax", + ) + + parser.add_argument( + "--output_scale", + type=float, + default=1.0, + help="Scaling factor to scale output", + ) + + parser.add_argument( + "--skip_correction_threshold", + type=float, + default=0.0, + help="Skip correction threshold", + ) + + parser.add_argument( + "--tolerance", type=float, default=1e-02, help="Tolerance for validation" + ) + + parser.add_argument( + "--warmup_iterations", + type=int, + default=0, + help="Number of iterations for warmup", + ) + + parser.add_argument( + "--iterations", + type=int, + default=1, + help="Number of iterations after warmup", + ) + + parser.add_argument( + "--skip_ref_check", + action="store_true", + help="Skip reference check", + ) + + parser.add_argument( + "--use_cold_l2", + action="store_true", + help="Use cold L2 cache", + ) + + args = parser.parse_args() + + run( + args.batch_size, + args.seq_len_q, + args.seq_len_k, + args.num_heads, + args.latent_dim, + args.rope_dim, + args.in_dtype, + args.out_dtype, + args.acc_dtype, + args.lse_dtype, + args.mma_qk_tiler_mn, + args.mma_pv_tiler_mn, + args.split_kv, + args.is_persistent, + args.is_var_seq, + args.is_var_split_kv, + args.page_size, + args.softmax_scale, + args.output_scale, + args.skip_correction_threshold, + args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + ) + + print("PASS") diff --git a/flashinfer/cute_dsl/mla_decode_fp8.py b/flashinfer/cute_dsl/mla_decode_fp8.py new file mode 100644 index 0000000000..ae987df5ba --- /dev/null +++ b/flashinfer/cute_dsl/mla_decode_fp8.py @@ -0,0 +1,4356 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +import sys +import argparse +import math +from typing import Type, Tuple, Optional +from types import SimpleNamespace + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +import cutlass.cute.testing as testing +from cutlass.cute.nvgpu import tcgen05 +from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode +import cutlass.cute.nvgpu.cpasync as cpasync +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.runtime import from_dlpack +from cutlass.cute.arch import Arch +from cutlass.cutlass_dsl import BaseDSL + + +from .mla_helpers import ( + ceil_div, + MAX_SPLITS, + LOG2_E, + MLAStaticTileScheduler, + MLAStaticTileSchedulerParams, + create_mla_static_tile_scheduler, + create_mla_static_tile_scheduler_params, +) + +""" +A Multi-Head Latent Attention (MLA) example using fp8 as input/output for the NVIDIA Blackwell SM100 architecture using CUTE DSL + +This example demonstrates an implementation of inference of multi-head latent attention using a TMA + Blackwell +SM100 TensorCore warp-specialized persistent kernel. The implementation integrates the (Qc + Qr)*(Kc + Kr)^T +matrix multiplication, softmax normalization, and softmax((Qc + Qr)*(Kc + Kr)^T)*Vc into a single kernel. +The kernel provides support for page table storage and variable-length KV cache sequences. It implements KV splitting +functionality to minimize latency when processing long KV sequences. + +The kernel implements key optimizations including: +- Warp specialization for different computation phases (load, MMA, softmax, correction, epilogue) +- Pipeline stages between different warps for overlapping computation and memory access +- Support for different precision data types +- Two sub-kernels (split KV kernel and reduction kernel) that enable split KV processing + +To run this example: + +.. code-block:: bash + + python examples/blackwell/mla_fp8.py \ + --batch_size 4 --latent_dim 512 --rope_dim 64 \ + --num_heads 128 --seq_len_q 1 --seq_len_k 1024 \ + --in_dtype Float8E4M3FN --out_dtype Float8E4M3FN \ + --acc_dtype Float32 --lse_dtype Float32 \ + --is_var_seq --is_var_split_kv \ + --is_persistent + +The above example runs Multi-Head Latent Attention (MLA) with the following configuration: +- Batch size: 4 +- Sequence length of Q: 1 +- Sequence length of K: 1024 +- Latent dimension: 512 +- RoPE dimension: 64 +- Number of heads: 128 +- Data types: Float8E4M3FN (input), Float8E4M3FN (output), Float32 (accumulation and LSE) + +It utilizes page table storage for the KV cache and enables both variable-length KV cache sequences +and variable split KV processing with persistent scheduling. + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/mla_fp8.py \ + --batch_size 4 --latent_dim 512 --rope_dim 64 \ + --num_heads 128 --seq_len_q 1 --seq_len_k 1024 \ + --in_dtype Float8E4M3FN --out_dtype Float8E4M3FN \ + --acc_dtype Float32 --lse_dtype Float32 \ + --is_var_seq --is_var_split_kv \ + --is_persistent --warmup_iterations 3 \ + --iterations 10 --skip_ref_check + +Constraints for this example: +* Data type requirements: + - Input/output: Float8E4M3FN + - Accumulation and LSE: Float32 +* Fixed architecture parameters: + - Number of attention heads: 128 + - Latent dimension: 512 + - RoPE dimension: 64 +* Input query modes should be (NumHeads, LatentDim/RopeDim, SeqLenQ, BatchSize) +* Input kv latent/rope modes should be (SeqLenK, LatentDim/RopeDim, BatchSize) +* Query sequence length must be 1-4 +* Only supports 2-CTA instructions +* Variable sequence length requires page table storage enabled +""" + + +class BlackwellMultiHeadLatentAttentionForwardFP8: + def __init__( + self, + acc_dtype: Type[cutlass.Numeric], + lse_dtype: Type[cutlass.Numeric], + mma_qk_tiler_mn: Tuple[int, int], + mma_pv_tiler_mn: Tuple[int, int], + max_active_clusters: int, + page_size: int, + skip_correction_threshold: float, + is_persistent: bool, + is_var_seq: bool, + is_var_split_kv: bool, + ): + """Initializes the configuration for a Blackwell Multi-Head Latent Attention (MLA) kernel. + + :param acc_dtype: Data type for accumulation S and O + :type acc_dtype: Type[cutlass.Numeric] + :param lse_dtype: Data type for output LSE + :type lse_dtype: Type[cutlass.Numeric] + :param mma_s_tiler: The (H, K) tile shape of the MMA instruction for S + :type mma_s_tiler: Tuple[int, int] + :param mma_p_tiler: The (H, D) tile shape of the MMA instruction for P + :type mma_p_tiler: Tuple[int, int] + :param max_active_clusters: Maximum number of active clusters + :type max_active_clusters: int + :param page_size: The page size + :type page_size: int + :param skip_correction_threshold: Threshold to skip correction + :type skip_correction_threshold: float + :param is_persistent: Whether to use persistent kernel mode + :type is_persistent: bool + :param is_var_seq: Whether to use variable sequence length + :type is_var_seq: bool + :param is_var_split_kv: Whether to use variable split KV + :type is_var_split_kv: bool + """ + + self.latent_dim = 512 + self.rope_dim = 64 + self.acc_dtype = acc_dtype + self.lse_dtype = lse_dtype + self.mma_qk_tiler_mn = mma_qk_tiler_mn + self.mma_pv_tiler_mn = mma_pv_tiler_mn + self.max_active_clusters = max_active_clusters + self.skip_correction_threshold = skip_correction_threshold + self.is_persistent = is_persistent + self.page_size = page_size + self.is_var_seq = is_var_seq + self.is_var_split_kv = is_var_split_kv + self.cluster_shape_mnk = (2, 1, 1) + self.use_2cta_instrs = True + # When using 2 CTAs with m=128: warps 0-1 handle accumulation for first half [0, n/2), + # while warps 2-3 handle accumulation for second half [n/2, n) + self.warps_in_n = 2 + self.num_compute_warps = 4 + self.threads_per_warp = 32 + mma_qk_tiler_k = self.rope_dim * 2 + self.mma_qk_tiler = ( + self.mma_qk_tiler_mn[0], + self.mma_qk_tiler_mn[1], + mma_qk_tiler_k, + ) + self.mma_qk_rope_tiler = ( + self.mma_qk_tiler_mn[0], + self.mma_qk_tiler_mn[1], + self.rope_dim, + ) + self.mma_pv_tiler = ( + self.mma_pv_tiler_mn[0], + self.mma_pv_tiler_mn[1], + self.mma_qk_tiler[1] * self.mma_qk_tiler[2] // self.mma_pv_tiler_mn[1], + ) + self.iterations_qk_latent = self.latent_dim // self.mma_qk_tiler[2] + self.iterations_qk_rope = 1 + self.iterations_qk = self.iterations_qk_latent + self.iterations_qk_rope + self.iterations_pv_k = self.mma_qk_tiler[1] // self.mma_pv_tiler[2] + self.iterations_pv_n = self.latent_dim // self.mma_pv_tiler[1] + + # Set specialized warp ids + self.compute_warp_ids = (0, 1, 2, 3) + self.correction_warp_ids = (4, 5, 6, 7) + self.mma_warp_id = 8 + self.load_tma_k_warp_id = 9 + self.load_tma_v_warp_id = 10 + self.empty_warp_ids = (11,) + self.threads_per_cta = self.threads_per_warp * len( + ( + self.mma_warp_id, + self.load_tma_k_warp_id, + self.load_tma_v_warp_id, + *self.compute_warp_ids, + *self.correction_warp_ids, + *self.empty_warp_ids, + ) + ) + + # register settings + self.softmax_reg_num = 192 + self.correction_reg_num = 256 + self.other_reg_num = 48 + # Named barriers + self.tmem_ptr_sync_bar = pipeline.NamedBarrier( + barrier_id=1, + num_threads=( + self.threads_per_warp + + self.threads_per_warp * self.num_compute_warps * 2 + ), + ) + self.softmax_exchange_sync_bar = pipeline.NamedBarrier( + barrier_id=2, num_threads=(self.threads_per_warp * self.num_compute_warps) + ) + self.epilogue_exchange_sync_bar = pipeline.NamedBarrier( + barrier_id=3, num_threads=(self.threads_per_warp * self.num_compute_warps) + ) + + def _setup_attributes(self): + """Set up configurations and parameters for the MLA kernel operation. + + This method initializes and configures various attributes required for the + execution of the multi-head latent attention kernel, mainly about the pipeline stages: + + - Sets up staging parameters for Q, K, V inputs and accumulator data + - Configures pipeline stages for softmax, correction, and epilogue operations + """ + + self.load_q_stage = 1 + self.load_k_stage = 3 + self.load_v_stage = 2 + self.mma_s_stage = 2 + self.p_mma_stage = 2 + self.p_cor_stage = 2 + self.mma_o_stage = 2 + + self.tmem_o_offset = self.mma_s_stage * self.mma_qk_tiler[1] // self.warps_in_n + self.correction_factor_offset = ( + self.tmem_o_offset + self.latent_dim // self.warps_in_n + ) + + @cute.jit + def __call__( + self, + q_latent: cute.Tensor, + q_rope: cute.Tensor, + c_latent: cute.Tensor, + c_rope: cute.Tensor, + page_table: cute.Tensor, + o: cute.Tensor, + lse: cute.Tensor, + workspace: cute.Tensor, + split_kv: cutlass.Int32, + cache_seqs: Optional[cute.Tensor], + block_split_kvs: Optional[cute.Tensor], + softmax_scale: cutlass.Float32, + output_scale: cutlass.Float32, + stream: cuda.CUstream, + ): + """Execute the Multi-Head Latent Attention operation on the provided tensors. + + The method handles: + 1. Initialization of workspace for temporary split KV buffers + 2. Validation of tensor data types + 3. Initialization of hardware-specific parameters and memory layouts + 4. Configuration of TMA (Tensor Memory Access) operations + 5. Grid and work scheduling computation + 6. Kernel launch(split KV kernel and reduction kernel) with appropriate parameters + + :param q_latent: The query tensor with shape [num_head, latent_dim, seq_len_q, batch_size] + :type q_latent: cute.Tensor + :param q_rope: The query RoPE tensor with shape [num_head, rope_dim, seq_len_q, batch_size] + :type q_rope: cute.Tensor + :param c_latent: The key tensor with shape [seq_len_k, latent_dim, batch_size] + :type c_latent: cute.Tensor + :param c_rope: The key RoPE tensor with shape [seq_len_k, rope_dim, batch_size] + :type c_rope: cute.Tensor + :param page_table: The page table tensor with shape [page_count, batch_size] + :type page_table: cute.Tensor + :param o: The output tensor with shape [num_head, latent_dim, seq_len_q, batch_size] + :type o: cute.Tensor + :param lse: The LSE tensor with shape [num_head, seq_len_q, batch_size] + :type lse: cute.Tensor + :param workspace: The workspace tensor with 1-d shape prepared for acc_o and acc_lse + :type workspace: cute.Tensor + :param split_kv: The scalar factor for split KV + :type split_kv: cutlass.Int32 + :param cache_seqs: The cache sequences tensor with shape [batch_size] + :type cache_seqs: cute.Tensor + :param block_split_kvs: The block split KV tensor with shape [batch_size] + :type block_split_kvs: cute.Tensor + :param softmax_scale: The scale factor for softmax + :type softmax_scale: cutlass.Float32 + :param output_scale: The scale factor for the output + :type output_scale: cutlass.Float32 + :param stream: The CUDA stream to execute the kernel on + :type stream: cuda.CUstream + + :raises TypeError: If tensor data types don't match or aren't supported + """ + + # setup static attributes before smem/grid/tma computation + self.q_dtype = q_latent.element_type + self.k_dtype = c_latent.element_type + self.v_dtype = c_latent.element_type + self.o_dtype = o.element_type + + # check type consistency + if cutlass.const_expr( + self.q_dtype != self.k_dtype or self.q_dtype != self.v_dtype + ): + raise TypeError( + f"Type mismatch: {self.q_dtype} != {self.k_dtype} or {self.q_dtype} != {self.v_dtype}" + ) + # check leading dimensions of input/output + if cutlass.const_expr(q_latent.stride[1] != 1 or q_rope.stride[1] != 1): + raise ValueError("q_latent and q_rope must have leading dimension 1") + if cutlass.const_expr(c_latent.stride[1] != 1 or c_rope.stride[1] != 1): + raise ValueError("c_latent and c_rope must have leading dimension 1") + if cutlass.const_expr(o.stride[1] != 1): + raise ValueError("o must have leading dimension 1") + if cutlass.const_expr(lse.stride[0] != 1): + raise ValueError("lse must have leading dimension 0") + + acc_o, acc_lse = self.initialize_workspace( + q_latent.shape[0], + q_latent.shape[1], + q_latent.shape[2], + q_latent.shape[3], + split_kv, + self.acc_dtype, + workspace, + ) + + c_latent_tranpose_layout = cute.select(c_latent.layout, mode=[1, 0, 2]) + c_latent_transpose = cute.make_tensor( + c_latent.iterator, c_latent_tranpose_layout + ) + + self.q_major_mode = OperandMajorMode.K + self.k_major_mode = OperandMajorMode.K + self.v_major_mode = OperandMajorMode.MN + + self._setup_attributes() + + cta_group = tcgen05.CtaGroup.TWO + # the intermediate tensor p is from smem & k-major + p_major_mode = OperandMajorMode.K + qk_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.q_dtype, + self.q_major_mode, + self.k_major_mode, + self.acc_dtype, + cta_group, + self.mma_qk_tiler[:2], + ) + pv_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.v_dtype, + p_major_mode, + self.v_major_mode, + self.acc_dtype, + cta_group, + self.mma_pv_tiler[:2], + ) + + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), + (qk_tiled_mma.thr_id.shape,), + ) + + self.epi_tile = self.mma_pv_tiler[:2] + + q_latent_smem_layout_staged = sm100_utils.make_smem_layout_a( + qk_tiled_mma, + self.mma_qk_tiler, + self.q_dtype, + (self.iterations_qk_latent * self.load_q_stage), + ) + q_latent_smem_layout_staged = cute.logical_divide( + q_latent_smem_layout_staged, (None, None, None, self.iterations_qk_latent) + ) + q_rope_smem_layout_staged = sm100_utils.make_smem_layout_a( + qk_tiled_mma, + self.mma_qk_rope_tiler, + self.q_dtype, + self.load_q_stage, + ) + + kc_latent_smem_layout_staged = sm100_utils.make_smem_layout_b( + qk_tiled_mma, + self.mma_qk_tiler, + self.k_dtype, + (self.iterations_qk_latent * self.load_k_stage), + ) + kc_page_tile_size = min( + self.page_size, qk_tiled_mma.op.shape_mnk[0] // qk_tiled_mma.thr_id.shape + ) + kc_latent_smem_layout_staged = cute.logical_divide( + kc_latent_smem_layout_staged, (None, None, None, self.iterations_qk_latent) + ) + + kc_latent_smem_layout_for_tma = sm100_utils.make_smem_layout( + OperandMajorMode.K, + (self.mma_qk_tiler[0] // qk_tiled_mma.thr_id.shape, self.mma_qk_tiler[2]), + self.k_dtype, + (self.iterations_qk_latent * self.load_k_stage), + ) + kc_latent_smem_layout_for_tma = cute.tiled_divide( + kc_latent_smem_layout_for_tma, (kc_page_tile_size, self.mma_qk_tiler[2]) + ) + kc_latent_smem_layout_for_tma = cute.logical_divide( + kc_latent_smem_layout_for_tma, (None, None, None, self.iterations_qk_latent) + ) + + kc_rope_smem_layout_staged = sm100_utils.make_smem_layout_b( + qk_tiled_mma, + self.mma_qk_rope_tiler, + self.k_dtype, + self.load_k_stage, + ) + kc_rope_smem_layout_for_tma = sm100_utils.make_smem_layout( + OperandMajorMode.K, + ( + self.mma_qk_rope_tiler[0] // qk_tiled_mma.thr_id.shape, + self.mma_qk_rope_tiler[2], + ), + self.k_dtype, + (self.iterations_qk_rope * self.load_k_stage), + ) + kc_rope_smem_layout_for_tma = cute.tiled_divide( + kc_rope_smem_layout_for_tma, (kc_page_tile_size, self.mma_qk_rope_tiler[2]) + ) + + p_smem_layout_staged = sm100_utils.make_smem_layout_a( + pv_tiled_mma, + self.mma_pv_tiler, + self.q_dtype, + (self.iterations_pv_k * self.p_mma_stage), + ) + p_smem_layout_staged = cute.logical_divide( + p_smem_layout_staged, (None, None, None, self.iterations_pv_k) + ) + + vc_smem_layout_staged = sm100_utils.make_smem_layout_b( + pv_tiled_mma, + self.mma_pv_tiler, + self.v_dtype, + (self.iterations_pv_k * self.iterations_pv_n * self.load_v_stage), + ) + vc_smem_layout_staged = cute.logical_divide( + cute.logical_divide( + vc_smem_layout_staged, + (None, None, None, self.iterations_pv_k * self.iterations_pv_n), + ), + (None, None, None, (self.iterations_pv_n, None)), + ) + vc_page_tile_size = min(self.page_size, self.mma_pv_tiler[2]) + vc_smem_layout_for_tma = sm100_utils.make_smem_layout( + OperandMajorMode.MN, + (self.mma_pv_tiler[1] // pv_tiled_mma.thr_id.shape, self.mma_pv_tiler[2]), + self.v_dtype, + (self.iterations_pv_k * self.iterations_pv_n * self.load_v_stage), + ) + vc_smem_layout_for_tma = cute.tiled_divide( + vc_smem_layout_for_tma, + ( + pv_tiled_mma.op.shape_mnk[1] // pv_tiled_mma.thr_id.shape, + vc_page_tile_size, + ), + ) + vc_smem_layout_for_tma = cute.logical_divide( + cute.logical_divide( + vc_smem_layout_for_tma, + (None, None, None, self.iterations_pv_k * self.iterations_pv_n), + ), + (None, None, None, (self.iterations_pv_n, None)), + ) + # TMA load for Q latent and rope + tma_load_op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp(cta_group) + + q_smem_layout = cute.select(q_latent_smem_layout_staged, mode=[0, 1, 2]) + + tma_atom_q_latent, tma_tensor_q_latent = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q_latent, + q_smem_layout, + self.mma_qk_tiler, + qk_tiled_mma, + cta_layout_vmnk.shape, + ) + q_rope_smem_layout = cute.select(q_rope_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_q_rope, tma_tensor_q_rope = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q_rope, + q_rope_smem_layout, + self.mma_qk_rope_tiler, + qk_tiled_mma, + cta_layout_vmnk.shape, + ) + # TMA load for c latent and k rope + kc_smem_layout = cute.select(kc_latent_smem_layout_for_tma, mode=[0]) + tma_atom_c_latent, tma_tensor_c_latent = self.make_paged_tiled_tma_atom( + tma_load_op, + c_latent, + kc_smem_layout, + (self.mma_qk_tiler[1], self.mma_qk_tiler[2]), + qk_tiled_mma, + is_k_load=True, + ) + kc_rope_smem_layout = cute.select(kc_rope_smem_layout_for_tma, mode=[0]) + tma_atom_c_rope, tma_tensor_c_rope = self.make_paged_tiled_tma_atom( + tma_load_op, + c_rope, + kc_rope_smem_layout, + (self.mma_qk_rope_tiler[1], self.mma_qk_rope_tiler[2]), + qk_tiled_mma, + is_k_load=True, + ) + + # TMA load for c latent transpose + vc_smem_layout = cute.select(vc_smem_layout_for_tma, mode=[0]) + tma_atom_c_latent_transpose, tma_tensor_c_latent_transpose = ( + self.make_paged_tiled_tma_atom( + tma_load_op, + c_latent_transpose, + vc_smem_layout, + (self.mma_pv_tiler[1], self.mma_pv_tiler[2]), + pv_tiled_mma, + is_k_load=False, + ) + ) + + q_latent_copy_size = ( + cute.size_in_bytes(self.q_dtype, q_smem_layout) + * cute.size(qk_tiled_mma.thr_id.shape) + * self.iterations_qk_latent + ) + q_rope_copy_size = ( + cute.size_in_bytes(self.q_dtype, q_rope_smem_layout) + * cute.size(qk_tiled_mma.thr_id.shape) + * self.iterations_qk_rope + ) + kc_latent_copy_size = ( + cute.size_in_bytes( + self.k_dtype, + cute.select(kc_latent_smem_layout_staged, mode=[0, 1, 2]), + ) + * cute.size(qk_tiled_mma.thr_id.shape) + * self.iterations_qk_latent + ) + kc_rope_copy_size = ( + cute.size_in_bytes( + self.k_dtype, + cute.select(kc_rope_smem_layout_staged, mode=[0, 1, 2]), + ) + * cute.size(qk_tiled_mma.thr_id.shape) + * self.iterations_qk_rope + ) + vc_copy_size = ( + cute.size_in_bytes( + self.v_dtype, cute.select(vc_smem_layout_staged, mode=[0, 1, 2]) + ) + * cute.size(pv_tiled_mma.thr_id.shape) + * self.iterations_pv_n + * self.iterations_pv_k + ) + + self.tma_copy_q_bytes = q_latent_copy_size + q_rope_copy_size + self.tma_copy_kc_bytes = kc_latent_copy_size + kc_rope_copy_size + self.tma_copy_vc_bytes = vc_copy_size + + tile_sched_params, grid = self._compute_grid( + o, + split_kv, + self.cluster_shape_mnk, + self.max_active_clusters, + self.is_persistent, + ) + + @cute.struct + class SplitKVKernelSharedStorage: + # Pipeline barriers + load_q_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_q_stage * 2] + load_k_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_k_stage * 2] + load_v_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.load_v_stage * 2] + mma_s_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.mma_s_stage * 2] + p_mma_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.p_mma_stage * 2] + p_cor_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.p_cor_stage * 2] + mma_o_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.mma_o_stage * 2] + + # Smem tensors + smem_p: cute.struct.Align[ + cute.struct.MemRange[self.q_dtype, cute.cosize(p_smem_layout_staged)], + 1024, + ] + smem_kc_latent: cute.struct.Align[ + cute.struct.MemRange[ + self.k_dtype, cute.cosize(kc_latent_smem_layout_staged) + ], + 1024, + ] + + smem_kc_rope: cute.struct.Align[ + cute.struct.MemRange[ + self.k_dtype, cute.cosize(kc_rope_smem_layout_staged) + ], + 1024, + ] + smem_q_latent: cute.struct.Align[ + cute.struct.MemRange[ + self.q_dtype, cute.cosize(q_latent_smem_layout_staged) + ], + 1024, + ] + smem_q_rope: cute.struct.Align[ + cute.struct.MemRange[ + self.q_dtype, cute.cosize(q_rope_smem_layout_staged) + ], + 1024, + ] + smem_vc: cute.struct.Align[ + cute.struct.MemRange[self.v_dtype, cute.cosize(vc_smem_layout_staged)], + 1024, + ] + softmax_smem_exchange: cute.struct.MemRange[ + self.acc_dtype, self.num_compute_warps * self.threads_per_warp + ] + epilogue_smem_exchange: cute.struct.MemRange[ + self.acc_dtype, self.num_compute_warps * self.threads_per_warp + ] + + # Tmem dealloc cluster barrier + tmem_dealloc_mbar_ptr: cutlass.Int64 + + # Tmem holding buffer + tmem_holding_buf: cutlass.Int32 + + softmax_scale_log2 = softmax_scale * LOG2_E + + self.split_kv_kernel( + qk_tiled_mma, + pv_tiled_mma, + tma_atom_q_latent, + tma_tensor_q_latent, + tma_atom_q_rope, + tma_tensor_q_rope, + tma_atom_c_latent, + tma_tensor_c_latent, + tma_atom_c_rope, + tma_tensor_c_rope, + tma_atom_c_latent_transpose, + tma_tensor_c_latent_transpose, + page_table, + o, + lse, + acc_o, + acc_lse, + split_kv, + cache_seqs, + block_split_kvs, + softmax_scale_log2, + output_scale, + q_latent_smem_layout_staged, + q_rope_smem_layout_staged, + kc_latent_smem_layout_staged, + kc_rope_smem_layout_staged, + p_smem_layout_staged, + vc_smem_layout_staged, + kc_latent_smem_layout_for_tma, + kc_rope_smem_layout_for_tma, + vc_smem_layout_for_tma, + cta_layout_vmnk, + tile_sched_params, + SplitKVKernelSharedStorage, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk, + smem=SplitKVKernelSharedStorage.size_in_bytes(), + stream=stream, + min_blocks_per_mp=1, + ) + if cutlass.const_expr(acc_o is not None): + self.reduction_kernel( + o, + lse, + acc_o, + acc_lse, + split_kv, + cache_seqs, + block_split_kvs, + ).launch( + grid=(q_latent.shape[0], q_latent.shape[2], q_latent.shape[3]), + block=[self.threads_per_warp * self.num_compute_warps, 1, 1], + smem=MAX_SPLITS * self.acc_dtype.width // 8, + stream=stream, + min_blocks_per_mp=1, + ) + + @cute.jit + def make_paged_tiled_tma_atom( + self, + tma_load_op: cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp, + gmem: cute.Tensor, + smem_layout: cute.Layout, + mma_tiler, + tiled_mma: cute.TiledMma, + is_k_load: bool, + ): + ident = cute.make_identity_layout(gmem.shape) + g_tile = cute.composition(ident, mma_tiler) + cta_mn = mma_tiler[0] // tiled_mma.thr_id.shape + cta_v_map = cute.flat_divide(g_tile, (cta_mn,)) + cta_v_map = cute.select(cta_v_map, mode=[0, 2]) + page_tile_size = ( + min(self.page_size, cta_mn) + if is_k_load + else min(self.page_size, mma_tiler[1]) + ) + cta_v_map = cute.zipped_divide( + cta_v_map, + (page_tile_size, mma_tiler[1]) if is_k_load else (cta_mn, page_tile_size), + ) + cta_v_map = cute.select(cta_v_map, mode=[0]) + from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir + + res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load( + gmem.value, + smem_layout.value, + cta_v_map, + tma_load_op._to_ir(), + num_multicast=1, + ) + return cute.CopyAtom( + tma_load_op, cpasync.CopyBulkTensorTileG2SNonExecTrait(res[0]) + ), res[1] + + @cute.kernel + def split_kv_kernel( + self, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + tma_atom_q_latent: Optional[cute.CopyAtom], + mQL: cute.Tensor, + tma_atom_q_rope: Optional[cute.CopyAtom], + mQR: cute.Tensor, + tma_atom_c_latent: Optional[cute.CopyAtom], + mCL: cute.Tensor, + tma_atom_c_rope: Optional[cute.CopyAtom], + mKR: cute.Tensor, + tma_atom_c_latent_transpose: Optional[cute.CopyAtom], + mCLT: cute.Tensor, + mPT: cute.Tensor, + mO: Optional[cute.Tensor], + mLSE: Optional[cute.Tensor], + mAccO: Optional[cute.Tensor], + mAccLSE: Optional[cute.Tensor], + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + softmax_scale_log2: cutlass.Float32, + output_scale: cutlass.Float32, + q_latent_smem_layout_staged: cute.ComposedLayout, + q_rope_smem_layout_staged: cute.ComposedLayout, + kc_latent_smem_layout_staged: cute.ComposedLayout, + kc_rope_smem_layout_staged: cute.ComposedLayout, + p_smem_layout_staged: cute.ComposedLayout, + vc_smem_layout_staged: cute.ComposedLayout, + kc_latent_smem_layout_for_tma: Optional[cute.ComposedLayout], + kc_rope_smem_layout_for_tma: Optional[cute.ComposedLayout], + vc_smem_layout_for_tma: Optional[cute.ComposedLayout], + cta_layout_vmnk: cute.Layout, + tile_sched_params: MLAStaticTileSchedulerParams, + SharedStorage: cutlass.Constexpr, + ): + """The device split_kv kernel implementation of the Multi-Head Latent Attention. + + This kernel coordinates multiple specialized warps to perform different phases of the MLA computation: + 1. Load warp: Loads Q/C latent/rope data from global memory to shared memory using TMA + 2. MMA warp: Performs matrix multiplications (Q*K^T and P*V) + 3. Compute warps: Compute softmax and do rescaling on accumulators, and store the intermediate/final results + to global memory + + The kernel produces either intermediate or final results of the MLA computation based on the split_kv parameter. + When split_kv is 1, the kernel generates the final results directly. Otherwise, it produces intermediate results + that will later be combined by a reduction kernel. + + The kernel implements a complex pipeline with overlapping computation and memory operations, + using tensor memory access (TMA) for efficient data loading, warp specialization for different + computation phases. + + :param tiled_mma_qk: Tiled MMA for Q*K^T + :type tiled_mma_qk: cute.TiledMma + :param tiled_mma_pv: Tiled MMA for P*V + :type tiled_mma_pv: cute.TiledMma + :param tma_atom_q_latent: TMA copy atom for query latent tensor + :type tma_atom_q_latent: cute.CopyAtom + :param mQL: query latent tensor + :type mQL: cute.Tensor + :param tma_atom_q_rope: TMA copy atom for query rope tensor + :type tma_atom_q_rope: cute.CopyAtom + :param mKR: Compressed rope tensor + :type mKR: cute.Tensor + :param tma_atom_c_latent: TMA copy atom for c latent tensor + :type tma_atom_c_latent: cute.CopyAtom + :param mCL: Compressed latent tensor + :type mCL: cute.Tensor + :param tma_atom_c_rope: TMA copy atom for c rope tensor + :type tma_atom_c_rope: cute.CopyAtom + :param mCLT: Compressed latent transpose tensor + :type mCLT: cute.Tensor + :param mPT: Page table tensor + :type mPT: cute.Tensor + :param mO: Output tensor + :type mO: cute.Tensor + :param mLSE: Log-sum-exp tensor + :type mLSE: cute.Tensor + :param mAccO: Intermediate accumulator output tensor + :type mAccO: cute.Tensor + :param mAccLSE: Intermediate accumulator log-sum-exp tensor + :type mAccLSE: cute.Tensor + :param split_kv: The split_kv parameter + :type split_kv: cutlass.Int32 + :param cache_seqs: The variable sequence length tensor + :type cache_seqs: cute.Tensor + :param block_split_kvs: The per-block split_kv values tensor + :type block_split_kvs: cute.Tensor + :param softmax_scale_log2: The log2 scale factor for softmax + :type softmax_scale_log2: cutlass.Float32 + :param output_scale: The scale factor for the output + :type output_scale: cutlass.Float32 + :param q_latent_smem_layout_staged: Shared memory layout for query tensor + :type q_latent_smem_layout_staged: cute.ComposedLayout + :param q_rope_smem_layout_staged: Shared memory layout for query rope tensor + :type q_rope_smem_layout_staged: cute.ComposedLayout + :param kc_latent_smem_layout_staged: Shared memory layout for key tensor + :type kc_latent_smem_layout_staged: cute.ComposedLayout + :param kc_rope_smem_layout_staged: Shared memory layout for key rope tensor + :type kc_rope_smem_layout_staged: cute.ComposedLayout + :param p_smem_layout_staged: Shared memory layout for probability matrix + :type p_smem_layout_staged: cute.ComposedLayout + :param vc_smem_layout_staged: Shared memory layout for value tensor + :type vc_smem_layout_staged: cute.ComposedLayout + :param cta_layout_vmnk: Layout for compute threads + :type cta_layout_vmnk: cute.Layout + :param tile_sched_params: Scheduling parameters for work distribution + :type tile_sched_params: MLAStaticTileSchedulerParams + :param SharedStorage: Shared storage for the kernel + :type SharedStorage: cutlass.Constexpr + """ + + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma_qk.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + + # Prefetch tma descriptor + if warp_idx == self.mma_warp_id: + cpasync.prefetch_descriptor(tma_atom_q_latent) + cpasync.prefetch_descriptor(tma_atom_q_rope) + cpasync.prefetch_descriptor(tma_atom_c_latent) + cpasync.prefetch_descriptor(tma_atom_c_rope) + cpasync.prefetch_descriptor(tma_atom_c_latent_transpose) + + # Alloc + smem = utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + + # Tensor memory dealloc barrier init + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=self.tmem_ptr_sync_bar, + allocator_warp_id=self.mma_warp_id, + is_two_cta=self.use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + load_q_pipeline = self.make_and_init_load_qkv_pipeline( + storage.load_q_mbar_ptr.data_ptr(), + cta_layout_vmnk, + self.load_q_stage, + self.tma_copy_q_bytes, + ) + load_k_pipeline = self.make_and_init_load_qkv_pipeline( + storage.load_k_mbar_ptr.data_ptr(), + cta_layout_vmnk, + self.load_k_stage, + self.tma_copy_kc_bytes, + ) + load_v_pipeline = self.make_and_init_load_qkv_pipeline( + storage.load_v_mbar_ptr.data_ptr(), + cta_layout_vmnk, + self.load_v_stage, + self.tma_copy_vc_bytes, + ) + mma_s_pipeline = self.make_and_init_mma_s_pipeline( + storage.mma_s_mbar_ptr.data_ptr(), cta_layout_vmnk + ) + p_mma_pipeline = self.make_and_init_p_mma_pipeline( + storage.p_mma_mbar_ptr.data_ptr(), cta_layout_vmnk + ) + p_cor_pipeline = self.make_and_init_p_cor_pipeline( + storage.p_cor_mbar_ptr.data_ptr() + ) + mma_o_pipeline = self.make_and_init_mma_o_pipeline( + storage.mma_o_mbar_ptr.data_ptr(), cta_layout_vmnk + ) + + # Cluster arrive after barrier init + pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mnk, is_relaxed=True) + + # Generate smem tensor Q/KC/VC/exchange + # (MMA, MMA_H, MMA_R, PIPE) + sQ = storage.smem_q_latent.get_tensor( + q_latent_smem_layout_staged.outer, swizzle=q_latent_smem_layout_staged.inner + ) + sQ_rope = storage.smem_q_rope.get_tensor( + q_rope_smem_layout_staged.outer, swizzle=q_rope_smem_layout_staged.inner + ) + # (MMA, MMA_K, MMA_R, PIPE) + sKC = storage.smem_kc_latent.get_tensor( + kc_latent_smem_layout_staged.outer, + swizzle=kc_latent_smem_layout_staged.inner, + ) + sKC_rope = storage.smem_kc_rope.get_tensor( + kc_rope_smem_layout_staged.outer, swizzle=kc_rope_smem_layout_staged.inner + ) + sKC_for_tma = storage.smem_kc_latent.get_tensor( + kc_latent_smem_layout_for_tma.outer, + swizzle=kc_latent_smem_layout_for_tma.inner, + ) + sKC_rope_for_tma = storage.smem_kc_rope.get_tensor( + kc_rope_smem_layout_for_tma.outer, swizzle=kc_rope_smem_layout_for_tma.inner + ) + # (MMA, MMA_D, MMA_K, PIPE) + sVC = storage.smem_vc.get_tensor( + vc_smem_layout_staged.outer, swizzle=vc_smem_layout_staged.inner + ) + sVC_for_tma = storage.smem_vc.get_tensor( + vc_smem_layout_for_tma.outer, swizzle=vc_smem_layout_for_tma.inner + ) + # (MMA, MMA_H, MMA_K) + sP = storage.smem_p.get_tensor( + p_smem_layout_staged.outer, swizzle=p_smem_layout_staged.inner + ) + # (compute_threads,) + softmax_smem_exchange = storage.softmax_smem_exchange.get_tensor( + cute.make_layout(self.num_compute_warps * self.threads_per_warp) + ) + epilogue_smem_exchange = storage.epilogue_smem_exchange.get_tensor( + cute.make_layout(self.num_compute_warps * self.threads_per_warp) + ) + + # + # Cluster wait before tensor memory alloc + # + pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mnk) + + # /////////////////////////////////////////////////////////////////////////////// + # Load warps, including page table and data tensors + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx >= self.empty_warp_ids[0] and warp_idx <= self.empty_warp_ids[-1]: + cute.arch.setmaxregister_decrease(self.other_reg_num) + + if warp_idx == self.load_tma_k_warp_id: + cute.arch.setmaxregister_decrease(self.other_reg_num) + load_q_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.load_q_stage + ) + load_k_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.load_k_stage + ) + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self.get_k_tile_count( + split_kv, + cache_seqs, + block_split_kvs, + blk_coord, + ) + if k_tile_count > 0: + # Construct fixed common/tma_qk/tma_pv params for load_tma + tma_common_params = SimpleNamespace( + blk_coord=blk_coord, + local_split_kv=local_split_kv, + load_q_pipeline=load_q_pipeline, + load_k_pipeline=load_k_pipeline, + load_v_pipeline=load_v_pipeline, + mPT=mPT, + ) + tma_qk_params = SimpleNamespace( + tiled_mma_qk=tiled_mma_qk, + tma_atom_q_latent=tma_atom_q_latent, + tma_atom_q_rope=tma_atom_q_rope, + tma_atom_c_latent=tma_atom_c_latent, + tma_atom_c_rope=tma_atom_c_rope, + mQL=mQL, + mQR=mQR, + mCL=mCL, + mKR=mKR, + sQ=sQ, + sQ_rope=sQ_rope, + sKC=sKC_for_tma, + sKC_rope=sKC_rope_for_tma, + ) + # Load tma + load_q_producer_state, load_k_producer_state = self.load_tma_qk( + tma_common_params, + tma_qk_params, + k_index, + k_tile_count, + load_q_producer_state, + load_k_producer_state, + ) + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + load_q_pipeline.producer_tail(load_q_producer_state) + load_k_pipeline.producer_tail(load_k_producer_state) + + if warp_idx == self.load_tma_v_warp_id: + cute.arch.setmaxregister_decrease(self.other_reg_num) + load_v_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.load_v_stage + ) + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self.get_k_tile_count( + split_kv, + cache_seqs, + block_split_kvs, + blk_coord, + ) + if k_tile_count > 0: + # Construct fixed common/tma_qk/tma_pv params for load_tma + tma_common_params = SimpleNamespace( + blk_coord=blk_coord, + local_split_kv=local_split_kv, + load_v_pipeline=load_v_pipeline, + mPT=mPT, + ) + tma_pv_params = SimpleNamespace( + tiled_mma_pv=tiled_mma_pv, + tma_atom_c_latent_transpose=tma_atom_c_latent_transpose, + mCLT=mCLT, + sVC=sVC_for_tma, + ) + # Load tma + load_v_producer_state = self.load_tma_v( + tma_common_params, + tma_pv_params, + k_index, + k_tile_count, + load_v_producer_state, + ) + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + load_v_pipeline.producer_tail(load_v_producer_state) + + # /////////////////////////////////////////////////////////////////////////////// + # MMA warp + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.mma_warp_id: + cute.arch.setmaxregister_decrease(self.other_reg_num) + # Alloc tensor memory buffer + tmem.allocate(cute.arch.get_max_tmem_alloc_cols("sm_100")) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + + load_q_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.load_q_stage + ) + load_k_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.load_k_stage + ) + load_v_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.load_v_stage + ) + mma_s_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.mma_s_stage + ) + p_mma_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.p_mma_stage + ) + mma_o_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.mma_o_stage + ) + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self.get_k_tile_count( + split_kv, cache_seqs, block_split_kvs, blk_coord + ) + if k_tile_count > 0: + mma_common_params = SimpleNamespace( + blk_coord=blk_coord, + local_split_kv=local_split_kv, + load_q_pipeline=load_q_pipeline, + load_k_pipeline=load_k_pipeline, + load_v_pipeline=load_v_pipeline, + tmem_ptr=tmem_ptr, + is_leader_cta=is_leader_cta, + L=mCL.shape[1], + ) + mma_qk_params = SimpleNamespace( + mma_s_pipeline=mma_s_pipeline, + sQ=sQ, + sQ_rope=sQ_rope, + sKC=sKC, + sKC_rope=sKC_rope, + ) + mma_pv_params = SimpleNamespace( + p_mma_pipeline=p_mma_pipeline, + mma_o_pipeline=mma_o_pipeline, + sP=sP, + sVC=sVC, + ) + ( + tiled_mma_qk, + tiled_mma_pv, + load_q_consumer_state, + load_k_consumer_state, + load_v_consumer_state, + mma_s_producer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) = self.mma( + mma_common_params, + mma_qk_params, + mma_pv_params, + k_tile_count, + tiled_mma_qk, + tiled_mma_pv, + load_q_consumer_state, + load_k_consumer_state, + load_v_consumer_state, + mma_s_producer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + mma_s_pipeline.producer_tail(mma_s_producer_state) + mma_o_pipeline.producer_tail(mma_o_producer_state) + + tmem.relinquish_alloc_permit() + tmem.free(tmem_ptr) + + # /////////////////////////////////////////////////////////////////////////////// + # Compute warp + # /////////////////////////////////////////////////////////////////////////////// + if ( + warp_idx >= self.compute_warp_ids[0] + and warp_idx <= self.compute_warp_ids[-1] + ): + cute.arch.setmaxregister_increase(self.softmax_reg_num) + mma_s_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.mma_s_stage + ) + p_mma_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.p_mma_stage + ) + p_cor_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.p_cor_stage + ) + mma_o_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.mma_o_stage + ) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self.get_k_tile_count( + split_kv, cache_seqs, block_split_kvs, blk_coord + ) + if k_tile_count > 0: + compute_common_params = SimpleNamespace( + blk_coord=blk_coord, + split_kv=split_kv, + local_split_kv=local_split_kv, + smem_exchange=softmax_smem_exchange, + mAccO=mAccO, + mO=mO, + K=cache_seqs[blk_coord[2]], + L=mCL.shape[1], + tmem_ptr=tmem_ptr, + tidx=tidx, + p_cor_pipeline=p_cor_pipeline, + ) + compute_softmax_params = SimpleNamespace( + tiled_mma_qk=tiled_mma_qk, + sP=sP, + mma_s_pipeline=mma_s_pipeline, + p_mma_pipeline=p_mma_pipeline, + softmax_scale_log2=softmax_scale_log2, + ) + mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state = ( + self.compute( + compute_common_params, + compute_softmax_params, + k_index=k_index, + k_tile_count=k_tile_count, + mma_s_consumer_state=mma_s_consumer_state, + p_mma_producer_state=p_mma_producer_state, + p_cor_producer_state=p_cor_producer_state, + ) + ) + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + p_cor_pipeline.producer_tail(p_cor_producer_state) + + # /////////////////////////////////////////////////////////////////////////////// + # Correction warp + # /////////////////////////////////////////////////////////////////////////////// + if ( + warp_idx >= self.correction_warp_ids[0] + and warp_idx <= self.correction_warp_ids[-1] + ): + cute.arch.setmaxregister_increase(self.correction_reg_num) + p_cor_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.p_cor_stage + ) + mma_o_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.mma_o_stage + ) + # sync with mma warp before retrieving tmem ptr + tmem.wait_for_alloc() + + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + + tile_sched = create_mla_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + blk_coord = work_tile.tile_idx + k_index, k_tile_count, local_split_kv = self.get_k_tile_count( + split_kv, cache_seqs, block_split_kvs, blk_coord + ) + if k_tile_count > 0: + compute_common_params = SimpleNamespace( + blk_coord=blk_coord, + split_kv=split_kv, + local_split_kv=local_split_kv, + smem_exchange=epilogue_smem_exchange, + mAccO=mAccO, + mO=mO, + K=cache_seqs[blk_coord[2]], + L=mCL.shape[1], + H=mQL.shape[0], + tmem_ptr=tmem_ptr, + tidx=tidx, + tiled_mma_pv=tiled_mma_pv, + p_cor_pipeline=p_cor_pipeline, + mma_o_pipeline=mma_o_pipeline, + ) + compute_epilogue_params = SimpleNamespace( + output_scale=output_scale, + softmax_scale_log2=softmax_scale_log2, + mAccLSE=mAccLSE, + mLSE=mLSE, + ) + p_cor_consumer_state, mma_o_consumer_state = self.correction( + compute_common_params, + compute_epilogue_params, + k_tile_count=k_tile_count, + p_cor_consumer_state=p_cor_consumer_state, + mma_o_consumer_state=mma_o_consumer_state, + ) + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + return + + @cute.kernel + def reduction_kernel( + self, + mO: cute.Tensor, + mLSE: cute.Tensor, + mAccO: cute.Tensor, + mAccLSE: cute.Tensor, + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + ): + """The reduction kernel for Multi-Head Latent Attention (MLA) that combines intermediate results + from multiple split_kv blocks into final outputs. + + :param mO: Output tensor for storing final results + :type mO: cute.Tensor + :param mLSE: Log-sum-exp tensor for storing final LSE values + :type mLSE: cute.Tensor + :param mAccO: Accumulated output tensor from split_kv blocks + :type mAccO: cute.Tensor + :param mAccLSE: Accumulated LSE tensor from split_kv blocks + :type mAccLSE: cute.Tensor + :param split_kv: Number of split_kv blocks + :type split_kv: cutlass.Int32 + :param cache_seqs: Cache sequence lengths tensor + :type cache_seqs: cute.Tensor + :param block_split_kvs: Per-block split_kv values tensor (for variable split_kv) + :type block_split_kvs: cute.Tensor + """ + bidx, bidy, bidz = cute.arch.block_idx() + tidx, _, _ = cute.arch.thread_idx() + blk_coord = (bidx, bidy, bidz) + local_split_kv = ( + block_split_kvs[blk_coord[2]] if self.is_var_split_kv else split_kv + ) + k_tile_total = cute.ceil_div(cache_seqs[blk_coord[2]], self.mma_qk_tiler[1]) + k_tile_per_cta = cute.ceil_div(k_tile_total, local_split_kv) + local_split_kv = cute.ceil_div(k_tile_total, k_tile_per_cta) + + # Alloc shared memory + smem = utils.SmemAllocator() + storage = smem.allocate(MAX_SPLITS * self.acc_dtype.width // 8, 16) + lse_scale_ptr = cute.recast_ptr(storage, dtype=self.acc_dtype) + smem_lse_scale = cute.make_tensor(lse_scale_ptr, cute.make_layout(MAX_SPLITS)) + + gLSE = mAccLSE[blk_coord[0], None, blk_coord[1], blk_coord[2]] + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + if warp_idx == 0: + # calculate the global lse and exp ^ (local_lse - global_lse) + lse_per_thread = cute.ceil_div(MAX_SPLITS, self.threads_per_warp) + + local_lse = cute.make_rmem_tensor( + cute.make_layout(lse_per_thread), self.lse_dtype + ) + lse_max = -self.lse_dtype.inf + # find the max lse + for i in cutlass.range_constexpr(lse_per_thread): + split_kv_idx = tidx + i * self.threads_per_warp + local_lse[i] = ( + gLSE[split_kv_idx] + if cute.elem_less(split_kv_idx, local_split_kv) + else -self.lse_dtype.inf + ) + # reduce the local lse + lse_max = cute.arch.fmax(lse_max, local_lse[i]) + lse_max = cute.arch.warp_reduction_max(lse_max) + lse_max = lse_max if lse_max != -self.lse_dtype.inf else 0.0 + # calculate sum_lse + sum_lse = 0.0 + for i in cutlass.range_constexpr(lse_per_thread): + sum_lse += cute.math.exp2(local_lse[i] - lse_max, fastmath=True) + sum_lse = cute.arch.warp_reduction_sum(sum_lse) + # calculate the global_lse + global_lse = ( + lse_max + cute.math.log2(sum_lse, fastmath=True) + if not sum_lse == self.lse_dtype(0.0) or sum_lse != sum_lse + else self.lse_dtype.inf + ) + if tidx == 0: + mLSE[blk_coord[0], blk_coord[1], blk_coord[2]] = global_lse + # store the scale to shared memory + for i in cutlass.range_constexpr(lse_per_thread): + split_kv_idx = tidx + i * self.threads_per_warp + if cute.elem_less(split_kv_idx, local_split_kv): + smem_lse_scale[split_kv_idx] = cute.math.exp2( + local_lse[i] - global_lse, fastmath=True + ) + + pipeline.sync(barrier_id=4) + + elements_per_thread = cute.ceil_div( + self.latent_dim, self.threads_per_warp * self.num_compute_warps + ) + gAccO = mAccO[blk_coord[0], None, None, blk_coord[1], blk_coord[2]] + rAccO = cute.make_rmem_tensor( + cute.make_layout(elements_per_thread), self.acc_dtype + ) + rO = cute.make_rmem_tensor(cute.make_layout(elements_per_thread), self.o_dtype) + rAccO.fill(0.0) + for i in range(local_split_kv): + for j in cutlass.range_constexpr(elements_per_thread): + element_idx = tidx + j * self.threads_per_warp * self.num_compute_warps + rAccO[j] += gAccO[i, element_idx] * smem_lse_scale[i] + rO.store(rAccO.load().to(self.o_dtype)) + for j in cutlass.range_constexpr(elements_per_thread): + element_idx = tidx + j * self.threads_per_warp * self.num_compute_warps + mO[blk_coord[0], element_idx, blk_coord[1], blk_coord[2]] = rO[j] + return + + @staticmethod + def get_split_kv( + B: int, S: int, K: int, mma_qk_tiler_mn: tuple, max_active_blocks: int + ) -> int: + """Get the proper split_kv value for the MLA kernel based on parameters. + + :param B: Batch size + :type B: int + :param S: Sequence length + :type S: int + :param K: Sequence length + :type K: int + :param mma_qk_tiler_mn: MLA tiling parameters + :type mma_qk_tiler_mn: tuple + :param max_active_blocks: Maximum number of active blocks + :type max_active_blocks: int + :return: Split_kv value + :rtype: int + """ + max_splits = ceil_div(K, mma_qk_tiler_mn[1]) + blocks_per_batch = max(1, max_active_blocks // B // (S * 2)) + split_heur = min(max_splits, blocks_per_batch) + # {$nv-internal-release begin} + # TODO: figure out the error of make_tile with dynamic int_tuple + # {$nv-internal-release end} + k_waves = ceil_div(max_splits, split_heur) + split_wave_aware = ceil_div(max_splits, k_waves) + max_split_kv = 32 + return min(split_wave_aware, max_split_kv) + + @cute.jit + def get_k_tile_count( + self, + split_kv: cutlass.Int32, + cache_seqs: cute.Tensor, + block_split_kvs: cute.Tensor, + blk_coord: cute.Coord, + ) -> tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32]: + """Get the current k_index, k_tile_count, and local split_kv value for the MLA kernel. + + :param split_kv: Split_kv value + :type split_kv: cutlass.Int32 + :param cache_seqs: Cache sequence lengths tensor + :type cache_seqs: cute.Tensor + :param block_split_kvs: Per-block split_kv values tensor + :type block_split_kvs: cute.Tensor + :param blk_coord: Block coordinate + :type blk_coord: cute.Coord + :return: k_index, k_tile_count, split_kv + :rtype: tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32] + """ + K = cache_seqs[blk_coord[2]] + if cutlass.const_expr(self.is_var_split_kv): + split_kv = block_split_kvs[blk_coord[2]] + + k_tile_total = cute.ceil_div(K, self.mma_qk_tiler[1]) + # {$nv-internal-release begin} + # TODO: figure out the error of make_tile with dynamic int_tuple + # {$nv-internal-release end} + k_tile_per_cta = cute.ceil_div(k_tile_total, split_kv) + k_index = blk_coord[3] * k_tile_per_cta + k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index) + return k_index, k_tile_count, split_kv + + @cute.jit + def load_tma_qk( + self, + common_params: SimpleNamespace, + qk_params: SimpleNamespace, + k_index: cutlass.Int32, + k_tile_count: cutlass.Int32, + load_q_producer_state: pipeline.PipelineState | None = None, + load_k_producer_state: pipeline.PipelineState | None = None, + ) -> tuple[pipeline.PipelineState, pipeline.PipelineState]: + """Load wrap to load Q/K tensors. Updates the load qk producer state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param qk_params: The qk parameters + :type qk_params: SimpleNamespace + :param k_index: The k index + :type k_index: cutlass.Int32 + :param k_tile_count: The k tile count + :type k_tile_count: cutlass.Int32 + :param load_q_producer_state: The load q producer state + :type load_q_producer_state: pipeline.PipelineState + :param load_k_producer_state: The load k producer state + :type load_k_producer_state: pipeline.PipelineState + + :return: The load q producer state and load k producer state + :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState] + """ + # page table + mPT = common_params.mPT[None, common_params.blk_coord[2]] + + # Flatten divide and partition global tensors for QK TMA load + # (bM, bK, rM, rK, rL) + mma_qk_tiler_mk = cute.select(self.mma_qk_tiler, mode=[0, 2]) + gQL = cute.flat_divide(qk_params.mQL, mma_qk_tiler_mk) + mma_qk_tiler_mk_rope = cute.select(self.mma_qk_rope_tiler, mode=[0, 2]) + gQR = cute.flat_divide(qk_params.mQR, mma_qk_tiler_mk_rope) + + thr_mma_qk = qk_params.tiled_mma_qk.get_slice( + common_params.blk_coord[0] % cute.size(qk_params.tiled_mma_qk.thr_id) + ) + tSgQL = thr_mma_qk.partition_A(gQL) + tSgQR = thr_mma_qk.partition_A(gQR) + + cta_m = min( + qk_params.tiled_mma_qk.op.shape_mnk[0] + // qk_params.tiled_mma_qk.thr_id.shape, + self.page_size, + ) + page_tile_size = min(self.page_size, cta_m) + gCL = cute.tiled_divide(qk_params.mCL, (page_tile_size, self.mma_qk_tiler[2])) + tSgCL = ( + gCL[ + None, + common_params.blk_coord[0] % qk_params.tiled_mma_qk.thr_id.shape, + None, + None, + ] + if cta_m < self.page_size + else gCL[None, 0, None, None] + ) + gKR = cute.tiled_divide( + qk_params.mKR, (page_tile_size, self.mma_qk_rope_tiler[2]) + ) + tSgKR = ( + gKR[ + None, + common_params.blk_coord[0] % qk_params.tiled_mma_qk.thr_id.shape, + None, + None, + ] + if cta_m < self.page_size + else gKR[None, 0, None, None] + ) + # tma partition for q, k latent/rope + + # smem: ((atom_v, rest_v), STAGE) + # gmem: ((atom_v, rest_v), RestM, RestK, RestL) + tQsQ, tQLgQL_mkl = cpasync.tma_partition( + qk_params.tma_atom_q_latent, + 0, + cute.make_layout(1), + cute.group_modes(qk_params.sQ, 0, 3), + cute.group_modes(tSgQL, 0, 3), + ) + + tQsQ_rope, tQRgQR_mkl = cpasync.tma_partition( + qk_params.tma_atom_q_rope, + 0, + cute.make_layout(1), + cute.group_modes(qk_params.sQ_rope, 0, 3), + cute.group_modes(tSgQR, 0, 3), + ) + tKCsKC, tCLgCL = cpasync.tma_partition( + qk_params.tma_atom_c_latent, + 0, + cute.make_layout(1), + qk_params.sKC, + tSgCL, + ) + + tKCsKC_rope, tKRgKR = cpasync.tma_partition( + qk_params.tma_atom_c_rope, + 0, + cute.make_layout(1), + qk_params.sKC_rope, + tSgKR, + ) + + tQLgQL = tQLgQL_mkl[ + None, None, None, common_params.blk_coord[1], common_params.blk_coord[2] + ] + tQRgQR = tQRgQR_mkl[ + None, None, None, common_params.blk_coord[1], common_params.blk_coord[2] + ] + + # set extra params + common_params.mPT = mPT + qk_params.tQLgQL = tQLgQL + qk_params.tQRgQR = tQRgQR + qk_params.tCLgCL = tCLgCL + qk_params.tKRgKR = tKRgKR + qk_params.tQsQ = tQsQ + qk_params.tQsQ_rope = tQsQ_rope + qk_params.tKCsKC = tKCsKC + qk_params.tKCsKC_rope = tKCsKC_rope + + k_tile_count_init = k_tile_count + while k_tile_count > 0: + # {$nv-internal-release begin} + # TODO: figure out how to support SingleNamespace/struct in ast + # {$nv-internal-release end} + load_q_producer_state, load_k_producer_state = self.load_tma_qk_one_k_tile( + common_params, + qk_params, + k_index, + k_tile_count, + load_q_producer_state, + load_k_producer_state, + load_q=k_tile_count_init == k_tile_count, + ) + k_index += 1 + k_tile_count -= 1 + + return load_q_producer_state, load_k_producer_state + + @cute.jit + def load_tma_v( + self, + common_params: SimpleNamespace, + v_params: SimpleNamespace, + k_index: cutlass.Int32, + k_tile_count: cutlass.Int32, + load_v_producer_state: pipeline.PipelineState, + ) -> pipeline.PipelineState: + """Load wrap to load V tensors. Updates the load v producer state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param v_params: The v parameters + :type v_params: SimpleNamespace + :param k_index: The k index + :type k_index: cutlass.Int32 + :param k_tile_count: The k tile count + :type k_tile_count: cutlass.Int32 + :param load_v_producer_state: The load v producer state + :type load_v_producer_state: pipeline.PipelineState + + :return: The load v producer state + :rtype: pipeline.PipelineState + """ + # page table + mPT = common_params.mPT[None, common_params.blk_coord[2]] + + # Flatten divide and partition global tensors for V TMA load + page_tile_size = min(self.page_size, self.mma_pv_tiler[2]) + gCLT = cute.flat_divide(v_params.mCLT, (self.mma_pv_tiler[1], page_tile_size)) + cta_n = self.mma_pv_tiler[1] // v_params.tiled_mma_pv.thr_id.shape + gCLT = cute.logical_divide(gCLT, (cta_n,))[ + (None, common_params.blk_coord[0]), None, None, None, None + ] + tOgCLT = cute.tiled_divide(gCLT, (cta_n, page_tile_size)) + tOgCLT = tOgCLT[None, 0, 0, None, None, None] + # tma partition for vc + # smem: ((atom_v, rest_v), STAGE) + # gmem: ((atom_v, rest_v), RestM, RestK, RestL) + tVCsVC, tCLTgCLT = cpasync.tma_partition( + v_params.tma_atom_c_latent_transpose, + 0, + cute.make_layout(1), + v_params.sVC, + tOgCLT, + ) + + # set extra params + common_params.mPT = mPT + v_params.tCLTgCLT = tCLTgCLT + v_params.tVCsVC = tVCsVC + + while k_tile_count > 0: + # {$nv-internal-release begin} + # TODO: figure out how to support SingleNamespace/struct in ast + # {$nv-internal-release end} + load_v_producer_state = self.load_tma_v_one_k_tile( + common_params, + v_params, + k_index, + load_v_producer_state, + ) + k_index += 1 + k_tile_count -= 1 + return load_v_producer_state + + @cute.jit + def load_tma_qk_one_k_tile( + self, + common_params: SimpleNamespace, + qk_params: SimpleNamespace, + k_index: cutlass.Int32, + k_tile_count: cutlass.Int32, + load_q_producer_state: pipeline.PipelineState, + load_k_producer_state: pipeline.PipelineState, + load_q: bool, + ) -> tuple[pipeline.PipelineState, pipeline.PipelineState]: + """Load one k-tile of Q/C latent/rope tensors. Updates the load qkv producer state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param qk_params: The qk parameters + :type qk_params: SimpleNamespace + :param k_index: The k index + :type k_index: cutlass.Int32 + :param k_tile_count: The k tile count + :type k_tile_count: cutlass.Int32 + :param load_q_producer_state: The load q producer state + :type load_q_producer_state: pipeline.PipelineState + :param load_k_producer_state: The load kv producer state + :type load_k_producer_state: pipeline.PipelineState + :param load_q: Whether to load q + :type load_q: bool + + :return: The load q producer state and load kv producer state + :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState] + """ + page_per_tile = ceil_div( + self.mma_qk_tiler[1] // self.page_size, qk_params.tiled_mma_qk.thr_id.shape + ) + k_idx = cute.make_rmem_tensor(cute.make_layout(page_per_tile), cutlass.Int32) + for i in cutlass.range_constexpr(page_per_tile): + k_idx[i] = ( + common_params.mPT[k_index] + if self.mma_qk_tiler[1] // self.page_size == 1 + else common_params.mPT[ + ( + k_index * qk_params.tiled_mma_qk.thr_id.shape + + common_params.blk_coord[0] + ) + * page_per_tile + + i + ] + ) + # load q once at first iteration + load_q_pipeline = common_params.load_q_pipeline + if load_q: + # get the mbar ptr from pipeline. + tma_bar_ptr = load_q_pipeline.producer_get_barrier(load_q_producer_state) + # expect the extra bytes for q. + load_q_pipeline.producer_acquire(load_q_producer_state) + for i in cutlass.range_constexpr(self.iterations_qk_latent): + # load q latent + cute.copy( + qk_params.tma_atom_q_latent, + qk_params.tQLgQL[None, 0, i], + qk_params.tQsQ[None, (i, 0)], + tma_bar_ptr=tma_bar_ptr, + ) + for i in cutlass.range_constexpr(self.iterations_qk_rope): + # load q rope + cute.copy( + qk_params.tma_atom_q_rope, + qk_params.tQRgQR[None, 0, i], + qk_params.tQsQ_rope[None, i], + tma_bar_ptr=tma_bar_ptr, + ) + load_q_producer_state.advance() + # get the mbar ptr from pipeline. + tma_bar_ptr = common_params.load_k_pipeline.producer_get_barrier( + load_k_producer_state + ) + common_params.load_k_pipeline.producer_acquire(load_k_producer_state) + for i in range(self.iterations_qk_latent): + for k in range(page_per_tile): + # load k latent + cute.copy( + qk_params.tma_atom_c_latent, + qk_params.tCLgCL[None, i, k_idx[k]], + qk_params.tKCsKC[None, k, 0, (i, load_k_producer_state.index)], + tma_bar_ptr=tma_bar_ptr, + ) + + for i in cutlass.range_constexpr(self.iterations_qk_rope): + for k in cutlass.range_constexpr(page_per_tile): + # load k rope + cute.copy( + qk_params.tma_atom_c_rope, + qk_params.tKRgKR[None, i, k_idx[k]], + qk_params.tKCsKC_rope[None, k, 0, load_k_producer_state.index], + tma_bar_ptr=tma_bar_ptr, + ) + load_k_producer_state.advance() + + return load_q_producer_state, load_k_producer_state + + @cute.jit + def load_tma_v_one_k_tile( + self, + common_params: SimpleNamespace, + v_params: SimpleNamespace, + k_index: cutlass.Int32, + load_v_producer_state: pipeline.PipelineState, + ) -> pipeline.PipelineState: + """Load one k-tile of compressed latent transpose tensor(v). Updates the load qkv producer state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param v_params: The load tma v parameters + :type v_params: SimpleNamespace + :param k_index: The k index + :type k_index: cutlass.Int32 + :param load_v_producer_state: The load v producer state + :type load_v_producer_state: pipeline.PipelineState + + :return: The load qkv producer state + :rtype: pipeline.PipelineState + """ + page_per_tile = self.mma_pv_tiler[2] * self.iterations_pv_k // self.page_size + page_per_subtile = ceil_div(page_per_tile, self.iterations_pv_k) + k_idx = cute.make_rmem_tensor(cute.make_layout(page_per_tile), cutlass.Int32) + for i in cutlass.range_constexpr(page_per_tile): + k_idx[i] = ( + common_params.mPT[k_index] + if page_per_tile == 1 + else common_params.mPT[k_index * page_per_tile + i] + ) + # get the mbar ptr from pipeline. + tma_bar_ptr = common_params.load_v_pipeline.producer_get_barrier( + load_v_producer_state + ) + common_params.load_v_pipeline.producer_acquire(load_v_producer_state) + for j in cutlass.range_constexpr(self.iterations_pv_n): + for i in cutlass.range_constexpr(self.iterations_pv_k): + if cutlass.const_expr(page_per_tile > 1): + for k in cutlass.range_constexpr(page_per_subtile): + k_idx_i = k_idx[k + i * page_per_subtile] + cute.copy( + v_params.tma_atom_c_latent_transpose, + v_params.tCLTgCLT[None, j, 0, k_idx_i], + v_params.tVCsVC[ + None, 0, k, ((j, i), load_v_producer_state.index) + ], + tma_bar_ptr=tma_bar_ptr, + ) + else: + cute.copy( + v_params.tma_atom_c_latent_transpose, + v_params.tCLTgCLT[None, j, i, k_idx[0]], + v_params.tVCsVC[ + None, 0, 0, ((j, i), load_v_producer_state.index) + ], + tma_bar_ptr=tma_bar_ptr, + ) + load_v_producer_state.advance() + return load_v_producer_state + + @cute.jit + def mma( + self, + common_params: SimpleNamespace, + qk_params: SimpleNamespace, + pv_params: SimpleNamespace, + k_tile_count: cutlass.Int32, + tiled_mma_qk: cute.TiledMma, + tiled_mma_pv: cute.TiledMma, + load_q_consumer_state: pipeline.PipelineState, + load_k_consumer_state: pipeline.PipelineState, + load_v_consumer_state: pipeline.PipelineState, + mma_s_producer_state: pipeline.PipelineState, + p_mma_consumer_state: pipeline.PipelineState, + mma_o_producer_state: pipeline.PipelineState, + ) -> tuple[ + cute.TiledMma, + cute.TiledMma, + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + ]: + """MMA warp to compute the result of Q*K^T and P*V. Updates the tiled mma and pipeline states. + + :param common_params: The common parameters for mma qk and pv + :type common_params: SimpleNamespace + :param qk_params: The mma qk parameters + :type qk_params: SimpleNamespace + :param pv_params: The mma pv parameters + :type pv_params: SimpleNamespace + :param k_tile_count: The k tile count + :type k_tile_count: cutlass.Int32 + :param tiled_mma_qk: The tiled mma qk + :type tiled_mma_qk: cute.TiledMma + :param tiled_mma_pv: The tiled mma pv + :type tiled_mma_pv: cute.TiledMma + :param load_q_consumer_state: The load q consumer state + :type load_q_consumer_state: pipeline.PipelineState + :param load_k_consumer_state: The load k consumer state + :type load_k_consumer_state: pipeline.PipelineState + :param load_v_consumer_state: The load v consumer state + :type load_v_consumer_state: pipeline.PipelineState + :param mma_s_producer_state: The mma s producer state + :type mma_s_producer_state: pipeline.PipelineState + :param p_mma_consumer_state: The p mma consumer state + :type p_mma_consumer_state: pipeline.PipelineState + :param mma_o_producer_state: The mma o producer state + :type mma_o_producer_state: pipeline.PipelineState + + :return: The tiled mma qk, the tiled mma pv, the load q consumer state, the load k consumer state, the load v consumer state, the mma s producer state, the p mma consumer state, and the mma o producer state + :rtype: tuple[cute.TiledMma, cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] + """ + + tSrQ = tiled_mma_qk.make_fragment_A(qk_params.sQ) + tSrQ_rope = tiled_mma_qk.make_fragment_A(qk_params.sQ_rope) + tSrKC = tiled_mma_qk.make_fragment_B(qk_params.sKC) + tSrKC_rope = tiled_mma_qk.make_fragment_B(qk_params.sKC_rope) + tOrP = tiled_mma_pv.make_fragment_A(pv_params.sP) + tOrVC = tiled_mma_pv.make_fragment_B(pv_params.sVC) + + tStS_shape = tiled_mma_qk.partition_shape_C( + cute.select(self.mma_qk_tiler, mode=[0, 1]) + ) + tStS_staged_fake = tiled_mma_qk.make_fragment_C( + cute.append(tStS_shape, self.mma_s_stage) + ) + # use real tmem ptr for tStS + tStS_staged = cute.make_tensor(common_params.tmem_ptr, tStS_staged_fake.layout) + tOtO_shape = tiled_mma_pv.partition_shape_C( + cute.select(self.mma_pv_tiler, mode=[0, 1]) + ) + # mma O has 1 stage. + tOtO = tiled_mma_pv.make_fragment_C(tOtO_shape) + tOtO_layout = cute.append( + tOtO.layout, + cute.make_layout( + common_params.L // self.mma_pv_tiler[1], + stride=self.mma_pv_tiler[1] // self.warps_in_n, + ), + ) + tOtO_staged = cute.make_tensor( + tStS_staged.iterator + self.tmem_o_offset, tOtO_layout + ) + + # set more parameters + qk_params.tSrQ = tSrQ + qk_params.tSrQ_rope = tSrQ_rope + qk_params.tSrKC = tSrKC + qk_params.tSrKC_rope = tSrKC_rope + qk_params.tStS_staged = tStS_staged + pv_params.tOrP = tOrP + pv_params.tOrVC = tOrVC + pv_params.tOtO_staged = tOtO_staged + + # mma O accumulates on K, so the accumlate flag is set to False once before all K blocks. + tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, False) + load_q_pipeline = common_params.load_q_pipeline + if common_params.is_leader_cta: + load_q_release_state = load_q_consumer_state.clone() + ( + tiled_mma_qk, + load_q_consumer_state, + load_k_consumer_state, + mma_s_producer_state, + ) = self.mma_qk( + common_params, + qk_params, + tiled_mma_qk, + load_q_consumer_state, + load_k_consumer_state, + mma_s_producer_state, + wait_q=True, + ) + k_tile_count -= 1 + + while k_tile_count > 0: + ( + tiled_mma_qk, + load_q_consumer_state, + load_k_consumer_state, + mma_s_producer_state, + ) = self.mma_qk( + common_params, + qk_params, + tiled_mma_qk, + load_q_consumer_state, + load_k_consumer_state, + mma_s_producer_state, + wait_q=False, + ) + ( + tiled_mma_pv, + load_v_consumer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) = self.mma_pv( + common_params, + pv_params, + tiled_mma_pv, + load_v_consumer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) + k_tile_count -= 1 + # release q consumer states + load_q_pipeline.consumer_release(load_q_release_state) + load_q_release_state.advance() + ( + tiled_mma_pv, + load_v_consumer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) = self.mma_pv( + common_params, + pv_params, + tiled_mma_pv, + load_v_consumer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) + + return ( + tiled_mma_qk, + tiled_mma_pv, + load_q_consumer_state, + load_k_consumer_state, + load_v_consumer_state, + mma_s_producer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) + + @cute.jit + def mma_qk( + self, + common_params: SimpleNamespace, + qk_params: SimpleNamespace, + tiled_mma_qk: cute.TiledMma, + load_q_consumer_state: pipeline.PipelineState, + load_k_consumer_state: pipeline.PipelineState, + mma_s_producer_state: pipeline.PipelineState, + wait_q: bool, + ) -> tuple[ + cute.TiledMma, + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + ]: + """Compute one k-tile of mma for Q*K^T. Updates the tiled MMA QK and pipeline states. + + :param qk_params: The qk parameters + :type qk_params: SimpleNamespace + :param tiled_mma_qk: The tiled mma qk + :type tiled_mma_qk: cute.TiledMma + :param load_q_consumer_state: The load q consumer state + :type load_q_consumer_state: pipeline.PipelineState + :param load_k_consumer_state: The load k consumer state + :type load_k_consumer_state: pipeline.PipelineState + :param mma_s_producer_state: The mma s producer state + :type mma_s_producer_state: pipeline.PipelineState + + :return: The tiled mma qk, the load q consumer state, the load k consumer state, and the mma s producer state + :rtype: tuple[cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] + """ + tStS = qk_params.tStS_staged[None, None, None, mma_s_producer_state.index] + + qk_params.mma_s_pipeline.producer_acquire(mma_s_producer_state) + tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, False) + load_q_pipeline = common_params.load_q_pipeline + load_k_pipeline = common_params.load_k_pipeline + if cutlass.const_expr(wait_q): + load_q_pipeline.consumer_wait(load_q_consumer_state) + load_k_pipeline.consumer_wait(load_k_consumer_state) + for q_stage in range(self.iterations_qk_latent): + kc_stage = load_k_consumer_state.index + for k_block in cutlass.range_constexpr(cute.size(qk_params.tSrQ.shape[2])): + cute.gemm( + tiled_mma_qk, + tStS, + qk_params.tSrQ[None, None, k_block, (q_stage, 0)], + qk_params.tSrKC[None, None, k_block, (q_stage, kc_stage)], + tStS, + ) + tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, True) + + for q_stage in range(self.iterations_qk_rope): + kc_stage = load_k_consumer_state.index + for k_block in cutlass.range_constexpr( + self.rope_dim // tiled_mma_qk.shape_mnk[2] + ): + cute.gemm( + tiled_mma_qk, + tStS, + qk_params.tSrQ_rope[None, None, k_block, q_stage], + qk_params.tSrKC_rope[None, None, k_block, kc_stage], + tStS, + ) + tiled_mma_qk.set(tcgen05.Field.ACCUMULATE, True) + load_k_pipeline.consumer_release(load_k_consumer_state) + load_k_consumer_state.advance() + if cutlass.const_expr(wait_q): + load_q_consumer_state.advance() + + qk_params.mma_s_pipeline.producer_commit(mma_s_producer_state) + mma_s_producer_state.advance() + return ( + tiled_mma_qk, + load_q_consumer_state, + load_k_consumer_state, + mma_s_producer_state, + ) + + @cute.jit + def mma_pv( + self, + common_params: SimpleNamespace, + pv_params: SimpleNamespace, + tiled_mma_pv: cute.TiledMma, + load_v_consumer_state: pipeline.PipelineState, + p_mma_consumer_state: pipeline.PipelineState, + mma_o_producer_state: pipeline.PipelineState, + ) -> tuple[ + cute.TiledMma, + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + ]: + """Compute one k-tile of mma for P*V. Updates the tiled mma pv and pipeline states. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param pv_params: The pv parameters + :type pv_params: SimpleNamespace + :param tiled_mma_pv: The tiled mma pv + :type tiled_mma_pv: cute.TiledMma + :param load_v_consumer_state: The load v consumer state + :type load_v_consumer_state: pipeline.PipelineState + :param p_mma_consumer_state: The P MMA consumer state + :type p_mma_consumer_state: pipeline.PipelineState + :param mma_o_producer_state: The MMA o producer state + :type mma_o_producer_state: pipeline.PipelineState + + :return: The tiled mma pv, the load v consumer state, the P MMA consumer state, and the MMA o producer state + :rtype: tuple[cute.TiledMma, pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] + """ + + pv_params.p_mma_pipeline.consumer_wait(p_mma_consumer_state) + load_v_pipeline = common_params.load_v_pipeline + accumulate_flag = tiled_mma_pv.get(tcgen05.Field.ACCUMULATE) + mma_o_pipeline = pv_params.mma_o_pipeline + + load_v_pipeline.consumer_wait(load_v_consumer_state) + vc_stage = load_v_consumer_state.index + for acc_stage in range(self.iterations_pv_n): + mma_o_pipeline.producer_acquire(mma_o_producer_state) + tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, accumulate_flag) + for p_stage in range(self.iterations_pv_k): + tOtO = pv_params.tOtO_staged[None, None, None, acc_stage] + for k_block in cutlass.range_constexpr(pv_params.tOrP.shape[2]): + cute.gemm( + tiled_mma_pv, + tOtO, + pv_params.tOrP[ + None, + None, + k_block, + (p_stage, p_mma_consumer_state.index), + ], + pv_params.tOrVC[ + None, None, k_block, ((acc_stage, p_stage), vc_stage) + ], + tOtO, + ) + tiled_mma_pv.set(tcgen05.Field.ACCUMULATE, True) + + mma_o_pipeline.producer_commit(mma_o_producer_state) + mma_o_producer_state.advance() + load_v_pipeline.consumer_release(load_v_consumer_state) + load_v_consumer_state.advance() + pv_params.p_mma_pipeline.consumer_release(p_mma_consumer_state) + p_mma_consumer_state.advance() + + return ( + tiled_mma_pv, + load_v_consumer_state, + p_mma_consumer_state, + mma_o_producer_state, + ) + + @cute.jit + def compute( + self, + common_params: SimpleNamespace, + softmax_params: SimpleNamespace, + k_index: cutlass.Int32, + k_tile_count: cutlass.Int32, + mma_s_consumer_state: pipeline.PipelineState, + p_mma_producer_state: pipeline.PipelineState, + p_cor_producer_state: pipeline.PipelineState, + ) -> tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState]: + """Compute warp to compute the result of softmax, rescale, and epilogue. Updates the related pipeline states. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param softmax_params: The softmax parameters + :type softmax_params: SimpleNamespace + :param k_index: The index of the k-tile + :type k_index: cutlass.Int32 + :param k_tile_count: The number of k-tiles + :type k_tile_count: cutlass.Int32 + :param mma_s_consumer_state: The MMA s consumer state + :type mma_s_consumer_state: pipeline.PipelineState + :param p_mma_producer_state: The P MMA producer state + :type p_mma_producer_state: pipeline.PipelineState + :param p_cor_producer_state: The P correction producer state + :type p_cor_producer_state: pipeline.PipelineState + + :return: The MMA s consumer state, the P MMA producer state, and the P correction producer state + :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState] + """ + + k_tile_total = cute.ceil_div(common_params.K, self.mma_qk_tiler[1]) + + row_max = -self.acc_dtype.inf + row_sum = self.acc_dtype(0) + correction_factor = self.acc_dtype(1) + common_params.p_cor_pipeline.producer_acquire(p_cor_producer_state) + + # no mask applied + while k_tile_count > 1: + ( + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + ) = self.softmax( + common_params, + softmax_params, + k_index, + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + False, + False, + ) + k_index = k_index + 1 + k_tile_count = k_tile_count - 1 + + # mask applied + if cutlass.const_expr(common_params.mAccO is not None): + ( + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + ) = self.softmax( + common_params, + softmax_params, + k_index, + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + k_index == k_tile_total - 1, + True, + ) + else: + ( + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + ) = self.softmax( + common_params, + softmax_params, + k_index, + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max, + row_sum, + correction_factor, + True, + True, + ) + + return mma_s_consumer_state, p_mma_producer_state, p_cor_producer_state + + @cute.jit + def correction( + self, + common_params: SimpleNamespace, + epilogue_params: SimpleNamespace, + k_tile_count: cutlass.Int32, + p_cor_consumer_state: pipeline.PipelineState, + mma_o_consumer_state: pipeline.PipelineState, + ) -> tuple[pipeline.PipelineState, pipeline.PipelineState]: + """Compute warp to compute the result of softmax, rescale, and epilogue. Updates the related pipeline states. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param epilogue_params: The epilogue parameters + :type epilogue_params: SimpleNamespace + :param k_index: The index of the k-tile + :type k_index: cutlass.Int32 + :param k_tile_count: The number of k-tiles + :type k_tile_count: cutlass.Int32 + :param p_cor_consumer_state: The P correction consumer state + :type p_cor_consumer_state: pipeline.PipelineState + :param mma_o_consumer_state: The MMA o consumer state + :type mma_o_consumer_state: pipeline.PipelineState + + :return: The P correction consumer state, and the MMA o consumer state + :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState] + """ + + k_tile_count_init = k_tile_count + while k_tile_count > 0: + p_cor_consumer_state, row_sum, row_max, correction_factor, no_correction = ( + self.get_correction_factor(common_params, p_cor_consumer_state) + ) + if k_tile_count_init != k_tile_count: + mma_o_consumer_state = self.rescale( + common_params, + mma_o_consumer_state, + correction_factor, + no_correction, + ) + k_tile_count = k_tile_count - 1 + if k_tile_count == 0: + mma_o_consumer_state = self.epilogue( + common_params, + epilogue_params, + mma_o_consumer_state, + row_sum, + row_max, + ) + return p_cor_consumer_state, mma_o_consumer_state + + @cute.jit + def exchange_p_cor_metadata( + self, + common_params: SimpleNamespace, + softmax_params: SimpleNamespace, + correction_factor: cutlass.Float32, + row_sum: cutlass.Float32, + row_max: cutlass.Float32, + row_max_new: cutlass.Float32, + tAcc: cute.Tensor, + tidx: cutlass.Int32, + p_cor_producer_state: pipeline.PipelineState, + ) -> tuple[pipeline.PipelineState, cutlass.Float32]: + """Compute the correction factor for the last k tile.""" + no_correction = 0 + if ( + row_max_new - row_max + ) * softmax_params.softmax_scale_log2 <= self.skip_correction_threshold: + no_correction = 1 + row_max_new = row_max + + # pad for 4x32b + corr_layout = cute.make_layout( + (tAcc.shape[0], (4, tAcc.shape[1][1]), self.mma_s_stage), + stride=(tAcc.stride[0], (1, tAcc.stride[1][1]), 4), + ) + tCor = cute.make_tensor( + common_params.tmem_ptr + self.correction_factor_offset, + corr_layout, + ) + cCor = cute.make_identity_tensor(tCor.shape) + corr_tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(4)), self.acc_dtype + ) + corr_tmem_store_tiled_copy = tcgen05.make_tmem_copy(corr_tmem_store_atom, tCor) + corr_tmem_store_thr_copy = corr_tmem_store_tiled_copy.get_slice(tidx) + cCor_for_copy = corr_tmem_store_thr_copy.partition_S(cCor) + tCor_for_copy = corr_tmem_store_thr_copy.partition_D(tCor) + rCor = cute.make_fragment_like( + cCor_for_copy[None, None, None, 0], self.acc_dtype + ) + rCor_int = cute.make_tensor( + cute.recast_ptr(rCor.iterator, dtype=cutlass.Int32), rCor.layout + ) + rCor[0] = row_sum + rCor[1] = row_max_new + rCor[2] = correction_factor + rCor_int[3] = no_correction + + cute.copy( + corr_tmem_store_tiled_copy, + rCor, + tCor_for_copy[None, None, None, p_cor_producer_state.index], + ) + # fence between tmem store and correction warp + cute.arch.fence_view_async_tmem_store() + common_params.p_cor_pipeline.producer_commit(p_cor_producer_state) + p_cor_producer_state.advance() + return p_cor_producer_state, row_max_new + + @cute.jit + def softmax( + self, + common_params: SimpleNamespace, + softmax_params: SimpleNamespace, + k_index: cutlass.Int32, + mma_s_consumer_state: pipeline.PipelineState, + p_mma_producer_state: pipeline.PipelineState, + p_cor_producer_state: pipeline.PipelineState, + row_max: cutlass.Float32, + row_sum: cutlass.Float32, + correction_factor: cutlass.Float32, + is_last_tile: bool, + is_local_last_tile: cutlass.Boolean, + ) -> tuple[ + pipeline.PipelineState, + pipeline.PipelineState, + pipeline.PipelineState, + cutlass.Float32, + cutlass.Float32, + cutlass.Float32, + ]: + """Softmax for one k-tile. Updates the related pipeline states and returns the computed results. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param softmax_params: The softmax parameters + :type softmax_params: SimpleNamespace + :param k_index: The index of the k-tile + :type k_index: cutlass.Int32 + :param mma_s_consumer_state: The MMA s consumer state + :type mma_s_consumer_state: pipeline.PipelineState + :param p_mma_producer_state: The P MMA producer state + :type p_mma_producer_state: pipeline.PipelineState + :param p_cor_producer_state: The P correction producer state + :type p_cor_producer_state: pipeline.PipelineState + :param row_max: The row max + :type row_max: cutlass.Float32 + :param row_sum: The row sum + :type row_sum: cutlass.Float32 + :param correction_factor: The correction factor + :type correction_factor: cutlass.Float32 + :param is_last_tile: Whether the last tile + :type is_last_tile: bool + :param is_local_last_tile: Whether the last tile is local + :type is_local_last_tile: cutlass.Boolean + + :return: The MMA s consumer state, the P MMA producer state, the P correction producer state, the row max, the row sum, and the correction factor + :rtype: tuple[pipeline.PipelineState, pipeline.PipelineState, pipeline.PipelineState, cutlass.Float32, cutlass.Float32, cutlass.Float32] + """ + + softmax_params.p_mma_pipeline.producer_acquire(p_mma_producer_state) + softmax_params.mma_s_pipeline.consumer_wait(mma_s_consumer_state) + + # load S from tmem + tStS_shape = softmax_params.tiled_mma_qk.partition_shape_C( + cute.select(self.mma_qk_tiler, mode=[0, 1]) + ) + tStS_staged_fake = softmax_params.tiled_mma_qk.make_fragment_C( + cute.append(tStS_shape, self.mma_s_stage) + ) + tStS_staged = cute.make_tensor(common_params.tmem_ptr, tStS_staged_fake.layout) + tStS = tStS_staged[None, None, None, mma_s_consumer_state.index] + + tAcc = tStS[(None, None), 0, 0] + cta_qk_tiler = ( + self.mma_qk_tiler[0] // self.cluster_shape_mnk[0], + self.mma_qk_tiler[1], + self.mma_qk_tiler[2], + ) + cS = cute.make_identity_tensor(cute.select(cta_qk_tiler, mode=[0, 1])) + + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype + ) + tmem_tiled_copy = tcgen05.make_tmem_copy(tmem_load_atom, tAcc) + + tidx = common_params.tidx % (self.num_compute_warps * self.threads_per_warp) + + tmem_thr_copy = tmem_tiled_copy.get_slice(tidx) + tTR_tAcc = tmem_thr_copy.partition_S(tAcc) + tTR_tS = tmem_thr_copy.partition_D(cS) + + tTR_rAcc = cute.make_fragment_like(tTR_tS, self.acc_dtype) + + row_max_new = row_max + arch = BaseDSL._get_dsl().get_arch_enum() + if cutlass.const_expr(arch >= Arch.sm_100 and arch <= Arch.sm_100f): + cute.copy(tmem_tiled_copy, tTR_tAcc, tTR_rAcc) + for i in cutlass.range_constexpr(cute.size(tTR_rAcc)): + if is_last_tile: + tTR_rAcc[i] = ( + tTR_rAcc[i] + if cute.elem_less( + tTR_tS[i][1] + self.mma_qk_tiler[1] * k_index, + common_params.K, + ) + else -self.acc_dtype.inf + ) + # reduction for row_max + row_max_new = tTR_rAcc.load().reduce(cute.ReductionOp.MAX, row_max_new, 0) + elif cutlass.const_expr(arch >= Arch.sm_103 and arch <= Arch.sm_103f): + tmem_load_red_atom = cute.make_copy_atom( + tcgen05.copy.LdRed32x32bOp( + tcgen05.copy.Repetition(64), redOp=tcgen05.TmemLoadRedOp.MAX + ), + self.acc_dtype, + ) + tmem_red_tiled_copy = tcgen05.make_tmem_copy(tmem_load_red_atom, tAcc) + tmem_red_thr_copy = tmem_red_tiled_copy.get_slice(tidx) + tTR_tAcc_red = tmem_red_thr_copy.partition_S(tAcc) + tTR_tS_red = tmem_red_thr_copy.partition_D(cS) + tTR_rAcc_red = cute.make_fragment_like(tTR_tS_red, self.acc_dtype) + tTR_rMax = cute.make_rmem_tensor( + cute.make_layout((1, tTR_tS_red.shape[1], tTR_tS_red.shape[2])), + self.acc_dtype, + ) + cute.copy( + tmem_red_tiled_copy, + tTR_tAcc_red, + (tTR_rAcc_red, tTR_rMax), + ) + tTR_rAcc = cute.make_tensor(tTR_rAcc_red.iterator, tTR_rAcc.layout) + if is_last_tile: + for i in cutlass.range_constexpr(cute.size(tTR_rAcc)): + tTR_rAcc[i] = ( + tTR_rAcc[i] + if cute.elem_less( + tTR_tS[i][1] + self.mma_qk_tiler[1] * k_index, + common_params.K, + ) + else -self.acc_dtype.inf + ) + # reduction for row_max + row_max_new = tTR_rAcc.load().reduce( + cute.ReductionOp.MAX, row_max_new, 0 + ) + else: + row_max_new = cute.arch.fmax(row_max_new, tTR_rMax[0]) + + # if warps in N is 2, reduce row_max across warps (0, 1) and (2, 3) + if cutlass.const_expr(self.warps_in_n == 2): + common_params.smem_exchange[tidx] = row_max_new + self.softmax_exchange_sync_bar.wait() + row_max_new = cute.arch.fmax( + row_max_new, + common_params.smem_exchange[ + (tidx + 64) % (self.num_compute_warps * self.threads_per_warp) + ], + ) + + # find correction factor + correction_factor = cute.math.exp2( + (row_max - row_max_new) * softmax_params.softmax_scale_log2, fastmath=True + ) + # split kv case + if cutlass.const_expr(not is_local_last_tile): + p_cor_producer_state, row_max_new = self.exchange_p_cor_metadata( + common_params, + softmax_params, + correction_factor, + row_sum, + row_max, + row_max_new, + tAcc, + tidx, + p_cor_producer_state, + ) + + # softmax + fma_b = softmax_params.softmax_scale_log2 + fma_c = (0.0 - row_max_new) * softmax_params.softmax_scale_log2 + + for i in cutlass.range(cute.size(tTR_rAcc), vectorize=True, unroll_full=True): + tTR_rAcc[i] = tTR_rAcc[i] * fma_b + fma_c + tTR_rAcc[i] = cute.math.exp2(tTR_rAcc[i], fastmath=True) + + tTR_rS = cute.make_fragment_like(tTR_tS, self.q_dtype) + + # quantize + tTR_rS.store(tTR_rAcc.load().to(self.q_dtype)) + + # create sP + sP = softmax_params.sP[None, None, None, (None, p_mma_producer_state.index)] + sP_mk_view = cute.make_tensor( + sP.iterator, + cute.make_layout( + ( + (sP.shape[0][0], sP.shape[1]), + (sP.shape[0][1], sP.shape[2], sP.shape[3]), + ), + stride=( + (sP.stride[0][0], sP.stride[1]), + (sP.stride[0][1], sP.stride[2], sP.stride[3]), + ), + ), + ) + # {$nv-internal-release begin} + # TODO: figure out if we could use A tmem for pv. + # {$nv-internal-release end} + # change to PISL + sP_wo_swizzle_iter = cute.recast_ptr(sP.iterator, swizzle_=None) + swizzle_bits = ( + int(math.log2(self.mma_pv_tiler[2] * self.q_dtype.width // 8 // 32)) + 1 + ) + swizzle_base = 3 if self.q_dtype.width == 16 else 4 + sP_swizzle = cute.make_swizzle(swizzle_bits, swizzle_base, 3) + sP_mk_view = cute.make_tensor( + sP_wo_swizzle_iter, + cute.make_composed_layout(sP_swizzle, 0, sP_mk_view.layout), + ) + universal_copy_bits = 128 + smem_copy_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.q_dtype, + num_bits_per_copy=universal_copy_bits, + ) + smem_tiled_copy = cute.make_tiled_copy_D(smem_copy_atom, tmem_tiled_copy) + smem_thr_copy = smem_tiled_copy.get_slice(tidx) + rP_copy_view = smem_thr_copy.retile(tTR_rS) + sP_copy_view = smem_thr_copy.partition_D(sP_mk_view) + cute.copy(smem_tiled_copy, rP_copy_view, sP_copy_view) + + # fence between smem store and mma o + cute.arch.fence_view_async_shared() + softmax_params.p_mma_pipeline.producer_commit(p_mma_producer_state) + p_mma_producer_state.advance() + + # row_sum, using `add_packed_f32x2` to reduce the number of instructions + row_sum = row_sum * correction_factor + row_sum_vec = (0.0, 0.0) + for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc), 2): + row_sum_vec = cute.arch.add_packed_f32x2( + row_sum_vec, (tTR_rAcc[i], tTR_rAcc[i + 1]) + ) + row_sum = row_sum_vec[0] + row_sum_vec[1] + row_sum + + # split kv case + if cutlass.const_expr(is_local_last_tile): + p_cor_producer_state, row_max_new = self.exchange_p_cor_metadata( + common_params, + softmax_params, + correction_factor, + row_sum, + row_max, + row_max_new, + tAcc, + tidx, + p_cor_producer_state, + ) + + # store correction factor/row_sum/row_max to tmem for correction warp + common_params.p_cor_pipeline.producer_acquire(p_cor_producer_state) + + # fence between tmem load and mma s + cute.arch.fence_view_async_tmem_load() + + softmax_params.mma_s_pipeline.consumer_release(mma_s_consumer_state) + mma_s_consumer_state.advance() + + return ( + mma_s_consumer_state, + p_mma_producer_state, + p_cor_producer_state, + row_max_new, + row_sum, + correction_factor, + ) + + @cute.jit + def _tmem_load_partition( + self, common_params: SimpleNamespace, tiled_mma_pv: cute.TiledMma, iter_n: int + ) -> tuple[ + cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma + ]: + """Tensor memory load partition for rescale and epilogue. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param tiled_mma_pv: The tiled mma pv + :type tiled_mma_pv: cute.TiledMma + :param iter_n: The iteration number + :type iter_n: int + + :return: The tiled mma pv, the tiled mma pv, the tiled mma pv, the tiled mma pv, the tiled mma pv + :rtype: tuple[cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma, cute.TiledMma] + """ + + tOtO_shape = tiled_mma_pv.partition_shape_C( + cute.select(self.mma_pv_tiler, mode=[0, 1]) + ) + tOtO = tiled_mma_pv.make_fragment_C(tOtO_shape) + tOtO_layout = cute.append( + tOtO.layout, + cute.make_layout( + common_params.L // self.mma_pv_tiler[1], + stride=self.mma_pv_tiler[1] // self.warps_in_n, + ), + ) + tOtO = cute.make_tensor( + common_params.tmem_ptr + self.tmem_o_offset, tOtO_layout + ) + tOtO = tOtO[None, None, None, iter_n] + + tAcc = tOtO[(None, None), 0, 0] + + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype + ) + tmem_load_tiled_copy = tcgen05.make_tmem_copy(tmem_load_atom, tAcc) + # {$nv-internal-release begin} + # TODO: supports size() on tiled copy. + # {$nv-internal-release end} + tmem_load_thr_copy = tmem_load_tiled_copy.get_slice( + common_params.tidx % (self.num_compute_warps * self.threads_per_warp) + ) + + cta_pv_tiler = ( + self.mma_pv_tiler[0] // self.cluster_shape_mnk[0], + self.mma_pv_tiler[1], + self.mma_pv_tiler[2], + ) + # Flatten divide and partition global tensors for O + cta_pv_tiler_mn = cute.select(cta_pv_tiler, mode=[0, 1]) + + gO = None + if cutlass.const_expr(common_params.mAccO is not None): + gO = cute.local_tile( + common_params.mAccO[None, common_params.blk_coord[3], None, None, None], + cta_pv_tiler_mn, + ( + common_params.blk_coord[0], + iter_n, + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + ) + cO = cute.local_tile( + cute.make_identity_tensor( + common_params.mAccO[ + None, common_params.blk_coord[3], None, None, None + ].shape + ), + cta_pv_tiler_mn, + ( + common_params.blk_coord[0], + iter_n, + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + ) + else: + gO = cute.local_tile( + common_params.mO, + cta_pv_tiler_mn, + ( + common_params.blk_coord[0], + iter_n, + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + ) + cO = cute.local_tile( + cute.make_identity_tensor(common_params.mO.shape), + cta_pv_tiler_mn, + ( + common_params.blk_coord[0], + iter_n, + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + ) + tTR_tAcc = tmem_load_thr_copy.partition_S(tAcc) + tTR_gO = tmem_load_thr_copy.partition_D(gO) + tTR_cO = tmem_load_thr_copy.partition_D(cO) + tTR_rAcc = cute.make_fragment_like(tTR_gO, self.acc_dtype) + return tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc + + def get_correction_factor( + self, + common_params: SimpleNamespace, + p_cor_consumer_state: pipeline.PipelineState, + ) -> tuple[ + pipeline.PipelineState, + cutlass.Float32, + cutlass.Float32, + cutlass.Float32, + cutlass.Int32, + ]: + """Get the correction factor from the P correction consumer state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param p_cor_consumer_state: The P correction consumer state + :type p_cor_consumer_state: pipeline.PipelineState + + :return: The P correction consumer state, the row_sum, the row_max, and the correction factor + :rtype: tuple[pipeline.PipelineState, cutlass.Float32, cutlass.Float32, cutlass.Float32, cutlass.Int32] + """ + common_params.p_cor_pipeline.consumer_wait(p_cor_consumer_state) + tidx = common_params.tidx % (self.num_compute_warps * self.threads_per_warp) + # load correction factor + _, tAcc, _, _, _, _ = self._tmem_load_partition( + common_params, common_params.tiled_mma_pv, 0 + ) + corr_layout = cute.make_layout( + (tAcc.shape[0], (4, tAcc.shape[1][1]), self.p_cor_stage), + stride=(tAcc.stride[0], (1, tAcc.stride[1][1]), 4), + ) + tCor = cute.make_tensor( + common_params.tmem_ptr + self.correction_factor_offset, corr_layout + ) + cCor = cute.make_identity_tensor(tCor.shape) + corr_tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(4)), self.acc_dtype + ) + corr_tmem_load_tiled_copy = tcgen05.make_tmem_copy(corr_tmem_load_atom, tCor) + corr_tmem_load_thr_copy = corr_tmem_load_tiled_copy.get_slice(tidx) + tCor_for_copy = corr_tmem_load_thr_copy.partition_S(tCor) + cCor_for_copy = corr_tmem_load_thr_copy.partition_D(cCor) + rCor = cute.make_fragment_like( + cCor_for_copy[None, None, None, 0], self.acc_dtype + ) + rCor_int = cute.make_tensor( + cute.recast_ptr(rCor.iterator, dtype=cutlass.Int32), rCor.layout + ) + cute.copy( + corr_tmem_load_tiled_copy, + tCor_for_copy[None, None, None, p_cor_consumer_state.index], + rCor, + ) + row_sum = rCor[0] + row_max = rCor[1] + correction_factor = rCor[2] + no_correction = rCor_int[3] + + common_params.p_cor_pipeline.consumer_release(p_cor_consumer_state) + p_cor_consumer_state.advance() + return p_cor_consumer_state, row_sum, row_max, correction_factor, no_correction + + @cute.jit + def rescale( + self, + common_params: SimpleNamespace, + mma_o_consumer_state: pipeline.PipelineState, + correction_factor: cutlass.Float32, + no_correction: cutlass.Int32, + ) -> pipeline.PipelineState: + """Rescale for one k-tile. Updates the related pipeline state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param mma_o_consumer_state: The mma o consumer state + :type mma_o_consumer_state: pipeline.PipelineState + :param correction_factor: The correction factor + :type correction_factor: cutlass.Float32 + :param no_correction: Whether to apply correction factor + :type no_correction: cutlass.Int32 + + :return: The MMA o consumer state + :rtype: pipeline.PipelineState + """ + skip_correction = cute.arch.vote_all_sync(no_correction == 1) + for iter_n in cutlass.range_constexpr(self.iterations_pv_n): + common_params.mma_o_pipeline.consumer_wait(mma_o_consumer_state) + if not skip_correction: + # tmem load tiled copy and partition results. + tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc = ( + self._tmem_load_partition( + common_params, common_params.tiled_mma_pv, iter_n + ) + ) + + # tmem store tiled copy + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.acc_dtype + ) + tmem_store_tiled_copy = tcgen05.make_tmem_copy(tmem_store_atom, tAcc) + + # load o + cute.copy(tmem_load_tiled_copy, tTR_tAcc, tTR_rAcc) + # rescale, using `mul_packed_f32x2` to reduce the number of instructions + for i in cutlass.range( + cute.size(tTR_rAcc), vectorize=True, unroll_full=True + ): + tTR_rAcc[i] = tTR_rAcc[i] * correction_factor + + # store o to tensor memory for next k tile + cute.copy(tmem_store_tiled_copy, tTR_rAcc, tTR_tAcc) + + cute.arch.fence_view_async_tmem_store() + common_params.mma_o_pipeline.consumer_release(mma_o_consumer_state) + mma_o_consumer_state.advance() + + return mma_o_consumer_state + + @cute.jit + def epilogue( + self, + common_params: SimpleNamespace, + epilogue_params: SimpleNamespace, + mma_o_consumer_state: pipeline.PipelineState, + row_sum: cutlass.Float32, + row_max: cutlass.Float32, + ) -> pipeline.PipelineState: + """Epilogue for one k-tile. Updates the related pipeline state. + + :param common_params: The common parameters + :type common_params: SimpleNamespace + :param epilogue_params: The epilogue parameters + :type epilogue_params: SimpleNamespace + :param mma_o_consumer_state: The mma o consumer state + :type mma_o_consumer_state: pipeline.PipelineState + :param row_sum: The row sum + :type row_sum: cutlass.Float32 + :param row_max: The row max + :type row_max: cutlass.Float32 + + :return: The MMA o consumer state + :rtype: pipeline.PipelineState + """ + + tidx = common_params.tidx % (self.num_compute_warps * self.threads_per_warp) + + # exchange row_sum between warps (0, 1) and (2, 3) + if cutlass.const_expr(self.warps_in_n == 2): + common_params.smem_exchange[tidx] = row_sum + self.epilogue_exchange_sync_bar.wait() + # (64, 2) + row_sum = ( + row_sum + + common_params.smem_exchange[ + (tidx + 64) % (self.num_compute_warps * self.threads_per_warp) + ] + ) + # mma_o pipeline consumer wait + for iter_n in cutlass.range_constexpr(self.iterations_pv_n): + common_params.mma_o_pipeline.consumer_wait(mma_o_consumer_state) + # tmem load tiled copy and partition results. + tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc = ( + self._tmem_load_partition( + common_params, common_params.tiled_mma_pv, iter_n + ) + ) + + # load o + cute.copy(tmem_load_tiled_copy, tTR_tAcc, tTR_rAcc) + + # apply output scale and normalize by row_sum + for i in cutlass.range( + cute.size(tTR_rAcc), vectorize=True, unroll_full=True + ): + tTR_rAcc[i] = ( + tTR_rAcc[i] + * epilogue_params.output_scale + * cute.arch.rcp_approx(row_sum) + ) + + # store o to global memory + tR2G_rO_src = None + tR2G_rO_dst = tTR_gO + if cutlass.const_expr(common_params.mAccO is None): + tR2G_rO_src = cute.make_fragment_like(tTR_gO, self.o_dtype) + # using final output dtype for o + tR2G_rO_src.store(tTR_rAcc.load().to(self.o_dtype)) + else: + # using accumulate dtype for o + tR2G_rO_src = tTR_rAcc + + if cute.elem_less(tTR_cO[0][0], common_params.H): + cute.autovec_copy( + tR2G_rO_src, + tR2G_rO_dst, + l1c_evict_priority=cute.nvgpu.CacheEvictionPriority.NO_ALLOCATE, + ) + + # store the lse to global memory + cta_pv_tiler = ( + self.mma_pv_tiler[0] // self.cluster_shape_mnk[0], + self.mma_pv_tiler[1], + self.mma_pv_tiler[2], + ) + gLSE = None + cLSE = None + if cutlass.const_expr(epilogue_params.mAccLSE is None): + gLSE = cute.local_tile( + epilogue_params.mLSE, + (cta_pv_tiler[0], 1, 1), + ( + common_params.blk_coord[0], + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + (1, 1, 1), + ) + cLSE = cute.local_tile( + cute.make_identity_tensor(epilogue_params.mLSE.shape), + (cta_pv_tiler[0], 1, 1), + ( + common_params.blk_coord[0], + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + (1, 1, 1), + ) + + else: + gLSE = cute.local_tile( + epilogue_params.mAccLSE[ + None, common_params.blk_coord[3], None, None + ], + (cta_pv_tiler[0], 1, 1), + ( + common_params.blk_coord[0], + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + (1, 1, 1), + ) + cLSE = cute.local_tile( + cute.make_identity_tensor( + epilogue_params.mAccLSE[ + None, common_params.blk_coord[3], None, None + ].shape + ), + (cta_pv_tiler[0], 1, 1), + ( + common_params.blk_coord[0], + common_params.blk_coord[1], + common_params.blk_coord[2], + ), + (1, 1, 1), + ) + lse = ( + cute.math.log2(row_sum, fastmath=True) + + epilogue_params.softmax_scale_log2 * row_max + ) + if cutlass.const_expr(self.warps_in_n == 2): + if cute.elem_less(cLSE[tidx][0], common_params.H): + gLSE[tidx] = lse + + cute.arch.fence_view_async_tmem_load() + common_params.mma_o_pipeline.consumer_release(mma_o_consumer_state) + mma_o_consumer_state.advance() + + return mma_o_consumer_state + + def make_and_init_load_qkv_pipeline( + self, load_qkv_mbar_ptr, cta_layout_vmnk, load_stages, tx_count + ) -> pipeline.PipelineTmaUmma: + """Create and initialize the tma load qkv pipeline. + + :param load_qkv_mbar_ptr: The load qkv mbar pointer + :type load_qkv_mbar_ptr: cute.Tensor + :param cta_layout_vmnk: The cta layout vmnk + :type cta_layout_vmnk: tuple[int, int, int] + :param load_stages: The load stages + :type load_stages: list[int] + :param tx_count: The tx count + :type tx_count: int + + :return: The tma load qkv pipeline + :rtype: pipeline.PipelineTmaUmma + """ + load_qkv_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.load_tma_k_warp_id]) + ) + load_qkv_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + return pipeline.PipelineTmaUmma.create( + barrier_storage=load_qkv_mbar_ptr, + num_stages=load_stages, + producer_group=load_qkv_producer_group, + consumer_group=load_qkv_consumer_group, + tx_count=tx_count, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + + def make_and_init_mma_s_pipeline( + self, mma_s_mbar_ptr, cta_layout_vmnk + ) -> pipeline.PipelineUmmaAsync: + """Create and initialize the mma s pipeline. + + :param mma_s_mbar_ptr: The mma s mbar pointer + :type mma_s_mbar_ptr: cute.Tensor + :param cta_layout_vmnk: The cta layout vmnk + :type cta_layout_vmnk: tuple[int, int, int] + + :return: The mma s pipeline + :rtype: pipeline.PipelineUmmaAsync + """ + + mma_s_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + consumer_thread_size = ( + self.threads_per_warp + * len(self.compute_warp_ids) + * self.cluster_shape_mnk[0] + ) + mma_s_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + consumer_thread_size, + ) + return pipeline.PipelineUmmaAsync.create( + barrier_storage=mma_s_mbar_ptr, + num_stages=self.mma_s_stage, + producer_group=mma_s_producer_group, + consumer_group=mma_s_consumer_group, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + + def make_and_init_p_mma_pipeline( + self, p_mma_mbar_ptr, cta_layout_vmnk + ) -> pipeline.PipelineAsyncUmma: + """Create and initialize the p mma pipeline. + + :param p_mma_mbar_ptr: The p mma mbar pointer + :type p_mma_mbar_ptr: cute.Tensor + :param cta_layout_vmnk: The cta layout vmnk + :type cta_layout_vmnk: tuple[int, int, int] + + :return: The p mma pipeline + :rtype: pipeline.PipelineAsyncUmma + """ + + producer_thread_size = ( + self.threads_per_warp + * len(self.compute_warp_ids) + * self.cluster_shape_mnk[0] + ) + p_mma_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + producer_thread_size, + ) + p_mma_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + return pipeline.PipelineAsyncUmma.create( + barrier_storage=p_mma_mbar_ptr, + num_stages=self.p_mma_stage, + producer_group=p_mma_producer_group, + consumer_group=p_mma_consumer_group, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + + def make_and_init_p_cor_pipeline( + self, p_cor_mbar_ptr + ) -> pipeline.PipelineAsyncUmma: + """Create and initialize the p correction pipeline. + + :param p_cor_mbar_ptr: The p correction mbar pointer + :type p_cor_mbar_ptr: cute.Tensor + + :return: The p correction pipeline + :rtype: pipeline.PipelineAsyncUmma + """ + + producer_thread_size = self.threads_per_warp * len(self.compute_warp_ids) + p_cor_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + producer_thread_size, + ) + p_cor_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + producer_thread_size, + ) + return pipeline.PipelineAsync.create( + barrier_storage=p_cor_mbar_ptr, + num_stages=self.p_cor_stage, + producer_group=p_cor_producer_group, + consumer_group=p_cor_consumer_group, + defer_sync=True, + ) + + def make_and_init_mma_o_pipeline( + self, mma_o_mbar_ptr, cta_layout_vmnk + ) -> pipeline.PipelineUmmaAsync: + """Create and initialize the mma o pipeline. + + :param mma_o_mbar_ptr: The mma o mbar pointer + :type mma_o_mbar_ptr: cute.Tensor + :param cta_layout_vmnk: The cta layout vmnk + :type cta_layout_vmnk: tuple[int, int, int] + + :return: The mma o pipeline + :rtype: pipeline.PipelineUmmaAsync + """ + + mma_o_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_warp_id]) + ) + consumer_thread_size = ( + self.threads_per_warp + * len(self.compute_warp_ids) + * self.cluster_shape_mnk[0] + ) + mma_o_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + consumer_thread_size, + ) + return pipeline.PipelineUmmaAsync.create( + barrier_storage=mma_o_mbar_ptr, + num_stages=self.mma_o_stage, + producer_group=mma_o_producer_group, + consumer_group=mma_o_consumer_group, + cta_layout_vmnk=cta_layout_vmnk, + defer_sync=True, + ) + + @staticmethod + def _compute_grid( + o: cute.Tensor, + split_kv: cutlass.Int32, + cluster_shape_mnk: Tuple[int, int, int], + max_active_clusters: int, + is_persistent: bool, + ) -> Tuple[MLAStaticTileSchedulerParams, Tuple[int, int, int]]: + """Compute grid shape for the output tensor C. + + :param c: The output tensor C + :type c: cute.Tensor + :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type cta_tile_shape_mnk: tuple[int, int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + + :return: Tile scheduler parameters and grid shape. + :rtype: tuple[MLAStaticTileSchedulerParams, tuple[int, int, int]] + """ + o_shape = o.shape + tile_sched_params = create_mla_static_tile_scheduler_params( + is_persistent, + cute.size(o_shape[3]), + cute.size(o_shape[2]), + cluster_shape_mnk, + split_kv, + ) + grid = MLAStaticTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + + return tile_sched_params, grid + + @staticmethod + def get_workspace_size( + H: int, + S: int, + D: int, + B: int, + split_kv: int, + acc_dtype: Type[cutlass.Numeric], + ) -> int: + """Get the extra workspace(device memory) size for the MLA kernel when split_kv is not 1. + + :param H: The height of the output tensor C + :type H: int + :param S: The sequence length of the output tensor C + :type S: int + :param D: The depth of the output tensor C + :type D: int + :param B: The batch size of the output tensor C + :type B: int + :param split_kv: The split key-value of the output tensor C + :type split_kv: int + :param acc_dtype: The data type of the output tensor C + :type acc_dtype: Type[cutlass.Numeric] + + :return: The workspace size for the MLA kernel + :rtype: int + """ + if split_kv == 1: + return 0 + return B * H * S * split_kv * (D + 1) * acc_dtype.width // 8 + + @cute.jit + def initialize_workspace( + self, + H: cutlass.Int32, + D: cutlass.Int32, + S: cutlass.Int32, + B: cutlass.Int32, + split_kv: cutlass.Int32, + acc_dtype: Type[cutlass.Numeric], + workspace: cute.Tensor, + ) -> tuple[cute.Tensor, cute.Tensor]: + """Initialize the workspace for the MLA kernel. Construct the intermediate tensors + acc_o and acc_lse. + + :param H: The height of the output tensor C + :type H: cutlass.Int32 + :param D: The depth of the output tensor C + :type D: cutlass.Int32 + :param S: The sequence length of the output tensor C + :type S: cutlass.Int32 + :param B: The batch size of the output tensor C + :type B: cutlass.Int32 + :param split_kv: The split key-value of the output tensor C + :type split_kv: cutlass.Int32 + :param acc_dtype: The data type of the output tensor C + :type acc_dtype: Type[cutlass.Numeric] + :param workspace: The workspace tensor + :type workspace: cute.Tensor + + :return: The output tensor C and the workspace tensor + :rtype: tuple[cute.Tensor, cute.Tensor] + """ + acc_o, acc_lse = None, None + if cutlass.const_expr(workspace is not None): + align = 256 // self.q_dtype.width + acc_o_layout = cute.make_layout( + (H, split_kv, D, S, B), + stride=( + cute.assume(split_kv * D, align), + cute.assume(D, align), + 1, + cute.assume(split_kv * H * D, align), + cute.assume(H * split_kv * S * D, align), + ), + ) + acc_o_iter = cute.recast_ptr(workspace.iterator, dtype=acc_dtype) + acc_o = cute.make_tensor(acc_o_iter, acc_o_layout) + acc_lse_layout = cute.make_layout( + (H, split_kv, S, B), + stride=(split_kv, 1, H * split_kv, H * split_kv * S), + ) + acc_lse_iter = cute.recast_ptr( + workspace.iterator + cute.cosize(acc_o_layout) * acc_dtype.width // 8, + dtype=acc_dtype, + ) + acc_lse = cute.make_tensor(acc_lse_iter, acc_lse_layout) + return acc_o, acc_lse + + @staticmethod + def can_implement( + B: int, + S: int, + K: int, + H: int, + L: int, + R: int, + in_dtype: Type[cutlass.Numeric], + out_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + lse_dtype: Type[cutlass.Numeric], + mma_qk_tiler_mn: Tuple[int, int], + mma_pv_tiler_mn: Tuple[int, int], + split_kv: int, + is_persistent: bool, + is_var_seq: bool, + is_var_split_kv: bool, + page_size: int, + ) -> bool: + """Check if the MLA kernel can be implemented. + + :param B: The batch size of the output tensor C + :type B: int + :param S: The sequence length of the output tensor C + :type S: int + :param K: The width of the output tensor KV + :type K: int + :param H: The number of heads of the output tensor C + :type H: int + :param L: The number of latent dimensions of the tensor KV + :type L: int + :param R: The number of rope dimensions of the tensor C_rope + :type R: int + :param in_dtype: The data type of the input tensor + :type in_dtype: Type[cutlass.Numeric] + :param out_dtype: The data type of the output tensor + :type out_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param lse_dtype: The data type of the log-sum-exp + :type lse_dtype: Type[cutlass.Numeric] + :param mma_qk_tiler_mn: The tile shape of the query-key matrix multiplication + :type mma_qk_tiler_mn: Tuple[int, int] + :param mma_pv_tiler_mn: The tile shape of the probability-value matrix multiplication + :type mma_pv_tiler_mn: Tuple[int, int] + :param split_kv: The split key-value of the output tensor C + :type split_kv: int + :param is_persistent: Whether to use persistent kernel optimization + :type is_persistent: bool + :param is_var_seq: Whether to use variable sequence length + :type is_var_seq: bool + :param is_var_split_kv: Whether to use variable split_kv + :type is_var_split_kv: bool + :param page_size: The page size of the page table + :type page_size: int + + :return: Whether the MLA kernel can be implemented + :rtype: bool + """ + if L != 512 or R != 64: + return False + if in_dtype not in [cutlass.Float8E4M3FN]: + return False + if out_dtype not in [cutlass.Float8E4M3FN]: + return False + if acc_dtype != cutlass.Float32 or lse_dtype != cutlass.Float32: + return False + # page size equals 1 is prohibited by tma specification, not 128B aligned. + if mma_qk_tiler_mn[1] % page_size != 0 or page_size == 1: + return False + if mma_qk_tiler_mn[0] != mma_pv_tiler_mn[0] or mma_qk_tiler_mn[0] != 128: + return False + if is_var_split_kv and not is_var_seq: + return False + if H > 128 or (H < 128 and split_kv != 1): + return False + if S <= 0 or S > 4: + return False + if K <= 0: + return False + return True + + +def run( + batch_size: int, + seq_len_q: int, + seq_len_k: int, + num_heads: int, + latent_dim: int, + rope_dim: int, + in_dtype: Type[cutlass.Numeric], + out_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + lse_dtype: Type[cutlass.Numeric], + mma_qk_tiler_mn: Tuple[int, int], + mma_pv_tiler_mn: Tuple[int, int], + split_kv: int, + is_persistent: bool, + is_var_seq: bool, + is_var_split_kv: bool, + page_size: int, + softmax_scale: float, + output_scale: float, + skip_correction_threshold: float, + tolerance: float, + warmup_iterations: int, + iterations: int, + skip_ref_check: bool, + use_cold_l2: bool, + **kwargs, +): + """Execute Multi-Head Latent Attention (MLA) on Blackwell architecture and validate results. + + This function creates random input tensors for query latent/rope, compressed latent/rope, and value, + then performs the complete MLA computation pipeline. It supports configurable data types, tiling parameters, + page table, variable sequence length, and variable split_kv. Results can be validated against a PyTorch reference + implementation or run multiple times for performance measurement. + + :param batch_size: Batch size + :type batch_size: int + :param seq_len_q: Sequence length of Q + :type seq_len_q: int + :param seq_len_k: Sequence length of K + :type seq_len_k: int + :param num_heads: Number of heads + :type num_heads: int + :param latent_dim: dimension of query/compressed latent + :type latent_dim: int + :param rope_dim: dimension of query/compressed rope + :type rope_dim: int + :param in_dtype: Input data type for query/compressed latent/rope tensors + :type in_dtype: Type[cutlass.Numeric] + :param out_dtype: Output data type for attention output + :type out_dtype: Type[cutlass.Numeric] + :param acc_dtype: Accumulator data type for query-key matrix multiplication + :type acc_dtype: Type[cutlass.Numeric] + :param lse_dtype: Accumulator data type for log-sum-exp + :type lse_dtype: Type[cutlass.Numeric] + :param mma_qk_tiler_mn: Matrix multiply accumulate tile shape (M, N) for query-key matrix multiplication + :type mma_qk_tiler_mn: Tuple[int, int] + :param mma_pv_tiler_mn: Matrix multiply accumulate tile shape (M, N) for probability-value matrix multiplication + :type mma_pv_tiler_mn: Tuple[int, int] + :param split_kv: Split key-value + :type split_kv: int + :param is_persistent: Whether to use persistent kernel optimization + :type is_persistent: bool + :param is_var_seq: Whether to use variable sequence length + :type is_var_seq: bool + :param is_var_split_kv: Whether to use variable split_kv + :type is_var_split_kv: bool + :param page_size: Page size of the page table + :type page_size: int + :param softmax_scale: Attention score scaling factor + :type softmax_scale: float + :param output_scale: Output scaling factor + :type output_scale: float + :param skip_correction_threshold: Threshold to skip correction + :type skip_correction_threshold: float + :param tolerance: Maximum acceptable error for validation + :type tolerance: float + :param warmup_iterations: Number of warmup iterations + :type warmup_iterations: int + :param iterations: Number of iterations to run for performance testing + :type iterations: int + :param skip_ref_check: Skip validation against reference implementation + :type skip_ref_check: bool + :param use_cold_l2: Whether to use cold L2 cache + :type use_cold_l2: bool + + :raises ValueError: If input shapes are incompatible or head dimension is unsupported + :raises RuntimeError: If GPU is unavailable for computation + """ + + print("Running Blackwell MLA test with:") + print(f" batch_size: {batch_size}") + print(f" seq_len_q: {seq_len_q}") + print(f" seq_len_k: {seq_len_k}") + print(f" num_heads: {num_heads}") + print(f" latent_dim: {latent_dim}") + print(f" rope_dim: {rope_dim}") + print(f" in_dtype: {in_dtype}") + print(f" out_dtype: {out_dtype}") + print(f" acc_dtype: {acc_dtype}") + print(f" mma_qk_tiler_mn: {mma_qk_tiler_mn}") + print(f" mma_pv_tiler_mn: {mma_pv_tiler_mn}") + print(f" split_kv: {split_kv}") + print(f" is_persistent: {is_persistent}") + print(f" is_var_seq: {is_var_seq}") + print(f" is_var_split_kv: {is_var_split_kv}") + print(f" page_size: {page_size}") + print(f" softmax_scale: {softmax_scale}") + print(f" output_scale: {output_scale}") + print(f" skip_correction_threshold: {skip_correction_threshold}") + print(f" tolerance: {tolerance}") + print(f" warmup_iterations: {warmup_iterations}") + print(f" iterations: {iterations}") + print(f" skip_ref_check: {skip_ref_check}") + print(f" use_cold_l2: {use_cold_l2}") + + import torch + import cutlass.torch as cutlass_torch + + # Prepare pytorch tensors: Q, K, V (random from 0 to 2) and O (all zero) + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + if not BlackwellMultiHeadLatentAttentionForwardFP8.can_implement( + batch_size, + seq_len_q, + seq_len_k, + num_heads, + latent_dim, + rope_dim, + in_dtype, + out_dtype, + acc_dtype, + lse_dtype, + mma_qk_tiler_mn, + mma_pv_tiler_mn, + split_kv, + is_persistent, + is_var_seq, + is_var_split_kv, + page_size, + ): + raise TypeError( + f"Unsupported testcase {batch_size}, {seq_len_q}, {seq_len_k}, {num_heads}, {latent_dim}, {rope_dim}, {in_dtype}, {out_dtype}, {acc_dtype}, {lse_dtype}, {mma_qk_tiler_mn}, {mma_pv_tiler_mn}, {split_kv}, {is_persistent}, {is_var_seq}, {is_var_split_kv}, {page_size}" + ) + + torch.manual_seed(1111) + + def create_data_tensor( + B, + HK, + D, + dtype, + is_dynamic_layout=True, + page_table=None, + cache_seqs=None, + is_lse=False, + seq_len_q=None, + ): + shape = (B, HK, D) + if page_table is not None: + if cache_seqs is not None: + max_seq_len = torch.max(cache_seqs) + shape = (B * ceil_div(max_seq_len, page_size), page_size, D) + else: + shape = (B * ceil_div(HK, page_size), page_size, D) + + if seq_len_q is not None: + shape = (B, seq_len_q, HK, D) + + permute_order = (1, 2, 0) + stride_order = (2, 0, 1) + leading_dim = 1 + if is_lse: + shape = (B, seq_len_q, HK) + permute_order = (2, 1, 0) + stride_order = (2, 1, 0) + leading_dim = 0 + elif seq_len_q is not None: + permute_order = (2, 3, 1, 0) + stride_order = (3, 2, 0, 1) + leading_dim = 1 + + init_config = cutlass.torch.RandomInitConfig(min_val=-2, max_val=2) + + torch_dtype = ( + cutlass_torch.dtype(dtype) if dtype != cutlass.Float8E4M3FN else torch.int8 + ) + + # Create dtype torch tensor (cpu) + torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( + shape, + torch_dtype, + permute_order=permute_order, + init_type=cutlass.torch.TensorInitType.RANDOM, + init_config=init_config, + ) + + # Create dtype torch tensor (gpu) + torch_tensor_gpu = torch_tensor_cpu.cuda() + + # Create f32 torch tensor (cpu) + f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32) + + # Create dtype cute tensor (gpu) + cute_tensor = from_dlpack(torch_tensor_gpu, assumed_align=16) + cute_tensor.element_type = dtype + if is_dynamic_layout: + cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=leading_dim) + if not is_lse: + cute_tensor = cute_tensor.mark_compact_shape_dynamic( + mode=leading_dim, + stride_order=stride_order, + divisibility=(128 // dtype.width), + ) + + cute_tensor = cutlass_torch.convert_cute_tensor( + f32_torch_tensor, + cute_tensor, + dtype, + is_dynamic_layout=is_dynamic_layout, + ) + + return f32_torch_tensor, cute_tensor, torch_tensor_gpu + + def create_cache_seqs(batch_size, seq_len_k, is_var_seq): + cache_seqs_ref = torch.ones(batch_size, dtype=torch.int32) * seq_len_k + cache_seqs_gpu = cache_seqs_ref.cuda() + cache_seqs = from_dlpack(cache_seqs_gpu, assumed_align=16).mark_layout_dynamic() + if is_var_seq: + max_seq_len = seq_len_k + min_seq_len = int(seq_len_k * 0.8) + cache_seqs_ref = cutlass_torch.create_and_permute_torch_tensor( + (batch_size,), + torch.int32, + init_type=cutlass.torch.TensorInitType.RANDOM, + init_config=cutlass.torch.RandomInitConfig( + min_val=min_seq_len, max_val=max_seq_len + 1 + ), + ) + cache_seqs_gpu = cache_seqs_ref.cuda() + cache_seqs = from_dlpack( + cache_seqs_gpu, + assumed_align=16, + ).mark_layout_dynamic() + return cache_seqs_ref, cache_seqs, cache_seqs_gpu + + def create_page_table(batch_size, seq_len_k, is_var_seq, page_size): + max_seq_len = seq_len_k if not is_var_seq else torch.max(cache_seqs_ref) + page_count = ceil_div(max_seq_len, page_size) + page_table_ref = torch.empty([batch_size, page_count], dtype=torch.int32) + # use transposed index for page table to make sure the value is in bound of `batch_size * seq_len_block`. In practice, the value could be any positive values. This setting is only for testing purpose. + for b in range(batch_size): + for j in range(page_count): + page_table_ref[b, j] = b + j * batch_size + page_table_gpu = page_table_ref.permute(1, 0).cuda() + page_table = from_dlpack(page_table_gpu, assumed_align=16).mark_layout_dynamic( + leading_dim=0 + ) + return page_table_ref, page_table, page_table_gpu + + def create_block_split_kvs( + batch_size, + split_kv, + cache_seqs_ref, + is_var_split_kv, + mma_qk_tiler_mn, + cluster_shape_mnk, + max_active_clusters, + ): + block_split_kvs_ref, block_split_kvs, block_split_kvs_gpu = None, None, None + # check if split_kv is valid otherwise do auto setting of split_kv + if is_var_split_kv: + block_split_kvs_ref = torch.zeros([batch_size], dtype=torch.int32) + for b in range(batch_size): + block_split_kvs_ref[b] = ( + BlackwellMultiHeadLatentAttentionForwardFP8.get_split_kv( + batch_size, + seq_len_q, + cache_seqs_ref[b].item(), + mma_qk_tiler_mn, + max_active_clusters * cluster_shape_mnk[0], + ) + ) + split_kv = torch.max(block_split_kvs_ref).item() + block_split_kvs_gpu = block_split_kvs_ref.cuda() + block_split_kvs = from_dlpack( + block_split_kvs_gpu, assumed_align=16 + ).mark_layout_dynamic() + elif split_kv <= 0: + split_kv = BlackwellMultiHeadLatentAttentionForwardFP8.get_split_kv( + batch_size, + seq_len_q, + cache_seqs_ref[0].item(), + mma_qk_tiler_mn, + max_active_clusters * cluster_shape_mnk[0], + ) + return split_kv, block_split_kvs_ref, block_split_kvs, block_split_kvs_gpu + + def create_workspace( + num_heads, seq_len_q, latent_dim, batch_size, split_kv, acc_dtype + ): + workspace_size = BlackwellMultiHeadLatentAttentionForwardFP8.get_workspace_size( + num_heads, + seq_len_q, + latent_dim, + batch_size, + split_kv, + acc_dtype, + ) + + workspace, workspace_torch = None, None + if workspace_size > 0: + workspace_torch = torch.empty([workspace_size], dtype=torch.int8).cuda() + workspace = from_dlpack(workspace_torch, assumed_align=32) + return workspace, workspace_torch + + cache_seqs_ref, cache_seqs, cache_seqs_torch = create_cache_seqs( + batch_size, seq_len_k, is_var_seq + ) + page_table_ref, page_table, page_table_torch = create_page_table( + batch_size, seq_len_k, is_var_seq, page_size + ) + cluster_shape_mnk = (2, 1, 1) + hardware_info = utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters( + cluster_shape_mnk[0] * cluster_shape_mnk[1] + ) + split_kv, block_split_kvs_ref, block_split_kvs, block_split_kvs_torch = ( + create_block_split_kvs( + batch_size, + split_kv, + cache_seqs_ref, + is_var_split_kv, + mma_qk_tiler_mn, + cluster_shape_mnk, + max_active_clusters, + ) + ) + + q_latent_ref, q_latent, q_latent_torch = create_data_tensor( + batch_size, + num_heads, + latent_dim, + in_dtype, + is_dynamic_layout=True, + seq_len_q=seq_len_q, + ) + q_rope_ref, q_rope, q_rope_torch = create_data_tensor( + batch_size, + num_heads, + rope_dim, + in_dtype, + is_dynamic_layout=True, + seq_len_q=seq_len_q, + ) + + c_latent_ref, c_latent, c_latent_torch = create_data_tensor( + batch_size, + seq_len_k, + latent_dim, + in_dtype, + is_dynamic_layout=True, + page_table=page_table, + cache_seqs=cache_seqs_ref, + ) + c_rope_ref, c_rope, c_rope_torch = create_data_tensor( + batch_size, + seq_len_k, + rope_dim, + in_dtype, + is_dynamic_layout=True, + page_table=page_table, + cache_seqs=cache_seqs_ref, + ) + o_ref, o, o_torch = create_data_tensor( + batch_size, + num_heads, + latent_dim, + out_dtype, + is_dynamic_layout=True, + seq_len_q=seq_len_q, + ) + lse_ref, lse, lse_torch = create_data_tensor( + batch_size, + num_heads, + 1, + lse_dtype, + is_dynamic_layout=True, + is_lse=True, + seq_len_q=seq_len_q, + ) + workspace, workspace_torch = create_workspace( + num_heads, seq_len_q, latent_dim, batch_size, split_kv, acc_dtype + ) + + mla = BlackwellMultiHeadLatentAttentionForwardFP8( + acc_dtype, + lse_dtype, + mma_qk_tiler_mn, + mma_pv_tiler_mn, + max_active_clusters, + page_size, + skip_correction_threshold, + is_persistent, + is_var_seq, + is_var_split_kv, + ) + + # Get current CUDA stream from PyTorch + torch_stream = torch.cuda.current_stream() + # Get the raw stream pointer as a CUstream + stream = cuda.CUstream(torch_stream.cuda_stream) + + # compile mla kernel + compiled_mla = cute.compile( + mla, + q_latent, + q_rope, + c_latent, + c_rope, + page_table, + o, + lse, + workspace, + split_kv, + cache_seqs, + block_split_kvs, + softmax_scale, + output_scale, + stream, + options="--opt-level 2", + ) + + def torch_reference_mla( + q_latent, + q_rope, + c_latent, + c_rope, + page_table, + cache_seqs, + softmax_scale=1.0, + output_scale=1.0, + ): + # expand and concat q_latent and q_rope to have the dimension of sequence length for q + q_ref = torch.cat([q_latent, q_rope], dim=1).permute(3, 2, 0, 1) + # expand and concat c_latent and c_rope to have the dimension of num_heads for k and v + page_count = page_table_ref.shape[1] + k_ref_paged = ( + torch.cat([c_latent, c_rope], dim=1) + .permute(2, 0, 1) + .reshape(batch_size * page_count, page_size, latent_dim + rope_dim) + ) + v_ref_paged = c_latent.permute(2, 0, 1).reshape( + batch_size * page_count, page_size, latent_dim + ) + + if is_var_seq: + max_seq_len = torch.max(cache_seqs_ref) + else: + max_seq_len = seq_len_k + + k_ref = torch.zeros([batch_size, 1, max_seq_len, latent_dim + rope_dim]) + v_ref = torch.zeros([batch_size, 1, max_seq_len, latent_dim]) + k_ref = torch.index_select( + k_ref_paged, 0, torch.flatten(page_table_ref) + ).reshape(batch_size, 1, -1, latent_dim + rope_dim)[:, :, :max_seq_len, :] + v_ref = torch.index_select( + v_ref_paged, 0, torch.flatten(page_table_ref) + ).reshape(batch_size, 1, -1, latent_dim)[:, :, :max_seq_len, :] + for b in range(batch_size): + k_ref[b, :, cache_seqs_ref[b] :, :] = 0 + v_ref[b, :, cache_seqs_ref[b] :, :] = 0 + import torch.nn.functional as F + + o_ref = F.scaled_dot_product_attention( + q_ref, + k_ref, + v_ref, + attn_mask=None, + dropout_p=0.0, + scale=softmax_scale, + is_causal=False, + ) + s_ref = torch.einsum("bhld,bhsd->bhls", q_ref, k_ref) + s_ref_max, s_ref_max_pos = torch.max(s_ref, dim=-1, keepdim=True) + softmax_scale_log2 = LOG2_E * softmax_scale + s_ref_sum = torch.sum( + torch.exp2((s_ref - s_ref_max) * softmax_scale_log2), dim=-1, keepdim=True + ) + + lse_ref = s_ref_max * softmax_scale_log2 + torch.log2(s_ref_sum) + lse_ref = lse_ref.squeeze(3).permute(2, 1, 0) + o_ref = o_ref * output_scale + o_ref = o_ref.permute(2, 3, 1, 0) + + return o_ref, lse_ref + + if skip_correction_threshold > 0.0: + print( + "Skipping correction verification since skip_correction_threshold is greater than 0.0..." + ) + skip_ref_check = True + if not skip_ref_check: + # Execute kernel once for reference checking + compiled_mla( + q_latent, + q_rope, + c_latent, + c_rope, + page_table, + o, + lse, + workspace, + split_kv, + cache_seqs, + block_split_kvs, + softmax_scale, + output_scale, + stream, + ) + torch.cuda.synchronize() + + print("Verifying results...") + if in_dtype == cutlass.Float8E4M3FN: + tolerance = 0.13 + o_ref, lse_ref = torch_reference_mla( + q_latent_ref, + q_rope_ref, + c_latent_ref, + c_rope_ref, + page_table, + cache_seqs, + softmax_scale, + output_scale, + ) + + if out_dtype in [cutlass.Float8E5M2, cutlass.Float8E4M3FN]: + # {$nv-internal-release begin} + # todo: not sure why, but the below `cute.testing.convert` will cause bus error occasionally in local and ci. + # {$nv-internal-release end} + # convert o back to f32 for comparison + o_fp32, o_fp32_torch = cutlass_torch.cute_tensor_like( + torch.empty(*o_torch.shape, dtype=torch.float32), + cutlass.Float32, + is_dynamic_layout=True, + assumed_align=16, + ) + cute.testing.convert(o, o_fp32) + o = o_fp32_torch.cpu() + ref_fp8, _ = cutlass_torch.cute_tensor_like( + torch.empty( + *o_ref.permute(3, 2, 0, 1).shape, dtype=torch.uint8 + ).permute(2, 3, 1, 0), + out_dtype, + is_dynamic_layout=True, + assumed_align=16, + ) + o_ref_gpu = o_ref.cuda() + o_ref_f32 = from_dlpack(o_ref_gpu).mark_layout_dynamic(leading_dim=1) + + # convert ref : f32 -> fp8 -> f32 + cute.testing.convert(o_ref_f32, ref_fp8) + cute.testing.convert(ref_fp8, o_ref_f32) + + o_ref = o_ref_gpu.cpu() + else: + o = o_torch.cpu().to(torch.float32) + lse = lse_torch.cpu() + lse_ref = lse_ref.to(cutlass.torch.dtype(lse_dtype)) + # Assert close results + torch.testing.assert_close(o, o_ref, atol=tolerance, rtol=1e-05) + torch.testing.assert_close(lse, lse_ref, atol=tolerance, rtol=1e-05) + print("Results verified successfully!") + + def generate_tensors(): + _, cache_seqs, _ = create_cache_seqs(batch_size, seq_len_k, is_var_seq) + _, page_table, _ = create_page_table( + batch_size, seq_len_k, is_var_seq, page_size + ) + _split_kv, _, block_split_kvs, _ = create_block_split_kvs( + batch_size, + split_kv, + cache_seqs_ref, + is_var_split_kv, + mma_qk_tiler_mn, + cluster_shape_mnk, + max_active_clusters, + ) + + _, q_latent, _ = create_data_tensor( + batch_size, + num_heads, + latent_dim, + in_dtype, + is_dynamic_layout=True, + seq_len_q=seq_len_q, + ) + _, q_rope, _ = create_data_tensor( + batch_size, + num_heads, + rope_dim, + in_dtype, + is_dynamic_layout=True, + seq_len_q=seq_len_q, + ) + + _, c_latent, _ = create_data_tensor( + batch_size, + seq_len_k, + latent_dim, + in_dtype, + is_dynamic_layout=True, + page_table=page_table, + cache_seqs=cache_seqs_ref, + ) + _, c_rope, _ = create_data_tensor( + batch_size, + seq_len_k, + rope_dim, + in_dtype, + is_dynamic_layout=True, + page_table=page_table, + cache_seqs=cache_seqs_ref, + ) + _, o, _ = create_data_tensor( + batch_size, + num_heads, + latent_dim, + out_dtype, + is_dynamic_layout=True, + seq_len_q=seq_len_q, + ) + _, lse, _ = create_data_tensor( + batch_size, + num_heads, + 1, + lse_dtype, + is_dynamic_layout=True, + is_lse=True, + seq_len_q=seq_len_q, + ) + workspace, workspace_torch = create_workspace( + num_heads, seq_len_q, latent_dim, batch_size, _split_kv, acc_dtype + ) + return testing.JitArguments( + q_latent, + q_rope, + c_latent, + c_rope, + page_table, + o, + lse, + workspace, + _split_kv, + cache_seqs, + block_split_kvs, + softmax_scale, + output_scale, + stream, + ) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + q_latent_torch.numel() * q_latent_torch.element_size() + + q_rope_torch.numel() * q_rope_torch.element_size() + + c_latent_torch.numel() * c_latent_torch.element_size() + + c_rope_torch.numel() * c_rope_torch.element_size() + + o_torch.numel() * o_torch.element_size() + + lse_torch.numel() * lse_torch.element_size() + + cache_seqs_torch.numel() * cache_seqs_torch.element_size() + ) + one_workspace_bytes += ( + page_table_torch.numel() * page_table_torch.element_size() + ) + if is_var_split_kv: + one_workspace_bytes += ( + block_split_kvs_torch.numel() * block_split_kvs_torch.element_size() + ) + if workspace_torch is not None: + one_workspace_bytes += ( + workspace_torch.numel() * workspace_torch.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + avg_time_us = testing.benchmark( + compiled_mla, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return avg_time_us # Return execution time in microseconds + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: + try: + return tuple(int(x.strip()) for x in s.split(",")) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + def parse_mma_tiler(s: str) -> Tuple[int, int, Tuple[int, int]]: + ret = parse_comma_separated_ints(s) + if len(ret) != 2: + raise argparse.ArgumentTypeError( + "Invalid format. Expected 2 comma-separated integers." + ) + return (ret[0], ret[1]) + + parser = argparse.ArgumentParser(description="Example of MLA on Blackwell.") + + parser.add_argument( + "--in_dtype", + type=cutlass.dtype, + default=cutlass.Float8E4M3FN, + help="Input data type", + ) + + parser.add_argument( + "--out_dtype", + type=cutlass.dtype, + default=cutlass.Float8E4M3FN, + help="Output data type", + ) + + parser.add_argument( + "--acc_dtype", + type=cutlass.dtype, + default=cutlass.Float32, + help="Accumulator data type", + ) + + parser.add_argument( + "--lse_dtype", + type=cutlass.dtype, + default=cutlass.Float32, + help="LSE data type", + ) + parser.add_argument( + "--mma_qk_tiler_mn", + type=parse_mma_tiler, + default=(128, 128), + help="MMA tile shape (H, K)", + ) + parser.add_argument( + "--mma_pv_tiler_mn", + type=parse_mma_tiler, + default=(128, 256), + help="MMA tile shape (H, D)", + ) + + parser.add_argument( + "--is_persistent", + action="store_true", + help="Is persistent", + ) + + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="Batch size", + ) + + parser.add_argument( + "--seq_len_q", + type=int, + default=1, + help="Sequence length of Q", + ) + + parser.add_argument( + "--seq_len_k", + type=int, + default=128, + help="Sequence length of K/V", + ) + + parser.add_argument( + "--num_heads", + type=int, + default=128, + help="Number of heads of Q", + ) + + parser.add_argument( + "--latent_dim", + type=int, + default=512, + help="Latent dimension of Q/C", + ) + + parser.add_argument( + "--rope_dim", + type=int, + default=64, + help="Rope dimension of Q/C", + ) + + parser.add_argument( + "--is_var_seq", + action="store_true", + help="Use variable length of sequence length or not", + ) + + parser.add_argument( + "--is_var_split_kv", + action="store_true", + help="Use variable length of split kv or not", + ) + + parser.add_argument( + "--page_size", + type=int, + default=128, + help="Page size of page table", + ) + + parser.add_argument( + "--split_kv", + type=int, + default=-1, + help="Split KV setting", + ) + + parser.add_argument( + "--softmax_scale", + type=float, + default=0.0416, + help="Scaling factor to scale softmax", + ) + + parser.add_argument( + "--output_scale", + type=float, + default=1.0, + help="Scaling factor to scale output", + ) + parser.add_argument( + "--skip_correction_threshold", + type=float, + default=0.0, + help="Threshold to skip correction", + ) + + parser.add_argument( + "--tolerance", type=float, default=1e-02, help="Tolerance for validation" + ) + + parser.add_argument( + "--warmup_iterations", + type=int, + default=0, + help="Number of iterations for warmup", + ) + + parser.add_argument( + "--iterations", + type=int, + default=1, + help="Number of iterations after warmup", + ) + + parser.add_argument( + "--skip_ref_check", + action="store_true", + help="Skip reference check", + ) + + parser.add_argument( + "--use_cold_l2", + action="store_true", + help="Use cold L2 cache", + ) + + args = parser.parse_args() + + run( + args.batch_size, + args.seq_len_q, + args.seq_len_k, + args.num_heads, + args.latent_dim, + args.rope_dim, + args.in_dtype, + args.out_dtype, + args.acc_dtype, + args.lse_dtype, + args.mma_qk_tiler_mn, + args.mma_pv_tiler_mn, + args.split_kv, + args.is_persistent, + args.is_var_seq, + args.is_var_split_kv, + args.page_size, + args.softmax_scale, + args.output_scale, + args.skip_correction_threshold, + args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + ) + + print("PASS") diff --git a/flashinfer/cute_dsl/mla_helpers.py b/flashinfer/cute_dsl/mla_helpers.py new file mode 100644 index 0000000000..1790b3c882 --- /dev/null +++ b/flashinfer/cute_dsl/mla_helpers.py @@ -0,0 +1,304 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +import cutlass +import cutlass.cute as cute + + +class MLAStaticTileSchedulerParams: + def __init__( + self, + is_persistent: bool, + problem_shape_b: cute.Int32, + problem_shape_s: cute.Int32, + cluster_shape_mnk: cute.Shape, + split_kv: cutlass.Int32, + *, + problem_shape_b_fdd: cute.FastDivmodDivisor = None, + problem_shape_s_fdd: cute.FastDivmodDivisor = None, + split_kv_fdd: cute.FastDivmodDivisor = None, + loc=None, + ip=None, + ): + """The static tile scheduler parameters prepared for MLA static tile scheduler. + + :param is_persistent: Whether to use persistent kernel mode + :type is_persistent: bool + :param problem_shape_b: The shape of the problem + :type problem_shape_b: cute.Int32 + :param problem_shape_s: The shape of the problem in sequence length Q dimension + :type problem_shape_s: cute.Int32 + :param cluster_shape_mnk: The shape of the cluster + :type cluster_shape_mnk: cute.Shape + :param split_kv: The scalar factor for split KV + """ + self.is_persistent = is_persistent + self.problem_shape_b = problem_shape_b + self.problem_shape_s = problem_shape_s + self.problem_shape_b_fdd = problem_shape_b_fdd + self.problem_shape_s_fdd = problem_shape_s_fdd + self.cluster_shape_mnk = cluster_shape_mnk + self.split_kv = split_kv + self.split_kv_fdd = split_kv_fdd + if cutlass.const_expr(problem_shape_b_fdd is None): + self.problem_shape_b_fdd = cute.fast_divmod_create_divisor( + problem_shape_b, loc=loc, ip=ip + ) + if cutlass.const_expr(problem_shape_s_fdd is None): + self.problem_shape_s_fdd = cute.fast_divmod_create_divisor( + problem_shape_s, loc=loc, ip=ip + ) + if cutlass.const_expr(split_kv_fdd is None): + self.split_kv_fdd = cute.fast_divmod_create_divisor( + split_kv, loc=loc, ip=ip + ) + self.loc = loc + self.ip = ip + + def __extract_mlir_values__(self): + values = cutlass.extract_mlir_values(self.problem_shape_b) + values += cutlass.extract_mlir_values(self.problem_shape_s) + values += cutlass.extract_mlir_values(self.split_kv) + values += cutlass.extract_mlir_values(self.problem_shape_b_fdd) + values += cutlass.extract_mlir_values(self.problem_shape_s_fdd) + values += cutlass.extract_mlir_values(self.split_kv_fdd) + return values + + def __new_from_mlir_values__(self, values): + problem_shape_b = cutlass.new_from_mlir_values( + self.problem_shape_b, (values[0],) + ) + problem_shape_s = cutlass.new_from_mlir_values( + self.problem_shape_s, (values[1],) + ) + split_kv = cutlass.new_from_mlir_values(self.split_kv, (values[2],)) + problem_shape_b_fdd = cutlass.new_from_mlir_values( + self.problem_shape_b_fdd, (values[3],) + ) + problem_shape_s_fdd = cutlass.new_from_mlir_values( + self.problem_shape_s_fdd, (values[4],) + ) + split_kv_fdd = cutlass.new_from_mlir_values(self.split_kv_fdd, (values[5],)) + return MLAStaticTileSchedulerParams( + self.is_persistent, + problem_shape_b, + problem_shape_s, + self.cluster_shape_mnk, + split_kv, + problem_shape_b_fdd=problem_shape_b_fdd, + problem_shape_s_fdd=problem_shape_s_fdd, + split_kv_fdd=split_kv_fdd, + loc=self.loc, + ) + + +def create_mla_static_tile_scheduler_params( + is_persistent: bool, + problem_shape_b: cute.Int32, + problem_shape_s: cute.Int32, + cluster_shape_mnk: cute.Shape, + split_kv: cutlass.Int32, +) -> MLAStaticTileSchedulerParams: + return MLAStaticTileSchedulerParams( + is_persistent, problem_shape_b, problem_shape_s, cluster_shape_mnk, split_kv + ) + + +class WorkTileInfo: + def __init__(self, blk_coord: cute.Coord, is_valid: bool): + self.blk_coord = blk_coord + self.is_valid = cutlass.Boolean(is_valid) + + def __extract_mlir_values__(self): + values = cutlass.extract_mlir_values(self.blk_coord) + values += cutlass.extract_mlir_values(self.is_valid) + return values + + def __new_from_mlir_values__(self, values): + new_tile_idx = cutlass.new_from_mlir_values(self.blk_coord, values[:-1]) + new_is_valid_tile = cutlass.new_from_mlir_values(self.is_valid, [values[-1]]) + return WorkTileInfo(new_tile_idx, new_is_valid_tile) + + @property + def is_valid_tile(self) -> cutlass.Boolean: + return self.is_valid + + @property + def tile_idx(self) -> cute.Coord: + return self.blk_coord + + +class MLAStaticTileScheduler: + def __init__( + self, + params: MLAStaticTileSchedulerParams, + current_work_linear_idx: cutlass.Int32, + blk_coord: cute.Coord, + grid_shape: cute.Shape, + *, + is_valid: bool = True, + loc=None, + ip=None, + ): + """The static tile scheduler for MLA split kv kernel. + Based on `is_persistent`, it provides 2 modes for use: + - Persistent mode: Launch fixed blocks and reschedule the data blocks. + - Non-persistent mode: Launch dynamic blocks and exit when the current work is done. + + :param params: The static tile scheduler parameters + :type params: MLAStaticTileSchedulerParams + :param current_work_linear_idx: The linear index of the current work + :type current_work_linear_idx: cutlass.Int32 + :param blk_coord: The coordinate of the current work + :type blk_coord: cute.Coord + :param grid_shape: The shape of the grid + :type grid_shape: cute.Shape + :param is_valid: Whether the current work is valid + :type is_valid: bool + """ + self.params = params + self.blk_coord = blk_coord + self.grid_shape = grid_shape + self.current_work_linear_idx = current_work_linear_idx + if params.is_persistent: + self.persistent_blk_layout = cute.make_layout( + ( + params.cluster_shape_mnk[0], + params.problem_shape_s, + params.problem_shape_b, + params.split_kv, + ), + loc=loc, + ip=ip, + ) + self.num_blocks = cute.size(self.persistent_blk_layout, loc=loc, ip=ip) + # Used for persistent scheduling + self.num_persistent_sm = cute.size(grid_shape, loc=loc, ip=ip) + else: + self.is_valid = is_valid + self.loc = loc + self.ip = ip + + @staticmethod + def get_grid_shape( + params: MLAStaticTileSchedulerParams, + max_active_clusters: int, + *, + loc=None, + ip=None, + ) -> cute.Shape: + # called by host + grid_shape = ( + params.cluster_shape_mnk[0], + params.problem_shape_b * params.problem_shape_s, + params.split_kv, + ) + if params.is_persistent: + return ( + cutlass.min( + max_active_clusters * cute.size(params.cluster_shape_mnk), + cute.size(grid_shape, loc=loc, ip=ip), + ), + 1, + 1, + ) + else: + return grid_shape + + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + is_valid = ( + self.current_work_linear_idx < self.num_blocks + if self.params.is_persistent + else self.is_valid + ) + + if self.params.is_persistent: + current_work_cluster_batch, cluster_idx = ( + self.current_work_linear_idx // self.params.cluster_shape_mnk[0], + self.current_work_linear_idx % self.params.cluster_shape_mnk[0], + ) + current_work_s_batch, s_idx = divmod( + current_work_cluster_batch, self.params.problem_shape_s_fdd + ) + current_work_b_batch, b_idx = divmod( + current_work_s_batch, self.params.problem_shape_b_fdd + ) + _, split_kv_idx = divmod(current_work_b_batch, self.params.split_kv_fdd) + + blk_coord = (cluster_idx, s_idx, b_idx, split_kv_idx) + else: + s_idx, b_idx = divmod(self.blk_coord[1], self.params.problem_shape_b_fdd) + blk_coord = (self.blk_coord[0], s_idx, b_idx, self.blk_coord[2]) + + return WorkTileInfo(blk_coord, is_valid) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def advance_to_next_work(self, *, advance_count=1, loc=None, ip=None): + if self.params.is_persistent: + self.current_work_linear_idx += advance_count * self.num_persistent_sm + else: + self.is_valid = False + + def __extract_mlir_values__(self): + values = cutlass.extract_mlir_values(self.params) + values.extend(cutlass.extract_mlir_values(self.current_work_linear_idx)) + values.extend(cutlass.extract_mlir_values(self.blk_coord)) + values.extend(cutlass.extract_mlir_values(self.grid_shape)) + return values + + def __new_from_mlir_values__(self, values): + assert len(values) == 13 + new_params = cutlass.new_from_mlir_values(self.params, values[0:6]) + new_current_work_linear_idx = cutlass.new_from_mlir_values( + self.current_work_linear_idx, [values[6]] + ) + new_blk_coord = cutlass.new_from_mlir_values(self.blk_coord, values[7:10]) + new_grid_shape = cutlass.new_from_mlir_values(self.grid_shape, values[10:]) + return MLAStaticTileScheduler( + new_params, new_current_work_linear_idx, new_blk_coord, new_grid_shape + ) + + +def create_mla_static_tile_scheduler( + params: MLAStaticTileSchedulerParams, + blk_coord: cute.Coord, + grid_shape: cute.Shape, +) -> MLAStaticTileScheduler: + return MLAStaticTileScheduler(params, blk_coord[0], blk_coord, grid_shape) + + +LOG2_E = 1.4426950408889634074 +# avoid register indexing on array. +MAX_SPLITS = 256 + + +def ceil_div(a: int, b: int) -> int: + return (a + b - 1) // b diff --git a/flashinfer/mla.py b/flashinfer/mla.py index b78f62101e..1c449db70d 100644 --- a/flashinfer/mla.py +++ b/flashinfer/mla.py @@ -768,6 +768,22 @@ def trtllm_batch_decode_with_kv_cache_mla( ) return out + elif backend == "cute-dsl": + from .cute_dsl.mla_decode import cute_dsl_mla_decode + + return cute_dsl_mla_decode( + query=query, + kv_cache=kv_cache, + workspace_buffer=workspace_buffer, + kv_lora_rank=kv_lora_rank, + qk_rope_head_dim=qk_rope_head_dim, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=max_seq_len, + softmax_scale=bmm1_scale if isinstance(bmm1_scale, float) else float(bmm1_scale.item()), + output_scale=bmm2_scale if isinstance(bmm2_scale, float) else float(bmm2_scale.item()), + out=out, + ) else: raise ValueError(f"Backend {backend} not supported") diff --git a/tests/attention/test_cute_dsl_mla_decode.py b/tests/attention/test_cute_dsl_mla_decode.py new file mode 100644 index 0000000000..97e5bc6285 --- /dev/null +++ b/tests/attention/test_cute_dsl_mla_decode.py @@ -0,0 +1,290 @@ +# Copyright (c) 2025 by FlashInfer team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for CuTe DSL MLA decode kernel.""" + +import pytest +import torch +import torch.nn.functional as F + +from flashinfer.utils import is_sm100a_supported +from flashinfer.cute_dsl import is_cute_dsl_available + + +def skip_if_unsupported(): + if not is_sm100a_supported(torch.device("cuda")): + pytest.skip("Requires SM100a (Blackwell)") + if not is_cute_dsl_available(): + pytest.skip("CuTe DSL not available") + + +def torch_reference_mla( + q_nope, + q_rope, + c_latent, + c_rope, + page_table, + cache_seqs, + softmax_scale, + output_scale, + page_size, +): + """PyTorch reference implementation for MLA decode. + + Args: + q_nope: [B, q_len, H, latent_dim] + q_rope: [B, q_len, H, rope_dim] + c_latent: [num_pages * page_size, latent_dim] + c_rope: [num_pages * page_size, rope_dim] + page_table: [B, max_pages] + cache_seqs: [B] — actual sequence lengths + softmax_scale: float + output_scale: float + page_size: int + """ + B, q_len, H, latent_dim = q_nope.shape + rope_dim = q_rope.shape[-1] + + outputs = [] + for b in range(B): + seq_len = cache_seqs[b].item() + num_pages_needed = (seq_len + page_size - 1) // page_size + + # Gather KV for this batch via page table + page_indices = page_table[b, :num_pages_needed] + kv_indices = [] + for p in page_indices: + start = p.item() * page_size + kv_indices.extend(range(start, start + page_size)) + kv_indices = kv_indices[:seq_len] + kv_indices_t = torch.tensor(kv_indices, device=q_nope.device) + + k_latent = c_latent[kv_indices_t] # [seq_len, latent_dim] + k_rope = c_rope[kv_indices_t] # [seq_len, rope_dim] + + # q: [q_len, H, D], k: [seq_len, D] + q_lat_b = q_nope[b] # [q_len, H, latent_dim] + q_rope_b = q_rope[b] # [q_len, H, rope_dim] + + # Compute attention scores + # QK^T = q_latent @ k_latent^T + q_rope @ k_rope^T + # [q_len, H, latent_dim] @ [latent_dim, seq_len] -> [q_len, H, seq_len] + attn_latent = torch.einsum("qhd,kd->qhk", q_lat_b.float(), k_latent.float()) + attn_rope = torch.einsum("qhd,kd->qhk", q_rope_b.float(), k_rope.float()) + attn = (attn_latent + attn_rope) * softmax_scale + + # Softmax + attn = F.softmax(attn, dim=-1) + + # Output: attn @ V (V = k_latent for MLA) + # [q_len, H, seq_len] @ [seq_len, latent_dim] -> [q_len, H, latent_dim] + out_b = torch.einsum("qhk,kd->qhd", attn, k_latent.float()) + out_b = out_b * output_scale + outputs.append(out_b) + + return torch.stack(outputs, dim=0) # [B, q_len, H, latent_dim] + + +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("seq_len_k", [128, 512, 2048]) +@pytest.mark.parametrize("page_size", [128]) +def test_cute_dsl_mla_decode_fp16(batch_size, seq_len_k, page_size): + """Test FP16 MLA decode kernel.""" + skip_if_unsupported() + + from flashinfer.cute_dsl.mla_decode import cute_dsl_mla_decode + + torch.manual_seed(42) + device = torch.device("cuda") + + num_heads = 128 + latent_dim = 512 + rope_dim = 64 + q_len = 1 + softmax_scale = 1.0 / (latent_dim**0.5) + output_scale = 1.0 + + # Allocate query: [B, q_len, H, D_qk] + D_qk = latent_dim + rope_dim + query = torch.randn(batch_size, q_len, num_heads, D_qk, dtype=torch.float16, device=device) + + # Allocate paged KV cache + num_pages_per_batch = (seq_len_k + page_size - 1) // page_size + total_pages = num_pages_per_batch * batch_size + 10 # extra pages + kv_cache = torch.randn(total_pages, page_size, latent_dim + rope_dim, dtype=torch.float16, device=device) + + # Page table: [B, max_pages] — sequential assignment + block_tables = torch.zeros(batch_size, num_pages_per_batch, dtype=torch.int32, device=device) + for b in range(batch_size): + for p in range(num_pages_per_batch): + block_tables[b, p] = b * num_pages_per_batch + p + + # Sequence lengths + seq_lens = torch.full((batch_size,), seq_len_k, dtype=torch.int32, device=device) + + # Workspace + workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device=device) + + # Run kernel + out = cute_dsl_mla_decode( + query=query, + kv_cache=kv_cache, + workspace_buffer=workspace_buffer, + kv_lora_rank=latent_dim, + qk_rope_head_dim=rope_dim, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=seq_len_k, + softmax_scale=softmax_scale, + output_scale=output_scale, + ) + + # Reference + kv_flat = kv_cache.reshape(-1, latent_dim + rope_dim) + c_latent_ref = kv_flat[:, :latent_dim] + c_rope_ref = kv_flat[:, latent_dim:] + q_nope = query[..., :latent_dim] + q_rope = query[..., latent_dim:] + + ref_out = torch_reference_mla( + q_nope, q_rope, c_latent_ref, c_rope_ref, + block_tables, seq_lens, softmax_scale, output_scale, page_size, + ) + + if q_len == 1: + ref_out = ref_out.squeeze(1) + + ref_out_fp16 = ref_out.to(torch.float16) + + # Check with tolerance appropriate for FP16 + torch.testing.assert_close(out, ref_out_fp16, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("seq_len_k", [128, 512]) +def test_cute_dsl_mla_decode_variable_seq_len(batch_size, seq_len_k, page_size=128): + """Test MLA decode with variable sequence lengths across the batch.""" + skip_if_unsupported() + + from flashinfer.cute_dsl.mla_decode import cute_dsl_mla_decode + + torch.manual_seed(42) + device = torch.device("cuda") + + num_heads = 128 + latent_dim = 512 + rope_dim = 64 + q_len = 1 + softmax_scale = 1.0 / (latent_dim**0.5) + output_scale = 1.0 + D_qk = latent_dim + rope_dim + + query = torch.randn(batch_size, q_len, num_heads, D_qk, dtype=torch.float16, device=device) + + # Variable sequence lengths + max_seq_len = seq_len_k + seq_lens = torch.randint( + page_size, max_seq_len + 1, (batch_size,), dtype=torch.int32, device=device + ) + + max_pages_per_batch = (max_seq_len + page_size - 1) // page_size + total_pages = max_pages_per_batch * batch_size + 10 + kv_cache = torch.randn(total_pages, page_size, D_qk, dtype=torch.float16, device=device) + + block_tables = torch.zeros(batch_size, max_pages_per_batch, dtype=torch.int32, device=device) + for b in range(batch_size): + for p in range(max_pages_per_batch): + block_tables[b, p] = b * max_pages_per_batch + p + + workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device=device) + + out = cute_dsl_mla_decode( + query=query, + kv_cache=kv_cache, + workspace_buffer=workspace_buffer, + kv_lora_rank=latent_dim, + qk_rope_head_dim=rope_dim, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=max_seq_len, + softmax_scale=softmax_scale, + output_scale=output_scale, + ) + + # Reference + kv_flat = kv_cache.reshape(-1, D_qk) + c_latent_ref = kv_flat[:, :latent_dim] + c_rope_ref = kv_flat[:, latent_dim:] + q_nope = query[..., :latent_dim] + q_rope = query[..., latent_dim:] + + ref_out = torch_reference_mla( + q_nope, q_rope, c_latent_ref, c_rope_ref, + block_tables, seq_lens, softmax_scale, output_scale, page_size, + ) + if q_len == 1: + ref_out = ref_out.squeeze(1) + ref_out_fp16 = ref_out.to(torch.float16) + + torch.testing.assert_close(out, ref_out_fp16, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("seq_len_k", [128, 512]) +def test_cute_dsl_mla_decode_via_api(batch_size, seq_len_k, page_size=128): + """Test MLA decode via the trtllm_batch_decode_with_kv_cache_mla API with cute-dsl backend.""" + skip_if_unsupported() + + from flashinfer.mla import trtllm_batch_decode_with_kv_cache_mla + + torch.manual_seed(42) + device = torch.device("cuda") + + num_heads = 128 + latent_dim = 512 + rope_dim = 64 + q_len = 1 + softmax_scale = 1.0 / (latent_dim**0.5) + D_qk = latent_dim + rope_dim + + query = torch.randn(batch_size, q_len, num_heads, D_qk, dtype=torch.float16, device=device) + + num_pages_per_batch = (seq_len_k + page_size - 1) // page_size + total_pages = num_pages_per_batch * batch_size + 10 + kv_cache = torch.randn(total_pages, page_size, D_qk, dtype=torch.float16, device=device) + + block_tables = torch.zeros(batch_size, num_pages_per_batch, dtype=torch.int32, device=device) + for b in range(batch_size): + for p in range(num_pages_per_batch): + block_tables[b, p] = b * num_pages_per_batch + p + + seq_lens = torch.full((batch_size,), seq_len_k, dtype=torch.int32, device=device) + workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device=device) + + out = trtllm_batch_decode_with_kv_cache_mla( + query=query, + kv_cache=kv_cache, + workspace_buffer=workspace_buffer, + qk_nope_head_dim=latent_dim, + kv_lora_rank=latent_dim, + qk_rope_head_dim=rope_dim, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=seq_len_k, + bmm1_scale=softmax_scale, + bmm2_scale=1.0, + backend="cute-dsl", + ) + + assert out.shape == (batch_size, num_heads, latent_dim) From d95da37a2c153b99f962beeb1f28cead1735080b Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Tue, 10 Mar 2026 00:52:29 -0700 Subject: [PATCH 02/31] style: Fix pre-commit lint/format/type errors in MLA decode kernel files --- flashinfer/cute_dsl/mla_decode.py | 15 ++--- flashinfer/cute_dsl/mla_decode_fp16.py | 17 +++--- flashinfer/cute_dsl/mla_decode_fp8.py | 14 ++--- flashinfer/mla.py | 8 ++- tests/attention/test_cute_dsl_mla_decode.py | 63 ++++++++++++++++----- 5 files changed, 76 insertions(+), 41 deletions(-) diff --git a/flashinfer/cute_dsl/mla_decode.py b/flashinfer/cute_dsl/mla_decode.py index 02834e3d45..fff9fc38c1 100644 --- a/flashinfer/cute_dsl/mla_decode.py +++ b/flashinfer/cute_dsl/mla_decode.py @@ -86,7 +86,7 @@ def _get_compiled_mla_kernel( sym_latent = cute.sym_int() sym_seq_q = cute.sym_int() sym_rope = cute.sym_int() - sym_batch = cute.sym_int() # query/output batch dimension + sym_batch = cute.sym_int() # query/output batch dimension sym_kv_batch = cute.sym_int() # KV cache batch dim (flat pool, =1 in paged mode) sym_seq_kv = cute.sym_int() sym_page_count = cute.sym_int() @@ -289,17 +289,18 @@ def cute_dsl_mla_decode( q_latent_k = q_nope.permute(2, 3, 1, 0) # [H, latent_dim, q_len, B], stride[1]=1 q_rope_k = q_rope.permute(2, 3, 1, 0) # [H, rope_dim, q_len, B], stride[1]=1 - # Total number of physical pages in the KV cache pool - num_pages = kv_cache.shape[0] - # Reshape KV cache to kernel layout [page_size, D, num_pages]. # The kernel indexes via page_table: for batch b, page p, offset t: # c_latent[t, d, page_table[p, b]] = token (page_table[p,b]*page_size + t)'s latent[d] # kv_cache: [num_pages, page_size, D_total] with strides (page_size*D_total, D_total, 1) # After permute(1, 2, 0) on latent slice: [page_size, latent_dim, num_pages] # strides = (D_total, 1, page_size*D_total) → stride[1]=1 ✓ - c_latent_k = kv_cache[:, :, :kv_lora_rank].permute(1, 2, 0) # [page_size, latent_dim, num_pages] - c_rope_k = kv_cache[:, :, kv_lora_rank:].permute(1, 2, 0) # [page_size, rope_dim, num_pages] + c_latent_k = kv_cache[:, :, :kv_lora_rank].permute( + 1, 2, 0 + ) # [page_size, latent_dim, num_pages] + c_rope_k = kv_cache[:, :, kv_lora_rank:].permute( + 1, 2, 0 + ) # [page_size, rope_dim, num_pages] # Page table: [B, max_pages] -> [max_pages, B] page_table_k = block_tables.t().contiguous().to(torch.int32) @@ -320,7 +321,7 @@ def cute_dsl_mla_decode( # Prepare workspace tensor if workspace_size > 0: - workspace_bytes = workspace_buffer[: workspace_size].contiguous() + workspace_bytes = workspace_buffer[:workspace_size].contiguous() else: workspace_bytes = workspace_buffer[:1].contiguous() diff --git a/flashinfer/cute_dsl/mla_decode_fp16.py b/flashinfer/cute_dsl/mla_decode_fp16.py index baae56b65b..c693ae18c5 100644 --- a/flashinfer/cute_dsl/mla_decode_fp16.py +++ b/flashinfer/cute_dsl/mla_decode_fp16.py @@ -26,15 +26,12 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import os -import sys import argparse import math from typing import Type, Tuple, Optional from types import SimpleNamespace import torch -import torch.nn.functional as F import cuda.bindings.driver as cuda import cutlass @@ -641,7 +638,7 @@ class SplitKVKernelSharedStorage: grid=grid, block=[self.threads_per_cta, 1, 1], cluster=self.cluster_shape_mnk, - smem=SplitKVKernelSharedStorage.size_in_bytes(), + smem=SplitKVKernelSharedStorage.size_in_bytes(), # type: ignore[attr-defined] stream=stream, min_blocks_per_mp=1, ) @@ -1331,7 +1328,7 @@ def reduction_kernel( # calculate the global_lse global_lse = ( lse_max + cute.math.log2(sum_lse, fastmath=True) - if not sum_lse == self.lse_dtype(0.0) or sum_lse != sum_lse + if not sum_lse == self.lse_dtype(0.0) or sum_lse != sum_lse # noqa: SIM201 else self.lse_dtype.inf ) if tidx == 0: @@ -1482,7 +1479,7 @@ def load_page_table( ) else: sPT_for_copy[None, idx, load_pt_producer_state.index].fill(0) - mbar_ptr = load_pt_pipeline.producer_get_barrier(load_pt_producer_state) + mbar_ptr = load_pt_pipeline.producer_get_barrier(load_pt_producer_state) # noqa: F841 load_pt_pipeline.producer_commit(load_pt_producer_state) load_pt_producer_state.advance() k_index += 1 @@ -2038,7 +2035,7 @@ def mma( mma_o_producer_state, ) - return ( + return ( # type: ignore[return-value] tiled_mma_qk, tiled_mma_pv, load_q_consumer_state, @@ -2798,7 +2795,7 @@ def _tmem_load_partition( tTR_gO = tmem_load_thr_copy.partition_D(gO) tTR_cO = tmem_load_thr_copy.partition_D(cO) tTR_rAcc = cute.make_fragment_like(tTR_gO, self.acc_dtype) - return tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc + return tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc # type: ignore[return-value] def get_correction_factor( self, @@ -4173,7 +4170,7 @@ def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: try: return tuple(int(x.strip()) for x in s.split(",")) except ValueError: - raise argparse.ArgumentTypeError( + raise argparse.ArgumentTypeError( # noqa: B904 "Invalid format. Expected comma-separated integers." ) @@ -4183,7 +4180,7 @@ def parse_mma_tiler(s: str) -> Tuple[int, int, Tuple[int, int]]: raise argparse.ArgumentTypeError( "Invalid format. Expected 2 comma-separated integers." ) - return (ret[0], ret[1]) + return (ret[0], ret[1]) # type: ignore[return-value] parser = argparse.ArgumentParser(description="Example of MLA on Blackwell.") diff --git a/flashinfer/cute_dsl/mla_decode_fp8.py b/flashinfer/cute_dsl/mla_decode_fp8.py index ae987df5ba..4372d8aaa2 100644 --- a/flashinfer/cute_dsl/mla_decode_fp8.py +++ b/flashinfer/cute_dsl/mla_decode_fp8.py @@ -26,8 +26,6 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import os -import sys import argparse import math from typing import Type, Tuple, Optional @@ -704,7 +702,7 @@ class SplitKVKernelSharedStorage: grid=grid, block=[self.threads_per_cta, 1, 1], cluster=self.cluster_shape_mnk, - smem=SplitKVKernelSharedStorage.size_in_bytes(), + smem=SplitKVKernelSharedStorage.size_in_bytes(), # type: ignore[attr-defined] stream=stream, min_blocks_per_mp=1, ) @@ -1394,7 +1392,7 @@ def reduction_kernel( # calculate the global_lse global_lse = ( lse_max + cute.math.log2(sum_lse, fastmath=True) - if not sum_lse == self.lse_dtype(0.0) or sum_lse != sum_lse + if not sum_lse == self.lse_dtype(0.0) or sum_lse != sum_lse # noqa: SIM201 else self.lse_dtype.inf ) if tidx == 0: @@ -2029,7 +2027,7 @@ def mma( mma_o_producer_state, ) - return ( + return ( # type: ignore[return-value] tiled_mma_qk, tiled_mma_pv, load_q_consumer_state, @@ -2794,7 +2792,7 @@ def _tmem_load_partition( tTR_gO = tmem_load_thr_copy.partition_D(gO) tTR_cO = tmem_load_thr_copy.partition_D(cO) tTR_rAcc = cute.make_fragment_like(tTR_gO, self.acc_dtype) - return tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc + return tmem_load_tiled_copy, tAcc, tTR_tAcc, tTR_gO, tTR_cO, tTR_rAcc # type: ignore[return-value] def get_correction_factor( self, @@ -4145,7 +4143,7 @@ def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: try: return tuple(int(x.strip()) for x in s.split(",")) except ValueError: - raise argparse.ArgumentTypeError( + raise argparse.ArgumentTypeError( # noqa: B904 "Invalid format. Expected comma-separated integers." ) @@ -4155,7 +4153,7 @@ def parse_mma_tiler(s: str) -> Tuple[int, int, Tuple[int, int]]: raise argparse.ArgumentTypeError( "Invalid format. Expected 2 comma-separated integers." ) - return (ret[0], ret[1]) + return (ret[0], ret[1]) # type: ignore[return-value] parser = argparse.ArgumentParser(description="Example of MLA on Blackwell.") diff --git a/flashinfer/mla.py b/flashinfer/mla.py index 1c449db70d..eaedd02165 100644 --- a/flashinfer/mla.py +++ b/flashinfer/mla.py @@ -780,8 +780,12 @@ def trtllm_batch_decode_with_kv_cache_mla( block_tables=block_tables, seq_lens=seq_lens, max_seq_len=max_seq_len, - softmax_scale=bmm1_scale if isinstance(bmm1_scale, float) else float(bmm1_scale.item()), - output_scale=bmm2_scale if isinstance(bmm2_scale, float) else float(bmm2_scale.item()), + softmax_scale=bmm1_scale + if isinstance(bmm1_scale, float) + else float(bmm1_scale.item()), + output_scale=bmm2_scale + if isinstance(bmm2_scale, float) + else float(bmm2_scale.item()), out=out, ) else: diff --git a/tests/attention/test_cute_dsl_mla_decode.py b/tests/attention/test_cute_dsl_mla_decode.py index 97e5bc6285..f1ad3988ec 100644 --- a/tests/attention/test_cute_dsl_mla_decode.py +++ b/tests/attention/test_cute_dsl_mla_decode.py @@ -54,7 +54,6 @@ def torch_reference_mla( page_size: int """ B, q_len, H, latent_dim = q_nope.shape - rope_dim = q_rope.shape[-1] outputs = [] for b in range(B): @@ -117,15 +116,25 @@ def test_cute_dsl_mla_decode_fp16(batch_size, seq_len_k, page_size): # Allocate query: [B, q_len, H, D_qk] D_qk = latent_dim + rope_dim - query = torch.randn(batch_size, q_len, num_heads, D_qk, dtype=torch.float16, device=device) + query = torch.randn( + batch_size, q_len, num_heads, D_qk, dtype=torch.float16, device=device + ) # Allocate paged KV cache num_pages_per_batch = (seq_len_k + page_size - 1) // page_size total_pages = num_pages_per_batch * batch_size + 10 # extra pages - kv_cache = torch.randn(total_pages, page_size, latent_dim + rope_dim, dtype=torch.float16, device=device) + kv_cache = torch.randn( + total_pages, + page_size, + latent_dim + rope_dim, + dtype=torch.float16, + device=device, + ) # Page table: [B, max_pages] — sequential assignment - block_tables = torch.zeros(batch_size, num_pages_per_batch, dtype=torch.int32, device=device) + block_tables = torch.zeros( + batch_size, num_pages_per_batch, dtype=torch.int32, device=device + ) for b in range(batch_size): for p in range(num_pages_per_batch): block_tables[b, p] = b * num_pages_per_batch + p @@ -158,8 +167,15 @@ def test_cute_dsl_mla_decode_fp16(batch_size, seq_len_k, page_size): q_rope = query[..., latent_dim:] ref_out = torch_reference_mla( - q_nope, q_rope, c_latent_ref, c_rope_ref, - block_tables, seq_lens, softmax_scale, output_scale, page_size, + q_nope, + q_rope, + c_latent_ref, + c_rope_ref, + block_tables, + seq_lens, + softmax_scale, + output_scale, + page_size, ) if q_len == 1: @@ -190,7 +206,9 @@ def test_cute_dsl_mla_decode_variable_seq_len(batch_size, seq_len_k, page_size=1 output_scale = 1.0 D_qk = latent_dim + rope_dim - query = torch.randn(batch_size, q_len, num_heads, D_qk, dtype=torch.float16, device=device) + query = torch.randn( + batch_size, q_len, num_heads, D_qk, dtype=torch.float16, device=device + ) # Variable sequence lengths max_seq_len = seq_len_k @@ -200,9 +218,13 @@ def test_cute_dsl_mla_decode_variable_seq_len(batch_size, seq_len_k, page_size=1 max_pages_per_batch = (max_seq_len + page_size - 1) // page_size total_pages = max_pages_per_batch * batch_size + 10 - kv_cache = torch.randn(total_pages, page_size, D_qk, dtype=torch.float16, device=device) + kv_cache = torch.randn( + total_pages, page_size, D_qk, dtype=torch.float16, device=device + ) - block_tables = torch.zeros(batch_size, max_pages_per_batch, dtype=torch.int32, device=device) + block_tables = torch.zeros( + batch_size, max_pages_per_batch, dtype=torch.int32, device=device + ) for b in range(batch_size): for p in range(max_pages_per_batch): block_tables[b, p] = b * max_pages_per_batch + p @@ -230,8 +252,15 @@ def test_cute_dsl_mla_decode_variable_seq_len(batch_size, seq_len_k, page_size=1 q_rope = query[..., latent_dim:] ref_out = torch_reference_mla( - q_nope, q_rope, c_latent_ref, c_rope_ref, - block_tables, seq_lens, softmax_scale, output_scale, page_size, + q_nope, + q_rope, + c_latent_ref, + c_rope_ref, + block_tables, + seq_lens, + softmax_scale, + output_scale, + page_size, ) if q_len == 1: ref_out = ref_out.squeeze(1) @@ -258,13 +287,19 @@ def test_cute_dsl_mla_decode_via_api(batch_size, seq_len_k, page_size=128): softmax_scale = 1.0 / (latent_dim**0.5) D_qk = latent_dim + rope_dim - query = torch.randn(batch_size, q_len, num_heads, D_qk, dtype=torch.float16, device=device) + query = torch.randn( + batch_size, q_len, num_heads, D_qk, dtype=torch.float16, device=device + ) num_pages_per_batch = (seq_len_k + page_size - 1) // page_size total_pages = num_pages_per_batch * batch_size + 10 - kv_cache = torch.randn(total_pages, page_size, D_qk, dtype=torch.float16, device=device) + kv_cache = torch.randn( + total_pages, page_size, D_qk, dtype=torch.float16, device=device + ) - block_tables = torch.zeros(batch_size, num_pages_per_batch, dtype=torch.int32, device=device) + block_tables = torch.zeros( + batch_size, num_pages_per_batch, dtype=torch.int32, device=device + ) for b in range(batch_size): for p in range(num_pages_per_batch): block_tables[b, p] = b * num_pages_per_batch + p From 5cc2f749b1908ee235493821ef41b12dd45849e4 Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Tue, 10 Mar 2026 01:00:00 -0700 Subject: [PATCH 03/31] chore: Update copyright year to 2026 --- flashinfer/cute_dsl/mla_decode.py | 2 +- tests/attention/test_cute_dsl_mla_decode.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flashinfer/cute_dsl/mla_decode.py b/flashinfer/cute_dsl/mla_decode.py index fff9fc38c1..f11ce89b72 100644 --- a/flashinfer/cute_dsl/mla_decode.py +++ b/flashinfer/cute_dsl/mla_decode.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 by FlashInfer team. +# Copyright (c) 2026 by FlashInfer team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/attention/test_cute_dsl_mla_decode.py b/tests/attention/test_cute_dsl_mla_decode.py index f1ad3988ec..a7f419db0c 100644 --- a/tests/attention/test_cute_dsl_mla_decode.py +++ b/tests/attention/test_cute_dsl_mla_decode.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 by FlashInfer team. +# Copyright (c) 2026 by FlashInfer team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 3c38f20cfe376d93c162e60be03cd54b5fb714d3 Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Tue, 10 Mar 2026 01:09:54 -0700 Subject: [PATCH 04/31] feat: Add dtype assertions and FP8 tests for cute_dsl_mla_decode --- flashinfer/cute_dsl/mla_decode.py | 7 ++ tests/attention/test_cute_dsl_mla_decode.py | 88 +++++++++++++++++++++ 2 files changed, 95 insertions(+) diff --git a/flashinfer/cute_dsl/mla_decode.py b/flashinfer/cute_dsl/mla_decode.py index f11ce89b72..ee8d4e4a43 100644 --- a/flashinfer/cute_dsl/mla_decode.py +++ b/flashinfer/cute_dsl/mla_decode.py @@ -264,6 +264,13 @@ def cute_dsl_mla_decode( torch.Tensor Output tensor [B, H, kv_lora_rank]. """ + assert query.dtype in ( + torch.float16, + torch.float8_e4m3fn, + ), f"cute_dsl_mla_decode only supports float16 and float8_e4m3fn, got {query.dtype}" + assert kv_cache.dtype == query.dtype, ( + f"kv_cache dtype {kv_cache.dtype} must match query dtype {query.dtype}" + ) B, q_len, H, D_qk = query.shape assert D_qk == kv_lora_rank + qk_rope_head_dim assert kv_lora_rank == _LATENT_DIM diff --git a/tests/attention/test_cute_dsl_mla_decode.py b/tests/attention/test_cute_dsl_mla_decode.py index a7f419db0c..8e9018e8a9 100644 --- a/tests/attention/test_cute_dsl_mla_decode.py +++ b/tests/attention/test_cute_dsl_mla_decode.py @@ -323,3 +323,91 @@ def test_cute_dsl_mla_decode_via_api(batch_size, seq_len_k, page_size=128): ) assert out.shape == (batch_size, num_heads, latent_dim) + + +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("seq_len_k", [128, 512]) +@pytest.mark.parametrize("page_size", [128]) +def test_cute_dsl_mla_decode_fp8(batch_size, seq_len_k, page_size): + """Test FP8 MLA decode kernel against FP32 reference.""" + skip_if_unsupported() + + from flashinfer.cute_dsl.mla_decode import cute_dsl_mla_decode + + torch.manual_seed(42) + device = torch.device("cuda") + + num_heads = 128 + latent_dim = 512 + rope_dim = 64 + q_len = 1 + softmax_scale = 1.0 / (latent_dim**0.5) + output_scale = 1.0 + D_qk = latent_dim + rope_dim + + # Create FP8 query and KV cache (cast from small-valued FP16 to stay in FP8 range) + query = ( + torch.randn( + batch_size, q_len, num_heads, D_qk, dtype=torch.float16, device=device + ) + * 0.1 + ).to(torch.float8_e4m3fn) + + num_pages_per_batch = (seq_len_k + page_size - 1) // page_size + total_pages = num_pages_per_batch * batch_size + 10 + kv_cache = ( + torch.randn(total_pages, page_size, D_qk, dtype=torch.float16, device=device) + * 0.1 + ).to(torch.float8_e4m3fn) + + block_tables = torch.zeros( + batch_size, num_pages_per_batch, dtype=torch.int32, device=device + ) + for b in range(batch_size): + for p in range(num_pages_per_batch): + block_tables[b, p] = b * num_pages_per_batch + p + + seq_lens = torch.full((batch_size,), seq_len_k, dtype=torch.int32, device=device) + workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device=device) + + out = cute_dsl_mla_decode( + query=query, + kv_cache=kv_cache, + workspace_buffer=workspace_buffer, + kv_lora_rank=latent_dim, + qk_rope_head_dim=rope_dim, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=seq_len_k, + softmax_scale=softmax_scale, + output_scale=output_scale, + ) + + assert out.dtype == torch.float8_e4m3fn + assert out.shape == (batch_size, num_heads, latent_dim) + + # Reference: compute in FP32 using FP8 values dequantized to FP32 + kv_flat = kv_cache.reshape(-1, D_qk).to(torch.float32) + c_latent_ref = kv_flat[:, :latent_dim] + c_rope_ref = kv_flat[:, latent_dim:] + q_nope = query[..., :latent_dim].to(torch.float32) + q_rope_tensor = query[..., latent_dim:].to(torch.float32) + + ref_out = torch_reference_mla( + q_nope, + q_rope_tensor, + c_latent_ref, + c_rope_ref, + block_tables, + seq_lens, + softmax_scale, + output_scale, + page_size, + ) + if q_len == 1: + ref_out = ref_out.squeeze(1) + + # Compare outputs in FP32; FP8 has limited precision so use wider tolerance + torch.testing.assert_close( + out.to(torch.float32), ref_out.to(torch.float32), atol=0.1, rtol=0.1 + ) From 7b12c25b3be804deb4498603eb624e5f249d08dd Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Tue, 10 Mar 2026 02:24:48 -0700 Subject: [PATCH 05/31] perf: Reduce host overhead in cute_dsl_mla_decode - Remove unnecessary .contiguous() on page_table transpose by changing fake tensor stride_order from (1,0) to (0,1), matching the original kernel's convention of non-contiguous permute(1,0) - Use torch.full instead of torch.ones * val for block_split_kvs - Remove redundant .contiguous() on workspace buffer slice - Remove redundant .to(int32).contiguous() when seq_lens is already int32 - Eliminate output copy_ by writing kernel output directly into caller's out tensor via permute view (works for both q_len=1 and q_len>1) - Fix output allocation order from (B,H,q_len,D) to (B,q_len,H,D) so permute back to user layout is naturally contiguous, removing .contiguous() - Cache split_kv and workspace_size computation via functools.cache - Remove tensor_api closure wrapper, call compiled_kernel directly - Add host overhead benchmark script Co-Authored-By: Claude Opus 4.6 --- flashinfer/cute_dsl/mla_decode.py | 204 +++++++--------- .../bench_cute_dsl_mla_host_overhead.py | 229 ++++++++++++++++++ 2 files changed, 318 insertions(+), 115 deletions(-) create mode 100644 tests/attention/bench_cute_dsl_mla_host_overhead.py diff --git a/flashinfer/cute_dsl/mla_decode.py b/flashinfer/cute_dsl/mla_decode.py index ee8d4e4a43..dc18f56f60 100644 --- a/flashinfer/cute_dsl/mla_decode.py +++ b/flashinfer/cute_dsl/mla_decode.py @@ -42,6 +42,24 @@ _SKIP_CORRECTION_THRESHOLD = 0.0 +@functools.cache +def _get_split_kv_and_workspace_size( + B: int, + q_len: int, + max_seq_len: int, + H: int, + max_active_blocks: int, +) -> Tuple[int, int]: + """Cache split_kv and workspace_size since they are deterministic for the same params.""" + split_kv = BlackwellMultiHeadLatentAttentionForwardFP16.get_split_kv( + B, q_len, max_seq_len, _MMA_QK_TILER_MN, max_active_blocks + ) + workspace_size = BlackwellMultiHeadLatentAttentionForwardFP16.get_workspace_size( + H, q_len, _LATENT_DIM, B, split_kv, cutlass.Float32 + ) + return split_kv, workspace_size + + @functools.cache def _get_compiled_mla_kernel( is_fp8: bool, @@ -51,11 +69,14 @@ def _get_compiled_mla_kernel( is_persistent: bool, is_var_seq: bool, is_var_split_kv: bool, -) -> Tuple[Callable, object]: +) -> Callable: """Compile and cache an MLA decode kernel. - Returns (compiled_kernel_closure, kernel_class_instance). - The kernel_class_instance is needed for get_split_kv() and get_workspace_size(). + Returns a callable that accepts (q_latent, q_rope, c_latent, c_rope, + page_table, o, lse, workspace, split_kv_scalar, cache_seqs, + block_split_kvs, softmax_scale_scalar, output_scale_scalar). + + All scalar arguments must be pre-wrapped as Int32/Float32. """ KernelClass = ( BlackwellMultiHeadLatentAttentionForwardFP8 @@ -122,11 +143,14 @@ def _get_compiled_mla_kernel( stride_order=(2, 0, 1), assumed_align=128, ) - # page_table: [page_count, batch_size] + # page_table: [page_count, batch_size] with stride[0]==1 + # Matches the original kernel's convention: page_table_ref.permute(1, 0) gives + # strides (1, page_count), so dim0(page_count) is the contiguous dimension. + # This allows passing block_tables.t() directly without .contiguous(). page_table_fake = cute.runtime.make_fake_compact_tensor( cutlass.Int32, (sym_page_count, sym_batch), - stride_order=(1, 0), + stride_order=(0, 1), assumed_align=128, ) # o: [num_heads, latent_dim, seq_len_q, batch_size] — stride[1]==1 @@ -184,39 +208,7 @@ def _get_compiled_mla_kernel( options="--enable-tvm-ffi", ) - def tensor_api( - q_latent: torch.Tensor, - q_rope: torch.Tensor, - c_latent: torch.Tensor, - c_rope: torch.Tensor, - page_table: torch.Tensor, - o: torch.Tensor, - lse: torch.Tensor, - workspace: torch.Tensor, - split_kv: int, - cache_seqs: torch.Tensor, - block_split_kvs: torch.Tensor, - softmax_scale: float, - output_scale: float, - ) -> None: - nonlocal compiled_kernel - compiled_kernel( - q_latent, - q_rope, - c_latent, - c_rope, - page_table, - o, - lse, - workspace, - Int32(split_kv), - cache_seqs, - block_split_kvs, - Float32(softmax_scale), - Float32(output_scale), - ) - - return tensor_api, kernel_obj + return compiled_kernel def cute_dsl_mla_decode( @@ -280,95 +272,78 @@ def cute_dsl_mla_decode( # Handle 3D vs 4D kv_cache: normalize to 3D [num_pages, page_size, D_total] if kv_cache.dim() == 4: - # [num_pages, 1, page_size, D_total] -> [num_pages, page_size, D_total] kv_cache = kv_cache.squeeze(1) page_size = kv_cache.shape[1] - D_total = kv_cache.shape[2] - assert D_total == kv_lora_rank + qk_rope_head_dim - - # Split query into latent and rope components - q_nope = query[..., :kv_lora_rank] # [B, q_len, H, latent_dim] - q_rope = query[..., kv_lora_rank:] # [B, q_len, H, rope_dim] - # Reshape to kernel layout: [B, q_len, H, D] -> [H, D, q_len, B] + # Split query into latent and rope components and reshape to kernel layout. + # [B, q_len, H, D] -> slice -> permute -> [H, D, q_len, B] with stride[1]=1. # Do NOT call .contiguous() — permute gives stride[1]=1 which the kernel requires. - # .contiguous() would rearrange to row-major making stride[3]=1 instead. - q_latent_k = q_nope.permute(2, 3, 1, 0) # [H, latent_dim, q_len, B], stride[1]=1 - q_rope_k = q_rope.permute(2, 3, 1, 0) # [H, rope_dim, q_len, B], stride[1]=1 + q_latent_k = query[..., :kv_lora_rank].permute(2, 3, 1, 0) + q_rope_k = query[..., kv_lora_rank:].permute(2, 3, 1, 0) # Reshape KV cache to kernel layout [page_size, D, num_pages]. - # The kernel indexes via page_table: for batch b, page p, offset t: - # c_latent[t, d, page_table[p, b]] = token (page_table[p,b]*page_size + t)'s latent[d] - # kv_cache: [num_pages, page_size, D_total] with strides (page_size*D_total, D_total, 1) - # After permute(1, 2, 0) on latent slice: [page_size, latent_dim, num_pages] - # strides = (D_total, 1, page_size*D_total) → stride[1]=1 ✓ - c_latent_k = kv_cache[:, :, :kv_lora_rank].permute( - 1, 2, 0 - ) # [page_size, latent_dim, num_pages] - c_rope_k = kv_cache[:, :, kv_lora_rank:].permute( - 1, 2, 0 - ) # [page_size, rope_dim, num_pages] - - # Page table: [B, max_pages] -> [max_pages, B] - page_table_k = block_tables.t().contiguous().to(torch.int32) - - # Determine split_kv and workspace - is_persistent = True - is_var_seq = True - is_var_split_kv = True - max_active_blocks = get_num_sm(query.device) + # The kernel indexes via page_table: c_latent[intra_page_offset, d, physical_page_idx]. + # After permute: strides = (D_total, 1, page_size*D_total) → stride[1]=1 ✓ + c_latent_k = kv_cache[:, :, :kv_lora_rank].permute(1, 2, 0) + c_rope_k = kv_cache[:, :, kv_lora_rank:].permute(1, 2, 0) - split_kv = BlackwellMultiHeadLatentAttentionForwardFP16.get_split_kv( - B, q_len, max_seq_len, _MMA_QK_TILER_MN, max_active_blocks - ) + # Page table: [B, max_pages] -> [max_pages, B] (view only, no copy needed). + # The kernel accepts non-contiguous strides via CuTe layout, matching the original + # kernel's convention of page_table_ref.permute(1, 0) without .contiguous(). + page_table_k = block_tables.permute(1, 0) - workspace_size = BlackwellMultiHeadLatentAttentionForwardFP16.get_workspace_size( - H, q_len, _LATENT_DIM, B, split_kv, cutlass.Float32 + # Cached split_kv and workspace_size computation + max_active_blocks = get_num_sm(query.device) + split_kv, workspace_size = _get_split_kv_and_workspace_size( + B, q_len, max_seq_len, H, max_active_blocks ) - # Prepare workspace tensor - if workspace_size > 0: - workspace_bytes = workspace_buffer[:workspace_size].contiguous() - else: - workspace_bytes = workspace_buffer[:1].contiguous() + # Prepare workspace — slice of contiguous 1D buffer is already contiguous + workspace_bytes = workspace_buffer[: max(workspace_size, 1)] - # Allocate output: [H, latent_dim, q_len, B] with stride[1]==1 - # torch.empty(B, H, q_len, D) has row-major strides (H*q_len*D, q_len*D, D, 1). - # After permute(1, 3, 2, 0) → shape [H, D, q_len, B] with strides (q_len*D, 1, D, H*q_len*D). - # Do NOT call .contiguous() — that would collapse to row-major making stride[3]=1. + # Output buffer setup: kernel needs [H, D, q_len, B] with stride[1]==1. + # If caller provides `out`, reuse it directly via permute to avoid allocation + copy_. + # q_len==1: out [B, H, D] → permute(1,2,0) → [H, D, B] → unsqueeze(2) → [H, D, 1, B] + # q_len >1: out [B, q_len, H, D] → permute(2,3,1,0) → [H, D, q_len, B] + # Both give stride[1]=1 ✓, kernel writes directly into out's memory. out_dtype = torch.float8_e4m3fn if is_fp8 else torch.float16 - o_k = torch.empty( - (B, H, q_len, _LATENT_DIM), dtype=out_dtype, device=query.device - ).permute(1, 3, 2, 0) # [H, latent_dim, q_len, B], stride[1]=1 - - # LSE: [H, q_len, B] with stride[0]==1 (H dim is contiguous). - # torch.empty(B, q_len, H) has row-major strides (q_len*H, H, 1). - # After permute(2, 1, 0) → shape [H, q_len, B] with strides (1, H, q_len*H). - # Do NOT call .contiguous() — that would make stride[2]=1 instead of stride[0]=1. + if out is not None: + if q_len == 1: + o_k = out.permute(1, 2, 0).unsqueeze(2) + else: + o_k = out.permute(2, 3, 1, 0) + else: + # Allocate as [B, q_len, H, D] so that permute back is already contiguous. + # permute(2, 3, 1, 0) → [H, D, q_len, B] with stride[1]=1 ✓ + o_k = torch.empty( + (B, q_len, H, _LATENT_DIM), dtype=out_dtype, device=query.device + ).permute(2, 3, 1, 0) + + # LSE: [H, q_len, B] with stride[0]==1 (H dim is contiguous) lse_k = torch.empty( (B, q_len, H), dtype=torch.float32, device=query.device - ).permute(2, 1, 0) # [H, q_len, B], stride[0]=1 - - # cache_seqs: per-batch sequence lengths - cache_seqs = seq_lens.to(torch.int32).contiguous() + ).permute(2, 1, 0) - # block_split_kvs: per-batch split_kv values - # Compute per-batch split_kv based on actual sequence lengths - block_split_kvs = torch.ones(B, dtype=torch.int32, device=query.device) * split_kv + # cache_seqs: per-batch sequence lengths (skip .to() if already int32) + cache_seqs = seq_lens if seq_lens.dtype == torch.int32 else seq_lens.to(torch.int32) + + # TOOD: this will trigger a kernel. + # block_split_kvs: uniform split_kv for all batches + block_split_kvs = torch.full((B,), split_kv, dtype=torch.int32, device=query.device) - # Get compiled kernel - tensor_api, kernel_cls = _get_compiled_mla_kernel( + # Get compiled kernel (cached after first compile) + compiled_kernel = _get_compiled_mla_kernel( is_fp8=is_fp8, page_size=page_size, num_heads=H, seq_len_q=q_len, - is_persistent=is_persistent, - is_var_seq=is_var_seq, - is_var_split_kv=is_var_split_kv, + is_persistent=True, + is_var_seq=True, + is_var_split_kv=True, ) # Call the kernel - tensor_api( + compiled_kernel( q_latent_k, q_rope_k, c_latent_k, @@ -376,23 +351,22 @@ def cute_dsl_mla_decode( page_table_k, o_k, lse_k, - workspace_bytes.view(torch.uint8), - split_kv, + workspace_bytes, + Int32(split_kv), cache_seqs, block_split_kvs, - softmax_scale, - output_scale, + Float32(softmax_scale), + Float32(output_scale), ) - # Reshape output: [H, latent_dim, q_len, B] -> [B, q_len, H, latent_dim] - result = o_k.permute(3, 2, 0, 1).contiguous() + # If out was provided, kernel already wrote into it — return directly. + if out is not None: + return out - # Squeeze q_len dimension if it's 1: [B, 1, H, D] -> [B, H, D] + # No out provided: reshape kernel output [H, D, q_len, B] -> [B, (q_len,) H, D] + # The permute back is always contiguous because we allocated as [B, q_len, H, D]. + result = o_k.permute(3, 2, 0, 1) if q_len == 1: result = result.squeeze(1) - if out is not None: - out.copy_(result) - return out - return result diff --git a/tests/attention/bench_cute_dsl_mla_host_overhead.py b/tests/attention/bench_cute_dsl_mla_host_overhead.py new file mode 100644 index 0000000000..79e74de236 --- /dev/null +++ b/tests/attention/bench_cute_dsl_mla_host_overhead.py @@ -0,0 +1,229 @@ +#!/usr/bin/env python3 +"""Benchmark host overhead of cute_dsl_mla_decode. + +Measures Python-side overhead by timing many iterations without GPU sync +between them — the GPU queue never drains so we're measuring purely +host-side work (tensor reshaping, TVM-FFI dispatch, etc.). +""" + +import time + +import torch + +from flashinfer.cute_dsl.mla_decode import cute_dsl_mla_decode + + +def bench_host_overhead( + batch_size: int = 4, + seq_len_k: int = 2048, + page_size: int = 128, + num_iters: int = 1000, + warmup_iters: int = 50, +): + device = torch.device("cuda") + num_heads = 128 + latent_dim = 512 + rope_dim = 64 + q_len = 1 + D_qk = latent_dim + rope_dim + softmax_scale = 1.0 / (latent_dim**0.5) + output_scale = 1.0 + + query = torch.randn(batch_size, q_len, num_heads, D_qk, dtype=torch.float16, device=device) + num_pages_per_batch = (seq_len_k + page_size - 1) // page_size + total_pages = num_pages_per_batch * batch_size + 10 + kv_cache = torch.randn(total_pages, page_size, D_qk, dtype=torch.float16, device=device) + + block_tables = torch.zeros(batch_size, num_pages_per_batch, dtype=torch.int32, device=device) + for b in range(batch_size): + for p in range(num_pages_per_batch): + block_tables[b, p] = b * num_pages_per_batch + p + + seq_lens = torch.full((batch_size,), seq_len_k, dtype=torch.int32, device=device) + workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device=device) + + # Warmup — includes compilation on first call + print("Warming up...") + for _ in range(warmup_iters): + cute_dsl_mla_decode( + query=query, kv_cache=kv_cache, workspace_buffer=workspace_buffer, + kv_lora_rank=latent_dim, qk_rope_head_dim=rope_dim, + block_tables=block_tables, seq_lens=seq_lens, max_seq_len=seq_len_k, + softmax_scale=softmax_scale, output_scale=output_scale, + ) + torch.cuda.synchronize() + + # Benchmark: no sync between iterations → measures host overhead only + print(f"Benchmarking {num_iters} iterations (no inter-iteration sync)...") + torch.cuda.synchronize() + t0 = time.perf_counter() + for _ in range(num_iters): + cute_dsl_mla_decode( + query=query, kv_cache=kv_cache, workspace_buffer=workspace_buffer, + kv_lora_rank=latent_dim, qk_rope_head_dim=rope_dim, + block_tables=block_tables, seq_lens=seq_lens, max_seq_len=seq_len_k, + softmax_scale=softmax_scale, output_scale=output_scale, + ) + torch.cuda.synchronize() + t1 = time.perf_counter() + + total_us = (t1 - t0) * 1e6 + per_call_us = total_us / num_iters + print(f"Total: {total_us:.0f} us for {num_iters} calls") + print(f"Per call: {per_call_us:.1f} us") + + # Also measure with line-level profiling of the key sections + print("\n--- Profiling individual sections ---") + profile_sections(query, kv_cache, workspace_buffer, latent_dim, rope_dim, + block_tables, seq_lens, seq_len_k, softmax_scale, output_scale, + num_iters=num_iters) + + return per_call_us + + +def profile_sections(query, kv_cache, workspace_buffer, kv_lora_rank, qk_rope_head_dim, + block_tables, seq_lens, max_seq_len, softmax_scale, output_scale, + num_iters=1000): + """Profile individual sections of cute_dsl_mla_decode to find hotspots.""" + from flashinfer.cute_dsl.mla_decode import ( + _get_compiled_mla_kernel, + _get_split_kv_and_workspace_size, + _LATENT_DIM, _ROPE_DIM, _MMA_QK_TILER_MN, _MAX_ACTIVE_CLUSTERS, + BlackwellMultiHeadLatentAttentionForwardFP16, + ) + from flashinfer.cute_dsl.utils import get_num_sm + from cutlass import Float32, Int32 + import cutlass + + B, q_len, H, D_qk = query.shape + page_size = kv_cache.shape[1] + is_fp8 = query.dtype == torch.float8_e4m3fn + device = query.device + + timings = {} + + def measure(name, fn): + torch.cuda.synchronize() + t0 = time.perf_counter() + for _ in range(num_iters): + fn() + torch.cuda.synchronize() + elapsed_us = (time.perf_counter() - t0) * 1e6 / num_iters + timings[name] = elapsed_us + + # 1. Query split + permute + def query_reshape(): + q_nope = query[..., :kv_lora_rank] + q_rope = query[..., kv_lora_rank:] + q_latent_k = q_nope.permute(2, 3, 1, 0) + q_rope_k = q_rope.permute(2, 3, 1, 0) + return q_latent_k, q_rope_k + measure("query_split+permute", query_reshape) + + # 2. KV cache split + permute + def kv_reshape(): + c_latent_k = kv_cache[:, :, :kv_lora_rank].permute(1, 2, 0) + c_rope_k = kv_cache[:, :, kv_lora_rank:].permute(1, 2, 0) + return c_latent_k, c_rope_k + measure("kv_split+permute", kv_reshape) + + # 3. Page table transpose + def page_table_transpose(): + return block_tables.t().contiguous().to(torch.int32) + measure("page_table_transpose", page_table_transpose) + + # 4. split_kv + workspace_size computation (cached) + max_active_blocks = get_num_sm(device) + def compute_split(): + return _get_split_kv_and_workspace_size( + B, q_len, max_seq_len, H, max_active_blocks + ) + measure("compute_split_kv+workspace(cached)", compute_split) + + # 5. Workspace slice + split_kv, workspace_size = compute_split() + def workspace_slice(): + return workspace_buffer[:max(workspace_size, 1)] + measure("workspace_slice(no .contiguous())", workspace_slice) + + # 6. Output + LSE allocation + out_dtype = torch.float8_e4m3fn if is_fp8 else torch.float16 + def alloc_output(): + o_k = torch.empty((B, H, q_len, _LATENT_DIM), dtype=out_dtype, device=device).permute(1, 3, 2, 0) + lse_k = torch.empty((B, q_len, H), dtype=torch.float32, device=device).permute(2, 1, 0) + return o_k, lse_k + measure("alloc_output+lse", alloc_output) + + # 7. cache_seqs + block_split_kvs creation + def create_aux_tensors(): + if seq_lens.dtype == torch.int32 and seq_lens.is_contiguous(): + cache_seqs = seq_lens + else: + cache_seqs = seq_lens.to(torch.int32).contiguous() + block_split_kvs = torch.full((B,), split_kv, dtype=torch.int32, device=device) + return cache_seqs, block_split_kvs + measure("create_aux_tensors(optimized)", create_aux_tensors) + + # 8. _get_compiled_mla_kernel (should be cached) + def get_kernel(): + return _get_compiled_mla_kernel( + is_fp8=is_fp8, page_size=page_size, num_heads=H, seq_len_q=q_len, + is_persistent=True, is_var_seq=True, is_var_split_kv=True, + ) + measure("get_compiled_kernel(cached)", get_kernel) + + # 9. Kernel call only (prepare everything, measure just the call) + compiled_kernel = get_kernel() + q_latent_k = query[..., :kv_lora_rank].permute(2, 3, 1, 0) + q_rope_k = query[..., kv_lora_rank:].permute(2, 3, 1, 0) + c_latent_k = kv_cache[:, :, :kv_lora_rank].permute(1, 2, 0) + c_rope_k = kv_cache[:, :, kv_lora_rank:].permute(1, 2, 0) + page_table_k = block_tables.t().contiguous() + o_k = torch.empty((B, H, q_len, _LATENT_DIM), dtype=out_dtype, device=device).permute(1, 3, 2, 0) + lse_k = torch.empty((B, q_len, H), dtype=torch.float32, device=device).permute(2, 1, 0) + ws = workspace_buffer[:max(workspace_size, 1)] + cache_seqs = seq_lens.to(torch.int32).contiguous() + block_split_kvs_t = torch.full((B,), split_kv, dtype=torch.int32, device=device) + split_kv_scalar = Int32(split_kv) + softmax_scale_scalar = Float32(softmax_scale) + output_scale_scalar = Float32(output_scale) + + def kernel_call_pre_cached(): + compiled_kernel( + q_latent_k, q_rope_k, c_latent_k, c_rope_k, page_table_k, + o_k, lse_k, ws, + split_kv_scalar, cache_seqs, block_split_kvs_t, + softmax_scale_scalar, output_scale_scalar, + ) + measure("kernel_call(pre-cached scalars)", kernel_call_pre_cached) + + def kernel_call_per_call(): + compiled_kernel( + q_latent_k, q_rope_k, c_latent_k, c_rope_k, page_table_k, + o_k, lse_k, ws, + Int32(split_kv), cache_seqs, block_split_kvs_t, + Float32(softmax_scale), Float32(output_scale), + ) + measure("kernel_call(per-call scalars)", kernel_call_per_call) + + # 10. Output reshape + def output_reshape(): + result = o_k.permute(3, 2, 0, 1).contiguous() + if q_len == 1: + result = result.squeeze(1) + return result + measure("output_reshape", output_reshape) + + # Print results + print(f"{'Section':<35} {'us/call':>10}") + print("-" * 47) + total = 0.0 + for name, us in timings.items(): + print(f" {name:<33} {us:>10.1f}") + total += us + print("-" * 47) + print(f" {'SUM':<33} {total:>10.1f}") + + +if __name__ == "__main__": + bench_host_overhead() From a8c3b5a72b9b817cd532d568377a0090a9131003 Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Tue, 10 Mar 2026 03:13:12 -0700 Subject: [PATCH 06/31] perf: Move permute logic into kernel __call__ to eliminate Python-side permutes Accept contiguous row-major tensors and reinterpret layouts inside the kernel's __call__ via zero-cost cute.make_tensor + cute.make_layout, removing ~10 us of Python-side .permute() overhead per call. Co-Authored-By: Claude Opus 4.6 --- flashinfer/cute_dsl/mla_decode.py | 107 ++++++++++------------ flashinfer/cute_dsl/mla_decode_fp16.py | 121 ++++++++++++++++--------- flashinfer/cute_dsl/mla_decode_fp8.py | 121 ++++++++++++++++--------- 3 files changed, 204 insertions(+), 145 deletions(-) diff --git a/flashinfer/cute_dsl/mla_decode.py b/flashinfer/cute_dsl/mla_decode.py index dc18f56f60..1a64491669 100644 --- a/flashinfer/cute_dsl/mla_decode.py +++ b/flashinfer/cute_dsl/mla_decode.py @@ -113,59 +113,60 @@ def _get_compiled_mla_kernel( sym_page_count = cute.sym_int() sym_workspace_size = cute.sym_int() - # q_latent: [num_heads, latent_dim, seq_len_q, batch_size] — stride[1]==1 + # All tensors use contiguous row-major layout (stride_order descending). + # The kernel's __call__ reinterprets them to the required layout via + # cute.make_tensor zero-cost metadata shuffle. + + # q_latent: [batch_size, seq_len_q, num_heads, latent_dim] — contiguous + # make_fake_compact_tensor stride_order: value 0 = fastest (stride=1) q_latent_fake = cute.runtime.make_fake_compact_tensor( cutlass_dtype, - (sym_heads, sym_latent, sym_seq_q, sym_batch), - stride_order=(3, 0, 2, 1), + (sym_batch, sym_seq_q, sym_heads, sym_latent), + stride_order=(3, 2, 1, 0), assumed_align=128, ) - # q_rope: [num_heads, rope_dim, seq_len_q, batch_size] — stride[1]==1 + # q_rope: [batch_size, seq_len_q, num_heads, rope_dim] — contiguous q_rope_fake = cute.runtime.make_fake_compact_tensor( cutlass_dtype, - (sym_heads, sym_rope, sym_seq_q, sym_batch), - stride_order=(3, 0, 2, 1), + (sym_batch, sym_seq_q, sym_heads, sym_rope), + stride_order=(3, 2, 1, 0), assumed_align=128, ) - # c_latent: [seq_len_k, latent_dim, kv_batch] — stride[1]==1 + # c_latent: [kv_batch, seq_len_k, latent_dim] — contiguous # kv_batch is a separate sym_int from query batch: paged KV cache uses a flat - # pool so kv_batch=1 at runtime, while query batch can be any value. + # pool so kv_batch=num_pages at runtime, while query batch can be any value. c_latent_fake = cute.runtime.make_fake_compact_tensor( cutlass_dtype, - (sym_seq_kv, sym_latent, sym_kv_batch), - stride_order=(2, 0, 1), + (sym_kv_batch, sym_seq_kv, sym_latent), + stride_order=(2, 1, 0), assumed_align=128, ) - # c_rope: [seq_len_k, rope_dim, kv_batch] — stride[1]==1 + # c_rope: [kv_batch, seq_len_k, rope_dim] — contiguous c_rope_fake = cute.runtime.make_fake_compact_tensor( cutlass_dtype, - (sym_seq_kv, sym_rope, sym_kv_batch), - stride_order=(2, 0, 1), + (sym_kv_batch, sym_seq_kv, sym_rope), + stride_order=(2, 1, 0), assumed_align=128, ) - # page_table: [page_count, batch_size] with stride[0]==1 - # Matches the original kernel's convention: page_table_ref.permute(1, 0) gives - # strides (1, page_count), so dim0(page_count) is the contiguous dimension. - # This allows passing block_tables.t() directly without .contiguous(). + # page_table: [batch_size, page_count] — contiguous page_table_fake = cute.runtime.make_fake_compact_tensor( cutlass.Int32, - (sym_page_count, sym_batch), - stride_order=(0, 1), + (sym_batch, sym_page_count), + stride_order=(1, 0), assumed_align=128, ) - # o: [num_heads, latent_dim, seq_len_q, batch_size] — stride[1]==1 + # o: [batch_size, seq_len_q, num_heads, latent_dim] — contiguous o_fake = cute.runtime.make_fake_compact_tensor( cutlass_dtype, - (sym_heads, sym_latent, sym_seq_q, sym_batch), - stride_order=(3, 0, 2, 1), + (sym_batch, sym_seq_q, sym_heads, sym_latent), + stride_order=(3, 2, 1, 0), assumed_align=128, ) - # lse: [num_heads, seq_len_q, batch_size] — stride[0]==1 (num_heads dim is contiguous) - # stride_order[d]=rank: dim0 rank=0 means dim0 is fastest → stride[0]=1 compile-time constant + # lse: [batch_size, seq_len_q, num_heads] — contiguous lse_fake = cute.runtime.make_fake_compact_tensor( cutlass.Float32, - (sym_heads, sym_seq_q, sym_batch), - stride_order=(0, 1, 2), + (sym_batch, sym_seq_q, sym_heads), + stride_order=(2, 1, 0), assumed_align=128, ) # workspace: 1-D @@ -275,22 +276,18 @@ def cute_dsl_mla_decode( kv_cache = kv_cache.squeeze(1) page_size = kv_cache.shape[1] - # Split query into latent and rope components and reshape to kernel layout. - # [B, q_len, H, D] -> slice -> permute -> [H, D, q_len, B] with stride[1]=1. - # Do NOT call .contiguous() — permute gives stride[1]=1 which the kernel requires. - q_latent_k = query[..., :kv_lora_rank].permute(2, 3, 1, 0) - q_rope_k = query[..., kv_lora_rank:].permute(2, 3, 1, 0) + # Split query into latent and rope components — keep contiguous [B, q_len, H, D]. + # The kernel's __call__ reinterprets to [H, D, q_len, B] via zero-cost make_tensor. + q_latent_k = query[..., :kv_lora_rank] + q_rope_k = query[..., kv_lora_rank:] - # Reshape KV cache to kernel layout [page_size, D, num_pages]. - # The kernel indexes via page_table: c_latent[intra_page_offset, d, physical_page_idx]. - # After permute: strides = (D_total, 1, page_size*D_total) → stride[1]=1 ✓ - c_latent_k = kv_cache[:, :, :kv_lora_rank].permute(1, 2, 0) - c_rope_k = kv_cache[:, :, kv_lora_rank:].permute(1, 2, 0) + # KV cache slices — keep contiguous [num_pages, page_size, D]. + # The kernel reinterprets to [page_size, D, num_pages] internally. + c_latent_k = kv_cache[:, :, :kv_lora_rank] + c_rope_k = kv_cache[:, :, kv_lora_rank:] - # Page table: [B, max_pages] -> [max_pages, B] (view only, no copy needed). - # The kernel accepts non-contiguous strides via CuTe layout, matching the original - # kernel's convention of page_table_ref.permute(1, 0) without .contiguous(). - page_table_k = block_tables.permute(1, 0) + # Page table: [B, max_pages] — passed directly, kernel reinterprets. + page_table_k = block_tables # Cached split_kv and workspace_size computation max_active_blocks = get_num_sm(query.device) @@ -301,33 +298,28 @@ def cute_dsl_mla_decode( # Prepare workspace — slice of contiguous 1D buffer is already contiguous workspace_bytes = workspace_buffer[: max(workspace_size, 1)] - # Output buffer setup: kernel needs [H, D, q_len, B] with stride[1]==1. - # If caller provides `out`, reuse it directly via permute to avoid allocation + copy_. - # q_len==1: out [B, H, D] → permute(1,2,0) → [H, D, B] → unsqueeze(2) → [H, D, 1, B] - # q_len >1: out [B, q_len, H, D] → permute(2,3,1,0) → [H, D, q_len, B] - # Both give stride[1]=1 ✓, kernel writes directly into out's memory. + # Output buffer: contiguous [B, q_len, H, D]. + # Kernel reinterprets to [H, D, q_len, B] internally via zero-cost make_tensor. out_dtype = torch.float8_e4m3fn if is_fp8 else torch.float16 if out is not None: if q_len == 1: - o_k = out.permute(1, 2, 0).unsqueeze(2) + o_k = out.unsqueeze(1) # [B, H, D] → [B, 1, H, D] else: - o_k = out.permute(2, 3, 1, 0) + o_k = out else: - # Allocate as [B, q_len, H, D] so that permute back is already contiguous. - # permute(2, 3, 1, 0) → [H, D, q_len, B] with stride[1]=1 ✓ o_k = torch.empty( (B, q_len, H, _LATENT_DIM), dtype=out_dtype, device=query.device - ).permute(2, 3, 1, 0) + ) - # LSE: [H, q_len, B] with stride[0]==1 (H dim is contiguous) + # LSE: contiguous [B, q_len, H]. Kernel reinterprets to [H, q_len, B]. lse_k = torch.empty( (B, q_len, H), dtype=torch.float32, device=query.device - ).permute(2, 1, 0) + ) # cache_seqs: per-batch sequence lengths (skip .to() if already int32) cache_seqs = seq_lens if seq_lens.dtype == torch.int32 else seq_lens.to(torch.int32) - # TOOD: this will trigger a kernel. + # TOOD: this will trigger a kernel. Need to remove it. # block_split_kvs: uniform split_kv for all batches block_split_kvs = torch.full((B,), split_kv, dtype=torch.int32, device=query.device) @@ -363,10 +355,7 @@ def cute_dsl_mla_decode( if out is not None: return out - # No out provided: reshape kernel output [H, D, q_len, B] -> [B, (q_len,) H, D] - # The permute back is always contiguous because we allocated as [B, q_len, H, D]. - result = o_k.permute(3, 2, 0, 1) + # o_k is already [B, q_len, H, D] contiguous — just squeeze for q_len==1. if q_len == 1: - result = result.squeeze(1) - - return result + return o_k.squeeze(1) + return o_k diff --git a/flashinfer/cute_dsl/mla_decode_fp16.py b/flashinfer/cute_dsl/mla_decode_fp16.py index c693ae18c5..5ca005b07c 100644 --- a/flashinfer/cute_dsl/mla_decode_fp16.py +++ b/flashinfer/cute_dsl/mla_decode_fp16.py @@ -296,19 +296,19 @@ def __call__( 5. Grid and work scheduling computation 6. Kernel launch(split KV kernel and reduction kernel) with appropriate parameters - :param q_latent: The query tensor with shape [num_head, latent_dim, seq_len_q, batch_size] + :param q_latent: The query tensor with shape [batch_size, seq_len_q, num_head, latent_dim] (contiguous) :type q_latent: cute.Tensor - :param q_rope: The query RoPE tensor with shape [num_head, rope_dim, seq_len_q, batch_size] + :param q_rope: The query RoPE tensor with shape [batch_size, seq_len_q, num_head, rope_dim] (contiguous) :type q_rope: cute.Tensor - :param c_latent: The key tensor with shape [seq_len_k, latent_dim, batch_size] + :param c_latent: The key tensor with shape [num_pages, page_size, latent_dim] (contiguous) :type c_latent: cute.Tensor - :param c_rope: The key RoPE tensor with shape [seq_len_k, rope_dim, batch_size] + :param c_rope: The key RoPE tensor with shape [num_pages, page_size, rope_dim] (contiguous) :type c_rope: cute.Tensor - :param page_table: The page table tensor with shape [page_count, batch_size] + :param page_table: The page table tensor with shape [batch_size, page_count] (contiguous) :type page_table: cute.Tensor - :param o: The output tensor with shape [num_head, latent_dim, seq_len_q, batch_size] + :param o: The output tensor with shape [batch_size, seq_len_q, num_head, latent_dim] (contiguous) :type o: cute.Tensor - :param lse: The LSE tensor with shape [num_head, seq_len_q, batch_size] + :param lse: The LSE tensor with shape [batch_size, seq_len_q, num_head] (contiguous) :type lse: cute.Tensor :param workspace: The workspace tensor with 1-d shape prepared for acc_o and acc_lse :type workspace: cute.Tensor @@ -341,15 +341,53 @@ def __call__( raise TypeError( f"Type mismatch: {self.q_dtype} != {self.k_dtype} or {self.q_dtype} != {self.v_dtype}" ) - # check leading dimensions of input/output - if cutlass.const_expr(q_latent.stride[1] != 1 or q_rope.stride[1] != 1): - raise ValueError("q_latent and q_rope must have leading dimension 1") - if cutlass.const_expr(c_latent.stride[1] != 1 or c_rope.stride[1] != 1): - raise ValueError("c_latent and c_rope must have leading dimension 1") - if cutlass.const_expr(o.stride[1] != 1): - raise ValueError("o must have leading dimension 1") - if cutlass.const_expr(lse.stride[0] != 1): - raise ValueError("lse must have leading dimension 0") + # Reinterpret contiguous [B, S_q, H, D] as [H, D, S_q, B] + # Input stride: (S_q*H*D, H*D, D, 1) → Target: (D, 1, H*D, S_q*H*D) + def _reinterpret_4d(t): + return cute.make_tensor( + t.iterator, + cute.make_layout( + (t.shape[2], t.shape[3], t.shape[1], t.shape[0]), + stride=(t.stride[2], t.stride[3], t.stride[1], t.stride[0]), + ), + ) + + q_latent = _reinterpret_4d(q_latent) + q_rope = _reinterpret_4d(q_rope) + o = _reinterpret_4d(o) + + # Reinterpret contiguous [num_pages, page_size, D] as [page_size, D, num_pages] + # Input stride: (PS*D, D, 1) → Target: (D, 1, PS*D) + def _reinterpret_3d_kv(t): + return cute.make_tensor( + t.iterator, + cute.make_layout( + (t.shape[1], t.shape[2], t.shape[0]), + stride=(t.stride[1], t.stride[2], t.stride[0]), + ), + ) + + c_latent = _reinterpret_3d_kv(c_latent) + c_rope = _reinterpret_3d_kv(c_rope) + + # Reinterpret contiguous [B, page_count] as [page_count, B] + page_table = cute.make_tensor( + page_table.iterator, + cute.make_layout( + (page_table.shape[1], page_table.shape[0]), + stride=(page_table.stride[1], page_table.stride[0]), + ), + ) + + # Reinterpret contiguous [B, S_q, H] as [H, S_q, B] + # Input stride: (S_q*H, H, 1) → Target: (1, H, S_q*H) + lse = cute.make_tensor( + lse.iterator, + cute.make_layout( + (lse.shape[2], lse.shape[1], lse.shape[0]), + stride=(lse.stride[2], lse.stride[1], lse.stride[0]), + ), + ) acc_o, acc_lse = self.initialize_workspace( q_latent.shape[0], @@ -3632,18 +3670,17 @@ def create_data_tensor( if seq_len_q is not None: shape = (B, seq_len_q, HK, D) - permute_order = (1, 2, 0) - stride_order = (2, 0, 1) - leading_dim = 1 + # Contiguous row-major: last dim has stride 1 (highest stride_order value = fastest) if is_lse: shape = (B, seq_len_q, HK) - permute_order = (2, 1, 0) - stride_order = (2, 1, 0) - leading_dim = 0 + leading_dim = 2 + stride_order = (0, 1, 2) elif seq_len_q is not None: - permute_order = (2, 3, 1, 0) - stride_order = (3, 2, 0, 1) - leading_dim = 1 + leading_dim = 3 + stride_order = (0, 1, 2, 3) + else: + leading_dim = 2 + stride_order = (0, 1, 2) init_config = cutlass.torch.RandomInitConfig(min_val=-2, max_val=2) @@ -3651,11 +3688,10 @@ def create_data_tensor( cutlass_torch.dtype(dtype) if dtype != cutlass.Float8E4M3FN else torch.int8 ) - # Create dtype torch tensor (cpu) + # Create contiguous dtype torch tensor (cpu) — no permute torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( shape, torch_dtype, - permute_order=permute_order, init_type=cutlass.torch.TensorInitType.RANDOM, init_config=init_config, ) @@ -3717,9 +3753,9 @@ def create_page_table(batch_size, seq_len_k, is_var_seq, page_size): for b in range(batch_size): for j in range(page_count): page_table_ref[b, j] = b + j * batch_size - page_table_gpu = page_table_ref.permute(1, 0).cuda() + page_table_gpu = page_table_ref.cuda() # contiguous [B, page_count] page_table = from_dlpack(page_table_gpu, assumed_align=16).mark_layout_dynamic( - leading_dim=0 + leading_dim=1 ) return page_table_ref, page_table, page_table_gpu @@ -3908,16 +3944,17 @@ def torch_reference_mla( softmax_scale=1.0, output_scale=1.0, ): - # expand and concat q_latent and q_rope to have the dimension of sequence length for q - q_ref = torch.cat([q_latent, q_rope], dim=1).permute(3, 2, 0, 1) - # expand and concat c_latent and c_rope to have the dimension of num_heads for k and v + # Ref tensors are now contiguous: + # q_latent/q_rope: [B, S_q, H, D] + # c_latent/c_rope: [num_pages, page_size, D] + # Concat along last dim and reshape for SDPA [B, S_q, H, D_total] + q_ref = torch.cat([q_latent, q_rope], dim=3) + # KV cache: concat along last dim, already [num_pages, page_size, D_total] page_count = page_table_ref.shape[1] - k_ref_paged = ( - torch.cat([c_latent, c_rope], dim=1) - .permute(2, 0, 1) - .reshape(batch_size * page_count, page_size, latent_dim + rope_dim) + k_ref_paged = torch.cat([c_latent, c_rope], dim=2).reshape( + batch_size * page_count, page_size, latent_dim + rope_dim ) - v_ref_paged = c_latent.permute(2, 0, 1).reshape( + v_ref_paged = c_latent.reshape( batch_size * page_count, page_size, latent_dim ) @@ -3956,9 +3993,9 @@ def torch_reference_mla( ) lse_ref = s_ref_max * softmax_scale_log2 + torch.log2(s_ref_sum) - lse_ref = lse_ref.squeeze(3).permute(2, 1, 0) + lse_ref = lse_ref.squeeze(3) # [B, S_q, H] o_ref = o_ref * output_scale - o_ref = o_ref.permute(2, 3, 1, 0) + # o_ref already [B, S_q, H, D_latent] — matches contiguous output layout return o_ref, lse_ref @@ -4015,15 +4052,13 @@ def torch_reference_mla( cute.testing.convert(o, o_fp32) o = o_fp32_torch.cpu() ref_fp8, _ = cutlass_torch.cute_tensor_like( - torch.empty( - *o_ref.permute(3, 2, 0, 1).shape, dtype=torch.uint8 - ).permute(2, 3, 1, 0), + torch.empty(*o_ref.shape, dtype=torch.uint8), out_dtype, is_dynamic_layout=True, assumed_align=16, ) o_ref_gpu = o_ref.cuda() - o_ref_f32 = from_dlpack(o_ref_gpu).mark_layout_dynamic(leading_dim=1) + o_ref_f32 = from_dlpack(o_ref_gpu).mark_layout_dynamic(leading_dim=3) # convert ref : f32 -> fp8 -> f32 cute.testing.convert(o_ref_f32, ref_fp8) diff --git a/flashinfer/cute_dsl/mla_decode_fp8.py b/flashinfer/cute_dsl/mla_decode_fp8.py index 4372d8aaa2..121f05585a 100644 --- a/flashinfer/cute_dsl/mla_decode_fp8.py +++ b/flashinfer/cute_dsl/mla_decode_fp8.py @@ -293,19 +293,19 @@ def __call__( 5. Grid and work scheduling computation 6. Kernel launch(split KV kernel and reduction kernel) with appropriate parameters - :param q_latent: The query tensor with shape [num_head, latent_dim, seq_len_q, batch_size] + :param q_latent: The query tensor with shape [batch_size, seq_len_q, num_head, latent_dim] (contiguous) :type q_latent: cute.Tensor - :param q_rope: The query RoPE tensor with shape [num_head, rope_dim, seq_len_q, batch_size] + :param q_rope: The query RoPE tensor with shape [batch_size, seq_len_q, num_head, rope_dim] (contiguous) :type q_rope: cute.Tensor - :param c_latent: The key tensor with shape [seq_len_k, latent_dim, batch_size] + :param c_latent: The key tensor with shape [num_pages, page_size, latent_dim] (contiguous) :type c_latent: cute.Tensor - :param c_rope: The key RoPE tensor with shape [seq_len_k, rope_dim, batch_size] + :param c_rope: The key RoPE tensor with shape [num_pages, page_size, rope_dim] (contiguous) :type c_rope: cute.Tensor - :param page_table: The page table tensor with shape [page_count, batch_size] + :param page_table: The page table tensor with shape [batch_size, page_count] (contiguous) :type page_table: cute.Tensor - :param o: The output tensor with shape [num_head, latent_dim, seq_len_q, batch_size] + :param o: The output tensor with shape [batch_size, seq_len_q, num_head, latent_dim] (contiguous) :type o: cute.Tensor - :param lse: The LSE tensor with shape [num_head, seq_len_q, batch_size] + :param lse: The LSE tensor with shape [batch_size, seq_len_q, num_head] (contiguous) :type lse: cute.Tensor :param workspace: The workspace tensor with 1-d shape prepared for acc_o and acc_lse :type workspace: cute.Tensor @@ -338,15 +338,53 @@ def __call__( raise TypeError( f"Type mismatch: {self.q_dtype} != {self.k_dtype} or {self.q_dtype} != {self.v_dtype}" ) - # check leading dimensions of input/output - if cutlass.const_expr(q_latent.stride[1] != 1 or q_rope.stride[1] != 1): - raise ValueError("q_latent and q_rope must have leading dimension 1") - if cutlass.const_expr(c_latent.stride[1] != 1 or c_rope.stride[1] != 1): - raise ValueError("c_latent and c_rope must have leading dimension 1") - if cutlass.const_expr(o.stride[1] != 1): - raise ValueError("o must have leading dimension 1") - if cutlass.const_expr(lse.stride[0] != 1): - raise ValueError("lse must have leading dimension 0") + # Reinterpret contiguous [B, S_q, H, D] as [H, D, S_q, B] + # Input stride: (S_q*H*D, H*D, D, 1) → Target: (D, 1, H*D, S_q*H*D) + def _reinterpret_4d(t): + return cute.make_tensor( + t.iterator, + cute.make_layout( + (t.shape[2], t.shape[3], t.shape[1], t.shape[0]), + stride=(t.stride[2], t.stride[3], t.stride[1], t.stride[0]), + ), + ) + + q_latent = _reinterpret_4d(q_latent) + q_rope = _reinterpret_4d(q_rope) + o = _reinterpret_4d(o) + + # Reinterpret contiguous [num_pages, page_size, D] as [page_size, D, num_pages] + # Input stride: (PS*D, D, 1) → Target: (D, 1, PS*D) + def _reinterpret_3d_kv(t): + return cute.make_tensor( + t.iterator, + cute.make_layout( + (t.shape[1], t.shape[2], t.shape[0]), + stride=(t.stride[1], t.stride[2], t.stride[0]), + ), + ) + + c_latent = _reinterpret_3d_kv(c_latent) + c_rope = _reinterpret_3d_kv(c_rope) + + # Reinterpret contiguous [B, page_count] as [page_count, B] + page_table = cute.make_tensor( + page_table.iterator, + cute.make_layout( + (page_table.shape[1], page_table.shape[0]), + stride=(page_table.stride[1], page_table.stride[0]), + ), + ) + + # Reinterpret contiguous [B, S_q, H] as [H, S_q, B] + # Input stride: (S_q*H, H, 1) → Target: (1, H, S_q*H) + lse = cute.make_tensor( + lse.iterator, + cute.make_layout( + (lse.shape[2], lse.shape[1], lse.shape[0]), + stride=(lse.stride[2], lse.stride[1], lse.stride[0]), + ), + ) acc_o, acc_lse = self.initialize_workspace( q_latent.shape[0], @@ -3607,18 +3645,17 @@ def create_data_tensor( if seq_len_q is not None: shape = (B, seq_len_q, HK, D) - permute_order = (1, 2, 0) - stride_order = (2, 0, 1) - leading_dim = 1 + # Contiguous row-major: last dim has stride 1 (highest stride_order value = fastest) if is_lse: shape = (B, seq_len_q, HK) - permute_order = (2, 1, 0) - stride_order = (2, 1, 0) - leading_dim = 0 + leading_dim = 2 + stride_order = (0, 1, 2) elif seq_len_q is not None: - permute_order = (2, 3, 1, 0) - stride_order = (3, 2, 0, 1) - leading_dim = 1 + leading_dim = 3 + stride_order = (0, 1, 2, 3) + else: + leading_dim = 2 + stride_order = (0, 1, 2) init_config = cutlass.torch.RandomInitConfig(min_val=-2, max_val=2) @@ -3626,11 +3663,10 @@ def create_data_tensor( cutlass_torch.dtype(dtype) if dtype != cutlass.Float8E4M3FN else torch.int8 ) - # Create dtype torch tensor (cpu) + # Create contiguous dtype torch tensor (cpu) — no permute torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( shape, torch_dtype, - permute_order=permute_order, init_type=cutlass.torch.TensorInitType.RANDOM, init_config=init_config, ) @@ -3692,9 +3728,9 @@ def create_page_table(batch_size, seq_len_k, is_var_seq, page_size): for b in range(batch_size): for j in range(page_count): page_table_ref[b, j] = b + j * batch_size - page_table_gpu = page_table_ref.permute(1, 0).cuda() + page_table_gpu = page_table_ref.cuda() # contiguous [B, page_count] page_table = from_dlpack(page_table_gpu, assumed_align=16).mark_layout_dynamic( - leading_dim=0 + leading_dim=1 ) return page_table_ref, page_table, page_table_gpu @@ -3881,16 +3917,17 @@ def torch_reference_mla( softmax_scale=1.0, output_scale=1.0, ): - # expand and concat q_latent and q_rope to have the dimension of sequence length for q - q_ref = torch.cat([q_latent, q_rope], dim=1).permute(3, 2, 0, 1) - # expand and concat c_latent and c_rope to have the dimension of num_heads for k and v + # Ref tensors are now contiguous: + # q_latent/q_rope: [B, S_q, H, D] + # c_latent/c_rope: [num_pages, page_size, D] + # Concat along last dim and reshape for SDPA [B, S_q, H, D_total] + q_ref = torch.cat([q_latent, q_rope], dim=3) + # KV cache: concat along last dim, already [num_pages, page_size, D_total] page_count = page_table_ref.shape[1] - k_ref_paged = ( - torch.cat([c_latent, c_rope], dim=1) - .permute(2, 0, 1) - .reshape(batch_size * page_count, page_size, latent_dim + rope_dim) + k_ref_paged = torch.cat([c_latent, c_rope], dim=2).reshape( + batch_size * page_count, page_size, latent_dim + rope_dim ) - v_ref_paged = c_latent.permute(2, 0, 1).reshape( + v_ref_paged = c_latent.reshape( batch_size * page_count, page_size, latent_dim ) @@ -3929,9 +3966,9 @@ def torch_reference_mla( ) lse_ref = s_ref_max * softmax_scale_log2 + torch.log2(s_ref_sum) - lse_ref = lse_ref.squeeze(3).permute(2, 1, 0) + lse_ref = lse_ref.squeeze(3) # [B, S_q, H] o_ref = o_ref * output_scale - o_ref = o_ref.permute(2, 3, 1, 0) + # o_ref already [B, S_q, H, D_latent] — matches contiguous output layout return o_ref, lse_ref @@ -3988,15 +4025,13 @@ def torch_reference_mla( cute.testing.convert(o, o_fp32) o = o_fp32_torch.cpu() ref_fp8, _ = cutlass_torch.cute_tensor_like( - torch.empty( - *o_ref.permute(3, 2, 0, 1).shape, dtype=torch.uint8 - ).permute(2, 3, 1, 0), + torch.empty(*o_ref.shape, dtype=torch.uint8), out_dtype, is_dynamic_layout=True, assumed_align=16, ) o_ref_gpu = o_ref.cuda() - o_ref_f32 = from_dlpack(o_ref_gpu).mark_layout_dynamic(leading_dim=1) + o_ref_f32 = from_dlpack(o_ref_gpu).mark_layout_dynamic(leading_dim=3) # convert ref : f32 -> fp8 -> f32 cute.testing.convert(o_ref_f32, ref_fp8) From 7420b6b124e9708e007751ddd9eb3ab1744aed35 Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Tue, 10 Mar 2026 04:21:51 -0700 Subject: [PATCH 07/31] feat: Add is_var_split_kv parameter and workspace size check to cute_dsl_mla_decode - Expose is_var_split_kv as a public parameter (default False) to control whether to use per-batch variable split_kv or uniform scalar split_kv, avoiding a torch.full GPU kernel (~5 us) when not needed. - Add workspace_buffer size assertion to catch undersized buffers early. Co-Authored-By: Claude Opus 4.6 --- flashinfer/cute_dsl/mla_decode.py | 45 ++++++++++++++++++++++--------- 1 file changed, 33 insertions(+), 12 deletions(-) diff --git a/flashinfer/cute_dsl/mla_decode.py b/flashinfer/cute_dsl/mla_decode.py index 1a64491669..4ac40735b5 100644 --- a/flashinfer/cute_dsl/mla_decode.py +++ b/flashinfer/cute_dsl/mla_decode.py @@ -181,12 +181,15 @@ def _get_compiled_mla_kernel( (sym_batch,), assumed_align=128, ) - # block_split_kvs: [batch_size] — int32 - block_split_kvs_fake = cute.runtime.make_fake_compact_tensor( - cutlass.Int32, - (sym_batch,), - assumed_align=128, - ) + # block_split_kvs: [batch_size] — int32 (only needed for is_var_split_kv=True) + if is_var_split_kv: + block_split_kvs_fake = cute.runtime.make_fake_compact_tensor( + cutlass.Int32, + (sym_batch,), + assumed_align=128, + ) + else: + block_split_kvs_fake = None stream_fake = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) @@ -211,7 +214,10 @@ def _get_compiled_mla_kernel( return compiled_kernel - +# TODO: need to tell users the max size of the workspace in the doc, so that they can allocate the workspace_buffer. +# TODO: how to set split_kv, is_persistent, is_var_seq, is_var_split_kv? +# TODO: check if page_size setup is right. +# TODO: query[..., :kv_lora_rank], do we need to remove such kind of slice and move the logic to call routine in the kernel file. def cute_dsl_mla_decode( query: torch.Tensor, kv_cache: torch.Tensor, @@ -223,6 +229,7 @@ def cute_dsl_mla_decode( max_seq_len: int, softmax_scale: float, output_scale: float = 1.0, + is_var_split_kv: bool = False, out: Optional[torch.Tensor] = None, ) -> torch.Tensor: """CuTe DSL MLA decode kernel for Blackwell SM100. @@ -249,6 +256,10 @@ def cute_dsl_mla_decode( Scale factor for QK^T before softmax. output_scale : float Scale factor applied to the output. + is_var_split_kv : bool + Whether to use variable split_kv per batch. When False (default), + uses a uniform scalar split_kv, avoiding a torch.full GPU kernel. + When True, allocates a per-batch block_split_kvs tensor. out : Optional[torch.Tensor] Pre-allocated output tensor [B, H, kv_lora_rank]. @@ -296,6 +307,10 @@ def cute_dsl_mla_decode( ) # Prepare workspace — slice of contiguous 1D buffer is already contiguous + assert workspace_buffer.numel() >= workspace_size, ( + f"workspace_buffer too small: {workspace_buffer.numel()} bytes, " + f"need {workspace_size} bytes" + ) workspace_bytes = workspace_buffer[: max(workspace_size, 1)] # Output buffer: contiguous [B, q_len, H, D]. @@ -318,10 +333,16 @@ def cute_dsl_mla_decode( # cache_seqs: per-batch sequence lengths (skip .to() if already int32) cache_seqs = seq_lens if seq_lens.dtype == torch.int32 else seq_lens.to(torch.int32) - - # TOOD: this will trigger a kernel. Need to remove it. - # block_split_kvs: uniform split_kv for all batches - block_split_kvs = torch.full((B,), split_kv, dtype=torch.int32, device=query.device) + + # block_split_kvs: only needed when is_var_split_kv=True + if is_var_split_kv: + # TODO: this will trigger a kernel. + # TODO: need to align with the test in kernel file. + block_split_kvs = torch.full( + (B,), split_kv, dtype=torch.int32, device=query.device + ) + else: + block_split_kvs = None # Get compiled kernel (cached after first compile) compiled_kernel = _get_compiled_mla_kernel( @@ -331,7 +352,7 @@ def cute_dsl_mla_decode( seq_len_q=q_len, is_persistent=True, is_var_seq=True, - is_var_split_kv=True, + is_var_split_kv=is_var_split_kv, ) # Call the kernel From c480145465543bdfeea27662e922af46942e75f4 Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Tue, 10 Mar 2026 04:24:36 -0700 Subject: [PATCH 08/31] style: Fix trailing whitespace and ruff formatting Co-Authored-By: Claude Opus 4.6 --- flashinfer/cute_dsl/mla_decode.py | 7 +++---- flashinfer/cute_dsl/mla_decode_fp16.py | 5 ++--- flashinfer/cute_dsl/mla_decode_fp8.py | 5 ++--- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/flashinfer/cute_dsl/mla_decode.py b/flashinfer/cute_dsl/mla_decode.py index 4ac40735b5..e39e68f27e 100644 --- a/flashinfer/cute_dsl/mla_decode.py +++ b/flashinfer/cute_dsl/mla_decode.py @@ -214,6 +214,7 @@ def _get_compiled_mla_kernel( return compiled_kernel + # TODO: need to tell users the max size of the workspace in the doc, so that they can allocate the workspace_buffer. # TODO: how to set split_kv, is_persistent, is_var_seq, is_var_split_kv? # TODO: check if page_size setup is right. @@ -327,16 +328,14 @@ def cute_dsl_mla_decode( ) # LSE: contiguous [B, q_len, H]. Kernel reinterprets to [H, q_len, B]. - lse_k = torch.empty( - (B, q_len, H), dtype=torch.float32, device=query.device - ) + lse_k = torch.empty((B, q_len, H), dtype=torch.float32, device=query.device) # cache_seqs: per-batch sequence lengths (skip .to() if already int32) cache_seqs = seq_lens if seq_lens.dtype == torch.int32 else seq_lens.to(torch.int32) # block_split_kvs: only needed when is_var_split_kv=True if is_var_split_kv: - # TODO: this will trigger a kernel. + # TODO: this will trigger a kernel. # TODO: need to align with the test in kernel file. block_split_kvs = torch.full( (B,), split_kv, dtype=torch.int32, device=query.device diff --git a/flashinfer/cute_dsl/mla_decode_fp16.py b/flashinfer/cute_dsl/mla_decode_fp16.py index 5ca005b07c..3806d41be6 100644 --- a/flashinfer/cute_dsl/mla_decode_fp16.py +++ b/flashinfer/cute_dsl/mla_decode_fp16.py @@ -341,6 +341,7 @@ def __call__( raise TypeError( f"Type mismatch: {self.q_dtype} != {self.k_dtype} or {self.q_dtype} != {self.v_dtype}" ) + # Reinterpret contiguous [B, S_q, H, D] as [H, D, S_q, B] # Input stride: (S_q*H*D, H*D, D, 1) → Target: (D, 1, H*D, S_q*H*D) def _reinterpret_4d(t): @@ -3954,9 +3955,7 @@ def torch_reference_mla( k_ref_paged = torch.cat([c_latent, c_rope], dim=2).reshape( batch_size * page_count, page_size, latent_dim + rope_dim ) - v_ref_paged = c_latent.reshape( - batch_size * page_count, page_size, latent_dim - ) + v_ref_paged = c_latent.reshape(batch_size * page_count, page_size, latent_dim) if is_var_seq: max_seq_len = torch.max(cache_seqs_ref) diff --git a/flashinfer/cute_dsl/mla_decode_fp8.py b/flashinfer/cute_dsl/mla_decode_fp8.py index 121f05585a..6ad55d99dd 100644 --- a/flashinfer/cute_dsl/mla_decode_fp8.py +++ b/flashinfer/cute_dsl/mla_decode_fp8.py @@ -338,6 +338,7 @@ def __call__( raise TypeError( f"Type mismatch: {self.q_dtype} != {self.k_dtype} or {self.q_dtype} != {self.v_dtype}" ) + # Reinterpret contiguous [B, S_q, H, D] as [H, D, S_q, B] # Input stride: (S_q*H*D, H*D, D, 1) → Target: (D, 1, H*D, S_q*H*D) def _reinterpret_4d(t): @@ -3927,9 +3928,7 @@ def torch_reference_mla( k_ref_paged = torch.cat([c_latent, c_rope], dim=2).reshape( batch_size * page_count, page_size, latent_dim + rope_dim ) - v_ref_paged = c_latent.reshape( - batch_size * page_count, page_size, latent_dim - ) + v_ref_paged = c_latent.reshape(batch_size * page_count, page_size, latent_dim) if is_var_seq: max_seq_len = torch.max(cache_seqs_ref) From b3b0f8b97eb31d194f22e057163bfc1f31d73756 Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Tue, 10 Mar 2026 19:33:30 -0700 Subject: [PATCH 09/31] perf: Simplify split_kv computation and remove is_var_split_kv parameter - Add get_split_kv_simplified() that computes split_kv without max_seq_len - Remove is_var_split_kv from public API (hardcode False), eliminating torch.full GPU kernel overhead per call - Remove unused bench_cute_dsl_mla_host_overhead.py Co-Authored-By: Claude Opus 4.6 --- flashinfer/cute_dsl/mla_decode.py | 23 +++++------------------ flashinfer/cute_dsl/mla_decode_fp16.py | 8 ++++++++ flashinfer/cute_dsl/mla_decode_fp8.py | 8 ++++++++ 3 files changed, 21 insertions(+), 18 deletions(-) diff --git a/flashinfer/cute_dsl/mla_decode.py b/flashinfer/cute_dsl/mla_decode.py index e39e68f27e..b9b7b4d007 100644 --- a/flashinfer/cute_dsl/mla_decode.py +++ b/flashinfer/cute_dsl/mla_decode.py @@ -46,13 +46,12 @@ def _get_split_kv_and_workspace_size( B: int, q_len: int, - max_seq_len: int, H: int, max_active_blocks: int, ) -> Tuple[int, int]: """Cache split_kv and workspace_size since they are deterministic for the same params.""" - split_kv = BlackwellMultiHeadLatentAttentionForwardFP16.get_split_kv( - B, q_len, max_seq_len, _MMA_QK_TILER_MN, max_active_blocks + split_kv = BlackwellMultiHeadLatentAttentionForwardFP16.get_split_kv_simplified( + B, q_len, max_active_blocks ) workspace_size = BlackwellMultiHeadLatentAttentionForwardFP16.get_workspace_size( H, q_len, _LATENT_DIM, B, split_kv, cutlass.Float32 @@ -230,7 +229,6 @@ def cute_dsl_mla_decode( max_seq_len: int, softmax_scale: float, output_scale: float = 1.0, - is_var_split_kv: bool = False, out: Optional[torch.Tensor] = None, ) -> torch.Tensor: """CuTe DSL MLA decode kernel for Blackwell SM100. @@ -257,10 +255,6 @@ def cute_dsl_mla_decode( Scale factor for QK^T before softmax. output_scale : float Scale factor applied to the output. - is_var_split_kv : bool - Whether to use variable split_kv per batch. When False (default), - uses a uniform scalar split_kv, avoiding a torch.full GPU kernel. - When True, allocates a per-batch block_split_kvs tensor. out : Optional[torch.Tensor] Pre-allocated output tensor [B, H, kv_lora_rank]. @@ -304,7 +298,7 @@ def cute_dsl_mla_decode( # Cached split_kv and workspace_size computation max_active_blocks = get_num_sm(query.device) split_kv, workspace_size = _get_split_kv_and_workspace_size( - B, q_len, max_seq_len, H, max_active_blocks + B, q_len, H, max_active_blocks ) # Prepare workspace — slice of contiguous 1D buffer is already contiguous @@ -333,15 +327,8 @@ def cute_dsl_mla_decode( # cache_seqs: per-batch sequence lengths (skip .to() if already int32) cache_seqs = seq_lens if seq_lens.dtype == torch.int32 else seq_lens.to(torch.int32) - # block_split_kvs: only needed when is_var_split_kv=True - if is_var_split_kv: - # TODO: this will trigger a kernel. - # TODO: need to align with the test in kernel file. - block_split_kvs = torch.full( - (B,), split_kv, dtype=torch.int32, device=query.device - ) - else: - block_split_kvs = None + is_var_split_kv = False + block_split_kvs = None # Get compiled kernel (cached after first compile) compiled_kernel = _get_compiled_mla_kernel( diff --git a/flashinfer/cute_dsl/mla_decode_fp16.py b/flashinfer/cute_dsl/mla_decode_fp16.py index 3806d41be6..905654923a 100644 --- a/flashinfer/cute_dsl/mla_decode_fp16.py +++ b/flashinfer/cute_dsl/mla_decode_fp16.py @@ -1430,6 +1430,14 @@ def get_split_kv( split_wave_aware = ceil_div(max_splits, k_waves) max_split_kv = 32 return min(split_wave_aware, max_split_kv) + + @staticmethod + def get_split_kv_simplified( + B: int, S: int, max_active_blocks: int + ) -> int: + blocks_per_batch = max(1, max_active_blocks // B // (S * 2)) + max_split_kv = 32 + return min(blocks_per_batch, max_split_kv) @cute.jit def get_k_tile_count( diff --git a/flashinfer/cute_dsl/mla_decode_fp8.py b/flashinfer/cute_dsl/mla_decode_fp8.py index 6ad55d99dd..28548c57b8 100644 --- a/flashinfer/cute_dsl/mla_decode_fp8.py +++ b/flashinfer/cute_dsl/mla_decode_fp8.py @@ -1494,6 +1494,14 @@ def get_split_kv( split_wave_aware = ceil_div(max_splits, k_waves) max_split_kv = 32 return min(split_wave_aware, max_split_kv) + + @staticmethod + def get_split_kv_simplified( + B: int, S: int, max_active_blocks: int + ) -> int: + blocks_per_batch = max(1, max_active_blocks // B // (S * 2)) + max_split_kv = 32 + return min(blocks_per_batch, max_split_kv) @cute.jit def get_k_tile_count( From 020fea5810ef1f90cc2d5b74e75f19a2a12c3353 Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Tue, 10 Mar 2026 21:07:37 -0700 Subject: [PATCH 10/31] feat: Add BFloat16 support to CuTe DSL MLA decode kernel - Add torch_to_cutlass_dtype() in utils.py for torch.dtype -> cutlass dtype conversion - Extend mla_decode_fp16.py can_implement() to accept BFloat16 - Refactor mla_decode.py to support float16/bfloat16/float8_e4m3fn via dtype-aware dispatch - Add BFloat16 parametrization to test_cute_dsl_mla_decode_fp16 test - Add backend parameter to bench_trtllm_gen_mla.py benchmark - Remove unused bench_cute_dsl_mla_host_overhead.py Co-Authored-By: Claude Opus 4.6 --- benchmarks/bench_trtllm_gen_mla.py | 47 +++- flashinfer/cute_dsl/__init__.py | 2 + flashinfer/cute_dsl/mla_decode.py | 84 +++++-- flashinfer/cute_dsl/mla_decode_fp16.py | 4 +- flashinfer/cute_dsl/utils.py | 14 ++ .../bench_cute_dsl_mla_host_overhead.py | 229 ------------------ tests/attention/test_cute_dsl_mla_decode.py | 23 +- 7 files changed, 135 insertions(+), 268 deletions(-) delete mode 100644 tests/attention/bench_cute_dsl_mla_host_overhead.py diff --git a/benchmarks/bench_trtllm_gen_mla.py b/benchmarks/bench_trtllm_gen_mla.py index 7f09050fe1..8846f330ab 100644 --- a/benchmarks/bench_trtllm_gen_mla.py +++ b/benchmarks/bench_trtllm_gen_mla.py @@ -10,7 +10,7 @@ kv_lora_rank = 512 -def bench_trtllm_mla(batch_size, q_len_per_request, seq_len, page_size, dtype): +def bench_trtllm_mla(batch_size, q_len_per_request, seq_len, page_size, dtype, backend="auto"): torch.manual_seed(42) device = "cuda:0" @@ -81,6 +81,7 @@ def bench_trtllm_mla(batch_size, q_len_per_request, seq_len, page_size, dtype): max_seq_len=max_seq_len, bmm1_scale=1.0 / ((128 + 64) ** 0.5), bmm2_scale=1.0, + backend=backend, ) # benchmark measurements = bench_gpu_time( @@ -96,6 +97,7 @@ def bench_trtllm_mla(batch_size, q_len_per_request, seq_len, page_size, dtype): max_seq_len=max_seq_len, bmm1_scale=1.0 / ((128 + 64) ** 0.5), bmm2_scale=1.0, + backend=backend, ), dry_run_iters=5, repeat_iters=30, @@ -126,7 +128,7 @@ def bench_trtllm_mla(batch_size, q_len_per_request, seq_len, page_size, dtype): * q_len_per_request ) print( - f"batch_size={batch_size}, q_len_per_request={q_len_per_request}, seq_len={seq_len}, num_q_heads={num_q_heads}, qk_nope_head_dim={qk_nope_head_dim}, qk_rope_head_dim={qk_rope_head_dim}, kv_lora_rank={kv_lora_rank}, page_size={page_size}" + f"backend={backend}, batch_size={batch_size}, q_len_per_request={q_len_per_request}, seq_len={seq_len}, num_q_heads={num_q_heads}, qk_nope_head_dim={qk_nope_head_dim}, qk_rope_head_dim={qk_rope_head_dim}, kv_lora_rank={kv_lora_rank}, page_size={page_size}" ) print(f"execution time: {ms:.4f} ms") print(f"memory bandwidth: {total_mem_bytes / ms / 1e6:.2f} GB/s") @@ -134,11 +136,44 @@ def bench_trtllm_mla(batch_size, q_len_per_request, seq_len, page_size, dtype): if __name__ == "__main__": - for dtype in [torch.bfloat16, torch.float8_e4m3fn]: + import argparse + + parser = argparse.ArgumentParser(description="Benchmark trtllm MLA decode") + parser.add_argument( + "--backend", + type=str, + default="auto", + help="Backend to use (auto, trtllm-gen, cute-dsl)", + ) + args = parser.parse_args() + + # cute-dsl only supports float16 and float8_e4m3fn + if args.backend == "cute-dsl": + dtypes = [torch.float16, torch.float8_e4m3fn] + else: + dtypes = [torch.bfloat16, torch.float8_e4m3fn] + + for dtype in dtypes: for page_size in [32, 64]: for batch_size in [1, 2, 4, 16, 32, 64, 128, 256, 512, 768, 1024]: for seq_len in [1024, 4096, 8192]: for q_len_per_request in [1, 2, 4, 8, 16]: - bench_trtllm_mla( - batch_size, q_len_per_request, seq_len, page_size, dtype - ) + try: + bench_trtllm_mla( + batch_size, + q_len_per_request, + seq_len, + page_size, + dtype, + backend=args.backend, + ) + except ValueError as e: + print(f"SKIPPED: {e}") + print() + except Exception as e: + print( + f"ERROR: batch_size={batch_size}, q_len={q_len_per_request}, " + f"seq_len={seq_len}, page_size={page_size}, dtype={dtype}, " + f"backend={args.backend}: {type(e).__name__}: {e}" + ) + print() diff --git a/flashinfer/cute_dsl/__init__.py b/flashinfer/cute_dsl/__init__.py index 6b9b31a2af..c667ddcef5 100644 --- a/flashinfer/cute_dsl/__init__.py +++ b/flashinfer/cute_dsl/__init__.py @@ -29,6 +29,7 @@ is_cute_dsl_available, make_ptr, get_cutlass_dtype, + torch_to_cutlass_dtype, get_num_sm, convert_sf_to_mma_layout, convert_sf_from_mma_layout, @@ -60,6 +61,7 @@ "is_cute_dsl_available", "make_ptr", "get_cutlass_dtype", + "torch_to_cutlass_dtype", "get_num_sm", # Scale factor layout conversion utilities "convert_sf_to_mma_layout", diff --git a/flashinfer/cute_dsl/mla_decode.py b/flashinfer/cute_dsl/mla_decode.py index b9b7b4d007..f0f42cbd4c 100644 --- a/flashinfer/cute_dsl/mla_decode.py +++ b/flashinfer/cute_dsl/mla_decode.py @@ -16,7 +16,7 @@ CuTe DSL MLA Decode Kernel Integration ======================================= -Wraps NVIDIA's CuTe DSL MLA decode kernels (FP16/FP8) for Blackwell SM100 +Wraps NVIDIA's CuTe DSL MLA decode kernels (FP16/BF16/FP8) for Blackwell SM100 and exposes them via a PyTorch API compatible with FlashInfer's MLA backend. """ @@ -30,7 +30,7 @@ from .mla_decode_fp16 import BlackwellMultiHeadLatentAttentionForwardFP16 from .mla_decode_fp8 import BlackwellMultiHeadLatentAttentionForwardFP8 -from .utils import get_num_sm +from .utils import get_num_sm, torch_to_cutlass_dtype # Default kernel configuration — matches DeepSeek-V2/V3 MLA dimensions @@ -41,6 +41,8 @@ _MAX_ACTIVE_CLUSTERS = 2 _SKIP_CORRECTION_THRESHOLD = 0.0 +_SUPPORTED_DTYPES = {torch.float16, torch.bfloat16, torch.float8_e4m3fn} + @functools.cache def _get_split_kv_and_workspace_size( @@ -61,7 +63,7 @@ def _get_split_kv_and_workspace_size( @functools.cache def _get_compiled_mla_kernel( - is_fp8: bool, + torch_dtype: torch.dtype, page_size: int, num_heads: int, seq_len_q: int, @@ -77,12 +79,39 @@ def _get_compiled_mla_kernel( All scalar arguments must be pre-wrapped as Int32/Float32. """ + is_fp8 = torch_dtype == torch.float8_e4m3fn KernelClass = ( BlackwellMultiHeadLatentAttentionForwardFP8 if is_fp8 else BlackwellMultiHeadLatentAttentionForwardFP16 ) + cutlass_dtype = torch_to_cutlass_dtype(torch_dtype) + if not KernelClass.can_implement( + 1, # B (runtime, use placeholder) + seq_len_q, + 1, # K (runtime, use placeholder) + num_heads, + _LATENT_DIM, + _ROPE_DIM, + cutlass_dtype, + cutlass_dtype, + cutlass.Float32, + cutlass.Float32, + _MMA_QK_TILER_MN, + _MMA_PV_TILER_MN, + 1, # split_kv (runtime, use 1 to pass the H<128 check) + is_persistent, + is_var_seq, + is_var_split_kv, + page_size, + ): + raise ValueError( + f"cute_dsl_mla_decode: unsupported configuration " + f"(q_len={seq_len_q}, num_heads={num_heads}, page_size={page_size}, " + f"dtype={torch_dtype})" + ) + kernel_obj = KernelClass( acc_dtype=cutlass.Float32, lse_dtype=cutlass.Float32, @@ -96,8 +125,6 @@ def _get_compiled_mla_kernel( is_var_split_kv=is_var_split_kv, ) - cutlass_dtype = cutlass.Float8E4M3FN if is_fp8 else cutlass.Float16 - # All dimensions as sym_int — this matches the original kernel's use of # mark_compact_shape_dynamic, which makes ALL shapes dynamic CuTe Integers. # Static Python ints would cause cute.assume() to fail with AttributeError @@ -168,9 +195,9 @@ def _get_compiled_mla_kernel( stride_order=(2, 1, 0), assumed_align=128, ) - # workspace: 1-D + # workspace: 1-D (int8 to match typical torch workspace buffers) workspace_fake = cute.runtime.make_fake_compact_tensor( - cutlass.Uint8, + cutlass.Int8, (sym_workspace_size,), assumed_align=128, ) @@ -214,9 +241,6 @@ def _get_compiled_mla_kernel( return compiled_kernel -# TODO: need to tell users the max size of the workspace in the doc, so that they can allocate the workspace_buffer. -# TODO: how to set split_kv, is_persistent, is_var_seq, is_var_split_kv? -# TODO: check if page_size setup is right. # TODO: query[..., :kv_lora_rank], do we need to remove such kind of slice and move the logic to call routine in the kernel file. def cute_dsl_mla_decode( query: torch.Tensor, @@ -240,7 +264,13 @@ def cute_dsl_mla_decode( kv_cache : torch.Tensor [num_pages, page_size, D_ckv + D_kpe] (3D) or [num_pages, 1, page_size, D_ckv + D_kpe] (4D) workspace_buffer : torch.Tensor - Pre-allocated workspace buffer. + Pre-allocated workspace buffer (uint8). Required size depends on batch size + and split_kv (auto-computed from B, q_len, and number of SMs): + + - Formula: ``B * H * q_len * split_kv * (kv_lora_rank + 1) * 4`` bytes + (0 when split_kv == 1, which happens when B >= num_SMs / 2) + - Typical max: ~18 MB on a 148-SM GPU (e.g. B=4..8, H=128, D=512) + - Safe default: 128 MB covers all realistic configurations kv_lora_rank : int Latent dimension (e.g. 512). qk_rope_head_dim : int @@ -263,10 +293,9 @@ def cute_dsl_mla_decode( torch.Tensor Output tensor [B, H, kv_lora_rank]. """ - assert query.dtype in ( - torch.float16, - torch.float8_e4m3fn, - ), f"cute_dsl_mla_decode only supports float16 and float8_e4m3fn, got {query.dtype}" + assert query.dtype in _SUPPORTED_DTYPES, ( + f"cute_dsl_mla_decode only supports {_SUPPORTED_DTYPES}, got {query.dtype}" + ) assert kv_cache.dtype == query.dtype, ( f"kv_cache dtype {kv_cache.dtype} must match query dtype {query.dtype}" ) @@ -275,7 +304,8 @@ def cute_dsl_mla_decode( assert kv_lora_rank == _LATENT_DIM assert qk_rope_head_dim == _ROPE_DIM - is_fp8 = query.dtype == torch.float8_e4m3fn + q_dtype = query.dtype + is_fp8 = q_dtype == torch.float8_e4m3fn # Handle 3D vs 4D kv_cache: normalize to 3D [num_pages, page_size, D_total] if kv_cache.dim() == 4: @@ -292,16 +322,30 @@ def cute_dsl_mla_decode( c_latent_k = kv_cache[:, :, :kv_lora_rank] c_rope_k = kv_cache[:, :, kv_lora_rank:] - # Page table: [B, max_pages] — passed directly, kernel reinterprets. + # Page table: [B, max_pages]: passed directly, kernel reinterprets. page_table_k = block_tables + # Runtime validation (int comparisons only, negligible overhead) + if max_seq_len <= 0: + raise ValueError(f"max_seq_len must be > 0, got {max_seq_len}") + if H < 128 and H != 1: + raise ValueError( + f"cute_dsl_mla_decode requires num_heads == 128 (or 1), got {H}" + ) + # Cached split_kv and workspace_size computation max_active_blocks = get_num_sm(query.device) split_kv, workspace_size = _get_split_kv_and_workspace_size( B, q_len, H, max_active_blocks ) - # Prepare workspace — slice of contiguous 1D buffer is already contiguous + if H < 128 and split_kv != 1: + raise ValueError( + f"cute_dsl_mla_decode: num_heads={H} < 128 requires split_kv==1, " + f"got split_kv={split_kv}" + ) + + # Prepare workspace: slice of contiguous 1D buffer is already contiguous assert workspace_buffer.numel() >= workspace_size, ( f"workspace_buffer too small: {workspace_buffer.numel()} bytes, " f"need {workspace_size} bytes" @@ -310,7 +354,7 @@ def cute_dsl_mla_decode( # Output buffer: contiguous [B, q_len, H, D]. # Kernel reinterprets to [H, D, q_len, B] internally via zero-cost make_tensor. - out_dtype = torch.float8_e4m3fn if is_fp8 else torch.float16 + out_dtype = q_dtype if out is not None: if q_len == 1: o_k = out.unsqueeze(1) # [B, H, D] → [B, 1, H, D] @@ -332,7 +376,7 @@ def cute_dsl_mla_decode( # Get compiled kernel (cached after first compile) compiled_kernel = _get_compiled_mla_kernel( - is_fp8=is_fp8, + torch_dtype=q_dtype, page_size=page_size, num_heads=H, seq_len_q=q_len, diff --git a/flashinfer/cute_dsl/mla_decode_fp16.py b/flashinfer/cute_dsl/mla_decode_fp16.py index 905654923a..e3fa3a833c 100644 --- a/flashinfer/cute_dsl/mla_decode_fp16.py +++ b/flashinfer/cute_dsl/mla_decode_fp16.py @@ -3490,9 +3490,9 @@ def can_implement( """ if L != 512 or R != 64: return False - if in_dtype not in [cutlass.Float16]: + if in_dtype not in [cutlass.Float16, cutlass.BFloat16]: return False - if out_dtype not in [cutlass.Float16]: + if out_dtype not in [cutlass.Float16, cutlass.BFloat16]: return False if acc_dtype != cutlass.Float32 or lse_dtype != cutlass.Float32: return False diff --git a/flashinfer/cute_dsl/utils.py b/flashinfer/cute_dsl/utils.py index b61f83ba57..653bbb5c40 100644 --- a/flashinfer/cute_dsl/utils.py +++ b/flashinfer/cute_dsl/utils.py @@ -51,6 +51,20 @@ def get_cutlass_dtype(dtype: str) -> cutlass.dtype: return dtype_map[dtype] +def torch_to_cutlass_dtype(dtype: torch.dtype) -> cutlass.dtype: + """Return the corresponding cutlass dtype for the given torch.dtype.""" + dtype_map = { + torch.float16: cutlass.Float16, + torch.bfloat16: cutlass.BFloat16, + torch.float32: cutlass.Float32, + torch.float8_e5m2: cutlass.Float8E5M2, + torch.float8_e4m3fn: cutlass.Float8E4M3FN, + } + if dtype not in dtype_map: + raise TypeError(f"{dtype} is not supported by cutlass") + return dtype_map[dtype] + + def cutlass_to_torch_dtype(cutlass_dtype): """ Return the corresponding torch.dtype per the given DSL type diff --git a/tests/attention/bench_cute_dsl_mla_host_overhead.py b/tests/attention/bench_cute_dsl_mla_host_overhead.py deleted file mode 100644 index 79e74de236..0000000000 --- a/tests/attention/bench_cute_dsl_mla_host_overhead.py +++ /dev/null @@ -1,229 +0,0 @@ -#!/usr/bin/env python3 -"""Benchmark host overhead of cute_dsl_mla_decode. - -Measures Python-side overhead by timing many iterations without GPU sync -between them — the GPU queue never drains so we're measuring purely -host-side work (tensor reshaping, TVM-FFI dispatch, etc.). -""" - -import time - -import torch - -from flashinfer.cute_dsl.mla_decode import cute_dsl_mla_decode - - -def bench_host_overhead( - batch_size: int = 4, - seq_len_k: int = 2048, - page_size: int = 128, - num_iters: int = 1000, - warmup_iters: int = 50, -): - device = torch.device("cuda") - num_heads = 128 - latent_dim = 512 - rope_dim = 64 - q_len = 1 - D_qk = latent_dim + rope_dim - softmax_scale = 1.0 / (latent_dim**0.5) - output_scale = 1.0 - - query = torch.randn(batch_size, q_len, num_heads, D_qk, dtype=torch.float16, device=device) - num_pages_per_batch = (seq_len_k + page_size - 1) // page_size - total_pages = num_pages_per_batch * batch_size + 10 - kv_cache = torch.randn(total_pages, page_size, D_qk, dtype=torch.float16, device=device) - - block_tables = torch.zeros(batch_size, num_pages_per_batch, dtype=torch.int32, device=device) - for b in range(batch_size): - for p in range(num_pages_per_batch): - block_tables[b, p] = b * num_pages_per_batch + p - - seq_lens = torch.full((batch_size,), seq_len_k, dtype=torch.int32, device=device) - workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device=device) - - # Warmup — includes compilation on first call - print("Warming up...") - for _ in range(warmup_iters): - cute_dsl_mla_decode( - query=query, kv_cache=kv_cache, workspace_buffer=workspace_buffer, - kv_lora_rank=latent_dim, qk_rope_head_dim=rope_dim, - block_tables=block_tables, seq_lens=seq_lens, max_seq_len=seq_len_k, - softmax_scale=softmax_scale, output_scale=output_scale, - ) - torch.cuda.synchronize() - - # Benchmark: no sync between iterations → measures host overhead only - print(f"Benchmarking {num_iters} iterations (no inter-iteration sync)...") - torch.cuda.synchronize() - t0 = time.perf_counter() - for _ in range(num_iters): - cute_dsl_mla_decode( - query=query, kv_cache=kv_cache, workspace_buffer=workspace_buffer, - kv_lora_rank=latent_dim, qk_rope_head_dim=rope_dim, - block_tables=block_tables, seq_lens=seq_lens, max_seq_len=seq_len_k, - softmax_scale=softmax_scale, output_scale=output_scale, - ) - torch.cuda.synchronize() - t1 = time.perf_counter() - - total_us = (t1 - t0) * 1e6 - per_call_us = total_us / num_iters - print(f"Total: {total_us:.0f} us for {num_iters} calls") - print(f"Per call: {per_call_us:.1f} us") - - # Also measure with line-level profiling of the key sections - print("\n--- Profiling individual sections ---") - profile_sections(query, kv_cache, workspace_buffer, latent_dim, rope_dim, - block_tables, seq_lens, seq_len_k, softmax_scale, output_scale, - num_iters=num_iters) - - return per_call_us - - -def profile_sections(query, kv_cache, workspace_buffer, kv_lora_rank, qk_rope_head_dim, - block_tables, seq_lens, max_seq_len, softmax_scale, output_scale, - num_iters=1000): - """Profile individual sections of cute_dsl_mla_decode to find hotspots.""" - from flashinfer.cute_dsl.mla_decode import ( - _get_compiled_mla_kernel, - _get_split_kv_and_workspace_size, - _LATENT_DIM, _ROPE_DIM, _MMA_QK_TILER_MN, _MAX_ACTIVE_CLUSTERS, - BlackwellMultiHeadLatentAttentionForwardFP16, - ) - from flashinfer.cute_dsl.utils import get_num_sm - from cutlass import Float32, Int32 - import cutlass - - B, q_len, H, D_qk = query.shape - page_size = kv_cache.shape[1] - is_fp8 = query.dtype == torch.float8_e4m3fn - device = query.device - - timings = {} - - def measure(name, fn): - torch.cuda.synchronize() - t0 = time.perf_counter() - for _ in range(num_iters): - fn() - torch.cuda.synchronize() - elapsed_us = (time.perf_counter() - t0) * 1e6 / num_iters - timings[name] = elapsed_us - - # 1. Query split + permute - def query_reshape(): - q_nope = query[..., :kv_lora_rank] - q_rope = query[..., kv_lora_rank:] - q_latent_k = q_nope.permute(2, 3, 1, 0) - q_rope_k = q_rope.permute(2, 3, 1, 0) - return q_latent_k, q_rope_k - measure("query_split+permute", query_reshape) - - # 2. KV cache split + permute - def kv_reshape(): - c_latent_k = kv_cache[:, :, :kv_lora_rank].permute(1, 2, 0) - c_rope_k = kv_cache[:, :, kv_lora_rank:].permute(1, 2, 0) - return c_latent_k, c_rope_k - measure("kv_split+permute", kv_reshape) - - # 3. Page table transpose - def page_table_transpose(): - return block_tables.t().contiguous().to(torch.int32) - measure("page_table_transpose", page_table_transpose) - - # 4. split_kv + workspace_size computation (cached) - max_active_blocks = get_num_sm(device) - def compute_split(): - return _get_split_kv_and_workspace_size( - B, q_len, max_seq_len, H, max_active_blocks - ) - measure("compute_split_kv+workspace(cached)", compute_split) - - # 5. Workspace slice - split_kv, workspace_size = compute_split() - def workspace_slice(): - return workspace_buffer[:max(workspace_size, 1)] - measure("workspace_slice(no .contiguous())", workspace_slice) - - # 6. Output + LSE allocation - out_dtype = torch.float8_e4m3fn if is_fp8 else torch.float16 - def alloc_output(): - o_k = torch.empty((B, H, q_len, _LATENT_DIM), dtype=out_dtype, device=device).permute(1, 3, 2, 0) - lse_k = torch.empty((B, q_len, H), dtype=torch.float32, device=device).permute(2, 1, 0) - return o_k, lse_k - measure("alloc_output+lse", alloc_output) - - # 7. cache_seqs + block_split_kvs creation - def create_aux_tensors(): - if seq_lens.dtype == torch.int32 and seq_lens.is_contiguous(): - cache_seqs = seq_lens - else: - cache_seqs = seq_lens.to(torch.int32).contiguous() - block_split_kvs = torch.full((B,), split_kv, dtype=torch.int32, device=device) - return cache_seqs, block_split_kvs - measure("create_aux_tensors(optimized)", create_aux_tensors) - - # 8. _get_compiled_mla_kernel (should be cached) - def get_kernel(): - return _get_compiled_mla_kernel( - is_fp8=is_fp8, page_size=page_size, num_heads=H, seq_len_q=q_len, - is_persistent=True, is_var_seq=True, is_var_split_kv=True, - ) - measure("get_compiled_kernel(cached)", get_kernel) - - # 9. Kernel call only (prepare everything, measure just the call) - compiled_kernel = get_kernel() - q_latent_k = query[..., :kv_lora_rank].permute(2, 3, 1, 0) - q_rope_k = query[..., kv_lora_rank:].permute(2, 3, 1, 0) - c_latent_k = kv_cache[:, :, :kv_lora_rank].permute(1, 2, 0) - c_rope_k = kv_cache[:, :, kv_lora_rank:].permute(1, 2, 0) - page_table_k = block_tables.t().contiguous() - o_k = torch.empty((B, H, q_len, _LATENT_DIM), dtype=out_dtype, device=device).permute(1, 3, 2, 0) - lse_k = torch.empty((B, q_len, H), dtype=torch.float32, device=device).permute(2, 1, 0) - ws = workspace_buffer[:max(workspace_size, 1)] - cache_seqs = seq_lens.to(torch.int32).contiguous() - block_split_kvs_t = torch.full((B,), split_kv, dtype=torch.int32, device=device) - split_kv_scalar = Int32(split_kv) - softmax_scale_scalar = Float32(softmax_scale) - output_scale_scalar = Float32(output_scale) - - def kernel_call_pre_cached(): - compiled_kernel( - q_latent_k, q_rope_k, c_latent_k, c_rope_k, page_table_k, - o_k, lse_k, ws, - split_kv_scalar, cache_seqs, block_split_kvs_t, - softmax_scale_scalar, output_scale_scalar, - ) - measure("kernel_call(pre-cached scalars)", kernel_call_pre_cached) - - def kernel_call_per_call(): - compiled_kernel( - q_latent_k, q_rope_k, c_latent_k, c_rope_k, page_table_k, - o_k, lse_k, ws, - Int32(split_kv), cache_seqs, block_split_kvs_t, - Float32(softmax_scale), Float32(output_scale), - ) - measure("kernel_call(per-call scalars)", kernel_call_per_call) - - # 10. Output reshape - def output_reshape(): - result = o_k.permute(3, 2, 0, 1).contiguous() - if q_len == 1: - result = result.squeeze(1) - return result - measure("output_reshape", output_reshape) - - # Print results - print(f"{'Section':<35} {'us/call':>10}") - print("-" * 47) - total = 0.0 - for name, us in timings.items(): - print(f" {name:<33} {us:>10.1f}") - total += us - print("-" * 47) - print(f" {'SUM':<33} {total:>10.1f}") - - -if __name__ == "__main__": - bench_host_overhead() diff --git a/tests/attention/test_cute_dsl_mla_decode.py b/tests/attention/test_cute_dsl_mla_decode.py index 8e9018e8a9..0de6794a3d 100644 --- a/tests/attention/test_cute_dsl_mla_decode.py +++ b/tests/attention/test_cute_dsl_mla_decode.py @@ -98,8 +98,9 @@ def torch_reference_mla( @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("seq_len_k", [128, 512, 2048]) @pytest.mark.parametrize("page_size", [128]) -def test_cute_dsl_mla_decode_fp16(batch_size, seq_len_k, page_size): - """Test FP16 MLA decode kernel.""" +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_cute_dsl_mla_decode_fp16(batch_size, seq_len_k, page_size, dtype): + """Test FP16/BF16 MLA decode kernel.""" skip_if_unsupported() from flashinfer.cute_dsl.mla_decode import cute_dsl_mla_decode @@ -117,7 +118,7 @@ def test_cute_dsl_mla_decode_fp16(batch_size, seq_len_k, page_size): # Allocate query: [B, q_len, H, D_qk] D_qk = latent_dim + rope_dim query = torch.randn( - batch_size, q_len, num_heads, D_qk, dtype=torch.float16, device=device + batch_size, q_len, num_heads, D_qk, dtype=dtype, device=device ) # Allocate paged KV cache @@ -127,7 +128,7 @@ def test_cute_dsl_mla_decode_fp16(batch_size, seq_len_k, page_size): total_pages, page_size, latent_dim + rope_dim, - dtype=torch.float16, + dtype=dtype, device=device, ) @@ -143,7 +144,7 @@ def test_cute_dsl_mla_decode_fp16(batch_size, seq_len_k, page_size): seq_lens = torch.full((batch_size,), seq_len_k, dtype=torch.int32, device=device) # Workspace - workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device=device) + workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device=device) # Run kernel out = cute_dsl_mla_decode( @@ -181,10 +182,10 @@ def test_cute_dsl_mla_decode_fp16(batch_size, seq_len_k, page_size): if q_len == 1: ref_out = ref_out.squeeze(1) - ref_out_fp16 = ref_out.to(torch.float16) + ref_out_cast = ref_out.to(dtype) - # Check with tolerance appropriate for FP16 - torch.testing.assert_close(out, ref_out_fp16, atol=1e-2, rtol=1e-2) + # Check with tolerance appropriate for FP16/BF16 + torch.testing.assert_close(out, ref_out_cast, atol=1e-2, rtol=1e-2) @pytest.mark.parametrize("batch_size", [1, 4]) @@ -229,7 +230,7 @@ def test_cute_dsl_mla_decode_variable_seq_len(batch_size, seq_len_k, page_size=1 for p in range(max_pages_per_batch): block_tables[b, p] = b * max_pages_per_batch + p - workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device=device) + workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device=device) out = cute_dsl_mla_decode( query=query, @@ -305,7 +306,7 @@ def test_cute_dsl_mla_decode_via_api(batch_size, seq_len_k, page_size=128): block_tables[b, p] = b * num_pages_per_batch + p seq_lens = torch.full((batch_size,), seq_len_k, dtype=torch.int32, device=device) - workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device=device) + workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device=device) out = trtllm_batch_decode_with_kv_cache_mla( query=query, @@ -368,7 +369,7 @@ def test_cute_dsl_mla_decode_fp8(batch_size, seq_len_k, page_size): block_tables[b, p] = b * num_pages_per_batch + p seq_lens = torch.full((batch_size,), seq_len_k, dtype=torch.int32, device=device) - workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device=device) + workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device=device) out = cute_dsl_mla_decode( query=query, From 842b624b35a6f6059b47ddc8871a0c05589fd02c Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Tue, 10 Mar 2026 22:44:26 -0700 Subject: [PATCH 11/31] minor. --- benchmarks/bench_trtllm_gen_mla.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/benchmarks/bench_trtllm_gen_mla.py b/benchmarks/bench_trtllm_gen_mla.py index 8846f330ab..ef6210e8d9 100644 --- a/benchmarks/bench_trtllm_gen_mla.py +++ b/benchmarks/bench_trtllm_gen_mla.py @@ -134,7 +134,6 @@ def bench_trtllm_mla(batch_size, q_len_per_request, seq_len, page_size, dtype, b print(f"memory bandwidth: {total_mem_bytes / ms / 1e6:.2f} GB/s") print(f"FLOPs: {flops / ms / 1e9:.2f} TFLOPs/s") - if __name__ == "__main__": import argparse @@ -146,18 +145,17 @@ def bench_trtllm_mla(batch_size, q_len_per_request, seq_len, page_size, dtype, b help="Backend to use (auto, trtllm-gen, cute-dsl)", ) args = parser.parse_args() - - # cute-dsl only supports float16 and float8_e4m3fn + if args.backend == "cute-dsl": - dtypes = [torch.float16, torch.float8_e4m3fn] + q_lens = [1, 2, 4] else: - dtypes = [torch.bfloat16, torch.float8_e4m3fn] + q_lens = [1, 2, 4, 8, 16] - for dtype in dtypes: + for dtype in [torch.bfloat16, torch.float8_e4m3fn]: for page_size in [32, 64]: for batch_size in [1, 2, 4, 16, 32, 64, 128, 256, 512, 768, 1024]: for seq_len in [1024, 4096, 8192]: - for q_len_per_request in [1, 2, 4, 8, 16]: + for q_len_per_request in q_lens: try: bench_trtllm_mla( batch_size, From deee8197b90b1ee10cffc8207e4253dd006bb0d4 Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Tue, 10 Mar 2026 22:49:50 -0700 Subject: [PATCH 12/31] format --- benchmarks/bench_trtllm_gen_mla.py | 7 +++++-- flashinfer/cute_dsl/mla_decode.py | 1 - flashinfer/cute_dsl/mla_decode_fp16.py | 6 ++---- flashinfer/cute_dsl/mla_decode_fp8.py | 6 ++---- tests/attention/test_cute_dsl_mla_decode.py | 4 +--- 5 files changed, 10 insertions(+), 14 deletions(-) diff --git a/benchmarks/bench_trtllm_gen_mla.py b/benchmarks/bench_trtllm_gen_mla.py index ef6210e8d9..a739ccc21b 100644 --- a/benchmarks/bench_trtllm_gen_mla.py +++ b/benchmarks/bench_trtllm_gen_mla.py @@ -10,7 +10,9 @@ kv_lora_rank = 512 -def bench_trtllm_mla(batch_size, q_len_per_request, seq_len, page_size, dtype, backend="auto"): +def bench_trtllm_mla( + batch_size, q_len_per_request, seq_len, page_size, dtype, backend="auto" +): torch.manual_seed(42) device = "cuda:0" @@ -134,6 +136,7 @@ def bench_trtllm_mla(batch_size, q_len_per_request, seq_len, page_size, dtype, b print(f"memory bandwidth: {total_mem_bytes / ms / 1e6:.2f} GB/s") print(f"FLOPs: {flops / ms / 1e9:.2f} TFLOPs/s") + if __name__ == "__main__": import argparse @@ -145,7 +148,7 @@ def bench_trtllm_mla(batch_size, q_len_per_request, seq_len, page_size, dtype, b help="Backend to use (auto, trtllm-gen, cute-dsl)", ) args = parser.parse_args() - + if args.backend == "cute-dsl": q_lens = [1, 2, 4] else: diff --git a/flashinfer/cute_dsl/mla_decode.py b/flashinfer/cute_dsl/mla_decode.py index f0f42cbd4c..64d738a26e 100644 --- a/flashinfer/cute_dsl/mla_decode.py +++ b/flashinfer/cute_dsl/mla_decode.py @@ -305,7 +305,6 @@ def cute_dsl_mla_decode( assert qk_rope_head_dim == _ROPE_DIM q_dtype = query.dtype - is_fp8 = q_dtype == torch.float8_e4m3fn # Handle 3D vs 4D kv_cache: normalize to 3D [num_pages, page_size, D_total] if kv_cache.dim() == 4: diff --git a/flashinfer/cute_dsl/mla_decode_fp16.py b/flashinfer/cute_dsl/mla_decode_fp16.py index e3fa3a833c..f4b266706f 100644 --- a/flashinfer/cute_dsl/mla_decode_fp16.py +++ b/flashinfer/cute_dsl/mla_decode_fp16.py @@ -1430,11 +1430,9 @@ def get_split_kv( split_wave_aware = ceil_div(max_splits, k_waves) max_split_kv = 32 return min(split_wave_aware, max_split_kv) - + @staticmethod - def get_split_kv_simplified( - B: int, S: int, max_active_blocks: int - ) -> int: + def get_split_kv_simplified(B: int, S: int, max_active_blocks: int) -> int: blocks_per_batch = max(1, max_active_blocks // B // (S * 2)) max_split_kv = 32 return min(blocks_per_batch, max_split_kv) diff --git a/flashinfer/cute_dsl/mla_decode_fp8.py b/flashinfer/cute_dsl/mla_decode_fp8.py index 28548c57b8..0e7e1a6e19 100644 --- a/flashinfer/cute_dsl/mla_decode_fp8.py +++ b/flashinfer/cute_dsl/mla_decode_fp8.py @@ -1494,11 +1494,9 @@ def get_split_kv( split_wave_aware = ceil_div(max_splits, k_waves) max_split_kv = 32 return min(split_wave_aware, max_split_kv) - + @staticmethod - def get_split_kv_simplified( - B: int, S: int, max_active_blocks: int - ) -> int: + def get_split_kv_simplified(B: int, S: int, max_active_blocks: int) -> int: blocks_per_batch = max(1, max_active_blocks // B // (S * 2)) max_split_kv = 32 return min(blocks_per_batch, max_split_kv) diff --git a/tests/attention/test_cute_dsl_mla_decode.py b/tests/attention/test_cute_dsl_mla_decode.py index 0de6794a3d..cf5c004a3f 100644 --- a/tests/attention/test_cute_dsl_mla_decode.py +++ b/tests/attention/test_cute_dsl_mla_decode.py @@ -117,9 +117,7 @@ def test_cute_dsl_mla_decode_fp16(batch_size, seq_len_k, page_size, dtype): # Allocate query: [B, q_len, H, D_qk] D_qk = latent_dim + rope_dim - query = torch.randn( - batch_size, q_len, num_heads, D_qk, dtype=dtype, device=device - ) + query = torch.randn(batch_size, q_len, num_heads, D_qk, dtype=dtype, device=device) # Allocate paged KV cache num_pages_per_batch = (seq_len_k + page_size - 1) // page_size From 5cef49316c422d837b4111d6a7c7fd6757e7f501 Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Tue, 10 Mar 2026 23:08:02 -0700 Subject: [PATCH 13/31] refactor: Replace hardcoded MLA config constants with function parameters Move module-level constants (_LATENT_DIM, _ROPE_DIM, _MMA_QK_TILER_MN, _MMA_PV_TILER_MN, _MAX_ACTIVE_CLUSTERS, _SUPPORTED_DTYPES) into their respective functions. Query max_active_clusters from hardware dynamically instead of hardcoding 2. Pass kv_lora_rank and qk_rope_head_dim as function parameters for flexibility. Co-Authored-By: Claude Opus 4.6 --- flashinfer/cute_dsl/mla_decode.py | 56 ++++++++++++++++--------------- 1 file changed, 29 insertions(+), 27 deletions(-) diff --git a/flashinfer/cute_dsl/mla_decode.py b/flashinfer/cute_dsl/mla_decode.py index 64d738a26e..72ea914c13 100644 --- a/flashinfer/cute_dsl/mla_decode.py +++ b/flashinfer/cute_dsl/mla_decode.py @@ -30,18 +30,7 @@ from .mla_decode_fp16 import BlackwellMultiHeadLatentAttentionForwardFP16 from .mla_decode_fp8 import BlackwellMultiHeadLatentAttentionForwardFP8 -from .utils import get_num_sm, torch_to_cutlass_dtype - - -# Default kernel configuration — matches DeepSeek-V2/V3 MLA dimensions -_LATENT_DIM = 512 -_ROPE_DIM = 64 -_MMA_QK_TILER_MN = (128, 128) -_MMA_PV_TILER_MN = (128, 256) -_MAX_ACTIVE_CLUSTERS = 2 -_SKIP_CORRECTION_THRESHOLD = 0.0 - -_SUPPORTED_DTYPES = {torch.float16, torch.bfloat16, torch.float8_e4m3fn} +from .utils import get_max_active_clusters, get_num_sm, torch_to_cutlass_dtype @functools.cache @@ -49,6 +38,7 @@ def _get_split_kv_and_workspace_size( B: int, q_len: int, H: int, + kv_lora_rank: int, max_active_blocks: int, ) -> Tuple[int, int]: """Cache split_kv and workspace_size since they are deterministic for the same params.""" @@ -56,7 +46,7 @@ def _get_split_kv_and_workspace_size( B, q_len, max_active_blocks ) workspace_size = BlackwellMultiHeadLatentAttentionForwardFP16.get_workspace_size( - H, q_len, _LATENT_DIM, B, split_kv, cutlass.Float32 + H, q_len, kv_lora_rank, B, split_kv, cutlass.Float32 ) return split_kv, workspace_size @@ -67,9 +57,12 @@ def _get_compiled_mla_kernel( page_size: int, num_heads: int, seq_len_q: int, + kv_lora_rank: int, + qk_rope_head_dim: int, is_persistent: bool, is_var_seq: bool, is_var_split_kv: bool, + skip_correction_threshold: float = 0.0, ) -> Callable: """Compile and cache an MLA decode kernel. @@ -79,6 +72,10 @@ def _get_compiled_mla_kernel( All scalar arguments must be pre-wrapped as Int32/Float32. """ + mma_qk_tiler_mn = (128, 128) + mma_pv_tiler_mn = (128, 256) + cluster_shape_mnk = (2, 1, 1) + is_fp8 = torch_dtype == torch.float8_e4m3fn KernelClass = ( BlackwellMultiHeadLatentAttentionForwardFP8 @@ -92,14 +89,14 @@ def _get_compiled_mla_kernel( seq_len_q, 1, # K (runtime, use placeholder) num_heads, - _LATENT_DIM, - _ROPE_DIM, + kv_lora_rank, + qk_rope_head_dim, cutlass_dtype, cutlass_dtype, cutlass.Float32, cutlass.Float32, - _MMA_QK_TILER_MN, - _MMA_PV_TILER_MN, + mma_qk_tiler_mn, + mma_pv_tiler_mn, 1, # split_kv (runtime, use 1 to pass the H<128 check) is_persistent, is_var_seq, @@ -115,11 +112,13 @@ def _get_compiled_mla_kernel( kernel_obj = KernelClass( acc_dtype=cutlass.Float32, lse_dtype=cutlass.Float32, - mma_qk_tiler_mn=_MMA_QK_TILER_MN, - mma_pv_tiler_mn=_MMA_PV_TILER_MN, - max_active_clusters=_MAX_ACTIVE_CLUSTERS, + mma_qk_tiler_mn=mma_qk_tiler_mn, + mma_pv_tiler_mn=mma_pv_tiler_mn, + max_active_clusters=get_max_active_clusters( + cluster_shape_mnk[0] * cluster_shape_mnk[1] + ), page_size=page_size, - skip_correction_threshold=_SKIP_CORRECTION_THRESHOLD, + skip_correction_threshold=skip_correction_threshold, is_persistent=is_persistent, is_var_seq=is_var_seq, is_var_split_kv=is_var_split_kv, @@ -293,16 +292,15 @@ def cute_dsl_mla_decode( torch.Tensor Output tensor [B, H, kv_lora_rank]. """ - assert query.dtype in _SUPPORTED_DTYPES, ( - f"cute_dsl_mla_decode only supports {_SUPPORTED_DTYPES}, got {query.dtype}" + supported_dtypes = {torch.float16, torch.bfloat16, torch.float8_e4m3fn} + assert query.dtype in supported_dtypes, ( + f"cute_dsl_mla_decode only supports {supported_dtypes}, got {query.dtype}" ) assert kv_cache.dtype == query.dtype, ( f"kv_cache dtype {kv_cache.dtype} must match query dtype {query.dtype}" ) B, q_len, H, D_qk = query.shape assert D_qk == kv_lora_rank + qk_rope_head_dim - assert kv_lora_rank == _LATENT_DIM - assert qk_rope_head_dim == _ROPE_DIM q_dtype = query.dtype @@ -335,7 +333,7 @@ def cute_dsl_mla_decode( # Cached split_kv and workspace_size computation max_active_blocks = get_num_sm(query.device) split_kv, workspace_size = _get_split_kv_and_workspace_size( - B, q_len, H, max_active_blocks + B, q_len, H, kv_lora_rank, max_active_blocks ) if H < 128 and split_kv != 1: @@ -361,7 +359,7 @@ def cute_dsl_mla_decode( o_k = out else: o_k = torch.empty( - (B, q_len, H, _LATENT_DIM), dtype=out_dtype, device=query.device + (B, q_len, H, kv_lora_rank), dtype=out_dtype, device=query.device ) # LSE: contiguous [B, q_len, H]. Kernel reinterprets to [H, q_len, B]. @@ -372,6 +370,7 @@ def cute_dsl_mla_decode( is_var_split_kv = False block_split_kvs = None + skip_correction_threshold = 0.0 # Get compiled kernel (cached after first compile) compiled_kernel = _get_compiled_mla_kernel( @@ -379,9 +378,12 @@ def cute_dsl_mla_decode( page_size=page_size, num_heads=H, seq_len_q=q_len, + kv_lora_rank=kv_lora_rank, + qk_rope_head_dim=qk_rope_head_dim, is_persistent=True, is_var_seq=True, is_var_split_kv=is_var_split_kv, + skip_correction_threshold=skip_correction_threshold, ) # Call the kernel From 2ece5a763de07266940d2cefa943459162330074 Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Tue, 10 Mar 2026 23:43:44 -0700 Subject: [PATCH 14/31] refactor: Split can_implement check from kernel compilation to avoid cache fragmentation Separate _check_can_implement (validates num_heads/seq_len_q) from _get_compiled_mla_kernel (does actual compilation). Both are cached with @functools.cache, but the compilation cache key no longer includes num_heads and seq_len_q since they don't affect the compiled kernel (all dimensions are sym_int). Co-Authored-By: Claude Opus 4.6 --- flashinfer/cute_dsl/mla_decode.py | 63 +++++++++++++++++++++++-------- 1 file changed, 48 insertions(+), 15 deletions(-) diff --git a/flashinfer/cute_dsl/mla_decode.py b/flashinfer/cute_dsl/mla_decode.py index 72ea914c13..561b7a92bf 100644 --- a/flashinfer/cute_dsl/mla_decode.py +++ b/flashinfer/cute_dsl/mla_decode.py @@ -52,7 +52,7 @@ def _get_split_kv_and_workspace_size( @functools.cache -def _get_compiled_mla_kernel( +def _check_can_implement( torch_dtype: torch.dtype, page_size: int, num_heads: int, @@ -62,19 +62,10 @@ def _get_compiled_mla_kernel( is_persistent: bool, is_var_seq: bool, is_var_split_kv: bool, - skip_correction_threshold: float = 0.0, -) -> Callable: - """Compile and cache an MLA decode kernel. - - Returns a callable that accepts (q_latent, q_rope, c_latent, c_rope, - page_table, o, lse, workspace, split_kv_scalar, cache_seqs, - block_split_kvs, softmax_scale_scalar, output_scale_scalar). - - All scalar arguments must be pre-wrapped as Int32/Float32. - """ +) -> None: + """Check if the kernel supports the given configuration (cached).""" mma_qk_tiler_mn = (128, 128) mma_pv_tiler_mn = (128, 256) - cluster_shape_mnk = (2, 1, 1) is_fp8 = torch_dtype == torch.float8_e4m3fn KernelClass = ( @@ -82,7 +73,6 @@ def _get_compiled_mla_kernel( if is_fp8 else BlackwellMultiHeadLatentAttentionForwardFP16 ) - cutlass_dtype = torch_to_cutlass_dtype(torch_dtype) if not KernelClass.can_implement( 1, # B (runtime, use placeholder) @@ -109,6 +99,38 @@ def _get_compiled_mla_kernel( f"dtype={torch_dtype})" ) + +@functools.cache +def _get_compiled_mla_kernel( + torch_dtype: torch.dtype, + page_size: int, + kv_lora_rank: int, + qk_rope_head_dim: int, + is_persistent: bool, + is_var_seq: bool, + is_var_split_kv: bool, + skip_correction_threshold: float = 0.0, +) -> Callable: + """Compile and cache an MLA decode kernel. + + Returns a callable that accepts (q_latent, q_rope, c_latent, c_rope, + page_table, o, lse, workspace, split_kv_scalar, cache_seqs, + block_split_kvs, softmax_scale_scalar, output_scale_scalar). + + All scalar arguments must be pre-wrapped as Int32/Float32. + """ + mma_qk_tiler_mn = (128, 128) + mma_pv_tiler_mn = (128, 256) + cluster_shape_mnk = (2, 1, 1) + + is_fp8 = torch_dtype == torch.float8_e4m3fn + KernelClass = ( + BlackwellMultiHeadLatentAttentionForwardFP8 + if is_fp8 + else BlackwellMultiHeadLatentAttentionForwardFP16 + ) + cutlass_dtype = torch_to_cutlass_dtype(torch_dtype) + kernel_obj = KernelClass( acc_dtype=cutlass.Float32, lse_dtype=cutlass.Float32, @@ -372,8 +394,8 @@ def cute_dsl_mla_decode( block_split_kvs = None skip_correction_threshold = 0.0 - # Get compiled kernel (cached after first compile) - compiled_kernel = _get_compiled_mla_kernel( + # Validate configuration (cached, negligible overhead after first call) + _check_can_implement( torch_dtype=q_dtype, page_size=page_size, num_heads=H, @@ -383,6 +405,17 @@ def cute_dsl_mla_decode( is_persistent=True, is_var_seq=True, is_var_split_kv=is_var_split_kv, + ) + + # Get compiled kernel (cached after first compile) + compiled_kernel = _get_compiled_mla_kernel( + torch_dtype=q_dtype, + page_size=page_size, + kv_lora_rank=kv_lora_rank, + qk_rope_head_dim=qk_rope_head_dim, + is_persistent=True, + is_var_seq=True, + is_var_split_kv=is_var_split_kv, skip_correction_threshold=skip_correction_threshold, ) From 98eae77455a169f3e0bb31b852cc81e27883ca90 Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Wed, 11 Mar 2026 00:04:04 -0700 Subject: [PATCH 15/31] fix: Align cute-dsl output shape with trtllm-gen and fix tensor scale handling - Output shape is now [B, q_len, H, D] (no squeeze), matching trtllm-gen - Reject tensor bmm1_scale/bmm2_scale in cute-dsl backend (was silently using log2e-transformed values from the trtllm-gen path) - Add cross-backend test comparing cute-dsl vs trtllm-gen numerically Co-Authored-By: Claude Opus 4.6 --- flashinfer/cute_dsl/mla_decode.py | 13 ++-- flashinfer/mla.py | 19 ++++-- tests/attention/test_cute_dsl_mla_decode.py | 76 ++++++++++++++++++--- 3 files changed, 83 insertions(+), 25 deletions(-) diff --git a/flashinfer/cute_dsl/mla_decode.py b/flashinfer/cute_dsl/mla_decode.py index 561b7a92bf..da72e4b0a4 100644 --- a/flashinfer/cute_dsl/mla_decode.py +++ b/flashinfer/cute_dsl/mla_decode.py @@ -307,12 +307,12 @@ def cute_dsl_mla_decode( output_scale : float Scale factor applied to the output. out : Optional[torch.Tensor] - Pre-allocated output tensor [B, H, kv_lora_rank]. + Pre-allocated output tensor [B, q_len, H, kv_lora_rank]. Returns ------- torch.Tensor - Output tensor [B, H, kv_lora_rank]. + Output tensor [B, q_len, H, kv_lora_rank]. """ supported_dtypes = {torch.float16, torch.bfloat16, torch.float8_e4m3fn} assert query.dtype in supported_dtypes, ( @@ -375,10 +375,7 @@ def cute_dsl_mla_decode( # Kernel reinterprets to [H, D, q_len, B] internally via zero-cost make_tensor. out_dtype = q_dtype if out is not None: - if q_len == 1: - o_k = out.unsqueeze(1) # [B, H, D] → [B, 1, H, D] - else: - o_k = out + o_k = out else: o_k = torch.empty( (B, q_len, H, kv_lora_rank), dtype=out_dtype, device=query.device @@ -440,7 +437,5 @@ def cute_dsl_mla_decode( if out is not None: return out - # o_k is already [B, q_len, H, D] contiguous — just squeeze for q_len==1. - if q_len == 1: - return o_k.squeeze(1) + # o_k is [B, q_len, H, D] — return as-is to match trtllm-gen output shape. return o_k diff --git a/flashinfer/mla.py b/flashinfer/mla.py index eaedd02165..5414398a8c 100644 --- a/flashinfer/mla.py +++ b/flashinfer/mla.py @@ -771,6 +771,17 @@ def trtllm_batch_decode_with_kv_cache_mla( elif backend == "cute-dsl": from .cute_dsl.mla_decode import cute_dsl_mla_decode + if isinstance(bmm1_scale, torch.Tensor): + raise ValueError( + "cute-dsl backend does not support tensor bmm1_scale, " + "please pass a float value" + ) + if isinstance(bmm2_scale, torch.Tensor): + raise ValueError( + "cute-dsl backend does not support tensor bmm2_scale, " + "please pass a float value" + ) + return cute_dsl_mla_decode( query=query, kv_cache=kv_cache, @@ -780,12 +791,8 @@ def trtllm_batch_decode_with_kv_cache_mla( block_tables=block_tables, seq_lens=seq_lens, max_seq_len=max_seq_len, - softmax_scale=bmm1_scale - if isinstance(bmm1_scale, float) - else float(bmm1_scale.item()), - output_scale=bmm2_scale - if isinstance(bmm2_scale, float) - else float(bmm2_scale.item()), + softmax_scale=bmm1_scale, + output_scale=bmm2_scale, out=out, ) else: diff --git a/tests/attention/test_cute_dsl_mla_decode.py b/tests/attention/test_cute_dsl_mla_decode.py index cf5c004a3f..d6960af476 100644 --- a/tests/attention/test_cute_dsl_mla_decode.py +++ b/tests/attention/test_cute_dsl_mla_decode.py @@ -177,9 +177,6 @@ def test_cute_dsl_mla_decode_fp16(batch_size, seq_len_k, page_size, dtype): page_size, ) - if q_len == 1: - ref_out = ref_out.squeeze(1) - ref_out_cast = ref_out.to(dtype) # Check with tolerance appropriate for FP16/BF16 @@ -261,8 +258,6 @@ def test_cute_dsl_mla_decode_variable_seq_len(batch_size, seq_len_k, page_size=1 output_scale, page_size, ) - if q_len == 1: - ref_out = ref_out.squeeze(1) ref_out_fp16 = ref_out.to(torch.float16) torch.testing.assert_close(out, ref_out_fp16, atol=1e-2, rtol=1e-2) @@ -321,7 +316,71 @@ def test_cute_dsl_mla_decode_via_api(batch_size, seq_len_k, page_size=128): backend="cute-dsl", ) - assert out.shape == (batch_size, num_heads, latent_dim) + assert out.shape == (batch_size, q_len, num_heads, latent_dim) + + +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("seq_len_k", [128, 512]) +def test_cute_dsl_vs_trtllm_gen(batch_size, seq_len_k, page_size=64): + """Test cute-dsl backend output matches trtllm-gen backend output.""" + skip_if_unsupported() + + from flashinfer.mla import trtllm_batch_decode_with_kv_cache_mla + + torch.manual_seed(42) + device = torch.device("cuda") + + num_heads = 128 + latent_dim = 512 + rope_dim = 64 + q_len = 1 + softmax_scale = 1.0 / (latent_dim**0.5) + D_qk = latent_dim + rope_dim + + query = torch.randn( + batch_size, q_len, num_heads, D_qk, dtype=torch.bfloat16, device=device + ) + + num_pages_per_batch = (seq_len_k + page_size - 1) // page_size + total_pages = num_pages_per_batch * batch_size + 10 + # trtllm-gen expects 4D kv_cache: [num_pages, 1, page_size, D] + kv_cache = torch.randn( + total_pages, 1, page_size, D_qk, dtype=torch.bfloat16, device=device + ) + + block_tables = torch.zeros( + batch_size, num_pages_per_batch, dtype=torch.int32, device=device + ) + for b in range(batch_size): + for p in range(num_pages_per_batch): + block_tables[b, p] = b * num_pages_per_batch + p + + seq_lens = torch.full((batch_size,), seq_len_k, dtype=torch.int32, device=device) + workspace_buffer = torch.zeros(256 * 1024 * 1024, dtype=torch.int8, device=device) + + common_args = dict( + query=query, + kv_cache=kv_cache, + workspace_buffer=workspace_buffer, + qk_nope_head_dim=latent_dim, + kv_lora_rank=latent_dim, + qk_rope_head_dim=rope_dim, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=seq_len_k, + bmm1_scale=softmax_scale, + bmm2_scale=1.0, + ) + + out_trtllm = trtllm_batch_decode_with_kv_cache_mla(**common_args, backend="trtllm-gen") + out_cute_dsl = trtllm_batch_decode_with_kv_cache_mla(**common_args, backend="cute-dsl") + + torch.testing.assert_close( + out_cute_dsl.to(torch.float32), + out_trtllm.to(torch.float32), + atol=1e-2, + rtol=1e-2, + ) @pytest.mark.parametrize("batch_size", [1, 4]) @@ -383,7 +442,7 @@ def test_cute_dsl_mla_decode_fp8(batch_size, seq_len_k, page_size): ) assert out.dtype == torch.float8_e4m3fn - assert out.shape == (batch_size, num_heads, latent_dim) + assert out.shape == (batch_size, q_len, num_heads, latent_dim) # Reference: compute in FP32 using FP8 values dequantized to FP32 kv_flat = kv_cache.reshape(-1, D_qk).to(torch.float32) @@ -403,9 +462,6 @@ def test_cute_dsl_mla_decode_fp8(batch_size, seq_len_k, page_size): output_scale, page_size, ) - if q_len == 1: - ref_out = ref_out.squeeze(1) - # Compare outputs in FP32; FP8 has limited precision so use wider tolerance torch.testing.assert_close( out.to(torch.float32), ref_out.to(torch.float32), atol=0.1, rtol=0.1 From a4a8723d3e3cfd316400fd1518adfc2687f3cfa2 Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Wed, 11 Mar 2026 02:31:34 -0700 Subject: [PATCH 16/31] fix workspace None issue. --- flashinfer/cute_dsl/mla_decode.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/flashinfer/cute_dsl/mla_decode.py b/flashinfer/cute_dsl/mla_decode.py index da72e4b0a4..5d78fbe33e 100644 --- a/flashinfer/cute_dsl/mla_decode.py +++ b/flashinfer/cute_dsl/mla_decode.py @@ -110,6 +110,7 @@ def _get_compiled_mla_kernel( is_var_seq: bool, is_var_split_kv: bool, skip_correction_threshold: float = 0.0, + is_workspace_size_zero: bool = False, ) -> Callable: """Compile and cache an MLA decode kernel. @@ -216,12 +217,15 @@ def _get_compiled_mla_kernel( stride_order=(2, 1, 0), assumed_align=128, ) - # workspace: 1-D (int8 to match typical torch workspace buffers) - workspace_fake = cute.runtime.make_fake_compact_tensor( - cutlass.Int8, - (sym_workspace_size,), - assumed_align=128, - ) + if is_workspace_size_zero: + workspace_fake = None + else: + # workspace: 1-D (int8 to match typical torch workspace buffers) + workspace_fake = cute.runtime.make_fake_compact_tensor( + cutlass.Int8, + (sym_workspace_size,), + assumed_align=128, + ) # cache_seqs: [batch_size] — int32 cache_seqs_fake = cute.runtime.make_fake_compact_tensor( cutlass.Int32, @@ -369,8 +373,11 @@ def cute_dsl_mla_decode( f"workspace_buffer too small: {workspace_buffer.numel()} bytes, " f"need {workspace_size} bytes" ) - workspace_bytes = workspace_buffer[: max(workspace_size, 1)] - + is_workspace_size_zero = workspace_size == 0 + if is_workspace_size_zero: + workspace_bytes = None + else: + workspace_bytes = workspace_buffer[: workspace_size] # Output buffer: contiguous [B, q_len, H, D]. # Kernel reinterprets to [H, D, q_len, B] internally via zero-cost make_tensor. out_dtype = q_dtype @@ -405,6 +412,8 @@ def cute_dsl_mla_decode( ) # Get compiled kernel (cached after first compile) + # Note: when is_workspace_size_zero is True, workspace_bytes is None and it will launch one kernel without workspace. + # Otherwise, workspace_bytes is not None and it will launch two kernels. compiled_kernel = _get_compiled_mla_kernel( torch_dtype=q_dtype, page_size=page_size, @@ -414,6 +423,7 @@ def cute_dsl_mla_decode( is_var_seq=True, is_var_split_kv=is_var_split_kv, skip_correction_threshold=skip_correction_threshold, + is_workspace_size_zero=is_workspace_size_zero, ) # Call the kernel From c2e769e1da4576502f2c86d24bb9cdc6e8b2aa40 Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Wed, 11 Mar 2026 03:33:03 -0700 Subject: [PATCH 17/31] fix: align assumed_align values with kernel's from_dlpack settings Match the fake tensor assumed_align values in mla_decode.py to the actual alignment used in mla_decode_fp16.py's from_dlpack calls: tensors use 16 bytes, workspace uses 32 bytes. Co-Authored-By: Claude Opus 4.6 --- flashinfer/cute_dsl/mla_decode.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/flashinfer/cute_dsl/mla_decode.py b/flashinfer/cute_dsl/mla_decode.py index 5d78fbe33e..328dbf65d6 100644 --- a/flashinfer/cute_dsl/mla_decode.py +++ b/flashinfer/cute_dsl/mla_decode.py @@ -171,14 +171,14 @@ def _get_compiled_mla_kernel( cutlass_dtype, (sym_batch, sym_seq_q, sym_heads, sym_latent), stride_order=(3, 2, 1, 0), - assumed_align=128, + assumed_align=16, ) # q_rope: [batch_size, seq_len_q, num_heads, rope_dim] — contiguous q_rope_fake = cute.runtime.make_fake_compact_tensor( cutlass_dtype, (sym_batch, sym_seq_q, sym_heads, sym_rope), stride_order=(3, 2, 1, 0), - assumed_align=128, + assumed_align=16, ) # c_latent: [kv_batch, seq_len_k, latent_dim] — contiguous # kv_batch is a separate sym_int from query batch: paged KV cache uses a flat @@ -187,35 +187,35 @@ def _get_compiled_mla_kernel( cutlass_dtype, (sym_kv_batch, sym_seq_kv, sym_latent), stride_order=(2, 1, 0), - assumed_align=128, + assumed_align=16, ) # c_rope: [kv_batch, seq_len_k, rope_dim] — contiguous c_rope_fake = cute.runtime.make_fake_compact_tensor( cutlass_dtype, (sym_kv_batch, sym_seq_kv, sym_rope), stride_order=(2, 1, 0), - assumed_align=128, + assumed_align=16, ) # page_table: [batch_size, page_count] — contiguous page_table_fake = cute.runtime.make_fake_compact_tensor( cutlass.Int32, (sym_batch, sym_page_count), stride_order=(1, 0), - assumed_align=128, + assumed_align=16, ) # o: [batch_size, seq_len_q, num_heads, latent_dim] — contiguous o_fake = cute.runtime.make_fake_compact_tensor( cutlass_dtype, (sym_batch, sym_seq_q, sym_heads, sym_latent), stride_order=(3, 2, 1, 0), - assumed_align=128, + assumed_align=16, ) # lse: [batch_size, seq_len_q, num_heads] — contiguous lse_fake = cute.runtime.make_fake_compact_tensor( cutlass.Float32, (sym_batch, sym_seq_q, sym_heads), stride_order=(2, 1, 0), - assumed_align=128, + assumed_align=16, ) if is_workspace_size_zero: workspace_fake = None @@ -224,20 +224,20 @@ def _get_compiled_mla_kernel( workspace_fake = cute.runtime.make_fake_compact_tensor( cutlass.Int8, (sym_workspace_size,), - assumed_align=128, + assumed_align=32, ) # cache_seqs: [batch_size] — int32 cache_seqs_fake = cute.runtime.make_fake_compact_tensor( cutlass.Int32, (sym_batch,), - assumed_align=128, + assumed_align=16, ) # block_split_kvs: [batch_size] — int32 (only needed for is_var_split_kv=True) if is_var_split_kv: block_split_kvs_fake = cute.runtime.make_fake_compact_tensor( cutlass.Int32, (sym_batch,), - assumed_align=128, + assumed_align=16, ) else: block_split_kvs_fake = None From b66eb4e465f9d07b7f180c73007256d3a3c97b59 Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Wed, 11 Mar 2026 18:47:18 -0700 Subject: [PATCH 18/31] perf: add divisibility hints and opt-level 2 for CuTe DSL MLA compilation Add divisibility=16 to sym_latent and sym_rope to enable better compiler optimizations, and bump compilation opt-level to 2. Co-Authored-By: Claude Opus 4.6 --- flashinfer/cute_dsl/mla_decode.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flashinfer/cute_dsl/mla_decode.py b/flashinfer/cute_dsl/mla_decode.py index 328dbf65d6..2be16e10c4 100644 --- a/flashinfer/cute_dsl/mla_decode.py +++ b/flashinfer/cute_dsl/mla_decode.py @@ -152,9 +152,9 @@ def _get_compiled_mla_kernel( # Static Python ints would cause cute.assume() to fail with AttributeError # inside initialize_workspace() since it expects DSL Integer types. sym_heads = cute.sym_int() - sym_latent = cute.sym_int() + sym_latent = cute.sym_int(divisibility=16) sym_seq_q = cute.sym_int() - sym_rope = cute.sym_int() + sym_rope = cute.sym_int(divisibility=16) sym_batch = cute.sym_int() # query/output batch dimension sym_kv_batch = cute.sym_int() # KV cache batch dim (flat pool, =1 in paged mode) sym_seq_kv = cute.sym_int() @@ -260,7 +260,7 @@ def _get_compiled_mla_kernel( Float32(1.0), # softmax_scale placeholder Float32(1.0), # output_scale placeholder stream_fake, - options="--enable-tvm-ffi", + options="--enable-tvm-ffi --opt-level 2", ) return compiled_kernel From 114460fdcc754a9898620b4219965ea5c51a26d0 Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Wed, 11 Mar 2026 21:18:03 -0700 Subject: [PATCH 19/31] format. --- flashinfer/cute_dsl/mla_decode.py | 2 +- tests/attention/test_cute_dsl_mla_decode.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/flashinfer/cute_dsl/mla_decode.py b/flashinfer/cute_dsl/mla_decode.py index 2be16e10c4..b97a7c0f07 100644 --- a/flashinfer/cute_dsl/mla_decode.py +++ b/flashinfer/cute_dsl/mla_decode.py @@ -377,7 +377,7 @@ def cute_dsl_mla_decode( if is_workspace_size_zero: workspace_bytes = None else: - workspace_bytes = workspace_buffer[: workspace_size] + workspace_bytes = workspace_buffer[:workspace_size] # Output buffer: contiguous [B, q_len, H, D]. # Kernel reinterprets to [H, D, q_len, B] internally via zero-cost make_tensor. out_dtype = q_dtype diff --git a/tests/attention/test_cute_dsl_mla_decode.py b/tests/attention/test_cute_dsl_mla_decode.py index d6960af476..cefb43f6c0 100644 --- a/tests/attention/test_cute_dsl_mla_decode.py +++ b/tests/attention/test_cute_dsl_mla_decode.py @@ -372,8 +372,12 @@ def test_cute_dsl_vs_trtllm_gen(batch_size, seq_len_k, page_size=64): bmm2_scale=1.0, ) - out_trtllm = trtllm_batch_decode_with_kv_cache_mla(**common_args, backend="trtllm-gen") - out_cute_dsl = trtllm_batch_decode_with_kv_cache_mla(**common_args, backend="cute-dsl") + out_trtllm = trtllm_batch_decode_with_kv_cache_mla( + **common_args, backend="trtllm-gen" + ) + out_cute_dsl = trtllm_batch_decode_with_kv_cache_mla( + **common_args, backend="cute-dsl" + ) torch.testing.assert_close( out_cute_dsl.to(torch.float32), From 104c9fec69c43960dd17a65b18ded8975d6b94fa Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Thu, 12 Mar 2026 00:21:40 -0700 Subject: [PATCH 20/31] feat: add is_var_seq parameter for auto persistent/non-persistent strategy Add is_var_seq parameter to cute_dsl_mla_decode and trtllm_batch_decode_with_kv_cache_mla. When is_var_seq=False (fixed-length), use persistent mode; when is_var_seq=True (variable-length), use non-persistent mode. Update tests accordingly. Co-Authored-By: Claude Opus 4.6 --- flashinfer/cute_dsl/mla_decode.py | 12 ++++++++---- flashinfer/mla.py | 2 ++ tests/attention/test_cute_dsl_mla_decode.py | 7 +++++-- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/flashinfer/cute_dsl/mla_decode.py b/flashinfer/cute_dsl/mla_decode.py index b97a7c0f07..0ee0634462 100644 --- a/flashinfer/cute_dsl/mla_decode.py +++ b/flashinfer/cute_dsl/mla_decode.py @@ -279,6 +279,7 @@ def cute_dsl_mla_decode( softmax_scale: float, output_scale: float = 1.0, out: Optional[torch.Tensor] = None, + is_var_seq: bool = True, ) -> torch.Tensor: """CuTe DSL MLA decode kernel for Blackwell SM100. @@ -398,6 +399,9 @@ def cute_dsl_mla_decode( block_split_kvs = None skip_correction_threshold = 0.0 + # for fix-length, set is_persistent to True; otherwise, set to False. + is_persistent = not is_var_seq + # Validate configuration (cached, negligible overhead after first call) _check_can_implement( torch_dtype=q_dtype, @@ -406,8 +410,8 @@ def cute_dsl_mla_decode( seq_len_q=q_len, kv_lora_rank=kv_lora_rank, qk_rope_head_dim=qk_rope_head_dim, - is_persistent=True, - is_var_seq=True, + is_persistent=is_persistent, + is_var_seq=is_var_seq, is_var_split_kv=is_var_split_kv, ) @@ -419,8 +423,8 @@ def cute_dsl_mla_decode( page_size=page_size, kv_lora_rank=kv_lora_rank, qk_rope_head_dim=qk_rope_head_dim, - is_persistent=True, - is_var_seq=True, + is_persistent=is_persistent, + is_var_seq=is_var_seq, is_var_split_kv=is_var_split_kv, skip_correction_threshold=skip_correction_threshold, is_workspace_size_zero=is_workspace_size_zero, diff --git a/flashinfer/mla.py b/flashinfer/mla.py index 5414398a8c..52c7ec331e 100644 --- a/flashinfer/mla.py +++ b/flashinfer/mla.py @@ -603,6 +603,7 @@ def trtllm_batch_decode_with_kv_cache_mla( skip_softmax_threshold_scale_factor: Optional[float] = None, enable_pdl: bool | None = None, backend: str = "auto", + is_var_seq: bool = True, ) -> torch.Tensor: """ Parameters @@ -794,6 +795,7 @@ def trtllm_batch_decode_with_kv_cache_mla( softmax_scale=bmm1_scale, output_scale=bmm2_scale, out=out, + is_var_seq=is_var_seq, ) else: raise ValueError(f"Backend {backend} not supported") diff --git a/tests/attention/test_cute_dsl_mla_decode.py b/tests/attention/test_cute_dsl_mla_decode.py index cefb43f6c0..7a7fda65ee 100644 --- a/tests/attention/test_cute_dsl_mla_decode.py +++ b/tests/attention/test_cute_dsl_mla_decode.py @@ -156,6 +156,7 @@ def test_cute_dsl_mla_decode_fp16(batch_size, seq_len_k, page_size, dtype): max_seq_len=seq_len_k, softmax_scale=softmax_scale, output_scale=output_scale, + is_var_seq=False, ) # Reference @@ -238,6 +239,7 @@ def test_cute_dsl_mla_decode_variable_seq_len(batch_size, seq_len_k, page_size=1 max_seq_len=max_seq_len, softmax_scale=softmax_scale, output_scale=output_scale, + is_var_seq=True, ) # Reference @@ -314,6 +316,7 @@ def test_cute_dsl_mla_decode_via_api(batch_size, seq_len_k, page_size=128): bmm1_scale=softmax_scale, bmm2_scale=1.0, backend="cute-dsl", + is_var_seq=False, ) assert out.shape == (batch_size, q_len, num_heads, latent_dim) @@ -373,10 +376,10 @@ def test_cute_dsl_vs_trtllm_gen(batch_size, seq_len_k, page_size=64): ) out_trtllm = trtllm_batch_decode_with_kv_cache_mla( - **common_args, backend="trtllm-gen" + **common_args, backend="trtllm-gen", is_var_seq=False ) out_cute_dsl = trtllm_batch_decode_with_kv_cache_mla( - **common_args, backend="cute-dsl" + **common_args, backend="cute-dsl", is_var_seq=False ) torch.testing.assert_close( From 0999000ae90f3840c0b309245c96c7e4ca52e18f Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Thu, 12 Mar 2026 00:47:14 -0700 Subject: [PATCH 21/31] doc update --- flashinfer/cute_dsl/mla_decode.py | 4 ++++ flashinfer/mla.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/flashinfer/cute_dsl/mla_decode.py b/flashinfer/cute_dsl/mla_decode.py index 0ee0634462..65a25c06b2 100644 --- a/flashinfer/cute_dsl/mla_decode.py +++ b/flashinfer/cute_dsl/mla_decode.py @@ -313,6 +313,10 @@ def cute_dsl_mla_decode( Scale factor applied to the output. out : Optional[torch.Tensor] Pre-allocated output tensor [B, q_len, H, kv_lora_rank]. + is_var_seq : bool + Whether the sequence length is variable. + If True, the sequence length is variable. + Otherwise,the sequence length is fixed for all the requests in the batch. Returns ------- diff --git a/flashinfer/mla.py b/flashinfer/mla.py index 52c7ec331e..75cd6e6c6a 100644 --- a/flashinfer/mla.py +++ b/flashinfer/mla.py @@ -634,6 +634,10 @@ def trtllm_batch_decode_with_kv_cache_mla( When set to ``auto``, the backend will be chosen based on the device architecture and kernel availability. For sm_100 and sm_103 (blackwell architecture), ``auto`` will choose ``trtllm-gen`` backend. For sm_120 (blackwell architecture), ``auto`` will choose ``xqa`` backend. + is_var_seq : bool + Whether the sequence length is variable. + If True, the sequence length is variable. + Otherwise,the sequence length is fixed for all the requests in the batch. Note ---- From 80a93fd3063eb4dcbc5d7052cebec0e11347bc15 Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Thu, 12 Mar 2026 01:38:52 -0700 Subject: [PATCH 22/31] fix: address review feedback for CuTe DSL MLA decode - Add SM100+ architecture check for cute-dsl backend - Add workspace_buffer dtype assertion (must be torch.int8) - Add q_len=2 test coverage for multi-token decode - Add comments for tile sizes, cluster shape, and alignment values - Clarify num_heads validation logic Co-Authored-By: Claude Opus 4.6 --- flashinfer/cute_dsl/mla_decode.py | 46 +++++++++++++-------- flashinfer/mla.py | 9 +++- tests/attention/test_cute_dsl_mla_decode.py | 4 +- 3 files changed, 37 insertions(+), 22 deletions(-) diff --git a/flashinfer/cute_dsl/mla_decode.py b/flashinfer/cute_dsl/mla_decode.py index 65a25c06b2..95a7cb6bdb 100644 --- a/flashinfer/cute_dsl/mla_decode.py +++ b/flashinfer/cute_dsl/mla_decode.py @@ -120,8 +120,11 @@ def _get_compiled_mla_kernel( All scalar arguments must be pre-wrapped as Int32/Float32. """ + # Tile sizes for Blackwell mma. + # (128, 128) for QK and (128, 256) for PV. mma_qk_tiler_mn = (128, 128) mma_pv_tiler_mn = (128, 256) + # 2 CTAs along M (num_heads) cluster_shape_mnk = (2, 1, 1) is_fp8 = torch_dtype == torch.float8_e4m3fn @@ -161,39 +164,40 @@ def _get_compiled_mla_kernel( sym_page_count = cute.sym_int() sym_workspace_size = cute.sym_int() - # All tensors use contiguous row-major layout (stride_order descending). - # The kernel's __call__ reinterprets them to the required layout via - # cute.make_tensor zero-cost metadata shuffle. + # q_latent, q_rope, c_latent, c_rope are slices of contiguous tensors on + # the last dim (e.g. query[..., :kv_lora_rank]), so they are NOT contiguous: + # stride[-2] = D_qk (original full last dim), not the sliced shape. + # Use make_fake_tensor with fully dynamic strides so the compiled kernel + # reads actual strides from the runtime tensor. Last-dim stride is always 1. - # q_latent: [batch_size, seq_len_q, num_heads, latent_dim] — contiguous - # make_fake_compact_tensor stride_order: value 0 = fastest (stride=1) - q_latent_fake = cute.runtime.make_fake_compact_tensor( + # q_latent: [batch_size, seq_len_q, num_heads, latent_dim] — non-contiguous slice + q_latent_fake = cute.runtime.make_fake_tensor( cutlass_dtype, (sym_batch, sym_seq_q, sym_heads, sym_latent), - stride_order=(3, 2, 1, 0), + stride=(cute.sym_int(), cute.sym_int(), cute.sym_int(), 1), assumed_align=16, ) - # q_rope: [batch_size, seq_len_q, num_heads, rope_dim] — contiguous - q_rope_fake = cute.runtime.make_fake_compact_tensor( + # q_rope: [batch_size, seq_len_q, num_heads, rope_dim] — non-contiguous slice + q_rope_fake = cute.runtime.make_fake_tensor( cutlass_dtype, (sym_batch, sym_seq_q, sym_heads, sym_rope), - stride_order=(3, 2, 1, 0), + stride=(cute.sym_int(), cute.sym_int(), cute.sym_int(), 1), assumed_align=16, ) - # c_latent: [kv_batch, seq_len_k, latent_dim] — contiguous + # c_latent: [kv_batch, seq_len_k, latent_dim] — non-contiguous slice # kv_batch is a separate sym_int from query batch: paged KV cache uses a flat # pool so kv_batch=num_pages at runtime, while query batch can be any value. - c_latent_fake = cute.runtime.make_fake_compact_tensor( + c_latent_fake = cute.runtime.make_fake_tensor( cutlass_dtype, (sym_kv_batch, sym_seq_kv, sym_latent), - stride_order=(2, 1, 0), + stride=(cute.sym_int(), cute.sym_int(), 1), assumed_align=16, ) - # c_rope: [kv_batch, seq_len_k, rope_dim] — contiguous - c_rope_fake = cute.runtime.make_fake_compact_tensor( + # c_rope: [kv_batch, seq_len_k, rope_dim] — non-contiguous slice + c_rope_fake = cute.runtime.make_fake_tensor( cutlass_dtype, (sym_kv_batch, sym_seq_kv, sym_rope), - stride_order=(2, 1, 0), + stride=(cute.sym_int(), cute.sym_int(), 1), assumed_align=16, ) # page_table: [batch_size, page_count] — contiguous @@ -220,7 +224,8 @@ def _get_compiled_mla_kernel( if is_workspace_size_zero: workspace_fake = None else: - # workspace: 1-D (int8 to match typical torch workspace buffers) + # workspace: 1-D int8 buffer. 32-byte alignment because workspace stores + # fp32 partial sums internally, requiring stricter alignment than tensors. workspace_fake = cute.runtime.make_fake_compact_tensor( cutlass.Int8, (sym_workspace_size,), @@ -356,9 +361,11 @@ def cute_dsl_mla_decode( # Runtime validation (int comparisons only, negligible overhead) if max_seq_len <= 0: raise ValueError(f"max_seq_len must be > 0, got {max_seq_len}") + # H=128: standard DeepSeek-V3 MLA config; H=1: used by split-kv reduction path. + # Values 2..127 are not supported by the kernel's tile config. if H < 128 and H != 1: raise ValueError( - f"cute_dsl_mla_decode requires num_heads == 128 (or 1), got {H}" + f"cute_dsl_mla_decode requires num_heads >= 128 (or 1 for reduction), got {H}" ) # Cached split_kv and workspace_size computation @@ -374,6 +381,9 @@ def cute_dsl_mla_decode( ) # Prepare workspace: slice of contiguous 1D buffer is already contiguous + assert workspace_buffer.dtype == torch.int8, ( + f"workspace_buffer must be torch.int8, got {workspace_buffer.dtype}" + ) assert workspace_buffer.numel() >= workspace_size, ( f"workspace_buffer too small: {workspace_buffer.numel()} bytes, " f"need {workspace_size} bytes" diff --git a/flashinfer/mla.py b/flashinfer/mla.py index 75cd6e6c6a..262e7ff484 100644 --- a/flashinfer/mla.py +++ b/flashinfer/mla.py @@ -774,16 +774,21 @@ def trtllm_batch_decode_with_kv_cache_mla( return out elif backend == "cute-dsl": + cc = get_compute_capability(query.device) + if cc[0] < 10: + raise RuntimeError( + f"cute-dsl backend (MLA decode kernel) requires SM100+, got SM{cc[0]}{cc[1]}" + ) from .cute_dsl.mla_decode import cute_dsl_mla_decode if isinstance(bmm1_scale, torch.Tensor): raise ValueError( - "cute-dsl backend does not support tensor bmm1_scale, " + "cute-dsl backend (MLA decode kernel) does not support tensor bmm1_scale, " "please pass a float value" ) if isinstance(bmm2_scale, torch.Tensor): raise ValueError( - "cute-dsl backend does not support tensor bmm2_scale, " + "cute-dsl backend (MLA decode kernel) does not support tensor bmm2_scale, " "please pass a float value" ) diff --git a/tests/attention/test_cute_dsl_mla_decode.py b/tests/attention/test_cute_dsl_mla_decode.py index 7a7fda65ee..026899d617 100644 --- a/tests/attention/test_cute_dsl_mla_decode.py +++ b/tests/attention/test_cute_dsl_mla_decode.py @@ -99,7 +99,8 @@ def torch_reference_mla( @pytest.mark.parametrize("seq_len_k", [128, 512, 2048]) @pytest.mark.parametrize("page_size", [128]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_cute_dsl_mla_decode_fp16(batch_size, seq_len_k, page_size, dtype): +@pytest.mark.parametrize("q_len", [1, 2]) +def test_cute_dsl_mla_decode_fp16(batch_size, seq_len_k, page_size, dtype, q_len): """Test FP16/BF16 MLA decode kernel.""" skip_if_unsupported() @@ -111,7 +112,6 @@ def test_cute_dsl_mla_decode_fp16(batch_size, seq_len_k, page_size, dtype): num_heads = 128 latent_dim = 512 rope_dim = 64 - q_len = 1 softmax_scale = 1.0 / (latent_dim**0.5) output_scale = 1.0 From 72bec0790d290cea4bc21bf7a5d9948844ad3c7d Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Sun, 15 Mar 2026 18:19:19 -0700 Subject: [PATCH 23/31] fix: add compat shim for cutlass-dsl setmaxregister API cutlass-dsl <4.4 only has warpgroup_reg_{dealloc,alloc} (deprecated), while >=4.4 renamed them to setmaxregister_{decrease,increase}. Fall back to the old names so CI with older cutlass-dsl still works. Co-Authored-By: Claude Opus 4.6 --- flashinfer/cute_dsl/mla_decode_fp16.py | 25 +++++++++++++++++++------ flashinfer/cute_dsl/mla_decode_fp8.py | 25 +++++++++++++++++++------ 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/flashinfer/cute_dsl/mla_decode_fp16.py b/flashinfer/cute_dsl/mla_decode_fp16.py index f4b266706f..70e5d219ec 100644 --- a/flashinfer/cute_dsl/mla_decode_fp16.py +++ b/flashinfer/cute_dsl/mla_decode_fp16.py @@ -38,6 +38,19 @@ import cutlass.cute as cute import cutlass.cute.testing as testing import cutlass.cute.nvgpu.tcgen05 as tcgen05 + +# Compat shim: setmaxregister_{decrease,increase} added in cutlass-dsl 4.4; +# older versions only have the deprecated warpgroup_reg_{dealloc,alloc}. +_setmaxregister_decrease = getattr( + cute.arch, + "setmaxregister_decrease", + getattr(cute.arch, "warpgroup_reg_dealloc", None), +) +_setmaxregister_increase = getattr( + cute.arch, + "setmaxregister_increase", + getattr(cute.arch, "warpgroup_reg_alloc", None), +) from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode import cutlass.cute.nvgpu.cpasync as cpasync import cutlass.utils as utils @@ -955,9 +968,9 @@ def split_kv_kernel( # /////////////////////////////////////////////////////////////////////////////// if warp_idx >= self.empty_warp_ids[0] and warp_idx <= self.empty_warp_ids[-1]: - cute.arch.setmaxregister_decrease(self.other_reg_num) + _setmaxregister_decrease(self.other_reg_num) if warp_idx == self.load_pt_warp_id: - cute.arch.setmaxregister_decrease(self.other_reg_num) + _setmaxregister_decrease(self.other_reg_num) load_pt_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.load_pt_stage ) @@ -992,7 +1005,7 @@ def split_kv_kernel( work_tile = tile_sched.get_current_work() load_pt_pipeline.producer_tail(load_pt_producer_state) if warp_idx == self.load_tma_warp_id: - cute.arch.setmaxregister_decrease(self.other_reg_num) + _setmaxregister_decrease(self.other_reg_num) load_q_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.load_q_stage ) @@ -1077,7 +1090,7 @@ def split_kv_kernel( # MMA warp # /////////////////////////////////////////////////////////////////////////////// if warp_idx == self.mma_warp_id: - cute.arch.setmaxregister_decrease(self.other_reg_num) + _setmaxregister_decrease(self.other_reg_num) # Alloc tensor memory buffer tmem.allocate(cute.arch.get_max_tmem_alloc_cols("sm_100")) tmem.wait_for_alloc() @@ -1166,7 +1179,7 @@ def split_kv_kernel( warp_idx >= self.compute_warp_ids[0] and warp_idx <= self.compute_warp_ids[-1] ): - cute.arch.setmaxregister_increase(self.softmax_reg_num) + _setmaxregister_increase(self.softmax_reg_num) mma_s_consumer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.mma_s_stage ) @@ -1236,7 +1249,7 @@ def split_kv_kernel( warp_idx >= self.correction_warp_ids[0] and warp_idx <= self.correction_warp_ids[-1] ): - cute.arch.setmaxregister_increase(self.correction_reg_num) + _setmaxregister_increase(self.correction_reg_num) p_cor_consumer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.p_cor_stage ) diff --git a/flashinfer/cute_dsl/mla_decode_fp8.py b/flashinfer/cute_dsl/mla_decode_fp8.py index 0e7e1a6e19..fcc3d468cd 100644 --- a/flashinfer/cute_dsl/mla_decode_fp8.py +++ b/flashinfer/cute_dsl/mla_decode_fp8.py @@ -38,6 +38,19 @@ import cutlass.cute.testing as testing from cutlass.cute.nvgpu import tcgen05 from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode + +# Compat shim: setmaxregister_{decrease,increase} added in cutlass-dsl 4.4; +# older versions only have the deprecated warpgroup_reg_{dealloc,alloc}. +_setmaxregister_decrease = getattr( + cute.arch, + "setmaxregister_decrease", + getattr(cute.arch, "warpgroup_reg_dealloc", None), +) +_setmaxregister_increase = getattr( + cute.arch, + "setmaxregister_increase", + getattr(cute.arch, "warpgroup_reg_alloc", None), +) import cutlass.cute.nvgpu.cpasync as cpasync import cutlass.utils as utils import cutlass.pipeline as pipeline @@ -1027,10 +1040,10 @@ def split_kv_kernel( # Load warps, including page table and data tensors # /////////////////////////////////////////////////////////////////////////////// if warp_idx >= self.empty_warp_ids[0] and warp_idx <= self.empty_warp_ids[-1]: - cute.arch.setmaxregister_decrease(self.other_reg_num) + _setmaxregister_decrease(self.other_reg_num) if warp_idx == self.load_tma_k_warp_id: - cute.arch.setmaxregister_decrease(self.other_reg_num) + _setmaxregister_decrease(self.other_reg_num) load_q_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.load_q_stage ) @@ -1090,7 +1103,7 @@ def split_kv_kernel( load_k_pipeline.producer_tail(load_k_producer_state) if warp_idx == self.load_tma_v_warp_id: - cute.arch.setmaxregister_decrease(self.other_reg_num) + _setmaxregister_decrease(self.other_reg_num) load_v_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.load_v_stage ) @@ -1136,7 +1149,7 @@ def split_kv_kernel( # MMA warp # /////////////////////////////////////////////////////////////////////////////// if warp_idx == self.mma_warp_id: - cute.arch.setmaxregister_decrease(self.other_reg_num) + _setmaxregister_decrease(self.other_reg_num) # Alloc tensor memory buffer tmem.allocate(cute.arch.get_max_tmem_alloc_cols("sm_100")) tmem.wait_for_alloc() @@ -1232,7 +1245,7 @@ def split_kv_kernel( warp_idx >= self.compute_warp_ids[0] and warp_idx <= self.compute_warp_ids[-1] ): - cute.arch.setmaxregister_increase(self.softmax_reg_num) + _setmaxregister_increase(self.softmax_reg_num) mma_s_consumer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.mma_s_stage ) @@ -1300,7 +1313,7 @@ def split_kv_kernel( warp_idx >= self.correction_warp_ids[0] and warp_idx <= self.correction_warp_ids[-1] ): - cute.arch.setmaxregister_increase(self.correction_reg_num) + _setmaxregister_increase(self.correction_reg_num) p_cor_consumer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.p_cor_stage ) From a913f909f105dc8d958e95749639b8acfdd97b16 Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Sun, 15 Mar 2026 18:28:57 -0700 Subject: [PATCH 24/31] refactor: move MLA CuTe DSL kernels to flashinfer/mla/cute_dsl/ Move MLA kernel files from flashinfer/cute_dsl/ to flashinfer/mla/cute_dsl/ to co-locate them with the MLA module. Convert flashinfer/mla.py to a package (flashinfer/mla/__init__.py + _core.py) following the pattern used by flashinfer/gemm/ and flashinfer/norm/. Co-Authored-By: Claude Opus 4.6 --- flashinfer/cute_dsl/__init__.py | 3 -- flashinfer/mla/__init__.py | 15 ++++++++++ flashinfer/{mla.py => mla/_core.py} | 12 ++++---- flashinfer/mla/cute_dsl/__init__.py | 30 +++++++++++++++++++ flashinfer/{ => mla}/cute_dsl/mla_decode.py | 6 +++- .../{ => mla}/cute_dsl/mla_decode_fp16.py | 0 .../{ => mla}/cute_dsl/mla_decode_fp8.py | 0 flashinfer/{ => mla}/cute_dsl/mla_helpers.py | 0 tests/attention/test_cute_dsl_mla_decode.py | 6 ++-- 9 files changed, 59 insertions(+), 13 deletions(-) create mode 100644 flashinfer/mla/__init__.py rename flashinfer/{mla.py => mla/_core.py} (99%) create mode 100644 flashinfer/mla/cute_dsl/__init__.py rename flashinfer/{ => mla}/cute_dsl/mla_decode.py (99%) rename flashinfer/{ => mla}/cute_dsl/mla_decode_fp16.py (100%) rename flashinfer/{ => mla}/cute_dsl/mla_decode_fp8.py (100%) rename flashinfer/{ => mla}/cute_dsl/mla_helpers.py (100%) diff --git a/flashinfer/cute_dsl/__init__.py b/flashinfer/cute_dsl/__init__.py index 3510995a60..1fd4e6e385 100644 --- a/flashinfer/cute_dsl/__init__.py +++ b/flashinfer/cute_dsl/__init__.py @@ -54,7 +54,6 @@ add_rmsnorm_fp4quant, AddRMSNormFP4QuantKernel, ) - from .mla_decode import cute_dsl_mla_decode # Backwards-compatible re-exports from flashinfer.norm.kernels submodule from ..norm.kernels import ( @@ -100,8 +99,6 @@ # Add + RMSNorm + FP4 Quantization "add_rmsnorm_fp4quant", "AddRMSNormFP4QuantKernel", - # MLA Decode - "cute_dsl_mla_decode", # Norm kernels (CuTe DSL) - backwards-compatible re-exports "RMSNormKernel", "QKRMSNormKernel", diff --git a/flashinfer/mla/__init__.py b/flashinfer/mla/__init__.py new file mode 100644 index 0000000000..9addfd8849 --- /dev/null +++ b/flashinfer/mla/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2023 by FlashInfer team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._core import * # noqa: F401,F403 diff --git a/flashinfer/mla.py b/flashinfer/mla/_core.py similarity index 99% rename from flashinfer/mla.py rename to flashinfer/mla/_core.py index ab957f5823..65f40ae4d8 100644 --- a/flashinfer/mla.py +++ b/flashinfer/mla/_core.py @@ -20,10 +20,10 @@ import torch -from .api_logging import flashinfer_api -from .jit import gen_batch_mla_module, gen_trtllm_gen_fmha_module, setup_cubin_loader -from .jit.mla import gen_mla_module -from .utils import ( +from ..api_logging import flashinfer_api +from ..jit import gen_batch_mla_module, gen_trtllm_gen_fmha_module, setup_cubin_loader +from ..jit.mla import gen_mla_module +from ..utils import ( MaskMode, check_shape_dtype_device, determine_mla_backend, @@ -32,7 +32,7 @@ get_device_sm_count, log2e, ) -from .xqa import xqa_mla +from ..xqa import xqa_mla def _check_cutlass_shape(q_nope_pe, ckv_kpe_cache, kv_len, page_table): @@ -779,7 +779,7 @@ def trtllm_batch_decode_with_kv_cache_mla( raise RuntimeError( f"cute-dsl backend (MLA decode kernel) requires SM100+, got SM{cc[0]}{cc[1]}" ) - from .cute_dsl.mla_decode import cute_dsl_mla_decode + from .cute_dsl import cute_dsl_mla_decode if isinstance(bmm1_scale, torch.Tensor): raise ValueError( diff --git a/flashinfer/mla/cute_dsl/__init__.py b/flashinfer/mla/cute_dsl/__init__.py new file mode 100644 index 0000000000..24572e9913 --- /dev/null +++ b/flashinfer/mla/cute_dsl/__init__.py @@ -0,0 +1,30 @@ +# Copyright (c) 2026 by FlashInfer team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +CuTe DSL MLA Decode Kernels for Blackwell SM100. +""" + +from flashinfer.cute_dsl.utils import is_cute_dsl_available + +if is_cute_dsl_available(): + from .mla_decode import cute_dsl_mla_decode + +__all__ = [ + "is_cute_dsl_available", +] + +if is_cute_dsl_available(): + __all__ += [ + "cute_dsl_mla_decode", + ] diff --git a/flashinfer/cute_dsl/mla_decode.py b/flashinfer/mla/cute_dsl/mla_decode.py similarity index 99% rename from flashinfer/cute_dsl/mla_decode.py rename to flashinfer/mla/cute_dsl/mla_decode.py index 95a7cb6bdb..45f7356626 100644 --- a/flashinfer/cute_dsl/mla_decode.py +++ b/flashinfer/mla/cute_dsl/mla_decode.py @@ -30,7 +30,11 @@ from .mla_decode_fp16 import BlackwellMultiHeadLatentAttentionForwardFP16 from .mla_decode_fp8 import BlackwellMultiHeadLatentAttentionForwardFP8 -from .utils import get_max_active_clusters, get_num_sm, torch_to_cutlass_dtype +from flashinfer.cute_dsl.utils import ( + get_max_active_clusters, + get_num_sm, + torch_to_cutlass_dtype, +) @functools.cache diff --git a/flashinfer/cute_dsl/mla_decode_fp16.py b/flashinfer/mla/cute_dsl/mla_decode_fp16.py similarity index 100% rename from flashinfer/cute_dsl/mla_decode_fp16.py rename to flashinfer/mla/cute_dsl/mla_decode_fp16.py diff --git a/flashinfer/cute_dsl/mla_decode_fp8.py b/flashinfer/mla/cute_dsl/mla_decode_fp8.py similarity index 100% rename from flashinfer/cute_dsl/mla_decode_fp8.py rename to flashinfer/mla/cute_dsl/mla_decode_fp8.py diff --git a/flashinfer/cute_dsl/mla_helpers.py b/flashinfer/mla/cute_dsl/mla_helpers.py similarity index 100% rename from flashinfer/cute_dsl/mla_helpers.py rename to flashinfer/mla/cute_dsl/mla_helpers.py diff --git a/tests/attention/test_cute_dsl_mla_decode.py b/tests/attention/test_cute_dsl_mla_decode.py index 026899d617..28ba68c4c4 100644 --- a/tests/attention/test_cute_dsl_mla_decode.py +++ b/tests/attention/test_cute_dsl_mla_decode.py @@ -104,7 +104,7 @@ def test_cute_dsl_mla_decode_fp16(batch_size, seq_len_k, page_size, dtype, q_len """Test FP16/BF16 MLA decode kernel.""" skip_if_unsupported() - from flashinfer.cute_dsl.mla_decode import cute_dsl_mla_decode + from flashinfer.mla.cute_dsl import cute_dsl_mla_decode torch.manual_seed(42) device = torch.device("cuda") @@ -190,7 +190,7 @@ def test_cute_dsl_mla_decode_variable_seq_len(batch_size, seq_len_k, page_size=1 """Test MLA decode with variable sequence lengths across the batch.""" skip_if_unsupported() - from flashinfer.cute_dsl.mla_decode import cute_dsl_mla_decode + from flashinfer.mla.cute_dsl import cute_dsl_mla_decode torch.manual_seed(42) device = torch.device("cuda") @@ -397,7 +397,7 @@ def test_cute_dsl_mla_decode_fp8(batch_size, seq_len_k, page_size): """Test FP8 MLA decode kernel against FP32 reference.""" skip_if_unsupported() - from flashinfer.cute_dsl.mla_decode import cute_dsl_mla_decode + from flashinfer.mla.cute_dsl import cute_dsl_mla_decode torch.manual_seed(42) device = torch.device("cuda") From 82cc2c366a68041088fb140bc5bcb459f7c7e795 Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Sun, 15 Mar 2026 18:44:04 -0700 Subject: [PATCH 25/31] fix: update copyright year to 2026 in flashinfer/mla/ Co-Authored-By: Claude Opus 4.6 --- flashinfer/mla/__init__.py | 2 +- flashinfer/mla/_core.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flashinfer/mla/__init__.py b/flashinfer/mla/__init__.py index 9addfd8849..5ca5348a41 100644 --- a/flashinfer/mla/__init__.py +++ b/flashinfer/mla/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 by FlashInfer team. +# Copyright (c) 2026 by FlashInfer team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/flashinfer/mla/_core.py b/flashinfer/mla/_core.py index 65f40ae4d8..b69256e826 100644 --- a/flashinfer/mla/_core.py +++ b/flashinfer/mla/_core.py @@ -1,5 +1,5 @@ """ -Copyright (c) 2023 by FlashInfer team. +Copyright (c) 2026 by FlashInfer team. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. From 2423c576ce21f7c6673b025fd94ef7e7ff0dfc31 Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Sun, 15 Mar 2026 18:46:27 -0700 Subject: [PATCH 26/31] fix: update copyright years to 2026 Co-Authored-By: Claude Opus 4.6 --- flashinfer/mla/cute_dsl/mla_decode_fp16.py | 2 +- flashinfer/mla/cute_dsl/mla_decode_fp8.py | 2 +- flashinfer/mla/cute_dsl/mla_helpers.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/flashinfer/mla/cute_dsl/mla_decode_fp16.py b/flashinfer/mla/cute_dsl/mla_decode_fp16.py index 70e5d219ec..9a5933f065 100644 --- a/flashinfer/mla/cute_dsl/mla_decode_fp16.py +++ b/flashinfer/mla/cute_dsl/mla_decode_fp16.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # Redistribution and use in source and binary forms, with or without diff --git a/flashinfer/mla/cute_dsl/mla_decode_fp8.py b/flashinfer/mla/cute_dsl/mla_decode_fp8.py index fcc3d468cd..b0e35e8ff4 100644 --- a/flashinfer/mla/cute_dsl/mla_decode_fp8.py +++ b/flashinfer/mla/cute_dsl/mla_decode_fp8.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # Redistribution and use in source and binary forms, with or without diff --git a/flashinfer/mla/cute_dsl/mla_helpers.py b/flashinfer/mla/cute_dsl/mla_helpers.py index 1790b3c882..ac2bee49df 100644 --- a/flashinfer/mla/cute_dsl/mla_helpers.py +++ b/flashinfer/mla/cute_dsl/mla_helpers.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # Redistribution and use in source and binary forms, with or without From 037eab65be5c622cd2976ee780aa2713f63e2cb6 Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Mon, 16 Mar 2026 17:39:32 -0700 Subject: [PATCH 27/31] fix: add compat shim for cutlass-dsl get_max_tmem_alloc_cols API get_max_tmem_alloc_cols was added in cutlass-dsl 4.4; older versions (e.g. 4.3.4) don't have it, causing AttributeError on import. Provide a local fallback that returns the same TMEM capacity constants. Co-Authored-By: Claude Opus 4.6 (1M context) --- flashinfer/mla/cute_dsl/mla_decode_fp16.py | 18 +++++++++++++++++- flashinfer/mla/cute_dsl/mla_decode_fp8.py | 16 +++++++++++++++- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/flashinfer/mla/cute_dsl/mla_decode_fp16.py b/flashinfer/mla/cute_dsl/mla_decode_fp16.py index 9a5933f065..d91385088f 100644 --- a/flashinfer/mla/cute_dsl/mla_decode_fp16.py +++ b/flashinfer/mla/cute_dsl/mla_decode_fp16.py @@ -39,6 +39,7 @@ import cutlass.cute.testing as testing import cutlass.cute.nvgpu.tcgen05 as tcgen05 +# TODO: Remove this hook helper function after nvidia-cutlass-dsl 4.3.x is no longer supported. # Compat shim: setmaxregister_{decrease,increase} added in cutlass-dsl 4.4; # older versions only have the deprecated warpgroup_reg_{dealloc,alloc}. _setmaxregister_decrease = getattr( @@ -51,6 +52,21 @@ "setmaxregister_increase", getattr(cute.arch, "warpgroup_reg_alloc", None), ) + +# Compat shim: get_max_tmem_alloc_cols added in cutlass-dsl 4.4; +# older versions don't have it, so we provide a fallback implementation. +_TMEM_MAX_ALLOC_COLUMNS_MAP = {"sm_100": 512, "sm_103": 512, "sm_120": 512} + + +# TODO: Remove this hook helper function after nvidia-cutlass-dsl 4.3.x is no longer supported. +def _get_max_tmem_alloc_cols(compute_capability: str) -> int: + if hasattr(cute.arch, "get_max_tmem_alloc_cols"): + return cute.arch.get_max_tmem_alloc_cols(compute_capability) + if compute_capability not in _TMEM_MAX_ALLOC_COLUMNS_MAP: + raise ValueError(f"Unsupported compute capability: {compute_capability}") + return _TMEM_MAX_ALLOC_COLUMNS_MAP[compute_capability] + + from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode import cutlass.cute.nvgpu.cpasync as cpasync import cutlass.utils as utils @@ -1092,7 +1108,7 @@ def split_kv_kernel( if warp_idx == self.mma_warp_id: _setmaxregister_decrease(self.other_reg_num) # Alloc tensor memory buffer - tmem.allocate(cute.arch.get_max_tmem_alloc_cols("sm_100")) + tmem.allocate(_get_max_tmem_alloc_cols("sm_100")) tmem.wait_for_alloc() tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) diff --git a/flashinfer/mla/cute_dsl/mla_decode_fp8.py b/flashinfer/mla/cute_dsl/mla_decode_fp8.py index b0e35e8ff4..8d50aab2e7 100644 --- a/flashinfer/mla/cute_dsl/mla_decode_fp8.py +++ b/flashinfer/mla/cute_dsl/mla_decode_fp8.py @@ -51,6 +51,20 @@ "setmaxregister_increase", getattr(cute.arch, "warpgroup_reg_alloc", None), ) + +# Compat shim: get_max_tmem_alloc_cols added in cutlass-dsl 4.4; +# older versions don't have it, so we provide a fallback implementation. +_TMEM_MAX_ALLOC_COLUMNS_MAP = {"sm_100": 512, "sm_103": 512, "sm_120": 512} + + +def _get_max_tmem_alloc_cols(compute_capability: str) -> int: + if hasattr(cute.arch, "get_max_tmem_alloc_cols"): + return cute.arch.get_max_tmem_alloc_cols(compute_capability) + if compute_capability not in _TMEM_MAX_ALLOC_COLUMNS_MAP: + raise ValueError(f"Unsupported compute capability: {compute_capability}") + return _TMEM_MAX_ALLOC_COLUMNS_MAP[compute_capability] + + import cutlass.cute.nvgpu.cpasync as cpasync import cutlass.utils as utils import cutlass.pipeline as pipeline @@ -1151,7 +1165,7 @@ def split_kv_kernel( if warp_idx == self.mma_warp_id: _setmaxregister_decrease(self.other_reg_num) # Alloc tensor memory buffer - tmem.allocate(cute.arch.get_max_tmem_alloc_cols("sm_100")) + tmem.allocate(_get_max_tmem_alloc_cols("sm_100")) tmem.wait_for_alloc() tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) From c09a2ccabf522b00ece80560b9560b4751e30ba7 Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Mon, 16 Mar 2026 19:35:19 -0700 Subject: [PATCH 28/31] feat: add cute-dsl backend to test_trtllm_gen_mla uniform testing - Add "cute-dsl" to backend parametrize in test_trtllm_batch_decode_mla - Add skip conditions for cute-dsl unsupported features (dynamic_scale, enable_pdl, skip_softmax, num_heads < 128) - Add torch_reference_mla as fallback when fa2 reference diverges from cute-dsl in certain configs (e.g. q_len>1 with small page_size) - Cast cute-dsl fp8 output to bf16 before comparison - Re-zero workspace buffer each test to avoid cross-backend contamination - Update docstring for backend, bmm1_scale, bmm2_scale params Co-Authored-By: Claude Opus 4.6 (1M context) --- flashinfer/mla/_core.py | 8 +- tests/attention/test_trtllm_gen_mla.py | 127 ++++++++++++++++++++----- 2 files changed, 106 insertions(+), 29 deletions(-) diff --git a/flashinfer/mla/_core.py b/flashinfer/mla/_core.py index b69256e826..5c0ad62b7a 100644 --- a/flashinfer/mla/_core.py +++ b/flashinfer/mla/_core.py @@ -620,9 +620,11 @@ def trtllm_batch_decode_with_kv_cache_mla( max_seq_len: max sequence length for kv_cache out: output tensor, if not provided, will be allocated internally bmm1_scale: fused scale for mla bmm1 input. - when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32. + When using ``trtllm-gen`` backend, it can be a ``torch.Tensor`` with dtype ``torch.float32``. + When using ``cute-dsl`` backend, only ``float`` values are supported. bmm2_scale: fused scale for mla bmm2 input. - when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32. + When using ``trtllm-gen`` backend, it can be a ``torch.Tensor`` with dtype ``torch.float32``. + When using ``cute-dsl`` backend, only ``float`` values are supported. sinks: additional value per head in the denominator of the softmax. skip_softmax_threshold_scale_factor: threshold scale factor for skipping softmax operations. Providing a value for this parameter enables skip-softmax sparsity as described in: https://arxiv.org/abs/2512.12087 @@ -630,7 +632,7 @@ def trtllm_batch_decode_with_kv_cache_mla( Setting the threshold to a higher value generally increases kernel performance at the cost of accuracy degradation. The actual threshold value equals the provided threshold_scale_factor divided by the context length. backend : str = "auto" - The implementation backend, could be ``auto``/``xqa`` or ``trtllm-gen``. Defaults to ``auto``. + The implementation backend, could be ``auto``/``xqa``, ``trtllm-gen``, or ``cute-dsl``. Defaults to ``auto``. When set to ``auto``, the backend will be chosen based on the device architecture and kernel availability. For sm_100 and sm_103 (blackwell architecture), ``auto`` will choose ``trtllm-gen`` backend. For sm_120 (blackwell architecture), ``auto`` will choose ``xqa`` backend. diff --git a/tests/attention/test_trtllm_gen_mla.py b/tests/attention/test_trtllm_gen_mla.py index 60ba7ab927..9eb31ac599 100644 --- a/tests/attention/test_trtllm_gen_mla.py +++ b/tests/attention/test_trtllm_gen_mla.py @@ -1,5 +1,6 @@ import pytest import torch +import torch.nn.functional as F import random import flashinfer @@ -214,6 +215,50 @@ def scaled_dot_product_attention( return out_ref, lse_ref +def torch_reference_mla( + query, + kv_cache, + block_tables, + seq_lens, + kv_lora_rank, + qk_rope_head_dim, + softmax_scale, + output_scale, + page_size, +): + """PyTorch reference for MLA decode. Returns [B, q_len, H, kv_lora_rank].""" + B, q_len, H, D_qk = query.shape + kv_flat = kv_cache.reshape(-1, D_qk) + c_latent = kv_flat[:, :kv_lora_rank] + c_rope = kv_flat[:, kv_lora_rank:] + q_nope = query[..., :kv_lora_rank] + q_rope = query[..., kv_lora_rank:] + + outputs = [] + for b in range(B): + seq_len = seq_lens[b].item() + num_pages = (seq_len + page_size - 1) // page_size + pages = block_tables[b, :num_pages] + kv_indices = [] + for p in pages: + start = p.item() * page_size + kv_indices.extend(range(start, start + page_size)) + kv_indices = kv_indices[:seq_len] + kv_idx_t = torch.tensor(kv_indices, device=query.device) + + k_lat = c_latent[kv_idx_t] # [seq_len, kv_lora_rank] + k_rope = c_rope[kv_idx_t] # [seq_len, rope_dim] + + attn_lat = torch.einsum("qhd,kd->qhk", q_nope[b].float(), k_lat.float()) + attn_rope = torch.einsum("qhd,kd->qhk", q_rope[b].float(), k_rope.float()) + attn = (attn_lat + attn_rope) * softmax_scale + attn = F.softmax(attn, dim=-1) + out_b = torch.einsum("qhk,kd->qhd", attn, k_lat.float()) * output_scale + outputs.append(out_b) + + return torch.stack(outputs, dim=0) # [B, q_len, H, kv_lora_rank] + + def trtllm_batch_decode_mla( layer_dimensions: MLALayerDimensions, batch_size: int, @@ -238,6 +283,15 @@ def trtllm_batch_decode_mla( if backend == "trtllm-gen": if compute_capability[0] != 10: pytest.skip("TRTLLM-GEN MLA only supports SM100 and SM103 GPUs") + if backend == "cute-dsl": + if compute_capability[0] < 10: + pytest.skip("cute-dsl MLA requires SM100+") + if dynamic_scale: + pytest.skip("cute-dsl does not support dynamic_scale") + if enable_pdl is not None: + pytest.skip("cute-dsl does not support enable_pdl") + if skips_softmax: + pytest.skip("cute-dsl does not support skip_softmax") if dynamic_scale and dtype != torch.float8_e4m3fn: pytest.skip("Dynamic scale is not supported for non-fp8 dtype") @@ -315,6 +369,9 @@ def trtllm_batch_decode_mla( global_trtllm_gen_fmha_workspace_buffer = torch.zeros( workspace_size, dtype=torch.int8, device=device ) + # trtllm-gen requires zero-initialized workspace (counter region); + # re-zero each time since other backends (e.g. cute-dsl) may share and dirty it. + global_trtllm_gen_fmha_workspace_buffer.zero_() workspace_buffer = global_trtllm_gen_fmha_workspace_buffer workspace_buffer_ref = global_workspace_buffer @@ -340,7 +397,8 @@ def trtllm_batch_decode_mla( ) # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future - assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() + if backend == "trtllm-gen": + assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() # Run reference attention and align output sm_scale = scale / ( @@ -395,39 +453,54 @@ def trtllm_batch_decode_mla( o_ref = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=False) - if backend == "trtllm-gen": + # cute-dsl fp8 kernel outputs fp8; cast to bf16 to match trtllm-gen / reference + if backend == "cute-dsl" and output.dtype == torch.float8_e4m3fn: + output = output.to(torch.bfloat16) + + if backend in ("trtllm-gen", "cute-dsl"): # check is nan assert not torch.isnan(o_ref).any(), "o_ref is nan" assert not torch.isnan(output).any(), "output is nan" + o_ref_view = o_ref.view( + batch_size, q_len_per_request, layer_dimensions.num_heads, -1 + ) + if dtype == torch.float8_e4m3fn: - try: - torch.testing.assert_close( - output, - o_ref.view( - batch_size, q_len_per_request, layer_dimensions.num_heads, -1 - ), - rtol=1e-1, - atol=1e-1, - ) # todo: do reference with normal attention? - except AssertionError as e: - print("output:", output) - print("o_ref:", o_ref) - raise e + rtol, atol = 1e-1, 1e-1 else: - try: - torch.testing.assert_close( - output, - o_ref.view( - batch_size, q_len_per_request, layer_dimensions.num_heads, -1 - ), - rtol=1e-2, - atol=1e-2, + rtol, atol = 1e-2, 1e-2 + + try: + torch.testing.assert_close(output, o_ref_view, rtol=rtol, atol=atol) + except AssertionError as fa2_err: + if backend == "cute-dsl": + # fa2 reference may diverge from cute-dsl in some configs; + # fall back to torch reference as ground truth. + query_for_ref = ( + query.to(torch.bfloat16) if dtype == torch.float8_e4m3fn else query + ) + kv_for_ref = ( + kv_cache.to(torch.bfloat16) + if dtype == torch.float8_e4m3fn + else kv_cache ) - except AssertionError as e: + o_torch_ref = torch_reference_mla( + query_for_ref, + kv_for_ref, + block_tables, + seq_lens_tensor, + layer_dimensions.head_dimensions.kv_lora_rank, + layer_dimensions.head_dimensions.qk_rope_head_dim, + softmax_scale=sm_scale, + output_scale=1.0, + page_size=page_size, + ).to(output.dtype) + torch.testing.assert_close(output, o_torch_ref, rtol=rtol, atol=atol) + else: print("output:", output) print("o_ref:", o_ref) - raise e + raise fa2_err elif backend == "xqa": atol = 0.05 rtol = 0.05 @@ -712,7 +785,7 @@ def trtllm_batch_decode_mla_sparse( ) # todo(Yingyi): verify larger q_len_per_request @pytest.mark.parametrize("dynamic_scale", [False]) @pytest.mark.parametrize("enable_pdl", [True, False, None]) -@pytest.mark.parametrize("backend", ["trtllm-gen", "xqa"]) +@pytest.mark.parametrize("backend", ["trtllm-gen", "xqa", "cute-dsl"]) @pytest.mark.parametrize("skips_softmax", [False, True]) def test_trtllm_batch_decode_mla( layer_dimensions: MLALayerDimensions, @@ -728,6 +801,8 @@ def test_trtllm_batch_decode_mla( ): if backend == "xqa" and layer_dimensions.head_dimensions == smaller_mla_dimensions: pytest.skip("XQA MLA does not support smaller MLA dimensions yet.") + if backend == "cute-dsl" and layer_dimensions.num_heads < 128: + pytest.skip("cute-dsl MLA requires num_heads >= 128") trtllm_batch_decode_mla( layer_dimensions, From 96fb34ed9315b0c565383f9f37ae0e6034b7e7dc Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Mon, 16 Mar 2026 20:24:11 -0700 Subject: [PATCH 29/31] feat: add cute-dsl backend support for MLA microbenchmark - Add "cute-dsl" to --backends choices for BatchMLAPagedAttentionWrapper - Add dispatch logic calling trtllm_batch_decode_with_kv_cache_mla with backend="cute-dsl" - Register cute-dsl in SM10.0 compute capability map - Update README with cute-dsl backend documentation Co-Authored-By: Claude Opus 4.6 (1M context) --- benchmarks/README.md | 6 ++--- benchmarks/routines/attention.py | 23 +++++++++++++++++++ .../routines/flashinfer_benchmark_utils.py | 3 ++- 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/benchmarks/README.md b/benchmarks/README.md index b3e3e3b22f..198719d8bd 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -19,7 +19,7 @@ Currently supports testing attention, gemm, fused MOE, normalization, quantizati - `BatchPrefillWithRaggedKVCacheWrapper` - Prefill attention with ragged KV cache. - Also supports computationally similar `cudnn_batch_prefill_with_kv_cache` (cudnn-native) and `trtllm_ragged_attention_deepseek`. - `BatchMLAPagedAttentionWrapper` - MLA attention proposed in DeepSeek series of models. - - Also supports computationally similar `trtllm_batch_decode_with_kv_cache_mla`. + - Also supports computationally similar `trtllm_batch_decode_with_kv_cache_mla` (trtllm-native) and CuTe DSL MLA decode kernel (cute-dsl, SM100+). - GEMM: - `gemm_fp8_nt_groupwise` - GEMM with FP8 data types using groupwise scaling. - `group_gemm_fp8_nt_groupwise` - Group GEMM with FP8 data types using groupwise scaling. @@ -191,7 +191,7 @@ The output CSV will contain detailed metrics including: | `--verbose`, `-v` | Print additional information (can be used multiple times for more verbosity, e.g. `-vv`) | | `--case_tag` | Optional tag for the test case, useful for annotating or filtering results in the output CSV. | | `--generate_repro_command`| If set, prints a reproducer command for the test case and stores it in the output CSV. | -| `--backends` | Space-separated list of backends to test, e.g. fa2, fa2_tc, fa3, auto, cudnn, cudnn-native, cutlass, trtllm, trtllm-gen, trtllm-native, cublas. (`auto` currently supported for `BatchDecodeWithPagedKVCacheWrapper` and `BatchPrefillWithPagedKVCacheWrapper`.)| +| `--backends` | Space-separated list of backends to test, e.g. fa2, fa2_tc, fa3, auto, cudnn, cudnn-native, cutlass, trtllm, trtllm-gen, trtllm-native, cute-dsl, cublas. (`auto` currently supported for `BatchDecodeWithPagedKVCacheWrapper` and `BatchPrefillWithPagedKVCacheWrapper`.)| ### Attention Flags | Flag | Description | @@ -464,7 +464,7 @@ Legend: | **BatchDecodeWithPagedKVCacheWrapper** | fa2 | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn | fa2, fa2_tc, cudnn, trtllm-gen, trtllm-native | fa2, fa2_tc, cudnn, trtllm-gen, trtllm-native | fa2, fa2_tc, cudnn | | **BatchPrefillWithPagedKVCacheWrapper** | | fa2, cudnn, cudnn-native | fa2, cudnn, cudnn-native | fa2, cudnn, cudnn-native | fa2, fa3, cudnn, cudnn-native | fa2, cudnn, cudnn-native, trtllm-gen, trtllm-native | fa2, cudnn, cudnn-native, trtllm-gen, trtllm-native | fa2, cudnn, cudnn-native | | **BatchPrefillWithRaggedKVCacheWrapper** | | fa2, cudnn, cudnn-native | fa2, cudnn, cudnn-native | fa2, cudnn, cudnn-native | fa2, fa3, cudnn, cudnn-native | fa2, cudnn, cudnn-native, cutlass, trtllm-native | fa2, cudnn, cudnn-native, cutlass, trtllm-native | fa2, cudnn, cudnn-native | -| **BatchMLAPagedAttentionWrapper** | | fa2 | fa2 | fa2 | fa2, fa3 | fa2, cutlass, trtllm-native | fa2, cutlass, trtllm-native | fa2 | +| **BatchMLAPagedAttentionWrapper** | | fa2 | fa2 | fa2 | fa2, fa3 | fa2, cutlass, trtllm-native, cute-dsl | fa2, cutlass, trtllm-native | fa2 | | **gemm_fp8_nt_groupwise** | | | | | | cutlass | cutlass | | | **group_gemm_fp8_nt_groupwise** | | | | | | cutlass | cutlass | | | **bmm_fp8** | | | | cudnn, cublas | cudnn, cublas | cudnn, cublas, cutlass | cudnn, cublas, cutlass | cudnn, cublas | diff --git a/benchmarks/routines/attention.py b/benchmarks/routines/attention.py index 717f82d92e..58d8e9c68a 100644 --- a/benchmarks/routines/attention.py +++ b/benchmarks/routines/attention.py @@ -110,6 +110,7 @@ def parse_attention_args(line, parser): "trtllm-gen", "trtllm-native", "trtllm-gen-native", # Deprecated, will be removed in future + "cute-dsl", ], help="Kernel backends to test. Default: fa2. backend=auto is only supported for BatchDecodeWithPagedKVCacheWrapper and BatchPrefillWithPagedKVCacheWrapper.", ) @@ -2122,6 +2123,13 @@ def testBatchMLAPagedAttentionWrapper(args): remove_trtllm_native = True if remove_trtllm_native: backends.remove("trtllm-native") + if "cute-dsl" in backends: + remove_cute_dsl = False + if num_qo_heads < 128: + print("[INFO] cute-dsl MLA backend requires num_heads >= 128. Skipping.") + remove_cute_dsl = True + if remove_cute_dsl: + backends.remove("cute-dsl") if len(backends) == 0: print("[ERROR] No backends to test. Exiting.") return res @@ -2307,6 +2315,21 @@ def run_backend_wrapper( bmm1_scale=sm_scale, bmm2_scale=1.0, ).squeeze(1) + elif backend == "cute-dsl": + return flashinfer.mla.trtllm_batch_decode_with_kv_cache_mla( + query=q.unsqueeze(1), + kv_cache=kv_cache.unsqueeze(1), + workspace_buffer=workspace_buffer, + qk_nope_head_dim=128, + kv_lora_rank=head_dim_ckv, + qk_rope_head_dim=head_dim_kpe, + block_tables=block_tables, + seq_lens=actual_seq_lens_kv.flatten(), + max_seq_len=s_kv, + bmm1_scale=sm_scale, + bmm2_scale=1.0, + backend="cute-dsl", + ).squeeze(1) else: print(f"[ERROR] Unsupported backend: {backend}") return None diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index 95fe833b3c..a98c7f60e2 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -322,12 +322,13 @@ def dtype_str_to_torch_dtype(dtype_str): }, "BatchMLAPagedAttentionWrapper": { # NOTE: trtllm-native calls trtllm_batch_decode_with_kv_cache_mla + # NOTE: cute-dsl calls trtllm_batch_decode_with_kv_cache_mla(backend="cute-dsl") "7.5": [], "8.0": ["fa2"], "8.6": ["fa2"], "8.9": ["fa2"], "9.0": ["fa2", "fa3"], - "10.0": ["fa2", "cutlass", "trtllm-native"], + "10.0": ["fa2", "cutlass", "trtllm-native", "cute-dsl"], "10.3": ["fa2", "cutlass", "trtllm-native"], "12.0": ["fa2"], }, From 81ec3aa711bad92b082b7cbd89fda97b736db690 Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Wed, 25 Mar 2026 00:39:58 -0700 Subject: [PATCH 30/31] feat: support flexible output dtype for CuTe DSL MLA FP8 decode kernel - Allow BFloat16 output for FP8 input (matching trtllm-gen backend default) - FP16/BF16 input defaults to same dtype output; FP8 input defaults to BF16 output - Add out_dtype parameter to cute_dsl_mla_decode for explicit override - Add uses_shared_paged_kv_idx=False validation for cute-dsl backend - Skip unsupported 3D page table tests for cute-dsl Co-Authored-By: Claude Opus 4.6 (1M context) --- flashinfer/mla/_core.py | 5 +++ flashinfer/mla/cute_dsl/mla_decode.py | 34 ++++++++++++++++----- flashinfer/mla/cute_dsl/mla_decode_fp8.py | 2 +- tests/attention/test_cute_dsl_mla_decode.py | 2 +- tests/attention/test_trtllm_gen_mla.py | 2 ++ 5 files changed, 36 insertions(+), 9 deletions(-) diff --git a/flashinfer/mla/_core.py b/flashinfer/mla/_core.py index 2991861afc..bca53627bf 100644 --- a/flashinfer/mla/_core.py +++ b/flashinfer/mla/_core.py @@ -831,6 +831,11 @@ def trtllm_batch_decode_with_kv_cache_mla( raise ValueError( "cute-dsl backend (MLA decode kernel) does not support skip_softmax_threshold_scale_factor" ) + if not uses_shared_paged_kv_idx: + raise ValueError( + "cute-dsl backend (MLA decode kernel) does not support separate KV page indices " + "(uses_shared_paged_kv_idx=False)" + ) return cute_dsl_mla_decode( query=query, diff --git a/flashinfer/mla/cute_dsl/mla_decode.py b/flashinfer/mla/cute_dsl/mla_decode.py index 45f7356626..ad19eb821b 100644 --- a/flashinfer/mla/cute_dsl/mla_decode.py +++ b/flashinfer/mla/cute_dsl/mla_decode.py @@ -58,6 +58,7 @@ def _get_split_kv_and_workspace_size( @functools.cache def _check_can_implement( torch_dtype: torch.dtype, + torch_out_dtype: torch.dtype, page_size: int, num_heads: int, seq_len_q: int, @@ -77,7 +78,8 @@ def _check_can_implement( if is_fp8 else BlackwellMultiHeadLatentAttentionForwardFP16 ) - cutlass_dtype = torch_to_cutlass_dtype(torch_dtype) + cutlass_in_dtype = torch_to_cutlass_dtype(torch_dtype) + cutlass_out_dtype = torch_to_cutlass_dtype(torch_out_dtype) if not KernelClass.can_implement( 1, # B (runtime, use placeholder) seq_len_q, @@ -85,8 +87,8 @@ def _check_can_implement( num_heads, kv_lora_rank, qk_rope_head_dim, - cutlass_dtype, - cutlass_dtype, + cutlass_in_dtype, + cutlass_out_dtype, cutlass.Float32, cutlass.Float32, mma_qk_tiler_mn, @@ -100,13 +102,14 @@ def _check_can_implement( raise ValueError( f"cute_dsl_mla_decode: unsupported configuration " f"(q_len={seq_len_q}, num_heads={num_heads}, page_size={page_size}, " - f"dtype={torch_dtype})" + f"in_dtype={torch_dtype}, out_dtype={torch_out_dtype})" ) @functools.cache def _get_compiled_mla_kernel( torch_dtype: torch.dtype, + torch_out_dtype: torch.dtype, page_size: int, kv_lora_rank: int, qk_rope_head_dim: int, @@ -138,6 +141,7 @@ def _get_compiled_mla_kernel( else BlackwellMultiHeadLatentAttentionForwardFP16 ) cutlass_dtype = torch_to_cutlass_dtype(torch_dtype) + cutlass_out_dtype = torch_to_cutlass_dtype(torch_out_dtype) kernel_obj = KernelClass( acc_dtype=cutlass.Float32, @@ -213,7 +217,7 @@ def _get_compiled_mla_kernel( ) # o: [batch_size, seq_len_q, num_heads, latent_dim] — contiguous o_fake = cute.runtime.make_fake_compact_tensor( - cutlass_dtype, + cutlass_out_dtype, (sym_batch, sym_seq_q, sym_heads, sym_latent), stride_order=(3, 2, 1, 0), assumed_align=16, @@ -288,6 +292,7 @@ def cute_dsl_mla_decode( softmax_scale: float, output_scale: float = 1.0, out: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, is_var_seq: bool = True, ) -> torch.Tensor: """CuTe DSL MLA decode kernel for Blackwell SM100. @@ -322,6 +327,10 @@ def cute_dsl_mla_decode( Scale factor applied to the output. out : Optional[torch.Tensor] Pre-allocated output tensor [B, q_len, H, kv_lora_rank]. + out_dtype : Optional[torch.dtype] + Output data type. If None, defaults to torch.bfloat16 (matching trtllm-gen backend). + Supported values: torch.bfloat16, torch.float8_e4m3fn (FP8 input only), + torch.float16, torch.bfloat16 (FP16/BF16 input). is_var_seq : bool Whether the sequence length is variable. If True, the sequence length is variable. @@ -343,6 +352,16 @@ def cute_dsl_mla_decode( assert D_qk == kv_lora_rank + qk_rope_head_dim q_dtype = query.dtype + # Resolve output dtype: for FP8 input, default to bfloat16 (matching trtllm-gen backend); + # for FP16/BF16 input, default to same as input. Allow override via out_dtype or out tensor. + if out is not None: + o_dtype = out.dtype + elif out_dtype is not None: + o_dtype = out_dtype + elif q_dtype == torch.float8_e4m3fn: + o_dtype = torch.bfloat16 + else: + o_dtype = q_dtype # Handle 3D vs 4D kv_cache: normalize to 3D [num_pages, page_size, D_total] if kv_cache.dim() == 4: @@ -399,12 +418,11 @@ def cute_dsl_mla_decode( workspace_bytes = workspace_buffer[:workspace_size] # Output buffer: contiguous [B, q_len, H, D]. # Kernel reinterprets to [H, D, q_len, B] internally via zero-cost make_tensor. - out_dtype = q_dtype if out is not None: o_k = out else: o_k = torch.empty( - (B, q_len, H, kv_lora_rank), dtype=out_dtype, device=query.device + (B, q_len, H, kv_lora_rank), dtype=o_dtype, device=query.device ) # LSE: contiguous [B, q_len, H]. Kernel reinterprets to [H, q_len, B]. @@ -423,6 +441,7 @@ def cute_dsl_mla_decode( # Validate configuration (cached, negligible overhead after first call) _check_can_implement( torch_dtype=q_dtype, + torch_out_dtype=o_dtype, page_size=page_size, num_heads=H, seq_len_q=q_len, @@ -438,6 +457,7 @@ def cute_dsl_mla_decode( # Otherwise, workspace_bytes is not None and it will launch two kernels. compiled_kernel = _get_compiled_mla_kernel( torch_dtype=q_dtype, + torch_out_dtype=o_dtype, page_size=page_size, kv_lora_rank=kv_lora_rank, qk_rope_head_dim=qk_rope_head_dim, diff --git a/flashinfer/mla/cute_dsl/mla_decode_fp8.py b/flashinfer/mla/cute_dsl/mla_decode_fp8.py index 8d50aab2e7..d0e5d83242 100644 --- a/flashinfer/mla/cute_dsl/mla_decode_fp8.py +++ b/flashinfer/mla/cute_dsl/mla_decode_fp8.py @@ -3489,7 +3489,7 @@ def can_implement( return False if in_dtype not in [cutlass.Float8E4M3FN]: return False - if out_dtype not in [cutlass.Float8E4M3FN]: + if out_dtype not in [cutlass.Float8E4M3FN, cutlass.BFloat16]: return False if acc_dtype != cutlass.Float32 or lse_dtype != cutlass.Float32: return False diff --git a/tests/attention/test_cute_dsl_mla_decode.py b/tests/attention/test_cute_dsl_mla_decode.py index 28ba68c4c4..4a84d47f51 100644 --- a/tests/attention/test_cute_dsl_mla_decode.py +++ b/tests/attention/test_cute_dsl_mla_decode.py @@ -448,7 +448,7 @@ def test_cute_dsl_mla_decode_fp8(batch_size, seq_len_k, page_size): output_scale=output_scale, ) - assert out.dtype == torch.float8_e4m3fn + assert out.dtype == torch.bfloat16 assert out.shape == (batch_size, q_len, num_heads, latent_dim) # Reference: compute in FP32 using FP8 values dequantized to FP32 diff --git a/tests/attention/test_trtllm_gen_mla.py b/tests/attention/test_trtllm_gen_mla.py index 9ecaf3a445..ceb0d271af 100755 --- a/tests/attention/test_trtllm_gen_mla.py +++ b/tests/attention/test_trtllm_gen_mla.py @@ -295,6 +295,8 @@ def trtllm_batch_decode_mla( pytest.skip("cute-dsl does not support enable_pdl") if skips_softmax: pytest.skip("cute-dsl does not support skip_softmax") + if not uses_shared_paged_kv_idx: + pytest.skip("cute-dsl does not support separate KV page indices") if dynamic_scale and dtype != torch.float8_e4m3fn: pytest.skip("Dynamic scale is not supported for non-fp8 dtype") From 4d221fd99a133803a67356c68ebf6d2a3871e10f Mon Sep 17 00:00:00 2001 From: Mindy Li <11663212+limin2021@users.noreply.github.com> Date: Wed, 25 Mar 2026 17:09:47 -0700 Subject: [PATCH 31/31] fix: skip CuTe DSL MLA tests on unsupported archs (SM120+) The tcgen05 MMA operations only support SM100-SM110. Tighten arch checks so SM120a (and above) are correctly skipped, and SM110 is correctly allowed. Co-Authored-By: Claude Opus 4.6 --- tests/attention/test_cute_dsl_mla_decode.py | 7 ++++--- tests/attention/test_trtllm_gen_mla.py | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/attention/test_cute_dsl_mla_decode.py b/tests/attention/test_cute_dsl_mla_decode.py index 4a84d47f51..d9427460f5 100644 --- a/tests/attention/test_cute_dsl_mla_decode.py +++ b/tests/attention/test_cute_dsl_mla_decode.py @@ -18,13 +18,14 @@ import torch import torch.nn.functional as F -from flashinfer.utils import is_sm100a_supported +from flashinfer.utils import is_sm100a_supported, is_sm110a_supported from flashinfer.cute_dsl import is_cute_dsl_available def skip_if_unsupported(): - if not is_sm100a_supported(torch.device("cuda")): - pytest.skip("Requires SM100a (Blackwell)") + device = torch.device("cuda") + if not (is_sm100a_supported(device) or is_sm110a_supported(device)): + pytest.skip("Requires SM100-SM110 (tcgen05)") if not is_cute_dsl_available(): pytest.skip("CuTe DSL not available") diff --git a/tests/attention/test_trtllm_gen_mla.py b/tests/attention/test_trtllm_gen_mla.py index ceb0d271af..19baa8d182 100755 --- a/tests/attention/test_trtllm_gen_mla.py +++ b/tests/attention/test_trtllm_gen_mla.py @@ -287,8 +287,8 @@ def trtllm_batch_decode_mla( if compute_capability[0] != 10: pytest.skip("TRTLLM-GEN MLA only supports SM100 and SM103 GPUs") if backend == "cute-dsl": - if compute_capability[0] < 10: - pytest.skip("cute-dsl MLA requires SM100+") + if compute_capability[0] not in (10, 11): + pytest.skip("cute-dsl MLA requires SM100-SM110 (tcgen05)") if dynamic_scale: pytest.skip("cute-dsl does not support dynamic_scale") if enable_pdl is not None: