diff --git a/benchmarks/routines/mamba.py b/benchmarks/routines/mamba.py index 0d53dae849..d5fffd57c8 100644 --- a/benchmarks/routines/mamba.py +++ b/benchmarks/routines/mamba.py @@ -14,14 +14,8 @@ limitations under the License. """ -# ============================================================================== -# Triton reference implementation for selective_state_update. -# Imported from tests/mamba/selective_state_update_triton.py to avoid code -# duplication. See that file for the canonical Triton kernel source. -# ============================================================================== - -import importlib import os +import sys from collections import defaultdict import numpy as np @@ -30,6 +24,14 @@ import flashinfer from flashinfer.testing.utils import bench_gpu_time +# Add tests/mamba to sys.path so triton_reference is importable as a package +_repo_root = os.path.normpath( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..") +) +_tests_mamba = os.path.join(_repo_root, "tests", "mamba") +if _tests_mamba not in sys.path: + sys.path.insert(0, _tests_mamba) + from .flashinfer_benchmark_utils import ( dtype_str_to_torch_dtype, get_device, @@ -38,40 +40,9 @@ filter_backends_by_compute_capability, ) -# ---- Import Triton reference kernel from tests/mamba/ ---- -# The canonical Triton selective_state_update lives in tests/mamba/selective_state_update_triton.py. -# We import it here rather than duplicating ~400 lines of kernel code. - - -def _import_triton_reference(): - """Import selective_state_update_triton from tests/mamba/. - - Uses importlib to load the module directly by file path, avoiding sys.path - pollution and fragile relative path assumptions. - """ - # Resolve path: benchmarks/routines/mamba.py -> ../../tests/mamba/selective_state_update_triton.py - _this_dir = os.path.dirname(os.path.abspath(__file__)) - _repo_root = os.path.normpath(os.path.join(_this_dir, "..", "..")) - _triton_ref_path = os.path.join( - _repo_root, "tests", "mamba", "selective_state_update_triton.py" - ) - - if not os.path.isfile(_triton_ref_path): - raise ImportError( - f"Cannot find Triton reference kernel at: {_triton_ref_path}\n" - f"Expected location: /tests/mamba/selective_state_update_triton.py\n" - f"Make sure you are running from within the FlashInfer repository." - ) - - spec = importlib.util.spec_from_file_location( - "selective_state_update_triton", _triton_ref_path - ) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - return module.selective_state_update_triton - - -selective_state_update_triton_reference = _import_triton_reference() +from triton_reference.selective_state_update import ( + selective_state_update_triton as selective_state_update_triton_reference, +) # ============================================================================== diff --git a/csrc/flashinfer_mamba_binding.cu b/csrc/flashinfer_mamba_binding.cu index dfdc5bebf8..2e2453cefc 100644 --- a/csrc/flashinfer_mamba_binding.cu +++ b/csrc/flashinfer_mamba_binding.cu @@ -42,7 +42,8 @@ void selective_state_update( bool disable_state_update, Optional intermediate_states_buffer, // (batch, cache_steps, nheads, dim, dstate) Optional intermediate_state_indices, // (batch,) - int64_t cache_steps); + int64_t cache_steps, + int64_t algorithm); // SSUAlgorithm: 0=auto, 1=simple, 2=vertical, 3=horizontal } // namespace flashinfer::mamba diff --git a/csrc/selective_state_update.cu b/csrc/selective_state_update.cu index d8ada31ed8..3918d3caf8 100644 --- a/csrc/selective_state_update.cu +++ b/csrc/selective_state_update.cu @@ -13,9 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +// clang-format off +// config.inc MUST come before the header: it defines DIM, DSTATE, NTOKENS_MTP +// constexprs that the header's function templates rely on. Reordering breaks compilation. +// NOTE: the .inc file is generated from the jinja templates +#include "selective_state_update_config.inc" #include -#include - +// clang-format on #include "tvm_ffi_utils.h" using namespace flashinfer; @@ -124,87 +128,13 @@ inline void validate_dtype_consistency( } } -// Helper to convert dtype code to string for error messages -inline const char* dtype_code_to_string(int64_t code) { - if (code == bfloat16_code) return "bfloat16"; - if (code == float16_code) return "float16"; - if (code == float32_code) return "float32"; - return "unknown"; -} - -// Type traits to map dtype codes to C++ types -template -struct DTypeToType; - -template <> -struct DTypeToType { - using type = nv_bfloat16; -}; -template <> -struct DTypeToType { - using type = half; -}; -template <> -struct DTypeToType { - using type = float; -}; -template <> -struct DTypeToType { - using type = int32_t; -}; -template <> -struct DTypeToType { - using type = int64_t; -}; - -// Allowed dtype combinations: {state_code, input_code, weight_code, matrixA_code, stateIndex_code} -constexpr std::tuple allowed_dtype_combos[] = { - {bfloat16_code, bfloat16_code, bfloat16_code, float32_code, int32_code}, - {float16_code, bfloat16_code, bfloat16_code, float32_code, int32_code}, - {float32_code, bfloat16_code, bfloat16_code, float32_code, int32_code}, - {bfloat16_code, bfloat16_code, float32_code, float32_code, int32_code}, - {float16_code, bfloat16_code, float32_code, float32_code, int32_code}, - {float32_code, bfloat16_code, float32_code, float32_code, int32_code}, - {bfloat16_code, bfloat16_code, bfloat16_code, float32_code, int64_code}, - {float16_code, bfloat16_code, bfloat16_code, float32_code, int64_code}, - {float32_code, bfloat16_code, bfloat16_code, float32_code, int64_code}, - {bfloat16_code, bfloat16_code, float32_code, float32_code, int64_code}, - {float16_code, bfloat16_code, float32_code, float32_code, int64_code}, - {float32_code, bfloat16_code, float32_code, float32_code, int64_code}, -}; - -// Helper to dispatch to the right template instantiation for STP -template -void dispatchCombo(SelectiveStateUpdateParams& p, cudaStream_t stream) { - using state_t = typename DTypeToType::type; - using input_t = typename DTypeToType::type; - using weight_t = typename DTypeToType::type; - using matrixA_t = typename DTypeToType::type; - using stateIndex_t = typename DTypeToType::type; - invokeSelectiveStateUpdate(p, stream); -} - -// Helper to dispatch to the right template instantiation for MTP -template -void dispatchComboMTP(mtp::SelectiveStateMTPParams& p, cudaStream_t stream) { - using state_t = typename DTypeToType::type; - using input_t = typename DTypeToType::type; - using weight_t = typename DTypeToType::type; - using matrixA_t = typename DTypeToType::type; - using stateIndex_t = typename DTypeToType::type; - mtp::invokeSelectiveStateUpdateMTP(p, - stream); -} - void run_selective_state_update_stp(TensorView const& state, TensorView const& x, TensorView const& dt, TensorView const& A, TensorView const& B, TensorView const& C, TensorView const& D, Optional z, Optional dt_bias, bool dt_softplus, Optional state_batch_indices, int64_t pad_slot_id, Optional out, - bool disable_state_update) { + bool disable_state_update, int64_t algorithm) { // Extract dimensions from input tensors auto const batch = x.size(0); auto const state_cache_size = state.size(0); @@ -344,64 +274,8 @@ void run_selective_state_update_stp(TensorView const& state, TensorView const& x ffi::CUDADeviceGuard device_guard(state.device().device_id); const cudaStream_t stream = get_stream(state.device()); - // Dispatch based on dtype combination - DLDataType state_dtype = state.dtype(); - DLDataType input_dtype = x.dtype(); - DLDataType weight_dtype = dt.dtype(); - DLDataType matrixA_dtype = A.dtype(); - int64_t state_dtype_code = encode_dlpack_dtype(state_dtype); - int64_t input_dtype_code = encode_dlpack_dtype(input_dtype); - int64_t weight_dtype_code = encode_dlpack_dtype(weight_dtype); - int64_t matrixA_dtype_code = encode_dlpack_dtype(matrixA_dtype); - - // Get state_batch_indices dtype, default to int32 if not provided - int64_t stateIndex_dtype_code = int32_code; - if (state_batch_indices.has_value()) { - DLDataType stateIndex_dtype = state_batch_indices.value().dtype(); - stateIndex_dtype_code = encode_dlpack_dtype(stateIndex_dtype); - } - - // Dispatch kernel based on dtype combination - auto dtype_key = std::make_tuple(state_dtype_code, input_dtype_code, weight_dtype_code, - matrixA_dtype_code, stateIndex_dtype_code); - - // Compile-time recursive dispatcher using Y-combinator pattern for lambda self-recursion - auto tryDispatch = [&](const auto& key, auto idx, auto& self) -> bool { - constexpr size_t I = decltype(idx)::value; - if constexpr (I < std::size(allowed_dtype_combos)) { - constexpr auto combo = allowed_dtype_combos[I]; - if (key == combo) { - constexpr auto s = std::get<0>(combo); - constexpr auto i = std::get<1>(combo); - constexpr auto w = std::get<2>(combo); - constexpr auto m = std::get<3>(combo); - constexpr auto si = std::get<4>(combo); - dispatchCombo(p, stream); - return true; - } - return self(key, std::integral_constant{}, self); - } - return false; - }; - - // Dispatch using compile-time type traits - if (!tryDispatch(dtype_key, std::integral_constant{}, tryDispatch)) { - // Unsupported dtype combination - build error message dynamically - std::ostringstream error_msg; - error_msg << "Unsupported dtype combination for selective_state_update: " << "state_dtype=" - << state_dtype.code << ":" << state_dtype.bits << ", " - << "input_dtype=" << input_dtype.code << ":" << input_dtype.bits << ", " - << "weight_dtype=" << weight_dtype.code << ":" << weight_dtype.bits << ", " - << "matrixA_dtype=" << matrixA_dtype.code << ":" << matrixA_dtype.bits - << ". Supported combos include:\n"; - for (const auto& combo : allowed_dtype_combos) { - error_msg << " (state=" << dtype_code_to_string(std::get<0>(combo)) - << ", input=" << dtype_code_to_string(std::get<1>(combo)) - << ", weight=" << dtype_code_to_string(std::get<2>(combo)) - << ", matrixA=" << dtype_code_to_string(std::get<3>(combo)) << ")\n"; - } - TVM_FFI_ICHECK(false) << error_msg.str(); - } + auto algo = static_cast(algorithm); + invokeSelectiveStateUpdate(p, algo, stream); } void run_selective_state_update_mtp( @@ -410,7 +284,7 @@ void run_selective_state_update_mtp( Optional dt_bias, bool dt_softplus, Optional state_batch_indices, int64_t pad_slot_id, Optional out, bool disable_state_update, Optional intermediate_states_buffer, - Optional intermediate_state_indices, int64_t cache_steps) { + Optional intermediate_state_indices, int64_t cache_steps, int64_t algorithm) { // Extract dimensions from input tensors auto const batch = x.size(0); auto const ntokens_mtp = x.size(1); @@ -505,6 +379,15 @@ void run_selective_state_update_mtp( validate_intermediate_state_indices(intermediate_state_indices, batch); validate_intermediate_states_buffer(intermediate_states_buffer); + // Validate that state_batch_indices and intermediate_state_indices have the same dtype + if (state_batch_indices.has_value() && intermediate_state_indices.has_value()) { + DLDataType state_batch_idx_dtype = state_batch_indices.value().dtype(); + DLDataType intermediate_idx_dtype = intermediate_state_indices.value().dtype(); + FLASHINFER_CHECK(state_batch_idx_dtype.code == intermediate_idx_dtype.code && + state_batch_idx_dtype.bits == intermediate_idx_dtype.bits, + "state_batch_indices and intermediate_state_indices must have the same dtype"); + } + // Validate cache_steps is non-negative FLASHINFER_CHECK(cache_steps >= 0, "cache_steps must be non-negative, got ", cache_steps); @@ -588,75 +471,9 @@ void run_selective_state_update_mtp( ffi::CUDADeviceGuard device_guard(state.device().device_id); const cudaStream_t stream = get_stream(state.device()); - // Dispatch based on dtype combination - DLDataType state_dtype = state.dtype(); - DLDataType input_dtype = x.dtype(); - DLDataType weight_dtype = dt.dtype(); - DLDataType matrixA_dtype = A.dtype(); - int64_t state_dtype_code = encode_dlpack_dtype(state_dtype); - int64_t input_dtype_code = encode_dlpack_dtype(input_dtype); - int64_t weight_dtype_code = encode_dlpack_dtype(weight_dtype); - int64_t matrixA_dtype_code = encode_dlpack_dtype(matrixA_dtype); - - // Get stateIndex dtype from whichever index tensor is available - // If both are provided, they must have the same dtype - int64_t stateIndex_dtype_code = int32_code; // default - if (state_batch_indices.has_value() && intermediate_state_indices.has_value()) { - DLDataType state_batch_idx_dtype = state_batch_indices.value().dtype(); - DLDataType intermediate_idx_dtype = intermediate_state_indices.value().dtype(); - FLASHINFER_CHECK(state_batch_idx_dtype.code == intermediate_idx_dtype.code && - state_batch_idx_dtype.bits == intermediate_idx_dtype.bits, - "state_batch_indices and intermediate_state_indices must have the same dtype"); - stateIndex_dtype_code = encode_dlpack_dtype(state_batch_idx_dtype); - } else if (state_batch_indices.has_value()) { - DLDataType state_batch_idx_dtype = state_batch_indices.value().dtype(); - stateIndex_dtype_code = encode_dlpack_dtype(state_batch_idx_dtype); - } else if (intermediate_state_indices.has_value()) { - DLDataType intermediate_idx_dtype = intermediate_state_indices.value().dtype(); - stateIndex_dtype_code = encode_dlpack_dtype(intermediate_idx_dtype); - } - - // Dispatch kernel based on dtype combination - auto dtype_key = std::make_tuple(state_dtype_code, input_dtype_code, weight_dtype_code, - matrixA_dtype_code, stateIndex_dtype_code); - - // Compile-time recursive dispatcher using Y-combinator pattern for lambda self-recursion - auto tryDispatch = [&](const auto& key, auto idx, auto& self) -> bool { - constexpr size_t I = decltype(idx)::value; - if constexpr (I < std::size(allowed_dtype_combos)) { - constexpr auto combo = allowed_dtype_combos[I]; - if (key == combo) { - constexpr auto s = std::get<0>(combo); - constexpr auto i = std::get<1>(combo); - constexpr auto w = std::get<2>(combo); - constexpr auto m = std::get<3>(combo); - constexpr auto si = std::get<4>(combo); - dispatchComboMTP(p, stream); - return true; - } - return self(key, std::integral_constant{}, self); - } - return false; - }; - - // Dispatch using compile-time type traits - if (!tryDispatch(dtype_key, std::integral_constant{}, tryDispatch)) { - // Unsupported dtype combination - build error message dynamically - std::ostringstream error_msg; - error_msg << "Unsupported dtype combination for selective_state_update: " << "state_dtype=" - << state_dtype.code << ":" << state_dtype.bits << ", " - << "input_dtype=" << input_dtype.code << ":" << input_dtype.bits << ", " - << "weight_dtype=" << weight_dtype.code << ":" << weight_dtype.bits << ", " - << "matrixA_dtype=" << matrixA_dtype.code << ":" << matrixA_dtype.bits - << ". Supported combos include:\n"; - for (const auto& combo : allowed_dtype_combos) { - error_msg << " (state=" << dtype_code_to_string(std::get<0>(combo)) - << ", input=" << dtype_code_to_string(std::get<1>(combo)) - << ", weight=" << dtype_code_to_string(std::get<2>(combo)) - << ", matrixA=" << dtype_code_to_string(std::get<3>(combo)) << ")\n"; - } - TVM_FFI_ICHECK(false) << error_msg.str(); - } + auto algo = static_cast(algorithm); + mtp::invokeSelectiveStateUpdateMTP(p, algo, + stream); } // ============================================================================= @@ -668,14 +485,17 @@ void selective_state_update(TensorView state, TensorView x, TensorView dt, Tenso Optional state_batch_indices, int64_t pad_slot_id, TensorView output, bool disable_state_update, Optional intermediate_states_buffer, - Optional intermediate_state_indices, int64_t cache_steps) { + Optional intermediate_state_indices, int64_t cache_steps, + int64_t algorithm) { if (x.dim() == 3) { run_selective_state_update_stp(state, x, dt, A, B, C, D, z, dt_bias, dt_softplus, - state_batch_indices, pad_slot_id, output, disable_state_update); + state_batch_indices, pad_slot_id, output, disable_state_update, + algorithm); } else if (x.dim() == 4) { - run_selective_state_update_mtp( - state, x, dt, A, B, C, D, z, dt_bias, dt_softplus, state_batch_indices, pad_slot_id, output, - disable_state_update, intermediate_states_buffer, intermediate_state_indices, cache_steps); + run_selective_state_update_mtp(state, x, dt, A, B, C, D, z, dt_bias, dt_softplus, + state_batch_indices, pad_slot_id, output, disable_state_update, + intermediate_states_buffer, intermediate_state_indices, + cache_steps, algorithm); } else { FLASHINFER_CHECK(false, "x must have 3 dimensions (single-token) or 4 dimensions (multi-token), got ", diff --git a/csrc/selective_state_update_customize_config.jinja b/csrc/selective_state_update_customize_config.jinja new file mode 100644 index 0000000000..418356212d --- /dev/null +++ b/csrc/selective_state_update_customize_config.jinja @@ -0,0 +1,14 @@ +#pragma once +#include +#include +#include + +using state_t = {{ state_dtype }}; +using input_t = {{ input_dtype }}; +using weight_t = {{ weight_dtype }}; +using matrixA_t = {{ matrixA_dtype }}; +using stateIndex_t = {{ stateIndex_dtype }}; + +constexpr int DIM = {{ dim }}; +constexpr int DSTATE = {{ dstate }}; +constexpr int NTOKENS_MTP = {{ ntokens_mtp }}; diff --git a/csrc/selective_state_update_dtype_inst.jinja b/csrc/selective_state_update_dtype_inst.jinja new file mode 100644 index 0000000000..02dd66322b --- /dev/null +++ b/csrc/selective_state_update_dtype_inst.jinja @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Auto-generated file - do not edit directly. +// Generated by flashinfer/jit/mamba/selective_state_update.py + +#include + +namespace flashinfer::mamba { + +template void invokeSelectiveStateUpdate<{{ input_dtype }}, {{ weight_dtype }}, {{ matrixA_dtype }}, {{ state_dtype }}, {{ stateIndex_dtype }}>( + SelectiveStateUpdateParams& params, SSUAlgorithm algo, cudaStream_t stream); + +namespace mtp { +template void invokeSelectiveStateUpdateMTP<{{ input_dtype }}, {{ weight_dtype }}, {{ matrixA_dtype }}, {{ state_dtype }}, {{ stateIndex_dtype }}>( + SelectiveStateMTPParams& params, SSUAlgorithm algo, cudaStream_t stream); +} // namespace mtp + +} // namespace flashinfer::mamba diff --git a/csrc/selective_state_update_kernel_inst.cu b/csrc/selective_state_update_kernel_inst.cu new file mode 100644 index 0000000000..6dcec72a5d --- /dev/null +++ b/csrc/selective_state_update_kernel_inst.cu @@ -0,0 +1,18 @@ +// clang-format off +// config.inc MUST come before the header: it defines DIM, DSTATE, NTOKENS_MTP +// constexprs that the header's function templates rely on. Reordering breaks compilation. +#include "selective_state_update_config.inc" +#include +// clang-format on + +namespace flashinfer::mamba { + +template void invokeSelectiveStateUpdate( + SelectiveStateUpdateParams&, SSUAlgorithm, cudaStream_t); + +namespace mtp { +template void invokeSelectiveStateUpdateMTP( + SelectiveStateMTPParams&, SSUAlgorithm, cudaStream_t); +} // namespace mtp + +} // namespace flashinfer::mamba diff --git a/flashinfer/aot.py b/flashinfer/aot.py index c0289fd3be..f11ac238bb 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -25,13 +25,28 @@ import shutil from itertools import product from pathlib import Path -from typing import List, Tuple, Iterator, Optional +from typing import Iterator, List, Optional, Tuple import torch - from packaging.version import Version + +from .compilation_context import CompilationContext +from .jit import JitSpec, build_jit_specs +from .jit import env as jit_env from .jit.activation import act_func_def_str, gen_act_and_mul_module +from .jit.attention import ( + gen_batch_attention_module, + gen_batch_decode_module, + gen_batch_mla_module, + gen_batch_prefill_module, + gen_cudnn_fmha_module, + gen_fmha_cutlass_sm100a_module, + gen_single_decode_module, + gen_single_prefill_module, + gen_trtllm_gen_fmha_module, +) from .jit.cascade import gen_cascade_module +from .jit.cpp_ext import get_cuda_version from .jit.fp4_quantization import ( gen_fp4_quantization_sm90_module, gen_fp4_quantization_sm100_module, @@ -41,58 +56,42 @@ gen_fp4_quantization_sm121_module, ) from .jit.fp8_quantization import gen_mxfp8_quantization_sm100_module -from .jit.gdn import gen_gdn_prefill_sm90_module from .jit.fused_moe import ( - gen_cutlass_fused_moe_sm120_module, - gen_cutlass_fused_moe_sm103_module, - gen_cutlass_fused_moe_sm100_module, gen_cutlass_fused_moe_sm90_module, + gen_cutlass_fused_moe_sm100_module, + gen_cutlass_fused_moe_sm103_module, + gen_cutlass_fused_moe_sm120_module, gen_trtllm_gen_fused_moe_sm100_module, ) +from .jit.gdn import gen_gdn_prefill_sm90_module from .jit.gemm import ( + gen_fp8_blockscale_gemm_sm90_module, gen_gemm_module, gen_gemm_sm90_module, - gen_fp8_blockscale_gemm_sm90_module, gen_gemm_sm100_module, gen_gemm_sm100_module_cutlass_fp4, gen_gemm_sm100_module_cutlass_fp8, gen_gemm_sm100_module_cutlass_mxfp8, - gen_tgv_gemm_sm10x_module, gen_gemm_sm120_module, gen_gemm_sm120_module_cutlass_fp4, + gen_tgv_gemm_sm10x_module, gen_trtllm_gen_gemm_module, gen_trtllm_low_latency_gemm_module, ) -from .jit.spdlog import gen_spdlog_module -from .jit.mla import gen_mla_module from .jit.mamba import ( gen_selective_state_update_module, gen_selective_state_update_sm90_module, - gen_selective_state_update_sm100_module, ) +from .jit.mla import gen_mla_module from .jit.norm import gen_norm_module from .jit.page import gen_page_module from .jit.quantization import gen_quantization_module from .jit.rope import gen_rope_module from .jit.sampling import gen_sampling_module -from .jit.topk import gen_topk_module +from .jit.spdlog import gen_spdlog_module from .jit.tllm_utils import gen_trtllm_utils_module +from .jit.topk import gen_topk_module from .jit.xqa import gen_xqa_module, gen_xqa_module_mla -from .jit.attention import ( - gen_batch_attention_module, - gen_batch_decode_module, - gen_batch_mla_module, - gen_batch_prefill_module, - gen_cudnn_fmha_module, - gen_fmha_cutlass_sm100a_module, - gen_single_decode_module, - gen_single_prefill_module, - gen_trtllm_gen_fmha_module, -) -from .jit import JitSpec, build_jit_specs -from .jit import env as jit_env -from .jit.cpp_ext import get_cuda_version -from .compilation_context import CompilationContext def gen_fa2( @@ -520,11 +519,14 @@ def gen_all_modules( jit_specs.append(gen_fp4_quantization_sm121_module()) if add_comm: - from .jit.comm import gen_trtllm_comm_module, gen_vllm_comm_module - from .jit.comm import gen_nvshmem_module - from .jit.comm import gen_comm_alltoall_module - from .jit.comm import gen_trtllm_mnnvl_comm_module - from .jit.comm import gen_moe_alltoall_module + from .jit.comm import ( + gen_comm_alltoall_module, + gen_moe_alltoall_module, + gen_nvshmem_module, + gen_trtllm_comm_module, + gen_trtllm_mnnvl_comm_module, + gen_vllm_comm_module, + ) jit_specs.append(gen_nvshmem_module()) jit_specs.append(gen_comm_alltoall_module()) @@ -543,14 +545,43 @@ def gen_all_modules( gen_rope_module(), gen_sampling_module(), gen_topk_module(), - gen_selective_state_update_module(), ] - if has_sm90: - jit_specs.append(gen_selective_state_update_sm90_module()) + # selective_state_update: one module per dtype combo per GPU arch + _ssu_dtype_combos = [ + # (state, input, weight, matrixA, stateIndex) + ( + torch.bfloat16, + torch.bfloat16, + torch.bfloat16, + torch.float32, + torch.int64, + ), + (torch.float32, torch.bfloat16, torch.bfloat16, torch.float32, torch.int64), + ] + _ssu_dims = [64] + _ssu_dstates = [128] + _ssu_ntokens = [1, 4, 6, 8] + for dtype_combo, dim, dstate, ntokens in product( + _ssu_dtype_combos, _ssu_dims, _ssu_dstates, _ssu_ntokens + ): + jit_specs.append( + # false positive: mypy can't resolve the signature because flashinfer.jit deps (filelock etc.) + # are absent in mypy's isolated env, causing it to infer an incorrect function signature + gen_selective_state_update_module(*dtype_combo, dim, dstate, ntokens) # type: ignore[call-arg] + ) + if has_sm90 or has_sm100: + for dtype_combo, dim, dstate, ntokens in product( + _ssu_dtype_combos, _ssu_dims, _ssu_dstates, _ssu_ntokens + ): + jit_specs.append( + # same false positive as above + gen_selective_state_update_sm90_module( # type: ignore[call-arg] + *dtype_combo, dim, dstate, ntokens + ) + ) jit_specs.append(gen_trtllm_utils_module()) + if has_sm90: jit_specs.append(gen_gdn_prefill_sm90_module()) - if has_sm100: - jit_specs.append(gen_selective_state_update_sm100_module()) if ( add_xqa and get_cuda_version() > Version("12.8") diff --git a/flashinfer/jit/mamba/__init__.py b/flashinfer/jit/mamba/__init__.py index 8ac01c2455..f6a2628b43 100644 --- a/flashinfer/jit/mamba/__init__.py +++ b/flashinfer/jit/mamba/__init__.py @@ -17,11 +17,9 @@ from .selective_state_update import ( gen_selective_state_update_module, gen_selective_state_update_sm90_module, - gen_selective_state_update_sm100_module, ) __all__ = [ "gen_selective_state_update_module", "gen_selective_state_update_sm90_module", - "gen_selective_state_update_sm100_module", ] diff --git a/flashinfer/jit/mamba/selective_state_update.py b/flashinfer/jit/mamba/selective_state_update.py index a9b18580e2..bba5c3d375 100644 --- a/flashinfer/jit/mamba/selective_state_update.py +++ b/flashinfer/jit/mamba/selective_state_update.py @@ -14,65 +14,178 @@ limitations under the License. """ +import os + +import jinja2 +import torch + from ...compilation_context import CompilationContext from .. import env as jit_env from ..core import JitSpec, gen_jit_spec +from ..utils import write_if_different +# Map torch dtypes to C++ type names +_dtype_map = { + torch.float16: "half", + torch.bfloat16: "nv_bfloat16", + torch.float32: "float", + torch.int32: "int32_t", + torch.int64: "int64_t", +} -def gen_selective_state_update_module() -> JitSpec: - return gen_jit_spec( - "mamba_selective_state_update", - [ - jit_env.FLASHINFER_CSRC_DIR / "selective_state_update.cu", - jit_env.FLASHINFER_CSRC_DIR / "flashinfer_mamba_binding.cu", - ], +# Map torch dtypes to filename-safe strings +_filename_safe_dtype_map = { + torch.float16: "f16", + torch.bfloat16: "bf16", + torch.float32: "f32", + torch.int32: "i32", + torch.int64: "i64", +} + + +def get_selective_state_update_uri( + state_dtype: torch.dtype, + input_dtype: torch.dtype, + weight_dtype: torch.dtype, + matrixA_dtype: torch.dtype, + stateIndex_dtype: torch.dtype, + dim: int, + dstate: int, + ntokens_mtp: int, +) -> str: + s = _filename_safe_dtype_map + return ( + f"selective_state_update_" + f"s_{s[state_dtype]}_i_{s[input_dtype]}_w_{s[weight_dtype]}_" + f"a_{s[matrixA_dtype]}_si_{s[stateIndex_dtype]}_" + f"d_{dim}_ds_{dstate}_nt_{ntokens_mtp}" ) -def gen_selective_state_update_sm90_module() -> JitSpec: - # We use a specialized module for Hopper GPUs due to the explicit use - # of TMA device functions (vertical producer-consumer kernel). - # This supports SM90 (Hopper) only. - # - # Technically, all the kernels in this module can be executed on newer GPUs than Hopper, - # but this kernel ends up being slower than the alternative SM100 module. - # Therefore, this is excluded to reduce the amount of compilation. - compilation_context = CompilationContext() - nvcc_flags = compilation_context.get_nvcc_flags_list(supported_major_versions=[9]) - nvcc_flags += [ - "-DFLASHINFER_MAMBA_ENABLE_SM90", - ] +def _gen_module( + uri: str, + state_dtype: torch.dtype, + input_dtype: torch.dtype, + weight_dtype: torch.dtype, + matrixA_dtype: torch.dtype, + stateIndex_dtype: torch.dtype, + dim: int, + dstate: int, + ntokens_mtp: int, + extra_cuda_cflags: list = None, +) -> JitSpec: + gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri + os.makedirs(gen_directory, exist_ok=True) + + # Render the config .inc + with open( + jit_env.FLASHINFER_CSRC_DIR / "selective_state_update_customize_config.jinja" + ) as f: + config_templ = jinja2.Template(f.read()) + + config_str = config_templ.render( + state_dtype=_dtype_map[state_dtype], + input_dtype=_dtype_map[input_dtype], + weight_dtype=_dtype_map[weight_dtype], + matrixA_dtype=_dtype_map[matrixA_dtype], + stateIndex_dtype=_dtype_map[stateIndex_dtype], + dim=dim, + dstate=dstate, + ntokens_mtp=ntokens_mtp, + ) + write_if_different(gen_directory / "selective_state_update_config.inc", config_str) + + # Copy source files to gen directory (so they can #include the config.inc) + source_paths = [] + for filename in [ + "selective_state_update.cu", + "selective_state_update_kernel_inst.cu", + "flashinfer_mamba_binding.cu", + ]: + src_path = jit_env.FLASHINFER_CSRC_DIR / filename + dest_path = gen_directory / filename + source_paths.append(dest_path) + with open(src_path, "r") as f: + source = f.read() + write_if_different(dest_path, source) return gen_jit_spec( - "mamba_selective_state_update_sm90", - [ - jit_env.FLASHINFER_CSRC_DIR / "selective_state_update.cu", - jit_env.FLASHINFER_CSRC_DIR / "flashinfer_mamba_binding.cu", - ], - extra_cuda_cflags=nvcc_flags, + uri, + source_paths, + extra_cuda_cflags=extra_cuda_cflags or [], ) -def gen_selective_state_update_sm100_module() -> JitSpec: - # We use a specialized module for Blackwell+ GPUs with horizontal - # producer-consumer kernel optimized for SM100 and newer architectures. - # This supports SM100 (Blackwell) and future architectures. - # Technically, the code in this module can compile on sm90 as well, but - # this kernel is a lot slower on hopper than those in the mamba_selective_state_update and - # mamba_selective_state_update_sm90 modules. +def gen_selective_state_update_module( + state_dtype: torch.dtype, + input_dtype: torch.dtype, + weight_dtype: torch.dtype, + matrixA_dtype: torch.dtype, + stateIndex_dtype: torch.dtype, + dim: int, + dstate: int, + ntokens_mtp: int, +) -> JitSpec: + uri = get_selective_state_update_uri( + state_dtype, + input_dtype, + weight_dtype, + matrixA_dtype, + stateIndex_dtype, + dim, + dstate, + ntokens_mtp, + ) + return _gen_module( + uri, + state_dtype, + input_dtype, + weight_dtype, + matrixA_dtype, + stateIndex_dtype, + dim, + dstate, + ntokens_mtp, + ) + + +def gen_selective_state_update_sm90_module( + state_dtype: torch.dtype, + input_dtype: torch.dtype, + weight_dtype: torch.dtype, + matrixA_dtype: torch.dtype, + stateIndex_dtype: torch.dtype, + dim: int, + dstate: int, + ntokens_mtp: int, +) -> JitSpec: + uri = ( + get_selective_state_update_uri( + state_dtype, + input_dtype, + weight_dtype, + matrixA_dtype, + stateIndex_dtype, + dim, + dstate, + ntokens_mtp, + ) + + "_sm90" + ) compilation_context = CompilationContext() nvcc_flags = compilation_context.get_nvcc_flags_list( - supported_major_versions=[10, 11, 12] + supported_major_versions=[9, 10, 11, 12] ) - nvcc_flags += [ - "-DFLASHINFER_MAMBA_ENABLE_SM100", - ] - - return gen_jit_spec( - "mamba_selective_state_update_sm100", - [ - jit_env.FLASHINFER_CSRC_DIR / "selective_state_update.cu", - jit_env.FLASHINFER_CSRC_DIR / "flashinfer_mamba_binding.cu", - ], + nvcc_flags += ["-DFLASHINFER_MAMBA_ENABLE_SM90"] + return _gen_module( + uri, + state_dtype, + input_dtype, + weight_dtype, + matrixA_dtype, + stateIndex_dtype, + dim, + dstate, + ntokens_mtp, extra_cuda_cflags=nvcc_flags, ) diff --git a/flashinfer/mamba/selective_state_update.py b/flashinfer/mamba/selective_state_update.py index 734b0f5c10..294330be88 100644 --- a/flashinfer/mamba/selective_state_update.py +++ b/flashinfer/mamba/selective_state_update.py @@ -23,40 +23,61 @@ from ..jit.mamba import ( gen_selective_state_update_module, gen_selective_state_update_sm90_module, - gen_selective_state_update_sm100_module, ) from ..utils import get_compute_capability, register_custom_op, register_fake_op @functools.cache -def get_selective_state_update_module_base(): - """Get cached JIT-compiled selective_state_update module (base version).""" - return gen_selective_state_update_module().build_and_load() - - -@functools.cache -def get_selective_state_update_module_sm90(): - """Get cached JIT-compiled selective_state_update module (SM90/Hopper version).""" - return gen_selective_state_update_sm90_module().build_and_load() - - -@functools.cache -def get_selective_state_update_module_sm100(): - """Get cached JIT-compiled selective_state_update module (SM100+/Blackwell version).""" - return gen_selective_state_update_sm100_module().build_and_load() +def _get_module( + state_dtype: torch.dtype, + input_dtype: torch.dtype, + weight_dtype: torch.dtype, + matrixA_dtype: torch.dtype, + stateIndex_dtype: torch.dtype, + dim: int, + dstate: int, + ntokens_mtp: int, + sm_major: int, +): + args = ( + state_dtype, + input_dtype, + weight_dtype, + matrixA_dtype, + stateIndex_dtype, + dim, + dstate, + ntokens_mtp, + ) + if sm_major >= 9: + return gen_selective_state_update_sm90_module(*args).build_and_load() + else: + return gen_selective_state_update_module(*args).build_and_load() -def get_selective_state_update_module(device: torch.device): +def get_selective_state_update_module( + device: torch.device, + state_dtype: torch.dtype, + input_dtype: torch.dtype, + weight_dtype: torch.dtype, + matrixA_dtype: torch.dtype, + stateIndex_dtype: torch.dtype, + dim: int, + dstate: int, + ntokens_mtp: int, +): major, _ = get_compute_capability(device) - if major >= 10: - # SM100+ (Blackwell and newer) uses horizontal producer-consumer kernel - return get_selective_state_update_module_sm100() - elif major == 9: - # SM90 (Hopper) uses vertical producer-consumer kernel - return get_selective_state_update_module_sm90() - else: - # Pre-Hopper uses simple kernel - return get_selective_state_update_module_base() + return _get_module( + state_dtype, + input_dtype, + weight_dtype, + matrixA_dtype, + stateIndex_dtype, + dim, + dstate, + ntokens_mtp, + major, + ) @flashinfer_api @@ -78,6 +99,7 @@ def selective_state_update( intermediate_states_buffer: Optional[torch.Tensor] = None, intermediate_state_indices: Optional[torch.Tensor] = None, cache_steps: int = 0, + algorithm: str = "auto", ) -> torch.Tensor: r"""Selective state update operation for Mamba layers (the generation phase). @@ -126,6 +148,10 @@ def selective_state_update( with shape (batch,) cache_steps : int Number of steps/tokens to cache for speculative decoding + algorithm : str + Algorithm to use: "auto" (default, picks the best kernel based on GPU arch, + data types, and problem size), "simple" (all GPUs), "vertical" and "horizontal" + (SM90+ only). MTP mode only supports "auto" or "simple". Returns ------- @@ -178,6 +204,30 @@ def selective_state_update( output = torch.empty_like(x) else: output = out + + # Determine stateIndex dtype from index tensors, default to int32 + stateIndex_dtype = torch.int32 + if state_batch_indices is not None: + stateIndex_dtype = state_batch_indices.dtype + elif intermediate_state_indices is not None: + stateIndex_dtype = intermediate_state_indices.dtype + + # Extract dim/dstate/ntokens for JIT specialization + dim = state.size(2) + dstate = state.size(3) + ntokens_mtp = x.size(1) if x.dim() == 4 else 1 + + if algorithm == "auto": + algorithm_int = 0 + elif algorithm == "simple": + algorithm_int = 1 + elif algorithm == "vertical": + algorithm_int = 2 + elif algorithm == "horizontal": + algorithm_int = 3 + else: + raise ValueError(f"Unknown algorithm: {algorithm}") + _selective_state_update( state, x, @@ -196,6 +246,15 @@ def selective_state_update( intermediate_states_buffer, intermediate_state_indices, cache_steps, + algorithm_int, + state.dtype, + x.dtype, + dt.dtype, + A.dtype, + stateIndex_dtype, + dim, + dstate, + ntokens_mtp, ) return output @@ -222,9 +281,28 @@ def _selective_state_update( intermediate_states_buffer: Optional[torch.Tensor], intermediate_state_indices: Optional[torch.Tensor], cache_steps: int, + algorithm: int, + state_dtype: torch.dtype, + input_dtype: torch.dtype, + weight_dtype: torch.dtype, + matrixA_dtype: torch.dtype, + stateIndex_dtype: torch.dtype, + dim: int, + dstate: int, + ntokens_mtp: int, ) -> None: """Internal function registered with torch.library for torch.compile() support.""" - get_selective_state_update_module(state.device).selective_state_update( + get_selective_state_update_module( + state.device, + state_dtype, + input_dtype, + weight_dtype, + matrixA_dtype, + stateIndex_dtype, + dim, + dstate, + ntokens_mtp, + ).selective_state_update( state, x, dt, @@ -242,6 +320,7 @@ def _selective_state_update( intermediate_states_buffer, intermediate_state_indices, cache_steps, + algorithm, ) @@ -264,6 +343,15 @@ def _selective_state_update_fake( intermediate_states_buffer: Optional[torch.Tensor], intermediate_state_indices: Optional[torch.Tensor], cache_steps: int, + algorithm: int, + state_dtype: torch.dtype, + input_dtype: torch.dtype, + weight_dtype: torch.dtype, + matrixA_dtype: torch.dtype, + stateIndex_dtype: torch.dtype, + dim: int, + dstate: int, + ntokens_mtp: int, ) -> None: """Fake implementation for torch.compile() meta tensor propagation.""" pass diff --git a/include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh b/include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh index 09d927f6fb..af86b8094a 100644 --- a/include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh +++ b/include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh @@ -288,48 +288,40 @@ __global__ void selective_state_update_kernel_simple_mtp(SelectiveStateMTPParams template -void invokeSelectiveStateUpdateMTP(SelectiveStateMTPParams& params, cudaStream_t stream) { +void invokeSelectiveStateUpdateMTP(SelectiveStateMTPParams& params, SSUAlgorithm algorithm, + cudaStream_t stream) { + // MTP only supports the simple kernel + FLASHINFER_CHECK(algorithm == SSUAlgorithm::kAuto || algorithm == SSUAlgorithm::kSimple, + "MTP selective_state_update only supports 'auto' or 'simple' algorithm, got ", + static_cast(algorithm)); // Common alignment checks for all kernels check_ptr_alignment_input_vars(params); - auto kernel_launcher = [&]() { - // Additional alignment checks specific to simple kernel - constexpr auto stateLoadSize = getVectorLoadSizeForFullUtilization(); - using load_state_t = PackedAligned; - - FLASHINFER_CHECK(reinterpret_cast(params.state) % sizeof(load_state_t) == 0, - "state pointer must be aligned to ", sizeof(load_state_t), " bytes"); - FLASHINFER_CHECK((params.dim * params.dstate * sizeof(state_t)) % sizeof(load_state_t) == 0, - "state head stride must be aligned to ", sizeof(load_state_t), " bytes"); - - constexpr int numWarps = 4; - constexpr int stateRowsPerWarpPerStage = 4; - constexpr int stageRows = stateRowsPerWarpPerStage * numWarps; - - dim3 block(warpSize, numWarps); - dim3 grid(params.batch, params.nheads); - - auto func = - selective_state_update_kernel_simple_mtp; - using sram_t = SharedStorageSimple; - constexpr size_t smem_size = sizeof(sram_t); - - // Use FLASHINFER_CHECK instead of FLASHINFER_CUDA_CALL since we're in a void lambda - // (FLASHINFER_CUDA_CALL uses "return e;" which is invalid in void context) - // { - // cudaError_t e = cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, - // smem_size); FLASHINFER_CHECK(e == cudaSuccess, "CUDA Error in cudaFuncSetAttribute: ", - // cudaGetErrorString(e), " (", int(e), ")"); - // } - FLASHINFER_CUDA_CHECK( - cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - - func<<>>(params); - }; - - dispatchDimDstateTokens(params, AllowedDims{}, AllowedDstates{}, AllowedNtokens{}, - kernel_launcher); + constexpr auto stateLoadSize = getVectorLoadSizeForFullUtilization(); + using load_state_t = PackedAligned; + + FLASHINFER_CHECK(reinterpret_cast(params.state) % sizeof(load_state_t) == 0, + "state pointer must be aligned to ", sizeof(load_state_t), " bytes"); + FLASHINFER_CHECK((params.dim * params.dstate * sizeof(state_t)) % sizeof(load_state_t) == 0, + "state head stride must be aligned to ", sizeof(load_state_t), " bytes"); + + constexpr int numWarps = 4; + constexpr int stateRowsPerWarpPerStage = 4; + constexpr int stageRows = stateRowsPerWarpPerStage * numWarps; + + dim3 block(warpSize, numWarps); + dim3 grid(params.batch, params.nheads); + + auto func = + selective_state_update_kernel_simple_mtp; + using sram_t = SharedStorageSimple; + constexpr size_t smem_size = sizeof(sram_t); + + FLASHINFER_CUDA_CHECK( + cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + func<<>>(params); } } // namespace flashinfer::mamba::mtp diff --git a/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh b/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh index 254f2f71ea..58d43ab2af 100644 --- a/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh +++ b/include/flashinfer/mamba/kernel_selective_state_update_stp.cuh @@ -57,17 +57,21 @@ __device__ __forceinline__ int conflict_free_column(int group, int baseCol) { return (baseCol + stateValuesPerBank * bankCycle) % colsPerStage; } -template +template struct SharedStorageSimple { - alignas(alignof(PackedAligned)) input_t x[dim]; - alignas(alignof(PackedAligned)) input_t z[dim]; + alignas(alignof(PackedAligned)) input_t x[rows_per_block]; + alignas(alignof(PackedAligned)) input_t z[rows_per_block]; alignas(alignof(PackedAligned)) input_t B[dstate]; alignas(alignof(PackedAligned)) input_t C[dstate]; - float out[dim]; + float out[rows_per_block]; }; +// Grid: (batch, nheads, cdiv(DIM, ROWS_PER_BLOCK)) +// When ROWS_PER_BLOCK == DIM, degenerates to the non-tiled case (blockIdx.z == 0 always). +// Used when batch*nheads is too small to saturate the GPU: set ROWS_PER_BLOCK < DIM to +// split dim across blocks for better occupancy. template + typename stateIndex_t, int DIM, int DSTATE, int ROWS_PER_BLOCK, int numWarps> __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams params) { auto* __restrict__ output = reinterpret_cast(params.output); auto* __restrict__ state = reinterpret_cast(params.state); @@ -77,8 +81,8 @@ __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams auto const* __restrict__ A = reinterpret_cast(params.A); auto const* __restrict__ B = reinterpret_cast(params.B); auto const* __restrict__ C = reinterpret_cast(params.C); - auto const* __restrict__ D = reinterpret_cast(params.D); // D: (nheads, dim) - auto const* __restrict__ dt_bias = reinterpret_cast(params.dt_bias); // (nheads) + auto const* __restrict__ D = reinterpret_cast(params.D); + auto const* __restrict__ dt_bias = reinterpret_cast(params.dt_bias); auto const* __restrict__ z = reinterpret_cast(params.z); auto const* __restrict__ state_batch_indices = reinterpret_cast(params.state_batch_indices); @@ -87,10 +91,11 @@ __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams int const nheads = params.nheads; int const ngroups = params.ngroups; - constexpr auto rowsPerWarp = (DIM + numWarps - 1) / numWarps; + constexpr auto rowsPerWarp = (ROWS_PER_BLOCK + numWarps - 1) / numWarps; auto const batch = blockIdx.x; auto const head = blockIdx.y; + auto const dim_offset = blockIdx.z * ROWS_PER_BLOCK; auto const group = head / (nheads / ngroups); auto lane = threadIdx.x % warpSize; auto warp = threadIdx.y; @@ -98,7 +103,7 @@ __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams auto const state_batch = (state_batch_indices) ? state_batch_indices[batch] : batch; state += state_batch * params.state_stride_batch + head * DIM * DSTATE; - __shared__ SharedStorageSimple sram; + __shared__ SharedStorageSimple sram; static constexpr auto stateLoadSize = getVectorLoadSizeForFullUtilization(); using load_state_t = PackedAligned; @@ -116,23 +121,21 @@ __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams auto d_value = D ? toFloat(D[head]) : 0.f; + // Load x slice and B (warp 0), z slice and C (warp 1) if (warp == 0) { - for (auto d = lane * load_input_t::count; d < DIM; d += warpSize * load_input_t::count) { - auto* dst = reinterpret_cast(&sram.x[d]); - *dst = *reinterpret_cast( - &x[batch * params.x_stride_batch + head * DIM + d]); + for (auto d = lane; d < ROWS_PER_BLOCK; d += warpSize) { + if (dim_offset + d < DIM) + sram.x[d] = x[batch * params.x_stride_batch + head * DIM + dim_offset + d]; } for (auto i = lane * load_input_t::count; i < DSTATE; i += warpSize * load_input_t::count) { auto* dst = reinterpret_cast(&sram.B[i]); *dst = *reinterpret_cast( &B[batch * params.B_stride_batch + group * DSTATE + i]); } - } else if (warp == 1) { // Load z, C - for (auto d = lane * load_input_t::count; d < DIM; d += warpSize * load_input_t::count) { - auto* dst = reinterpret_cast(&sram.z[d]); - *dst = z ? *reinterpret_cast( - &z[batch * params.z_stride_batch + head * DIM + d]) - : make_zeros(); + } else if (warp == 1) { + for (auto d = lane; d < ROWS_PER_BLOCK; d += warpSize) { + if (dim_offset + d < DIM) + sram.z[d] = z ? z[batch * params.z_stride_batch + head * DIM + dim_offset + d] : input_t(0); } for (auto i = lane * load_input_t::count; i < DSTATE; i += warpSize * load_input_t::count) { auto* dst = reinterpret_cast(&sram.C[i]); @@ -143,11 +146,11 @@ __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams __syncthreads(); for (auto _d = warp * rowsPerWarp; _d < (warp + 1) * rowsPerWarp; _d++) { - auto d = _d; + auto d = dim_offset + _d; if (d >= DIM) break; float x_value = toFloat(sram.x[_d]); - float out_value = d_value * x_value * int(lane == 0); // first lane has the value + float out_value = d_value * x_value * int(lane == 0); for (int i = lane * load_state_t::count; i < DSTATE; i += warpSize * load_state_t::count) { auto rState = make_zeros(); @@ -170,7 +173,6 @@ __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams *reinterpret_cast(&state[d * DSTATE + i]) = rState; } - // warpReduce the out_value out_value = warpReduceSum(out_value); if (lane == 0) { sram.out[_d] = out_value; @@ -180,11 +182,12 @@ __global__ void selective_state_update_kernel_simple(SelectiveStateUpdateParams __syncthreads(); for (int l = lane; l < rowsPerWarp; l += warpSize) { - auto d = warp * rowsPerWarp + l; + auto _d = warp * rowsPerWarp + l; + auto d = dim_offset + _d; if (d < DIM) { - auto out_value = sram.out[d]; + auto out_value = sram.out[_d]; if (z) { - float z_value = toFloat(sram.z[d]); + float z_value = toFloat(sram.z[_d]); float sig_z = __fdividef(1.f, (1.f + __expf(0.f - z_value))); float silu_z = z_value * sig_z; out_value *= silu_z; @@ -210,10 +213,14 @@ struct SharedStorageVertical { barrier_t bar_consumers; }; -template +template __device__ __forceinline__ void producer_func_vertical(SramT& sram, CUtensorMap const& tensorState, - int batch, int head) { + input_t const* x_global_ptr, + input_t const* B_global_ptr, + input_t const* C_global_ptr, + input_t const* z_global_ptr, int batch, + int head) { #ifdef FLASHINFER_MAMBA_ENABLE_SM90 namespace cde = cuda::device::experimental; @@ -222,11 +229,44 @@ __device__ __forceinline__ void producer_func_vertical(SramT& sram, CUtensorMap auto constexpr stagesWriteOnly = numStages; auto constexpr bytesState = rowsPerStage * DSTATE * sizeof(state_t); - auto constexpr bytesToArrive = bytesState; + auto constexpr bytesX = DIM * sizeof(input_t); + auto constexpr bytesB = DSTATE * sizeof(input_t); + auto constexpr bytesC = DSTATE * sizeof(input_t); + auto constexpr bytesZ = hasZ ? DIM * sizeof(input_t) : 0; + auto constexpr bytesInputs = bytesX + bytesB + bytesC + bytesZ; + + // Phase 1, iter 0: fire all input vector loads + state load (if readState) + // All inputs piggyback onto bar_full[0] so consumers get them before stage 0 + { + constexpr auto stage = 0; + constexpr auto d = 0; - // Phase 1: Read only (filling the pipeline) + sram.bar_empty[stage].wait(sram.bar_empty[stage].arrive()); + + cuda::device::memcpy_async_tx(&sram.x[0], x_global_ptr, cuda::aligned_size_t<16>(bytesX), + sram.bar_full[stage]); + cuda::device::memcpy_async_tx(&sram.B[0], B_global_ptr, cuda::aligned_size_t<16>(bytesB), + sram.bar_full[stage]); + cuda::device::memcpy_async_tx(&sram.C[0], C_global_ptr, cuda::aligned_size_t<16>(bytesC), + sram.bar_full[stage]); + if constexpr (hasZ) { + cuda::device::memcpy_async_tx(&sram.z[0], z_global_ptr, cuda::aligned_size_t<16>(bytesZ), + sram.bar_full[stage]); + } + + if constexpr (readState) { + cde::cp_async_bulk_tensor_4d_global_to_shared(&sram.state[stage][0], &tensorState, 0, d, head, + batch, sram.bar_full[stage]); + auto const _ = + cuda::device::barrier_arrive_tx(sram.bar_full[stage], 1, bytesState + bytesInputs); + } else { + auto const _ = cuda::device::barrier_arrive_tx(sram.bar_full[stage], 1, bytesInputs); + } + } + + // Phase 1, iter 1..stagesReadOnly-1: state only (x already in flight) #pragma unroll - for (int iter = 0; iter < stagesReadOnly; ++iter) { + for (int iter = 1; iter < stagesReadOnly; ++iter) { auto const stage = iter % numStages; auto const d = iter * rowsPerStage; @@ -235,8 +275,7 @@ __device__ __forceinline__ void producer_func_vertical(SramT& sram, CUtensorMap if constexpr (readState) { cde::cp_async_bulk_tensor_4d_global_to_shared(&sram.state[stage][0], &tensorState, 0, d, head, batch, sram.bar_full[stage]); - - auto const _ = cuda::device::barrier_arrive_tx(sram.bar_full[stage], 1, bytesToArrive); + auto const _ = cuda::device::barrier_arrive_tx(sram.bar_full[stage], 1, bytesState); } else { auto const _ = sram.bar_full[stage].arrive(); } @@ -267,7 +306,7 @@ __device__ __forceinline__ void producer_func_vertical(SramT& sram, CUtensorMap if constexpr (readState) { cde::cp_async_bulk_tensor_4d_global_to_shared(&sram.state[stage][0], &tensorState, 0, d_read, head, batch, sram.bar_full[stage]); - auto const _ = cuda::device::barrier_arrive_tx(sram.bar_full[stage], 1, bytesToArrive); + auto const _ = cuda::device::barrier_arrive_tx(sram.bar_full[stage], 1, bytesState); } else { auto const _ = sram.bar_full[stage].arrive(); } @@ -417,7 +456,7 @@ __global__ void selective_state_update_kernel_producer_consumer_vertical( auto lane = threadIdx.x % warpSize; auto warp = threadIdx.y; - auto const state_batch = (state_batch_indices) ? state_batch_indices[batch] : batch; + auto const state_batch = (state_batch_indices) ? __ldg(&state_batch_indices[batch]) : batch; extern __shared__ uint8_t sbuffer[]; using sram_t = SharedStorageVertical() { + producer_func_vertical(sram, tensorState, x_global_ptr, B_global_ptr, + C_global_ptr, hasZ ? z_global_ptr : nullptr, + state_batch, head); + }; + auto const dispatch_state = [&]() { if (read_state && write_state) - producer_func_vertical( - sram, tensorState, state_batch, head); - else if (read_state && !write_state) - producer_func_vertical( - sram, tensorState, state_batch, head); + call.template operator()(); + else if (read_state) + call.template operator()(); + else + call.template operator()(); + }; + + cg::invoke_one(cg::coalesced_threads(), [&]() { + if (z_global_ptr) + dispatch_state.template operator()(); else - producer_func_vertical( - sram, tensorState, state_batch, head); + dispatch_state.template operator()(); }); } } else { // consumers - using load_t = PackedAligned; - #pragma unroll // Unblock the producer for (uint8_t stage = 0; stage < numStages; ++stage) { auto const _ = sram.bar_empty[stage].arrive(); } - // Load A - auto const A_value = toFloat(A[head]); + // Load A, D, dt, dt_bias via __ldg (read-only texture cache) — + // these are broadcast scalars read once per block. + auto const A_value = toFloat(__ldg(&A[head])); - // Load D - auto const d_value = D ? toFloat(D[head]) : 0.f; + auto const d_value = D ? toFloat(__ldg(&D[head])) : 0.f; - // load dt_value - auto dt_value = toFloat(dt[batch * params.dt_stride_batch + head]); - if (dt_bias) dt_value += toFloat(dt_bias[head]); + auto dt_value = toFloat(__ldg(&dt[batch * params.dt_stride_batch + head])); + if (dt_bias) dt_value += toFloat(__ldg(&dt_bias[head])); if (params.dt_softplus) { dt_value = thresholded_softplus(dt_value); } auto const dA = __expf(A_value * dt_value); - if (warp == 0) { // Load x, B - for (auto d = lane * load_t::count; d < DIM; d += warpSize * load_t::count) { - auto* dst = reinterpret_cast(&sram.x[d]); - *dst = *reinterpret_cast(&x[batch * params.x_stride_batch + head * DIM + d]); - } - for (auto i = lane * load_t::count; i < DSTATE; i += warpSize * load_t::count) { - auto* dst = reinterpret_cast(&sram.B[i]); - *dst = *reinterpret_cast( - &B[batch * params.B_stride_batch + group * DSTATE + i]); - } - } else if (warp == 1) { // Load z, C - for (auto d = lane * load_t::count; d < DIM; d += warpSize * load_t::count) { - auto* dst = reinterpret_cast(&sram.z[d]); - *dst = - z ? *reinterpret_cast(&z[batch * params.z_stride_batch + head * DIM + d]) - : make_zeros(); - } - for (auto i = lane * load_t::count; i < DSTATE; i += warpSize * load_t::count) { - auto* dst = reinterpret_cast(&sram.C[i]); - *dst = *reinterpret_cast( - &C[batch * params.C_stride_batch + group * DSTATE + i]); - } - } - - sram.bar_consumers.wait(sram.bar_consumers.arrive()); - if (state_batch != params.pad_slot_id) consumer_func_vertical(lane, warp, d_value, dt_value, dA, @@ -518,7 +542,7 @@ __global__ void selective_state_update_kernel_producer_consumer_vertical( rowsPerStage, numStages, false>(lane, warp, d_value, dt_value, dA, sram); - // Write output + // Write output — wait for all consumer warps to finish writing sram.out sram.bar_consumers.wait(sram.bar_consumers.arrive()); auto d = warp * warpSize + lane; if (d < DIM) { @@ -535,11 +559,7 @@ __global__ void selective_state_update_kernel_producer_consumer_vertical( #endif } -// ============================================================================= -// Horizontal Producer-Consumer Kernel for SM100+ (Blackwell and newer) -// ============================================================================= - -#ifdef FLASHINFER_MAMBA_ENABLE_SM100 +#ifdef FLASHINFER_MAMBA_ENABLE_SM90 template -void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, cudaStream_t stream) { +void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, SSUAlgorithm algorithm, + cudaStream_t stream) { auto [sm_major, sm_minor] = GetCudaComputeCapability(); // Common alignment checks for all kernels check_ptr_alignment_input_vars(params); -#ifdef FLASHINFER_MAMBA_ENABLE_SM100 - if (sm_major < 10) // pre-Blackwell -#elif defined(FLASHINFER_MAMBA_ENABLE_SM90) - if (sm_major < 9) // pre-Hopper + // Resolve auto to a concrete algorithm based on GPU architecture and batch size + SSUAlgorithm algo = algorithm; + if (algo == SSUAlgorithm::kAuto) { +#ifdef FLASHINFER_MAMBA_ENABLE_SM90 + if (sm_major < 9) { + algo = SSUAlgorithm::kSimple; + } else { + // At small batch sizes, the tiled simple kernel outperforms producer-consumer + // kernels because it has lower per-block overhead and can still saturate the GPU + // via dim-tiling. Threshold: batch*nheads < 2*num_SMs (i.e. not enough blocks + // for the non-tiled producer-consumer kernels to hide latency). + int const total_blocks = params.batch * params.nheads; + int const num_sms = GetCudaMultiProcessorCount(); + if (total_blocks < num_sms * 2) + algo = SSUAlgorithm::kSimple; + else if (sm_major < 10) + algo = SSUAlgorithm::kVertical; + else + // On Blackwell+: vertical is slightly faster for fp32 state, + // horizontal is faster for fp16/bf16 state. + algo = (sizeof(state_t) == 4) ? SSUAlgorithm::kVertical : SSUAlgorithm::kHorizontal; + } +#else + algo = SSUAlgorithm::kSimple; #endif - { - auto kernel_launcher = [&]() { - // Additional alignment checks specific to simple kernel - constexpr auto stateLoadSize = getVectorLoadSizeForFullUtilization(); - using load_state_t = PackedAligned; - - FLASHINFER_CHECK(reinterpret_cast(params.state) % sizeof(load_state_t) == 0, - "state pointer must be aligned to ", sizeof(load_state_t), " bytes"); - FLASHINFER_CHECK((params.dim * params.dstate * sizeof(state_t)) % sizeof(load_state_t) == 0, - "state head stride must be aligned to ", sizeof(load_state_t), " bytes"); + } - constexpr int numWarps = 4; - dim3 block(warpSize, numWarps); + if (algo == SSUAlgorithm::kSimple) { + constexpr auto stateLoadSize = getVectorLoadSizeForFullUtilization(); + using load_state_t = PackedAligned; + + FLASHINFER_CHECK(reinterpret_cast(params.state) % sizeof(load_state_t) == 0, + "state pointer must be aligned to ", sizeof(load_state_t), " bytes"); + FLASHINFER_CHECK((params.dim * params.dstate * sizeof(state_t)) % sizeof(load_state_t) == 0, + "state head stride must be aligned to ", sizeof(load_state_t), " bytes"); + + constexpr int numWarps = 4; + constexpr int ROWS_PER_BLOCK = 4; + int const total_blocks = params.batch * params.nheads; + int const num_sms = GetCudaMultiProcessorCount(); + + dim3 block(warpSize, numWarps); + if (total_blocks < num_sms * 2) { + // Tiled: split dim across blocks for better GPU occupancy at small batch sizes + int const dim_tiles = (DIM + ROWS_PER_BLOCK - 1) / ROWS_PER_BLOCK; + dim3 grid(params.batch, params.nheads, dim_tiles); + selective_state_update_kernel_simple + <<>>(params); + } else { + // Non-tiled: enough blocks already for full occupancy; ROWS_PER_BLOCK == DIM so blockIdx.z == + // 0 dim3 grid(params.batch, params.nheads); selective_state_update_kernel_simple<<>>(params); - }; - - dispatchDimDstate(params, AllowedDims{}, AllowedDstates{}, kernel_launcher); + DSTATE, DIM, numWarps> + <<>>(params); + } } #ifdef FLASHINFER_MAMBA_ENABLE_SM90 - else { - - auto kernel_launcher = [&]() { - // Note: State uses TMA which requires 128B alignment (checked below) - constexpr auto numConsumers = 4; - constexpr auto numWarps = 1 + numConsumers; - constexpr auto numStages = 3; - constexpr auto rowsPerStage = 4 * numConsumers; - FLASHINFER_CHECK(params.dim % rowsPerStage == 0, "dim must be divisible by ", rowsPerStage, - " for SM90+ kernel"); - auto scan_func = selective_state_update_kernel_producer_consumer_vertical< - input_t, weight_t, matrixA_t, state_t, stateIndex_t, DIM, DSTATE, numConsumers, - rowsPerStage, numStages>; + else if (algo == SSUAlgorithm::kVertical) { + constexpr auto numConsumers = 4; + constexpr auto numWarps = 1 + numConsumers; + constexpr auto numStages = 3; + constexpr auto rowsPerStage = 4 * numConsumers; + FLASHINFER_CHECK(params.dim % rowsPerStage == 0, "dim must be divisible by ", rowsPerStage, + " for vertical kernel"); + auto scan_func = selective_state_update_kernel_producer_consumer_vertical< + input_t, weight_t, matrixA_t, state_t, stateIndex_t, DIM, DSTATE, numConsumers, + rowsPerStage, numStages>; + + dim3 block(warpSize, numWarps); + dim3 grid(params.batch, params.nheads); + + auto state_tensor = + tma::buildNdDescriptor(typeid(state_t), + /*shapes*/ {DSTATE, DIM, params.nheads, params.state_cache_size}, + /*strides*/ {1, DSTATE, DSTATE * DIM, params.state_stride_batch}, + /*tiles*/ {DSTATE, rowsPerStage, 1, 1}, params.state); + + using sram_t = SharedStorageVertical; + constexpr size_t smem_size = sizeof(sram_t); + FLASHINFER_CUDA_CHECK( + cudaFuncSetAttribute(scan_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + scan_func<<>>(params, state_tensor); + } else if (algo == SSUAlgorithm::kHorizontal) { + constexpr auto numConsumers = (DIM / 64) * 4; + constexpr auto numProducers = 1; + constexpr auto numWarps = numProducers + numConsumers; + + constexpr auto sectorSize = 32; // bytes + constexpr auto stageCols = 2 * sectorSize / sizeof(state_t); + + constexpr auto totalStages = DSTATE / stageCols; + constexpr auto numStages = (totalStages >= 4) ? 4 : totalStages; + + auto ratio_launcher = [&]() { + auto scan_func = selective_state_update_kernel_producer_consumer_horizontal< + input_t, weight_t, matrixA_t, state_t, stateIndex_t, DIM, DSTATE, numConsumers, stageCols, + RATIO, numStages>; dim3 block(warpSize, numWarps); dim3 grid(params.batch, params.nheads); - auto nh = params.nheads; - auto dim = params.dim; - auto state_tensor = tma::buildNdDescriptor(typeid(state_t), - /*shapes*/ {DSTATE, DIM, nh, params.state_cache_size}, + /*shapes*/ {DSTATE, DIM, params.nheads, params.state_cache_size}, /*strides*/ {1, DSTATE, DSTATE * DIM, params.state_stride_batch}, - /*tiles*/ {DSTATE, rowsPerStage, 1, 1}, params.state); + /*tiles*/ {stageCols, DIM, 1, 1}, params.state); + static_assert(DSTATE % stageCols == 0 && DSTATE >= stageCols); - // Calculate shared memory size and opt-in to extended shared memory - using sram_t = SharedStorageVertical; + using sram_t = SharedStorageHorizontal; constexpr size_t smem_size = sizeof(sram_t); FLASHINFER_CUDA_CHECK( cudaFuncSetAttribute(scan_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -953,58 +1033,13 @@ void invokeSelectiveStateUpdate(SelectiveStateUpdateParams& params, cudaStream_t scan_func<<>>(params, state_tensor); }; - dispatchDimDstate(params, AllowedDims{}, AllowedDstates{}, kernel_launcher); + dispatchRatio(params, std::integer_sequence{}, ratio_launcher); } #endif - -#ifdef FLASHINFER_MAMBA_ENABLE_SM100 else { - // SM100+ (Blackwell and newer) uses horizontal producer-consumer kernel - auto kernel_launcher = [&]() { - // profiling showed that it's good to have 4 producers per 64 rows - constexpr auto numConsumers = (DIM / 64) * 4; - constexpr auto numProducers = 1; - constexpr auto numWarps = numProducers + numConsumers; - - constexpr auto sectorSize = 32; // bytes - constexpr auto stageCols = 2 * sectorSize / sizeof(state_t); - - constexpr auto totalStages = DSTATE / stageCols; - constexpr auto numStages = (totalStages >= 4) ? 4 : totalStages; - - auto ratio_launcher = [&]() { - auto scan_func = selective_state_update_kernel_producer_consumer_horizontal< - input_t, weight_t, matrixA_t, state_t, stateIndex_t, DIM, DSTATE, numConsumers, - stageCols, RATIO, numStages>; - - dim3 block(warpSize, numWarps); - dim3 grid(params.batch, params.nheads); - - auto nh = params.nheads; - auto dim = params.dim; - - auto state_tensor = - tma::buildNdDescriptor(typeid(state_t), - /*shapes*/ {DSTATE, DIM, nh, params.state_cache_size}, - /*strides*/ {1, DSTATE, DSTATE * DIM, params.state_stride_batch}, - /*tiles*/ {stageCols, DIM, 1, 1}, params.state); - static_assert(DSTATE % stageCols == 0 && DSTATE >= stageCols); - - using sram_t = SharedStorageHorizontal; - constexpr size_t smem_size = sizeof(sram_t); - FLASHINFER_CUDA_CHECK(cudaFuncSetAttribute( - scan_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - - scan_func<<>>(params, state_tensor); - }; - - dispatchRatio(params, std::integer_sequence{}, ratio_launcher); - }; - - dispatchDimDstate(params, AllowedDims{}, AllowedDstates{}, kernel_launcher); + FLASHINFER_CHECK(false, "Unsupported SSU algorithm: ", SSUAlgorithmToString(algo), + ". Vertical/horizontal require FLASHINFER_MAMBA_ENABLE_SM90."); } -#endif } } // namespace flashinfer::mamba diff --git a/include/flashinfer/mamba/selective_state_update.cuh b/include/flashinfer/mamba/selective_state_update.cuh index f8b44c3779..0607d7a0f2 100644 --- a/include/flashinfer/mamba/selective_state_update.cuh +++ b/include/flashinfer/mamba/selective_state_update.cuh @@ -21,12 +21,29 @@ namespace flashinfer::mamba { -// ============================================================================= -// Allowed dispatch values for kernel instantiation -// ============================================================================= -using AllowedDims = std::integer_sequence; -using AllowedDstates = std::integer_sequence; -using AllowedNtokens = std::integer_sequence; +// Host-side algorithm selection for invokeSelectiveStateUpdate dispatch. +// Not stored in kernel params — no register overhead. +enum class SSUAlgorithm : int32_t { + kAuto = 0, + kSimple = 1, + kVertical = 2, + kHorizontal = 3, +}; + +inline const char* SSUAlgorithmToString(SSUAlgorithm algo) { + switch (algo) { + case SSUAlgorithm::kAuto: + return "Auto"; + case SSUAlgorithm::kSimple: + return "Simple"; + case SSUAlgorithm::kVertical: + return "Vertical"; + case SSUAlgorithm::kHorizontal: + return "Horizontal"; + default: + return "Unknown"; + } +} struct SelectiveStateUpdateParams { uint32_t batch{}, nheads{}, dim{}, dstate{}, ngroups{}, state_cache_size{}; diff --git a/tests/mamba/test_selective_state_update_mtp.py b/tests/mamba/test_selective_state_update_mtp.py index 866c32e98a..f295ee6bad 100644 --- a/tests/mamba/test_selective_state_update_mtp.py +++ b/tests/mamba/test_selective_state_update_mtp.py @@ -11,10 +11,30 @@ import flashinfer -from .selective_state_update_triton import selective_state_update_triton +from .triton_reference.selective_state_update import selective_state_update_triton from .utils import create_test_inputs, clone_preserving_strides +# Base combination: batch=64, nheads=64, dim=64, dstate=128, cache_steps=4, +# state_dtype=bf16, weight_dtype=f32, use_out_tensor=True +# Each additional row varies exactly one parameter from the base. +# fmt: off +_BASE_PARAMS = [ + # (batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype, use_out_tensor) + ( 64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True ), # base + ( 1, 64, 64, 128, 4, torch.bfloat16, torch.float32, True ), # batch=1 + ( 4, 64, 64, 128, 4, torch.bfloat16, torch.float32, True ), # batch=4 + ( 64, 8, 64, 128, 4, torch.bfloat16, torch.float32, True ), # nheads=8 + ( 64, 64, 128, 128, 4, torch.bfloat16, torch.float32, True ), # dim=128 + ( 64, 64, 64, 64, 4, torch.bfloat16, torch.float32, True ), # dstate=64 + ( 64, 64, 64, 128, 1, torch.bfloat16, torch.float32, True ), # cache_steps=1 + ( 64, 64, 64, 128, 8, torch.bfloat16, torch.float32, True ), # cache_steps=8 + ( 64, 64, 64, 128, 4, torch.float32, torch.float32, True ), # state_dtype=f32 + ( 64, 64, 64, 128, 4, torch.bfloat16, torch.float32, False), # use_out_tensor=False +] +# fmt: on + + class TestSelectiveStateUpdateMTP: """Test class for multi-token selective state update kernels.""" @@ -25,41 +45,7 @@ class TestSelectiveStateUpdateMTP: INPUT_DTYPE = torch.bfloat16 MATRIX_A_DTYPE = torch.float32 - @pytest.fixture(params=[1, 4]) - def batch(self, request): - return request.param - - @pytest.fixture(params=[8, 32]) - def nheads(self, request): - return request.param - - @pytest.fixture(params=[64, 128]) - def dim(self, request): - return request.param - - @pytest.fixture(params=[64, 128]) - def dstate(self, request): - return request.param - - @pytest.fixture(params=[1, 4, 8]) - def cache_steps(self, request): - """Number of tokens in multi-token mode (T dimension).""" - return request.param - - @pytest.fixture(params=[torch.float32, torch.bfloat16]) - def state_dtype(self, request): - return request.param - - @pytest.fixture(params=[torch.float32]) - def weight_dtype(self, request): - return request.param - - @pytest.fixture(params=[False, True]) - def use_out_tensor(self, request): - return request.param - - @pytest.fixture - def inputs( + def make_inputs( self, batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype ): """Create test inputs for given parameters.""" @@ -79,8 +65,7 @@ def inputs( seed=0, ) - @pytest.fixture - def reference_output(self, inputs): + def make_reference_output(self, inputs): """Compute reference output using triton implementation.""" state_ref = clone_preserving_strides(inputs["state_cache"]) y_ref = selective_state_update_triton( @@ -182,9 +167,26 @@ def _print_mismatch_details(self, ref, test, name): f"diff={diff:.6e}, rel_diff={rel_diff:.6e}" ) - def test_output_correctness(self, inputs, reference_output, use_out_tensor): + @pytest.mark.parametrize( + "batch,nheads,dim,dstate,cache_steps,state_dtype,weight_dtype,use_out_tensor", + _BASE_PARAMS, + ) + def test_output_correctness( + self, + batch, + nheads, + dim, + dstate, + cache_steps, + state_dtype, + weight_dtype, + use_out_tensor, + ): """Test that kernel output matches reference within tolerance.""" - y_ref, state_ref = reference_output + inputs = self.make_inputs( + batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype + ) + y_ref, state_ref = self.make_reference_output(inputs) # Prepare output tensor if requested if use_out_tensor: @@ -207,36 +209,7 @@ def test_output_correctness(self, inputs, reference_output, use_out_tensor): class TestSelectiveStateUpdateMTPWithZ(TestSelectiveStateUpdateMTP): """Test multi-token selective_state_update with z tensor (gating).""" - @pytest.fixture(params=[4]) - def batch(self, request): - return request.param - - @pytest.fixture(params=[8]) - def nheads(self, request): - return request.param - - @pytest.fixture(params=[64]) - def dim(self, request): - return request.param - - @pytest.fixture(params=[64]) - def dstate(self, request): - return request.param - - @pytest.fixture(params=[4]) - def cache_steps(self, request): - return request.param - - @pytest.fixture(params=[torch.bfloat16]) - def state_dtype(self, request): - return request.param - - @pytest.fixture(params=[torch.float32]) - def weight_dtype(self, request): - return request.param - - @pytest.fixture - def inputs( + def make_inputs( self, batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype ): """Create test inputs with z tensor.""" @@ -256,41 +229,56 @@ def inputs( seed=0, ) + @pytest.mark.parametrize( + "batch,nheads,dim,dstate,cache_steps,state_dtype,weight_dtype,use_out_tensor", + [(64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True)], + ) + def test_output_correctness( + self, + batch, + nheads, + dim, + dstate, + cache_steps, + state_dtype, + weight_dtype, + use_out_tensor, + ): + super().test_output_correctness( + batch, + nheads, + dim, + dstate, + cache_steps, + state_dtype, + weight_dtype, + use_out_tensor, + ) + class TestSelectiveStateUpdateMTPDisableStateUpdate(TestSelectiveStateUpdateMTP): """Test multi-token selective_state_update with disable_state_update=True.""" - @pytest.fixture(params=[4]) - def batch(self, request): - return request.param - - @pytest.fixture(params=[32]) - def nheads(self, request): - return request.param - - @pytest.fixture(params=[64]) - def dim(self, request): - return request.param - - @pytest.fixture(params=[64, 128]) - def dstate(self, request): - return request.param - - @pytest.fixture(params=[4, 8]) - def cache_steps(self, request): - return request.param - - @pytest.fixture(params=[torch.bfloat16]) - def state_dtype(self, request): - return request.param - - @pytest.fixture(params=[torch.float32]) - def weight_dtype(self, request): - return request.param - - def test_output_correctness(self, inputs, reference_output, use_out_tensor): + @pytest.mark.parametrize( + "batch,nheads,dim,dstate,cache_steps,state_dtype,weight_dtype,use_out_tensor", + [(64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True)], + ) + def test_output_correctness( + self, + batch, + nheads, + dim, + dstate, + cache_steps, + state_dtype, + weight_dtype, + use_out_tensor, + ): """Test that kernel output matches reference but state is not updated.""" - y_ref, state_ref = reference_output + inputs = self.make_inputs( + batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype + ) + y_ref, _ = self.make_reference_output(inputs) # Save the initial state before running the kernel state_initial = inputs["state_cache"].clone() @@ -343,36 +331,7 @@ def test_output_correctness(self, inputs, reference_output, use_out_tensor): class TestSelectiveStateUpdateMTPWithIntermediateStates(TestSelectiveStateUpdateMTP): """Test multi-token selective_state_update with intermediate states buffer.""" - @pytest.fixture(params=[4]) - def batch(self, request): - return request.param - - @pytest.fixture(params=[32]) - def nheads(self, request): - return request.param - - @pytest.fixture(params=[64]) - def dim(self, request): - return request.param - - @pytest.fixture(params=[64, 128]) - def dstate(self, request): - return request.param - - @pytest.fixture(params=[2, 4, 8]) - def cache_steps(self, request): - return request.param - - @pytest.fixture(params=[torch.bfloat16]) - def state_dtype(self, request): - return request.param - - @pytest.fixture(params=[torch.float32]) - def weight_dtype(self, request): - return request.param - - @pytest.fixture - def inputs( + def make_inputs( self, batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype ): """Create test inputs with intermediate states buffer.""" @@ -392,8 +351,7 @@ def inputs( seed=0, ) - @pytest.fixture - def reference_output(self, inputs): + def make_reference_output(self, inputs): """Compute reference output using triton implementation with intermediate states.""" state_ref = clone_preserving_strides(inputs["state_cache"]) intermediate_states_ref = inputs["intermediate_states_buffer"].clone() @@ -440,9 +398,37 @@ def run_kernel_with_intermediate_states(self, inputs, out=None): cache_steps=inputs["cache_steps"], ) - def test_output_correctness(self, inputs, reference_output, use_out_tensor): + # fmt: off + _INTERMEDIATE_PARAMS = [ + # (batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype, use_out_tensor) + ( 64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True ), # base + ( 64, 64, 64, 64, 4, torch.bfloat16, torch.float32, True ), # dstate=64 + ( 64, 64, 64, 128, 2, torch.bfloat16, torch.float32, True ), # cache_steps=2 + ( 64, 64, 64, 128, 8, torch.bfloat16, torch.float32, True ), # cache_steps=8 + ( 64, 64, 64, 128, 4, torch.bfloat16, torch.float32, False), # use_out_tensor=False + ] + # fmt: on + + @pytest.mark.parametrize( + "batch,nheads,dim,dstate,cache_steps,state_dtype,weight_dtype,use_out_tensor", + _INTERMEDIATE_PARAMS, + ) + def test_output_correctness( + self, + batch, + nheads, + dim, + dstate, + cache_steps, + state_dtype, + weight_dtype, + use_out_tensor, + ): """Test that kernel output matches and intermediate states are cached correctly.""" - y_ref, state_ref, intermediate_states_ref = reference_output + inputs = self.make_inputs( + batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype + ) + y_ref, _state_ref, intermediate_states_ref = self.make_reference_output(inputs) # Prepare output tensor if requested if use_out_tensor: @@ -488,36 +474,7 @@ def test_output_correctness(self, inputs, reference_output, use_out_tensor): class TestSelectiveStateUpdateMTPNonContiguous(TestSelectiveStateUpdateMTP): """Test multi-token selective_state_update with non-contiguous state cache.""" - @pytest.fixture(params=[4]) - def batch(self, request): - return request.param - - @pytest.fixture(params=[8]) - def nheads(self, request): - return request.param - - @pytest.fixture(params=[64]) - def dim(self, request): - return request.param - - @pytest.fixture(params=[64]) - def dstate(self, request): - return request.param - - @pytest.fixture(params=[4]) - def cache_steps(self, request): - return request.param - - @pytest.fixture(params=[torch.bfloat16]) - def state_dtype(self, request): - return request.param - - @pytest.fixture(params=[torch.float32]) - def weight_dtype(self, request): - return request.param - - @pytest.fixture - def inputs( + def make_inputs( self, batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype ): """Create test inputs with non-contiguous state cache (2x batch stride).""" @@ -540,8 +497,7 @@ def inputs( seed=0, ) - @pytest.fixture - def reference_output(self, inputs): + def make_reference_output(self, inputs): """Compute reference output, preserving non-contiguous strides.""" state_ref = clone_preserving_strides(inputs["state_cache"]) y_ref = selective_state_update_triton( @@ -560,38 +516,36 @@ def reference_output(self, inputs): ) return y_ref, state_ref + @pytest.mark.parametrize( + "batch,nheads,dim,dstate,cache_steps,state_dtype,weight_dtype,use_out_tensor", + [(64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True)], + ) + def test_output_correctness( + self, + batch, + nheads, + dim, + dstate, + cache_steps, + state_dtype, + weight_dtype, + use_out_tensor, + ): + super().test_output_correctness( + batch, + nheads, + dim, + dstate, + cache_steps, + state_dtype, + weight_dtype, + use_out_tensor, + ) + class TestSelectiveStateUpdateMTPInt32Indices(TestSelectiveStateUpdateMTP): """Test multi-token selective_state_update with int32 state_batch_indices.""" - @pytest.fixture(params=[4]) - def batch(self, request): - return request.param - - @pytest.fixture(params=[8]) - def nheads(self, request): - return request.param - - @pytest.fixture(params=[64]) - def dim(self, request): - return request.param - - @pytest.fixture(params=[64]) - def dstate(self, request): - return request.param - - @pytest.fixture(params=[4]) - def cache_steps(self, request): - return request.param - - @pytest.fixture(params=[torch.bfloat16]) - def state_dtype(self, request): - return request.param - - @pytest.fixture(params=[torch.float32]) - def weight_dtype(self, request): - return request.param - def run_kernel(self, inputs, out=None, disable_state_update=False): """Run the flashinfer kernel with int32 state_batch_indices.""" # Cast slot_idx to int32 @@ -614,46 +568,51 @@ def run_kernel(self, inputs, out=None, disable_state_update=False): disable_state_update=disable_state_update, ) + @pytest.mark.parametrize( + "batch,nheads,dim,dstate,cache_steps,state_dtype,weight_dtype,use_out_tensor", + [(64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True)], + ) + def test_output_correctness( + self, + batch, + nheads, + dim, + dstate, + cache_steps, + state_dtype, + weight_dtype, + use_out_tensor, + ): + super().test_output_correctness( + batch, + nheads, + dim, + dstate, + cache_steps, + state_dtype, + weight_dtype, + use_out_tensor, + ) + class TestSelectiveStateUpdateMTPVariousNgroups(TestSelectiveStateUpdateMTP): """Test multi-token selective_state_update with various ngroups values.""" - NGROUPS = None # Will be set by fixture - - @pytest.fixture(params=[1, 2, 4, 8]) - def ngroups(self, request): - return request.param - - @pytest.fixture(params=[4]) - def batch(self, request): - return request.param - - @pytest.fixture(params=[32]) - def nheads(self, request): - return request.param - - @pytest.fixture(params=[64]) - def dim(self, request): - return request.param - - @pytest.fixture(params=[64]) - def dstate(self, request): - return request.param - - @pytest.fixture(params=[4]) - def cache_steps(self, request): - return request.param - - @pytest.fixture(params=[torch.bfloat16]) - def state_dtype(self, request): - return request.param - - @pytest.fixture(params=[torch.float32]) - def weight_dtype(self, request): - return request.param - - @pytest.fixture - def inputs( + # fmt: off + _NGROUPS_PARAMS = [ + # (batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype, use_out_tensor, ngroups) + ( 64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True, 1), + ( 64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True, 2), + ( 64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True, 4), + ( 64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True, 8), + ] + # fmt: on + + @pytest.mark.parametrize( + "batch,nheads,dim,dstate,cache_steps,state_dtype,weight_dtype,use_out_tensor,ngroups", + _NGROUPS_PARAMS, + ) + def test_output_correctness( self, batch, nheads, @@ -662,10 +621,11 @@ def inputs( cache_steps, state_dtype, weight_dtype, + use_out_tensor, ngroups, ): - """Create test inputs with specified ngroups.""" - return create_test_inputs( + """Test that kernel output matches reference within tolerance.""" + inputs = create_test_inputs( batch, nheads, dim, @@ -680,38 +640,60 @@ def inputs( cache_steps=cache_steps, seed=0, ) + y_ref, state_ref = self.make_reference_output(inputs) + if use_out_tensor: + out = torch.empty_like(inputs["x"]) + else: + out = None -class TestSelectiveStateUpdateMTPLargeBatch(TestSelectiveStateUpdateMTP): - """Test multi-token selective_state_update with larger batch sizes.""" - - @pytest.fixture(params=[16, 64]) - def batch(self, request): - return request.param - - @pytest.fixture(params=[32]) - def nheads(self, request): - return request.param + y_test = self.run_kernel(inputs, out=out) - @pytest.fixture(params=[64]) - def dim(self, request): - return request.param + if use_out_tensor: + assert y_test.data_ptr() == out.data_ptr(), ( + "Returned tensor should be the same object as the provided output tensor" + ) - @pytest.fixture(params=[64]) - def dstate(self, request): - return request.param + self.assert_outputs_match(y_ref, y_test) + self.assert_states_match(state_ref, inputs["state_cache"], inputs["slot_idx"]) - @pytest.fixture(params=[4, 8]) - def cache_steps(self, request): - return request.param - @pytest.fixture(params=[torch.bfloat16]) - def state_dtype(self, request): - return request.param +class TestSelectiveStateUpdateMTPLargeBatch(TestSelectiveStateUpdateMTP): + """Test multi-token selective_state_update with larger batch sizes.""" - @pytest.fixture(params=[torch.float32]) - def weight_dtype(self, request): - return request.param + # fmt: off + _LARGE_BATCH_PARAMS = [ + # (batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype, use_out_tensor) + ( 16, 64, 64, 128, 4, torch.bfloat16, torch.float32, True ), # batch=16 + ( 256, 64, 64, 128, 4, torch.bfloat16, torch.float32, True ), # batch=256 + ] + # fmt: on + + @pytest.mark.parametrize( + "batch,nheads,dim,dstate,cache_steps,state_dtype,weight_dtype,use_out_tensor", + _LARGE_BATCH_PARAMS, + ) + def test_output_correctness( + self, + batch, + nheads, + dim, + dstate, + cache_steps, + state_dtype, + weight_dtype, + use_out_tensor, + ): + super().test_output_correctness( + batch, + nheads, + dim, + dstate, + cache_steps, + state_dtype, + weight_dtype, + use_out_tensor, + ) class TestSelectiveStateUpdateMTPIndicesDtypeMismatch: diff --git a/tests/mamba/test_selective_state_update_stp.py b/tests/mamba/test_selective_state_update_stp.py index c26a5849b8..d23faa644a 100644 --- a/tests/mamba/test_selective_state_update_stp.py +++ b/tests/mamba/test_selective_state_update_stp.py @@ -3,11 +3,41 @@ import torch import flashinfer +from flashinfer.utils import get_compute_capability -from .selective_state_update_triton import selective_state_update_triton +from .triton_reference.selective_state_update import selective_state_update_triton from .utils import create_test_inputs, clone_preserving_strides +def _get_algorithms(): + """Return list of algorithms supported on the current GPU.""" + major, _ = get_compute_capability(torch.device("cuda")) + algos = ["simple"] + if major >= 9: + algos.extend(["vertical", "horizontal"]) + return algos + + +# Base combination: batch=64, nheads=64, dim=64, dstate=128, state_dtype=bf16, +# weight_dtype=f32, use_out_tensor=True +# Each additional row varies exactly one parameter from the base. +# fmt: off +_BASE_PARAMS = [ + # (batch, nheads, dim, dstate, state_dtype, weight_dtype, use_out_tensor) + ( 64, 64, 64, 128, torch.bfloat16, torch.float32, True ), # base bf16 + ( 64, 64, 64, 128, torch.float32, torch.float32, True ), # state_dtype=f32 + ( 1, 64, 64, 128, torch.bfloat16, torch.float32, True ), # batch=1 + ( 64, 8, 64, 128, torch.bfloat16, torch.float32, True ), # nheads=8 + ( 64, 64, 128, 128, torch.bfloat16, torch.float32, True ), # dim=128 + ( 64, 64, 64, 64, torch.bfloat16, torch.float32, True ), # dstate=64 + ( 64, 64, 64, 256, torch.bfloat16, torch.float32, True ), # dstate=256 + ( 64, 64, 64, 128, torch.float16, torch.float32, True ), # state_dtype=f16 + ( 64, 64, 64, 128, torch.bfloat16, torch.bfloat16, True ), # weight_dtype=bf16 + ( 64, 64, 64, 128, torch.bfloat16, torch.float32, False), # use_out_tensor=False +] +# fmt: on + + class TestSelectiveStateUpdate: """Test class for selective state update kernels.""" @@ -18,36 +48,7 @@ class TestSelectiveStateUpdate: INPUT_DTYPE = torch.bfloat16 MATRIX_A_DTYPE = torch.float32 - @pytest.fixture(params=[1, 64]) - def batch(self, request): - return request.param - - @pytest.fixture(params=[8, 64]) - def nheads(self, request): - return request.param - - @pytest.fixture(params=[64, 128]) - def dim(self, request): - return request.param - - @pytest.fixture(params=[64, 128, 256]) - def dstate(self, request): - return request.param - - @pytest.fixture(params=[torch.float16, torch.bfloat16, torch.float32]) - def state_dtype(self, request): - return request.param - - @pytest.fixture(params=[torch.float32, torch.bfloat16]) - def weight_dtype(self, request): - return request.param - - @pytest.fixture(params=[False, True]) - def use_out_tensor(self, request): - return request.param - - @pytest.fixture - def inputs(self, batch, nheads, dim, dstate, state_dtype, weight_dtype): + def make_inputs(self, batch, nheads, dim, dstate, state_dtype, weight_dtype): """Create test inputs for given parameters.""" return create_test_inputs( batch, @@ -63,8 +64,7 @@ def inputs(self, batch, nheads, dim, dstate, state_dtype, weight_dtype): seed=0, ) - @pytest.fixture - def reference_output(self, inputs): + def make_reference_output(self, inputs): """Compute reference output using triton implementation.""" state_ref = inputs["state_cache"].clone() y_ref = selective_state_update_triton( @@ -83,7 +83,7 @@ def reference_output(self, inputs): ) return y_ref, state_ref - def run_kernel(self, inputs, out=None): + def run_kernel(self, inputs, out=None, algorithm="auto"): """Run the flashinfer kernel and return output.""" return flashinfer.mamba.selective_state_update( inputs["state_cache"], @@ -99,6 +99,7 @@ def run_kernel(self, inputs, out=None): state_batch_indices=inputs["slot_idx"], pad_slot_id=-1, out=out, + algorithm=algorithm, ) def assert_outputs_match(self, y_ref, y_test, msg_prefix=""): @@ -165,20 +166,32 @@ def _print_mismatch_details(self, ref, test, name): f"diff={diff:.6e}, rel_diff={rel_diff:.6e}" ) - def test_output_correctness(self, inputs, reference_output, use_out_tensor): + @pytest.mark.parametrize("algorithm", _get_algorithms()) + @pytest.mark.parametrize( + "batch,nheads,dim,dstate,state_dtype,weight_dtype,use_out_tensor", _BASE_PARAMS + ) + def test_output_correctness( + self, + batch, + nheads, + dim, + dstate, + state_dtype, + weight_dtype, + use_out_tensor, + algorithm, + ): """Test that kernel output matches reference within tolerance.""" - y_ref, state_ref = reference_output + inputs = self.make_inputs(batch, nheads, dim, dstate, state_dtype, weight_dtype) + y_ref, state_ref = self.make_reference_output(inputs) # Prepare output tensor if requested if use_out_tensor: - batch = inputs["x"].shape[0] - nheads = inputs["x"].shape[1] - dim = inputs["x"].shape[2] out = torch.empty(batch, nheads, dim, dtype=self.INPUT_DTYPE, device="cuda") else: out = None - y_test = self.run_kernel(inputs, out=out) + y_test = self.run_kernel(inputs, out=out, algorithm=algorithm) # Verify output tensor identity if provided if use_out_tensor: @@ -186,39 +199,19 @@ def test_output_correctness(self, inputs, reference_output, use_out_tensor): "Returned tensor should be the same object as the provided output tensor" ) - self.assert_outputs_match(y_ref, y_test) - self.assert_states_match(state_ref, inputs["state_cache"], inputs["slot_idx"]) + self.assert_outputs_match(y_ref, y_test, msg_prefix=f"[{algorithm}] ") + self.assert_states_match( + state_ref, + inputs["state_cache"], + inputs["slot_idx"], + msg_prefix=f"[{algorithm}] ", + ) class TestSelectiveStateUpdateWithZ(TestSelectiveStateUpdate): """Test selective_state_update with z tensor (gating).""" - @pytest.fixture(params=[1]) - def batch(self, request): - return request.param - - @pytest.fixture(params=[8]) - def nheads(self, request): - return request.param - - @pytest.fixture(params=[64]) - def dim(self, request): - return request.param - - @pytest.fixture(params=[128]) - def dstate(self, request): - return request.param - - @pytest.fixture(params=[torch.bfloat16]) - def state_dtype(self, request): - return request.param - - @pytest.fixture(params=[torch.bfloat16]) - def weight_dtype(self, request): - return request.param - - @pytest.fixture - def inputs(self, batch, nheads, dim, dstate, state_dtype, weight_dtype): + def make_inputs(self, batch, nheads, dim, dstate, state_dtype, weight_dtype): """Create test inputs with z tensor.""" return create_test_inputs( batch, @@ -234,35 +227,38 @@ def inputs(self, batch, nheads, dim, dstate, state_dtype, weight_dtype): seed=0, ) + @pytest.mark.parametrize("algorithm", _get_algorithms()) + @pytest.mark.parametrize( + "batch,nheads,dim,dstate,state_dtype,weight_dtype,use_out_tensor", + [(64, 64, 64, 128, torch.bfloat16, torch.float32, True)], + ) + def test_output_correctness( + self, + batch, + nheads, + dim, + dstate, + state_dtype, + weight_dtype, + use_out_tensor, + algorithm, + ): + super().test_output_correctness( + batch, + nheads, + dim, + dstate, + state_dtype, + weight_dtype, + use_out_tensor, + algorithm, + ) + class TestSelectiveStateUpdateDisableStateUpdate(TestSelectiveStateUpdate): """Test selective_state_update with disable_state_update=True.""" - @pytest.fixture(params=[1]) - def batch(self, request): - return request.param - - @pytest.fixture(params=[8]) - def nheads(self, request): - return request.param - - @pytest.fixture(params=[128]) - def dim(self, request): - return request.param - - @pytest.fixture(params=[64]) - def dstate(self, request): - return request.param - - @pytest.fixture(params=[torch.bfloat16]) - def state_dtype(self, request): - return request.param - - @pytest.fixture(params=[torch.bfloat16]) - def weight_dtype(self, request): - return request.param - - def run_kernel(self, inputs, out=None): + def run_kernel(self, inputs, out=None, algorithm="auto"): """Run the flashinfer kernel with disable_state_update=True.""" return flashinfer.mamba.selective_state_update( inputs["state_cache"], @@ -279,25 +275,38 @@ def run_kernel(self, inputs, out=None): pad_slot_id=-1, out=out, disable_state_update=True, + algorithm=algorithm, ) - def test_output_correctness(self, inputs, reference_output, use_out_tensor): + @pytest.mark.parametrize( + "batch,nheads,dim,dstate,state_dtype,weight_dtype,use_out_tensor", + [(64, 64, 64, 128, torch.bfloat16, torch.float32, True)], + ) + def test_output_correctness( + self, + batch, + nheads, + dim, + dstate, + state_dtype, + weight_dtype, + use_out_tensor, + algorithm="auto", + ): """Test that kernel output matches reference but state is not updated.""" - y_ref, state_ref = reference_output + inputs = self.make_inputs(batch, nheads, dim, dstate, state_dtype, weight_dtype) + y_ref, _state_ref = self.make_reference_output(inputs) # Save the initial state before running the kernel state_initial = inputs["state_cache"].clone() # Prepare output tensor if requested if use_out_tensor: - batch = inputs["x"].shape[0] - nheads = inputs["x"].shape[1] - dim = inputs["x"].shape[2] out = torch.empty(batch, nheads, dim, dtype=self.INPUT_DTYPE, device="cuda") else: out = None - y_test = self.run_kernel(inputs, out=out) + y_test = self.run_kernel(inputs, out=out, algorithm=algorithm) # Verify output tensor identity if provided if use_out_tensor: @@ -339,20 +348,7 @@ def test_output_correctness(self, inputs, reference_output, use_out_tensor): class TestSelectiveStateUpdateNonContiguous(TestSelectiveStateUpdate): """Test selective_state_update with non-contiguous state cache.""" - @pytest.fixture(params=[128]) - def dstate(self, request): - return request.param - - @pytest.fixture(params=[torch.bfloat16]) - def state_dtype(self, request): - return request.param - - @pytest.fixture(params=[torch.float32]) - def weight_dtype(self, request): - return request.param - - @pytest.fixture - def inputs(self, batch, nheads, dim, dstate, state_dtype, weight_dtype): + def make_inputs(self, batch, nheads, dim, dstate, state_dtype, weight_dtype): """Create test inputs with non-contiguous state cache (2x batch stride).""" noncontiguous_batch_stride = 2 * nheads * dim * dstate @@ -371,8 +367,7 @@ def inputs(self, batch, nheads, dim, dstate, state_dtype, weight_dtype): seed=0, ) - @pytest.fixture - def reference_output(self, inputs): + def make_reference_output(self, inputs): """Compute reference output, preserving non-contiguous strides.""" state_ref = clone_preserving_strides(inputs["state_cache"]) y_ref = selective_state_update_triton( @@ -391,35 +386,38 @@ def reference_output(self, inputs): ) return y_ref, state_ref + @pytest.mark.parametrize("algorithm", _get_algorithms()) + @pytest.mark.parametrize( + "batch,nheads,dim,dstate,state_dtype,weight_dtype,use_out_tensor", + [(64, 64, 64, 128, torch.bfloat16, torch.float32, True)], + ) + def test_output_correctness( + self, + batch, + nheads, + dim, + dstate, + state_dtype, + weight_dtype, + use_out_tensor, + algorithm, + ): + super().test_output_correctness( + batch, + nheads, + dim, + dstate, + state_dtype, + weight_dtype, + use_out_tensor, + algorithm, + ) + class TestSelectiveStateUpdateInt32Indices(TestSelectiveStateUpdate): """Test selective_state_update with int32 state_batch_indices.""" - @pytest.fixture(params=[1]) - def batch(self, request): - return request.param - - @pytest.fixture(params=[8]) - def nheads(self, request): - return request.param - - @pytest.fixture(params=[64]) - def dim(self, request): - return request.param - - @pytest.fixture(params=[128]) - def dstate(self, request): - return request.param - - @pytest.fixture(params=[torch.bfloat16]) - def state_dtype(self, request): - return request.param - - @pytest.fixture(params=[torch.bfloat16]) - def weight_dtype(self, request): - return request.param - - def run_kernel(self, inputs, out=None): + def run_kernel(self, inputs, out=None, algorithm="auto"): """Run the flashinfer kernel with int32 state_batch_indices.""" # Cast slot_idx to int32 slot_idx_int32 = inputs["slot_idx"].to(torch.int32) @@ -438,6 +436,34 @@ def run_kernel(self, inputs, out=None): state_batch_indices=slot_idx_int32, pad_slot_id=-1, out=out, + algorithm=algorithm, + ) + + @pytest.mark.parametrize("algorithm", _get_algorithms()) + @pytest.mark.parametrize( + "batch,nheads,dim,dstate,state_dtype,weight_dtype,use_out_tensor", + [(64, 64, 64, 128, torch.bfloat16, torch.float32, True)], + ) + def test_output_correctness( + self, + batch, + nheads, + dim, + dstate, + state_dtype, + weight_dtype, + use_out_tensor, + algorithm, + ): + super().test_output_correctness( + batch, + nheads, + dim, + dstate, + state_dtype, + weight_dtype, + use_out_tensor, + algorithm, ) diff --git a/tests/mamba/triton_reference/__init__.py b/tests/mamba/triton_reference/__init__.py new file mode 100644 index 0000000000..b7a96808f5 --- /dev/null +++ b/tests/mamba/triton_reference/__init__.py @@ -0,0 +1,6 @@ +""" +Triton reference implementations for Mamba kernels. + +This package contains production-level Triton implementations used as +reference for testing CUDA/CUTLASS kernel implementations. +""" diff --git a/tests/mamba/selective_state_update_triton.py b/tests/mamba/triton_reference/selective_state_update.py similarity index 97% rename from tests/mamba/selective_state_update_triton.py rename to tests/mamba/triton_reference/selective_state_update.py index c40f90612a..88d20bf251 100644 --- a/tests/mamba/selective_state_update_triton.py +++ b/tests/mamba/triton_reference/selective_state_update.py @@ -9,25 +9,10 @@ import torch import triton import triton.language as tl -from packaging import version -PAD_SLOT_ID = -1 - -TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") - -if TRITON3: +from .softplus import softplus - @triton.jit - def softplus(dt): - dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt) - return dt - -else: - - @triton.jit - def softplus(dt): - dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) - return dt +PAD_SLOT_ID = -1 @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) diff --git a/tests/mamba/triton_reference/softplus.py b/tests/mamba/triton_reference/softplus.py new file mode 100644 index 0000000000..b2ce0d58b2 --- /dev/null +++ b/tests/mamba/triton_reference/softplus.py @@ -0,0 +1,26 @@ +# Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/mamba_ssm.py + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright (c) 2024, Tri Dao, Albert Gu. + +import triton +import triton.language as tl +from packaging import version + +TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") + +if TRITON3: + + @triton.jit + def softplus(dt): + dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt) + return dt + +else: + + @triton.jit + def softplus(dt): + dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) + return dt