int16 Block-Scaled State and Stochastic Rounding for SSU (mamba)#2645
int16 Block-Scaled State and Stochastic Rounding for SSU (mamba)#2645yzh119 merged 69 commits intoflashinfer-ai:mainfrom
Conversation
Move the test input generation helper from test_selective_state_update.py to a new test_utils.py module for reuse across tests. The refactored function adds support for multi-token mode, intermediate state buffers, and configurable state cache strides.
struct - Add helper functions for tensor validation and dtype checks - Move output tensor to Optional and update checks accordingly - Add state_stride_batch and update_state fields to SelectiveStateUpdateParams - Refactor kernel param usage for clarity and consistency
Extract dispatchDimDstate and dispatchRatio helpers to simplify kernel dispatch code and reduce duplication.
- Add kernel and dispatcher support for int32/int64 state_batch_indices - Update tests to cover int32 indices - Fix test_utils to use int64 slot_idx by default Support int32 and int64 state_batch_indices in selective_state_update - Remove int32 type check to allow both int32 and int64 index types - Add stateIndex_t template parameter to kernels for index type dispatch - Extract kernel implementations to new selective_state_update_stp.cuh - Remove unused TMA helper functions from create_tensor_map.cuh - Add comprehensive MTP (multi-token prediction) test suite
checks - Add common.cuh with kernel dispatch helpers and alignment checks - Split and rename kernel_selective_state_update_stp.cuh, add kernel_selective_state_update_mtp.cuh - Refactor Python selective_state_update to clarify dimension handling - Add test for dtype mismatch between state_batch_indices and intermediate_state_indices - Update test_utils to generate int64 intermediate_slot_idx by default - Remove redundant input type check in validate_intermediate_state_indices
Always define state_batch_idx (either from state_batch_indices or pid_b) to mirror the CUDA kernel's state_batch variable. This allows the intermediate state caching logic to use a simple check of `state_batch_idx != pad_slot_id` without requiring an extra HAS_STATE_BATCH_INDICES guard, matching the CUDA kernel behavior. addresses: flashinfer-ai#2444 (comment)
- Add test_chunk_scan_combined.py comparing CUTLASS CuTe DSL Blackwell implementation against Triton reference - Move selective_state_update_triton.py into triton_reference/ package - Add Triton reference implementations for Mamba2 SSD kernels: - ssd_combined.py (main entry point) - ssd_chunk_scan.py, ssd_chunk_state.py, ssd_state_passing.py - ssd_bmm.py, softplus.py (utilities)
# Conflicts: # tests/mamba/selective_state_update_triton.py # tests/mamba/test_selective_state_update_mtp.py # tests/mamba/test_selective_state_update_stp.py
- Move dtype dispatch and instantiation to codegen via Jinja templates - Generate config and instantiation files per dtype combination - Update Python JIT logic to build/load kernels for specific dtypes - Remove C++ dtype dispatch helpers from selective_state_update.cu - Update kernel launcher comment for clarity on consumer warps
Support explicit algorithm choice (auto/simple/vertical/horizontal) for selective_state_update and MTP kernels. Update kernel signatures, Python bindings, and JIT module generation to include algorithm and compile-time shape parameters (dim, dstate, ntokens_mtp). Refactor dispatch logic for SM90/SM100 architectures.
… .cu files The config.inc defines DIM, DSTATE, NTOKENS_MTP as constexpr globals that the header's function templates rely on. With the previous order (header first, config second), NVCC's lenient two-phase lookup masked the issue, but a fresh JIT compilation after cache clearing would fail with 'identifier DIM/DSTATE is undefined' errors. clang-format is disabled for these includes because it reorders them alphabetically, which breaks compilation. AI-assisted
Assign each of the 4 consumer warps a single tensor to load (x, B, z, C) instead of warps 0 and 1 each loading two tensors sequentially. This maximizes memory-level parallelism during the load phase. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Replace cartesian-product fixture parametrization with explicit rows: one base case plus one row per parameter deviation. Cuts the test count from ~200+ (MTP) and ~144+ (STP) down to ~26 and ~15 respectively. AI-assisted Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Parametrize tests to run with all supported algorithms - Update test logic to pass algorithm argument through - Improve test output messages to include algorithm name - Add utility to detect available algorithms based on GPU arch
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/mamba/test_philox_rounding.py (1)
339-341:test_stochastic_rounding_swis effectively sm100a-gated.Line [340] pulls in
stochastic_round_module, which skips onmajor < 10(Lines [225]-[226]). So this “software fallback” test won’t run on older GPUs. Consider splitting into:
- SW-only correctness test (all GPUs), and
- SW-vs-HW parity test (sm100a+ only).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/mamba/test_philox_rounding.py` around lines 339 - 341, The test_stochastic_rounding_sw currently depends on stochastic_round_module (which is skipped for major < 10), effectively gating the whole test; split it into two tests: 1) a SW-only correctness test (keep function name test_stochastic_rounding_sw_correctness) that only accepts stochastic_round_sw_module and seed and asserts expected outputs on all GPUs (remove stochastic_round_module from its parameters), and 2) a SW-vs-HW parity test (e.g., test_stochastic_rounding_sw_parity) that accepts both stochastic_round_sw_module and stochastic_round_module and is guarded by the existing sm100a/major>=10 skip logic (reuse the same skip condition used where stochastic_round_module is defined) to compare outputs between SW and HW implementations.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/mamba/test_philox_rounding.py`:
- Around line 339-341: The test_stochastic_rounding_sw currently depends on
stochastic_round_module (which is skipped for major < 10), effectively gating
the whole test; split it into two tests: 1) a SW-only correctness test (keep
function name test_stochastic_rounding_sw_correctness) that only accepts
stochastic_round_sw_module and seed and asserts expected outputs on all GPUs
(remove stochastic_round_module from its parameters), and 2) a SW-vs-HW parity
test (e.g., test_stochastic_rounding_sw_parity) that accepts both
stochastic_round_sw_module and stochastic_round_module and is guarded by the
existing sm100a/major>=10 skip logic (reuse the same skip condition used where
stochastic_round_module is defined) to compare outputs between SW and HW
implementations.
On GPUs with compute capability < 10, use regular rounding in Triton reference for stochastic rounding tests, matching hardware support.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/mamba/test_selective_state_update_stp.py`:
- Around line 742-746: _SR_PARAMS is defined as a mutable list which triggers
the RUF012 lint for class attributes; change its definition from a list to an
immutable tuple so the class attribute cannot be mutated. Locate the _SR_PARAMS
variable in tests/mamba/test_selective_state_update_stp.py and replace the
surrounding square brackets with parentheses while preserving all elements and
comments so the values and order (e.g., entries like (64, 64, 64, 128,
torch.float16, torch.float32, True)) remain unchanged.
- Line 490: The parameter names that are unused should be prefixed with an
underscore to silence Ruff ARG002; rename the unused parameter state_dtype in
the make_inputs(...) signatures to _state_dtype, and rename the unused catch-all
in assert_states_match(...) from **kwargs to **_kwargs; apply the same
underscore-prefix change to the other identical function signatures in this file
(the other make_inputs and assert_states_match occurrences) so behavior is
unchanged but linter warnings are suppressed.
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
tests/mamba/test_selective_state_update_mtp.pytests/mamba/test_selective_state_update_stp.py
|
/bot run |
|
[FAILED] Pipeline #45162408: 1/20 passed |
|
/bot run |
|
[FAILED] Pipeline #45180729: 9/20 passed |
yzh119
left a comment
There was a problem hiding this comment.
LGTM overall, some minor comments
| @@ -1,330 +0,0 @@ | |||
| # Issue self-claim workflow for external contributors | |||
There was a problem hiding this comment.
Please revert the change on this file.
| intermediate_state_indices : Optional[torch.Tensor] | ||
| Optional indices mapping batch elements to intermediate state buffer positions | ||
| with shape (batch,) | ||
| rand_seed : Optional[int] |
There was a problem hiding this comment.
do we consider cudagraph compatibility? If so we might also consider device-side random seed (stored in a integer gpu tensor with size 1).
There was a problem hiding this comment.
Oh I didn't know we are to change the seed on the fly...
(enable cuda graphs) - Change rand_seed argument from int to CUDA int64 tensor for Philox stochastic rounding, ensuring CUDA graph compatibility - Update C++/CUDA kernels and Python bindings to accept device-side seed - Add validation for rand_seed tensor shape, dtype, and device - Update tests to use tensor-based rand_seed
…shinfer-ai#2645) <!-- .github/pull_request_template.md --> ### Motivation The `selective_state_update` kernels (single-token STP and multi-token MTP) store SSM state in memory between steps. This PR adds two complementary features for reducing state memory bandwidth and improving numerical quality: **int16 block-scaled quantization** for 2× memory footprint reduction, and **Philox-based stochastic rounding** for statistically unbiased fp32→fp16 conversion. --- ### int16 Block-Scaled State The state tensor can now be stored as int16 with a per-row (per DIM-row) float32 decode scale, enabling 2× compression vs fp16 at low accuracy loss. **Kernel changes** (`kernel_selective_state_update_stp.cuh`, `kernel_selective_state_update_mtp.cuh`) Added a `state_scale_t` template parameter (replacing a boolean `scaleState` flag — `void` means no scaling, `float` enables it). When scaling is active, the kernel does a 2-pass quantization: compute the row max across warp lanes, derive encode/decode scales, then convert and store. Intermediate state writes for MTP likewise quantize before writing to global memory, and the decode scale is stored alongside. **Vertical algorithm** (`kernel_selective_state_update_stp.cuh`) The existing vertical/TMA path was extended with int16 support; TMA alignment requirements were tightened to 128 bytes accordingly. **Python/JIT plumbing** (`selective_state_update.py`, `selective_state_update_customize_config.jinja`, `selective_state_update.cu`) `state_scale` tensor and its dtype flow through from the Python API into the JIT codegen and kernel launch. The Triton reference was updated to match the per-block scaling logic for bitwise-comparable tests. **Tests** (`test_selective_state_update_stp.py`, `test_selective_state_update_mtp.py`) End-to-end tests check dequantized state and output correctness against the Triton reference for int16 state across a range of batch/head/dim/dstate configurations. Tests also verify that passing `intermediate_states` with int16 scaled state is correctly rejected. --- ### Stochastic Rounding for fp16 State When state is fp16, truncation-based conversion from fp32 accumulation introduces systematic bias. Stochastic rounding is statistically unbiased: it rounds up or down with probability proportional to the fractional remainder. **Philox PRNG** (`conversion.cuh`) A Philox-4x32 implementation matching Triton's `tl.randint` exactly (bitwise verified in tests). Template parameter for number of rounds. `cvt_rs_f16_f32` implements the actual stochastic conversion — software emulation on older architectures, PTX `cvt.rs.f16x2.f32` on SM100+. **Kernel integration** (both STP and MTP kernels) `PHILOX_ROUNDS` template parameter controls whether stochastic rounding is active. When > 0, all fp32→fp16 state stores use `cvt_rs_f16_f32` with Philox-generated noise. Restricted to fp16 state via `static_assert`. **Philox-4x32 amortization** Each Philox call natively produces 4 random integers. Rather than calling once per element (discarding 3 of 4 outputs), the kernels call `philox_randint4x` once per 4 elements and index `rand_ints[k % 4]`, cutting PRNG work by 4×. **Bug fix** Philox random offsets now correctly include batch and head strides, matching the per-element addressing used in the kernel. **Tests** (`test_philox_rounding.py`, extended MTP/STP tests) Bitwise match of Philox PRNG vs Triton, hardware vs software stochastic rounding on SM100, and tolerance-based correctness checks for SR state updates with and without intermediate states. --- ### Performance The MTP kernel additionally received dim-tiling across `blockIdx.z` (splitting the DIM dimension across grid blocks when `batch * nheads < num_sms * 2`), saturating the GPU at small batch sizes and closing the gap vs the Triton reference in the undersaturated regime. <img width="1257" height="1571" alt="image" src="https://github.com/user-attachments/assets/d7fcb86c-76c5-4c04-905e-09d1b14a0690" /> <img width="1126" height="1407" alt="image" src="https://github.com/user-attachments/assets/e01aa38d-b859-46cd-b471-f47c9b2f3761" /> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * int16 state storage with per-state/per-tensor scaling and intermediate-state quantization. * Optional per-state scaling and Philox-based stochastic rounding (new optional inputs: state_scale, intermediate_state_scales, rand_seed, philox_rounds). * Tiled kernel/layout optimizations and a new warp-level max reduction utility. * **Tests** * Extensive coverage for int16, intermediate-state paths, and stochastic rounding (hardware and software fallbacks). * **Chores** * Removed issue-management CI workflow. * Added ignore rules for Zed editor. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Motivation
The
selective_state_updatekernels (single-token STP and multi-token MTP) store SSM state in memory between steps. This PR adds two complementary features for reducing state memory bandwidth and improving numerical quality: int16 block-scaled quantization for 2× memory footprint reduction, and Philox-based stochastic rounding for statistically unbiased fp32→fp16 conversion.int16 Block-Scaled State
The state tensor can now be stored as int16 with a per-row (per DIM-row) float32 decode scale, enabling 2× compression vs fp16 at low accuracy loss.
Kernel changes (
kernel_selective_state_update_stp.cuh,kernel_selective_state_update_mtp.cuh)Added a
state_scale_ttemplate parameter (replacing a booleanscaleStateflag —voidmeans no scaling,floatenables it). When scaling is active, the kernel does a 2-pass quantization: compute the row max across warp lanes, derive encode/decode scales, then convert and store. Intermediate state writes for MTP likewise quantize before writing to global memory, and the decode scale is stored alongside.Vertical algorithm (
kernel_selective_state_update_stp.cuh)The existing vertical/TMA path was extended with int16 support; TMA alignment requirements were tightened to 128 bytes accordingly.
Python/JIT plumbing (
selective_state_update.py,selective_state_update_customize_config.jinja,selective_state_update.cu)state_scaletensor and its dtype flow through from the Python API into the JIT codegen and kernel launch. The Triton reference was updated to match the per-block scaling logic for bitwise-comparable tests.Tests (
test_selective_state_update_stp.py,test_selective_state_update_mtp.py)End-to-end tests check dequantized state and output correctness against the Triton reference for int16 state across a range of batch/head/dim/dstate configurations. Tests also verify that passing
intermediate_stateswith int16 scaled state is correctly rejected.Stochastic Rounding for fp16 State
When state is fp16, truncation-based conversion from fp32 accumulation introduces systematic bias. Stochastic rounding is statistically unbiased: it rounds up or down with probability proportional to the fractional remainder.
Philox PRNG (
conversion.cuh)A Philox-4x32 implementation matching Triton's
tl.randintexactly (bitwise verified in tests). Template parameter for number of rounds.cvt_rs_f16_f32implements the actual stochastic conversion — software emulation on older architectures, PTXcvt.rs.f16x2.f32on SM100+.Kernel integration (both STP and MTP kernels)
PHILOX_ROUNDStemplate parameter controls whether stochastic rounding is active. When > 0, all fp32→fp16 state stores usecvt_rs_f16_f32with Philox-generated noise. Restricted to fp16 state viastatic_assert.Philox-4x32 amortization
Each Philox call natively produces 4 random integers. Rather than calling once per element (discarding 3 of 4 outputs), the kernels call
philox_randint4xonce per 4 elements and indexrand_ints[k % 4], cutting PRNG work by 4×.Bug fix
Philox random offsets now correctly include batch and head strides, matching the per-element addressing used in the kernel.
Tests (
test_philox_rounding.py, extended MTP/STP tests)Bitwise match of Philox PRNG vs Triton, hardware vs software stochastic rounding on SM100, and tolerance-based correctness checks for SR state updates with and without intermediate states.
Performance
The MTP kernel additionally received dim-tiling across
blockIdx.z(splitting the DIM dimension across grid blocks whenbatch * nheads < num_sms * 2), saturating the GPU at small batch sizes and closing the gap vs the Triton reference in the undersaturated regime.🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Summary by CodeRabbit
New Features
Tests
Chores