Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
d0a53b5
adopt reference implementation from sglang
ishovkun Jan 29, 2026
320a72c
Extract create_test_inputs to shared test_utils module
ishovkun Jan 29, 2026
4022f10
Rename test to reflect that it's an single-token test file
ishovkun Jan 29, 2026
a8bc286
Add multi-token support to the interface of selective_state_update
ishovkun Jan 29, 2026
2e70ea4
Refactor selective_state_update: add validation helpers and update param
ishovkun Jan 29, 2026
295ae56
Non-contiguous state
ishovkun Jan 29, 2026
5541624
Simplify code for template dispatching
ishovkun Jan 29, 2026
ab33cc1
Refactor dispatch logic in selective_state_update.cuh
ishovkun Jan 29, 2026
26271a9
Refactor pointer alignement checking away from the logic.
ishovkun Jan 29, 2026
f3f02f5
Support int32 and int64 state_batch_indices in selective_state_update
ishovkun Jan 29, 2026
1cb4ac7
Refactor Mamba selective state update kernel dispatch and add dtype
ishovkun Jan 30, 2026
3265bd5
Merge branch 'flashinfer-ai:main' into main
ishovkun Jan 30, 2026
9d6d35c
Fix simple stp kernel to only write state if a flag is provided
ishovkun Jan 30, 2026
5b5756d
Fix Triton kernel intermediate state caching to match CUDA behavior
ishovkun Jan 30, 2026
e3f751e
Merge branch 'main' of github.com:ishovkun/flashinfer-dev
ishovkun Jan 31, 2026
fb693d0
Add Mamba2 SSD chunk scan test and reorganize Triton refs
ishovkun Feb 3, 2026
0ce5d47
Merge branch 'main' of github.com:ishovkun/flashinfer-dev
ishovkun Feb 17, 2026
304fd59
Enable .jinja templates for mamba
ishovkun Feb 17, 2026
329bfd0
Remove SM100 module, unify SM90+ selective state update handling
ishovkun Feb 17, 2026
f464097
Add algorithm selection to selective_state_update kernels
ishovkun Feb 18, 2026
c65670c
Fix include order: config.inc before header in selective_state_update…
ishovkun Feb 18, 2026
44b6c25
Parallelize consumer warp loads in vertical SSU kernel
ishovkun Feb 18, 2026
eff403c
Reduce test combinations in SSU tests to base + independent deviations
ishovkun Feb 18, 2026
afc7c6a
Add algorithm parameter to selective_state_update tests
ishovkun Feb 19, 2026
74accb0
Merge branch 'flashinfer-ai:main' into main
ishovkun Feb 19, 2026
1d42007
Update selective_state_update instantiations to include SSUAlgorithm
ishovkun Feb 19, 2026
61d88bd
Clarify algorithm selection docstring in selective_state_update
ishovkun Feb 19, 2026
ead4943
Merge branch 'main' of github.com:ishovkun/flashinfer-dev
ishovkun Feb 19, 2026
6f6a3d7
Remove chunk scan combined kernels as they are irrelevant to this PR
ishovkun Feb 19, 2026
de96dd5
Remove ssd_chunk_state.py Triton reference implementation (irrelevant to
ishovkun Feb 19, 2026
4c30f07
Delete test_utils.py
ishovkun Feb 19, 2026
1f1c2f4
Suppress mypy false positive for gen_selective_state_update calls
ishovkun Feb 19, 2026
157ecb5
Move Triton reference kernel to triton_reference subdir and update
ishovkun Feb 19, 2026
f32b63b
mark an unused variable with "_" in a test
ishovkun Feb 19, 2026
2656202
rename an unused test variable to _state_ref
ishovkun Feb 19, 2026
5580d28
Refactor Triton reference import for selective_state_update
ishovkun Feb 19, 2026
58f56cd
Fixes aot compilation of the gdn_prefill_sm90 module
ishovkun Feb 20, 2026
5d8184e
Substantially reduce the nubmer of SSU aot compilation units. Limited to
ishovkun Feb 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 12 additions & 41 deletions benchmarks/routines/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,8 @@
limitations under the License.
"""

# ==============================================================================
# Triton reference implementation for selective_state_update.
# Imported from tests/mamba/selective_state_update_triton.py to avoid code
# duplication. See that file for the canonical Triton kernel source.
# ==============================================================================

import importlib
import os
import sys
from collections import defaultdict

import numpy as np
Expand All @@ -30,6 +24,14 @@
import flashinfer
from flashinfer.testing.utils import bench_gpu_time

# Add tests/mamba to sys.path so triton_reference is importable as a package
_repo_root = os.path.normpath(
os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..")
)
_tests_mamba = os.path.join(_repo_root, "tests", "mamba")
if _tests_mamba not in sys.path:
sys.path.insert(0, _tests_mamba)

from .flashinfer_benchmark_utils import (
dtype_str_to_torch_dtype,
get_device,
Expand All @@ -38,40 +40,9 @@
filter_backends_by_compute_capability,
)

# ---- Import Triton reference kernel from tests/mamba/ ----
# The canonical Triton selective_state_update lives in tests/mamba/selective_state_update_triton.py.
# We import it here rather than duplicating ~400 lines of kernel code.


def _import_triton_reference():
"""Import selective_state_update_triton from tests/mamba/.

Uses importlib to load the module directly by file path, avoiding sys.path
pollution and fragile relative path assumptions.
"""
# Resolve path: benchmarks/routines/mamba.py -> ../../tests/mamba/selective_state_update_triton.py
_this_dir = os.path.dirname(os.path.abspath(__file__))
_repo_root = os.path.normpath(os.path.join(_this_dir, "..", ".."))
_triton_ref_path = os.path.join(
_repo_root, "tests", "mamba", "selective_state_update_triton.py"
)

if not os.path.isfile(_triton_ref_path):
raise ImportError(
f"Cannot find Triton reference kernel at: {_triton_ref_path}\n"
f"Expected location: <repo>/tests/mamba/selective_state_update_triton.py\n"
f"Make sure you are running from within the FlashInfer repository."
)

spec = importlib.util.spec_from_file_location(
"selective_state_update_triton", _triton_ref_path
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module.selective_state_update_triton


selective_state_update_triton_reference = _import_triton_reference()
from triton_reference.selective_state_update import (
selective_state_update_triton as selective_state_update_triton_reference,
)


# ==============================================================================
Expand Down
3 changes: 2 additions & 1 deletion csrc/flashinfer_mamba_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ void selective_state_update(
bool disable_state_update,
Optional<TensorView> intermediate_states_buffer, // (batch, cache_steps, nheads, dim, dstate)
Optional<TensorView> intermediate_state_indices, // (batch,)
int64_t cache_steps);
int64_t cache_steps,
int64_t algorithm); // SSUAlgorithm: 0=auto, 1=simple, 2=vertical, 3=horizontal

} // namespace flashinfer::mamba

Expand Down
240 changes: 30 additions & 210 deletions csrc/selective_state_update.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// clang-format off
// config.inc MUST come before the header: it defines DIM, DSTATE, NTOKENS_MTP
// constexprs that the header's function templates rely on. Reordering breaks compilation.
// NOTE: the .inc file is generated from the jinja templates
#include "selective_state_update_config.inc"
#include <flashinfer/mamba/selective_state_update.cuh>
#include <sstream>

// clang-format on
#include "tvm_ffi_utils.h"

using namespace flashinfer;
Expand Down Expand Up @@ -124,87 +128,13 @@ inline void validate_dtype_consistency(
}
}

// Helper to convert dtype code to string for error messages
inline const char* dtype_code_to_string(int64_t code) {
if (code == bfloat16_code) return "bfloat16";
if (code == float16_code) return "float16";
if (code == float32_code) return "float32";
return "unknown";
}

// Type traits to map dtype codes to C++ types
template <int64_t code>
struct DTypeToType;

template <>
struct DTypeToType<bfloat16_code> {
using type = nv_bfloat16;
};
template <>
struct DTypeToType<float16_code> {
using type = half;
};
template <>
struct DTypeToType<float32_code> {
using type = float;
};
template <>
struct DTypeToType<int32_code> {
using type = int32_t;
};
template <>
struct DTypeToType<int64_code> {
using type = int64_t;
};

// Allowed dtype combinations: {state_code, input_code, weight_code, matrixA_code, stateIndex_code}
constexpr std::tuple<int64_t, int64_t, int64_t, int64_t, int64_t> allowed_dtype_combos[] = {
{bfloat16_code, bfloat16_code, bfloat16_code, float32_code, int32_code},
{float16_code, bfloat16_code, bfloat16_code, float32_code, int32_code},
{float32_code, bfloat16_code, bfloat16_code, float32_code, int32_code},
{bfloat16_code, bfloat16_code, float32_code, float32_code, int32_code},
{float16_code, bfloat16_code, float32_code, float32_code, int32_code},
{float32_code, bfloat16_code, float32_code, float32_code, int32_code},
{bfloat16_code, bfloat16_code, bfloat16_code, float32_code, int64_code},
{float16_code, bfloat16_code, bfloat16_code, float32_code, int64_code},
{float32_code, bfloat16_code, bfloat16_code, float32_code, int64_code},
{bfloat16_code, bfloat16_code, float32_code, float32_code, int64_code},
{float16_code, bfloat16_code, float32_code, float32_code, int64_code},
{float32_code, bfloat16_code, float32_code, float32_code, int64_code},
};

// Helper to dispatch to the right template instantiation for STP
template <int64_t state_code, int64_t input_code, int64_t weight_code, int64_t matrixA_code,
int64_t stateIndex_code>
void dispatchCombo(SelectiveStateUpdateParams& p, cudaStream_t stream) {
using state_t = typename DTypeToType<state_code>::type;
using input_t = typename DTypeToType<input_code>::type;
using weight_t = typename DTypeToType<weight_code>::type;
using matrixA_t = typename DTypeToType<matrixA_code>::type;
using stateIndex_t = typename DTypeToType<stateIndex_code>::type;
invokeSelectiveStateUpdate<input_t, weight_t, matrixA_t, state_t, stateIndex_t>(p, stream);
}

// Helper to dispatch to the right template instantiation for MTP
template <int64_t state_code, int64_t input_code, int64_t weight_code, int64_t matrixA_code,
int64_t stateIndex_code>
void dispatchComboMTP(mtp::SelectiveStateMTPParams& p, cudaStream_t stream) {
using state_t = typename DTypeToType<state_code>::type;
using input_t = typename DTypeToType<input_code>::type;
using weight_t = typename DTypeToType<weight_code>::type;
using matrixA_t = typename DTypeToType<matrixA_code>::type;
using stateIndex_t = typename DTypeToType<stateIndex_code>::type;
mtp::invokeSelectiveStateUpdateMTP<input_t, weight_t, matrixA_t, state_t, stateIndex_t>(p,
stream);
}

void run_selective_state_update_stp(TensorView const& state, TensorView const& x,
TensorView const& dt, TensorView const& A, TensorView const& B,
TensorView const& C, TensorView const& D,
Optional<TensorView> z, Optional<TensorView> dt_bias,
bool dt_softplus, Optional<TensorView> state_batch_indices,
int64_t pad_slot_id, Optional<TensorView> out,
bool disable_state_update) {
bool disable_state_update, int64_t algorithm) {
// Extract dimensions from input tensors
auto const batch = x.size(0);
auto const state_cache_size = state.size(0);
Expand Down Expand Up @@ -344,64 +274,8 @@ void run_selective_state_update_stp(TensorView const& state, TensorView const& x
ffi::CUDADeviceGuard device_guard(state.device().device_id);
const cudaStream_t stream = get_stream(state.device());

// Dispatch based on dtype combination
DLDataType state_dtype = state.dtype();
DLDataType input_dtype = x.dtype();
DLDataType weight_dtype = dt.dtype();
DLDataType matrixA_dtype = A.dtype();
int64_t state_dtype_code = encode_dlpack_dtype(state_dtype);
int64_t input_dtype_code = encode_dlpack_dtype(input_dtype);
int64_t weight_dtype_code = encode_dlpack_dtype(weight_dtype);
int64_t matrixA_dtype_code = encode_dlpack_dtype(matrixA_dtype);

// Get state_batch_indices dtype, default to int32 if not provided
int64_t stateIndex_dtype_code = int32_code;
if (state_batch_indices.has_value()) {
DLDataType stateIndex_dtype = state_batch_indices.value().dtype();
stateIndex_dtype_code = encode_dlpack_dtype(stateIndex_dtype);
}

// Dispatch kernel based on dtype combination
auto dtype_key = std::make_tuple(state_dtype_code, input_dtype_code, weight_dtype_code,
matrixA_dtype_code, stateIndex_dtype_code);

// Compile-time recursive dispatcher using Y-combinator pattern for lambda self-recursion
auto tryDispatch = [&](const auto& key, auto idx, auto& self) -> bool {
constexpr size_t I = decltype(idx)::value;
if constexpr (I < std::size(allowed_dtype_combos)) {
constexpr auto combo = allowed_dtype_combos[I];
if (key == combo) {
constexpr auto s = std::get<0>(combo);
constexpr auto i = std::get<1>(combo);
constexpr auto w = std::get<2>(combo);
constexpr auto m = std::get<3>(combo);
constexpr auto si = std::get<4>(combo);
dispatchCombo<s, i, w, m, si>(p, stream);
return true;
}
return self(key, std::integral_constant<size_t, I + 1>{}, self);
}
return false;
};

// Dispatch using compile-time type traits
if (!tryDispatch(dtype_key, std::integral_constant<size_t, 0>{}, tryDispatch)) {
// Unsupported dtype combination - build error message dynamically
std::ostringstream error_msg;
error_msg << "Unsupported dtype combination for selective_state_update: " << "state_dtype="
<< state_dtype.code << ":" << state_dtype.bits << ", "
<< "input_dtype=" << input_dtype.code << ":" << input_dtype.bits << ", "
<< "weight_dtype=" << weight_dtype.code << ":" << weight_dtype.bits << ", "
<< "matrixA_dtype=" << matrixA_dtype.code << ":" << matrixA_dtype.bits
<< ". Supported combos include:\n";
for (const auto& combo : allowed_dtype_combos) {
error_msg << " (state=" << dtype_code_to_string(std::get<0>(combo))
<< ", input=" << dtype_code_to_string(std::get<1>(combo))
<< ", weight=" << dtype_code_to_string(std::get<2>(combo))
<< ", matrixA=" << dtype_code_to_string(std::get<3>(combo)) << ")\n";
}
TVM_FFI_ICHECK(false) << error_msg.str();
}
auto algo = static_cast<SSUAlgorithm>(algorithm);
invokeSelectiveStateUpdate<input_t, weight_t, matrixA_t, state_t, stateIndex_t>(p, algo, stream);
}

void run_selective_state_update_mtp(
Expand All @@ -410,7 +284,7 @@ void run_selective_state_update_mtp(
Optional<TensorView> dt_bias, bool dt_softplus, Optional<TensorView> state_batch_indices,
int64_t pad_slot_id, Optional<TensorView> out, bool disable_state_update,
Optional<TensorView> intermediate_states_buffer,
Optional<TensorView> intermediate_state_indices, int64_t cache_steps) {
Optional<TensorView> intermediate_state_indices, int64_t cache_steps, int64_t algorithm) {
// Extract dimensions from input tensors
auto const batch = x.size(0);
auto const ntokens_mtp = x.size(1);
Expand Down Expand Up @@ -505,6 +379,15 @@ void run_selective_state_update_mtp(
validate_intermediate_state_indices(intermediate_state_indices, batch);
validate_intermediate_states_buffer(intermediate_states_buffer);

// Validate that state_batch_indices and intermediate_state_indices have the same dtype
if (state_batch_indices.has_value() && intermediate_state_indices.has_value()) {
DLDataType state_batch_idx_dtype = state_batch_indices.value().dtype();
DLDataType intermediate_idx_dtype = intermediate_state_indices.value().dtype();
FLASHINFER_CHECK(state_batch_idx_dtype.code == intermediate_idx_dtype.code &&
state_batch_idx_dtype.bits == intermediate_idx_dtype.bits,
"state_batch_indices and intermediate_state_indices must have the same dtype");
}

// Validate cache_steps is non-negative
FLASHINFER_CHECK(cache_steps >= 0, "cache_steps must be non-negative, got ", cache_steps);

Expand Down Expand Up @@ -588,75 +471,9 @@ void run_selective_state_update_mtp(
ffi::CUDADeviceGuard device_guard(state.device().device_id);
const cudaStream_t stream = get_stream(state.device());

// Dispatch based on dtype combination
DLDataType state_dtype = state.dtype();
DLDataType input_dtype = x.dtype();
DLDataType weight_dtype = dt.dtype();
DLDataType matrixA_dtype = A.dtype();
int64_t state_dtype_code = encode_dlpack_dtype(state_dtype);
int64_t input_dtype_code = encode_dlpack_dtype(input_dtype);
int64_t weight_dtype_code = encode_dlpack_dtype(weight_dtype);
int64_t matrixA_dtype_code = encode_dlpack_dtype(matrixA_dtype);

// Get stateIndex dtype from whichever index tensor is available
// If both are provided, they must have the same dtype
int64_t stateIndex_dtype_code = int32_code; // default
if (state_batch_indices.has_value() && intermediate_state_indices.has_value()) {
DLDataType state_batch_idx_dtype = state_batch_indices.value().dtype();
DLDataType intermediate_idx_dtype = intermediate_state_indices.value().dtype();
FLASHINFER_CHECK(state_batch_idx_dtype.code == intermediate_idx_dtype.code &&
state_batch_idx_dtype.bits == intermediate_idx_dtype.bits,
"state_batch_indices and intermediate_state_indices must have the same dtype");
stateIndex_dtype_code = encode_dlpack_dtype(state_batch_idx_dtype);
} else if (state_batch_indices.has_value()) {
DLDataType state_batch_idx_dtype = state_batch_indices.value().dtype();
stateIndex_dtype_code = encode_dlpack_dtype(state_batch_idx_dtype);
} else if (intermediate_state_indices.has_value()) {
DLDataType intermediate_idx_dtype = intermediate_state_indices.value().dtype();
stateIndex_dtype_code = encode_dlpack_dtype(intermediate_idx_dtype);
}

// Dispatch kernel based on dtype combination
auto dtype_key = std::make_tuple(state_dtype_code, input_dtype_code, weight_dtype_code,
matrixA_dtype_code, stateIndex_dtype_code);

// Compile-time recursive dispatcher using Y-combinator pattern for lambda self-recursion
auto tryDispatch = [&](const auto& key, auto idx, auto& self) -> bool {
constexpr size_t I = decltype(idx)::value;
if constexpr (I < std::size(allowed_dtype_combos)) {
constexpr auto combo = allowed_dtype_combos[I];
if (key == combo) {
constexpr auto s = std::get<0>(combo);
constexpr auto i = std::get<1>(combo);
constexpr auto w = std::get<2>(combo);
constexpr auto m = std::get<3>(combo);
constexpr auto si = std::get<4>(combo);
dispatchComboMTP<s, i, w, m, si>(p, stream);
return true;
}
return self(key, std::integral_constant<size_t, I + 1>{}, self);
}
return false;
};

// Dispatch using compile-time type traits
if (!tryDispatch(dtype_key, std::integral_constant<size_t, 0>{}, tryDispatch)) {
// Unsupported dtype combination - build error message dynamically
std::ostringstream error_msg;
error_msg << "Unsupported dtype combination for selective_state_update: " << "state_dtype="
<< state_dtype.code << ":" << state_dtype.bits << ", "
<< "input_dtype=" << input_dtype.code << ":" << input_dtype.bits << ", "
<< "weight_dtype=" << weight_dtype.code << ":" << weight_dtype.bits << ", "
<< "matrixA_dtype=" << matrixA_dtype.code << ":" << matrixA_dtype.bits
<< ". Supported combos include:\n";
for (const auto& combo : allowed_dtype_combos) {
error_msg << " (state=" << dtype_code_to_string(std::get<0>(combo))
<< ", input=" << dtype_code_to_string(std::get<1>(combo))
<< ", weight=" << dtype_code_to_string(std::get<2>(combo))
<< ", matrixA=" << dtype_code_to_string(std::get<3>(combo)) << ")\n";
}
TVM_FFI_ICHECK(false) << error_msg.str();
}
auto algo = static_cast<SSUAlgorithm>(algorithm);
mtp::invokeSelectiveStateUpdateMTP<input_t, weight_t, matrixA_t, state_t, stateIndex_t>(p, algo,
stream);
}

// =============================================================================
Expand All @@ -668,14 +485,17 @@ void selective_state_update(TensorView state, TensorView x, TensorView dt, Tenso
Optional<TensorView> state_batch_indices, int64_t pad_slot_id,
TensorView output, bool disable_state_update,
Optional<TensorView> intermediate_states_buffer,
Optional<TensorView> intermediate_state_indices, int64_t cache_steps) {
Optional<TensorView> intermediate_state_indices, int64_t cache_steps,
int64_t algorithm) {
if (x.dim() == 3) {
run_selective_state_update_stp(state, x, dt, A, B, C, D, z, dt_bias, dt_softplus,
state_batch_indices, pad_slot_id, output, disable_state_update);
state_batch_indices, pad_slot_id, output, disable_state_update,
algorithm);
} else if (x.dim() == 4) {
run_selective_state_update_mtp(
state, x, dt, A, B, C, D, z, dt_bias, dt_softplus, state_batch_indices, pad_slot_id, output,
disable_state_update, intermediate_states_buffer, intermediate_state_indices, cache_steps);
run_selective_state_update_mtp(state, x, dt, A, B, C, D, z, dt_bias, dt_softplus,
state_batch_indices, pad_slot_id, output, disable_state_update,
intermediate_states_buffer, intermediate_state_indices,
cache_steps, algorithm);
} else {
FLASHINFER_CHECK(false,
"x must have 3 dimensions (single-token) or 4 dimensions (multi-token), got ",
Expand Down
14 changes: 14 additions & 0 deletions csrc/selective_state_update_customize_config.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cstdint>

using state_t = {{ state_dtype }};
using input_t = {{ input_dtype }};
using weight_t = {{ weight_dtype }};
using matrixA_t = {{ matrixA_dtype }};
using stateIndex_t = {{ stateIndex_dtype }};

constexpr int DIM = {{ dim }};
constexpr int DSTATE = {{ dstate }};
constexpr int NTOKENS_MTP = {{ ntokens_mtp }};
Loading
Loading