From d0a53b54dc5483e4ac14ce7bbc029b0ce8170448 Mon Sep 17 00:00:00 2001 From: Igor Shovkun Date: Wed, 28 Jan 2026 21:29:54 -0800 Subject: [PATCH 01/33] adopt reference implementation from sglang --- tests/mamba/selective_state_update_triton.py | 343 +++++++++++++------ 1 file changed, 245 insertions(+), 98 deletions(-) diff --git a/tests/mamba/selective_state_update_triton.py b/tests/mamba/selective_state_update_triton.py index 2d46ffe0c9..ef76073f8f 100644 --- a/tests/mamba/selective_state_update_triton.py +++ b/tests/mamba/selective_state_update_triton.py @@ -1,42 +1,33 @@ -# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py -# Copyright (c) 2024, Tri Dao, Albert Gu. -# -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/mamba_ssm.py + # SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py import torch import triton import triton.language as tl from packaging import version +PAD_SLOT_ID = -1 + TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") if TRITON3: @triton.jit def softplus(dt): - return tl.math.log(tl.math.exp(dt) + 1) + dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt) + return dt else: @triton.jit def softplus(dt): - return tl.math.log1p(tl.exp(dt)) - - -PAD_SLOT_ID = -1 + dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) + return dt @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) @@ -51,7 +42,29 @@ def softplus(dt): @triton.heuristics( {"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])} ) -@triton.jit +@triton.heuristics( + { + "CACHE_INTERMEDIATE_STATES": lambda args: args["intermediate_states_buffer"] + is not None + } +) +@triton.heuristics( + { + "HAS_EAGLE_TREE_CUSTOM_ATTN_MASK": lambda args: args[ + "retrieve_parent_token_ptr" + ] + is not None + } +) +@triton.heuristics( + { + "HAS_INTERMEDIATE_STATE_INDICES": lambda args: args[ + "intermediate_state_indices_ptr" + ] + is not None + } +) +@triton.jit(do_not_specialize=["T"]) def _selective_scan_update_kernel( # Pointers to matrices state_ptr, @@ -66,8 +79,13 @@ def _selective_scan_update_kernel( out_ptr, state_batch_indices_ptr, pad_slot_id, + intermediate_states_buffer, + cache_steps, + retrieve_parent_token_ptr, + intermediate_state_indices_ptr, # Matrix dimensions batch, + T, nheads, dim, dstate, @@ -78,9 +96,11 @@ def _selective_scan_update_kernel( stride_state_dim, stride_state_dstate, stride_x_batch, + stride_x_T, stride_x_head, stride_x_dim, stride_dt_batch, + stride_dt_T, stride_dt_head, stride_dt_dim, stride_dt_bias_head, @@ -89,19 +109,25 @@ def _selective_scan_update_kernel( stride_A_dim, stride_A_dstate, stride_B_batch, + stride_B_T, stride_B_group, stride_B_dstate, stride_C_batch, + stride_C_T, stride_C_group, stride_C_dstate, stride_D_head, stride_D_dim, stride_z_batch, + stride_z_T, stride_z_head, stride_z_dim, stride_out_batch, + stride_out_T, stride_out_head, stride_out_dim, + stride_retrieve_parent_token_batch, + stride_retrieve_parent_token_T, # Meta-parameters DT_SOFTPLUS: tl.constexpr, TIE_HDIM: tl.constexpr, @@ -110,6 +136,10 @@ def _selective_scan_update_kernel( HAS_D: tl.constexpr, HAS_Z: tl.constexpr, HAS_STATE_BATCH_INDICES: tl.constexpr, + DISABLE_STATE_UPDATE: tl.constexpr, + CACHE_INTERMEDIATE_STATES: tl.constexpr, + HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: tl.constexpr, + HAS_INTERMEDIATE_STATE_INDICES: tl.constexpr, BLOCK_SIZE_DSTATE: tl.constexpr, ): pid_m = tl.program_id(axis=0) @@ -121,7 +151,7 @@ def _selective_scan_update_kernel( # is the same as the batch id. if HAS_STATE_BATCH_INDICES: state_batch_indices_ptr += pid_b - state_batch_idx = tl.load(state_batch_indices_ptr) + state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64) state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head else: state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head @@ -142,70 +172,129 @@ def _selective_scan_update_kernel( state_ptrs = state_ptr + ( offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate ) - x_ptrs = x_ptr + offs_m * stride_x_dim - dt_ptrs = dt_ptr + offs_m * stride_dt_dim + + mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) + if HAS_STATE_BATCH_INDICES: + mask &= state_batch_idx != pad_slot_id + state = tl.load(state_ptrs, mask=mask, other=0.0).to(tl.float32) + if HAS_DT_BIAS: dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim if HAS_D: D_ptr += pid_h * stride_D_head - A_ptrs = A_ptr + ( - offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate - ) - B_ptrs = B_ptr + offs_n * stride_B_dstate - C_ptrs = C_ptr + offs_n * stride_C_dstate - if HAS_D: D_ptrs = D_ptr + offs_m * stride_D_dim - if HAS_Z: - z_ptrs = z_ptr + offs_m * stride_z_dim - out_ptrs = out_ptr + offs_m * stride_out_dim - mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) - if HAS_STATE_BATCH_INDICES: - mask &= state_batch_idx != pad_slot_id - state = tl.load(state_ptrs, mask=mask, other=0.0) - - x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - if not TIE_HDIM: - dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - if HAS_DT_BIAS: - dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - if DT_SOFTPLUS: - dt = tl.where(dt <= 20.0, softplus(dt), dt) - A = tl.load( - A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0 - ).to(tl.float32) - dA = tl.exp(A * dt[:, None]) - else: - dt = tl.load(dt_ptr).to(tl.float32) - if HAS_DT_BIAS: - dt += tl.load(dt_bias_ptr).to(tl.float32) - if DT_SOFTPLUS: - dt = tl.where(dt <= 20.0, softplus(dt), dt) - A = tl.load(A_ptr).to(tl.float32) - dA = tl.exp(A * dt) # scalar, not a matrix - - B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) - C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) - if HAS_D: - D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - if HAS_Z: - z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + A_ptrs = A_ptr + offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate - if not TIE_HDIM: - dB = B[None, :] * dt[:, None] - else: - dB = B * dt # vector of size (dstate,) - state = state * dA + dB * x[:, None] + cache_idx = -1 + if CACHE_INTERMEDIATE_STATES: + if HAS_INTERMEDIATE_STATE_INDICES: + intermediate_state_idx = tl.load(intermediate_state_indices_ptr + pid_b).to( + tl.int64 + ) + cache_idx = intermediate_state_idx + elif HAS_STATE_BATCH_INDICES: + cache_idx = state_batch_idx + else: + cache_idx = pid_b - mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) - if HAS_STATE_BATCH_INDICES: - mask &= state_batch_idx != pad_slot_id - tl.store(state_ptrs, state, mask=mask) - out = tl.sum(state * C[None, :], axis=1) - if HAS_D: - out += x * D - if HAS_Z: - out *= z * tl.sigmoid(z) - tl.store(out_ptrs, out, mask=offs_m < dim) + current_step_idx = 0 + for _ in range(T): + if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: + if current_step_idx != 0 and cache_idx >= 0: + parent_ptr = ( + retrieve_parent_token_ptr + + pid_b * stride_retrieve_parent_token_batch + + current_step_idx * stride_retrieve_parent_token_T + ) + parent_step_idx = tl.load(parent_ptr).to(tl.int32) + + if parent_step_idx >= 0 and parent_step_idx < T: + step_offset = parent_step_idx * nheads * dim * dstate + cache_ptr = ( + intermediate_states_buffer + + cache_idx * cache_steps * nheads * dim * dstate + + step_offset + + pid_h * dim * dstate + + offs_m[:, None] * dstate + + offs_n[None, :] + ) + state = tl.load(cache_ptr, mask=mask, other=0.0).to(tl.float32) + + x_ptrs = x_ptr + offs_m * stride_x_dim + dt_ptrs = dt_ptr + offs_m * stride_dt_dim + B_ptrs = B_ptr + offs_n * stride_B_dstate + C_ptrs = C_ptr + offs_n * stride_C_dstate + if HAS_Z: + z_ptrs = z_ptr + offs_m * stride_z_dim + out_ptrs = out_ptr + offs_m * stride_out_dim + + x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if not TIE_HDIM: + dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if DT_SOFTPLUS: + dt = softplus(dt) + A = tl.load( + A_ptrs, + mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), + other=0.0, + ).to(tl.float32) + dA = tl.exp(A * dt[:, None]) + else: + dt = tl.load(dt_ptr).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptr).to(tl.float32) + if DT_SOFTPLUS: + dt = softplus(dt) + A = tl.load(A_ptr).to(tl.float32) + dA = tl.exp(A * dt) # scalar, not a matrix + + B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + if HAS_D: + D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_Z: + z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + + dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt + state = state * dA + dB * x[:, None] + + if CACHE_INTERMEDIATE_STATES: + if HAS_STATE_BATCH_INDICES: + if state_batch_idx != pad_slot_id: + cache_ptr_base = ( + intermediate_states_buffer + + cache_idx * cache_steps * nheads * dim * dstate + + current_step_idx * nheads * dim * dstate + + pid_h * dim * dstate + ) + cache_ptrs = cache_ptr_base + ( + offs_m[:, None] * dstate + offs_n[None, :] + ) + tl.store( + cache_ptrs, state.to(cache_ptrs.dtype.element_ty), mask=mask + ) + + out = tl.sum(state * C[None, :], axis=1) + if HAS_D: + out += x * D + if HAS_Z: + out *= z * tl.sigmoid(z) + tl.store(out_ptrs, out, mask=offs_m < dim) + + current_step_idx += 1 # noqa: SIM113 + + x_ptr += stride_x_T + dt_ptr += stride_dt_T + B_ptr += stride_B_T + C_ptr += stride_C_T + out_ptr += stride_out_T + if HAS_Z: + z_ptr += stride_z_T + + if not DISABLE_STATE_UPDATE: + tl.store(state_ptrs, state.to(state_ptrs.dtype.element_ty), mask=mask) def selective_state_update_triton( @@ -221,17 +310,23 @@ def selective_state_update_triton( dt_softplus=False, state_batch_indices=None, pad_slot_id=PAD_SLOT_ID, + out=None, + disable_state_update=False, + intermediate_states_buffer=None, + cache_steps=None, + retrieve_parent_token=None, + intermediate_state_indices=None, ): """ Argument: state: (batch, dim, dstate) or (batch, nheads, dim, dstate) - x: (batch, dim) or (batch, nheads, dim) - dt: (batch, dim) or (batch, nheads, dim) + x: (batch, dim) or (batch, nheads, dim) for single-token or (batch, T, nheads, dim) for multi-token + dt: (batch, dim) or (batch, nheads, dim) for single-token or (batch, T, nheads, dim) for multi-token A: (dim, dstate) or (nheads, dim, dstate) - B: (batch, dstate) or (batch, ngroups, dstate) - C: (batch, dstate) or (batch, ngroups, dstate) + B: (batch, dstate) or (batch, ngroups, dstate) for single-token or (batch, T, ngroups, dstate) for multi-token + C: (batch, dstate) or (batch, ngroups, dstate) for single-token or (batch, T, ngroups, dstate) for multi-token D: (dim,) or (nheads, dim) - z: (batch, dim) or (batch, nheads, dim) + z: (batch, dim) or (batch, nheads, dim) for single-token or (batch, T, nheads, dim) for multi-token dt_bias: (dim,) or (nheads, dim) pad_slot_id: int if cache_indices is passed, lets the kernel identify padded @@ -239,38 +334,63 @@ def selective_state_update_triton( for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] in this case, the kernel will not process entries at indices 0 and 3 - Return: - out: (batch, dim) or (batch, nheads, dim) + out: Preallocated ssm output tensor. Assume same shape as x. + In-place updated. + disable_state_update: If True, don't write back to state (for speculative verify) + intermediate_states_buffer: Buffer to cache intermediate states + cache_steps: Total number of steps in the buffer + retrieve_parent_token: (batch, T) tensor of parent token indices for EAGLE tree attention + intermediate_state_indices: (batch,) tensor of indices for intermediate_states_buffer operations. + If provided, uses these indices instead of state_batch_indices for the buffer. """ - has_heads = state.dim() > 3 + # Track original x dimensionality to squeeze output appropriately + x_orig_dim = x.dim() + if state.dim() == 3: state = state.unsqueeze(1) if x.dim() == 2: x = x.unsqueeze(1) + if x.dim() == 3: + x = x.unsqueeze(1) if dt.dim() == 2: dt = dt.unsqueeze(1) + if dt.dim() == 3: + dt = dt.unsqueeze(1) if A.dim() == 2: A = A.unsqueeze(0) if B.dim() == 2: B = B.unsqueeze(1) + if B.dim() == 3: + B = B.unsqueeze(1) if C.dim() == 2: C = C.unsqueeze(1) + if C.dim() == 3: + C = C.unsqueeze(1) if D is not None and D.dim() == 1: D = D.unsqueeze(0) - if z is not None and z.dim() == 2: - z = z.unsqueeze(1) + if z is not None: + if z.dim() == 2: + z = z.unsqueeze(1) + if z.dim() == 3: + z = z.unsqueeze(1) if dt_bias is not None and dt_bias.dim() == 1: dt_bias = dt_bias.unsqueeze(0) + if out is None: + out = torch.empty_like(x) + if out.dim() == 2: + out = out.unsqueeze(1) + if out.dim() == 3: + out = out.unsqueeze(1) _, nheads, dim, dstate = state.shape - batch = x.shape[0] + batch, T, _, _ = x.shape - assert x.shape == (batch, nheads, dim) + assert x.shape == (batch, T, nheads, dim) assert dt.shape == x.shape assert A.shape == (nheads, dim, dstate) - ngroups = B.shape[1] + ngroups = B.shape[2] assert nheads % ngroups == 0, "nheads must be divisible by ngroups" - assert B.shape == (batch, ngroups, dstate) + assert B.shape == (batch, T, ngroups, dstate) assert C.shape == B.shape if D is not None: assert D.shape == (nheads, dim) @@ -280,9 +400,14 @@ def selective_state_update_triton( assert dt_bias.shape == (nheads, dim) if state_batch_indices is not None: assert state_batch_indices.shape == (batch,) - out = torch.empty_like(x) + assert out.shape == x.shape + grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads) - z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0) + z_strides = ( + (z.stride(0), z.stride(1), z.stride(2), z.stride(3)) + if z is not None + else (0, 0, 0, 0) + ) # We don't want autotune since it will overwrite the state # We instead tune by hand. BLOCK_SIZE_M, num_warps = ( @@ -298,8 +423,15 @@ def selective_state_update_triton( A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(-1) == 0 - and (dt_bias is not None and dt_bias.stride(-1) == 0) + and (dt_bias is None or dt_bias.stride(-1) == 0) + ) + + retrieve_parent_token_strides = ( + (retrieve_parent_token.stride(0), retrieve_parent_token.stride(1)) + if retrieve_parent_token is not None + else (0, 0) ) + with torch.cuda.device(x.device.index): _selective_scan_update_kernel[grid]( state, @@ -314,7 +446,12 @@ def selective_state_update_triton( out, state_batch_indices, pad_slot_id, + intermediate_states_buffer, + cache_steps if cache_steps is not None else 0, + retrieve_parent_token, + intermediate_state_indices, batch, + T, nheads, dim, dstate, @@ -326,9 +463,11 @@ def selective_state_update_triton( x.stride(0), x.stride(1), x.stride(2), + x.stride(3), dt.stride(0), dt.stride(1), dt.stride(2), + dt.stride(3), *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else (0, 0), A.stride(0), A.stride(1), @@ -336,21 +475,29 @@ def selective_state_update_triton( B.stride(0), B.stride(1), B.stride(2), + B.stride(3), C.stride(0), C.stride(1), C.stride(2), + C.stride(3), *(D.stride(0), D.stride(1)) if D is not None else (0, 0), z_strides[0], z_strides[1], z_strides[2], + z_strides[3], out.stride(0), out.stride(1), out.stride(2), + out.stride(3), + retrieve_parent_token_strides[0], + retrieve_parent_token_strides[1], dt_softplus, tie_hdim, BLOCK_SIZE_M, + DISABLE_STATE_UPDATE=disable_state_update, num_warps=num_warps, ) - if not has_heads: + # Squeeze T dimension if original x didn't have it (was 2D or 3D) + if x_orig_dim < 4: out = out.squeeze(1) return out From 320a72c13cdbe255d227ea9eb9e28800ef610ecb Mon Sep 17 00:00:00 2001 From: Igor Shovkun Date: Wed, 28 Jan 2026 21:51:58 -0800 Subject: [PATCH 02/33] Extract create_test_inputs to shared test_utils module Move the test input generation helper from test_selective_state_update.py to a new test_utils.py module for reuse across tests. The refactored function adds support for multi-token mode, intermediate state buffers, and configurable state cache strides. --- tests/mamba/test_selective_state_update.py | 99 ++------- tests/mamba/test_utils.py | 226 +++++++++++++++++++++ 2 files changed, 239 insertions(+), 86 deletions(-) create mode 100644 tests/mamba/test_utils.py diff --git a/tests/mamba/test_selective_state_update.py b/tests/mamba/test_selective_state_update.py index 4a92340f55..35d53a541c 100644 --- a/tests/mamba/test_selective_state_update.py +++ b/tests/mamba/test_selective_state_update.py @@ -5,82 +5,7 @@ import flashinfer from .selective_state_update_triton import selective_state_update_triton - - -def create_test_inputs( - batch_size, - nheads, - dim, - dstate, - ngroups, - input_dtype, - weight_dtype, - matrixA_dtype, - state_dtype, - z_none=True, -): - # Set seed for reproducibility - torch.manual_seed(0) - - device = torch.device("cuda") - - # if we use the cache, then the state indices are taken from a specific slot - # so the state in the kernel will have batch as the first dimension, but it will - # only come from a particular slot; the full tensor first dim is larger - ssm_state_cache_size = max(384, int(2 * batch_size)) - - state_cache = torch.randn( - ssm_state_cache_size, nheads, dim, dstate, dtype=state_dtype, device=device - ) - - x = torch.randn(batch_size, nheads, dim, dtype=input_dtype, device=device) - - dt = torch.randn(batch_size, nheads, dtype=weight_dtype, device=device).as_strided( - (batch_size, nheads, dim), (nheads, 1, 0) - ) - # The dtype of A is separate in nemotron nano v3. - # A only has one value per head (as discussed in mamba 2 block) hence the strides. - A_base = torch.rand(nheads, dtype=matrixA_dtype, device=device) - A = A_base.as_strided((nheads, dim, dstate), (1, 0, 0)) - assert A.stride() == (1, 0, 0) - - # B and C - (batch_size, ngroups, dstate) - B = torch.randn(batch_size, ngroups, dstate, dtype=input_dtype, device=device) - C = torch.randn(batch_size, ngroups, dstate, dtype=input_dtype, device=device) - - # D - (nheads, dim) with strides (1, 0) - one value per head - D = torch.randn(nheads, dtype=weight_dtype, device=device).as_strided( - (nheads, dim), (1, 0) - ) - - dt_bias = torch.randn(nheads, dtype=weight_dtype, device=device).as_strided( - (nheads, dim), (1, 0) - ) - - # Slot indices for state batching - (batch_size,) - slot_idx = torch.randperm(ssm_state_cache_size, dtype=torch.int32, device=device)[ - :batch_size - ] - - # Create z tensor if z_none is False - z = ( - None - if z_none - else torch.randn(batch_size, nheads, dim, dtype=input_dtype, device=device) - ) - - return { - "state_cache": state_cache, - "x": x, - "dt": dt, - "A": A, - "B": B, - "C": C, - "D": D, - "z": z, - "dt_bias": dt_bias, - "slot_idx": slot_idx, - } +from .test_utils import create_test_inputs @pytest.mark.parametrize("batch", [1, 64]) @@ -112,10 +37,11 @@ def test_selective_state_update( dstate, ngroups, input_dtype, - weight_dtype, - matrixA_dtype, + weight_dtype=weight_dtype, + matrixA_dtype=matrixA_dtype, state_dtype=state_dtype, - z_none=True, + generate_z=False, + seed=0, ) state = inputs["state_cache"] @@ -128,7 +54,7 @@ def test_selective_state_update( inputs["B"], inputs["C"], D=inputs["D"], - z=inputs["z"], + z=inputs.get("z"), dt_bias=inputs["dt_bias"], dt_softplus=delta_softplus, state_batch_indices=inputs["slot_idx"], @@ -149,7 +75,7 @@ def test_selective_state_update( inputs["B"], inputs["C"], D=inputs["D"], - z=inputs["z"], + z=inputs.get("z"), dt_bias=inputs["dt_bias"], dt_softplus=delta_softplus, state_batch_indices=inputs["slot_idx"], @@ -264,10 +190,11 @@ def test_selective_state_update_with_z(use_out_tensor): dstate, ngroups, input_dtype, - weight_dtype, - matrixA_dtype, + weight_dtype=weight_dtype, + matrixA_dtype=matrixA_dtype, state_dtype=state_dtype, - z_none=False, + generate_z=True, + seed=0, ) state = inputs["state_cache"] @@ -280,7 +207,7 @@ def test_selective_state_update_with_z(use_out_tensor): inputs["B"], inputs["C"], D=inputs["D"], - z=inputs["z"], + z=inputs.get("z"), dt_bias=inputs["dt_bias"], dt_softplus=delta_softplus, state_batch_indices=inputs["slot_idx"], @@ -301,7 +228,7 @@ def test_selective_state_update_with_z(use_out_tensor): inputs["B"], inputs["C"], D=inputs["D"], - z=inputs["z"], + z=inputs.get("z"), dt_bias=inputs["dt_bias"], dt_softplus=delta_softplus, state_batch_indices=inputs["slot_idx"], diff --git a/tests/mamba/test_utils.py b/tests/mamba/test_utils.py new file mode 100644 index 0000000000..64995b823b --- /dev/null +++ b/tests/mamba/test_utils.py @@ -0,0 +1,226 @@ +from typing import Any, Dict, Optional + +import numpy as np +import torch + + +def create_test_inputs( + batch_size: int, + nheads: int, + dim: int, + dstate: int, + ngroups: int, + input_dtype: torch.dtype, + weight_dtype: torch.dtype = torch.float32, + state_dtype: torch.dtype = torch.bfloat16, + matrixA_dtype: torch.dtype = torch.float32, + generate_z: bool = False, + generate_intermediate_states_buffer: bool = False, + cache_steps: Optional[int] = None, + generate_retrieve_parent_token: bool = False, + state_cache_batch_stride: Optional[int] = None, + device: str = "cuda", + seed: int = 42, +) -> Dict[str, Any]: + """ + Create test inputs for selective_state_update functions. + + This function generates all necessary tensors for testing selective state + update kernels, supporting both single-token and multi-token (speculative + decoding) scenarios. + + Arguments: + batch_size: Number of sequences in the batch. + nheads: Number of attention heads. + dim: Head dimension (headdim). + dstate: SSM state size. + ngroups: Number of groups for B and C matrices. + input_dtype: Data type for input tensors (x, B, C, z) - from model config.json (typically bf16). + weight_dtype: Data type for weight tensors (D, dt, dt_bias) - hardcoded fp32 in mamba2_mixer.py. + state_dtype: Data type for state tensor - user configurable (bf16/fp16/fp32). Defaults to input_dtype. + matrixA_dtype: Data type for the A matrix - hardcoded fp32 in mamba2_mixer.py. + generate_z: If True, generate z tensor for gating. + generate_intermediate_states_buffer: If True, generate buffer for + caching intermediate states during speculative decoding. + cache_steps: Number of steps/tokens to cache. Required if + generate_intermediate_states_buffer is True. Also determines + T dimension when > 1 (multi-token mode). + generate_retrieve_parent_token: If True, generate tensor for EAGLE + tree attention parent token retrieval. + state_cache_batch_stride: Optional batch stride for ssm_state_cache. + If None, defaults to contiguous stride (nheads * dim * dstate). + Must be >= nheads * dim * dstate if specified. + device: Device to create tensors on. + seed: Random seed for reproducibility. + + Returns: + Dictionary containing all generated tensors with the following keys: + - state_cache: (total_entries, nheads, dim, dstate) + - x: (batch_size, [T,] nheads, dim) - T present if cache_steps provided + - dt: (batch_size, [T,] nheads, dim) - T present if cache_steps provided + - A: (nheads, dim, dstate) + - B: (batch_size, [T,] ngroups, dstate) - T present if cache_steps provided + - C: (batch_size, [T,] ngroups, dstate) - T present if cache_steps provided + - D: (nheads, dim) + - dt_bias: (nheads, dim) + - slot_idx: (batch_size,) + - z: (batch_size, [T,] nheads, dim) - only if generate_z=True, T present if cache_steps provided + - intermediate_states_buffer: (batch_size, cache_steps, nheads, dim, dstate) + - only if generate_intermediate_states_buffer=True + - intermediate_slot_idx: (batch_size,) + - only if generate_intermediate_states_buffer=True + - retrieve_parent_token: (batch_size, T) + - only if generate_retrieve_parent_token=True + - cache_steps: int - only if cache_steps is provided + """ + # Set seeds for reproducibility + torch.manual_seed(seed) + np.random.seed(seed) + + # Determine if we're in multi-token mode + # Always use 4D tensors when cache_steps is provided (even for cache_steps=1) + T = cache_steps if cache_steps is not None else None + + # If we use the cache, then the state indices are taken from a specific slot + # so the state in the kernel will have batch as the first dimension, but it will + # only come from a particular slot; the full tensor first dim is larger + ssm_state_cache_size = max(384, batch_size * 10) + + # State dtype defaults to input_dtype if not specified + + # SSM state cache: (total_entries, nheads, dim, dstate) + # Calculate the contiguous batch stride + contiguous_batch_stride = nheads * dim * dstate + + # Use provided batch stride or default to contiguous + if state_cache_batch_stride is None: + state_cache_batch_stride = contiguous_batch_stride + + # Validate that batch stride is large enough + if state_cache_batch_stride < contiguous_batch_stride: + raise ValueError( + f"state_cache_batch_stride ({state_cache_batch_stride}) must be >= " + f"contiguous stride ({contiguous_batch_stride} = nheads * dim * dstate)" + ) + + total_elements = ssm_state_cache_size * state_cache_batch_stride + state_cache_flat = torch.randn(total_elements, dtype=state_dtype, device=device) + state_cache = state_cache_flat.as_strided( + (ssm_state_cache_size, nheads, dim, dstate), + (state_cache_batch_stride, dim * dstate, dstate, 1), + ) + + # Input x: (batch_size, [T,] nheads, dim) + if T is not None: + x = torch.randn(batch_size, T, nheads, dim, device=device, dtype=input_dtype) + else: + x = torch.randn(batch_size, nheads, dim, dtype=input_dtype, device=device) + + # dt: (batch_size, [T,] nheads, dim) with strides that broadcast dim + # dt uses weight_dtype (fp32) as per mamba2_mixer.py + # dt has T dimension for multi-token mode, matching x shape + if T is not None: + dt_base = torch.randn(batch_size, T, nheads, dtype=weight_dtype, device=device) + dt = dt_base.as_strided( + (batch_size, T, nheads, dim), (T * nheads, nheads, 1, 0) + ) + else: + dt_base = torch.randn(batch_size, nheads, dtype=weight_dtype, device=device) + dt = dt_base.as_strided((batch_size, nheads, dim), (nheads, 1, 0)) + + # A matrix: (nheads, dim, dstate) with strides (1, 0, 0) - one value per head + # A should be negative for stability + A_base = -torch.rand(nheads, dtype=matrixA_dtype, device=device) - 1.0 + A = A_base.as_strided((nheads, dim, dstate), (1, 0, 0)) + + # B: (batch_size, T, ngroups, dstate) + # C: (batch_size, ngroups, dstate) + if T is not None: + B = torch.randn( + batch_size, T, ngroups, dstate, device=device, dtype=input_dtype + ) + C = torch.randn( + batch_size, T, ngroups, dstate, device=device, dtype=input_dtype + ) + else: + B = torch.randn(batch_size, ngroups, dstate, dtype=input_dtype, device=device) + C = torch.randn(batch_size, ngroups, dstate, dtype=input_dtype, device=device) + + # D: (nheads, dim) with strides (1, 0) - one value per head + D_base = torch.randn(nheads, dtype=weight_dtype, device=device) + D = D_base.as_strided((nheads, dim), (1, 0)) + + # dt_bias: (nheads, dim) with strides (1, 0) - one value per head + dt_bias_base = torch.rand(nheads, dtype=weight_dtype, device=device) - 4.0 + dt_bias = dt_bias_base.as_strided((nheads, dim), (1, 0)) + + # Slot indices for state batching - (batch_size,) + slot_idx = torch.randperm(ssm_state_cache_size, dtype=torch.int32, device=device)[ + :batch_size + ] + + # Build result dictionary + result = { + "state_cache": state_cache, + "x": x, + "dt": dt, + "A": A, + "B": B, + "C": C, + "D": D, + "dt_bias": dt_bias, + "slot_idx": slot_idx, + } + + # Optional: z tensor for gating + # z: (batch_size, [T,] nheads, dim) - has T dimension for multi-token mode, matching x shape + if generate_z: + if T is not None: + z = torch.randn( + batch_size, T, nheads, dim, dtype=input_dtype, device=device + ) + else: + z = torch.randn(batch_size, nheads, dim, dtype=input_dtype, device=device) + result["z"] = z + + # Optional: intermediate states buffer for speculative decoding + if generate_intermediate_states_buffer: + if cache_steps is None: + raise ValueError( + "cache_steps must be provided when generate_intermediate_states_buffer=True" + ) + intermediate_states_buffer = torch.zeros( + batch_size, + cache_steps, + nheads, + dim, + dstate, + dtype=state_dtype, + device=device, + ) + result["intermediate_states_buffer"] = intermediate_states_buffer + result["cache_steps"] = cache_steps + # Also generate indices mapping batch elements to intermediate state buffer positions + intermediate_slot_idx = torch.arange( + batch_size, dtype=torch.int32, device=device + ) + result["intermediate_slot_idx"] = intermediate_slot_idx + + # Optional: retrieve_parent_token for EAGLE tree attention + if generate_retrieve_parent_token: + if T is None or T <= 1: + raise ValueError( + "cache_steps > 1 required when generate_retrieve_parent_token=True" + ) + # Create a simple linear chain structure by default + # Token 0: parent = -1 (initial state) + # Token t: parent = t - 1 (previous token) + retrieve_parent_token = torch.zeros( + batch_size, T, dtype=torch.int32, device=device + ) + retrieve_parent_token[:, 0] = -1 # First token uses initial state + for t in range(1, T): + retrieve_parent_token[:, t] = t - 1 + result["retrieve_parent_token"] = retrieve_parent_token + + return result From 4022f10713f4230b54bef342b693a17a032b0140 Mon Sep 17 00:00:00 2001 From: Igor Shovkun Date: Wed, 28 Jan 2026 22:19:57 -0800 Subject: [PATCH 03/33] Rename test to reflect that it's an single-token test file --- tests/mamba/test_selective_state_update.py | 253 --------------- .../mamba/test_selective_state_update_stp.py | 301 ++++++++++++++++++ 2 files changed, 301 insertions(+), 253 deletions(-) delete mode 100644 tests/mamba/test_selective_state_update.py create mode 100644 tests/mamba/test_selective_state_update_stp.py diff --git a/tests/mamba/test_selective_state_update.py b/tests/mamba/test_selective_state_update.py deleted file mode 100644 index 35d53a541c..0000000000 --- a/tests/mamba/test_selective_state_update.py +++ /dev/null @@ -1,253 +0,0 @@ -import numpy as np -import pytest -import torch - -import flashinfer - -from .selective_state_update_triton import selective_state_update_triton -from .test_utils import create_test_inputs - - -@pytest.mark.parametrize("batch", [1, 64]) -@pytest.mark.parametrize("nheads", [8, 64]) -@pytest.mark.parametrize("dim", [64, 128]) -@pytest.mark.parametrize("dstate", [64, 128, 256]) -@pytest.mark.parametrize("state_dtype", [torch.float16, torch.bfloat16, torch.float32]) -@pytest.mark.parametrize("weight_dtype", [torch.float32, torch.bfloat16]) -@pytest.mark.parametrize("use_out_tensor", [False, True]) -def test_selective_state_update( - batch, - nheads, - dim, - dstate, - state_dtype, - weight_dtype, - use_out_tensor, -): - """Test selective_state_update correctness against reference implementation.""" - ngroups = 8 - delta_softplus = True - input_dtype = torch.bfloat16 - matrixA_dtype = torch.float32 - - inputs = create_test_inputs( - batch, - nheads, - dim, - dstate, - ngroups, - input_dtype, - weight_dtype=weight_dtype, - matrixA_dtype=matrixA_dtype, - state_dtype=state_dtype, - generate_z=False, - seed=0, - ) - - state = inputs["state_cache"] - state_ref = state.clone() - y_ref = selective_state_update_triton( - state_ref, - inputs["x"], - inputs["dt"], - inputs["A"], - inputs["B"], - inputs["C"], - D=inputs["D"], - z=inputs.get("z"), - dt_bias=inputs["dt_bias"], - dt_softplus=delta_softplus, - state_batch_indices=inputs["slot_idx"], - pad_slot_id=-1, - ) - - # Prepare output tensor if use_out_tensor is True - if use_out_tensor: - out = torch.empty(batch, nheads, dim, dtype=input_dtype, device="cuda") - else: - out = None - - y_test = flashinfer.mamba.selective_state_update( - state, - inputs["x"], - inputs["dt"], - inputs["A"], - inputs["B"], - inputs["C"], - D=inputs["D"], - z=inputs.get("z"), - dt_bias=inputs["dt_bias"], - dt_softplus=delta_softplus, - state_batch_indices=inputs["slot_idx"], - pad_slot_id=-1, - out=out, - ) - - # Verify the returned tensor is the same object as the provided output tensor - if use_out_tensor: - assert y_test.data_ptr() == out.data_ptr(), ( - "Returned tensor should be the same object as the provided output tensor" - ) - - atol = 1e-3 - rtol = 1e-2 - outputs_match = torch.allclose(y_ref, y_test, atol=atol, rtol=rtol) - - if outputs_match: - print(f"✓ Outputs match within tolerance (atol={atol}, rtol={rtol})") - else: - print(f"✗ Outputs do NOT match within tolerance (atol={atol}, rtol={rtol})") - - # Detailed comparison using numpy testing - y_ref_np = y_ref.detach().cpu().float().numpy() - y_test_np = y_test.detach().cpu().float().numpy() - print(f"dtypes: ref {y_ref_np.dtype}, test {y_test_np.dtype}") - - print("\nDetailed mismatch analysis:") - mismatch_mask = ~np.isclose(y_ref_np, y_test_np, atol=atol, rtol=rtol) - num_mismatches = np.sum(mismatch_mask) - total_elements = y_ref_np.size - - print( - f"Number of mismatched elements: {num_mismatches} / {total_elements} ({100 * num_mismatches / total_elements:.2f}%)" - ) - - mismatch_indices = np.argwhere(mismatch_mask) - print("First few mismatch locations (up to 10):") - for idx in mismatch_indices[:10]: - idx_tuple = tuple(int(i) for i in idx) - ref_val = y_ref_np[idx_tuple] - test_val = y_test_np[idx_tuple] - diff = abs(ref_val - test_val) - rel_diff = diff / (abs(ref_val) + 1e-8) - print( - f" Index {idx_tuple}: ref={ref_val:.6f}, test={test_val:.6f}, diff={diff:.6e}, rel_diff={rel_diff:.6e}" - ) - - assert outputs_match - - # Check if states match within tolerance - states_match = torch.allclose( - state_ref[inputs["slot_idx"]], - state[inputs["slot_idx"]], - atol=atol, - rtol=rtol, - ) - - if states_match: - print(f"✓ States match within tolerance (atol={atol}, rtol={rtol})") - else: - print(f"✗ States do NOT match within tolerance (atol={atol}, rtol={rtol})") - - # Detailed comparison using numpy testing - state_ref_np = state_ref[inputs["slot_idx"]].detach().cpu().float().numpy() - state_test_np = state[inputs["slot_idx"]].detach().cpu().float().numpy() - - print("\nDetailed state mismatch analysis:") - state_mismatch_mask = ~np.isclose( - state_ref_np, state_test_np, atol=atol, rtol=rtol - ) - num_state_mismatches = np.sum(state_mismatch_mask) - total_state_elements = state_ref_np.size - - print( - f"Number of mismatched state elements: {num_state_mismatches} / {total_state_elements} ({100 * num_state_mismatches / total_state_elements:.2f}%)" - ) - - state_mismatch_indices = np.argwhere(state_mismatch_mask) - print("First few state mismatch locations (up to 10):") - for idx in state_mismatch_indices[:10]: - idx_tuple = tuple(int(i) for i in idx) - ref_val = state_ref_np[idx_tuple] - test_val = state_test_np[idx_tuple] - diff = abs(ref_val - test_val) - rel_diff = diff / (abs(ref_val) + 1e-8) - print( - f" Index{idx_tuple}: ref={ref_val:.6f}, test={test_val:.6f}, diff={diff:.6e}, rel_diff={rel_diff:.6e}" - ) - - assert states_match - - -@pytest.mark.parametrize("use_out_tensor", [False, True]) -def test_selective_state_update_with_z(use_out_tensor): - """Test selective_state_update with z tensor (not None).""" - batch = 1 - nheads = 8 - dim = 64 - dstate = 128 - ngroups = 8 - delta_softplus = True - input_dtype = torch.bfloat16 - weight_dtype = torch.bfloat16 - matrixA_dtype = torch.float32 - state_dtype = torch.bfloat16 - - inputs = create_test_inputs( - batch, - nheads, - dim, - dstate, - ngroups, - input_dtype, - weight_dtype=weight_dtype, - matrixA_dtype=matrixA_dtype, - state_dtype=state_dtype, - generate_z=True, - seed=0, - ) - - state = inputs["state_cache"] - state_ref = state.clone() - y_ref = selective_state_update_triton( - state_ref, - inputs["x"], - inputs["dt"], - inputs["A"], - inputs["B"], - inputs["C"], - D=inputs["D"], - z=inputs.get("z"), - dt_bias=inputs["dt_bias"], - dt_softplus=delta_softplus, - state_batch_indices=inputs["slot_idx"], - pad_slot_id=-1, - ) - - # Prepare output tensor if use_out_tensor is True - if use_out_tensor: - out = torch.empty(batch, nheads, dim, dtype=input_dtype, device="cuda") - else: - out = None - - y_test = flashinfer.mamba.selective_state_update( - state, - inputs["x"], - inputs["dt"], - inputs["A"], - inputs["B"], - inputs["C"], - D=inputs["D"], - z=inputs.get("z"), - dt_bias=inputs["dt_bias"], - dt_softplus=delta_softplus, - state_batch_indices=inputs["slot_idx"], - pad_slot_id=-1, - out=out, - ) - - # Verify the returned tensor is the same object as the provided output tensor - if use_out_tensor: - assert y_test.data_ptr() == out.data_ptr(), ( - "Returned tensor should be the same object as the provided output tensor" - ) - - atol = 1e-3 - rtol = 1e-2 - torch.testing.assert_close(y_ref, y_test, atol=atol, rtol=rtol) - torch.testing.assert_close( - state_ref[inputs["slot_idx"]], - state[inputs["slot_idx"]], - atol=atol, - rtol=rtol, - ) diff --git a/tests/mamba/test_selective_state_update_stp.py b/tests/mamba/test_selective_state_update_stp.py new file mode 100644 index 0000000000..331386ef67 --- /dev/null +++ b/tests/mamba/test_selective_state_update_stp.py @@ -0,0 +1,301 @@ +import numpy as np +import pytest +import torch + +import flashinfer + +from .selective_state_update_triton import selective_state_update_triton +from .test_utils import create_test_inputs + + +def clone_preserving_strides(tensor): + """Clone a tensor while preserving its strides (non-contiguous layout).""" + result = torch.empty_strided( + tensor.size(), tensor.stride(), dtype=tensor.dtype, device=tensor.device + ) + result.copy_(tensor) + return result + + +class TestSelectiveStateUpdate: + """Test class for selective state update kernels.""" + + # Test configuration + ATOL = 1e-3 + RTOL = 1e-2 + NGROUPS = 8 + INPUT_DTYPE = torch.bfloat16 + MATRIX_A_DTYPE = torch.float32 + + @pytest.fixture(params=[1, 64]) + def batch(self, request): + return request.param + + @pytest.fixture(params=[8, 64]) + def nheads(self, request): + return request.param + + @pytest.fixture(params=[64, 128]) + def dim(self, request): + return request.param + + @pytest.fixture(params=[64, 128, 256]) + def dstate(self, request): + return request.param + + @pytest.fixture(params=[torch.float16, torch.bfloat16, torch.float32]) + def state_dtype(self, request): + return request.param + + @pytest.fixture(params=[torch.float32, torch.bfloat16]) + def weight_dtype(self, request): + return request.param + + @pytest.fixture(params=[False, True]) + def use_out_tensor(self, request): + return request.param + + @pytest.fixture + def inputs(self, batch, nheads, dim, dstate, state_dtype, weight_dtype): + """Create test inputs for given parameters.""" + return create_test_inputs( + batch, + nheads, + dim, + dstate, + self.NGROUPS, + self.INPUT_DTYPE, + weight_dtype=weight_dtype, + matrixA_dtype=self.MATRIX_A_DTYPE, + state_dtype=state_dtype, + generate_z=False, + seed=0, + ) + + @pytest.fixture + def reference_output(self, inputs): + """Compute reference output using triton implementation.""" + state_ref = inputs["state_cache"].clone() + y_ref = selective_state_update_triton( + state_ref, + inputs["x"], + inputs["dt"], + inputs["A"], + inputs["B"], + inputs["C"], + D=inputs["D"], + z=inputs.get("z"), + dt_bias=inputs["dt_bias"], + dt_softplus=True, + state_batch_indices=inputs["slot_idx"], + pad_slot_id=-1, + ) + return y_ref, state_ref + + def run_kernel(self, inputs, out=None): + """Run the flashinfer kernel and return output.""" + return flashinfer.mamba.selective_state_update( + inputs["state_cache"], + inputs["x"], + inputs["dt"], + inputs["A"], + inputs["B"], + inputs["C"], + D=inputs["D"], + z=inputs.get("z"), + dt_bias=inputs["dt_bias"], + dt_softplus=True, + state_batch_indices=inputs["slot_idx"], + pad_slot_id=-1, + out=out, + ) + + def assert_outputs_match(self, y_ref, y_test, msg_prefix=""): + """Assert outputs match with detailed error reporting.""" + outputs_match = torch.allclose(y_ref, y_test, atol=self.ATOL, rtol=self.RTOL) + + if outputs_match: + print( + f"✓ {msg_prefix}Outputs match within tolerance (atol={self.ATOL}, rtol={self.RTOL})" + ) + else: + print( + f"✗ {msg_prefix}Outputs do NOT match within tolerance (atol={self.ATOL}, rtol={self.RTOL})" + ) + self._print_mismatch_details(y_ref, y_test, "output") + + assert outputs_match + + def assert_states_match(self, state_ref, state_test, slot_idx, msg_prefix=""): + """Assert states match with detailed error reporting.""" + state_ref_batch = state_ref[slot_idx] + state_test_batch = state_test[slot_idx] + states_match = torch.allclose( + state_ref_batch, state_test_batch, atol=self.ATOL, rtol=self.RTOL + ) + + if states_match: + print( + f"✓ {msg_prefix}States match within tolerance (atol={self.ATOL}, rtol={self.RTOL})" + ) + else: + print( + f"✗ {msg_prefix}States do NOT match within tolerance (atol={self.ATOL}, rtol={self.RTOL})" + ) + self._print_mismatch_details(state_ref_batch, state_test_batch, "state") + + assert states_match + + def _print_mismatch_details(self, ref, test, name): + """Print detailed mismatch analysis.""" + ref_np = ref.detach().cpu().float().numpy() + test_np = test.detach().cpu().float().numpy() + + mismatch_mask = ~np.isclose(ref_np, test_np, atol=self.ATOL, rtol=self.RTOL) + num_mismatches = np.sum(mismatch_mask) + total_elements = ref_np.size + + print(f"\nDetailed {name} mismatch analysis:") + print( + f"Number of mismatched elements: {num_mismatches} / {total_elements} " + f"({100 * num_mismatches / total_elements:.2f}%)" + ) + + mismatch_indices = np.argwhere(mismatch_mask) + print(f"First few {name} mismatch locations (up to 10):") + for idx in mismatch_indices[:10]: + idx_tuple = tuple(int(i) for i in idx) + ref_val = ref_np[idx_tuple] + test_val = test_np[idx_tuple] + diff = abs(ref_val - test_val) + rel_diff = diff / (abs(ref_val) + 1e-8) + print( + f" Index {idx_tuple}: ref={ref_val:.6f}, test={test_val:.6f}, " + f"diff={diff:.6e}, rel_diff={rel_diff:.6e}" + ) + + def test_output_correctness(self, inputs, reference_output, use_out_tensor): + """Test that kernel output matches reference within tolerance.""" + y_ref, state_ref = reference_output + + # Prepare output tensor if requested + if use_out_tensor: + batch = inputs["x"].shape[0] + nheads = inputs["x"].shape[1] + dim = inputs["x"].shape[2] + out = torch.empty(batch, nheads, dim, dtype=self.INPUT_DTYPE, device="cuda") + else: + out = None + + y_test = self.run_kernel(inputs, out=out) + + # Verify output tensor identity if provided + if use_out_tensor: + assert y_test.data_ptr() == out.data_ptr(), ( + "Returned tensor should be the same object as the provided output tensor" + ) + + self.assert_outputs_match(y_ref, y_test) + self.assert_states_match(state_ref, inputs["state_cache"], inputs["slot_idx"]) + + +class TestSelectiveStateUpdateWithZ(TestSelectiveStateUpdate): + """Test selective_state_update with z tensor (gating).""" + + @pytest.fixture(params=[1]) + def batch(self, request): + return request.param + + @pytest.fixture(params=[8]) + def nheads(self, request): + return request.param + + @pytest.fixture(params=[64]) + def dim(self, request): + return request.param + + @pytest.fixture(params=[128]) + def dstate(self, request): + return request.param + + @pytest.fixture(params=[torch.bfloat16]) + def state_dtype(self, request): + return request.param + + @pytest.fixture(params=[torch.bfloat16]) + def weight_dtype(self, request): + return request.param + + @pytest.fixture + def inputs(self, batch, nheads, dim, dstate, state_dtype, weight_dtype): + """Create test inputs with z tensor.""" + return create_test_inputs( + batch, + nheads, + dim, + dstate, + self.NGROUPS, + self.INPUT_DTYPE, + weight_dtype=weight_dtype, + matrixA_dtype=self.MATRIX_A_DTYPE, + state_dtype=state_dtype, + generate_z=True, + seed=0, + ) + + +@pytest.mark.xfail(reason="Non-contiguous state cache not yet supported") +class TestSelectiveStateUpdateNonContiguous(TestSelectiveStateUpdate): + """Test selective_state_update with non-contiguous state cache.""" + + @pytest.fixture(params=[64, 128]) + def dstate(self, request): + return request.param + + @pytest.fixture(params=[torch.bfloat16, torch.float32]) + def state_dtype(self, request): + return request.param + + @pytest.fixture(params=[torch.float32]) + def weight_dtype(self, request): + return request.param + + @pytest.fixture + def inputs(self, batch, nheads, dim, dstate, state_dtype, weight_dtype): + """Create test inputs with non-contiguous state cache (2x batch stride).""" + noncontiguous_batch_stride = 2 * nheads * dim * dstate + + return create_test_inputs( + batch, + nheads, + dim, + dstate, + self.NGROUPS, + self.INPUT_DTYPE, + weight_dtype=weight_dtype, + matrixA_dtype=self.MATRIX_A_DTYPE, + state_dtype=state_dtype, + generate_z=False, + state_cache_batch_stride=noncontiguous_batch_stride, + seed=0, + ) + + @pytest.fixture + def reference_output(self, inputs): + """Compute reference output, preserving non-contiguous strides.""" + state_ref = clone_preserving_strides(inputs["state_cache"]) + y_ref = selective_state_update_triton( + state_ref, + inputs["x"], + inputs["dt"], + inputs["A"], + inputs["B"], + inputs["C"], + D=inputs["D"], + z=inputs.get("z"), + dt_bias=inputs["dt_bias"], + dt_softplus=True, + state_batch_indices=inputs["slot_idx"], + pad_slot_id=-1, + ) + return y_ref, state_ref From a8bc286f74d3075ea41ccc83dfa6871315f98a63 Mon Sep 17 00:00:00 2001 From: Igor Shovkun Date: Wed, 28 Jan 2026 22:40:59 -0800 Subject: [PATCH 04/33] Add multi-token support to the interface of selective_state_update --- csrc/flashinfer_mamba_binding.cu | 27 +++++++-- csrc/selective_state_update.cu | 46 +++++++++------ flashinfer/mamba/selective_state_update.py | 67 +++++++++++++++++----- 3 files changed, 105 insertions(+), 35 deletions(-) diff --git a/csrc/flashinfer_mamba_binding.cu b/csrc/flashinfer_mamba_binding.cu index 347fd0fb4f..dfdc5bebf8 100644 --- a/csrc/flashinfer_mamba_binding.cu +++ b/csrc/flashinfer_mamba_binding.cu @@ -20,10 +20,29 @@ using tvm::ffi::Optional; namespace flashinfer::mamba { -void selective_state_update(TensorView state, TensorView x, TensorView dt, TensorView output, - TensorView A, TensorView B, TensorView C, TensorView D, - Optional z, Optional dt_bias, bool dt_softplus, - Optional state_batch_indices, int64_t pad_slot_id); +void selective_state_update( + TensorView state, // (batch, dim, dstate) or (batch, nheads, dim, dstate) + TensorView x, // (batch, dim) or (batch, nheads, dim) for single-token + // or (batch, T, nheads, dim) for multi-token + TensorView dt, // (batch, dim) or (batch, nheads, dim) for single-token + // or (batch, T, nheads, dim) for multi-token + TensorView A, // (dim, dstate) or (nheads, dim, dstate) + TensorView B, // (batch, dstate) or (batch, ngroups, dstate) for single-token + // or (batch, T, ngroups, dstate) for multi-token + TensorView C, // (batch, dstate) or (batch, ngroups, dstate) for single-token + // or (batch, T, ngroups, dstate) for multi-token + TensorView D, // (dim,) or (nheads, dim) + Optional z, // (batch, dim) or (batch, nheads, dim) for single-token + // or (batch, T, nheads, dim) for multi-token + Optional dt_bias, // (dim,) or (nheads, dim) + bool dt_softplus, + Optional state_batch_indices, // (batch,) + int64_t pad_slot_id, + TensorView output, // same as x + bool disable_state_update, + Optional intermediate_states_buffer, // (batch, cache_steps, nheads, dim, dstate) + Optional intermediate_state_indices, // (batch,) + int64_t cache_steps); } // namespace flashinfer::mamba diff --git a/csrc/selective_state_update.cu b/csrc/selective_state_update.cu index f2dbd121c5..021be04b05 100644 --- a/csrc/selective_state_update.cu +++ b/csrc/selective_state_update.cu @@ -22,10 +22,25 @@ using tvm::ffi::Optional; namespace flashinfer::mamba { -void selective_state_update(TensorView state, TensorView x, TensorView dt, TensorView output, - TensorView A, TensorView B, TensorView C, TensorView D, - Optional z, Optional dt_bias, bool dt_softplus, - Optional state_batch_indices, int64_t pad_slot_id) { +// New function signature with multi-token support +void selective_state_update(TensorView state, TensorView x, TensorView dt, TensorView A, + TensorView B, TensorView C, TensorView D, Optional z, + Optional dt_bias, bool dt_softplus, + Optional state_batch_indices, int64_t pad_slot_id, + TensorView output, bool disable_state_update, + Optional intermediate_states_buffer, + Optional intermediate_state_indices, int64_t cache_steps) { + // // TODO: Implement multi-token support + // FLASHINFER_CHECK(false, "selective_state_update with new signature not yet implemented"); + // } + + // Old function - commented out for reference + // void selective_state_update_old(TensorView state, TensorView x, TensorView dt, TensorView + // output, + // TensorView A, TensorView B, TensorView C, TensorView D, + // Optional z, Optional dt_bias, bool + // dt_softplus, Optional state_batch_indices, int64_t + // pad_slot_id) { auto const batch = x.size(0); auto const state_cache_size = state.size(0); auto const nheads = state.size(1); @@ -187,43 +202,42 @@ void selective_state_update(TensorView state, TensorView x, TensorView dt, Tenso auto dtype_key = std::make_tuple(state_dtype_code, input_dtype_code, weight_dtype_code, matrixA_dtype_code); - if (dtype_key == std::make_tuple(/*state*/ bfloat16_code, /*input */ bfloat16_code, - /*weight */ bfloat16_code, /*matrixA */ float32_code)) { + if (dtype_key == std::make_tuple(bfloat16_code, bfloat16_code, bfloat16_code, float32_code)) { using state_t = nv_bfloat16; using input_t = nv_bfloat16; using weight_t = nv_bfloat16; using matrixA_t = float; invokeSelectiveStateUpdate(p, stream); - } else if (dtype_key == std::make_tuple(/*state*/ float16_code, /*input */ bfloat16_code, - /*weight */ bfloat16_code, /*matrixA */ float32_code)) { + } else if (dtype_key == + std::make_tuple(float16_code, bfloat16_code, bfloat16_code, float32_code)) { using state_t = half; using input_t = nv_bfloat16; using weight_t = nv_bfloat16; using matrixA_t = float; invokeSelectiveStateUpdate(p, stream); - } else if (dtype_key == std::make_tuple(/*state*/ float32_code, /*input */ bfloat16_code, - /*weight */ bfloat16_code, /*matrixA */ float32_code)) { + } else if (dtype_key == + std::make_tuple(float32_code, bfloat16_code, bfloat16_code, float32_code)) { using state_t = float; using input_t = nv_bfloat16; using weight_t = nv_bfloat16; using matrixA_t = float; invokeSelectiveStateUpdate(p, stream); - } else if (dtype_key == std::make_tuple(/*state*/ bfloat16_code, /*input */ bfloat16_code, - /*weight */ float32_code, /*matrixA */ float32_code)) { + } else if (dtype_key == + std::make_tuple(bfloat16_code, bfloat16_code, float32_code, float32_code)) { using state_t = nv_bfloat16; using input_t = nv_bfloat16; using weight_t = float; using matrixA_t = float; invokeSelectiveStateUpdate(p, stream); - } else if (dtype_key == std::make_tuple(/*state*/ float16_code, /*input */ bfloat16_code, - /*weight */ float32_code, /*matrixA */ float32_code)) { + } else if (dtype_key == + std::make_tuple(float16_code, bfloat16_code, float32_code, float32_code)) { using state_t = half; using input_t = nv_bfloat16; using weight_t = float; using matrixA_t = float; invokeSelectiveStateUpdate(p, stream); - } else if (dtype_key == std::make_tuple(/*state*/ float32_code, /*input */ bfloat16_code, - /*weight */ float32_code, /*matrixA */ float32_code)) { + } else if (dtype_key == + std::make_tuple(float32_code, bfloat16_code, float32_code, float32_code)) { using state_t = float; using input_t = nv_bfloat16; using weight_t = float; diff --git a/flashinfer/mamba/selective_state_update.py b/flashinfer/mamba/selective_state_update.py index 4971b7eb8c..07ef073b33 100644 --- a/flashinfer/mamba/selective_state_update.py +++ b/flashinfer/mamba/selective_state_update.py @@ -73,7 +73,11 @@ def selective_state_update( dt_softplus: bool = False, state_batch_indices: Optional[torch.Tensor] = None, pad_slot_id: int = -1, - out: torch.Tensor | None = None, + out: Optional[torch.Tensor] = None, + disable_state_update: bool = False, + intermediate_states_buffer: Optional[torch.Tensor] = None, + intermediate_state_indices: Optional[torch.Tensor] = None, + cache_steps: int = 0, ) -> torch.Tensor: r"""Selective state update operation for Mamba layers (the generation phase). @@ -82,36 +86,52 @@ def selective_state_update( state : torch.Tensor State tensor with shape (state_cache_size, dim, dstate) or (state_cache_size, nheads, dim, dstate) x : torch.Tensor - Input tensor with shape (batch, dim) or (batch, nheads, dim) + Input tensor with shape (batch, dim) or (batch, nheads, dim) for single-token + or (batch, T, nheads, dim) for multi-token dt : torch.Tensor - Delta time tensor with shape (batch, dim) or (batch, nheads, dim) + Delta time tensor with shape (batch, dim) or (batch, nheads, dim) for single-token + or (batch, T, nheads, dim) for multi-token A : torch.Tensor A matrix with shape (dim, dstate) or (nheads, dim, dstate) B : torch.Tensor - B matrix with shape (batch, dstate) or (batch, ngroups, dstate) + B matrix with shape (batch, dstate) or (batch, ngroups, dstate) for single-token + or (batch, T, ngroups, dstate) for multi-token C : torch.Tensor - C matrix with shape (batch, dstate) or (batch, ngroups, dstate) + C matrix with shape (batch, dstate) or (batch, ngroups, dstate) for single-token + or (batch, T, ngroups, dstate) for multi-token D : torch.Tensor D vector with shape (dim,) or (nheads, dim) z : Optional[torch.Tensor] - Optional z tensor with shape (batch, dim) or (batch, nheads, dim) + Optional z tensor with shape (batch, dim) or (batch, nheads, dim) for single-token + or (batch, T, nheads, dim) for multi-token dt_bias : Optional[torch.Tensor] Optional dt bias with shape (dim,) or (nheads, dim) dt_softplus : bool Whether to apply softplus to dt state_batch_indices : Optional[torch.Tensor] - Optional batch indices for cache processing + Optional batch indices for cache processing with shape (batch,) pad_slot_id : int If state_batch_indices is passed, lets the kernel identify padded entries that will not be processed. For example: state_batch_indices = [pad_slot_id, 1, 20, pad_slot_id] in this case, the kernel will not process entries at indices 0 and 3 - out : torch.Tensor | None - Optional output tensor + out : Optional[torch.Tensor] + Optional output tensor (same shape as x) + disable_state_update : bool + If True, skip updating the state tensor (useful for speculative decoding verification) + intermediate_states_buffer : Optional[torch.Tensor] + Optional buffer for caching intermediate states during speculative decoding + with shape (batch, cache_steps, nheads, dim, dstate) + intermediate_state_indices : Optional[torch.Tensor] + Optional indices mapping batch elements to intermediate state buffer positions + with shape (batch,) + cache_steps : int + Number of steps/tokens to cache for speculative decoding Returns ------- output : torch.Tensor - Output tensor with shape (batch, dim) or (batch, nheads, dim) + Output tensor with shape (batch, dim) or (batch, nheads, dim) for single-token + or (batch, T, nheads, dim) for multi-token """ if state.dim() == 3: state = state.unsqueeze(1) @@ -139,7 +159,6 @@ def selective_state_update( state, x, dt, - output, A, B, C, @@ -149,18 +168,23 @@ def selective_state_update( dt_softplus, state_batch_indices, pad_slot_id, + output, + disable_state_update, + intermediate_states_buffer, + intermediate_state_indices, + cache_steps, ) return output @register_custom_op( - "flashinfer::selective_state_update", mutates_args=("state", "output") + "flashinfer::selective_state_update", + mutates_args=("state", "output", "intermediate_states_buffer"), ) def _selective_state_update( state: torch.Tensor, x: torch.Tensor, dt: torch.Tensor, - output: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, @@ -170,13 +194,17 @@ def _selective_state_update( dt_softplus: bool, state_batch_indices: Optional[torch.Tensor], pad_slot_id: int, + output: torch.Tensor, + disable_state_update: bool, + intermediate_states_buffer: Optional[torch.Tensor], + intermediate_state_indices: Optional[torch.Tensor], + cache_steps: int, ) -> None: """Internal function registered with torch.library for torch.compile() support.""" get_selective_state_update_module(state.device).selective_state_update( state, x, dt, - output, A, B, C, @@ -186,6 +214,11 @@ def _selective_state_update( dt_softplus, state_batch_indices, pad_slot_id, + output, + disable_state_update, + intermediate_states_buffer, + intermediate_state_indices, + cache_steps, ) @@ -194,7 +227,6 @@ def _selective_state_update_fake( state: torch.Tensor, x: torch.Tensor, dt: torch.Tensor, - output: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, @@ -204,6 +236,11 @@ def _selective_state_update_fake( dt_softplus: bool, state_batch_indices: Optional[torch.Tensor], pad_slot_id: int, + output: torch.Tensor, + disable_state_update: bool, + intermediate_states_buffer: Optional[torch.Tensor], + intermediate_state_indices: Optional[torch.Tensor], + cache_steps: int, ) -> None: """Fake implementation for torch.compile() meta tensor propagation.""" pass From 2e70ea4de793523b751be7feaeb4214bd9726969 Mon Sep 17 00:00:00 2001 From: Igor Shovkun Date: Thu, 29 Jan 2026 09:55:37 -0800 Subject: [PATCH 05/33] Refactor selective_state_update: add validation helpers and update param struct - Add helper functions for tensor validation and dtype checks - Move output tensor to Optional and update checks accordingly - Add state_stride_batch and update_state fields to SelectiveStateUpdateParams - Refactor kernel param usage for clarity and consistency --- csrc/selective_state_update.cu | 531 ++++++++++++++---- .../mamba/selective_state_update.cuh | 255 ++++----- 2 files changed, 564 insertions(+), 222 deletions(-) diff --git a/csrc/selective_state_update.cu b/csrc/selective_state_update.cu index 021be04b05..bd7d96f74a 100644 --- a/csrc/selective_state_update.cu +++ b/csrc/selective_state_update.cu @@ -22,25 +22,117 @@ using tvm::ffi::Optional; namespace flashinfer::mamba { -// New function signature with multi-token support -void selective_state_update(TensorView state, TensorView x, TensorView dt, TensorView A, - TensorView B, TensorView C, TensorView D, Optional z, - Optional dt_bias, bool dt_softplus, - Optional state_batch_indices, int64_t pad_slot_id, - TensorView output, bool disable_state_update, - Optional intermediate_states_buffer, - Optional intermediate_state_indices, int64_t cache_steps) { - // // TODO: Implement multi-token support - // FLASHINFER_CHECK(false, "selective_state_update with new signature not yet implemented"); - // } - - // Old function - commented out for reference - // void selective_state_update_old(TensorView state, TensorView x, TensorView dt, TensorView - // output, - // TensorView A, TensorView B, TensorView C, TensorView D, - // Optional z, Optional dt_bias, bool - // dt_softplus, Optional state_batch_indices, int64_t - // pad_slot_id) { +static inline void validate_state_tensor(TensorView const& state) { + CHECK_CUDA(state); + CHECK_DIM(4, state); // state: {state_cache_size, nheads, dim, dstate} + // Check that dimensions 1, 2, 3 are contiguous (batch dimension can be non-contiguous) + auto strides = state.strides(); + auto sizes = state.sizes(); + FLASHINFER_CHECK(strides[3] == 1, "state dimension 3 (dstate) must have stride 1"); + FLASHINFER_CHECK(strides[2] == sizes[3], + "state dimension 2 (dim) must be contiguous with dimension 3"); + FLASHINFER_CHECK(strides[1] == sizes[2] * sizes[3], + "state dimension 1 (nheads) must be contiguous with dimension 2"); +} + +inline void validate_D_tensor(TensorView const& D, int64_t nheads, int64_t dim) { + CHECK_CUDA(D); + CHECK_DIM(2, D); // D: {nheads, dim} + FLASHINFER_CHECK(D.size(0) == nheads, "D.size(0) must equal nheads"); + FLASHINFER_CHECK(D.size(1) == dim, "D.size(1) must equal dim"); + FLASHINFER_CHECK(D.stride(0) == 1, "D.stride(0) must be 1, got ", D.stride(0)); + FLASHINFER_CHECK(D.stride(1) == 0, "D.stride(1) must be 0 (broadcasted), got ", D.stride(1)); +} + +inline void validate_A_tensor(TensorView const& A, int64_t nheads, int64_t dim, int64_t dstate) { + CHECK_CUDA(A); + CHECK_DIM(3, A); // A: {nheads, dim, dstate} + FLASHINFER_CHECK(A.size(0) == nheads, "A.size(0) must equal nheads"); + FLASHINFER_CHECK(A.size(1) == dim, "A.size(1) must equal dim"); + FLASHINFER_CHECK(A.size(2) == dstate, "A.size(2) must equal dstate"); + FLASHINFER_CHECK(A.stride(0) == 1, "A.stride(0) must be 1, got ", A.stride(0)); + FLASHINFER_CHECK(A.stride(1) == 0, "A.stride(1) must be 0 (broadcasted), got ", A.stride(1)); + FLASHINFER_CHECK(A.stride(2) == 0, "A.stride(2) must be 0 (broadcasted), got ", A.stride(2)); +} + +inline void validate_dt_bias_tensor(Optional const& dt_bias, int64_t nheads, + int64_t dim) { + if (!dt_bias.has_value()) return; + auto const& bias = dt_bias.value(); + CHECK_CUDA(bias); + CHECK_DIM(2, bias); // dt_bias: {nheads, dim} + FLASHINFER_CHECK(bias.size(0) == nheads, "dt_bias.size(0) must equal nheads"); + FLASHINFER_CHECK(bias.size(1) == dim, "dt_bias.size(1) must equal dim"); + FLASHINFER_CHECK(bias.stride(0) == 1, "dt_bias.stride(0) must be 1, got ", bias.stride(0)); + FLASHINFER_CHECK(bias.stride(1) == 0, "dt_bias.stride(1) must be 0 (broadcasted), got ", + bias.stride(1)); +} + +inline void validate_state_batch_indices(Optional const& state_batch_indices, + int64_t batch) { + if (!state_batch_indices.has_value()) return; + CHECK_DIM(1, (*state_batch_indices)); + CHECK_CONTIGUOUS((*state_batch_indices)); + FLASHINFER_CHECK(state_batch_indices.value().size(0) == batch, + "state_batch_indices.shape must be (", batch, ")"); + CHECK_INPUT_TYPE(state_batch_indices.value(), dl_int32); +} + +inline void validate_intermediate_state_indices( + Optional const& intermediate_state_indices, int64_t batch) { + if (!intermediate_state_indices.has_value()) return; + CHECK_CUDA(intermediate_state_indices.value()); + CHECK_DIM(1, intermediate_state_indices.value()); + CHECK_CONTIGUOUS(intermediate_state_indices.value()); + FLASHINFER_CHECK(intermediate_state_indices.value().size(0) == batch, + "intermediate_state_indices.shape must be (", batch, ")"); + CHECK_INPUT_TYPE(intermediate_state_indices.value(), dl_int32); +} + +inline void validate_intermediate_states_buffer( + Optional const& intermediate_states_buffer) { + if (!intermediate_states_buffer.has_value()) return; + CHECK_CUDA(intermediate_states_buffer.value()); + CHECK_CONTIGUOUS(intermediate_states_buffer.value()); +} + +// Validates dtype consistency across tensors +inline void validate_dtypes( + TensorView const& state, TensorView const& dt, TensorView const& D, TensorView const& x, + TensorView const& B, TensorView const& C, Optional const& dt_bias, + Optional const& z = Optional(), + Optional const& out = Optional(), + Optional const& intermediate_states_buffer = Optional()) { + auto state_dtype = state.dtype(); + auto weight_dtype = dt.dtype(); + auto input_dtype = x.dtype(); + FLASHINFER_CHECK(D.dtype() == weight_dtype, "D must have the same dtype as dt"); + FLASHINFER_CHECK(B.dtype() == input_dtype, "B must have the same dtype as x"); + FLASHINFER_CHECK(C.dtype() == input_dtype, "C must have the same dtype as x"); + if (dt_bias.has_value()) { + FLASHINFER_CHECK(dt_bias.value().dtype() == weight_dtype, + "dt_bias must have the same dtype as dt"); + } + if (z.has_value()) { + FLASHINFER_CHECK(z.value().dtype() == input_dtype, "z must have the same dtype as x"); + } + if (out.has_value()) { + FLASHINFER_CHECK(out.value().dtype() == input_dtype, "out must have the same dtype as x"); + } + if (intermediate_states_buffer.has_value()) { + FLASHINFER_CHECK(intermediate_states_buffer.value().dtype() == state_dtype, + "intermediate_states_buffer must have the same dtype as state"); + } +} + +void run_selective_state_update_stp(TensorView const& state, TensorView const& x, + TensorView const& dt, TensorView const& A, TensorView const& B, + TensorView const& C, TensorView const& D, + Optional z, Optional dt_bias, + bool dt_softplus, Optional state_batch_indices, + int64_t pad_slot_id, Optional out, + bool disable_state_update) { + // Extract dimensions from input tensors auto const batch = x.size(0); auto const state_cache_size = state.size(0); auto const nheads = state.size(1); @@ -49,85 +141,54 @@ void selective_state_update(TensorView state, TensorView x, TensorView dt, Tenso auto const ngroups = B.size(1); FLASHINFER_CHECK(state_cache_size >= batch, "state.size(0) must be >= x.size(0)"); - FLASHINFER_CHECK(nheads % ngroups == 0, "nheads must be divisible by ngroups"); // Check x shape and strides + CHECK_CUDA(x); CHECK_DIM(3, x); FLASHINFER_CHECK(x.size(1) == nheads, "x.size(1) must equal nheads"); FLASHINFER_CHECK(x.size(2) == dim, "x.size(2) must equal dim"); - CHECK_LAST_DIM_CONTIGUOUS_INPUT(x); + CHECK_LAST_DIM_CONTIGUOUS(x); FLASHINFER_CHECK(x.stride(1) == dim, "x.stride(1) must equal dim, got ", x.stride(1), - " expected ", x.size(2)); - - // Check output shape and strides - CHECK_DIM(3, output); - CHECK_LAST_DIM_CONTIGUOUS(output); - FLASHINFER_CHECK(output.size(1) == nheads, "output.size(1) must equal nheads"); - FLASHINFER_CHECK(output.size(2) == dim, "output.size(2) must equal dim"); - FLASHINFER_CHECK(output.stride(1) == dim, "output.stride(1) must equal dim"); + " expected ", dim); // Check dt shape and strides CHECK_CUDA(dt); CHECK_DIM(3, dt); // dt: {batch, nheads, dim} + FLASHINFER_CHECK(dt.size(0) == batch, "dt.size(0) must equal batch"); FLASHINFER_CHECK(dt.size(1) == nheads, "dt.size(1) must equal nheads"); FLASHINFER_CHECK(dt.size(2) == dim, "dt.size(2) must equal dim"); FLASHINFER_CHECK(dt.stride(1) == 1, "dt.stride(1) must be 1, got ", dt.stride(1)); FLASHINFER_CHECK(dt.stride(2) == 0, "dt.stride(2) must be 0 (broadcasted), got ", dt.stride(2)); - // Check state - fully contiguous - CHECK_INPUT(state); // CUDA + fully contiguous (uses TVM FFI) - CHECK_DIM(4, state); // state: {state_cache_size, nheads, dim, dstate} + // Validate common tensors using helper functions + validate_state_tensor(state); + validate_D_tensor(D, nheads, dim); + validate_A_tensor(A, nheads, dim, dstate); + validate_dt_bias_tensor(dt_bias, nheads, dim); + validate_state_batch_indices(state_batch_indices, batch); // Check B shape and strides CHECK_CUDA(B); - CHECK_DIM(3, B); // B: {batch, B.size(1), dstate} + CHECK_DIM(3, B); // B: {batch, ngroups, dstate} FLASHINFER_CHECK(B.size(0) == batch, "B.size(0) must equal batch"); FLASHINFER_CHECK(B.size(1) == ngroups, "B.size(1) must equal ngroups"); FLASHINFER_CHECK(B.size(2) == dstate, "B.size(2) must equal dstate"); - CHECK_LAST_DIM_CONTIGUOUS(B); // stride(2) == 1 - FLASHINFER_CHECK(B.stride(1) == B.size(2), "B.stride(1) must equal dstate, got ", B.stride(1), - " expected ", B.size(2)); + CHECK_LAST_DIM_CONTIGUOUS(B); + FLASHINFER_CHECK(B.stride(1) == dstate, "B.stride(1) must equal dstate, got ", B.stride(1), + " expected ", dstate); // Check C shape and strides CHECK_CUDA(C); - CHECK_LAST_DIM_CONTIGUOUS(C); // stride(2) == 1 - CHECK_DIM(3, C); // C: {batch, C.size(1), dstate} - FLASHINFER_CHECK(C.stride(1) == C.size(2), "C.stride(1) must equal dstate, got ", C.stride(1), - " expected ", C.size(2)); + CHECK_DIM(3, C); // C: {batch, ngroups, dstate} FLASHINFER_CHECK(C.size(0) == batch, "C.size(0) must equal batch"); FLASHINFER_CHECK(C.size(1) == ngroups, "C.size(1) must equal ngroups"); FLASHINFER_CHECK(C.size(2) == dstate, "C.size(2) must equal dstate"); + CHECK_LAST_DIM_CONTIGUOUS(C); + FLASHINFER_CHECK(C.stride(1) == dstate, "C.stride(1) must equal dstate, got ", C.stride(1), + " expected ", dstate); - // Check D - specific stride patterns indicating broadcasting - CHECK_CUDA(D); - CHECK_DIM(2, D); // D: {nheads, dim} - FLASHINFER_CHECK(D.size(0) == nheads, "D.size(0) must equal nheads"); - FLASHINFER_CHECK(D.size(1) == dim, "D.size(1) must equal dim"); - FLASHINFER_CHECK(D.stride(0) == 1, "D.stride(0) must be 1, got ", D.stride(0)); - FLASHINFER_CHECK(D.stride(1) == 0, "D.stride(1) must be 0 (broadcasted), got ", D.stride(1)); - - // Check A - specific stride patterns indicating broadcasting - CHECK_CUDA(A); - CHECK_DIM(3, A); // A: {nheads, dim, dstate} - FLASHINFER_CHECK(A.size(0) == nheads, "A.size(0) must equal nheads"); - FLASHINFER_CHECK(A.size(1) == dim, "A.size(1) must equal dim"); - FLASHINFER_CHECK(A.size(2) == dstate, "A.size(2) must equal dstate"); - FLASHINFER_CHECK(A.stride(1) == 0, "A.stride(1) must be 0 (broadcasted), got ", A.stride(1)); - FLASHINFER_CHECK(A.stride(2) == 0, "A.stride(2) must be 0 (broadcasted), got ", A.stride(2)); - - // Optional dt_bias check - if (dt_bias.has_value()) { - auto& bias = dt_bias.value(); - CHECK_CUDA(bias); - CHECK_DIM(2, bias); // dt_bias: {nheads, dim} - FLASHINFER_CHECK(bias.size(0) == nheads, "dt_bias.size(0) must equal nheads"); - FLASHINFER_CHECK(bias.size(1) == dim, "dt_bias.size(1) must equal dim"); - FLASHINFER_CHECK(bias.stride(0) == 1, "dt_bias.stride(0) must be 1, got ", bias.stride(0)); - FLASHINFER_CHECK(bias.stride(1) == 0, "dt_bias.stride(1) must be 0 (broadcasted), got ", - bias.stride(1)); - } - + // Optional z check if (z.has_value()) { auto& z_tensor = z.value(); CHECK_CUDA(z_tensor); @@ -135,20 +196,31 @@ void selective_state_update(TensorView state, TensorView x, TensorView dt, Tenso FLASHINFER_CHECK(z_tensor.size(0) == batch, "z.size(0) must equal batch"); FLASHINFER_CHECK(z_tensor.size(1) == nheads, "z.size(1) must equal nheads"); FLASHINFER_CHECK(z_tensor.size(2) == dim, "z.size(2) must equal dim"); - CHECK_LAST_DIM_CONTIGUOUS_INPUT(z_tensor); + CHECK_LAST_DIM_CONTIGUOUS(z_tensor); FLASHINFER_CHECK(z_tensor.stride(1) == dim, "z.stride(1) must equal dim, got ", - z_tensor.stride(1), " expected ", z_tensor.size(2)); + z_tensor.stride(1), " expected ", dim); } - if (state_batch_indices) { - CHECK_DIM(1, (*state_batch_indices)); - FLASHINFER_CHECK(state_batch_indices.value().size(0) == batch, - "state_batch_indices.shape must be (", batch, ")"); + // Check output tensor if provided + if (out.has_value()) { + auto& output = out.value(); + CHECK_CUDA(output); + CHECK_CONTIGUOUS(output); + CHECK_DIM(3, output); + FLASHINFER_CHECK(output.size(0) == batch, "out.size(0) must equal batch"); + FLASHINFER_CHECK(output.size(1) == nheads, "out.size(1) must equal nheads"); + FLASHINFER_CHECK(output.size(2) == dim, "out.size(2) must equal dim"); + CHECK_LAST_DIM_CONTIGUOUS(output); + FLASHINFER_CHECK(output.stride(1) == dim, "out.stride(1) must equal dim"); } + // Validate dtype consistency + validate_dtypes(state, dt, D, x, B, C, dt_bias, z, out); + + // Initialize params struct SelectiveStateUpdateParams p; - // copy dimensions + // Copy dimensions p.batch = batch; p.nheads = nheads; p.dim = dim; @@ -157,31 +229,43 @@ void selective_state_update(TensorView state, TensorView x, TensorView dt, Tenso p.state_cache_size = state_cache_size; p.dt_softplus = dt_softplus; p.pad_slot_id = pad_slot_id; + p.update_state = !disable_state_update; // Copy strides p.x_stride_batch = x.stride(0); p.dt_stride_batch = dt.stride(0); p.B_stride_batch = B.stride(0); p.C_stride_batch = C.stride(0); - p.out_stride_batch = output.stride(0); - if (state_batch_indices) p.state_batch_indices = state_batch_indices.value().data_ptr(); + if (out.has_value()) { + p.out_stride_batch = out.value().stride(0); + } else { + p.out_stride_batch = 0; + } + p.state_stride_batch = state.stride(0); + if (state_batch_indices.has_value()) { + p.state_batch_indices = const_cast(state_batch_indices.value().data_ptr()); + } // Copy pointers - p.state = state.data_ptr(); - p.x = x.data_ptr(); - p.dt = dt.data_ptr(); - p.output = output.data_ptr(); - if (dt_bias) { - p.dt_bias = dt_bias.value().data_ptr(); + p.state = const_cast(state.data_ptr()); + p.x = const_cast(x.data_ptr()); + p.dt = const_cast(dt.data_ptr()); + if (out.has_value()) { + p.output = out.value().data_ptr(); + } else { + p.output = nullptr; + } + if (dt_bias.has_value()) { + p.dt_bias = const_cast(dt_bias.value().data_ptr()); } - if (z) { - p.z = z.value().data_ptr(); + if (z.has_value()) { + p.z = const_cast(z.value().data_ptr()); p.z_stride_batch = z.value().stride(0); } - p.A = A.data_ptr(); - p.B = B.data_ptr(); - p.C = C.data_ptr(); - p.D = D.data_ptr(); + p.A = const_cast(A.data_ptr()); + p.B = const_cast(B.data_ptr()); + p.C = const_cast(C.data_ptr()); + p.D = const_cast(D.data_ptr()); // Set device and get stream ffi::CUDADeviceGuard device_guard(state.device().device_id); @@ -198,7 +282,7 @@ void selective_state_update(TensorView state, TensorView x, TensorView dt, Tenso int64_t weight_dtype_code = encode_dlpack_dtype(weight_dtype); int64_t matrixA_dtype_code = encode_dlpack_dtype(matrixA_dtype); - // Pack all dtype codes into a single value for switching + // Dispatch kernel based on dtype combination auto dtype_key = std::make_tuple(state_dtype_code, input_dtype_code, weight_dtype_code, matrixA_dtype_code); @@ -246,12 +330,11 @@ void selective_state_update(TensorView state, TensorView x, TensorView dt, Tenso } else { // Default case: unsupported dtype combination TVM_FFI_ICHECK(false) - << "Unsupported dtype combination for selective_state_update: " - << "state_dtype=" << state_dtype.code << ":" << state_dtype.bits << ", " - << "input_dtype=" << input_dtype.code << ":" << input_dtype.bits << ", " - << "weight_dtype=" << weight_dtype.code << ":" << weight_dtype.bits << ", " - << "matrixA_dtype=" << matrixA_dtype.code << ":" << matrixA_dtype.bits - << ". Supported combos include:\n" + << "Unsupported dtype combination for selective_state_update: " << "state_dtype=" + << state_dtype.code << ":" << state_dtype.bits << ", " << "input_dtype=" << input_dtype.code + << ":" << input_dtype.bits << ", " << "weight_dtype=" << weight_dtype.code << ":" + << weight_dtype.bits << ", " << "matrixA_dtype=" << matrixA_dtype.code << ":" + << matrixA_dtype.bits << ". Supported combos include:\n" << " (state=bfloat16, input=bfloat16, weight=bfloat16, matrixA=float32)\n" << " (state=float16, input=bfloat16, weight=bfloat16, matrixA=float32)\n" << " (state=float32, input=bfloat16, weight=bfloat16, matrixA=float32)\n" @@ -261,4 +344,258 @@ void selective_state_update(TensorView state, TensorView x, TensorView dt, Tenso } } +// ============================================================================= +// Generic dispatcher - routes to single-token or multi-token based on x.dim() +// ============================================================================= +void selective_state_update(TensorView state, TensorView x, TensorView dt, TensorView A, + TensorView B, TensorView C, TensorView D, Optional z, + Optional dt_bias, bool dt_softplus, + Optional state_batch_indices, int64_t pad_slot_id, + TensorView output, bool disable_state_update, + Optional intermediate_states_buffer, + Optional intermediate_state_indices, int64_t cache_steps) { + if (x.dim() == 3) { + run_selective_state_update_stp(state, x, dt, A, B, C, D, z, dt_bias, dt_softplus, + state_batch_indices, pad_slot_id, output, disable_state_update); + } else if (x.dim() == 4) { + FLASHINFER_CHECK(false, "x must have 3 dimensions (single-token), got ", x.dim()); + // run_selective_state_update_mtp(state, x, dt, A, B, C, D, z, dt_bias, dt_softplus, + // state_batch_indices, pad_slot_id, output, + // disable_state_update, intermediate_states_buffer, + // intermediate_state_indices, cache_steps); + } else { + FLASHINFER_CHECK(false, + "x must have 3 dimensions (single-token) or 4 dimensions (multi-token), got ", + x.dim()); + } +} + +// Old function - commented out for reference +// void selective_state_update_old(TensorView state, TensorView x, TensorView dt, TensorView +// output, +// TensorView A, TensorView B, TensorView C, TensorView D, +// Optional z, Optional dt_bias, bool +// dt_softplus, Optional state_batch_indices, int64_t +// pad_slot_id) { +// auto const batch = x.size(0); +// auto const state_cache_size = state.size(0); +// auto const nheads = state.size(1); +// auto const dim = state.size(2); +// auto const dstate = state.size(3); +// auto const ngroups = B.size(1); + +// FLASHINFER_CHECK(state_cache_size >= batch, "state.size(0) must be >= x.size(0)"); + +// FLASHINFER_CHECK(nheads % ngroups == 0, "nheads must be divisible by ngroups"); + +// // Check x shape and strides +// CHECK_DIM(3, x); +// FLASHINFER_CHECK(x.size(1) == nheads, "x.size(1) must equal nheads"); +// FLASHINFER_CHECK(x.size(2) == dim, "x.size(2) must equal dim"); +// CHECK_LAST_DIM_CONTIGUOUS_INPUT(x); +// FLASHINFER_CHECK(x.stride(1) == dim, "x.stride(1) must equal dim, got ", x.stride(1), +// " expected ", x.size(2)); + +// // Check output shape and strides +// CHECK_DIM(3, output); +// CHECK_LAST_DIM_CONTIGUOUS(output); +// FLASHINFER_CHECK(output.size(1) == nheads, "output.size(1) must equal nheads"); +// FLASHINFER_CHECK(output.size(2) == dim, "output.size(2) must equal dim"); +// FLASHINFER_CHECK(output.stride(1) == dim, "output.stride(1) must equal dim"); + +// // Check dt shape and strides +// CHECK_CUDA(dt); +// CHECK_DIM(3, dt); // dt: {batch, nheads, dim} +// FLASHINFER_CHECK(dt.size(1) == nheads, "dt.size(1) must equal nheads"); +// FLASHINFER_CHECK(dt.size(2) == dim, "dt.size(2) must equal dim"); +// FLASHINFER_CHECK(dt.stride(1) == 1, "dt.stride(1) must be 1, got ", dt.stride(1)); +// FLASHINFER_CHECK(dt.stride(2) == 0, "dt.stride(2) must be 0 (broadcasted), got ", +// dt.stride(2)); + +// // Check state - fully contiguous +// CHECK_INPUT(state); // CUDA + fully contiguous (uses TVM FFI) +// CHECK_DIM(4, state); // state: {state_cache_size, nheads, dim, dstate} + +// // Check B shape and strides +// CHECK_CUDA(B); +// CHECK_DIM(3, B); // B: {batch, B.size(1), dstate} +// FLASHINFER_CHECK(B.size(0) == batch, "B.size(0) must equal batch"); +// FLASHINFER_CHECK(B.size(1) == ngroups, "B.size(1) must equal ngroups"); +// FLASHINFER_CHECK(B.size(2) == dstate, "B.size(2) must equal dstate"); +// CHECK_LAST_DIM_CONTIGUOUS(B); // stride(2) == 1 +// FLASHINFER_CHECK(B.stride(1) == B.size(2), "B.stride(1) must equal dstate, got ", B.stride(1), +// " expected ", B.size(2)); + +// // Check C shape and strides +// CHECK_CUDA(C); +// CHECK_LAST_DIM_CONTIGUOUS(C); // stride(2) == 1 +// CHECK_DIM(3, C); // C: {batch, C.size(1), dstate} +// FLASHINFER_CHECK(C.stride(1) == C.size(2), "C.stride(1) must equal dstate, got ", C.stride(1), +// " expected ", C.size(2)); +// FLASHINFER_CHECK(C.size(0) == batch, "C.size(0) must equal batch"); +// FLASHINFER_CHECK(C.size(1) == ngroups, "C.size(1) must equal ngroups"); +// FLASHINFER_CHECK(C.size(2) == dstate, "C.size(2) must equal dstate"); + +// // Check D - specific stride patterns indicating broadcasting +// CHECK_CUDA(D); +// CHECK_DIM(2, D); // D: {nheads, dim} +// FLASHINFER_CHECK(D.size(0) == nheads, "D.size(0) must equal nheads"); +// FLASHINFER_CHECK(D.size(1) == dim, "D.size(1) must equal dim"); +// FLASHINFER_CHECK(D.stride(0) == 1, "D.stride(0) must be 1, got ", D.stride(0)); +// FLASHINFER_CHECK(D.stride(1) == 0, "D.stride(1) must be 0 (broadcasted), got ", D.stride(1)); + +// // Check A - specific stride patterns indicating broadcasting +// CHECK_CUDA(A); +// CHECK_DIM(3, A); // A: {nheads, dim, dstate} +// FLASHINFER_CHECK(A.size(0) == nheads, "A.size(0) must equal nheads"); +// FLASHINFER_CHECK(A.size(1) == dim, "A.size(1) must equal dim"); +// FLASHINFER_CHECK(A.size(2) == dstate, "A.size(2) must equal dstate"); +// FLASHINFER_CHECK(A.stride(1) == 0, "A.stride(1) must be 0 (broadcasted), got ", A.stride(1)); +// FLASHINFER_CHECK(A.stride(2) == 0, "A.stride(2) must be 0 (broadcasted), got ", A.stride(2)); + +// // Optional dt_bias check +// if (dt_bias.has_value()) { +// auto& bias = dt_bias.value(); +// CHECK_CUDA(bias); +// CHECK_DIM(2, bias); // dt_bias: {nheads, dim} +// FLASHINFER_CHECK(bias.size(0) == nheads, "dt_bias.size(0) must equal nheads"); +// FLASHINFER_CHECK(bias.size(1) == dim, "dt_bias.size(1) must equal dim"); +// FLASHINFER_CHECK(bias.stride(0) == 1, "dt_bias.stride(0) must be 1, got ", bias.stride(0)); +// FLASHINFER_CHECK(bias.stride(1) == 0, "dt_bias.stride(1) must be 0 (broadcasted), got ", +// bias.stride(1)); +// } + +// if (z.has_value()) { +// auto& z_tensor = z.value(); +// CHECK_CUDA(z_tensor); +// CHECK_DIM(3, z_tensor); // z: {batch, nheads, dim} +// FLASHINFER_CHECK(z_tensor.size(0) == batch, "z.size(0) must equal batch"); +// FLASHINFER_CHECK(z_tensor.size(1) == nheads, "z.size(1) must equal nheads"); +// FLASHINFER_CHECK(z_tensor.size(2) == dim, "z.size(2) must equal dim"); +// CHECK_LAST_DIM_CONTIGUOUS_INPUT(z_tensor); +// FLASHINFER_CHECK(z_tensor.stride(1) == dim, "z.stride(1) must equal dim, got ", +// z_tensor.stride(1), " expected ", z_tensor.size(2)); +// } + +// if (state_batch_indices) { +// CHECK_DIM(1, (*state_batch_indices)); +// FLASHINFER_CHECK(state_batch_indices.value().size(0) == batch, +// "state_batch_indices.shape must be (", batch, ")"); +// } + +// SelectiveStateUpdateParams p; + +// // copy dimensions +// p.batch = batch; +// p.nheads = nheads; +// p.dim = dim; +// p.dstate = dstate; +// p.ngroups = ngroups; +// p.state_cache_size = state_cache_size; +// p.dt_softplus = dt_softplus; +// p.pad_slot_id = pad_slot_id; + +// // Copy strides +// p.x_stride_batch = x.stride(0); +// p.dt_stride_batch = dt.stride(0); +// p.B_stride_batch = B.stride(0); +// p.C_stride_batch = C.stride(0); +// p.out_stride_batch = output.stride(0); +// if (state_batch_indices) p.state_batch_indices = state_batch_indices.value().data_ptr(); + +// // Copy pointers +// p.state = state.data_ptr(); +// p.x = x.data_ptr(); +// p.dt = dt.data_ptr(); +// p.output = output.data_ptr(); +// if (dt_bias) { +// p.dt_bias = dt_bias.value().data_ptr(); +// } +// if (z) { +// p.z = z.value().data_ptr(); +// p.z_stride_batch = z.value().stride(0); +// } +// p.A = A.data_ptr(); +// p.B = B.data_ptr(); +// p.C = C.data_ptr(); +// p.D = D.data_ptr(); + +// // Set device and get stream +// ffi::CUDADeviceGuard device_guard(state.device().device_id); +// const cudaStream_t stream = get_stream(state.device()); + +// // Dispatch based on dtype combination +// DLDataType state_dtype = state.dtype(); +// DLDataType input_dtype = x.dtype(); +// DLDataType weight_dtype = dt.dtype(); +// DLDataType matrixA_dtype = A.dtype(); + +// int64_t state_dtype_code = encode_dlpack_dtype(state_dtype); +// int64_t input_dtype_code = encode_dlpack_dtype(input_dtype); +// int64_t weight_dtype_code = encode_dlpack_dtype(weight_dtype); +// int64_t matrixA_dtype_code = encode_dlpack_dtype(matrixA_dtype); + +// // Pack all dtype codes into a single value for switching +// auto dtype_key = +// std::make_tuple(state_dtype_code, input_dtype_code, weight_dtype_code, matrixA_dtype_code); + +// if (dtype_key == std::make_tuple(bfloat16_code, bfloat16_code, bfloat16_code, float32_code)) { +// using state_t = nv_bfloat16; +// using input_t = nv_bfloat16; +// using weight_t = nv_bfloat16; +// using matrixA_t = float; +// invokeSelectiveStateUpdate(p, stream); +// } else if (dtype_key == +// std::make_tuple(float16_code, bfloat16_code, bfloat16_code, float32_code)) { +// using state_t = half; +// using input_t = nv_bfloat16; +// using weight_t = nv_bfloat16; +// using matrixA_t = float; +// invokeSelectiveStateUpdate(p, stream); +// } else if (dtype_key == +// std::make_tuple(float32_code, bfloat16_code, bfloat16_code, float32_code)) { +// using state_t = float; +// using input_t = nv_bfloat16; +// using weight_t = nv_bfloat16; +// using matrixA_t = float; +// invokeSelectiveStateUpdate(p, stream); +// } else if (dtype_key == +// std::make_tuple(bfloat16_code, bfloat16_code, float32_code, float32_code)) { +// using state_t = nv_bfloat16; +// using input_t = nv_bfloat16; +// using weight_t = float; +// using matrixA_t = float; +// invokeSelectiveStateUpdate(p, stream); +// } else if (dtype_key == +// std::make_tuple(float16_code, bfloat16_code, float32_code, float32_code)) { +// using state_t = half; +// using input_t = nv_bfloat16; +// using weight_t = float; +// using matrixA_t = float; +// invokeSelectiveStateUpdate(p, stream); +// } else if (dtype_key == +// std::make_tuple(float32_code, bfloat16_code, float32_code, float32_code)) { +// using state_t = float; +// using input_t = nv_bfloat16; +// using weight_t = float; +// using matrixA_t = float; +// invokeSelectiveStateUpdate(p, stream); +// } else { +// // Default case: unsupported dtype combination +// TVM_FFI_ICHECK(false) +// << "Unsupported dtype combination for selective_state_update: " +// << "state_dtype=" << state_dtype.code << ":" << state_dtype.bits << ", " +// << "input_dtype=" << input_dtype.code << ":" << input_dtype.bits << ", " +// << "weight_dtype=" << weight_dtype.code << ":" << weight_dtype.bits << ", " +// << "matrixA_dtype=" << matrixA_dtype.code << ":" << matrixA_dtype.bits +// << ". Supported combos include:\n" +// << " (state=bfloat16, input=bfloat16, weight=bfloat16, matrixA=float32)\n" +// << " (state=float16, input=bfloat16, weight=bfloat16, matrixA=float32)\n" +// << " (state=float32, input=bfloat16, weight=bfloat16, matrixA=float32)\n" +// << " (state=bfloat16, input=bfloat16, weight=float32, matrixA=float32)\n" +// << " (state=float16, input=bfloat16, weight=float32, matrixA=float32)\n" +// << " (state=float32, input=bfloat16, weight=float32, matrixA=float32)"; +// } +// } + } // namespace flashinfer::mamba diff --git a/include/flashinfer/mamba/selective_state_update.cuh b/include/flashinfer/mamba/selective_state_update.cuh index 364e7e5a03..e4da79585f 100644 --- a/include/flashinfer/mamba/selective_state_update.cuh +++ b/include/flashinfer/mamba/selective_state_update.cuh @@ -38,10 +38,9 @@ constexpr unsigned warpSize = 32; struct SelectiveStateUpdateParams { uint32_t batch{}, nheads{}, dim{}, dstate{}, ngroups{}, state_cache_size{}; int32_t pad_slot_id{-1}; - bool dt_softplus{false}; int64_t x_stride_batch{}, dt_stride_batch{}, B_stride_batch{}, C_stride_batch{}, - out_stride_batch{}, z_stride_batch{}; + out_stride_batch{}, z_stride_batch{}, state_stride_batch{}; void* __restrict__ state{nullptr}; // state_t: (state_cache_size, nheads, dim, dstate) void* __restrict__ x{nullptr}; // input_t: (batch, nheads, dim) @@ -55,6 +54,9 @@ struct SelectiveStateUpdateParams { void* __restrict__ z{nullptr}; // input_t: (batch, nheads, dim) void* __restrict__ output{nullptr}; // input_t: (batch, nheads, dim) void* __restrict__ state_batch_indices{nullptr}; // state_batch_indices: (batch,) + + bool dt_softplus{false}; + bool update_state{true}; }; __forceinline__ __device__ float softplus(float x) { return __logf(1.f + __expf(x)); } @@ -135,20 +137,14 @@ struct SharedStorageSimple { template __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams params) { - auto* __restrict__ output = - reinterpret_cast(params.output); // output: (batch, nheads, dim) - auto* __restrict__ state = - reinterpret_cast(params.state); // state: (batch, nheads, dim, dstate) - - auto const* __restrict__ x = - reinterpret_cast(params.x); // x: (batch, nheads, dim) - auto const* __restrict__ dt = - reinterpret_cast(params.dt); // dt: (batch, nheads) - auto const* __restrict__ A = reinterpret_cast(params.A); // A: (nheads) - auto const* __restrict__ B = - reinterpret_cast(params.B); // B: (batch, ngroups, dstate) - auto const* __restrict__ C = - reinterpret_cast(params.C); // C: (batch, ngroups, dstate) + auto* __restrict__ output = reinterpret_cast(params.output); + auto* __restrict__ state = reinterpret_cast(params.state); + + auto const* __restrict__ x = reinterpret_cast(params.x); + auto const* __restrict__ dt = reinterpret_cast(params.dt); + auto const* __restrict__ A = reinterpret_cast(params.A); + auto const* __restrict__ B = reinterpret_cast(params.B); + auto const* __restrict__ C = reinterpret_cast(params.C); auto const* __restrict__ D = reinterpret_cast(params.D); // D: (nheads, dim) auto const* __restrict__ dt_bias = reinterpret_cast(params.dt_bias); // (nheads) auto const* __restrict__ z = reinterpret_cast(params.z); @@ -168,7 +164,7 @@ __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams auto warp = threadIdx.y; auto const state_batch = (state_batch_indices) ? state_batch_indices[batch] : batch; - state += (state_batch * nheads + head) * DIM * DSTATE; + state += state_batch * params.state_stride_batch + head * DIM * DSTATE; __shared__ SharedStorageSimple sram; @@ -875,10 +871,23 @@ __global__ void selective_state_update_kernel_producer_consumer_horizontal( #endif // FLASHINFER_MAMBA_ENABLE_SM100 +template +std::string format_array(const T (&arr)[N]) { + std::ostringstream oss; + for (size_t i = 0; i < N; ++i) { + if (i > 0) oss << ", "; + oss << arr[i]; + } + return oss.str(); +} + template void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, cudaStream_t stream) { auto [sm_major, sm_minor] = GetCudaComputeCapability(); + constexpr int allowed_dstates[] = {64, 128, 256}; + constexpr int allowed_dims[] = {64, 128, 256}; + #ifdef FLASHINFER_MAMBA_ENABLE_SM100 if (sm_major < 10) // pre-Blackwell #elif defined(FLASHINFER_MAMBA_ENABLE_SM90) @@ -921,32 +930,31 @@ void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, cudaStream_t numWarps><<>>(params); }; - auto dispatch_dstate = [&]() { - switch (params.dstate) { - case 64: - dispatch_dim_dstate.template operator()(); - break; - case 128: - dispatch_dim_dstate.template operator()(); - break; - case 256: - dispatch_dim_dstate.template operator()(); - break; - default: - FLASHINFER_CHECK(false, "Unsupported dstate value. Supported values are: 64, 128, 256"); + auto dispatch_dstate = [&]() { + if (params.dstate == DSTATE) { + dispatch_dim_dstate.template operator()(); + return true; } + return false; }; - switch (params.dim) { - case 64: - dispatch_dstate.template operator()<64>(); - break; - case 128: - dispatch_dstate.template operator()<128>(); - break; - default: - FLASHINFER_CHECK(false, "Unsupported dim value. Supported values are: 64, 128"); - } + auto dispatch_dim = [&]() { + if (params.dim == DIM) { + bool dispatched = [&](std::integer_sequence) { + return (dispatch_dstate.template operator()() || ...); + }(std::make_integer_sequence{}); + FLASHINFER_CHECK(dispatched, "Unsupported dstate value: ", params.dstate, + ".\nSupported values: ", format_array(allowed_dstates)); + return true; + } + return false; + }; + + bool dim_dispatched = [&](std::integer_sequence) { + return (dispatch_dim.template operator()() || ...); + }(std::make_integer_sequence{}); + FLASHINFER_CHECK(dim_dispatched, "Unsupported dim value: ", params.dim, + ".\nSupported values: ", format_array(allowed_dims)); } #ifdef FLASHINFER_MAMBA_ENABLE_SM90 else { @@ -1007,38 +1015,39 @@ void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, cudaStream_t scan_func<<>>(params, tensorState); }; - auto dispatch_dstate = [&]() { - switch (params.dstate) { - case 64: - dispatch_dim_dstate.template operator()(); - break; - case 128: - dispatch_dim_dstate.template operator()(); - break; - case 256: - dispatch_dim_dstate.template operator()(); - break; - default: - FLASHINFER_CHECK(false, "Unsupported dstate value. Supported values are: 64, 128, 256"); + auto dispatch_dstate = [&]() { + if (params.dstate == DSTATE) { + dispatch_dim_dstate.template operator()(); + return true; } + return false; }; - switch (params.dim) { - case 64: - dispatch_dstate.template operator()<64>(); - break; - case 128: - dispatch_dstate.template operator()<128>(); - break; - default: - FLASHINFER_CHECK(false, "Unsupported dim value. Supported values are: 64, 128"); - } + auto dispatch_dim = [&]() { + if (params.dim == DIM) { + bool dispatched = [&](std::integer_sequence) { + return (dispatch_dstate.template operator()() || ...); + }(std::make_integer_sequence{}); + FLASHINFER_CHECK(dispatched, "Unsupported dstate value: ", params.dstate, + ".\nSupported values: ", format_array(allowed_dstates)); + return true; + } + return false; + }; + + bool dim_dispatched = [&](std::integer_sequence) { + return (dispatch_dim.template operator()() || ...); + }(std::make_integer_sequence{}); + FLASHINFER_CHECK(dim_dispatched, "Unsupported dim value: ", params.dim, + ".\nSupported values: ", format_array(allowed_dims)); } #endif #ifdef FLASHINFER_MAMBA_ENABLE_SM100 else { // SM100+ (Blackwell and newer) uses horizontal producer-consumer kernel + constexpr int allowed_heads_groups_ratios[] = {1, 8, 16}; + auto dispatch_dim_dstate = [&]() { // Alignment checks for vectorized loads in Blackwell kernel using load_input_t = PackedAligned; @@ -1073,76 +1082,72 @@ void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, cudaStream_t constexpr auto totalStages = DSTATE / stageCols; constexpr auto numStages = (totalStages >= 4) ? 4 : totalStages; - auto dispatch_heads_groups_ratio = [&]() { - auto scan_func = selective_state_update_kernel_producer_consumer_horizontal< - input_t, weight_t, matrixA_t, state_t, DIM, DSTATE, numConsumers, stageCols, - headsGroupsRatio, numStages>; - - dim3 block(warpSize, numWarps); - dim3 grid(params.batch, params.nheads); - - auto nh = params.nheads; - auto dim = params.dim; - - FLASHINFER_CHECK(reinterpret_cast(params.state) % 128 == - 0); // TMA requires 128B aligned - auto tensorState = tma::createTensorMap( - params.state, params.state_cache_size * nh * dim, DSTATE, DIM, stageCols); - static_assert(DSTATE % stageCols == 0 && DSTATE >= stageCols); - - // Calculate shared memory size and opt-in to extended shared memory - using sram_t = SharedStorageHorizontal; - constexpr size_t smem_size = sizeof(sram_t); - FLASHINFER_CUDA_CHECK(cudaFuncSetAttribute( - scan_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - - scan_func<<>>(params, tensorState); + auto dispatch_ratio = [&]() { + if (params.nheads / params.ngroups == RATIO) { + auto scan_func = selective_state_update_kernel_producer_consumer_horizontal< + input_t, weight_t, matrixA_t, state_t, DIM, DSTATE, numConsumers, stageCols, RATIO, + numStages>; + + dim3 block(warpSize, numWarps); + dim3 grid(params.batch, params.nheads); + + auto nh = params.nheads; + auto dim = params.dim; + + FLASHINFER_CHECK(reinterpret_cast(params.state) % 128 == + 0); // TMA requires 128B aligned + auto tensorState = tma::createTensorMap( + params.state, params.state_cache_size * nh * dim, DSTATE, DIM, stageCols); + static_assert(DSTATE % stageCols == 0 && DSTATE >= stageCols); + + // Calculate shared memory size and opt-in to extended shared memory + using sram_t = SharedStorageHorizontal; + constexpr size_t smem_size = sizeof(sram_t); + FLASHINFER_CUDA_CHECK(cudaFuncSetAttribute( + scan_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + scan_func<<>>(params, tensorState); + return true; + } + return false; }; - switch (params.nheads / params.ngroups) { - case 1: - dispatch_heads_groups_ratio.template operator()<1>(); - break; - case 8: - dispatch_heads_groups_ratio.template operator()<8>(); - break; - case 16: - dispatch_heads_groups_ratio.template operator()<16>(); - break; - default: - FLASHINFER_CHECK(false, - "Unsupported nheads/ngroups ratio: ", params.nheads / params.ngroups, - ". Supported values are: 1, 8, 16"); + bool ratio_dispatched = + [&](std::integer_sequence) { + return (dispatch_ratio.template operator()() || ...); + }(std::make_integer_sequence{}); + FLASHINFER_CHECK(ratio_dispatched, + "Unsupported nheads/ngroups ratio: ", params.nheads / params.ngroups, + ".\nSupported values: ", format_array(allowed_heads_groups_ratios)); + }; + + auto dispatch_dstate = [&]() { + if (params.dstate == DSTATE) { + dispatch_dim_dstate.template operator()(); + return true; } + return false; }; - auto dispatch_dstate = [&]() { - switch (params.dstate) { - case 64: - dispatch_dim_dstate.template operator()(); - break; - case 128: - dispatch_dim_dstate.template operator()(); - break; - case 256: - dispatch_dim_dstate.template operator()(); - break; - default: - FLASHINFER_CHECK(false, "Unsupported dstate value. Supported values are: 64, 128, 256"); + auto dispatch_dim = [&]() { + if (params.dim == DIM) { + bool dispatched = [&](std::integer_sequence) { + return (dispatch_dstate.template operator()() || ...); + }(std::make_integer_sequence{}); + FLASHINFER_CHECK(dispatched, "Unsupported dstate value: ", params.dstate, + ".\nSupported values: ", format_array(allowed_dstates)); + return true; } + return false; }; - switch (params.dim) { - case 64: - dispatch_dstate.template operator()<64>(); - break; - case 128: - dispatch_dstate.template operator()<128>(); - break; - default: - FLASHINFER_CHECK(false, "Unsupported dim value. Supported values are: 64, 128"); - } + bool dim_dispatched = [&](std::integer_sequence) { + return (dispatch_dim.template operator()() || ...); + }(std::make_integer_sequence{}); + FLASHINFER_CHECK(dim_dispatched, "Unsupported dim value: ", params.dim, + ".\nSupported values: ", format_array(allowed_dims)); } #endif } From 295ae56fa5d69e247f0767321c7a8132e787e7e7 Mon Sep 17 00:00:00 2001 From: Igor Shovkun Date: Thu, 29 Jan 2026 10:36:07 -0800 Subject: [PATCH 06/33] Non-contiguous state --- .../flashinfer/mamba/create_tensor_map.cuh | 113 ++++++++ .../mamba/selective_state_update.cuh | 243 ++++++++++++------ 2 files changed, 277 insertions(+), 79 deletions(-) diff --git a/include/flashinfer/mamba/create_tensor_map.cuh b/include/flashinfer/mamba/create_tensor_map.cuh index 4208622dfa..d58ca94bf8 100644 --- a/include/flashinfer/mamba/create_tensor_map.cuh +++ b/include/flashinfer/mamba/create_tensor_map.cuh @@ -85,4 +85,117 @@ inline CUtensorMap createTensorMap(void* matrix_ptr, uint32_t matrix_height, uin return tensor_map; } +inline CUtensorMap buildNdDescriptor(std::type_info const& dtype, + std::vector const& shapes, + std::vector const& strides, + std::vector const& tileShapes, void* gmemAddr) { + // The multiplication factor of the data padding in SMEM. + CUtensorMap desc{}; + CUtensorMapDataType tmaDataFormat; + int dtype_size{}; + if (dtype == typeid(float)) { + tmaDataFormat = CU_TENSOR_MAP_DATA_TYPE_FLOAT32; + dtype_size = sizeof(float); + } else if (dtype == typeid(half)) { + tmaDataFormat = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + dtype_size = sizeof(half); + } else if (dtype == typeid(__nv_bfloat16)) { + tmaDataFormat = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + dtype_size = sizeof(__nv_bfloat16); + } else { + throw std::invalid_argument("buildNdDescriptor: unsupported dtype"); + } + + // The swizzle type. + CUtensorMapSwizzle swizzleType{CU_TENSOR_MAP_SWIZZLE_NONE}; + + // Check gmem address must be 16B-aligned + FLASHINFER_CHECK((reinterpret_cast(gmemAddr) & 0b1111) == 0, + "Tensor must be 16B-aligned"); + + // Check shape must be in range [1, 2^32] + int32_t dim = shapes.size(); + // dimensions for batched gemm with blocked layout. + // Check shape range. + for (int32_t ii = 0; ii < dim; ++ii) { + FLASHINFER_CHECK(shapes[ii] >= (uint64_t(1))); // Size must be min 1 + FLASHINFER_CHECK(shapes[ii] <= (uint64_t(1) << 32)); // Size must be max 2^32 + } + + // TMA descriptor does not store the zeroth stride and assumes it is 1. + FLASHINFER_CHECK(static_cast(strides.size()) == dim); + FLASHINFER_CHECK(strides[0] == 1); + + // Build strides in bytes. + // cuTensorMapEncodeTiled ignores the stride of the first dimension (implicitly 1). + std::vector stridesInBytes(dim - 1); + for (int32_t ii = 0; ii < dim - 1; ++ii) { + stridesInBytes[ii] = strides[ii + 1] * dtype_size; + } + + // Build box dim array. If tileShapes is smaller than dim, just fill with 1s. + FLASHINFER_CHECK(static_cast(tileShapes.size()) <= dim); + std::vector boxDim(dim, 1); + boxDim[0] = tileShapes[0]; + for (size_t ii = 1; ii < tileShapes.size(); ++ii) { + if (tileShapes[ii] > 256) { + std::cerr << "buildNdTmaDescriptor: boxDim too large " << tileShapes[ii] << std::endl; + FLASHINFER_CHECK(false); + } else { + boxDim[ii] = tileShapes[ii]; + } + } + + // Set tile strides to 1; + std::vector tileStrides(dim, 1); + + // Build the descriptor. + CUresult result = + cuTensorMapEncodeTiled(&desc, tmaDataFormat, + /*tensorRank=*/dim, gmemAddr, shapes.data(), stridesInBytes.data(), + boxDim.data(), tileStrides.data(), + /*interleave=*/CU_TENSOR_MAP_INTERLEAVE_NONE, swizzleType, + /*l2Promotion=*/CU_TENSOR_MAP_L2_PROMOTION_L2_128B, + /*oobFill=*/CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + + if (result != CUDA_SUCCESS) { + char const* errorString; + cuGetErrorString(result, &errorString); + std::stringstream ss; + ss << "Error: Failed to initialize the TMA descriptor " << result << std::endl; + + ss << "tmaFormat: " << static_cast(tmaDataFormat) << " dim: " << dim + << " gmem: " << gmemAddr << std::endl; + + ss << "Shape: "; + for (int ii = 0; ii < dim; ++ii) { + ss << shapes[ii] << " "; + } + ss << std::endl; + + ss << "Stride: "; + for (int ii = 0; ii < dim - 1; ++ii) { + ss << stridesInBytes[ii] << " "; + } + ss << std::endl; + + ss << "tileShapes: "; + for (int ii = 0; ii < dim; ++ii) { + ss << boxDim[ii] << " "; + } + ss << std::endl; + + ss << "tileStrides: "; + for (int ii = 0; ii < dim; ++ii) { + ss << tileStrides[ii] << " "; + } + ss << std::endl; + ss << "swizzleType: " << int(swizzleType) << std::endl; + ss << "(in " << __FILE__ << ":" << __LINE__ << ")" << std::endl; + throw std::runtime_error(ss.str()); + } + + return desc; +} + } // namespace flashinfer::mamba::tma diff --git a/include/flashinfer/mamba/selective_state_update.cuh b/include/flashinfer/mamba/selective_state_update.cuh index e4da79585f..41c9830d63 100644 --- a/include/flashinfer/mamba/selective_state_update.cuh +++ b/include/flashinfer/mamba/selective_state_update.cuh @@ -264,7 +264,7 @@ __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams template -struct SharedStorage { +struct SharedStorageVertical { alignas(128) state_t state[numStages][rowsPerStage * dstate]; alignas(alignof(PackedAligned)) input_t x[dim]; float out[dim]; // dt is special cause we're gonna store input in there as well @@ -278,12 +278,100 @@ struct SharedStorage { barrier_t bar_consumers; }; +template +__device__ __forceinline__ void producer_func_vertical(SramT& sram, CUtensorMap const& tensorState, + int batch, int head) { +#ifdef FLASHINFER_MAMBA_ENABLE_SM90 + namespace cde = cuda::device::experimental; + + auto constexpr stagesReadOnly = numStages; + auto constexpr stagesBoth = DIM / rowsPerStage - numStages; + auto constexpr stagesWriteOnly = numStages; + + auto constexpr bytesState = rowsPerStage * DSTATE * sizeof(state_t); + auto constexpr bytesToArrive = bytesState; + + // Phase 1: Read only (filling the pipeline) +#pragma unroll + for (int iter = 0; iter < stagesReadOnly; ++iter) { + auto const stage = iter % numStages; + auto const d = iter * rowsPerStage; + + sram.bar_empty[stage].wait(sram.bar_empty[stage].arrive()); + + if constexpr (readState) { + cde::cp_async_bulk_tensor_4d_global_to_shared(&sram.state[stage][0], &tensorState, 0, d, head, + batch, sram.bar_full[stage]); + + auto const _ = cuda::device::barrier_arrive_tx(sram.bar_full[stage], 1, bytesToArrive); + } else { + auto const _ = sram.bar_full[stage].arrive(); + } + } + + // Phase 2: Both read and write (steady state) +#pragma unroll + for (int iter = 0; iter < stagesBoth; ++iter) { + auto const stage = (stagesReadOnly + iter) % numStages; + auto const d_read = (stagesReadOnly + iter) * rowsPerStage; + auto const d_write = iter * rowsPerStage; + + sram.bar_empty[stage].wait(sram.bar_empty[stage].arrive()); + + if constexpr (readState || writeState) { + // Unblock async proxy for writeback + cde::fence_proxy_async_shared_cta(); + // Writeback + if constexpr (writeState) { + cde::cp_async_bulk_tensor_4d_shared_to_global(&tensorState, 0, d_write, head, batch, + &sram.state[stage][0]); + + cde::cp_async_bulk_commit_group(); + cde::cp_async_bulk_wait_group_read<0>(); + } + + // Read next + if constexpr (readState) { + cde::cp_async_bulk_tensor_4d_global_to_shared(&sram.state[stage][0], &tensorState, 0, + d_read, head, batch, sram.bar_full[stage]); + auto const _ = cuda::device::barrier_arrive_tx(sram.bar_full[stage], 1, bytesToArrive); + } else { + auto const _ = sram.bar_full[stage].arrive(); + } + } else { + auto const _ = sram.bar_full[stage].arrive(); + } + } + + // Phase 3: Write only (draining the pipeline) +#pragma unroll + for (int iter = 0; iter < stagesWriteOnly; ++iter) { + auto const stage = (stagesReadOnly + stagesBoth + iter) % numStages; + auto const d_write = (stagesBoth + iter) * rowsPerStage; + + sram.bar_empty[stage].wait(sram.bar_empty[stage].arrive()); + + if constexpr (writeState) { + // Unblock async proxy for writeback + cde::fence_proxy_async_shared_cta(); + cde::cp_async_bulk_tensor_4d_shared_to_global(&tensorState, 0, d_write, head, batch, + &sram.state[stage][0]); + + cde::cp_async_bulk_commit_group(); + cde::cp_async_bulk_wait_group_read<0>(); + } + } +#endif +} + template __device__ __forceinline__ void consumer_func_vertical( int lane, int warp, float d_value, float dt_value, float dA, - SharedStorage& - sram) { + SharedStorageVertical& sram) { +#ifdef FLASHINFER_MAMBA_ENABLE_SM90 namespace cde = cuda::device::experimental; for (auto dBegin = 0, stage = 0; dBegin < DIM; dBegin += rowsPerStage, stage = (stage + 1) % numStages) { @@ -364,6 +452,7 @@ __device__ __forceinline__ void consumer_func_vertical( cde::fence_proxy_async_shared_cta(); auto _ = sram.bar_empty[stage].arrive(); } +#endif } template ; - // Use dynamic shared memory to allow opting into extended shared memory on SM90+ - extern __shared__ __align__(128) char smem[]; - sram_t& sram = *reinterpret_cast(smem); + extern __shared__ uint8_t sbuffer[]; + using sram_t = SharedStorageVertical; + auto& sram = *reinterpret_cast(sbuffer); namespace cde = cuda::device::experimental; namespace cg = cooperative_groups; @@ -421,41 +509,22 @@ __global__ void selective_state_update_kernel_producer_consumer_vertical( if (warp == consumerWarps) // producer { - auto const state_offset = (state_batch * nheads + head) * DIM; - - for (int d = 0, stage = 0; d < DIM + rowsPerStage * numStages; - d += rowsPerStage, stage = (stage + 1) % numStages) { - if (lane == 0) { - cg::invoke_one(cg::coalesced_threads(), [&]() { - sram.bar_empty[stage].wait(sram.bar_empty[stage].arrive()); - - if (state_batch != params.pad_slot_id) { - // Writeback - if (d >= rowsPerStage * numStages) { - cde::cp_async_bulk_tensor_2d_shared_to_global( - &tensorState, - /*x*/ 0, - /*y*/ state_offset + d - rowsPerStage * numStages, &sram.state[stage][0]); - cde::cp_async_bulk_commit_group(); - cde::cp_async_bulk_wait_group_read<0>(); - } - - if (d < DIM) { - cde::cp_async_bulk_tensor_2d_global_to_shared(&sram.state[stage][0], &tensorState, - /*x*/ 0, /*y*/ state_offset + d, - sram.bar_full[stage]); + // auto const state_offset = (state_batch * nheads + head) * DIM; + auto const read_state = (state_batch != params.pad_slot_id); + auto const write_state = read_state && params.update_state; - // Unblock the consumers - auto constexpr bytesState = rowsPerStage * DSTATE * sizeof(state_t); - auto constexpr bytesToArrive = bytesState; - auto const _ = - cuda::device::barrier_arrive_tx(sram.bar_full[stage], 1, bytesToArrive); - } - } else { - auto const _ = sram.bar_full[stage].arrive(); - } - }); - } + if (lane == 0) { + cg::invoke_one(cg::coalesced_threads(), [&]() { + if (read_state && write_state) + producer_func_vertical( + sram, tensorState, state_batch, head); + else if (read_state && !write_state) + producer_func_vertical( + sram, tensorState, state_batch, head); + else + producer_func_vertical( + sram, tensorState, state_batch, head); + }); } } else { // consumers @@ -553,11 +622,11 @@ struct SharedStorageHorizontal { barrier_t bar_consumers; }; -template +template __device__ __forceinline__ void producer_func_horizontal(SramT& sram, - CUtensorMap const& tensorState, - int state_offset) { + CUtensorMap const& tensorState, int batch, + int head) { namespace cde = cuda::device::experimental; auto constexpr stagesReadOnly = numStages; @@ -575,9 +644,9 @@ __device__ __forceinline__ void producer_func_horizontal(SramT& sram, sram.bar_empty[stage].wait(sram.bar_empty[stage].arrive()); - if constexpr (useStateCache) { - cde::cp_async_bulk_tensor_2d_global_to_shared(&sram.state[stage][0], &tensorState, /*x*/ i, - /*y*/ state_offset, sram.bar_full[stage]); + if constexpr (readState) { + cde::cp_async_bulk_tensor_4d_global_to_shared(&sram.state[stage][0], &tensorState, i, 0, head, + batch, sram.bar_full[stage]); auto const _ = cuda::device::barrier_arrive_tx(sram.bar_full[stage], 1, bytesToArrive); } else { auto const _ = sram.bar_full[stage].arrive(); @@ -593,20 +662,25 @@ __device__ __forceinline__ void producer_func_horizontal(SramT& sram, sram.bar_empty[stage].wait(sram.bar_empty[stage].arrive()); - if constexpr (useStateCache) { + if constexpr (readState || writeState) { // Unblock async proxy for writeback cde::fence_proxy_async_shared_cta(); // Writeback - cde::cp_async_bulk_tensor_2d_shared_to_global(&tensorState, /*x*/ i_write, - /*y*/ state_offset, &sram.state[stage][0]); - cde::cp_async_bulk_commit_group(); - cde::cp_async_bulk_wait_group_read<0>(); + if constexpr (writeState) { + cde::cp_async_bulk_tensor_4d_shared_to_global(&tensorState, i_write, 0, head, batch, + &sram.state[stage][0]); + cde::cp_async_bulk_commit_group(); + cde::cp_async_bulk_wait_group_read<0>(); + } // Read next - cde::cp_async_bulk_tensor_2d_global_to_shared(&sram.state[stage][0], &tensorState, - /*x*/ i_read, /*y*/ state_offset, - sram.bar_full[stage]); - auto const _ = cuda::device::barrier_arrive_tx(sram.bar_full[stage], 1, bytesToArrive); + if constexpr (readState) { + cde::cp_async_bulk_tensor_4d_global_to_shared(&sram.state[stage][0], &tensorState, i_read, + 0, head, batch, sram.bar_full[stage]); + auto const _ = cuda::device::barrier_arrive_tx(sram.bar_full[stage], 1, bytesToArrive); + } else { + auto const _ = sram.bar_full[stage].arrive(); + } } else { auto const _ = sram.bar_full[stage].arrive(); } @@ -620,11 +694,11 @@ __device__ __forceinline__ void producer_func_horizontal(SramT& sram, sram.bar_empty[stage].wait(sram.bar_empty[stage].arrive()); - if constexpr (useStateCache) { + if constexpr (writeState) { // Unblock async proxy for writeback cde::fence_proxy_async_shared_cta(); - cde::cp_async_bulk_tensor_2d_shared_to_global(&tensorState, /*x*/ i_write, - /*y*/ state_offset, &sram.state[stage][0]); + cde::cp_async_bulk_tensor_4d_shared_to_global(&tensorState, i_write, 0, head, batch, + &sram.state[stage][0]); cde::cp_async_bulk_commit_group(); cde::cp_async_bulk_wait_group_read<0>(); } @@ -781,15 +855,19 @@ __global__ void selective_state_update_kernel_producer_consumer_horizontal( if (warp == consumerWarps) // producer { - auto const state_offset = (state_batch * nheads + head) * DIM; + auto const read_state = (state_batch != params.pad_slot_id); + auto const write_state = read_state && params.update_state; cg::invoke_one(cg::coalesced_threads(), [&]() { - if (state_batch != params.pad_slot_id) - producer_func_horizontal( - sram, tensorState, state_offset); + if (read_state && write_state) + producer_func_horizontal( + sram, tensorState, state_batch, head); + else if (read_state && !write_state) + producer_func_horizontal( + sram, tensorState, state_batch, head); else - producer_func_horizontal( - sram, tensorState, state_offset); + producer_func_horizontal( + sram, tensorState, state_batch, head); }); } else { // consumers @@ -1000,19 +1078,20 @@ void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, cudaStream_t auto nh = params.nheads; auto dim = params.dim; - FLASHINFER_CHECK(reinterpret_cast(params.state) % 128 == - 0); // TMA requires 128B aligned - auto tensorState = tma::createTensorMap( - params.state, params.state_cache_size * nh * dim, DSTATE, rowsPerStage, DSTATE); + auto state_tensor = + tma::buildNdDescriptor(typeid(state_t), + /*shapes*/ {DSTATE, DIM, nh, params.state_cache_size}, + /*strides*/ {1, DSTATE, DSTATE * DIM, params.state_stride_batch}, + /*tiles*/ {DSTATE, rowsPerStage, 1, 1}, params.state); // Calculate shared memory size and opt-in to extended shared memory - using sram_t = SharedStorage; + using sram_t = SharedStorageVertical; constexpr size_t smem_size = sizeof(sram_t); FLASHINFER_CUDA_CHECK( cudaFuncSetAttribute(scan_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - scan_func<<>>(params, tensorState); + scan_func<<>>(params, state_tensor); }; auto dispatch_dstate = [&]() { @@ -1094,10 +1173,16 @@ void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, cudaStream_t auto nh = params.nheads; auto dim = params.dim; - FLASHINFER_CHECK(reinterpret_cast(params.state) % 128 == - 0); // TMA requires 128B aligned - auto tensorState = tma::createTensorMap( - params.state, params.state_cache_size * nh * dim, DSTATE, DIM, stageCols); + // FLASHINFER_CHECK(reinterpret_cast(params.state) % 128 == + // 0); // TMA requires 128B aligned + // auto tensorState = tma::createTensorMap( + // params.state, params.state_cache_size * nh * dim, DSTATE, DIM, stageCols); + + auto state_tensor = tma::buildNdDescriptor( + typeid(state_t), + /*shapes*/ {DSTATE, DIM, nh, params.state_cache_size}, + /*strides*/ {1, DSTATE, DSTATE * DIM, params.state_stride_batch}, + /*tiles*/ {stageCols, DIM, 1, 1}, params.state); static_assert(DSTATE % stageCols == 0 && DSTATE >= stageCols); // Calculate shared memory size and opt-in to extended shared memory @@ -1107,7 +1192,7 @@ void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, cudaStream_t FLASHINFER_CUDA_CHECK(cudaFuncSetAttribute( scan_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - scan_func<<>>(params, tensorState); + scan_func<<>>(params, state_tensor); return true; } return false; From 55416243399530cddedc439defac4c3f7ea2aff2 Mon Sep 17 00:00:00 2001 From: Igor Shovkun Date: Thu, 29 Jan 2026 11:17:40 -0800 Subject: [PATCH 07/33] Simplify code for template dispatching --- csrc/selective_state_update.cu | 136 +++++++++++------- .../mamba/test_selective_state_update_stp.py | 106 +++++++++++++- 2 files changed, 184 insertions(+), 58 deletions(-) diff --git a/csrc/selective_state_update.cu b/csrc/selective_state_update.cu index bd7d96f74a..a3b0c9bee7 100644 --- a/csrc/selective_state_update.cu +++ b/csrc/selective_state_update.cu @@ -14,6 +14,7 @@ * limitations under the License. */ #include +#include #include "tvm_ffi_utils.h" @@ -125,6 +126,51 @@ inline void validate_dtypes( } } +// Helper to convert dtype code to string for error messages +inline const char* dtype_code_to_string(int64_t code) { + if (code == bfloat16_code) return "bfloat16"; + if (code == float16_code) return "float16"; + if (code == float32_code) return "float32"; + return "unknown"; +} + +// Type traits to map dtype codes to C++ types +template +struct DTypeToType; + +template <> +struct DTypeToType { + using type = nv_bfloat16; +}; +template <> +struct DTypeToType { + using type = half; +}; +template <> +struct DTypeToType { + using type = float; +}; + +// Allowed dtype combinations: {state_code, input_code, weight_code, matrixA_code} +constexpr std::tuple allowed_dtype_combos[] = { + {bfloat16_code, bfloat16_code, bfloat16_code, float32_code}, + {float16_code, bfloat16_code, bfloat16_code, float32_code}, + {float32_code, bfloat16_code, bfloat16_code, float32_code}, + {bfloat16_code, bfloat16_code, float32_code, float32_code}, + {float16_code, bfloat16_code, float32_code, float32_code}, + {float32_code, bfloat16_code, float32_code, float32_code}, +}; + +// Helper to dispatch to the right template instantiation +template +void dispatchCombo(SelectiveStateUpdateParams& p, cudaStream_t stream) { + using state_t = typename DTypeToType::type; + using input_t = typename DTypeToType::type; + using weight_t = typename DTypeToType::type; + using matrixA_t = typename DTypeToType::type; + invokeSelectiveStateUpdate(p, stream); +} + void run_selective_state_update_stp(TensorView const& state, TensorView const& x, TensorView const& dt, TensorView const& A, TensorView const& B, TensorView const& C, TensorView const& D, @@ -286,61 +332,41 @@ void run_selective_state_update_stp(TensorView const& state, TensorView const& x auto dtype_key = std::make_tuple(state_dtype_code, input_dtype_code, weight_dtype_code, matrixA_dtype_code); - if (dtype_key == std::make_tuple(bfloat16_code, bfloat16_code, bfloat16_code, float32_code)) { - using state_t = nv_bfloat16; - using input_t = nv_bfloat16; - using weight_t = nv_bfloat16; - using matrixA_t = float; - invokeSelectiveStateUpdate(p, stream); - } else if (dtype_key == - std::make_tuple(float16_code, bfloat16_code, bfloat16_code, float32_code)) { - using state_t = half; - using input_t = nv_bfloat16; - using weight_t = nv_bfloat16; - using matrixA_t = float; - invokeSelectiveStateUpdate(p, stream); - } else if (dtype_key == - std::make_tuple(float32_code, bfloat16_code, bfloat16_code, float32_code)) { - using state_t = float; - using input_t = nv_bfloat16; - using weight_t = nv_bfloat16; - using matrixA_t = float; - invokeSelectiveStateUpdate(p, stream); - } else if (dtype_key == - std::make_tuple(bfloat16_code, bfloat16_code, float32_code, float32_code)) { - using state_t = nv_bfloat16; - using input_t = nv_bfloat16; - using weight_t = float; - using matrixA_t = float; - invokeSelectiveStateUpdate(p, stream); - } else if (dtype_key == - std::make_tuple(float16_code, bfloat16_code, float32_code, float32_code)) { - using state_t = half; - using input_t = nv_bfloat16; - using weight_t = float; - using matrixA_t = float; - invokeSelectiveStateUpdate(p, stream); - } else if (dtype_key == - std::make_tuple(float32_code, bfloat16_code, float32_code, float32_code)) { - using state_t = float; - using input_t = nv_bfloat16; - using weight_t = float; - using matrixA_t = float; - invokeSelectiveStateUpdate(p, stream); - } else { - // Default case: unsupported dtype combination - TVM_FFI_ICHECK(false) - << "Unsupported dtype combination for selective_state_update: " << "state_dtype=" - << state_dtype.code << ":" << state_dtype.bits << ", " << "input_dtype=" << input_dtype.code - << ":" << input_dtype.bits << ", " << "weight_dtype=" << weight_dtype.code << ":" - << weight_dtype.bits << ", " << "matrixA_dtype=" << matrixA_dtype.code << ":" - << matrixA_dtype.bits << ". Supported combos include:\n" - << " (state=bfloat16, input=bfloat16, weight=bfloat16, matrixA=float32)\n" - << " (state=float16, input=bfloat16, weight=bfloat16, matrixA=float32)\n" - << " (state=float32, input=bfloat16, weight=bfloat16, matrixA=float32)\n" - << " (state=bfloat16, input=bfloat16, weight=float32, matrixA=float32)\n" - << " (state=float16, input=bfloat16, weight=float32, matrixA=float32)\n" - << " (state=float32, input=bfloat16, weight=float32, matrixA=float32)"; + // Compile-time recursive dispatcher using Y-combinator pattern for lambda self-recursion + auto tryDispatch = [&](const auto& key, auto idx, auto& self) -> bool { + constexpr size_t I = decltype(idx)::value; + if constexpr (I < std::size(allowed_dtype_combos)) { + constexpr auto combo = allowed_dtype_combos[I]; + if (key == combo) { + constexpr auto s = std::get<0>(combo); + constexpr auto i = std::get<1>(combo); + constexpr auto w = std::get<2>(combo); + constexpr auto m = std::get<3>(combo); + dispatchCombo(p, stream); + return true; + } + return self(key, std::integral_constant{}, self); + } + return false; + }; + + // Dispatch using compile-time type traits + if (!tryDispatch(dtype_key, std::integral_constant{}, tryDispatch)) { + // Unsupported dtype combination - build error message dynamically + std::ostringstream error_msg; + error_msg << "Unsupported dtype combination for selective_state_update: " + << "state_dtype=" << state_dtype.code << ":" << state_dtype.bits << ", " + << "input_dtype=" << input_dtype.code << ":" << input_dtype.bits << ", " + << "weight_dtype=" << weight_dtype.code << ":" << weight_dtype.bits << ", " + << "matrixA_dtype=" << matrixA_dtype.code << ":" << matrixA_dtype.bits + << ". Supported combos include:\n"; + for (const auto& combo : allowed_dtype_combos) { + error_msg << " (state=" << dtype_code_to_string(std::get<0>(combo)) + << ", input=" << dtype_code_to_string(std::get<1>(combo)) + << ", weight=" << dtype_code_to_string(std::get<2>(combo)) + << ", matrixA=" << dtype_code_to_string(std::get<3>(combo)) << ")\n"; + } + TVM_FFI_ICHECK(false) << error_msg.str(); } } diff --git a/tests/mamba/test_selective_state_update_stp.py b/tests/mamba/test_selective_state_update_stp.py index 331386ef67..9e87cbe0c8 100644 --- a/tests/mamba/test_selective_state_update_stp.py +++ b/tests/mamba/test_selective_state_update_stp.py @@ -244,15 +244,115 @@ def inputs(self, batch, nheads, dim, dstate, state_dtype, weight_dtype): ) -@pytest.mark.xfail(reason="Non-contiguous state cache not yet supported") +class TestSelectiveStateUpdateDisableStateUpdate(TestSelectiveStateUpdate): + """Test selective_state_update with disable_state_update=True.""" + + @pytest.fixture(params=[1]) + def batch(self, request): + return request.param + + @pytest.fixture(params=[8]) + def nheads(self, request): + return request.param + + @pytest.fixture(params=[128]) + def dim(self, request): + return request.param + + @pytest.fixture(params=[64]) + def dstate(self, request): + return request.param + + @pytest.fixture(params=[torch.bfloat16]) + def state_dtype(self, request): + return request.param + + @pytest.fixture(params=[torch.bfloat16]) + def weight_dtype(self, request): + return request.param + + def run_kernel(self, inputs, out=None): + """Run the flashinfer kernel with disable_state_update=True.""" + return flashinfer.mamba.selective_state_update( + inputs["state_cache"], + inputs["x"], + inputs["dt"], + inputs["A"], + inputs["B"], + inputs["C"], + D=inputs["D"], + z=inputs.get("z"), + dt_bias=inputs["dt_bias"], + dt_softplus=True, + state_batch_indices=inputs["slot_idx"], + pad_slot_id=-1, + out=out, + disable_state_update=True, + ) + + def test_output_correctness(self, inputs, reference_output, use_out_tensor): + """Test that kernel output matches reference but state is not updated.""" + y_ref, state_ref = reference_output + + # Save the initial state before running the kernel + state_initial = inputs["state_cache"].clone() + + # Prepare output tensor if requested + if use_out_tensor: + batch = inputs["x"].shape[0] + nheads = inputs["x"].shape[1] + dim = inputs["x"].shape[2] + out = torch.empty(batch, nheads, dim, dtype=self.INPUT_DTYPE, device="cuda") + else: + out = None + + y_test = self.run_kernel(inputs, out=out) + + # Verify output tensor identity if provided + if use_out_tensor: + assert y_test.data_ptr() == out.data_ptr(), ( + "Returned tensor should be the same object as the provided output tensor" + ) + + # Check that output is still correct + self.assert_outputs_match(y_ref, y_test, msg_prefix="[disable_state_update] ") + + # Check that state was NOT updated (should remain the same as initial) + state_after = inputs["state_cache"] + state_unchanged = torch.allclose( + state_initial, state_after, atol=1e-8, rtol=1e-8 + ) + + if state_unchanged: + print("✓ [disable_state_update] State cache was not modified (as expected)") + else: + print( + "✗ [disable_state_update] State cache was modified (should remain unchanged!)" + ) + # Show where state changed + state_initial_np = state_initial.detach().cpu().float().numpy() + state_after_np = state_after.detach().cpu().float().numpy() + mismatch_mask = ~np.isclose( + state_initial_np, state_after_np, atol=1e-8, rtol=1e-8 + ) + num_changed = np.sum(mismatch_mask) + print( + f"Number of changed state elements: {num_changed} / {state_initial_np.size}" + ) + + assert state_unchanged, ( + "State should not be updated when disable_state_update=True" + ) + + class TestSelectiveStateUpdateNonContiguous(TestSelectiveStateUpdate): """Test selective_state_update with non-contiguous state cache.""" - @pytest.fixture(params=[64, 128]) + @pytest.fixture(params=[128]) def dstate(self, request): return request.param - @pytest.fixture(params=[torch.bfloat16, torch.float32]) + @pytest.fixture(params=[torch.bfloat16]) def state_dtype(self, request): return request.param From ab33cc17d663e5879aa1bc4769d8ff8af977f608 Mon Sep 17 00:00:00 2001 From: Igor Shovkun Date: Thu, 29 Jan 2026 12:18:08 -0800 Subject: [PATCH 08/33] Refactor dispatch logic in selective_state_update.cuh Extract dispatchDimDstate and dispatchRatio helpers to simplify kernel dispatch code and reduce duplication. --- .../mamba/selective_state_update.cuh | 211 ++++++++---------- 1 file changed, 89 insertions(+), 122 deletions(-) diff --git a/include/flashinfer/mamba/selective_state_update.cuh b/include/flashinfer/mamba/selective_state_update.cuh index 41c9830d63..e95054b21e 100644 --- a/include/flashinfer/mamba/selective_state_update.cuh +++ b/include/flashinfer/mamba/selective_state_update.cuh @@ -959,6 +959,61 @@ std::string format_array(const T (&arr)[N]) { return oss.str(); } +// Helper function to dispatch dim and dstate with a kernel launcher +template +void dispatchDimDstate(SelectiveStateUpdateParams& params, + std::integer_sequence, + std::integer_sequence, KernelLauncher&& launcher) { + constexpr int allowed_dims[] = {AllowedDims...}; + constexpr int allowed_dstates[] = {AllowedDstates...}; + + auto dispatch_dim_dstate = [&]() { + launcher.template operator()(); + }; + + auto dispatch_dstate = [&]() { + if (params.dstate == DSTATE) { + dispatch_dim_dstate.template operator()(); + return true; + } + return false; + }; + + auto dispatch_dim = [&]() { + if (params.dim == DIM) { + bool dispatched = (dispatch_dstate.template operator()() || ...); + FLASHINFER_CHECK(dispatched, "Unsupported dstate value: ", params.dstate, + ".\nSupported values: ", format_array(allowed_dstates)); + return true; + } + return false; + }; + + bool dim_dispatched = (dispatch_dim.template operator()() || ...); + FLASHINFER_CHECK(dim_dispatched, "Unsupported dim value: ", params.dim, + ".\nSupported values: ", format_array(allowed_dims)); +} + +// Helper function to dispatch ratio with a kernel launcher +template +void dispatchRatio(SelectiveStateUpdateParams& params, std::integer_sequence, + KernelLauncher&& launcher) { + constexpr int allowed_ratios[] = {AllowedRatios...}; + + auto dispatch_single_ratio = [&]() { + if (params.nheads / params.ngroups == RATIO) { + launcher.template operator()(); + return true; + } + return false; + }; + + bool ratio_dispatched = (dispatch_single_ratio.template operator()() || ...); + FLASHINFER_CHECK(ratio_dispatched, + "Unsupported nheads/ngroups ratio: ", params.nheads / params.ngroups, + ".\nSupported values: ", format_array(allowed_ratios)); +} + template void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, cudaStream_t stream) { auto [sm_major, sm_minor] = GetCudaComputeCapability(); @@ -972,7 +1027,7 @@ void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, cudaStream_t if (sm_major < 9) // pre-Hopper #endif { - auto dispatch_dim_dstate = [&]() { + auto kernel_launcher = [&]() { // Alignment checks for vectorized loads in simple kernel constexpr auto stateLoadSize = getVectorLoadSizeForFullUtilization(); using load_state_t = PackedAligned; @@ -1008,36 +1063,13 @@ void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, cudaStream_t numWarps><<>>(params); }; - auto dispatch_dstate = [&]() { - if (params.dstate == DSTATE) { - dispatch_dim_dstate.template operator()(); - return true; - } - return false; - }; - - auto dispatch_dim = [&]() { - if (params.dim == DIM) { - bool dispatched = [&](std::integer_sequence) { - return (dispatch_dstate.template operator()() || ...); - }(std::make_integer_sequence{}); - FLASHINFER_CHECK(dispatched, "Unsupported dstate value: ", params.dstate, - ".\nSupported values: ", format_array(allowed_dstates)); - return true; - } - return false; - }; - - bool dim_dispatched = [&](std::integer_sequence) { - return (dispatch_dim.template operator()() || ...); - }(std::make_integer_sequence{}); - FLASHINFER_CHECK(dim_dispatched, "Unsupported dim value: ", params.dim, - ".\nSupported values: ", format_array(allowed_dims)); + dispatchDimDstate(params, std::integer_sequence{}, + std::integer_sequence{}, kernel_launcher); } #ifdef FLASHINFER_MAMBA_ENABLE_SM90 else { - auto dispatch_dim_dstate = [&]() { + auto kernel_launcher = [&]() { // Alignment checks for vectorized loads in Hopper kernel // Note: State uses TMA which requires 128B alignment (checked below) // x, z, B, and C use PackedAligned @@ -1094,40 +1126,15 @@ void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, cudaStream_t scan_func<<>>(params, state_tensor); }; - auto dispatch_dstate = [&]() { - if (params.dstate == DSTATE) { - dispatch_dim_dstate.template operator()(); - return true; - } - return false; - }; - - auto dispatch_dim = [&]() { - if (params.dim == DIM) { - bool dispatched = [&](std::integer_sequence) { - return (dispatch_dstate.template operator()() || ...); - }(std::make_integer_sequence{}); - FLASHINFER_CHECK(dispatched, "Unsupported dstate value: ", params.dstate, - ".\nSupported values: ", format_array(allowed_dstates)); - return true; - } - return false; - }; - - bool dim_dispatched = [&](std::integer_sequence) { - return (dispatch_dim.template operator()() || ...); - }(std::make_integer_sequence{}); - FLASHINFER_CHECK(dim_dispatched, "Unsupported dim value: ", params.dim, - ".\nSupported values: ", format_array(allowed_dims)); + dispatchDimDstate(params, std::integer_sequence{}, + std::integer_sequence{}, kernel_launcher); } #endif #ifdef FLASHINFER_MAMBA_ENABLE_SM100 else { // SM100+ (Blackwell and newer) uses horizontal producer-consumer kernel - constexpr int allowed_heads_groups_ratios[] = {1, 8, 16}; - - auto dispatch_dim_dstate = [&]() { + auto kernel_launcher = [&]() { // Alignment checks for vectorized loads in Blackwell kernel using load_input_t = PackedAligned; @@ -1161,78 +1168,38 @@ void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, cudaStream_t constexpr auto totalStages = DSTATE / stageCols; constexpr auto numStages = (totalStages >= 4) ? 4 : totalStages; - auto dispatch_ratio = [&]() { - if (params.nheads / params.ngroups == RATIO) { - auto scan_func = selective_state_update_kernel_producer_consumer_horizontal< - input_t, weight_t, matrixA_t, state_t, DIM, DSTATE, numConsumers, stageCols, RATIO, - numStages>; - - dim3 block(warpSize, numWarps); - dim3 grid(params.batch, params.nheads); - - auto nh = params.nheads; - auto dim = params.dim; - - // FLASHINFER_CHECK(reinterpret_cast(params.state) % 128 == - // 0); // TMA requires 128B aligned - // auto tensorState = tma::createTensorMap( - // params.state, params.state_cache_size * nh * dim, DSTATE, DIM, stageCols); - - auto state_tensor = tma::buildNdDescriptor( - typeid(state_t), - /*shapes*/ {DSTATE, DIM, nh, params.state_cache_size}, - /*strides*/ {1, DSTATE, DSTATE * DIM, params.state_stride_batch}, - /*tiles*/ {stageCols, DIM, 1, 1}, params.state); - static_assert(DSTATE % stageCols == 0 && DSTATE >= stageCols); - - // Calculate shared memory size and opt-in to extended shared memory - using sram_t = SharedStorageHorizontal; - constexpr size_t smem_size = sizeof(sram_t); - FLASHINFER_CUDA_CHECK(cudaFuncSetAttribute( - scan_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - - scan_func<<>>(params, state_tensor); - return true; - } - return false; - }; + auto ratio_launcher = [&]() { + auto scan_func = selective_state_update_kernel_producer_consumer_horizontal< + input_t, weight_t, matrixA_t, state_t, DIM, DSTATE, numConsumers, stageCols, RATIO, + numStages>; - bool ratio_dispatched = - [&](std::integer_sequence) { - return (dispatch_ratio.template operator()() || ...); - }(std::make_integer_sequence{}); - FLASHINFER_CHECK(ratio_dispatched, - "Unsupported nheads/ngroups ratio: ", params.nheads / params.ngroups, - ".\nSupported values: ", format_array(allowed_heads_groups_ratios)); - }; + dim3 block(warpSize, numWarps); + dim3 grid(params.batch, params.nheads); - auto dispatch_dstate = [&]() { - if (params.dstate == DSTATE) { - dispatch_dim_dstate.template operator()(); - return true; - } - return false; - }; + auto nh = params.nheads; + auto dim = params.dim; - auto dispatch_dim = [&]() { - if (params.dim == DIM) { - bool dispatched = [&](std::integer_sequence) { - return (dispatch_dstate.template operator()() || ...); - }(std::make_integer_sequence{}); - FLASHINFER_CHECK(dispatched, "Unsupported dstate value: ", params.dstate, - ".\nSupported values: ", format_array(allowed_dstates)); - return true; - } - return false; + auto state_tensor = + tma::buildNdDescriptor(typeid(state_t), + /*shapes*/ {DSTATE, DIM, nh, params.state_cache_size}, + /*strides*/ {1, DSTATE, DSTATE * DIM, params.state_stride_batch}, + /*tiles*/ {stageCols, DIM, 1, 1}, params.state); + static_assert(DSTATE % stageCols == 0 && DSTATE >= stageCols); + + using sram_t = SharedStorageHorizontal; + constexpr size_t smem_size = sizeof(sram_t); + FLASHINFER_CUDA_CHECK(cudaFuncSetAttribute( + scan_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + scan_func<<>>(params, state_tensor); + }; + + dispatchRatio(params, std::integer_sequence{}, ratio_launcher); }; - bool dim_dispatched = [&](std::integer_sequence) { - return (dispatch_dim.template operator()() || ...); - }(std::make_integer_sequence{}); - FLASHINFER_CHECK(dim_dispatched, "Unsupported dim value: ", params.dim, - ".\nSupported values: ", format_array(allowed_dims)); + dispatchDimDstate(params, std::integer_sequence{}, + std::integer_sequence{}, kernel_launcher); } #endif } From 26271a9faae6998d8d197558a0e7de4e712c3332 Mon Sep 17 00:00:00 2001 From: Igor Shovkun Date: Thu, 29 Jan 2026 12:38:48 -0800 Subject: [PATCH 09/33] Refactor pointer alignement checking away from the logic. --- .../mamba/selective_state_update.cuh | 93 ++++++------------- 1 file changed, 28 insertions(+), 65 deletions(-) diff --git a/include/flashinfer/mamba/selective_state_update.cuh b/include/flashinfer/mamba/selective_state_update.cuh index e95054b21e..e9af5b227b 100644 --- a/include/flashinfer/mamba/selective_state_update.cuh +++ b/include/flashinfer/mamba/selective_state_update.cuh @@ -1014,6 +1014,30 @@ void dispatchRatio(SelectiveStateUpdateParams& params, std::integer_sequence +void check_ptr_alignment_input_vars(const SelectiveStateUpdateParams& params) { + using load_input_t = PackedAligned; + FLASHINFER_CHECK(reinterpret_cast(params.x) % sizeof(load_input_t) == 0, + "x pointer must be aligned to ", sizeof(load_input_t), " bytes"); + FLASHINFER_CHECK((params.x_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, + "x batch stride must be aligned to ", sizeof(load_input_t), " bytes"); + if (params.z) { + FLASHINFER_CHECK(reinterpret_cast(params.z) % sizeof(load_input_t) == 0, + "z pointer must be aligned to ", sizeof(load_input_t), " bytes"); + FLASHINFER_CHECK((params.z_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, + "z batch stride must be aligned to ", sizeof(load_input_t), " bytes"); + } + FLASHINFER_CHECK(reinterpret_cast(params.B) % sizeof(load_input_t) == 0, + "B pointer must be aligned to ", sizeof(load_input_t), " bytes"); + FLASHINFER_CHECK(reinterpret_cast(params.C) % sizeof(load_input_t) == 0, + "C pointer must be aligned to ", sizeof(load_input_t), " bytes"); + FLASHINFER_CHECK((params.B_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, + "B batch stride must be aligned to ", sizeof(load_input_t), " bytes"); + FLASHINFER_CHECK((params.C_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, + "C batch stride must be aligned to ", sizeof(load_input_t), " bytes"); +} + template void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, cudaStream_t stream) { auto [sm_major, sm_minor] = GetCudaComputeCapability(); @@ -1021,6 +1045,9 @@ void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, cudaStream_t constexpr int allowed_dstates[] = {64, 128, 256}; constexpr int allowed_dims[] = {64, 128, 256}; + // Common alignment checks for all kernels + check_ptr_alignment_input_vars(params); + #ifdef FLASHINFER_MAMBA_ENABLE_SM100 if (sm_major < 10) // pre-Blackwell #elif defined(FLASHINFER_MAMBA_ENABLE_SM90) @@ -1028,31 +1055,12 @@ void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, cudaStream_t #endif { auto kernel_launcher = [&]() { - // Alignment checks for vectorized loads in simple kernel + // Additional alignment checks specific to simple kernel constexpr auto stateLoadSize = getVectorLoadSizeForFullUtilization(); using load_state_t = PackedAligned; - using load_input_t = PackedAligned; FLASHINFER_CHECK(reinterpret_cast(params.state) % sizeof(load_state_t) == 0, "state pointer must be aligned to ", sizeof(load_state_t), " bytes"); - FLASHINFER_CHECK(reinterpret_cast(params.x) % sizeof(load_input_t) == 0, - "x pointer must be aligned to ", sizeof(load_input_t), " bytes"); - FLASHINFER_CHECK((params.x_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, - "x batch stride must be aligned to ", sizeof(load_input_t), " bytes"); - if (params.z) { - FLASHINFER_CHECK(reinterpret_cast(params.z) % sizeof(load_input_t) == 0, - "z pointer must be aligned to ", sizeof(load_input_t), " bytes"); - FLASHINFER_CHECK((params.z_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, - "z batch stride must be aligned to ", sizeof(load_input_t), " bytes"); - } - FLASHINFER_CHECK(reinterpret_cast(params.B) % sizeof(load_input_t) == 0, - "B pointer must be aligned to ", sizeof(load_input_t), " bytes"); - FLASHINFER_CHECK(reinterpret_cast(params.C) % sizeof(load_input_t) == 0, - "C pointer must be aligned to ", sizeof(load_input_t), " bytes"); - FLASHINFER_CHECK((params.B_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, - "B batch stride must be aligned to ", sizeof(load_input_t), " bytes"); - FLASHINFER_CHECK((params.C_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, - "C batch stride must be aligned to ", sizeof(load_input_t), " bytes"); FLASHINFER_CHECK((params.dim * params.dstate * sizeof(state_t)) % sizeof(load_state_t) == 0, "state head stride must be aligned to ", sizeof(load_state_t), " bytes"); @@ -1070,30 +1078,7 @@ void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, cudaStream_t else { auto kernel_launcher = [&]() { - // Alignment checks for vectorized loads in Hopper kernel // Note: State uses TMA which requires 128B alignment (checked below) - // x, z, B, and C use PackedAligned - using load_input_t = PackedAligned; - - FLASHINFER_CHECK(reinterpret_cast(params.x) % sizeof(load_input_t) == 0, - "x pointer must be aligned to ", sizeof(load_input_t), " bytes"); - FLASHINFER_CHECK((params.x_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, - "x batch stride must be aligned to ", sizeof(load_input_t), " bytes"); - if (params.z) { - FLASHINFER_CHECK(reinterpret_cast(params.z) % sizeof(load_input_t) == 0, - "z pointer must be aligned to ", sizeof(load_input_t), " bytes"); - FLASHINFER_CHECK((params.z_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, - "z batch stride must be aligned to ", sizeof(load_input_t), " bytes"); - } - FLASHINFER_CHECK(reinterpret_cast(params.B) % sizeof(load_input_t) == 0, - "B pointer must be aligned to ", sizeof(load_input_t), " bytes"); - FLASHINFER_CHECK(reinterpret_cast(params.C) % sizeof(load_input_t) == 0, - "C pointer must be aligned to ", sizeof(load_input_t), " bytes"); - FLASHINFER_CHECK((params.B_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, - "B batch stride must be aligned to ", sizeof(load_input_t), " bytes"); - FLASHINFER_CHECK((params.C_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, - "C batch stride must be aligned to ", sizeof(load_input_t), " bytes"); - constexpr auto numConsumers = 4; constexpr auto numWarps = 1 + numConsumers; constexpr auto numStages = 3; @@ -1135,28 +1120,6 @@ void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, cudaStream_t else { // SM100+ (Blackwell and newer) uses horizontal producer-consumer kernel auto kernel_launcher = [&]() { - // Alignment checks for vectorized loads in Blackwell kernel - using load_input_t = PackedAligned; - - FLASHINFER_CHECK(reinterpret_cast(params.x) % sizeof(load_input_t) == 0, - "x pointer must be aligned to ", sizeof(load_input_t), " bytes"); - FLASHINFER_CHECK((params.x_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, - "x batch stride must be aligned to ", sizeof(load_input_t), " bytes"); - if (params.z) { - FLASHINFER_CHECK(reinterpret_cast(params.z) % sizeof(load_input_t) == 0, - "z pointer must be aligned to ", sizeof(load_input_t), " bytes"); - FLASHINFER_CHECK((params.z_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, - "z batch stride must be aligned to ", sizeof(load_input_t), " bytes"); - } - FLASHINFER_CHECK(reinterpret_cast(params.B) % sizeof(load_input_t) == 0, - "B pointer must be aligned to ", sizeof(load_input_t), " bytes"); - FLASHINFER_CHECK(reinterpret_cast(params.C) % sizeof(load_input_t) == 0, - "C pointer must be aligned to ", sizeof(load_input_t), " bytes"); - FLASHINFER_CHECK((params.B_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, - "B batch stride must be aligned to ", sizeof(load_input_t), " bytes"); - FLASHINFER_CHECK((params.C_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, - "C batch stride must be aligned to ", sizeof(load_input_t), " bytes"); - // profiling showed that it's good to have 4 producers per 64 rows constexpr auto numConsumers = (DIM / 64) * 4; constexpr auto numProducers = 1; From f3f02f5fd42b9bbbb5cce776db5033c80213a9d9 Mon Sep 17 00:00:00 2001 From: Igor Shovkun Date: Thu, 29 Jan 2026 13:42:03 -0800 Subject: [PATCH 10/33] Support int32 and int64 state_batch_indices in selective_state_update - Add kernel and dispatcher support for int32/int64 state_batch_indices - Update tests to cover int32 indices - Fix test_utils to use int64 slot_idx by default Support int32 and int64 state_batch_indices in selective_state_update - Remove int32 type check to allow both int32 and int64 index types - Add stateIndex_t template parameter to kernels for index type dispatch - Extract kernel implementations to new selective_state_update_stp.cuh - Remove unused TMA helper functions from create_tensor_map.cuh - Add comprehensive MTP (multi-token prediction) test suite --- csrc/selective_state_update.cu | 537 ++++---- .../flashinfer/mamba/create_tensor_map.cuh | 75 -- .../mamba/selective_state_update.cuh | 1098 +--------------- .../mamba/selective_state_update_stp.cuh | 1105 +++++++++++++++++ .../mamba/test_selective_state_update_mtp.py | 714 +++++++++++ .../mamba/test_selective_state_update_stp.py | 60 +- tests/mamba/test_utils.py | 11 +- 7 files changed, 2190 insertions(+), 1410 deletions(-) create mode 100644 include/flashinfer/mamba/selective_state_update_stp.cuh create mode 100644 tests/mamba/test_selective_state_update_mtp.py diff --git a/csrc/selective_state_update.cu b/csrc/selective_state_update.cu index a3b0c9bee7..d925f1ccdb 100644 --- a/csrc/selective_state_update.cu +++ b/csrc/selective_state_update.cu @@ -76,7 +76,6 @@ inline void validate_state_batch_indices(Optional const& state_batch CHECK_CONTIGUOUS((*state_batch_indices)); FLASHINFER_CHECK(state_batch_indices.value().size(0) == batch, "state_batch_indices.shape must be (", batch, ")"); - CHECK_INPUT_TYPE(state_batch_indices.value(), dl_int32); } inline void validate_intermediate_state_indices( @@ -98,7 +97,7 @@ inline void validate_intermediate_states_buffer( } // Validates dtype consistency across tensors -inline void validate_dtypes( +inline void validate_dtype_consistency( TensorView const& state, TensorView const& dt, TensorView const& D, TensorView const& x, TensorView const& B, TensorView const& C, Optional const& dt_bias, Optional const& z = Optional(), @@ -150,25 +149,41 @@ template <> struct DTypeToType { using type = float; }; +template <> +struct DTypeToType { + using type = int32_t; +}; +template <> +struct DTypeToType { + using type = int64_t; +}; -// Allowed dtype combinations: {state_code, input_code, weight_code, matrixA_code} -constexpr std::tuple allowed_dtype_combos[] = { - {bfloat16_code, bfloat16_code, bfloat16_code, float32_code}, - {float16_code, bfloat16_code, bfloat16_code, float32_code}, - {float32_code, bfloat16_code, bfloat16_code, float32_code}, - {bfloat16_code, bfloat16_code, float32_code, float32_code}, - {float16_code, bfloat16_code, float32_code, float32_code}, - {float32_code, bfloat16_code, float32_code, float32_code}, +// Allowed dtype combinations: {state_code, input_code, weight_code, matrixA_code, stateIndex_code} +constexpr std::tuple allowed_dtype_combos[] = { + {bfloat16_code, bfloat16_code, bfloat16_code, float32_code, int32_code}, + {float16_code, bfloat16_code, bfloat16_code, float32_code, int32_code}, + {float32_code, bfloat16_code, bfloat16_code, float32_code, int32_code}, + {bfloat16_code, bfloat16_code, float32_code, float32_code, int32_code}, + {float16_code, bfloat16_code, float32_code, float32_code, int32_code}, + {float32_code, bfloat16_code, float32_code, float32_code, int32_code}, + {bfloat16_code, bfloat16_code, bfloat16_code, float32_code, int64_code}, + {float16_code, bfloat16_code, bfloat16_code, float32_code, int64_code}, + {float32_code, bfloat16_code, bfloat16_code, float32_code, int64_code}, + {bfloat16_code, bfloat16_code, float32_code, float32_code, int64_code}, + {float16_code, bfloat16_code, float32_code, float32_code, int64_code}, + {float32_code, bfloat16_code, float32_code, float32_code, int64_code}, }; // Helper to dispatch to the right template instantiation -template +template void dispatchCombo(SelectiveStateUpdateParams& p, cudaStream_t stream) { using state_t = typename DTypeToType::type; using input_t = typename DTypeToType::type; using weight_t = typename DTypeToType::type; using matrixA_t = typename DTypeToType::type; - invokeSelectiveStateUpdate(p, stream); + using stateIndex_t = typename DTypeToType::type; + invokeSelectiveStateUpdate(p, stream); } void run_selective_state_update_stp(TensorView const& state, TensorView const& x, @@ -261,7 +276,7 @@ void run_selective_state_update_stp(TensorView const& state, TensorView const& x } // Validate dtype consistency - validate_dtypes(state, dt, D, x, B, C, dt_bias, z, out); + validate_dtype_consistency(state, dt, D, x, B, C, dt_bias, z, out); // Initialize params struct SelectiveStateUpdateParams p; @@ -322,15 +337,265 @@ void run_selective_state_update_stp(TensorView const& state, TensorView const& x DLDataType input_dtype = x.dtype(); DLDataType weight_dtype = dt.dtype(); DLDataType matrixA_dtype = A.dtype(); + int64_t state_dtype_code = encode_dlpack_dtype(state_dtype); + int64_t input_dtype_code = encode_dlpack_dtype(input_dtype); + int64_t weight_dtype_code = encode_dlpack_dtype(weight_dtype); + int64_t matrixA_dtype_code = encode_dlpack_dtype(matrixA_dtype); + + // Get state_batch_indices dtype, default to int32 if not provided + int64_t stateIndex_dtype_code = int32_code; + if (state_batch_indices.has_value()) { + DLDataType stateIndex_dtype = state_batch_indices.value().dtype(); + stateIndex_dtype_code = encode_dlpack_dtype(stateIndex_dtype); + } + + // Dispatch kernel based on dtype combination + auto dtype_key = std::make_tuple(state_dtype_code, input_dtype_code, weight_dtype_code, + matrixA_dtype_code, stateIndex_dtype_code); + + // Compile-time recursive dispatcher using Y-combinator pattern for lambda self-recursion + auto tryDispatch = [&](const auto& key, auto idx, auto& self) -> bool { + constexpr size_t I = decltype(idx)::value; + if constexpr (I < std::size(allowed_dtype_combos)) { + constexpr auto combo = allowed_dtype_combos[I]; + if (key == combo) { + constexpr auto s = std::get<0>(combo); + constexpr auto i = std::get<1>(combo); + constexpr auto w = std::get<2>(combo); + constexpr auto m = std::get<3>(combo); + constexpr auto si = std::get<4>(combo); + dispatchCombo(p, stream); + return true; + } + return self(key, std::integral_constant{}, self); + } + return false; + }; + + // Dispatch using compile-time type traits + if (!tryDispatch(dtype_key, std::integral_constant{}, tryDispatch)) { + // Unsupported dtype combination - build error message dynamically + std::ostringstream error_msg; + error_msg << "Unsupported dtype combination for selective_state_update: " << "state_dtype=" + << state_dtype.code << ":" << state_dtype.bits << ", " + << "input_dtype=" << input_dtype.code << ":" << input_dtype.bits << ", " + << "weight_dtype=" << weight_dtype.code << ":" << weight_dtype.bits << ", " + << "matrixA_dtype=" << matrixA_dtype.code << ":" << matrixA_dtype.bits + << ". Supported combos include:\n"; + for (const auto& combo : allowed_dtype_combos) { + error_msg << " (state=" << dtype_code_to_string(std::get<0>(combo)) + << ", input=" << dtype_code_to_string(std::get<1>(combo)) + << ", weight=" << dtype_code_to_string(std::get<2>(combo)) + << ", matrixA=" << dtype_code_to_string(std::get<3>(combo)) << ")\n"; + } + TVM_FFI_ICHECK(false) << error_msg.str(); + } +} + +void run_selective_state_update_mtp( + TensorView const& state, TensorView const& x, TensorView const& dt, TensorView const& A, + TensorView const& B, TensorView const& C, TensorView const& D, Optional z, + Optional dt_bias, bool dt_softplus, Optional state_batch_indices, + int64_t pad_slot_id, Optional out, bool disable_state_update, + Optional intermediate_states_buffer, + Optional intermediate_state_indices, int64_t cache_steps) { + // Extract dimensions from input tensors + auto const batch = x.size(0); + auto const ntokens_mtp = x.size(1); + auto const state_cache_size = state.size(0); + auto const nheads = state.size(1); + auto const dim = state.size(2); + auto const dstate = state.size(3); + auto const ngroups = B.size(2); + + FLASHINFER_CHECK(state_cache_size >= batch, "state.size(0) must be >= x.size(0)"); + FLASHINFER_CHECK(nheads % ngroups == 0, "nheads must be divisible by ngroups"); + // Check x shape and strides + CHECK_CUDA(x); + CHECK_DIM(4, x); + FLASHINFER_CHECK(x.size(2) == nheads, "x.size(2) must equal nheads"); + FLASHINFER_CHECK(x.size(3) == dim, "x.size(3) must equal dim"); + CHECK_LAST_DIM_CONTIGUOUS(x); + FLASHINFER_CHECK(x.stride(2) == dim, "x.stride(2) must equal dim, got ", x.stride(2), + " expected ", dim); + + // Check dt shape and strides + CHECK_CUDA(dt); + CHECK_DIM(4, dt); // dt: {batch, ntokens_mtp, nheads, dim} + FLASHINFER_CHECK(dt.size(0) == batch, "dt.size(0) must equal batch =", batch); + FLASHINFER_CHECK(dt.size(1) == ntokens_mtp, "dt.size(1) must equal ntokens_mtp =", ntokens_mtp); + FLASHINFER_CHECK(dt.size(2) == nheads, "dt.size(2) must equal nheads"); + FLASHINFER_CHECK(dt.size(3) == dim, "dt.size(3) must equal dim"); + FLASHINFER_CHECK(dt.stride(2) == 1, "dt.stride(2) must be 1, got ", dt.stride(2)); + FLASHINFER_CHECK(dt.stride(3) == 0, "dt.stride(3) must be 0 (broadcasted), got ", dt.stride(3)); + + // Validate common tensors using helper functions + validate_state_tensor(state); + validate_D_tensor(D, nheads, dim); + validate_A_tensor(A, nheads, dim, dstate); + validate_dt_bias_tensor(dt_bias, nheads, dim); + validate_state_batch_indices(state_batch_indices, batch); + + // Check B shape and strides + CHECK_CUDA(B); + CHECK_DIM(4, B); // B: {batch, ntokens_mtp, ngroups, dstate} + FLASHINFER_CHECK(B.size(0) == batch, "B.size(0) must equal batch =", batch); + FLASHINFER_CHECK(B.size(1) == ntokens_mtp, "B.size(1) must equal ntokens_mtp =", ntokens_mtp); + FLASHINFER_CHECK(B.size(2) == ngroups, "B.size(2) must equal ngroups =", ngroups); + FLASHINFER_CHECK(B.size(3) == dstate, "B.size(3) must equal dstate =", dstate); + CHECK_LAST_DIM_CONTIGUOUS(B); + FLASHINFER_CHECK(B.stride(2) == dstate, "B.stride(2) must equal dstate, got ", B.stride(2), + " expected ", dstate); + + // Check C shape and strides + CHECK_CUDA(C); + CHECK_DIM(4, C); // C: {batch, ntokens_mtp, ngroups, dstate} + FLASHINFER_CHECK(C.size(0) == batch, "C.size(0) must equal batch"); + FLASHINFER_CHECK(C.size(1) == ntokens_mtp, "C.size(1) must equal ntokens_mtp =", ntokens_mtp); + FLASHINFER_CHECK(C.size(2) == ngroups, "C.size(2) must equal ngroups"); + FLASHINFER_CHECK(C.size(3) == dstate, "C.size(3) must equal dstate"); + CHECK_LAST_DIM_CONTIGUOUS(C); + FLASHINFER_CHECK(C.stride(2) == dstate, "C.stride(2) must equal dstate, got ", C.stride(2), + " expected ", dstate); + + // Optional z check + if (z.has_value()) { + auto& z_tensor = z.value(); + CHECK_CUDA(z_tensor); + CHECK_DIM(4, z_tensor); // z: {batch, ntokens_mtp, nheads, dim} + FLASHINFER_CHECK(z_tensor.size(0) == batch, "z.size(0) must equal batch"); + FLASHINFER_CHECK(z_tensor.size(1) == ntokens_mtp, "z.size(1) must equal ntokens_mtp"); + FLASHINFER_CHECK(z_tensor.size(2) == nheads, "z.size(2) must equal nheads"); + FLASHINFER_CHECK(z_tensor.size(3) == dim, "z.size(3) must equal dim"); + CHECK_LAST_DIM_CONTIGUOUS(z_tensor); + FLASHINFER_CHECK(z_tensor.stride(2) == dim, "z.stride(2) must equal dim, got ", + z_tensor.stride(2), " expected ", dim); + } + + // Check output tensor if provided + if (out.has_value()) { + auto& output = out.value(); + CHECK_CUDA(output); + CHECK_DIM(4, output); + FLASHINFER_CHECK(output.size(0) == batch, "out.size(0) must equal batch = ", batch); + FLASHINFER_CHECK(output.size(1) == ntokens_mtp, + "out.size(1) must equal ntokens_mtp = ", ntokens_mtp); + FLASHINFER_CHECK(output.size(2) == nheads, "out.size(2) must equal nheads = ", nheads); + FLASHINFER_CHECK(output.size(3) == dim, "out.size(3) must equal dim = ", dim); + CHECK_LAST_DIM_CONTIGUOUS(output); + FLASHINFER_CHECK(output.stride(2) == dim, "out.stride(2) = ", output.stride(2), + " must equal dim = ", dim); + } + + // Validate dtype consistency + validate_dtype_consistency(state, dt, D, x, B, C, dt_bias, z, out, intermediate_states_buffer); + validate_intermediate_state_indices(intermediate_state_indices, batch); + validate_intermediate_states_buffer(intermediate_states_buffer); + + // Validate cache_steps is non-negative + FLASHINFER_CHECK(cache_steps >= 0, "cache_steps must be non-negative, got ", cache_steps); + + // Initialize MTP params struct + mtp::SelectiveStateMTPParams p; + + // Copy dimensions (inherited from base) + p.batch = batch; + p.nheads = nheads; + p.dim = dim; + p.dstate = dstate; + p.ngroups = ngroups; + p.state_cache_size = state_cache_size; + p.dt_softplus = dt_softplus; + p.pad_slot_id = pad_slot_id; + + // MTP-specific dimensions + p.ntokens_mtp = ntokens_mtp; + p.cache_steps = static_cast(cache_steps); + p.update_state = !disable_state_update; + + // Copy batch strides (inherited) + p.x_stride_batch = x.stride(0); + p.dt_stride_batch = dt.stride(0); + p.B_stride_batch = B.stride(0); + p.C_stride_batch = C.stride(0); + if (out.has_value()) { + p.out_stride_batch = out.value().stride(0); + } else { + p.out_stride_batch = 0; + } + p.state_stride_batch = state.stride(0); + + // Copy MTP strides + p.x_stride_mtp = x.stride(1); + p.dt_stride_mtp = dt.stride(1); + p.B_stride_mtp = B.stride(1); + p.C_stride_mtp = C.stride(1); + if (out.has_value()) { + p.out_stride_mtp = out.value().stride(1); + } else { + p.out_stride_mtp = 0; + } + + if (state_batch_indices.has_value()) { + p.state_batch_indices = const_cast(state_batch_indices.value().data_ptr()); + } + + if (intermediate_states_buffer.has_value()) { + p.intermediate_states = const_cast(intermediate_states_buffer.value().data_ptr()); + p.intermediate_state_stride_batch = intermediate_states_buffer.value().stride(0); + } + + if (intermediate_state_indices.has_value()) { + p.intermediate_state_indices = const_cast(intermediate_state_indices.value().data_ptr()); + } + + // Copy pointers + p.state = const_cast(state.data_ptr()); + p.x = const_cast(x.data_ptr()); + p.dt = const_cast(dt.data_ptr()); + if (out.has_value()) { + p.output = out.value().data_ptr(); + } else { + p.output = nullptr; + } + if (dt_bias.has_value()) { + p.dt_bias = const_cast(dt_bias.value().data_ptr()); + } + if (z.has_value()) { + p.z = const_cast(z.value().data_ptr()); + p.z_stride_batch = z.value().stride(0); + p.z_stride_mtp = z.value().stride(1); + } + p.A = const_cast(A.data_ptr()); + p.B = const_cast(B.data_ptr()); + p.C = const_cast(C.data_ptr()); + p.D = const_cast(D.data_ptr()); + + // Set device and get stream + ffi::CUDADeviceGuard device_guard(state.device().device_id); + const cudaStream_t stream = get_stream(state.device()); + + // Dispatch based on dtype combination + DLDataType state_dtype = state.dtype(); + DLDataType input_dtype = x.dtype(); + DLDataType weight_dtype = dt.dtype(); + DLDataType matrixA_dtype = A.dtype(); int64_t state_dtype_code = encode_dlpack_dtype(state_dtype); int64_t input_dtype_code = encode_dlpack_dtype(input_dtype); int64_t weight_dtype_code = encode_dlpack_dtype(weight_dtype); int64_t matrixA_dtype_code = encode_dlpack_dtype(matrixA_dtype); + // Get intermediate_state_indices dtype, default to int32 if not provided + int64_t intermediateStateIndex_dtype_code = int32_code; + if (intermediate_state_indices.has_value()) { + DLDataType intermediateStateIndex_dtype = intermediate_state_indices.value().dtype(); + intermediateStateIndex_dtype_code = encode_dlpack_dtype(intermediateStateIndex_dtype); + } + // Dispatch kernel based on dtype combination - auto dtype_key = - std::make_tuple(state_dtype_code, input_dtype_code, weight_dtype_code, matrixA_dtype_code); + auto dtype_key = std::make_tuple(state_dtype_code, input_dtype_code, weight_dtype_code, + matrixA_dtype_code, intermediateStateIndex_dtype_code); // Compile-time recursive dispatcher using Y-combinator pattern for lambda self-recursion auto tryDispatch = [&](const auto& key, auto idx, auto& self) -> bool { @@ -342,7 +607,8 @@ void run_selective_state_update_stp(TensorView const& state, TensorView const& x constexpr auto i = std::get<1>(combo); constexpr auto w = std::get<2>(combo); constexpr auto m = std::get<3>(combo); - dispatchCombo(p, stream); + constexpr auto si = std::get<4>(combo); + dispatchCombo(p, stream); return true; } return self(key, std::integral_constant{}, self); @@ -354,8 +620,8 @@ void run_selective_state_update_stp(TensorView const& state, TensorView const& x if (!tryDispatch(dtype_key, std::integral_constant{}, tryDispatch)) { // Unsupported dtype combination - build error message dynamically std::ostringstream error_msg; - error_msg << "Unsupported dtype combination for selective_state_update: " - << "state_dtype=" << state_dtype.code << ":" << state_dtype.bits << ", " + error_msg << "Unsupported dtype combination for selective_state_update: " << "state_dtype=" + << state_dtype.code << ":" << state_dtype.bits << ", " << "input_dtype=" << input_dtype.code << ":" << input_dtype.bits << ", " << "weight_dtype=" << weight_dtype.code << ":" << weight_dtype.bits << ", " << "matrixA_dtype=" << matrixA_dtype.code << ":" << matrixA_dtype.bits @@ -385,10 +651,9 @@ void selective_state_update(TensorView state, TensorView x, TensorView dt, Tenso state_batch_indices, pad_slot_id, output, disable_state_update); } else if (x.dim() == 4) { FLASHINFER_CHECK(false, "x must have 3 dimensions (single-token), got ", x.dim()); - // run_selective_state_update_mtp(state, x, dt, A, B, C, D, z, dt_bias, dt_softplus, - // state_batch_indices, pad_slot_id, output, - // disable_state_update, intermediate_states_buffer, - // intermediate_state_indices, cache_steps); + run_selective_state_update_mtp( + state, x, dt, A, B, C, D, z, dt_bias, dt_softplus, state_batch_indices, pad_slot_id, output, + disable_state_update, intermediate_states_buffer, intermediate_state_indices, cache_steps); } else { FLASHINFER_CHECK(false, "x must have 3 dimensions (single-token) or 4 dimensions (multi-token), got ", @@ -396,232 +661,4 @@ void selective_state_update(TensorView state, TensorView x, TensorView dt, Tenso } } -// Old function - commented out for reference -// void selective_state_update_old(TensorView state, TensorView x, TensorView dt, TensorView -// output, -// TensorView A, TensorView B, TensorView C, TensorView D, -// Optional z, Optional dt_bias, bool -// dt_softplus, Optional state_batch_indices, int64_t -// pad_slot_id) { -// auto const batch = x.size(0); -// auto const state_cache_size = state.size(0); -// auto const nheads = state.size(1); -// auto const dim = state.size(2); -// auto const dstate = state.size(3); -// auto const ngroups = B.size(1); - -// FLASHINFER_CHECK(state_cache_size >= batch, "state.size(0) must be >= x.size(0)"); - -// FLASHINFER_CHECK(nheads % ngroups == 0, "nheads must be divisible by ngroups"); - -// // Check x shape and strides -// CHECK_DIM(3, x); -// FLASHINFER_CHECK(x.size(1) == nheads, "x.size(1) must equal nheads"); -// FLASHINFER_CHECK(x.size(2) == dim, "x.size(2) must equal dim"); -// CHECK_LAST_DIM_CONTIGUOUS_INPUT(x); -// FLASHINFER_CHECK(x.stride(1) == dim, "x.stride(1) must equal dim, got ", x.stride(1), -// " expected ", x.size(2)); - -// // Check output shape and strides -// CHECK_DIM(3, output); -// CHECK_LAST_DIM_CONTIGUOUS(output); -// FLASHINFER_CHECK(output.size(1) == nheads, "output.size(1) must equal nheads"); -// FLASHINFER_CHECK(output.size(2) == dim, "output.size(2) must equal dim"); -// FLASHINFER_CHECK(output.stride(1) == dim, "output.stride(1) must equal dim"); - -// // Check dt shape and strides -// CHECK_CUDA(dt); -// CHECK_DIM(3, dt); // dt: {batch, nheads, dim} -// FLASHINFER_CHECK(dt.size(1) == nheads, "dt.size(1) must equal nheads"); -// FLASHINFER_CHECK(dt.size(2) == dim, "dt.size(2) must equal dim"); -// FLASHINFER_CHECK(dt.stride(1) == 1, "dt.stride(1) must be 1, got ", dt.stride(1)); -// FLASHINFER_CHECK(dt.stride(2) == 0, "dt.stride(2) must be 0 (broadcasted), got ", -// dt.stride(2)); - -// // Check state - fully contiguous -// CHECK_INPUT(state); // CUDA + fully contiguous (uses TVM FFI) -// CHECK_DIM(4, state); // state: {state_cache_size, nheads, dim, dstate} - -// // Check B shape and strides -// CHECK_CUDA(B); -// CHECK_DIM(3, B); // B: {batch, B.size(1), dstate} -// FLASHINFER_CHECK(B.size(0) == batch, "B.size(0) must equal batch"); -// FLASHINFER_CHECK(B.size(1) == ngroups, "B.size(1) must equal ngroups"); -// FLASHINFER_CHECK(B.size(2) == dstate, "B.size(2) must equal dstate"); -// CHECK_LAST_DIM_CONTIGUOUS(B); // stride(2) == 1 -// FLASHINFER_CHECK(B.stride(1) == B.size(2), "B.stride(1) must equal dstate, got ", B.stride(1), -// " expected ", B.size(2)); - -// // Check C shape and strides -// CHECK_CUDA(C); -// CHECK_LAST_DIM_CONTIGUOUS(C); // stride(2) == 1 -// CHECK_DIM(3, C); // C: {batch, C.size(1), dstate} -// FLASHINFER_CHECK(C.stride(1) == C.size(2), "C.stride(1) must equal dstate, got ", C.stride(1), -// " expected ", C.size(2)); -// FLASHINFER_CHECK(C.size(0) == batch, "C.size(0) must equal batch"); -// FLASHINFER_CHECK(C.size(1) == ngroups, "C.size(1) must equal ngroups"); -// FLASHINFER_CHECK(C.size(2) == dstate, "C.size(2) must equal dstate"); - -// // Check D - specific stride patterns indicating broadcasting -// CHECK_CUDA(D); -// CHECK_DIM(2, D); // D: {nheads, dim} -// FLASHINFER_CHECK(D.size(0) == nheads, "D.size(0) must equal nheads"); -// FLASHINFER_CHECK(D.size(1) == dim, "D.size(1) must equal dim"); -// FLASHINFER_CHECK(D.stride(0) == 1, "D.stride(0) must be 1, got ", D.stride(0)); -// FLASHINFER_CHECK(D.stride(1) == 0, "D.stride(1) must be 0 (broadcasted), got ", D.stride(1)); - -// // Check A - specific stride patterns indicating broadcasting -// CHECK_CUDA(A); -// CHECK_DIM(3, A); // A: {nheads, dim, dstate} -// FLASHINFER_CHECK(A.size(0) == nheads, "A.size(0) must equal nheads"); -// FLASHINFER_CHECK(A.size(1) == dim, "A.size(1) must equal dim"); -// FLASHINFER_CHECK(A.size(2) == dstate, "A.size(2) must equal dstate"); -// FLASHINFER_CHECK(A.stride(1) == 0, "A.stride(1) must be 0 (broadcasted), got ", A.stride(1)); -// FLASHINFER_CHECK(A.stride(2) == 0, "A.stride(2) must be 0 (broadcasted), got ", A.stride(2)); - -// // Optional dt_bias check -// if (dt_bias.has_value()) { -// auto& bias = dt_bias.value(); -// CHECK_CUDA(bias); -// CHECK_DIM(2, bias); // dt_bias: {nheads, dim} -// FLASHINFER_CHECK(bias.size(0) == nheads, "dt_bias.size(0) must equal nheads"); -// FLASHINFER_CHECK(bias.size(1) == dim, "dt_bias.size(1) must equal dim"); -// FLASHINFER_CHECK(bias.stride(0) == 1, "dt_bias.stride(0) must be 1, got ", bias.stride(0)); -// FLASHINFER_CHECK(bias.stride(1) == 0, "dt_bias.stride(1) must be 0 (broadcasted), got ", -// bias.stride(1)); -// } - -// if (z.has_value()) { -// auto& z_tensor = z.value(); -// CHECK_CUDA(z_tensor); -// CHECK_DIM(3, z_tensor); // z: {batch, nheads, dim} -// FLASHINFER_CHECK(z_tensor.size(0) == batch, "z.size(0) must equal batch"); -// FLASHINFER_CHECK(z_tensor.size(1) == nheads, "z.size(1) must equal nheads"); -// FLASHINFER_CHECK(z_tensor.size(2) == dim, "z.size(2) must equal dim"); -// CHECK_LAST_DIM_CONTIGUOUS_INPUT(z_tensor); -// FLASHINFER_CHECK(z_tensor.stride(1) == dim, "z.stride(1) must equal dim, got ", -// z_tensor.stride(1), " expected ", z_tensor.size(2)); -// } - -// if (state_batch_indices) { -// CHECK_DIM(1, (*state_batch_indices)); -// FLASHINFER_CHECK(state_batch_indices.value().size(0) == batch, -// "state_batch_indices.shape must be (", batch, ")"); -// } - -// SelectiveStateUpdateParams p; - -// // copy dimensions -// p.batch = batch; -// p.nheads = nheads; -// p.dim = dim; -// p.dstate = dstate; -// p.ngroups = ngroups; -// p.state_cache_size = state_cache_size; -// p.dt_softplus = dt_softplus; -// p.pad_slot_id = pad_slot_id; - -// // Copy strides -// p.x_stride_batch = x.stride(0); -// p.dt_stride_batch = dt.stride(0); -// p.B_stride_batch = B.stride(0); -// p.C_stride_batch = C.stride(0); -// p.out_stride_batch = output.stride(0); -// if (state_batch_indices) p.state_batch_indices = state_batch_indices.value().data_ptr(); - -// // Copy pointers -// p.state = state.data_ptr(); -// p.x = x.data_ptr(); -// p.dt = dt.data_ptr(); -// p.output = output.data_ptr(); -// if (dt_bias) { -// p.dt_bias = dt_bias.value().data_ptr(); -// } -// if (z) { -// p.z = z.value().data_ptr(); -// p.z_stride_batch = z.value().stride(0); -// } -// p.A = A.data_ptr(); -// p.B = B.data_ptr(); -// p.C = C.data_ptr(); -// p.D = D.data_ptr(); - -// // Set device and get stream -// ffi::CUDADeviceGuard device_guard(state.device().device_id); -// const cudaStream_t stream = get_stream(state.device()); - -// // Dispatch based on dtype combination -// DLDataType state_dtype = state.dtype(); -// DLDataType input_dtype = x.dtype(); -// DLDataType weight_dtype = dt.dtype(); -// DLDataType matrixA_dtype = A.dtype(); - -// int64_t state_dtype_code = encode_dlpack_dtype(state_dtype); -// int64_t input_dtype_code = encode_dlpack_dtype(input_dtype); -// int64_t weight_dtype_code = encode_dlpack_dtype(weight_dtype); -// int64_t matrixA_dtype_code = encode_dlpack_dtype(matrixA_dtype); - -// // Pack all dtype codes into a single value for switching -// auto dtype_key = -// std::make_tuple(state_dtype_code, input_dtype_code, weight_dtype_code, matrixA_dtype_code); - -// if (dtype_key == std::make_tuple(bfloat16_code, bfloat16_code, bfloat16_code, float32_code)) { -// using state_t = nv_bfloat16; -// using input_t = nv_bfloat16; -// using weight_t = nv_bfloat16; -// using matrixA_t = float; -// invokeSelectiveStateUpdate(p, stream); -// } else if (dtype_key == -// std::make_tuple(float16_code, bfloat16_code, bfloat16_code, float32_code)) { -// using state_t = half; -// using input_t = nv_bfloat16; -// using weight_t = nv_bfloat16; -// using matrixA_t = float; -// invokeSelectiveStateUpdate(p, stream); -// } else if (dtype_key == -// std::make_tuple(float32_code, bfloat16_code, bfloat16_code, float32_code)) { -// using state_t = float; -// using input_t = nv_bfloat16; -// using weight_t = nv_bfloat16; -// using matrixA_t = float; -// invokeSelectiveStateUpdate(p, stream); -// } else if (dtype_key == -// std::make_tuple(bfloat16_code, bfloat16_code, float32_code, float32_code)) { -// using state_t = nv_bfloat16; -// using input_t = nv_bfloat16; -// using weight_t = float; -// using matrixA_t = float; -// invokeSelectiveStateUpdate(p, stream); -// } else if (dtype_key == -// std::make_tuple(float16_code, bfloat16_code, float32_code, float32_code)) { -// using state_t = half; -// using input_t = nv_bfloat16; -// using weight_t = float; -// using matrixA_t = float; -// invokeSelectiveStateUpdate(p, stream); -// } else if (dtype_key == -// std::make_tuple(float32_code, bfloat16_code, float32_code, float32_code)) { -// using state_t = float; -// using input_t = nv_bfloat16; -// using weight_t = float; -// using matrixA_t = float; -// invokeSelectiveStateUpdate(p, stream); -// } else { -// // Default case: unsupported dtype combination -// TVM_FFI_ICHECK(false) -// << "Unsupported dtype combination for selective_state_update: " -// << "state_dtype=" << state_dtype.code << ":" << state_dtype.bits << ", " -// << "input_dtype=" << input_dtype.code << ":" << input_dtype.bits << ", " -// << "weight_dtype=" << weight_dtype.code << ":" << weight_dtype.bits << ", " -// << "matrixA_dtype=" << matrixA_dtype.code << ":" << matrixA_dtype.bits -// << ". Supported combos include:\n" -// << " (state=bfloat16, input=bfloat16, weight=bfloat16, matrixA=float32)\n" -// << " (state=float16, input=bfloat16, weight=bfloat16, matrixA=float32)\n" -// << " (state=float32, input=bfloat16, weight=bfloat16, matrixA=float32)\n" -// << " (state=bfloat16, input=bfloat16, weight=float32, matrixA=float32)\n" -// << " (state=float16, input=bfloat16, weight=float32, matrixA=float32)\n" -// << " (state=float32, input=bfloat16, weight=float32, matrixA=float32)"; -// } -// } - } // namespace flashinfer::mamba diff --git a/include/flashinfer/mamba/create_tensor_map.cuh b/include/flashinfer/mamba/create_tensor_map.cuh index d58ca94bf8..8d08cff69d 100644 --- a/include/flashinfer/mamba/create_tensor_map.cuh +++ b/include/flashinfer/mamba/create_tensor_map.cuh @@ -8,83 +8,8 @@ #include #include -#ifndef gpuErrchk -#define gpuErrchk(ans) \ - { \ - gpuAssert((ans), __FILE__, __LINE__); \ - } -#endif - -static inline void gpuAssert(cudaError_t code, const char* file, int line, bool abort = true) { - if (code != cudaSuccess) { - fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); - std::cout << "GPU assert failed" << std::endl; - if (abort) exit(code); - } -} namespace flashinfer::mamba::tma { -// namespace cde = cuda::device::experimental; - -static inline PFN_cuTensorMapEncodeTiled_v12000 get_cuTensorMapEncodeTiled() { - // Get pointer to cuTensorMapEncodeTiled - cudaDriverEntryPointQueryResult driver_status; - void* cuTensorMapEncodeTiled_ptr = nullptr; - gpuErrchk(cudaGetDriverEntryPointByVersion("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, - 12000, cudaEnableDefault, &driver_status)); - - if (driver_status != cudaDriverEntryPointSuccess) { - std::cerr << "Could not get cuTensorMapEncodeTiled driver entry point" << std::endl; - abort(); - } - - return reinterpret_cast(cuTensorMapEncodeTiled_ptr); -} - -template -inline CUtensorMap createTensorMap(void* matrix_ptr, uint32_t matrix_height, uint32_t matrix_width, - uint32_t tile_height, uint32_t tile_width) { - CUtensorMap tensor_map{}; - constexpr uint32_t rank = 2; - - std::array matrix_dim = {matrix_width, matrix_height}; - std::array stride = {matrix_width * sizeof(Dtype)}; - std::array box_size = {tile_width, tile_height}; - std::array elem_stride = {1, 1}; - - // CUtensorMapDataType dtype_format = CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16; - CUtensorMapDataType dtype_format; - if constexpr (std::is_same_v) { - dtype_format = CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16; - } else if constexpr (std::is_same_v) { - dtype_format = CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32; - } else if constexpr (std::is_same_v) { - dtype_format = CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; - } else { - static_assert([]() { return false; }(), "Unsupported data type for TMA tensor map"); - return tensor_map; // shut the compiler up - } - - auto cuTensorMapEncodeTiled = get_cuTensorMapEncodeTiled(); - CUresult res = cuTensorMapEncodeTiled( - &tensor_map, dtype_format, rank, matrix_ptr, matrix_dim.data(), stride.data(), - box_size.data(), elem_stride.data(), CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, - CU_TENSOR_MAP_SWIZZLE_NONE, CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, - CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); - - if (res != CUDA_SUCCESS) { - const char* err_name = nullptr; - const char* err_str = nullptr; - cuGetErrorName(res, &err_name); - cuGetErrorString(res, &err_str); - std::cerr << "Could not create a tensor map" << std::endl; - std::cerr << "Error is: " << err_name << ": " << err_str << std::endl; - abort(); - } - - return tensor_map; -} - inline CUtensorMap buildNdDescriptor(std::type_info const& dtype, std::vector const& shapes, std::vector const& strides, diff --git a/include/flashinfer/mamba/selective_state_update.cuh b/include/flashinfer/mamba/selective_state_update.cuh index e9af5b227b..64e0eb2b0c 100644 --- a/include/flashinfer/mamba/selective_state_update.cuh +++ b/include/flashinfer/mamba/selective_state_update.cuh @@ -16,23 +16,21 @@ #ifndef FLASHINFER_MAMBA_SELECTIVE_STATE_UPDATE_CUH_ #define FLASHINFER_MAMBA_SELECTIVE_STATE_UPDATE_CUH_ -#include -#include -#include +// #include +// #include +// #include -#include +// #include #include -#include +// #include -#include "../utils.cuh" -#include "../vec_dtypes.cuh" -#include "conversion.cuh" -#include "create_tensor_map.cuh" +// #include "../utils.cuh" +// #include "../vec_dtypes.cuh" +// #include "conversion.cuh" +// #include "create_tensor_map.cuh" namespace flashinfer::mamba { -using namespace conversion; - constexpr unsigned warpSize = 32; struct SelectiveStateUpdateParams { @@ -101,1072 +99,24 @@ __device__ __forceinline__ float warpReduceSum(float val) { return val; } -// Computes a conflict-free column index for shared memory access. -// This permutation avoids bank conflicts when threads access strided patterns. -// -// Without permutation (baseCol directly): -// Thread 0 -> Bank 0, Thread 32 -> Bank 0, Thread 64 -> Bank 0 (conflict!) -// -// With permutation (adding bankCycle offset): -// bankCycle = which "round" of 32 banks we're in -// By offsetting each round by 1 bank: -// Thread 0 -> Bank 0 -// Thread 32 -> Bank 1 (offset by 1) -// Thread 64 -> Bank 2 (offset by 2) -// -// Visual: (stateValuesPerBank=1, numBanks=32, colsPerStage=128) -// baseCol: 0 1 2 ... 31 | 32 33 34 ... 63 | 64 ... -// bankCycle: 0 0 0 ... 0 | 1 1 1 ... 1 | 2 ... -// ii: 0 1 2 ... 31 | 33 34 35 ... 64 | 66 ... (mod colsPerStage) -template -__device__ __forceinline__ int conflict_free_column(int group, int baseCol) { - auto const seq_index = group * colsPerStage + baseCol; - auto const bankCycle = (seq_index / stateValuesPerBank) / numBanks; - return (baseCol + stateValuesPerBank * bankCycle) % colsPerStage; -} - -template -struct SharedStorageSimple { - alignas(alignof(PackedAligned)) input_t x[dim]; - float out[dim]; - alignas(alignof(PackedAligned)) input_t z[dim]; - alignas(alignof(PackedAligned)) input_t B[dstate]; - alignas(alignof(PackedAligned)) input_t C[dstate]; -}; - -template -__global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams params) { - auto* __restrict__ output = reinterpret_cast(params.output); - auto* __restrict__ state = reinterpret_cast(params.state); - - auto const* __restrict__ x = reinterpret_cast(params.x); - auto const* __restrict__ dt = reinterpret_cast(params.dt); - auto const* __restrict__ A = reinterpret_cast(params.A); - auto const* __restrict__ B = reinterpret_cast(params.B); - auto const* __restrict__ C = reinterpret_cast(params.C); - auto const* __restrict__ D = reinterpret_cast(params.D); // D: (nheads, dim) - auto const* __restrict__ dt_bias = reinterpret_cast(params.dt_bias); // (nheads) - auto const* __restrict__ z = reinterpret_cast(params.z); - auto const* __restrict__ state_batch_indices = - reinterpret_cast(params.state_batch_indices); - bool const dt_softplus = params.dt_softplus; - - int const nheads = params.nheads; - int const ngroups = params.ngroups; - - constexpr auto rowsPerWarp = (DIM + numWarps - 1) / numWarps; - - auto const batch = blockIdx.x; - auto const head = blockIdx.y; - auto const group = head / (nheads / ngroups); - auto lane = threadIdx.x % warpSize; - auto warp = threadIdx.y; - - auto const state_batch = (state_batch_indices) ? state_batch_indices[batch] : batch; - state += state_batch * params.state_stride_batch + head * DIM * DSTATE; - - __shared__ SharedStorageSimple sram; - - static constexpr auto stateLoadSize = getVectorLoadSizeForFullUtilization(); - using load_state_t = PackedAligned; - using load_input_t = PackedAligned; - - auto const A_value = toFloat(A[head]); - - auto dt_value = toFloat(dt[batch * params.dt_stride_batch + head]); - if (dt_bias) dt_value += toFloat(dt_bias[head]); - if (dt_softplus) { - dt_value = thresholded_softplus(dt_value); - } - - auto const dA = __expf(A_value * dt_value); - - auto d_value = D ? toFloat(D[head]) : 0.f; - - if (warp == 0) { - for (auto d = lane * load_input_t::count; d < DIM; d += warpSize * load_input_t::count) { - auto* dst = reinterpret_cast(&sram.x[d]); - *dst = *reinterpret_cast( - &x[batch * params.x_stride_batch + head * DIM + d]); - } - for (auto i = lane * load_input_t::count; i < DSTATE; i += warpSize * load_input_t::count) { - auto* dst = reinterpret_cast(&sram.B[i]); - *dst = *reinterpret_cast( - &B[batch * params.B_stride_batch + group * DSTATE + i]); - } - } else if (warp == 1) { // Load z, C - for (auto d = lane * load_input_t::count; d < DIM; d += warpSize * load_input_t::count) { - auto* dst = reinterpret_cast(&sram.z[d]); - *dst = z ? *reinterpret_cast( - &z[batch * params.z_stride_batch + head * DIM + d]) - : make_zeros(); - } - for (auto i = lane * load_input_t::count; i < DSTATE; i += warpSize * load_input_t::count) { - auto* dst = reinterpret_cast(&sram.C[i]); - *dst = *reinterpret_cast( - &C[batch * params.C_stride_batch + group * DSTATE + i]); - } - } - __syncthreads(); - - for (auto _d = warp * rowsPerWarp; _d < (warp + 1) * rowsPerWarp; _d++) { - auto d = _d; - if (d >= DIM) break; - - float x_value = toFloat(sram.x[_d]); - float out_value = d_value * x_value * int(lane == 0); // first lane has the value - - for (int i = lane * load_state_t::count; i < DSTATE; i += warpSize * load_state_t::count) { - auto rState = make_zeros(); - if (state_batch != params.pad_slot_id) - rState = *reinterpret_cast(&state[d * DSTATE + i]); - - for (int ii = 0; ii < load_state_t::count; ii++) { - auto state_value = toFloat(rState.val[ii]); - auto B_value = toFloat(sram.B[i + ii]); - auto C_value = toFloat(sram.C[i + ii]); - - auto const dB = B_value * dt_value; - auto const new_state = state_value * dA + dB * x_value; - - convertAndStore(&rState.val[ii], new_state); - - out_value += new_state * C_value; - } - if (state_batch != params.pad_slot_id) - *reinterpret_cast(&state[d * DSTATE + i]) = rState; - } - - // warpReduce the out_value - out_value = warpReduceSum(out_value); - if (lane == 0) { - sram.out[_d] = out_value; - } - } - - __syncthreads(); - - for (int l = lane; l < rowsPerWarp; l += warpSize) { - auto d = warp * rowsPerWarp + l; - if (d < DIM) { - auto out_value = sram.out[d]; - if (z) { - float z_value = toFloat(sram.z[d]); - float sig_z = __fdividef(1.f, (1.f + __expf(0.f - z_value))); - float silu_z = z_value * sig_z; - out_value *= silu_z; - } - convertAndStore(&output[batch * params.out_stride_batch + head * DIM + d], out_value); - } - } -} - -template -struct SharedStorageVertical { - alignas(128) state_t state[numStages][rowsPerStage * dstate]; - alignas(alignof(PackedAligned)) input_t x[dim]; - float out[dim]; // dt is special cause we're gonna store input in there as well - alignas(alignof(PackedAligned)) input_t z[dim]; - alignas(alignof(PackedAligned)) input_t B[dstate]; - alignas(alignof(PackedAligned)) input_t C[dstate]; - - using barrier_t = cuda::barrier; - barrier_t bar_empty[numStages]; - barrier_t bar_full[numStages]; - barrier_t bar_consumers; -}; - -template -__device__ __forceinline__ void producer_func_vertical(SramT& sram, CUtensorMap const& tensorState, - int batch, int head) { -#ifdef FLASHINFER_MAMBA_ENABLE_SM90 - namespace cde = cuda::device::experimental; - - auto constexpr stagesReadOnly = numStages; - auto constexpr stagesBoth = DIM / rowsPerStage - numStages; - auto constexpr stagesWriteOnly = numStages; - - auto constexpr bytesState = rowsPerStage * DSTATE * sizeof(state_t); - auto constexpr bytesToArrive = bytesState; - - // Phase 1: Read only (filling the pipeline) -#pragma unroll - for (int iter = 0; iter < stagesReadOnly; ++iter) { - auto const stage = iter % numStages; - auto const d = iter * rowsPerStage; - - sram.bar_empty[stage].wait(sram.bar_empty[stage].arrive()); - - if constexpr (readState) { - cde::cp_async_bulk_tensor_4d_global_to_shared(&sram.state[stage][0], &tensorState, 0, d, head, - batch, sram.bar_full[stage]); - - auto const _ = cuda::device::barrier_arrive_tx(sram.bar_full[stage], 1, bytesToArrive); - } else { - auto const _ = sram.bar_full[stage].arrive(); - } - } - - // Phase 2: Both read and write (steady state) -#pragma unroll - for (int iter = 0; iter < stagesBoth; ++iter) { - auto const stage = (stagesReadOnly + iter) % numStages; - auto const d_read = (stagesReadOnly + iter) * rowsPerStage; - auto const d_write = iter * rowsPerStage; - - sram.bar_empty[stage].wait(sram.bar_empty[stage].arrive()); - - if constexpr (readState || writeState) { - // Unblock async proxy for writeback - cde::fence_proxy_async_shared_cta(); - // Writeback - if constexpr (writeState) { - cde::cp_async_bulk_tensor_4d_shared_to_global(&tensorState, 0, d_write, head, batch, - &sram.state[stage][0]); - - cde::cp_async_bulk_commit_group(); - cde::cp_async_bulk_wait_group_read<0>(); - } - - // Read next - if constexpr (readState) { - cde::cp_async_bulk_tensor_4d_global_to_shared(&sram.state[stage][0], &tensorState, 0, - d_read, head, batch, sram.bar_full[stage]); - auto const _ = cuda::device::barrier_arrive_tx(sram.bar_full[stage], 1, bytesToArrive); - } else { - auto const _ = sram.bar_full[stage].arrive(); - } - } else { - auto const _ = sram.bar_full[stage].arrive(); - } - } - - // Phase 3: Write only (draining the pipeline) -#pragma unroll - for (int iter = 0; iter < stagesWriteOnly; ++iter) { - auto const stage = (stagesReadOnly + stagesBoth + iter) % numStages; - auto const d_write = (stagesBoth + iter) * rowsPerStage; - - sram.bar_empty[stage].wait(sram.bar_empty[stage].arrive()); - - if constexpr (writeState) { - // Unblock async proxy for writeback - cde::fence_proxy_async_shared_cta(); - cde::cp_async_bulk_tensor_4d_shared_to_global(&tensorState, 0, d_write, head, batch, - &sram.state[stage][0]); - - cde::cp_async_bulk_commit_group(); - cde::cp_async_bulk_wait_group_read<0>(); - } - } -#endif -} - -template -__device__ __forceinline__ void consumer_func_vertical( - int lane, int warp, float d_value, float dt_value, float dA, - SharedStorageVertical& sram) { -#ifdef FLASHINFER_MAMBA_ENABLE_SM90 - namespace cde = cuda::device::experimental; - for (auto dBegin = 0, stage = 0; dBegin < DIM; - dBegin += rowsPerStage, stage = (stage + 1) % numStages) { - // wait for the producer - sram.bar_full[stage].wait(sram.bar_full[stage].arrive()); - -#pragma unroll - for (auto dd = warp; dd < rowsPerStage; dd += consumerWarps) { - auto d = dBegin + dd; - float const x_value = toFloat(sram.x[d]); - float out_value = d_value * x_value * int(lane == 0); // first lane has the value - - constexpr auto bankSize = sizeof(uint32_t); - constexpr auto stateValuesPerBank = bankSize / sizeof(state_t); - - if constexpr (sizeof(state_t) == sizeof(input_t)) { - for (int i = lane * stateValuesPerBank; i < DSTATE; i += warpSize * stateValuesPerBank) { - auto* sState_ptr = reinterpret_cast(&sram.state[stage][dd * DSTATE + i]); - uint32_t rState = *sState_ptr; - auto* rState_ptr = reinterpret_cast(&rState); - - uint32_t rB = *reinterpret_cast(&sram.B[i]); - auto* rB_ptr = reinterpret_cast(&rB); - - uint32_t rC = *reinterpret_cast(&sram.C[i]); - auto* rC_ptr = reinterpret_cast(&rC); - - for (int e = 0; e < stateValuesPerBank; e++) { - float state_value; - if constexpr (!useStateCache) { - state_value = 0.f; - } else { - state_value = toFloat(rState_ptr[e]); - } - auto const B_value = toFloat(rB_ptr[e]); - auto const C_value = toFloat(rC_ptr[e]); - - auto const dB = B_value * dt_value; - auto const new_state = state_value * dA + dB * x_value; - - convertAndStore(&rState_ptr[e], new_state); - out_value += new_state * C_value; - } - *sState_ptr = rState; - } - } else { - for (int i = lane * stateValuesPerBank; i < DSTATE; i += warpSize * stateValuesPerBank) { - auto* sState_ptr = reinterpret_cast(&sram.state[stage][dd * DSTATE + i]); - uint32_t rState = *sState_ptr; - auto* rState_ptr = reinterpret_cast(&rState); - - for (int e = 0; e < stateValuesPerBank; e++) { - float state_value; - if constexpr (!useStateCache) { - state_value = 0.f; - } else { - state_value = toFloat(rState_ptr[e]); - } - auto const B_value = toFloat(sram.B[i + e]); - auto const C_value = toFloat(sram.C[i + e]); - auto const dB = B_value * dt_value; - auto const new_state = state_value * dA + dB * x_value; - - convertAndStore(&rState_ptr[e], new_state); - out_value += new_state * C_value; - } - *sState_ptr = rState; - } - } - - out_value = warpReduceSum(out_value); - if (lane == 0) { - sram.out[d] = out_value; - } - } - - // Unblock producer - cde::fence_proxy_async_shared_cta(); - auto _ = sram.bar_empty[stage].arrive(); - } -#endif -} - -template -__global__ void selective_state_update_kernel_producer_consumer_vertical( - SelectiveStateUpdateParams params, __grid_constant__ CUtensorMap const tensorState) { -#ifdef FLASHINFER_MAMBA_ENABLE_SM90 - auto* __restrict__ output = reinterpret_cast(params.output); - - auto const* __restrict__ x = reinterpret_cast(params.x); - auto const* __restrict__ dt = reinterpret_cast(params.dt); - auto const* __restrict__ A = reinterpret_cast(params.A); - auto const* __restrict__ B = reinterpret_cast(params.B); - auto const* __restrict__ C = reinterpret_cast(params.C); - auto const* __restrict__ D = reinterpret_cast(params.D); - auto const* __restrict__ dt_bias = reinterpret_cast(params.dt_bias); - auto const* __restrict__ z = reinterpret_cast(params.z); - auto const* __restrict__ state_batch_indices = - reinterpret_cast(params.state_batch_indices); - - int const nheads = params.nheads; - int const ngroups = params.ngroups; - - constexpr auto numWarps = 1 + consumerWarps; - - auto const batch = blockIdx.x; - auto const head = blockIdx.y; - auto const group = head / (nheads / ngroups); - auto lane = threadIdx.x % warpSize; - auto warp = threadIdx.y; - - auto const state_batch = (state_batch_indices) ? state_batch_indices[batch] : batch; - - extern __shared__ uint8_t sbuffer[]; - using sram_t = SharedStorageVertical; - auto& sram = *reinterpret_cast(sbuffer); - - namespace cde = cuda::device::experimental; - namespace cg = cooperative_groups; - - for (int stage = warp; stage < numStages; stage += numWarps) { - if (lane > 0) continue; - constexpr auto num_arrivals = 1 + consumerWarps * warpSize; - init(&sram.bar_empty[stage], num_arrivals); - init(&sram.bar_full[stage], num_arrivals); - // signal to async proxy that barriers are initilized - cde::fence_proxy_async_shared_cta(); - } - if (lane == 0 && warp == 0) { - init(&sram.bar_consumers, warpSize * consumerWarps); - } - __syncthreads(); - - if (warp == consumerWarps) // producer - { - // auto const state_offset = (state_batch * nheads + head) * DIM; - auto const read_state = (state_batch != params.pad_slot_id); - auto const write_state = read_state && params.update_state; - - if (lane == 0) { - cg::invoke_one(cg::coalesced_threads(), [&]() { - if (read_state && write_state) - producer_func_vertical( - sram, tensorState, state_batch, head); - else if (read_state && !write_state) - producer_func_vertical( - sram, tensorState, state_batch, head); - else - producer_func_vertical( - sram, tensorState, state_batch, head); - }); - } - } else { // consumers - - using load_t = PackedAligned; - -#pragma unroll - // Unblock the producer - for (uint8_t stage = 0; stage < numStages; ++stage) { - auto const _ = sram.bar_empty[stage].arrive(); - } - - // Load A - auto const A_value = toFloat(A[head]); - - // Load D - auto const d_value = D ? toFloat(D[head]) : 0.f; - - // load dt_value - auto dt_value = toFloat(dt[batch * params.dt_stride_batch + head]); - if (dt_bias) dt_value += toFloat(dt_bias[head]); - if (params.dt_softplus) { - dt_value = thresholded_softplus(dt_value); - } - auto const dA = __expf(A_value * dt_value); - - if (warp == 0) { // Load x, B - for (auto d = lane * load_t::count; d < DIM; d += warpSize * load_t::count) { - auto* dst = reinterpret_cast(&sram.x[d]); - *dst = *reinterpret_cast(&x[batch * params.x_stride_batch + head * DIM + d]); - } - for (auto i = lane * load_t::count; i < DSTATE; i += warpSize * load_t::count) { - auto* dst = reinterpret_cast(&sram.B[i]); - *dst = *reinterpret_cast( - &B[batch * params.B_stride_batch + group * DSTATE + i]); - } - } else if (warp == 1) { // Load z, C - for (auto d = lane * load_t::count; d < DIM; d += warpSize * load_t::count) { - auto* dst = reinterpret_cast(&sram.z[d]); - *dst = - z ? *reinterpret_cast(&z[batch * params.z_stride_batch + head * DIM + d]) - : make_zeros(); - } - for (auto i = lane * load_t::count; i < DSTATE; i += warpSize * load_t::count) { - auto* dst = reinterpret_cast(&sram.C[i]); - *dst = *reinterpret_cast( - &C[batch * params.C_stride_batch + group * DSTATE + i]); - } - } - - sram.bar_consumers.wait(sram.bar_consumers.arrive()); - - if (state_batch != params.pad_slot_id) - consumer_func_vertical(lane, warp, d_value, dt_value, dA, - sram); - else - consumer_func_vertical(lane, warp, d_value, dt_value, dA, - sram); - - // Write output - sram.bar_consumers.wait(sram.bar_consumers.arrive()); - auto d = warp * warpSize + lane; - if (d < DIM) { - auto out_value = sram.out[d]; - if (z) { - float z_value = toFloat(sram.z[d]); - float sig_z = __fdividef(1.f, (1.f + __expf(0.f - z_value))); - float silu_z = z_value * sig_z; - out_value *= silu_z; - } - convertAndStore(&output[batch * params.out_stride_batch + head * DIM + d], out_value); - } - } -#endif -} - -// ============================================================================= -// Horizontal Producer-Consumer Kernel for SM100+ (Blackwell and newer) -// ============================================================================= - -#ifdef FLASHINFER_MAMBA_ENABLE_SM100 - -template -struct SharedStorageHorizontal { - alignas(128) state_t state[numStages][dim * stageCols]; - alignas(alignof(PackedAligned)) input_t B[dstate]; - alignas(alignof(PackedAligned)) input_t C[dstate]; +namespace mtp { +// Extended params struct for multi-token prediction (MTP) +struct SelectiveStateMTPParams : public SelectiveStateUpdateParams { + uint32_t ntokens_mtp{1}; + uint64_t cache_steps{0}; - using barrier_t = cuda::barrier; - barrier_t bar_empty[numStages]; - barrier_t bar_full[numStages]; - barrier_t bar_consumers; + // MTP-specific strides for the token dimension + int64_t x_stride_mtp{}, dt_stride_mtp{}, B_stride_mtp{}, C_stride_mtp{}, out_stride_mtp{}, + z_stride_mtp{}; + void* __restrict__ intermediate_states{ + nullptr}; // state_t: (ntokens_mtp, state_cache_size, nheads, dim, dstate) + void* __restrict__ intermediate_state_indices{nullptr}; // (batch,) + int64_t intermediate_state_stride_batch{}; // stride for batch dimension of intermediate_states }; - -template -__device__ __forceinline__ void producer_func_horizontal(SramT& sram, - CUtensorMap const& tensorState, int batch, - int head) { - namespace cde = cuda::device::experimental; - - auto constexpr stagesReadOnly = numStages; - auto constexpr stagesBoth = DSTATE / colsPerStage - numStages; - auto constexpr stagesWriteOnly = numStages; - - auto constexpr bytesState = DIM * colsPerStage * sizeof(state_t); - auto constexpr bytesToArrive = bytesState; - - // Phase 1: Read only (filling the pipeline) -#pragma unroll - for (int iter = 0; iter < stagesReadOnly; ++iter) { - auto const stage = iter % numStages; - auto const i = iter * colsPerStage; - - sram.bar_empty[stage].wait(sram.bar_empty[stage].arrive()); - - if constexpr (readState) { - cde::cp_async_bulk_tensor_4d_global_to_shared(&sram.state[stage][0], &tensorState, i, 0, head, - batch, sram.bar_full[stage]); - auto const _ = cuda::device::barrier_arrive_tx(sram.bar_full[stage], 1, bytesToArrive); - } else { - auto const _ = sram.bar_full[stage].arrive(); - } - } - - // Phase 2: Both read and write (steady state) -#pragma unroll - for (int iter = 0; iter < stagesBoth; ++iter) { - auto const stage = (stagesReadOnly + iter) % numStages; - auto const i_read = (stagesReadOnly + iter) * colsPerStage; - auto const i_write = iter * colsPerStage; - - sram.bar_empty[stage].wait(sram.bar_empty[stage].arrive()); - - if constexpr (readState || writeState) { - // Unblock async proxy for writeback - cde::fence_proxy_async_shared_cta(); - // Writeback - if constexpr (writeState) { - cde::cp_async_bulk_tensor_4d_shared_to_global(&tensorState, i_write, 0, head, batch, - &sram.state[stage][0]); - cde::cp_async_bulk_commit_group(); - cde::cp_async_bulk_wait_group_read<0>(); - } - - // Read next - if constexpr (readState) { - cde::cp_async_bulk_tensor_4d_global_to_shared(&sram.state[stage][0], &tensorState, i_read, - 0, head, batch, sram.bar_full[stage]); - auto const _ = cuda::device::barrier_arrive_tx(sram.bar_full[stage], 1, bytesToArrive); - } else { - auto const _ = sram.bar_full[stage].arrive(); - } - } else { - auto const _ = sram.bar_full[stage].arrive(); - } - } - - // Phase 3: Write only (draining the pipeline) -#pragma unroll - for (int iter = 0; iter < stagesWriteOnly; ++iter) { - auto const stage = (stagesReadOnly + stagesBoth + iter) % numStages; - auto const i_write = (stagesBoth + iter) * colsPerStage; - - sram.bar_empty[stage].wait(sram.bar_empty[stage].arrive()); - - if constexpr (writeState) { - // Unblock async proxy for writeback - cde::fence_proxy_async_shared_cta(); - cde::cp_async_bulk_tensor_4d_shared_to_global(&tensorState, i_write, 0, head, batch, - &sram.state[stage][0]); - cde::cp_async_bulk_commit_group(); - cde::cp_async_bulk_wait_group_read<0>(); - } - } -} - -template -__device__ __forceinline__ void consumer_func_horizontal( - int d, int member, float A_value, float dt_value, float x_value, - SharedStorageHorizontal& sram, - float& out_value) { - namespace cde = cuda::device::experimental; - constexpr auto lanesPerRow = (consumerWarps * warpSize) / DIM; - constexpr auto itemsPerThread = colsPerStage / lanesPerRow; - auto const group = d % (warpSize / lanesPerRow); - - // #pragma unroll 1 - for (int iBegin = 0, stage = 0; iBegin < DSTATE; - iBegin += colsPerStage, stage = (stage + 1) % numStages) { - // wait for the producer - sram.bar_full[stage].wait(sram.bar_full[stage].arrive()); - - constexpr auto bankSize = sizeof(uint32_t); - constexpr auto stateValuesPerBank = bankSize / sizeof(state_t); - constexpr auto numBanks = 32; - if constexpr (sizeof(state_t) == sizeof(input_t)) { -#pragma unroll - for (int item = 0; item < itemsPerThread; item += stateValuesPerBank) { - auto const baseCol = item + member * itemsPerThread; - // If I just use baseCol as the index, a lot of bank conflicts will arise. - // - auto const ii = - conflict_free_column(group, baseCol); - - auto const i = iBegin + ii; - - auto* sState_ptr = reinterpret_cast(&sram.state[stage][d * colsPerStage + ii]); - uint32_t rState = *sState_ptr; - auto* rState_ptr = reinterpret_cast(&rState); - - uint32_t rB = *reinterpret_cast(&sram.B[i]); - auto* rB_ptr = reinterpret_cast(&rB); - - uint32_t rC = *reinterpret_cast(&sram.C[i]); - auto* rC_ptr = reinterpret_cast(&rC); - - for (int e = 0; e < stateValuesPerBank; e++) { - float state_value; - if constexpr (!useStateCache) { - state_value = 0.f; - } else { - state_value = toFloat(rState_ptr[e]); - } - - auto const B_value = toFloat(rB_ptr[e]); - auto const C_value = toFloat(rC_ptr[e]); - - auto const dA = __expf(A_value * dt_value); - auto const dB = B_value * dt_value; - auto const new_state = state_value * dA + dB * x_value; - - convertAndStore(&rState_ptr[e], new_state); - out_value += new_state * C_value; - } - *sState_ptr = rState; - } - } else { - for (int item = 0; item < itemsPerThread; item += stateValuesPerBank) { - auto const baseCol = item + member * itemsPerThread; - auto const ii = - conflict_free_column(group, baseCol); - auto const i = iBegin + ii; - - auto* sState_ptr = reinterpret_cast(&sram.state[stage][d * colsPerStage + ii]); - uint32_t rState = *sState_ptr; - auto* rState_ptr = reinterpret_cast(&rState); - - for (int e = 0; e < stateValuesPerBank; e++) { - float state_value; - if constexpr (!useStateCache) { - state_value = 0.f; - } else { - state_value = toFloat(rState_ptr[e]); - } - - auto const B_value = toFloat(sram.B[i + e]); - auto const C_value = toFloat(sram.C[i + e]); - - auto const dA = __expf(A_value * dt_value); - auto const dB = B_value * dt_value; - auto const new_state = state_value * dA + dB * x_value; - - convertAndStore(&rState_ptr[e], new_state); - out_value += new_state * C_value; - } - *sState_ptr = rState; - } - } - - auto _ = sram.bar_empty[stage].arrive(); - } -} - -template -__global__ void selective_state_update_kernel_producer_consumer_horizontal( - SelectiveStateUpdateParams params, __grid_constant__ CUtensorMap const tensorState) { - auto* __restrict__ output = reinterpret_cast(params.output); - auto const* __restrict__ x = reinterpret_cast(params.x); - auto const* __restrict__ dt = reinterpret_cast(params.dt); - auto const* __restrict__ A = reinterpret_cast(params.A); - auto const* __restrict__ B = reinterpret_cast(params.B); - auto const* __restrict__ C = reinterpret_cast(params.C); - auto const* __restrict__ D = reinterpret_cast(params.D); - auto const* __restrict__ dt_bias = reinterpret_cast(params.dt_bias); - auto const* __restrict__ z = reinterpret_cast(params.z); - auto const* __restrict__ state_batch_indices = - reinterpret_cast(params.state_batch_indices); - - int const nheads = params.nheads; - - constexpr auto numWarps = 1 + consumerWarps; - - auto const batch = blockIdx.x; - auto const head = blockIdx.y; - auto const group = head / headsGroupsRatio; - auto lane = threadIdx.x % warpSize; - auto warp = threadIdx.y; - - auto const state_batch = (state_batch_indices) ? state_batch_indices[batch] : batch; - - extern __shared__ uint8_t sbuffer[]; - using sram_t = SharedStorageHorizontal; - auto& sram = *reinterpret_cast(sbuffer); - - namespace cde = cuda::device::experimental; - namespace cg = cooperative_groups; - - for (int stage = warp; stage < numStages; stage += numWarps) { - if (lane > 0) continue; - constexpr auto num_arrivals = 1 + consumerWarps * warpSize; - init(&sram.bar_empty[stage], num_arrivals); - init(&sram.bar_full[stage], num_arrivals); - // signal to async proxy that barriers are initilized - cde::fence_proxy_async_shared_cta(); - } - if (lane == 0 && warp == 0) { - init(&sram.bar_consumers, warpSize * consumerWarps); - } - __syncthreads(); - - if (warp == consumerWarps) // producer - { - auto const read_state = (state_batch != params.pad_slot_id); - auto const write_state = read_state && params.update_state; - - cg::invoke_one(cg::coalesced_threads(), [&]() { - if (read_state && write_state) - producer_func_horizontal( - sram, tensorState, state_batch, head); - else if (read_state && !write_state) - producer_func_horizontal( - sram, tensorState, state_batch, head); - else - producer_func_horizontal( - sram, tensorState, state_batch, head); - }); - } else { // consumers - - using load_t = PackedAligned; - - // Unblock the producer -#pragma unroll - for (auto stage = 0; stage < numStages; ++stage) { - auto const _ = sram.bar_empty[stage].arrive(); - } - - // Load A - auto const A_value = toFloat(A[head]); - - // Load D - auto const d_value = D ? toFloat(D[head]) : 0.f; - - // load dt_value - auto dt_value = toFloat(dt[batch * params.dt_stride_batch + head]); - if (dt_bias) dt_value += toFloat(dt_bias[head]); - if (params.dt_softplus) { - dt_value = thresholded_softplus(dt_value); - } - - if (warp == 0) { // Load B - for (auto d = lane * load_t::count; d < DSTATE; d += warpSize * load_t::count) { - auto* dst = reinterpret_cast(&sram.B[d]); - *dst = *reinterpret_cast( - &B[batch * params.B_stride_batch + group * DSTATE + d]); - } - } else if (warp == 1) { // Load C - for (auto i = lane * load_t::count; i < DSTATE; i += warpSize * load_t::count) { - auto* dst = reinterpret_cast(&sram.C[i]); - *dst = *reinterpret_cast( - &C[batch * params.C_stride_batch + group * DSTATE + i]); - } - } - - constexpr auto lanesPerRow = (consumerWarps * warpSize) / DIM; - static_assert(lanesPerRow >= 1); - constexpr auto rowsPerWarp = warpSize / lanesPerRow; - auto const group = lane % rowsPerWarp; - auto const member = lane / rowsPerWarp; - auto const d = warp * rowsPerWarp + group; - auto const x_value = toFloat(x[batch * params.x_stride_batch + head * DIM + d]); - auto const z_value = z ? toFloat(z[batch * params.z_stride_batch + head * DIM + d]) : 0.f; - - sram.bar_consumers.wait(sram.bar_consumers.arrive()); - - // Thread - float out_value = 0.f; - if (state_batch != params.pad_slot_id) - consumer_func_horizontal(d, member, A_value, dt_value, x_value, - sram, out_value); - else - consumer_func_horizontal(d, member, A_value, dt_value, - x_value, sram, out_value); - - out_value += __shfl_down_sync(UINT32_MAX, out_value, 16); - if constexpr (lanesPerRow == 4) { - out_value += __shfl_down_sync(UINT32_MAX, out_value, 8); - } - - if (member == 0) { - out_value += d_value * x_value; - - // Write output - if (z) { - float sig_z = __fdividef(1.f, (1.f + __expf(0.f - z_value))); - float silu_z = z_value * sig_z; - out_value *= silu_z; - } - convertAndStore(&output[batch * params.out_stride_batch + head * DIM + d], out_value); - } - } -} - -#endif // FLASHINFER_MAMBA_ENABLE_SM100 - -template -std::string format_array(const T (&arr)[N]) { - std::ostringstream oss; - for (size_t i = 0; i < N; ++i) { - if (i > 0) oss << ", "; - oss << arr[i]; - } - return oss.str(); -} - -// Helper function to dispatch dim and dstate with a kernel launcher -template -void dispatchDimDstate(SelectiveStateUpdateParams& params, - std::integer_sequence, - std::integer_sequence, KernelLauncher&& launcher) { - constexpr int allowed_dims[] = {AllowedDims...}; - constexpr int allowed_dstates[] = {AllowedDstates...}; - - auto dispatch_dim_dstate = [&]() { - launcher.template operator()(); - }; - - auto dispatch_dstate = [&]() { - if (params.dstate == DSTATE) { - dispatch_dim_dstate.template operator()(); - return true; - } - return false; - }; - - auto dispatch_dim = [&]() { - if (params.dim == DIM) { - bool dispatched = (dispatch_dstate.template operator()() || ...); - FLASHINFER_CHECK(dispatched, "Unsupported dstate value: ", params.dstate, - ".\nSupported values: ", format_array(allowed_dstates)); - return true; - } - return false; - }; - - bool dim_dispatched = (dispatch_dim.template operator()() || ...); - FLASHINFER_CHECK(dim_dispatched, "Unsupported dim value: ", params.dim, - ".\nSupported values: ", format_array(allowed_dims)); -} - -// Helper function to dispatch ratio with a kernel launcher -template -void dispatchRatio(SelectiveStateUpdateParams& params, std::integer_sequence, - KernelLauncher&& launcher) { - constexpr int allowed_ratios[] = {AllowedRatios...}; - - auto dispatch_single_ratio = [&]() { - if (params.nheads / params.ngroups == RATIO) { - launcher.template operator()(); - return true; - } - return false; - }; - - bool ratio_dispatched = (dispatch_single_ratio.template operator()() || ...); - FLASHINFER_CHECK(ratio_dispatched, - "Unsupported nheads/ngroups ratio: ", params.nheads / params.ngroups, - ".\nSupported values: ", format_array(allowed_ratios)); -} - -// Check alignment for common input variables (x, z, B, C) -template -void check_ptr_alignment_input_vars(const SelectiveStateUpdateParams& params) { - using load_input_t = PackedAligned; - FLASHINFER_CHECK(reinterpret_cast(params.x) % sizeof(load_input_t) == 0, - "x pointer must be aligned to ", sizeof(load_input_t), " bytes"); - FLASHINFER_CHECK((params.x_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, - "x batch stride must be aligned to ", sizeof(load_input_t), " bytes"); - if (params.z) { - FLASHINFER_CHECK(reinterpret_cast(params.z) % sizeof(load_input_t) == 0, - "z pointer must be aligned to ", sizeof(load_input_t), " bytes"); - FLASHINFER_CHECK((params.z_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, - "z batch stride must be aligned to ", sizeof(load_input_t), " bytes"); - } - FLASHINFER_CHECK(reinterpret_cast(params.B) % sizeof(load_input_t) == 0, - "B pointer must be aligned to ", sizeof(load_input_t), " bytes"); - FLASHINFER_CHECK(reinterpret_cast(params.C) % sizeof(load_input_t) == 0, - "C pointer must be aligned to ", sizeof(load_input_t), " bytes"); - FLASHINFER_CHECK((params.B_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, - "B batch stride must be aligned to ", sizeof(load_input_t), " bytes"); - FLASHINFER_CHECK((params.C_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, - "C batch stride must be aligned to ", sizeof(load_input_t), " bytes"); -} - -template -void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, cudaStream_t stream) { - auto [sm_major, sm_minor] = GetCudaComputeCapability(); - - constexpr int allowed_dstates[] = {64, 128, 256}; - constexpr int allowed_dims[] = {64, 128, 256}; - - // Common alignment checks for all kernels - check_ptr_alignment_input_vars(params); - -#ifdef FLASHINFER_MAMBA_ENABLE_SM100 - if (sm_major < 10) // pre-Blackwell -#elif defined(FLASHINFER_MAMBA_ENABLE_SM90) - if (sm_major < 9) // pre-Hopper -#endif - { - auto kernel_launcher = [&]() { - // Additional alignment checks specific to simple kernel - constexpr auto stateLoadSize = getVectorLoadSizeForFullUtilization(); - using load_state_t = PackedAligned; - - FLASHINFER_CHECK(reinterpret_cast(params.state) % sizeof(load_state_t) == 0, - "state pointer must be aligned to ", sizeof(load_state_t), " bytes"); - FLASHINFER_CHECK((params.dim * params.dstate * sizeof(state_t)) % sizeof(load_state_t) == 0, - "state head stride must be aligned to ", sizeof(load_state_t), " bytes"); - - constexpr int numWarps = 4; - dim3 block(warpSize, numWarps); - dim3 grid(params.batch, params.nheads); - selective_state_update_kernel_simple<<>>(params); - }; - - dispatchDimDstate(params, std::integer_sequence{}, - std::integer_sequence{}, kernel_launcher); - } -#ifdef FLASHINFER_MAMBA_ENABLE_SM90 - else { - - auto kernel_launcher = [&]() { - // Note: State uses TMA which requires 128B alignment (checked below) - constexpr auto numConsumers = 4; - constexpr auto numWarps = 1 + numConsumers; - constexpr auto numStages = 3; - constexpr auto rowsPerStage = 4 * numConsumers; - FLASHINFER_CHECK(params.dim % rowsPerStage == 0, "dim must be divisible by ", rowsPerStage, - " for SM90+ kernel"); - auto scan_func = selective_state_update_kernel_producer_consumer_vertical< - input_t, weight_t, matrixA_t, state_t, DIM, DSTATE, numConsumers, rowsPerStage, - numStages>; - - dim3 block(warpSize, numWarps); - dim3 grid(params.batch, params.nheads); - - auto nh = params.nheads; - auto dim = params.dim; - - auto state_tensor = - tma::buildNdDescriptor(typeid(state_t), - /*shapes*/ {DSTATE, DIM, nh, params.state_cache_size}, - /*strides*/ {1, DSTATE, DSTATE * DIM, params.state_stride_batch}, - /*tiles*/ {DSTATE, rowsPerStage, 1, 1}, params.state); - - // Calculate shared memory size and opt-in to extended shared memory - using sram_t = SharedStorageVertical; - constexpr size_t smem_size = sizeof(sram_t); - FLASHINFER_CUDA_CHECK( - cudaFuncSetAttribute(scan_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - - scan_func<<>>(params, state_tensor); - }; - - dispatchDimDstate(params, std::integer_sequence{}, - std::integer_sequence{}, kernel_launcher); - } -#endif - -#ifdef FLASHINFER_MAMBA_ENABLE_SM100 - else { - // SM100+ (Blackwell and newer) uses horizontal producer-consumer kernel - auto kernel_launcher = [&]() { - // profiling showed that it's good to have 4 producers per 64 rows - constexpr auto numConsumers = (DIM / 64) * 4; - constexpr auto numProducers = 1; - constexpr auto numWarps = numProducers + numConsumers; - - constexpr auto sectorSize = 32; // bytes - constexpr auto stageCols = 2 * sectorSize / sizeof(state_t); - - constexpr auto totalStages = DSTATE / stageCols; - constexpr auto numStages = (totalStages >= 4) ? 4 : totalStages; - - auto ratio_launcher = [&]() { - auto scan_func = selective_state_update_kernel_producer_consumer_horizontal< - input_t, weight_t, matrixA_t, state_t, DIM, DSTATE, numConsumers, stageCols, RATIO, - numStages>; - - dim3 block(warpSize, numWarps); - dim3 grid(params.batch, params.nheads); - - auto nh = params.nheads; - auto dim = params.dim; - - auto state_tensor = - tma::buildNdDescriptor(typeid(state_t), - /*shapes*/ {DSTATE, DIM, nh, params.state_cache_size}, - /*strides*/ {1, DSTATE, DSTATE * DIM, params.state_stride_batch}, - /*tiles*/ {stageCols, DIM, 1, 1}, params.state); - static_assert(DSTATE % stageCols == 0 && DSTATE >= stageCols); - - using sram_t = SharedStorageHorizontal; - constexpr size_t smem_size = sizeof(sram_t); - FLASHINFER_CUDA_CHECK(cudaFuncSetAttribute( - scan_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - - scan_func<<>>(params, state_tensor); - }; - - dispatchRatio(params, std::integer_sequence{}, ratio_launcher); - }; - - dispatchDimDstate(params, std::integer_sequence{}, - std::integer_sequence{}, kernel_launcher); - } -#endif -} +} // namespace mtp } // namespace flashinfer::mamba +#include "selective_state_update_stp.cuh" + #endif // FLASHINFER_MAMBA_SELECTIVE_STATE_UPDATE_CUH_ diff --git a/include/flashinfer/mamba/selective_state_update_stp.cuh b/include/flashinfer/mamba/selective_state_update_stp.cuh new file mode 100644 index 0000000000..5c54d25af1 --- /dev/null +++ b/include/flashinfer/mamba/selective_state_update_stp.cuh @@ -0,0 +1,1105 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// #ifndef FLASHINFER_MAMBA_SELECTIVE_STATE_UPDATE_CUH_ +// #define FLASHINFER_MAMBA_SELECTIVE_STATE_UPDATE_CUH_ + +#include +#include +#include + +#include +#include +#include + +#include "../utils.cuh" +#include "../vec_dtypes.cuh" +#include "conversion.cuh" +#include "create_tensor_map.cuh" + +namespace flashinfer::mamba { + +using namespace conversion; + +// Computes a conflict-free column index for shared memory access. +// This permutation avoids bank conflicts when threads access strided patterns. +// +// Without permutation (baseCol directly): +// Thread 0 -> Bank 0, Thread 32 -> Bank 0, Thread 64 -> Bank 0 (conflict!) +// +// With permutation (adding bankCycle offset): +// bankCycle = which "round" of 32 banks we're in +// By offsetting each round by 1 bank: +// Thread 0 -> Bank 0 +// Thread 32 -> Bank 1 (offset by 1) +// Thread 64 -> Bank 2 (offset by 2) +// +// Visual: (stateValuesPerBank=1, numBanks=32, colsPerStage=128) +// baseCol: 0 1 2 ... 31 | 32 33 34 ... 63 | 64 ... +// bankCycle: 0 0 0 ... 0 | 1 1 1 ... 1 | 2 ... +// ii: 0 1 2 ... 31 | 33 34 35 ... 64 | 66 ... (mod colsPerStage) +template +__device__ __forceinline__ int conflict_free_column(int group, int baseCol) { + auto const seq_index = group * colsPerStage + baseCol; + auto const bankCycle = (seq_index / stateValuesPerBank) / numBanks; + return (baseCol + stateValuesPerBank * bankCycle) % colsPerStage; +} + +template +struct SharedStorageSimple { + alignas(alignof(PackedAligned)) input_t x[dim]; + float out[dim]; + alignas(alignof(PackedAligned)) input_t z[dim]; + alignas(alignof(PackedAligned)) input_t B[dstate]; + alignas(alignof(PackedAligned)) input_t C[dstate]; +}; + +template +__global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams params) { + auto* __restrict__ output = reinterpret_cast(params.output); + auto* __restrict__ state = reinterpret_cast(params.state); + + auto const* __restrict__ x = reinterpret_cast(params.x); + auto const* __restrict__ dt = reinterpret_cast(params.dt); + auto const* __restrict__ A = reinterpret_cast(params.A); + auto const* __restrict__ B = reinterpret_cast(params.B); + auto const* __restrict__ C = reinterpret_cast(params.C); + auto const* __restrict__ D = reinterpret_cast(params.D); // D: (nheads, dim) + auto const* __restrict__ dt_bias = reinterpret_cast(params.dt_bias); // (nheads) + auto const* __restrict__ z = reinterpret_cast(params.z); + auto const* __restrict__ state_batch_indices = + reinterpret_cast(params.state_batch_indices); + bool const dt_softplus = params.dt_softplus; + + int const nheads = params.nheads; + int const ngroups = params.ngroups; + + constexpr auto rowsPerWarp = (DIM + numWarps - 1) / numWarps; + + auto const batch = blockIdx.x; + auto const head = blockIdx.y; + auto const group = head / (nheads / ngroups); + auto lane = threadIdx.x % warpSize; + auto warp = threadIdx.y; + + auto const state_batch = (state_batch_indices) ? state_batch_indices[batch] : batch; + state += state_batch * params.state_stride_batch + head * DIM * DSTATE; + + __shared__ SharedStorageSimple sram; + + static constexpr auto stateLoadSize = getVectorLoadSizeForFullUtilization(); + using load_state_t = PackedAligned; + using load_input_t = PackedAligned; + + auto const A_value = toFloat(A[head]); + + auto dt_value = toFloat(dt[batch * params.dt_stride_batch + head]); + if (dt_bias) dt_value += toFloat(dt_bias[head]); + if (dt_softplus) { + dt_value = thresholded_softplus(dt_value); + } + + auto const dA = __expf(A_value * dt_value); + + auto d_value = D ? toFloat(D[head]) : 0.f; + + if (warp == 0) { + for (auto d = lane * load_input_t::count; d < DIM; d += warpSize * load_input_t::count) { + auto* dst = reinterpret_cast(&sram.x[d]); + *dst = *reinterpret_cast( + &x[batch * params.x_stride_batch + head * DIM + d]); + } + for (auto i = lane * load_input_t::count; i < DSTATE; i += warpSize * load_input_t::count) { + auto* dst = reinterpret_cast(&sram.B[i]); + *dst = *reinterpret_cast( + &B[batch * params.B_stride_batch + group * DSTATE + i]); + } + } else if (warp == 1) { // Load z, C + for (auto d = lane * load_input_t::count; d < DIM; d += warpSize * load_input_t::count) { + auto* dst = reinterpret_cast(&sram.z[d]); + *dst = z ? *reinterpret_cast( + &z[batch * params.z_stride_batch + head * DIM + d]) + : make_zeros(); + } + for (auto i = lane * load_input_t::count; i < DSTATE; i += warpSize * load_input_t::count) { + auto* dst = reinterpret_cast(&sram.C[i]); + *dst = *reinterpret_cast( + &C[batch * params.C_stride_batch + group * DSTATE + i]); + } + } + __syncthreads(); + + for (auto _d = warp * rowsPerWarp; _d < (warp + 1) * rowsPerWarp; _d++) { + auto d = _d; + if (d >= DIM) break; + + float x_value = toFloat(sram.x[_d]); + float out_value = d_value * x_value * int(lane == 0); // first lane has the value + + for (int i = lane * load_state_t::count; i < DSTATE; i += warpSize * load_state_t::count) { + auto rState = make_zeros(); + if (state_batch != params.pad_slot_id) + rState = *reinterpret_cast(&state[d * DSTATE + i]); + + for (int ii = 0; ii < load_state_t::count; ii++) { + auto state_value = toFloat(rState.val[ii]); + auto B_value = toFloat(sram.B[i + ii]); + auto C_value = toFloat(sram.C[i + ii]); + + auto const dB = B_value * dt_value; + auto const new_state = state_value * dA + dB * x_value; + + convertAndStore(&rState.val[ii], new_state); + + out_value += new_state * C_value; + } + if (state_batch != params.pad_slot_id) + *reinterpret_cast(&state[d * DSTATE + i]) = rState; + } + + // warpReduce the out_value + out_value = warpReduceSum(out_value); + if (lane == 0) { + sram.out[_d] = out_value; + } + } + + __syncthreads(); + + for (int l = lane; l < rowsPerWarp; l += warpSize) { + auto d = warp * rowsPerWarp + l; + if (d < DIM) { + auto out_value = sram.out[d]; + if (z) { + float z_value = toFloat(sram.z[d]); + float sig_z = __fdividef(1.f, (1.f + __expf(0.f - z_value))); + float silu_z = z_value * sig_z; + out_value *= silu_z; + } + convertAndStore(&output[batch * params.out_stride_batch + head * DIM + d], out_value); + } + } +} + +template +struct SharedStorageVertical { + alignas(128) state_t state[numStages][rowsPerStage * dstate]; + alignas(alignof(PackedAligned)) input_t x[dim]; + float out[dim]; // dt is special cause we're gonna store input in there as well + alignas(alignof(PackedAligned)) input_t z[dim]; + alignas(alignof(PackedAligned)) input_t B[dstate]; + alignas(alignof(PackedAligned)) input_t C[dstate]; + + using barrier_t = cuda::barrier; + barrier_t bar_empty[numStages]; + barrier_t bar_full[numStages]; + barrier_t bar_consumers; +}; + +template +__device__ __forceinline__ void producer_func_vertical(SramT& sram, CUtensorMap const& tensorState, + int batch, int head) { +#ifdef FLASHINFER_MAMBA_ENABLE_SM90 + namespace cde = cuda::device::experimental; + + auto constexpr stagesReadOnly = numStages; + auto constexpr stagesBoth = DIM / rowsPerStage - numStages; + auto constexpr stagesWriteOnly = numStages; + + auto constexpr bytesState = rowsPerStage * DSTATE * sizeof(state_t); + auto constexpr bytesToArrive = bytesState; + + // Phase 1: Read only (filling the pipeline) +#pragma unroll + for (int iter = 0; iter < stagesReadOnly; ++iter) { + auto const stage = iter % numStages; + auto const d = iter * rowsPerStage; + + sram.bar_empty[stage].wait(sram.bar_empty[stage].arrive()); + + if constexpr (readState) { + cde::cp_async_bulk_tensor_4d_global_to_shared(&sram.state[stage][0], &tensorState, 0, d, head, + batch, sram.bar_full[stage]); + + auto const _ = cuda::device::barrier_arrive_tx(sram.bar_full[stage], 1, bytesToArrive); + } else { + auto const _ = sram.bar_full[stage].arrive(); + } + } + + // Phase 2: Both read and write (steady state) +#pragma unroll + for (int iter = 0; iter < stagesBoth; ++iter) { + auto const stage = (stagesReadOnly + iter) % numStages; + auto const d_read = (stagesReadOnly + iter) * rowsPerStage; + auto const d_write = iter * rowsPerStage; + + sram.bar_empty[stage].wait(sram.bar_empty[stage].arrive()); + + if constexpr (readState || writeState) { + // Unblock async proxy for writeback + cde::fence_proxy_async_shared_cta(); + // Writeback + if constexpr (writeState) { + cde::cp_async_bulk_tensor_4d_shared_to_global(&tensorState, 0, d_write, head, batch, + &sram.state[stage][0]); + + cde::cp_async_bulk_commit_group(); + cde::cp_async_bulk_wait_group_read<0>(); + } + + // Read next + if constexpr (readState) { + cde::cp_async_bulk_tensor_4d_global_to_shared(&sram.state[stage][0], &tensorState, 0, + d_read, head, batch, sram.bar_full[stage]); + auto const _ = cuda::device::barrier_arrive_tx(sram.bar_full[stage], 1, bytesToArrive); + } else { + auto const _ = sram.bar_full[stage].arrive(); + } + } else { + auto const _ = sram.bar_full[stage].arrive(); + } + } + + // Phase 3: Write only (draining the pipeline) +#pragma unroll + for (int iter = 0; iter < stagesWriteOnly; ++iter) { + auto const stage = (stagesReadOnly + stagesBoth + iter) % numStages; + auto const d_write = (stagesBoth + iter) * rowsPerStage; + + sram.bar_empty[stage].wait(sram.bar_empty[stage].arrive()); + + if constexpr (writeState) { + // Unblock async proxy for writeback + cde::fence_proxy_async_shared_cta(); + cde::cp_async_bulk_tensor_4d_shared_to_global(&tensorState, 0, d_write, head, batch, + &sram.state[stage][0]); + + cde::cp_async_bulk_commit_group(); + cde::cp_async_bulk_wait_group_read<0>(); + } + } +#endif +} + +template +__device__ __forceinline__ void consumer_func_vertical( + int lane, int warp, float d_value, float dt_value, float dA, + SharedStorageVertical& sram) { +#ifdef FLASHINFER_MAMBA_ENABLE_SM90 + namespace cde = cuda::device::experimental; + for (auto dBegin = 0, stage = 0; dBegin < DIM; + dBegin += rowsPerStage, stage = (stage + 1) % numStages) { + // wait for the producer + sram.bar_full[stage].wait(sram.bar_full[stage].arrive()); + +#pragma unroll + for (auto dd = warp; dd < rowsPerStage; dd += consumerWarps) { + auto d = dBegin + dd; + float const x_value = toFloat(sram.x[d]); + float out_value = d_value * x_value * int(lane == 0); // first lane has the value + + constexpr auto bankSize = sizeof(uint32_t); + constexpr auto stateValuesPerBank = bankSize / sizeof(state_t); + + if constexpr (sizeof(state_t) == sizeof(input_t)) { + for (int i = lane * stateValuesPerBank; i < DSTATE; i += warpSize * stateValuesPerBank) { + auto* sState_ptr = reinterpret_cast(&sram.state[stage][dd * DSTATE + i]); + uint32_t rState = *sState_ptr; + auto* rState_ptr = reinterpret_cast(&rState); + + uint32_t rB = *reinterpret_cast(&sram.B[i]); + auto* rB_ptr = reinterpret_cast(&rB); + + uint32_t rC = *reinterpret_cast(&sram.C[i]); + auto* rC_ptr = reinterpret_cast(&rC); + + for (int e = 0; e < stateValuesPerBank; e++) { + float state_value; + if constexpr (!useStateCache) { + state_value = 0.f; + } else { + state_value = toFloat(rState_ptr[e]); + } + auto const B_value = toFloat(rB_ptr[e]); + auto const C_value = toFloat(rC_ptr[e]); + + auto const dB = B_value * dt_value; + auto const new_state = state_value * dA + dB * x_value; + + convertAndStore(&rState_ptr[e], new_state); + out_value += new_state * C_value; + } + *sState_ptr = rState; + } + } else { + for (int i = lane * stateValuesPerBank; i < DSTATE; i += warpSize * stateValuesPerBank) { + auto* sState_ptr = reinterpret_cast(&sram.state[stage][dd * DSTATE + i]); + uint32_t rState = *sState_ptr; + auto* rState_ptr = reinterpret_cast(&rState); + + for (int e = 0; e < stateValuesPerBank; e++) { + float state_value; + if constexpr (!useStateCache) { + state_value = 0.f; + } else { + state_value = toFloat(rState_ptr[e]); + } + auto const B_value = toFloat(sram.B[i + e]); + auto const C_value = toFloat(sram.C[i + e]); + auto const dB = B_value * dt_value; + auto const new_state = state_value * dA + dB * x_value; + + convertAndStore(&rState_ptr[e], new_state); + out_value += new_state * C_value; + } + *sState_ptr = rState; + } + } + + out_value = warpReduceSum(out_value); + if (lane == 0) { + sram.out[d] = out_value; + } + } + + // Unblock producer + cde::fence_proxy_async_shared_cta(); + auto _ = sram.bar_empty[stage].arrive(); + } +#endif +} + +template +__global__ void selective_state_update_kernel_producer_consumer_vertical( + SelectiveStateUpdateParams params, __grid_constant__ CUtensorMap const tensorState) { +#ifdef FLASHINFER_MAMBA_ENABLE_SM90 + auto* __restrict__ output = reinterpret_cast(params.output); + + auto const* __restrict__ x = reinterpret_cast(params.x); + auto const* __restrict__ dt = reinterpret_cast(params.dt); + auto const* __restrict__ A = reinterpret_cast(params.A); + auto const* __restrict__ B = reinterpret_cast(params.B); + auto const* __restrict__ C = reinterpret_cast(params.C); + auto const* __restrict__ D = reinterpret_cast(params.D); + auto const* __restrict__ dt_bias = reinterpret_cast(params.dt_bias); + auto const* __restrict__ z = reinterpret_cast(params.z); + auto const* __restrict__ state_batch_indices = + reinterpret_cast(params.state_batch_indices); + + int const nheads = params.nheads; + int const ngroups = params.ngroups; + + constexpr auto numWarps = 1 + consumerWarps; + + auto const batch = blockIdx.x; + auto const head = blockIdx.y; + auto const group = head / (nheads / ngroups); + auto lane = threadIdx.x % warpSize; + auto warp = threadIdx.y; + + auto const state_batch = (state_batch_indices) ? state_batch_indices[batch] : batch; + + extern __shared__ uint8_t sbuffer[]; + using sram_t = SharedStorageVertical; + auto& sram = *reinterpret_cast(sbuffer); + + namespace cde = cuda::device::experimental; + namespace cg = cooperative_groups; + + for (int stage = warp; stage < numStages; stage += numWarps) { + if (lane > 0) continue; + constexpr auto num_arrivals = 1 + consumerWarps * warpSize; + init(&sram.bar_empty[stage], num_arrivals); + init(&sram.bar_full[stage], num_arrivals); + // signal to async proxy that barriers are initilized + cde::fence_proxy_async_shared_cta(); + } + if (lane == 0 && warp == 0) { + init(&sram.bar_consumers, warpSize * consumerWarps); + } + __syncthreads(); + + if (warp == consumerWarps) // producer + { + // auto const state_offset = (state_batch * nheads + head) * DIM; + auto const read_state = (state_batch != params.pad_slot_id); + auto const write_state = read_state && params.update_state; + + if (lane == 0) { + cg::invoke_one(cg::coalesced_threads(), [&]() { + if (read_state && write_state) + producer_func_vertical( + sram, tensorState, state_batch, head); + else if (read_state && !write_state) + producer_func_vertical( + sram, tensorState, state_batch, head); + else + producer_func_vertical( + sram, tensorState, state_batch, head); + }); + } + } else { // consumers + + using load_t = PackedAligned; + +#pragma unroll + // Unblock the producer + for (uint8_t stage = 0; stage < numStages; ++stage) { + auto const _ = sram.bar_empty[stage].arrive(); + } + + // Load A + auto const A_value = toFloat(A[head]); + + // Load D + auto const d_value = D ? toFloat(D[head]) : 0.f; + + // load dt_value + auto dt_value = toFloat(dt[batch * params.dt_stride_batch + head]); + if (dt_bias) dt_value += toFloat(dt_bias[head]); + if (params.dt_softplus) { + dt_value = thresholded_softplus(dt_value); + } + auto const dA = __expf(A_value * dt_value); + + if (warp == 0) { // Load x, B + for (auto d = lane * load_t::count; d < DIM; d += warpSize * load_t::count) { + auto* dst = reinterpret_cast(&sram.x[d]); + *dst = *reinterpret_cast(&x[batch * params.x_stride_batch + head * DIM + d]); + } + for (auto i = lane * load_t::count; i < DSTATE; i += warpSize * load_t::count) { + auto* dst = reinterpret_cast(&sram.B[i]); + *dst = *reinterpret_cast( + &B[batch * params.B_stride_batch + group * DSTATE + i]); + } + } else if (warp == 1) { // Load z, C + for (auto d = lane * load_t::count; d < DIM; d += warpSize * load_t::count) { + auto* dst = reinterpret_cast(&sram.z[d]); + *dst = + z ? *reinterpret_cast(&z[batch * params.z_stride_batch + head * DIM + d]) + : make_zeros(); + } + for (auto i = lane * load_t::count; i < DSTATE; i += warpSize * load_t::count) { + auto* dst = reinterpret_cast(&sram.C[i]); + *dst = *reinterpret_cast( + &C[batch * params.C_stride_batch + group * DSTATE + i]); + } + } + + sram.bar_consumers.wait(sram.bar_consumers.arrive()); + + if (state_batch != params.pad_slot_id) + consumer_func_vertical(lane, warp, d_value, dt_value, dA, + sram); + else + consumer_func_vertical(lane, warp, d_value, dt_value, dA, + sram); + + // Write output + sram.bar_consumers.wait(sram.bar_consumers.arrive()); + auto d = warp * warpSize + lane; + if (d < DIM) { + auto out_value = sram.out[d]; + if (z) { + float z_value = toFloat(sram.z[d]); + float sig_z = __fdividef(1.f, (1.f + __expf(0.f - z_value))); + float silu_z = z_value * sig_z; + out_value *= silu_z; + } + convertAndStore(&output[batch * params.out_stride_batch + head * DIM + d], out_value); + } + } +#endif +} + +// ============================================================================= +// Horizontal Producer-Consumer Kernel for SM100+ (Blackwell and newer) +// ============================================================================= + +#ifdef FLASHINFER_MAMBA_ENABLE_SM100 + +template +struct SharedStorageHorizontal { + alignas(128) state_t state[numStages][dim * stageCols]; + alignas(alignof(PackedAligned)) input_t B[dstate]; + alignas(alignof(PackedAligned)) input_t C[dstate]; + + using barrier_t = cuda::barrier; + barrier_t bar_empty[numStages]; + barrier_t bar_full[numStages]; + barrier_t bar_consumers; +}; + +template +__device__ __forceinline__ void producer_func_horizontal(SramT& sram, + CUtensorMap const& tensorState, int batch, + int head) { + namespace cde = cuda::device::experimental; + + auto constexpr stagesReadOnly = numStages; + auto constexpr stagesBoth = DSTATE / colsPerStage - numStages; + auto constexpr stagesWriteOnly = numStages; + + auto constexpr bytesState = DIM * colsPerStage * sizeof(state_t); + auto constexpr bytesToArrive = bytesState; + + // Phase 1: Read only (filling the pipeline) +#pragma unroll + for (int iter = 0; iter < stagesReadOnly; ++iter) { + auto const stage = iter % numStages; + auto const i = iter * colsPerStage; + + sram.bar_empty[stage].wait(sram.bar_empty[stage].arrive()); + + if constexpr (readState) { + cde::cp_async_bulk_tensor_4d_global_to_shared(&sram.state[stage][0], &tensorState, i, 0, head, + batch, sram.bar_full[stage]); + auto const _ = cuda::device::barrier_arrive_tx(sram.bar_full[stage], 1, bytesToArrive); + } else { + auto const _ = sram.bar_full[stage].arrive(); + } + } + + // Phase 2: Both read and write (steady state) +#pragma unroll + for (int iter = 0; iter < stagesBoth; ++iter) { + auto const stage = (stagesReadOnly + iter) % numStages; + auto const i_read = (stagesReadOnly + iter) * colsPerStage; + auto const i_write = iter * colsPerStage; + + sram.bar_empty[stage].wait(sram.bar_empty[stage].arrive()); + + if constexpr (readState || writeState) { + // Unblock async proxy for writeback + cde::fence_proxy_async_shared_cta(); + // Writeback + if constexpr (writeState) { + cde::cp_async_bulk_tensor_4d_shared_to_global(&tensorState, i_write, 0, head, batch, + &sram.state[stage][0]); + cde::cp_async_bulk_commit_group(); + cde::cp_async_bulk_wait_group_read<0>(); + } + + // Read next + if constexpr (readState) { + cde::cp_async_bulk_tensor_4d_global_to_shared(&sram.state[stage][0], &tensorState, i_read, + 0, head, batch, sram.bar_full[stage]); + auto const _ = cuda::device::barrier_arrive_tx(sram.bar_full[stage], 1, bytesToArrive); + } else { + auto const _ = sram.bar_full[stage].arrive(); + } + } else { + auto const _ = sram.bar_full[stage].arrive(); + } + } + + // Phase 3: Write only (draining the pipeline) +#pragma unroll + for (int iter = 0; iter < stagesWriteOnly; ++iter) { + auto const stage = (stagesReadOnly + stagesBoth + iter) % numStages; + auto const i_write = (stagesBoth + iter) * colsPerStage; + + sram.bar_empty[stage].wait(sram.bar_empty[stage].arrive()); + + if constexpr (writeState) { + // Unblock async proxy for writeback + cde::fence_proxy_async_shared_cta(); + cde::cp_async_bulk_tensor_4d_shared_to_global(&tensorState, i_write, 0, head, batch, + &sram.state[stage][0]); + cde::cp_async_bulk_commit_group(); + cde::cp_async_bulk_wait_group_read<0>(); + } + } +} + +template +__device__ __forceinline__ void consumer_func_horizontal( + int d, int member, float A_value, float dt_value, float x_value, + SharedStorageHorizontal& sram, + float& out_value) { + namespace cde = cuda::device::experimental; + constexpr auto lanesPerRow = (consumerWarps * warpSize) / DIM; + constexpr auto itemsPerThread = colsPerStage / lanesPerRow; + auto const group = d % (warpSize / lanesPerRow); + + // #pragma unroll 1 + for (int iBegin = 0, stage = 0; iBegin < DSTATE; + iBegin += colsPerStage, stage = (stage + 1) % numStages) { + // wait for the producer + sram.bar_full[stage].wait(sram.bar_full[stage].arrive()); + + constexpr auto bankSize = sizeof(uint32_t); + constexpr auto stateValuesPerBank = bankSize / sizeof(state_t); + constexpr auto numBanks = 32; + if constexpr (sizeof(state_t) == sizeof(input_t)) { +#pragma unroll + for (int item = 0; item < itemsPerThread; item += stateValuesPerBank) { + auto const baseCol = item + member * itemsPerThread; + // If I just use baseCol as the index, a lot of bank conflicts will arise. + // + auto const ii = + conflict_free_column(group, baseCol); + + auto const i = iBegin + ii; + + auto* sState_ptr = reinterpret_cast(&sram.state[stage][d * colsPerStage + ii]); + uint32_t rState = *sState_ptr; + auto* rState_ptr = reinterpret_cast(&rState); + + uint32_t rB = *reinterpret_cast(&sram.B[i]); + auto* rB_ptr = reinterpret_cast(&rB); + + uint32_t rC = *reinterpret_cast(&sram.C[i]); + auto* rC_ptr = reinterpret_cast(&rC); + + for (int e = 0; e < stateValuesPerBank; e++) { + float state_value; + if constexpr (!useStateCache) { + state_value = 0.f; + } else { + state_value = toFloat(rState_ptr[e]); + } + + auto const B_value = toFloat(rB_ptr[e]); + auto const C_value = toFloat(rC_ptr[e]); + + auto const dA = __expf(A_value * dt_value); + auto const dB = B_value * dt_value; + auto const new_state = state_value * dA + dB * x_value; + + convertAndStore(&rState_ptr[e], new_state); + out_value += new_state * C_value; + } + *sState_ptr = rState; + } + } else { + for (int item = 0; item < itemsPerThread; item += stateValuesPerBank) { + auto const baseCol = item + member * itemsPerThread; + auto const ii = + conflict_free_column(group, baseCol); + auto const i = iBegin + ii; + + auto* sState_ptr = reinterpret_cast(&sram.state[stage][d * colsPerStage + ii]); + uint32_t rState = *sState_ptr; + auto* rState_ptr = reinterpret_cast(&rState); + + for (int e = 0; e < stateValuesPerBank; e++) { + float state_value; + if constexpr (!useStateCache) { + state_value = 0.f; + } else { + state_value = toFloat(rState_ptr[e]); + } + + auto const B_value = toFloat(sram.B[i + e]); + auto const C_value = toFloat(sram.C[i + e]); + + auto const dA = __expf(A_value * dt_value); + auto const dB = B_value * dt_value; + auto const new_state = state_value * dA + dB * x_value; + + convertAndStore(&rState_ptr[e], new_state); + out_value += new_state * C_value; + } + *sState_ptr = rState; + } + } + + auto _ = sram.bar_empty[stage].arrive(); + } +} + +template +__global__ void selective_state_update_kernel_producer_consumer_horizontal( + SelectiveStateUpdateParams params, __grid_constant__ CUtensorMap const tensorState) { + auto* __restrict__ output = reinterpret_cast(params.output); + auto const* __restrict__ x = reinterpret_cast(params.x); + auto const* __restrict__ dt = reinterpret_cast(params.dt); + auto const* __restrict__ A = reinterpret_cast(params.A); + auto const* __restrict__ B = reinterpret_cast(params.B); + auto const* __restrict__ C = reinterpret_cast(params.C); + auto const* __restrict__ D = reinterpret_cast(params.D); + auto const* __restrict__ dt_bias = reinterpret_cast(params.dt_bias); + auto const* __restrict__ z = reinterpret_cast(params.z); + auto const* __restrict__ state_batch_indices = + reinterpret_cast(params.state_batch_indices); + + int const nheads = params.nheads; + + constexpr auto numWarps = 1 + consumerWarps; + + auto const batch = blockIdx.x; + auto const head = blockIdx.y; + auto const group = head / headsGroupsRatio; + auto lane = threadIdx.x % warpSize; + auto warp = threadIdx.y; + + auto const state_batch = (state_batch_indices) ? state_batch_indices[batch] : batch; + + extern __shared__ uint8_t sbuffer[]; + using sram_t = SharedStorageHorizontal; + auto& sram = *reinterpret_cast(sbuffer); + + namespace cde = cuda::device::experimental; + namespace cg = cooperative_groups; + + for (int stage = warp; stage < numStages; stage += numWarps) { + if (lane > 0) continue; + constexpr auto num_arrivals = 1 + consumerWarps * warpSize; + init(&sram.bar_empty[stage], num_arrivals); + init(&sram.bar_full[stage], num_arrivals); + // signal to async proxy that barriers are initilized + cde::fence_proxy_async_shared_cta(); + } + if (lane == 0 && warp == 0) { + init(&sram.bar_consumers, warpSize * consumerWarps); + } + __syncthreads(); + + if (warp == consumerWarps) // producer + { + auto const read_state = (state_batch != params.pad_slot_id); + auto const write_state = read_state && params.update_state; + + cg::invoke_one(cg::coalesced_threads(), [&]() { + if (read_state && write_state) + producer_func_horizontal( + sram, tensorState, state_batch, head); + else if (read_state && !write_state) + producer_func_horizontal( + sram, tensorState, state_batch, head); + else + producer_func_horizontal( + sram, tensorState, state_batch, head); + }); + } else { // consumers + + using load_t = PackedAligned; + + // Unblock the producer +#pragma unroll + for (auto stage = 0; stage < numStages; ++stage) { + auto const _ = sram.bar_empty[stage].arrive(); + } + + // Load A + auto const A_value = toFloat(A[head]); + + // Load D + auto const d_value = D ? toFloat(D[head]) : 0.f; + + // load dt_value + auto dt_value = toFloat(dt[batch * params.dt_stride_batch + head]); + if (dt_bias) dt_value += toFloat(dt_bias[head]); + if (params.dt_softplus) { + dt_value = thresholded_softplus(dt_value); + } + + if (warp == 0) { // Load B + for (auto d = lane * load_t::count; d < DSTATE; d += warpSize * load_t::count) { + auto* dst = reinterpret_cast(&sram.B[d]); + *dst = *reinterpret_cast( + &B[batch * params.B_stride_batch + group * DSTATE + d]); + } + } else if (warp == 1) { // Load C + for (auto i = lane * load_t::count; i < DSTATE; i += warpSize * load_t::count) { + auto* dst = reinterpret_cast(&sram.C[i]); + *dst = *reinterpret_cast( + &C[batch * params.C_stride_batch + group * DSTATE + i]); + } + } + + constexpr auto lanesPerRow = (consumerWarps * warpSize) / DIM; + static_assert(lanesPerRow >= 1); + constexpr auto rowsPerWarp = warpSize / lanesPerRow; + auto const group = lane % rowsPerWarp; + auto const member = lane / rowsPerWarp; + auto const d = warp * rowsPerWarp + group; + auto const x_value = toFloat(x[batch * params.x_stride_batch + head * DIM + d]); + auto const z_value = z ? toFloat(z[batch * params.z_stride_batch + head * DIM + d]) : 0.f; + + sram.bar_consumers.wait(sram.bar_consumers.arrive()); + + // Thread + float out_value = 0.f; + if (state_batch != params.pad_slot_id) + consumer_func_horizontal(d, member, A_value, dt_value, x_value, + sram, out_value); + else + consumer_func_horizontal(d, member, A_value, dt_value, + x_value, sram, out_value); + + out_value += __shfl_down_sync(UINT32_MAX, out_value, 16); + if constexpr (lanesPerRow == 4) { + out_value += __shfl_down_sync(UINT32_MAX, out_value, 8); + } + + if (member == 0) { + out_value += d_value * x_value; + + // Write output + if (z) { + float sig_z = __fdividef(1.f, (1.f + __expf(0.f - z_value))); + float silu_z = z_value * sig_z; + out_value *= silu_z; + } + convertAndStore(&output[batch * params.out_stride_batch + head * DIM + d], out_value); + } + } +} + +#endif // FLASHINFER_MAMBA_ENABLE_SM100 + +template +std::string format_array(const T (&arr)[N]) { + std::ostringstream oss; + for (size_t i = 0; i < N; ++i) { + if (i > 0) oss << ", "; + oss << arr[i]; + } + return oss.str(); +} + +// Helper function to dispatch dim and dstate with a kernel launcher +template +void dispatchDimDstate(SelectiveStateUpdateParams& params, + std::integer_sequence, + std::integer_sequence, KernelLauncher&& launcher) { + constexpr int allowed_dims[] = {AllowedDims...}; + constexpr int allowed_dstates[] = {AllowedDstates...}; + + auto dispatch_dim_dstate = [&]() { + launcher.template operator()(); + }; + + auto dispatch_dstate = [&]() { + if (params.dstate == DSTATE) { + dispatch_dim_dstate.template operator()(); + return true; + } + return false; + }; + + auto dispatch_dim = [&]() { + if (params.dim == DIM) { + bool dispatched = (dispatch_dstate.template operator()() || ...); + FLASHINFER_CHECK(dispatched, "Unsupported dstate value: ", params.dstate, + ".\nSupported values: ", format_array(allowed_dstates)); + return true; + } + return false; + }; + + bool dim_dispatched = (dispatch_dim.template operator()() || ...); + FLASHINFER_CHECK(dim_dispatched, "Unsupported dim value: ", params.dim, + ".\nSupported values: ", format_array(allowed_dims)); +} + +// Helper function to dispatch ratio with a kernel launcher +template +void dispatchRatio(SelectiveStateUpdateParams& params, std::integer_sequence, + KernelLauncher&& launcher) { + constexpr int allowed_ratios[] = {AllowedRatios...}; + + auto dispatch_single_ratio = [&]() { + if (params.nheads / params.ngroups == RATIO) { + launcher.template operator()(); + return true; + } + return false; + }; + + bool ratio_dispatched = (dispatch_single_ratio.template operator()() || ...); + FLASHINFER_CHECK(ratio_dispatched, + "Unsupported nheads/ngroups ratio: ", params.nheads / params.ngroups, + ".\nSupported values: ", format_array(allowed_ratios)); +} + +// Check alignment for common input variables (x, z, B, C) +template +void check_ptr_alignment_input_vars(const SelectiveStateUpdateParams& params) { + using load_input_t = PackedAligned; + FLASHINFER_CHECK(reinterpret_cast(params.x) % sizeof(load_input_t) == 0, + "x pointer must be aligned to ", sizeof(load_input_t), " bytes"); + FLASHINFER_CHECK((params.x_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, + "x batch stride must be aligned to ", sizeof(load_input_t), " bytes"); + if (params.z) { + FLASHINFER_CHECK(reinterpret_cast(params.z) % sizeof(load_input_t) == 0, + "z pointer must be aligned to ", sizeof(load_input_t), " bytes"); + FLASHINFER_CHECK((params.z_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, + "z batch stride must be aligned to ", sizeof(load_input_t), " bytes"); + } + FLASHINFER_CHECK(reinterpret_cast(params.B) % sizeof(load_input_t) == 0, + "B pointer must be aligned to ", sizeof(load_input_t), " bytes"); + FLASHINFER_CHECK(reinterpret_cast(params.C) % sizeof(load_input_t) == 0, + "C pointer must be aligned to ", sizeof(load_input_t), " bytes"); + FLASHINFER_CHECK((params.B_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, + "B batch stride must be aligned to ", sizeof(load_input_t), " bytes"); + FLASHINFER_CHECK((params.C_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, + "C batch stride must be aligned to ", sizeof(load_input_t), " bytes"); +} + +template +void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, cudaStream_t stream) { + auto [sm_major, sm_minor] = GetCudaComputeCapability(); + + constexpr int allowed_dstates[] = {64, 128, 256}; + constexpr int allowed_dims[] = {64, 128, 256}; + + // Common alignment checks for all kernels + check_ptr_alignment_input_vars(params); + +#ifdef FLASHINFER_MAMBA_ENABLE_SM100 + if (sm_major < 10) // pre-Blackwell +#elif defined(FLASHINFER_MAMBA_ENABLE_SM90) + if (sm_major < 9) // pre-Hopper +#endif + { + auto kernel_launcher = [&]() { + // Additional alignment checks specific to simple kernel + constexpr auto stateLoadSize = getVectorLoadSizeForFullUtilization(); + using load_state_t = PackedAligned; + + FLASHINFER_CHECK(reinterpret_cast(params.state) % sizeof(load_state_t) == 0, + "state pointer must be aligned to ", sizeof(load_state_t), " bytes"); + FLASHINFER_CHECK((params.dim * params.dstate * sizeof(state_t)) % sizeof(load_state_t) == 0, + "state head stride must be aligned to ", sizeof(load_state_t), " bytes"); + + constexpr int numWarps = 4; + dim3 block(warpSize, numWarps); + dim3 grid(params.batch, params.nheads); + selective_state_update_kernel_simple<<>>(params); + }; + + dispatchDimDstate(params, std::integer_sequence{}, + std::integer_sequence{}, kernel_launcher); + } +#ifdef FLASHINFER_MAMBA_ENABLE_SM90 + else { + + auto kernel_launcher = [&]() { + // Note: State uses TMA which requires 128B alignment (checked below) + constexpr auto numConsumers = 4; + constexpr auto numWarps = 1 + numConsumers; + constexpr auto numStages = 3; + constexpr auto rowsPerStage = 4 * numConsumers; + FLASHINFER_CHECK(params.dim % rowsPerStage == 0, "dim must be divisible by ", rowsPerStage, + " for SM90+ kernel"); + auto scan_func = selective_state_update_kernel_producer_consumer_vertical< + input_t, weight_t, matrixA_t, state_t, stateIndex_t, DIM, DSTATE, numConsumers, + rowsPerStage, numStages>; + + dim3 block(warpSize, numWarps); + dim3 grid(params.batch, params.nheads); + + auto nh = params.nheads; + auto dim = params.dim; + + auto state_tensor = + tma::buildNdDescriptor(typeid(state_t), + /*shapes*/ {DSTATE, DIM, nh, params.state_cache_size}, + /*strides*/ {1, DSTATE, DSTATE * DIM, params.state_stride_batch}, + /*tiles*/ {DSTATE, rowsPerStage, 1, 1}, params.state); + + // Calculate shared memory size and opt-in to extended shared memory + using sram_t = SharedStorageVertical; + constexpr size_t smem_size = sizeof(sram_t); + FLASHINFER_CUDA_CHECK( + cudaFuncSetAttribute(scan_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + scan_func<<>>(params, state_tensor); + }; + + dispatchDimDstate(params, std::integer_sequence{}, + std::integer_sequence{}, kernel_launcher); + } +#endif + +#ifdef FLASHINFER_MAMBA_ENABLE_SM100 + else { + // SM100+ (Blackwell and newer) uses horizontal producer-consumer kernel + auto kernel_launcher = [&]() { + // profiling showed that it's good to have 4 producers per 64 rows + constexpr auto numConsumers = (DIM / 64) * 4; + constexpr auto numProducers = 1; + constexpr auto numWarps = numProducers + numConsumers; + + constexpr auto sectorSize = 32; // bytes + constexpr auto stageCols = 2 * sectorSize / sizeof(state_t); + + constexpr auto totalStages = DSTATE / stageCols; + constexpr auto numStages = (totalStages >= 4) ? 4 : totalStages; + + auto ratio_launcher = [&]() { + auto scan_func = selective_state_update_kernel_producer_consumer_horizontal< + input_t, weight_t, matrixA_t, state_t, stateIndex_t, DIM, DSTATE, numConsumers, + stageCols, RATIO, numStages>; + + dim3 block(warpSize, numWarps); + dim3 grid(params.batch, params.nheads); + + auto nh = params.nheads; + auto dim = params.dim; + + auto state_tensor = + tma::buildNdDescriptor(typeid(state_t), + /*shapes*/ {DSTATE, DIM, nh, params.state_cache_size}, + /*strides*/ {1, DSTATE, DSTATE * DIM, params.state_stride_batch}, + /*tiles*/ {stageCols, DIM, 1, 1}, params.state); + static_assert(DSTATE % stageCols == 0 && DSTATE >= stageCols); + + using sram_t = SharedStorageHorizontal; + constexpr size_t smem_size = sizeof(sram_t); + FLASHINFER_CUDA_CHECK(cudaFuncSetAttribute( + scan_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + scan_func<<>>(params, state_tensor); + }; + + dispatchRatio(params, std::integer_sequence{}, ratio_launcher); + }; + + dispatchDimDstate(params, std::integer_sequence{}, + std::integer_sequence{}, kernel_launcher); + } +#endif +} + +} // namespace flashinfer::mamba diff --git a/tests/mamba/test_selective_state_update_mtp.py b/tests/mamba/test_selective_state_update_mtp.py new file mode 100644 index 0000000000..afa1b6d202 --- /dev/null +++ b/tests/mamba/test_selective_state_update_mtp.py @@ -0,0 +1,714 @@ +""" +Multi-Token Prediction (MTP) tests for selective_state_update. + +These tests verify the selective_state_update kernel works correctly with +multi-token inputs (batch, T, nheads, dim) for speculative decoding scenarios. +""" + +import numpy as np +import pytest +import torch + +import flashinfer + +from .selective_state_update_triton import selective_state_update_triton +from .test_utils import create_test_inputs, clone_preserving_strides + + +class TestSelectiveStateUpdateMTP: + """Test class for multi-token selective state update kernels.""" + + # Test configuration + ATOL = 1e-3 + RTOL = 1e-2 + NGROUPS = 8 + INPUT_DTYPE = torch.bfloat16 + MATRIX_A_DTYPE = torch.float32 + + @pytest.fixture(params=[1, 4]) + def batch(self, request): + return request.param + + @pytest.fixture(params=[8, 32]) + def nheads(self, request): + return request.param + + @pytest.fixture(params=[64, 128]) + def dim(self, request): + return request.param + + @pytest.fixture(params=[64, 128]) + def dstate(self, request): + return request.param + + @pytest.fixture(params=[1, 4, 8]) + def cache_steps(self, request): + """Number of tokens in multi-token mode (T dimension).""" + return request.param + + @pytest.fixture(params=[torch.float32, torch.bfloat16]) + def state_dtype(self, request): + return request.param + + @pytest.fixture(params=[torch.float32]) + def weight_dtype(self, request): + return request.param + + @pytest.fixture(params=[False, True]) + def use_out_tensor(self, request): + return request.param + + @pytest.fixture + def inputs( + self, batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype + ): + """Create test inputs for given parameters.""" + return create_test_inputs( + batch, + nheads, + dim, + dstate, + self.NGROUPS, + self.INPUT_DTYPE, + weight_dtype=weight_dtype, + matrixA_dtype=self.MATRIX_A_DTYPE, + state_dtype=state_dtype, + generate_z=False, + generate_intermediate_states_buffer=False, + cache_steps=cache_steps, + seed=0, + ) + + @pytest.fixture + def reference_output(self, inputs): + """Compute reference output using triton implementation.""" + state_ref = clone_preserving_strides(inputs["state_cache"]) + y_ref = selective_state_update_triton( + state_ref, + inputs["x"], + inputs["dt"], + inputs["A"], + inputs["B"], + inputs["C"], + D=inputs["D"], + z=inputs.get("z"), + dt_bias=inputs["dt_bias"], + dt_softplus=True, + state_batch_indices=inputs["slot_idx"], + pad_slot_id=-1, + ) + return y_ref, state_ref + + def run_kernel(self, inputs, out=None, disable_state_update=False): + """Run the flashinfer kernel and return output.""" + return flashinfer.mamba.selective_state_update( + inputs["state_cache"], + inputs["x"], + inputs["dt"], + inputs["A"], + inputs["B"], + inputs["C"], + D=inputs["D"], + z=inputs.get("z"), + dt_bias=inputs["dt_bias"], + dt_softplus=True, + state_batch_indices=inputs["slot_idx"], + pad_slot_id=-1, + out=out, + disable_state_update=disable_state_update, + ) + + def assert_outputs_match(self, y_ref, y_test, msg_prefix=""): + """Assert outputs match with detailed error reporting.""" + outputs_match = torch.allclose(y_ref, y_test, atol=self.ATOL, rtol=self.RTOL) + + if outputs_match: + print( + f"✓ {msg_prefix}Outputs match within tolerance (atol={self.ATOL}, rtol={self.RTOL})" + ) + else: + print( + f"✗ {msg_prefix}Outputs do NOT match within tolerance (atol={self.ATOL}, rtol={self.RTOL})" + ) + self._print_mismatch_details(y_ref, y_test, "output") + + assert outputs_match + + def assert_states_match(self, state_ref, state_test, slot_idx, msg_prefix=""): + """Assert states match with detailed error reporting.""" + state_ref_batch = state_ref[slot_idx] + state_test_batch = state_test[slot_idx] + states_match = torch.allclose( + state_ref_batch, state_test_batch, atol=self.ATOL, rtol=self.RTOL + ) + + if states_match: + print( + f"✓ {msg_prefix}States match within tolerance (atol={self.ATOL}, rtol={self.RTOL})" + ) + else: + print( + f"✗ {msg_prefix}States do NOT match within tolerance (atol={self.ATOL}, rtol={self.RTOL})" + ) + self._print_mismatch_details(state_ref_batch, state_test_batch, "state") + + assert states_match + + def _print_mismatch_details(self, ref, test, name): + """Print detailed mismatch analysis.""" + ref_np = ref.detach().cpu().float().numpy() + test_np = test.detach().cpu().float().numpy() + + mismatch_mask = ~np.isclose(ref_np, test_np, atol=self.ATOL, rtol=self.RTOL) + num_mismatches = np.sum(mismatch_mask) + total_elements = ref_np.size + + print(f"\nDetailed {name} mismatch analysis:") + print( + f"Number of mismatched elements: {num_mismatches} / {total_elements} " + f"({100 * num_mismatches / total_elements:.2f}%)" + ) + + mismatch_indices = np.argwhere(mismatch_mask) + print(f"First few {name} mismatch locations (up to 10):") + for idx in mismatch_indices[:10]: + idx_tuple = tuple(int(i) for i in idx) + ref_val = ref_np[idx_tuple] + test_val = test_np[idx_tuple] + diff = abs(ref_val - test_val) + rel_diff = diff / (abs(ref_val) + 1e-8) + print( + f" Index {idx_tuple}: ref={ref_val:.6f}, test={test_val:.6f}, " + f"diff={diff:.6e}, rel_diff={rel_diff:.6e}" + ) + + def test_output_correctness(self, inputs, reference_output, use_out_tensor): + """Test that kernel output matches reference within tolerance.""" + y_ref, state_ref = reference_output + + # Prepare output tensor if requested + if use_out_tensor: + out = torch.empty_like(inputs["x"]) + else: + out = None + + y_test = self.run_kernel(inputs, out=out) + + # Verify output tensor identity if provided + if use_out_tensor: + assert y_test.data_ptr() == out.data_ptr(), ( + "Returned tensor should be the same object as the provided output tensor" + ) + + self.assert_outputs_match(y_ref, y_test) + self.assert_states_match(state_ref, inputs["state_cache"], inputs["slot_idx"]) + + +class TestSelectiveStateUpdateMTPWithZ(TestSelectiveStateUpdateMTP): + """Test multi-token selective_state_update with z tensor (gating).""" + + @pytest.fixture(params=[4]) + def batch(self, request): + return request.param + + @pytest.fixture(params=[8]) + def nheads(self, request): + return request.param + + @pytest.fixture(params=[64]) + def dim(self, request): + return request.param + + @pytest.fixture(params=[64]) + def dstate(self, request): + return request.param + + @pytest.fixture(params=[4]) + def cache_steps(self, request): + return request.param + + @pytest.fixture(params=[torch.bfloat16]) + def state_dtype(self, request): + return request.param + + @pytest.fixture(params=[torch.float32]) + def weight_dtype(self, request): + return request.param + + @pytest.fixture + def inputs( + self, batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype + ): + """Create test inputs with z tensor.""" + return create_test_inputs( + batch, + nheads, + dim, + dstate, + self.NGROUPS, + self.INPUT_DTYPE, + weight_dtype=weight_dtype, + matrixA_dtype=self.MATRIX_A_DTYPE, + state_dtype=state_dtype, + generate_z=True, + generate_intermediate_states_buffer=False, + cache_steps=cache_steps, + seed=0, + ) + + +class TestSelectiveStateUpdateMTPDisableStateUpdate(TestSelectiveStateUpdateMTP): + """Test multi-token selective_state_update with disable_state_update=True.""" + + @pytest.fixture(params=[4]) + def batch(self, request): + return request.param + + @pytest.fixture(params=[32]) + def nheads(self, request): + return request.param + + @pytest.fixture(params=[64]) + def dim(self, request): + return request.param + + @pytest.fixture(params=[64, 128]) + def dstate(self, request): + return request.param + + @pytest.fixture(params=[4, 8]) + def cache_steps(self, request): + return request.param + + @pytest.fixture(params=[torch.bfloat16]) + def state_dtype(self, request): + return request.param + + @pytest.fixture(params=[torch.float32]) + def weight_dtype(self, request): + return request.param + + def test_output_correctness(self, inputs, reference_output, use_out_tensor): + """Test that kernel output matches reference but state is not updated.""" + y_ref, state_ref = reference_output + + # Save the initial state before running the kernel + state_initial = inputs["state_cache"].clone() + + # Prepare output tensor if requested + if use_out_tensor: + out = torch.empty_like(inputs["x"]) + else: + out = None + + y_test = self.run_kernel(inputs, out=out, disable_state_update=True) + + # Verify output tensor identity if provided + if use_out_tensor: + assert y_test.data_ptr() == out.data_ptr(), ( + "Returned tensor should be the same object as the provided output tensor" + ) + + # Check that output is still correct + self.assert_outputs_match(y_ref, y_test, msg_prefix="[disable_state_update] ") + + # Check that state was NOT updated (should remain the same as initial) + state_after = inputs["state_cache"] + state_unchanged = torch.allclose( + state_initial, state_after, atol=1e-8, rtol=1e-8 + ) + + if state_unchanged: + print("✓ [disable_state_update] State cache was not modified (as expected)") + else: + print( + "✗ [disable_state_update] State cache was modified (should remain unchanged!)" + ) + # Show where state changed + state_initial_np = state_initial.detach().cpu().float().numpy() + state_after_np = state_after.detach().cpu().float().numpy() + mismatch_mask = ~np.isclose( + state_initial_np, state_after_np, atol=1e-8, rtol=1e-8 + ) + num_changed = np.sum(mismatch_mask) + print( + f"Number of changed state elements: {num_changed} / {state_initial_np.size}" + ) + + assert state_unchanged, ( + "State should not be updated when disable_state_update=True" + ) + + +class TestSelectiveStateUpdateMTPWithIntermediateStates(TestSelectiveStateUpdateMTP): + """Test multi-token selective_state_update with intermediate states buffer.""" + + @pytest.fixture(params=[4]) + def batch(self, request): + return request.param + + @pytest.fixture(params=[32]) + def nheads(self, request): + return request.param + + @pytest.fixture(params=[64]) + def dim(self, request): + return request.param + + @pytest.fixture(params=[64, 128]) + def dstate(self, request): + return request.param + + @pytest.fixture(params=[2, 4, 8]) + def cache_steps(self, request): + return request.param + + @pytest.fixture(params=[torch.bfloat16]) + def state_dtype(self, request): + return request.param + + @pytest.fixture(params=[torch.float32]) + def weight_dtype(self, request): + return request.param + + @pytest.fixture + def inputs( + self, batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype + ): + """Create test inputs with intermediate states buffer.""" + return create_test_inputs( + batch, + nheads, + dim, + dstate, + self.NGROUPS, + self.INPUT_DTYPE, + weight_dtype=weight_dtype, + matrixA_dtype=self.MATRIX_A_DTYPE, + state_dtype=state_dtype, + generate_z=False, + generate_intermediate_states_buffer=True, + cache_steps=cache_steps, + seed=0, + ) + + @pytest.fixture + def reference_output(self, inputs): + """Compute reference output using triton implementation with intermediate states.""" + state_ref = clone_preserving_strides(inputs["state_cache"]) + intermediate_states_ref = inputs["intermediate_states_buffer"].clone() + + y_ref = selective_state_update_triton( + state_ref, + inputs["x"], + inputs["dt"], + inputs["A"], + inputs["B"], + inputs["C"], + D=inputs["D"], + z=inputs.get("z"), + dt_bias=inputs["dt_bias"], + dt_softplus=True, + state_batch_indices=inputs["slot_idx"], + pad_slot_id=-1, + disable_state_update=True, + intermediate_states_buffer=intermediate_states_ref, + cache_steps=inputs["cache_steps"], + intermediate_state_indices=inputs["intermediate_slot_idx"], + ) + return y_ref, state_ref, intermediate_states_ref + + def run_kernel_with_intermediate_states(self, inputs, out=None): + """Run the flashinfer kernel with intermediate states buffer.""" + return flashinfer.mamba.selective_state_update( + inputs["state_cache"], + inputs["x"], + inputs["dt"], + inputs["A"], + inputs["B"], + inputs["C"], + D=inputs["D"], + z=inputs.get("z"), + dt_bias=inputs["dt_bias"], + dt_softplus=True, + state_batch_indices=inputs["slot_idx"], + pad_slot_id=-1, + out=out, + disable_state_update=True, + intermediate_states_buffer=inputs["intermediate_states_buffer"], + intermediate_state_indices=inputs["intermediate_slot_idx"], + cache_steps=inputs["cache_steps"], + ) + + def test_output_correctness(self, inputs, reference_output, use_out_tensor): + """Test that kernel output matches and intermediate states are cached correctly.""" + y_ref, state_ref, intermediate_states_ref = reference_output + + # Prepare output tensor if requested + if use_out_tensor: + out = torch.empty_like(inputs["x"]) + else: + out = None + + y_test = self.run_kernel_with_intermediate_states(inputs, out=out) + + # Verify output tensor identity if provided + if use_out_tensor: + assert y_test.data_ptr() == out.data_ptr(), ( + "Returned tensor should be the same object as the provided output tensor" + ) + + # Check output + self.assert_outputs_match(y_ref, y_test, msg_prefix="[intermediate_states] ") + + # Check intermediate states were cached correctly + cache_steps = inputs["cache_steps"] + intermediate_states_test = inputs["intermediate_states_buffer"] + + for t in range(cache_steps): + cached_state_ref = intermediate_states_ref[:, t, :, :, :] + cached_state_test = intermediate_states_test[:, t, :, :, :] + + states_match = torch.allclose( + cached_state_ref, cached_state_test, atol=self.ATOL, rtol=self.RTOL + ) + + max_diff = (cached_state_ref - cached_state_test).abs().max().item() + if states_match: + print(f"✓ Intermediate state {t} matches (max_diff={max_diff:.6e})") + else: + print(f"✗ Intermediate state {t} mismatch (max_diff={max_diff:.6e})") + self._print_mismatch_details( + cached_state_ref, cached_state_test, f"intermediate_state_{t}" + ) + + assert states_match, f"Intermediate state at step {t} mismatch" + + +class TestSelectiveStateUpdateMTPNonContiguous(TestSelectiveStateUpdateMTP): + """Test multi-token selective_state_update with non-contiguous state cache.""" + + @pytest.fixture(params=[4]) + def batch(self, request): + return request.param + + @pytest.fixture(params=[8]) + def nheads(self, request): + return request.param + + @pytest.fixture(params=[64]) + def dim(self, request): + return request.param + + @pytest.fixture(params=[64]) + def dstate(self, request): + return request.param + + @pytest.fixture(params=[4]) + def cache_steps(self, request): + return request.param + + @pytest.fixture(params=[torch.bfloat16]) + def state_dtype(self, request): + return request.param + + @pytest.fixture(params=[torch.float32]) + def weight_dtype(self, request): + return request.param + + @pytest.fixture + def inputs( + self, batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype + ): + """Create test inputs with non-contiguous state cache (2x batch stride).""" + noncontiguous_batch_stride = 2 * nheads * dim * dstate + + return create_test_inputs( + batch, + nheads, + dim, + dstate, + self.NGROUPS, + self.INPUT_DTYPE, + weight_dtype=weight_dtype, + matrixA_dtype=self.MATRIX_A_DTYPE, + state_dtype=state_dtype, + generate_z=False, + generate_intermediate_states_buffer=False, + state_cache_batch_stride=noncontiguous_batch_stride, + cache_steps=cache_steps, + seed=0, + ) + + @pytest.fixture + def reference_output(self, inputs): + """Compute reference output, preserving non-contiguous strides.""" + state_ref = clone_preserving_strides(inputs["state_cache"]) + y_ref = selective_state_update_triton( + state_ref, + inputs["x"], + inputs["dt"], + inputs["A"], + inputs["B"], + inputs["C"], + D=inputs["D"], + z=inputs.get("z"), + dt_bias=inputs["dt_bias"], + dt_softplus=True, + state_batch_indices=inputs["slot_idx"], + pad_slot_id=-1, + ) + return y_ref, state_ref + + +class TestSelectiveStateUpdateMTPInt32Indices(TestSelectiveStateUpdateMTP): + """Test multi-token selective_state_update with int32 state_batch_indices.""" + + @pytest.fixture(params=[4]) + def batch(self, request): + return request.param + + @pytest.fixture(params=[8]) + def nheads(self, request): + return request.param + + @pytest.fixture(params=[64]) + def dim(self, request): + return request.param + + @pytest.fixture(params=[64]) + def dstate(self, request): + return request.param + + @pytest.fixture(params=[4]) + def cache_steps(self, request): + return request.param + + @pytest.fixture(params=[torch.bfloat16]) + def state_dtype(self, request): + return request.param + + @pytest.fixture(params=[torch.float32]) + def weight_dtype(self, request): + return request.param + + def run_kernel(self, inputs, out=None, disable_state_update=False): + """Run the flashinfer kernel with int32 state_batch_indices.""" + # Cast slot_idx to int32 + slot_idx_int32 = inputs["slot_idx"].to(torch.int32) + + return flashinfer.mamba.selective_state_update( + inputs["state_cache"], + inputs["x"], + inputs["dt"], + inputs["A"], + inputs["B"], + inputs["C"], + D=inputs["D"], + z=inputs.get("z"), + dt_bias=inputs["dt_bias"], + dt_softplus=True, + state_batch_indices=slot_idx_int32, + pad_slot_id=-1, + out=out, + disable_state_update=disable_state_update, + ) + + +class TestSelectiveStateUpdateMTPVariousNgroups(TestSelectiveStateUpdateMTP): + """Test multi-token selective_state_update with various ngroups values.""" + + NGROUPS = None # Will be set by fixture + + @pytest.fixture(params=[1, 2, 4, 8]) + def ngroups(self, request): + return request.param + + @pytest.fixture(params=[4]) + def batch(self, request): + return request.param + + @pytest.fixture(params=[32]) + def nheads(self, request): + return request.param + + @pytest.fixture(params=[64]) + def dim(self, request): + return request.param + + @pytest.fixture(params=[64]) + def dstate(self, request): + return request.param + + @pytest.fixture(params=[4]) + def cache_steps(self, request): + return request.param + + @pytest.fixture(params=[torch.bfloat16]) + def state_dtype(self, request): + return request.param + + @pytest.fixture(params=[torch.float32]) + def weight_dtype(self, request): + return request.param + + @pytest.fixture + def inputs( + self, + batch, + nheads, + dim, + dstate, + cache_steps, + state_dtype, + weight_dtype, + ngroups, + ): + """Create test inputs with specified ngroups.""" + return create_test_inputs( + batch, + nheads, + dim, + dstate, + ngroups, + self.INPUT_DTYPE, + weight_dtype=weight_dtype, + matrixA_dtype=self.MATRIX_A_DTYPE, + state_dtype=state_dtype, + generate_z=False, + generate_intermediate_states_buffer=False, + cache_steps=cache_steps, + seed=0, + ) + + +class TestSelectiveStateUpdateMTPLargeBatch(TestSelectiveStateUpdateMTP): + """Test multi-token selective_state_update with larger batch sizes.""" + + @pytest.fixture(params=[16, 64]) + def batch(self, request): + return request.param + + @pytest.fixture(params=[32]) + def nheads(self, request): + return request.param + + @pytest.fixture(params=[64]) + def dim(self, request): + return request.param + + @pytest.fixture(params=[64]) + def dstate(self, request): + return request.param + + @pytest.fixture(params=[4, 8]) + def cache_steps(self, request): + return request.param + + @pytest.fixture(params=[torch.bfloat16]) + def state_dtype(self, request): + return request.param + + @pytest.fixture(params=[torch.float32]) + def weight_dtype(self, request): + return request.param diff --git a/tests/mamba/test_selective_state_update_stp.py b/tests/mamba/test_selective_state_update_stp.py index 9e87cbe0c8..4aba115a20 100644 --- a/tests/mamba/test_selective_state_update_stp.py +++ b/tests/mamba/test_selective_state_update_stp.py @@ -5,16 +5,7 @@ import flashinfer from .selective_state_update_triton import selective_state_update_triton -from .test_utils import create_test_inputs - - -def clone_preserving_strides(tensor): - """Clone a tensor while preserving its strides (non-contiguous layout).""" - result = torch.empty_strided( - tensor.size(), tensor.stride(), dtype=tensor.dtype, device=tensor.device - ) - result.copy_(tensor) - return result +from .test_utils import create_test_inputs, clone_preserving_strides class TestSelectiveStateUpdate: @@ -399,3 +390,52 @@ def reference_output(self, inputs): pad_slot_id=-1, ) return y_ref, state_ref + + +class TestSelectiveStateUpdateInt32Indices(TestSelectiveStateUpdate): + """Test selective_state_update with int32 state_batch_indices.""" + + @pytest.fixture(params=[1]) + def batch(self, request): + return request.param + + @pytest.fixture(params=[8]) + def nheads(self, request): + return request.param + + @pytest.fixture(params=[64]) + def dim(self, request): + return request.param + + @pytest.fixture(params=[128]) + def dstate(self, request): + return request.param + + @pytest.fixture(params=[torch.bfloat16]) + def state_dtype(self, request): + return request.param + + @pytest.fixture(params=[torch.bfloat16]) + def weight_dtype(self, request): + return request.param + + def run_kernel(self, inputs, out=None): + """Run the flashinfer kernel with int32 state_batch_indices.""" + # Cast slot_idx to int32 + slot_idx_int32 = inputs["slot_idx"].to(torch.int32) + + return flashinfer.mamba.selective_state_update( + inputs["state_cache"], + inputs["x"], + inputs["dt"], + inputs["A"], + inputs["B"], + inputs["C"], + D=inputs["D"], + z=inputs.get("z"), + dt_bias=inputs["dt_bias"], + dt_softplus=True, + state_batch_indices=slot_idx_int32, + pad_slot_id=-1, + out=out, + ) diff --git a/tests/mamba/test_utils.py b/tests/mamba/test_utils.py index 64995b823b..240d501d8c 100644 --- a/tests/mamba/test_utils.py +++ b/tests/mamba/test_utils.py @@ -4,6 +4,15 @@ import torch +def clone_preserving_strides(tensor): + """Clone a tensor while preserving its strides (non-contiguous layout).""" + result = torch.empty_strided( + tensor.size(), tensor.stride(), dtype=tensor.dtype, device=tensor.device + ) + result.copy_(tensor) + return result + + def create_test_inputs( batch_size: int, nheads: int, @@ -155,7 +164,7 @@ def create_test_inputs( dt_bias = dt_bias_base.as_strided((nheads, dim), (1, 0)) # Slot indices for state batching - (batch_size,) - slot_idx = torch.randperm(ssm_state_cache_size, dtype=torch.int32, device=device)[ + slot_idx = torch.randperm(ssm_state_cache_size, dtype=torch.int64, device=device)[ :batch_size ] From 1cb4ac734d77129c4893b37a03ce346f642f3b0a Mon Sep 17 00:00:00 2001 From: Igor Shovkun Date: Fri, 30 Jan 2026 08:36:47 -0800 Subject: [PATCH 11/33] Refactor Mamba selective state update kernel dispatch and add dtype checks - Add common.cuh with kernel dispatch helpers and alignment checks - Split and rename kernel_selective_state_update_stp.cuh, add kernel_selective_state_update_mtp.cuh - Refactor Python selective_state_update to clarify dimension handling - Add test for dtype mismatch between state_batch_indices and intermediate_state_indices - Update test_utils to generate int64 intermediate_slot_idx by default - Remove redundant input type check in validate_intermediate_state_indices --- csrc/selective_state_update.cu | 42 ++- flashinfer/mamba/selective_state_update.py | 39 +- include/flashinfer/mamba/common.cuh | 196 ++++++++++ .../kernel_selective_state_update_mtp.cuh | 335 ++++++++++++++++++ ... => kernel_selective_state_update_stp.cuh} | 107 +----- .../mamba/selective_state_update.cuh | 87 ++--- .../mamba/test_selective_state_update_mtp.py | 59 +++ .../mamba/test_selective_state_update_stp.py | 47 +++ tests/mamba/test_utils.py | 4 +- 9 files changed, 728 insertions(+), 188 deletions(-) create mode 100644 include/flashinfer/mamba/common.cuh create mode 100644 include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh rename include/flashinfer/mamba/{selective_state_update_stp.cuh => kernel_selective_state_update_stp.cuh} (89%) diff --git a/csrc/selective_state_update.cu b/csrc/selective_state_update.cu index d925f1ccdb..d8ada31ed8 100644 --- a/csrc/selective_state_update.cu +++ b/csrc/selective_state_update.cu @@ -86,7 +86,6 @@ inline void validate_intermediate_state_indices( CHECK_CONTIGUOUS(intermediate_state_indices.value()); FLASHINFER_CHECK(intermediate_state_indices.value().size(0) == batch, "intermediate_state_indices.shape must be (", batch, ")"); - CHECK_INPUT_TYPE(intermediate_state_indices.value(), dl_int32); } inline void validate_intermediate_states_buffer( @@ -174,7 +173,7 @@ constexpr std::tuple allowed_dtype_ {float32_code, bfloat16_code, float32_code, float32_code, int64_code}, }; -// Helper to dispatch to the right template instantiation +// Helper to dispatch to the right template instantiation for STP template void dispatchCombo(SelectiveStateUpdateParams& p, cudaStream_t stream) { @@ -186,6 +185,19 @@ void dispatchCombo(SelectiveStateUpdateParams& p, cudaStream_t stream) { invokeSelectiveStateUpdate(p, stream); } +// Helper to dispatch to the right template instantiation for MTP +template +void dispatchComboMTP(mtp::SelectiveStateMTPParams& p, cudaStream_t stream) { + using state_t = typename DTypeToType::type; + using input_t = typename DTypeToType::type; + using weight_t = typename DTypeToType::type; + using matrixA_t = typename DTypeToType::type; + using stateIndex_t = typename DTypeToType::type; + mtp::invokeSelectiveStateUpdateMTP(p, + stream); +} + void run_selective_state_update_stp(TensorView const& state, TensorView const& x, TensorView const& dt, TensorView const& A, TensorView const& B, TensorView const& C, TensorView const& D, @@ -586,16 +598,27 @@ void run_selective_state_update_mtp( int64_t weight_dtype_code = encode_dlpack_dtype(weight_dtype); int64_t matrixA_dtype_code = encode_dlpack_dtype(matrixA_dtype); - // Get intermediate_state_indices dtype, default to int32 if not provided - int64_t intermediateStateIndex_dtype_code = int32_code; - if (intermediate_state_indices.has_value()) { - DLDataType intermediateStateIndex_dtype = intermediate_state_indices.value().dtype(); - intermediateStateIndex_dtype_code = encode_dlpack_dtype(intermediateStateIndex_dtype); + // Get stateIndex dtype from whichever index tensor is available + // If both are provided, they must have the same dtype + int64_t stateIndex_dtype_code = int32_code; // default + if (state_batch_indices.has_value() && intermediate_state_indices.has_value()) { + DLDataType state_batch_idx_dtype = state_batch_indices.value().dtype(); + DLDataType intermediate_idx_dtype = intermediate_state_indices.value().dtype(); + FLASHINFER_CHECK(state_batch_idx_dtype.code == intermediate_idx_dtype.code && + state_batch_idx_dtype.bits == intermediate_idx_dtype.bits, + "state_batch_indices and intermediate_state_indices must have the same dtype"); + stateIndex_dtype_code = encode_dlpack_dtype(state_batch_idx_dtype); + } else if (state_batch_indices.has_value()) { + DLDataType state_batch_idx_dtype = state_batch_indices.value().dtype(); + stateIndex_dtype_code = encode_dlpack_dtype(state_batch_idx_dtype); + } else if (intermediate_state_indices.has_value()) { + DLDataType intermediate_idx_dtype = intermediate_state_indices.value().dtype(); + stateIndex_dtype_code = encode_dlpack_dtype(intermediate_idx_dtype); } // Dispatch kernel based on dtype combination auto dtype_key = std::make_tuple(state_dtype_code, input_dtype_code, weight_dtype_code, - matrixA_dtype_code, intermediateStateIndex_dtype_code); + matrixA_dtype_code, stateIndex_dtype_code); // Compile-time recursive dispatcher using Y-combinator pattern for lambda self-recursion auto tryDispatch = [&](const auto& key, auto idx, auto& self) -> bool { @@ -608,7 +631,7 @@ void run_selective_state_update_mtp( constexpr auto w = std::get<2>(combo); constexpr auto m = std::get<3>(combo); constexpr auto si = std::get<4>(combo); - dispatchCombo(p, stream); + dispatchComboMTP(p, stream); return true; } return self(key, std::integral_constant{}, self); @@ -650,7 +673,6 @@ void selective_state_update(TensorView state, TensorView x, TensorView dt, Tenso run_selective_state_update_stp(state, x, dt, A, B, C, D, z, dt_bias, dt_softplus, state_batch_indices, pad_slot_id, output, disable_state_update); } else if (x.dim() == 4) { - FLASHINFER_CHECK(false, "x must have 3 dimensions (single-token), got ", x.dim()); run_selective_state_update_mtp( state, x, dt, A, B, C, D, z, dt_bias, dt_softplus, state_batch_indices, pad_slot_id, output, disable_state_update, intermediate_states_buffer, intermediate_state_indices, cache_steps); diff --git a/flashinfer/mamba/selective_state_update.py b/flashinfer/mamba/selective_state_update.py index 07ef073b33..734b0f5c10 100644 --- a/flashinfer/mamba/selective_state_update.py +++ b/flashinfer/mamba/selective_state_update.py @@ -133,24 +133,47 @@ def selective_state_update( Output tensor with shape (batch, dim) or (batch, nheads, dim) for single-token or (batch, T, nheads, dim) for multi-token """ + # Determine if we're in multi-token mode (more than 1 token) + is_mtp = cache_steps >= 1 + if state.dim() == 3: state = state.unsqueeze(1) + if A.dim() == 2: + A = A.unsqueeze(0) + if D.dim() == 1: + D = D.unsqueeze(0) + if dt_bias is not None and dt_bias.dim() == 1: + dt_bias = dt_bias.unsqueeze(0) + + # Handle x, dt, B, C, z dimensions based on mode + # For single-token: 2D -> 3D (batch, nheads, dim) + # For multi-token: 3D -> 4D (batch, T, nheads, dim) if x.dim() == 2: x = x.unsqueeze(1) + if is_mtp and x.dim() == 3: + # Add T dimension for MTP mode: (batch, nheads, dim) -> (batch, T, nheads, dim) + x = x.unsqueeze(1) + if dt.dim() == 2: dt = dt.unsqueeze(1) - if A.dim() == 2: - A = A.unsqueeze(0) + if is_mtp and dt.dim() == 3: + dt = dt.unsqueeze(1) + if B.dim() == 2: B = B.unsqueeze(1) + if is_mtp and B.dim() == 3: + B = B.unsqueeze(1) + if C.dim() == 2: C = C.unsqueeze(1) - if D.dim() == 1: - D = D.unsqueeze(0) - if z is not None and z.dim() == 2: - z = z.unsqueeze(1) - if dt_bias is not None and dt_bias.dim() == 1: - dt_bias = dt_bias.unsqueeze(0) + if is_mtp and C.dim() == 3: + C = C.unsqueeze(1) + + if z is not None: + if z.dim() == 2: + z = z.unsqueeze(1) + if is_mtp and z.dim() == 3: + z = z.unsqueeze(1) if out is None: output = torch.empty_like(x) else: diff --git a/include/flashinfer/mamba/common.cuh b/include/flashinfer/mamba/common.cuh new file mode 100644 index 0000000000..f7e89029de --- /dev/null +++ b/include/flashinfer/mamba/common.cuh @@ -0,0 +1,196 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_MAMBA_COMMON_CUH_ +#define FLASHINFER_MAMBA_COMMON_CUH_ + +#include + +#include +#include +#include + +#include "../utils.cuh" + +namespace flashinfer::mamba { + +constexpr unsigned warpSize = 32; + +// ============================================================================= +// Common types and utilities +// ============================================================================= + +// Simple packed vector type for loading N elements of type T +template +struct alignas(N * sizeof(T)) PackedAligned { + T val[N]; + static constexpr int count = N; + using dtype = T; +}; + +template +__device__ __forceinline__ auto make_zeros() -> load_t { + load_t ret{}; +#pragma unroll + for (int i = 0; i < ret.count; i++) + ret.val[i] = typename load_t::dtype{}; // default initialization + return ret; +}; + +// Computes the vector load size that ensures full warp utilization. +// Avoids cases like: dstate=64, load_t = sizeof(float4)/sizeof(f16), warpsize=32 (32 * 8 > 64) +// in which case a part of the warp would be idle. +template +inline constexpr auto getVectorLoadSizeForFullUtilization() -> unsigned { + static_assert(sizeof(float4) >= sizeof(T)); + constexpr unsigned maxHardwareLoadSize = sizeof(float4) / sizeof(T); + constexpr unsigned maxLogicalLoadSize = (unsigned)DSTATE / warpSize; + return maxHardwareLoadSize < maxLogicalLoadSize ? maxHardwareLoadSize : maxLogicalLoadSize; +} + +__device__ __forceinline__ float warpReduceSum(float val) { + for (int s = warpSize / 2; s > 0; s /= 2) { + val += __shfl_down_sync(UINT32_MAX, val, s); + } + return val; +} + +__forceinline__ __device__ float softplus(float x) { return __logf(1.f + __expf(x)); } + +__device__ __forceinline__ float thresholded_softplus(float dt_value) { + constexpr float threshold = 20.f; + return (dt_value <= threshold) ? softplus(dt_value) : dt_value; +} + +// ============================================================================= +// Dispatch helpers +// ============================================================================= + +// Format an integer_sequence as a comma-separated string for error messages +template +std::string format_sequence(std::integer_sequence) { + std::ostringstream oss; + bool first = true; + ((oss << (first ? (first = false, "") : ", ") << Values), ...); + return oss.str(); +} + +// Helper function to dispatch dim and dstate with a kernel launcher +template +void dispatchDimDstate(ParamsType& params, std::integer_sequence dims_seq, + std::integer_sequence dstates_seq, + KernelLauncher&& launcher) { + auto dispatch_dstate = [&]() { + auto try_dstate = [&]() { + if (params.dstate == DSTATE) { + launcher.template operator()(); + return true; + } + return false; + }; + bool dispatched = (try_dstate.template operator()() || ...); + FLASHINFER_CHECK(dispatched, "Unsupported dstate value: ", params.dstate, + ".\nSupported values: ", format_sequence(dstates_seq)); + }; + + auto try_dim = [&]() { + if (params.dim == DIM) { + dispatch_dstate.template operator()(); + return true; + } + return false; + }; + + bool dim_dispatched = (try_dim.template operator()() || ...); + FLASHINFER_CHECK(dim_dispatched, "Unsupported dim value: ", params.dim, + ".\nSupported values: ", format_sequence(dims_seq)); +} + +// Helper function to dispatch ratio with a kernel launcher +template +void dispatchRatio(ParamsType& params, std::integer_sequence ratios_seq, + KernelLauncher&& launcher) { + auto try_ratio = [&]() { + if (params.nheads / params.ngroups == RATIO) { + launcher.template operator()(); + return true; + } + return false; + }; + + bool ratio_dispatched = (try_ratio.template operator()() || ...); + FLASHINFER_CHECK(ratio_dispatched, + "Unsupported nheads/ngroups ratio: ", params.nheads / params.ngroups, + ".\nSupported values: ", format_sequence(ratios_seq)); +} + +// Helper function to dispatch dim, dstate, and ntokens_mtp with a kernel launcher +// Reuses dispatchDimDstate by wrapping the launcher to add token dispatch +template +void dispatchDimDstateTokens(ParamsType& params, + std::integer_sequence dims_seq, + std::integer_sequence dstates_seq, + std::integer_sequence tokens_seq, + KernelLauncher&& launcher) { + // Wrap the launcher to add token dispatch as the innermost level + auto dim_dstate_launcher = [&]() { + auto try_tokens = [&]() { + if (params.ntokens_mtp == TOKENS_MTP) { + launcher.template operator()(); + return true; + } + return false; + }; + bool dispatched = (try_tokens.template operator()() || ...); + FLASHINFER_CHECK(dispatched, "Unsupported ntokens_mtp value: ", params.ntokens_mtp, + ".\nSupported values: ", format_sequence(tokens_seq)); + }; + + dispatchDimDstate(params, dims_seq, dstates_seq, dim_dstate_launcher); +} + +// ============================================================================= +// Alignment checks +// ============================================================================= + +// Check alignment for common input variables (x, z, B, C) +// Works for both STP (SelectiveStateUpdateParams) and MTP (SelectiveStateMTPParams) +template +void check_ptr_alignment_input_vars(const ParamsType& params) { + using load_input_t = PackedAligned; + FLASHINFER_CHECK(reinterpret_cast(params.x) % sizeof(load_input_t) == 0, + "x pointer must be aligned to ", sizeof(load_input_t), " bytes"); + FLASHINFER_CHECK((params.x_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, + "x batch stride must be aligned to ", sizeof(load_input_t), " bytes"); + if (params.z) { + FLASHINFER_CHECK(reinterpret_cast(params.z) % sizeof(load_input_t) == 0, + "z pointer must be aligned to ", sizeof(load_input_t), " bytes"); + FLASHINFER_CHECK((params.z_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, + "z batch stride must be aligned to ", sizeof(load_input_t), " bytes"); + } + FLASHINFER_CHECK(reinterpret_cast(params.B) % sizeof(load_input_t) == 0, + "B pointer must be aligned to ", sizeof(load_input_t), " bytes"); + FLASHINFER_CHECK(reinterpret_cast(params.C) % sizeof(load_input_t) == 0, + "C pointer must be aligned to ", sizeof(load_input_t), " bytes"); + FLASHINFER_CHECK((params.B_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, + "B batch stride must be aligned to ", sizeof(load_input_t), " bytes"); + FLASHINFER_CHECK((params.C_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, + "C batch stride must be aligned to ", sizeof(load_input_t), " bytes"); +} + +} // namespace flashinfer::mamba + +#endif // FLASHINFER_MAMBA_COMMON_CUH_ diff --git a/include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh b/include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh new file mode 100644 index 0000000000..09d927f6fb --- /dev/null +++ b/include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh @@ -0,0 +1,335 @@ +#include +#include +#include + +#include +#include +#include + +#include "../utils.cuh" +#include "../vec_dtypes.cuh" +#include "common.cuh" +#include "conversion.cuh" +#include "create_tensor_map.cuh" + +namespace flashinfer::mamba::mtp { + +using namespace conversion; + +template +struct SharedStorageSimple { + input_t x[TOKENS_MTP][DIM]; + float out[TOKENS_MTP][DIM]; + input_t z[TOKENS_MTP][DIM]; + input_t B[TOKENS_MTP][DSTATE]; + input_t C[TOKENS_MTP][DSTATE]; + state_t state[STATE_ROWS][DSTATE]; +}; + +template +__global__ void selective_state_update_kernel_simple_mtp(SelectiveStateMTPParams params) { + auto* __restrict__ output = reinterpret_cast(params.output); + auto* __restrict__ state = reinterpret_cast(params.state); + auto* __restrict__ intermediate_states = reinterpret_cast(params.intermediate_states); + + auto const* __restrict__ x = reinterpret_cast(params.x); + auto const* __restrict__ dt = reinterpret_cast(params.dt); + auto const* __restrict__ A = reinterpret_cast(params.A); + auto const* __restrict__ B = reinterpret_cast(params.B); + auto const* __restrict__ C = reinterpret_cast(params.C); + auto const* __restrict__ D = reinterpret_cast(params.D); + auto const* __restrict__ dt_bias = reinterpret_cast(params.dt_bias); + auto const* __restrict__ z = reinterpret_cast(params.z); + auto const* __restrict__ state_batch_indices = + reinterpret_cast(params.state_batch_indices); + auto const* __restrict__ intermediate_state_indices = + reinterpret_cast(params.intermediate_state_indices); + bool const dt_softplus = params.dt_softplus; + + int const nheads = params.nheads; + int const ngroups = params.ngroups; + + auto const batch = blockIdx.x; + auto const head = blockIdx.y; + auto const group = head / (nheads / ngroups); + auto lane = threadIdx.x % warpSize; + auto warp = threadIdx.y; + + auto const state_batch = (state_batch_indices) ? state_batch_indices[batch] : batch; + auto const intermediate_cache_idx = + intermediate_state_indices ? intermediate_state_indices[batch] : state_batch; + state += state_batch * params.state_stride_batch + head * DIM * DSTATE; + + constexpr auto stateRowsPerWarpPerStage = 4; + constexpr auto stageRows = stateRowsPerWarpPerStage * numWarps; + + extern __shared__ __align__(128) char smem[]; + auto& sram = + *reinterpret_cast*>( + smem); + + static constexpr auto stateLoadSize = getVectorLoadSizeForFullUtilization(); + using load_state_t = PackedAligned; + using load_input_t = PackedAligned; + using load_weight_t = PackedAligned; + + auto const A_value = toFloat(A[head]); + auto const d_value = D ? toFloat(D[head]) : 0.f; + auto const dt_bias_value = dt_bias ? toFloat(dt_bias[head]) : 0.f; + + // Loop over multiple tokens + if (warp == 0) { // Load x: gmem -> smem + for (int mtp_step = 0; mtp_step < TOKENS_MTP; mtp_step++) { + for (auto d = lane * load_input_t::count; d < DIM; d += warpSize * load_input_t::count) { + auto* dst = reinterpret_cast(&sram.x[mtp_step][d]); + *dst = *reinterpret_cast( + &x[batch * params.x_stride_batch + mtp_step * params.x_stride_mtp + head * DIM + d]); + } + } + } else if (warp == 1) { // Load B: gmem -> smem + for (int mtp_step = 0; mtp_step < TOKENS_MTP; mtp_step++) { + for (auto i = lane * load_input_t::count; i < DSTATE; i += warpSize * load_input_t::count) { + auto* dst = reinterpret_cast(&sram.B[mtp_step][i]); + *dst = *reinterpret_cast( + &B[batch * params.B_stride_batch + mtp_step * params.B_stride_mtp + group * DSTATE + + i]); + } + } + } else if (warp == 2) { // Load z: gmem -> smem + for (int mtp_step = 0; mtp_step < TOKENS_MTP; mtp_step++) { + for (auto d = lane * load_input_t::count; d < DIM; d += warpSize * load_input_t::count) { + auto* dst = reinterpret_cast(&sram.z[mtp_step][d]); + *dst = z ? *reinterpret_cast( + &z[batch * params.z_stride_batch + mtp_step * params.z_stride_mtp + + head * DIM + d]) + : make_zeros(); + } + } + } + // Load C: gmem -> smem + else if (warp == 3) { + for (int mtp_step = 0; mtp_step < TOKENS_MTP; mtp_step++) { + for (auto i = lane * load_input_t::count; i < DSTATE; i += warpSize * load_input_t::count) { + auto* dst = reinterpret_cast(&sram.C[mtp_step][i]); + *dst = *reinterpret_cast( + &C[batch * params.C_stride_batch + mtp_step * params.C_stride_mtp + group * DSTATE + + i]); + } + } + } + + float rdt[TOKENS_MTP]; + for (int step = 0; step < TOKENS_MTP; step++) { + auto dt_value = + dt_bias_value + + toFloat(dt[batch * params.dt_stride_batch + step * params.dt_stride_mtp + head]); + if (dt_softplus) { + dt_value = thresholded_softplus(dt_value); + } + rdt[step] = dt_value; + } + + __syncthreads(); + + for (auto dBegin = 0; dBegin < DIM; dBegin += stageRows) { + // Load state gmem -> smem + for (int warpRow = 0; warpRow < stateRowsPerWarpPerStage; warpRow++) { + auto dd = warp * stateRowsPerWarpPerStage + warpRow; + auto d = dBegin + dd; + if (d < DIM) { + if (state_batch != params.pad_slot_id) { + for (int i = lane * load_state_t::count; i < DSTATE; + i += warpSize * load_state_t::count) { + auto* dst = reinterpret_cast(&sram.state[dd][i]); + *dst = *reinterpret_cast(&state[d * DSTATE + i]); + } + } + } + } + + // Compute how many input_t elements to pack per SRAM load based on DSTATE/warpSize ratio + constexpr auto stateValuesPerThread = DSTATE / warpSize; + // We will be loading two-banks worth of input_t at a time instead of 1 in order to reduce the + // load on LSU. + constexpr auto maxPackedElements = sizeof(uint64_t) / sizeof(input_t); + constexpr auto packedSramLdInputElements = + (stateValuesPerThread >= maxPackedElements) ? maxPackedElements : stateValuesPerThread; + static_assert(stateValuesPerThread % packedSramLdInputElements == 0, + "stateValuesPerThread must be divisible by packedSramLdInputElements"); + using packed_input_t = PackedAligned; + float rState[stateValuesPerThread]; + packed_input_t rB; + packed_input_t rC; + + for (int warpRow = 0; warpRow < stateRowsPerWarpPerStage; warpRow++) { + auto dd = warp * stateRowsPerWarpPerStage + warpRow; + auto d = dBegin + dd; + + if (d >= DIM) break; + + // Load state smem -> rmem + // There is a bank conflict here, but we are not in a hot loop and we must align the state + // indices with the input indices + for (int ii = 0; ii < stateValuesPerThread; ii++) { + int i = lane * packed_input_t::count + + (ii / packed_input_t::count) * warpSize * packed_input_t::count + + (ii % packed_input_t::count); + rState[ii] = + (state_batch != params.pad_slot_id && i < DSTATE) ? toFloat(sram.state[dd][i]) : 0.f; + } + + for (int step = 0; step < TOKENS_MTP; step++) { + float x_value = toFloat(sram.x[step][d]); + float out_value = d_value * x_value * int(lane == 0); // first lane has the value + + // Compute dt value for this token + auto dt_value = rdt[step]; + auto const dA = __expf(A_value * dt_value); + + // Process state in groups of packed_input_t::count to match B/C bank-aligned loads + for (int ii = 0; ii < stateValuesPerThread; ii += packed_input_t::count) { + int base_i = lane * packed_input_t::count + + (ii / packed_input_t::count) * warpSize * packed_input_t::count; + + // Bank-aligned load for B and C + rB = *reinterpret_cast(&sram.B[step][base_i]); + rC = *reinterpret_cast(&sram.C[step][base_i]); + +#pragma unroll + for (int k = 0; k < packed_input_t::count; k++) { + auto& state_value = rState[ii + k]; + auto B_value = toFloat(rB.val[k]); + auto C_value = toFloat(rC.val[k]); + + auto const dB = B_value * dt_value; + auto const new_state = state_value * dA + dB * x_value; + state_value = new_state; + + out_value += new_state * C_value; + } + + if constexpr (sizeof(state_t) == sizeof(input_t)) { + if (intermediate_states) { + using packed_state_t = PackedAligned; + packed_state_t rStateOut; +#pragma unroll + for (int k = 0; k < packed_input_t::count; k++) { + convertAndStore(&rStateOut.val[k], rState[ii + k]); + } + *reinterpret_cast(&sram.state[dd][base_i]) = rStateOut; + } + } else { + if (intermediate_states) { +#pragma unroll + for (int k = 0; k < packed_input_t::count; k++) { + convertAndStore(&sram.state[dd][base_i + k], rState[ii + k]); + } + } + } + } + + out_value = warpReduceSum(out_value); + if (lane == 0) { + sram.out[step][d] = out_value; + } + + if (intermediate_states && state_batch != params.pad_slot_id) { + for (int i = lane * load_state_t::count; i < DSTATE; + i += warpSize * load_state_t::count) { + auto* src = reinterpret_cast(&sram.state[dd][i]); + auto* dst = reinterpret_cast( + &intermediate_states[intermediate_cache_idx * + params.intermediate_state_stride_batch + + step * nheads * DIM * DSTATE + head * DIM * DSTATE + + d * DSTATE + i]); + *dst = *src; + } + } + } + + // Update state if enabled and not padded + if (params.update_state && state_batch != params.pad_slot_id) { + // Store to rmem -> smem + for (int ii = 0; ii < stateValuesPerThread; ii++) { + int i = lane * packed_input_t::count + + (ii / packed_input_t::count) * warpSize * packed_input_t::count + + (ii % packed_input_t::count); + if (i < DSTATE) { + convertAndStore(&sram.state[dd][i], rState[ii]); + } + } + // store smem -> gmem + for (int i = lane * load_state_t::count; i < DSTATE; i += warpSize * load_state_t::count) { + auto* src = reinterpret_cast(&sram.state[dd][i]); + *reinterpret_cast(&state[d * DSTATE + i]) = *src; + } + } + } + } + + __syncthreads(); + + for (auto step = warp; step < TOKENS_MTP; step += numWarps) { + for (auto d = lane; d < DIM; d += warpSize) { + auto out_value = sram.out[step][d]; + if (z) { + float z_value = toFloat(sram.z[step][d]); + float sig_z = __fdividef(1.f, (1.f + __expf(0.f - z_value))); + float silu_z = z_value * sig_z; + out_value *= silu_z; + } + auto* dst = reinterpret_cast( + &output[batch * params.out_stride_batch + step * params.out_stride_mtp + head * DIM + d]); + convertAndStore(dst, out_value); + } + } +} + +template +void invokeSelectiveStateUpdateMTP(SelectiveStateMTPParams& params, cudaStream_t stream) { + // Common alignment checks for all kernels + check_ptr_alignment_input_vars(params); + + auto kernel_launcher = [&]() { + // Additional alignment checks specific to simple kernel + constexpr auto stateLoadSize = getVectorLoadSizeForFullUtilization(); + using load_state_t = PackedAligned; + + FLASHINFER_CHECK(reinterpret_cast(params.state) % sizeof(load_state_t) == 0, + "state pointer must be aligned to ", sizeof(load_state_t), " bytes"); + FLASHINFER_CHECK((params.dim * params.dstate * sizeof(state_t)) % sizeof(load_state_t) == 0, + "state head stride must be aligned to ", sizeof(load_state_t), " bytes"); + + constexpr int numWarps = 4; + constexpr int stateRowsPerWarpPerStage = 4; + constexpr int stageRows = stateRowsPerWarpPerStage * numWarps; + + dim3 block(warpSize, numWarps); + dim3 grid(params.batch, params.nheads); + + auto func = + selective_state_update_kernel_simple_mtp; + using sram_t = SharedStorageSimple; + constexpr size_t smem_size = sizeof(sram_t); + + // Use FLASHINFER_CHECK instead of FLASHINFER_CUDA_CALL since we're in a void lambda + // (FLASHINFER_CUDA_CALL uses "return e;" which is invalid in void context) + // { + // cudaError_t e = cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, + // smem_size); FLASHINFER_CHECK(e == cudaSuccess, "CUDA Error in cudaFuncSetAttribute: ", + // cudaGetErrorString(e), " (", int(e), ")"); + // } + FLASHINFER_CUDA_CHECK( + cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + func<<>>(params); + }; + + dispatchDimDstateTokens(params, AllowedDims{}, AllowedDstates{}, AllowedNtokens{}, + kernel_launcher); +} + +} // namespace flashinfer::mamba::mtp diff --git a/include/flashinfer/mamba/selective_state_update_stp.cuh b/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh similarity index 89% rename from include/flashinfer/mamba/selective_state_update_stp.cuh rename to include/flashinfer/mamba/kernel_selective_state_update_stp.cuh index 5c54d25af1..88be0f2a4a 100644 --- a/include/flashinfer/mamba/selective_state_update_stp.cuh +++ b/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh @@ -21,11 +21,11 @@ #include #include -#include #include #include "../utils.cuh" #include "../vec_dtypes.cuh" +#include "common.cuh" #include "conversion.cuh" #include "create_tensor_map.cuh" @@ -60,10 +60,10 @@ __device__ __forceinline__ int conflict_free_column(int group, int baseCol) { template struct SharedStorageSimple { alignas(alignof(PackedAligned)) input_t x[dim]; - float out[dim]; alignas(alignof(PackedAligned)) input_t z[dim]; alignas(alignof(PackedAligned)) input_t B[dstate]; alignas(alignof(PackedAligned)) input_t C[dstate]; + float out[dim]; }; template )) input_t x[dim]; - float out[dim]; // dt is special cause we're gonna store input in there as well alignas(alignof(PackedAligned)) input_t z[dim]; alignas(alignof(PackedAligned)) input_t B[dstate]; alignas(alignof(PackedAligned)) input_t C[dstate]; + float out[dim]; // dt is special cause we're gonna store input in there as well using barrier_t = cuda::barrier; barrier_t bar_empty[numStages]; @@ -883,103 +883,11 @@ __global__ void selective_state_update_kernel_producer_consumer_horizontal( #endif // FLASHINFER_MAMBA_ENABLE_SM100 -template -std::string format_array(const T (&arr)[N]) { - std::ostringstream oss; - for (size_t i = 0; i < N; ++i) { - if (i > 0) oss << ", "; - oss << arr[i]; - } - return oss.str(); -} - -// Helper function to dispatch dim and dstate with a kernel launcher -template -void dispatchDimDstate(SelectiveStateUpdateParams& params, - std::integer_sequence, - std::integer_sequence, KernelLauncher&& launcher) { - constexpr int allowed_dims[] = {AllowedDims...}; - constexpr int allowed_dstates[] = {AllowedDstates...}; - - auto dispatch_dim_dstate = [&]() { - launcher.template operator()(); - }; - - auto dispatch_dstate = [&]() { - if (params.dstate == DSTATE) { - dispatch_dim_dstate.template operator()(); - return true; - } - return false; - }; - - auto dispatch_dim = [&]() { - if (params.dim == DIM) { - bool dispatched = (dispatch_dstate.template operator()() || ...); - FLASHINFER_CHECK(dispatched, "Unsupported dstate value: ", params.dstate, - ".\nSupported values: ", format_array(allowed_dstates)); - return true; - } - return false; - }; - - bool dim_dispatched = (dispatch_dim.template operator()() || ...); - FLASHINFER_CHECK(dim_dispatched, "Unsupported dim value: ", params.dim, - ".\nSupported values: ", format_array(allowed_dims)); -} - -// Helper function to dispatch ratio with a kernel launcher -template -void dispatchRatio(SelectiveStateUpdateParams& params, std::integer_sequence, - KernelLauncher&& launcher) { - constexpr int allowed_ratios[] = {AllowedRatios...}; - - auto dispatch_single_ratio = [&]() { - if (params.nheads / params.ngroups == RATIO) { - launcher.template operator()(); - return true; - } - return false; - }; - - bool ratio_dispatched = (dispatch_single_ratio.template operator()() || ...); - FLASHINFER_CHECK(ratio_dispatched, - "Unsupported nheads/ngroups ratio: ", params.nheads / params.ngroups, - ".\nSupported values: ", format_array(allowed_ratios)); -} - -// Check alignment for common input variables (x, z, B, C) -template -void check_ptr_alignment_input_vars(const SelectiveStateUpdateParams& params) { - using load_input_t = PackedAligned; - FLASHINFER_CHECK(reinterpret_cast(params.x) % sizeof(load_input_t) == 0, - "x pointer must be aligned to ", sizeof(load_input_t), " bytes"); - FLASHINFER_CHECK((params.x_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, - "x batch stride must be aligned to ", sizeof(load_input_t), " bytes"); - if (params.z) { - FLASHINFER_CHECK(reinterpret_cast(params.z) % sizeof(load_input_t) == 0, - "z pointer must be aligned to ", sizeof(load_input_t), " bytes"); - FLASHINFER_CHECK((params.z_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, - "z batch stride must be aligned to ", sizeof(load_input_t), " bytes"); - } - FLASHINFER_CHECK(reinterpret_cast(params.B) % sizeof(load_input_t) == 0, - "B pointer must be aligned to ", sizeof(load_input_t), " bytes"); - FLASHINFER_CHECK(reinterpret_cast(params.C) % sizeof(load_input_t) == 0, - "C pointer must be aligned to ", sizeof(load_input_t), " bytes"); - FLASHINFER_CHECK((params.B_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, - "B batch stride must be aligned to ", sizeof(load_input_t), " bytes"); - FLASHINFER_CHECK((params.C_stride_batch * sizeof(input_t)) % sizeof(load_input_t) == 0, - "C batch stride must be aligned to ", sizeof(load_input_t), " bytes"); -} - template void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, cudaStream_t stream) { auto [sm_major, sm_minor] = GetCudaComputeCapability(); - constexpr int allowed_dstates[] = {64, 128, 256}; - constexpr int allowed_dims[] = {64, 128, 256}; - // Common alignment checks for all kernels check_ptr_alignment_input_vars(params); @@ -1006,8 +914,7 @@ void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, cudaStream_t DSTATE, numWarps><<>>(params); }; - dispatchDimDstate(params, std::integer_sequence{}, - std::integer_sequence{}, kernel_launcher); + dispatchDimDstate(params, AllowedDims{}, AllowedDstates{}, kernel_launcher); } #ifdef FLASHINFER_MAMBA_ENABLE_SM90 else { @@ -1046,8 +953,7 @@ void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, cudaStream_t scan_func<<>>(params, state_tensor); }; - dispatchDimDstate(params, std::integer_sequence{}, - std::integer_sequence{}, kernel_launcher); + dispatchDimDstate(params, AllowedDims{}, AllowedDstates{}, kernel_launcher); } #endif @@ -1096,8 +1002,7 @@ void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, cudaStream_t dispatchRatio(params, std::integer_sequence{}, ratio_launcher); }; - dispatchDimDstate(params, std::integer_sequence{}, - std::integer_sequence{}, kernel_launcher); + dispatchDimDstate(params, AllowedDims{}, AllowedDstates{}, kernel_launcher); } #endif } diff --git a/include/flashinfer/mamba/selective_state_update.cuh b/include/flashinfer/mamba/selective_state_update.cuh index 64e0eb2b0c..f8b44c3779 100644 --- a/include/flashinfer/mamba/selective_state_update.cuh +++ b/include/flashinfer/mamba/selective_state_update.cuh @@ -16,22 +16,17 @@ #ifndef FLASHINFER_MAMBA_SELECTIVE_STATE_UPDATE_CUH_ #define FLASHINFER_MAMBA_SELECTIVE_STATE_UPDATE_CUH_ -// #include -// #include -// #include - -// #include #include -// #include - -// #include "../utils.cuh" -// #include "../vec_dtypes.cuh" -// #include "conversion.cuh" -// #include "create_tensor_map.cuh" +#include namespace flashinfer::mamba { -constexpr unsigned warpSize = 32; +// ============================================================================= +// Allowed dispatch values for kernel instantiation +// ============================================================================= +using AllowedDims = std::integer_sequence; +using AllowedDstates = std::integer_sequence; +using AllowedNtokens = std::integer_sequence; struct SelectiveStateUpdateParams { uint32_t batch{}, nheads{}, dim{}, dstate{}, ngroups{}, state_cache_size{}; @@ -40,65 +35,22 @@ struct SelectiveStateUpdateParams { int64_t x_stride_batch{}, dt_stride_batch{}, B_stride_batch{}, C_stride_batch{}, out_stride_batch{}, z_stride_batch{}, state_stride_batch{}; - void* __restrict__ state{nullptr}; // state_t: (state_cache_size, nheads, dim, dstate) - void* __restrict__ x{nullptr}; // input_t: (batch, nheads, dim) - void* __restrict__ dt{ - nullptr}; // weight_t: (batch, nheads) but pretends to be (batch, nheads, dim) - void* __restrict__ dt_bias{nullptr}; // weight_t (nheads) but pretends to be (nheads, dim) - void* __restrict__ A{nullptr}; // matrixA_t: (nheads) but pretends to be (nheads, dim, dstate) - void* __restrict__ B{nullptr}; // input_t: (batch, ngroups, dstate) - void* __restrict__ C{nullptr}; // input_t: (batch, ngroups, dstate) - void* __restrict__ D{nullptr}; // weight_t: (nheads) but pretends to be (nheads, dim) - void* __restrict__ z{nullptr}; // input_t: (batch, nheads, dim) - void* __restrict__ output{nullptr}; // input_t: (batch, nheads, dim) - void* __restrict__ state_batch_indices{nullptr}; // state_batch_indices: (batch,) + void* __restrict__ state{nullptr}; + void* __restrict__ x{nullptr}; + void* __restrict__ dt{nullptr}; + void* __restrict__ dt_bias{nullptr}; + void* __restrict__ A{nullptr}; + void* __restrict__ B{nullptr}; + void* __restrict__ C{nullptr}; + void* __restrict__ D{nullptr}; + void* __restrict__ z{nullptr}; + void* __restrict__ output{nullptr}; + void* __restrict__ state_batch_indices{nullptr}; bool dt_softplus{false}; bool update_state{true}; }; -__forceinline__ __device__ float softplus(float x) { return __logf(1.f + __expf(x)); } - -__device__ __forceinline__ float thresholded_softplus(float dt_value) { - constexpr float threshold = 20.f; - return (dt_value <= threshold) ? softplus(dt_value) : dt_value; -} - -// Simple packed vector type for loading N elements of type T -template -struct alignas(N * sizeof(T)) PackedAligned { - T val[N]; - static constexpr int count = N; - using dtype = T; -}; - -template -__device__ __forceinline__ auto make_zeros() -> load_t { - load_t ret{}; -#pragma unroll - for (int i = 0; i < ret.count; i++) - ret.val[i] = typename load_t::dtype{}; // default initialization - return ret; -}; - -// Computes the vector load size that ensures full warp utilization. -// Avoids cases like: dstate=64, load_t = sizeof(float4)/sizeof(f16), warpsize=32 (32 * 8 > 64) -// in which case a part of the warp would be idle. -template -inline constexpr auto getVectorLoadSizeForFullUtilization() -> unsigned { - static_assert(sizeof(float4) >= sizeof(T)); - constexpr unsigned maxHardwareLoadSize = sizeof(float4) / sizeof(T); - constexpr unsigned maxLogicalLoadSize = (unsigned)DSTATE / warpSize; - return maxHardwareLoadSize < maxLogicalLoadSize ? maxHardwareLoadSize : maxLogicalLoadSize; -} - -__device__ __forceinline__ float warpReduceSum(float val) { - for (int s = warpSize / 2; s > 0; s /= 2) { - val += __shfl_down_sync(UINT32_MAX, val, s); - } - return val; -} - namespace mtp { // Extended params struct for multi-token prediction (MTP) struct SelectiveStateMTPParams : public SelectiveStateUpdateParams { @@ -117,6 +69,7 @@ struct SelectiveStateMTPParams : public SelectiveStateUpdateParams { } // namespace flashinfer::mamba -#include "selective_state_update_stp.cuh" +#include "kernel_selective_state_update_mtp.cuh" +#include "kernel_selective_state_update_stp.cuh" #endif // FLASHINFER_MAMBA_SELECTIVE_STATE_UPDATE_CUH_ diff --git a/tests/mamba/test_selective_state_update_mtp.py b/tests/mamba/test_selective_state_update_mtp.py index afa1b6d202..38e1592d22 100644 --- a/tests/mamba/test_selective_state_update_mtp.py +++ b/tests/mamba/test_selective_state_update_mtp.py @@ -712,3 +712,62 @@ def state_dtype(self, request): @pytest.fixture(params=[torch.float32]) def weight_dtype(self, request): return request.param + + +class TestSelectiveStateUpdateMTPIndicesDtypeMismatch: + """Test that selective_state_update fails with dtype mismatch between indices.""" + + def test_state_batch_idx_and_intermediate_idx_dtype_mismatch_should_fail(self): + """Test that state_batch_indices and intermediate_state_indices dtype mismatch raises an error.""" + batch = 4 + nheads = 32 + dim = 64 + dstate = 128 + ngroups = 8 + cache_steps = 4 + + # Create inputs with intermediate states buffer + inputs = create_test_inputs( + batch, + nheads, + dim, + dstate, + ngroups, + input_dtype=torch.bfloat16, + weight_dtype=torch.float32, + matrixA_dtype=torch.float32, + state_dtype=torch.bfloat16, + generate_z=False, + generate_intermediate_states_buffer=True, + cache_steps=cache_steps, + seed=0, + ) + + # Convert state_batch_indices to int64 (default is typically int64) + inputs["slot_idx"] = inputs["slot_idx"].to(torch.int64) + + # Convert intermediate_state_indices to int32 (different dtype) + inputs["intermediate_slot_idx"] = inputs["intermediate_slot_idx"].to( + torch.int32 + ) + + # This should fail due to dtype mismatch between indices + with pytest.raises((RuntimeError, ValueError)): + flashinfer.mamba.selective_state_update( + inputs["state_cache"], + inputs["x"], + inputs["dt"], + inputs["A"], + inputs["B"], + inputs["C"], + D=inputs["D"], + z=inputs.get("z"), + dt_bias=inputs["dt_bias"], + dt_softplus=True, + state_batch_indices=inputs["slot_idx"], + pad_slot_id=-1, + disable_state_update=True, + intermediate_states_buffer=inputs["intermediate_states_buffer"], + intermediate_state_indices=inputs["intermediate_slot_idx"], + cache_steps=inputs["cache_steps"], + ) diff --git a/tests/mamba/test_selective_state_update_stp.py b/tests/mamba/test_selective_state_update_stp.py index 4aba115a20..566a93df9c 100644 --- a/tests/mamba/test_selective_state_update_stp.py +++ b/tests/mamba/test_selective_state_update_stp.py @@ -439,3 +439,50 @@ def run_kernel(self, inputs, out=None): pad_slot_id=-1, out=out, ) + + +class TestSelectiveStateUpdateDtypeMismatch: + """Test that selective_state_update fails with dtype mismatch between D and dt.""" + + def test_d_f32_dt_bf16_should_fail(self): + """Test that D (f32) and dt (bf16) dtype mismatch raises an error.""" + batch = 1 + nheads = 8 + dim = 64 + dstate = 128 + ngroups = 8 + + # Create inputs with standard dtypes + inputs = create_test_inputs( + batch, + nheads, + dim, + dstate, + ngroups, + input_dtype=torch.bfloat16, + weight_dtype=torch.bfloat16, + matrixA_dtype=torch.float32, + state_dtype=torch.bfloat16, + generate_z=False, + seed=0, + ) + + # Override D to be float32 (while dt remains bfloat16) + inputs["D"] = inputs["D"].to(torch.float32) + + # This should fail due to dtype mismatch + with pytest.raises((RuntimeError, ValueError)): + flashinfer.mamba.selective_state_update( + inputs["state_cache"], + inputs["x"], + inputs["dt"], + inputs["A"], + inputs["B"], + inputs["C"], + D=inputs["D"], + z=inputs.get("z"), + dt_bias=inputs["dt_bias"], + dt_softplus=True, + state_batch_indices=inputs["slot_idx"], + pad_slot_id=-1, + ) diff --git a/tests/mamba/test_utils.py b/tests/mamba/test_utils.py index 240d501d8c..6f33e930e3 100644 --- a/tests/mamba/test_utils.py +++ b/tests/mamba/test_utils.py @@ -211,7 +211,7 @@ def create_test_inputs( result["cache_steps"] = cache_steps # Also generate indices mapping batch elements to intermediate state buffer positions intermediate_slot_idx = torch.arange( - batch_size, dtype=torch.int32, device=device + batch_size, dtype=torch.int64, device=device ) result["intermediate_slot_idx"] = intermediate_slot_idx @@ -225,7 +225,7 @@ def create_test_inputs( # Token 0: parent = -1 (initial state) # Token t: parent = t - 1 (previous token) retrieve_parent_token = torch.zeros( - batch_size, T, dtype=torch.int32, device=device + batch_size, T, dtype=torch.int64, device=device ) retrieve_parent_token[:, 0] = -1 # First token uses initial state for t in range(1, T): From 9d6d35ce953d084182aef51552cda9c5f0c76cbd Mon Sep 17 00:00:00 2001 From: Igor Shovkun Date: Fri, 30 Jan 2026 09:24:25 -0800 Subject: [PATCH 12/33] Fix simple stp kernel to only write state if a flag is provided --- include/flashinfer/mamba/kernel_selective_state_update_stp.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh b/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh index 88be0f2a4a..254f2f71ea 100644 --- a/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh +++ b/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh @@ -166,7 +166,7 @@ __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams out_value += new_state * C_value; } - if (state_batch != params.pad_slot_id) + if (params.update_state && state_batch != params.pad_slot_id) *reinterpret_cast(&state[d * DSTATE + i]) = rState; } From 5b5756d6f0d1ce4c00ab482900de55486e613ae4 Mon Sep 17 00:00:00 2001 From: Igor Shovkun Date: Fri, 30 Jan 2026 14:52:01 -0800 Subject: [PATCH 13/33] Fix Triton kernel intermediate state caching to match CUDA behavior Always define state_batch_idx (either from state_batch_indices or pid_b) to mirror the CUDA kernel's state_batch variable. This allows the intermediate state caching logic to use a simple check of `state_batch_idx != pad_slot_id` without requiring an extra HAS_STATE_BATCH_INDICES guard, matching the CUDA kernel behavior. addresses: https://github.com/flashinfer-ai/flashinfer/pull/2444#discussion_r2747238335 --- tests/mamba/selective_state_update_triton.py | 26 +++++++++----------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/tests/mamba/selective_state_update_triton.py b/tests/mamba/selective_state_update_triton.py index ef76073f8f..c40f90612a 100644 --- a/tests/mamba/selective_state_update_triton.py +++ b/tests/mamba/selective_state_update_triton.py @@ -154,6 +154,7 @@ def _selective_scan_update_kernel( state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64) state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head else: + state_batch_idx = pid_b state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head @@ -261,20 +262,17 @@ def _selective_scan_update_kernel( state = state * dA + dB * x[:, None] if CACHE_INTERMEDIATE_STATES: - if HAS_STATE_BATCH_INDICES: - if state_batch_idx != pad_slot_id: - cache_ptr_base = ( - intermediate_states_buffer - + cache_idx * cache_steps * nheads * dim * dstate - + current_step_idx * nheads * dim * dstate - + pid_h * dim * dstate - ) - cache_ptrs = cache_ptr_base + ( - offs_m[:, None] * dstate + offs_n[None, :] - ) - tl.store( - cache_ptrs, state.to(cache_ptrs.dtype.element_ty), mask=mask - ) + if state_batch_idx != pad_slot_id: + cache_ptr_base = ( + intermediate_states_buffer + + cache_idx * cache_steps * nheads * dim * dstate + + current_step_idx * nheads * dim * dstate + + pid_h * dim * dstate + ) + cache_ptrs = cache_ptr_base + ( + offs_m[:, None] * dstate + offs_n[None, :] + ) + tl.store(cache_ptrs, state.to(cache_ptrs.dtype.element_ty), mask=mask) out = tl.sum(state * C[None, :], axis=1) if HAS_D: From fb693d00c6931db4aa6e01debfda61b5414f0d4b Mon Sep 17 00:00:00 2001 From: Igor Shovkun Date: Mon, 2 Feb 2026 21:43:27 -0800 Subject: [PATCH 14/33] Add Mamba2 SSD chunk scan test and reorganize Triton refs - Add test_chunk_scan_combined.py comparing CUTLASS CuTe DSL Blackwell implementation against Triton reference - Move selective_state_update_triton.py into triton_reference/ package - Add Triton reference implementations for Mamba2 SSD kernels: - ssd_combined.py (main entry point) - ssd_chunk_scan.py, ssd_chunk_state.py, ssd_state_passing.py - ssd_bmm.py, softplus.py (utilities) --- tests/mamba/selective_state_update_triton.py | 501 ------------ tests/mamba/test_chunk_scan_combined.py | 703 ++++++++++++++++ .../mamba/test_selective_state_update_mtp.py | 2 +- .../mamba/test_selective_state_update_stp.py | 2 +- tests/mamba/triton_reference/__init__.py | 6 + tests/mamba/triton_reference/softplus.py | 26 + tests/mamba/triton_reference/ssd_bmm.py | 272 ++++++ .../mamba/triton_reference/ssd_chunk_scan.py | 628 ++++++++++++++ .../mamba/triton_reference/ssd_chunk_state.py | 771 ++++++++++++++++++ tests/mamba/triton_reference/ssd_combined.py | 265 ++++++ .../triton_reference/ssd_state_passing.py | 282 +++++++ 11 files changed, 2955 insertions(+), 503 deletions(-) delete mode 100644 tests/mamba/selective_state_update_triton.py create mode 100644 tests/mamba/test_chunk_scan_combined.py create mode 100644 tests/mamba/triton_reference/__init__.py create mode 100644 tests/mamba/triton_reference/softplus.py create mode 100644 tests/mamba/triton_reference/ssd_bmm.py create mode 100644 tests/mamba/triton_reference/ssd_chunk_scan.py create mode 100644 tests/mamba/triton_reference/ssd_chunk_state.py create mode 100644 tests/mamba/triton_reference/ssd_combined.py create mode 100644 tests/mamba/triton_reference/ssd_state_passing.py diff --git a/tests/mamba/selective_state_update_triton.py b/tests/mamba/selective_state_update_triton.py deleted file mode 100644 index c40f90612a..0000000000 --- a/tests/mamba/selective_state_update_triton.py +++ /dev/null @@ -1,501 +0,0 @@ -# Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/mamba_ssm.py - -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Copyright (c) 2024, Tri Dao, Albert Gu. -# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py - -import torch -import triton -import triton.language as tl -from packaging import version - -PAD_SLOT_ID = -1 - -TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") - -if TRITON3: - - @triton.jit - def softplus(dt): - dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt) - return dt - -else: - - @triton.jit - def softplus(dt): - dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) - return dt - - -@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) -@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) -@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) -@triton.heuristics( - { - "HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"] - is not None - } -) -@triton.heuristics( - {"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])} -) -@triton.heuristics( - { - "CACHE_INTERMEDIATE_STATES": lambda args: args["intermediate_states_buffer"] - is not None - } -) -@triton.heuristics( - { - "HAS_EAGLE_TREE_CUSTOM_ATTN_MASK": lambda args: args[ - "retrieve_parent_token_ptr" - ] - is not None - } -) -@triton.heuristics( - { - "HAS_INTERMEDIATE_STATE_INDICES": lambda args: args[ - "intermediate_state_indices_ptr" - ] - is not None - } -) -@triton.jit(do_not_specialize=["T"]) -def _selective_scan_update_kernel( - # Pointers to matrices - state_ptr, - x_ptr, - dt_ptr, - dt_bias_ptr, - A_ptr, - B_ptr, - C_ptr, - D_ptr, - z_ptr, - out_ptr, - state_batch_indices_ptr, - pad_slot_id, - intermediate_states_buffer, - cache_steps, - retrieve_parent_token_ptr, - intermediate_state_indices_ptr, - # Matrix dimensions - batch, - T, - nheads, - dim, - dstate, - nheads_ngroups_ratio, - # Strides - stride_state_batch, - stride_state_head, - stride_state_dim, - stride_state_dstate, - stride_x_batch, - stride_x_T, - stride_x_head, - stride_x_dim, - stride_dt_batch, - stride_dt_T, - stride_dt_head, - stride_dt_dim, - stride_dt_bias_head, - stride_dt_bias_dim, - stride_A_head, - stride_A_dim, - stride_A_dstate, - stride_B_batch, - stride_B_T, - stride_B_group, - stride_B_dstate, - stride_C_batch, - stride_C_T, - stride_C_group, - stride_C_dstate, - stride_D_head, - stride_D_dim, - stride_z_batch, - stride_z_T, - stride_z_head, - stride_z_dim, - stride_out_batch, - stride_out_T, - stride_out_head, - stride_out_dim, - stride_retrieve_parent_token_batch, - stride_retrieve_parent_token_T, - # Meta-parameters - DT_SOFTPLUS: tl.constexpr, - TIE_HDIM: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - HAS_DT_BIAS: tl.constexpr, - HAS_D: tl.constexpr, - HAS_Z: tl.constexpr, - HAS_STATE_BATCH_INDICES: tl.constexpr, - DISABLE_STATE_UPDATE: tl.constexpr, - CACHE_INTERMEDIATE_STATES: tl.constexpr, - HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: tl.constexpr, - HAS_INTERMEDIATE_STATE_INDICES: tl.constexpr, - BLOCK_SIZE_DSTATE: tl.constexpr, -): - pid_m = tl.program_id(axis=0) - pid_b = tl.program_id(axis=1) - pid_h = tl.program_id(axis=2) - - # If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate - # is taken from the state_batch_indices_ptr Otherwise, the state coordinate - # is the same as the batch id. - if HAS_STATE_BATCH_INDICES: - state_batch_indices_ptr += pid_b - state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64) - state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head - else: - state_batch_idx = pid_b - state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head - - x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head - dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head - if HAS_DT_BIAS: - dt_bias_ptr += pid_h * stride_dt_bias_head - A_ptr += pid_h * stride_A_head - B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group - C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group - if HAS_Z: - z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head - out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) - state_ptrs = state_ptr + ( - offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate - ) - - mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) - if HAS_STATE_BATCH_INDICES: - mask &= state_batch_idx != pad_slot_id - state = tl.load(state_ptrs, mask=mask, other=0.0).to(tl.float32) - - if HAS_DT_BIAS: - dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim - if HAS_D: - D_ptr += pid_h * stride_D_head - D_ptrs = D_ptr + offs_m * stride_D_dim - A_ptrs = A_ptr + offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate - - cache_idx = -1 - if CACHE_INTERMEDIATE_STATES: - if HAS_INTERMEDIATE_STATE_INDICES: - intermediate_state_idx = tl.load(intermediate_state_indices_ptr + pid_b).to( - tl.int64 - ) - cache_idx = intermediate_state_idx - elif HAS_STATE_BATCH_INDICES: - cache_idx = state_batch_idx - else: - cache_idx = pid_b - - current_step_idx = 0 - for _ in range(T): - if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: - if current_step_idx != 0 and cache_idx >= 0: - parent_ptr = ( - retrieve_parent_token_ptr - + pid_b * stride_retrieve_parent_token_batch - + current_step_idx * stride_retrieve_parent_token_T - ) - parent_step_idx = tl.load(parent_ptr).to(tl.int32) - - if parent_step_idx >= 0 and parent_step_idx < T: - step_offset = parent_step_idx * nheads * dim * dstate - cache_ptr = ( - intermediate_states_buffer - + cache_idx * cache_steps * nheads * dim * dstate - + step_offset - + pid_h * dim * dstate - + offs_m[:, None] * dstate - + offs_n[None, :] - ) - state = tl.load(cache_ptr, mask=mask, other=0.0).to(tl.float32) - - x_ptrs = x_ptr + offs_m * stride_x_dim - dt_ptrs = dt_ptr + offs_m * stride_dt_dim - B_ptrs = B_ptr + offs_n * stride_B_dstate - C_ptrs = C_ptr + offs_n * stride_C_dstate - if HAS_Z: - z_ptrs = z_ptr + offs_m * stride_z_dim - out_ptrs = out_ptr + offs_m * stride_out_dim - - x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - if not TIE_HDIM: - dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - if HAS_DT_BIAS: - dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - if DT_SOFTPLUS: - dt = softplus(dt) - A = tl.load( - A_ptrs, - mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), - other=0.0, - ).to(tl.float32) - dA = tl.exp(A * dt[:, None]) - else: - dt = tl.load(dt_ptr).to(tl.float32) - if HAS_DT_BIAS: - dt += tl.load(dt_bias_ptr).to(tl.float32) - if DT_SOFTPLUS: - dt = softplus(dt) - A = tl.load(A_ptr).to(tl.float32) - dA = tl.exp(A * dt) # scalar, not a matrix - - B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) - C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) - if HAS_D: - D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - if HAS_Z: - z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - - dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt - state = state * dA + dB * x[:, None] - - if CACHE_INTERMEDIATE_STATES: - if state_batch_idx != pad_slot_id: - cache_ptr_base = ( - intermediate_states_buffer - + cache_idx * cache_steps * nheads * dim * dstate - + current_step_idx * nheads * dim * dstate - + pid_h * dim * dstate - ) - cache_ptrs = cache_ptr_base + ( - offs_m[:, None] * dstate + offs_n[None, :] - ) - tl.store(cache_ptrs, state.to(cache_ptrs.dtype.element_ty), mask=mask) - - out = tl.sum(state * C[None, :], axis=1) - if HAS_D: - out += x * D - if HAS_Z: - out *= z * tl.sigmoid(z) - tl.store(out_ptrs, out, mask=offs_m < dim) - - current_step_idx += 1 # noqa: SIM113 - - x_ptr += stride_x_T - dt_ptr += stride_dt_T - B_ptr += stride_B_T - C_ptr += stride_C_T - out_ptr += stride_out_T - if HAS_Z: - z_ptr += stride_z_T - - if not DISABLE_STATE_UPDATE: - tl.store(state_ptrs, state.to(state_ptrs.dtype.element_ty), mask=mask) - - -def selective_state_update_triton( - state, - x, - dt, - A, - B, - C, - D=None, - z=None, - dt_bias=None, - dt_softplus=False, - state_batch_indices=None, - pad_slot_id=PAD_SLOT_ID, - out=None, - disable_state_update=False, - intermediate_states_buffer=None, - cache_steps=None, - retrieve_parent_token=None, - intermediate_state_indices=None, -): - """ - Argument: - state: (batch, dim, dstate) or (batch, nheads, dim, dstate) - x: (batch, dim) or (batch, nheads, dim) for single-token or (batch, T, nheads, dim) for multi-token - dt: (batch, dim) or (batch, nheads, dim) for single-token or (batch, T, nheads, dim) for multi-token - A: (dim, dstate) or (nheads, dim, dstate) - B: (batch, dstate) or (batch, ngroups, dstate) for single-token or (batch, T, ngroups, dstate) for multi-token - C: (batch, dstate) or (batch, ngroups, dstate) for single-token or (batch, T, ngroups, dstate) for multi-token - D: (dim,) or (nheads, dim) - z: (batch, dim) or (batch, nheads, dim) for single-token or (batch, T, nheads, dim) for multi-token - dt_bias: (dim,) or (nheads, dim) - pad_slot_id: int - if cache_indices is passed, lets the kernel identify padded - entries that will not be processed, - for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] - in this case, the kernel will not process entries at - indices 0 and 3 - out: Preallocated ssm output tensor. Assume same shape as x. - In-place updated. - disable_state_update: If True, don't write back to state (for speculative verify) - intermediate_states_buffer: Buffer to cache intermediate states - cache_steps: Total number of steps in the buffer - retrieve_parent_token: (batch, T) tensor of parent token indices for EAGLE tree attention - intermediate_state_indices: (batch,) tensor of indices for intermediate_states_buffer operations. - If provided, uses these indices instead of state_batch_indices for the buffer. - """ - # Track original x dimensionality to squeeze output appropriately - x_orig_dim = x.dim() - - if state.dim() == 3: - state = state.unsqueeze(1) - if x.dim() == 2: - x = x.unsqueeze(1) - if x.dim() == 3: - x = x.unsqueeze(1) - if dt.dim() == 2: - dt = dt.unsqueeze(1) - if dt.dim() == 3: - dt = dt.unsqueeze(1) - if A.dim() == 2: - A = A.unsqueeze(0) - if B.dim() == 2: - B = B.unsqueeze(1) - if B.dim() == 3: - B = B.unsqueeze(1) - if C.dim() == 2: - C = C.unsqueeze(1) - if C.dim() == 3: - C = C.unsqueeze(1) - if D is not None and D.dim() == 1: - D = D.unsqueeze(0) - if z is not None: - if z.dim() == 2: - z = z.unsqueeze(1) - if z.dim() == 3: - z = z.unsqueeze(1) - if dt_bias is not None and dt_bias.dim() == 1: - dt_bias = dt_bias.unsqueeze(0) - if out is None: - out = torch.empty_like(x) - if out.dim() == 2: - out = out.unsqueeze(1) - if out.dim() == 3: - out = out.unsqueeze(1) - - _, nheads, dim, dstate = state.shape - batch, T, _, _ = x.shape - - assert x.shape == (batch, T, nheads, dim) - assert dt.shape == x.shape - assert A.shape == (nheads, dim, dstate) - ngroups = B.shape[2] - assert nheads % ngroups == 0, "nheads must be divisible by ngroups" - assert B.shape == (batch, T, ngroups, dstate) - assert C.shape == B.shape - if D is not None: - assert D.shape == (nheads, dim) - if z is not None: - assert z.shape == x.shape - if dt_bias is not None: - assert dt_bias.shape == (nheads, dim) - if state_batch_indices is not None: - assert state_batch_indices.shape == (batch,) - assert out.shape == x.shape - - grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads) - z_strides = ( - (z.stride(0), z.stride(1), z.stride(2), z.stride(3)) - if z is not None - else (0, 0, 0, 0) - ) - # We don't want autotune since it will overwrite the state - # We instead tune by hand. - BLOCK_SIZE_M, num_warps = ( - (32, 4) - if dstate <= 16 - else ( - (16, 4) - if dstate <= 32 - else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8)))) - ) - ) - tie_hdim = ( - A.stride(-1) == 0 - and A.stride(-2) == 0 - and dt.stride(-1) == 0 - and (dt_bias is None or dt_bias.stride(-1) == 0) - ) - - retrieve_parent_token_strides = ( - (retrieve_parent_token.stride(0), retrieve_parent_token.stride(1)) - if retrieve_parent_token is not None - else (0, 0) - ) - - with torch.cuda.device(x.device.index): - _selective_scan_update_kernel[grid]( - state, - x, - dt, - dt_bias, - A, - B, - C, - D, - z, - out, - state_batch_indices, - pad_slot_id, - intermediate_states_buffer, - cache_steps if cache_steps is not None else 0, - retrieve_parent_token, - intermediate_state_indices, - batch, - T, - nheads, - dim, - dstate, - nheads // ngroups, - state.stride(0), - state.stride(1), - state.stride(2), - state.stride(3), - x.stride(0), - x.stride(1), - x.stride(2), - x.stride(3), - dt.stride(0), - dt.stride(1), - dt.stride(2), - dt.stride(3), - *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else (0, 0), - A.stride(0), - A.stride(1), - A.stride(2), - B.stride(0), - B.stride(1), - B.stride(2), - B.stride(3), - C.stride(0), - C.stride(1), - C.stride(2), - C.stride(3), - *(D.stride(0), D.stride(1)) if D is not None else (0, 0), - z_strides[0], - z_strides[1], - z_strides[2], - z_strides[3], - out.stride(0), - out.stride(1), - out.stride(2), - out.stride(3), - retrieve_parent_token_strides[0], - retrieve_parent_token_strides[1], - dt_softplus, - tie_hdim, - BLOCK_SIZE_M, - DISABLE_STATE_UPDATE=disable_state_update, - num_warps=num_warps, - ) - # Squeeze T dimension if original x didn't have it (was 2D or 3D) - if x_orig_dim < 4: - out = out.squeeze(1) - return out diff --git a/tests/mamba/test_chunk_scan_combined.py b/tests/mamba/test_chunk_scan_combined.py new file mode 100644 index 0000000000..dd51b8de58 --- /dev/null +++ b/tests/mamba/test_chunk_scan_combined.py @@ -0,0 +1,703 @@ +""" +Test for Mamba2 SSD (Structured State-Space Duality) chunk scan combined kernel. + +Compares the CUTLASS CuTe DSL Blackwell implementation against the production +Triton implementation. +""" + +import sys +from pathlib import Path + +import numpy as np +import pytest +import torch + +# Add CUTLASS mamba2_ssd path +CUTLASS_MAMBA2_SSD_PATH = ( + Path(__file__).resolve().parents[2] + / "3rdparty" + / "cutlass" + / "examples" + / "python" + / "CuTeDSL" + / "blackwell" + / "mamba2_ssd" +) +sys.path.insert(0, str(CUTLASS_MAMBA2_SSD_PATH)) + +# Import Triton reference +from .triton_reference.ssd_chunk_state import _chunk_cumsum_fwd +from .triton_reference.ssd_combined import _mamba_chunk_scan_combined_fwd + + +def is_blackwell_available(): + """Check if Blackwell GPU (SM100) is available.""" + if not torch.cuda.is_available(): + return False + major, minor = torch.cuda.get_device_capability() + return major >= 10 # SM100 = Blackwell + + +# Skip all tests if not on Blackwell +pytestmark = pytest.mark.skipif( + not is_blackwell_available(), + reason="Blackwell GPU (SM100+) required for CuTe DSL Mamba2 SSD kernel", +) + + +def import_cutlass_modules(): + """Import CUTLASS modules (only when needed, as they require SM100).""" + import cuda.bindings.driver as cuda + import cutlass + import cutlass.cute as cute + import cutlass.torch as cutlass_torch + from cutlass.cute.runtime import from_dlpack + from mamba2_ssd import SSDKernel + + return { + "cuda": cuda, + "cutlass": cutlass, + "cute": cute, + "cutlass_torch": cutlass_torch, + "from_dlpack": from_dlpack, + "SSDKernel": SSDKernel, + } + + +class CutlassSSDWrapper: + """ + Wrapper around CUTLASS CuTe DSL SSD kernel to match Triton API. + + The CUTLASS kernel expects: + - Preprocessed cumsum_delta (step 1 already done) + - Specific tensor layouts different from Triton + - Output tensors preallocated + + This wrapper: + 1. Computes cumsum using Triton's step 1 (_chunk_cumsum_fwd) + 2. Converts tensors to CUTLASS layout + 3. Calls CUTLASS kernel + 4. Converts output back to Triton layout + """ + + def __init__( + self, + chunk_size: int, + headdim: int, + dstate: int, + has_d: bool = True, + d_has_hdim: bool = False, + io_dtype=None, + cumsum_dtype=None, + acc_dtype=None, + ): + """ + Initialize the wrapper. + + Args: + chunk_size: L - size of each chunk + headdim: D - head dimension + dstate: N - state dimension + has_d: Whether to fuse D scaling (Y += X*D) + d_has_hdim: If True, D is (headdim, nheads), else (1, nheads) + io_dtype: Input/output dtype (default: cutlass.BFloat16) + cumsum_dtype: Cumsum intermediate dtype (default: cutlass.Float32) + acc_dtype: Accumulator dtype (default: cutlass.Float32) + """ + self.modules = import_cutlass_modules() + cutlass = self.modules["cutlass"] + + self.chunk_size = chunk_size + self.headdim = headdim + self.dstate = dstate + self.has_d = has_d + self.d_has_hdim = d_has_hdim + + self.io_dtype = io_dtype or cutlass.BFloat16 + self.cumsum_dtype = cumsum_dtype or cutlass.Float32 + self.acc_dtype = acc_dtype or cutlass.Float32 + + # Create the kernel + SSDKernel = self.modules["SSDKernel"] + self.kernel = SSDKernel( + self.io_dtype, + self.cumsum_dtype, + self.acc_dtype, + chunk_size, + headdim, + dstate, + has_d, + d_has_hdim, + ) + + self._compiled_kernel = None + + def _create_cutlass_tensor(self, shape, permute_order, dtype, dynamic_modes): + """ + Create a tensor using the exact logic from mamba2_ssd.py to ensure compatibility. + + Args: + shape: Base shape of the tensor (before permutation) + permute_order: Order to permute dimensions + dtype: CUTLASS dtype + dynamic_modes: List of modes to mark as dynamic + + Returns: + (cute_tensor, torch_tensor): The CuTe tensor wrapper and the underlying PyTorch tensor on GPU + """ + cutlass_torch = self.modules["cutlass_torch"] + from_dlpack = self.modules["from_dlpack"] + + # Create a dummy CPU tensor with the base layout to establish the permutation pattern + # mimicking create_and_permute_tensor from mamba2_ssd.py + base_tensor = torch.empty(*shape, dtype=torch.float32) + permuted_tensor = base_tensor.permute(permute_order) + + # Move to GPU with target dtype - this creates the specific layout CUTLASS expects + torch_dtype = cutlass_torch.dtype(dtype) + dst_tensor = permuted_tensor.to(torch_dtype).cuda() + + # Create CuTe tensor + cute_tensor = from_dlpack(dst_tensor, assumed_align=16) + for mode in dynamic_modes: + cute_tensor = cute_tensor.mark_compact_shape_dynamic( + mode=mode, stride_order=dst_tensor.dim_order() + ) + + return cute_tensor, dst_tensor + + def __call__( + self, + x, + dt, + A, + B, + C, + chunk_size, + D=None, + z=None, + dt_bias=None, + initial_states=None, + seq_idx=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + ): + """ + Run the SSD kernel with Triton-compatible API. + + Args: + x: (batch, seqlen, nheads, headdim) + dt: (batch, seqlen, nheads) + A: (nheads,) + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + chunk_size: Size of chunks + D: Optional (nheads, headdim) or (nheads,) + z: Optional gating tensor (not supported yet) + dt_bias: Optional (nheads,) + initial_states: Optional (batch, nheads, headdim, dstate) + seq_idx: Optional sequence indices (not supported yet) + dt_softplus: Whether to apply softplus to dt + dt_limit: Limits for dt values + + Returns: + out: (batch, seqlen, nheads, headdim) + final_states: (batch, nheads, headdim, dstate) + """ + cutlass = self.modules["cutlass"] + cute = self.modules["cute"] + + # Validate inputs + batch, seqlen, nheads, headdim = x.shape + _, _, ngroups, dstate = B.shape + nchunks = seqlen // chunk_size + + assert seqlen % chunk_size == 0, ( + f"seqlen ({seqlen}) must be divisible by chunk_size ({chunk_size})" + ) + assert headdim == self.headdim, f"headdim mismatch: {headdim} vs {self.headdim}" + assert dstate == self.dstate, f"dstate mismatch: {dstate} vs {self.dstate}" + assert chunk_size == self.chunk_size, ( + f"chunk_size mismatch: {chunk_size} vs {self.chunk_size}" + ) + + if z is not None: + raise NotImplementedError("z (gating) not yet supported in CUTLASS wrapper") + if seq_idx is not None: + raise NotImplementedError("seq_idx not yet supported in CUTLASS wrapper") + if initial_states is not None: + raise NotImplementedError( + "initial_states not yet supported in CUTLASS wrapper" + ) + + # Step 1: Compute cumsum using Triton kernel + # dA_cumsum: (batch, nheads, nchunks, chunk_size) + # dt_processed: (batch, nheads, nchunks, chunk_size) - after softplus/bias + dA_cumsum, dt_processed = _chunk_cumsum_fwd( + dt, + A, + chunk_size, + dt_bias=dt_bias, + dt_softplus=dt_softplus, + dt_limit=dt_limit, + ) + + # Convert tensors to CUTLASS layout using the same pattern as mamba2_ssd.py + # Key: create contiguous tensor in base shape, then permute to get correct strides + # CUTLASS expects specific permuted layouts for each tensor + + # x: Triton (batch, seqlen, nheads, headdim) -> CUTLASS (headdim, chunk_size, nchunks, nheads, batch) + x_reshaped = x.reshape(batch, nchunks, chunk_size, nheads, headdim) + x_tensor, x_dst = self._create_cutlass_tensor( + [batch, nheads, headdim, nchunks, chunk_size], + [2, 4, 3, 1, 0], + self.io_dtype, + [2, 3, 4], + ) + x_dst.copy_(x_reshaped.permute(4, 2, 1, 3, 0).to(x_dst.dtype)) + + # delta (dt_processed): (batch, nheads, nchunks, chunk_size) -> (chunk_size, nchunks, nheads, batch) + delta_tensor, delta_dst = self._create_cutlass_tensor( + [batch, nheads, nchunks, chunk_size], [3, 2, 1, 0], self.io_dtype, [1, 2, 3] + ) + delta_dst.copy_(dt_processed.permute(3, 2, 1, 0).to(delta_dst.dtype)) + + # cumsum_delta (dA_cumsum): same layout as delta + cumsum_delta_tensor, cumsum_delta_dst = self._create_cutlass_tensor( + [batch, nheads, nchunks, chunk_size], + [3, 2, 1, 0], + self.cumsum_dtype, + [1, 2, 3], + ) + cumsum_delta_dst.copy_(dA_cumsum.permute(3, 2, 1, 0).to(cumsum_delta_dst.dtype)) + + # B: Triton (batch, seqlen, ngroups, dstate) -> CUTLASS (chunk_size, dstate, nchunks, ngroups, batch) + B_reshaped = B.reshape(batch, nchunks, chunk_size, ngroups, dstate) + b_tensor, b_dst = self._create_cutlass_tensor( + [batch, ngroups, dstate, nchunks, chunk_size], + [4, 2, 3, 1, 0], + self.io_dtype, + [2, 3, 4], + ) + b_dst.copy_(B_reshaped.permute(2, 4, 1, 3, 0).to(b_dst.dtype)) + + # C: same layout as B + C_reshaped = C.reshape(batch, nchunks, chunk_size, ngroups, dstate) + c_tensor, c_dst = self._create_cutlass_tensor( + [batch, ngroups, dstate, nchunks, chunk_size], + [4, 2, 3, 1, 0], + self.io_dtype, + [2, 3, 4], + ) + c_dst.copy_(C_reshaped.permute(2, 4, 1, 3, 0).to(c_dst.dtype)) + + # D: (nheads,) -> CUTLASS (1, nheads) or (headdim, nheads) + if self.has_d and D is not None: + if self.d_has_hdim: + # D is (nheads, headdim) -> (headdim, nheads) + if D.dim() == 1: + D = D.unsqueeze(1).expand(-1, headdim) + d_tensor, d_dst = self._create_cutlass_tensor( + [nheads, headdim], [1, 0], self.io_dtype, [1] + ) + d_dst.copy_(D.t().to(d_dst.dtype)) + else: + # D is (nheads,) -> (1, nheads) + if D.dim() == 2: + D = D[:, 0] + d_tensor, d_dst = self._create_cutlass_tensor( + [nheads, 1], [1, 0], self.io_dtype, [1] + ) + d_dst.copy_(D.unsqueeze(0).to(d_dst.dtype)) + else: + d_tensor = None + + # Output tensors + # y: (chunk_size, headdim, nchunks, nheads, batch) + y_tensor, y_cutlass = self._create_cutlass_tensor( + [batch, nheads, headdim, nchunks, chunk_size], + [4, 2, 3, 1, 0], + self.io_dtype, + [2, 3, 4], + ) + + # fstate: (headdim, dstate, nheads, batch) + fstate_tensor, fstate_cutlass = self._create_cutlass_tensor( + [batch, nheads, headdim, dstate], [2, 3, 1, 0], self.io_dtype, [2, 3] + ) + + # Get max active clusters + hardware_info = cutlass.utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters(1) + + stream = cutlass.cuda.default_stream() + + # Compile kernel if not already done + if self._compiled_kernel is None: + self._compiled_kernel = cute.compile( + self.kernel, + x_tensor, + cumsum_delta_tensor, + delta_tensor, + b_tensor, + c_tensor, + y_tensor, + fstate_tensor, + d_tensor, + max_active_clusters, + stream, + ) + + # Run kernel + self._compiled_kernel( + x_tensor, + cumsum_delta_tensor, + delta_tensor, + b_tensor, + c_tensor, + y_tensor, + fstate_tensor, + d_tensor, + stream, + ) + + # Convert outputs back to Triton layout + # y_cutlass is (L, D, C, EH, B) + # We need to map it back to (batch, seqlen, nheads, headdim) + # Permute (L, D, C, EH, B) -> (B, C, L, EH, D) + y_permuted = y_cutlass.permute(4, 2, 0, 3, 1) + y_out = y_permuted.reshape(batch, seqlen, nheads, headdim) + + # fstate_cutlass is (D, N, EH, B) + # We need (batch, nheads, headdim, dstate) + # Permute (D, N, EH, B) -> (B, EH, D, N) + fstate_out = fstate_cutlass.permute(3, 2, 0, 1).contiguous() + + return y_out, fstate_out + + +class TestChunkScanCombined: + """Test class for chunk scan combined kernel.""" + + # Test configuration - slightly relaxed tolerance for bf16 precision + ATOL = 5e-2 + RTOL = 5e-2 + INPUT_DTYPE = torch.bfloat16 + + @pytest.fixture(params=[1, 2]) + def batch(self, request): + return request.param + + @pytest.fixture(params=[8]) # nheads must be divisible by ngroups + def nheads(self, request): + return request.param + + @pytest.fixture(params=[64]) # Must match kernel's D + def headdim(self, request): + return request.param + + @pytest.fixture( + params=[128] + ) # Must match kernel's N (CUTLASS kernel is hardcoded for N=128) + def dstate(self, request): + return request.param + + @pytest.fixture(params=[128]) # Must match kernel's L + def chunk_size(self, request): + return request.param + + @pytest.fixture(params=[1, 4]) # Number of chunks + def nchunks(self, request): + return request.param + + @pytest.fixture(params=[8]) # ngroups divides nheads + def ngroups(self, request): + return request.param + + @pytest.fixture + def inputs(self, batch, nheads, headdim, dstate, chunk_size, nchunks, ngroups): + """Create test inputs.""" + torch.manual_seed(42) + + seqlen = chunk_size * nchunks + + # x: (batch, seqlen, nheads, headdim) + x = torch.randn( + batch, seqlen, nheads, headdim, dtype=self.INPUT_DTYPE, device="cuda" + ) + + # dt: (batch, seqlen, nheads) + dt = torch.randn(batch, seqlen, nheads, dtype=torch.float32, device="cuda") + + # A: (nheads,) - should be negative for stability + A = -torch.rand(nheads, dtype=torch.float32, device="cuda") - 1.0 + + # B: (batch, seqlen, ngroups, dstate) + B = torch.randn( + batch, seqlen, ngroups, dstate, dtype=self.INPUT_DTYPE, device="cuda" + ) + + # C: (batch, seqlen, ngroups, dstate) + C = torch.randn( + batch, seqlen, ngroups, dstate, dtype=self.INPUT_DTYPE, device="cuda" + ) + + # D: (nheads, headdim) or (nheads,) + D = torch.randn(nheads, dtype=self.INPUT_DTYPE, device="cuda") + + # dt_bias: (nheads,) + dt_bias = torch.rand(nheads, dtype=torch.float32, device="cuda") - 4.0 + + return { + "x": x, + "dt": dt, + "A": A, + "B": B, + "C": C, + "D": D, + "dt_bias": dt_bias, + "chunk_size": chunk_size, + "seqlen": seqlen, + "nheads": nheads, + "headdim": headdim, + "dstate": dstate, + "ngroups": ngroups, + } + + @pytest.fixture + def reference_output(self, inputs): + """Compute reference output using Triton implementation.""" + out, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd( + inputs["x"], + inputs["dt"], + inputs["A"], + inputs["B"], + inputs["C"], + inputs["chunk_size"], + D=inputs["D"], + z=None, + dt_bias=inputs["dt_bias"], + initial_states=None, + seq_idx=None, + dt_softplus=True, + ) + return out, final_states + + def _print_mismatch_details(self, ref, test, name, atol, rtol): + """Print detailed mismatch analysis.""" + ref_np = ref.detach().cpu().float().numpy() + test_np = test.detach().cpu().float().numpy() + + mismatch_mask = ~np.isclose(ref_np, test_np, atol=atol, rtol=rtol) + num_mismatches = np.sum(mismatch_mask) + total_elements = ref_np.size + + print(f"\nDetailed {name} mismatch analysis:") + print( + f"Number of mismatched elements: {num_mismatches} / {total_elements} " + f"({100 * num_mismatches / total_elements:.2f}%)" + ) + + if num_mismatches > 0: + mismatch_indices = np.argwhere(mismatch_mask) + print(f"First few {name} mismatch locations (up to 10):") + for idx in mismatch_indices[:10]: + idx_tuple = tuple(int(i) for i in idx) + ref_val = ref_np[idx_tuple] + test_val = test_np[idx_tuple] + diff = abs(ref_val - test_val) + rel_diff = diff / (abs(ref_val) + 1e-8) + print( + f" Index {idx_tuple}: ref={ref_val:.6f}, test={test_val:.6f}, " + f"diff={diff:.6e}, rel_diff={rel_diff:.6e}" + ) + + def test_output_correctness(self, inputs, reference_output): + """Test that CUTLASS kernel output matches Triton reference.""" + out_ref, final_states_ref = reference_output + + # Create CUTLASS wrapper + wrapper = CutlassSSDWrapper( + chunk_size=inputs["chunk_size"], + headdim=inputs["headdim"], + dstate=inputs["dstate"], + has_d=True, + d_has_hdim=False, # D is (nheads,) not (nheads, headdim) + ) + + # Run CUTLASS kernel + out_test, final_states_test = wrapper( + inputs["x"], + inputs["dt"], + inputs["A"], + inputs["B"], + inputs["C"], + inputs["chunk_size"], + D=inputs["D"], + dt_bias=inputs["dt_bias"], + dt_softplus=True, + ) + + # Compare outputs - cast to same dtype for comparison + out_ref_cmp = out_ref.to(out_test.dtype) + out_match = torch.allclose( + out_ref_cmp, out_test, atol=self.ATOL, rtol=self.RTOL + ) + + if out_match: + print( + f"✓ Outputs match within tolerance (atol={self.ATOL}, rtol={self.RTOL})" + ) + else: + print("✗ Outputs do NOT match within tolerance") + self._print_mismatch_details( + out_ref_cmp, out_test, "output", self.ATOL, self.RTOL + ) + + # Compare final states - cast to same dtype for comparison + final_states_ref_cmp = final_states_ref.to(final_states_test.dtype) + states_match = torch.allclose( + final_states_ref_cmp, final_states_test, atol=self.ATOL, rtol=self.RTOL + ) + + if states_match: + print( + f"✓ Final states match within tolerance (atol={self.ATOL}, rtol={self.RTOL})" + ) + else: + print("✗ Final states do NOT match within tolerance") + self._print_mismatch_details( + final_states_ref_cmp, + final_states_test, + "final_states", + self.ATOL, + self.RTOL, + ) + + assert out_match, "Output mismatch between CUTLASS and Triton" + assert states_match, "Final states mismatch between CUTLASS and Triton" + + +class TestChunkScanCombinedNoD(TestChunkScanCombined): + """Test chunk scan without D scaling.""" + + @pytest.fixture(params=[1]) + def batch(self, request): + return request.param + + @pytest.fixture(params=[1]) + def nchunks(self, request): + return request.param + + @pytest.fixture + def inputs(self, batch, nheads, headdim, dstate, chunk_size, nchunks, ngroups): + """Create test inputs without D.""" + torch.manual_seed(42) + + seqlen = chunk_size * nchunks + + x = torch.randn( + batch, seqlen, nheads, headdim, dtype=self.INPUT_DTYPE, device="cuda" + ) + dt = torch.randn(batch, seqlen, nheads, dtype=torch.float32, device="cuda") + A = -torch.rand(nheads, dtype=torch.float32, device="cuda") - 1.0 + B = torch.randn( + batch, seqlen, ngroups, dstate, dtype=self.INPUT_DTYPE, device="cuda" + ) + C = torch.randn( + batch, seqlen, ngroups, dstate, dtype=self.INPUT_DTYPE, device="cuda" + ) + dt_bias = torch.rand(nheads, dtype=torch.float32, device="cuda") - 4.0 + + return { + "x": x, + "dt": dt, + "A": A, + "B": B, + "C": C, + "D": None, + "dt_bias": dt_bias, + "chunk_size": chunk_size, + "seqlen": seqlen, + "nheads": nheads, + "headdim": headdim, + "dstate": dstate, + "ngroups": ngroups, + } + + @pytest.fixture + def reference_output(self, inputs): + """Compute reference output using Triton implementation without D.""" + out, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd( + inputs["x"], + inputs["dt"], + inputs["A"], + inputs["B"], + inputs["C"], + inputs["chunk_size"], + D=None, + z=None, + dt_bias=inputs["dt_bias"], + initial_states=None, + seq_idx=None, + dt_softplus=True, + ) + return out, final_states + + def test_output_correctness(self, inputs, reference_output): + """Test without D scaling.""" + out_ref, final_states_ref = reference_output + + wrapper = CutlassSSDWrapper( + chunk_size=inputs["chunk_size"], + headdim=inputs["headdim"], + dstate=inputs["dstate"], + has_d=False, + d_has_hdim=False, + ) + + out_test, final_states_test = wrapper( + inputs["x"], + inputs["dt"], + inputs["A"], + inputs["B"], + inputs["C"], + inputs["chunk_size"], + D=None, + dt_bias=inputs["dt_bias"], + dt_softplus=True, + ) + + # Cast to same dtype for comparison + out_ref_cmp = out_ref.to(out_test.dtype) + final_states_ref_cmp = final_states_ref.to(final_states_test.dtype) + + out_match = torch.allclose( + out_ref_cmp, out_test, atol=self.ATOL, rtol=self.RTOL + ) + states_match = torch.allclose( + final_states_ref_cmp, final_states_test, atol=self.ATOL, rtol=self.RTOL + ) + + if out_match: + print("✓ [NoD] Outputs match within tolerance") + else: + print("✗ [NoD] Outputs do NOT match") + self._print_mismatch_details( + out_ref_cmp, out_test, "output", self.ATOL, self.RTOL + ) + + if states_match: + print("✓ [NoD] Final states match within tolerance") + else: + print("✗ [NoD] Final states do NOT match") + self._print_mismatch_details( + final_states_ref_cmp, + final_states_test, + "final_states", + self.ATOL, + self.RTOL, + ) + + assert out_match + assert states_match diff --git a/tests/mamba/test_selective_state_update_mtp.py b/tests/mamba/test_selective_state_update_mtp.py index 38e1592d22..022936b3bd 100644 --- a/tests/mamba/test_selective_state_update_mtp.py +++ b/tests/mamba/test_selective_state_update_mtp.py @@ -11,7 +11,7 @@ import flashinfer -from .selective_state_update_triton import selective_state_update_triton +from .triton_reference.selective_state_update import selective_state_update_triton from .test_utils import create_test_inputs, clone_preserving_strides diff --git a/tests/mamba/test_selective_state_update_stp.py b/tests/mamba/test_selective_state_update_stp.py index 566a93df9c..705d8b0f2b 100644 --- a/tests/mamba/test_selective_state_update_stp.py +++ b/tests/mamba/test_selective_state_update_stp.py @@ -4,7 +4,7 @@ import flashinfer -from .selective_state_update_triton import selective_state_update_triton +from .triton_reference.selective_state_update import selective_state_update_triton from .test_utils import create_test_inputs, clone_preserving_strides diff --git a/tests/mamba/triton_reference/__init__.py b/tests/mamba/triton_reference/__init__.py new file mode 100644 index 0000000000..b7a96808f5 --- /dev/null +++ b/tests/mamba/triton_reference/__init__.py @@ -0,0 +1,6 @@ +""" +Triton reference implementations for Mamba kernels. + +This package contains production-level Triton implementations used as +reference for testing CUDA/CUTLASS kernel implementations. +""" diff --git a/tests/mamba/triton_reference/softplus.py b/tests/mamba/triton_reference/softplus.py new file mode 100644 index 0000000000..b2ce0d58b2 --- /dev/null +++ b/tests/mamba/triton_reference/softplus.py @@ -0,0 +1,26 @@ +# Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/mamba_ssm.py + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright (c) 2024, Tri Dao, Albert Gu. + +import triton +import triton.language as tl +from packaging import version + +TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") + +if TRITON3: + + @triton.jit + def softplus(dt): + dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt) + return dt + +else: + + @triton.jit + def softplus(dt): + dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) + return dt diff --git a/tests/mamba/triton_reference/ssd_bmm.py b/tests/mamba/triton_reference/ssd_bmm.py new file mode 100644 index 0000000000..846949e0ed --- /dev/null +++ b/tests/mamba/triton_reference/ssd_bmm.py @@ -0,0 +1,272 @@ +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_bmm.py +# Copyright (c) 2024, Tri Dao, Albert Gu. +# +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=2, + ), + ], + key=["chunk_size", "K", "IS_CAUSAL"], +) +@triton.jit +def _bmm_chunk_fwd_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + out_ptr, + seq_idx_ptr, + # Matrix dimensions + seqlen, + chunk_size, + K, + ngroups, + stride_a_batch, + stride_a_seqlen, + stride_a_head, + stride_ak, + stride_b_batch, + stride_b_seqlen, + stride_b_head, + stride_bk, + stride_out_batch, + stride_out_chunk, + stride_out_head, + stride_outm, + stride_outn, + stride_seq_idx_batch, + stride_seq_idx_seqlen, + # Meta-parameters + IS_CAUSAL: tl.constexpr, + dot_dtype: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_ch = tl.program_id(axis=2).to(tl.int64) + pid_c = pid_ch // ngroups + pid_h = pid_ch - pid_c * ngroups + num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + if IS_CAUSAL: + if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M: + return + a_ptr += ( + pid_b * stride_a_batch + + pid_c * chunk_size * stride_a_seqlen + + pid_h * stride_a_head + ) + b_ptr += ( + pid_b * stride_b_batch + + pid_c * chunk_size * stride_b_seqlen + + pid_h * stride_b_head + ) + if HAS_SEQ_IDX: + seq_idx_ptr += ( + pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + ) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen) + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) + & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ).to(dot_dtype) + b = tl.load( + b_ptrs, + mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) + & (offs_n[None, :] < chunk_size_limit), + other=0.0, + ).to(dot_dtype) + acc += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + if HAS_SEQ_IDX: + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + seq_idx_m = tl.load( + seq_idx_ptr + offs_m * stride_seq_idx_seqlen, + mask=offs_m < chunk_size_limit, + other=-1, + ) + seq_idx_n = tl.load( + seq_idx_ptr + offs_n * stride_seq_idx_seqlen, + mask=offs_n < chunk_size_limit, + other=-2, + ) + acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0) + out = acc.to(out_ptr.dtype.element_ty) + + out_ptr += ( + pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head + ) + out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn) + tl.store( + out_ptrs, + out, + mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size), + ) + + +def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None): + """ + Argument: + a: (batch, seqlen, k) or (batch, seqlen, ngroups, k) + b: (batch, seqlen, k) or (batch, seqlen, ngroups, k) + seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out. + causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are + guaranteed to be correct. + Return: + out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size) + """ + # Check constraints. + has_groups = a.dim() == 4 + if not has_groups: + batch, seqlen, k = a.shape + else: + batch, seqlen, ngroups, k = a.shape + assert b.shape == a.shape + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if a.stride(-1) != 1 and a.stride(1) != 1: + a = a.contiguous() + if b.stride(-1) != 1 and b.stride(1) != 1: + b = b.contiguous() + nchunks = math.ceil(seqlen / chunk_size) + # Allocates output. + out_dtype = a.dtype if output_dtype is None else output_dtype + out = torch.empty( + ( + (batch, nchunks, chunk_size, chunk_size) + if not has_groups + else (batch, nchunks, ngroups, chunk_size, chunk_size) + ), + device=a.device, + dtype=out_dtype, + ) + dot_dtype = ( + tl.bfloat16 + if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 + else ( + tl.float16 + if a.dtype == torch.float16 or b.dtype == torch.float16 + else tl.float32 + ) + ) + grid = lambda META: ( + triton.cdiv(chunk_size, META["BLOCK_SIZE_M"]) + * triton.cdiv(chunk_size, META["BLOCK_SIZE_N"]), + batch, + nchunks if not has_groups else nchunks * ngroups, + ) + with torch.cuda.device(a.device.index): + _bmm_chunk_fwd_kernel[grid]( + a, + b, + out, + seq_idx, + seqlen, + chunk_size, + k, + ngroups if has_groups else 1, + a.stride(0), + a.stride(1), + 0 if not has_groups else a.stride(2), + a.stride(-1), + b.stride(0), + b.stride(1), + 0 if not has_groups else b.stride(2), + b.stride(-1), + out.stride(0), + out.stride(1), + 0 if not has_groups else out.stride(2), + out.stride(-2), + out.stride(-1), + *( + (seq_idx.stride(0), seq_idx.stride(1)) + if seq_idx is not None + else (0, 0) + ), + causal, + dot_dtype, + HAS_SEQ_IDX=seq_idx is not None, + ) + return out diff --git a/tests/mamba/triton_reference/ssd_chunk_scan.py b/tests/mamba/triton_reference/ssd_chunk_scan.py new file mode 100644 index 0000000000..67cdad636c --- /dev/null +++ b/tests/mamba/triton_reference/ssd_chunk_scan.py @@ -0,0 +1,628 @@ +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_scan.py +# Copyright (c) 2024, Tri Dao, Albert Gu. +# +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import triton +import triton.language as tl +from packaging import version + +TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0") + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=2, + ), + ], + key=["chunk_size", "hdim", "dstate", "IS_CAUSAL"], +) +@triton.jit +def _chunk_scan_fwd_kernel( + # Pointers to matrices + cb_ptr, + x_ptr, + z_ptr, + out_ptr, + out_x_ptr, + dt_ptr, + dA_cumsum_ptr, + seq_idx_ptr, + C_ptr, + states_ptr, + D_ptr, + initstates_ptr, + chunk_indices_ptr, + chunk_offsets_ptr, + chunk_meta_num, + # Matrix dimensions + chunk_size, + hdim, + dstate, + batch, + seqlen, + nheads_ngroups_ratio, + # Strides + stride_cb_batch, + stride_cb_chunk, + stride_cb_head, + stride_cb_csize_m, + stride_cb_csize_k, + stride_x_batch, + stride_x_seqlen, + stride_x_head, + stride_x_hdim, + stride_z_batch, + stride_z_seqlen, + stride_z_head, + stride_z_hdim, + stride_out_batch, + stride_out_seqlen, + stride_out_head, + stride_out_hdim, + stride_dt_batch, + stride_dt_chunk, + stride_dt_head, + stride_dt_csize, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + stride_seq_idx_batch, + stride_seq_idx_seqlen, + stride_C_batch, + stride_C_seqlen, + stride_C_head, + stride_C_dstate, + stride_states_batch, + stride_states_chunk, + stride_states_head, + stride_states_hdim, + stride_states_dstate, + stride_init_states_batch, + stride_init_states_head, + stride_init_states_hdim, + stride_init_states_dstate, + stride_D_head, + # Meta-parameters + IS_CAUSAL: tl.constexpr, + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + HAS_Z: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, + IS_TRITON_22: tl.constexpr, + HAS_INITSTATES: tl.constexpr, +): + pid_bc = tl.program_id(axis=1).to(tl.int64) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + if not HAS_INITSTATES: + c_idx = pid_c + c_off = 0 + else: + c_idx = tl.load(chunk_indices_ptr + pid_c, mask=pid_c > -1, other=0) + c_off = tl.load(chunk_offsets_ptr + pid_c, mask=pid_c > -1, other=0) + + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + cb_ptr += ( + pid_b * stride_cb_batch + + c_idx * stride_cb_chunk + + (pid_h // nheads_ngroups_ratio) * stride_cb_head + ) + x_ptr += ( + pid_b * stride_x_batch + + c_idx * chunk_size * stride_x_seqlen + + pid_h * stride_x_head + ) + dt_ptr += pid_b * stride_dt_batch + c_idx * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += ( + pid_b * stride_dA_cs_batch + + c_idx * stride_dA_cs_chunk + + pid_h * stride_dA_cs_head + ) + C_ptr += ( + pid_b * stride_C_batch + + c_idx * chunk_size * stride_C_seqlen + + (pid_h // nheads_ngroups_ratio) * stride_C_head + ) + + # M-block offsets and prev states + # - logic in next block may override these if there is an active offset + offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M) + prev_states_ptr = ( + states_ptr + + pid_b * stride_states_batch + + c_idx * stride_states_chunk + + pid_h * stride_states_head + ) + prev_states_hdim = stride_states_hdim + prev_states_dstate = stride_states_dstate + + chunk_size_limit = min(chunk_size, seqlen - c_idx * chunk_size) + if HAS_SEQ_IDX: + seq_idx_ptr += ( + pid_b * stride_seq_idx_batch + c_idx * chunk_size * stride_seq_idx_seqlen + ) + + # - we only need seq_idx_prev to be aligned to chunk boundary + seq_idx_prev = tl.load( + seq_idx_ptr - stride_seq_idx_seqlen, mask=c_idx >= 1, other=0 + ) + + if HAS_INITSTATES: + # if there are init states, we only need seq_idx_m to point + # what is the current seq_idx + + # get current seq idx + if (pid_m * BLOCK_SIZE_M + c_off) < chunk_size_limit: + seq_idx_m = tl.load( + seq_idx_ptr + + (pid_m * BLOCK_SIZE_M + c_off) * stride_seq_idx_seqlen, + ) + + # - recall that in ssd_state_passing, for the case c_off == 0 + # i.e., the very first sequence, we made states_ptr hold its initial state + # so this edge case is taken care of + if ( + (c_off == 0) + and ( + seq_idx_prev != seq_idx_m + ) # if a seq is changed exactly on boundary + or (c_off > 0) # implies a new example (pseudo chunk) + ): + # - replace prev_states_ptr with init_states + prev_states_ptr = ( + initstates_ptr + + seq_idx_m * stride_init_states_batch + + pid_h * stride_init_states_head + ) + prev_states_hdim = stride_init_states_hdim # override strides + prev_states_dstate = stride_init_states_dstate + + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dA_cs_m = tl.load( + dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0 + ).to(tl.float32) + + # - handle chunk state limit + if HAS_INITSTATES: + # have to split this if otherwise compilation will have problems + dA_cs_m_boundary = 0.0 + + # get the c_idx for the next (logica) chunk + c_idx_n = tl.load( + chunk_indices_ptr + (pid_c + 1), + mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num, + other=-1, # to trigger different chunk + ) + + # - there are things to consider + # A. if c_off > 0 then we need to move the dA_cs boundary to ensure correct + # contribution of past states + # B. if c_off_n < chunk_size_limit, then we need to adjust this so as not to + # encroach into the next sequence, where c_off_n is the offset of the next + # (logical) chunk. + # An equivalent check for B is c_idx == c_idx_n, where there is repetition in + # (logical) chunk indices. + + if (c_idx == c_idx_n) or c_off > 0: + # get the next offset + c_off_n = tl.load( + chunk_offsets_ptr + (pid_c + 1), + mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num, + other=chunk_size, + ) + + # in this case, adjust down the chunk_size_limit + if c_idx == c_idx_n: + chunk_size_limit = min(c_off_n, chunk_size_limit) + + # get the cs at the offset boundary + # - c_off == 0 is a passthrough + # - We need dA_cs at the boundary, defined by c_off - no need + # to increase pointer by pid_m (it is a constant offset, + # i.e. the same for all blocks) + dA_cs_m_boundary = tl.load( + dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize, + mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)), + other=0.0, + ).to(tl.float32) + + if HAS_SEQ_IDX: + # - handle seq idx when HAS_INITSTATES==False + if not HAS_INITSTATES: + seq_idx_m = tl.load( + seq_idx_ptr + offs_m * stride_seq_idx_seqlen, + mask=offs_m < chunk_size_limit, + other=-1, + ) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Without the if (pid_c > -1), with Triton 2.1.0, I get + # Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed. + # With Triton 2.2.0, this works + if IS_TRITON_22 or c_idx > -1: + # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 + offs_k_dstate = tl.arange( + 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K + ) + C_ptrs = C_ptr + ( + offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate + ) + + prev_states_ptrs = prev_states_ptr + ( + offs_n[None, :] * prev_states_hdim + + offs_k_dstate[:, None] * prev_states_dstate + ) + if HAS_SEQ_IDX: + if not HAS_INITSTATES: + # - this is for continuous batching where there is no init states + scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) + else: + # - if there is initstates, we will rely on prev_states, no zeroing + # required. + scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary) + else: + scale_m = tl.exp(dA_cs_m) + if BLOCK_SIZE_DSTATE <= 128: + C = tl.load( + C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) + & (offs_k_dstate[None, :] < dstate), + other=0.0, + ) + + prev_states = tl.load( + prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), + other=0.0, + ) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + acc = tl.dot(C, prev_states) * scale_m[:, None] + else: + for k in range(0, dstate, BLOCK_SIZE_K): + C = tl.load( + C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) + & (offs_k_dstate[None, :] < dstate - k), + other=0.0, + ) + # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty) + prev_states = tl.load( + prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate - k) + & (offs_n[None, :] < hdim), + other=0.0, + ) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + acc += tl.dot(C, prev_states) + C_ptrs += BLOCK_SIZE_K + prev_states_ptrs += BLOCK_SIZE_K + acc *= scale_m[:, None] + + offs_k = tl.arange(0, BLOCK_SIZE_K) + c_off + cb_ptrs = cb_ptr + ( + offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k + ) + x_ptrs = x_ptr + ( + offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim + ) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + K_MAX = ( + chunk_size_limit + if not IS_CAUSAL + else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) + ) + for k in range(0, K_MAX, BLOCK_SIZE_K): + cb = tl.load( + cb_ptrs, + mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k), + other=0.0, + ).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to( + tl.float32 + ) + # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j]. + # So we don't need masking wrt seq_idx here. + cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :]) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) + cb *= dt_k + if IS_CAUSAL: + mask = offs_m[:, None] >= k + offs_k[None, :] + cb = tl.where(mask, cb, 0.0) + cb = cb.to(x_ptr.dtype.element_ty) + x = tl.load( + x_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim), + other=0.0, + ) + acc += tl.dot(cb, x) + cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + offs_out_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M) + offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + if HAS_D: + if D_HAS_HDIM: + D = tl.load( + D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0 + ).to(tl.float32) + else: + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + x_residual = tl.load( + x_ptr + + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim), + mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), + other=0.0, + ).to(tl.float32) + acc += x_residual * D + + if HAS_Z: + out_x_ptr += ( + pid_b * stride_out_batch + + c_idx * chunk_size * stride_out_seqlen + + pid_h * stride_out_head + ) + out_x_ptrs = out_x_ptr + ( + stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] + ) + tl.store( + out_x_ptrs, + acc, + mask=(offs_out_m[:, None] < chunk_size_limit) + & (offs_out_n[None, :] < hdim), + ) + + z_ptr += ( + pid_b * stride_z_batch + + c_idx * chunk_size * stride_z_seqlen + + pid_h * stride_z_head + ) + z_ptrs = z_ptr + ( + stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :] + ) + z = tl.load( + z_ptrs, + mask=(offs_out_m[:, None] < chunk_size_limit) + & (offs_out_n[None, :] < hdim), + other=0.0, + ).to(tl.float32) + acc *= z * tl.sigmoid(z) + + out_ptr += ( + pid_b * stride_out_batch + + c_idx * chunk_size * stride_out_seqlen + + pid_h * stride_out_head + ) + out_ptrs = out_ptr + ( + stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim + ) + tl.store( + out_ptrs, + acc, + mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), + ) + + +def _chunk_scan_fwd( + cb, + x, + dt, + dA_cumsum, + C, + states, + D=None, + z=None, + seq_idx=None, + chunk_indices=None, + chunk_offsets=None, + initial_states=None, + out=None, +): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = C.shape + assert nheads % ngroups == 0 + assert C.shape == (batch, seqlen, ngroups, dstate) + assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert states.shape == (batch, nchunks, nheads, headdim, dstate) + + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + + if initial_states is not None: + # with initial states, we need to take care of how + # seq_idx crosses the boundaries + assert batch == 1, "chunk scan only supports initial states with batch 1" + assert chunk_indices is not None and chunk_offsets is not None, ( + "chunk_indices and chunk_offsets should have been set" + ) + else: + chunk_indices, chunk_offsets = None, None + else: + chunk_indices, chunk_offsets = None, None + + if out is None: + out = torch.empty_like(x) + assert out.shape == x.shape + + if z is not None: + out_x = torch.empty_like(x) + assert out_x.stride() == out.stride() + else: + out_x = None + + grid = lambda META: ( + triton.cdiv(chunk_size, META["BLOCK_SIZE_M"]) + * triton.cdiv(headdim, META["BLOCK_SIZE_N"]), + batch * nchunks if chunk_offsets is None else len(chunk_offsets), + nheads, + ) + z_strides = ( + (z.stride(0), z.stride(1), z.stride(2), z.stride(3)) + if z is not None + else (0, 0, 0, 0) + ) + _chunk_scan_fwd_kernel[grid]( + cb, + x, + z, + out, + out_x, + dt, + dA_cumsum, + seq_idx, + C, + states, + D, + initial_states, + chunk_indices, + chunk_offsets, + len(chunk_indices) if chunk_indices is not None else 0, + chunk_size, + headdim, + dstate, + batch, + seqlen, + nheads // ngroups, + cb.stride(0), + cb.stride(1), + cb.stride(2), + cb.stride(3), + cb.stride(4), + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + z_strides[0], + z_strides[1], + z_strides[2], + z_strides[3], + out.stride(0), + out.stride(1), + out.stride(2), + out.stride(3), + dt.stride(0), + dt.stride(2), + dt.stride(1), + dt.stride(3), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + C.stride(0), + C.stride(1), + C.stride(2), + C.stride(3), + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + states.stride(4), + *( + ( + initial_states.stride(0), + initial_states.stride(1), + initial_states.stride(2), + initial_states.stride(3), + ) + if initial_states is not None + else (0, 0, 0, 0) + ), + D.stride(0) if D is not None else 0, + True, + D is not None, + D.dim() == 2 if D is not None else True, + BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), + HAS_Z=z is not None, + HAS_SEQ_IDX=seq_idx is not None, + IS_TRITON_22=TRITON_22, + HAS_INITSTATES=initial_states is not None, + ) + return out_x if z is not None else out diff --git a/tests/mamba/triton_reference/ssd_chunk_state.py b/tests/mamba/triton_reference/ssd_chunk_state.py new file mode 100644 index 0000000000..14b9d28851 --- /dev/null +++ b/tests/mamba/triton_reference/ssd_chunk_state.py @@ -0,0 +1,771 @@ +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_state.py +# Copyright (c) 2024, Tri Dao, Albert Gu. +# +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +import triton +import triton.language as tl + +from .softplus import softplus + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE_H": 1}), + triton.Config({"BLOCK_SIZE_H": 2}), + triton.Config({"BLOCK_SIZE_H": 4}), + triton.Config({"BLOCK_SIZE_H": 8}), + triton.Config({"BLOCK_SIZE_H": 16}), + triton.Config({"BLOCK_SIZE_H": 32}), + triton.Config({"BLOCK_SIZE_H": 64}), + ], + key=["chunk_size", "nheads"], +) +@triton.jit +def _chunk_cumsum_fwd_kernel( + # Pointers to matrices + dt_ptr, + A_ptr, + dt_bias_ptr, + dt_out_ptr, + dA_cumsum_ptr, + # Matrix dimension + batch, + seqlen, + nheads, + chunk_size, + dt_min, + dt_max, + # Strides + stride_dt_batch, + stride_dt_seqlen, + stride_dt_head, + stride_A_head, + stride_dt_bias_head, + stride_dt_out_batch, + stride_dt_out_chunk, + stride_dt_out_head, + stride_dt_out_csize, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + # Meta-parameters + DT_SOFTPLUS: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_CHUNK: tl.constexpr, +): + pid_b = tl.program_id(axis=0) + + # if dt is long, may cause problems, so use 64 bit + # https://github.com/triton-lang/triton/issues/1058 + pid_c = tl.program_id(axis=1).to(tl.int64) + pid_h = tl.program_id(axis=2) + dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen + dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + + offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) + offs_c = tl.arange(0, BLOCK_SIZE_CHUNK) + dt_ptrs = dt_ptr + ( + offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen + ) + A_ptrs = A_ptr + offs_h * stride_A_head + dt_out_ptrs = dt_out_ptr + ( + offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize + ) + dA_cs_ptrs = dA_cumsum_ptr + ( + offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize + ) + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + dt = tl.load( + dt_ptrs, + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), + other=0.0, + ).to(tl.float32) + if HAS_DT_BIAS: + dt_bias = tl.load( + dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0 + ).to(tl.float32) + dt += dt_bias[:, None] + if DT_SOFTPLUS: + dt = tl.where(dt <= 20.0, softplus(dt), dt) + # As of Triton 2.2.0, tl.clamp is not available yet + # dt = tl.clamp(dt, dt_min, dt_max) + dt = tl.minimum(tl.maximum(dt, dt_min), dt_max) + dt = tl.where( + (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0 + ) + tl.store( + dt_out_ptrs, + dt, + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size), + ) + A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) + dA = dt * A[:, None] + dA_cs = tl.cumsum(dA, axis=1) + tl.store( + dA_cs_ptrs, + dA_cs, + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size), + ) + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=2, + ), + ], + key=["hdim", "dstate", "chunk_size"], +) +@triton.jit +def _chunk_state_fwd_kernel( + # Pointers to matrices + x_ptr, + b_ptr, + states_ptr, + dt_ptr, + dA_cumsum_ptr, + seq_idx_ptr, + # Matrix dimensions + hdim, + dstate, + chunk_size, + batch, + seqlen, + nheads_ngroups_ratio, + # Strides + stride_x_batch, + stride_x_seqlen, + stride_x_head, + stride_x_hdim, + stride_b_batch, + stride_b_seqlen, + stride_b_head, + stride_b_dstate, + stride_states_batch, + stride_states_chunk, + stride_states_head, + stride_states_hdim, + stride_states_dstate, + stride_dt_batch, + stride_dt_chunk, + stride_dt_head, + stride_dt_csize, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + stride_seq_idx_batch, + stride_seq_idx_seqlen, + # Meta-parameters + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1).to(tl.int64) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + b_ptr += ( + pid_b * stride_b_batch + + pid_c * chunk_size * stride_b_seqlen + + (pid_h // nheads_ngroups_ratio) * stride_b_head + ) + x_ptr += ( + pid_b * stride_x_batch + + pid_c * chunk_size * stride_x_seqlen + + pid_h * stride_x_head + ) + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += ( + pid_b * stride_dA_cs_batch + + pid_c * stride_dA_cs_chunk + + pid_h * stride_dA_cs_head + ) + if HAS_SEQ_IDX: + seq_idx_ptr += ( + pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + ) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + x_ptrs = x_ptr + ( + offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen + ) + b_ptrs = b_ptr + ( + offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen + ) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to( + tl.float32 + ) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + if HAS_SEQ_IDX: + seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + if HAS_SEQ_IDX: + seq_idx_last = tl.load( + seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen + ) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, chunk_size_limit, BLOCK_SIZE_K): + x = tl.load( + x_ptrs, + mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), + other=0.0, + ).to(tl.float32) + dA_cs_k = tl.load( + dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0 + ).to(tl.float32) + if HAS_SEQ_IDX: + seq_idx_k = tl.load( + seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1 + ) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to( + tl.float32 + ) + if not HAS_SEQ_IDX: + # scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k + scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k + else: + # scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0) + scale = tl.where( + seq_idx_k == seq_idx_last, + tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k, + 0.0, + ) + b *= scale[:, None] + b = b.to(x_ptr.dtype.element_ty) + acc += tl.dot(x, b) + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + b_ptrs += BLOCK_SIZE_K * stride_b_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + if HAS_SEQ_IDX: + seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen + states = acc.to(states_ptr.dtype.element_ty) + + states_ptr += ( + pid_b * stride_states_batch + + pid_c * stride_states_chunk + + pid_h * stride_states_head + ) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + states_ptrs = states_ptr + ( + offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate + ) + c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) + tl.store(states_ptrs, states, mask=c_mask) + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=2, + ), + ], + key=["hdim", "dstate", "chunk_size"], +) +@triton.jit +def _chunk_state_varlen_kernel( + # Pointers to matrices + x_ptr, + b_ptr, + dt_ptr, + dA_cumsum_ptr, + chunk_states_ptr, + cu_seqlens_ptr, + states_ptr, + initstates_ptr, + # Matrix dimensions + hdim, + dstate, + chunk_size, + seqlen, + nheads_ngroups_ratio, + # Strides + stride_x_seqlen, + stride_x_head, + stride_x_hdim, + stride_b_seqlen, + stride_b_head, + stride_b_dstate, + stride_dt_chunk, + stride_dt_head, + stride_dt_csize, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + stride_chunk_states_chunk, + stride_chunk_states_head, + stride_chunk_states_hdim, + stride_chunk_states_dstate, + stride_states_batch, + stride_states_head, + stride_states_hdim, + stride_states_dstate, + stride_init_states_batch, + stride_init_states_head, + stride_init_states_hdim, + stride_init_states_dstate, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + HAS_INITSTATES: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + end_idx = tl.load(cu_seqlens_ptr + pid_b + 1) + pid_c = (end_idx - 1) // chunk_size + b_ptr += ( + pid_c * chunk_size * stride_b_seqlen + + (pid_h // nheads_ngroups_ratio) * stride_b_head + ) + x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + chunk_states_ptr += ( + pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head + ) + + if HAS_INITSTATES: + # if there are init states provided, we differentiate between states (which + # are boundary conditions at a chunk boundary) and initstates (which are boundary + # conditions when a new example in a cont batch starts) + initstates_ptr += pid_h * stride_init_states_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + x_ptrs = x_ptr + ( + offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen + ) + b_ptrs = b_ptr + ( + offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen + ) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cs_last = tl.load( + dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize + ).to(tl.float32) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + + chunk_size_limit = end_idx - pid_c * chunk_size + start_idx = tl.load(cu_seqlens_ptr + pid_b) + start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, chunk_size_limit, BLOCK_SIZE_K): + x = tl.load( + x_ptrs, + mask=(offs_m[:, None] < hdim) + & (offs_k[None, :] < chunk_size_limit - k) + & (offs_k[None, :] >= start_idx_cur - k), + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) + & (offs_n[None, :] < dstate) + & (offs_k[:, None] >= start_idx_cur - k), + other=0.0, + ).to(tl.float32) + dA_cs_k = tl.load( + dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0 + ).to(tl.float32) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to( + tl.float32 + ) + scale = tl.where( + (offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k), + tl.exp(dA_cs_last - dA_cs_k) * dt_k, + 0.0, + ) + b *= scale[:, None] + b = b.to(x_ptr.dtype.element_ty) + acc += tl.dot(x, b) + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + b_ptrs += BLOCK_SIZE_K * stride_b_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + # If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk + # If HAS_INITSTATES==True need to consider two possiblties + # - if start_idx < pid_c * chunk_size, then we need to take the past_states_ptrs + # - if state_idx >= pid * chunk_size, then we need to insert initstates + if ( + (start_idx < pid_c * chunk_size) # first chunk + or (HAS_INITSTATES) + ): + dA_cs_boundary = 0.0 # default + + if not HAS_INITSTATES: + past_states_ptrs = chunk_states_ptr + ( + offs_m[:, None] * stride_chunk_states_hdim + + offs_n[None, :] * stride_chunk_states_dstate + ) + else: + # - this seems repetitve, buts its to help the compiler + if start_idx < pid_c * chunk_size: + past_states_ptrs = chunk_states_ptr + ( + offs_m[:, None] * stride_chunk_states_hdim + + offs_n[None, :] * stride_chunk_states_dstate + ) + else: + past_states_ptrs = initstates_ptr + ( + pid_b * stride_init_states_batch + + offs_m[:, None] * stride_init_states_hdim + + offs_n[None, :] * stride_init_states_dstate + ) + + # need to adjust the boundary + if start_idx > pid_c * chunk_size: + dA_cs_boundary = tl.load( + dA_cumsum_ptr + + (start_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize + ).to(tl.float32) + + past_states = tl.load( + past_states_ptrs, + mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate), + other=0.0, + ).to(tl.float32) + + scale = tl.exp(dA_cs_last - dA_cs_boundary) + acc += past_states * scale + + states = acc.to(states_ptr.dtype.element_ty) + + states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + states_ptrs = states_ptr + ( + offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate + ) + c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) + tl.store(states_ptrs, states, mask=c_mask) + + +def _chunk_cumsum_fwd( + dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf")) +): + batch, seqlen, nheads = dt.shape + assert A.shape == (nheads,) + if dt_bias is not None: + assert dt_bias.shape == (nheads,) + nchunks = math.ceil(seqlen / chunk_size) + dt_out = torch.empty( + batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32 + ) + dA_cumsum = torch.empty( + batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32 + ) + grid_chunk_cs = lambda META: ( + batch, + nchunks, + triton.cdiv(nheads, META["BLOCK_SIZE_H"]), + ) + with torch.cuda.device(dt.device.index): + _chunk_cumsum_fwd_kernel[grid_chunk_cs]( + dt, + A, + dt_bias, + dt_out, + dA_cumsum, + batch, + seqlen, + nheads, + chunk_size, + dt_limit[0], + dt_limit[1], + dt.stride(0), + dt.stride(1), + dt.stride(2), + A.stride(0), + dt_bias.stride(0) if dt_bias is not None else 0, + dt_out.stride(0), + dt_out.stride(2), + dt_out.stride(1), + dt_out.stride(3), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), + dt_softplus, + HAS_DT_BIAS=dt_bias is not None, + BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), + ) + return dA_cumsum, dt_out + + +def _chunk_state_fwd( + B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_fp32=True +): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if states is not None: + assert states.shape == (batch, nchunks, nheads, headdim, dstate) + else: + states_dtype = torch.float32 if states_in_fp32 else B.dtype + states = torch.empty( + (batch, nchunks, nheads, headdim, dstate), + device=x.device, + dtype=states_dtype, + ) + grid = lambda META: ( + triton.cdiv(headdim, META["BLOCK_SIZE_M"]) + * triton.cdiv(dstate, META["BLOCK_SIZE_N"]), + batch * nchunks, + nheads, + ) + with torch.cuda.device(x.device.index): + _chunk_state_fwd_kernel[grid]( + x, + B, + states, + dt, + dA_cumsum, + seq_idx, + headdim, + dstate, + chunk_size, + batch, + seqlen, + nheads // ngroups, + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + B.stride(0), + B.stride(1), + B.stride(2), + B.stride(-1), + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + states.stride(4), + dt.stride(0), + dt.stride(2), + dt.stride(1), + dt.stride(3), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), + *( + (seq_idx.stride(0), seq_idx.stride(1)) + if seq_idx is not None + else (0, 0) + ), + HAS_SEQ_IDX=seq_idx is not None, + ) + return states + + +def chunk_state_varlen( + B, x, dt, dA_cumsum, cu_seqlens, chunk_states, initial_states=None +): + total_seqlen, nheads, headdim = x.shape + _, nchunks, chunk_size = dt.shape + _, ngroups, dstate = B.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + assert nheads % ngroups == 0 + assert B.shape == (total_seqlen, ngroups, dstate) + assert dt.shape == (nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert chunk_states.shape == (nchunks, nheads, headdim, dstate) + + if initial_states is not None: + assert initial_states.shape == (batch, nheads, headdim, dstate) + + states = torch.empty( + batch, + nheads, + headdim, + dstate, + dtype=chunk_states.dtype, + device=chunk_states.device, + ) + grid = lambda META: ( + triton.cdiv(headdim, META["BLOCK_SIZE_M"]) + * triton.cdiv(dstate, META["BLOCK_SIZE_N"]), + batch, + nheads, + ) + with torch.cuda.device(x.device.index): + _chunk_state_varlen_kernel[grid]( + x, + B, + dt, + dA_cumsum, + chunk_states, + cu_seqlens, + states, + initial_states, + headdim, + dstate, + chunk_size, + total_seqlen, + nheads // ngroups, + x.stride(0), + x.stride(1), + x.stride(2), + B.stride(0), + B.stride(1), + B.stride(2), + dt.stride(1), + dt.stride(0), + dt.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + chunk_states.stride(0), + chunk_states.stride(1), + chunk_states.stride(2), + chunk_states.stride(3), + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + *( + ( + initial_states.stride(0), + initial_states.stride(1), + initial_states.stride(2), + initial_states.stride(3), + ) + if initial_states is not None + else (0, 0, 0, 0) + ), + HAS_INITSTATES=initial_states is not None, + ) + return states diff --git a/tests/mamba/triton_reference/ssd_combined.py b/tests/mamba/triton_reference/ssd_combined.py new file mode 100644 index 0000000000..8b65b38a1b --- /dev/null +++ b/tests/mamba/triton_reference/ssd_combined.py @@ -0,0 +1,265 @@ +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_combined.py +# Copyright (c) 2024, Tri Dao, Albert Gu. +# +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from einops import rearrange + +from .ssd_bmm import _bmm_chunk_fwd +from .ssd_chunk_scan import _chunk_scan_fwd +from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd, chunk_state_varlen +from .ssd_state_passing import _state_passing_fwd + + +def is_int_pow_2(n): + return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0 + + +def _mamba_chunk_scan_combined_fwd( + x, + dt, + A, + B, + C, + chunk_size, + D=None, + z=None, + dt_bias=None, + initial_states=None, + seq_idx=None, + chunk_indices=None, + chunk_offsets=None, + cu_seqlens=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + state_dtype=None, + out=None, +): + assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2" + batch, seqlen, nheads, headdim = x.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert dt.shape == (batch, seqlen, nheads) + assert A.shape == (nheads,) + assert C.shape == B.shape + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if ( + x.stride(-1) != 1 and x.stride(1) != 1 + ): # Either M or K dimension should be contiguous + x = x.contiguous() + if ( + z is not None and z.stride(-1) != 1 and z.stride(1) != 1 + ): # Either M or K dimension should be contiguous + z = z.contiguous() + if D is not None and D.stride(-1) != 1: + D = D.contiguous() + if initial_states is not None: + if cu_seqlens is None: + assert initial_states.shape == (batch, nheads, headdim, dstate) + else: + assert initial_states.shape == ( + len(cu_seqlens) - 1, + nheads, + headdim, + dstate, + ) + + # This function executes 5 sub-functions for computing mamba + # - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/ + # which has a minimal implementation to understand the below operations + # - as explained by the blog, mamba is a special case of causal attention + # - the idea is to chunk the attention matrix and compute each + # submatrix separately using different optimizations. + # - see the blog and paper for a visualization of the submatrices + # which we refer to in the comments below + + # 1. Compute chunked cumsum of A * dt + # - here dt may go through a softplus activation + dA_cumsum, dt = _chunk_cumsum_fwd( + dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit + ) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + # - for handling chunked prefill, this requires i) initial_states + # ii) seq_idx iii) is_cont_batched and (iv) chunk_offsets to be all specified. + # - When a new seq_idx is detected, we will stop passing the prev_state + # and switch accordingly to the init_state corresponding to the new seq_idx. + # - We will also make sure that the dA_cumsum is taken only from the start of the + # sequence (hence we need the full dA_cumsum tensor and not just the values at chunk boundaries) + # - this will ensure that states will be updated with the rightmost flushed seq_idx + # of the previous chunk. This implies that the first chunk of states is either 0 + # or equal to init_states of the first example. + states, final_states = _state_passing_fwd( + rearrange(states, "... p n -> ... (p n)"), + dA_cumsum, + initial_states=( + rearrange(initial_states, "... p n -> ... (p n)") + if initial_states is not None + else None + ), + seq_idx=seq_idx, + chunk_size=chunk_size, + out_dtype=state_dtype if state_dtype is not None else C.dtype, + is_cont_batched=cu_seqlens is not None, + chunk_offsets=chunk_offsets, + ) + states, final_states = ( + rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states] + ) + + # 4. Compute batched matrix multiply for C_j^T B_i terms + CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32) + + # 5. Scan and compute the diagonal blocks, taking into + # account past causal states. + # - if initial states are provided, then states information will be + # augmented with initial_states. + # - to do this properly, we need to account for example changes in + # the continuous batch, therefore we introduce pseudo chunks, which is + # a chunk that is split up each time an example changes. + # - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had + # a seq_idx change, in which case we take states information from + # init_states. + out_x = _chunk_scan_fwd( + CB, + x, + dt, + dA_cumsum, + C, + states, + D=D, + z=z, + seq_idx=seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + initial_states=initial_states, + out=out, + ) + if cu_seqlens is None: + return out_x, dt, dA_cumsum, states, final_states + else: + assert batch == 1, ( + "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" + ) + varlen_states = chunk_state_varlen( + B.squeeze(0), + x.squeeze(0), + dt.squeeze(0), + dA_cumsum.squeeze(0), + cu_seqlens, + states.squeeze(0), + initial_states=initial_states, + ) + return out_x, dt, dA_cumsum, states, final_states, varlen_states + + +def mamba_chunk_scan_combined( + x, + dt, + A, + B, + C, + chunk_size, + D=None, + z=None, + dt_bias=None, + initial_states=None, + seq_idx=None, + chunk_indices=None, + chunk_offsets=None, + cu_seqlens=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + out=None, + return_final_states=False, + return_varlen_states=False, + state_dtype=None, +): + """ + Argument: + x: (batch, seqlen, nheads, headdim) + dt: (batch, seqlen, nheads) + A: (nheads) + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + chunk_size: int + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + dt_bias: (nheads,) + initial_states: (batch, nheads, headdim, dstate) + seq_idx: (batch, seqlen) + cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True + dt_softplus: Whether to apply softplus to dt + out: Preallocated output tensor + state_dtype: The data type of the ssm state + """ + + if not return_varlen_states: + cu_seqlens = None + else: + assert cu_seqlens is not None, ( + "cu_seqlens must be provided if return_varlen_states is True" + ) + out_x, dt_out, dA_cumsum, states, final_states, *rest = ( + _mamba_chunk_scan_combined_fwd( + x, + dt, + A, + B, + C, + chunk_size, + D=D, + z=z, + dt_bias=dt_bias, + initial_states=initial_states, + seq_idx=seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + cu_seqlens=cu_seqlens, + dt_softplus=dt_softplus, + dt_limit=dt_limit, + out=out, + state_dtype=state_dtype, + ) + ) + if not return_varlen_states: + if not return_final_states: + return + else: + return final_states + else: + varlen_states = rest[0] + return ( + (varlen_states) + if not return_final_states + else (final_states, varlen_states) + ) diff --git a/tests/mamba/triton_reference/ssd_state_passing.py b/tests/mamba/triton_reference/ssd_state_passing.py new file mode 100644 index 0000000000..df0fe2b27f --- /dev/null +++ b/tests/mamba/triton_reference/ssd_state_passing.py @@ -0,0 +1,282 @@ +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_state_passing.py +# Copyright (c) 2024, Tri Dao, Albert Gu. +# +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 64}), + triton.Config({"BLOCK_SIZE": 128}), + triton.Config({"BLOCK_SIZE": 256}), + triton.Config({"BLOCK_SIZE": 512}), + triton.Config({"BLOCK_SIZE": 1024}), + triton.Config({"BLOCK_SIZE": 2048}), + ], + key=["dim"], +) +@triton.jit +def _state_passing_fwd_kernel( + # Pointers to matrices + states_ptr, + out_ptr, + final_states_ptr, + dA_cs_ptr, + initstates_ptr, + seq_idx_ptr, + chunk_offsets_ptr, + chunk_meta_num, + # Matrix dimensions + dim, + nchunks, + seqlen, + chunk_size, + # Strides + stride_states_batch, + stride_states_chunk, + stride_states_head, + stride_states_dim, + stride_out_batch, + stride_out_chunk, + stride_out_head, + stride_out_dim, + stride_final_states_batch, + stride_final_states_head, + stride_final_states_dim, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + stride_initstates_batch, + stride_initstates_head, + stride_initstates_dim, + stride_seq_idx_batch, + stride_seq_idx_seqlen, + # Meta-parameters + HAS_INITSTATES: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + IS_CONT_BATCHED: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + pid_m = tl.program_id(axis=0) + states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head + dA_cs_ptr += ( + pid_b * stride_dA_cs_batch + + pid_h * stride_dA_cs_head + + (chunk_size - 1) * stride_dA_cs_csize + ) + out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + final_states_ptr += ( + pid_b * stride_final_states_batch + pid_h * stride_final_states_head + ) + if HAS_INITSTATES: + initstates_ptr += pid_h * stride_initstates_head + if not IS_CONT_BATCHED: + initstates_ptr += pid_b * stride_initstates_batch + + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + states_ptrs = states_ptr + offs_m * stride_states_dim + out_ptrs = out_ptr + offs_m * stride_out_dim + final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim + + # - states will be the past state of the sequence that continues on the current check + if not HAS_INITSTATES: + states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + else: + initstates_ptr += offs_m * stride_initstates_dim + initstates_ptrs = initstates_ptr + # - for cont batches, for the first chunk mean it will be the first batch's + # init state + states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + tl.store(out_ptrs, states, mask=offs_m < dim) + out_ptrs += stride_out_chunk + prev_seq_idx_chunk_end = 0 + logical_chunk_idx = 0 + for c in range(nchunks): + new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + dA_cs = tl.load(dA_cs_ptr).to(tl.float32) + scale_mask = True + if HAS_SEQ_IDX: + # - the seq to pass forward is the one that is flushed to the right + # boundary. + # - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk. + seq_idx_chunk_end = tl.load( + seq_idx_ptr + + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen + ) + if HAS_INITSTATES: + if IS_CONT_BATCHED and prev_seq_idx_chunk_end != seq_idx_chunk_end: + # this means in the current chunk the rightmost flushed seq + # has changed. + # - so we do not propagate the state from previous chunk + # - but rather we load that sequence's init state + initstates_ptrs = ( + initstates_ptr + seq_idx_chunk_end * stride_initstates_batch + ) + + # - update state with seq_idx_new's init state + states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to( + tl.float32 + ) + + # - we need to consider the cumsum only of the last sequence in the chunk + # - find its starting position (given by c_off of the logical chunk index) + # - and subtract the cumsum just before that position from the total cumsum + # - first, update the logical chunk index (add the number of sequences in the current physical chunk): + # sequence index at the start of the current chunk + seq_idx_chunk_start = tl.load( + seq_idx_ptr + + min(c * chunk_size, seqlen) * stride_seq_idx_seqlen + ) + logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start + # - load the chunk offset: + c_off = tl.load( + chunk_offsets_ptr + logical_chunk_idx, + mask=logical_chunk_idx < chunk_meta_num, + other=0, + ) + # - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything + if c_off > 0: + # - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset + dA_cs_boundary = tl.load( + dA_cs_ptr + - (chunk_size - 1) * stride_dA_cs_csize + + (c_off - 1) * stride_dA_cs_csize, + mask=(c_off - 1) > -1 and c_off < chunk_size, + other=0.0, + ) + dA_cs -= dA_cs_boundary + + # - increment logical chunk index for every physical chunk + logical_chunk_idx += 1 + else: + scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end + prev_seq_idx_chunk_end = seq_idx_chunk_end + + scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0) + states = scale * states + new_states + if c < nchunks - 1: + tl.store(out_ptrs, states, mask=offs_m < dim) + else: + tl.store(final_states_ptrs, states, mask=offs_m < dim) + states_ptrs += stride_states_chunk + dA_cs_ptr += stride_dA_cs_chunk + out_ptrs += stride_out_chunk + + +def _state_passing_fwd( + states, + dA_cumsum, + initial_states=None, + seq_idx=None, + chunk_size=None, + out_dtype=None, + is_cont_batched=False, + chunk_offsets=None, +): + batch, nchunks, nheads, dim = states.shape + if chunk_size is None: + chunk_size = dA_cumsum.shape[-1] + else: + assert chunk_size == dA_cumsum.shape[-1] + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + if initial_states is not None: + if is_cont_batched: + # - if cu_seqlens is provided, then the initial states + # are used for continuous batching. In which case we + # require seq_idx to be provided + assert seq_idx is not None, ( + "seq_idx must be provided for continuous batching" + ) + # - we also need chunk_offsets to be provided, to account + # for computation of dA_cumsum from the start of the + # sequence + assert chunk_offsets is not None, ( + "chunk_offsets must be provided for continuous batching" + ) + else: + # - this is the regular batching case, where initial + # states are used are for each example of the batch. + assert initial_states.shape == (batch, nheads, dim) + + if seq_idx is not None: + seqlen = seq_idx.shape[-1] + assert seq_idx.shape == (batch, seqlen) + out_dtype = states.dtype if out_dtype is None else out_dtype + out = torch.empty( + (batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype + ) + final_states = torch.empty( + (batch, nheads, dim), device=states.device, dtype=torch.float32 + ) + grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE"]), batch, nheads) + with torch.cuda.device(states.device.index): + _state_passing_fwd_kernel[grid]( + states, + out, + final_states, + dA_cumsum, + initial_states, + seq_idx, + chunk_offsets, + len(chunk_offsets) if chunk_offsets is not None else 0, + dim, + nchunks, + seqlen if seq_idx is not None else 0, + chunk_size, + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + out.stride(0), + out.stride(1), + out.stride(2), + out.stride(3), + final_states.stride(0), + final_states.stride(1), + final_states.stride(2), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), + *( + ( + initial_states.stride(0), + initial_states.stride(1), + initial_states.stride(2), + ) + if initial_states is not None + else (0, 0, 0) + ), + *( + (seq_idx.stride(0), seq_idx.stride(1)) + if seq_idx is not None + else (0, 0) + ), + HAS_INITSTATES=initial_states is not None, + HAS_SEQ_IDX=seq_idx is not None, + IS_CONT_BATCHED=is_cont_batched, + ) + return out, final_states From 304fd59489f71ea6332831d91ed1fdb40c638db5 Mon Sep 17 00:00:00 2001 From: Igor Shovkun Date: Tue, 17 Feb 2026 14:07:19 -0800 Subject: [PATCH 15/33] Enable .jinja templates for mamba - Move dtype dispatch and instantiation to codegen via Jinja templates - Generate config and instantiation files per dtype combination - Update Python JIT logic to build/load kernels for specific dtypes - Remove C++ dtype dispatch helpers from selective_state_update.cu - Update kernel launcher comment for clarity on consumer warps --- csrc/selective_state_update.cu | 215 ++---------------- ...ective_state_update_customize_config.jinja | 10 + csrc/selective_state_update_dtype_inst.jinja | 32 +++ csrc/selective_state_update_kernel_inst.cu | 15 ++ flashinfer/aot.py | 102 ++++++--- .../jit/mamba/selective_state_update.py | 181 +++++++++++---- flashinfer/mamba/selective_state_update.py | 89 +++++--- .../kernel_selective_state_update_stp.cuh | 6 +- 8 files changed, 345 insertions(+), 305 deletions(-) create mode 100644 csrc/selective_state_update_customize_config.jinja create mode 100644 csrc/selective_state_update_dtype_inst.jinja create mode 100644 csrc/selective_state_update_kernel_inst.cu diff --git a/csrc/selective_state_update.cu b/csrc/selective_state_update.cu index d8ada31ed8..7990aa167c 100644 --- a/csrc/selective_state_update.cu +++ b/csrc/selective_state_update.cu @@ -14,8 +14,8 @@ * limitations under the License. */ #include -#include +#include "selective_state_update_config.inc" #include "tvm_ffi_utils.h" using namespace flashinfer; @@ -124,80 +124,6 @@ inline void validate_dtype_consistency( } } -// Helper to convert dtype code to string for error messages -inline const char* dtype_code_to_string(int64_t code) { - if (code == bfloat16_code) return "bfloat16"; - if (code == float16_code) return "float16"; - if (code == float32_code) return "float32"; - return "unknown"; -} - -// Type traits to map dtype codes to C++ types -template -struct DTypeToType; - -template <> -struct DTypeToType { - using type = nv_bfloat16; -}; -template <> -struct DTypeToType { - using type = half; -}; -template <> -struct DTypeToType { - using type = float; -}; -template <> -struct DTypeToType { - using type = int32_t; -}; -template <> -struct DTypeToType { - using type = int64_t; -}; - -// Allowed dtype combinations: {state_code, input_code, weight_code, matrixA_code, stateIndex_code} -constexpr std::tuple allowed_dtype_combos[] = { - {bfloat16_code, bfloat16_code, bfloat16_code, float32_code, int32_code}, - {float16_code, bfloat16_code, bfloat16_code, float32_code, int32_code}, - {float32_code, bfloat16_code, bfloat16_code, float32_code, int32_code}, - {bfloat16_code, bfloat16_code, float32_code, float32_code, int32_code}, - {float16_code, bfloat16_code, float32_code, float32_code, int32_code}, - {float32_code, bfloat16_code, float32_code, float32_code, int32_code}, - {bfloat16_code, bfloat16_code, bfloat16_code, float32_code, int64_code}, - {float16_code, bfloat16_code, bfloat16_code, float32_code, int64_code}, - {float32_code, bfloat16_code, bfloat16_code, float32_code, int64_code}, - {bfloat16_code, bfloat16_code, float32_code, float32_code, int64_code}, - {float16_code, bfloat16_code, float32_code, float32_code, int64_code}, - {float32_code, bfloat16_code, float32_code, float32_code, int64_code}, -}; - -// Helper to dispatch to the right template instantiation for STP -template -void dispatchCombo(SelectiveStateUpdateParams& p, cudaStream_t stream) { - using state_t = typename DTypeToType::type; - using input_t = typename DTypeToType::type; - using weight_t = typename DTypeToType::type; - using matrixA_t = typename DTypeToType::type; - using stateIndex_t = typename DTypeToType::type; - invokeSelectiveStateUpdate(p, stream); -} - -// Helper to dispatch to the right template instantiation for MTP -template -void dispatchComboMTP(mtp::SelectiveStateMTPParams& p, cudaStream_t stream) { - using state_t = typename DTypeToType::type; - using input_t = typename DTypeToType::type; - using weight_t = typename DTypeToType::type; - using matrixA_t = typename DTypeToType::type; - using stateIndex_t = typename DTypeToType::type; - mtp::invokeSelectiveStateUpdateMTP(p, - stream); -} - void run_selective_state_update_stp(TensorView const& state, TensorView const& x, TensorView const& dt, TensorView const& A, TensorView const& B, TensorView const& C, TensorView const& D, @@ -344,64 +270,7 @@ void run_selective_state_update_stp(TensorView const& state, TensorView const& x ffi::CUDADeviceGuard device_guard(state.device().device_id); const cudaStream_t stream = get_stream(state.device()); - // Dispatch based on dtype combination - DLDataType state_dtype = state.dtype(); - DLDataType input_dtype = x.dtype(); - DLDataType weight_dtype = dt.dtype(); - DLDataType matrixA_dtype = A.dtype(); - int64_t state_dtype_code = encode_dlpack_dtype(state_dtype); - int64_t input_dtype_code = encode_dlpack_dtype(input_dtype); - int64_t weight_dtype_code = encode_dlpack_dtype(weight_dtype); - int64_t matrixA_dtype_code = encode_dlpack_dtype(matrixA_dtype); - - // Get state_batch_indices dtype, default to int32 if not provided - int64_t stateIndex_dtype_code = int32_code; - if (state_batch_indices.has_value()) { - DLDataType stateIndex_dtype = state_batch_indices.value().dtype(); - stateIndex_dtype_code = encode_dlpack_dtype(stateIndex_dtype); - } - - // Dispatch kernel based on dtype combination - auto dtype_key = std::make_tuple(state_dtype_code, input_dtype_code, weight_dtype_code, - matrixA_dtype_code, stateIndex_dtype_code); - - // Compile-time recursive dispatcher using Y-combinator pattern for lambda self-recursion - auto tryDispatch = [&](const auto& key, auto idx, auto& self) -> bool { - constexpr size_t I = decltype(idx)::value; - if constexpr (I < std::size(allowed_dtype_combos)) { - constexpr auto combo = allowed_dtype_combos[I]; - if (key == combo) { - constexpr auto s = std::get<0>(combo); - constexpr auto i = std::get<1>(combo); - constexpr auto w = std::get<2>(combo); - constexpr auto m = std::get<3>(combo); - constexpr auto si = std::get<4>(combo); - dispatchCombo(p, stream); - return true; - } - return self(key, std::integral_constant{}, self); - } - return false; - }; - - // Dispatch using compile-time type traits - if (!tryDispatch(dtype_key, std::integral_constant{}, tryDispatch)) { - // Unsupported dtype combination - build error message dynamically - std::ostringstream error_msg; - error_msg << "Unsupported dtype combination for selective_state_update: " << "state_dtype=" - << state_dtype.code << ":" << state_dtype.bits << ", " - << "input_dtype=" << input_dtype.code << ":" << input_dtype.bits << ", " - << "weight_dtype=" << weight_dtype.code << ":" << weight_dtype.bits << ", " - << "matrixA_dtype=" << matrixA_dtype.code << ":" << matrixA_dtype.bits - << ". Supported combos include:\n"; - for (const auto& combo : allowed_dtype_combos) { - error_msg << " (state=" << dtype_code_to_string(std::get<0>(combo)) - << ", input=" << dtype_code_to_string(std::get<1>(combo)) - << ", weight=" << dtype_code_to_string(std::get<2>(combo)) - << ", matrixA=" << dtype_code_to_string(std::get<3>(combo)) << ")\n"; - } - TVM_FFI_ICHECK(false) << error_msg.str(); - } + invokeSelectiveStateUpdate(p, stream); } void run_selective_state_update_mtp( @@ -505,6 +374,15 @@ void run_selective_state_update_mtp( validate_intermediate_state_indices(intermediate_state_indices, batch); validate_intermediate_states_buffer(intermediate_states_buffer); + // Validate that state_batch_indices and intermediate_state_indices have the same dtype + if (state_batch_indices.has_value() && intermediate_state_indices.has_value()) { + DLDataType state_batch_idx_dtype = state_batch_indices.value().dtype(); + DLDataType intermediate_idx_dtype = intermediate_state_indices.value().dtype(); + FLASHINFER_CHECK(state_batch_idx_dtype.code == intermediate_idx_dtype.code && + state_batch_idx_dtype.bits == intermediate_idx_dtype.bits, + "state_batch_indices and intermediate_state_indices must have the same dtype"); + } + // Validate cache_steps is non-negative FLASHINFER_CHECK(cache_steps >= 0, "cache_steps must be non-negative, got ", cache_steps); @@ -588,75 +466,8 @@ void run_selective_state_update_mtp( ffi::CUDADeviceGuard device_guard(state.device().device_id); const cudaStream_t stream = get_stream(state.device()); - // Dispatch based on dtype combination - DLDataType state_dtype = state.dtype(); - DLDataType input_dtype = x.dtype(); - DLDataType weight_dtype = dt.dtype(); - DLDataType matrixA_dtype = A.dtype(); - int64_t state_dtype_code = encode_dlpack_dtype(state_dtype); - int64_t input_dtype_code = encode_dlpack_dtype(input_dtype); - int64_t weight_dtype_code = encode_dlpack_dtype(weight_dtype); - int64_t matrixA_dtype_code = encode_dlpack_dtype(matrixA_dtype); - - // Get stateIndex dtype from whichever index tensor is available - // If both are provided, they must have the same dtype - int64_t stateIndex_dtype_code = int32_code; // default - if (state_batch_indices.has_value() && intermediate_state_indices.has_value()) { - DLDataType state_batch_idx_dtype = state_batch_indices.value().dtype(); - DLDataType intermediate_idx_dtype = intermediate_state_indices.value().dtype(); - FLASHINFER_CHECK(state_batch_idx_dtype.code == intermediate_idx_dtype.code && - state_batch_idx_dtype.bits == intermediate_idx_dtype.bits, - "state_batch_indices and intermediate_state_indices must have the same dtype"); - stateIndex_dtype_code = encode_dlpack_dtype(state_batch_idx_dtype); - } else if (state_batch_indices.has_value()) { - DLDataType state_batch_idx_dtype = state_batch_indices.value().dtype(); - stateIndex_dtype_code = encode_dlpack_dtype(state_batch_idx_dtype); - } else if (intermediate_state_indices.has_value()) { - DLDataType intermediate_idx_dtype = intermediate_state_indices.value().dtype(); - stateIndex_dtype_code = encode_dlpack_dtype(intermediate_idx_dtype); - } - - // Dispatch kernel based on dtype combination - auto dtype_key = std::make_tuple(state_dtype_code, input_dtype_code, weight_dtype_code, - matrixA_dtype_code, stateIndex_dtype_code); - - // Compile-time recursive dispatcher using Y-combinator pattern for lambda self-recursion - auto tryDispatch = [&](const auto& key, auto idx, auto& self) -> bool { - constexpr size_t I = decltype(idx)::value; - if constexpr (I < std::size(allowed_dtype_combos)) { - constexpr auto combo = allowed_dtype_combos[I]; - if (key == combo) { - constexpr auto s = std::get<0>(combo); - constexpr auto i = std::get<1>(combo); - constexpr auto w = std::get<2>(combo); - constexpr auto m = std::get<3>(combo); - constexpr auto si = std::get<4>(combo); - dispatchComboMTP(p, stream); - return true; - } - return self(key, std::integral_constant{}, self); - } - return false; - }; - - // Dispatch using compile-time type traits - if (!tryDispatch(dtype_key, std::integral_constant{}, tryDispatch)) { - // Unsupported dtype combination - build error message dynamically - std::ostringstream error_msg; - error_msg << "Unsupported dtype combination for selective_state_update: " << "state_dtype=" - << state_dtype.code << ":" << state_dtype.bits << ", " - << "input_dtype=" << input_dtype.code << ":" << input_dtype.bits << ", " - << "weight_dtype=" << weight_dtype.code << ":" << weight_dtype.bits << ", " - << "matrixA_dtype=" << matrixA_dtype.code << ":" << matrixA_dtype.bits - << ". Supported combos include:\n"; - for (const auto& combo : allowed_dtype_combos) { - error_msg << " (state=" << dtype_code_to_string(std::get<0>(combo)) - << ", input=" << dtype_code_to_string(std::get<1>(combo)) - << ", weight=" << dtype_code_to_string(std::get<2>(combo)) - << ", matrixA=" << dtype_code_to_string(std::get<3>(combo)) << ")\n"; - } - TVM_FFI_ICHECK(false) << error_msg.str(); - } + mtp::invokeSelectiveStateUpdateMTP(p, + stream); } // ============================================================================= diff --git a/csrc/selective_state_update_customize_config.jinja b/csrc/selective_state_update_customize_config.jinja new file mode 100644 index 0000000000..a2f70798c7 --- /dev/null +++ b/csrc/selective_state_update_customize_config.jinja @@ -0,0 +1,10 @@ +#pragma once +#include +#include +#include + +using state_t = {{ state_dtype }}; +using input_t = {{ input_dtype }}; +using weight_t = {{ weight_dtype }}; +using matrixA_t = {{ matrixA_dtype }}; +using stateIndex_t = {{ stateIndex_dtype }}; diff --git a/csrc/selective_state_update_dtype_inst.jinja b/csrc/selective_state_update_dtype_inst.jinja new file mode 100644 index 0000000000..4879afaed2 --- /dev/null +++ b/csrc/selective_state_update_dtype_inst.jinja @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Auto-generated file - do not edit directly. +// Generated by flashinfer/jit/mamba/selective_state_update.py + +#include + +namespace flashinfer::mamba { + +template void invokeSelectiveStateUpdate<{{ input_dtype }}, {{ weight_dtype }}, {{ matrixA_dtype }}, {{ state_dtype }}, {{ stateIndex_dtype }}>( + SelectiveStateUpdateParams& params, cudaStream_t stream); + +namespace mtp { +template void invokeSelectiveStateUpdateMTP<{{ input_dtype }}, {{ weight_dtype }}, {{ matrixA_dtype }}, {{ state_dtype }}, {{ stateIndex_dtype }}>( + SelectiveStateMTPParams& params, cudaStream_t stream); +} // namespace mtp + +} // namespace flashinfer::mamba diff --git a/csrc/selective_state_update_kernel_inst.cu b/csrc/selective_state_update_kernel_inst.cu new file mode 100644 index 0000000000..e84e152c9b --- /dev/null +++ b/csrc/selective_state_update_kernel_inst.cu @@ -0,0 +1,15 @@ +#include + +#include "selective_state_update_config.inc" + +namespace flashinfer::mamba { + +template void invokeSelectiveStateUpdate( + SelectiveStateUpdateParams&, cudaStream_t); + +namespace mtp { +template void invokeSelectiveStateUpdateMTP( + SelectiveStateMTPParams&, cudaStream_t); +} // namespace mtp + +} // namespace flashinfer::mamba diff --git a/flashinfer/aot.py b/flashinfer/aot.py index c0289fd3be..78a0397ec4 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -25,13 +25,28 @@ import shutil from itertools import product from pathlib import Path -from typing import List, Tuple, Iterator, Optional +from typing import Iterator, List, Optional, Tuple import torch - from packaging.version import Version + +from .compilation_context import CompilationContext +from .jit import JitSpec, build_jit_specs +from .jit import env as jit_env from .jit.activation import act_func_def_str, gen_act_and_mul_module +from .jit.attention import ( + gen_batch_attention_module, + gen_batch_decode_module, + gen_batch_mla_module, + gen_batch_prefill_module, + gen_cudnn_fmha_module, + gen_fmha_cutlass_sm100a_module, + gen_single_decode_module, + gen_single_prefill_module, + gen_trtllm_gen_fmha_module, +) from .jit.cascade import gen_cascade_module +from .jit.cpp_ext import get_cuda_version from .jit.fp4_quantization import ( gen_fp4_quantization_sm90_module, gen_fp4_quantization_sm100_module, @@ -41,58 +56,43 @@ gen_fp4_quantization_sm121_module, ) from .jit.fp8_quantization import gen_mxfp8_quantization_sm100_module -from .jit.gdn import gen_gdn_prefill_sm90_module from .jit.fused_moe import ( - gen_cutlass_fused_moe_sm120_module, - gen_cutlass_fused_moe_sm103_module, - gen_cutlass_fused_moe_sm100_module, gen_cutlass_fused_moe_sm90_module, + gen_cutlass_fused_moe_sm100_module, + gen_cutlass_fused_moe_sm103_module, + gen_cutlass_fused_moe_sm120_module, gen_trtllm_gen_fused_moe_sm100_module, ) +from .jit.gdn import gen_gdn_prefill_sm90_module from .jit.gemm import ( + gen_fp8_blockscale_gemm_sm90_module, gen_gemm_module, gen_gemm_sm90_module, - gen_fp8_blockscale_gemm_sm90_module, gen_gemm_sm100_module, gen_gemm_sm100_module_cutlass_fp4, gen_gemm_sm100_module_cutlass_fp8, gen_gemm_sm100_module_cutlass_mxfp8, - gen_tgv_gemm_sm10x_module, gen_gemm_sm120_module, gen_gemm_sm120_module_cutlass_fp4, + gen_tgv_gemm_sm10x_module, gen_trtllm_gen_gemm_module, gen_trtllm_low_latency_gemm_module, ) -from .jit.spdlog import gen_spdlog_module -from .jit.mla import gen_mla_module from .jit.mamba import ( gen_selective_state_update_module, gen_selective_state_update_sm90_module, gen_selective_state_update_sm100_module, ) +from .jit.mla import gen_mla_module from .jit.norm import gen_norm_module from .jit.page import gen_page_module from .jit.quantization import gen_quantization_module from .jit.rope import gen_rope_module from .jit.sampling import gen_sampling_module -from .jit.topk import gen_topk_module +from .jit.spdlog import gen_spdlog_module from .jit.tllm_utils import gen_trtllm_utils_module +from .jit.topk import gen_topk_module from .jit.xqa import gen_xqa_module, gen_xqa_module_mla -from .jit.attention import ( - gen_batch_attention_module, - gen_batch_decode_module, - gen_batch_mla_module, - gen_batch_prefill_module, - gen_cudnn_fmha_module, - gen_fmha_cutlass_sm100a_module, - gen_single_decode_module, - gen_single_prefill_module, - gen_trtllm_gen_fmha_module, -) -from .jit import JitSpec, build_jit_specs -from .jit import env as jit_env -from .jit.cpp_ext import get_cuda_version -from .compilation_context import CompilationContext def gen_fa2( @@ -520,11 +520,14 @@ def gen_all_modules( jit_specs.append(gen_fp4_quantization_sm121_module()) if add_comm: - from .jit.comm import gen_trtllm_comm_module, gen_vllm_comm_module - from .jit.comm import gen_nvshmem_module - from .jit.comm import gen_comm_alltoall_module - from .jit.comm import gen_trtllm_mnnvl_comm_module - from .jit.comm import gen_moe_alltoall_module + from .jit.comm import ( + gen_comm_alltoall_module, + gen_moe_alltoall_module, + gen_nvshmem_module, + gen_trtllm_comm_module, + gen_trtllm_mnnvl_comm_module, + gen_vllm_comm_module, + ) jit_specs.append(gen_nvshmem_module()) jit_specs.append(gen_comm_alltoall_module()) @@ -543,14 +546,45 @@ def gen_all_modules( gen_rope_module(), gen_sampling_module(), gen_topk_module(), - gen_selective_state_update_module(), ] + # selective_state_update: one module per dtype combo per GPU arch + _ssu_dtype_combos = [ + # (state, input, weight, matrixA, stateIndex) + ( + torch.bfloat16, + torch.bfloat16, + torch.bfloat16, + torch.float32, + torch.int32, + ), + (torch.float16, torch.bfloat16, torch.bfloat16, torch.float32, torch.int32), + (torch.float32, torch.bfloat16, torch.bfloat16, torch.float32, torch.int32), + (torch.bfloat16, torch.bfloat16, torch.float32, torch.float32, torch.int32), + (torch.float16, torch.bfloat16, torch.float32, torch.float32, torch.int32), + (torch.float32, torch.bfloat16, torch.float32, torch.float32, torch.int32), + ( + torch.bfloat16, + torch.bfloat16, + torch.bfloat16, + torch.float32, + torch.int64, + ), + (torch.float16, torch.bfloat16, torch.bfloat16, torch.float32, torch.int64), + (torch.float32, torch.bfloat16, torch.bfloat16, torch.float32, torch.int64), + (torch.bfloat16, torch.bfloat16, torch.float32, torch.float32, torch.int64), + (torch.float16, torch.bfloat16, torch.float32, torch.float32, torch.int64), + (torch.float32, torch.bfloat16, torch.float32, torch.float32, torch.int64), + ] + for combo in _ssu_dtype_combos: + jit_specs.append(gen_selective_state_update_module(*combo)) if has_sm90: - jit_specs.append(gen_selective_state_update_sm90_module()) + for combo in _ssu_dtype_combos: + jit_specs.append(gen_selective_state_update_sm90_module(*combo)) jit_specs.append(gen_trtllm_utils_module()) jit_specs.append(gen_gdn_prefill_sm90_module()) if has_sm100: - jit_specs.append(gen_selective_state_update_sm100_module()) + for combo in _ssu_dtype_combos: + jit_specs.append(gen_selective_state_update_sm100_module(*combo)) if ( add_xqa and get_cuda_version() > Version("12.8") diff --git a/flashinfer/jit/mamba/selective_state_update.py b/flashinfer/jit/mamba/selective_state_update.py index a9b18580e2..16f10ed57c 100644 --- a/flashinfer/jit/mamba/selective_state_update.py +++ b/flashinfer/jit/mamba/selective_state_update.py @@ -14,65 +14,164 @@ limitations under the License. """ +import os + +import jinja2 +import torch + from ...compilation_context import CompilationContext from .. import env as jit_env from ..core import JitSpec, gen_jit_spec +from ..utils import write_if_different + +# Map torch dtypes to C++ type names +_dtype_map = { + torch.float16: "half", + torch.bfloat16: "nv_bfloat16", + torch.float32: "float", + torch.int32: "int32_t", + torch.int64: "int64_t", +} + +# Map torch dtypes to filename-safe strings +_filename_safe_dtype_map = { + torch.float16: "f16", + torch.bfloat16: "bf16", + torch.float32: "f32", + torch.int32: "i32", + torch.int64: "i64", +} + + +def get_selective_state_update_uri( + state_dtype: torch.dtype, + input_dtype: torch.dtype, + weight_dtype: torch.dtype, + matrixA_dtype: torch.dtype, + stateIndex_dtype: torch.dtype, +) -> str: + s = _filename_safe_dtype_map + return ( + f"selective_state_update_" + f"s_{s[state_dtype]}_i_{s[input_dtype]}_w_{s[weight_dtype]}_" + f"a_{s[matrixA_dtype]}_si_{s[stateIndex_dtype]}" + ) + + +def _gen_module( + uri: str, + state_dtype: torch.dtype, + input_dtype: torch.dtype, + weight_dtype: torch.dtype, + matrixA_dtype: torch.dtype, + stateIndex_dtype: torch.dtype, + extra_cuda_cflags: list = None, +) -> JitSpec: + gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri + os.makedirs(gen_directory, exist_ok=True) + + # Render the config .inc + with open( + jit_env.FLASHINFER_CSRC_DIR / "selective_state_update_customize_config.jinja" + ) as f: + config_templ = jinja2.Template(f.read()) + + config_str = config_templ.render( + state_dtype=_dtype_map[state_dtype], + input_dtype=_dtype_map[input_dtype], + weight_dtype=_dtype_map[weight_dtype], + matrixA_dtype=_dtype_map[matrixA_dtype], + stateIndex_dtype=_dtype_map[stateIndex_dtype], + ) + write_if_different(gen_directory / "selective_state_update_config.inc", config_str) + # Copy source files to gen directory (so they can #include the config.inc) + source_paths = [] + for filename in [ + "selective_state_update.cu", + "selective_state_update_kernel_inst.cu", + "flashinfer_mamba_binding.cu", + ]: + src_path = jit_env.FLASHINFER_CSRC_DIR / filename + dest_path = gen_directory / filename + source_paths.append(dest_path) + with open(src_path, "r") as f: + source = f.read() + write_if_different(dest_path, source) -def gen_selective_state_update_module() -> JitSpec: return gen_jit_spec( - "mamba_selective_state_update", - [ - jit_env.FLASHINFER_CSRC_DIR / "selective_state_update.cu", - jit_env.FLASHINFER_CSRC_DIR / "flashinfer_mamba_binding.cu", - ], + uri, + source_paths, + extra_cuda_cflags=extra_cuda_cflags or [], + ) + + +def gen_selective_state_update_module( + state_dtype: torch.dtype, + input_dtype: torch.dtype, + weight_dtype: torch.dtype, + matrixA_dtype: torch.dtype, + stateIndex_dtype: torch.dtype, +) -> JitSpec: + uri = get_selective_state_update_uri( + state_dtype, input_dtype, weight_dtype, matrixA_dtype, stateIndex_dtype + ) + return _gen_module( + uri, state_dtype, input_dtype, weight_dtype, matrixA_dtype, stateIndex_dtype ) -def gen_selective_state_update_sm90_module() -> JitSpec: - # We use a specialized module for Hopper GPUs due to the explicit use - # of TMA device functions (vertical producer-consumer kernel). - # This supports SM90 (Hopper) only. - # - # Technically, all the kernels in this module can be executed on newer GPUs than Hopper, - # but this kernel ends up being slower than the alternative SM100 module. - # Therefore, this is excluded to reduce the amount of compilation. +def gen_selective_state_update_sm90_module( + state_dtype: torch.dtype, + input_dtype: torch.dtype, + weight_dtype: torch.dtype, + matrixA_dtype: torch.dtype, + stateIndex_dtype: torch.dtype, +) -> JitSpec: + uri = ( + get_selective_state_update_uri( + state_dtype, input_dtype, weight_dtype, matrixA_dtype, stateIndex_dtype + ) + + "_sm90" + ) compilation_context = CompilationContext() nvcc_flags = compilation_context.get_nvcc_flags_list(supported_major_versions=[9]) - nvcc_flags += [ - "-DFLASHINFER_MAMBA_ENABLE_SM90", - ] - - return gen_jit_spec( - "mamba_selective_state_update_sm90", - [ - jit_env.FLASHINFER_CSRC_DIR / "selective_state_update.cu", - jit_env.FLASHINFER_CSRC_DIR / "flashinfer_mamba_binding.cu", - ], + nvcc_flags += ["-DFLASHINFER_MAMBA_ENABLE_SM90"] + return _gen_module( + uri, + state_dtype, + input_dtype, + weight_dtype, + matrixA_dtype, + stateIndex_dtype, extra_cuda_cflags=nvcc_flags, ) -def gen_selective_state_update_sm100_module() -> JitSpec: - # We use a specialized module for Blackwell+ GPUs with horizontal - # producer-consumer kernel optimized for SM100 and newer architectures. - # This supports SM100 (Blackwell) and future architectures. - # Technically, the code in this module can compile on sm90 as well, but - # this kernel is a lot slower on hopper than those in the mamba_selective_state_update and - # mamba_selective_state_update_sm90 modules. +def gen_selective_state_update_sm100_module( + state_dtype: torch.dtype, + input_dtype: torch.dtype, + weight_dtype: torch.dtype, + matrixA_dtype: torch.dtype, + stateIndex_dtype: torch.dtype, +) -> JitSpec: + uri = ( + get_selective_state_update_uri( + state_dtype, input_dtype, weight_dtype, matrixA_dtype, stateIndex_dtype + ) + + "_sm100" + ) compilation_context = CompilationContext() nvcc_flags = compilation_context.get_nvcc_flags_list( supported_major_versions=[10, 11, 12] ) - nvcc_flags += [ - "-DFLASHINFER_MAMBA_ENABLE_SM100", - ] - - return gen_jit_spec( - "mamba_selective_state_update_sm100", - [ - jit_env.FLASHINFER_CSRC_DIR / "selective_state_update.cu", - jit_env.FLASHINFER_CSRC_DIR / "flashinfer_mamba_binding.cu", - ], + nvcc_flags += ["-DFLASHINFER_MAMBA_ENABLE_SM100"] + return _gen_module( + uri, + state_dtype, + input_dtype, + weight_dtype, + matrixA_dtype, + stateIndex_dtype, extra_cuda_cflags=nvcc_flags, ) diff --git a/flashinfer/mamba/selective_state_update.py b/flashinfer/mamba/selective_state_update.py index 734b0f5c10..5cbd205da5 100644 --- a/flashinfer/mamba/selective_state_update.py +++ b/flashinfer/mamba/selective_state_update.py @@ -29,34 +29,41 @@ @functools.cache -def get_selective_state_update_module_base(): - """Get cached JIT-compiled selective_state_update module (base version).""" - return gen_selective_state_update_module().build_and_load() - - -@functools.cache -def get_selective_state_update_module_sm90(): - """Get cached JIT-compiled selective_state_update module (SM90/Hopper version).""" - return gen_selective_state_update_sm90_module().build_and_load() - - -@functools.cache -def get_selective_state_update_module_sm100(): - """Get cached JIT-compiled selective_state_update module (SM100+/Blackwell version).""" - return gen_selective_state_update_sm100_module().build_and_load() +def _get_module( + state_dtype: torch.dtype, + input_dtype: torch.dtype, + weight_dtype: torch.dtype, + matrixA_dtype: torch.dtype, + stateIndex_dtype: torch.dtype, + sm_major: int, +): + dtype_args = ( + state_dtype, + input_dtype, + weight_dtype, + matrixA_dtype, + stateIndex_dtype, + ) + if sm_major >= 10: + return gen_selective_state_update_sm100_module(*dtype_args).build_and_load() + elif sm_major == 9: + return gen_selective_state_update_sm90_module(*dtype_args).build_and_load() + else: + return gen_selective_state_update_module(*dtype_args).build_and_load() -def get_selective_state_update_module(device: torch.device): +def get_selective_state_update_module( + device: torch.device, + state_dtype: torch.dtype, + input_dtype: torch.dtype, + weight_dtype: torch.dtype, + matrixA_dtype: torch.dtype, + stateIndex_dtype: torch.dtype, +): major, _ = get_compute_capability(device) - if major >= 10: - # SM100+ (Blackwell and newer) uses horizontal producer-consumer kernel - return get_selective_state_update_module_sm100() - elif major == 9: - # SM90 (Hopper) uses vertical producer-consumer kernel - return get_selective_state_update_module_sm90() - else: - # Pre-Hopper uses simple kernel - return get_selective_state_update_module_base() + return _get_module( + state_dtype, input_dtype, weight_dtype, matrixA_dtype, stateIndex_dtype, major + ) @flashinfer_api @@ -178,6 +185,14 @@ def selective_state_update( output = torch.empty_like(x) else: output = out + + # Determine stateIndex dtype from index tensors, default to int32 + stateIndex_dtype = torch.int32 + if state_batch_indices is not None: + stateIndex_dtype = state_batch_indices.dtype + elif intermediate_state_indices is not None: + stateIndex_dtype = intermediate_state_indices.dtype + _selective_state_update( state, x, @@ -196,6 +211,11 @@ def selective_state_update( intermediate_states_buffer, intermediate_state_indices, cache_steps, + state.dtype, + x.dtype, + dt.dtype, + A.dtype, + stateIndex_dtype, ) return output @@ -222,9 +242,21 @@ def _selective_state_update( intermediate_states_buffer: Optional[torch.Tensor], intermediate_state_indices: Optional[torch.Tensor], cache_steps: int, + state_dtype: torch.dtype, + input_dtype: torch.dtype, + weight_dtype: torch.dtype, + matrixA_dtype: torch.dtype, + stateIndex_dtype: torch.dtype, ) -> None: """Internal function registered with torch.library for torch.compile() support.""" - get_selective_state_update_module(state.device).selective_state_update( + get_selective_state_update_module( + state.device, + state_dtype, + input_dtype, + weight_dtype, + matrixA_dtype, + stateIndex_dtype, + ).selective_state_update( state, x, dt, @@ -264,6 +296,11 @@ def _selective_state_update_fake( intermediate_states_buffer: Optional[torch.Tensor], intermediate_state_indices: Optional[torch.Tensor], cache_steps: int, + state_dtype: torch.dtype, + input_dtype: torch.dtype, + weight_dtype: torch.dtype, + matrixA_dtype: torch.dtype, + stateIndex_dtype: torch.dtype, ) -> None: """Fake implementation for torch.compile() meta tensor propagation.""" pass diff --git a/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh b/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh index 254f2f71ea..81304f3116 100644 --- a/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh +++ b/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh @@ -961,13 +961,15 @@ void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, cudaStream_t else { // SM100+ (Blackwell and newer) uses horizontal producer-consumer kernel auto kernel_launcher = [&]() { - // profiling showed that it's good to have 4 producers per 64 rows + // profiling showed that it's good to have 4 consumer warps per 64 rows constexpr auto numConsumers = (DIM / 64) * 4; constexpr auto numProducers = 1; constexpr auto numWarps = numProducers + numConsumers; constexpr auto sectorSize = 32; // bytes - constexpr auto stageCols = 2 * sectorSize / sizeof(state_t); + constexpr auto stageCols = + 2 * sectorSize / + sizeof(state_t); // bf16 has 16 columns per stage, fp32 has 8 columns per stage constexpr auto totalStages = DSTATE / stageCols; constexpr auto numStages = (totalStages >= 4) ? 4 : totalStages; From 329bfd0dbb9e72fb4e40c994d4edbfc048fb2450 Mon Sep 17 00:00:00 2001 From: Igor Shovkun Date: Tue, 17 Feb 2026 14:30:10 -0800 Subject: [PATCH 16/33] Remove SM100 module, unify SM90+ selective state update handling --- flashinfer/aot.py | 6 +--- flashinfer/jit/mamba/__init__.py | 2 -- .../jit/mamba/selective_state_update.py | 31 ++----------------- flashinfer/mamba/selective_state_update.py | 5 +-- .../kernel_selective_state_update_stp.cuh | 29 +++++------------ 5 files changed, 11 insertions(+), 62 deletions(-) diff --git a/flashinfer/aot.py b/flashinfer/aot.py index 78a0397ec4..5df208ae1c 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -81,7 +81,6 @@ from .jit.mamba import ( gen_selective_state_update_module, gen_selective_state_update_sm90_module, - gen_selective_state_update_sm100_module, ) from .jit.mla import gen_mla_module from .jit.norm import gen_norm_module @@ -577,14 +576,11 @@ def gen_all_modules( ] for combo in _ssu_dtype_combos: jit_specs.append(gen_selective_state_update_module(*combo)) - if has_sm90: + if has_sm90 or has_sm100: for combo in _ssu_dtype_combos: jit_specs.append(gen_selective_state_update_sm90_module(*combo)) jit_specs.append(gen_trtllm_utils_module()) jit_specs.append(gen_gdn_prefill_sm90_module()) - if has_sm100: - for combo in _ssu_dtype_combos: - jit_specs.append(gen_selective_state_update_sm100_module(*combo)) if ( add_xqa and get_cuda_version() > Version("12.8") diff --git a/flashinfer/jit/mamba/__init__.py b/flashinfer/jit/mamba/__init__.py index 8ac01c2455..f6a2628b43 100644 --- a/flashinfer/jit/mamba/__init__.py +++ b/flashinfer/jit/mamba/__init__.py @@ -17,11 +17,9 @@ from .selective_state_update import ( gen_selective_state_update_module, gen_selective_state_update_sm90_module, - gen_selective_state_update_sm100_module, ) __all__ = [ "gen_selective_state_update_module", "gen_selective_state_update_sm90_module", - "gen_selective_state_update_sm100_module", ] diff --git a/flashinfer/jit/mamba/selective_state_update.py b/flashinfer/jit/mamba/selective_state_update.py index 16f10ed57c..9eb69d9936 100644 --- a/flashinfer/jit/mamba/selective_state_update.py +++ b/flashinfer/jit/mamba/selective_state_update.py @@ -135,37 +135,10 @@ def gen_selective_state_update_sm90_module( + "_sm90" ) compilation_context = CompilationContext() - nvcc_flags = compilation_context.get_nvcc_flags_list(supported_major_versions=[9]) - nvcc_flags += ["-DFLASHINFER_MAMBA_ENABLE_SM90"] - return _gen_module( - uri, - state_dtype, - input_dtype, - weight_dtype, - matrixA_dtype, - stateIndex_dtype, - extra_cuda_cflags=nvcc_flags, - ) - - -def gen_selective_state_update_sm100_module( - state_dtype: torch.dtype, - input_dtype: torch.dtype, - weight_dtype: torch.dtype, - matrixA_dtype: torch.dtype, - stateIndex_dtype: torch.dtype, -) -> JitSpec: - uri = ( - get_selective_state_update_uri( - state_dtype, input_dtype, weight_dtype, matrixA_dtype, stateIndex_dtype - ) - + "_sm100" - ) - compilation_context = CompilationContext() nvcc_flags = compilation_context.get_nvcc_flags_list( - supported_major_versions=[10, 11, 12] + supported_major_versions=[9, 10, 11, 12] ) - nvcc_flags += ["-DFLASHINFER_MAMBA_ENABLE_SM100"] + nvcc_flags += ["-DFLASHINFER_MAMBA_ENABLE_SM90"] return _gen_module( uri, state_dtype, diff --git a/flashinfer/mamba/selective_state_update.py b/flashinfer/mamba/selective_state_update.py index 5cbd205da5..7635e0f408 100644 --- a/flashinfer/mamba/selective_state_update.py +++ b/flashinfer/mamba/selective_state_update.py @@ -23,7 +23,6 @@ from ..jit.mamba import ( gen_selective_state_update_module, gen_selective_state_update_sm90_module, - gen_selective_state_update_sm100_module, ) from ..utils import get_compute_capability, register_custom_op, register_fake_op @@ -44,9 +43,7 @@ def _get_module( matrixA_dtype, stateIndex_dtype, ) - if sm_major >= 10: - return gen_selective_state_update_sm100_module(*dtype_args).build_and_load() - elif sm_major == 9: + if sm_major >= 9: return gen_selective_state_update_sm90_module(*dtype_args).build_and_load() else: return gen_selective_state_update_module(*dtype_args).build_and_load() diff --git a/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh b/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh index 81304f3116..46ff69997b 100644 --- a/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh +++ b/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh @@ -535,11 +535,7 @@ __global__ void selective_state_update_kernel_producer_consumer_vertical( #endif } -// ============================================================================= -// Horizontal Producer-Consumer Kernel for SM100+ (Blackwell and newer) -// ============================================================================= - -#ifdef FLASHINFER_MAMBA_ENABLE_SM100 +#ifdef FLASHINFER_MAMBA_ENABLE_SM90 template @@ -891,9 +887,7 @@ void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, cudaStream_t // Common alignment checks for all kernels check_ptr_alignment_input_vars(params); -#ifdef FLASHINFER_MAMBA_ENABLE_SM100 - if (sm_major < 10) // pre-Blackwell -#elif defined(FLASHINFER_MAMBA_ENABLE_SM90) +#ifdef FLASHINFER_MAMBA_ENABLE_SM90 if (sm_major < 9) // pre-Hopper #endif { @@ -917,10 +911,9 @@ void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, cudaStream_t dispatchDimDstate(params, AllowedDims{}, AllowedDstates{}, kernel_launcher); } #ifdef FLASHINFER_MAMBA_ENABLE_SM90 - else { - + else if (sm_major < 10) { + // SM90 (Hopper) uses vertical producer-consumer kernel auto kernel_launcher = [&]() { - // Note: State uses TMA which requires 128B alignment (checked below) constexpr auto numConsumers = 4; constexpr auto numWarps = 1 + numConsumers; constexpr auto numStages = 3; @@ -943,7 +936,6 @@ void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, cudaStream_t /*strides*/ {1, DSTATE, DSTATE * DIM, params.state_stride_batch}, /*tiles*/ {DSTATE, rowsPerStage, 1, 1}, params.state); - // Calculate shared memory size and opt-in to extended shared memory using sram_t = SharedStorageVertical; constexpr size_t smem_size = sizeof(sram_t); @@ -954,22 +946,15 @@ void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, cudaStream_t }; dispatchDimDstate(params, AllowedDims{}, AllowedDstates{}, kernel_launcher); - } -#endif - -#ifdef FLASHINFER_MAMBA_ENABLE_SM100 - else { + } else { // SM100+ (Blackwell and newer) uses horizontal producer-consumer kernel auto kernel_launcher = [&]() { - // profiling showed that it's good to have 4 consumer warps per 64 rows constexpr auto numConsumers = (DIM / 64) * 4; constexpr auto numProducers = 1; constexpr auto numWarps = numProducers + numConsumers; constexpr auto sectorSize = 32; // bytes - constexpr auto stageCols = - 2 * sectorSize / - sizeof(state_t); // bf16 has 16 columns per stage, fp32 has 8 columns per stage + constexpr auto stageCols = 2 * sectorSize / sizeof(state_t); constexpr auto totalStages = DSTATE / stageCols; constexpr auto numStages = (totalStages >= 4) ? 4 : totalStages; From f4640971b6ca2676b8bcda8e89a27d2a7c9a535b Mon Sep 17 00:00:00 2001 From: Igor Shovkun Date: Wed, 18 Feb 2026 11:22:16 -0800 Subject: [PATCH 17/33] Add algorithm selection to selective_state_update kernels Support explicit algorithm choice (auto/simple/vertical/horizontal) for selective_state_update and MTP kernels. Update kernel signatures, Python bindings, and JIT module generation to include algorithm and compile-time shape parameters (dim, dstate, ntokens_mtp). Refactor dispatch logic for SM90/SM100 architectures. --- csrc/flashinfer_mamba_binding.cu | 3 +- csrc/selective_state_update.cu | 23 ++- ...ective_state_update_customize_config.jinja | 4 + csrc/selective_state_update_kernel_inst.cu | 4 +- flashinfer/aot.py | 21 ++- .../jit/mamba/selective_state_update.py | 49 +++++- flashinfer/mamba/selective_state_update.py | 62 ++++++- .../kernel_selective_state_update_mtp.cuh | 70 ++++---- .../kernel_selective_state_update_stp.cuh | 160 +++++++++--------- .../mamba/selective_state_update.cuh | 14 +- 10 files changed, 259 insertions(+), 151 deletions(-) diff --git a/csrc/flashinfer_mamba_binding.cu b/csrc/flashinfer_mamba_binding.cu index dfdc5bebf8..2e2453cefc 100644 --- a/csrc/flashinfer_mamba_binding.cu +++ b/csrc/flashinfer_mamba_binding.cu @@ -42,7 +42,8 @@ void selective_state_update( bool disable_state_update, Optional intermediate_states_buffer, // (batch, cache_steps, nheads, dim, dstate) Optional intermediate_state_indices, // (batch,) - int64_t cache_steps); + int64_t cache_steps, + int64_t algorithm); // SSUAlgorithm: 0=auto, 1=simple, 2=vertical, 3=horizontal } // namespace flashinfer::mamba diff --git a/csrc/selective_state_update.cu b/csrc/selective_state_update.cu index 7990aa167c..b943af821d 100644 --- a/csrc/selective_state_update.cu +++ b/csrc/selective_state_update.cu @@ -130,7 +130,7 @@ void run_selective_state_update_stp(TensorView const& state, TensorView const& x Optional z, Optional dt_bias, bool dt_softplus, Optional state_batch_indices, int64_t pad_slot_id, Optional out, - bool disable_state_update) { + bool disable_state_update, int64_t algorithm) { // Extract dimensions from input tensors auto const batch = x.size(0); auto const state_cache_size = state.size(0); @@ -270,7 +270,8 @@ void run_selective_state_update_stp(TensorView const& state, TensorView const& x ffi::CUDADeviceGuard device_guard(state.device().device_id); const cudaStream_t stream = get_stream(state.device()); - invokeSelectiveStateUpdate(p, stream); + auto algo = static_cast(algorithm); + invokeSelectiveStateUpdate(p, algo, stream); } void run_selective_state_update_mtp( @@ -279,7 +280,7 @@ void run_selective_state_update_mtp( Optional dt_bias, bool dt_softplus, Optional state_batch_indices, int64_t pad_slot_id, Optional out, bool disable_state_update, Optional intermediate_states_buffer, - Optional intermediate_state_indices, int64_t cache_steps) { + Optional intermediate_state_indices, int64_t cache_steps, int64_t algorithm) { // Extract dimensions from input tensors auto const batch = x.size(0); auto const ntokens_mtp = x.size(1); @@ -466,7 +467,8 @@ void run_selective_state_update_mtp( ffi::CUDADeviceGuard device_guard(state.device().device_id); const cudaStream_t stream = get_stream(state.device()); - mtp::invokeSelectiveStateUpdateMTP(p, + auto algo = static_cast(algorithm); + mtp::invokeSelectiveStateUpdateMTP(p, algo, stream); } @@ -479,14 +481,17 @@ void selective_state_update(TensorView state, TensorView x, TensorView dt, Tenso Optional state_batch_indices, int64_t pad_slot_id, TensorView output, bool disable_state_update, Optional intermediate_states_buffer, - Optional intermediate_state_indices, int64_t cache_steps) { + Optional intermediate_state_indices, int64_t cache_steps, + int64_t algorithm) { if (x.dim() == 3) { run_selective_state_update_stp(state, x, dt, A, B, C, D, z, dt_bias, dt_softplus, - state_batch_indices, pad_slot_id, output, disable_state_update); + state_batch_indices, pad_slot_id, output, disable_state_update, + algorithm); } else if (x.dim() == 4) { - run_selective_state_update_mtp( - state, x, dt, A, B, C, D, z, dt_bias, dt_softplus, state_batch_indices, pad_slot_id, output, - disable_state_update, intermediate_states_buffer, intermediate_state_indices, cache_steps); + run_selective_state_update_mtp(state, x, dt, A, B, C, D, z, dt_bias, dt_softplus, + state_batch_indices, pad_slot_id, output, disable_state_update, + intermediate_states_buffer, intermediate_state_indices, + cache_steps, algorithm); } else { FLASHINFER_CHECK(false, "x must have 3 dimensions (single-token) or 4 dimensions (multi-token), got ", diff --git a/csrc/selective_state_update_customize_config.jinja b/csrc/selective_state_update_customize_config.jinja index a2f70798c7..418356212d 100644 --- a/csrc/selective_state_update_customize_config.jinja +++ b/csrc/selective_state_update_customize_config.jinja @@ -8,3 +8,7 @@ using input_t = {{ input_dtype }}; using weight_t = {{ weight_dtype }}; using matrixA_t = {{ matrixA_dtype }}; using stateIndex_t = {{ stateIndex_dtype }}; + +constexpr int DIM = {{ dim }}; +constexpr int DSTATE = {{ dstate }}; +constexpr int NTOKENS_MTP = {{ ntokens_mtp }}; diff --git a/csrc/selective_state_update_kernel_inst.cu b/csrc/selective_state_update_kernel_inst.cu index e84e152c9b..11088b3290 100644 --- a/csrc/selective_state_update_kernel_inst.cu +++ b/csrc/selective_state_update_kernel_inst.cu @@ -5,11 +5,11 @@ namespace flashinfer::mamba { template void invokeSelectiveStateUpdate( - SelectiveStateUpdateParams&, cudaStream_t); + SelectiveStateUpdateParams&, SSUAlgorithm, cudaStream_t); namespace mtp { template void invokeSelectiveStateUpdateMTP( - SelectiveStateMTPParams&, cudaStream_t); + SelectiveStateMTPParams&, SSUAlgorithm, cudaStream_t); } // namespace mtp } // namespace flashinfer::mamba diff --git a/flashinfer/aot.py b/flashinfer/aot.py index 5df208ae1c..a1705cf0e9 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -574,11 +574,24 @@ def gen_all_modules( (torch.float16, torch.bfloat16, torch.float32, torch.float32, torch.int64), (torch.float32, torch.bfloat16, torch.float32, torch.float32, torch.int64), ] - for combo in _ssu_dtype_combos: - jit_specs.append(gen_selective_state_update_module(*combo)) + _ssu_dims = [64, 128, 256] + _ssu_dstates = [64, 128, 256] + _ssu_ntokens = [1, 2, 4, 6, 8, 12, 16] + for dtype_combo, dim, dstate, ntokens in product( + _ssu_dtype_combos, _ssu_dims, _ssu_dstates, _ssu_ntokens + ): + jit_specs.append( + gen_selective_state_update_module(*dtype_combo, dim, dstate, ntokens) + ) if has_sm90 or has_sm100: - for combo in _ssu_dtype_combos: - jit_specs.append(gen_selective_state_update_sm90_module(*combo)) + for dtype_combo, dim, dstate, ntokens in product( + _ssu_dtype_combos, _ssu_dims, _ssu_dstates, _ssu_ntokens + ): + jit_specs.append( + gen_selective_state_update_sm90_module( + *dtype_combo, dim, dstate, ntokens + ) + ) jit_specs.append(gen_trtllm_utils_module()) jit_specs.append(gen_gdn_prefill_sm90_module()) diff --git a/flashinfer/jit/mamba/selective_state_update.py b/flashinfer/jit/mamba/selective_state_update.py index 9eb69d9936..bba5c3d375 100644 --- a/flashinfer/jit/mamba/selective_state_update.py +++ b/flashinfer/jit/mamba/selective_state_update.py @@ -49,12 +49,16 @@ def get_selective_state_update_uri( weight_dtype: torch.dtype, matrixA_dtype: torch.dtype, stateIndex_dtype: torch.dtype, + dim: int, + dstate: int, + ntokens_mtp: int, ) -> str: s = _filename_safe_dtype_map return ( f"selective_state_update_" f"s_{s[state_dtype]}_i_{s[input_dtype]}_w_{s[weight_dtype]}_" - f"a_{s[matrixA_dtype]}_si_{s[stateIndex_dtype]}" + f"a_{s[matrixA_dtype]}_si_{s[stateIndex_dtype]}_" + f"d_{dim}_ds_{dstate}_nt_{ntokens_mtp}" ) @@ -65,6 +69,9 @@ def _gen_module( weight_dtype: torch.dtype, matrixA_dtype: torch.dtype, stateIndex_dtype: torch.dtype, + dim: int, + dstate: int, + ntokens_mtp: int, extra_cuda_cflags: list = None, ) -> JitSpec: gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri @@ -82,6 +89,9 @@ def _gen_module( weight_dtype=_dtype_map[weight_dtype], matrixA_dtype=_dtype_map[matrixA_dtype], stateIndex_dtype=_dtype_map[stateIndex_dtype], + dim=dim, + dstate=dstate, + ntokens_mtp=ntokens_mtp, ) write_if_different(gen_directory / "selective_state_update_config.inc", config_str) @@ -112,12 +122,30 @@ def gen_selective_state_update_module( weight_dtype: torch.dtype, matrixA_dtype: torch.dtype, stateIndex_dtype: torch.dtype, + dim: int, + dstate: int, + ntokens_mtp: int, ) -> JitSpec: uri = get_selective_state_update_uri( - state_dtype, input_dtype, weight_dtype, matrixA_dtype, stateIndex_dtype + state_dtype, + input_dtype, + weight_dtype, + matrixA_dtype, + stateIndex_dtype, + dim, + dstate, + ntokens_mtp, ) return _gen_module( - uri, state_dtype, input_dtype, weight_dtype, matrixA_dtype, stateIndex_dtype + uri, + state_dtype, + input_dtype, + weight_dtype, + matrixA_dtype, + stateIndex_dtype, + dim, + dstate, + ntokens_mtp, ) @@ -127,10 +155,20 @@ def gen_selective_state_update_sm90_module( weight_dtype: torch.dtype, matrixA_dtype: torch.dtype, stateIndex_dtype: torch.dtype, + dim: int, + dstate: int, + ntokens_mtp: int, ) -> JitSpec: uri = ( get_selective_state_update_uri( - state_dtype, input_dtype, weight_dtype, matrixA_dtype, stateIndex_dtype + state_dtype, + input_dtype, + weight_dtype, + matrixA_dtype, + stateIndex_dtype, + dim, + dstate, + ntokens_mtp, ) + "_sm90" ) @@ -146,5 +184,8 @@ def gen_selective_state_update_sm90_module( weight_dtype, matrixA_dtype, stateIndex_dtype, + dim, + dstate, + ntokens_mtp, extra_cuda_cflags=nvcc_flags, ) diff --git a/flashinfer/mamba/selective_state_update.py b/flashinfer/mamba/selective_state_update.py index 7635e0f408..25f47cb68b 100644 --- a/flashinfer/mamba/selective_state_update.py +++ b/flashinfer/mamba/selective_state_update.py @@ -34,19 +34,25 @@ def _get_module( weight_dtype: torch.dtype, matrixA_dtype: torch.dtype, stateIndex_dtype: torch.dtype, + dim: int, + dstate: int, + ntokens_mtp: int, sm_major: int, ): - dtype_args = ( + args = ( state_dtype, input_dtype, weight_dtype, matrixA_dtype, stateIndex_dtype, + dim, + dstate, + ntokens_mtp, ) if sm_major >= 9: - return gen_selective_state_update_sm90_module(*dtype_args).build_and_load() + return gen_selective_state_update_sm90_module(*args).build_and_load() else: - return gen_selective_state_update_module(*dtype_args).build_and_load() + return gen_selective_state_update_module(*args).build_and_load() def get_selective_state_update_module( @@ -56,10 +62,21 @@ def get_selective_state_update_module( weight_dtype: torch.dtype, matrixA_dtype: torch.dtype, stateIndex_dtype: torch.dtype, + dim: int, + dstate: int, + ntokens_mtp: int, ): major, _ = get_compute_capability(device) return _get_module( - state_dtype, input_dtype, weight_dtype, matrixA_dtype, stateIndex_dtype, major + state_dtype, + input_dtype, + weight_dtype, + matrixA_dtype, + stateIndex_dtype, + dim, + dstate, + ntokens_mtp, + major, ) @@ -82,6 +99,7 @@ def selective_state_update( intermediate_states_buffer: Optional[torch.Tensor] = None, intermediate_state_indices: Optional[torch.Tensor] = None, cache_steps: int = 0, + algorithm: str = "auto", ) -> torch.Tensor: r"""Selective state update operation for Mamba layers (the generation phase). @@ -130,6 +148,10 @@ def selective_state_update( with shape (batch,) cache_steps : int Number of steps/tokens to cache for speculative decoding + algorithm : str + Algorithm to use: "auto" (default, selects based on GPU arch), + "simple" (all GPUs), "vertical" (SM90+), "horizontal" (SM100+). + MTP mode only supports "auto" or "simple". Returns ------- @@ -190,6 +212,22 @@ def selective_state_update( elif intermediate_state_indices is not None: stateIndex_dtype = intermediate_state_indices.dtype + # Extract dim/dstate/ntokens for JIT specialization + dim = state.size(2) + dstate = state.size(3) + ntokens_mtp = x.size(1) if x.dim() == 4 else 1 + + if algorithm == "auto": + algorithm_int = 0 + elif algorithm == "simple": + algorithm_int = 1 + elif algorithm == "vertical": + algorithm_int = 2 + elif algorithm == "horizontal": + algorithm_int = 3 + else: + raise ValueError(f"Unknown algorithm: {algorithm}") + _selective_state_update( state, x, @@ -208,11 +246,15 @@ def selective_state_update( intermediate_states_buffer, intermediate_state_indices, cache_steps, + algorithm_int, state.dtype, x.dtype, dt.dtype, A.dtype, stateIndex_dtype, + dim, + dstate, + ntokens_mtp, ) return output @@ -239,11 +281,15 @@ def _selective_state_update( intermediate_states_buffer: Optional[torch.Tensor], intermediate_state_indices: Optional[torch.Tensor], cache_steps: int, + algorithm: int, state_dtype: torch.dtype, input_dtype: torch.dtype, weight_dtype: torch.dtype, matrixA_dtype: torch.dtype, stateIndex_dtype: torch.dtype, + dim: int, + dstate: int, + ntokens_mtp: int, ) -> None: """Internal function registered with torch.library for torch.compile() support.""" get_selective_state_update_module( @@ -253,6 +299,9 @@ def _selective_state_update( weight_dtype, matrixA_dtype, stateIndex_dtype, + dim, + dstate, + ntokens_mtp, ).selective_state_update( state, x, @@ -271,6 +320,7 @@ def _selective_state_update( intermediate_states_buffer, intermediate_state_indices, cache_steps, + algorithm, ) @@ -293,11 +343,15 @@ def _selective_state_update_fake( intermediate_states_buffer: Optional[torch.Tensor], intermediate_state_indices: Optional[torch.Tensor], cache_steps: int, + algorithm: int, state_dtype: torch.dtype, input_dtype: torch.dtype, weight_dtype: torch.dtype, matrixA_dtype: torch.dtype, stateIndex_dtype: torch.dtype, + dim: int, + dstate: int, + ntokens_mtp: int, ) -> None: """Fake implementation for torch.compile() meta tensor propagation.""" pass diff --git a/include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh b/include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh index 09d927f6fb..af86b8094a 100644 --- a/include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh +++ b/include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh @@ -288,48 +288,40 @@ __global__ void selective_state_update_kernel_simple_mtp(SelectiveStateMTPParams template -void invokeSelectiveStateUpdateMTP(SelectiveStateMTPParams& params, cudaStream_t stream) { +void invokeSelectiveStateUpdateMTP(SelectiveStateMTPParams& params, SSUAlgorithm algorithm, + cudaStream_t stream) { + // MTP only supports the simple kernel + FLASHINFER_CHECK(algorithm == SSUAlgorithm::kAuto || algorithm == SSUAlgorithm::kSimple, + "MTP selective_state_update only supports 'auto' or 'simple' algorithm, got ", + static_cast(algorithm)); // Common alignment checks for all kernels check_ptr_alignment_input_vars(params); - auto kernel_launcher = [&]() { - // Additional alignment checks specific to simple kernel - constexpr auto stateLoadSize = getVectorLoadSizeForFullUtilization(); - using load_state_t = PackedAligned; - - FLASHINFER_CHECK(reinterpret_cast(params.state) % sizeof(load_state_t) == 0, - "state pointer must be aligned to ", sizeof(load_state_t), " bytes"); - FLASHINFER_CHECK((params.dim * params.dstate * sizeof(state_t)) % sizeof(load_state_t) == 0, - "state head stride must be aligned to ", sizeof(load_state_t), " bytes"); - - constexpr int numWarps = 4; - constexpr int stateRowsPerWarpPerStage = 4; - constexpr int stageRows = stateRowsPerWarpPerStage * numWarps; - - dim3 block(warpSize, numWarps); - dim3 grid(params.batch, params.nheads); - - auto func = - selective_state_update_kernel_simple_mtp; - using sram_t = SharedStorageSimple; - constexpr size_t smem_size = sizeof(sram_t); - - // Use FLASHINFER_CHECK instead of FLASHINFER_CUDA_CALL since we're in a void lambda - // (FLASHINFER_CUDA_CALL uses "return e;" which is invalid in void context) - // { - // cudaError_t e = cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, - // smem_size); FLASHINFER_CHECK(e == cudaSuccess, "CUDA Error in cudaFuncSetAttribute: ", - // cudaGetErrorString(e), " (", int(e), ")"); - // } - FLASHINFER_CUDA_CHECK( - cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - - func<<>>(params); - }; - - dispatchDimDstateTokens(params, AllowedDims{}, AllowedDstates{}, AllowedNtokens{}, - kernel_launcher); + constexpr auto stateLoadSize = getVectorLoadSizeForFullUtilization(); + using load_state_t = PackedAligned; + + FLASHINFER_CHECK(reinterpret_cast(params.state) % sizeof(load_state_t) == 0, + "state pointer must be aligned to ", sizeof(load_state_t), " bytes"); + FLASHINFER_CHECK((params.dim * params.dstate * sizeof(state_t)) % sizeof(load_state_t) == 0, + "state head stride must be aligned to ", sizeof(load_state_t), " bytes"); + + constexpr int numWarps = 4; + constexpr int stateRowsPerWarpPerStage = 4; + constexpr int stageRows = stateRowsPerWarpPerStage * numWarps; + + dim3 block(warpSize, numWarps); + dim3 grid(params.batch, params.nheads); + + auto func = + selective_state_update_kernel_simple_mtp; + using sram_t = SharedStorageSimple; + constexpr size_t smem_size = sizeof(sram_t); + + FLASHINFER_CUDA_CHECK( + cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + func<<>>(params); } } // namespace flashinfer::mamba::mtp diff --git a/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh b/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh index 46ff69997b..ee0fbb3344 100644 --- a/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh +++ b/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh @@ -881,63 +881,99 @@ __global__ void selective_state_update_kernel_producer_consumer_horizontal( template -void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, cudaStream_t stream) { +void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, SSUAlgorithm algorithm, + cudaStream_t stream) { auto [sm_major, sm_minor] = GetCudaComputeCapability(); // Common alignment checks for all kernels check_ptr_alignment_input_vars(params); + // Resolve auto to a concrete algorithm based on GPU architecture + SSUAlgorithm algo = algorithm; + if (algo == SSUAlgorithm::kAuto) { #ifdef FLASHINFER_MAMBA_ENABLE_SM90 - if (sm_major < 9) // pre-Hopper + if (sm_major < 9) + algo = SSUAlgorithm::kSimple; + else if (sm_major < 10) + algo = SSUAlgorithm::kVertical; + else + algo = SSUAlgorithm::kHorizontal; +#else + algo = SSUAlgorithm::kSimple; #endif - { - auto kernel_launcher = [&]() { - // Additional alignment checks specific to simple kernel - constexpr auto stateLoadSize = getVectorLoadSizeForFullUtilization(); - using load_state_t = PackedAligned; + } - FLASHINFER_CHECK(reinterpret_cast(params.state) % sizeof(load_state_t) == 0, - "state pointer must be aligned to ", sizeof(load_state_t), " bytes"); - FLASHINFER_CHECK((params.dim * params.dstate * sizeof(state_t)) % sizeof(load_state_t) == 0, - "state head stride must be aligned to ", sizeof(load_state_t), " bytes"); + if (algo == SSUAlgorithm::kSimple) { + constexpr auto stateLoadSize = getVectorLoadSizeForFullUtilization(); + using load_state_t = PackedAligned; - constexpr int numWarps = 4; - dim3 block(warpSize, numWarps); - dim3 grid(params.batch, params.nheads); - selective_state_update_kernel_simple<<>>(params); - }; + FLASHINFER_CHECK(reinterpret_cast(params.state) % sizeof(load_state_t) == 0, + "state pointer must be aligned to ", sizeof(load_state_t), " bytes"); + FLASHINFER_CHECK((params.dim * params.dstate * sizeof(state_t)) % sizeof(load_state_t) == 0, + "state head stride must be aligned to ", sizeof(load_state_t), " bytes"); - dispatchDimDstate(params, AllowedDims{}, AllowedDstates{}, kernel_launcher); + constexpr int numWarps = 4; + dim3 block(warpSize, numWarps); + dim3 grid(params.batch, params.nheads); + selective_state_update_kernel_simple<<>>(params); } #ifdef FLASHINFER_MAMBA_ENABLE_SM90 - else if (sm_major < 10) { - // SM90 (Hopper) uses vertical producer-consumer kernel - auto kernel_launcher = [&]() { - constexpr auto numConsumers = 4; - constexpr auto numWarps = 1 + numConsumers; - constexpr auto numStages = 3; - constexpr auto rowsPerStage = 4 * numConsumers; - FLASHINFER_CHECK(params.dim % rowsPerStage == 0, "dim must be divisible by ", rowsPerStage, - " for SM90+ kernel"); - auto scan_func = selective_state_update_kernel_producer_consumer_vertical< - input_t, weight_t, matrixA_t, state_t, stateIndex_t, DIM, DSTATE, numConsumers, - rowsPerStage, numStages>; + else if (algo == SSUAlgorithm::kVertical) { + constexpr auto numConsumers = 4; + constexpr auto numWarps = 1 + numConsumers; + constexpr auto numStages = 3; + constexpr auto rowsPerStage = 4 * numConsumers; + FLASHINFER_CHECK(params.dim % rowsPerStage == 0, "dim must be divisible by ", rowsPerStage, + " for vertical kernel"); + auto scan_func = selective_state_update_kernel_producer_consumer_vertical< + input_t, weight_t, matrixA_t, state_t, stateIndex_t, DIM, DSTATE, numConsumers, + rowsPerStage, numStages>; + + dim3 block(warpSize, numWarps); + dim3 grid(params.batch, params.nheads); + + auto state_tensor = + tma::buildNdDescriptor(typeid(state_t), + /*shapes*/ {DSTATE, DIM, params.nheads, params.state_cache_size}, + /*strides*/ {1, DSTATE, DSTATE * DIM, params.state_stride_batch}, + /*tiles*/ {DSTATE, rowsPerStage, 1, 1}, params.state); + + using sram_t = SharedStorageVertical; + constexpr size_t smem_size = sizeof(sram_t); + FLASHINFER_CUDA_CHECK( + cudaFuncSetAttribute(scan_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + scan_func<<>>(params, state_tensor); + } else if (algo == SSUAlgorithm::kHorizontal) { + constexpr auto numConsumers = (DIM / 64) * 4; + constexpr auto numProducers = 1; + constexpr auto numWarps = numProducers + numConsumers; + + constexpr auto sectorSize = 32; // bytes + constexpr auto stageCols = 2 * sectorSize / sizeof(state_t); + + constexpr auto totalStages = DSTATE / stageCols; + constexpr auto numStages = (totalStages >= 4) ? 4 : totalStages; + + auto ratio_launcher = [&]() { + auto scan_func = selective_state_update_kernel_producer_consumer_horizontal< + input_t, weight_t, matrixA_t, state_t, stateIndex_t, DIM, DSTATE, numConsumers, stageCols, + RATIO, numStages>; dim3 block(warpSize, numWarps); dim3 grid(params.batch, params.nheads); - auto nh = params.nheads; - auto dim = params.dim; - auto state_tensor = tma::buildNdDescriptor(typeid(state_t), - /*shapes*/ {DSTATE, DIM, nh, params.state_cache_size}, + /*shapes*/ {DSTATE, DIM, params.nheads, params.state_cache_size}, /*strides*/ {1, DSTATE, DSTATE * DIM, params.state_stride_batch}, - /*tiles*/ {DSTATE, rowsPerStage, 1, 1}, params.state); + /*tiles*/ {stageCols, DIM, 1, 1}, params.state); + static_assert(DSTATE % stageCols == 0 && DSTATE >= stageCols); - using sram_t = SharedStorageVertical; + using sram_t = SharedStorageHorizontal; constexpr size_t smem_size = sizeof(sram_t); FLASHINFER_CUDA_CHECK( cudaFuncSetAttribute(scan_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -945,53 +981,13 @@ void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, cudaStream_t scan_func<<>>(params, state_tensor); }; - dispatchDimDstate(params, AllowedDims{}, AllowedDstates{}, kernel_launcher); - } else { - // SM100+ (Blackwell and newer) uses horizontal producer-consumer kernel - auto kernel_launcher = [&]() { - constexpr auto numConsumers = (DIM / 64) * 4; - constexpr auto numProducers = 1; - constexpr auto numWarps = numProducers + numConsumers; - - constexpr auto sectorSize = 32; // bytes - constexpr auto stageCols = 2 * sectorSize / sizeof(state_t); - - constexpr auto totalStages = DSTATE / stageCols; - constexpr auto numStages = (totalStages >= 4) ? 4 : totalStages; - - auto ratio_launcher = [&]() { - auto scan_func = selective_state_update_kernel_producer_consumer_horizontal< - input_t, weight_t, matrixA_t, state_t, stateIndex_t, DIM, DSTATE, numConsumers, - stageCols, RATIO, numStages>; - - dim3 block(warpSize, numWarps); - dim3 grid(params.batch, params.nheads); - - auto nh = params.nheads; - auto dim = params.dim; - - auto state_tensor = - tma::buildNdDescriptor(typeid(state_t), - /*shapes*/ {DSTATE, DIM, nh, params.state_cache_size}, - /*strides*/ {1, DSTATE, DSTATE * DIM, params.state_stride_batch}, - /*tiles*/ {stageCols, DIM, 1, 1}, params.state); - static_assert(DSTATE % stageCols == 0 && DSTATE >= stageCols); - - using sram_t = SharedStorageHorizontal; - constexpr size_t smem_size = sizeof(sram_t); - FLASHINFER_CUDA_CHECK(cudaFuncSetAttribute( - scan_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - - scan_func<<>>(params, state_tensor); - }; - - dispatchRatio(params, std::integer_sequence{}, ratio_launcher); - }; - - dispatchDimDstate(params, AllowedDims{}, AllowedDstates{}, kernel_launcher); + dispatchRatio(params, std::integer_sequence{}, ratio_launcher); } #endif + else { + FLASHINFER_CHECK(false, "Unsupported SSU algorithm: ", static_cast(algo), + ". Vertical/horizontal require FLASHINFER_MAMBA_ENABLE_SM90."); + } } } // namespace flashinfer::mamba diff --git a/include/flashinfer/mamba/selective_state_update.cuh b/include/flashinfer/mamba/selective_state_update.cuh index f8b44c3779..b3517f848f 100644 --- a/include/flashinfer/mamba/selective_state_update.cuh +++ b/include/flashinfer/mamba/selective_state_update.cuh @@ -21,12 +21,14 @@ namespace flashinfer::mamba { -// ============================================================================= -// Allowed dispatch values for kernel instantiation -// ============================================================================= -using AllowedDims = std::integer_sequence; -using AllowedDstates = std::integer_sequence; -using AllowedNtokens = std::integer_sequence; +// Host-side algorithm selection for invokeSelectiveStateUpdate dispatch. +// Not stored in kernel params — no register overhead. +enum class SSUAlgorithm : int32_t { + kAuto = 0, + kSimple = 1, + kVertical = 2, + kHorizontal = 3, +}; struct SelectiveStateUpdateParams { uint32_t batch{}, nheads{}, dim{}, dstate{}, ngroups{}, state_cache_size{}; From c65670c77f92d28343d5ff01432037e89da2ac84 Mon Sep 17 00:00:00 2001 From: Igor Shovkun Date: Wed, 18 Feb 2026 12:10:22 -0800 Subject: [PATCH 18/33] Fix include order: config.inc before header in selective_state_update .cu files The config.inc defines DIM, DSTATE, NTOKENS_MTP as constexpr globals that the header's function templates rely on. With the previous order (header first, config second), NVCC's lenient two-phase lookup masked the issue, but a fresh JIT compilation after cache clearing would fail with 'identifier DIM/DSTATE is undefined' errors. clang-format is disabled for these includes because it reorders them alphabetically, which breaks compilation. AI-assisted --- csrc/selective_state_update.cu | 7 +++++-- csrc/selective_state_update_kernel_inst.cu | 7 +++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/csrc/selective_state_update.cu b/csrc/selective_state_update.cu index b943af821d..7afe8b21e0 100644 --- a/csrc/selective_state_update.cu +++ b/csrc/selective_state_update.cu @@ -13,9 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include - +// clang-format off +// config.inc MUST come before the header: it defines DIM, DSTATE, NTOKENS_MTP +// constexprs that the header's function templates rely on. Reordering breaks compilation. #include "selective_state_update_config.inc" +#include +// clang-format on #include "tvm_ffi_utils.h" using namespace flashinfer; diff --git a/csrc/selective_state_update_kernel_inst.cu b/csrc/selective_state_update_kernel_inst.cu index 11088b3290..6dcec72a5d 100644 --- a/csrc/selective_state_update_kernel_inst.cu +++ b/csrc/selective_state_update_kernel_inst.cu @@ -1,6 +1,9 @@ -#include - +// clang-format off +// config.inc MUST come before the header: it defines DIM, DSTATE, NTOKENS_MTP +// constexprs that the header's function templates rely on. Reordering breaks compilation. #include "selective_state_update_config.inc" +#include +// clang-format on namespace flashinfer::mamba { From 44b6c256534b9b970a3a7fcf87e07946e69c5f0e Mon Sep 17 00:00:00 2001 From: Igor Shovkun Date: Wed, 18 Feb 2026 12:55:57 -0800 Subject: [PATCH 19/33] Parallelize consumer warp loads in vertical SSU kernel Assign each of the 4 consumer warps a single tensor to load (x, B, z, C) instead of warps 0 and 1 each loading two tensors sequentially. This maximizes memory-level parallelism during the load phase. Co-Authored-By: Claude Sonnet 4.6 --- .../mamba/kernel_selective_state_update_stp.cuh | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh b/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh index ee0fbb3344..474491091a 100644 --- a/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh +++ b/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh @@ -483,23 +483,25 @@ __global__ void selective_state_update_kernel_producer_consumer_vertical( } auto const dA = __expf(A_value * dt_value); - if (warp == 0) { // Load x, B + if (warp == 0) { // Load x for (auto d = lane * load_t::count; d < DIM; d += warpSize * load_t::count) { auto* dst = reinterpret_cast(&sram.x[d]); *dst = *reinterpret_cast(&x[batch * params.x_stride_batch + head * DIM + d]); } + } else if (warp == 1) { // Load B for (auto i = lane * load_t::count; i < DSTATE; i += warpSize * load_t::count) { auto* dst = reinterpret_cast(&sram.B[i]); *dst = *reinterpret_cast( &B[batch * params.B_stride_batch + group * DSTATE + i]); } - } else if (warp == 1) { // Load z, C + } else if (warp == 2) { // Load z for (auto d = lane * load_t::count; d < DIM; d += warpSize * load_t::count) { auto* dst = reinterpret_cast(&sram.z[d]); *dst = z ? *reinterpret_cast(&z[batch * params.z_stride_batch + head * DIM + d]) : make_zeros(); } + } else if (warp == 3) { // Load C for (auto i = lane * load_t::count; i < DSTATE; i += warpSize * load_t::count) { auto* dst = reinterpret_cast(&sram.C[i]); *dst = *reinterpret_cast( @@ -985,7 +987,7 @@ void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, SSUAlgorithm } #endif else { - FLASHINFER_CHECK(false, "Unsupported SSU algorithm: ", static_cast(algo), + FLASHINFER_CHECK(false, "Unsupported SSU algorithm: ", SSUAlgorithmToString(algo), ". Vertical/horizontal require FLASHINFER_MAMBA_ENABLE_SM90."); } } From eff403c1166e2259e0442d927e8e05f814b59222 Mon Sep 17 00:00:00 2001 From: Igor Shovkun Date: Wed, 18 Feb 2026 14:15:26 -0800 Subject: [PATCH 20/33] Reduce test combinations in SSU tests to base + independent deviations Replace cartesian-product fixture parametrization with explicit rows: one base case plus one row per parameter deviation. Cuts the test count from ~200+ (MTP) and ~144+ (STP) down to ~26 and ~15 respectively. AI-assisted Co-Authored-By: Claude Sonnet 4.6 --- .../mamba/test_selective_state_update_mtp.py | 494 +++++++++--------- .../mamba/test_selective_state_update_stp.py | 207 +++----- 2 files changed, 313 insertions(+), 388 deletions(-) diff --git a/tests/mamba/test_selective_state_update_mtp.py b/tests/mamba/test_selective_state_update_mtp.py index 866c32e98a..4430d305b0 100644 --- a/tests/mamba/test_selective_state_update_mtp.py +++ b/tests/mamba/test_selective_state_update_mtp.py @@ -15,6 +15,26 @@ from .utils import create_test_inputs, clone_preserving_strides +# Base combination: batch=64, nheads=64, dim=64, dstate=128, cache_steps=4, +# state_dtype=bf16, weight_dtype=f32, use_out_tensor=True +# Each additional row varies exactly one parameter from the base. +# fmt: off +_BASE_PARAMS = [ + # (batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype, use_out_tensor) + ( 64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True ), # base + ( 1, 64, 64, 128, 4, torch.bfloat16, torch.float32, True ), # batch=1 + ( 4, 64, 64, 128, 4, torch.bfloat16, torch.float32, True ), # batch=4 + ( 64, 8, 64, 128, 4, torch.bfloat16, torch.float32, True ), # nheads=8 + ( 64, 64, 128, 128, 4, torch.bfloat16, torch.float32, True ), # dim=128 + ( 64, 64, 64, 64, 4, torch.bfloat16, torch.float32, True ), # dstate=64 + ( 64, 64, 64, 128, 1, torch.bfloat16, torch.float32, True ), # cache_steps=1 + ( 64, 64, 64, 128, 8, torch.bfloat16, torch.float32, True ), # cache_steps=8 + ( 64, 64, 64, 128, 4, torch.float32, torch.float32, True ), # state_dtype=f32 + ( 64, 64, 64, 128, 4, torch.bfloat16, torch.float32, False), # use_out_tensor=False +] +# fmt: on + + class TestSelectiveStateUpdateMTP: """Test class for multi-token selective state update kernels.""" @@ -25,41 +45,7 @@ class TestSelectiveStateUpdateMTP: INPUT_DTYPE = torch.bfloat16 MATRIX_A_DTYPE = torch.float32 - @pytest.fixture(params=[1, 4]) - def batch(self, request): - return request.param - - @pytest.fixture(params=[8, 32]) - def nheads(self, request): - return request.param - - @pytest.fixture(params=[64, 128]) - def dim(self, request): - return request.param - - @pytest.fixture(params=[64, 128]) - def dstate(self, request): - return request.param - - @pytest.fixture(params=[1, 4, 8]) - def cache_steps(self, request): - """Number of tokens in multi-token mode (T dimension).""" - return request.param - - @pytest.fixture(params=[torch.float32, torch.bfloat16]) - def state_dtype(self, request): - return request.param - - @pytest.fixture(params=[torch.float32]) - def weight_dtype(self, request): - return request.param - - @pytest.fixture(params=[False, True]) - def use_out_tensor(self, request): - return request.param - - @pytest.fixture - def inputs( + def make_inputs( self, batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype ): """Create test inputs for given parameters.""" @@ -79,8 +65,7 @@ def inputs( seed=0, ) - @pytest.fixture - def reference_output(self, inputs): + def make_reference_output(self, inputs): """Compute reference output using triton implementation.""" state_ref = clone_preserving_strides(inputs["state_cache"]) y_ref = selective_state_update_triton( @@ -182,9 +167,26 @@ def _print_mismatch_details(self, ref, test, name): f"diff={diff:.6e}, rel_diff={rel_diff:.6e}" ) - def test_output_correctness(self, inputs, reference_output, use_out_tensor): + @pytest.mark.parametrize( + "batch,nheads,dim,dstate,cache_steps,state_dtype,weight_dtype,use_out_tensor", + _BASE_PARAMS, + ) + def test_output_correctness( + self, + batch, + nheads, + dim, + dstate, + cache_steps, + state_dtype, + weight_dtype, + use_out_tensor, + ): """Test that kernel output matches reference within tolerance.""" - y_ref, state_ref = reference_output + inputs = self.make_inputs( + batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype + ) + y_ref, state_ref = self.make_reference_output(inputs) # Prepare output tensor if requested if use_out_tensor: @@ -207,36 +209,7 @@ def test_output_correctness(self, inputs, reference_output, use_out_tensor): class TestSelectiveStateUpdateMTPWithZ(TestSelectiveStateUpdateMTP): """Test multi-token selective_state_update with z tensor (gating).""" - @pytest.fixture(params=[4]) - def batch(self, request): - return request.param - - @pytest.fixture(params=[8]) - def nheads(self, request): - return request.param - - @pytest.fixture(params=[64]) - def dim(self, request): - return request.param - - @pytest.fixture(params=[64]) - def dstate(self, request): - return request.param - - @pytest.fixture(params=[4]) - def cache_steps(self, request): - return request.param - - @pytest.fixture(params=[torch.bfloat16]) - def state_dtype(self, request): - return request.param - - @pytest.fixture(params=[torch.float32]) - def weight_dtype(self, request): - return request.param - - @pytest.fixture - def inputs( + def make_inputs( self, batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype ): """Create test inputs with z tensor.""" @@ -256,41 +229,56 @@ def inputs( seed=0, ) + @pytest.mark.parametrize( + "batch,nheads,dim,dstate,cache_steps,state_dtype,weight_dtype,use_out_tensor", + [(64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True)], + ) + def test_output_correctness( + self, + batch, + nheads, + dim, + dstate, + cache_steps, + state_dtype, + weight_dtype, + use_out_tensor, + ): + super().test_output_correctness( + batch, + nheads, + dim, + dstate, + cache_steps, + state_dtype, + weight_dtype, + use_out_tensor, + ) + class TestSelectiveStateUpdateMTPDisableStateUpdate(TestSelectiveStateUpdateMTP): """Test multi-token selective_state_update with disable_state_update=True.""" - @pytest.fixture(params=[4]) - def batch(self, request): - return request.param - - @pytest.fixture(params=[32]) - def nheads(self, request): - return request.param - - @pytest.fixture(params=[64]) - def dim(self, request): - return request.param - - @pytest.fixture(params=[64, 128]) - def dstate(self, request): - return request.param - - @pytest.fixture(params=[4, 8]) - def cache_steps(self, request): - return request.param - - @pytest.fixture(params=[torch.bfloat16]) - def state_dtype(self, request): - return request.param - - @pytest.fixture(params=[torch.float32]) - def weight_dtype(self, request): - return request.param - - def test_output_correctness(self, inputs, reference_output, use_out_tensor): + @pytest.mark.parametrize( + "batch,nheads,dim,dstate,cache_steps,state_dtype,weight_dtype,use_out_tensor", + [(64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True)], + ) + def test_output_correctness( + self, + batch, + nheads, + dim, + dstate, + cache_steps, + state_dtype, + weight_dtype, + use_out_tensor, + ): """Test that kernel output matches reference but state is not updated.""" - y_ref, state_ref = reference_output + inputs = self.make_inputs( + batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype + ) + y_ref, state_ref = self.make_reference_output(inputs) # Save the initial state before running the kernel state_initial = inputs["state_cache"].clone() @@ -343,36 +331,7 @@ def test_output_correctness(self, inputs, reference_output, use_out_tensor): class TestSelectiveStateUpdateMTPWithIntermediateStates(TestSelectiveStateUpdateMTP): """Test multi-token selective_state_update with intermediate states buffer.""" - @pytest.fixture(params=[4]) - def batch(self, request): - return request.param - - @pytest.fixture(params=[32]) - def nheads(self, request): - return request.param - - @pytest.fixture(params=[64]) - def dim(self, request): - return request.param - - @pytest.fixture(params=[64, 128]) - def dstate(self, request): - return request.param - - @pytest.fixture(params=[2, 4, 8]) - def cache_steps(self, request): - return request.param - - @pytest.fixture(params=[torch.bfloat16]) - def state_dtype(self, request): - return request.param - - @pytest.fixture(params=[torch.float32]) - def weight_dtype(self, request): - return request.param - - @pytest.fixture - def inputs( + def make_inputs( self, batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype ): """Create test inputs with intermediate states buffer.""" @@ -392,8 +351,7 @@ def inputs( seed=0, ) - @pytest.fixture - def reference_output(self, inputs): + def make_reference_output(self, inputs): """Compute reference output using triton implementation with intermediate states.""" state_ref = clone_preserving_strides(inputs["state_cache"]) intermediate_states_ref = inputs["intermediate_states_buffer"].clone() @@ -440,9 +398,37 @@ def run_kernel_with_intermediate_states(self, inputs, out=None): cache_steps=inputs["cache_steps"], ) - def test_output_correctness(self, inputs, reference_output, use_out_tensor): + # fmt: off + _INTERMEDIATE_PARAMS = [ + # (batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype, use_out_tensor) + ( 64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True ), # base + ( 64, 64, 64, 64, 4, torch.bfloat16, torch.float32, True ), # dstate=64 + ( 64, 64, 64, 128, 2, torch.bfloat16, torch.float32, True ), # cache_steps=2 + ( 64, 64, 64, 128, 8, torch.bfloat16, torch.float32, True ), # cache_steps=8 + ( 64, 64, 64, 128, 4, torch.bfloat16, torch.float32, False), # use_out_tensor=False + ] + # fmt: on + + @pytest.mark.parametrize( + "batch,nheads,dim,dstate,cache_steps,state_dtype,weight_dtype,use_out_tensor", + _INTERMEDIATE_PARAMS, + ) + def test_output_correctness( + self, + batch, + nheads, + dim, + dstate, + cache_steps, + state_dtype, + weight_dtype, + use_out_tensor, + ): """Test that kernel output matches and intermediate states are cached correctly.""" - y_ref, state_ref, intermediate_states_ref = reference_output + inputs = self.make_inputs( + batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype + ) + y_ref, state_ref, intermediate_states_ref = self.make_reference_output(inputs) # Prepare output tensor if requested if use_out_tensor: @@ -488,36 +474,7 @@ def test_output_correctness(self, inputs, reference_output, use_out_tensor): class TestSelectiveStateUpdateMTPNonContiguous(TestSelectiveStateUpdateMTP): """Test multi-token selective_state_update with non-contiguous state cache.""" - @pytest.fixture(params=[4]) - def batch(self, request): - return request.param - - @pytest.fixture(params=[8]) - def nheads(self, request): - return request.param - - @pytest.fixture(params=[64]) - def dim(self, request): - return request.param - - @pytest.fixture(params=[64]) - def dstate(self, request): - return request.param - - @pytest.fixture(params=[4]) - def cache_steps(self, request): - return request.param - - @pytest.fixture(params=[torch.bfloat16]) - def state_dtype(self, request): - return request.param - - @pytest.fixture(params=[torch.float32]) - def weight_dtype(self, request): - return request.param - - @pytest.fixture - def inputs( + def make_inputs( self, batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype ): """Create test inputs with non-contiguous state cache (2x batch stride).""" @@ -540,8 +497,7 @@ def inputs( seed=0, ) - @pytest.fixture - def reference_output(self, inputs): + def make_reference_output(self, inputs): """Compute reference output, preserving non-contiguous strides.""" state_ref = clone_preserving_strides(inputs["state_cache"]) y_ref = selective_state_update_triton( @@ -560,38 +516,36 @@ def reference_output(self, inputs): ) return y_ref, state_ref + @pytest.mark.parametrize( + "batch,nheads,dim,dstate,cache_steps,state_dtype,weight_dtype,use_out_tensor", + [(64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True)], + ) + def test_output_correctness( + self, + batch, + nheads, + dim, + dstate, + cache_steps, + state_dtype, + weight_dtype, + use_out_tensor, + ): + super().test_output_correctness( + batch, + nheads, + dim, + dstate, + cache_steps, + state_dtype, + weight_dtype, + use_out_tensor, + ) + class TestSelectiveStateUpdateMTPInt32Indices(TestSelectiveStateUpdateMTP): """Test multi-token selective_state_update with int32 state_batch_indices.""" - @pytest.fixture(params=[4]) - def batch(self, request): - return request.param - - @pytest.fixture(params=[8]) - def nheads(self, request): - return request.param - - @pytest.fixture(params=[64]) - def dim(self, request): - return request.param - - @pytest.fixture(params=[64]) - def dstate(self, request): - return request.param - - @pytest.fixture(params=[4]) - def cache_steps(self, request): - return request.param - - @pytest.fixture(params=[torch.bfloat16]) - def state_dtype(self, request): - return request.param - - @pytest.fixture(params=[torch.float32]) - def weight_dtype(self, request): - return request.param - def run_kernel(self, inputs, out=None, disable_state_update=False): """Run the flashinfer kernel with int32 state_batch_indices.""" # Cast slot_idx to int32 @@ -614,46 +568,51 @@ def run_kernel(self, inputs, out=None, disable_state_update=False): disable_state_update=disable_state_update, ) + @pytest.mark.parametrize( + "batch,nheads,dim,dstate,cache_steps,state_dtype,weight_dtype,use_out_tensor", + [(64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True)], + ) + def test_output_correctness( + self, + batch, + nheads, + dim, + dstate, + cache_steps, + state_dtype, + weight_dtype, + use_out_tensor, + ): + super().test_output_correctness( + batch, + nheads, + dim, + dstate, + cache_steps, + state_dtype, + weight_dtype, + use_out_tensor, + ) + class TestSelectiveStateUpdateMTPVariousNgroups(TestSelectiveStateUpdateMTP): """Test multi-token selective_state_update with various ngroups values.""" - NGROUPS = None # Will be set by fixture - - @pytest.fixture(params=[1, 2, 4, 8]) - def ngroups(self, request): - return request.param - - @pytest.fixture(params=[4]) - def batch(self, request): - return request.param - - @pytest.fixture(params=[32]) - def nheads(self, request): - return request.param - - @pytest.fixture(params=[64]) - def dim(self, request): - return request.param - - @pytest.fixture(params=[64]) - def dstate(self, request): - return request.param - - @pytest.fixture(params=[4]) - def cache_steps(self, request): - return request.param - - @pytest.fixture(params=[torch.bfloat16]) - def state_dtype(self, request): - return request.param - - @pytest.fixture(params=[torch.float32]) - def weight_dtype(self, request): - return request.param - - @pytest.fixture - def inputs( + # fmt: off + _NGROUPS_PARAMS = [ + # (batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype, use_out_tensor, ngroups) + ( 64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True, 1), + ( 64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True, 2), + ( 64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True, 4), + ( 64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True, 8), + ] + # fmt: on + + @pytest.mark.parametrize( + "batch,nheads,dim,dstate,cache_steps,state_dtype,weight_dtype,use_out_tensor,ngroups", + _NGROUPS_PARAMS, + ) + def test_output_correctness( self, batch, nheads, @@ -662,10 +621,11 @@ def inputs( cache_steps, state_dtype, weight_dtype, + use_out_tensor, ngroups, ): - """Create test inputs with specified ngroups.""" - return create_test_inputs( + """Test that kernel output matches reference within tolerance.""" + inputs = create_test_inputs( batch, nheads, dim, @@ -680,38 +640,60 @@ def inputs( cache_steps=cache_steps, seed=0, ) + y_ref, state_ref = self.make_reference_output(inputs) + if use_out_tensor: + out = torch.empty_like(inputs["x"]) + else: + out = None -class TestSelectiveStateUpdateMTPLargeBatch(TestSelectiveStateUpdateMTP): - """Test multi-token selective_state_update with larger batch sizes.""" - - @pytest.fixture(params=[16, 64]) - def batch(self, request): - return request.param - - @pytest.fixture(params=[32]) - def nheads(self, request): - return request.param + y_test = self.run_kernel(inputs, out=out) - @pytest.fixture(params=[64]) - def dim(self, request): - return request.param + if use_out_tensor: + assert y_test.data_ptr() == out.data_ptr(), ( + "Returned tensor should be the same object as the provided output tensor" + ) - @pytest.fixture(params=[64]) - def dstate(self, request): - return request.param + self.assert_outputs_match(y_ref, y_test) + self.assert_states_match(state_ref, inputs["state_cache"], inputs["slot_idx"]) - @pytest.fixture(params=[4, 8]) - def cache_steps(self, request): - return request.param - @pytest.fixture(params=[torch.bfloat16]) - def state_dtype(self, request): - return request.param +class TestSelectiveStateUpdateMTPLargeBatch(TestSelectiveStateUpdateMTP): + """Test multi-token selective_state_update with larger batch sizes.""" - @pytest.fixture(params=[torch.float32]) - def weight_dtype(self, request): - return request.param + # fmt: off + _LARGE_BATCH_PARAMS = [ + # (batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype, use_out_tensor) + ( 16, 64, 64, 128, 4, torch.bfloat16, torch.float32, True ), # batch=16 + ( 256, 64, 64, 128, 4, torch.bfloat16, torch.float32, True ), # batch=256 + ] + # fmt: on + + @pytest.mark.parametrize( + "batch,nheads,dim,dstate,cache_steps,state_dtype,weight_dtype,use_out_tensor", + _LARGE_BATCH_PARAMS, + ) + def test_output_correctness( + self, + batch, + nheads, + dim, + dstate, + cache_steps, + state_dtype, + weight_dtype, + use_out_tensor, + ): + super().test_output_correctness( + batch, + nheads, + dim, + dstate, + cache_steps, + state_dtype, + weight_dtype, + use_out_tensor, + ) class TestSelectiveStateUpdateMTPIndicesDtypeMismatch: diff --git a/tests/mamba/test_selective_state_update_stp.py b/tests/mamba/test_selective_state_update_stp.py index c26a5849b8..ef007eddd1 100644 --- a/tests/mamba/test_selective_state_update_stp.py +++ b/tests/mamba/test_selective_state_update_stp.py @@ -8,6 +8,26 @@ from .utils import create_test_inputs, clone_preserving_strides +# Base combination: batch=64, nheads=64, dim=64, dstate=128, state_dtype=bf16, +# weight_dtype=f32, use_out_tensor=True +# Each additional row varies exactly one parameter from the base. +# fmt: off +_BASE_PARAMS = [ + # (batch, nheads, dim, dstate, state_dtype, weight_dtype, use_out_tensor) + ( 64, 64, 64, 128, torch.bfloat16, torch.float32, True ), # base bf16 + ( 64, 64, 64, 128, torch.float32, torch.float32, True ), # state_dtype=f32 + ( 1, 64, 64, 128, torch.bfloat16, torch.float32, True ), # batch=1 + ( 64, 8, 64, 128, torch.bfloat16, torch.float32, True ), # nheads=8 + ( 64, 64, 128, 128, torch.bfloat16, torch.float32, True ), # dim=128 + ( 64, 64, 64, 64, torch.bfloat16, torch.float32, True ), # dstate=64 + ( 64, 64, 64, 256, torch.bfloat16, torch.float32, True ), # dstate=256 + ( 64, 64, 64, 128, torch.float16, torch.float32, True ), # state_dtype=f16 + ( 64, 64, 64, 128, torch.bfloat16, torch.bfloat16, True ), # weight_dtype=bf16 + ( 64, 64, 64, 128, torch.bfloat16, torch.float32, False), # use_out_tensor=False +] +# fmt: on + + class TestSelectiveStateUpdate: """Test class for selective state update kernels.""" @@ -18,36 +38,7 @@ class TestSelectiveStateUpdate: INPUT_DTYPE = torch.bfloat16 MATRIX_A_DTYPE = torch.float32 - @pytest.fixture(params=[1, 64]) - def batch(self, request): - return request.param - - @pytest.fixture(params=[8, 64]) - def nheads(self, request): - return request.param - - @pytest.fixture(params=[64, 128]) - def dim(self, request): - return request.param - - @pytest.fixture(params=[64, 128, 256]) - def dstate(self, request): - return request.param - - @pytest.fixture(params=[torch.float16, torch.bfloat16, torch.float32]) - def state_dtype(self, request): - return request.param - - @pytest.fixture(params=[torch.float32, torch.bfloat16]) - def weight_dtype(self, request): - return request.param - - @pytest.fixture(params=[False, True]) - def use_out_tensor(self, request): - return request.param - - @pytest.fixture - def inputs(self, batch, nheads, dim, dstate, state_dtype, weight_dtype): + def make_inputs(self, batch, nheads, dim, dstate, state_dtype, weight_dtype): """Create test inputs for given parameters.""" return create_test_inputs( batch, @@ -63,8 +54,7 @@ def inputs(self, batch, nheads, dim, dstate, state_dtype, weight_dtype): seed=0, ) - @pytest.fixture - def reference_output(self, inputs): + def make_reference_output(self, inputs): """Compute reference output using triton implementation.""" state_ref = inputs["state_cache"].clone() y_ref = selective_state_update_triton( @@ -165,15 +155,18 @@ def _print_mismatch_details(self, ref, test, name): f"diff={diff:.6e}, rel_diff={rel_diff:.6e}" ) - def test_output_correctness(self, inputs, reference_output, use_out_tensor): + @pytest.mark.parametrize( + "batch,nheads,dim,dstate,state_dtype,weight_dtype,use_out_tensor", _BASE_PARAMS + ) + def test_output_correctness( + self, batch, nheads, dim, dstate, state_dtype, weight_dtype, use_out_tensor + ): """Test that kernel output matches reference within tolerance.""" - y_ref, state_ref = reference_output + inputs = self.make_inputs(batch, nheads, dim, dstate, state_dtype, weight_dtype) + y_ref, state_ref = self.make_reference_output(inputs) # Prepare output tensor if requested if use_out_tensor: - batch = inputs["x"].shape[0] - nheads = inputs["x"].shape[1] - dim = inputs["x"].shape[2] out = torch.empty(batch, nheads, dim, dtype=self.INPUT_DTYPE, device="cuda") else: out = None @@ -193,32 +186,7 @@ def test_output_correctness(self, inputs, reference_output, use_out_tensor): class TestSelectiveStateUpdateWithZ(TestSelectiveStateUpdate): """Test selective_state_update with z tensor (gating).""" - @pytest.fixture(params=[1]) - def batch(self, request): - return request.param - - @pytest.fixture(params=[8]) - def nheads(self, request): - return request.param - - @pytest.fixture(params=[64]) - def dim(self, request): - return request.param - - @pytest.fixture(params=[128]) - def dstate(self, request): - return request.param - - @pytest.fixture(params=[torch.bfloat16]) - def state_dtype(self, request): - return request.param - - @pytest.fixture(params=[torch.bfloat16]) - def weight_dtype(self, request): - return request.param - - @pytest.fixture - def inputs(self, batch, nheads, dim, dstate, state_dtype, weight_dtype): + def make_inputs(self, batch, nheads, dim, dstate, state_dtype, weight_dtype): """Create test inputs with z tensor.""" return create_test_inputs( batch, @@ -234,34 +202,21 @@ def inputs(self, batch, nheads, dim, dstate, state_dtype, weight_dtype): seed=0, ) + @pytest.mark.parametrize( + "batch,nheads,dim,dstate,state_dtype,weight_dtype,use_out_tensor", + [(64, 64, 64, 128, torch.bfloat16, torch.float32, True)], + ) + def test_output_correctness( + self, batch, nheads, dim, dstate, state_dtype, weight_dtype, use_out_tensor + ): + super().test_output_correctness( + batch, nheads, dim, dstate, state_dtype, weight_dtype, use_out_tensor + ) + class TestSelectiveStateUpdateDisableStateUpdate(TestSelectiveStateUpdate): """Test selective_state_update with disable_state_update=True.""" - @pytest.fixture(params=[1]) - def batch(self, request): - return request.param - - @pytest.fixture(params=[8]) - def nheads(self, request): - return request.param - - @pytest.fixture(params=[128]) - def dim(self, request): - return request.param - - @pytest.fixture(params=[64]) - def dstate(self, request): - return request.param - - @pytest.fixture(params=[torch.bfloat16]) - def state_dtype(self, request): - return request.param - - @pytest.fixture(params=[torch.bfloat16]) - def weight_dtype(self, request): - return request.param - def run_kernel(self, inputs, out=None): """Run the flashinfer kernel with disable_state_update=True.""" return flashinfer.mamba.selective_state_update( @@ -281,18 +236,22 @@ def run_kernel(self, inputs, out=None): disable_state_update=True, ) - def test_output_correctness(self, inputs, reference_output, use_out_tensor): + @pytest.mark.parametrize( + "batch,nheads,dim,dstate,state_dtype,weight_dtype,use_out_tensor", + [(64, 64, 64, 128, torch.bfloat16, torch.float32, True)], + ) + def test_output_correctness( + self, batch, nheads, dim, dstate, state_dtype, weight_dtype, use_out_tensor + ): """Test that kernel output matches reference but state is not updated.""" - y_ref, state_ref = reference_output + inputs = self.make_inputs(batch, nheads, dim, dstate, state_dtype, weight_dtype) + y_ref, state_ref = self.make_reference_output(inputs) # Save the initial state before running the kernel state_initial = inputs["state_cache"].clone() # Prepare output tensor if requested if use_out_tensor: - batch = inputs["x"].shape[0] - nheads = inputs["x"].shape[1] - dim = inputs["x"].shape[2] out = torch.empty(batch, nheads, dim, dtype=self.INPUT_DTYPE, device="cuda") else: out = None @@ -339,20 +298,7 @@ def test_output_correctness(self, inputs, reference_output, use_out_tensor): class TestSelectiveStateUpdateNonContiguous(TestSelectiveStateUpdate): """Test selective_state_update with non-contiguous state cache.""" - @pytest.fixture(params=[128]) - def dstate(self, request): - return request.param - - @pytest.fixture(params=[torch.bfloat16]) - def state_dtype(self, request): - return request.param - - @pytest.fixture(params=[torch.float32]) - def weight_dtype(self, request): - return request.param - - @pytest.fixture - def inputs(self, batch, nheads, dim, dstate, state_dtype, weight_dtype): + def make_inputs(self, batch, nheads, dim, dstate, state_dtype, weight_dtype): """Create test inputs with non-contiguous state cache (2x batch stride).""" noncontiguous_batch_stride = 2 * nheads * dim * dstate @@ -371,8 +317,7 @@ def inputs(self, batch, nheads, dim, dstate, state_dtype, weight_dtype): seed=0, ) - @pytest.fixture - def reference_output(self, inputs): + def make_reference_output(self, inputs): """Compute reference output, preserving non-contiguous strides.""" state_ref = clone_preserving_strides(inputs["state_cache"]) y_ref = selective_state_update_triton( @@ -391,34 +336,21 @@ def reference_output(self, inputs): ) return y_ref, state_ref + @pytest.mark.parametrize( + "batch,nheads,dim,dstate,state_dtype,weight_dtype,use_out_tensor", + [(64, 64, 64, 128, torch.bfloat16, torch.float32, True)], + ) + def test_output_correctness( + self, batch, nheads, dim, dstate, state_dtype, weight_dtype, use_out_tensor + ): + super().test_output_correctness( + batch, nheads, dim, dstate, state_dtype, weight_dtype, use_out_tensor + ) + class TestSelectiveStateUpdateInt32Indices(TestSelectiveStateUpdate): """Test selective_state_update with int32 state_batch_indices.""" - @pytest.fixture(params=[1]) - def batch(self, request): - return request.param - - @pytest.fixture(params=[8]) - def nheads(self, request): - return request.param - - @pytest.fixture(params=[64]) - def dim(self, request): - return request.param - - @pytest.fixture(params=[128]) - def dstate(self, request): - return request.param - - @pytest.fixture(params=[torch.bfloat16]) - def state_dtype(self, request): - return request.param - - @pytest.fixture(params=[torch.bfloat16]) - def weight_dtype(self, request): - return request.param - def run_kernel(self, inputs, out=None): """Run the flashinfer kernel with int32 state_batch_indices.""" # Cast slot_idx to int32 @@ -440,6 +372,17 @@ def run_kernel(self, inputs, out=None): out=out, ) + @pytest.mark.parametrize( + "batch,nheads,dim,dstate,state_dtype,weight_dtype,use_out_tensor", + [(64, 64, 64, 128, torch.bfloat16, torch.float32, True)], + ) + def test_output_correctness( + self, batch, nheads, dim, dstate, state_dtype, weight_dtype, use_out_tensor + ): + super().test_output_correctness( + batch, nheads, dim, dstate, state_dtype, weight_dtype, use_out_tensor + ) + class TestSelectiveStateUpdateDtypeMismatch: """Test that selective_state_update fails with dtype mismatch between D and dt.""" From afc7c6a05bb8b73d89f2b38982f537714454e15b Mon Sep 17 00:00:00 2001 From: Igor Shovkun Date: Wed, 18 Feb 2026 21:26:18 -0800 Subject: [PATCH 21/33] Add algorithm parameter to selective_state_update tests - Parametrize tests to run with all supported algorithms - Update test logic to pass algorithm argument through - Improve test output messages to include algorithm name - Add utility to detect available algorithms based on GPU arch --- .../kernel_selective_state_update_stp.cuh | 230 +++++++++++------- .../mamba/selective_state_update.cuh | 15 ++ .../mamba/test_selective_state_update_stp.py | 113 +++++++-- 3 files changed, 253 insertions(+), 105 deletions(-) diff --git a/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh b/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh index 474491091a..58d43ab2af 100644 --- a/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh +++ b/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh @@ -57,17 +57,21 @@ __device__ __forceinline__ int conflict_free_column(int group, int baseCol) { return (baseCol + stateValuesPerBank * bankCycle) % colsPerStage; } -template +template struct SharedStorageSimple { - alignas(alignof(PackedAligned)) input_t x[dim]; - alignas(alignof(PackedAligned)) input_t z[dim]; + alignas(alignof(PackedAligned)) input_t x[rows_per_block]; + alignas(alignof(PackedAligned)) input_t z[rows_per_block]; alignas(alignof(PackedAligned)) input_t B[dstate]; alignas(alignof(PackedAligned)) input_t C[dstate]; - float out[dim]; + float out[rows_per_block]; }; +// Grid: (batch, nheads, cdiv(DIM, ROWS_PER_BLOCK)) +// When ROWS_PER_BLOCK == DIM, degenerates to the non-tiled case (blockIdx.z == 0 always). +// Used when batch*nheads is too small to saturate the GPU: set ROWS_PER_BLOCK < DIM to +// split dim across blocks for better occupancy. template + typename stateIndex_t, int DIM, int DSTATE, int ROWS_PER_BLOCK, int numWarps> __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams params) { auto* __restrict__ output = reinterpret_cast(params.output); auto* __restrict__ state = reinterpret_cast(params.state); @@ -77,8 +81,8 @@ __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams auto const* __restrict__ A = reinterpret_cast(params.A); auto const* __restrict__ B = reinterpret_cast(params.B); auto const* __restrict__ C = reinterpret_cast(params.C); - auto const* __restrict__ D = reinterpret_cast(params.D); // D: (nheads, dim) - auto const* __restrict__ dt_bias = reinterpret_cast(params.dt_bias); // (nheads) + auto const* __restrict__ D = reinterpret_cast(params.D); + auto const* __restrict__ dt_bias = reinterpret_cast(params.dt_bias); auto const* __restrict__ z = reinterpret_cast(params.z); auto const* __restrict__ state_batch_indices = reinterpret_cast(params.state_batch_indices); @@ -87,10 +91,11 @@ __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams int const nheads = params.nheads; int const ngroups = params.ngroups; - constexpr auto rowsPerWarp = (DIM + numWarps - 1) / numWarps; + constexpr auto rowsPerWarp = (ROWS_PER_BLOCK + numWarps - 1) / numWarps; auto const batch = blockIdx.x; auto const head = blockIdx.y; + auto const dim_offset = blockIdx.z * ROWS_PER_BLOCK; auto const group = head / (nheads / ngroups); auto lane = threadIdx.x % warpSize; auto warp = threadIdx.y; @@ -98,7 +103,7 @@ __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams auto const state_batch = (state_batch_indices) ? state_batch_indices[batch] : batch; state += state_batch * params.state_stride_batch + head * DIM * DSTATE; - __shared__ SharedStorageSimple sram; + __shared__ SharedStorageSimple sram; static constexpr auto stateLoadSize = getVectorLoadSizeForFullUtilization(); using load_state_t = PackedAligned; @@ -116,23 +121,21 @@ __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams auto d_value = D ? toFloat(D[head]) : 0.f; + // Load x slice and B (warp 0), z slice and C (warp 1) if (warp == 0) { - for (auto d = lane * load_input_t::count; d < DIM; d += warpSize * load_input_t::count) { - auto* dst = reinterpret_cast(&sram.x[d]); - *dst = *reinterpret_cast( - &x[batch * params.x_stride_batch + head * DIM + d]); + for (auto d = lane; d < ROWS_PER_BLOCK; d += warpSize) { + if (dim_offset + d < DIM) + sram.x[d] = x[batch * params.x_stride_batch + head * DIM + dim_offset + d]; } for (auto i = lane * load_input_t::count; i < DSTATE; i += warpSize * load_input_t::count) { auto* dst = reinterpret_cast(&sram.B[i]); *dst = *reinterpret_cast( &B[batch * params.B_stride_batch + group * DSTATE + i]); } - } else if (warp == 1) { // Load z, C - for (auto d = lane * load_input_t::count; d < DIM; d += warpSize * load_input_t::count) { - auto* dst = reinterpret_cast(&sram.z[d]); - *dst = z ? *reinterpret_cast( - &z[batch * params.z_stride_batch + head * DIM + d]) - : make_zeros(); + } else if (warp == 1) { + for (auto d = lane; d < ROWS_PER_BLOCK; d += warpSize) { + if (dim_offset + d < DIM) + sram.z[d] = z ? z[batch * params.z_stride_batch + head * DIM + dim_offset + d] : input_t(0); } for (auto i = lane * load_input_t::count; i < DSTATE; i += warpSize * load_input_t::count) { auto* dst = reinterpret_cast(&sram.C[i]); @@ -143,11 +146,11 @@ __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams __syncthreads(); for (auto _d = warp * rowsPerWarp; _d < (warp + 1) * rowsPerWarp; _d++) { - auto d = _d; + auto d = dim_offset + _d; if (d >= DIM) break; float x_value = toFloat(sram.x[_d]); - float out_value = d_value * x_value * int(lane == 0); // first lane has the value + float out_value = d_value * x_value * int(lane == 0); for (int i = lane * load_state_t::count; i < DSTATE; i += warpSize * load_state_t::count) { auto rState = make_zeros(); @@ -170,7 +173,6 @@ __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams *reinterpret_cast(&state[d * DSTATE + i]) = rState; } - // warpReduce the out_value out_value = warpReduceSum(out_value); if (lane == 0) { sram.out[_d] = out_value; @@ -180,11 +182,12 @@ __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams __syncthreads(); for (int l = lane; l < rowsPerWarp; l += warpSize) { - auto d = warp * rowsPerWarp + l; + auto _d = warp * rowsPerWarp + l; + auto d = dim_offset + _d; if (d < DIM) { - auto out_value = sram.out[d]; + auto out_value = sram.out[_d]; if (z) { - float z_value = toFloat(sram.z[d]); + float z_value = toFloat(sram.z[_d]); float sig_z = __fdividef(1.f, (1.f + __expf(0.f - z_value))); float silu_z = z_value * sig_z; out_value *= silu_z; @@ -210,10 +213,14 @@ struct SharedStorageVertical { barrier_t bar_consumers; }; -template +template __device__ __forceinline__ void producer_func_vertical(SramT& sram, CUtensorMap const& tensorState, - int batch, int head) { + input_t const* x_global_ptr, + input_t const* B_global_ptr, + input_t const* C_global_ptr, + input_t const* z_global_ptr, int batch, + int head) { #ifdef FLASHINFER_MAMBA_ENABLE_SM90 namespace cde = cuda::device::experimental; @@ -222,11 +229,44 @@ __device__ __forceinline__ void producer_func_vertical(SramT& sram, CUtensorMap auto constexpr stagesWriteOnly = numStages; auto constexpr bytesState = rowsPerStage * DSTATE * sizeof(state_t); - auto constexpr bytesToArrive = bytesState; + auto constexpr bytesX = DIM * sizeof(input_t); + auto constexpr bytesB = DSTATE * sizeof(input_t); + auto constexpr bytesC = DSTATE * sizeof(input_t); + auto constexpr bytesZ = hasZ ? DIM * sizeof(input_t) : 0; + auto constexpr bytesInputs = bytesX + bytesB + bytesC + bytesZ; + + // Phase 1, iter 0: fire all input vector loads + state load (if readState) + // All inputs piggyback onto bar_full[0] so consumers get them before stage 0 + { + constexpr auto stage = 0; + constexpr auto d = 0; - // Phase 1: Read only (filling the pipeline) + sram.bar_empty[stage].wait(sram.bar_empty[stage].arrive()); + + cuda::device::memcpy_async_tx(&sram.x[0], x_global_ptr, cuda::aligned_size_t<16>(bytesX), + sram.bar_full[stage]); + cuda::device::memcpy_async_tx(&sram.B[0], B_global_ptr, cuda::aligned_size_t<16>(bytesB), + sram.bar_full[stage]); + cuda::device::memcpy_async_tx(&sram.C[0], C_global_ptr, cuda::aligned_size_t<16>(bytesC), + sram.bar_full[stage]); + if constexpr (hasZ) { + cuda::device::memcpy_async_tx(&sram.z[0], z_global_ptr, cuda::aligned_size_t<16>(bytesZ), + sram.bar_full[stage]); + } + + if constexpr (readState) { + cde::cp_async_bulk_tensor_4d_global_to_shared(&sram.state[stage][0], &tensorState, 0, d, head, + batch, sram.bar_full[stage]); + auto const _ = + cuda::device::barrier_arrive_tx(sram.bar_full[stage], 1, bytesState + bytesInputs); + } else { + auto const _ = cuda::device::barrier_arrive_tx(sram.bar_full[stage], 1, bytesInputs); + } + } + + // Phase 1, iter 1..stagesReadOnly-1: state only (x already in flight) #pragma unroll - for (int iter = 0; iter < stagesReadOnly; ++iter) { + for (int iter = 1; iter < stagesReadOnly; ++iter) { auto const stage = iter % numStages; auto const d = iter * rowsPerStage; @@ -235,8 +275,7 @@ __device__ __forceinline__ void producer_func_vertical(SramT& sram, CUtensorMap if constexpr (readState) { cde::cp_async_bulk_tensor_4d_global_to_shared(&sram.state[stage][0], &tensorState, 0, d, head, batch, sram.bar_full[stage]); - - auto const _ = cuda::device::barrier_arrive_tx(sram.bar_full[stage], 1, bytesToArrive); + auto const _ = cuda::device::barrier_arrive_tx(sram.bar_full[stage], 1, bytesState); } else { auto const _ = sram.bar_full[stage].arrive(); } @@ -267,7 +306,7 @@ __device__ __forceinline__ void producer_func_vertical(SramT& sram, CUtensorMap if constexpr (readState) { cde::cp_async_bulk_tensor_4d_global_to_shared(&sram.state[stage][0], &tensorState, 0, d_read, head, batch, sram.bar_full[stage]); - auto const _ = cuda::device::barrier_arrive_tx(sram.bar_full[stage], 1, bytesToArrive); + auto const _ = cuda::device::barrier_arrive_tx(sram.bar_full[stage], 1, bytesState); } else { auto const _ = sram.bar_full[stage].arrive(); } @@ -417,7 +456,7 @@ __global__ void selective_state_update_kernel_producer_consumer_vertical( auto lane = threadIdx.x % warpSize; auto warp = threadIdx.y; - auto const state_batch = (state_batch_indices) ? state_batch_indices[batch] : batch; + auto const state_batch = (state_batch_indices) ? __ldg(&state_batch_indices[batch]) : batch; extern __shared__ uint8_t sbuffer[]; using sram_t = SharedStorageVertical() { + producer_func_vertical(sram, tensorState, x_global_ptr, B_global_ptr, + C_global_ptr, hasZ ? z_global_ptr : nullptr, + state_batch, head); + }; + auto const dispatch_state = [&]() { if (read_state && write_state) - producer_func_vertical( - sram, tensorState, state_batch, head); - else if (read_state && !write_state) - producer_func_vertical( - sram, tensorState, state_batch, head); + call.template operator()(); + else if (read_state) + call.template operator()(); else - producer_func_vertical( - sram, tensorState, state_batch, head); + call.template operator()(); + }; + + cg::invoke_one(cg::coalesced_threads(), [&]() { + if (z_global_ptr) + dispatch_state.template operator()(); + else + dispatch_state.template operator()(); }); } } else { // consumers - using load_t = PackedAligned; - #pragma unroll // Unblock the producer for (uint8_t stage = 0; stage < numStages; ++stage) { auto const _ = sram.bar_empty[stage].arrive(); } - // Load A - auto const A_value = toFloat(A[head]); + // Load A, D, dt, dt_bias via __ldg (read-only texture cache) — + // these are broadcast scalars read once per block. + auto const A_value = toFloat(__ldg(&A[head])); - // Load D - auto const d_value = D ? toFloat(D[head]) : 0.f; + auto const d_value = D ? toFloat(__ldg(&D[head])) : 0.f; - // load dt_value - auto dt_value = toFloat(dt[batch * params.dt_stride_batch + head]); - if (dt_bias) dt_value += toFloat(dt_bias[head]); + auto dt_value = toFloat(__ldg(&dt[batch * params.dt_stride_batch + head])); + if (dt_bias) dt_value += toFloat(__ldg(&dt_bias[head])); if (params.dt_softplus) { dt_value = thresholded_softplus(dt_value); } auto const dA = __expf(A_value * dt_value); - if (warp == 0) { // Load x - for (auto d = lane * load_t::count; d < DIM; d += warpSize * load_t::count) { - auto* dst = reinterpret_cast(&sram.x[d]); - *dst = *reinterpret_cast(&x[batch * params.x_stride_batch + head * DIM + d]); - } - } else if (warp == 1) { // Load B - for (auto i = lane * load_t::count; i < DSTATE; i += warpSize * load_t::count) { - auto* dst = reinterpret_cast(&sram.B[i]); - *dst = *reinterpret_cast( - &B[batch * params.B_stride_batch + group * DSTATE + i]); - } - } else if (warp == 2) { // Load z - for (auto d = lane * load_t::count; d < DIM; d += warpSize * load_t::count) { - auto* dst = reinterpret_cast(&sram.z[d]); - *dst = - z ? *reinterpret_cast(&z[batch * params.z_stride_batch + head * DIM + d]) - : make_zeros(); - } - } else if (warp == 3) { // Load C - for (auto i = lane * load_t::count; i < DSTATE; i += warpSize * load_t::count) { - auto* dst = reinterpret_cast(&sram.C[i]); - *dst = *reinterpret_cast( - &C[batch * params.C_stride_batch + group * DSTATE + i]); - } - } - - sram.bar_consumers.wait(sram.bar_consumers.arrive()); - if (state_batch != params.pad_slot_id) consumer_func_vertical(lane, warp, d_value, dt_value, dA, @@ -520,7 +542,7 @@ __global__ void selective_state_update_kernel_producer_consumer_vertical( rowsPerStage, numStages, false>(lane, warp, d_value, dt_value, dA, sram); - // Write output + // Write output — wait for all consumer warps to finish writing sram.out sram.bar_consumers.wait(sram.bar_consumers.arrive()); auto d = warp * warpSize + lane; if (d < DIM) { @@ -890,16 +912,28 @@ void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, SSUAlgorithm // Common alignment checks for all kernels check_ptr_alignment_input_vars(params); - // Resolve auto to a concrete algorithm based on GPU architecture + // Resolve auto to a concrete algorithm based on GPU architecture and batch size SSUAlgorithm algo = algorithm; if (algo == SSUAlgorithm::kAuto) { #ifdef FLASHINFER_MAMBA_ENABLE_SM90 - if (sm_major < 9) + if (sm_major < 9) { algo = SSUAlgorithm::kSimple; - else if (sm_major < 10) - algo = SSUAlgorithm::kVertical; - else - algo = SSUAlgorithm::kHorizontal; + } else { + // At small batch sizes, the tiled simple kernel outperforms producer-consumer + // kernels because it has lower per-block overhead and can still saturate the GPU + // via dim-tiling. Threshold: batch*nheads < 2*num_SMs (i.e. not enough blocks + // for the non-tiled producer-consumer kernels to hide latency). + int const total_blocks = params.batch * params.nheads; + int const num_sms = GetCudaMultiProcessorCount(); + if (total_blocks < num_sms * 2) + algo = SSUAlgorithm::kSimple; + else if (sm_major < 10) + algo = SSUAlgorithm::kVertical; + else + // On Blackwell+: vertical is slightly faster for fp32 state, + // horizontal is faster for fp16/bf16 state. + algo = (sizeof(state_t) == 4) ? SSUAlgorithm::kVertical : SSUAlgorithm::kHorizontal; + } #else algo = SSUAlgorithm::kSimple; #endif @@ -915,10 +949,26 @@ void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, SSUAlgorithm "state head stride must be aligned to ", sizeof(load_state_t), " bytes"); constexpr int numWarps = 4; + constexpr int ROWS_PER_BLOCK = 4; + int const total_blocks = params.batch * params.nheads; + int const num_sms = GetCudaMultiProcessorCount(); + dim3 block(warpSize, numWarps); - dim3 grid(params.batch, params.nheads); - selective_state_update_kernel_simple<<>>(params); + if (total_blocks < num_sms * 2) { + // Tiled: split dim across blocks for better GPU occupancy at small batch sizes + int const dim_tiles = (DIM + ROWS_PER_BLOCK - 1) / ROWS_PER_BLOCK; + dim3 grid(params.batch, params.nheads, dim_tiles); + selective_state_update_kernel_simple + <<>>(params); + } else { + // Non-tiled: enough blocks already for full occupancy; ROWS_PER_BLOCK == DIM so blockIdx.z == + // 0 + dim3 grid(params.batch, params.nheads); + selective_state_update_kernel_simple + <<>>(params); + } } #ifdef FLASHINFER_MAMBA_ENABLE_SM90 else if (algo == SSUAlgorithm::kVertical) { diff --git a/include/flashinfer/mamba/selective_state_update.cuh b/include/flashinfer/mamba/selective_state_update.cuh index b3517f848f..0607d7a0f2 100644 --- a/include/flashinfer/mamba/selective_state_update.cuh +++ b/include/flashinfer/mamba/selective_state_update.cuh @@ -30,6 +30,21 @@ enum class SSUAlgorithm : int32_t { kHorizontal = 3, }; +inline const char* SSUAlgorithmToString(SSUAlgorithm algo) { + switch (algo) { + case SSUAlgorithm::kAuto: + return "Auto"; + case SSUAlgorithm::kSimple: + return "Simple"; + case SSUAlgorithm::kVertical: + return "Vertical"; + case SSUAlgorithm::kHorizontal: + return "Horizontal"; + default: + return "Unknown"; + } +} + struct SelectiveStateUpdateParams { uint32_t batch{}, nheads{}, dim{}, dstate{}, ngroups{}, state_cache_size{}; int32_t pad_slot_id{-1}; diff --git a/tests/mamba/test_selective_state_update_stp.py b/tests/mamba/test_selective_state_update_stp.py index ef007eddd1..6b8b293d07 100644 --- a/tests/mamba/test_selective_state_update_stp.py +++ b/tests/mamba/test_selective_state_update_stp.py @@ -3,11 +3,21 @@ import torch import flashinfer +from flashinfer.utils import get_compute_capability from .selective_state_update_triton import selective_state_update_triton from .utils import create_test_inputs, clone_preserving_strides +def _get_algorithms(): + """Return list of algorithms supported on the current GPU.""" + major, _ = get_compute_capability(torch.device("cuda")) + algos = ["simple"] + if major >= 9: + algos.extend(["vertical", "horizontal"]) + return algos + + # Base combination: batch=64, nheads=64, dim=64, dstate=128, state_dtype=bf16, # weight_dtype=f32, use_out_tensor=True # Each additional row varies exactly one parameter from the base. @@ -73,7 +83,7 @@ def make_reference_output(self, inputs): ) return y_ref, state_ref - def run_kernel(self, inputs, out=None): + def run_kernel(self, inputs, out=None, algorithm="auto"): """Run the flashinfer kernel and return output.""" return flashinfer.mamba.selective_state_update( inputs["state_cache"], @@ -89,6 +99,7 @@ def run_kernel(self, inputs, out=None): state_batch_indices=inputs["slot_idx"], pad_slot_id=-1, out=out, + algorithm=algorithm, ) def assert_outputs_match(self, y_ref, y_test, msg_prefix=""): @@ -155,11 +166,20 @@ def _print_mismatch_details(self, ref, test, name): f"diff={diff:.6e}, rel_diff={rel_diff:.6e}" ) + @pytest.mark.parametrize("algorithm", _get_algorithms()) @pytest.mark.parametrize( "batch,nheads,dim,dstate,state_dtype,weight_dtype,use_out_tensor", _BASE_PARAMS ) def test_output_correctness( - self, batch, nheads, dim, dstate, state_dtype, weight_dtype, use_out_tensor + self, + batch, + nheads, + dim, + dstate, + state_dtype, + weight_dtype, + use_out_tensor, + algorithm, ): """Test that kernel output matches reference within tolerance.""" inputs = self.make_inputs(batch, nheads, dim, dstate, state_dtype, weight_dtype) @@ -171,7 +191,7 @@ def test_output_correctness( else: out = None - y_test = self.run_kernel(inputs, out=out) + y_test = self.run_kernel(inputs, out=out, algorithm=algorithm) # Verify output tensor identity if provided if use_out_tensor: @@ -179,8 +199,13 @@ def test_output_correctness( "Returned tensor should be the same object as the provided output tensor" ) - self.assert_outputs_match(y_ref, y_test) - self.assert_states_match(state_ref, inputs["state_cache"], inputs["slot_idx"]) + self.assert_outputs_match(y_ref, y_test, msg_prefix=f"[{algorithm}] ") + self.assert_states_match( + state_ref, + inputs["state_cache"], + inputs["slot_idx"], + msg_prefix=f"[{algorithm}] ", + ) class TestSelectiveStateUpdateWithZ(TestSelectiveStateUpdate): @@ -202,22 +227,38 @@ def make_inputs(self, batch, nheads, dim, dstate, state_dtype, weight_dtype): seed=0, ) + @pytest.mark.parametrize("algorithm", _get_algorithms()) @pytest.mark.parametrize( "batch,nheads,dim,dstate,state_dtype,weight_dtype,use_out_tensor", [(64, 64, 64, 128, torch.bfloat16, torch.float32, True)], ) def test_output_correctness( - self, batch, nheads, dim, dstate, state_dtype, weight_dtype, use_out_tensor + self, + batch, + nheads, + dim, + dstate, + state_dtype, + weight_dtype, + use_out_tensor, + algorithm, ): super().test_output_correctness( - batch, nheads, dim, dstate, state_dtype, weight_dtype, use_out_tensor + batch, + nheads, + dim, + dstate, + state_dtype, + weight_dtype, + use_out_tensor, + algorithm, ) class TestSelectiveStateUpdateDisableStateUpdate(TestSelectiveStateUpdate): """Test selective_state_update with disable_state_update=True.""" - def run_kernel(self, inputs, out=None): + def run_kernel(self, inputs, out=None, algorithm="auto"): """Run the flashinfer kernel with disable_state_update=True.""" return flashinfer.mamba.selective_state_update( inputs["state_cache"], @@ -234,6 +275,7 @@ def run_kernel(self, inputs, out=None): pad_slot_id=-1, out=out, disable_state_update=True, + algorithm=algorithm, ) @pytest.mark.parametrize( @@ -241,7 +283,15 @@ def run_kernel(self, inputs, out=None): [(64, 64, 64, 128, torch.bfloat16, torch.float32, True)], ) def test_output_correctness( - self, batch, nheads, dim, dstate, state_dtype, weight_dtype, use_out_tensor + self, + batch, + nheads, + dim, + dstate, + state_dtype, + weight_dtype, + use_out_tensor, + algorithm="auto", ): """Test that kernel output matches reference but state is not updated.""" inputs = self.make_inputs(batch, nheads, dim, dstate, state_dtype, weight_dtype) @@ -256,7 +306,7 @@ def test_output_correctness( else: out = None - y_test = self.run_kernel(inputs, out=out) + y_test = self.run_kernel(inputs, out=out, algorithm=algorithm) # Verify output tensor identity if provided if use_out_tensor: @@ -336,22 +386,38 @@ def make_reference_output(self, inputs): ) return y_ref, state_ref + @pytest.mark.parametrize("algorithm", _get_algorithms()) @pytest.mark.parametrize( "batch,nheads,dim,dstate,state_dtype,weight_dtype,use_out_tensor", [(64, 64, 64, 128, torch.bfloat16, torch.float32, True)], ) def test_output_correctness( - self, batch, nheads, dim, dstate, state_dtype, weight_dtype, use_out_tensor + self, + batch, + nheads, + dim, + dstate, + state_dtype, + weight_dtype, + use_out_tensor, + algorithm, ): super().test_output_correctness( - batch, nheads, dim, dstate, state_dtype, weight_dtype, use_out_tensor + batch, + nheads, + dim, + dstate, + state_dtype, + weight_dtype, + use_out_tensor, + algorithm, ) class TestSelectiveStateUpdateInt32Indices(TestSelectiveStateUpdate): """Test selective_state_update with int32 state_batch_indices.""" - def run_kernel(self, inputs, out=None): + def run_kernel(self, inputs, out=None, algorithm="auto"): """Run the flashinfer kernel with int32 state_batch_indices.""" # Cast slot_idx to int32 slot_idx_int32 = inputs["slot_idx"].to(torch.int32) @@ -370,17 +436,34 @@ def run_kernel(self, inputs, out=None): state_batch_indices=slot_idx_int32, pad_slot_id=-1, out=out, + algorithm=algorithm, ) + @pytest.mark.parametrize("algorithm", _get_algorithms()) @pytest.mark.parametrize( "batch,nheads,dim,dstate,state_dtype,weight_dtype,use_out_tensor", [(64, 64, 64, 128, torch.bfloat16, torch.float32, True)], ) def test_output_correctness( - self, batch, nheads, dim, dstate, state_dtype, weight_dtype, use_out_tensor + self, + batch, + nheads, + dim, + dstate, + state_dtype, + weight_dtype, + use_out_tensor, + algorithm, ): super().test_output_correctness( - batch, nheads, dim, dstate, state_dtype, weight_dtype, use_out_tensor + batch, + nheads, + dim, + dstate, + state_dtype, + weight_dtype, + use_out_tensor, + algorithm, ) From 1d42007ad43cf1d6a34a40f12b1f9341f283f77f Mon Sep 17 00:00:00 2001 From: Igor Shovkun Date: Wed, 18 Feb 2026 22:01:44 -0800 Subject: [PATCH 22/33] Update selective_state_update instantiations to include SSUAlgorithm param --- csrc/selective_state_update_dtype_inst.jinja | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/selective_state_update_dtype_inst.jinja b/csrc/selective_state_update_dtype_inst.jinja index 4879afaed2..02dd66322b 100644 --- a/csrc/selective_state_update_dtype_inst.jinja +++ b/csrc/selective_state_update_dtype_inst.jinja @@ -22,11 +22,11 @@ namespace flashinfer::mamba { template void invokeSelectiveStateUpdate<{{ input_dtype }}, {{ weight_dtype }}, {{ matrixA_dtype }}, {{ state_dtype }}, {{ stateIndex_dtype }}>( - SelectiveStateUpdateParams& params, cudaStream_t stream); + SelectiveStateUpdateParams& params, SSUAlgorithm algo, cudaStream_t stream); namespace mtp { template void invokeSelectiveStateUpdateMTP<{{ input_dtype }}, {{ weight_dtype }}, {{ matrixA_dtype }}, {{ state_dtype }}, {{ stateIndex_dtype }}>( - SelectiveStateMTPParams& params, cudaStream_t stream); + SelectiveStateMTPParams& params, SSUAlgorithm algo, cudaStream_t stream); } // namespace mtp } // namespace flashinfer::mamba From 61d88bd26387bd9bd43886ab383c1640d5c9932c Mon Sep 17 00:00:00 2001 From: Igor Shovkun Date: Wed, 18 Feb 2026 22:07:40 -0800 Subject: [PATCH 23/33] Clarify algorithm selection docstring in selective_state_update --- flashinfer/mamba/selective_state_update.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flashinfer/mamba/selective_state_update.py b/flashinfer/mamba/selective_state_update.py index 25f47cb68b..294330be88 100644 --- a/flashinfer/mamba/selective_state_update.py +++ b/flashinfer/mamba/selective_state_update.py @@ -149,9 +149,9 @@ def selective_state_update( cache_steps : int Number of steps/tokens to cache for speculative decoding algorithm : str - Algorithm to use: "auto" (default, selects based on GPU arch), - "simple" (all GPUs), "vertical" (SM90+), "horizontal" (SM100+). - MTP mode only supports "auto" or "simple". + Algorithm to use: "auto" (default, picks the best kernel based on GPU arch, + data types, and problem size), "simple" (all GPUs), "vertical" and "horizontal" + (SM90+ only). MTP mode only supports "auto" or "simple". Returns ------- From 6f6a3d78021e80f076a2f8d6e5d4f359a3f2ebbe Mon Sep 17 00:00:00 2001 From: Igor Shovkun Date: Wed, 18 Feb 2026 22:20:24 -0800 Subject: [PATCH 24/33] Remove chunk scan combined kernels as they are irrelevant to this PR --- tests/mamba/test_chunk_scan_combined.py | 703 ------------------ tests/mamba/triton_reference/ssd_bmm.py | 272 ------- .../mamba/triton_reference/ssd_chunk_scan.py | 628 ---------------- tests/mamba/triton_reference/ssd_combined.py | 265 ------- .../triton_reference/ssd_state_passing.py | 282 ------- 5 files changed, 2150 deletions(-) delete mode 100644 tests/mamba/test_chunk_scan_combined.py delete mode 100644 tests/mamba/triton_reference/ssd_bmm.py delete mode 100644 tests/mamba/triton_reference/ssd_chunk_scan.py delete mode 100644 tests/mamba/triton_reference/ssd_combined.py delete mode 100644 tests/mamba/triton_reference/ssd_state_passing.py diff --git a/tests/mamba/test_chunk_scan_combined.py b/tests/mamba/test_chunk_scan_combined.py deleted file mode 100644 index dd51b8de58..0000000000 --- a/tests/mamba/test_chunk_scan_combined.py +++ /dev/null @@ -1,703 +0,0 @@ -""" -Test for Mamba2 SSD (Structured State-Space Duality) chunk scan combined kernel. - -Compares the CUTLASS CuTe DSL Blackwell implementation against the production -Triton implementation. -""" - -import sys -from pathlib import Path - -import numpy as np -import pytest -import torch - -# Add CUTLASS mamba2_ssd path -CUTLASS_MAMBA2_SSD_PATH = ( - Path(__file__).resolve().parents[2] - / "3rdparty" - / "cutlass" - / "examples" - / "python" - / "CuTeDSL" - / "blackwell" - / "mamba2_ssd" -) -sys.path.insert(0, str(CUTLASS_MAMBA2_SSD_PATH)) - -# Import Triton reference -from .triton_reference.ssd_chunk_state import _chunk_cumsum_fwd -from .triton_reference.ssd_combined import _mamba_chunk_scan_combined_fwd - - -def is_blackwell_available(): - """Check if Blackwell GPU (SM100) is available.""" - if not torch.cuda.is_available(): - return False - major, minor = torch.cuda.get_device_capability() - return major >= 10 # SM100 = Blackwell - - -# Skip all tests if not on Blackwell -pytestmark = pytest.mark.skipif( - not is_blackwell_available(), - reason="Blackwell GPU (SM100+) required for CuTe DSL Mamba2 SSD kernel", -) - - -def import_cutlass_modules(): - """Import CUTLASS modules (only when needed, as they require SM100).""" - import cuda.bindings.driver as cuda - import cutlass - import cutlass.cute as cute - import cutlass.torch as cutlass_torch - from cutlass.cute.runtime import from_dlpack - from mamba2_ssd import SSDKernel - - return { - "cuda": cuda, - "cutlass": cutlass, - "cute": cute, - "cutlass_torch": cutlass_torch, - "from_dlpack": from_dlpack, - "SSDKernel": SSDKernel, - } - - -class CutlassSSDWrapper: - """ - Wrapper around CUTLASS CuTe DSL SSD kernel to match Triton API. - - The CUTLASS kernel expects: - - Preprocessed cumsum_delta (step 1 already done) - - Specific tensor layouts different from Triton - - Output tensors preallocated - - This wrapper: - 1. Computes cumsum using Triton's step 1 (_chunk_cumsum_fwd) - 2. Converts tensors to CUTLASS layout - 3. Calls CUTLASS kernel - 4. Converts output back to Triton layout - """ - - def __init__( - self, - chunk_size: int, - headdim: int, - dstate: int, - has_d: bool = True, - d_has_hdim: bool = False, - io_dtype=None, - cumsum_dtype=None, - acc_dtype=None, - ): - """ - Initialize the wrapper. - - Args: - chunk_size: L - size of each chunk - headdim: D - head dimension - dstate: N - state dimension - has_d: Whether to fuse D scaling (Y += X*D) - d_has_hdim: If True, D is (headdim, nheads), else (1, nheads) - io_dtype: Input/output dtype (default: cutlass.BFloat16) - cumsum_dtype: Cumsum intermediate dtype (default: cutlass.Float32) - acc_dtype: Accumulator dtype (default: cutlass.Float32) - """ - self.modules = import_cutlass_modules() - cutlass = self.modules["cutlass"] - - self.chunk_size = chunk_size - self.headdim = headdim - self.dstate = dstate - self.has_d = has_d - self.d_has_hdim = d_has_hdim - - self.io_dtype = io_dtype or cutlass.BFloat16 - self.cumsum_dtype = cumsum_dtype or cutlass.Float32 - self.acc_dtype = acc_dtype or cutlass.Float32 - - # Create the kernel - SSDKernel = self.modules["SSDKernel"] - self.kernel = SSDKernel( - self.io_dtype, - self.cumsum_dtype, - self.acc_dtype, - chunk_size, - headdim, - dstate, - has_d, - d_has_hdim, - ) - - self._compiled_kernel = None - - def _create_cutlass_tensor(self, shape, permute_order, dtype, dynamic_modes): - """ - Create a tensor using the exact logic from mamba2_ssd.py to ensure compatibility. - - Args: - shape: Base shape of the tensor (before permutation) - permute_order: Order to permute dimensions - dtype: CUTLASS dtype - dynamic_modes: List of modes to mark as dynamic - - Returns: - (cute_tensor, torch_tensor): The CuTe tensor wrapper and the underlying PyTorch tensor on GPU - """ - cutlass_torch = self.modules["cutlass_torch"] - from_dlpack = self.modules["from_dlpack"] - - # Create a dummy CPU tensor with the base layout to establish the permutation pattern - # mimicking create_and_permute_tensor from mamba2_ssd.py - base_tensor = torch.empty(*shape, dtype=torch.float32) - permuted_tensor = base_tensor.permute(permute_order) - - # Move to GPU with target dtype - this creates the specific layout CUTLASS expects - torch_dtype = cutlass_torch.dtype(dtype) - dst_tensor = permuted_tensor.to(torch_dtype).cuda() - - # Create CuTe tensor - cute_tensor = from_dlpack(dst_tensor, assumed_align=16) - for mode in dynamic_modes: - cute_tensor = cute_tensor.mark_compact_shape_dynamic( - mode=mode, stride_order=dst_tensor.dim_order() - ) - - return cute_tensor, dst_tensor - - def __call__( - self, - x, - dt, - A, - B, - C, - chunk_size, - D=None, - z=None, - dt_bias=None, - initial_states=None, - seq_idx=None, - dt_softplus=False, - dt_limit=(0.0, float("inf")), - ): - """ - Run the SSD kernel with Triton-compatible API. - - Args: - x: (batch, seqlen, nheads, headdim) - dt: (batch, seqlen, nheads) - A: (nheads,) - B: (batch, seqlen, ngroups, dstate) - C: (batch, seqlen, ngroups, dstate) - chunk_size: Size of chunks - D: Optional (nheads, headdim) or (nheads,) - z: Optional gating tensor (not supported yet) - dt_bias: Optional (nheads,) - initial_states: Optional (batch, nheads, headdim, dstate) - seq_idx: Optional sequence indices (not supported yet) - dt_softplus: Whether to apply softplus to dt - dt_limit: Limits for dt values - - Returns: - out: (batch, seqlen, nheads, headdim) - final_states: (batch, nheads, headdim, dstate) - """ - cutlass = self.modules["cutlass"] - cute = self.modules["cute"] - - # Validate inputs - batch, seqlen, nheads, headdim = x.shape - _, _, ngroups, dstate = B.shape - nchunks = seqlen // chunk_size - - assert seqlen % chunk_size == 0, ( - f"seqlen ({seqlen}) must be divisible by chunk_size ({chunk_size})" - ) - assert headdim == self.headdim, f"headdim mismatch: {headdim} vs {self.headdim}" - assert dstate == self.dstate, f"dstate mismatch: {dstate} vs {self.dstate}" - assert chunk_size == self.chunk_size, ( - f"chunk_size mismatch: {chunk_size} vs {self.chunk_size}" - ) - - if z is not None: - raise NotImplementedError("z (gating) not yet supported in CUTLASS wrapper") - if seq_idx is not None: - raise NotImplementedError("seq_idx not yet supported in CUTLASS wrapper") - if initial_states is not None: - raise NotImplementedError( - "initial_states not yet supported in CUTLASS wrapper" - ) - - # Step 1: Compute cumsum using Triton kernel - # dA_cumsum: (batch, nheads, nchunks, chunk_size) - # dt_processed: (batch, nheads, nchunks, chunk_size) - after softplus/bias - dA_cumsum, dt_processed = _chunk_cumsum_fwd( - dt, - A, - chunk_size, - dt_bias=dt_bias, - dt_softplus=dt_softplus, - dt_limit=dt_limit, - ) - - # Convert tensors to CUTLASS layout using the same pattern as mamba2_ssd.py - # Key: create contiguous tensor in base shape, then permute to get correct strides - # CUTLASS expects specific permuted layouts for each tensor - - # x: Triton (batch, seqlen, nheads, headdim) -> CUTLASS (headdim, chunk_size, nchunks, nheads, batch) - x_reshaped = x.reshape(batch, nchunks, chunk_size, nheads, headdim) - x_tensor, x_dst = self._create_cutlass_tensor( - [batch, nheads, headdim, nchunks, chunk_size], - [2, 4, 3, 1, 0], - self.io_dtype, - [2, 3, 4], - ) - x_dst.copy_(x_reshaped.permute(4, 2, 1, 3, 0).to(x_dst.dtype)) - - # delta (dt_processed): (batch, nheads, nchunks, chunk_size) -> (chunk_size, nchunks, nheads, batch) - delta_tensor, delta_dst = self._create_cutlass_tensor( - [batch, nheads, nchunks, chunk_size], [3, 2, 1, 0], self.io_dtype, [1, 2, 3] - ) - delta_dst.copy_(dt_processed.permute(3, 2, 1, 0).to(delta_dst.dtype)) - - # cumsum_delta (dA_cumsum): same layout as delta - cumsum_delta_tensor, cumsum_delta_dst = self._create_cutlass_tensor( - [batch, nheads, nchunks, chunk_size], - [3, 2, 1, 0], - self.cumsum_dtype, - [1, 2, 3], - ) - cumsum_delta_dst.copy_(dA_cumsum.permute(3, 2, 1, 0).to(cumsum_delta_dst.dtype)) - - # B: Triton (batch, seqlen, ngroups, dstate) -> CUTLASS (chunk_size, dstate, nchunks, ngroups, batch) - B_reshaped = B.reshape(batch, nchunks, chunk_size, ngroups, dstate) - b_tensor, b_dst = self._create_cutlass_tensor( - [batch, ngroups, dstate, nchunks, chunk_size], - [4, 2, 3, 1, 0], - self.io_dtype, - [2, 3, 4], - ) - b_dst.copy_(B_reshaped.permute(2, 4, 1, 3, 0).to(b_dst.dtype)) - - # C: same layout as B - C_reshaped = C.reshape(batch, nchunks, chunk_size, ngroups, dstate) - c_tensor, c_dst = self._create_cutlass_tensor( - [batch, ngroups, dstate, nchunks, chunk_size], - [4, 2, 3, 1, 0], - self.io_dtype, - [2, 3, 4], - ) - c_dst.copy_(C_reshaped.permute(2, 4, 1, 3, 0).to(c_dst.dtype)) - - # D: (nheads,) -> CUTLASS (1, nheads) or (headdim, nheads) - if self.has_d and D is not None: - if self.d_has_hdim: - # D is (nheads, headdim) -> (headdim, nheads) - if D.dim() == 1: - D = D.unsqueeze(1).expand(-1, headdim) - d_tensor, d_dst = self._create_cutlass_tensor( - [nheads, headdim], [1, 0], self.io_dtype, [1] - ) - d_dst.copy_(D.t().to(d_dst.dtype)) - else: - # D is (nheads,) -> (1, nheads) - if D.dim() == 2: - D = D[:, 0] - d_tensor, d_dst = self._create_cutlass_tensor( - [nheads, 1], [1, 0], self.io_dtype, [1] - ) - d_dst.copy_(D.unsqueeze(0).to(d_dst.dtype)) - else: - d_tensor = None - - # Output tensors - # y: (chunk_size, headdim, nchunks, nheads, batch) - y_tensor, y_cutlass = self._create_cutlass_tensor( - [batch, nheads, headdim, nchunks, chunk_size], - [4, 2, 3, 1, 0], - self.io_dtype, - [2, 3, 4], - ) - - # fstate: (headdim, dstate, nheads, batch) - fstate_tensor, fstate_cutlass = self._create_cutlass_tensor( - [batch, nheads, headdim, dstate], [2, 3, 1, 0], self.io_dtype, [2, 3] - ) - - # Get max active clusters - hardware_info = cutlass.utils.HardwareInfo() - max_active_clusters = hardware_info.get_max_active_clusters(1) - - stream = cutlass.cuda.default_stream() - - # Compile kernel if not already done - if self._compiled_kernel is None: - self._compiled_kernel = cute.compile( - self.kernel, - x_tensor, - cumsum_delta_tensor, - delta_tensor, - b_tensor, - c_tensor, - y_tensor, - fstate_tensor, - d_tensor, - max_active_clusters, - stream, - ) - - # Run kernel - self._compiled_kernel( - x_tensor, - cumsum_delta_tensor, - delta_tensor, - b_tensor, - c_tensor, - y_tensor, - fstate_tensor, - d_tensor, - stream, - ) - - # Convert outputs back to Triton layout - # y_cutlass is (L, D, C, EH, B) - # We need to map it back to (batch, seqlen, nheads, headdim) - # Permute (L, D, C, EH, B) -> (B, C, L, EH, D) - y_permuted = y_cutlass.permute(4, 2, 0, 3, 1) - y_out = y_permuted.reshape(batch, seqlen, nheads, headdim) - - # fstate_cutlass is (D, N, EH, B) - # We need (batch, nheads, headdim, dstate) - # Permute (D, N, EH, B) -> (B, EH, D, N) - fstate_out = fstate_cutlass.permute(3, 2, 0, 1).contiguous() - - return y_out, fstate_out - - -class TestChunkScanCombined: - """Test class for chunk scan combined kernel.""" - - # Test configuration - slightly relaxed tolerance for bf16 precision - ATOL = 5e-2 - RTOL = 5e-2 - INPUT_DTYPE = torch.bfloat16 - - @pytest.fixture(params=[1, 2]) - def batch(self, request): - return request.param - - @pytest.fixture(params=[8]) # nheads must be divisible by ngroups - def nheads(self, request): - return request.param - - @pytest.fixture(params=[64]) # Must match kernel's D - def headdim(self, request): - return request.param - - @pytest.fixture( - params=[128] - ) # Must match kernel's N (CUTLASS kernel is hardcoded for N=128) - def dstate(self, request): - return request.param - - @pytest.fixture(params=[128]) # Must match kernel's L - def chunk_size(self, request): - return request.param - - @pytest.fixture(params=[1, 4]) # Number of chunks - def nchunks(self, request): - return request.param - - @pytest.fixture(params=[8]) # ngroups divides nheads - def ngroups(self, request): - return request.param - - @pytest.fixture - def inputs(self, batch, nheads, headdim, dstate, chunk_size, nchunks, ngroups): - """Create test inputs.""" - torch.manual_seed(42) - - seqlen = chunk_size * nchunks - - # x: (batch, seqlen, nheads, headdim) - x = torch.randn( - batch, seqlen, nheads, headdim, dtype=self.INPUT_DTYPE, device="cuda" - ) - - # dt: (batch, seqlen, nheads) - dt = torch.randn(batch, seqlen, nheads, dtype=torch.float32, device="cuda") - - # A: (nheads,) - should be negative for stability - A = -torch.rand(nheads, dtype=torch.float32, device="cuda") - 1.0 - - # B: (batch, seqlen, ngroups, dstate) - B = torch.randn( - batch, seqlen, ngroups, dstate, dtype=self.INPUT_DTYPE, device="cuda" - ) - - # C: (batch, seqlen, ngroups, dstate) - C = torch.randn( - batch, seqlen, ngroups, dstate, dtype=self.INPUT_DTYPE, device="cuda" - ) - - # D: (nheads, headdim) or (nheads,) - D = torch.randn(nheads, dtype=self.INPUT_DTYPE, device="cuda") - - # dt_bias: (nheads,) - dt_bias = torch.rand(nheads, dtype=torch.float32, device="cuda") - 4.0 - - return { - "x": x, - "dt": dt, - "A": A, - "B": B, - "C": C, - "D": D, - "dt_bias": dt_bias, - "chunk_size": chunk_size, - "seqlen": seqlen, - "nheads": nheads, - "headdim": headdim, - "dstate": dstate, - "ngroups": ngroups, - } - - @pytest.fixture - def reference_output(self, inputs): - """Compute reference output using Triton implementation.""" - out, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd( - inputs["x"], - inputs["dt"], - inputs["A"], - inputs["B"], - inputs["C"], - inputs["chunk_size"], - D=inputs["D"], - z=None, - dt_bias=inputs["dt_bias"], - initial_states=None, - seq_idx=None, - dt_softplus=True, - ) - return out, final_states - - def _print_mismatch_details(self, ref, test, name, atol, rtol): - """Print detailed mismatch analysis.""" - ref_np = ref.detach().cpu().float().numpy() - test_np = test.detach().cpu().float().numpy() - - mismatch_mask = ~np.isclose(ref_np, test_np, atol=atol, rtol=rtol) - num_mismatches = np.sum(mismatch_mask) - total_elements = ref_np.size - - print(f"\nDetailed {name} mismatch analysis:") - print( - f"Number of mismatched elements: {num_mismatches} / {total_elements} " - f"({100 * num_mismatches / total_elements:.2f}%)" - ) - - if num_mismatches > 0: - mismatch_indices = np.argwhere(mismatch_mask) - print(f"First few {name} mismatch locations (up to 10):") - for idx in mismatch_indices[:10]: - idx_tuple = tuple(int(i) for i in idx) - ref_val = ref_np[idx_tuple] - test_val = test_np[idx_tuple] - diff = abs(ref_val - test_val) - rel_diff = diff / (abs(ref_val) + 1e-8) - print( - f" Index {idx_tuple}: ref={ref_val:.6f}, test={test_val:.6f}, " - f"diff={diff:.6e}, rel_diff={rel_diff:.6e}" - ) - - def test_output_correctness(self, inputs, reference_output): - """Test that CUTLASS kernel output matches Triton reference.""" - out_ref, final_states_ref = reference_output - - # Create CUTLASS wrapper - wrapper = CutlassSSDWrapper( - chunk_size=inputs["chunk_size"], - headdim=inputs["headdim"], - dstate=inputs["dstate"], - has_d=True, - d_has_hdim=False, # D is (nheads,) not (nheads, headdim) - ) - - # Run CUTLASS kernel - out_test, final_states_test = wrapper( - inputs["x"], - inputs["dt"], - inputs["A"], - inputs["B"], - inputs["C"], - inputs["chunk_size"], - D=inputs["D"], - dt_bias=inputs["dt_bias"], - dt_softplus=True, - ) - - # Compare outputs - cast to same dtype for comparison - out_ref_cmp = out_ref.to(out_test.dtype) - out_match = torch.allclose( - out_ref_cmp, out_test, atol=self.ATOL, rtol=self.RTOL - ) - - if out_match: - print( - f"✓ Outputs match within tolerance (atol={self.ATOL}, rtol={self.RTOL})" - ) - else: - print("✗ Outputs do NOT match within tolerance") - self._print_mismatch_details( - out_ref_cmp, out_test, "output", self.ATOL, self.RTOL - ) - - # Compare final states - cast to same dtype for comparison - final_states_ref_cmp = final_states_ref.to(final_states_test.dtype) - states_match = torch.allclose( - final_states_ref_cmp, final_states_test, atol=self.ATOL, rtol=self.RTOL - ) - - if states_match: - print( - f"✓ Final states match within tolerance (atol={self.ATOL}, rtol={self.RTOL})" - ) - else: - print("✗ Final states do NOT match within tolerance") - self._print_mismatch_details( - final_states_ref_cmp, - final_states_test, - "final_states", - self.ATOL, - self.RTOL, - ) - - assert out_match, "Output mismatch between CUTLASS and Triton" - assert states_match, "Final states mismatch between CUTLASS and Triton" - - -class TestChunkScanCombinedNoD(TestChunkScanCombined): - """Test chunk scan without D scaling.""" - - @pytest.fixture(params=[1]) - def batch(self, request): - return request.param - - @pytest.fixture(params=[1]) - def nchunks(self, request): - return request.param - - @pytest.fixture - def inputs(self, batch, nheads, headdim, dstate, chunk_size, nchunks, ngroups): - """Create test inputs without D.""" - torch.manual_seed(42) - - seqlen = chunk_size * nchunks - - x = torch.randn( - batch, seqlen, nheads, headdim, dtype=self.INPUT_DTYPE, device="cuda" - ) - dt = torch.randn(batch, seqlen, nheads, dtype=torch.float32, device="cuda") - A = -torch.rand(nheads, dtype=torch.float32, device="cuda") - 1.0 - B = torch.randn( - batch, seqlen, ngroups, dstate, dtype=self.INPUT_DTYPE, device="cuda" - ) - C = torch.randn( - batch, seqlen, ngroups, dstate, dtype=self.INPUT_DTYPE, device="cuda" - ) - dt_bias = torch.rand(nheads, dtype=torch.float32, device="cuda") - 4.0 - - return { - "x": x, - "dt": dt, - "A": A, - "B": B, - "C": C, - "D": None, - "dt_bias": dt_bias, - "chunk_size": chunk_size, - "seqlen": seqlen, - "nheads": nheads, - "headdim": headdim, - "dstate": dstate, - "ngroups": ngroups, - } - - @pytest.fixture - def reference_output(self, inputs): - """Compute reference output using Triton implementation without D.""" - out, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd( - inputs["x"], - inputs["dt"], - inputs["A"], - inputs["B"], - inputs["C"], - inputs["chunk_size"], - D=None, - z=None, - dt_bias=inputs["dt_bias"], - initial_states=None, - seq_idx=None, - dt_softplus=True, - ) - return out, final_states - - def test_output_correctness(self, inputs, reference_output): - """Test without D scaling.""" - out_ref, final_states_ref = reference_output - - wrapper = CutlassSSDWrapper( - chunk_size=inputs["chunk_size"], - headdim=inputs["headdim"], - dstate=inputs["dstate"], - has_d=False, - d_has_hdim=False, - ) - - out_test, final_states_test = wrapper( - inputs["x"], - inputs["dt"], - inputs["A"], - inputs["B"], - inputs["C"], - inputs["chunk_size"], - D=None, - dt_bias=inputs["dt_bias"], - dt_softplus=True, - ) - - # Cast to same dtype for comparison - out_ref_cmp = out_ref.to(out_test.dtype) - final_states_ref_cmp = final_states_ref.to(final_states_test.dtype) - - out_match = torch.allclose( - out_ref_cmp, out_test, atol=self.ATOL, rtol=self.RTOL - ) - states_match = torch.allclose( - final_states_ref_cmp, final_states_test, atol=self.ATOL, rtol=self.RTOL - ) - - if out_match: - print("✓ [NoD] Outputs match within tolerance") - else: - print("✗ [NoD] Outputs do NOT match") - self._print_mismatch_details( - out_ref_cmp, out_test, "output", self.ATOL, self.RTOL - ) - - if states_match: - print("✓ [NoD] Final states match within tolerance") - else: - print("✗ [NoD] Final states do NOT match") - self._print_mismatch_details( - final_states_ref_cmp, - final_states_test, - "final_states", - self.ATOL, - self.RTOL, - ) - - assert out_match - assert states_match diff --git a/tests/mamba/triton_reference/ssd_bmm.py b/tests/mamba/triton_reference/ssd_bmm.py deleted file mode 100644 index 846949e0ed..0000000000 --- a/tests/mamba/triton_reference/ssd_bmm.py +++ /dev/null @@ -1,272 +0,0 @@ -# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_bmm.py -# Copyright (c) 2024, Tri Dao, Albert Gu. -# -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math - -import torch -import triton -import triton.language as tl - - -@triton.autotune( - configs=[ - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, - num_stages=3, - num_warps=8, - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, - num_stages=5, - num_warps=2, - ), - triton.Config( - {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, - num_stages=5, - num_warps=2, - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=2, - ), - ], - key=["chunk_size", "K", "IS_CAUSAL"], -) -@triton.jit -def _bmm_chunk_fwd_kernel( - # Pointers to matrices - a_ptr, - b_ptr, - out_ptr, - seq_idx_ptr, - # Matrix dimensions - seqlen, - chunk_size, - K, - ngroups, - stride_a_batch, - stride_a_seqlen, - stride_a_head, - stride_ak, - stride_b_batch, - stride_b_seqlen, - stride_b_head, - stride_bk, - stride_out_batch, - stride_out_chunk, - stride_out_head, - stride_outm, - stride_outn, - stride_seq_idx_batch, - stride_seq_idx_seqlen, - # Meta-parameters - IS_CAUSAL: tl.constexpr, - dot_dtype: tl.constexpr, - HAS_SEQ_IDX: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, -): - pid_b = tl.program_id(axis=1) - pid_ch = tl.program_id(axis=2).to(tl.int64) - pid_c = pid_ch // ngroups - pid_h = pid_ch - pid_c * ngroups - num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) - pid_m = tl.program_id(axis=0) // num_pid_n - pid_n = tl.program_id(axis=0) % num_pid_n - if IS_CAUSAL: - if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M: - return - a_ptr += ( - pid_b * stride_a_batch - + pid_c * chunk_size * stride_a_seqlen - + pid_h * stride_a_head - ) - b_ptr += ( - pid_b * stride_b_batch - + pid_c * chunk_size * stride_b_seqlen - + pid_h * stride_b_head - ) - if HAS_SEQ_IDX: - seq_idx_ptr += ( - pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen - ) - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak) - b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen) - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load( - a_ptrs, - mask=(offs_m[:, None] < chunk_size_limit) - & (offs_k[None, :] < K - k * BLOCK_SIZE_K), - other=0.0, - ).to(dot_dtype) - b = tl.load( - b_ptrs, - mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) - & (offs_n[None, :] < chunk_size_limit), - other=0.0, - ).to(dot_dtype) - acc += tl.dot(a, b) - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * stride_bk - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - if HAS_SEQ_IDX: - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - seq_idx_m = tl.load( - seq_idx_ptr + offs_m * stride_seq_idx_seqlen, - mask=offs_m < chunk_size_limit, - other=-1, - ) - seq_idx_n = tl.load( - seq_idx_ptr + offs_n * stride_seq_idx_seqlen, - mask=offs_n < chunk_size_limit, - other=-2, - ) - acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0) - out = acc.to(out_ptr.dtype.element_ty) - - out_ptr += ( - pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head - ) - out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn) - tl.store( - out_ptrs, - out, - mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size), - ) - - -def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None): - """ - Argument: - a: (batch, seqlen, k) or (batch, seqlen, ngroups, k) - b: (batch, seqlen, k) or (batch, seqlen, ngroups, k) - seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out. - causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are - guaranteed to be correct. - Return: - out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size) - """ - # Check constraints. - has_groups = a.dim() == 4 - if not has_groups: - batch, seqlen, k = a.shape - else: - batch, seqlen, ngroups, k = a.shape - assert b.shape == a.shape - if seq_idx is not None: - assert seq_idx.shape == (batch, seqlen) - if a.stride(-1) != 1 and a.stride(1) != 1: - a = a.contiguous() - if b.stride(-1) != 1 and b.stride(1) != 1: - b = b.contiguous() - nchunks = math.ceil(seqlen / chunk_size) - # Allocates output. - out_dtype = a.dtype if output_dtype is None else output_dtype - out = torch.empty( - ( - (batch, nchunks, chunk_size, chunk_size) - if not has_groups - else (batch, nchunks, ngroups, chunk_size, chunk_size) - ), - device=a.device, - dtype=out_dtype, - ) - dot_dtype = ( - tl.bfloat16 - if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 - else ( - tl.float16 - if a.dtype == torch.float16 or b.dtype == torch.float16 - else tl.float32 - ) - ) - grid = lambda META: ( - triton.cdiv(chunk_size, META["BLOCK_SIZE_M"]) - * triton.cdiv(chunk_size, META["BLOCK_SIZE_N"]), - batch, - nchunks if not has_groups else nchunks * ngroups, - ) - with torch.cuda.device(a.device.index): - _bmm_chunk_fwd_kernel[grid]( - a, - b, - out, - seq_idx, - seqlen, - chunk_size, - k, - ngroups if has_groups else 1, - a.stride(0), - a.stride(1), - 0 if not has_groups else a.stride(2), - a.stride(-1), - b.stride(0), - b.stride(1), - 0 if not has_groups else b.stride(2), - b.stride(-1), - out.stride(0), - out.stride(1), - 0 if not has_groups else out.stride(2), - out.stride(-2), - out.stride(-1), - *( - (seq_idx.stride(0), seq_idx.stride(1)) - if seq_idx is not None - else (0, 0) - ), - causal, - dot_dtype, - HAS_SEQ_IDX=seq_idx is not None, - ) - return out diff --git a/tests/mamba/triton_reference/ssd_chunk_scan.py b/tests/mamba/triton_reference/ssd_chunk_scan.py deleted file mode 100644 index 67cdad636c..0000000000 --- a/tests/mamba/triton_reference/ssd_chunk_scan.py +++ /dev/null @@ -1,628 +0,0 @@ -# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_scan.py -# Copyright (c) 2024, Tri Dao, Albert Gu. -# -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import triton -import triton.language as tl -from packaging import version - -TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0") - - -@triton.autotune( - configs=[ - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, - num_stages=3, - num_warps=8, - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, - num_stages=5, - num_warps=2, - ), - triton.Config( - {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, - num_stages=5, - num_warps=2, - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=2, - ), - ], - key=["chunk_size", "hdim", "dstate", "IS_CAUSAL"], -) -@triton.jit -def _chunk_scan_fwd_kernel( - # Pointers to matrices - cb_ptr, - x_ptr, - z_ptr, - out_ptr, - out_x_ptr, - dt_ptr, - dA_cumsum_ptr, - seq_idx_ptr, - C_ptr, - states_ptr, - D_ptr, - initstates_ptr, - chunk_indices_ptr, - chunk_offsets_ptr, - chunk_meta_num, - # Matrix dimensions - chunk_size, - hdim, - dstate, - batch, - seqlen, - nheads_ngroups_ratio, - # Strides - stride_cb_batch, - stride_cb_chunk, - stride_cb_head, - stride_cb_csize_m, - stride_cb_csize_k, - stride_x_batch, - stride_x_seqlen, - stride_x_head, - stride_x_hdim, - stride_z_batch, - stride_z_seqlen, - stride_z_head, - stride_z_hdim, - stride_out_batch, - stride_out_seqlen, - stride_out_head, - stride_out_hdim, - stride_dt_batch, - stride_dt_chunk, - stride_dt_head, - stride_dt_csize, - stride_dA_cs_batch, - stride_dA_cs_chunk, - stride_dA_cs_head, - stride_dA_cs_csize, - stride_seq_idx_batch, - stride_seq_idx_seqlen, - stride_C_batch, - stride_C_seqlen, - stride_C_head, - stride_C_dstate, - stride_states_batch, - stride_states_chunk, - stride_states_head, - stride_states_hdim, - stride_states_dstate, - stride_init_states_batch, - stride_init_states_head, - stride_init_states_hdim, - stride_init_states_dstate, - stride_D_head, - # Meta-parameters - IS_CAUSAL: tl.constexpr, - HAS_D: tl.constexpr, - D_HAS_HDIM: tl.constexpr, - HAS_Z: tl.constexpr, - HAS_SEQ_IDX: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - BLOCK_SIZE_DSTATE: tl.constexpr, - IS_TRITON_22: tl.constexpr, - HAS_INITSTATES: tl.constexpr, -): - pid_bc = tl.program_id(axis=1).to(tl.int64) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch - if not HAS_INITSTATES: - c_idx = pid_c - c_off = 0 - else: - c_idx = tl.load(chunk_indices_ptr + pid_c, mask=pid_c > -1, other=0) - c_off = tl.load(chunk_offsets_ptr + pid_c, mask=pid_c > -1, other=0) - - pid_h = tl.program_id(axis=2) - num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) - pid_m = tl.program_id(axis=0) // num_pid_n - pid_n = tl.program_id(axis=0) % num_pid_n - cb_ptr += ( - pid_b * stride_cb_batch - + c_idx * stride_cb_chunk - + (pid_h // nheads_ngroups_ratio) * stride_cb_head - ) - x_ptr += ( - pid_b * stride_x_batch - + c_idx * chunk_size * stride_x_seqlen - + pid_h * stride_x_head - ) - dt_ptr += pid_b * stride_dt_batch + c_idx * stride_dt_chunk + pid_h * stride_dt_head - dA_cumsum_ptr += ( - pid_b * stride_dA_cs_batch - + c_idx * stride_dA_cs_chunk - + pid_h * stride_dA_cs_head - ) - C_ptr += ( - pid_b * stride_C_batch - + c_idx * chunk_size * stride_C_seqlen - + (pid_h // nheads_ngroups_ratio) * stride_C_head - ) - - # M-block offsets and prev states - # - logic in next block may override these if there is an active offset - offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M) - prev_states_ptr = ( - states_ptr - + pid_b * stride_states_batch - + c_idx * stride_states_chunk - + pid_h * stride_states_head - ) - prev_states_hdim = stride_states_hdim - prev_states_dstate = stride_states_dstate - - chunk_size_limit = min(chunk_size, seqlen - c_idx * chunk_size) - if HAS_SEQ_IDX: - seq_idx_ptr += ( - pid_b * stride_seq_idx_batch + c_idx * chunk_size * stride_seq_idx_seqlen - ) - - # - we only need seq_idx_prev to be aligned to chunk boundary - seq_idx_prev = tl.load( - seq_idx_ptr - stride_seq_idx_seqlen, mask=c_idx >= 1, other=0 - ) - - if HAS_INITSTATES: - # if there are init states, we only need seq_idx_m to point - # what is the current seq_idx - - # get current seq idx - if (pid_m * BLOCK_SIZE_M + c_off) < chunk_size_limit: - seq_idx_m = tl.load( - seq_idx_ptr - + (pid_m * BLOCK_SIZE_M + c_off) * stride_seq_idx_seqlen, - ) - - # - recall that in ssd_state_passing, for the case c_off == 0 - # i.e., the very first sequence, we made states_ptr hold its initial state - # so this edge case is taken care of - if ( - (c_off == 0) - and ( - seq_idx_prev != seq_idx_m - ) # if a seq is changed exactly on boundary - or (c_off > 0) # implies a new example (pseudo chunk) - ): - # - replace prev_states_ptr with init_states - prev_states_ptr = ( - initstates_ptr - + seq_idx_m * stride_init_states_batch - + pid_h * stride_init_states_head - ) - prev_states_hdim = stride_init_states_hdim # override strides - prev_states_dstate = stride_init_states_dstate - - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - dA_cs_m = tl.load( - dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0 - ).to(tl.float32) - - # - handle chunk state limit - if HAS_INITSTATES: - # have to split this if otherwise compilation will have problems - dA_cs_m_boundary = 0.0 - - # get the c_idx for the next (logica) chunk - c_idx_n = tl.load( - chunk_indices_ptr + (pid_c + 1), - mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num, - other=-1, # to trigger different chunk - ) - - # - there are things to consider - # A. if c_off > 0 then we need to move the dA_cs boundary to ensure correct - # contribution of past states - # B. if c_off_n < chunk_size_limit, then we need to adjust this so as not to - # encroach into the next sequence, where c_off_n is the offset of the next - # (logical) chunk. - # An equivalent check for B is c_idx == c_idx_n, where there is repetition in - # (logical) chunk indices. - - if (c_idx == c_idx_n) or c_off > 0: - # get the next offset - c_off_n = tl.load( - chunk_offsets_ptr + (pid_c + 1), - mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num, - other=chunk_size, - ) - - # in this case, adjust down the chunk_size_limit - if c_idx == c_idx_n: - chunk_size_limit = min(c_off_n, chunk_size_limit) - - # get the cs at the offset boundary - # - c_off == 0 is a passthrough - # - We need dA_cs at the boundary, defined by c_off - no need - # to increase pointer by pid_m (it is a constant offset, - # i.e. the same for all blocks) - dA_cs_m_boundary = tl.load( - dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize, - mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)), - other=0.0, - ).to(tl.float32) - - if HAS_SEQ_IDX: - # - handle seq idx when HAS_INITSTATES==False - if not HAS_INITSTATES: - seq_idx_m = tl.load( - seq_idx_ptr + offs_m * stride_seq_idx_seqlen, - mask=offs_m < chunk_size_limit, - other=-1, - ) - - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - # Without the if (pid_c > -1), with Triton 2.1.0, I get - # Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed. - # With Triton 2.2.0, this works - if IS_TRITON_22 or c_idx > -1: - # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 - offs_k_dstate = tl.arange( - 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K - ) - C_ptrs = C_ptr + ( - offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate - ) - - prev_states_ptrs = prev_states_ptr + ( - offs_n[None, :] * prev_states_hdim - + offs_k_dstate[:, None] * prev_states_dstate - ) - if HAS_SEQ_IDX: - if not HAS_INITSTATES: - # - this is for continuous batching where there is no init states - scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) - else: - # - if there is initstates, we will rely on prev_states, no zeroing - # required. - scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary) - else: - scale_m = tl.exp(dA_cs_m) - if BLOCK_SIZE_DSTATE <= 128: - C = tl.load( - C_ptrs, - mask=(offs_m[:, None] < chunk_size_limit) - & (offs_k_dstate[None, :] < dstate), - other=0.0, - ) - - prev_states = tl.load( - prev_states_ptrs, - mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), - other=0.0, - ) - prev_states = prev_states.to(C_ptr.dtype.element_ty) - acc = tl.dot(C, prev_states) * scale_m[:, None] - else: - for k in range(0, dstate, BLOCK_SIZE_K): - C = tl.load( - C_ptrs, - mask=(offs_m[:, None] < chunk_size_limit) - & (offs_k_dstate[None, :] < dstate - k), - other=0.0, - ) - # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty) - prev_states = tl.load( - prev_states_ptrs, - mask=(offs_k_dstate[:, None] < dstate - k) - & (offs_n[None, :] < hdim), - other=0.0, - ) - prev_states = prev_states.to(C_ptr.dtype.element_ty) - acc += tl.dot(C, prev_states) - C_ptrs += BLOCK_SIZE_K - prev_states_ptrs += BLOCK_SIZE_K - acc *= scale_m[:, None] - - offs_k = tl.arange(0, BLOCK_SIZE_K) + c_off - cb_ptrs = cb_ptr + ( - offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k - ) - x_ptrs = x_ptr + ( - offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim - ) - dt_ptrs = dt_ptr + offs_k * stride_dt_csize - dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize - K_MAX = ( - chunk_size_limit - if not IS_CAUSAL - else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) - ) - for k in range(0, K_MAX, BLOCK_SIZE_K): - cb = tl.load( - cb_ptrs, - mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k), - other=0.0, - ).to(tl.float32) - dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to( - tl.float32 - ) - # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j]. - # So we don't need masking wrt seq_idx here. - cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :]) - dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) - cb *= dt_k - if IS_CAUSAL: - mask = offs_m[:, None] >= k + offs_k[None, :] - cb = tl.where(mask, cb, 0.0) - cb = cb.to(x_ptr.dtype.element_ty) - x = tl.load( - x_ptrs, - mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim), - other=0.0, - ) - acc += tl.dot(cb, x) - cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k - x_ptrs += BLOCK_SIZE_K * stride_x_seqlen - dt_ptrs += BLOCK_SIZE_K * stride_dt_csize - dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize - - offs_out_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M) - offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - - if HAS_D: - if D_HAS_HDIM: - D = tl.load( - D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0 - ).to(tl.float32) - else: - D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) - x_residual = tl.load( - x_ptr - + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim), - mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), - other=0.0, - ).to(tl.float32) - acc += x_residual * D - - if HAS_Z: - out_x_ptr += ( - pid_b * stride_out_batch - + c_idx * chunk_size * stride_out_seqlen - + pid_h * stride_out_head - ) - out_x_ptrs = out_x_ptr + ( - stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] - ) - tl.store( - out_x_ptrs, - acc, - mask=(offs_out_m[:, None] < chunk_size_limit) - & (offs_out_n[None, :] < hdim), - ) - - z_ptr += ( - pid_b * stride_z_batch - + c_idx * chunk_size * stride_z_seqlen - + pid_h * stride_z_head - ) - z_ptrs = z_ptr + ( - stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :] - ) - z = tl.load( - z_ptrs, - mask=(offs_out_m[:, None] < chunk_size_limit) - & (offs_out_n[None, :] < hdim), - other=0.0, - ).to(tl.float32) - acc *= z * tl.sigmoid(z) - - out_ptr += ( - pid_b * stride_out_batch - + c_idx * chunk_size * stride_out_seqlen - + pid_h * stride_out_head - ) - out_ptrs = out_ptr + ( - stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim - ) - tl.store( - out_ptrs, - acc, - mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), - ) - - -def _chunk_scan_fwd( - cb, - x, - dt, - dA_cumsum, - C, - states, - D=None, - z=None, - seq_idx=None, - chunk_indices=None, - chunk_offsets=None, - initial_states=None, - out=None, -): - batch, seqlen, nheads, headdim = x.shape - _, _, nchunks, chunk_size = dt.shape - _, _, ngroups, dstate = C.shape - assert nheads % ngroups == 0 - assert C.shape == (batch, seqlen, ngroups, dstate) - assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) - if z is not None: - assert z.shape == x.shape - if D is not None: - assert D.shape == (nheads, headdim) or D.shape == (nheads,) - assert dt.shape == (batch, nheads, nchunks, chunk_size) - assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) - assert states.shape == (batch, nchunks, nheads, headdim, dstate) - - if seq_idx is not None: - assert seq_idx.shape == (batch, seqlen) - - if initial_states is not None: - # with initial states, we need to take care of how - # seq_idx crosses the boundaries - assert batch == 1, "chunk scan only supports initial states with batch 1" - assert chunk_indices is not None and chunk_offsets is not None, ( - "chunk_indices and chunk_offsets should have been set" - ) - else: - chunk_indices, chunk_offsets = None, None - else: - chunk_indices, chunk_offsets = None, None - - if out is None: - out = torch.empty_like(x) - assert out.shape == x.shape - - if z is not None: - out_x = torch.empty_like(x) - assert out_x.stride() == out.stride() - else: - out_x = None - - grid = lambda META: ( - triton.cdiv(chunk_size, META["BLOCK_SIZE_M"]) - * triton.cdiv(headdim, META["BLOCK_SIZE_N"]), - batch * nchunks if chunk_offsets is None else len(chunk_offsets), - nheads, - ) - z_strides = ( - (z.stride(0), z.stride(1), z.stride(2), z.stride(3)) - if z is not None - else (0, 0, 0, 0) - ) - _chunk_scan_fwd_kernel[grid]( - cb, - x, - z, - out, - out_x, - dt, - dA_cumsum, - seq_idx, - C, - states, - D, - initial_states, - chunk_indices, - chunk_offsets, - len(chunk_indices) if chunk_indices is not None else 0, - chunk_size, - headdim, - dstate, - batch, - seqlen, - nheads // ngroups, - cb.stride(0), - cb.stride(1), - cb.stride(2), - cb.stride(3), - cb.stride(4), - x.stride(0), - x.stride(1), - x.stride(2), - x.stride(3), - z_strides[0], - z_strides[1], - z_strides[2], - z_strides[3], - out.stride(0), - out.stride(1), - out.stride(2), - out.stride(3), - dt.stride(0), - dt.stride(2), - dt.stride(1), - dt.stride(3), - dA_cumsum.stride(0), - dA_cumsum.stride(2), - dA_cumsum.stride(1), - dA_cumsum.stride(3), - *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), - C.stride(0), - C.stride(1), - C.stride(2), - C.stride(3), - states.stride(0), - states.stride(1), - states.stride(2), - states.stride(3), - states.stride(4), - *( - ( - initial_states.stride(0), - initial_states.stride(1), - initial_states.stride(2), - initial_states.stride(3), - ) - if initial_states is not None - else (0, 0, 0, 0) - ), - D.stride(0) if D is not None else 0, - True, - D is not None, - D.dim() == 2 if D is not None else True, - BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), - HAS_Z=z is not None, - HAS_SEQ_IDX=seq_idx is not None, - IS_TRITON_22=TRITON_22, - HAS_INITSTATES=initial_states is not None, - ) - return out_x if z is not None else out diff --git a/tests/mamba/triton_reference/ssd_combined.py b/tests/mamba/triton_reference/ssd_combined.py deleted file mode 100644 index 8b65b38a1b..0000000000 --- a/tests/mamba/triton_reference/ssd_combined.py +++ /dev/null @@ -1,265 +0,0 @@ -# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_combined.py -# Copyright (c) 2024, Tri Dao, Albert Gu. -# -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from einops import rearrange - -from .ssd_bmm import _bmm_chunk_fwd -from .ssd_chunk_scan import _chunk_scan_fwd -from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd, chunk_state_varlen -from .ssd_state_passing import _state_passing_fwd - - -def is_int_pow_2(n): - return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0 - - -def _mamba_chunk_scan_combined_fwd( - x, - dt, - A, - B, - C, - chunk_size, - D=None, - z=None, - dt_bias=None, - initial_states=None, - seq_idx=None, - chunk_indices=None, - chunk_offsets=None, - cu_seqlens=None, - dt_softplus=False, - dt_limit=(0.0, float("inf")), - state_dtype=None, - out=None, -): - assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2" - batch, seqlen, nheads, headdim = x.shape - _, _, ngroups, dstate = B.shape - assert nheads % ngroups == 0 - assert B.shape == (batch, seqlen, ngroups, dstate) - assert dt.shape == (batch, seqlen, nheads) - assert A.shape == (nheads,) - assert C.shape == B.shape - if z is not None: - assert z.shape == x.shape - if D is not None: - assert D.shape == (nheads, headdim) or D.shape == (nheads,) - if seq_idx is not None: - assert seq_idx.shape == (batch, seqlen) - if B.stride(-1) != 1: - B = B.contiguous() - if C.stride(-1) != 1: - C = C.contiguous() - if ( - x.stride(-1) != 1 and x.stride(1) != 1 - ): # Either M or K dimension should be contiguous - x = x.contiguous() - if ( - z is not None and z.stride(-1) != 1 and z.stride(1) != 1 - ): # Either M or K dimension should be contiguous - z = z.contiguous() - if D is not None and D.stride(-1) != 1: - D = D.contiguous() - if initial_states is not None: - if cu_seqlens is None: - assert initial_states.shape == (batch, nheads, headdim, dstate) - else: - assert initial_states.shape == ( - len(cu_seqlens) - 1, - nheads, - headdim, - dstate, - ) - - # This function executes 5 sub-functions for computing mamba - # - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/ - # which has a minimal implementation to understand the below operations - # - as explained by the blog, mamba is a special case of causal attention - # - the idea is to chunk the attention matrix and compute each - # submatrix separately using different optimizations. - # - see the blog and paper for a visualization of the submatrices - # which we refer to in the comments below - - # 1. Compute chunked cumsum of A * dt - # - here dt may go through a softplus activation - dA_cumsum, dt = _chunk_cumsum_fwd( - dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit - ) - - # 2. Compute the state for each intra-chunk - # (right term of low-rank factorization of off-diagonal blocks; B terms) - states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True) - - # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries - # (middle term of factorization of off-diag blocks; A terms) - # - for handling chunked prefill, this requires i) initial_states - # ii) seq_idx iii) is_cont_batched and (iv) chunk_offsets to be all specified. - # - When a new seq_idx is detected, we will stop passing the prev_state - # and switch accordingly to the init_state corresponding to the new seq_idx. - # - We will also make sure that the dA_cumsum is taken only from the start of the - # sequence (hence we need the full dA_cumsum tensor and not just the values at chunk boundaries) - # - this will ensure that states will be updated with the rightmost flushed seq_idx - # of the previous chunk. This implies that the first chunk of states is either 0 - # or equal to init_states of the first example. - states, final_states = _state_passing_fwd( - rearrange(states, "... p n -> ... (p n)"), - dA_cumsum, - initial_states=( - rearrange(initial_states, "... p n -> ... (p n)") - if initial_states is not None - else None - ), - seq_idx=seq_idx, - chunk_size=chunk_size, - out_dtype=state_dtype if state_dtype is not None else C.dtype, - is_cont_batched=cu_seqlens is not None, - chunk_offsets=chunk_offsets, - ) - states, final_states = ( - rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states] - ) - - # 4. Compute batched matrix multiply for C_j^T B_i terms - CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32) - - # 5. Scan and compute the diagonal blocks, taking into - # account past causal states. - # - if initial states are provided, then states information will be - # augmented with initial_states. - # - to do this properly, we need to account for example changes in - # the continuous batch, therefore we introduce pseudo chunks, which is - # a chunk that is split up each time an example changes. - # - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had - # a seq_idx change, in which case we take states information from - # init_states. - out_x = _chunk_scan_fwd( - CB, - x, - dt, - dA_cumsum, - C, - states, - D=D, - z=z, - seq_idx=seq_idx, - chunk_indices=chunk_indices, - chunk_offsets=chunk_offsets, - initial_states=initial_states, - out=out, - ) - if cu_seqlens is None: - return out_x, dt, dA_cumsum, states, final_states - else: - assert batch == 1, ( - "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" - ) - varlen_states = chunk_state_varlen( - B.squeeze(0), - x.squeeze(0), - dt.squeeze(0), - dA_cumsum.squeeze(0), - cu_seqlens, - states.squeeze(0), - initial_states=initial_states, - ) - return out_x, dt, dA_cumsum, states, final_states, varlen_states - - -def mamba_chunk_scan_combined( - x, - dt, - A, - B, - C, - chunk_size, - D=None, - z=None, - dt_bias=None, - initial_states=None, - seq_idx=None, - chunk_indices=None, - chunk_offsets=None, - cu_seqlens=None, - dt_softplus=False, - dt_limit=(0.0, float("inf")), - out=None, - return_final_states=False, - return_varlen_states=False, - state_dtype=None, -): - """ - Argument: - x: (batch, seqlen, nheads, headdim) - dt: (batch, seqlen, nheads) - A: (nheads) - B: (batch, seqlen, ngroups, dstate) - C: (batch, seqlen, ngroups, dstate) - chunk_size: int - D: (nheads, headdim) or (nheads,) - z: (batch, seqlen, nheads, headdim) - dt_bias: (nheads,) - initial_states: (batch, nheads, headdim, dstate) - seq_idx: (batch, seqlen) - cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True - dt_softplus: Whether to apply softplus to dt - out: Preallocated output tensor - state_dtype: The data type of the ssm state - """ - - if not return_varlen_states: - cu_seqlens = None - else: - assert cu_seqlens is not None, ( - "cu_seqlens must be provided if return_varlen_states is True" - ) - out_x, dt_out, dA_cumsum, states, final_states, *rest = ( - _mamba_chunk_scan_combined_fwd( - x, - dt, - A, - B, - C, - chunk_size, - D=D, - z=z, - dt_bias=dt_bias, - initial_states=initial_states, - seq_idx=seq_idx, - chunk_indices=chunk_indices, - chunk_offsets=chunk_offsets, - cu_seqlens=cu_seqlens, - dt_softplus=dt_softplus, - dt_limit=dt_limit, - out=out, - state_dtype=state_dtype, - ) - ) - if not return_varlen_states: - if not return_final_states: - return - else: - return final_states - else: - varlen_states = rest[0] - return ( - (varlen_states) - if not return_final_states - else (final_states, varlen_states) - ) diff --git a/tests/mamba/triton_reference/ssd_state_passing.py b/tests/mamba/triton_reference/ssd_state_passing.py deleted file mode 100644 index df0fe2b27f..0000000000 --- a/tests/mamba/triton_reference/ssd_state_passing.py +++ /dev/null @@ -1,282 +0,0 @@ -# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_state_passing.py -# Copyright (c) 2024, Tri Dao, Albert Gu. -# -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import triton -import triton.language as tl - - -@triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE": 64}), - triton.Config({"BLOCK_SIZE": 128}), - triton.Config({"BLOCK_SIZE": 256}), - triton.Config({"BLOCK_SIZE": 512}), - triton.Config({"BLOCK_SIZE": 1024}), - triton.Config({"BLOCK_SIZE": 2048}), - ], - key=["dim"], -) -@triton.jit -def _state_passing_fwd_kernel( - # Pointers to matrices - states_ptr, - out_ptr, - final_states_ptr, - dA_cs_ptr, - initstates_ptr, - seq_idx_ptr, - chunk_offsets_ptr, - chunk_meta_num, - # Matrix dimensions - dim, - nchunks, - seqlen, - chunk_size, - # Strides - stride_states_batch, - stride_states_chunk, - stride_states_head, - stride_states_dim, - stride_out_batch, - stride_out_chunk, - stride_out_head, - stride_out_dim, - stride_final_states_batch, - stride_final_states_head, - stride_final_states_dim, - stride_dA_cs_batch, - stride_dA_cs_chunk, - stride_dA_cs_head, - stride_dA_cs_csize, - stride_initstates_batch, - stride_initstates_head, - stride_initstates_dim, - stride_seq_idx_batch, - stride_seq_idx_seqlen, - # Meta-parameters - HAS_INITSTATES: tl.constexpr, - HAS_SEQ_IDX: tl.constexpr, - IS_CONT_BATCHED: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - pid_b = tl.program_id(axis=1) - pid_h = tl.program_id(axis=2) - pid_m = tl.program_id(axis=0) - states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head - dA_cs_ptr += ( - pid_b * stride_dA_cs_batch - + pid_h * stride_dA_cs_head - + (chunk_size - 1) * stride_dA_cs_csize - ) - out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head - final_states_ptr += ( - pid_b * stride_final_states_batch + pid_h * stride_final_states_head - ) - if HAS_INITSTATES: - initstates_ptr += pid_h * stride_initstates_head - if not IS_CONT_BATCHED: - initstates_ptr += pid_b * stride_initstates_batch - - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch - - offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - states_ptrs = states_ptr + offs_m * stride_states_dim - out_ptrs = out_ptr + offs_m * stride_out_dim - final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim - - # - states will be the past state of the sequence that continues on the current check - if not HAS_INITSTATES: - states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) - else: - initstates_ptr += offs_m * stride_initstates_dim - initstates_ptrs = initstates_ptr - # - for cont batches, for the first chunk mean it will be the first batch's - # init state - states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - tl.store(out_ptrs, states, mask=offs_m < dim) - out_ptrs += stride_out_chunk - prev_seq_idx_chunk_end = 0 - logical_chunk_idx = 0 - for c in range(nchunks): - new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - dA_cs = tl.load(dA_cs_ptr).to(tl.float32) - scale_mask = True - if HAS_SEQ_IDX: - # - the seq to pass forward is the one that is flushed to the right - # boundary. - # - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk. - seq_idx_chunk_end = tl.load( - seq_idx_ptr - + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen - ) - if HAS_INITSTATES: - if IS_CONT_BATCHED and prev_seq_idx_chunk_end != seq_idx_chunk_end: - # this means in the current chunk the rightmost flushed seq - # has changed. - # - so we do not propagate the state from previous chunk - # - but rather we load that sequence's init state - initstates_ptrs = ( - initstates_ptr + seq_idx_chunk_end * stride_initstates_batch - ) - - # - update state with seq_idx_new's init state - states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to( - tl.float32 - ) - - # - we need to consider the cumsum only of the last sequence in the chunk - # - find its starting position (given by c_off of the logical chunk index) - # - and subtract the cumsum just before that position from the total cumsum - # - first, update the logical chunk index (add the number of sequences in the current physical chunk): - # sequence index at the start of the current chunk - seq_idx_chunk_start = tl.load( - seq_idx_ptr - + min(c * chunk_size, seqlen) * stride_seq_idx_seqlen - ) - logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start - # - load the chunk offset: - c_off = tl.load( - chunk_offsets_ptr + logical_chunk_idx, - mask=logical_chunk_idx < chunk_meta_num, - other=0, - ) - # - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything - if c_off > 0: - # - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset - dA_cs_boundary = tl.load( - dA_cs_ptr - - (chunk_size - 1) * stride_dA_cs_csize - + (c_off - 1) * stride_dA_cs_csize, - mask=(c_off - 1) > -1 and c_off < chunk_size, - other=0.0, - ) - dA_cs -= dA_cs_boundary - - # - increment logical chunk index for every physical chunk - logical_chunk_idx += 1 - else: - scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end - prev_seq_idx_chunk_end = seq_idx_chunk_end - - scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0) - states = scale * states + new_states - if c < nchunks - 1: - tl.store(out_ptrs, states, mask=offs_m < dim) - else: - tl.store(final_states_ptrs, states, mask=offs_m < dim) - states_ptrs += stride_states_chunk - dA_cs_ptr += stride_dA_cs_chunk - out_ptrs += stride_out_chunk - - -def _state_passing_fwd( - states, - dA_cumsum, - initial_states=None, - seq_idx=None, - chunk_size=None, - out_dtype=None, - is_cont_batched=False, - chunk_offsets=None, -): - batch, nchunks, nheads, dim = states.shape - if chunk_size is None: - chunk_size = dA_cumsum.shape[-1] - else: - assert chunk_size == dA_cumsum.shape[-1] - assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) - if initial_states is not None: - if is_cont_batched: - # - if cu_seqlens is provided, then the initial states - # are used for continuous batching. In which case we - # require seq_idx to be provided - assert seq_idx is not None, ( - "seq_idx must be provided for continuous batching" - ) - # - we also need chunk_offsets to be provided, to account - # for computation of dA_cumsum from the start of the - # sequence - assert chunk_offsets is not None, ( - "chunk_offsets must be provided for continuous batching" - ) - else: - # - this is the regular batching case, where initial - # states are used are for each example of the batch. - assert initial_states.shape == (batch, nheads, dim) - - if seq_idx is not None: - seqlen = seq_idx.shape[-1] - assert seq_idx.shape == (batch, seqlen) - out_dtype = states.dtype if out_dtype is None else out_dtype - out = torch.empty( - (batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype - ) - final_states = torch.empty( - (batch, nheads, dim), device=states.device, dtype=torch.float32 - ) - grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE"]), batch, nheads) - with torch.cuda.device(states.device.index): - _state_passing_fwd_kernel[grid]( - states, - out, - final_states, - dA_cumsum, - initial_states, - seq_idx, - chunk_offsets, - len(chunk_offsets) if chunk_offsets is not None else 0, - dim, - nchunks, - seqlen if seq_idx is not None else 0, - chunk_size, - states.stride(0), - states.stride(1), - states.stride(2), - states.stride(3), - out.stride(0), - out.stride(1), - out.stride(2), - out.stride(3), - final_states.stride(0), - final_states.stride(1), - final_states.stride(2), - dA_cumsum.stride(0), - dA_cumsum.stride(2), - dA_cumsum.stride(1), - dA_cumsum.stride(3), - *( - ( - initial_states.stride(0), - initial_states.stride(1), - initial_states.stride(2), - ) - if initial_states is not None - else (0, 0, 0) - ), - *( - (seq_idx.stride(0), seq_idx.stride(1)) - if seq_idx is not None - else (0, 0) - ), - HAS_INITSTATES=initial_states is not None, - HAS_SEQ_IDX=seq_idx is not None, - IS_CONT_BATCHED=is_cont_batched, - ) - return out, final_states From de96dd5833f9c1aeecafd1a278f9529fd4eb414f Mon Sep 17 00:00:00 2001 From: Igor Shovkun Date: Wed, 18 Feb 2026 22:40:19 -0800 Subject: [PATCH 25/33] Remove ssd_chunk_state.py Triton reference implementation (irrelevant to PR) --- .../mamba/triton_reference/ssd_chunk_state.py | 771 ------------------ 1 file changed, 771 deletions(-) delete mode 100644 tests/mamba/triton_reference/ssd_chunk_state.py diff --git a/tests/mamba/triton_reference/ssd_chunk_state.py b/tests/mamba/triton_reference/ssd_chunk_state.py deleted file mode 100644 index 14b9d28851..0000000000 --- a/tests/mamba/triton_reference/ssd_chunk_state.py +++ /dev/null @@ -1,771 +0,0 @@ -# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_state.py -# Copyright (c) 2024, Tri Dao, Albert Gu. -# -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math - -import torch -import triton -import triton.language as tl - -from .softplus import softplus - - -@triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE_H": 1}), - triton.Config({"BLOCK_SIZE_H": 2}), - triton.Config({"BLOCK_SIZE_H": 4}), - triton.Config({"BLOCK_SIZE_H": 8}), - triton.Config({"BLOCK_SIZE_H": 16}), - triton.Config({"BLOCK_SIZE_H": 32}), - triton.Config({"BLOCK_SIZE_H": 64}), - ], - key=["chunk_size", "nheads"], -) -@triton.jit -def _chunk_cumsum_fwd_kernel( - # Pointers to matrices - dt_ptr, - A_ptr, - dt_bias_ptr, - dt_out_ptr, - dA_cumsum_ptr, - # Matrix dimension - batch, - seqlen, - nheads, - chunk_size, - dt_min, - dt_max, - # Strides - stride_dt_batch, - stride_dt_seqlen, - stride_dt_head, - stride_A_head, - stride_dt_bias_head, - stride_dt_out_batch, - stride_dt_out_chunk, - stride_dt_out_head, - stride_dt_out_csize, - stride_dA_cs_batch, - stride_dA_cs_chunk, - stride_dA_cs_head, - stride_dA_cs_csize, - # Meta-parameters - DT_SOFTPLUS: tl.constexpr, - HAS_DT_BIAS: tl.constexpr, - BLOCK_SIZE_H: tl.constexpr, - BLOCK_SIZE_CHUNK: tl.constexpr, -): - pid_b = tl.program_id(axis=0) - - # if dt is long, may cause problems, so use 64 bit - # https://github.com/triton-lang/triton/issues/1058 - pid_c = tl.program_id(axis=1).to(tl.int64) - pid_h = tl.program_id(axis=2) - dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen - dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk - dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk - - offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) - offs_c = tl.arange(0, BLOCK_SIZE_CHUNK) - dt_ptrs = dt_ptr + ( - offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen - ) - A_ptrs = A_ptr + offs_h * stride_A_head - dt_out_ptrs = dt_out_ptr + ( - offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize - ) - dA_cs_ptrs = dA_cumsum_ptr + ( - offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize - ) - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - - dt = tl.load( - dt_ptrs, - mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), - other=0.0, - ).to(tl.float32) - if HAS_DT_BIAS: - dt_bias = tl.load( - dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0 - ).to(tl.float32) - dt += dt_bias[:, None] - if DT_SOFTPLUS: - dt = tl.where(dt <= 20.0, softplus(dt), dt) - # As of Triton 2.2.0, tl.clamp is not available yet - # dt = tl.clamp(dt, dt_min, dt_max) - dt = tl.minimum(tl.maximum(dt, dt_min), dt_max) - dt = tl.where( - (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0 - ) - tl.store( - dt_out_ptrs, - dt, - mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size), - ) - A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) - dA = dt * A[:, None] - dA_cs = tl.cumsum(dA, axis=1) - tl.store( - dA_cs_ptrs, - dA_cs, - mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size), - ) - - -@triton.autotune( - configs=[ - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, - num_stages=3, - num_warps=8, - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, - num_stages=5, - num_warps=2, - ), - triton.Config( - {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, - num_stages=5, - num_warps=2, - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=2, - ), - ], - key=["hdim", "dstate", "chunk_size"], -) -@triton.jit -def _chunk_state_fwd_kernel( - # Pointers to matrices - x_ptr, - b_ptr, - states_ptr, - dt_ptr, - dA_cumsum_ptr, - seq_idx_ptr, - # Matrix dimensions - hdim, - dstate, - chunk_size, - batch, - seqlen, - nheads_ngroups_ratio, - # Strides - stride_x_batch, - stride_x_seqlen, - stride_x_head, - stride_x_hdim, - stride_b_batch, - stride_b_seqlen, - stride_b_head, - stride_b_dstate, - stride_states_batch, - stride_states_chunk, - stride_states_head, - stride_states_hdim, - stride_states_dstate, - stride_dt_batch, - stride_dt_chunk, - stride_dt_head, - stride_dt_csize, - stride_dA_cs_batch, - stride_dA_cs_chunk, - stride_dA_cs_head, - stride_dA_cs_csize, - stride_seq_idx_batch, - stride_seq_idx_seqlen, - # Meta-parameters - HAS_SEQ_IDX: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, -): - pid_bc = tl.program_id(axis=1).to(tl.int64) - pid_c = pid_bc // batch - pid_b = pid_bc - pid_c * batch - pid_h = tl.program_id(axis=2) - num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) - pid_m = tl.program_id(axis=0) // num_pid_n - pid_n = tl.program_id(axis=0) % num_pid_n - b_ptr += ( - pid_b * stride_b_batch - + pid_c * chunk_size * stride_b_seqlen - + (pid_h // nheads_ngroups_ratio) * stride_b_head - ) - x_ptr += ( - pid_b * stride_x_batch - + pid_c * chunk_size * stride_x_seqlen - + pid_h * stride_x_head - ) - dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head - dA_cumsum_ptr += ( - pid_b * stride_dA_cs_batch - + pid_c * stride_dA_cs_chunk - + pid_h * stride_dA_cs_head - ) - if HAS_SEQ_IDX: - seq_idx_ptr += ( - pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen - ) - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - x_ptrs = x_ptr + ( - offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen - ) - b_ptrs = b_ptr + ( - offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen - ) - dt_ptrs = dt_ptr + offs_k * stride_dt_csize - dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to( - tl.float32 - ) - dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize - if HAS_SEQ_IDX: - seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen - - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - if HAS_SEQ_IDX: - seq_idx_last = tl.load( - seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen - ) - - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, chunk_size_limit, BLOCK_SIZE_K): - x = tl.load( - x_ptrs, - mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), - other=0.0, - ) - b = tl.load( - b_ptrs, - mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), - other=0.0, - ).to(tl.float32) - dA_cs_k = tl.load( - dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0 - ).to(tl.float32) - if HAS_SEQ_IDX: - seq_idx_k = tl.load( - seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1 - ) - dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to( - tl.float32 - ) - if not HAS_SEQ_IDX: - # scale = tl.exp((dA_cs_last - dA_cs_k)) * dt_k - scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k - else: - # scale = tl.where(seq_idx_k == seq_idx_last, tl.exp((dA_cs_last - dA_cs_k)) * dt_k, 0.0) - scale = tl.where( - seq_idx_k == seq_idx_last, - tl.exp(tl.minimum((dA_cs_last - dA_cs_k), 0.0)) * dt_k, - 0.0, - ) - b *= scale[:, None] - b = b.to(x_ptr.dtype.element_ty) - acc += tl.dot(x, b) - x_ptrs += BLOCK_SIZE_K * stride_x_seqlen - b_ptrs += BLOCK_SIZE_K * stride_b_seqlen - dt_ptrs += BLOCK_SIZE_K * stride_dt_csize - dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize - if HAS_SEQ_IDX: - seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen - states = acc.to(states_ptr.dtype.element_ty) - - states_ptr += ( - pid_b * stride_states_batch - + pid_c * stride_states_chunk - + pid_h * stride_states_head - ) - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - states_ptrs = states_ptr + ( - offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate - ) - c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) - tl.store(states_ptrs, states, mask=c_mask) - - -@triton.autotune( - configs=[ - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, - num_stages=3, - num_warps=8, - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=4, - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, - num_stages=5, - num_warps=2, - ), - triton.Config( - {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, - num_stages=5, - num_warps=2, - ), - triton.Config( - {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, - num_stages=4, - num_warps=2, - ), - ], - key=["hdim", "dstate", "chunk_size"], -) -@triton.jit -def _chunk_state_varlen_kernel( - # Pointers to matrices - x_ptr, - b_ptr, - dt_ptr, - dA_cumsum_ptr, - chunk_states_ptr, - cu_seqlens_ptr, - states_ptr, - initstates_ptr, - # Matrix dimensions - hdim, - dstate, - chunk_size, - seqlen, - nheads_ngroups_ratio, - # Strides - stride_x_seqlen, - stride_x_head, - stride_x_hdim, - stride_b_seqlen, - stride_b_head, - stride_b_dstate, - stride_dt_chunk, - stride_dt_head, - stride_dt_csize, - stride_dA_cs_chunk, - stride_dA_cs_head, - stride_dA_cs_csize, - stride_chunk_states_chunk, - stride_chunk_states_head, - stride_chunk_states_hdim, - stride_chunk_states_dstate, - stride_states_batch, - stride_states_head, - stride_states_hdim, - stride_states_dstate, - stride_init_states_batch, - stride_init_states_head, - stride_init_states_hdim, - stride_init_states_dstate, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - HAS_INITSTATES: tl.constexpr, -): - pid_b = tl.program_id(axis=1) - pid_h = tl.program_id(axis=2) - num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) - pid_m = tl.program_id(axis=0) // num_pid_n - pid_n = tl.program_id(axis=0) % num_pid_n - end_idx = tl.load(cu_seqlens_ptr + pid_b + 1) - pid_c = (end_idx - 1) // chunk_size - b_ptr += ( - pid_c * chunk_size * stride_b_seqlen - + (pid_h // nheads_ngroups_ratio) * stride_b_head - ) - x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head - dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head - dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - chunk_states_ptr += ( - pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head - ) - - if HAS_INITSTATES: - # if there are init states provided, we differentiate between states (which - # are boundary conditions at a chunk boundary) and initstates (which are boundary - # conditions when a new example in a cont batch starts) - initstates_ptr += pid_h * stride_init_states_head - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - x_ptrs = x_ptr + ( - offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen - ) - b_ptrs = b_ptr + ( - offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen - ) - dt_ptrs = dt_ptr + offs_k * stride_dt_csize - dA_cs_last = tl.load( - dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize - ).to(tl.float32) - dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize - - chunk_size_limit = end_idx - pid_c * chunk_size - start_idx = tl.load(cu_seqlens_ptr + pid_b) - start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0) - - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, chunk_size_limit, BLOCK_SIZE_K): - x = tl.load( - x_ptrs, - mask=(offs_m[:, None] < hdim) - & (offs_k[None, :] < chunk_size_limit - k) - & (offs_k[None, :] >= start_idx_cur - k), - other=0.0, - ) - b = tl.load( - b_ptrs, - mask=(offs_k[:, None] < chunk_size_limit - k) - & (offs_n[None, :] < dstate) - & (offs_k[:, None] >= start_idx_cur - k), - other=0.0, - ).to(tl.float32) - dA_cs_k = tl.load( - dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0 - ).to(tl.float32) - dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to( - tl.float32 - ) - scale = tl.where( - (offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k), - tl.exp(dA_cs_last - dA_cs_k) * dt_k, - 0.0, - ) - b *= scale[:, None] - b = b.to(x_ptr.dtype.element_ty) - acc += tl.dot(x, b) - x_ptrs += BLOCK_SIZE_K * stride_x_seqlen - b_ptrs += BLOCK_SIZE_K * stride_b_seqlen - dt_ptrs += BLOCK_SIZE_K * stride_dt_csize - dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize - - # If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk - # If HAS_INITSTATES==True need to consider two possiblties - # - if start_idx < pid_c * chunk_size, then we need to take the past_states_ptrs - # - if state_idx >= pid * chunk_size, then we need to insert initstates - if ( - (start_idx < pid_c * chunk_size) # first chunk - or (HAS_INITSTATES) - ): - dA_cs_boundary = 0.0 # default - - if not HAS_INITSTATES: - past_states_ptrs = chunk_states_ptr + ( - offs_m[:, None] * stride_chunk_states_hdim - + offs_n[None, :] * stride_chunk_states_dstate - ) - else: - # - this seems repetitve, buts its to help the compiler - if start_idx < pid_c * chunk_size: - past_states_ptrs = chunk_states_ptr + ( - offs_m[:, None] * stride_chunk_states_hdim - + offs_n[None, :] * stride_chunk_states_dstate - ) - else: - past_states_ptrs = initstates_ptr + ( - pid_b * stride_init_states_batch - + offs_m[:, None] * stride_init_states_hdim - + offs_n[None, :] * stride_init_states_dstate - ) - - # need to adjust the boundary - if start_idx > pid_c * chunk_size: - dA_cs_boundary = tl.load( - dA_cumsum_ptr - + (start_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize - ).to(tl.float32) - - past_states = tl.load( - past_states_ptrs, - mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate), - other=0.0, - ).to(tl.float32) - - scale = tl.exp(dA_cs_last - dA_cs_boundary) - acc += past_states * scale - - states = acc.to(states_ptr.dtype.element_ty) - - states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - states_ptrs = states_ptr + ( - offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate - ) - c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) - tl.store(states_ptrs, states, mask=c_mask) - - -def _chunk_cumsum_fwd( - dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf")) -): - batch, seqlen, nheads = dt.shape - assert A.shape == (nheads,) - if dt_bias is not None: - assert dt_bias.shape == (nheads,) - nchunks = math.ceil(seqlen / chunk_size) - dt_out = torch.empty( - batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32 - ) - dA_cumsum = torch.empty( - batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32 - ) - grid_chunk_cs = lambda META: ( - batch, - nchunks, - triton.cdiv(nheads, META["BLOCK_SIZE_H"]), - ) - with torch.cuda.device(dt.device.index): - _chunk_cumsum_fwd_kernel[grid_chunk_cs]( - dt, - A, - dt_bias, - dt_out, - dA_cumsum, - batch, - seqlen, - nheads, - chunk_size, - dt_limit[0], - dt_limit[1], - dt.stride(0), - dt.stride(1), - dt.stride(2), - A.stride(0), - dt_bias.stride(0) if dt_bias is not None else 0, - dt_out.stride(0), - dt_out.stride(2), - dt_out.stride(1), - dt_out.stride(3), - dA_cumsum.stride(0), - dA_cumsum.stride(2), - dA_cumsum.stride(1), - dA_cumsum.stride(3), - dt_softplus, - HAS_DT_BIAS=dt_bias is not None, - BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), - ) - return dA_cumsum, dt_out - - -def _chunk_state_fwd( - B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_fp32=True -): - batch, seqlen, nheads, headdim = x.shape - _, _, nchunks, chunk_size = dt.shape - _, _, ngroups, dstate = B.shape - assert nheads % ngroups == 0 - assert B.shape == (batch, seqlen, ngroups, dstate) - assert dt.shape == (batch, nheads, nchunks, chunk_size) - assert dA_cumsum.shape == dt.shape - if seq_idx is not None: - assert seq_idx.shape == (batch, seqlen) - if states is not None: - assert states.shape == (batch, nchunks, nheads, headdim, dstate) - else: - states_dtype = torch.float32 if states_in_fp32 else B.dtype - states = torch.empty( - (batch, nchunks, nheads, headdim, dstate), - device=x.device, - dtype=states_dtype, - ) - grid = lambda META: ( - triton.cdiv(headdim, META["BLOCK_SIZE_M"]) - * triton.cdiv(dstate, META["BLOCK_SIZE_N"]), - batch * nchunks, - nheads, - ) - with torch.cuda.device(x.device.index): - _chunk_state_fwd_kernel[grid]( - x, - B, - states, - dt, - dA_cumsum, - seq_idx, - headdim, - dstate, - chunk_size, - batch, - seqlen, - nheads // ngroups, - x.stride(0), - x.stride(1), - x.stride(2), - x.stride(3), - B.stride(0), - B.stride(1), - B.stride(2), - B.stride(-1), - states.stride(0), - states.stride(1), - states.stride(2), - states.stride(3), - states.stride(4), - dt.stride(0), - dt.stride(2), - dt.stride(1), - dt.stride(3), - dA_cumsum.stride(0), - dA_cumsum.stride(2), - dA_cumsum.stride(1), - dA_cumsum.stride(3), - *( - (seq_idx.stride(0), seq_idx.stride(1)) - if seq_idx is not None - else (0, 0) - ), - HAS_SEQ_IDX=seq_idx is not None, - ) - return states - - -def chunk_state_varlen( - B, x, dt, dA_cumsum, cu_seqlens, chunk_states, initial_states=None -): - total_seqlen, nheads, headdim = x.shape - _, nchunks, chunk_size = dt.shape - _, ngroups, dstate = B.shape - batch = cu_seqlens.shape[0] - 1 - cu_seqlens = cu_seqlens.contiguous() - assert nheads % ngroups == 0 - assert B.shape == (total_seqlen, ngroups, dstate) - assert dt.shape == (nheads, nchunks, chunk_size) - assert dA_cumsum.shape == dt.shape - assert chunk_states.shape == (nchunks, nheads, headdim, dstate) - - if initial_states is not None: - assert initial_states.shape == (batch, nheads, headdim, dstate) - - states = torch.empty( - batch, - nheads, - headdim, - dstate, - dtype=chunk_states.dtype, - device=chunk_states.device, - ) - grid = lambda META: ( - triton.cdiv(headdim, META["BLOCK_SIZE_M"]) - * triton.cdiv(dstate, META["BLOCK_SIZE_N"]), - batch, - nheads, - ) - with torch.cuda.device(x.device.index): - _chunk_state_varlen_kernel[grid]( - x, - B, - dt, - dA_cumsum, - chunk_states, - cu_seqlens, - states, - initial_states, - headdim, - dstate, - chunk_size, - total_seqlen, - nheads // ngroups, - x.stride(0), - x.stride(1), - x.stride(2), - B.stride(0), - B.stride(1), - B.stride(2), - dt.stride(1), - dt.stride(0), - dt.stride(2), - dA_cumsum.stride(1), - dA_cumsum.stride(0), - dA_cumsum.stride(2), - chunk_states.stride(0), - chunk_states.stride(1), - chunk_states.stride(2), - chunk_states.stride(3), - states.stride(0), - states.stride(1), - states.stride(2), - states.stride(3), - *( - ( - initial_states.stride(0), - initial_states.stride(1), - initial_states.stride(2), - initial_states.stride(3), - ) - if initial_states is not None - else (0, 0, 0, 0) - ), - HAS_INITSTATES=initial_states is not None, - ) - return states From 4c30f07be382c6e892c62a04b1e1c6fb3774d6fc Mon Sep 17 00:00:00 2001 From: Igor Shovkun Date: Wed, 18 Feb 2026 22:42:10 -0800 Subject: [PATCH 26/33] Delete test_utils.py --- tests/mamba/test_utils.py | 235 -------------------------------------- 1 file changed, 235 deletions(-) delete mode 100644 tests/mamba/test_utils.py diff --git a/tests/mamba/test_utils.py b/tests/mamba/test_utils.py deleted file mode 100644 index 6f33e930e3..0000000000 --- a/tests/mamba/test_utils.py +++ /dev/null @@ -1,235 +0,0 @@ -from typing import Any, Dict, Optional - -import numpy as np -import torch - - -def clone_preserving_strides(tensor): - """Clone a tensor while preserving its strides (non-contiguous layout).""" - result = torch.empty_strided( - tensor.size(), tensor.stride(), dtype=tensor.dtype, device=tensor.device - ) - result.copy_(tensor) - return result - - -def create_test_inputs( - batch_size: int, - nheads: int, - dim: int, - dstate: int, - ngroups: int, - input_dtype: torch.dtype, - weight_dtype: torch.dtype = torch.float32, - state_dtype: torch.dtype = torch.bfloat16, - matrixA_dtype: torch.dtype = torch.float32, - generate_z: bool = False, - generate_intermediate_states_buffer: bool = False, - cache_steps: Optional[int] = None, - generate_retrieve_parent_token: bool = False, - state_cache_batch_stride: Optional[int] = None, - device: str = "cuda", - seed: int = 42, -) -> Dict[str, Any]: - """ - Create test inputs for selective_state_update functions. - - This function generates all necessary tensors for testing selective state - update kernels, supporting both single-token and multi-token (speculative - decoding) scenarios. - - Arguments: - batch_size: Number of sequences in the batch. - nheads: Number of attention heads. - dim: Head dimension (headdim). - dstate: SSM state size. - ngroups: Number of groups for B and C matrices. - input_dtype: Data type for input tensors (x, B, C, z) - from model config.json (typically bf16). - weight_dtype: Data type for weight tensors (D, dt, dt_bias) - hardcoded fp32 in mamba2_mixer.py. - state_dtype: Data type for state tensor - user configurable (bf16/fp16/fp32). Defaults to input_dtype. - matrixA_dtype: Data type for the A matrix - hardcoded fp32 in mamba2_mixer.py. - generate_z: If True, generate z tensor for gating. - generate_intermediate_states_buffer: If True, generate buffer for - caching intermediate states during speculative decoding. - cache_steps: Number of steps/tokens to cache. Required if - generate_intermediate_states_buffer is True. Also determines - T dimension when > 1 (multi-token mode). - generate_retrieve_parent_token: If True, generate tensor for EAGLE - tree attention parent token retrieval. - state_cache_batch_stride: Optional batch stride for ssm_state_cache. - If None, defaults to contiguous stride (nheads * dim * dstate). - Must be >= nheads * dim * dstate if specified. - device: Device to create tensors on. - seed: Random seed for reproducibility. - - Returns: - Dictionary containing all generated tensors with the following keys: - - state_cache: (total_entries, nheads, dim, dstate) - - x: (batch_size, [T,] nheads, dim) - T present if cache_steps provided - - dt: (batch_size, [T,] nheads, dim) - T present if cache_steps provided - - A: (nheads, dim, dstate) - - B: (batch_size, [T,] ngroups, dstate) - T present if cache_steps provided - - C: (batch_size, [T,] ngroups, dstate) - T present if cache_steps provided - - D: (nheads, dim) - - dt_bias: (nheads, dim) - - slot_idx: (batch_size,) - - z: (batch_size, [T,] nheads, dim) - only if generate_z=True, T present if cache_steps provided - - intermediate_states_buffer: (batch_size, cache_steps, nheads, dim, dstate) - - only if generate_intermediate_states_buffer=True - - intermediate_slot_idx: (batch_size,) - - only if generate_intermediate_states_buffer=True - - retrieve_parent_token: (batch_size, T) - - only if generate_retrieve_parent_token=True - - cache_steps: int - only if cache_steps is provided - """ - # Set seeds for reproducibility - torch.manual_seed(seed) - np.random.seed(seed) - - # Determine if we're in multi-token mode - # Always use 4D tensors when cache_steps is provided (even for cache_steps=1) - T = cache_steps if cache_steps is not None else None - - # If we use the cache, then the state indices are taken from a specific slot - # so the state in the kernel will have batch as the first dimension, but it will - # only come from a particular slot; the full tensor first dim is larger - ssm_state_cache_size = max(384, batch_size * 10) - - # State dtype defaults to input_dtype if not specified - - # SSM state cache: (total_entries, nheads, dim, dstate) - # Calculate the contiguous batch stride - contiguous_batch_stride = nheads * dim * dstate - - # Use provided batch stride or default to contiguous - if state_cache_batch_stride is None: - state_cache_batch_stride = contiguous_batch_stride - - # Validate that batch stride is large enough - if state_cache_batch_stride < contiguous_batch_stride: - raise ValueError( - f"state_cache_batch_stride ({state_cache_batch_stride}) must be >= " - f"contiguous stride ({contiguous_batch_stride} = nheads * dim * dstate)" - ) - - total_elements = ssm_state_cache_size * state_cache_batch_stride - state_cache_flat = torch.randn(total_elements, dtype=state_dtype, device=device) - state_cache = state_cache_flat.as_strided( - (ssm_state_cache_size, nheads, dim, dstate), - (state_cache_batch_stride, dim * dstate, dstate, 1), - ) - - # Input x: (batch_size, [T,] nheads, dim) - if T is not None: - x = torch.randn(batch_size, T, nheads, dim, device=device, dtype=input_dtype) - else: - x = torch.randn(batch_size, nheads, dim, dtype=input_dtype, device=device) - - # dt: (batch_size, [T,] nheads, dim) with strides that broadcast dim - # dt uses weight_dtype (fp32) as per mamba2_mixer.py - # dt has T dimension for multi-token mode, matching x shape - if T is not None: - dt_base = torch.randn(batch_size, T, nheads, dtype=weight_dtype, device=device) - dt = dt_base.as_strided( - (batch_size, T, nheads, dim), (T * nheads, nheads, 1, 0) - ) - else: - dt_base = torch.randn(batch_size, nheads, dtype=weight_dtype, device=device) - dt = dt_base.as_strided((batch_size, nheads, dim), (nheads, 1, 0)) - - # A matrix: (nheads, dim, dstate) with strides (1, 0, 0) - one value per head - # A should be negative for stability - A_base = -torch.rand(nheads, dtype=matrixA_dtype, device=device) - 1.0 - A = A_base.as_strided((nheads, dim, dstate), (1, 0, 0)) - - # B: (batch_size, T, ngroups, dstate) - # C: (batch_size, ngroups, dstate) - if T is not None: - B = torch.randn( - batch_size, T, ngroups, dstate, device=device, dtype=input_dtype - ) - C = torch.randn( - batch_size, T, ngroups, dstate, device=device, dtype=input_dtype - ) - else: - B = torch.randn(batch_size, ngroups, dstate, dtype=input_dtype, device=device) - C = torch.randn(batch_size, ngroups, dstate, dtype=input_dtype, device=device) - - # D: (nheads, dim) with strides (1, 0) - one value per head - D_base = torch.randn(nheads, dtype=weight_dtype, device=device) - D = D_base.as_strided((nheads, dim), (1, 0)) - - # dt_bias: (nheads, dim) with strides (1, 0) - one value per head - dt_bias_base = torch.rand(nheads, dtype=weight_dtype, device=device) - 4.0 - dt_bias = dt_bias_base.as_strided((nheads, dim), (1, 0)) - - # Slot indices for state batching - (batch_size,) - slot_idx = torch.randperm(ssm_state_cache_size, dtype=torch.int64, device=device)[ - :batch_size - ] - - # Build result dictionary - result = { - "state_cache": state_cache, - "x": x, - "dt": dt, - "A": A, - "B": B, - "C": C, - "D": D, - "dt_bias": dt_bias, - "slot_idx": slot_idx, - } - - # Optional: z tensor for gating - # z: (batch_size, [T,] nheads, dim) - has T dimension for multi-token mode, matching x shape - if generate_z: - if T is not None: - z = torch.randn( - batch_size, T, nheads, dim, dtype=input_dtype, device=device - ) - else: - z = torch.randn(batch_size, nheads, dim, dtype=input_dtype, device=device) - result["z"] = z - - # Optional: intermediate states buffer for speculative decoding - if generate_intermediate_states_buffer: - if cache_steps is None: - raise ValueError( - "cache_steps must be provided when generate_intermediate_states_buffer=True" - ) - intermediate_states_buffer = torch.zeros( - batch_size, - cache_steps, - nheads, - dim, - dstate, - dtype=state_dtype, - device=device, - ) - result["intermediate_states_buffer"] = intermediate_states_buffer - result["cache_steps"] = cache_steps - # Also generate indices mapping batch elements to intermediate state buffer positions - intermediate_slot_idx = torch.arange( - batch_size, dtype=torch.int64, device=device - ) - result["intermediate_slot_idx"] = intermediate_slot_idx - - # Optional: retrieve_parent_token for EAGLE tree attention - if generate_retrieve_parent_token: - if T is None or T <= 1: - raise ValueError( - "cache_steps > 1 required when generate_retrieve_parent_token=True" - ) - # Create a simple linear chain structure by default - # Token 0: parent = -1 (initial state) - # Token t: parent = t - 1 (previous token) - retrieve_parent_token = torch.zeros( - batch_size, T, dtype=torch.int64, device=device - ) - retrieve_parent_token[:, 0] = -1 # First token uses initial state - for t in range(1, T): - retrieve_parent_token[:, t] = t - 1 - result["retrieve_parent_token"] = retrieve_parent_token - - return result From 1f1c2f4d484f9a0f08162a384a18c0dab02660ef Mon Sep 17 00:00:00 2001 From: Igor Shovkun Date: Wed, 18 Feb 2026 23:12:12 -0800 Subject: [PATCH 27/33] Suppress mypy false positive for gen_selective_state_update calls --- flashinfer/aot.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/flashinfer/aot.py b/flashinfer/aot.py index a1705cf0e9..e45ea05491 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -581,14 +581,17 @@ def gen_all_modules( _ssu_dtype_combos, _ssu_dims, _ssu_dstates, _ssu_ntokens ): jit_specs.append( - gen_selective_state_update_module(*dtype_combo, dim, dstate, ntokens) + # false positive: mypy can't resolve the signature because flashinfer.jit deps (filelock etc.) + # are absent in mypy's isolated env, causing it to infer an incorrect function signature + gen_selective_state_update_module(*dtype_combo, dim, dstate, ntokens) # type: ignore[call-arg] ) if has_sm90 or has_sm100: for dtype_combo, dim, dstate, ntokens in product( _ssu_dtype_combos, _ssu_dims, _ssu_dstates, _ssu_ntokens ): jit_specs.append( - gen_selective_state_update_sm90_module( + # same false positive as above + gen_selective_state_update_sm90_module( # type: ignore[call-arg] *dtype_combo, dim, dstate, ntokens ) ) From 157ecb53b1c79d7f9c75838c43f10c44da168c3b Mon Sep 17 00:00:00 2001 From: Igor Shovkun Date: Thu, 19 Feb 2026 07:52:06 -0800 Subject: [PATCH 28/33] Move Triton reference kernel to triton_reference subdir and update imports --- benchmarks/routines/mamba.py | 6 ++-- csrc/selective_state_update.cu | 1 + .../mamba/test_selective_state_update_mtp.py | 2 +- .../mamba/test_selective_state_update_stp.py | 4 +-- .../selective_state_update.py} | 28 ++++++++----------- 5 files changed, 19 insertions(+), 22 deletions(-) rename tests/mamba/{selective_state_update_triton.py => triton_reference/selective_state_update.py} (97%) diff --git a/benchmarks/routines/mamba.py b/benchmarks/routines/mamba.py index 0d53dae849..a8bdd5640e 100644 --- a/benchmarks/routines/mamba.py +++ b/benchmarks/routines/mamba.py @@ -49,17 +49,17 @@ def _import_triton_reference(): Uses importlib to load the module directly by file path, avoiding sys.path pollution and fragile relative path assumptions. """ - # Resolve path: benchmarks/routines/mamba.py -> ../../tests/mamba/selective_state_update_triton.py + # Resolve path: benchmarks/routines/mamba.py -> ../../tests/mamba/triton_reference/selective_state_update.py _this_dir = os.path.dirname(os.path.abspath(__file__)) _repo_root = os.path.normpath(os.path.join(_this_dir, "..", "..")) _triton_ref_path = os.path.join( - _repo_root, "tests", "mamba", "selective_state_update_triton.py" + _repo_root, "tests", "mamba", "triton_reference", "selective_state_update.py" ) if not os.path.isfile(_triton_ref_path): raise ImportError( f"Cannot find Triton reference kernel at: {_triton_ref_path}\n" - f"Expected location: /tests/mamba/selective_state_update_triton.py\n" + f"Expected location: /tests/mamba/triton_reference/selective_state_update.py\n" f"Make sure you are running from within the FlashInfer repository." ) diff --git a/csrc/selective_state_update.cu b/csrc/selective_state_update.cu index 7afe8b21e0..3918d3caf8 100644 --- a/csrc/selective_state_update.cu +++ b/csrc/selective_state_update.cu @@ -16,6 +16,7 @@ // clang-format off // config.inc MUST come before the header: it defines DIM, DSTATE, NTOKENS_MTP // constexprs that the header's function templates rely on. Reordering breaks compilation. +// NOTE: the .inc file is generated from the jinja templates #include "selective_state_update_config.inc" #include // clang-format on diff --git a/tests/mamba/test_selective_state_update_mtp.py b/tests/mamba/test_selective_state_update_mtp.py index 4430d305b0..a51c5e23f3 100644 --- a/tests/mamba/test_selective_state_update_mtp.py +++ b/tests/mamba/test_selective_state_update_mtp.py @@ -11,7 +11,7 @@ import flashinfer -from .selective_state_update_triton import selective_state_update_triton +from .triton_reference.selective_state_update import selective_state_update_triton from .utils import create_test_inputs, clone_preserving_strides diff --git a/tests/mamba/test_selective_state_update_stp.py b/tests/mamba/test_selective_state_update_stp.py index 6b8b293d07..d23faa644a 100644 --- a/tests/mamba/test_selective_state_update_stp.py +++ b/tests/mamba/test_selective_state_update_stp.py @@ -5,7 +5,7 @@ import flashinfer from flashinfer.utils import get_compute_capability -from .selective_state_update_triton import selective_state_update_triton +from .triton_reference.selective_state_update import selective_state_update_triton from .utils import create_test_inputs, clone_preserving_strides @@ -295,7 +295,7 @@ def test_output_correctness( ): """Test that kernel output matches reference but state is not updated.""" inputs = self.make_inputs(batch, nheads, dim, dstate, state_dtype, weight_dtype) - y_ref, state_ref = self.make_reference_output(inputs) + y_ref, _state_ref = self.make_reference_output(inputs) # Save the initial state before running the kernel state_initial = inputs["state_cache"].clone() diff --git a/tests/mamba/selective_state_update_triton.py b/tests/mamba/triton_reference/selective_state_update.py similarity index 97% rename from tests/mamba/selective_state_update_triton.py rename to tests/mamba/triton_reference/selective_state_update.py index c40f90612a..f574a6ebae 100644 --- a/tests/mamba/selective_state_update_triton.py +++ b/tests/mamba/triton_reference/selective_state_update.py @@ -9,25 +9,21 @@ import torch import triton import triton.language as tl -from packaging import version -PAD_SLOT_ID = -1 - -TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") - -if TRITON3: +try: + from .softplus import softplus # noqa: F401 +except ImportError: + # Fallback when loaded standalone via importlib (no package context) + import os as _os + import importlib.util as _ilu - @triton.jit - def softplus(dt): - dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt) - return dt + _softplus_path = _os.path.join(_os.path.dirname(__file__), "softplus.py") + _sp_spec = _ilu.spec_from_file_location("softplus", _softplus_path) + _sp_mod = _ilu.module_from_spec(_sp_spec) + _sp_spec.loader.exec_module(_sp_mod) + softplus = _sp_mod.softplus -else: - - @triton.jit - def softplus(dt): - dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) - return dt +PAD_SLOT_ID = -1 @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) From f32b63b6e3b2dc5af8a4c22b98ac2863be63e942 Mon Sep 17 00:00:00 2001 From: Igor Shovkun Date: Thu, 19 Feb 2026 08:13:02 -0800 Subject: [PATCH 29/33] mark an unused variable with "_" in a test --- tests/mamba/test_selective_state_update_mtp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/mamba/test_selective_state_update_mtp.py b/tests/mamba/test_selective_state_update_mtp.py index a51c5e23f3..e49ff8c494 100644 --- a/tests/mamba/test_selective_state_update_mtp.py +++ b/tests/mamba/test_selective_state_update_mtp.py @@ -278,7 +278,7 @@ def test_output_correctness( inputs = self.make_inputs( batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype ) - y_ref, state_ref = self.make_reference_output(inputs) + y_ref, _ = self.make_reference_output(inputs) # Save the initial state before running the kernel state_initial = inputs["state_cache"].clone() From 2656202aadbe2ee26e10d659412563ad75b16f04 Mon Sep 17 00:00:00 2001 From: Igor Shovkun Date: Thu, 19 Feb 2026 08:14:20 -0800 Subject: [PATCH 30/33] rename an unused test variable to _state_ref --- tests/mamba/test_selective_state_update_mtp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/mamba/test_selective_state_update_mtp.py b/tests/mamba/test_selective_state_update_mtp.py index e49ff8c494..f295ee6bad 100644 --- a/tests/mamba/test_selective_state_update_mtp.py +++ b/tests/mamba/test_selective_state_update_mtp.py @@ -428,7 +428,7 @@ def test_output_correctness( inputs = self.make_inputs( batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype ) - y_ref, state_ref, intermediate_states_ref = self.make_reference_output(inputs) + y_ref, _state_ref, intermediate_states_ref = self.make_reference_output(inputs) # Prepare output tensor if requested if use_out_tensor: From 5580d2877e2f064a8949eb7eb984919e82bff3ce Mon Sep 17 00:00:00 2001 From: Igor Shovkun Date: Thu, 19 Feb 2026 08:24:44 -0800 Subject: [PATCH 31/33] Refactor Triton reference import for selective_state_update Simplify import logic by adding tests/mamba to sys.path and using standard package imports for the Triton reference kernel and softplus. --- benchmarks/routines/mamba.py | 53 +++++-------------- .../selective_state_update.py | 13 +---- 2 files changed, 13 insertions(+), 53 deletions(-) diff --git a/benchmarks/routines/mamba.py b/benchmarks/routines/mamba.py index a8bdd5640e..d5fffd57c8 100644 --- a/benchmarks/routines/mamba.py +++ b/benchmarks/routines/mamba.py @@ -14,14 +14,8 @@ limitations under the License. """ -# ============================================================================== -# Triton reference implementation for selective_state_update. -# Imported from tests/mamba/selective_state_update_triton.py to avoid code -# duplication. See that file for the canonical Triton kernel source. -# ============================================================================== - -import importlib import os +import sys from collections import defaultdict import numpy as np @@ -30,6 +24,14 @@ import flashinfer from flashinfer.testing.utils import bench_gpu_time +# Add tests/mamba to sys.path so triton_reference is importable as a package +_repo_root = os.path.normpath( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..") +) +_tests_mamba = os.path.join(_repo_root, "tests", "mamba") +if _tests_mamba not in sys.path: + sys.path.insert(0, _tests_mamba) + from .flashinfer_benchmark_utils import ( dtype_str_to_torch_dtype, get_device, @@ -38,40 +40,9 @@ filter_backends_by_compute_capability, ) -# ---- Import Triton reference kernel from tests/mamba/ ---- -# The canonical Triton selective_state_update lives in tests/mamba/selective_state_update_triton.py. -# We import it here rather than duplicating ~400 lines of kernel code. - - -def _import_triton_reference(): - """Import selective_state_update_triton from tests/mamba/. - - Uses importlib to load the module directly by file path, avoiding sys.path - pollution and fragile relative path assumptions. - """ - # Resolve path: benchmarks/routines/mamba.py -> ../../tests/mamba/triton_reference/selective_state_update.py - _this_dir = os.path.dirname(os.path.abspath(__file__)) - _repo_root = os.path.normpath(os.path.join(_this_dir, "..", "..")) - _triton_ref_path = os.path.join( - _repo_root, "tests", "mamba", "triton_reference", "selective_state_update.py" - ) - - if not os.path.isfile(_triton_ref_path): - raise ImportError( - f"Cannot find Triton reference kernel at: {_triton_ref_path}\n" - f"Expected location: /tests/mamba/triton_reference/selective_state_update.py\n" - f"Make sure you are running from within the FlashInfer repository." - ) - - spec = importlib.util.spec_from_file_location( - "selective_state_update_triton", _triton_ref_path - ) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - return module.selective_state_update_triton - - -selective_state_update_triton_reference = _import_triton_reference() +from triton_reference.selective_state_update import ( + selective_state_update_triton as selective_state_update_triton_reference, +) # ============================================================================== diff --git a/tests/mamba/triton_reference/selective_state_update.py b/tests/mamba/triton_reference/selective_state_update.py index f574a6ebae..88d20bf251 100644 --- a/tests/mamba/triton_reference/selective_state_update.py +++ b/tests/mamba/triton_reference/selective_state_update.py @@ -10,18 +10,7 @@ import triton import triton.language as tl -try: - from .softplus import softplus # noqa: F401 -except ImportError: - # Fallback when loaded standalone via importlib (no package context) - import os as _os - import importlib.util as _ilu - - _softplus_path = _os.path.join(_os.path.dirname(__file__), "softplus.py") - _sp_spec = _ilu.spec_from_file_location("softplus", _softplus_path) - _sp_mod = _ilu.module_from_spec(_sp_spec) - _sp_spec.loader.exec_module(_sp_mod) - softplus = _sp_mod.softplus +from .softplus import softplus PAD_SLOT_ID = -1 From 58f56cd09d664c34c2386c5158c2c0363faab91c Mon Sep 17 00:00:00 2001 From: Igor Shovkun Date: Fri, 20 Feb 2026 09:42:24 -0800 Subject: [PATCH 32/33] Fixes aot compilation of the gdn_prefill_sm90 module --- flashinfer/aot.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flashinfer/aot.py b/flashinfer/aot.py index e45ea05491..251634ad94 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -596,6 +596,7 @@ def gen_all_modules( ) ) jit_specs.append(gen_trtllm_utils_module()) + if has_sm90: jit_specs.append(gen_gdn_prefill_sm90_module()) if ( From 5d8184ec7b71b8a8c382134551c06410becad53d Mon Sep 17 00:00:00 2001 From: Igor Shovkun Date: Fri, 20 Feb 2026 10:05:17 -0800 Subject: [PATCH 33/33] Substantially reduce the nubmer of SSU aot compilation units. Limited to only used by major frameworks. --- flashinfer/aot.py | 24 ++++-------------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/flashinfer/aot.py b/flashinfer/aot.py index 251634ad94..f11ac238bb 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -548,19 +548,7 @@ def gen_all_modules( ] # selective_state_update: one module per dtype combo per GPU arch _ssu_dtype_combos = [ - # (state, input, weight, matrixA, stateIndex) - ( - torch.bfloat16, - torch.bfloat16, - torch.bfloat16, - torch.float32, - torch.int32, - ), - (torch.float16, torch.bfloat16, torch.bfloat16, torch.float32, torch.int32), - (torch.float32, torch.bfloat16, torch.bfloat16, torch.float32, torch.int32), - (torch.bfloat16, torch.bfloat16, torch.float32, torch.float32, torch.int32), - (torch.float16, torch.bfloat16, torch.float32, torch.float32, torch.int32), - (torch.float32, torch.bfloat16, torch.float32, torch.float32, torch.int32), + # (state, input, weight, matrixA, stateIndex) ( torch.bfloat16, torch.bfloat16, @@ -568,15 +556,11 @@ def gen_all_modules( torch.float32, torch.int64, ), - (torch.float16, torch.bfloat16, torch.bfloat16, torch.float32, torch.int64), (torch.float32, torch.bfloat16, torch.bfloat16, torch.float32, torch.int64), - (torch.bfloat16, torch.bfloat16, torch.float32, torch.float32, torch.int64), - (torch.float16, torch.bfloat16, torch.float32, torch.float32, torch.int64), - (torch.float32, torch.bfloat16, torch.float32, torch.float32, torch.int64), ] - _ssu_dims = [64, 128, 256] - _ssu_dstates = [64, 128, 256] - _ssu_ntokens = [1, 2, 4, 6, 8, 12, 16] + _ssu_dims = [64] + _ssu_dstates = [128] + _ssu_ntokens = [1, 4, 6, 8] for dtype_combo, dim, dstate, ntokens in product( _ssu_dtype_combos, _ssu_dims, _ssu_dstates, _ssu_ntokens ):