Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
d0a53b5
adopt reference implementation from sglang
ishovkun Jan 29, 2026
320a72c
Extract create_test_inputs to shared test_utils module
ishovkun Jan 29, 2026
4022f10
Rename test to reflect that it's an single-token test file
ishovkun Jan 29, 2026
a8bc286
Add multi-token support to the interface of selective_state_update
ishovkun Jan 29, 2026
2e70ea4
Refactor selective_state_update: add validation helpers and update param
ishovkun Jan 29, 2026
295ae56
Non-contiguous state
ishovkun Jan 29, 2026
5541624
Simplify code for template dispatching
ishovkun Jan 29, 2026
ab33cc1
Refactor dispatch logic in selective_state_update.cuh
ishovkun Jan 29, 2026
26271a9
Refactor pointer alignement checking away from the logic.
ishovkun Jan 29, 2026
f3f02f5
Support int32 and int64 state_batch_indices in selective_state_update
ishovkun Jan 29, 2026
1cb4ac7
Refactor Mamba selective state update kernel dispatch and add dtype
ishovkun Jan 30, 2026
3265bd5
Merge branch 'flashinfer-ai:main' into main
ishovkun Jan 30, 2026
9d6d35c
Fix simple stp kernel to only write state if a flag is provided
ishovkun Jan 30, 2026
5b5756d
Fix Triton kernel intermediate state caching to match CUDA behavior
ishovkun Jan 30, 2026
e3f751e
Merge branch 'main' of github.com:ishovkun/flashinfer-dev
ishovkun Jan 31, 2026
fb693d0
Add Mamba2 SSD chunk scan test and reorganize Triton refs
ishovkun Feb 3, 2026
0ce5d47
Merge branch 'main' of github.com:ishovkun/flashinfer-dev
ishovkun Feb 17, 2026
304fd59
Enable .jinja templates for mamba
ishovkun Feb 17, 2026
329bfd0
Remove SM100 module, unify SM90+ selective state update handling
ishovkun Feb 17, 2026
f464097
Add algorithm selection to selective_state_update kernels
ishovkun Feb 18, 2026
c65670c
Fix include order: config.inc before header in selective_state_update…
ishovkun Feb 18, 2026
44b6c25
Parallelize consumer warp loads in vertical SSU kernel
ishovkun Feb 18, 2026
eff403c
Reduce test combinations in SSU tests to base + independent deviations
ishovkun Feb 18, 2026
afc7c6a
Add algorithm parameter to selective_state_update tests
ishovkun Feb 19, 2026
74accb0
Merge branch 'flashinfer-ai:main' into main
ishovkun Feb 19, 2026
1d42007
Update selective_state_update instantiations to include SSUAlgorithm
ishovkun Feb 19, 2026
61d88bd
Clarify algorithm selection docstring in selective_state_update
ishovkun Feb 19, 2026
ead4943
Merge branch 'main' of github.com:ishovkun/flashinfer-dev
ishovkun Feb 19, 2026
6f6a3d7
Remove chunk scan combined kernels as they are irrelevant to this PR
ishovkun Feb 19, 2026
de96dd5
Remove ssd_chunk_state.py Triton reference implementation (irrelevant to
ishovkun Feb 19, 2026
4c30f07
Delete test_utils.py
ishovkun Feb 19, 2026
1f1c2f4
Suppress mypy false positive for gen_selective_state_update calls
ishovkun Feb 19, 2026
157ecb5
Move Triton reference kernel to triton_reference subdir and update
ishovkun Feb 19, 2026
f32b63b
mark an unused variable with "_" in a test
ishovkun Feb 19, 2026
2656202
rename an unused test variable to _state_ref
ishovkun Feb 19, 2026
5580d28
Refactor Triton reference import for selective_state_update
ishovkun Feb 19, 2026
8738964
Add int16 state quantization with block scaling to
ishovkun Feb 19, 2026
02db096
Add int16 quantized state support to selective_state_update
ishovkun Feb 20, 2026
58f56cd
Fixes aot compilation of the gdn_prefill_sm90 module
ishovkun Feb 20, 2026
d4e33de
Merge branch 'main' into ssu_int16
ishovkun Feb 20, 2026
5d8184e
Substantially reduce the nubmer of SSU aot compilation units. Limited to
ishovkun Feb 20, 2026
9775391
Merge branch 'main' into ssu_int16
ishovkun Feb 20, 2026
7f1173f
Add int16 support for block scaling in selective_state_update kernel
ishovkun Feb 20, 2026
35cc7ba
Add int16 block scaling support to selective_state_update MTP
ishovkun Feb 20, 2026
6cf61b7
Fix rNewState array size calculation for scaleState flag
ishovkun Feb 20, 2026
e9ab619
Refactor selective_state_update to use state_scale dtype
ishovkun Feb 23, 2026
60b627e
Add Philox-4x32 PRNG matching Triton tl.randint and tests
ishovkun Feb 24, 2026
b873d10
Refactor philox_randint to template and add rounding tests
ishovkun Feb 24, 2026
3292662
Stochastic rounding support for fp16 state update (plubming)
ishovkun Feb 25, 2026
b206a5f
Implement stochastic rounding for fp16 state in selective_state_update
ishovkun Feb 25, 2026
c70efbd
Optimize Philox PRNG usage in selective_state_update kernel
ishovkun Feb 26, 2026
181d80d
Fix Philox random offset calculation for state updates
ishovkun Feb 26, 2026
fd1af7c
Remove .plans directory from .gitignore
ishovkun Feb 26, 2026
ff8dfde
Merge branch 'ssu_int16': int16 block-scaled state and stochastic rou…
ishovkun Feb 26, 2026
60bbb5d
Merge remote-tracking branch 'upstream/main'
ishovkun Feb 26, 2026
c01eced
Replace asserts with if checks in the python wrapper
ishovkun Feb 27, 2026
0bf77aa
Remove redundant dtype check for state_batch_indices and
ishovkun Feb 27, 2026
deb48a8
Use tuples instead of lists for parameter sets in tests
ishovkun Feb 27, 2026
e1d9dc3
Fix selective_state_update argument order for state_scale_dtype
ishovkun Feb 27, 2026
b30db63
Replace float pointer casts with __float_as_uint in conversion kernels
ishovkun Feb 27, 2026
317f6bb
Handle zero max value in state scaling calculations
ishovkun Feb 27, 2026
852b9a2
Add static_assert for fp16 state in SR branch to check that an edge case
ishovkun Feb 27, 2026
2ca355e
if not philox_rounds > 0:` → `if philox_rounds <= 0:` — same semantics,
ishovkun Feb 27, 2026
fa95f9d
Fix SM gencode flag to match current device compute capability (fixed
ishovkun Mar 2, 2026
3b20f6e
Fix Triton reference to skip stochastic rounding on pre-SM100a GPUs
ishovkun Mar 2, 2026
f93b325
Rename unused state_dtype and kwargs parameters to _state_dtype and
ishovkun Mar 2, 2026
6b65cff
Change _SR_PARAMS from list to tuple in test_selective_state_update_stp
ishovkun Mar 2, 2026
4cb8c40
Restore issue-claim.yml accidentally deleted during rebase
ishovkun Mar 3, 2026
10a36c4
Refactor selective_state_update to use device-side rand_seed tensor
ishovkun Mar 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ flashinfer/cute_dsl/benchmark_gated_delta_rule.py
# vscode
.vscode/

# zed text editor
.zed/
Comment thread
yzh119 marked this conversation as resolved.
.rules

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
5 changes: 4 additions & 1 deletion csrc/flashinfer_mamba_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,13 @@ void selective_state_update(
bool dt_softplus,
Optional<TensorView> state_batch_indices, // (batch,)
int64_t pad_slot_id,
TensorView output, // same as x
Optional<TensorView> state_scale, // float32: (state_cache_size, nheads, dim)
TensorView output, // same as x
bool disable_state_update,
Optional<TensorView> intermediate_states_buffer, // (batch, cache_steps, nheads, dim, dstate)
Optional<TensorView> intermediate_state_indices, // (batch,)
Optional<TensorView> intermediate_state_scales, // float32: (batch, cache_steps, nheads, dim)
Optional<TensorView> rand_seed, // device-side int64 tensor for Philox rounding
int64_t cache_steps,
int64_t algorithm); // SSUAlgorithm: 0=auto, 1=simple, 2=vertical, 3=horizontal

Expand Down
106 changes: 83 additions & 23 deletions csrc/selective_state_update.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
// clang-format off
// config.inc MUST come before the header: it defines DIM, DSTATE, NTOKENS_MTP
// constexprs that the header's function templates rely on. Reordering breaks compilation.
// NOTE: the .inc file is generated from the jinja templates
// NOTE: the .inc file is generated from the jinja template csrc/selective_state_update_customize_config.jinja
#include "selective_state_update_config.inc"
#include <flashinfer/mamba/selective_state_update.cuh>
// clang-format on
Expand Down Expand Up @@ -99,6 +99,22 @@ inline void validate_intermediate_states_buffer(
CHECK_CONTIGUOUS(intermediate_states_buffer.value());
}

inline void validate_state_scale(Optional<TensorView> const& state_scale, int64_t state_cache_size,
int64_t nheads, int64_t dim) {
if (!state_scale.has_value()) return;
auto const& scale = state_scale.value();
CHECK_CUDA(scale);
CHECK_DIM(3, scale); // state_scale: {state_cache_size, nheads, dim}
FLASHINFER_CHECK(scale.size(0) == state_cache_size,
"state_scale.size(0) must equal state_cache_size");
FLASHINFER_CHECK(scale.size(1) == nheads, "state_scale.size(1) must equal nheads");
FLASHINFER_CHECK(scale.size(2) == dim, "state_scale.size(2) must equal dim");
// Inner dims (nheads, dim) must be contiguous
FLASHINFER_CHECK(scale.stride(2) == 1, "state_scale.stride(2) must be 1, got ", scale.stride(2));
FLASHINFER_CHECK(scale.stride(1) == dim, "state_scale.stride(1) must equal dim, got ",
scale.stride(1));
}

// Validates dtype consistency across tensors
inline void validate_dtype_consistency(
TensorView const& state, TensorView const& dt, TensorView const& D, TensorView const& x,
Expand Down Expand Up @@ -133,8 +149,9 @@ void run_selective_state_update_stp(TensorView const& state, TensorView const& x
TensorView const& C, TensorView const& D,
Optional<TensorView> z, Optional<TensorView> dt_bias,
bool dt_softplus, Optional<TensorView> state_batch_indices,
int64_t pad_slot_id, Optional<TensorView> out,
bool disable_state_update, int64_t algorithm) {
Optional<TensorView> state_scale, int64_t pad_slot_id,
Optional<TensorView> out, bool disable_state_update,
Optional<TensorView> rand_seed, int64_t algorithm) {
// Extract dimensions from input tensors
auto const batch = x.size(0);
auto const state_cache_size = state.size(0);
Expand Down Expand Up @@ -219,6 +236,7 @@ void run_selective_state_update_stp(TensorView const& state, TensorView const& x

// Validate dtype consistency
validate_dtype_consistency(state, dt, D, x, B, C, dt_bias, z, out);
validate_state_scale(state_scale, state_cache_size, nheads, dim);

// Initialize params struct
SelectiveStateUpdateParams p;
Expand Down Expand Up @@ -248,6 +266,18 @@ void run_selective_state_update_stp(TensorView const& state, TensorView const& x
if (state_batch_indices.has_value()) {
p.state_batch_indices = const_cast<void*>(state_batch_indices.value().data_ptr());
}
if (state_scale.has_value()) {
p.state_scale = state_scale.value().data_ptr();
p.state_scale_stride_batch = state_scale.value().stride(0);
}
if (rand_seed.has_value()) {
auto const& rs = rand_seed.value();
CHECK_CUDA(rs);
FLASHINFER_CHECK(rs.numel() == 1,
"rand_seed must be a single-element tensor, got numel=", rs.numel());
FLASHINFER_CHECK(rs.dtype().code == kDLInt && rs.dtype().bits == 64, "rand_seed must be int64");
p.rand_seed = static_cast<const int64_t*>(rs.data_ptr());
}

// Copy pointers
p.state = const_cast<void*>(state.data_ptr());
Expand Down Expand Up @@ -275,16 +305,18 @@ void run_selective_state_update_stp(TensorView const& state, TensorView const& x
const cudaStream_t stream = get_stream(state.device());

auto algo = static_cast<SSUAlgorithm>(algorithm);
invokeSelectiveStateUpdate<input_t, weight_t, matrixA_t, state_t, stateIndex_t>(p, algo, stream);
invokeSelectiveStateUpdate<input_t, weight_t, matrixA_t, state_t, stateIndex_t, state_scale_t>(
p, algo, stream);
}

void run_selective_state_update_mtp(
TensorView const& state, TensorView const& x, TensorView const& dt, TensorView const& A,
TensorView const& B, TensorView const& C, TensorView const& D, Optional<TensorView> z,
Optional<TensorView> dt_bias, bool dt_softplus, Optional<TensorView> state_batch_indices,
int64_t pad_slot_id, Optional<TensorView> out, bool disable_state_update,
Optional<TensorView> intermediate_states_buffer,
Optional<TensorView> intermediate_state_indices, int64_t cache_steps, int64_t algorithm) {
Optional<TensorView> state_scale, int64_t pad_slot_id, Optional<TensorView> out,
bool disable_state_update, Optional<TensorView> intermediate_states_buffer,
Optional<TensorView> intermediate_state_indices, Optional<TensorView> intermediate_state_scales,
Optional<TensorView> rand_seed, int64_t cache_steps, int64_t algorithm) {
// Extract dimensions from input tensors
auto const batch = x.size(0);
auto const ntokens_mtp = x.size(1);
Expand Down Expand Up @@ -378,6 +410,7 @@ void run_selective_state_update_mtp(
validate_dtype_consistency(state, dt, D, x, B, C, dt_bias, z, out, intermediate_states_buffer);
validate_intermediate_state_indices(intermediate_state_indices, batch);
validate_intermediate_states_buffer(intermediate_states_buffer);
validate_state_scale(state_scale, state_cache_size, nheads, dim);

// Validate that state_batch_indices and intermediate_state_indices have the same dtype
if (state_batch_indices.has_value() && intermediate_state_indices.has_value()) {
Expand Down Expand Up @@ -435,6 +468,10 @@ void run_selective_state_update_mtp(
if (state_batch_indices.has_value()) {
p.state_batch_indices = const_cast<void*>(state_batch_indices.value().data_ptr());
}
if (state_scale.has_value()) {
p.state_scale = state_scale.value().data_ptr();
p.state_scale_stride_batch = state_scale.value().stride(0);
}

if (intermediate_states_buffer.has_value()) {
p.intermediate_states = const_cast<void*>(intermediate_states_buffer.value().data_ptr());
Expand All @@ -445,6 +482,30 @@ void run_selective_state_update_mtp(
p.intermediate_state_indices = const_cast<void*>(intermediate_state_indices.value().data_ptr());
}

if (intermediate_state_scales.has_value()) {
auto const& iscales = intermediate_state_scales.value();
CHECK_CUDA(iscales);
CHECK_CONTIGUOUS(iscales);
CHECK_DIM(4, iscales); // (batch, cache_steps, nheads, dim)
FLASHINFER_CHECK(iscales.size(0) == batch,
"intermediate_state_scales.size(0) must equal batch");
FLASHINFER_CHECK(iscales.size(1) == cache_steps,
"intermediate_state_scales.size(1) must equal cache_steps");
FLASHINFER_CHECK(iscales.size(2) == nheads,
"intermediate_state_scales.size(2) must equal nheads");
FLASHINFER_CHECK(iscales.size(3) == dim, "intermediate_state_scales.size(3) must equal dim");
p.intermediate_state_scales = iscales.data_ptr();
p.intermediate_state_scales_stride_batch = iscales.stride(0);
}
if (rand_seed.has_value()) {
auto const& rs = rand_seed.value();
CHECK_CUDA(rs);
FLASHINFER_CHECK(rs.numel() == 1,
"rand_seed must be a single-element tensor, got numel=", rs.numel());
FLASHINFER_CHECK(rs.dtype().code == kDLInt && rs.dtype().bits == 64, "rand_seed must be int64");
p.rand_seed = static_cast<const int64_t*>(rs.data_ptr());
}

// Copy pointers
p.state = const_cast<void*>(state.data_ptr());
p.x = const_cast<void*>(x.data_ptr());
Expand Down Expand Up @@ -472,30 +533,29 @@ void run_selective_state_update_mtp(
const cudaStream_t stream = get_stream(state.device());

auto algo = static_cast<SSUAlgorithm>(algorithm);
mtp::invokeSelectiveStateUpdateMTP<input_t, weight_t, matrixA_t, state_t, stateIndex_t>(p, algo,
stream);
mtp::invokeSelectiveStateUpdateMTP<input_t, weight_t, matrixA_t, state_t, stateIndex_t,
state_scale_t>(p, algo, stream);
}

// =============================================================================
// Generic dispatcher - routes to single-token or multi-token based on x.dim()
// =============================================================================
void selective_state_update(TensorView state, TensorView x, TensorView dt, TensorView A,
TensorView B, TensorView C, TensorView D, Optional<TensorView> z,
Optional<TensorView> dt_bias, bool dt_softplus,
Optional<TensorView> state_batch_indices, int64_t pad_slot_id,
TensorView output, bool disable_state_update,
Optional<TensorView> intermediate_states_buffer,
Optional<TensorView> intermediate_state_indices, int64_t cache_steps,
int64_t algorithm) {
void selective_state_update(
TensorView state, TensorView x, TensorView dt, TensorView A, TensorView B, TensorView C,
TensorView D, Optional<TensorView> z, Optional<TensorView> dt_bias, bool dt_softplus,
Optional<TensorView> state_batch_indices, int64_t pad_slot_id, Optional<TensorView> state_scale,
TensorView output, bool disable_state_update, Optional<TensorView> intermediate_states_buffer,
Optional<TensorView> intermediate_state_indices, Optional<TensorView> intermediate_state_scales,
Optional<TensorView> rand_seed, int64_t cache_steps, int64_t algorithm) {
if (x.dim() == 3) {
run_selective_state_update_stp(state, x, dt, A, B, C, D, z, dt_bias, dt_softplus,
state_batch_indices, pad_slot_id, output, disable_state_update,
algorithm);
state_batch_indices, state_scale, pad_slot_id, output,
disable_state_update, rand_seed, algorithm);
} else if (x.dim() == 4) {
run_selective_state_update_mtp(state, x, dt, A, B, C, D, z, dt_bias, dt_softplus,
state_batch_indices, pad_slot_id, output, disable_state_update,
intermediate_states_buffer, intermediate_state_indices,
cache_steps, algorithm);
run_selective_state_update_mtp(
state, x, dt, A, B, C, D, z, dt_bias, dt_softplus, state_batch_indices, state_scale,
pad_slot_id, output, disable_state_update, intermediate_states_buffer,
intermediate_state_indices, intermediate_state_scales, rand_seed, cache_steps, algorithm);
} else {
FLASHINFER_CHECK(false,
"x must have 3 dimensions (single-token) or 4 dimensions (multi-token), got ",
Expand Down
6 changes: 6 additions & 0 deletions csrc/selective_state_update_customize_config.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@ using input_t = {{ input_dtype }};
using weight_t = {{ weight_dtype }};
using matrixA_t = {{ matrixA_dtype }};
using stateIndex_t = {{ stateIndex_dtype }};
// Type for block-scale decode factors (e.g. float, __half).
// void = no scaling (state_t is used as-is).
using state_scale_t = {{ state_scale_type }};

constexpr int DIM = {{ dim }};
constexpr int DSTATE = {{ dstate }};
constexpr int NTOKENS_MTP = {{ ntokens_mtp }};
// Philox PRNG rounds for stochastic rounding of fp16 state stores.
// 0 = no stochastic rounding; typical value = 10.
constexpr int PHILOX_ROUNDS = {{ philox_rounds }};
10 changes: 6 additions & 4 deletions csrc/selective_state_update_kernel_inst.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@

namespace flashinfer::mamba {

template void invokeSelectiveStateUpdate<input_t, weight_t, matrixA_t, state_t, stateIndex_t>(
SelectiveStateUpdateParams&, SSUAlgorithm, cudaStream_t);
template void invokeSelectiveStateUpdate<input_t, weight_t, matrixA_t, state_t, stateIndex_t,
state_scale_t>(SelectiveStateUpdateParams&, SSUAlgorithm,
cudaStream_t);

namespace mtp {
template void invokeSelectiveStateUpdateMTP<input_t, weight_t, matrixA_t, state_t, stateIndex_t>(
SelectiveStateMTPParams&, SSUAlgorithm, cudaStream_t);
template void invokeSelectiveStateUpdateMTP<input_t, weight_t, matrixA_t, state_t, stateIndex_t,
state_scale_t>(SelectiveStateMTPParams&, SSUAlgorithm,
cudaStream_t);
} // namespace mtp

} // namespace flashinfer::mamba
21 changes: 19 additions & 2 deletions flashinfer/aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,15 +548,32 @@ def gen_all_modules(
]
# selective_state_update: one module per dtype combo per GPU arch
_ssu_dtype_combos = [
# (state, input, weight, matrixA, stateIndex)
# (state, input, weight, matrixA, stateIndex, state_scale_dtype)
(
torch.bfloat16,
torch.bfloat16,
torch.bfloat16,
torch.float32,
torch.int64,
None,
),
# int16 state (block-scaled quantization, scale stored as float32)
(
torch.int16,
torch.bfloat16,
torch.bfloat16,
torch.float32,
torch.int64,
torch.float32,
),
(
torch.float32,
torch.bfloat16,
torch.bfloat16,
torch.float32,
torch.int64,
None,
),
(torch.float32, torch.bfloat16, torch.bfloat16, torch.float32, torch.int64),
]
_ssu_dims = [64]
_ssu_dstates = [128]
Expand Down
Loading
Loading