Skip to content

int16 Block-Scaled State and Stochastic Rounding for SSU (mamba)#2645

Merged
yzh119 merged 69 commits intoflashinfer-ai:mainfrom
ishovkun:main
Mar 4, 2026
Merged

int16 Block-Scaled State and Stochastic Rounding for SSU (mamba)#2645
yzh119 merged 69 commits intoflashinfer-ai:mainfrom
ishovkun:main

Conversation

@ishovkun
Copy link
Copy Markdown
Contributor

@ishovkun ishovkun commented Feb 26, 2026

Motivation

The selective_state_update kernels (single-token STP and multi-token MTP) store SSM state in memory between steps. This PR adds two complementary features for reducing state memory bandwidth and improving numerical quality: int16 block-scaled quantization for 2× memory footprint reduction, and Philox-based stochastic rounding for statistically unbiased fp32→fp16 conversion.


int16 Block-Scaled State

The state tensor can now be stored as int16 with a per-row (per DIM-row) float32 decode scale, enabling 2× compression vs fp16 at low accuracy loss.

Kernel changes (kernel_selective_state_update_stp.cuh, kernel_selective_state_update_mtp.cuh)
Added a state_scale_t template parameter (replacing a boolean scaleState flag — void means no scaling, float enables it). When scaling is active, the kernel does a 2-pass quantization: compute the row max across warp lanes, derive encode/decode scales, then convert and store. Intermediate state writes for MTP likewise quantize before writing to global memory, and the decode scale is stored alongside.

Vertical algorithm (kernel_selective_state_update_stp.cuh)
The existing vertical/TMA path was extended with int16 support; TMA alignment requirements were tightened to 128 bytes accordingly.

Python/JIT plumbing (selective_state_update.py, selective_state_update_customize_config.jinja, selective_state_update.cu)
state_scale tensor and its dtype flow through from the Python API into the JIT codegen and kernel launch. The Triton reference was updated to match the per-block scaling logic for bitwise-comparable tests.

Tests (test_selective_state_update_stp.py, test_selective_state_update_mtp.py)
End-to-end tests check dequantized state and output correctness against the Triton reference for int16 state across a range of batch/head/dim/dstate configurations. Tests also verify that passing intermediate_states with int16 scaled state is correctly rejected.


Stochastic Rounding for fp16 State

When state is fp16, truncation-based conversion from fp32 accumulation introduces systematic bias. Stochastic rounding is statistically unbiased: it rounds up or down with probability proportional to the fractional remainder.

Philox PRNG (conversion.cuh)
A Philox-4x32 implementation matching Triton's tl.randint exactly (bitwise verified in tests). Template parameter for number of rounds. cvt_rs_f16_f32 implements the actual stochastic conversion — software emulation on older architectures, PTX cvt.rs.f16x2.f32 on SM100+.

Kernel integration (both STP and MTP kernels)
PHILOX_ROUNDS template parameter controls whether stochastic rounding is active. When > 0, all fp32→fp16 state stores use cvt_rs_f16_f32 with Philox-generated noise. Restricted to fp16 state via static_assert.

Philox-4x32 amortization
Each Philox call natively produces 4 random integers. Rather than calling once per element (discarding 3 of 4 outputs), the kernels call philox_randint4x once per 4 elements and index rand_ints[k % 4], cutting PRNG work by 4×.

Bug fix
Philox random offsets now correctly include batch and head strides, matching the per-element addressing used in the kernel.

Tests (test_philox_rounding.py, extended MTP/STP tests)
Bitwise match of Philox PRNG vs Triton, hardware vs software stochastic rounding on SM100, and tolerance-based correctness checks for SR state updates with and without intermediate states.


Performance

The MTP kernel additionally received dim-tiling across blockIdx.z (splitting the DIM dimension across grid blocks when batch * nheads < num_sms * 2), saturating the GPU at small batch sizes and closing the gap vs the Triton reference in the undersaturated regime.

image image

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Summary by CodeRabbit

  • New Features

    • int16 state storage with per-state/per-tensor scaling and intermediate-state quantization.
    • Optional per-state scaling and Philox-based stochastic rounding (new optional inputs: state_scale, intermediate_state_scales, rand_seed, philox_rounds).
    • Tiled kernel/layout optimizations and a new warp-level max reduction utility.
  • Tests

    • Extensive coverage for int16, intermediate-state paths, and stochastic rounding (hardware and software fallbacks).
  • Chores

    • Removed issue-management CI workflow.
    • Added ignore rules for Zed editor.

ishovkun and others added 30 commits January 28, 2026 21:29
Move the test input generation helper from
test_selective_state_update.py
to a new test_utils.py module for reuse across tests. The refactored
function adds support for multi-token mode, intermediate state buffers,
and configurable state cache strides.
struct

- Add helper functions for tensor validation and dtype checks
- Move output tensor to Optional and update checks accordingly
- Add state_stride_batch and update_state fields to
  SelectiveStateUpdateParams
- Refactor kernel param usage for clarity and consistency
Extract dispatchDimDstate and dispatchRatio helpers to simplify
kernel dispatch code and reduce duplication.
- Add kernel and dispatcher support for int32/int64 state_batch_indices
- Update tests to cover int32 indices
- Fix test_utils to use int64 slot_idx by default
  Support int32 and int64 state_batch_indices in selective_state_update

- Remove int32 type check to allow both int32 and int64 index types
- Add stateIndex_t template parameter to kernels for index type dispatch
- Extract kernel implementations to new selective_state_update_stp.cuh
- Remove unused TMA helper functions from create_tensor_map.cuh
- Add comprehensive MTP (multi-token prediction) test suite
checks

- Add common.cuh with kernel dispatch helpers and alignment checks
- Split and rename kernel_selective_state_update_stp.cuh, add
  kernel_selective_state_update_mtp.cuh
- Refactor Python selective_state_update to clarify dimension handling
- Add test for dtype mismatch between state_batch_indices and
  intermediate_state_indices
- Update test_utils to generate int64 intermediate_slot_idx by default
- Remove redundant input type check in
  validate_intermediate_state_indices
Always define state_batch_idx (either from state_batch_indices or pid_b)
to mirror the CUDA kernel's state_batch variable. This allows the
intermediate state caching logic to use a simple check of
`state_batch_idx != pad_slot_id` without requiring an extra
HAS_STATE_BATCH_INDICES guard, matching the CUDA kernel behavior.

addresses:
flashinfer-ai#2444 (comment)
- Add test_chunk_scan_combined.py comparing CUTLASS CuTe DSL
  Blackwell implementation against Triton reference
- Move selective_state_update_triton.py into triton_reference/ package
- Add Triton reference implementations for Mamba2 SSD kernels:
  - ssd_combined.py (main entry point)
  - ssd_chunk_scan.py, ssd_chunk_state.py, ssd_state_passing.py
  - ssd_bmm.py, softplus.py (utilities)
# Conflicts:
#	tests/mamba/selective_state_update_triton.py
#	tests/mamba/test_selective_state_update_mtp.py
#	tests/mamba/test_selective_state_update_stp.py
- Move dtype dispatch and instantiation to codegen via Jinja templates
- Generate config and instantiation files per dtype combination
- Update Python JIT logic to build/load kernels for specific dtypes
- Remove C++ dtype dispatch helpers from selective_state_update.cu
- Update kernel launcher comment for clarity on consumer warps
Support explicit algorithm choice (auto/simple/vertical/horizontal)
for selective_state_update and MTP kernels. Update kernel signatures,
Python bindings, and JIT module generation to include algorithm and
compile-time shape parameters (dim, dstate, ntokens_mtp). Refactor
dispatch logic for SM90/SM100 architectures.
… .cu files

The config.inc defines DIM, DSTATE, NTOKENS_MTP as constexpr globals
that the header's function templates rely on. With the previous order
(header first, config second), NVCC's lenient two-phase lookup masked
the issue, but a fresh JIT compilation after cache clearing would fail
with 'identifier DIM/DSTATE is undefined' errors.

clang-format is disabled for these includes because it reorders them
alphabetically, which breaks compilation.

AI-assisted
Assign each of the 4 consumer warps a single tensor to load (x, B, z, C)
instead of warps 0 and 1 each loading two tensors sequentially. This
maximizes memory-level parallelism during the load phase.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Replace cartesian-product fixture parametrization with explicit rows:
one base case plus one row per parameter deviation. Cuts the test count
from ~200+ (MTP) and ~144+ (STP) down to ~26 and ~15 respectively.

AI-assisted

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Parametrize tests to run with all supported algorithms
- Update test logic to pass algorithm argument through
- Improve test output messages to include algorithm name
- Add utility to detect available algorithms based on GPU arch
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/mamba/test_philox_rounding.py (1)

339-341: test_stochastic_rounding_sw is effectively sm100a-gated.

Line [340] pulls in stochastic_round_module, which skips on major < 10 (Lines [225]-[226]). So this “software fallback” test won’t run on older GPUs. Consider splitting into:

  1. SW-only correctness test (all GPUs), and
  2. SW-vs-HW parity test (sm100a+ only).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/mamba/test_philox_rounding.py` around lines 339 - 341, The
test_stochastic_rounding_sw currently depends on stochastic_round_module (which
is skipped for major < 10), effectively gating the whole test; split it into two
tests: 1) a SW-only correctness test (keep function name
test_stochastic_rounding_sw_correctness) that only accepts
stochastic_round_sw_module and seed and asserts expected outputs on all GPUs
(remove stochastic_round_module from its parameters), and 2) a SW-vs-HW parity
test (e.g., test_stochastic_rounding_sw_parity) that accepts both
stochastic_round_sw_module and stochastic_round_module and is guarded by the
existing sm100a/major>=10 skip logic (reuse the same skip condition used where
stochastic_round_module is defined) to compare outputs between SW and HW
implementations.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/mamba/test_philox_rounding.py`:
- Around line 339-341: The test_stochastic_rounding_sw currently depends on
stochastic_round_module (which is skipped for major < 10), effectively gating
the whole test; split it into two tests: 1) a SW-only correctness test (keep
function name test_stochastic_rounding_sw_correctness) that only accepts
stochastic_round_sw_module and seed and asserts expected outputs on all GPUs
(remove stochastic_round_module from its parameters), and 2) a SW-vs-HW parity
test (e.g., test_stochastic_rounding_sw_parity) that accepts both
stochastic_round_sw_module and stochastic_round_module and is guarded by the
existing sm100a/major>=10 skip logic (reuse the same skip condition used where
stochastic_round_module is defined) to compare outputs between SW and HW
implementations.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2ca355e and fa95f9d.

📒 Files selected for processing (1)
  • tests/mamba/test_philox_rounding.py

On GPUs with compute capability < 10, use regular rounding in Triton
reference for stochastic rounding tests, matching hardware support.
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/mamba/test_selective_state_update_stp.py`:
- Around line 742-746: _SR_PARAMS is defined as a mutable list which triggers
the RUF012 lint for class attributes; change its definition from a list to an
immutable tuple so the class attribute cannot be mutated. Locate the _SR_PARAMS
variable in tests/mamba/test_selective_state_update_stp.py and replace the
surrounding square brackets with parentheses while preserving all elements and
comments so the values and order (e.g., entries like (64, 64, 64, 128,
torch.float16, torch.float32, True)) remain unchanged.
- Line 490: The parameter names that are unused should be prefixed with an
underscore to silence Ruff ARG002; rename the unused parameter state_dtype in
the make_inputs(...) signatures to _state_dtype, and rename the unused catch-all
in assert_states_match(...) from **kwargs to **_kwargs; apply the same
underscore-prefix change to the other identical function signatures in this file
(the other make_inputs and assert_states_match occurrences) so behavior is
unchanged but linter warnings are suppressed.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between fa95f9d and 3b20f6e.

📒 Files selected for processing (2)
  • tests/mamba/test_selective_state_update_mtp.py
  • tests/mamba/test_selective_state_update_stp.py

@bkryu
Copy link
Copy Markdown
Collaborator

bkryu commented Mar 2, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !352 has been updated with latest changes, and the CI pipeline #45162408 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #45162408: 1/20 passed

@bkryu
Copy link
Copy Markdown
Collaborator

bkryu commented Mar 2, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !352 has been created, and the CI pipeline #45180729 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #45180729: 9/20 passed

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall, some minor comments

@@ -1,330 +0,0 @@
# Issue self-claim workflow for external contributors
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please revert the change on this file.

intermediate_state_indices : Optional[torch.Tensor]
Optional indices mapping batch elements to intermediate state buffer positions
with shape (batch,)
rand_seed : Optional[int]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we consider cudagraph compatibility? If so we might also consider device-side random seed (stored in a integer gpu tensor with size 1).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I didn't know we are to change the seed on the fly...

ishovkun added 2 commits March 3, 2026 15:19
(enable cuda graphs)

- Change rand_seed argument from int to CUDA int64 tensor for Philox
  stochastic rounding, ensuring CUDA graph compatibility
- Update C++/CUDA kernels and Python bindings to accept device-side seed
- Add validation for rand_seed tensor shape, dtype, and device
- Update tests to use tensor-based rand_seed
Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm good with the current status, thanks for your contribution @ishovkun !

@yzh119 yzh119 enabled auto-merge (squash) March 4, 2026 04:21
@yzh119 yzh119 disabled auto-merge March 4, 2026 17:30
@yzh119 yzh119 merged commit b0e7eb7 into flashinfer-ai:main Mar 4, 2026
30 of 37 checks passed
ameynaik-hub pushed a commit to ameynaik-hub/flashinfer that referenced this pull request Mar 18, 2026
…shinfer-ai#2645)

<!-- .github/pull_request_template.md -->

### Motivation

The `selective_state_update` kernels (single-token STP and multi-token
MTP) store SSM state in memory between steps. This PR adds two
complementary features for reducing state memory bandwidth and improving
numerical quality: **int16 block-scaled quantization** for 2× memory
footprint reduction, and **Philox-based stochastic rounding** for
statistically unbiased fp32→fp16 conversion.

---

### int16 Block-Scaled State

The state tensor can now be stored as int16 with a per-row (per DIM-row)
float32 decode scale, enabling 2× compression vs fp16 at low accuracy
loss.

**Kernel changes** (`kernel_selective_state_update_stp.cuh`,
`kernel_selective_state_update_mtp.cuh`)
Added a `state_scale_t` template parameter (replacing a boolean
`scaleState` flag — `void` means no scaling, `float` enables it). When
scaling is active, the kernel does a 2-pass quantization: compute the
row max across warp lanes, derive encode/decode scales, then convert and
store. Intermediate state writes for MTP likewise quantize before
writing to global memory, and the decode scale is stored alongside.

**Vertical algorithm** (`kernel_selective_state_update_stp.cuh`)
The existing vertical/TMA path was extended with int16 support; TMA
alignment requirements were tightened to 128 bytes accordingly.

**Python/JIT plumbing** (`selective_state_update.py`,
`selective_state_update_customize_config.jinja`,
`selective_state_update.cu`)
`state_scale` tensor and its dtype flow through from the Python API into
the JIT codegen and kernel launch. The Triton reference was updated to
match the per-block scaling logic for bitwise-comparable tests.

**Tests** (`test_selective_state_update_stp.py`,
`test_selective_state_update_mtp.py`)
End-to-end tests check dequantized state and output correctness against
the Triton reference for int16 state across a range of
batch/head/dim/dstate configurations. Tests also verify that passing
`intermediate_states` with int16 scaled state is correctly rejected.

---

### Stochastic Rounding for fp16 State

When state is fp16, truncation-based conversion from fp32 accumulation
introduces systematic bias. Stochastic rounding is statistically
unbiased: it rounds up or down with probability proportional to the
fractional remainder.

**Philox PRNG** (`conversion.cuh`)
A Philox-4x32 implementation matching Triton's `tl.randint` exactly
(bitwise verified in tests). Template parameter for number of rounds.
`cvt_rs_f16_f32` implements the actual stochastic conversion — software
emulation on older architectures, PTX `cvt.rs.f16x2.f32` on SM100+.

**Kernel integration** (both STP and MTP kernels)
`PHILOX_ROUNDS` template parameter controls whether stochastic rounding
is active. When > 0, all fp32→fp16 state stores use `cvt_rs_f16_f32`
with Philox-generated noise. Restricted to fp16 state via
`static_assert`.

**Philox-4x32 amortization**
Each Philox call natively produces 4 random integers. Rather than
calling once per element (discarding 3 of 4 outputs), the kernels call
`philox_randint4x` once per 4 elements and index `rand_ints[k % 4]`,
cutting PRNG work by 4×.

**Bug fix**
Philox random offsets now correctly include batch and head strides,
matching the per-element addressing used in the kernel.

**Tests** (`test_philox_rounding.py`, extended MTP/STP tests)
Bitwise match of Philox PRNG vs Triton, hardware vs software stochastic
rounding on SM100, and tolerance-based correctness checks for SR state
updates with and without intermediate states.

---

### Performance

The MTP kernel additionally received dim-tiling across `blockIdx.z`
(splitting the DIM dimension across grid blocks when `batch * nheads <
num_sms * 2`), saturating the GPU at small batch sizes and closing the
gap vs the Triton reference in the undersaturated regime.

<img width="1257" height="1571" alt="image"
src="https://github.com/user-attachments/assets/d7fcb86c-76c5-4c04-905e-09d1b14a0690"
/>

<img width="1126" height="1407" alt="image"
src="https://github.com/user-attachments/assets/e01aa38d-b859-46cd-b471-f47c9b2f3761"
/>

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* int16 state storage with per-state/per-tensor scaling and
intermediate-state quantization.
* Optional per-state scaling and Philox-based stochastic rounding (new
optional inputs: state_scale, intermediate_state_scales, rand_seed,
philox_rounds).
* Tiled kernel/layout optimizations and a new warp-level max reduction
utility.

* **Tests**
* Extensive coverage for int16, intermediate-state paths, and stochastic
rounding (hardware and software fallbacks).

* **Chores**
  * Removed issue-management CI workflow.
  * Added ignore rules for Zed editor.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants