Skip to content

Add varlen and speculative decoding support to selective state update#2700

Merged
aleozlx merged 7 commits intoflashinfer-ai:mainfrom
roikoren755:feat/selective-state-update-update
Mar 21, 2026
Merged

Add varlen and speculative decoding support to selective state update#2700
aleozlx merged 7 commits intoflashinfer-ai:mainfrom
roikoren755:feat/selective-state-update-update

Conversation

@roikoren755
Copy link
Copy Markdown
Contributor

@roikoren755 roikoren755 commented Mar 5, 2026

📌 Description

vLLM uses a different scheme for speculative decoding and prefix caching, when compared with SGLang and TRT-LLM, namly:

  • dst_state_batch_indices - telling the kernel where in the state tensor to store the newly computed state
  • cu_seqlens - allowing for a varying number of tokens per sequence when speculative decoding is enabled
  • num_accepted_tokens - used to decide from which index in the state tensor to read the initial cached state per sequence in speculative decoding

This PR adds support for all of these, while keeping support for previous variants, and without hurting performance.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Varlen (variable-length) multi-token support with per-sequence cu_seqlens and num_accepted_tokens; optional dst_state_batch_indices for separate read/write state slots during speculative decoding.
  • API

    • Public selective_state_update APIs extended to accept dst_state_batch_indices, cu_seqlens, num_accepted_tokens and related dtype/options; docstrings updated.
  • Tests

    • Added comprehensive varlen unit tests and a Triton-based varlen reference implementation validating layouts, padding, and per-slot semantics.
  • Chores

    • Benchmarks and tooling updated to exercise varlen mode and algorithm selection.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 5, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds varlen (variable-length sequence) and speculative-decoding support to selective_state_update by introducing optional dst_state_batch_indices, cu_seqlens, and num_accepted_tokens and wiring them through Python APIs, JIT/module generation, C++ dispatchers, kernel params, CUDA kernels, tests, and benchmarks.

Changes

Cohort / File(s) Summary
C++ Bindings & Dispatcher
csrc/flashinfer_mamba_binding.cu, csrc/selective_state_update.cu
Added optional dst_state_batch_indices, cu_seqlens, num_accepted_tokens to public/native signatures; extended validation, packing, and dispatcher logic to select STP vs MTP/varlen paths and carry new strides/indices.
Python API & JIT Module Gen
flashinfer/mamba/selective_state_update.py, flashinfer/jit/mamba/selective_state_update.py
Extended public/internal Python signatures and dtype plumbing to accept/forward dst_state_batch_indices, cu_seqlens, num_accepted_tokens; updated module URI/config generation.
Kernel Param Structs & Headers
include/flashinfer/mamba/selective_state_update.cuh, csrc/selective_state_update_customize_config.jinja
Added dst_state_batch_indices pointer and strides to params; added cu_seqlens/num_accepted_tokens members and corresponding type aliases in config template.
STP Kernel Internals
include/flashinfer/mamba/kernel_selective_state_update_stp.cuh
Threaded dst_batch/dst_state_batch through producer/consumer helpers and switched write paths to use destination batch when provided.
MTP / Varlen Kernel Internals
include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh
Added per-sequence cu_seqlens/num_accepted tokens, per-sequence base/stride math, seq_len guards, zero-fill beyond seq_len, and gated reads/writes to dst/intermediate buffers.
C++ Implementation Logic
csrc/selective_state_update.cu
Refactored validation and stride packing for 3D varlen vs 4D non-varlen shapes; added dtype/shape checks, mutual-exclusion constraints, and updated runner signatures.
Triton Reference & Tests
tests/mamba/triton_reference/selective_state_update_varlen.py, tests/mamba/test_selective_state_update_varlen.py
Added Triton varlen-capable reference implementation and comprehensive unit tests exercising varlen, dst_state_batch_indices, and num_accepted_tokens semantics.
Benchmarks & Misc
benchmarks/routines/mamba.py, .gitignore
Added --varlen and --algorithm flags, varlen input preparation and dispatch; .idea/ added to .gitignore.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant Caller
    participant PythonWrapper as Python
    participant Binding as CppBinding
    participant Dispatcher
    participant Kernel as CUDAKernel
    participant GPUState as GPUStateMemory

    Caller->>Python: selective_state_update(..., dst_state_batch_indices?, cu_seqlens?, num_accepted_tokens?)
    Python->>Binding: call native binding with optional args
    Binding->>Dispatcher: pack params, detect varlen (cu_seqlens + x.dim)
    Dispatcher->>Kernel: launch STP or MTP/varlen kernel with cu_seqlens/num_accepted_tokens/dst_indices
    Kernel->>GPUState: read source state (state_batch_indices or inferred)
    Kernel->>Kernel: compute updates per-token/sequence (respect cu_seqlens, num_accepted_tokens)
    Kernel->>GPUState: write updates to dst_state_batch_indices or intermediate/out buffers
    Kernel-->>Dispatcher: completion
    Dispatcher-->>Binding: return status/output
    Binding-->>Python: return updated tensor
    Python-->>Caller: deliver result
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested labels

op: norm, ready

Suggested reviewers

  • cyx-6
  • jiahanc
  • IwakuraRein
  • bkryu
  • yzh119
  • Anerudhan
  • kahyunnam
  • jimmyzho

Poem

"I hopped through cu_seqlens along the way,
dst slots ready where tokens play,
Accepted counts show which state to read,
Varlen hops now guide each token's speed,
A rabbit cheers: update, then stay!"

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 48.28% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly and concisely describes the main change: adding varlen and speculative decoding support to selective state update, which aligns with the primary objectives of the changeset.
Description check ✅ Passed The PR description explains what changes are made (dst_state_batch_indices, cu_seqlens, num_accepted_tokens support), why they're needed (vLLM compatibility), includes completed checklist items, and acknowledges testing was performed.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

You can validate your CodeRabbit configuration file in your editor.

If your editor has YAML language server, you can enable auto-completion and validation by adding # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json at the top of your CodeRabbit configuration file.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the selective_state_update operation by integrating features essential for advanced speculative decoding and prefix caching schemes, particularly those used in systems like vLLM. It introduces new parameters and modifies existing kernels to handle variable-length sequences, manage separate read/write state cache slots, and precisely control initial state selection. These changes ensure compatibility with modern inference optimizations without compromising performance.

Highlights

  • Support for dst_state_batch_indices: Introduced a new parameter dst_state_batch_indices to allow writing updated states to different cache slots than where they were read from, enabling advanced state management for speculative decoding.
  • Variable-Length Sequence Support (cu_seqlens): Added cu_seqlens to enable processing of variable-length sequences, where tokens are flattened into a single batch dimension, improving efficiency for diverse batch sizes.
  • Initial State Selection (num_accepted_tokens): Implemented num_accepted_tokens to specify from which index in the state tensor to read the initial cached state per sequence, crucial for speculative decoding scenarios.
  • Enhanced state_batch_indices Flexibility: Updated state_batch_indices to support 2D tensors (N, max_seqlen), providing more granular control over state access for multi-token processing and speculative decoding.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • csrc/flashinfer_mamba_binding.cu
    • Updated selective_state_update function signature to include dst_state_batch_indices, cu_seqlens, and num_accepted_tokens.
    • Modified parameter comments to reflect support for varlen multi-token inputs and 2D state_batch_indices.
  • csrc/selective_state_update.cu
    • Adjusted validate_state_batch_indices to accept both 1D and 2D tensor shapes.
    • Integrated dst_state_batch_indices into run_selective_state_update_stp and run_selective_state_update_mtp function signatures and parameter packing.
    • Refactored run_selective_state_update_mtp to dynamically handle input tensor dimensions based on the presence of cu_seqlens (varlen vs. fixed-length multi-token).
    • Updated the main selective_state_update dispatcher to route calls based on input dimensions and cu_seqlens availability.
  • flashinfer/mamba/selective_state_update.py
    • Extended Python selective_state_update function signature and docstrings to include dst_state_batch_indices, cu_seqlens, and num_accepted_tokens.
    • Modified input tensor dimension handling to correctly process varlen inputs without unnecessary unsqueeze operations.
    • Updated ntokens_mtp calculation to account for cache_steps in varlen mode.
    • Ensured new parameters are correctly passed to the underlying C++ kernels.
  • include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh
    • Updated grid dimension comments to reflect batch_or_n_sequences.
    • Added cu_seqlens, num_accepted_tokens, and dst_state_batch_indices to kernel parameters.
    • Implemented logic to determine sequence length and initial token index based on cu_seqlens and num_accepted_tokens.
    • Adjusted memory access patterns and strides for input tensors (x, dt, B, C, z, out) to support varlen processing.
    • Modified state update logic to write to dst_state_batch_indices when provided, enabling separate read/write state slots.
    • Added boundary checks (if (step >= seq_len)) within loops to handle variable sequence lengths correctly.
  • include/flashinfer/mamba/kernel_selective_state_update_stp.cuh
    • Added dst_state_batch_indices to kernel parameters and introduced dst_state_batch for write operations.
    • Updated state and state scale write operations to target dst_state and dst_state_scale when dst_state_batch_indices is used.
    • Modified producer functions (producer_func_vertical, producer_func_horizontal) to accept and utilize dst_batch for state writebacks.
  • include/flashinfer/mamba/selective_state_update.cuh
    • Added dst_state_batch_indices and stride members for state_batch_indices and dst_state_batch_indices to SelectiveStateUpdateParams.
    • Included cu_seqlens and num_accepted_tokens in SelectiveStateMTPParams.
  • tests/mamba/test_selective_state_update_varlen.py
    • Added a new test file to validate dst_state_batch_indices, cu_seqlens, num_accepted_tokens, and 2D state_batch_indices functionality.
    • Implemented tests for scenarios including distinct source/destination state slots, 2D indices with single-token sequences, uniform and variable-length varlen sequences, and initial state selection via num_accepted_tokens.
  • tests/mamba/triton_reference/selective_state_update_varlen.py
    • Added a new Triton reference implementation for selective_state_update that supports varlen and speculative decoding features.
    • Provided a Triton kernel (_selective_scan_update_kernel) and a Python wrapper (selective_state_update_varlen_triton) for accurate comparison against the FlashInfer implementation.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for vLLM-style speculative decoding and prefix caching by adding dst_state_batch_indices, cu_seqlens, and num_accepted_tokens parameters, enabling varlen inputs and separate read/write state slots. However, a security audit identified several vulnerabilities related to missing or insufficient validation of input tensors and indices, primarily manifesting as potential out-of-bounds (OOB) memory access within CUDA kernels. This could lead to sensitive data leakage or memory corruption. Critical findings include an integer underflow risk with cu_seqlens, a regression where batch size validation against the state cache size was removed, and a general lack of bounds checking for user-supplied indices. It is crucial to address these OOB access risks, especially by restoring a safety check that was removed during refactoring.

Comment on lines +337 to +366
int64_t batch;
int64_t ntokens_mtp;

auto const state_cache_size = state.size(0);
auto const nheads = state.size(1);
auto const dim = state.size(2);
auto const dstate = state.size(3);
auto const ngroups = B.size(2);

FLASHINFER_CHECK(state_cache_size >= batch, "state.size(0) must be >= x.size(0)");
FLASHINFER_CHECK(nheads % ngroups == 0, "nheads must be divisible by ngroups");

// Check x shape and strides
CHECK_CUDA(x);
CHECK_DIM(4, x);
FLASHINFER_CHECK(x.size(2) == nheads, "x.size(2) must equal nheads");
FLASHINFER_CHECK(x.size(3) == dim, "x.size(3) must equal dim");
CHECK_LAST_DIM_CONTIGUOUS(x);
FLASHINFER_CHECK(x.stride(2) == dim, "x.stride(2) must equal dim, got ", x.stride(2),
" expected ", dim);
if (is_varlen) {
CHECK_DIM(3, x); // x: {total_tokens, nheads, dim}
FLASHINFER_CHECK(x.size(1) == nheads, "x.size(1) must equal nheads");
FLASHINFER_CHECK(x.size(2) == dim, "x.size(2) must equal dim");
CHECK_LAST_DIM_CONTIGUOUS(x);
FLASHINFER_CHECK(x.stride(1) == dim, "x.stride(1) must equal dim");
batch = cu_seqlens.value().size(0) - 1;
FLASHINFER_CHECK(cache_steps >= 1,
"cache_steps must be >= 1 in varlen mode (specifies max_seqlen)");
ntokens_mtp = cache_steps;
} else {
CHECK_DIM(4, x); // x: {batch, ntokens_mtp, nheads, dim}
batch = x.size(0);
ntokens_mtp = x.size(1);
FLASHINFER_CHECK(x.size(2) == nheads, "x.size(2) must equal nheads");
FLASHINFER_CHECK(x.size(3) == dim, "x.size(3) must equal dim");
CHECK_LAST_DIM_CONTIGUOUS(x);
FLASHINFER_CHECK(x.stride(2) == dim, "x.stride(2) must equal dim, got ", x.stride(2),
" expected ", dim);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The removal of the FLASHINFER_CHECK(state_cache_size >= batch, ...) validation in the multi-token prediction (MTP) path is a critical security regression. Without this check, if state_batch_indices is not provided and the input batch size exceeds the state_cache_size, the kernel will perform out-of-bounds memory access on the state tensor. This could lead to data leakage or corruption. This check is still present in the single-token path (run_selective_state_update_stp) and is essential for preventing OOB access. Please restore this validation check in run_selective_state_update_mtp for the non-varlen case.

FLASHINFER_CHECK(x.size(2) == dim, "x.size(2) must equal dim");
CHECK_LAST_DIM_CONTIGUOUS(x);
FLASHINFER_CHECK(x.stride(1) == dim, "x.stride(1) must equal dim");
batch = cu_seqlens.value().size(0) - 1;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The calculation of the batch size from cu_seqlens is vulnerable to an integer underflow if an empty cu_seqlens tensor is provided. Specifically, batch = cu_seqlens.value().size(0) - 1 will result in -1. When this value is assigned to params.batch (a uint32_t), it becomes 0xFFFFFFFF. This extremely large value is used as the grid dimension for the kernel launch and subsequently used within the kernel to index into cu_seqlens and other tensors, leading to out-of-bounds memory access. Please add a check to ensure cu_seqlens has at least one element: FLASHINFER_CHECK(cs.size(0) >= 1, "cu_seqlens must have at least one element");.

Comment on lines 110 to 111
auto const state_ptr_offset = state_batch * params.state_stride_batch + head * DIM * DSTATE;
state += state_ptr_offset;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The kernel retrieves indices from state_batch_indices, dst_state_batch_indices, and intermediate_state_indices and uses them to calculate offsets for accessing the state and intermediate_states tensors without any bounds checking. These indices are not validated to be within the valid range [0, state_cache_size). An attacker providing malicious indices could read from or write to arbitrary memory locations within the GPU's address space, potentially leading to sensitive data leakage from other users' sessions or memory corruption. Please implement bounds checks against params.state_cache_size before using these indices.

Comment on lines 76 to 84
inline void validate_state_batch_indices(Optional<TensorView> const& state_batch_indices,
int64_t batch) {
if (!state_batch_indices.has_value()) return;
CHECK_DIM(1, (*state_batch_indices));
CHECK_CONTIGUOUS((*state_batch_indices));
FLASHINFER_CHECK(state_batch_indices.value().size(0) == batch,
"state_batch_indices.shape must be (", batch, ")");
auto const& sbi = state_batch_indices.value();
FLASHINFER_CHECK(sbi.dim() == 1 || sbi.dim() == 2, "state_batch_indices must be 1D or 2D, got ",
sbi.dim(), "D");
FLASHINFER_CHECK(sbi.size(0) >= batch, "state_batch_indices.size(0) must be >= batch (", batch,
")");
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The validate_state_batch_indices function only validates the first dimension of the indices tensor. However, in multi-token mode, the kernel also accesses the second dimension using the step index, which can go up to cache_steps - 1. If the provided tensor is 2D and its second dimension is smaller than cache_steps, an out-of-bounds read will occur in the kernel. Please update the validation logic to check sbi.size(1) >= cache_steps when the tensor is 2D and used in the MTP path.

Comment on lines +551 to 559
if (num_accepted_tokens.has_value()) {
auto const& nat = num_accepted_tokens.value();
CHECK_CUDA(nat);
CHECK_DIM(1, nat);
CHECK_CONTIGUOUS(nat);
FLASHINFER_CHECK(nat.dtype().code == kDLInt && nat.dtype().bits == 32,
"num_accepted_tokens must be int32");
p.num_accepted_tokens = const_cast<void*>(nat.data_ptr());
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

Missing validation for the size of the num_accepted_tokens tensor. The kernel accesses this tensor using seq_idx, which ranges from 0 to batch - 1. If num_accepted_tokens.size(0) is less than batch, an out-of-bounds read will occur. Please add a validation check: FLASHINFER_CHECK(nat.size(0) >= batch, "num_accepted_tokens.size(0) must be >= batch");.

@aleozlx aleozlx added the v0.6.6 release blocker label for 0.6.6 label Mar 5, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/mamba/test_selective_state_update_varlen.py (1)

80-162: Recommend adding architecture checks for consistency with other mamba tests.

While the varlen implementation provides fallback kernels for SM80+, other mamba tests (test_selective_state_update_stp.py, test_selective_state_update_mtp.py) explicitly use get_compute_capability() to document architecture support. Adding similar checks here would improve consistency and clarity. This is optional since tests will run on any CUDA GPU, but aligns with the testing pattern used elsewhere in the module.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/mamba/test_selective_state_update_varlen.py` around lines 80 - 162, Add
the same CUDA architecture guard used in other mamba tests: call
get_compute_capability() and skip the test (or class) when the compute
capability is below the minimum you support for the varlen path (mirror checks
in test_selective_state_update_stp.py/test_selective_state_update_mtp.py); place
this check at the start of TestSelectiveStateUpdateDstIndices or inside
test_dst_different_from_src so the test is skipped on unsupported GPUs,
referencing the test name TestSelectiveStateUpdateDstIndices and
test_dst_different_from_src to locate where to insert the guard.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/mamba/test_selective_state_update_varlen.py`:
- Around line 80-162: Add the same CUDA architecture guard used in other mamba
tests: call get_compute_capability() and skip the test (or class) when the
compute capability is below the minimum you support for the varlen path (mirror
checks in
test_selective_state_update_stp.py/test_selective_state_update_mtp.py); place
this check at the start of TestSelectiveStateUpdateDstIndices or inside
test_dst_different_from_src so the test is skipped on unsupported GPUs,
referencing the test name TestSelectiveStateUpdateDstIndices and
test_dst_different_from_src to locate where to insert the guard.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 9cc54fc5-eea3-4d4e-8877-ad7540693e79

📥 Commits

Reviewing files that changed from the base of the PR and between 858d8ff and 85cb984.

📒 Files selected for processing (8)
  • csrc/flashinfer_mamba_binding.cu
  • csrc/selective_state_update.cu
  • flashinfer/mamba/selective_state_update.py
  • include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh
  • include/flashinfer/mamba/kernel_selective_state_update_stp.cuh
  • include/flashinfer/mamba/selective_state_update.cuh
  • tests/mamba/test_selective_state_update_varlen.py
  • tests/mamba/triton_reference/selective_state_update_varlen.py

@roikoren755 roikoren755 changed the title Selective State Update update Add varlen and speculative decoding support to selective state update Mar 5, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
benchmarks/routines/mamba.py (1)

302-379: ⚠️ Potential issue | 🟠 Major

Potential index out of bounds when allocating src_indices and dst_indices.

The ssm_state_cache_size is set to max(384, batch_size * 10), but varlen mode requires 2 * n_seqs * max_seqlen unique indices for non-overlapping src_indices and dst_indices. When 2 * batch_size * cache_steps > ssm_state_cache_size, line 367-369 will cause an index error.

Example: batch_size=10, cache_steps=50 requires 1000 indices but only 384 are available.

🐛 Proposed fix
     ## Prepare input tensors
-    ssm_state_cache_size = max(384, batch_size * 10)
+    if is_varlen:
+        # Varlen needs non-overlapping src and dst indices
+        ssm_state_cache_size = max(384, 2 * batch_size * cache_steps)
+    else:
+        ssm_state_cache_size = max(384, batch_size * 10)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/mamba.py` around lines 302 - 379, ssm_state_cache_size
can be too small for varlen because src_indices and dst_indices slice 2 * n_seqs
* max_seqlen entries from perm; update the allocation to ensure capacity by
computing required = 2 * n_seqs * max_seqlen (or 2 * batch_size * max_seqlen)
and set ssm_state_cache_size = max(384, batch_size * 10, required) before
creating state_cache and perm, or alternatively guard the perm sampling by
generating torch.randperm(required, device=device) when in varlen mode so that
src_indices and dst_indices cannot index out of bounds (references:
ssm_state_cache_size, src_indices, dst_indices, perm, n_seqs, max_seqlen).
🧹 Nitpick comments (2)
csrc/selective_state_update.cu (2)

280-292: Minor inconsistency in CUDA validation between state_batch_indices and dst_state_batch_indices.

CHECK_CUDA is called for dst_state_batch_indices (line 288) but not for state_batch_indices (lines 280-285). For consistency, consider adding CHECK_CUDA validation to validate_state_batch_indices helper function so both tensors are validated uniformly.

♻️ Suggested improvement
 inline void validate_state_batch_indices(Optional<TensorView> const& state_batch_indices,
                                          int64_t batch, int64_t max_seqlen = 1) {
   if (!state_batch_indices.has_value()) return;
   auto const& sbi = state_batch_indices.value();
+  CHECK_CUDA(sbi);
   FLASHINFER_CHECK(sbi.dim() == 1 || sbi.dim() == 2, "state_batch_indices must be 1D or 2D, got ",
                    sbi.dim(), "D");

Then remove CHECK_CUDA(dsbi) from line 288 and line 556.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/selective_state_update.cu` around lines 280 - 292, state_batch_indices
is not validated with CHECK_CUDA while dst_state_batch_indices is; move the
CHECK_CUDA check into the existing validate_state_batch_indices helper so both
tensors are validated uniformly (call CHECK_CUDA on the tensor inside
validate_state_batch_indices), then remove the redundant CHECK_CUDA(dsbi) calls
that remain (e.g., the current CHECK_CUDA before assigning
p.dst_state_batch_indices and any other standalone CHECK_CUDA(dsbi) usage such
as the one later in the file).

668-683: Consider updating the error message to mention varlen mode.

The error message states "3 dimensions (single-token) or 4 dimensions (multi-token)" but doesn't mention that 3D with cu_seqlens is also valid for varlen multi-token mode. This could confuse users who provide 3D input but forget cu_seqlens.

📝 Suggested improvement
   } else {
     FLASHINFER_CHECK(false,
-                    "x must have 3 dimensions (single-token) or 4 dimensions (multi-token), got ",
+                    "x must have 3 dimensions (single-token, or varlen multi-token with cu_seqlens) "
+                    "or 4 dimensions (multi-token), got ",
                      x.dim());
   }
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/selective_state_update.cu` around lines 668 - 683, The error message in
the selection branch is misleading for varlen mode; update the FLASHINFER_CHECK
call so its message mentions that 3D input is valid either for single-token or
for varlen multi-token when cu_seqlens is provided (i.e., clarify "3 dimensions
(single-token) or 4 dimensions (multi-token) or 3 dimensions with cu_seqlens
(varlen/multi-token)"). Edit the FLASHINFER_CHECK invocation near the x.dim()
checks (referencing x.dim(), has_cu_seqlens / cu_seqlens) to include that
wording so users supplying 3D + cu_seqlens won't be confused.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@benchmarks/routines/mamba.py`:
- Around line 302-379: ssm_state_cache_size can be too small for varlen because
src_indices and dst_indices slice 2 * n_seqs * max_seqlen entries from perm;
update the allocation to ensure capacity by computing required = 2 * n_seqs *
max_seqlen (or 2 * batch_size * max_seqlen) and set ssm_state_cache_size =
max(384, batch_size * 10, required) before creating state_cache and perm, or
alternatively guard the perm sampling by generating torch.randperm(required,
device=device) when in varlen mode so that src_indices and dst_indices cannot
index out of bounds (references: ssm_state_cache_size, src_indices, dst_indices,
perm, n_seqs, max_seqlen).

---

Nitpick comments:
In `@csrc/selective_state_update.cu`:
- Around line 280-292: state_batch_indices is not validated with CHECK_CUDA
while dst_state_batch_indices is; move the CHECK_CUDA check into the existing
validate_state_batch_indices helper so both tensors are validated uniformly
(call CHECK_CUDA on the tensor inside validate_state_batch_indices), then remove
the redundant CHECK_CUDA(dsbi) calls that remain (e.g., the current CHECK_CUDA
before assigning p.dst_state_batch_indices and any other standalone
CHECK_CUDA(dsbi) usage such as the one later in the file).
- Around line 668-683: The error message in the selection branch is misleading
for varlen mode; update the FLASHINFER_CHECK call so its message mentions that
3D input is valid either for single-token or for varlen multi-token when
cu_seqlens is provided (i.e., clarify "3 dimensions (single-token) or 4
dimensions (multi-token) or 3 dimensions with cu_seqlens (varlen/multi-token)").
Edit the FLASHINFER_CHECK invocation near the x.dim() checks (referencing
x.dim(), has_cu_seqlens / cu_seqlens) to include that wording so users supplying
3D + cu_seqlens won't be confused.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 0087c120-5ace-4992-b8f0-422e323dc90c

📥 Commits

Reviewing files that changed from the base of the PR and between 85cb984 and 1bb2fc4.

📒 Files selected for processing (3)
  • .gitignore
  • benchmarks/routines/mamba.py
  • csrc/selective_state_update.cu

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Mar 9, 2026

/bot run

@aleozlx aleozlx added the run-ci label Mar 9, 2026
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !391 has been created, and the CI pipeline #45680534 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[SUCCESS] Pipeline #45680534: 9/20 passed

@aleozlx aleozlx removed the v0.6.6 release blocker label for 0.6.6 label Mar 9, 2026
@roikoren755 roikoren755 force-pushed the feat/selective-state-update-update branch from 1bb2fc4 to 17328df Compare March 15, 2026 13:55
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/routines/mamba.py`:
- Around line 296-300: When is_varlen is true the code needs 2 * n_seqs *
max_seqlen cache slots but state_cache is still sized independently; before
materializing the varlen src/dst indices (the perm[...] slices and subsequent
reshape) ensure state_cache is grown/resized to at least 2 * n_seqs * max_seqlen
(use n_seqs = batch_size and max_seqlen = cache_steps) so the second perm slice
and reshape won't be too short; apply the same change to the analogous block
around the perm/reshape at lines 360-370 so both varlen paths expand state_cache
before slicing.

In `@csrc/selective_state_update.cu`:
- Around line 77-87: The code only checks shapes for state_batch_indices (sbi)
but never verifies it's a CUDA tensor before later packing its raw pointer; add
a GPU-device check immediately after extracting sbi and shape checks (same place
where you validate sizes) to ensure sbi.is_cuda() (or the project macro
equivalent) and error out if not CUDA, mirroring the validation done for
dst_state_batch_indices so a host tensor is never dereferenced from device code.

In `@flashinfer/mamba/selective_state_update.py`:
- Around line 278-283: Reject variable-length inputs whose longest sequence
exceeds cache_steps by validating cu_seqlens before setting ntokens_mtp: when
is_varlen is True, compute the maximum span from cu_seqlens (e.g.,
max(cu_seqlens[i+1]-cu_seqlens[i]) or equivalent) and if it is greater than
cache_steps raise an error (or return) instead of assigning ntokens_mtp =
cache_steps; update the early branch around is_varlen/ntokens_mtp in
selective_state_update.py to perform this guard so the kernel never silently
truncates tails.

In `@include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh`:
- Around line 410-432: The dst-slot write path that uses dst_state_batch_indices
currently ignores params.update_state and still mutates
params.state/params.state_scale; modify the block guarded by "if
(has_dst_indices) { ... }" (the code that computes dst_idx, uses dst_state_ptr
and writes into params.state and the dst_scale write using params.state_scale
and sram.state_scale) to first check params.update_state (e.g., if
(params.update_state) before performing any writes) so that when update_state is
false the dst writes are skipped as well.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c3b081c9-8956-4bda-b11e-48955319b1e7

📥 Commits

Reviewing files that changed from the base of the PR and between 1bb2fc4 and 17328df.

📒 Files selected for processing (10)
  • .gitignore
  • benchmarks/routines/mamba.py
  • csrc/flashinfer_mamba_binding.cu
  • csrc/selective_state_update.cu
  • flashinfer/mamba/selective_state_update.py
  • include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh
  • include/flashinfer/mamba/kernel_selective_state_update_stp.cuh
  • include/flashinfer/mamba/selective_state_update.cuh
  • tests/mamba/test_selective_state_update_varlen.py
  • tests/mamba/triton_reference/selective_state_update_varlen.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • include/flashinfer/mamba/selective_state_update.cuh

Comment on lines +296 to +300
if is_varlen:
n_seqs = batch_size
max_seqlen = cache_steps
total_tokens = n_seqs * max_seqlen

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Grow the cache before materializing varlen src/dst indices.

This path needs 2 * n_seqs * max_seqlen distinct slots, but the benchmark still sizes state_cache independently of max_seqlen. Once cache_steps gets large enough, the second perm[...] slice is too short and the reshape fails before timing starts.

🛠️ Proposed fix
-    ssm_state_cache_size = max(384, batch_size * 10)
+    ssm_state_cache_size = max(384, batch_size * 10)
+    if is_varlen:
+        ssm_state_cache_size = max(
+            ssm_state_cache_size, 2 * n_seqs * max_seqlen
+        )

Also applies to: 360-370

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/mamba.py` around lines 296 - 300, When is_varlen is true
the code needs 2 * n_seqs * max_seqlen cache slots but state_cache is still
sized independently; before materializing the varlen src/dst indices (the
perm[...] slices and subsequent reshape) ensure state_cache is grown/resized to
at least 2 * n_seqs * max_seqlen (use n_seqs = batch_size and max_seqlen =
cache_steps) so the second perm slice and reshape won't be too short; apply the
same change to the analogous block around the perm/reshape at lines 360-370 so
both varlen paths expand state_cache before slicing.

Comment on lines +77 to +87
int64_t batch, int64_t max_seqlen = 1) {
if (!state_batch_indices.has_value()) return;
CHECK_DIM(1, (*state_batch_indices));
CHECK_CONTIGUOUS((*state_batch_indices));
FLASHINFER_CHECK(state_batch_indices.value().size(0) == batch,
"state_batch_indices.shape must be (", batch, ")");
auto const& sbi = state_batch_indices.value();
FLASHINFER_CHECK(sbi.dim() == 1 || sbi.dim() == 2, "state_batch_indices must be 1D or 2D, got ",
sbi.dim(), "D");
FLASHINFER_CHECK(sbi.size(0) >= batch, "state_batch_indices.size(0) must be >= batch (", batch,
")");
if (sbi.dim() == 2) {
FLASHINFER_CHECK(sbi.size(1) >= max_seqlen,
"state_batch_indices.size(1) must be >= max_seqlen (", max_seqlen, ")");
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate state_batch_indices on CUDA before packing its raw pointer.

After widening this helper to 1D/2D, state_batch_indices only gets shape checks here. Unlike dst_state_batch_indices, it never hits a later CHECK_CUDA, so a host tensor can still be dereferenced from device code.

🛠️ Proposed fix
 inline void validate_state_batch_indices(Optional<TensorView> const& state_batch_indices,
                                          int64_t batch, int64_t max_seqlen = 1) {
   if (!state_batch_indices.has_value()) return;
   auto const& sbi = state_batch_indices.value();
+  CHECK_CUDA(sbi);
   FLASHINFER_CHECK(sbi.dim() == 1 || sbi.dim() == 2, "state_batch_indices must be 1D or 2D, got ",
                    sbi.dim(), "D");
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/selective_state_update.cu` around lines 77 - 87, The code only checks
shapes for state_batch_indices (sbi) but never verifies it's a CUDA tensor
before later packing its raw pointer; add a GPU-device check immediately after
extracting sbi and shape checks (same place where you validate sizes) to ensure
sbi.is_cuda() (or the project macro equivalent) and error out if not CUDA,
mirroring the validation done for dst_state_batch_indices so a host tensor is
never dereferenced from device code.

Comment on lines +278 to +283
if is_varlen:
ntokens_mtp = cache_steps
elif x.dim() == 4:
ntokens_mtp = x.size(1)
else:
ntokens_mtp = 1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Reject varlen inputs whose longest sequence exceeds cache_steps.

ntokens_mtp is specialized directly from cache_steps. If any cu_seqlens span is longer, the kernel only processes the prefix and leaves the tail tokens unwritten.

🛡️ Proposed guard
     if is_varlen:
+        max_seqlen = int((cu_seqlens[1:] - cu_seqlens[:-1]).max().item())
+        if cache_steps < max_seqlen:
+            raise ValueError(
+                f"cache_steps ({cache_steps}) must be >= max sequence length ({max_seqlen})"
+            )
         ntokens_mtp = cache_steps
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if is_varlen:
ntokens_mtp = cache_steps
elif x.dim() == 4:
ntokens_mtp = x.size(1)
else:
ntokens_mtp = 1
if is_varlen:
max_seqlen = int((cu_seqlens[1:] - cu_seqlens[:-1]).max().item())
if cache_steps < max_seqlen:
raise ValueError(
f"cache_steps ({cache_steps}) must be >= max sequence length ({max_seqlen})"
)
ntokens_mtp = cache_steps
elif x.dim() == 4:
ntokens_mtp = x.size(1)
else:
ntokens_mtp = 1
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/mamba/selective_state_update.py` around lines 278 - 283, Reject
variable-length inputs whose longest sequence exceeds cache_steps by validating
cu_seqlens before setting ntokens_mtp: when is_varlen is True, compute the
maximum span from cu_seqlens (e.g., max(cu_seqlens[i+1]-cu_seqlens[i]) or
equivalent) and if it is greater than cache_steps raise an error (or return)
instead of assigning ntokens_mtp = cache_steps; update the early branch around
is_varlen/ntokens_mtp in selective_state_update.py to perform this guard so the
kernel never silently truncates tails.

Comment on lines +410 to +432
if (state_batch != params.pad_slot_id) {
if (has_dst_indices) {
auto dst_idx = static_cast<int64_t>(
dst_state_batch_indices[seq_idx * params.dst_state_batch_indices_stride_batch +
step * params.dst_state_batch_indices_stride_T]);
if (dst_idx != params.pad_slot_id) {
auto* dst_state_ptr = reinterpret_cast<state_t*>(params.state);
for (int i = lane * load_state_t::count; i < DSTATE;
i += warpSize * load_state_t::count) {
auto* src = reinterpret_cast<load_state_t*>(&sram.state[dd][i]);
*reinterpret_cast<load_state_t*>(
&dst_state_ptr[dst_idx * params.state_stride_batch + head * DIM * DSTATE +
d * DSTATE + i]) = *src;
}
if constexpr (scaleState) {
if (lane == 0) {
auto* dst_scale = reinterpret_cast<state_scale_t*>(params.state_scale);
dst_scale[dst_idx * params.state_scale_stride_batch + head * DIM + d] =
sram.state_scale[dd];
}
}
}
} else if (has_intermediate) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Gate dst-slot writes on params.update_state.

disable_state_update=True currently suppresses only the final source-slot write. The new per-token dst_state_batch_indices path still stores into params.state, so verification runs mutate the cache anyway.

🛠️ Proposed fix
-          if (has_dst_indices) {
+          if (params.update_state && has_dst_indices) {
             auto dst_idx = static_cast<int64_t>(
                 dst_state_batch_indices[seq_idx * params.dst_state_batch_indices_stride_batch +
                                         step * params.dst_state_batch_indices_stride_T]);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if (state_batch != params.pad_slot_id) {
if (has_dst_indices) {
auto dst_idx = static_cast<int64_t>(
dst_state_batch_indices[seq_idx * params.dst_state_batch_indices_stride_batch +
step * params.dst_state_batch_indices_stride_T]);
if (dst_idx != params.pad_slot_id) {
auto* dst_state_ptr = reinterpret_cast<state_t*>(params.state);
for (int i = lane * load_state_t::count; i < DSTATE;
i += warpSize * load_state_t::count) {
auto* src = reinterpret_cast<load_state_t*>(&sram.state[dd][i]);
*reinterpret_cast<load_state_t*>(
&dst_state_ptr[dst_idx * params.state_stride_batch + head * DIM * DSTATE +
d * DSTATE + i]) = *src;
}
if constexpr (scaleState) {
if (lane == 0) {
auto* dst_scale = reinterpret_cast<state_scale_t*>(params.state_scale);
dst_scale[dst_idx * params.state_scale_stride_batch + head * DIM + d] =
sram.state_scale[dd];
}
}
}
} else if (has_intermediate) {
if (state_batch != params.pad_slot_id) {
if (params.update_state && has_dst_indices) {
auto dst_idx = static_cast<int64_t>(
dst_state_batch_indices[seq_idx * params.dst_state_batch_indices_stride_batch +
step * params.dst_state_batch_indices_stride_T]);
if (dst_idx != params.pad_slot_id) {
auto* dst_state_ptr = reinterpret_cast<state_t*>(params.state);
for (int i = lane * load_state_t::count; i < DSTATE;
i += warpSize * load_state_t::count) {
auto* src = reinterpret_cast<load_state_t*>(&sram.state[dd][i]);
*reinterpret_cast<load_state_t*>(
&dst_state_ptr[dst_idx * params.state_stride_batch + head * DIM * DSTATE +
d * DSTATE + i]) = *src;
}
if constexpr (scaleState) {
if (lane == 0) {
auto* dst_scale = reinterpret_cast<state_scale_t*>(params.state_scale);
dst_scale[dst_idx * params.state_scale_stride_batch + head * DIM + d] =
sram.state_scale[dd];
}
}
}
} else if (has_intermediate) {
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh` around lines
410 - 432, The dst-slot write path that uses dst_state_batch_indices currently
ignores params.update_state and still mutates params.state/params.state_scale;
modify the block guarded by "if (has_dst_indices) { ... }" (the code that
computes dst_idx, uses dst_state_ptr and writes into params.state and the
dst_scale write using params.state_scale and sram.state_scale) to first check
params.update_state (e.g., if (params.update_state) before performing any
writes) so that when update_state is false the dst writes are skipped as well.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (4)
include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh (1)

411-433: ⚠️ Potential issue | 🟠 Major

Gate dst-slot writes with params.update_state.

Line 412’s dst write branch still mutates params.state when disable_state_update=True (params.update_state == false).

🛠️ Suggested fix
-          if (has_dst_indices) {
+          if (params.update_state && has_dst_indices) {
             auto dst_idx = static_cast<int64_t>(
                 dst_state_batch_indices[seq_idx * params.dst_state_batch_indices_stride_batch +
                                         step * params.dst_state_batch_indices_stride_T]);
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh` around lines
411 - 433, The dst-slot write path currently always mutates params.state and
params.state_scale; guard those writes with the update flag by checking
params.update_state before performing the memory stores—i.e., inside the
has_dst_indices branch around the loop that writes into dst_state_ptr and around
the scaleState block that writes into dst_scale, skip the writes when
params.update_state is false so no mutation occurs when
disable_state_update=True; reference the symbols dst_state_ptr, params.state,
params.state_scale, params.update_state, and sram.state/sram.state_scale to
locate where to add the conditional.
csrc/selective_state_update.cu (2)

76-87: ⚠️ Potential issue | 🟠 Major

Require CUDA memory for state_batch_indices before pointer packing.

state_batch_indices is shape-validated but never device-validated before data_ptr() is consumed by CUDA kernels.

🛠️ Suggested fix
 inline void validate_state_batch_indices(Optional<TensorView> const& state_batch_indices,
                                          int64_t batch, int64_t max_seqlen = 1) {
   if (!state_batch_indices.has_value()) return;
   auto const& sbi = state_batch_indices.value();
+  CHECK_CUDA(sbi);
   FLASHINFER_CHECK(sbi.dim() == 1 || sbi.dim() == 2, "state_batch_indices must be 1D or 2D, got ",
                    sbi.dim(), "D");

Also applies to: 280-285, 548-553

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/selective_state_update.cu` around lines 76 - 87, The shape checks in
validate_state_batch_indices validate dims/sizes but do not ensure
state_batch_indices is resident on the CUDA device before its pointer is
consumed by kernels; add a device check (e.g., assert or FLASHINFER_CHECK that
the TensorView sbi is on CUDA: sbi.is_cuda() or sbi.device().is_cuda()) right
after retrieving sbi and before any code that will call data_ptr() and be passed
to CUDA kernels; apply the same CUDA-device validation to the other similar
validation sites that handle state_batch_indices before pointer packing.

358-377: ⚠️ Potential issue | 🔴 Critical

Varlen path should validate flattened leading dimensions across tensors.

In varlen mode, dt/B/C/z/out are not checked against x.size(0). A shorter tensor can be indexed out-of-bounds via bos + step.

🛡️ Suggested checks
   if (is_varlen) {
     CHECK_DIM(3, x);  // x: {total_tokens, nheads, dim}
+    int64_t const total_tokens = x.size(0);
     FLASHINFER_CHECK(x.size(1) == nheads, "x.size(1) must equal nheads");
     FLASHINFER_CHECK(x.size(2) == dim, "x.size(2) must equal dim");
@@
   if (is_varlen) {
     CHECK_DIM(3, dt);  // dt: {total_tokens, nheads, dim}
+    FLASHINFER_CHECK(dt.size(0) == x.size(0), "dt.size(0) must equal x.size(0) in varlen mode");
@@
   if (is_varlen) {
     CHECK_DIM(3, B);  // B: {total_tokens, ngroups, dstate}
+    FLASHINFER_CHECK(B.size(0) == x.size(0), "B.size(0) must equal x.size(0) in varlen mode");
@@
   if (is_varlen) {
     CHECK_DIM(3, C);  // C: {total_tokens, ngroups, dstate}
+    FLASHINFER_CHECK(C.size(0) == x.size(0), "C.size(0) must equal x.size(0) in varlen mode");
@@
     if (is_varlen) {
       CHECK_DIM(3, z_tensor);  // z: {total_tokens, nheads, dim}
+      FLASHINFER_CHECK(z_tensor.size(0) == x.size(0),
+                       "z.size(0) must equal x.size(0) in varlen mode");
@@
     if (is_varlen) {
       CHECK_DIM(3, output);  // out: {total_tokens, nheads, dim}
+      FLASHINFER_CHECK(output.size(0) == x.size(0),
+                       "out.size(0) must equal x.size(0) in varlen mode");

Also applies to: 386-472

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/selective_state_update.cu` around lines 358 - 377, In the is_varlen
branch add explicit validation that all tensors indexed by flattened token
positions (dt, B, C, z, out) have their leading flattened dimension equal to
x.size(0) (the total_tokens computed from x) so indexing with bos + step cannot
go out-of-bounds; use cu_seqlens.value().size(0)-1 (or the same total_tokens
variable) to compare against dt.size(0), B.size(0), C.size(0), z.size(0),
out.size(0) and emit FLASHINFER_CHECK errors when they differ; also ensure any
use of cache_steps/bos/step is guarded by these checks so bos + step <
x.size(0).
flashinfer/mamba/selective_state_update.py (1)

286-291: ⚠️ Potential issue | 🟠 Major

Add a varlen guard: cache_steps must cover the longest sequence.

Without this check, varlen sequences longer than cache_steps are truncated by the specialized kernel token budget.

🛡️ Suggested guard
     if is_varlen:
+        max_seqlen = int((cu_seqlens[1:] - cu_seqlens[:-1]).max().item())
+        if cache_steps < max_seqlen:
+            raise ValueError(
+                f"cache_steps ({cache_steps}) must be >= max sequence length ({max_seqlen})"
+            )
         ntokens_mtp = cache_steps
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/mamba/selective_state_update.py` around lines 286 - 291, When
is_varlen is True the code assumes cache_steps covers the full token length but
doesn't verify it; add a guard in selective_state_update to compute the longest
sequence length from the input (e.g. derive max_seq_len from x.size(1) or the
provided lengths tensor) and assert or raise a clear ValueError if cache_steps <
max_seq_len so the specialized kernel token budget won't truncate sequences;
update the branch that sets ntokens_mtp (the is_varlen branch) to perform this
check and fail fast with an informative message referencing cache_steps and
max_seq_len.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/selective_state_update.cu`:
- Around line 561-569: Reject/ignore cu_seqlens when input tensor x is
non-varlen (x.dim() == 4): in the block that assigns p.cu_seqlens from
cu_seqlens (the code referencing cu_seqlens, cs, and p.cu_seqlens), add a guard
that checks x.dim() and if x.dim() == 4 then either DCHECK/FLASHINFER_CHECK that
cu_seqlens is not provided or simply do not set p.cu_seqlens (leave it null) and
log/raise an error; apply the same guard and behavior in the analogous block
around the later assignment at lines 669-674 so the kernel will not switch to
varlen addressing when x is 4D.
- Around line 570-579: When num_accepted_tokens is provided, ensure
state_batch_indices is a 2D CUDA tensor (not 1D) so the kernel can read
state_batch_indices[seq_idx, init_token_idx]; add checks after
FLASHINFER_CHECK(state_batch_indices.has_value()) to validate
state_batch_indices.dim()==2, CHECK_CUDA(state_batch_indices.value()),
CHECK_CONTIGUOUS(state_batch_indices.value()),
FLASHINFER_CHECK(state_batch_indices.value().size(0) >= batch, ...) and
FLASHINFER_CHECK(state_batch_indices.value().size(1) > 0, ...), then set
p.state_batch_indices =
const_cast<void*>(state_batch_indices.value().data_ptr()) alongside
p.num_accepted_tokens to ensure nonzero stride_T and correct indexing.

---

Duplicate comments:
In `@csrc/selective_state_update.cu`:
- Around line 76-87: The shape checks in validate_state_batch_indices validate
dims/sizes but do not ensure state_batch_indices is resident on the CUDA device
before its pointer is consumed by kernels; add a device check (e.g., assert or
FLASHINFER_CHECK that the TensorView sbi is on CUDA: sbi.is_cuda() or
sbi.device().is_cuda()) right after retrieving sbi and before any code that will
call data_ptr() and be passed to CUDA kernels; apply the same CUDA-device
validation to the other similar validation sites that handle state_batch_indices
before pointer packing.
- Around line 358-377: In the is_varlen branch add explicit validation that all
tensors indexed by flattened token positions (dt, B, C, z, out) have their
leading flattened dimension equal to x.size(0) (the total_tokens computed from
x) so indexing with bos + step cannot go out-of-bounds; use
cu_seqlens.value().size(0)-1 (or the same total_tokens variable) to compare
against dt.size(0), B.size(0), C.size(0), z.size(0), out.size(0) and emit
FLASHINFER_CHECK errors when they differ; also ensure any use of
cache_steps/bos/step is guarded by these checks so bos + step < x.size(0).

In `@flashinfer/mamba/selective_state_update.py`:
- Around line 286-291: When is_varlen is True the code assumes cache_steps
covers the full token length but doesn't verify it; add a guard in
selective_state_update to compute the longest sequence length from the input
(e.g. derive max_seq_len from x.size(1) or the provided lengths tensor) and
assert or raise a clear ValueError if cache_steps < max_seq_len so the
specialized kernel token budget won't truncate sequences; update the branch that
sets ntokens_mtp (the is_varlen branch) to perform this check and fail fast with
an informative message referencing cache_steps and max_seq_len.

In `@include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh`:
- Around line 411-433: The dst-slot write path currently always mutates
params.state and params.state_scale; guard those writes with the update flag by
checking params.update_state before performing the memory stores—i.e., inside
the has_dst_indices branch around the loop that writes into dst_state_ptr and
around the scaleState block that writes into dst_scale, skip the writes when
params.update_state is false so no mutation occurs when
disable_state_update=True; reference the symbols dst_state_ptr, params.state,
params.state_scale, params.update_state, and sram.state/sram.state_scale to
locate where to add the conditional.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 3f3148d4-1983-4797-a149-e6055374c8c2

📥 Commits

Reviewing files that changed from the base of the PR and between 17328df and 30c2c3e.

📒 Files selected for processing (8)
  • csrc/flashinfer_mamba_binding.cu
  • csrc/selective_state_update.cu
  • csrc/selective_state_update_customize_config.jinja
  • flashinfer/jit/mamba/selective_state_update.py
  • flashinfer/mamba/selective_state_update.py
  • include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh
  • include/flashinfer/mamba/selective_state_update.cuh
  • tests/mamba/test_selective_state_update_varlen.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • include/flashinfer/mamba/selective_state_update.cuh
  • tests/mamba/test_selective_state_update_varlen.py

Comment on lines +561 to +569
if (cu_seqlens.has_value()) {
auto const& cs = cu_seqlens.value();
CHECK_CUDA(cs);
CHECK_DIM(1, cs);
CHECK_CONTIGUOUS(cs);
FLASHINFER_CHECK(cs.size(0) == batch + 1, "cu_seqlens.size(0) must equal n_sequences + 1 (",
batch + 1, ")");
p.cu_seqlens = const_cast<void*>(cs.data_ptr());
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Disallow cu_seqlens for non-varlen (x.dim()==4) inputs.

cu_seqlens is packed whenever present, and the kernel switches to varlen addressing based on that pointer even if x is 4D.

🧭 Suggested fix
   bool const is_varlen = (x.dim() == 3 && cu_seqlens.has_value());
+  FLASHINFER_CHECK(!(cu_seqlens.has_value() && x.dim() != 3),
+                   "cu_seqlens is only supported when x is 3D varlen layout");
@@
-  if (cu_seqlens.has_value()) {
+  if (is_varlen) {
     auto const& cs = cu_seqlens.value();
     CHECK_CUDA(cs);
     CHECK_DIM(1, cs);

Also applies to: 669-674

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/selective_state_update.cu` around lines 561 - 569, Reject/ignore
cu_seqlens when input tensor x is non-varlen (x.dim() == 4): in the block that
assigns p.cu_seqlens from cu_seqlens (the code referencing cu_seqlens, cs, and
p.cu_seqlens), add a guard that checks x.dim() and if x.dim() == 4 then either
DCHECK/FLASHINFER_CHECK that cu_seqlens is not provided or simply do not set
p.cu_seqlens (leave it null) and log/raise an error; apply the same guard and
behavior in the analogous block around the later assignment at lines 669-674 so
the kernel will not switch to varlen addressing when x is 4D.

@roikoren755 roikoren755 force-pushed the feat/selective-state-update-update branch from 30c2c3e to fd5eed6 Compare March 18, 2026 14:23
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
flashinfer/mamba/selective_state_update.py (1)

274-282: ⚠️ Potential issue | 🟠 Major

Validate that all index tensors share one dtype.

The generated module specializes a single stateIndex_t, but this picks the first non-None dtype and forwards the others unchanged. If state_batch_indices, dst_state_batch_indices, and/or intermediate_state_indices differ, at least one tensor will be reinterpreted with the wrong element width in the CUDA path.

🛠️ Proposed fix
-    # Determine stateIndex dtype from index tensors, default to int32
-    stateIndex_dtype = torch.int32
-    if state_batch_indices is not None:
-        stateIndex_dtype = state_batch_indices.dtype
-    elif dst_state_batch_indices is not None:
-        stateIndex_dtype = dst_state_batch_indices.dtype
-    elif intermediate_state_indices is not None:
-        stateIndex_dtype = intermediate_state_indices.dtype
+    # Determine stateIndex dtype from index tensors, default to int32.
+    # All index tensors in one launch must share the same dtype because the
+    # generated module only specializes a single stateIndex_t.
+    index_dtypes = {
+        tensor.dtype
+        for tensor in (
+            state_batch_indices,
+            dst_state_batch_indices,
+            intermediate_state_indices,
+        )
+        if tensor is not None
+    }
+    if len(index_dtypes) > 1:
+        raise ValueError(
+            "state_batch_indices, dst_state_batch_indices, and "
+            "intermediate_state_indices must share the same dtype"
+        )
+    stateIndex_dtype = next(iter(index_dtypes), torch.int32)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/mamba/selective_state_update.py` around lines 274 - 282, The
current logic in selective_state_update.py picks the first non-None dtype into
stateIndex_dtype without ensuring the other index tensors match; update the
block that sets stateIndex_dtype to validate that all non-None tensors among
state_batch_indices, dst_state_batch_indices, and intermediate_state_indices
share the same dtype (compare their .dtype to the chosen stateIndex_dtype) and
raise a clear ValueError if any mismatch is found, so the CUDA path won't
reinterpret tensors with the wrong element width; alternatively, if you prefer
automatic fixes, cast any mismatched tensors to the chosen stateIndex_dtype
before continuing, but be explicit about which approach you take in the
error/logic.
include/flashinfer/mamba/kernel_selective_state_update_stp.cuh (1)

670-727: ⚠️ Potential issue | 🔴 Critical

Reject padded destination slots before enabling SM90 writeback.

The vertical/horizontal SM90 paths only gate writeback on the source slot. If dst_state_batch_indices contains pad_slot_id, the producers still issue TMA writes to batch -1, and the vertical scaled-state path also stores decode scales through that padded destination.

🛠️ Proposed fix
-    auto const write_state = read_state && params.update_state;
+    auto const write_state =
+        read_state && params.update_state && dst_state_batch != params.pad_slot_id;
@@
-      if (params.update_state && state_batch != params.pad_slot_id) {
+      if (params.update_state && state_batch != params.pad_slot_id &&
+          dst_state_batch != params.pad_slot_id) {
         if (d < DIM) {
           state_scale[dst_state_batch * params.state_scale_stride_batch + head * DIM + d] =
               sram.state_scale[d];
         }

Also applies to: 787-791, 1058-1107

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/mamba/kernel_selective_state_update_stp.cuh` around lines
670 - 727, The code currently gates writeback only on the source slot
(read_state) causing producers to issue SM90/TMA writes when dst_state_batch
equals pad_slot_id; compute dst_state_batch (from
params.dst_state_batch_indices) early and reject padded destination slots by
changing the write enable to also require dst_state_batch != params.pad_slot_id
(e.g., set write_state = read_state && params.update_state && dst_state_batch !=
params.pad_slot_id), and use that write_state when instantiating/calling
producer_func_vertical/producer paths and before any scaled-state stores so no
SM90/TMA or state-scale writes occur for padded destination slots (apply same
guard to the horizontal/vertical SM90 paths and the other occurrences noted
around the dst-related blocks).
♻️ Duplicate comments (7)
flashinfer/mamba/selective_state_update.py (1)

286-291: ⚠️ Potential issue | 🟠 Major

Reject varlen sequences longer than cache_steps.

ntokens_mtp is specialized directly from cache_steps, while the MTP kernel iterates exactly that many steps. If any cu_seqlens span is longer, the tail tokens are silently skipped and their outputs/state never get written.

🛠️ Proposed fix
-    if is_varlen:
-        ntokens_mtp = cache_steps
+    if is_varlen:
+        max_seqlen = int((cu_seqlens[1:] - cu_seqlens[:-1]).max().item())
+        if cache_steps < max_seqlen:
+            raise ValueError(
+                f"cache_steps ({cache_steps}) must be >= max sequence length ({max_seqlen})"
+            )
+        ntokens_mtp = cache_steps
     elif x.dim() == 4:
         ntokens_mtp = x.size(1)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/mamba/selective_state_update.py` around lines 286 - 291, The
varlen branch sets ntokens_mtp = cache_steps unconditionally which lets the MTP
kernel iterate cache_steps and silently drop any varlen spans longer than
cache_steps; before assigning ntokens_mtp when is_varlen is true, check
cu_seqlens (the input cumulative sequence lengths for varlen batches) for any
span length > cache_steps and raise an explicit error (ValueError) if found;
otherwise keep ntokens_mtp = cache_steps. Reference the is_varlen branch,
ntokens_mtp, cache_steps, and cu_seqlens to locate where to add this validation.
benchmarks/routines/mamba.py (1)

296-307: ⚠️ Potential issue | 🟠 Major

Grow the cache before materializing varlen src/dst indices.

Varlen mode needs 2 * n_seqs * max_seqlen distinct slots, but ssm_state_cache_size is still independent of max_seqlen. Larger cache_steps makes the second perm[...] slice too short and the reshape fails before benchmarking starts.

🛠️ Proposed fix
     ## Prepare input tensors
     ssm_state_cache_size = max(384, batch_size * 10)
+    if is_varlen:
+        ssm_state_cache_size = max(
+            ssm_state_cache_size, 2 * n_seqs * max_seqlen
+        )
 
     # State cache: (total_entries, nheads, dim, dstate) - contiguous
     state_cache = torch.randn(

Also applies to: 360-369

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/mamba.py` around lines 296 - 307, When is_varlen is true,
ssm_state_cache_size must be grown to accommodate the varlen indices: ensure
ssm_state_cache_size is set to at least max(current_min, 2 * n_seqs *
max_seqlen) before creating state_cache and before any perm[...] slicing; update
the calculation that sets ssm_state_cache_size (used to allocate state_cache) to
use max(384, batch_size * 10, 2 * n_seqs * max_seqlen) so the subsequent perm
slices and reshape succeed, and apply the same change in the later block that
also computes ssm_state_cache_size (the second occurrence around the other
perm/reshape usage).
include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh (1)

411-432: ⚠️ Potential issue | 🟠 Major

Honor update_state on the dst-slot path.

disable_state_update=True currently suppresses only the final source-slot write. When dst_state_batch_indices is present, this block still writes into params.state / params.state_scale, so verification runs mutate the cache anyway.

🛠️ Proposed fix
-          if (has_dst_indices) {
+          if (params.update_state && has_dst_indices) {
             auto dst_idx = static_cast<int64_t>(
                 dst_state_batch_indices[seq_idx * params.dst_state_batch_indices_stride_batch +
                                         step * params.dst_state_batch_indices_stride_T]);
             if (dst_idx != params.pad_slot_id) {
               auto* dst_state_ptr = reinterpret_cast<state_t*>(params.state);
@@
-          } else if (has_intermediate) {
+          } else if (has_intermediate) {
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh` around lines
411 - 432, The dst-slot write path still updates params.state and
params.state_scale even when state updates should be disabled; wrap the writes
inside the dst-state block (the loop writing into params.state and the
scaleState branch that writes into params.state_scale) with a guard that
respects the update flag (e.g., if constexpr (!disable_state_update) or the
existing update_state template/flag), so that when disable_state_update is true
no writes occur to dst_state_batch_indices/dst_idx -> params.state or
params.state_scale; reference dst_state_batch_indices, dst_idx, params.state,
params.state_scale, scaleState, sram.state, load_state_t and the existing
dst-slot write loop to locate the changes.
csrc/selective_state_update.cu (4)

570-579: ⚠️ Potential issue | 🟠 Major

Require 2D state_batch_indices when num_accepted_tokens is provided.

The kernel path uses per-token accepted offsets; with 1D state_batch_indices, state_batch_indices_stride_T becomes 0 (Line 552), so accepted-token indexing is ignored.

🛠️ Suggested fix
   if (num_accepted_tokens.has_value()) {
@@
     FLASHINFER_CHECK(state_batch_indices.has_value(),
                      "state_batch_indices is required when num_accepted_tokens is provided");
+    FLASHINFER_CHECK(state_batch_indices.value().dim() == 2,
+                     "state_batch_indices must be 2D when num_accepted_tokens is provided");
     p.num_accepted_tokens = const_cast<void*>(nat.data_ptr());
   }

561-569: ⚠️ Potential issue | 🔴 Critical

Reject cu_seqlens for 4D input and only pack it in varlen mode.

cu_seqlens is currently packed whenever present (Line 561), while dispatcher still routes x.dim()==4 into MTP (Line 669). This can incorrectly enable varlen addressing for non-varlen layout.

🛠️ Suggested fix
   bool const has_cu_seqlens = cu_seqlens.has_value();
+  FLASHINFER_CHECK(!(has_cu_seqlens && x.dim() != 3),
+                   "cu_seqlens is only supported when x is 3D varlen layout");
@@
-  if (cu_seqlens.has_value()) {
+  if (is_varlen) {
     auto const& cs = cu_seqlens.value();
     CHECK_CUDA(cs);
     CHECK_DIM(1, cs);

Also applies to: 664-670

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/selective_state_update.cu` around lines 561 - 569, The code currently
packs cu_seqlens into p.cu_seqlens whenever present, which allows varlen
addressing for inputs with x.dim()==4 (MTP); change the logic so cu_seqlens is
only accepted and assigned when the input is in varlen mode (i.e., not 4D).
Concretely, in the block handling cu_seqlens (the code that reads cs and sets
p.cu_seqlens) add a guard that rejects or ignores cu_seqlens if x.dim() == 4 (or
check the varlen flag used by the dispatcher) and only const_cast and assign to
p.cu_seqlens when varlen is true; apply the same change to the other symmetric
packing site that currently always assigns p.cu_seqlens.

358-472: ⚠️ Potential issue | 🔴 Critical

Varlen mode is missing first-dimension token-count consistency checks.

In varlen branches, dt, B, C, optional z, and optional out are not checked against x.size(0) on their first dimension. This can allow undersized tensors and out-of-bounds access when indexed by token offsets.

🛠️ Suggested fix
   if (is_varlen) {
     CHECK_DIM(3, x);  // x: {total_tokens, nheads, dim}
+    auto const total_tokens = x.size(0);
     FLASHINFER_CHECK(x.size(1) == nheads, "x.size(1) must equal nheads");
     FLASHINFER_CHECK(x.size(2) == dim, "x.size(2) must equal dim");
@@
   if (is_varlen) {
     CHECK_DIM(3, dt);  // dt: {total_tokens, nheads, dim}
+    FLASHINFER_CHECK(dt.size(0) == x.size(0), "dt.size(0) must equal total_tokens");
@@
   if (is_varlen) {
     CHECK_DIM(3, B);  // B: {total_tokens, ngroups, dstate}
+    FLASHINFER_CHECK(B.size(0) == x.size(0), "B.size(0) must equal total_tokens");
@@
   if (is_varlen) {
     CHECK_DIM(3, C);  // C: {total_tokens, ngroups, dstate}
+    FLASHINFER_CHECK(C.size(0) == x.size(0), "C.size(0) must equal total_tokens");
@@
     if (is_varlen) {
       CHECK_DIM(3, z_tensor);  // z: {total_tokens, nheads, dim}
+      FLASHINFER_CHECK(z_tensor.size(0) == x.size(0), "z.size(0) must equal total_tokens");
@@
     if (is_varlen) {
       CHECK_DIM(3, output);  // out: {total_tokens, nheads, dim}
+      FLASHINFER_CHECK(output.size(0) == x.size(0), "out.size(0) must equal total_tokens");
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/selective_state_update.cu` around lines 358 - 472, The varlen branches
fail to validate that dt, B, C, z (z_tensor) and out (output) have their first
dimension equal to x.size(0) (total_tokens), which risks OOB when indexing by
token offsets; update the checks inside each is_varlen block to assert
dt.size(0) == x.size(0), B.size(0) == x.size(0), C.size(0) == x.size(0), and if
present z_tensor.size(0) == x.size(0) and output.size(0) == x.size(0) (use the
same FLASHINFER_CHECK style as other checks), referencing the existing symbols
dt, B, C, z_tensor, output and x.size(0)/total_tokens to locate where to add
these assertions.

76-87: ⚠️ Potential issue | 🔴 Critical

Add CUDA-device validation for state_batch_indices before pointer packing.

state_batch_indices is used to populate raw kernel pointers (Line 282 and Line 550), but validate_state_batch_indices never enforces CUDA residency. A host tensor here can be dereferenced from device code.

🛠️ Suggested fix
 inline void validate_state_batch_indices(Optional<TensorView> const& state_batch_indices,
                                          int64_t batch, int64_t max_seqlen = 1) {
   if (!state_batch_indices.has_value()) return;
   auto const& sbi = state_batch_indices.value();
+  CHECK_CUDA(sbi);
   FLASHINFER_CHECK(sbi.dim() == 1 || sbi.dim() == 2, "state_batch_indices must be 1D or 2D, got ",
                    sbi.dim(), "D");
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/selective_state_update.cu` around lines 76 - 87,
validate_state_batch_indices currently only checks shape/size but not device
residency, so a host tensor can be used when packing raw kernel pointers (e.g.,
where state_batch_indices is dereferenced to build device pointers for kernels).
Update validate_state_batch_indices to assert the tensor is on CUDA (check
sbi.is_cuda() or sbi.device().is_cuda()) before returning; if not, raise a
FLASHINFER_CHECK/appropriate error explaining it must be a CUDA tensor. Keep the
existing shape/size checks (sbi.dim(), sbi.size(...)) and apply this device
check early (before any pointer-packing sites that consume state_batch_indices).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/selective_state_update.cu`:
- Around line 346-367: When is_varlen is true, validate cu_seqlens length before
computing batch: check that cu_seqlens.has_value() and cu_seqlens->size(0) >= 2
(so batch = cu_seqlens.value().size(0) - 1 is non-negative) and emit a
FLASHINFER_CHECK with a clear message if not; update the same guard in the other
varlen block that computes batch (the later block around ntokens/offset
handling) to avoid deriving batch = -1 from an empty cu_seqlens. Ensure you
perform this check prior to assigning batch and before any subsequent uses of
batch or indexing into cu_seqlens.

In `@include/flashinfer/mamba/kernel_selective_state_update_stp.cuh`:
- Around line 121-136: The code computes
dst_state_batch/dst_state/dst_state_scale from dst_sbi but then only uses the
source-slot write guards, so if dst_state_batch_indices contains a padded slot
(pad_slot_id) the kernel can write through a negative/invalid cache index; fix
by computing a boolean dst_valid (e.g., dst_sbi != nullptr && dst_state_batch !=
pad_slot_id) and use that same dst_valid wherever writes to dst_state or
dst_state_scale occur (mirror the existing source-slot guard logic), and ensure
any arithmetic that constructs dst_state and dst_state_scale is only used when
dst_valid is true to avoid creating/using invalid pointers (apply same change
around the other occurrences mentioned: the blocks around lines with
dst_state/dst_state_scale at the other offsets).

---

Outside diff comments:
In `@flashinfer/mamba/selective_state_update.py`:
- Around line 274-282: The current logic in selective_state_update.py picks the
first non-None dtype into stateIndex_dtype without ensuring the other index
tensors match; update the block that sets stateIndex_dtype to validate that all
non-None tensors among state_batch_indices, dst_state_batch_indices, and
intermediate_state_indices share the same dtype (compare their .dtype to the
chosen stateIndex_dtype) and raise a clear ValueError if any mismatch is found,
so the CUDA path won't reinterpret tensors with the wrong element width;
alternatively, if you prefer automatic fixes, cast any mismatched tensors to the
chosen stateIndex_dtype before continuing, but be explicit about which approach
you take in the error/logic.

In `@include/flashinfer/mamba/kernel_selective_state_update_stp.cuh`:
- Around line 670-727: The code currently gates writeback only on the source
slot (read_state) causing producers to issue SM90/TMA writes when
dst_state_batch equals pad_slot_id; compute dst_state_batch (from
params.dst_state_batch_indices) early and reject padded destination slots by
changing the write enable to also require dst_state_batch != params.pad_slot_id
(e.g., set write_state = read_state && params.update_state && dst_state_batch !=
params.pad_slot_id), and use that write_state when instantiating/calling
producer_func_vertical/producer paths and before any scaled-state stores so no
SM90/TMA or state-scale writes occur for padded destination slots (apply same
guard to the horizontal/vertical SM90 paths and the other occurrences noted
around the dst-related blocks).

---

Duplicate comments:
In `@benchmarks/routines/mamba.py`:
- Around line 296-307: When is_varlen is true, ssm_state_cache_size must be
grown to accommodate the varlen indices: ensure ssm_state_cache_size is set to
at least max(current_min, 2 * n_seqs * max_seqlen) before creating state_cache
and before any perm[...] slicing; update the calculation that sets
ssm_state_cache_size (used to allocate state_cache) to use max(384, batch_size *
10, 2 * n_seqs * max_seqlen) so the subsequent perm slices and reshape succeed,
and apply the same change in the later block that also computes
ssm_state_cache_size (the second occurrence around the other perm/reshape
usage).

In `@csrc/selective_state_update.cu`:
- Around line 561-569: The code currently packs cu_seqlens into p.cu_seqlens
whenever present, which allows varlen addressing for inputs with x.dim()==4
(MTP); change the logic so cu_seqlens is only accepted and assigned when the
input is in varlen mode (i.e., not 4D). Concretely, in the block handling
cu_seqlens (the code that reads cs and sets p.cu_seqlens) add a guard that
rejects or ignores cu_seqlens if x.dim() == 4 (or check the varlen flag used by
the dispatcher) and only const_cast and assign to p.cu_seqlens when varlen is
true; apply the same change to the other symmetric packing site that currently
always assigns p.cu_seqlens.
- Around line 358-472: The varlen branches fail to validate that dt, B, C, z
(z_tensor) and out (output) have their first dimension equal to x.size(0)
(total_tokens), which risks OOB when indexing by token offsets; update the
checks inside each is_varlen block to assert dt.size(0) == x.size(0), B.size(0)
== x.size(0), C.size(0) == x.size(0), and if present z_tensor.size(0) ==
x.size(0) and output.size(0) == x.size(0) (use the same FLASHINFER_CHECK style
as other checks), referencing the existing symbols dt, B, C, z_tensor, output
and x.size(0)/total_tokens to locate where to add these assertions.
- Around line 76-87: validate_state_batch_indices currently only checks
shape/size but not device residency, so a host tensor can be used when packing
raw kernel pointers (e.g., where state_batch_indices is dereferenced to build
device pointers for kernels). Update validate_state_batch_indices to assert the
tensor is on CUDA (check sbi.is_cuda() or sbi.device().is_cuda()) before
returning; if not, raise a FLASHINFER_CHECK/appropriate error explaining it must
be a CUDA tensor. Keep the existing shape/size checks (sbi.dim(), sbi.size(...))
and apply this device check early (before any pointer-packing sites that consume
state_batch_indices).

In `@flashinfer/mamba/selective_state_update.py`:
- Around line 286-291: The varlen branch sets ntokens_mtp = cache_steps
unconditionally which lets the MTP kernel iterate cache_steps and silently drop
any varlen spans longer than cache_steps; before assigning ntokens_mtp when
is_varlen is true, check cu_seqlens (the input cumulative sequence lengths for
varlen batches) for any span length > cache_steps and raise an explicit error
(ValueError) if found; otherwise keep ntokens_mtp = cache_steps. Reference the
is_varlen branch, ntokens_mtp, cache_steps, and cu_seqlens to locate where to
add this validation.

In `@include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh`:
- Around line 411-432: The dst-slot write path still updates params.state and
params.state_scale even when state updates should be disabled; wrap the writes
inside the dst-state block (the loop writing into params.state and the
scaleState branch that writes into params.state_scale) with a guard that
respects the update flag (e.g., if constexpr (!disable_state_update) or the
existing update_state template/flag), so that when disable_state_update is true
no writes occur to dst_state_batch_indices/dst_idx -> params.state or
params.state_scale; reference dst_state_batch_indices, dst_idx, params.state,
params.state_scale, scaleState, sram.state, load_state_t and the existing
dst-slot write loop to locate the changes.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ec58034c-444e-40c0-b640-dabfc9d0299a

📥 Commits

Reviewing files that changed from the base of the PR and between 30c2c3e and fd5eed6.

📒 Files selected for processing (12)
  • .gitignore
  • benchmarks/routines/mamba.py
  • csrc/flashinfer_mamba_binding.cu
  • csrc/selective_state_update.cu
  • csrc/selective_state_update_customize_config.jinja
  • flashinfer/jit/mamba/selective_state_update.py
  • flashinfer/mamba/selective_state_update.py
  • include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh
  • include/flashinfer/mamba/kernel_selective_state_update_stp.cuh
  • include/flashinfer/mamba/selective_state_update.cuh
  • tests/mamba/test_selective_state_update_varlen.py
  • tests/mamba/triton_reference/selective_state_update_varlen.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • include/flashinfer/mamba/selective_state_update.cuh
  • csrc/selective_state_update_customize_config.jinja
  • .gitignore

Comment on lines +346 to +367
bool const is_varlen = (x.dim() == 3 && cu_seqlens.has_value());
// Extract dimensions from input tensors
auto const batch = x.size(0);
auto const ntokens_mtp = x.size(1);
int64_t batch;
int64_t ntokens_mtp;

auto const state_cache_size = state.size(0);
auto const nheads = state.size(1);
auto const dim = state.size(2);
auto const dstate = state.size(3);
auto const ngroups = B.size(2);

FLASHINFER_CHECK(state_cache_size >= batch, "state.size(0) must be >= x.size(0)");
FLASHINFER_CHECK(nheads % ngroups == 0, "nheads must be divisible by ngroups");

// Check x shape and strides
CHECK_CUDA(x);
CHECK_DIM(4, x);
FLASHINFER_CHECK(x.size(2) == nheads, "x.size(2) must equal nheads");
FLASHINFER_CHECK(x.size(3) == dim, "x.size(3) must equal dim");
CHECK_LAST_DIM_CONTIGUOUS(x);
FLASHINFER_CHECK(x.stride(2) == dim, "x.stride(2) must equal dim, got ", x.stride(2),
" expected ", dim);
if (is_varlen) {
CHECK_DIM(3, x); // x: {total_tokens, nheads, dim}
FLASHINFER_CHECK(x.size(1) == nheads, "x.size(1) must equal nheads");
FLASHINFER_CHECK(x.size(2) == dim, "x.size(2) must equal dim");
CHECK_LAST_DIM_CONTIGUOUS(x);
FLASHINFER_CHECK(x.stride(1) == dim, "x.stride(1) must equal dim");
batch = cu_seqlens.value().size(0) - 1;
FLASHINFER_CHECK(cache_steps >= 1,
"cache_steps must be >= 1 in varlen mode (specifies max_seqlen)");
ntokens_mtp = cache_steps;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate cu_seqlens length before deriving batch.

batch is computed as cu_seqlens.size(0) - 1 on Line 364 before enforcing a minimum length. If cu_seqlens is empty, batch becomes -1 and downstream checks can pass incorrectly.

🛠️ Suggested fix
   if (is_varlen) {
+    auto const& cs = cu_seqlens.value();
+    CHECK_CUDA(cs);
+    CHECK_DIM(1, cs);
+    CHECK_CONTIGUOUS(cs);
+    FLASHINFER_CHECK(cs.size(0) >= 1, "cu_seqlens must have at least one element");
     CHECK_DIM(3, x);  // x: {total_tokens, nheads, dim}
@@
-    batch = cu_seqlens.value().size(0) - 1;
+    batch = cs.size(0) - 1;
@@
-  if (cu_seqlens.has_value()) {
-    auto const& cs = cu_seqlens.value();
-    CHECK_CUDA(cs);
-    CHECK_DIM(1, cs);
-    CHECK_CONTIGUOUS(cs);
+  if (is_varlen) {
+    auto const& cs = cu_seqlens.value();
     FLASHINFER_CHECK(cs.size(0) == batch + 1, "cu_seqlens.size(0) must equal n_sequences + 1 (",
                      batch + 1, ")");

Also applies to: 561-567

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/selective_state_update.cu` around lines 346 - 367, When is_varlen is
true, validate cu_seqlens length before computing batch: check that
cu_seqlens.has_value() and cu_seqlens->size(0) >= 2 (so batch =
cu_seqlens.value().size(0) - 1 is non-negative) and emit a FLASHINFER_CHECK with
a clear message if not; update the same guard in the other varlen block that
computes batch (the later block around ntokens/offset handling) to avoid
deriving batch = -1 from an empty cu_seqlens. Ensure you perform this check
prior to assigning batch and before any subsequent uses of batch or indexing
into cu_seqlens.

Comment on lines +121 to +136
auto const* __restrict__ dst_sbi =
reinterpret_cast<stateIndex_t const*>(params.dst_state_batch_indices);
auto const dst_state_batch =
dst_sbi ? static_cast<int64_t>(dst_sbi[batch * params.dst_state_batch_indices_stride_batch])
: state_batch;
auto const state_ptr_offset = state_batch * params.state_stride_batch + head * DIM * DSTATE;
state += state_ptr_offset;
auto* __restrict__ dst_state = reinterpret_cast<state_t*>(params.state) +
dst_state_batch * params.state_stride_batch + head * DIM * DSTATE;
if constexpr (scaleState) {
state_scale += state_batch * params.state_scale_stride_batch + head * DIM;
}
[[maybe_unused]] auto* __restrict__ dst_state_scale =
scaleState ? reinterpret_cast<state_scale_t*>(params.state_scale) +
dst_state_batch * params.state_scale_stride_batch + head * DIM
: nullptr;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Guard simple-kernel writeback when dst_state_batch_indices is padded.

dst_state / dst_state_scale are derived from dst_state_batch, but every write guard only checks the source slot. If a caller uses pad_slot_id in dst_state_batch_indices, this path writes through a negative cache index.

🛠️ Proposed fix
   auto const dst_state_batch =
       dst_sbi ? static_cast<int64_t>(dst_sbi[batch * params.dst_state_batch_indices_stride_batch])
               : state_batch;
+  auto const dst_writable = dst_state_batch != params.pad_slot_id;
   auto const state_ptr_offset = state_batch * params.state_stride_batch + head * DIM * DSTATE;
   state += state_ptr_offset;
-  auto* __restrict__ dst_state = reinterpret_cast<state_t*>(params.state) +
-                                 dst_state_batch * params.state_stride_batch + head * DIM * DSTATE;
+  auto* __restrict__ dst_state =
+      dst_writable
+          ? reinterpret_cast<state_t*>(params.state) +
+                dst_state_batch * params.state_stride_batch + head * DIM * DSTATE
+          : nullptr;
@@
-      if (!scaleState && params.update_state && state_batch != params.pad_slot_id) {
+      if (!scaleState && params.update_state && state_batch != params.pad_slot_id &&
+          dst_writable) {
         *reinterpret_cast<load_state_t*>(&dst_state[d * DSTATE + i]) = rState;
       }
@@
-      if (params.update_state && state_batch != params.pad_slot_id) {
+      if (params.update_state && state_batch != params.pad_slot_id && dst_writable) {
@@
-    if (params.update_state && state_batch != params.pad_slot_id) {
+    if (params.update_state && state_batch != params.pad_slot_id && dst_writable) {
       for (int l = lane; l < rowsPerWarp; l += warpSize) {

Also applies to: 244-246, 254-270, 297-303

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/mamba/kernel_selective_state_update_stp.cuh` around lines
121 - 136, The code computes dst_state_batch/dst_state/dst_state_scale from
dst_sbi but then only uses the source-slot write guards, so if
dst_state_batch_indices contains a padded slot (pad_slot_id) the kernel can
write through a negative/invalid cache index; fix by computing a boolean
dst_valid (e.g., dst_sbi != nullptr && dst_state_batch != pad_slot_id) and use
that same dst_valid wherever writes to dst_state or dst_state_scale occur
(mirror the existing source-slot guard logic), and ensure any arithmetic that
constructs dst_state and dst_state_scale is only used when dst_valid is true to
avoid creating/using invalid pointers (apply same change around the other
occurrences mentioned: the blocks around lines with dst_state/dst_state_scale at
the other offsets).

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Mar 19, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !391 has been updated with latest changes, and the CI pipeline #46539983 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #46539983: 13/20 passed

@roikoren755 roikoren755 force-pushed the feat/selective-state-update-update branch from fd5eed6 to b6e179b Compare March 20, 2026 10:30
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Mar 20, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !391 has been updated with latest changes, and the CI pipeline #46618531 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[SUCCESS] Pipeline #46618531: 14/20 passed

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Mar 20, 2026

tests clean

@aleozlx aleozlx enabled auto-merge (squash) March 20, 2026 23:12
@aleozlx aleozlx merged commit 6fef570 into flashinfer-ai:main Mar 21, 2026
31 of 33 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants