Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
bcbd2aa
I seem that I can call my cuda wrapper from python
Dec 22, 2025
44d2b41
check all the kernel inputs. ready to develop dispatch code.
Dec 22, 2025
3d00fdf
simple implementation of selective_state_update is working
Jan 5, 2026
ebe9070
ported the hopper version + runtime dispatch check
Jan 5, 2026
8e80802
Passed pre-commit checks
ishovkun Jan 6, 2026
8568d67
Update flashinfer/mamba/selective_state_update.py
ishovkun Jan 7, 2026
7bb48cc
remove unreachable code
Jan 7, 2026
5812938
Improve docstring for input state shape
ishovkun Jan 7, 2026
9883eab
Simple kernel also uses fast_exp
Jan 7, 2026
8aed468
Remove error check of the kernel launch (it was a debugging leftover).
Jan 7, 2026
5c65e53
no need for the fast_exp function
Jan 7, 2026
681a5ba
Remove unnecessary None check for D before unsqueeze
Jan 7, 2026
5e8536e
Hoist dA computation outside the innermost loop.
Jan 7, 2026
acab6bd
support for non-none z
Jan 7, 2026
84bdc07
stage forgotten z handling in the test
Jan 7, 2026
edb5269
test: handle z
Jan 7, 2026
3e97354
make sure that batch size does not exceed the state cache size
Jan 7, 2026
7ba913d
do use matrixA_dtype in the test
Jan 7, 2026
83c5ecb
Add selective state update module to JIT specs
Jan 7, 2026
c0f96aa
Fix dt_bias stride check to handle None value
Jan 7, 2026
03a8570
handle a few different DIM and DSTATE
Jan 7, 2026
ce2e868
selective state: test various dims and state sizes
Jan 7, 2026
d0cd9da
Fix
Jan 7, 2026
62eeef8
Merge branch 'flashinfer-ai:main' into main
ishovkun Jan 8, 2026
0f5f4b8
formatting
ishovkun Jan 8, 2026
728e0c3
Merge branch 'main' of github.com:ishovkun/flashinfer-dev
ishovkun Jan 8, 2026
a35c408
Add SM90 Mamba selective state update JIT module
ishovkun Jan 8, 2026
81163a2
ifdef guards for the hopper+ implementation for aot
ishovkun Jan 8, 2026
d2a147a
export both selective_state_update and selective_state_update_sm90 mo…
ishovkun Jan 8, 2026
040734e
Add alignment checks for Mamba state update
ishovkun Jan 9, 2026
6d9aa1b
Fix `toFloat` usage in Mamba selective state update kernel
ishovkun Jan 9, 2026
d22eec4
avoid an ambiguous variable name
ishovkun Jan 9, 2026
59c9f85
Support SM90+ for Mamba selective state update
ishovkun Jan 9, 2026
c443c54
Improve FLASHINFER_CHECK error message in selective_state_update
ishovkun Jan 9, 2026
95fa14d
Refactor SM90 module to use CompilationContext for nvcc flags
ishovkun Jan 9, 2026
dfc0a38
Exclude an unnecessary compiler flag as it is part of the default flags.
ishovkun Jan 12, 2026
baa55a1
comment about the use of TMA to justify the neccesity of sm90 module
ishovkun Jan 12, 2026
5afc7f6
a comment about the choice of A tensor in the unit test
ishovkun Jan 12, 2026
0d9c71b
Use torch.testing.assert_allclose instead of torch.allclose as the
ishovkun Jan 12, 2026
5cda288
Merge branch 'flashinfer-ai:main' into main
ishovkun Jan 12, 2026
d99923c
fix obsoleted bf16 ifdefs
ishovkun Jan 12, 2026
c7b231a
Merge branch 'flashinfer-ai:main' into main
ishovkun Jan 16, 2026
7aeb97a
Add state_dtype support for fp16/bf16/fp32 in selective_state_update
ishovkun Jan 16, 2026
9388e9d
Add alignment checks for vectorized loads in Mamba SSU kernel
ishovkun Jan 16, 2026
bc78e7f
Fix shared memory alignment for vectorized loads in Mamba SSU
ishovkun Jan 16, 2026
48d5f05
Improve error message for selective_state_update dtype mismatch
ishovkun Jan 16, 2026
f5b2c16
Fix output loop to use all warp lanes in Mamba kernel
ishovkun Jan 16, 2026
78b770b
Add mixed dtype support for selective_state_update
ishovkun Jan 17, 2026
5f52fb8
Add missing alignment checks for x and z in Mamba SSU
ishovkun Jan 17, 2026
65ee5c2
init
yzh119 Jan 18, 2026
1938c5c
address coderabbit comments
yzh119 Jan 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 50 additions & 11 deletions csrc/selective_state_update.cu
Original file line number Diff line number Diff line change
Expand Up @@ -187,24 +187,63 @@ void selective_state_update(TensorView state, TensorView x, TensorView dt, Tenso
auto dtype_key =
std::make_tuple(state_dtype_code, input_dtype_code, weight_dtype_code, matrixA_dtype_code);

// Currently only support: input_t = weight_t = state_t = bfloat16, matrixA_t = float
if (dtype_key == std::make_tuple(bfloat16_code, bfloat16_code, bfloat16_code, float32_code)) {
if (dtype_key == std::make_tuple(/*state*/ bfloat16_code, /*input */ bfloat16_code,
/*weight */ bfloat16_code, /*matrixA */ float32_code)) {
using state_t = nv_bfloat16;
using input_t = nv_bfloat16;
using weight_t = nv_bfloat16;
using matrixA_t = float;

invokeSelectiveStateUpdate<input_t, weight_t, matrixA_t, state_t>(p, stream);
} else if (dtype_key == std::make_tuple(/*state*/ float16_code, /*input */ bfloat16_code,
/*weight */ bfloat16_code, /*matrixA */ float32_code)) {
using state_t = half;
using input_t = nv_bfloat16;
using weight_t = nv_bfloat16;
using matrixA_t = float;
invokeSelectiveStateUpdate<input_t, weight_t, matrixA_t, state_t>(p, stream);
} else if (dtype_key == std::make_tuple(/*state*/ float32_code, /*input */ bfloat16_code,
/*weight */ bfloat16_code, /*matrixA */ float32_code)) {
using state_t = float;
using input_t = nv_bfloat16;
using weight_t = nv_bfloat16;
using matrixA_t = float;
invokeSelectiveStateUpdate<input_t, weight_t, matrixA_t, state_t>(p, stream);
} else if (dtype_key == std::make_tuple(/*state*/ bfloat16_code, /*input */ bfloat16_code,
/*weight */ float32_code, /*matrixA */ float32_code)) {
using state_t = nv_bfloat16;
using input_t = nv_bfloat16;
using weight_t = float;
using matrixA_t = float;
invokeSelectiveStateUpdate<input_t, weight_t, matrixA_t, state_t>(p, stream);
} else if (dtype_key == std::make_tuple(/*state*/ float16_code, /*input */ bfloat16_code,
/*weight */ float32_code, /*matrixA */ float32_code)) {
using state_t = half;
using input_t = nv_bfloat16;
using weight_t = float;
using matrixA_t = float;
invokeSelectiveStateUpdate<input_t, weight_t, matrixA_t, state_t>(p, stream);
} else if (dtype_key == std::make_tuple(/*state*/ float32_code, /*input */ bfloat16_code,
/*weight */ float32_code, /*matrixA */ float32_code)) {
using state_t = float;
using input_t = nv_bfloat16;
using weight_t = float;
using matrixA_t = float;
invokeSelectiveStateUpdate<input_t, weight_t, matrixA_t, state_t>(p, stream);
} else {
// Default case: unsupported dtype combination
TVM_FFI_ICHECK(false) << "Unsupported dtype combination for selective_state_update: "
<< "state_dtype=" << state_dtype.code << ":" << state_dtype.bits << ", "
<< "input_dtype=" << input_dtype.code << ":" << input_dtype.bits << ", "
<< "weight_dtype=" << weight_dtype.code << ":" << weight_dtype.bits
<< ", "
<< "matrixA_dtype=" << matrixA_dtype.code << ":" << matrixA_dtype.bits
<< ". Currently only support: "
<< "state=bfloat16, input=bfloat16, weight=bfloat16, matrixA=float32";
TVM_FFI_ICHECK(false)
<< "Unsupported dtype combination for selective_state_update: "
<< "state_dtype=" << state_dtype.code << ":" << state_dtype.bits << ", "
<< "input_dtype=" << input_dtype.code << ":" << input_dtype.bits << ", "
<< "weight_dtype=" << weight_dtype.code << ":" << weight_dtype.bits << ", "
<< "matrixA_dtype=" << matrixA_dtype.code << ":" << matrixA_dtype.bits
<< ". Supported combos include:\n"
<< " (state=bfloat16, input=bfloat16, weight=bfloat16, matrixA=float32)\n"
<< " (state=float16, input=bfloat16, weight=bfloat16, matrixA=float32)\n"
<< " (state=float32, input=bfloat16, weight=bfloat16, matrixA=float32)\n"
<< " (state=bfloat16, input=bfloat16, weight=float32, matrixA=float32)\n"
<< " (state=float16, input=bfloat16, weight=float32, matrixA=float32)\n"
<< " (state=float32, input=bfloat16, weight=float32, matrixA=float32)";
}
}

Expand Down
Loading