diff --git a/.gitignore b/.gitignore index 1a89e54605..156be6f629 100644 --- a/.gitignore +++ b/.gitignore @@ -24,6 +24,10 @@ flashinfer/cute_dsl/benchmark_gated_delta_rule.py # vscode .vscode/ +# zed text editor +.zed/ +.rules + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/csrc/flashinfer_mamba_binding.cu b/csrc/flashinfer_mamba_binding.cu index 2e2453cefc..9461cbfa08 100644 --- a/csrc/flashinfer_mamba_binding.cu +++ b/csrc/flashinfer_mamba_binding.cu @@ -38,10 +38,13 @@ void selective_state_update( bool dt_softplus, Optional state_batch_indices, // (batch,) int64_t pad_slot_id, - TensorView output, // same as x + Optional state_scale, // float32: (state_cache_size, nheads, dim) + TensorView output, // same as x bool disable_state_update, Optional intermediate_states_buffer, // (batch, cache_steps, nheads, dim, dstate) Optional intermediate_state_indices, // (batch,) + Optional intermediate_state_scales, // float32: (batch, cache_steps, nheads, dim) + Optional rand_seed, // device-side int64 tensor for Philox rounding int64_t cache_steps, int64_t algorithm); // SSUAlgorithm: 0=auto, 1=simple, 2=vertical, 3=horizontal diff --git a/csrc/selective_state_update.cu b/csrc/selective_state_update.cu index 3918d3caf8..3a00f5ba9a 100644 --- a/csrc/selective_state_update.cu +++ b/csrc/selective_state_update.cu @@ -16,7 +16,7 @@ // clang-format off // config.inc MUST come before the header: it defines DIM, DSTATE, NTOKENS_MTP // constexprs that the header's function templates rely on. Reordering breaks compilation. -// NOTE: the .inc file is generated from the jinja templates +// NOTE: the .inc file is generated from the jinja template csrc/selective_state_update_customize_config.jinja #include "selective_state_update_config.inc" #include // clang-format on @@ -99,6 +99,22 @@ inline void validate_intermediate_states_buffer( CHECK_CONTIGUOUS(intermediate_states_buffer.value()); } +inline void validate_state_scale(Optional const& state_scale, int64_t state_cache_size, + int64_t nheads, int64_t dim) { + if (!state_scale.has_value()) return; + auto const& scale = state_scale.value(); + CHECK_CUDA(scale); + CHECK_DIM(3, scale); // state_scale: {state_cache_size, nheads, dim} + FLASHINFER_CHECK(scale.size(0) == state_cache_size, + "state_scale.size(0) must equal state_cache_size"); + FLASHINFER_CHECK(scale.size(1) == nheads, "state_scale.size(1) must equal nheads"); + FLASHINFER_CHECK(scale.size(2) == dim, "state_scale.size(2) must equal dim"); + // Inner dims (nheads, dim) must be contiguous + FLASHINFER_CHECK(scale.stride(2) == 1, "state_scale.stride(2) must be 1, got ", scale.stride(2)); + FLASHINFER_CHECK(scale.stride(1) == dim, "state_scale.stride(1) must equal dim, got ", + scale.stride(1)); +} + // Validates dtype consistency across tensors inline void validate_dtype_consistency( TensorView const& state, TensorView const& dt, TensorView const& D, TensorView const& x, @@ -133,8 +149,9 @@ void run_selective_state_update_stp(TensorView const& state, TensorView const& x TensorView const& C, TensorView const& D, Optional z, Optional dt_bias, bool dt_softplus, Optional state_batch_indices, - int64_t pad_slot_id, Optional out, - bool disable_state_update, int64_t algorithm) { + Optional state_scale, int64_t pad_slot_id, + Optional out, bool disable_state_update, + Optional rand_seed, int64_t algorithm) { // Extract dimensions from input tensors auto const batch = x.size(0); auto const state_cache_size = state.size(0); @@ -219,6 +236,7 @@ void run_selective_state_update_stp(TensorView const& state, TensorView const& x // Validate dtype consistency validate_dtype_consistency(state, dt, D, x, B, C, dt_bias, z, out); + validate_state_scale(state_scale, state_cache_size, nheads, dim); // Initialize params struct SelectiveStateUpdateParams p; @@ -248,6 +266,18 @@ void run_selective_state_update_stp(TensorView const& state, TensorView const& x if (state_batch_indices.has_value()) { p.state_batch_indices = const_cast(state_batch_indices.value().data_ptr()); } + if (state_scale.has_value()) { + p.state_scale = state_scale.value().data_ptr(); + p.state_scale_stride_batch = state_scale.value().stride(0); + } + if (rand_seed.has_value()) { + auto const& rs = rand_seed.value(); + CHECK_CUDA(rs); + FLASHINFER_CHECK(rs.numel() == 1, + "rand_seed must be a single-element tensor, got numel=", rs.numel()); + FLASHINFER_CHECK(rs.dtype().code == kDLInt && rs.dtype().bits == 64, "rand_seed must be int64"); + p.rand_seed = static_cast(rs.data_ptr()); + } // Copy pointers p.state = const_cast(state.data_ptr()); @@ -275,16 +305,18 @@ void run_selective_state_update_stp(TensorView const& state, TensorView const& x const cudaStream_t stream = get_stream(state.device()); auto algo = static_cast(algorithm); - invokeSelectiveStateUpdate(p, algo, stream); + invokeSelectiveStateUpdate( + p, algo, stream); } void run_selective_state_update_mtp( TensorView const& state, TensorView const& x, TensorView const& dt, TensorView const& A, TensorView const& B, TensorView const& C, TensorView const& D, Optional z, Optional dt_bias, bool dt_softplus, Optional state_batch_indices, - int64_t pad_slot_id, Optional out, bool disable_state_update, - Optional intermediate_states_buffer, - Optional intermediate_state_indices, int64_t cache_steps, int64_t algorithm) { + Optional state_scale, int64_t pad_slot_id, Optional out, + bool disable_state_update, Optional intermediate_states_buffer, + Optional intermediate_state_indices, Optional intermediate_state_scales, + Optional rand_seed, int64_t cache_steps, int64_t algorithm) { // Extract dimensions from input tensors auto const batch = x.size(0); auto const ntokens_mtp = x.size(1); @@ -378,6 +410,7 @@ void run_selective_state_update_mtp( validate_dtype_consistency(state, dt, D, x, B, C, dt_bias, z, out, intermediate_states_buffer); validate_intermediate_state_indices(intermediate_state_indices, batch); validate_intermediate_states_buffer(intermediate_states_buffer); + validate_state_scale(state_scale, state_cache_size, nheads, dim); // Validate that state_batch_indices and intermediate_state_indices have the same dtype if (state_batch_indices.has_value() && intermediate_state_indices.has_value()) { @@ -435,6 +468,10 @@ void run_selective_state_update_mtp( if (state_batch_indices.has_value()) { p.state_batch_indices = const_cast(state_batch_indices.value().data_ptr()); } + if (state_scale.has_value()) { + p.state_scale = state_scale.value().data_ptr(); + p.state_scale_stride_batch = state_scale.value().stride(0); + } if (intermediate_states_buffer.has_value()) { p.intermediate_states = const_cast(intermediate_states_buffer.value().data_ptr()); @@ -445,6 +482,30 @@ void run_selective_state_update_mtp( p.intermediate_state_indices = const_cast(intermediate_state_indices.value().data_ptr()); } + if (intermediate_state_scales.has_value()) { + auto const& iscales = intermediate_state_scales.value(); + CHECK_CUDA(iscales); + CHECK_CONTIGUOUS(iscales); + CHECK_DIM(4, iscales); // (batch, cache_steps, nheads, dim) + FLASHINFER_CHECK(iscales.size(0) == batch, + "intermediate_state_scales.size(0) must equal batch"); + FLASHINFER_CHECK(iscales.size(1) == cache_steps, + "intermediate_state_scales.size(1) must equal cache_steps"); + FLASHINFER_CHECK(iscales.size(2) == nheads, + "intermediate_state_scales.size(2) must equal nheads"); + FLASHINFER_CHECK(iscales.size(3) == dim, "intermediate_state_scales.size(3) must equal dim"); + p.intermediate_state_scales = iscales.data_ptr(); + p.intermediate_state_scales_stride_batch = iscales.stride(0); + } + if (rand_seed.has_value()) { + auto const& rs = rand_seed.value(); + CHECK_CUDA(rs); + FLASHINFER_CHECK(rs.numel() == 1, + "rand_seed must be a single-element tensor, got numel=", rs.numel()); + FLASHINFER_CHECK(rs.dtype().code == kDLInt && rs.dtype().bits == 64, "rand_seed must be int64"); + p.rand_seed = static_cast(rs.data_ptr()); + } + // Copy pointers p.state = const_cast(state.data_ptr()); p.x = const_cast(x.data_ptr()); @@ -472,30 +533,29 @@ void run_selective_state_update_mtp( const cudaStream_t stream = get_stream(state.device()); auto algo = static_cast(algorithm); - mtp::invokeSelectiveStateUpdateMTP(p, algo, - stream); + mtp::invokeSelectiveStateUpdateMTP(p, algo, stream); } // ============================================================================= // Generic dispatcher - routes to single-token or multi-token based on x.dim() // ============================================================================= -void selective_state_update(TensorView state, TensorView x, TensorView dt, TensorView A, - TensorView B, TensorView C, TensorView D, Optional z, - Optional dt_bias, bool dt_softplus, - Optional state_batch_indices, int64_t pad_slot_id, - TensorView output, bool disable_state_update, - Optional intermediate_states_buffer, - Optional intermediate_state_indices, int64_t cache_steps, - int64_t algorithm) { +void selective_state_update( + TensorView state, TensorView x, TensorView dt, TensorView A, TensorView B, TensorView C, + TensorView D, Optional z, Optional dt_bias, bool dt_softplus, + Optional state_batch_indices, int64_t pad_slot_id, Optional state_scale, + TensorView output, bool disable_state_update, Optional intermediate_states_buffer, + Optional intermediate_state_indices, Optional intermediate_state_scales, + Optional rand_seed, int64_t cache_steps, int64_t algorithm) { if (x.dim() == 3) { run_selective_state_update_stp(state, x, dt, A, B, C, D, z, dt_bias, dt_softplus, - state_batch_indices, pad_slot_id, output, disable_state_update, - algorithm); + state_batch_indices, state_scale, pad_slot_id, output, + disable_state_update, rand_seed, algorithm); } else if (x.dim() == 4) { - run_selective_state_update_mtp(state, x, dt, A, B, C, D, z, dt_bias, dt_softplus, - state_batch_indices, pad_slot_id, output, disable_state_update, - intermediate_states_buffer, intermediate_state_indices, - cache_steps, algorithm); + run_selective_state_update_mtp( + state, x, dt, A, B, C, D, z, dt_bias, dt_softplus, state_batch_indices, state_scale, + pad_slot_id, output, disable_state_update, intermediate_states_buffer, + intermediate_state_indices, intermediate_state_scales, rand_seed, cache_steps, algorithm); } else { FLASHINFER_CHECK(false, "x must have 3 dimensions (single-token) or 4 dimensions (multi-token), got ", diff --git a/csrc/selective_state_update_customize_config.jinja b/csrc/selective_state_update_customize_config.jinja index 418356212d..4e6fb6caa4 100644 --- a/csrc/selective_state_update_customize_config.jinja +++ b/csrc/selective_state_update_customize_config.jinja @@ -8,7 +8,13 @@ using input_t = {{ input_dtype }}; using weight_t = {{ weight_dtype }}; using matrixA_t = {{ matrixA_dtype }}; using stateIndex_t = {{ stateIndex_dtype }}; +// Type for block-scale decode factors (e.g. float, __half). +// void = no scaling (state_t is used as-is). +using state_scale_t = {{ state_scale_type }}; constexpr int DIM = {{ dim }}; constexpr int DSTATE = {{ dstate }}; constexpr int NTOKENS_MTP = {{ ntokens_mtp }}; +// Philox PRNG rounds for stochastic rounding of fp16 state stores. +// 0 = no stochastic rounding; typical value = 10. +constexpr int PHILOX_ROUNDS = {{ philox_rounds }}; diff --git a/csrc/selective_state_update_kernel_inst.cu b/csrc/selective_state_update_kernel_inst.cu index 6dcec72a5d..c734fa6d13 100644 --- a/csrc/selective_state_update_kernel_inst.cu +++ b/csrc/selective_state_update_kernel_inst.cu @@ -7,12 +7,14 @@ namespace flashinfer::mamba { -template void invokeSelectiveStateUpdate( - SelectiveStateUpdateParams&, SSUAlgorithm, cudaStream_t); +template void invokeSelectiveStateUpdate(SelectiveStateUpdateParams&, SSUAlgorithm, + cudaStream_t); namespace mtp { -template void invokeSelectiveStateUpdateMTP( - SelectiveStateMTPParams&, SSUAlgorithm, cudaStream_t); +template void invokeSelectiveStateUpdateMTP(SelectiveStateMTPParams&, SSUAlgorithm, + cudaStream_t); } // namespace mtp } // namespace flashinfer::mamba diff --git a/flashinfer/aot.py b/flashinfer/aot.py index f11ac238bb..7a727c0fa3 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -548,15 +548,32 @@ def gen_all_modules( ] # selective_state_update: one module per dtype combo per GPU arch _ssu_dtype_combos = [ - # (state, input, weight, matrixA, stateIndex) + # (state, input, weight, matrixA, stateIndex, state_scale_dtype) ( torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.float32, torch.int64, + None, + ), + # int16 state (block-scaled quantization, scale stored as float32) + ( + torch.int16, + torch.bfloat16, + torch.bfloat16, + torch.float32, + torch.int64, + torch.float32, + ), + ( + torch.float32, + torch.bfloat16, + torch.bfloat16, + torch.float32, + torch.int64, + None, ), - (torch.float32, torch.bfloat16, torch.bfloat16, torch.float32, torch.int64), ] _ssu_dims = [64] _ssu_dstates = [128] diff --git a/flashinfer/jit/mamba/selective_state_update.py b/flashinfer/jit/mamba/selective_state_update.py index bba5c3d375..15b1e13094 100644 --- a/flashinfer/jit/mamba/selective_state_update.py +++ b/flashinfer/jit/mamba/selective_state_update.py @@ -15,6 +15,7 @@ """ import os +from typing import Optional import jinja2 import torch @@ -29,6 +30,7 @@ torch.float16: "half", torch.bfloat16: "nv_bfloat16", torch.float32: "float", + torch.int16: "int16_t", torch.int32: "int32_t", torch.int64: "int64_t", } @@ -38,6 +40,7 @@ torch.float16: "f16", torch.bfloat16: "bf16", torch.float32: "f32", + torch.int16: "i16", torch.int32: "i32", torch.int64: "i64", } @@ -49,17 +52,24 @@ def get_selective_state_update_uri( weight_dtype: torch.dtype, matrixA_dtype: torch.dtype, stateIndex_dtype: torch.dtype, + state_scale_dtype: Optional[torch.dtype], dim: int, dstate: int, ntokens_mtp: int, + philox_rounds: int = 0, ) -> str: s = _filename_safe_dtype_map - return ( + uri = ( f"selective_state_update_" f"s_{s[state_dtype]}_i_{s[input_dtype]}_w_{s[weight_dtype]}_" f"a_{s[matrixA_dtype]}_si_{s[stateIndex_dtype]}_" f"d_{dim}_ds_{dstate}_nt_{ntokens_mtp}" ) + if state_scale_dtype is not None: + uri += f"_sc_{s[state_scale_dtype]}" + if philox_rounds > 0: + uri += f"_pr_{philox_rounds}" + return uri def _gen_module( @@ -69,9 +79,11 @@ def _gen_module( weight_dtype: torch.dtype, matrixA_dtype: torch.dtype, stateIndex_dtype: torch.dtype, + state_scale_dtype: Optional[torch.dtype], dim: int, dstate: int, ntokens_mtp: int, + philox_rounds: int = 0, extra_cuda_cflags: list = None, ) -> JitSpec: gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri @@ -83,6 +95,9 @@ def _gen_module( ) as f: config_templ = jinja2.Template(f.read()) + state_scale_type = ( + _dtype_map[state_scale_dtype] if state_scale_dtype is not None else "void" + ) config_str = config_templ.render( state_dtype=_dtype_map[state_dtype], input_dtype=_dtype_map[input_dtype], @@ -92,6 +107,8 @@ def _gen_module( dim=dim, dstate=dstate, ntokens_mtp=ntokens_mtp, + state_scale_type=state_scale_type, + philox_rounds=philox_rounds, ) write_if_different(gen_directory / "selective_state_update_config.inc", config_str) @@ -122,9 +139,11 @@ def gen_selective_state_update_module( weight_dtype: torch.dtype, matrixA_dtype: torch.dtype, stateIndex_dtype: torch.dtype, + state_scale_dtype: Optional[torch.dtype], dim: int, dstate: int, ntokens_mtp: int, + philox_rounds: int = 0, ) -> JitSpec: uri = get_selective_state_update_uri( state_dtype, @@ -132,9 +151,11 @@ def gen_selective_state_update_module( weight_dtype, matrixA_dtype, stateIndex_dtype, + state_scale_dtype, dim, dstate, ntokens_mtp, + philox_rounds, ) return _gen_module( uri, @@ -143,9 +164,12 @@ def gen_selective_state_update_module( weight_dtype, matrixA_dtype, stateIndex_dtype, + state_scale_dtype, dim, dstate, ntokens_mtp, + philox_rounds=philox_rounds, + extra_cuda_cflags=["-lineinfo"], ) @@ -155,9 +179,11 @@ def gen_selective_state_update_sm90_module( weight_dtype: torch.dtype, matrixA_dtype: torch.dtype, stateIndex_dtype: torch.dtype, + state_scale_dtype: Optional[torch.dtype], dim: int, dstate: int, ntokens_mtp: int, + philox_rounds: int = 0, ) -> JitSpec: uri = ( get_selective_state_update_uri( @@ -166,9 +192,11 @@ def gen_selective_state_update_sm90_module( weight_dtype, matrixA_dtype, stateIndex_dtype, + state_scale_dtype, dim, dstate, ntokens_mtp, + philox_rounds, ) + "_sm90" ) @@ -184,8 +212,10 @@ def gen_selective_state_update_sm90_module( weight_dtype, matrixA_dtype, stateIndex_dtype, + state_scale_dtype, dim, dstate, ntokens_mtp, + philox_rounds=philox_rounds, extra_cuda_cflags=nvcc_flags, ) diff --git a/flashinfer/mamba/selective_state_update.py b/flashinfer/mamba/selective_state_update.py index 294330be88..92447a923e 100644 --- a/flashinfer/mamba/selective_state_update.py +++ b/flashinfer/mamba/selective_state_update.py @@ -38,6 +38,8 @@ def _get_module( dstate: int, ntokens_mtp: int, sm_major: int, + state_scale_dtype: Optional[torch.dtype] = None, + philox_rounds: int = 0, ): args = ( state_dtype, @@ -45,9 +47,11 @@ def _get_module( weight_dtype, matrixA_dtype, stateIndex_dtype, + state_scale_dtype, dim, dstate, ntokens_mtp, + philox_rounds, ) if sm_major >= 9: return gen_selective_state_update_sm90_module(*args).build_and_load() @@ -65,6 +69,8 @@ def get_selective_state_update_module( dim: int, dstate: int, ntokens_mtp: int, + state_scale_dtype: Optional[torch.dtype] = None, + philox_rounds: int = 0, ): major, _ = get_compute_capability(device) return _get_module( @@ -77,6 +83,8 @@ def get_selective_state_update_module( dstate, ntokens_mtp, major, + state_scale_dtype, + philox_rounds, ) @@ -94,10 +102,14 @@ def selective_state_update( dt_softplus: bool = False, state_batch_indices: Optional[torch.Tensor] = None, pad_slot_id: int = -1, + state_scale: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, disable_state_update: bool = False, intermediate_states_buffer: Optional[torch.Tensor] = None, intermediate_state_indices: Optional[torch.Tensor] = None, + intermediate_state_scales: Optional[torch.Tensor] = None, + rand_seed: Optional[torch.Tensor] = None, + philox_rounds: int = 10, cache_steps: int = 0, algorithm: str = "auto", ) -> torch.Tensor: @@ -136,6 +148,9 @@ def selective_state_update( If state_batch_indices is passed, lets the kernel identify padded entries that will not be processed. For example: state_batch_indices = [pad_slot_id, 1, 20, pad_slot_id] in this case, the kernel will not process entries at indices 0 and 3 + state_scale : Optional[torch.Tensor] + Optional float32 scale tensor with shape (state_cache_size, nheads, dim) + for int16 state quantization with block scaling out : Optional[torch.Tensor] Optional output tensor (same shape as x) disable_state_update : bool @@ -146,6 +161,17 @@ def selective_state_update( intermediate_state_indices : Optional[torch.Tensor] Optional indices mapping batch elements to intermediate state buffer positions with shape (batch,) + rand_seed : Optional[torch.Tensor] + Optional single-element int64 CUDA tensor for stochastic rounding seed (Philox-4x32 PRNG). + Using a device-side tensor (rather than a host integer) ensures CUDA graph compatibility, + since the graph captures the pointer and the seed value can be updated between replays. + When provided, state values are stochastically rounded before storing to fp16. + When None, no stochastic rounding is applied (regardless of philox_rounds). + Cannot be used together with state_scale. + philox_rounds : int + Number of Philox-4x32 PRNG rounds for stochastic rounding (default 10, + matching Triton's tl.randint). Only effective when rand_seed is not None; + ignored otherwise. Must be non-negative. cache_steps : int Number of steps/tokens to cache for speculative decoding algorithm : str @@ -200,6 +226,34 @@ def selective_state_update( z = z.unsqueeze(1) if is_mtp and z.dim() == 3: z = z.unsqueeze(1) + # Normalize state_scale to 3D: (state_cache_size, nheads, dim) + if state_scale is not None and state_scale.dim() == 4 and state_scale.size(-1) == 1: + state_scale = state_scale.squeeze(-1) + + # Validate rand_seed and philox_rounds + if rand_seed is not None: + if not isinstance(rand_seed, torch.Tensor): + raise TypeError( + f"rand_seed must be a CUDA int64 tensor, got {type(rand_seed).__name__}" + ) + if rand_seed.numel() != 1: + raise ValueError( + f"rand_seed must be a single-element tensor, got numel={rand_seed.numel()}" + ) + if rand_seed.dtype != torch.int64: + raise ValueError(f"rand_seed must have dtype int64, got {rand_seed.dtype}") + if not rand_seed.is_cuda: + raise ValueError("rand_seed must be a CUDA tensor") + if state_scale is not None: + raise ValueError("rand_seed and state_scale cannot both be provided") + if philox_rounds <= 0: + raise ValueError( + f"philox_rounds must be > 0 when rand_seed is provided, got {philox_rounds}" + ) + else: + # No stochastic rounding when rand_seed is None + philox_rounds = 0 + if out is None: output = torch.empty_like(x) else: @@ -241,12 +295,16 @@ def selective_state_update( dt_softplus, state_batch_indices, pad_slot_id, + state_scale, output, disable_state_update, intermediate_states_buffer, intermediate_state_indices, + intermediate_state_scales, + rand_seed, cache_steps, algorithm_int, + philox_rounds, state.dtype, x.dtype, dt.dtype, @@ -261,7 +319,13 @@ def selective_state_update( @register_custom_op( "flashinfer::selective_state_update", - mutates_args=("state", "output", "intermediate_states_buffer"), + mutates_args=( + "state", + "output", + "intermediate_states_buffer", + "state_scale", + "intermediate_state_scales", + ), ) def _selective_state_update( state: torch.Tensor, @@ -276,12 +340,16 @@ def _selective_state_update( dt_softplus: bool, state_batch_indices: Optional[torch.Tensor], pad_slot_id: int, + state_scale: Optional[torch.Tensor], output: torch.Tensor, disable_state_update: bool, intermediate_states_buffer: Optional[torch.Tensor], intermediate_state_indices: Optional[torch.Tensor], + intermediate_state_scales: Optional[torch.Tensor], + rand_seed: Optional[torch.Tensor], cache_steps: int, algorithm: int, + philox_rounds: int, state_dtype: torch.dtype, input_dtype: torch.dtype, weight_dtype: torch.dtype, @@ -302,6 +370,8 @@ def _selective_state_update( dim, dstate, ntokens_mtp, + state_scale_dtype=state_scale.dtype if state_scale is not None else None, + philox_rounds=philox_rounds, ).selective_state_update( state, x, @@ -315,10 +385,13 @@ def _selective_state_update( dt_softplus, state_batch_indices, pad_slot_id, + state_scale, output, disable_state_update, intermediate_states_buffer, intermediate_state_indices, + intermediate_state_scales, + rand_seed, cache_steps, algorithm, ) @@ -338,12 +411,16 @@ def _selective_state_update_fake( dt_softplus: bool, state_batch_indices: Optional[torch.Tensor], pad_slot_id: int, + state_scale: Optional[torch.Tensor], output: torch.Tensor, disable_state_update: bool, intermediate_states_buffer: Optional[torch.Tensor], intermediate_state_indices: Optional[torch.Tensor], + intermediate_state_scales: Optional[torch.Tensor], + rand_seed: Optional[torch.Tensor], cache_steps: int, algorithm: int, + philox_rounds: int, state_dtype: torch.dtype, input_dtype: torch.dtype, weight_dtype: torch.dtype, diff --git a/include/flashinfer/mamba/common.cuh b/include/flashinfer/mamba/common.cuh index f7e89029de..fb0b9a0830 100644 --- a/include/flashinfer/mamba/common.cuh +++ b/include/flashinfer/mamba/common.cuh @@ -67,6 +67,13 @@ __device__ __forceinline__ float warpReduceSum(float val) { return val; } +__device__ __forceinline__ float warpReduceMax(float val) { + for (int s = warpSize / 2; s > 0; s /= 2) { + val = max(val, __shfl_down_sync(UINT32_MAX, val, s)); + } + return val; +} + __forceinline__ __device__ float softplus(float x) { return __logf(1.f + __expf(x)); } __device__ __forceinline__ float thresholded_softplus(float dt_value) { diff --git a/include/flashinfer/mamba/conversion.cuh b/include/flashinfer/mamba/conversion.cuh index 7e8a84bf8a..100538d9ec 100644 --- a/include/flashinfer/mamba/conversion.cuh +++ b/include/flashinfer/mamba/conversion.cuh @@ -18,6 +18,10 @@ inline __device__ float toFloat(__half h) { return __half2float(h); } inline __device__ float toFloat(__nv_bfloat16 val) { return __bfloat162float(val); } #endif +// No accuracy loss: int16_t range [-32768, 32767] fits exactly in float32 +// (24-bit mantissa represents all integers up to 2^24 = 16M exactly). +inline __device__ float toFloat(int16_t val) { return static_cast(val); } + inline __device__ void convertAndStore(float* output, float input) { *output = input; } inline __device__ void convertAndStore(__half* output, float input) { @@ -30,4 +34,151 @@ inline __device__ void convertAndStore(__nv_bfloat16* output, float input) { } #endif +inline __device__ void convertAndStore(int16_t* output, float input) { + // Symmetric clip: [-max, max] (not [-max-1, max]) so that negation is safe. + // Matches Triton reference which clips to [-32767, 32767] before storing. + constexpr float int16_max = static_cast(std::numeric_limits::max()); + input = fminf(fmaxf(input, -int16_max), int16_max); + *output = static_cast(__float2int_rn(input)); +} + +// ============================================================================= +// Philox-4x32 PRNG (matches Triton's tl.randint) +// ============================================================================= + +// Generates four pseudorandom uint32s from (seed, offset) using the Philox-4x32 algorithm. +// Produces bit-identical output to Triton's tl.randint4x(seed, offset, n_rounds). +// All four outputs (c0..c3) are independent and uniformly distributed. +template +__device__ __forceinline__ void philox_randint4x(int64_t seed, uint32_t offset, uint32_t& r0, + uint32_t& r1, uint32_t& r2, uint32_t& r3) { + constexpr uint32_t PHILOX_KEY_A = 0x9E3779B9u; + constexpr uint32_t PHILOX_KEY_B = 0xBB67AE85u; + constexpr uint32_t PHILOX_ROUND_A = 0xD2511F53u; + constexpr uint32_t PHILOX_ROUND_B = 0xCD9E8D57u; + + uint32_t k0 = static_cast(static_cast(seed)); + uint32_t k1 = static_cast(static_cast(seed) >> 32); + uint32_t c0 = offset, c1 = 0, c2 = 0, c3 = 0; + +#pragma unroll + for (int i = 0; i < n_rounds; i++) { + uint32_t _c0 = c0, _c2 = c2; + c0 = __umulhi(PHILOX_ROUND_B, _c2) ^ c1 ^ k0; + c2 = __umulhi(PHILOX_ROUND_A, _c0) ^ c3 ^ k1; + c1 = PHILOX_ROUND_B * _c2; + c3 = PHILOX_ROUND_A * _c0; + k0 += PHILOX_KEY_A; + k1 += PHILOX_KEY_B; + } + r0 = c0; + r1 = c1; + r2 = c2; + r3 = c3; +} + +// Generates a pseudorandom uint32 from (seed, offset) using the Philox-4x32 algorithm. +// Produces bit-identical output to Triton's tl.randint(seed, offset, n_rounds). +// NOTE: This discards 3 of the 4 Philox outputs. For better throughput, use +// philox_randint4x to get all 4 outputs from a single Philox invocation. +template +__device__ __forceinline__ uint32_t philox_randint(int64_t seed, uint32_t offset) { + uint32_t r0, r1, r2, r3; + philox_randint4x(seed, offset, r0, r1, r2, r3); + return r0; +} + +// ============================================================================= +// Stochastic rounding: fp32 → fp16 +// ============================================================================= + +// Software stochastic rounding: convert one fp32 value to fp16 using 13 random bits. +// Adds random noise at the sub-fp16-mantissa position, then truncates. +// rand13: 13-bit random value in bits [12:0]. +__device__ __forceinline__ uint16_t cvt_rs_f16_sw(float x, uint32_t rand13) { + uint32_t bits = __float_as_uint(x); + uint32_t sign = bits & 0x80000000u; + uint32_t abs_bits = bits & 0x7FFFFFFFu; + + // fp32 has 23 mantissa bits, fp16 has 10. The 13 LSBs are the remainder. + // Add 13-bit random noise at bits [12:0]. Carry into bit 13 → round up. + abs_bits += (rand13 & 0x1FFFu); + + // Convert to fp16 by truncation. + uint32_t f32_exp = (abs_bits >> 23) & 0xFFu; + uint32_t f32_mantissa = abs_bits & 0x7FFFFFu; + + uint16_t f16_bits; + if (f32_exp == 0xFF) { + f16_bits = (f32_mantissa != 0) ? 0x7E00u : 0x7C00u; // NaN or Inf + } else if (f32_exp > 142) { // 127 + 15 = 142 → overflow to Inf + f16_bits = 0x7C00u; + } else if (f32_exp < 113) { // 127 - 14 = 113 → underflow to zero + f16_bits = 0; + } else { + uint16_t f16_exp = static_cast(f32_exp - 112); // rebias: 127→15 + uint16_t f16_mantissa = static_cast(f32_mantissa >> 13); + f16_bits = (f16_exp << 10) | f16_mantissa; + } + + return static_cast(sign >> 16) | f16_bits; +} + +// Forward declaration (defined below, after cvt_rs_f16x2_f32). +__device__ __forceinline__ uint32_t cvt_rs_f16x2_f32(float a, float b, uint32_t rbits); + +// Stochastic rounding: convert one fp32 value to fp16 using 13 random bits. +// On sm_100a+: uses PTX cvt.rs.f16x2.f32 with a dummy zero second input. +// On other archs: software emulation. +__device__ __forceinline__ __half cvt_rs_f16_f32(float x, uint32_t rand13) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 && defined(__CUDA_ARCH_FEAT_SM100_ALL) + // Pack rand13 into rbits[12:0] (for PTX operand b → low half → our x). + // High half gets zero noise for the dummy input. + uint32_t rbits = rand13 & 0x1FFFu; + uint32_t packed = cvt_rs_f16x2_f32(x, 0.0f, rbits); + return __ushort_as_half(static_cast(packed & 0xFFFFu)); +#else + return __ushort_as_half(cvt_rs_f16_sw(x, rand13)); +#endif +} + +// Stochastic rounding: convert two fp32 values to packed fp16x2 using random bits. +// On sm_100a+: uses PTX cvt.rs.f16x2.f32 instruction. +// On other archs: software emulation matching the hardware behavior. +// +// rbits layout (from PTX docs): +// bits [28:16] = 13 random bits for PTX operand "a" (→ d[31:16], high half) +// bits [12:0] = 13 random bits for PTX operand "b" (→ d[15:0], low half) +// bits [31:29] and [15:13] = unused (zero) +// from: https://docs.nvidia.com/cuda/parallel-thread-execution/#cvt-rs-rbits-layout-f16 +// +// Our asm maps: %1→C++ a→PTX b→d[15:0], %2→C++ b→PTX a→d[31:16] +// So: C++ a uses rbits[12:0], C++ b uses rbits[28:16]. +__device__ __forceinline__ uint32_t cvt_rs_f16x2_f32(float a, float b, uint32_t rbits) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 && defined(__CUDA_ARCH_FEAT_SM100_ALL) + uint32_t packed; + asm("cvt.rs.f16x2.f32 %0, %2, %1, %3;" + : "=r"(packed) + : "r"(__float_as_uint(a)), "r"(__float_as_uint(b)), "r"(rbits)); + return packed; +#else + uint32_t rand_a = rbits & 0x1FFFu; // bits [12:0] → C++ a (PTX b → low half) + uint32_t rand_b = (rbits >> 16) & 0x1FFFu; // bits [28:16] → C++ b (PTX a → high half) + uint16_t a_fp16 = __half_as_ushort(cvt_rs_f16_f32(a, rand_a)); + uint16_t b_fp16 = __half_as_ushort(cvt_rs_f16_f32(b, rand_b)); + return static_cast(a_fp16) | (static_cast(b_fp16) << 16); +#endif +} + +// Stochastic rounding store: generates Philox random bits and converts fp32 → fp16 in one call. +// PHILOX_ROUNDS: number of Philox rounds (compile-time), must be > 0. +// seed: Philox seed (from params.rand_seed). +// offset: unique per-element offset (e.g. d * DSTATE + i) for deterministic randomness. +template +inline __device__ void convertSRAndStore(__half* output, float input, int64_t seed, + uint32_t offset) { + uint32_t rand = philox_randint(seed, offset); + *output = cvt_rs_f16_f32(input, rand & 0x1FFFu); +} + } // namespace flashinfer::mamba::conversion diff --git a/include/flashinfer/mamba/create_tensor_map.cuh b/include/flashinfer/mamba/create_tensor_map.cuh index 8d08cff69d..fcf7a84430 100644 --- a/include/flashinfer/mamba/create_tensor_map.cuh +++ b/include/flashinfer/mamba/create_tensor_map.cuh @@ -8,6 +8,8 @@ #include #include +#include "../utils.cuh" + namespace flashinfer::mamba::tma { inline CUtensorMap buildNdDescriptor(std::type_info const& dtype, @@ -27,6 +29,9 @@ inline CUtensorMap buildNdDescriptor(std::type_info const& dtype, } else if (dtype == typeid(__nv_bfloat16)) { tmaDataFormat = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; dtype_size = sizeof(__nv_bfloat16); + } else if (dtype == typeid(int16_t)) { + tmaDataFormat = CU_TENSOR_MAP_DATA_TYPE_UINT16; + dtype_size = sizeof(int16_t); } else { throw std::invalid_argument("buildNdDescriptor: unsupported dtype"); } @@ -34,9 +39,8 @@ inline CUtensorMap buildNdDescriptor(std::type_info const& dtype, // The swizzle type. CUtensorMapSwizzle swizzleType{CU_TENSOR_MAP_SWIZZLE_NONE}; - // Check gmem address must be 16B-aligned - FLASHINFER_CHECK((reinterpret_cast(gmemAddr) & 0b1111) == 0, - "Tensor must be 16B-aligned"); + // Check gmem address must be 128B-aligned for TMA + FLASHINFER_CHECK_TMA_ALIGNED(gmemAddr); // Check shape must be in range [1, 2^32] int32_t dim = shapes.size(); diff --git a/include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh b/include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh index af86b8094a..37c96a302e 100644 --- a/include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh +++ b/include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh @@ -5,6 +5,7 @@ #include #include #include +#include #include "../utils.cuh" #include "../vec_dtypes.cuh" @@ -16,19 +17,27 @@ namespace flashinfer::mamba::mtp { using namespace conversion; -template +template struct SharedStorageSimple { - input_t x[TOKENS_MTP][DIM]; - float out[TOKENS_MTP][DIM]; - input_t z[TOKENS_MTP][DIM]; - input_t B[TOKENS_MTP][DSTATE]; - input_t C[TOKENS_MTP][DSTATE]; - state_t state[STATE_ROWS][DSTATE]; + static constexpr bool scaleState = !std::is_same_v; + alignas(alignof(PackedAligned)) input_t x[TOKENS_MTP][ROWS_PER_BLOCK]; + alignas(alignof(PackedAligned)) float out[TOKENS_MTP][ROWS_PER_BLOCK]; + alignas(alignof(PackedAligned)) input_t z[TOKENS_MTP][ROWS_PER_BLOCK]; + alignas(alignof(PackedAligned)) input_t B[TOKENS_MTP][DSTATE]; + alignas(alignof(PackedAligned)) input_t C[TOKENS_MTP][DSTATE]; + alignas(alignof(PackedAligned)) state_t state[STATE_ROWS][DSTATE]; + alignas(alignof(PackedAligned)) + std::conditional_t state_scale[STATE_ROWS]; }; +// Grid: (batch, nheads, cdiv(DIM, ROWS_PER_BLOCK)) +// When ROWS_PER_BLOCK == DIM, degenerates to the non-tiled case (blockIdx.z == 0 always). template + typename stateIndex_t, typename state_scale_t, int TOKENS_MTP, int DIM, int DSTATE, + int ROWS_PER_BLOCK, int PHILOX_ROUNDS, int numWarps> __global__ void selective_state_update_kernel_simple_mtp(SelectiveStateMTPParams params) { + constexpr bool scaleState = !std::is_same_v; auto* __restrict__ output = reinterpret_cast(params.output); auto* __restrict__ state = reinterpret_cast(params.state); auto* __restrict__ intermediate_states = reinterpret_cast(params.intermediate_states); @@ -52,22 +61,33 @@ __global__ void selective_state_update_kernel_simple_mtp(SelectiveStateMTPParams auto const batch = blockIdx.x; auto const head = blockIdx.y; + auto const dim_offset = blockIdx.z * ROWS_PER_BLOCK; auto const group = head / (nheads / ngroups); auto lane = threadIdx.x % warpSize; auto warp = threadIdx.y; + // State scale pointer (only used when scaleState == true) + [[maybe_unused]] auto* __restrict__ state_scale = + reinterpret_cast(params.state_scale); + + // Load device-side Philox seed once into a register + [[maybe_unused]] int64_t const rand_seed = params.rand_seed ? *params.rand_seed : 0; + auto const state_batch = (state_batch_indices) ? state_batch_indices[batch] : batch; auto const intermediate_cache_idx = intermediate_state_indices ? intermediate_state_indices[batch] : state_batch; - state += state_batch * params.state_stride_batch + head * DIM * DSTATE; + auto const state_ptr_offset = state_batch * params.state_stride_batch + head * DIM * DSTATE; + state += state_ptr_offset; + if constexpr (scaleState) { + state_scale += state_batch * params.state_scale_stride_batch + head * DIM; + } constexpr auto stateRowsPerWarpPerStage = 4; constexpr auto stageRows = stateRowsPerWarpPerStage * numWarps; extern __shared__ __align__(128) char smem[]; - auto& sram = - *reinterpret_cast*>( - smem); + auto& sram = *reinterpret_cast*>(smem); static constexpr auto stateLoadSize = getVectorLoadSizeForFullUtilization(); using load_state_t = PackedAligned; @@ -81,10 +101,14 @@ __global__ void selective_state_update_kernel_simple_mtp(SelectiveStateMTPParams // Loop over multiple tokens if (warp == 0) { // Load x: gmem -> smem for (int mtp_step = 0; mtp_step < TOKENS_MTP; mtp_step++) { - for (auto d = lane * load_input_t::count; d < DIM; d += warpSize * load_input_t::count) { - auto* dst = reinterpret_cast(&sram.x[mtp_step][d]); - *dst = *reinterpret_cast( - &x[batch * params.x_stride_batch + mtp_step * params.x_stride_mtp + head * DIM + d]); + for (auto d = lane * load_input_t::count; d < ROWS_PER_BLOCK; + d += warpSize * load_input_t::count) { + if (dim_offset + d < DIM) { + auto* dst = reinterpret_cast(&sram.x[mtp_step][d]); + *dst = *reinterpret_cast( + &x[batch * params.x_stride_batch + mtp_step * params.x_stride_mtp + head * DIM + + dim_offset + d]); + } } } } else if (warp == 1) { // Load B: gmem -> smem @@ -98,12 +122,15 @@ __global__ void selective_state_update_kernel_simple_mtp(SelectiveStateMTPParams } } else if (warp == 2) { // Load z: gmem -> smem for (int mtp_step = 0; mtp_step < TOKENS_MTP; mtp_step++) { - for (auto d = lane * load_input_t::count; d < DIM; d += warpSize * load_input_t::count) { - auto* dst = reinterpret_cast(&sram.z[mtp_step][d]); - *dst = z ? *reinterpret_cast( - &z[batch * params.z_stride_batch + mtp_step * params.z_stride_mtp + - head * DIM + d]) - : make_zeros(); + for (auto d = lane * load_input_t::count; d < ROWS_PER_BLOCK; + d += warpSize * load_input_t::count) { + if (dim_offset + d < DIM) { + auto* dst = reinterpret_cast(&sram.z[mtp_step][d]); + *dst = z ? *reinterpret_cast( + &z[batch * params.z_stride_batch + mtp_step * params.z_stride_mtp + + head * DIM + dim_offset + d]) + : make_zeros(); + } } } } @@ -132,21 +159,31 @@ __global__ void selective_state_update_kernel_simple_mtp(SelectiveStateMTPParams __syncthreads(); - for (auto dBegin = 0; dBegin < DIM; dBegin += stageRows) { + for (auto dBegin = 0; dBegin < ROWS_PER_BLOCK; dBegin += stageRows) { // Load state gmem -> smem for (int warpRow = 0; warpRow < stateRowsPerWarpPerStage; warpRow++) { auto dd = warp * stateRowsPerWarpPerStage + warpRow; auto d = dBegin + dd; - if (d < DIM) { + if (dim_offset + d < DIM) { if (state_batch != params.pad_slot_id) { for (int i = lane * load_state_t::count; i < DSTATE; i += warpSize * load_state_t::count) { auto* dst = reinterpret_cast(&sram.state[dd][i]); - *dst = *reinterpret_cast(&state[d * DSTATE + i]); + *dst = *reinterpret_cast(&state[(dim_offset + d) * DSTATE + i]); } } } } + // Load state_scale gmem -> smem (contiguous across warpRows) + if constexpr (scaleState) { + for (int warpRow = lane; warpRow < stateRowsPerWarpPerStage; warpRow += warpSize) { + auto dd = warp * stateRowsPerWarpPerStage + warpRow; + auto d = dBegin + dd; + if (dim_offset + d < DIM && state_batch != params.pad_slot_id) { + sram.state_scale[dd] = state_scale[dim_offset + d]; + } + } + } // Compute how many input_t elements to pack per SRAM load based on DSTATE/warpSize ratio constexpr auto stateValuesPerThread = DSTATE / warpSize; @@ -164,23 +201,28 @@ __global__ void selective_state_update_kernel_simple_mtp(SelectiveStateMTPParams for (int warpRow = 0; warpRow < stateRowsPerWarpPerStage; warpRow++) { auto dd = warp * stateRowsPerWarpPerStage + warpRow; - auto d = dBegin + dd; + auto d = dim_offset + dBegin + dd; // global DIM index if (d >= DIM) break; // Load state smem -> rmem // There is a bank conflict here, but we are not in a hot loop and we must align the state // indices with the input indices + float state_decode_scale = 1.f; + if constexpr (scaleState) { + if (state_batch != params.pad_slot_id) state_decode_scale = toFloat(sram.state_scale[dd]); + } for (int ii = 0; ii < stateValuesPerThread; ii++) { int i = lane * packed_input_t::count + (ii / packed_input_t::count) * warpSize * packed_input_t::count + (ii % packed_input_t::count); - rState[ii] = - (state_batch != params.pad_slot_id && i < DSTATE) ? toFloat(sram.state[dd][i]) : 0.f; + rState[ii] = (state_batch != params.pad_slot_id && i < DSTATE) + ? toFloat(sram.state[dd][i]) * state_decode_scale + : 0.f; } for (int step = 0; step < TOKENS_MTP; step++) { - float x_value = toFloat(sram.x[step][d]); + float x_value = toFloat(sram.x[step][d - dim_offset]); float out_value = d_value * x_value * int(lane == 0); // first lane has the value // Compute dt value for this token @@ -209,32 +251,90 @@ __global__ void selective_state_update_kernel_simple_mtp(SelectiveStateMTPParams out_value += new_state * C_value; } - if constexpr (sizeof(state_t) == sizeof(input_t)) { - if (intermediate_states) { - using packed_state_t = PackedAligned; - packed_state_t rStateOut; + // Store intermediate state to smem (non-scaleState path) + if constexpr (!scaleState) { + if constexpr (sizeof(state_t) == sizeof(input_t)) { + if (intermediate_states) { + using packed_state_t = PackedAligned; + packed_state_t rStateOut; + // Philox-4x32 produces 4 random ints per call; amortize across packed elements. + [[maybe_unused]] uint32_t rand_ints[4]; #pragma unroll - for (int k = 0; k < packed_input_t::count; k++) { - convertAndStore(&rStateOut.val[k], rState[ii + k]); + for (int k = 0; k < packed_input_t::count; k++) { + if constexpr (PHILOX_ROUNDS > 0) { + // SR only applies to fp16 state, so packed count is always >= 2. + static_assert(packed_input_t::count >= 2, + "Stochastic rounding requires fp16 state (packed count >= 2)"); + if (k % 4 == 0) + philox_randint4x( + rand_seed, state_ptr_offset + d * DSTATE + base_i + k, rand_ints[0], + rand_ints[1], rand_ints[2], rand_ints[3]); + rStateOut.val[k] = cvt_rs_f16_f32(rState[ii + k], rand_ints[k % 4] & 0x1FFFu); + } else { + convertAndStore(&rStateOut.val[k], rState[ii + k]); + } + } + *reinterpret_cast(&sram.state[dd][base_i]) = rStateOut; } - *reinterpret_cast(&sram.state[dd][base_i]) = rStateOut; - } - } else { - if (intermediate_states) { + } else { + if (intermediate_states) { + // Philox-4x32 produces 4 random ints per call; amortize across packed elements. + [[maybe_unused]] uint32_t rand_ints[4]; #pragma unroll - for (int k = 0; k < packed_input_t::count; k++) { - convertAndStore(&sram.state[dd][base_i + k], rState[ii + k]); + for (int k = 0; k < packed_input_t::count; k++) { + if constexpr (PHILOX_ROUNDS > 0) { + if (k % 4 == 0) + philox_randint4x( + rand_seed, state_ptr_offset + d * DSTATE + base_i + k, rand_ints[0], + rand_ints[1], rand_ints[2], rand_ints[3]); + sram.state[dd][base_i + k] = + cvt_rs_f16_f32(rState[ii + k], rand_ints[k % 4] & 0x1FFFu); + } else { + convertAndStore(&sram.state[dd][base_i + k], rState[ii + k]); + } + } + } + } + } + } + + // For scaleState + intermediate_states: quantize rState → sram.state with block scaling + if constexpr (scaleState) { + if (intermediate_states && state_batch != params.pad_slot_id) { + // 2-pass: compute max, then encode + float istate_max = std::numeric_limits::lowest(); + for (int ii = 0; ii < stateValuesPerThread; ii++) { + istate_max = fmaxf(istate_max, fabsf(rState[ii])); + } + istate_max = warpReduceMax(istate_max); + istate_max = __shfl_sync(UINT32_MAX, istate_max, 0); + float const ie_scale = + (istate_max == 0.f) + ? 1.f + : static_cast(std::numeric_limits::max()) / istate_max; + float const id_scale = 1.f / ie_scale; + + // Encode rState → sram.state + for (int ii = 0; ii < stateValuesPerThread; ii++) { + int i = lane * packed_input_t::count + + (ii / packed_input_t::count) * warpSize * packed_input_t::count + + (ii % packed_input_t::count); + if (i < DSTATE) { + convertAndStore(&sram.state[dd][i], rState[ii] * ie_scale); } } + // Store decode scale to smem for later gmem write + if (lane == 0) sram.state_scale[dd] = id_scale; } } out_value = warpReduceSum(out_value); if (lane == 0) { - sram.out[step][d] = out_value; + sram.out[step][d - dim_offset] = out_value; } if (intermediate_states && state_batch != params.pad_slot_id) { + // Write intermediate state smem → gmem for (int i = lane * load_state_t::count; i < DSTATE; i += warpSize * load_state_t::count) { auto* src = reinterpret_cast(&sram.state[dd][i]); @@ -245,18 +345,67 @@ __global__ void selective_state_update_kernel_simple_mtp(SelectiveStateMTPParams d * DSTATE + i]); *dst = *src; } + // Write intermediate state decode scale → gmem + if constexpr (scaleState) { + if (lane == 0) { + auto* iscales = reinterpret_cast(params.intermediate_state_scales); + iscales[intermediate_cache_idx * params.intermediate_state_scales_stride_batch + + step * nheads * DIM + head * DIM + d] = sram.state_scale[dd]; + } + } } } // Update state if enabled and not padded if (params.update_state && state_batch != params.pad_slot_id) { - // Store to rmem -> smem - for (int ii = 0; ii < stateValuesPerThread; ii++) { - int i = lane * packed_input_t::count + - (ii / packed_input_t::count) * warpSize * packed_input_t::count + - (ii % packed_input_t::count); - if (i < DSTATE) { - convertAndStore(&sram.state[dd][i], rState[ii]); + // When intermediate_states is enabled, sram.state[dd] already holds the + // stochastically-rounded (or scaled) state from the last token step's intermediate write. + // Skip the redundant Philox PRNG / re-quantization and write directly to gmem. + if (!intermediate_states) { + if constexpr (scaleState) { + // 2-pass quantization: compute max, then re-encode + float new_state_max = std::numeric_limits::lowest(); + for (int ii = 0; ii < stateValuesPerThread; ii++) { + new_state_max = fmaxf(new_state_max, fabsf(rState[ii])); + } + new_state_max = warpReduceMax(new_state_max); + new_state_max = __shfl_sync(UINT32_MAX, new_state_max, 0); + float const new_encode_scale = + (new_state_max == 0.f) + ? 1.f + : static_cast(std::numeric_limits::max()) / new_state_max; + float const new_decode_scale = 1.f / new_encode_scale; + + // Re-encode state values and store to smem + for (int ii = 0; ii < stateValuesPerThread; ii++) { + int i = lane * packed_input_t::count + + (ii / packed_input_t::count) * warpSize * packed_input_t::count + + (ii % packed_input_t::count); + if (i < DSTATE) { + convertAndStore(&sram.state[dd][i], rState[ii] * new_encode_scale); + } + } + if (lane == 0) convertAndStore(&sram.state_scale[dd], new_decode_scale); + } else { + // Store to rmem -> smem + // Philox-4x32 produces 4 random ints per call; amortize across consecutive elements. + [[maybe_unused]] uint32_t rand_ints[4]; + for (int ii = 0; ii < stateValuesPerThread; ii++) { + int i = lane * packed_input_t::count + + (ii / packed_input_t::count) * warpSize * packed_input_t::count + + (ii % packed_input_t::count); + if (i < DSTATE) { + if constexpr (PHILOX_ROUNDS > 0) { + if (ii % 4 == 0) + philox_randint4x(rand_seed, state_ptr_offset + d * DSTATE + i, + rand_ints[0], rand_ints[1], rand_ints[2], + rand_ints[3]); + sram.state[dd][i] = cvt_rs_f16_f32(rState[ii], rand_ints[ii % 4] & 0x1FFFu); + } else { + convertAndStore(&sram.state[dd][i], rState[ii]); + } + } + } } } // store smem -> gmem @@ -266,30 +415,51 @@ __global__ void selective_state_update_kernel_simple_mtp(SelectiveStateMTPParams } } } + // Store state_scale smem -> gmem (contiguous across warpRows) + if constexpr (scaleState) { + if (params.update_state && state_batch != params.pad_slot_id) { + for (int warpRow = lane; warpRow < stateRowsPerWarpPerStage; warpRow += warpSize) { + auto dd = warp * stateRowsPerWarpPerStage + warpRow; + auto d = dim_offset + dBegin + dd; + if (d < DIM) { + state_scale[d] = sram.state_scale[dd]; + } + } + } + } } __syncthreads(); for (auto step = warp; step < TOKENS_MTP; step += numWarps) { - for (auto d = lane; d < DIM; d += warpSize) { - auto out_value = sram.out[step][d]; - if (z) { - float z_value = toFloat(sram.z[step][d]); - float sig_z = __fdividef(1.f, (1.f + __expf(0.f - z_value))); - float silu_z = z_value * sig_z; - out_value *= silu_z; + for (auto d = lane; d < ROWS_PER_BLOCK; d += warpSize) { + if (dim_offset + d < DIM) { + auto out_value = sram.out[step][d]; + if (z) { + float z_value = toFloat(sram.z[step][d]); + float sig_z = __fdividef(1.f, (1.f + __expf(0.f - z_value))); + float silu_z = z_value * sig_z; + out_value *= silu_z; + } + auto* dst = reinterpret_cast( + &output[batch * params.out_stride_batch + step * params.out_stride_mtp + head * DIM + + dim_offset + d]); + convertAndStore(dst, out_value); } - auto* dst = reinterpret_cast( - &output[batch * params.out_stride_batch + step * params.out_stride_mtp + head * DIM + d]); - convertAndStore(dst, out_value); } } } template + typename stateIndex_t, typename state_scale_t> void invokeSelectiveStateUpdateMTP(SelectiveStateMTPParams& params, SSUAlgorithm algorithm, cudaStream_t stream) { + constexpr bool scaleState = !std::is_same_v; + // Stochastic rounding is only implemented for fp16 state + if constexpr (PHILOX_ROUNDS > 0) { + static_assert(std::is_same_v, + "Stochastic rounding (PHILOX_ROUNDS > 0) only supports fp16 state"); + } // MTP only supports the simple kernel FLASHINFER_CHECK(algorithm == SSUAlgorithm::kAuto || algorithm == SSUAlgorithm::kSimple, "MTP selective_state_update only supports 'auto' or 'simple' algorithm, got ", @@ -300,28 +470,45 @@ void invokeSelectiveStateUpdateMTP(SelectiveStateMTPParams& params, SSUAlgorithm constexpr auto stateLoadSize = getVectorLoadSizeForFullUtilization(); using load_state_t = PackedAligned; - FLASHINFER_CHECK(reinterpret_cast(params.state) % sizeof(load_state_t) == 0, - "state pointer must be aligned to ", sizeof(load_state_t), " bytes"); + FLASHINFER_CHECK_ALIGNMENT(params.state, sizeof(load_state_t)); FLASHINFER_CHECK((params.dim * params.dstate * sizeof(state_t)) % sizeof(load_state_t) == 0, "state head stride must be aligned to ", sizeof(load_state_t), " bytes"); constexpr int numWarps = 4; constexpr int stateRowsPerWarpPerStage = 4; - constexpr int stageRows = stateRowsPerWarpPerStage * numWarps; + constexpr int stateRowsPerBlockPerStage = stateRowsPerWarpPerStage * numWarps; + int const total_tiles = params.batch * params.nheads; + int const num_sms = GetCudaMultiProcessorCount(); dim3 block(warpSize, numWarps); - dim3 grid(params.batch, params.nheads); - - auto func = - selective_state_update_kernel_simple_mtp; - using sram_t = SharedStorageSimple; - constexpr size_t smem_size = sizeof(sram_t); - - FLASHINFER_CUDA_CHECK( - cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - - func<<>>(params); + if (total_tiles < num_sms * 2) { + // Small tile per CTA (stateRowsPerBlockPerStage * DSTATE): split dim across grid.z for GPU + // occupancy + int const dim_tiles = (DIM + stateRowsPerBlockPerStage - 1) / stateRowsPerBlockPerStage; + dim3 grid(params.batch, params.nheads, dim_tiles); + auto func = selective_state_update_kernel_simple_mtp< + input_t, weight_t, matrixA_t, state_t, stateIndex_t, state_scale_t, NTOKENS_MTP, DIM, + DSTATE, stateRowsPerBlockPerStage, PHILOX_ROUNDS, numWarps>; + using sram_t = + SharedStorageSimple; + constexpr size_t smem_size = sizeof(sram_t); + FLASHINFER_CUDA_CHECK( + cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + func<<>>(params); + } else { + // Full tile per CTA (DIM * DSTATE): enough blocks for occupancy, no dim splitting needed + dim3 grid(params.batch, params.nheads); + auto func = selective_state_update_kernel_simple_mtp; + using sram_t = SharedStorageSimple; + constexpr size_t smem_size = sizeof(sram_t); + FLASHINFER_CUDA_CHECK( + cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + func<<>>(params); + } } } // namespace flashinfer::mamba::mtp diff --git a/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh b/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh index 58d43ab2af..a9efc232ee 100644 --- a/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh +++ b/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh @@ -22,6 +22,7 @@ #include #include +#include #include "../utils.cuh" #include "../vec_dtypes.cuh" @@ -57,13 +58,16 @@ __device__ __forceinline__ int conflict_free_column(int group, int baseCol) { return (baseCol + stateValuesPerBank * bankCycle) % colsPerStage; } -template +template struct SharedStorageSimple { + static constexpr bool scaleState = !std::is_same_v; alignas(alignof(PackedAligned)) input_t x[rows_per_block]; alignas(alignof(PackedAligned)) input_t z[rows_per_block]; alignas(alignof(PackedAligned)) input_t B[dstate]; alignas(alignof(PackedAligned)) input_t C[dstate]; float out[rows_per_block]; + alignas(alignof(PackedAligned)) + std::conditional_t state_scale[rows_per_block]; }; // Grid: (batch, nheads, cdiv(DIM, ROWS_PER_BLOCK)) @@ -71,8 +75,10 @@ struct SharedStorageSimple { // Used when batch*nheads is too small to saturate the GPU: set ROWS_PER_BLOCK < DIM to // split dim across blocks for better occupancy. template + typename stateIndex_t, typename state_scale_t, int DIM, int DSTATE, int ROWS_PER_BLOCK, + int numWarps, int PHILOX_ROUNDS> __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams params) { + constexpr bool scaleState = !std::is_same_v; auto* __restrict__ output = reinterpret_cast(params.output); auto* __restrict__ state = reinterpret_cast(params.state); @@ -88,6 +94,13 @@ __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams reinterpret_cast(params.state_batch_indices); bool const dt_softplus = params.dt_softplus; + // State scale pointer (only used when scaleState == true) + [[maybe_unused]] auto* __restrict__ state_scale = + reinterpret_cast(params.state_scale); + + // Load device-side Philox seed once into a register + [[maybe_unused]] int64_t const rand_seed = params.rand_seed ? *params.rand_seed : 0; + int const nheads = params.nheads; int const ngroups = params.ngroups; @@ -101,9 +114,13 @@ __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams auto warp = threadIdx.y; auto const state_batch = (state_batch_indices) ? state_batch_indices[batch] : batch; - state += state_batch * params.state_stride_batch + head * DIM * DSTATE; + auto const state_ptr_offset = state_batch * params.state_stride_batch + head * DIM * DSTATE; + state += state_ptr_offset; + if constexpr (scaleState) { + state_scale += state_batch * params.state_scale_stride_batch + head * DIM; + } - __shared__ SharedStorageSimple sram; + __shared__ SharedStorageSimple sram; static constexpr auto stateLoadSize = getVectorLoadSizeForFullUtilization(); using load_state_t = PackedAligned; @@ -127,16 +144,28 @@ __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams if (dim_offset + d < DIM) sram.x[d] = x[batch * params.x_stride_batch + head * DIM + dim_offset + d]; } + if constexpr (scaleState) { + for (auto d = lane; d < ROWS_PER_BLOCK; d += warpSize) { + if (dim_offset + d < DIM) sram.state_scale[d] = state_scale[dim_offset + d]; + } + } + } else if (warp == 1) { for (auto i = lane * load_input_t::count; i < DSTATE; i += warpSize * load_input_t::count) { auto* dst = reinterpret_cast(&sram.B[i]); *dst = *reinterpret_cast( &B[batch * params.B_stride_batch + group * DSTATE + i]); } - } else if (warp == 1) { - for (auto d = lane; d < ROWS_PER_BLOCK; d += warpSize) { - if (dim_offset + d < DIM) - sram.z[d] = z ? z[batch * params.z_stride_batch + head * DIM + dim_offset + d] : input_t(0); - } + } else if (warp == 2) { + if (z) + for (auto d = lane; d < ROWS_PER_BLOCK; d += warpSize) { + if (dim_offset + d < DIM) + sram.z[d] = z[batch * params.z_stride_batch + head * DIM + dim_offset + d]; + } + else + for (auto d = lane; d < ROWS_PER_BLOCK; d += warpSize) { + if (dim_offset + d < DIM) sram.z[d] = input_t(0); + } + } else if (warp == 3) { for (auto i = lane * load_input_t::count; i < DSTATE; i += warpSize * load_input_t::count) { auto* dst = reinterpret_cast(&sram.C[i]); *dst = *reinterpret_cast( @@ -150,33 +179,87 @@ __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams if (d >= DIM) break; float x_value = toFloat(sram.x[_d]); + float state_decode_scale = 1.f; + [[maybe_unused]] float new_state_max = std::numeric_limits::lowest(); + if constexpr (scaleState) { + state_decode_scale = toFloat(sram.state_scale[_d]); + } float out_value = d_value * x_value * int(lane == 0); - for (int i = lane * load_state_t::count; i < DSTATE; i += warpSize * load_state_t::count) { + // When scaleState, keep new_state values in registers to avoid re-reading gmem + // and recomputing in the quantization pass. Each thread covers DSTATE/warpSize elements. + [[maybe_unused]] float rNewState[scaleState ? DSTATE / warpSize : 1]; + + // update the out value and compute the max state and state sum. + // Philox-4x32 produces 4 random ints per call; reuse across up to 4 consecutive elements. + // Refresh every time the ii-within-outer-loop crosses a multiple-of-4 boundary. + // Works for any count (1, 2, 4, 8, ...): count<=4 refreshes once per outer iter (or less), + // count>4 (e.g. count=8 for bf16+DSTATE=256) refreshes count/4 times per outer iter. + [[maybe_unused]] uint32_t rand_ints[4]; + for (int iter = 0, i = lane * load_state_t::count; i < DSTATE; + iter++, i += warpSize * load_state_t::count) { auto rState = make_zeros(); if (state_batch != params.pad_slot_id) rState = *reinterpret_cast(&state[d * DSTATE + i]); for (int ii = 0; ii < load_state_t::count; ii++) { - auto state_value = toFloat(rState.val[ii]); + if constexpr (PHILOX_ROUNDS > 0 && !scaleState) { + if (ii % 4 == 0) + philox_randint4x(rand_seed, state_ptr_offset + d * DSTATE + i + ii, + rand_ints[0], rand_ints[1], rand_ints[2], rand_ints[3]); + } + + auto state_value = toFloat(rState.val[ii]) * state_decode_scale; auto B_value = toFloat(sram.B[i + ii]); auto C_value = toFloat(sram.C[i + ii]); auto const dB = B_value * dt_value; auto const new_state = state_value * dA + dB * x_value; - - convertAndStore(&rState.val[ii], new_state); + if constexpr (scaleState) { + new_state_max = fmaxf(new_state_max, fabsf(new_state)); + rNewState[iter * load_state_t::count + ii] = new_state; + } else if constexpr (PHILOX_ROUNDS > 0) { + rState.val[ii] = cvt_rs_f16_f32(new_state, rand_ints[ii % 4] & 0x1FFFu); + } else { + convertAndStore(&rState.val[ii], new_state); + } out_value += new_state * C_value; } - if (params.update_state && state_batch != params.pad_slot_id) + if (!scaleState && params.update_state && state_batch != params.pad_slot_id) { *reinterpret_cast(&state[d * DSTATE + i]) = rState; + } } - out_value = warpReduceSum(out_value); if (lane == 0) { sram.out[_d] = out_value; } + + if constexpr (scaleState) { + if (params.update_state && state_batch != params.pad_slot_id) { + new_state_max = warpReduceMax(new_state_max); + new_state_max = __shfl_sync(UINT32_MAX, new_state_max, 0); // broadcast to all lanes + float const new_state_encode_scale = + (new_state_max == 0.f) + ? 1.f + : static_cast(std::numeric_limits::max()) / new_state_max; + float const new_state_decode_scale = 1.f / new_state_encode_scale; + + for (int iter = 0, i = lane * load_state_t::count; i < DSTATE; + iter++, i += warpSize * load_state_t::count) { + auto rState = make_zeros(); + for (int ii = 0; ii < load_state_t::count; ii++) { + convertAndStore(&rState.val[ii], + rNewState[iter * load_state_t::count + ii] * new_state_encode_scale); + } + *reinterpret_cast(&state[d * DSTATE + i]) = rState; + } + + if (lane == 0) { + convertAndStore(&sram.state_scale[_d], new_state_decode_scale); + } + } + } } __syncthreads(); @@ -195,17 +278,30 @@ __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams convertAndStore(&output[batch * params.out_stride_batch + head * DIM + d], out_value); } } + if constexpr (scaleState) { + if (params.update_state && state_batch != params.pad_slot_id) { + for (int l = lane; l < rowsPerWarp; l += warpSize) { + auto _d = warp * rowsPerWarp + l; + auto d = dim_offset + _d; + if (d < DIM) { + state_scale[d] = sram.state_scale[_d]; + } + } + } + } } template + typename state_scale_t, int rowsPerStage, int dim, int dstate, uint8_t numStages> struct SharedStorageVertical { + static constexpr bool scaleState = !std::is_same_v; alignas(128) state_t state[numStages][rowsPerStage * dstate]; alignas(alignof(PackedAligned)) input_t x[dim]; alignas(alignof(PackedAligned)) input_t z[dim]; alignas(alignof(PackedAligned)) input_t B[dstate]; alignas(alignof(PackedAligned)) input_t C[dstate]; - float out[dim]; // dt is special cause we're gonna store input in there as well + float out[dim]; + alignas(128) std::conditional_t state_scale[dim * scaleState]; using barrier_t = cuda::barrier; barrier_t bar_empty[numStages]; @@ -213,14 +309,14 @@ struct SharedStorageVertical { barrier_t bar_consumers; }; -template -__device__ __forceinline__ void producer_func_vertical(SramT& sram, CUtensorMap const& tensorState, - input_t const* x_global_ptr, - input_t const* B_global_ptr, - input_t const* C_global_ptr, - input_t const* z_global_ptr, int batch, - int head) { +template +__device__ __forceinline__ void producer_func_vertical( + SramT& sram, CUtensorMap const& tensorState, input_t const* x_global_ptr, + input_t const* B_global_ptr, input_t const* C_global_ptr, input_t const* z_global_ptr, + [[maybe_unused]] void const* state_scale_ptr, int batch, int head) { + constexpr bool scaleState = !std::is_same_v; #ifdef FLASHINFER_MAMBA_ENABLE_SM90 namespace cde = cuda::device::experimental; @@ -233,7 +329,13 @@ __device__ __forceinline__ void producer_func_vertical(SramT& sram, CUtensorMap auto constexpr bytesB = DSTATE * sizeof(input_t); auto constexpr bytesC = DSTATE * sizeof(input_t); auto constexpr bytesZ = hasZ ? DIM * sizeof(input_t) : 0; - auto constexpr bytesInputs = bytesX + bytesB + bytesC + bytesZ; + auto constexpr bytesStateScale = []() constexpr { + if constexpr (scaleState) + return DIM * sizeof(state_scale_t); + else + return size_t(0); + }(); + auto constexpr bytesInputs = bytesX + bytesB + bytesC + bytesZ + bytesStateScale; // Phase 1, iter 0: fire all input vector loads + state load (if readState) // All inputs piggyback onto bar_full[0] so consumers get them before stage 0 @@ -253,6 +355,11 @@ __device__ __forceinline__ void producer_func_vertical(SramT& sram, CUtensorMap cuda::device::memcpy_async_tx(&sram.z[0], z_global_ptr, cuda::aligned_size_t<16>(bytesZ), sram.bar_full[stage]); } + if constexpr (scaleState) { + cuda::device::memcpy_async_tx( + &sram.state_scale[0], static_cast(state_scale_ptr), + cuda::aligned_size_t<16>(bytesStateScale), sram.bar_full[stage]); + } if constexpr (readState) { cde::cp_async_bulk_tensor_4d_global_to_shared(&sram.state[stage][0], &tensorState, 0, d, head, @@ -336,12 +443,15 @@ __device__ __forceinline__ void producer_func_vertical(SramT& sram, CUtensorMap #endif } -template +template __device__ __forceinline__ void consumer_func_vertical( int lane, int warp, float d_value, float dt_value, float dA, - SharedStorageVertical& sram) { + SharedStorageVertical& sram, + int64_t rand_seed, [[maybe_unused]] int64_t state_ptr_offset) { + constexpr bool scaleState = !std::is_same_v; #ifdef FLASHINFER_MAMBA_ENABLE_SM90 namespace cde = cuda::device::experimental; for (auto dBegin = 0, stage = 0; dBegin < DIM; @@ -358,8 +468,22 @@ __device__ __forceinline__ void consumer_func_vertical( constexpr auto bankSize = sizeof(uint32_t); constexpr auto stateValuesPerBank = bankSize / sizeof(state_t); + [[maybe_unused]] float state_decode_scale = 1.f; + [[maybe_unused]] float new_state_max = std::numeric_limits::lowest(); + if constexpr (scaleState) { + state_decode_scale = toFloat(sram.state_scale[d]); + } + // Register buffer for 2-pass quantization (min size 1 to avoid zero-sized array in device + // code) + [[maybe_unused]] float rNewState[scaleState ? DSTATE / warpSize : 1]; + + // Philox-4x32 produces 4 random ints per call; reuse across up to 4 consecutive elements. + // Refresh when e % 4 == 0. stateValuesPerBank is constexpr so all branches compile away. + [[maybe_unused]] uint32_t rand_ints[4]; if constexpr (sizeof(state_t) == sizeof(input_t)) { - for (int i = lane * stateValuesPerBank; i < DSTATE; i += warpSize * stateValuesPerBank) { + for (int iter = 0, i = lane * stateValuesPerBank; i < DSTATE; + iter++, i += warpSize * stateValuesPerBank) { + // load a bank-worth of states auto* sState_ptr = reinterpret_cast(&sram.state[stage][dd * DSTATE + i]); uint32_t rState = *sState_ptr; auto* rState_ptr = reinterpret_cast(&rState); @@ -371,11 +495,21 @@ __device__ __forceinline__ void consumer_func_vertical( auto* rC_ptr = reinterpret_cast(&rC); for (int e = 0; e < stateValuesPerBank; e++) { + if constexpr (PHILOX_ROUNDS > 0 && !scaleState) { + if (e % 4 == 0) + philox_randint4x(rand_seed, state_ptr_offset + d * DSTATE + i + e, + rand_ints[0], rand_ints[1], rand_ints[2], + rand_ints[3]); + } + float state_value; if constexpr (!useStateCache) { state_value = 0.f; } else { state_value = toFloat(rState_ptr[e]); + if constexpr (scaleState) { + state_value *= state_decode_scale; + } } auto const B_value = toFloat(rB_ptr[e]); auto const C_value = toFloat(rC_ptr[e]); @@ -383,33 +517,62 @@ __device__ __forceinline__ void consumer_func_vertical( auto const dB = B_value * dt_value; auto const new_state = state_value * dA + dB * x_value; - convertAndStore(&rState_ptr[e], new_state); + if constexpr (scaleState) { + new_state_max = fmaxf(new_state_max, fabsf(new_state)); + rNewState[iter * stateValuesPerBank + e] = new_state; + } else if constexpr (PHILOX_ROUNDS > 0) { + rState_ptr[e] = cvt_rs_f16_f32(new_state, rand_ints[e % 4] & 0x1FFFu); + } else { + convertAndStore(&rState_ptr[e], new_state); + } out_value += new_state * C_value; } - *sState_ptr = rState; + if constexpr (!scaleState) { + *sState_ptr = rState; + } } } else { - for (int i = lane * stateValuesPerBank; i < DSTATE; i += warpSize * stateValuesPerBank) { + for (int iter = 0, i = lane * stateValuesPerBank; i < DSTATE; + iter++, i += warpSize * stateValuesPerBank) { auto* sState_ptr = reinterpret_cast(&sram.state[stage][dd * DSTATE + i]); uint32_t rState = *sState_ptr; auto* rState_ptr = reinterpret_cast(&rState); for (int e = 0; e < stateValuesPerBank; e++) { + if constexpr (PHILOX_ROUNDS > 0 && !scaleState) { + if (e % 4 == 0) + philox_randint4x(rand_seed, state_ptr_offset + d * DSTATE + i + e, + rand_ints[0], rand_ints[1], rand_ints[2], + rand_ints[3]); + } + float state_value; if constexpr (!useStateCache) { state_value = 0.f; } else { state_value = toFloat(rState_ptr[e]); + if constexpr (scaleState) { + state_value *= state_decode_scale; + } } auto const B_value = toFloat(sram.B[i + e]); auto const C_value = toFloat(sram.C[i + e]); auto const dB = B_value * dt_value; auto const new_state = state_value * dA + dB * x_value; - convertAndStore(&rState_ptr[e], new_state); + if constexpr (scaleState) { + new_state_max = fmaxf(new_state_max, fabsf(new_state)); + rNewState[iter * stateValuesPerBank + e] = new_state; + } else if constexpr (PHILOX_ROUNDS > 0) { + rState_ptr[e] = cvt_rs_f16_f32(new_state, rand_ints[e % 4] & 0x1FFFu); + } else { + convertAndStore(&rState_ptr[e], new_state); + } out_value += new_state * C_value; } - *sState_ptr = rState; + if constexpr (!scaleState) { + *sState_ptr = rState; + } } } @@ -417,6 +580,33 @@ __device__ __forceinline__ void consumer_func_vertical( if (lane == 0) { sram.out[d] = out_value; } + + // 2nd pass: quantize new state values and write back to sram for TMA writeback + if constexpr (scaleState && useStateCache) { + new_state_max = warpReduceMax(new_state_max); + new_state_max = __shfl_sync(UINT32_MAX, new_state_max, 0); + float const new_encode_scale = + (new_state_max == 0.f) + ? 1.f + : static_cast(std::numeric_limits::max()) / new_state_max; + float const new_decode_scale = 1.f / new_encode_scale; + + for (int iter = 0, i = lane * stateValuesPerBank; i < DSTATE; + iter++, i += warpSize * stateValuesPerBank) { + auto* sState_ptr = reinterpret_cast(&sram.state[stage][dd * DSTATE + i]); + uint32_t rState; + auto* rState_ptr = reinterpret_cast(&rState); + for (int e = 0; e < stateValuesPerBank; e++) { + convertAndStore(&rState_ptr[e], + rNewState[iter * stateValuesPerBank + e] * new_encode_scale); + } + *sState_ptr = rState; + } + + if (lane == 0) { + convertAndStore(&sram.state_scale[d], new_decode_scale); + } + } } // Unblock producer @@ -427,10 +617,11 @@ __device__ __forceinline__ void consumer_func_vertical( } template + typename stateIndex_t, typename state_scale_t, int DIM, int DSTATE, int PHILOX_ROUNDS, + int consumerWarps = 1, int rowsPerStage = 4, int numStages = 1> __global__ void selective_state_update_kernel_producer_consumer_vertical( SelectiveStateUpdateParams params, __grid_constant__ CUtensorMap const tensorState) { + constexpr bool scaleState = !std::is_same_v; #ifdef FLASHINFER_MAMBA_ENABLE_SM90 auto* __restrict__ output = reinterpret_cast(params.output); @@ -444,6 +635,11 @@ __global__ void selective_state_update_kernel_producer_consumer_vertical( auto const* __restrict__ z = reinterpret_cast(params.z); auto const* __restrict__ state_batch_indices = reinterpret_cast(params.state_batch_indices); + [[maybe_unused]] auto* __restrict__ state_scale = + reinterpret_cast(params.state_scale); + + // Load device-side Philox seed once into a register + [[maybe_unused]] int64_t const rand_seed = params.rand_seed ? *params.rand_seed : 0; int const nheads = params.nheads; int const ngroups = params.ngroups; @@ -457,10 +653,12 @@ __global__ void selective_state_update_kernel_producer_consumer_vertical( auto warp = threadIdx.y; auto const state_batch = (state_batch_indices) ? __ldg(&state_batch_indices[batch]) : batch; + auto const state_ptr_offset = + static_cast(state_batch) * params.state_stride_batch + head * DIM * DSTATE; extern __shared__ uint8_t sbuffer[]; - using sram_t = SharedStorageVertical; + using sram_t = SharedStorageVertical; auto& sram = *reinterpret_cast(sbuffer); namespace cde = cuda::device::experimental; @@ -490,11 +688,15 @@ __global__ void selective_state_update_kernel_producer_consumer_vertical( auto const* B_global_ptr = &B[batch * params.B_stride_batch + group * DSTATE]; auto const* C_global_ptr = &C[batch * params.C_stride_batch + group * DSTATE]; auto const* z_global_ptr = z ? &z[batch * params.z_stride_batch + head * DIM] : nullptr; + [[maybe_unused]] void const* state_scale_ptr = nullptr; + if constexpr (scaleState) { + state_scale_ptr = &state_scale[state_batch * params.state_scale_stride_batch + head * DIM]; + } auto const call = [&]() { - producer_func_vertical(sram, tensorState, x_global_ptr, B_global_ptr, - C_global_ptr, hasZ ? z_global_ptr : nullptr, - state_batch, head); + producer_func_vertical( + sram, tensorState, x_global_ptr, B_global_ptr, C_global_ptr, + hasZ ? z_global_ptr : nullptr, state_scale_ptr, state_batch, head); }; auto const dispatch_state = [&]() { if (read_state && write_state) @@ -534,13 +736,13 @@ __global__ void selective_state_update_kernel_producer_consumer_vertical( auto const dA = __expf(A_value * dt_value); if (state_batch != params.pad_slot_id) - consumer_func_vertical(lane, warp, d_value, dt_value, dA, - sram); + consumer_func_vertical( + lane, warp, d_value, dt_value, dA, sram, rand_seed, state_ptr_offset); else - consumer_func_vertical(lane, warp, d_value, dt_value, dA, - sram); + consumer_func_vertical( + lane, warp, d_value, dt_value, dA, sram, rand_seed, state_ptr_offset); // Write output — wait for all consumer warps to finish writing sram.out sram.bar_consumers.wait(sram.bar_consumers.arrive()); @@ -555,6 +757,14 @@ __global__ void selective_state_update_kernel_producer_consumer_vertical( } convertAndStore(&output[batch * params.out_stride_batch + head * DIM + d], out_value); } + if constexpr (scaleState) { + if (params.update_state && state_batch != params.pad_slot_id) { + if (d < DIM) { + state_scale[state_batch * params.state_scale_stride_batch + head * DIM + d] = + sram.state_scale[d]; + } + } + } } #endif } @@ -659,12 +869,13 @@ __device__ __forceinline__ void producer_func_horizontal(SramT& sram, } template + int DSTATE, int PHILOX_ROUNDS, int consumerWarps, int colsPerStage, int numStages, + bool useStateCache> __device__ __forceinline__ void consumer_func_horizontal( int d, int member, float A_value, float dt_value, float x_value, SharedStorageHorizontal& sram, - float& out_value) { + float& out_value, int64_t rand_seed, [[maybe_unused]] int64_t state_ptr_offset) { namespace cde = cuda::device::experimental; constexpr auto lanesPerRow = (consumerWarps * warpSize) / DIM; constexpr auto itemsPerThread = colsPerStage / lanesPerRow; @@ -679,12 +890,15 @@ __device__ __forceinline__ void consumer_func_horizontal( constexpr auto bankSize = sizeof(uint32_t); constexpr auto stateValuesPerBank = bankSize / sizeof(state_t); constexpr auto numBanks = 32; + // Philox-4x32 produces 4 random ints per call; reuse across up to 4 consecutive elements. + // flat_e tracks position across outer+inner loops; refresh every 4 elements. + // Loop is fully unrolled (#pragma unroll), so the modulo and branch compile away. + [[maybe_unused]] uint32_t rand_ints[4]; if constexpr (sizeof(state_t) == sizeof(input_t)) { #pragma unroll for (int item = 0; item < itemsPerThread; item += stateValuesPerBank) { auto const baseCol = item + member * itemsPerThread; // If I just use baseCol as the index, a lot of bank conflicts will arise. - // auto const ii = conflict_free_column(group, baseCol); @@ -701,6 +915,14 @@ __device__ __forceinline__ void consumer_func_horizontal( auto* rC_ptr = reinterpret_cast(&rC); for (int e = 0; e < stateValuesPerBank; e++) { + int flat_e = item + e; + if constexpr (PHILOX_ROUNDS > 0) { + if (flat_e % 4 == 0) + philox_randint4x(rand_seed, state_ptr_offset + d * DSTATE + i + e, + rand_ints[0], rand_ints[1], rand_ints[2], + rand_ints[3]); + } + float state_value; if constexpr (!useStateCache) { state_value = 0.f; @@ -715,12 +937,18 @@ __device__ __forceinline__ void consumer_func_horizontal( auto const dB = B_value * dt_value; auto const new_state = state_value * dA + dB * x_value; - convertAndStore(&rState_ptr[e], new_state); + // TODO: when stateValuesPerBank == 2, could use cvt_rs_f16x2_f32 for both at once + if constexpr (PHILOX_ROUNDS > 0) { + rState_ptr[e] = cvt_rs_f16_f32(new_state, rand_ints[flat_e % 4] & 0x1FFFu); + } else { + convertAndStore(&rState_ptr[e], new_state); + } out_value += new_state * C_value; } *sState_ptr = rState; } } else { +#pragma unroll for (int item = 0; item < itemsPerThread; item += stateValuesPerBank) { auto const baseCol = item + member * itemsPerThread; auto const ii = @@ -732,6 +960,14 @@ __device__ __forceinline__ void consumer_func_horizontal( auto* rState_ptr = reinterpret_cast(&rState); for (int e = 0; e < stateValuesPerBank; e++) { + int flat_e = item + e; + if constexpr (PHILOX_ROUNDS > 0) { + if (flat_e % 4 == 0) + philox_randint4x(rand_seed, state_ptr_offset + d * DSTATE + i + e, + rand_ints[0], rand_ints[1], rand_ints[2], + rand_ints[3]); + } + float state_value; if constexpr (!useStateCache) { state_value = 0.f; @@ -746,7 +982,12 @@ __device__ __forceinline__ void consumer_func_horizontal( auto const dB = B_value * dt_value; auto const new_state = state_value * dA + dB * x_value; - convertAndStore(&rState_ptr[e], new_state); + // TODO: when stateValuesPerBank == 2, could use cvt_rs_f16x2_f32 for both at once + if constexpr (PHILOX_ROUNDS > 0) { + rState_ptr[e] = cvt_rs_f16_f32(new_state, rand_ints[flat_e % 4] & 0x1FFFu); + } else { + convertAndStore(&rState_ptr[e], new_state); + } out_value += new_state * C_value; } *sState_ptr = rState; @@ -758,8 +999,8 @@ __device__ __forceinline__ void consumer_func_horizontal( } template + typename stateIndex_t, int DIM, int DSTATE, int PHILOX_ROUNDS, int headsGroupsRatio, + int consumerWarps, int colsPerStage, int numStages = 1> __global__ void selective_state_update_kernel_producer_consumer_horizontal( SelectiveStateUpdateParams params, __grid_constant__ CUtensorMap const tensorState) { auto* __restrict__ output = reinterpret_cast(params.output); @@ -774,6 +1015,9 @@ __global__ void selective_state_update_kernel_producer_consumer_horizontal( auto const* __restrict__ state_batch_indices = reinterpret_cast(params.state_batch_indices); + // Load device-side Philox seed once into a register + [[maybe_unused]] int64_t const rand_seed = params.rand_seed ? *params.rand_seed : 0; + int const nheads = params.nheads; constexpr auto numWarps = 1 + consumerWarps; @@ -785,6 +1029,8 @@ __global__ void selective_state_update_kernel_producer_consumer_horizontal( auto warp = threadIdx.y; auto const state_batch = (state_batch_indices) ? state_batch_indices[batch] : batch; + auto const state_ptr_offset = + static_cast(state_batch) * params.state_stride_batch + head * DIM * DSTATE; extern __shared__ uint8_t sbuffer[]; using sram_t = SharedStorageHorizontal(d, member, A_value, dt_value, x_value, - sram, out_value); + consumer_func_horizontal( + d, member, A_value, dt_value, x_value, sram, out_value, rand_seed, state_ptr_offset); else - consumer_func_horizontal(d, member, A_value, dt_value, - x_value, sram, out_value); + consumer_func_horizontal( + d, member, A_value, dt_value, x_value, sram, out_value, rand_seed, state_ptr_offset); out_value += __shfl_down_sync(UINT32_MAX, out_value, 16); if constexpr (lanesPerRow == 4) { @@ -904,9 +1150,15 @@ __global__ void selective_state_update_kernel_producer_consumer_horizontal( #endif // FLASHINFER_MAMBA_ENABLE_SM90 (horizontal kernel) template + typename stateIndex_t, typename state_scale_t> void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, SSUAlgorithm algorithm, cudaStream_t stream) { + constexpr bool scaleState = !std::is_same_v; + // Stochastic rounding is only implemented for fp16 state + if constexpr (PHILOX_ROUNDS > 0) { + static_assert(std::is_same_v, + "Stochastic rounding (PHILOX_ROUNDS > 0) only supports fp16 state"); + } auto [sm_major, sm_minor] = GetCudaComputeCapability(); // Common alignment checks for all kernels @@ -929,6 +1181,9 @@ void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, SSUAlgorithm algo = SSUAlgorithm::kSimple; else if (sm_major < 10) algo = SSUAlgorithm::kVertical; + else if (scaleState) + // Horizontal kernel cannot do 2-pass quantization, so always use vertical + algo = SSUAlgorithm::kVertical; else // On Blackwell+: vertical is slightly faster for fp32 state, // horizontal is faster for fp16/bf16 state. @@ -943,8 +1198,7 @@ void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, SSUAlgorithm constexpr auto stateLoadSize = getVectorLoadSizeForFullUtilization(); using load_state_t = PackedAligned; - FLASHINFER_CHECK(reinterpret_cast(params.state) % sizeof(load_state_t) == 0, - "state pointer must be aligned to ", sizeof(load_state_t), " bytes"); + FLASHINFER_CHECK_ALIGNMENT(params.state, sizeof(load_state_t)); FLASHINFER_CHECK((params.dim * params.dstate * sizeof(state_t)) % sizeof(load_state_t) == 0, "state head stride must be aligned to ", sizeof(load_state_t), " bytes"); @@ -958,15 +1212,15 @@ void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, SSUAlgorithm // Tiled: split dim across blocks for better GPU occupancy at small batch sizes int const dim_tiles = (DIM + ROWS_PER_BLOCK - 1) / ROWS_PER_BLOCK; dim3 grid(params.batch, params.nheads, dim_tiles); - selective_state_update_kernel_simple - <<>>(params); + selective_state_update_kernel_simple<<>>(params); } else { // Non-tiled: enough blocks already for full occupancy; ROWS_PER_BLOCK == DIM so blockIdx.z == // 0 dim3 grid(params.batch, params.nheads); - selective_state_update_kernel_simple + selective_state_update_kernel_simple <<>>(params); } } @@ -978,9 +1232,17 @@ void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, SSUAlgorithm constexpr auto rowsPerStage = 4 * numConsumers; FLASHINFER_CHECK(params.dim % rowsPerStage == 0, "dim must be divisible by ", rowsPerStage, " for vertical kernel"); + + // TMA alignment checks for all pointers loaded via memcpy_async_tx + FLASHINFER_CHECK_TMA_ALIGNED(params.x); + FLASHINFER_CHECK_TMA_ALIGNED(params.B); + FLASHINFER_CHECK_TMA_ALIGNED(params.C); + if (params.z) FLASHINFER_CHECK_TMA_ALIGNED(params.z); + if constexpr (scaleState) FLASHINFER_CHECK_TMA_ALIGNED(params.state_scale); + auto scan_func = selective_state_update_kernel_producer_consumer_vertical< - input_t, weight_t, matrixA_t, state_t, stateIndex_t, DIM, DSTATE, numConsumers, - rowsPerStage, numStages>; + input_t, weight_t, matrixA_t, state_t, stateIndex_t, state_scale_t, DIM, DSTATE, + PHILOX_ROUNDS, numConsumers, rowsPerStage, numStages>; dim3 block(warpSize, numWarps); dim3 grid(params.batch, params.nheads); @@ -991,14 +1253,18 @@ void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, SSUAlgorithm /*strides*/ {1, DSTATE, DSTATE * DIM, params.state_stride_batch}, /*tiles*/ {DSTATE, rowsPerStage, 1, 1}, params.state); - using sram_t = SharedStorageVertical; + using sram_t = SharedStorageVertical; constexpr size_t smem_size = sizeof(sram_t); FLASHINFER_CUDA_CHECK( cudaFuncSetAttribute(scan_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); scan_func<<>>(params, state_tensor); } else if (algo == SSUAlgorithm::kHorizontal) { + FLASHINFER_CHECK( + !scaleState, + "Horizontal kernel does not support scaled state (int16). " + "Cannot do 2-pass quantization because dstate tiles are discarded after processing."); constexpr auto numConsumers = (DIM / 64) * 4; constexpr auto numProducers = 1; constexpr auto numWarps = numProducers + numConsumers; @@ -1011,8 +1277,8 @@ void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, SSUAlgorithm auto ratio_launcher = [&]() { auto scan_func = selective_state_update_kernel_producer_consumer_horizontal< - input_t, weight_t, matrixA_t, state_t, stateIndex_t, DIM, DSTATE, numConsumers, stageCols, - RATIO, numStages>; + input_t, weight_t, matrixA_t, state_t, stateIndex_t, DIM, DSTATE, PHILOX_ROUNDS, RATIO, + numConsumers, stageCols, numStages>; dim3 block(warpSize, numWarps); dim3 grid(params.batch, params.nheads); diff --git a/include/flashinfer/mamba/selective_state_update.cuh b/include/flashinfer/mamba/selective_state_update.cuh index 0607d7a0f2..83eda2712c 100644 --- a/include/flashinfer/mamba/selective_state_update.cuh +++ b/include/flashinfer/mamba/selective_state_update.cuh @@ -50,7 +50,7 @@ struct SelectiveStateUpdateParams { int32_t pad_slot_id{-1}; int64_t x_stride_batch{}, dt_stride_batch{}, B_stride_batch{}, C_stride_batch{}, - out_stride_batch{}, z_stride_batch{}, state_stride_batch{}; + out_stride_batch{}, z_stride_batch{}, state_stride_batch{}, state_scale_stride_batch{}; void* __restrict__ state{nullptr}; void* __restrict__ x{nullptr}; @@ -63,9 +63,16 @@ struct SelectiveStateUpdateParams { void* __restrict__ z{nullptr}; void* __restrict__ output{nullptr}; void* __restrict__ state_batch_indices{nullptr}; + // Block-scale decode factors for quantized state: float32 (state_cache_size, nheads, dim, 1) + void* __restrict__ state_scale{nullptr}; bool dt_softplus{false}; bool update_state{true}; + + // Philox PRNG seed for stochastic rounding of fp16 state stores. + // Only used when the kernel is compiled with NUM_PHILOX_ROUNDS > 0. + // Device-side pointer to a single int64_t value. + const int64_t* rand_seed{nullptr}; }; namespace mtp { @@ -77,10 +84,12 @@ struct SelectiveStateMTPParams : public SelectiveStateUpdateParams { // MTP-specific strides for the token dimension int64_t x_stride_mtp{}, dt_stride_mtp{}, B_stride_mtp{}, C_stride_mtp{}, out_stride_mtp{}, z_stride_mtp{}; + int64_t intermediate_state_stride_batch{}, intermediate_state_scales_stride_batch{}; void* __restrict__ intermediate_states{ nullptr}; // state_t: (ntokens_mtp, state_cache_size, nheads, dim, dstate) void* __restrict__ intermediate_state_indices{nullptr}; // (batch,) - int64_t intermediate_state_stride_batch{}; // stride for batch dimension of intermediate_states + void* __restrict__ intermediate_state_scales{ + nullptr}; // float: (batch, cache_steps, nheads, dim) }; } // namespace mtp diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index 2841e6b164..c7edf5ab57 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -64,6 +64,12 @@ ") at ", __FILE__, ":", __LINE__, " in ", STR(func)); \ } while (0) +#define FLASHINFER_CHECK_ALIGNMENT(ptr, size_bytes) \ + FLASHINFER_CHECK(reinterpret_cast(ptr) % (size_bytes) == 0, #ptr, \ + " must be aligned to ", (size_bytes), " bytes, got address ", (uintptr_t)(ptr)) + +#define FLASHINFER_CHECK_TMA_ALIGNED(ptr) FLASHINFER_CHECK_ALIGNMENT(ptr, 128) + #define DISPATCH_USE_FP16_QK_REDUCTION(use_fp16_qk_reduction, USE_FP16_QK_REDUCTION, ...) \ if (use_fp16_qk_reduction) { \ FLASHINFER_ERROR("FP16_QK_REDUCTION disabled at compile time"); \ diff --git a/tests/mamba/test_philox_rounding.py b/tests/mamba/test_philox_rounding.py new file mode 100644 index 0000000000..66c344c87f --- /dev/null +++ b/tests/mamba/test_philox_rounding.py @@ -0,0 +1,432 @@ +"""Tests for Philox PRNG and stochastic rounding primitives. + +Test 1: CUDA vs Triton Philox randint — bitwise comparison (any GPU). +Test 2: CUDA vs Triton stochastic rounding (cvt.rs.f16x2.f32) — bitwise comparison (sm_100a+). +""" + +import pathlib + +import pytest +import torch +import triton +import triton.language as tl +from torch.utils.cpp_extension import load_inline + +from flashinfer.utils import get_compute_capability + + +# --------------------------------------------------------------------------- +# Triton reference: tl.randint +# --------------------------------------------------------------------------- +@triton.jit +def _triton_philox_kernel( + out_ptr, + seed_ptr, + n_elements, + N_ROUNDS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + seed = tl.load(seed_ptr) + rand = tl.randint(seed, offsets, N_ROUNDS) + # rand is uint32 but Triton stores it as int32 bit-pattern + tl.store(out_ptr + offsets, rand, mask=mask) + + +def triton_philox(seed: int, n_elements: int, n_rounds: int) -> torch.Tensor: + """Run the Triton philox kernel and return int32 tensor (uint32 bit-pattern).""" + seed_t = torch.tensor([seed], dtype=torch.int64, device="cuda") + out = torch.empty(n_elements, dtype=torch.int32, device="cuda") + BLOCK_SIZE = 1024 + grid = ((n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE,) + _triton_philox_kernel[grid](out, seed_t, n_elements, n_rounds, BLOCK_SIZE) + return out + + +# --------------------------------------------------------------------------- +# Triton reference: convert_rs_fp16x2 (stochastic rounding) +# --------------------------------------------------------------------------- +@triton.jit +def _triton_convert_rs_kernel( + out_ptr, + fp32_ptr, + rand_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + """Apply stochastic rounding: fp32 → fp16 using random bits.""" + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(fp32_ptr + offsets, mask=mask) + rand = tl.load(rand_ptr + offsets, mask=mask) + # cvt.rs.f16x2.f32: stochastic rounding of fp32 pair → fp16x2 + y = tl.inline_asm_elementwise( + asm="""{ + cvt.rs.f16x2.f32 $0, $2, $1, $3; + }""", + constraints=("=r,r,r,r,r"), + args=(x, rand), + dtype=tl.float16, + is_pure=True, + pack=2, + ) + tl.store(out_ptr + offsets, y, mask=mask) + + +def triton_stochastic_round( + fp32_values: torch.Tensor, rand_bits: torch.Tensor +) -> torch.Tensor: + """Stochastic-round fp32 → fp16 using random bits via Triton PTX.""" + assert fp32_values.dtype == torch.float32 + assert rand_bits.dtype == torch.int32 + n = fp32_values.numel() + # n must be even for fp16x2 packing + assert n % 2 == 0, "n_elements must be even for fp16x2 packing" + out = torch.empty(n, dtype=torch.float16, device="cuda") + BLOCK_SIZE = 1024 + grid = ((n + BLOCK_SIZE - 1) // BLOCK_SIZE,) + _triton_convert_rs_kernel[grid](out, fp32_values, rand_bits, n, BLOCK_SIZE) + return out + + +# --------------------------------------------------------------------------- +# CUDA sources +# --------------------------------------------------------------------------- +_FLASHINFER_INCLUDE = str(pathlib.Path(__file__).resolve().parents[2] / "include") + +_PHILOX_CUDA_SOURCE = r""" +#include +#include + +template +__global__ void philox_kernel(int32_t* out, int64_t seed, int n_elements) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= n_elements) return; + uint32_t result = flashinfer::mamba::conversion::philox_randint(seed, (uint32_t)idx); + out[idx] = static_cast(result); +} + +torch::Tensor cuda_philox(int64_t seed, int n_elements, int n_rounds) { + auto out = torch::empty({n_elements}, torch::dtype(torch::kInt32).device(torch::kCUDA)); + int threads = 256; + int blocks = (n_elements + threads - 1) / threads; + switch (n_rounds) { + case 1: philox_kernel<1><<>>(out.data_ptr(), seed, n_elements); break; + case 4: philox_kernel<4><<>>(out.data_ptr(), seed, n_elements); break; + case 10: philox_kernel<10><<>>(out.data_ptr(), seed, n_elements); break; + default: TORCH_CHECK(false, "Unsupported n_rounds: ", n_rounds); + } + return out; +} +""" + +_PHILOX_CPP_SOURCE = r""" +torch::Tensor cuda_philox(int64_t seed, int n_elements, int n_rounds); +""" + +_STOCHASTIC_ROUND_CUDA_SOURCE = r""" +#include +#include +#include + +__global__ void stochastic_round_kernel(half* out, const float* fp32_in, + const int32_t* rand_in, int n_elements) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int pair_idx = idx * 2; + if (pair_idx + 1 >= n_elements) return; + + float a = fp32_in[pair_idx]; + float b = fp32_in[pair_idx + 1]; + // Triton pack=2 uses rand from the first element of the pair + uint32_t rand = *reinterpret_cast(&rand_in[pair_idx]); + + uint32_t packed = flashinfer::mamba::conversion::cvt_rs_f16x2_f32(a, b, rand); + *reinterpret_cast(&out[pair_idx]) = packed; +} + +torch::Tensor cuda_stochastic_round(torch::Tensor fp32_values, torch::Tensor rand_bits) { + int n = fp32_values.numel(); + auto out = torch::empty({n}, torch::dtype(torch::kFloat16).device(torch::kCUDA)); + int n_pairs = n / 2; + int threads = 256; + int blocks = (n_pairs + threads - 1) / threads; + stochastic_round_kernel<<>>( + reinterpret_cast(out.data_ptr()), + fp32_values.data_ptr(), + rand_bits.data_ptr(), + n); + return out; +} +""" + +_STOCHASTIC_ROUND_CPP_SOURCE = r""" +torch::Tensor cuda_stochastic_round(torch::Tensor fp32_values, torch::Tensor rand_bits); +""" + +_STOCHASTIC_ROUND_SINGLE_CUDA_SOURCE = r""" +#include +#include +#include + +// Each thread converts one fp32 value to fp16 using cvt_rs_f16_f32 (single-value). +// rand13_in contains 13-bit random values (one per element, stored as int32). +__global__ void stochastic_round_single_kernel(half* out, const float* fp32_in, + const int32_t* rand13_in, int n_elements) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= n_elements) return; + + float x = fp32_in[idx]; + uint32_t rand13 = static_cast(rand13_in[idx]) & 0x1FFFu; + out[idx] = flashinfer::mamba::conversion::cvt_rs_f16_f32(x, rand13); +} + +torch::Tensor cuda_stochastic_round_single(torch::Tensor fp32_values, torch::Tensor rand13_bits) { + int n = fp32_values.numel(); + auto out = torch::empty({n}, torch::dtype(torch::kFloat16).device(torch::kCUDA)); + int threads = 256; + int blocks = (n + threads - 1) / threads; + stochastic_round_single_kernel<<>>( + reinterpret_cast(out.data_ptr()), + fp32_values.data_ptr(), + rand13_bits.data_ptr(), + n); + return out; +} +""" + +_STOCHASTIC_ROUND_SINGLE_CPP_SOURCE = r""" +torch::Tensor cuda_stochastic_round_single(torch::Tensor fp32_values, torch::Tensor rand13_bits); +""" + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- +@pytest.fixture(scope="module") +def philox_module(): + """Compile philox_randint test kernel (works on any GPU).""" + return load_inline( + name="test_philox", + cpp_sources=[_PHILOX_CPP_SOURCE], + cuda_sources=[_PHILOX_CUDA_SOURCE], + extra_include_paths=[_FLASHINFER_INCLUDE], + functions=["cuda_philox"], + verbose=False, + ) + + +@pytest.fixture(scope="module") +def stochastic_round_module(): + """Compile cvt_rs_f16x2_f32 test kernel with sm_100a (hardware PTX path).""" + major, minor = get_compute_capability(torch.device("cuda")) + if major < 10: + pytest.skip("cvt.rs.f16x2.f32 requires sm_100a (Blackwell or later)") + # Append 'a' suffix for SM >= 9, matching flashinfer/compilation_context.py:44-45 + minor_str = f"{minor}a" if major >= 9 else str(minor) + gencode = f"-gencode=arch=compute_{major}{minor_str},code=sm_{major}{minor_str}" + return load_inline( + name="test_stochastic_round", + cpp_sources=[_STOCHASTIC_ROUND_CPP_SOURCE], + cuda_sources=[_STOCHASTIC_ROUND_CUDA_SOURCE], + extra_include_paths=[_FLASHINFER_INCLUDE], + extra_cuda_cflags=[gencode], + functions=["cuda_stochastic_round"], + verbose=False, + ) + + +@pytest.fixture(scope="module") +def stochastic_round_sw_module(): + """Compile cvt_rs_f16x2_f32 test kernel without sm_100a (software fallback path).""" + return load_inline( + name="test_stochastic_round_sw", + cpp_sources=[_STOCHASTIC_ROUND_CPP_SOURCE], + cuda_sources=[_STOCHASTIC_ROUND_CUDA_SOURCE], + extra_include_paths=[_FLASHINFER_INCLUDE], + functions=["cuda_stochastic_round"], + verbose=False, + ) + + +@pytest.fixture(scope="module") +def stochastic_round_single_module(): + """Compile cvt_rs_f16_f32 single-value test kernel (software path, any GPU).""" + return load_inline( + name="test_stochastic_round_single", + cpp_sources=[_STOCHASTIC_ROUND_SINGLE_CPP_SOURCE], + cuda_sources=[_STOCHASTIC_ROUND_SINGLE_CUDA_SOURCE], + extra_include_paths=[_FLASHINFER_INCLUDE], + functions=["cuda_stochastic_round_single"], + verbose=False, + ) + + +# --------------------------------------------------------------------------- +# Test 1: Philox randint (any GPU) +# --------------------------------------------------------------------------- +@pytest.mark.parametrize("n_rounds", [1, 4, 10]) +@pytest.mark.parametrize("seed", [0, 42, 123456, 2**31 - 1]) +def test_philox_randint(philox_module, seed, n_rounds): + """Bitwise comparison of CUDA philox_randint vs Triton tl.randint.""" + n_elements = 1024 + + cuda_out = philox_module.cuda_philox(seed, n_elements, n_rounds) + triton_out = triton_philox(seed, n_elements, n_rounds) + + mismatches = (cuda_out != triton_out).sum().item() + if mismatches > 0: + diff_idx = torch.where(cuda_out != triton_out)[0][:10] + for idx in diff_idx: + i = idx.item() + c = cuda_out[i].item() & 0xFFFFFFFF + t = triton_out[i].item() & 0xFFFFFFFF + print(f" offset={i}: CUDA=0x{c:08X}, Triton=0x{t:08X}") + + assert mismatches == 0, ( + f"seed={seed}, n_rounds={n_rounds}: {mismatches}/{n_elements} mismatches" + ) + print(f" seed={seed}, n_rounds={n_rounds}: all {n_elements} values match") + + +# --------------------------------------------------------------------------- +# Test 2: Stochastic rounding (sm_100a+ only) +# --------------------------------------------------------------------------- +@pytest.mark.parametrize("seed", [0, 42, 99999]) +def test_stochastic_rounding(stochastic_round_module, seed): + """Bitwise comparison of CUDA vs Triton stochastic rounding (cvt.rs.f16x2.f32).""" + n_elements = 1024 # must be even + torch.manual_seed(seed) + + fp32_values = torch.randn(n_elements, dtype=torch.float32, device="cuda") + rand_bits = torch.randint( + -(2**31), 2**31, (n_elements,), dtype=torch.int32, device="cuda" + ) + + cuda_out = stochastic_round_module.cuda_stochastic_round(fp32_values, rand_bits) + triton_out = triton_stochastic_round(fp32_values, rand_bits) + + # Compare as raw uint16 bit patterns + cuda_bits = cuda_out.view(torch.int16) + triton_bits = triton_out.view(torch.int16) + + mismatches = (cuda_bits != triton_bits).sum().item() + if mismatches > 0: + diff_idx = torch.where(cuda_bits != triton_bits)[0][:10] + for idx in diff_idx: + i = idx.item() + pair = i // 2 + cb = cuda_bits[i].item() & 0xFFFF + tb = triton_bits[i].item() & 0xFFFF + cf = cuda_out[i].item() + tf = triton_out[i].item() + rb = rand_bits[pair].item() & 0xFFFFFFFF + print( + f" elem={i} (pair={pair}): fp32={fp32_values[i].item():.6f}, " + f"rand=0x{rb:08X}, CUDA=0x{cb:04X}({cf}), Triton=0x{tb:04X}({tf})" + ) + + assert mismatches == 0, f"seed={seed}: {mismatches}/{n_elements} mismatches" + print(f" seed={seed}: all {n_elements} fp16 values match bitwise (hw)") + + +# --------------------------------------------------------------------------- +# Test 3: Stochastic rounding software fallback (any GPU, verified on Blackwell) +# --------------------------------------------------------------------------- +@pytest.mark.parametrize("seed", [0, 42, 99999]) +def test_stochastic_rounding_sw( + stochastic_round_sw_module, stochastic_round_module, seed +): + """Software stochastic rounding matches hardware PTX path bitwise.""" + n_elements = 1024 # must be even + torch.manual_seed(seed) + + fp32_values = torch.randn(n_elements, dtype=torch.float32, device="cuda") + rand_bits = torch.randint( + -(2**31), 2**31, (n_elements,), dtype=torch.int32, device="cuda" + ) + + sw_out = stochastic_round_sw_module.cuda_stochastic_round(fp32_values, rand_bits) + hw_out = stochastic_round_module.cuda_stochastic_round(fp32_values, rand_bits) + + # Compare as raw uint16 bit patterns + sw_bits = sw_out.view(torch.int16) + hw_bits = hw_out.view(torch.int16) + + mismatches = (sw_bits != hw_bits).sum().item() + if mismatches > 0: + diff_idx = torch.where(sw_bits != hw_bits)[0][:10] + for idx in diff_idx: + i = idx.item() + pair = i // 2 + sb = sw_bits[i].item() & 0xFFFF + hb = hw_bits[i].item() & 0xFFFF + sf = sw_out[i].item() + hf = hw_out[i].item() + rb = rand_bits[pair].item() & 0xFFFFFFFF + print( + f" elem={i} (pair={pair}): fp32={fp32_values[i].item():.6f}, " + f"rand=0x{rb:08X}, SW=0x{sb:04X}({sf}), HW=0x{hb:04X}({hf})" + ) + + assert mismatches == 0, f"seed={seed}: {mismatches}/{n_elements} mismatches" + print(f" seed={seed}: all {n_elements} fp16 values match bitwise (sw vs hw)") + + +# --------------------------------------------------------------------------- +# Test 4: Single-value stochastic rounding matches pair-wise (any GPU) +# --------------------------------------------------------------------------- +@pytest.mark.parametrize("seed", [0, 42, 99999]) +def test_stochastic_rounding_single_vs_pair( + stochastic_round_single_module, stochastic_round_sw_module, seed +): + """cvt_rs_f16_f32 (single) matches the corresponding element from cvt_rs_f16x2_f32 (pair).""" + n_elements = 1024 # must be even + torch.manual_seed(seed) + + fp32_values = torch.randn(n_elements, dtype=torch.float32, device="cuda") + # Generate 13-bit random values per element + rand13 = torch.randint(0, 8192, (n_elements,), dtype=torch.int32, device="cuda") + + # Single-value path: cvt_rs_f16_f32(x, rand13) for each element + single_out = stochastic_round_single_module.cuda_stochastic_round_single( + fp32_values, rand13 + ) + + # Pair-wise path: cvt_rs_f16x2_f32(a, b, rbits) where rbits packs rand13 for both + # rbits layout: bits[12:0] = rand for C++ a (low half), bits[28:16] = rand for C++ b (high half) + rand_a = rand13[0::2] # even elements + rand_b = rand13[1::2] # odd elements + rbits = (rand_a & 0x1FFF) | ((rand_b & 0x1FFF) << 16) + # Expand rbits back to n_elements (pair-wise kernel reads from pair_idx) + rbits_expanded = torch.zeros(n_elements, dtype=torch.int32, device="cuda") + rbits_expanded[0::2] = rbits + rbits_expanded[1::2] = rbits # pair kernel reads from pair_idx = even index + + pair_out = stochastic_round_sw_module.cuda_stochastic_round( + fp32_values, rbits_expanded + ) + + # Compare as raw bit patterns + single_bits = single_out.view(torch.int16) + pair_bits = pair_out.view(torch.int16) + + mismatches = (single_bits != pair_bits).sum().item() + if mismatches > 0: + diff_idx = torch.where(single_bits != pair_bits)[0][:10] + for idx in diff_idx: + i = idx.item() + sb = single_bits[i].item() & 0xFFFF + pb = pair_bits[i].item() & 0xFFFF + sf = single_out[i].item() + pf = pair_out[i].item() + r13 = rand13[i].item() & 0x1FFF + print( + f" elem={i}: fp32={fp32_values[i].item():.6f}, " + f"rand13=0x{r13:04X}, single=0x{sb:04X}({sf}), pair=0x{pb:04X}({pf})" + ) + + assert mismatches == 0, f"seed={seed}: {mismatches}/{n_elements} mismatches" + print(f" seed={seed}: all {n_elements} fp16 values match bitwise (single vs pair)") diff --git a/tests/mamba/test_selective_state_update_mtp.py b/tests/mamba/test_selective_state_update_mtp.py index f295ee6bad..0dc91e444c 100644 --- a/tests/mamba/test_selective_state_update_mtp.py +++ b/tests/mamba/test_selective_state_update_mtp.py @@ -10,6 +10,7 @@ import torch import flashinfer +from flashinfer.utils import get_compute_capability from .triton_reference.selective_state_update import selective_state_update_triton from .utils import create_test_inputs, clone_preserving_strides @@ -19,7 +20,7 @@ # state_dtype=bf16, weight_dtype=f32, use_out_tensor=True # Each additional row varies exactly one parameter from the base. # fmt: off -_BASE_PARAMS = [ +_BASE_PARAMS = ( # (batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype, use_out_tensor) ( 64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True ), # base ( 1, 64, 64, 128, 4, torch.bfloat16, torch.float32, True ), # batch=1 @@ -31,7 +32,7 @@ ( 64, 64, 64, 128, 8, torch.bfloat16, torch.float32, True ), # cache_steps=8 ( 64, 64, 64, 128, 4, torch.float32, torch.float32, True ), # state_dtype=f32 ( 64, 64, 64, 128, 4, torch.bfloat16, torch.float32, False), # use_out_tensor=False -] +) # fmt: on @@ -399,14 +400,16 @@ def run_kernel_with_intermediate_states(self, inputs, out=None): ) # fmt: off - _INTERMEDIATE_PARAMS = [ + _INTERMEDIATE_PARAMS = ( # (batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype, use_out_tensor) - ( 64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True ), # base + ( 64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True ), # base bf16 ( 64, 64, 64, 64, 4, torch.bfloat16, torch.float32, True ), # dstate=64 ( 64, 64, 64, 128, 2, torch.bfloat16, torch.float32, True ), # cache_steps=2 ( 64, 64, 64, 128, 8, torch.bfloat16, torch.float32, True ), # cache_steps=8 ( 64, 64, 64, 128, 4, torch.bfloat16, torch.float32, False), # use_out_tensor=False - ] + ( 64, 64, 64, 128, 4, torch.float16, torch.float32, True ), # state_dtype=f16 + ( 64, 64, 64, 128, 4, torch.float32, torch.float32, True ), # state_dtype=f32 + ) # fmt: on @pytest.mark.parametrize( @@ -599,13 +602,13 @@ class TestSelectiveStateUpdateMTPVariousNgroups(TestSelectiveStateUpdateMTP): """Test multi-token selective_state_update with various ngroups values.""" # fmt: off - _NGROUPS_PARAMS = [ + _NGROUPS_PARAMS = ( # (batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype, use_out_tensor, ngroups) ( 64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True, 1), ( 64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True, 2), ( 64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True, 4), ( 64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True, 8), - ] + ) # fmt: on @pytest.mark.parametrize( @@ -662,11 +665,11 @@ class TestSelectiveStateUpdateMTPLargeBatch(TestSelectiveStateUpdateMTP): """Test multi-token selective_state_update with larger batch sizes.""" # fmt: off - _LARGE_BATCH_PARAMS = [ + _LARGE_BATCH_PARAMS = ( # (batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype, use_out_tensor) ( 16, 64, 64, 128, 4, torch.bfloat16, torch.float32, True ), # batch=16 ( 256, 64, 64, 128, 4, torch.bfloat16, torch.float32, True ), # batch=256 - ] + ) # fmt: on @pytest.mark.parametrize( @@ -696,6 +699,331 @@ def test_output_correctness( ) +# fmt: off +_INT16_MTP_PARAMS = ( + # (batch, nheads, dim, dstate, cache_steps, weight_dtype, use_out_tensor) + ( 64, 64, 64, 128, 4, torch.float32, True ), # base + ( 64, 64, 128, 128, 4, torch.float32, True ), # dim=128 + ( 64, 64, 64, 128, 4, torch.bfloat16, True ), # weight_dtype=bf16 +) +# fmt: on + + +class TestSelectiveStateUpdateMTPInt16(TestSelectiveStateUpdateMTP): + """Test multi-token selective_state_update with int16 quantized state and block scaling.""" + + ATOL = 1e-1 + RTOL = 1e-2 + + def make_inputs( + self, batch, nheads, dim, dstate, cache_steps, _state_dtype, weight_dtype + ): + """Create test inputs with int16 state.""" + return create_test_inputs( + batch, + nheads, + dim, + dstate, + self.NGROUPS, + self.INPUT_DTYPE, + weight_dtype=weight_dtype, + matrixA_dtype=self.MATRIX_A_DTYPE, + state_dtype=torch.int16, + generate_z=False, + generate_intermediate_states_buffer=False, + cache_steps=cache_steps, + seed=0, + ) + + def make_reference_output(self, inputs): + """Compute reference output using Triton with state_scale.""" + state_ref = clone_preserving_strides(inputs["state_cache"]) + state_scale_ref = inputs["state_scale"].clone() + y_ref = selective_state_update_triton( + state_ref, + inputs["x"], + inputs["dt"], + inputs["A"], + inputs["B"], + inputs["C"], + D=inputs["D"], + z=inputs.get("z"), + dt_bias=inputs["dt_bias"], + dt_softplus=True, + state_batch_indices=inputs["slot_idx"], + pad_slot_id=-1, + state_scale=state_scale_ref, + ) + return y_ref, state_ref, state_scale_ref + + def run_kernel(self, inputs, out=None, disable_state_update=False): + """Run the flashinfer kernel with state_scale.""" + return flashinfer.mamba.selective_state_update( + inputs["state_cache"], + inputs["x"], + inputs["dt"], + inputs["A"], + inputs["B"], + inputs["C"], + D=inputs["D"], + z=inputs.get("z"), + dt_bias=inputs["dt_bias"], + dt_softplus=True, + state_batch_indices=inputs["slot_idx"], + pad_slot_id=-1, + out=out, + disable_state_update=disable_state_update, + state_scale=inputs["state_scale"], + ) + + def assert_states_match( + self, + state_ref, + state_test, + slot_idx, + msg_prefix="", + state_scale_ref=None, + state_scale_test=None, + ): + """Compare dequantized int16 states.""" + state_ref_batch = state_ref[slot_idx].float() + state_test_batch = state_test[slot_idx].float() + + # Dequantize using the respective scales + if state_scale_ref is not None: + scale_ref = state_scale_ref[slot_idx] + if scale_ref.dim() == 3: + scale_ref = scale_ref.unsqueeze(-1) + state_ref_batch = state_ref_batch * scale_ref + if state_scale_test is not None: + scale_test = state_scale_test[slot_idx] + if scale_test.dim() == 3: + scale_test = scale_test.unsqueeze(-1) + state_test_batch = state_test_batch * scale_test + + states_match = torch.allclose( + state_ref_batch, state_test_batch, atol=self.ATOL, rtol=self.RTOL + ) + + if states_match: + print( + f"✓ {msg_prefix}Dequantized states match within tolerance (atol={self.ATOL}, rtol={self.RTOL})" + ) + else: + print(f"✗ {msg_prefix}Dequantized states do NOT match") + diff = (state_ref_batch - state_test_batch).abs() + print( + f" Max diff: {diff.max().item():.6f}, Mean diff: {diff.mean().item():.6f}" + ) + + assert states_match + + @pytest.mark.parametrize( + "batch,nheads,dim,dstate,cache_steps,state_dtype,weight_dtype,use_out_tensor", + [ + (b, nh, d, ds, cs, torch.int16, wd, uo) + for b, nh, d, ds, cs, wd, uo in _INT16_MTP_PARAMS + ], + ) + def test_output_correctness( + self, + batch, + nheads, + dim, + dstate, + cache_steps, + state_dtype, + weight_dtype, + use_out_tensor, + ): + inputs = self.make_inputs( + batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype + ) + y_ref, state_ref, state_scale_ref = self.make_reference_output(inputs) + + if use_out_tensor: + out = torch.empty_like(inputs["x"]) + else: + out = None + + y_test = self.run_kernel(inputs, out=out) + + if use_out_tensor: + assert y_test.data_ptr() == out.data_ptr() + + self.assert_outputs_match(y_ref, y_test) + self.assert_states_match( + state_ref, + inputs["state_cache"], + inputs["slot_idx"], + state_scale_ref=state_scale_ref, + state_scale_test=inputs["state_scale"], + ) + + +class TestSelectiveStateUpdateMTPInt16IntermediateStates( + TestSelectiveStateUpdateMTPWithIntermediateStates +): + """Test int16 scaled state with intermediate_states buffer.""" + + ATOL = 1e-1 + RTOL = 1e-2 + + def make_inputs( + self, batch, nheads, dim, dstate, cache_steps, _state_dtype, weight_dtype + ): + """Create test inputs with int16 state and intermediate states buffer.""" + return create_test_inputs( + batch, + nheads, + dim, + dstate, + self.NGROUPS, + self.INPUT_DTYPE, + weight_dtype=weight_dtype, + matrixA_dtype=self.MATRIX_A_DTYPE, + state_dtype=torch.int16, + generate_z=False, + generate_intermediate_states_buffer=True, + cache_steps=cache_steps, + seed=0, + ) + + def make_reference_output(self, inputs): + """Compute reference output using Triton with state_scale and intermediate states.""" + state_ref = clone_preserving_strides(inputs["state_cache"]) + state_scale_ref = inputs["state_scale"].clone() + intermediate_states_ref = inputs["intermediate_states_buffer"].clone() + intermediate_state_scales_ref = inputs["intermediate_state_scales"].clone() + + y_ref = selective_state_update_triton( + state_ref, + inputs["x"], + inputs["dt"], + inputs["A"], + inputs["B"], + inputs["C"], + D=inputs["D"], + z=inputs.get("z"), + dt_bias=inputs["dt_bias"], + dt_softplus=True, + state_batch_indices=inputs["slot_idx"], + pad_slot_id=-1, + disable_state_update=True, + intermediate_states_buffer=intermediate_states_ref, + cache_steps=inputs["cache_steps"], + intermediate_state_indices=inputs["intermediate_slot_idx"], + state_scale=state_scale_ref, + intermediate_state_scales=intermediate_state_scales_ref, + ) + return y_ref, state_ref, intermediate_states_ref, intermediate_state_scales_ref + + def run_kernel_with_intermediate_states(self, inputs, out=None): + """Run the flashinfer kernel with int16 state and intermediate states.""" + return flashinfer.mamba.selective_state_update( + inputs["state_cache"], + inputs["x"], + inputs["dt"], + inputs["A"], + inputs["B"], + inputs["C"], + D=inputs["D"], + z=inputs.get("z"), + dt_bias=inputs["dt_bias"], + dt_softplus=True, + state_batch_indices=inputs["slot_idx"], + pad_slot_id=-1, + out=out, + disable_state_update=True, + intermediate_states_buffer=inputs["intermediate_states_buffer"], + intermediate_state_indices=inputs["intermediate_slot_idx"], + cache_steps=inputs["cache_steps"], + state_scale=inputs["state_scale"], + intermediate_state_scales=inputs["intermediate_state_scales"], + ) + + # fmt: off + _INT16_INTERMEDIATE_PARAMS = ( + # (batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype, use_out_tensor) + ( 64, 64, 64, 128, 4, torch.int16, torch.float32, True ), # base + ( 64, 64, 64, 128, 2, torch.int16, torch.float32, True ), # cache_steps=2 + ( 64, 64, 64, 128, 8, torch.int16, torch.float32, True ), # cache_steps=8 + ) + # fmt: on + + @pytest.mark.parametrize( + "batch,nheads,dim,dstate,cache_steps,state_dtype,weight_dtype,use_out_tensor", + _INT16_INTERMEDIATE_PARAMS, + ) + def test_output_correctness( + self, + batch, + nheads, + dim, + dstate, + cache_steps, + state_dtype, + weight_dtype, + use_out_tensor, + ): + """Test output and dequantized intermediate states match.""" + inputs = self.make_inputs( + batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype + ) + y_ref, _state_ref, intermediate_states_ref, iscales_ref = ( + self.make_reference_output(inputs) + ) + + if use_out_tensor: + out = torch.empty_like(inputs["x"]) + else: + out = None + + y_test = self.run_kernel_with_intermediate_states(inputs, out=out) + + if use_out_tensor: + assert y_test.data_ptr() == out.data_ptr() + + self.assert_outputs_match(y_ref, y_test, msg_prefix="[int16_intermediate] ") + + # Compare dequantized intermediate states + cache_steps_val = inputs["cache_steps"] + intermediate_states_test = inputs["intermediate_states_buffer"] + iscales_test = inputs["intermediate_state_scales"] + + for t in range(cache_steps_val): + # Dequantize: int16 * scale → float + ref_deq = ( + intermediate_states_ref[:, t, :, :, :].float() + * iscales_ref[:, t, :, :, None] + ) + test_deq = ( + intermediate_states_test[:, t, :, :, :].float() + * iscales_test[:, t, :, :, None] + ) + + states_match = torch.allclose( + ref_deq, test_deq, atol=self.ATOL, rtol=self.RTOL + ) + + max_diff = (ref_deq - test_deq).abs().max().item() + if states_match: + print( + f"✓ Intermediate state {t} (dequantized) matches " + f"(max_diff={max_diff:.6e})" + ) + else: + print( + f"✗ Intermediate state {t} (dequantized) mismatch " + f"(max_diff={max_diff:.6e})" + ) + self._print_mismatch_details( + ref_deq, test_deq, f"intermediate_state_{t}" + ) + + assert states_match, f"Intermediate state at step {t} mismatch" + + class TestSelectiveStateUpdateMTPIndicesDtypeMismatch: """Test that selective_state_update fails with dtype mismatch between indices.""" @@ -753,3 +1081,292 @@ def test_state_batch_idx_and_intermediate_idx_dtype_mismatch_should_fail(self): intermediate_state_indices=inputs["intermediate_slot_idx"], cache_steps=inputs["cache_steps"], ) + + +class TestSelectiveStateUpdateMTPStochasticRounding(TestSelectiveStateUpdateMTP): + """Test fp16 state with stochastic rounding vs Triton reference.""" + + ATOL = 0.001 + RTOL = 0.01 + + RAND_SEED = torch.tensor(42, dtype=torch.int64, device="cuda") + + def make_inputs( + self, batch, nheads, dim, dstate, cache_steps, _state_dtype, weight_dtype + ): + """Create test inputs with fp16 state.""" + return create_test_inputs( + batch, + nheads, + dim, + dstate, + self.NGROUPS, + self.INPUT_DTYPE, + weight_dtype=weight_dtype, + matrixA_dtype=self.MATRIX_A_DTYPE, + state_dtype=torch.float16, + generate_z=False, + generate_intermediate_states_buffer=False, + cache_steps=cache_steps, + seed=0, + ) + + def make_reference_output(self, inputs): + """Compute reference output using Triton with stochastic rounding.""" + state_ref = clone_preserving_strides(inputs["state_cache"]) + major, _ = get_compute_capability(torch.device("cuda")) + # Triton cvt.rs.f16x2.f32 requires SM100a+; on older GPUs the Triton + # reference falls back to regular rounding while the CUDA kernel still + # exercises its software stochastic rounding path. + rand_seed = self.RAND_SEED if major >= 10 else None + y_ref = selective_state_update_triton( + state_ref, + inputs["x"], + inputs["dt"], + inputs["A"], + inputs["B"], + inputs["C"], + D=inputs["D"], + z=inputs.get("z"), + dt_bias=inputs["dt_bias"], + dt_softplus=True, + state_batch_indices=inputs["slot_idx"], + pad_slot_id=-1, + rand_seed=rand_seed, + ) + return y_ref, state_ref + + def run_kernel(self, inputs, out=None, disable_state_update=False): + """Run the flashinfer kernel with stochastic rounding.""" + return flashinfer.mamba.selective_state_update( + inputs["state_cache"], + inputs["x"], + inputs["dt"], + inputs["A"], + inputs["B"], + inputs["C"], + D=inputs["D"], + z=inputs.get("z"), + dt_bias=inputs["dt_bias"], + dt_softplus=True, + state_batch_indices=inputs["slot_idx"], + pad_slot_id=-1, + out=out, + disable_state_update=disable_state_update, + rand_seed=self.RAND_SEED, + ) + + def assert_states_match(self, state_ref, state_test, slot_idx, msg_prefix=""): + """Assert states match within tolerance (SR path has different FP operation order than Triton).""" + state_ref_batch = state_ref[slot_idx] + state_test_batch = state_test[slot_idx] + states_match = torch.allclose( + state_ref_batch.float(), + state_test_batch.float(), + atol=self.ATOL, + rtol=self.RTOL, + ) + + if states_match: + print( + f"✓ {msg_prefix}States match within tolerance (atol={self.ATOL}, rtol={self.RTOL})" + ) + else: + max_diff = ( + (state_ref_batch.float() - state_test_batch.float()).abs().max().item() + ) + print(f"✗ {msg_prefix}States do NOT match (max_diff={max_diff:.6e})") + self._print_mismatch_details(state_ref_batch, state_test_batch, "state") + + assert states_match + + # fmt: off + _SR_PARAMS = ( + # (batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype, use_out_tensor) + ( 64, 64, 64, 128, 4, torch.float16, torch.float32, True ), # base + ( 64, 64, 64, 64, 4, torch.float16, torch.float32, True ), # dstate=64 + ) + # fmt: on + + @pytest.mark.parametrize( + "batch,nheads,dim,dstate,cache_steps,state_dtype,weight_dtype,use_out_tensor", + _SR_PARAMS, + ) + def test_output_correctness( + self, + batch, + nheads, + dim, + dstate, + cache_steps, + state_dtype, + weight_dtype, + use_out_tensor, + ): + super().test_output_correctness( + batch, + nheads, + dim, + dstate, + cache_steps, + state_dtype, + weight_dtype, + use_out_tensor, + ) + + +class TestSelectiveStateUpdateMTPStochasticRoundingWithIntermediateStates( + TestSelectiveStateUpdateMTPWithIntermediateStates +): + """Test fp16 state with stochastic rounding + intermediate states.""" + + ATOL = 0.001 + RTOL = 0.01 + + RAND_SEED = torch.tensor(42, dtype=torch.int64, device="cuda") + + def make_inputs( + self, batch, nheads, dim, dstate, cache_steps, _state_dtype, weight_dtype + ): + """Create test inputs with fp16 state and intermediate states.""" + return create_test_inputs( + batch, + nheads, + dim, + dstate, + self.NGROUPS, + self.INPUT_DTYPE, + weight_dtype=weight_dtype, + matrixA_dtype=self.MATRIX_A_DTYPE, + state_dtype=torch.float16, + generate_z=False, + generate_intermediate_states_buffer=True, + cache_steps=cache_steps, + seed=0, + ) + + def make_reference_output(self, inputs): + """Compute reference output using Triton with SR and intermediate states.""" + state_ref = clone_preserving_strides(inputs["state_cache"]) + intermediate_states_ref = inputs["intermediate_states_buffer"].clone() + major, _ = get_compute_capability(torch.device("cuda")) + # Triton cvt.rs.f16x2.f32 requires SM100a+; on older GPUs the Triton + # reference falls back to regular rounding while the CUDA kernel still + # exercises its software stochastic rounding path. + rand_seed = self.RAND_SEED if major >= 10 else None + + y_ref = selective_state_update_triton( + state_ref, + inputs["x"], + inputs["dt"], + inputs["A"], + inputs["B"], + inputs["C"], + D=inputs["D"], + z=inputs.get("z"), + dt_bias=inputs["dt_bias"], + dt_softplus=True, + state_batch_indices=inputs["slot_idx"], + pad_slot_id=-1, + disable_state_update=True, + intermediate_states_buffer=intermediate_states_ref, + cache_steps=inputs["cache_steps"], + intermediate_state_indices=inputs["intermediate_slot_idx"], + rand_seed=rand_seed, + ) + return y_ref, state_ref, intermediate_states_ref + + def run_kernel_with_intermediate_states(self, inputs, out=None): + """Run the flashinfer kernel with SR and intermediate states.""" + return flashinfer.mamba.selective_state_update( + inputs["state_cache"], + inputs["x"], + inputs["dt"], + inputs["A"], + inputs["B"], + inputs["C"], + D=inputs["D"], + z=inputs.get("z"), + dt_bias=inputs["dt_bias"], + dt_softplus=True, + state_batch_indices=inputs["slot_idx"], + pad_slot_id=-1, + out=out, + disable_state_update=True, + intermediate_states_buffer=inputs["intermediate_states_buffer"], + intermediate_state_indices=inputs["intermediate_slot_idx"], + cache_steps=inputs["cache_steps"], + rand_seed=self.RAND_SEED, + ) + + # fmt: off + _SR_INTERMEDIATE_PARAMS = ( + # (batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype, use_out_tensor) + ( 64, 64, 64, 128, 4, torch.float16, torch.float32, True ), # base + ( 64, 64, 64, 64, 4, torch.float16, torch.float32, True ), # dstate=64 + ) + # fmt: on + + @pytest.mark.parametrize( + "batch,nheads,dim,dstate,cache_steps,state_dtype,weight_dtype,use_out_tensor", + _SR_INTERMEDIATE_PARAMS, + ) + def test_output_correctness( + self, + batch, + nheads, + dim, + dstate, + cache_steps, + state_dtype, + weight_dtype, + use_out_tensor, + ): + """Test output and intermediate states match bitwise with SR.""" + inputs = self.make_inputs( + batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype + ) + y_ref, _state_ref, intermediate_states_ref = self.make_reference_output(inputs) + + if use_out_tensor: + out = torch.empty_like(inputs["x"]) + else: + out = None + + y_test = self.run_kernel_with_intermediate_states(inputs, out=out) + + if use_out_tensor: + assert y_test.data_ptr() == out.data_ptr() + + # Output comparison (approximate — output is fp32 sum, not directly SR'd) + outputs_match = torch.allclose(y_ref, y_test, atol=1e-3, rtol=1e-2) + if outputs_match: + print("✓ [SR_intermediate] Outputs match within tolerance") + else: + print("✗ [SR_intermediate] Outputs do NOT match") + self._print_mismatch_details(y_ref, y_test, "output") + assert outputs_match + + # Intermediate states: tolerance-based match (SR uses different FP op order than Triton) + cache_steps_val = inputs["cache_steps"] + intermediate_states_test = inputs["intermediate_states_buffer"] + + for t in range(cache_steps_val): + ref_state = intermediate_states_ref[:, t, :, :, :] + test_state = intermediate_states_test[:, t, :, :, :] + + states_match = torch.allclose( + ref_state.float(), test_state.float(), atol=self.ATOL, rtol=self.RTOL + ) + + if states_match: + max_diff = (ref_state.float() - test_state.float()).abs().max().item() + print(f"✓ Intermediate state {t} matches (max_diff={max_diff:.6e})") + else: + max_diff = (ref_state.float() - test_state.float()).abs().max().item() + print(f"✗ Intermediate state {t} mismatch (max_diff={max_diff:.6e})") + self._print_mismatch_details( + ref_state, test_state, f"intermediate_state_{t}" + ) + + assert states_match, f"Intermediate state at step {t} mismatch" diff --git a/tests/mamba/test_selective_state_update_stp.py b/tests/mamba/test_selective_state_update_stp.py index d23faa644a..7da054d18f 100644 --- a/tests/mamba/test_selective_state_update_stp.py +++ b/tests/mamba/test_selective_state_update_stp.py @@ -467,6 +467,309 @@ def test_output_correctness( ) +# fmt: off +_INT16_PARAMS = [ + # (batch, nheads, dim, dstate, weight_dtype, use_out_tensor) + ( 64, 64, 64, 128, torch.float32, True ), # base + ( 1, 64, 64, 128, torch.float32, True ), # batch=1 + ( 64, 8, 64, 128, torch.float32, True ), # nheads=8 + ( 64, 64, 128, 128, torch.float32, True ), # dim=128 + ( 64, 64, 64, 64, torch.float32, True ), # dstate=64 + ( 64, 64, 64, 256, torch.float32, True ), # dstate=256 + ( 64, 64, 64, 128, torch.bfloat16, True ), # weight_dtype=bf16 +] +# fmt: on + + +class TestSelectiveStateUpdateInt16(TestSelectiveStateUpdate): + """Test selective_state_update with int16 quantized state and block scaling.""" + + ATOL = 1e-1 + RTOL = 1e-2 + + def make_inputs(self, batch, nheads, dim, dstate, _state_dtype, weight_dtype): + """Create test inputs with int16 state.""" + return create_test_inputs( + batch, + nheads, + dim, + dstate, + self.NGROUPS, + self.INPUT_DTYPE, + weight_dtype=weight_dtype, + matrixA_dtype=self.MATRIX_A_DTYPE, + state_dtype=torch.int16, + generate_z=False, + seed=0, + ) + + def make_reference_output(self, inputs): + """Compute reference output using Triton with state_scale.""" + state_ref = inputs["state_cache"].clone() + state_scale_ref = inputs["state_scale"].clone() + y_ref = selective_state_update_triton( + state_ref, + inputs["x"], + inputs["dt"], + inputs["A"], + inputs["B"], + inputs["C"], + D=inputs["D"], + z=inputs.get("z"), + dt_bias=inputs["dt_bias"], + dt_softplus=True, + state_batch_indices=inputs["slot_idx"], + pad_slot_id=-1, + state_scale=state_scale_ref, + ) + return y_ref, state_ref, state_scale_ref + + def run_kernel(self, inputs, out=None, algorithm="auto"): + """Run the flashinfer kernel with state_scale.""" + return flashinfer.mamba.selective_state_update( + inputs["state_cache"], + inputs["x"], + inputs["dt"], + inputs["A"], + inputs["B"], + inputs["C"], + D=inputs["D"], + z=inputs.get("z"), + dt_bias=inputs["dt_bias"], + dt_softplus=True, + state_batch_indices=inputs["slot_idx"], + pad_slot_id=-1, + out=out, + state_scale=inputs["state_scale"], + algorithm=algorithm, + ) + + def assert_states_match( + self, + state_ref, + state_test, + slot_idx, + msg_prefix="", + state_scale_ref=None, + state_scale_test=None, + ): + """Compare dequantized int16 states.""" + state_ref_batch = state_ref[slot_idx].float() + state_test_batch = state_test[slot_idx].float() + + # Dequantize using the respective scales + if state_scale_ref is not None: + scale_ref = state_scale_ref[ + slot_idx + ] # (batch, nheads, dim, 1) or (batch, nheads, dim) + if scale_ref.dim() == 3: + scale_ref = scale_ref.unsqueeze(-1) + state_ref_batch = state_ref_batch * scale_ref + if state_scale_test is not None: + scale_test = state_scale_test[slot_idx] + if scale_test.dim() == 3: + scale_test = scale_test.unsqueeze(-1) + state_test_batch = state_test_batch * scale_test + + states_match = torch.allclose( + state_ref_batch, state_test_batch, atol=self.ATOL, rtol=self.RTOL + ) + + if states_match: + print( + f"✓ {msg_prefix}Dequantized states match within tolerance (atol={self.ATOL}, rtol={self.RTOL})" + ) + else: + print(f"✗ {msg_prefix}Dequantized states do NOT match") + diff = (state_ref_batch - state_test_batch).abs() + print( + f" Max diff: {diff.max().item():.6f}, Mean diff: {diff.mean().item():.6f}" + ) + + assert states_match + + @pytest.mark.parametrize( + "algorithm", [a for a in _get_algorithms() if a != "horizontal"] + ) + @pytest.mark.parametrize( + "batch,nheads,dim,dstate,state_dtype,weight_dtype,use_out_tensor", + [(b, nh, d, ds, torch.int16, wd, uo) for b, nh, d, ds, wd, uo in _INT16_PARAMS], + ) + def test_output_correctness( + self, + batch, + nheads, + dim, + dstate, + state_dtype, + weight_dtype, + use_out_tensor, + algorithm, + ): + inputs = self.make_inputs(batch, nheads, dim, dstate, state_dtype, weight_dtype) + y_ref, state_ref, state_scale_ref = self.make_reference_output(inputs) + + if use_out_tensor: + out = torch.empty(batch, nheads, dim, dtype=self.INPUT_DTYPE, device="cuda") + else: + out = None + + y_test = self.run_kernel(inputs, out=out, algorithm=algorithm) + + if use_out_tensor: + assert y_test.data_ptr() == out.data_ptr() + + self.assert_outputs_match(y_ref, y_test, msg_prefix=f"[{algorithm}] ") + self.assert_states_match( + state_ref, + inputs["state_cache"], + inputs["slot_idx"], + msg_prefix=f"[{algorithm}] ", + state_scale_ref=state_scale_ref, + state_scale_test=inputs["state_scale"], + ) + + +def _get_algorithms_no_horizontal(): + """Return algorithms that support stochastic rounding (no horizontal).""" + major, _ = get_compute_capability(torch.device("cuda")) + algos = ["simple"] + if major >= 9: + algos.append("vertical") + return algos + + +class TestSelectiveStateUpdateStochasticRounding(TestSelectiveStateUpdate): + """Test fp16 state with stochastic rounding vs Triton reference.""" + + ATOL = 0.001 + RTOL = 0.01 + + RAND_SEED = torch.tensor(42, dtype=torch.int64, device="cuda") + + def make_inputs(self, batch, nheads, dim, dstate, _state_dtype, weight_dtype): + """Create test inputs with fp16 state.""" + return create_test_inputs( + batch, + nheads, + dim, + dstate, + self.NGROUPS, + self.INPUT_DTYPE, + weight_dtype=weight_dtype, + matrixA_dtype=self.MATRIX_A_DTYPE, + state_dtype=torch.float16, + generate_z=False, + seed=0, + ) + + def make_reference_output(self, inputs): + """Compute reference output using Triton with stochastic rounding.""" + state_ref = inputs["state_cache"].clone() + major, _ = get_compute_capability(torch.device("cuda")) + # Triton cvt.rs.f16x2.f32 requires SM100a+; on older GPUs the Triton + # reference falls back to regular rounding while the CUDA kernel still + # exercises its software stochastic rounding path. + rand_seed = self.RAND_SEED if major >= 10 else None + y_ref = selective_state_update_triton( + state_ref, + inputs["x"], + inputs["dt"], + inputs["A"], + inputs["B"], + inputs["C"], + D=inputs["D"], + z=inputs.get("z"), + dt_bias=inputs["dt_bias"], + dt_softplus=True, + state_batch_indices=inputs["slot_idx"], + pad_slot_id=-1, + rand_seed=rand_seed, + ) + return y_ref, state_ref + + def run_kernel(self, inputs, out=None, algorithm="auto"): + """Run the flashinfer kernel with stochastic rounding.""" + return flashinfer.mamba.selective_state_update( + inputs["state_cache"], + inputs["x"], + inputs["dt"], + inputs["A"], + inputs["B"], + inputs["C"], + D=inputs["D"], + z=inputs.get("z"), + dt_bias=inputs["dt_bias"], + dt_softplus=True, + state_batch_indices=inputs["slot_idx"], + pad_slot_id=-1, + out=out, + algorithm=algorithm, + rand_seed=self.RAND_SEED, + ) + + def assert_states_match( + self, state_ref, state_test, slot_idx, msg_prefix="", **_kwargs + ): + """Assert states match within tolerance (SR path has different FP operation order than Triton).""" + state_ref_batch = state_ref[slot_idx] + state_test_batch = state_test[slot_idx] + states_match = torch.allclose( + state_ref_batch.float(), + state_test_batch.float(), + atol=self.ATOL, + rtol=self.RTOL, + ) + + if states_match: + print( + f"✓ {msg_prefix}States match within tolerance (atol={self.ATOL}, rtol={self.RTOL})" + ) + else: + max_diff = ( + (state_ref_batch.float() - state_test_batch.float()).abs().max().item() + ) + print(f"✗ {msg_prefix}States do NOT match (max_diff={max_diff:.6e})") + self._print_mismatch_details(state_ref_batch, state_test_batch, "state") + + assert states_match + + # fmt: off + _SR_PARAMS = ( + # (batch, nheads, dim, dstate, state_dtype, weight_dtype, use_out_tensor) + ( 64, 64, 64, 128, torch.float16, torch.float32, True ), # base + ( 64, 64, 64, 64, torch.float16, torch.float32, True ), # dstate=64 + ) + # fmt: on + + @pytest.mark.parametrize("algorithm", _get_algorithms_no_horizontal()) + @pytest.mark.parametrize( + "batch,nheads,dim,dstate,state_dtype,weight_dtype,use_out_tensor", + _SR_PARAMS, + ) + def test_output_correctness( + self, + batch, + nheads, + dim, + dstate, + state_dtype, + weight_dtype, + use_out_tensor, + algorithm, + ): + super().test_output_correctness( + batch, + nheads, + dim, + dstate, + state_dtype, + weight_dtype, + use_out_tensor, + algorithm, + ) + + class TestSelectiveStateUpdateDtypeMismatch: """Test that selective_state_update fails with dtype mismatch between D and dt.""" diff --git a/tests/mamba/triton_reference/selective_state_update.py b/tests/mamba/triton_reference/selective_state_update.py index 88d20bf251..4013417810 100644 --- a/tests/mamba/triton_reference/selective_state_update.py +++ b/tests/mamba/triton_reference/selective_state_update.py @@ -15,6 +15,22 @@ PAD_SLOT_ID = -1 +@triton.jit +def convert_rs_fp16x2(x: tl.tensor, rand: tl.tensor) -> tl.tensor: + """Stochastic rounding: fp32 pair → fp16x2 using random bits (PTX cvt.rs.f16x2.f32).""" + y = tl.inline_asm_elementwise( + asm="""{ + cvt.rs.f16x2.f32 $0, $2, $1, $3; + }""", + constraints=("=r,r,r,r,r"), + args=(x, rand), + dtype=tl.float16, + is_pure=True, + pack=2, + ) + return y + + @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) @@ -49,10 +65,24 @@ is not None } ) +@triton.heuristics( + {"HAS_STATE_SCALE": lambda args: args["state_scale_ptr"] is not None} +) +@triton.heuristics( + { + "HAS_INTERMEDIATE_STATE_SCALES": lambda args: args[ + "intermediate_state_scales_ptr" + ] + is not None + } +) +@triton.heuristics({"USE_RS_ROUNDING": lambda args: args["rand_seed_ptr"] is not None}) @triton.jit(do_not_specialize=["T"]) def _selective_scan_update_kernel( # Pointers to matrices state_ptr, + state_scale_ptr, + rand_seed_ptr, x_ptr, dt_ptr, dt_bias_ptr, @@ -68,6 +98,7 @@ def _selective_scan_update_kernel( cache_steps, retrieve_parent_token_ptr, intermediate_state_indices_ptr, + intermediate_state_scales_ptr, # Matrix dimensions batch, T, @@ -80,6 +111,9 @@ def _selective_scan_update_kernel( stride_state_head, stride_state_dim, stride_state_dstate, + stride_state_scale_batch, + stride_state_scale_head, + stride_state_scale_dim, stride_x_batch, stride_x_T, stride_x_head, @@ -125,6 +159,10 @@ def _selective_scan_update_kernel( CACHE_INTERMEDIATE_STATES: tl.constexpr, HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: tl.constexpr, HAS_INTERMEDIATE_STATE_INDICES: tl.constexpr, + HAS_STATE_SCALE: tl.constexpr, + HAS_INTERMEDIATE_STATE_SCALES: tl.constexpr, + USE_RS_ROUNDING: tl.constexpr, + PHILOX_ROUNDS: tl.constexpr, BLOCK_SIZE_DSTATE: tl.constexpr, ): pid_m = tl.program_id(axis=0) @@ -138,9 +176,18 @@ def _selective_scan_update_kernel( state_batch_indices_ptr += pid_b state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64) state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head + if HAS_STATE_SCALE: + state_scale_ptr += ( + state_batch_idx * stride_state_scale_batch + + pid_h * stride_state_scale_head + ) else: state_batch_idx = pid_b state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head + if HAS_STATE_SCALE: + state_scale_ptr += ( + pid_b * stride_state_scale_batch + pid_h * stride_state_scale_head + ) x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head @@ -164,6 +211,16 @@ def _selective_scan_update_kernel( mask &= state_batch_idx != pad_slot_id state = tl.load(state_ptrs, mask=mask, other=0.0).to(tl.float32) + if HAS_STATE_SCALE: + state_scale_ptrs = state_scale_ptr + offs_m[:, None] * stride_state_scale_dim + scales_mask = offs_m[:, None] < dim + if HAS_STATE_BATCH_INDICES: + scales_mask = scales_mask & (state_batch_idx != pad_slot_id) + decode_scale = tl.load(state_scale_ptrs, mask=scales_mask, other=1.0).to( + tl.float32 + ) + state = state * decode_scale + if HAS_DT_BIAS: dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim if HAS_D: @@ -205,6 +262,19 @@ def _selective_scan_update_kernel( + offs_n[None, :] ) state = tl.load(cache_ptr, mask=mask, other=0.0).to(tl.float32) + if HAS_INTERMEDIATE_STATE_SCALES: + # Dequantize using the stored per-step scale + iscale_load_ptrs = ( + intermediate_state_scales_ptr + + cache_idx * cache_steps * nheads * dim + + parent_step_idx * nheads * dim + + pid_h * dim + + offs_m + ) + iscale = tl.load( + iscale_load_ptrs, mask=offs_m < dim, other=1.0 + ).to(tl.float32) + state = state * iscale[:, None] x_ptrs = x_ptr + offs_m * stride_x_dim dt_ptrs = dt_ptr + offs_m * stride_dt_dim @@ -257,7 +327,52 @@ def _selective_scan_update_kernel( cache_ptrs = cache_ptr_base + ( offs_m[:, None] * dstate + offs_n[None, :] ) - tl.store(cache_ptrs, state.to(cache_ptrs.dtype.element_ty), mask=mask) + if HAS_INTERMEDIATE_STATE_SCALES: + # Quantize intermediate state with per-step block scaling + int16_max_i: tl.constexpr = 32767.0 + amax_i = tl.max(tl.abs(state), axis=-1) + encode_scale_i = tl.where(amax_i == 0.0, 1.0, int16_max_i / amax_i) + decode_scale_i = 1.0 / encode_scale_i + # Store intermediate decode scale + iscale_ptrs = ( + intermediate_state_scales_ptr + + cache_idx * cache_steps * nheads * dim + + current_step_idx * nheads * dim + + pid_h * dim + + offs_m + ) + tl.store(iscale_ptrs, decode_scale_i, mask=offs_m < dim) + # Quantize and store + q_state = state * encode_scale_i[:, None] + q_state = tl.extra.cuda.libdevice.round(q_state) + q_state = tl.minimum(tl.maximum(q_state, -int16_max_i), int16_max_i) + tl.store( + cache_ptrs, q_state.to(cache_ptrs.dtype.element_ty), mask=mask + ) + elif USE_RS_ROUNDING: + rand_seed = tl.load(rand_seed_ptr) + if HAS_STATE_BATCH_INDICES: + rand_offsets = ( + state_batch_idx * stride_state_batch + + pid_h * stride_state_head + ) + else: + rand_offsets = ( + pid_b * stride_state_batch + pid_h * stride_state_head + ) + rand_offsets += ( + offs_m[:, None] * stride_state_dim + + offs_n[None, :] * stride_state_dstate + ) + if PHILOX_ROUNDS > 0: + rand = tl.randint(rand_seed, rand_offsets, PHILOX_ROUNDS) + else: + rand = tl.randint(rand_seed, rand_offsets) + tl.store(cache_ptrs, convert_rs_fp16x2(state, rand), mask=mask) + else: + tl.store( + cache_ptrs, state.to(cache_ptrs.dtype.element_ty), mask=mask + ) out = tl.sum(state * C[None, :], axis=1) if HAS_D: @@ -277,7 +392,40 @@ def _selective_scan_update_kernel( z_ptr += stride_z_T if not DISABLE_STATE_UPDATE: - tl.store(state_ptrs, state.to(state_ptrs.dtype.element_ty), mask=mask) + if HAS_STATE_SCALE: + # Quantize state back with block scaling (int16 path) + int16_max: tl.constexpr = 32767.0 + amax = tl.max(tl.abs(state), axis=-1, keep_dims=True) + encode_scale = tl.where(amax == 0.0, 1.0, int16_max / amax) + new_decode_scale = 1.0 / encode_scale + # Store updated decode scales + dst_scales_mask = offs_m[:, None] < dim + tl.store(state_scale_ptrs, new_decode_scale, mask=dst_scales_mask) + # Quantize + state = state * encode_scale + state = tl.extra.cuda.libdevice.round(state) + state = tl.minimum(tl.maximum(state, -int16_max), int16_max) + tl.store(state_ptrs, state.to(state_ptrs.dtype.element_ty), mask=mask) + elif USE_RS_ROUNDING: + # Stochastic rounding for fp16 state + rand_seed = tl.load(rand_seed_ptr) + if HAS_STATE_BATCH_INDICES: + rand_offsets = ( + state_batch_idx * stride_state_batch + pid_h * stride_state_head + ) + else: + rand_offsets = pid_b * stride_state_batch + pid_h * stride_state_head + rand_offsets += ( + offs_m[:, None] * stride_state_dim + + offs_n[None, :] * stride_state_dstate + ) + if PHILOX_ROUNDS > 0: + rand = tl.randint(rand_seed, rand_offsets, PHILOX_ROUNDS) + else: + rand = tl.randint(rand_seed, rand_offsets) + tl.store(state_ptrs, convert_rs_fp16x2(state, rand), mask=mask) + else: + tl.store(state_ptrs, state.to(state_ptrs.dtype.element_ty), mask=mask) def selective_state_update_triton( @@ -299,6 +447,10 @@ def selective_state_update_triton( cache_steps=None, retrieve_parent_token=None, intermediate_state_indices=None, + state_scale=None, + intermediate_state_scales=None, + rand_seed=None, + philox_rounds=0, ): """ Argument: @@ -325,6 +477,13 @@ def selective_state_update_triton( retrieve_parent_token: (batch, T) tensor of parent token indices for EAGLE tree attention intermediate_state_indices: (batch,) tensor of indices for intermediate_states_buffer operations. If provided, uses these indices instead of state_batch_indices for the buffer. + state_scale: Optional (batch, nheads, dim, 1) float32 tensor of decode scales for + block-scaled state quantization. When provided, the kernel dequantizes state + on load and requantizes on store. + rand_seed: Optional scalar tensor (int64 or int32) for Philox PRNG seed. + When provided, enables stochastic rounding for fp16 state stores. + Mutually exclusive with state_scale (int16 block scaling). + philox_rounds: Number of Philox PRNG rounds (0 = default, typically 10). """ # Track original x dimensionality to squeeze output appropriately x_orig_dim = x.dim() @@ -409,6 +568,20 @@ def selective_state_update_triton( and (dt_bias is None or dt_bias.stride(-1) == 0) ) + if state_scale is not None: + assert state_scale.dtype == torch.float32 + assert state_scale.shape == ( + state.shape[0], + nheads, + dim, + ) or state_scale.shape == (state.shape[0], nheads, dim, 1) + + state_scale_strides = ( + (state_scale.stride(0), state_scale.stride(1), state_scale.stride(2)) + if state_scale is not None + else (0, 0, 0) + ) + retrieve_parent_token_strides = ( (retrieve_parent_token.stride(0), retrieve_parent_token.stride(1)) if retrieve_parent_token is not None @@ -418,6 +591,8 @@ def selective_state_update_triton( with torch.cuda.device(x.device.index): _selective_scan_update_kernel[grid]( state, + state_scale, + rand_seed, x, dt, dt_bias, @@ -433,6 +608,7 @@ def selective_state_update_triton( cache_steps if cache_steps is not None else 0, retrieve_parent_token, intermediate_state_indices, + intermediate_state_scales, batch, T, nheads, @@ -443,6 +619,9 @@ def selective_state_update_triton( state.stride(1), state.stride(2), state.stride(3), + state_scale_strides[0], + state_scale_strides[1], + state_scale_strides[2], x.stride(0), x.stride(1), x.stride(2), @@ -478,6 +657,7 @@ def selective_state_update_triton( tie_hdim, BLOCK_SIZE_M, DISABLE_STATE_UPDATE=disable_state_update, + PHILOX_ROUNDS=philox_rounds, num_warps=num_warps, ) # Squeeze T dimension if original x didn't have it (was 2D or 3D) diff --git a/tests/mamba/utils.py b/tests/mamba/utils.py index 6f33e930e3..9793b36480 100644 --- a/tests/mamba/utils.py +++ b/tests/mamba/utils.py @@ -46,7 +46,8 @@ def create_test_inputs( ngroups: Number of groups for B and C matrices. input_dtype: Data type for input tensors (x, B, C, z) - from model config.json (typically bf16). weight_dtype: Data type for weight tensors (D, dt, dt_bias) - hardcoded fp32 in mamba2_mixer.py. - state_dtype: Data type for state tensor - user configurable (bf16/fp16/fp32). Defaults to input_dtype. + state_dtype: Data type for state tensor - user configurable (bf16/fp16/fp32/int16). Defaults to input_dtype. + When int16, generates float state, quantizes to int16, and returns state_scale. matrixA_dtype: Data type for the A matrix - hardcoded fp32 in mamba2_mixer.py. generate_z: If True, generate z tensor for gating. generate_intermediate_states_buffer: If True, generate buffer for @@ -65,6 +66,7 @@ def create_test_inputs( Returns: Dictionary containing all generated tensors with the following keys: - state_cache: (total_entries, nheads, dim, dstate) + - state_scale: (total_entries, nheads, dim, 1) float32 - only when state_dtype=int16, else None - x: (batch_size, [T,] nheads, dim) - T present if cache_steps provided - dt: (batch_size, [T,] nheads, dim) - T present if cache_steps provided - A: (nheads, dim, dstate) @@ -113,11 +115,39 @@ def create_test_inputs( ) total_elements = ssm_state_cache_size * state_cache_batch_stride - state_cache_flat = torch.randn(total_elements, dtype=state_dtype, device=device) - state_cache = state_cache_flat.as_strided( - (ssm_state_cache_size, nheads, dim, dstate), - (state_cache_batch_stride, dim * dstate, dstate, 1), - ) + if state_dtype == torch.int16: + # Generate in float, quantize to int16, and produce decode scales + state_cache_f32 = torch.randn( + ssm_state_cache_size, + nheads, + dim, + dstate, + device=device, + dtype=torch.float32, + ) + int16_max = torch.iinfo(torch.int16).max # 32767 + amax = torch.amax(torch.abs(state_cache_f32), dim=-1, keepdim=True) + encode_scale = torch.where(amax == 0.0, torch.ones_like(amax), int16_max / amax) + state_cache_int16 = ( + torch.round(state_cache_f32 * encode_scale) + .clamp(-int16_max, int16_max) + .to(torch.int16) + ) + state_scale = 1.0 / encode_scale # (ssm_state_cache_size, nheads, dim, 1) + # Re-layout with the requested batch stride + state_cache_flat = torch.zeros(total_elements, dtype=torch.int16, device=device) + state_cache = state_cache_flat.as_strided( + (ssm_state_cache_size, nheads, dim, dstate), + (state_cache_batch_stride, dim * dstate, dstate, 1), + ) + state_cache.copy_(state_cache_int16) + else: + state_cache_flat = torch.randn(total_elements, dtype=state_dtype, device=device) + state_cache = state_cache_flat.as_strided( + (ssm_state_cache_size, nheads, dim, dstate), + (state_cache_batch_stride, dim * dstate, dstate, 1), + ) + state_scale = None # Input x: (batch_size, [T,] nheads, dim) if T is not None: @@ -171,6 +201,7 @@ def create_test_inputs( # Build result dictionary result = { "state_cache": state_cache, + "state_scale": state_scale, "x": x, "dt": dt, "A": A, @@ -214,6 +245,17 @@ def create_test_inputs( batch_size, dtype=torch.int64, device=device ) result["intermediate_slot_idx"] = intermediate_slot_idx + # For int16 state + intermediate states, generate per-step decode scales + if state_dtype == torch.int16: + intermediate_state_scales = torch.zeros( + batch_size, + cache_steps, + nheads, + dim, + dtype=torch.float32, + device=device, + ) + result["intermediate_state_scales"] = intermediate_state_scales # Optional: retrieve_parent_token for EAGLE tree attention if generate_retrieve_parent_token: