Add varlen and speculative decoding support to selective state update#2700
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds varlen (variable-length sequence) and speculative-decoding support to selective_state_update by introducing optional Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Caller
participant PythonWrapper as Python
participant Binding as CppBinding
participant Dispatcher
participant Kernel as CUDAKernel
participant GPUState as GPUStateMemory
Caller->>Python: selective_state_update(..., dst_state_batch_indices?, cu_seqlens?, num_accepted_tokens?)
Python->>Binding: call native binding with optional args
Binding->>Dispatcher: pack params, detect varlen (cu_seqlens + x.dim)
Dispatcher->>Kernel: launch STP or MTP/varlen kernel with cu_seqlens/num_accepted_tokens/dst_indices
Kernel->>GPUState: read source state (state_batch_indices or inferred)
Kernel->>Kernel: compute updates per-token/sequence (respect cu_seqlens, num_accepted_tokens)
Kernel->>GPUState: write updates to dst_state_batch_indices or intermediate/out buffers
Kernel-->>Dispatcher: completion
Dispatcher-->>Binding: return status/output
Binding-->>Python: return updated tensor
Python-->>Caller: deliver result
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip You can validate your CodeRabbit configuration file in your editor.If your editor has YAML language server, you can enable auto-completion and validation by adding |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for vLLM-style speculative decoding and prefix caching by adding dst_state_batch_indices, cu_seqlens, and num_accepted_tokens parameters, enabling varlen inputs and separate read/write state slots. However, a security audit identified several vulnerabilities related to missing or insufficient validation of input tensors and indices, primarily manifesting as potential out-of-bounds (OOB) memory access within CUDA kernels. This could lead to sensitive data leakage or memory corruption. Critical findings include an integer underflow risk with cu_seqlens, a regression where batch size validation against the state cache size was removed, and a general lack of bounds checking for user-supplied indices. It is crucial to address these OOB access risks, especially by restoring a safety check that was removed during refactoring.
| int64_t batch; | ||
| int64_t ntokens_mtp; | ||
|
|
||
| auto const state_cache_size = state.size(0); | ||
| auto const nheads = state.size(1); | ||
| auto const dim = state.size(2); | ||
| auto const dstate = state.size(3); | ||
| auto const ngroups = B.size(2); | ||
|
|
||
| FLASHINFER_CHECK(state_cache_size >= batch, "state.size(0) must be >= x.size(0)"); | ||
| FLASHINFER_CHECK(nheads % ngroups == 0, "nheads must be divisible by ngroups"); | ||
|
|
||
| // Check x shape and strides | ||
| CHECK_CUDA(x); | ||
| CHECK_DIM(4, x); | ||
| FLASHINFER_CHECK(x.size(2) == nheads, "x.size(2) must equal nheads"); | ||
| FLASHINFER_CHECK(x.size(3) == dim, "x.size(3) must equal dim"); | ||
| CHECK_LAST_DIM_CONTIGUOUS(x); | ||
| FLASHINFER_CHECK(x.stride(2) == dim, "x.stride(2) must equal dim, got ", x.stride(2), | ||
| " expected ", dim); | ||
| if (is_varlen) { | ||
| CHECK_DIM(3, x); // x: {total_tokens, nheads, dim} | ||
| FLASHINFER_CHECK(x.size(1) == nheads, "x.size(1) must equal nheads"); | ||
| FLASHINFER_CHECK(x.size(2) == dim, "x.size(2) must equal dim"); | ||
| CHECK_LAST_DIM_CONTIGUOUS(x); | ||
| FLASHINFER_CHECK(x.stride(1) == dim, "x.stride(1) must equal dim"); | ||
| batch = cu_seqlens.value().size(0) - 1; | ||
| FLASHINFER_CHECK(cache_steps >= 1, | ||
| "cache_steps must be >= 1 in varlen mode (specifies max_seqlen)"); | ||
| ntokens_mtp = cache_steps; | ||
| } else { | ||
| CHECK_DIM(4, x); // x: {batch, ntokens_mtp, nheads, dim} | ||
| batch = x.size(0); | ||
| ntokens_mtp = x.size(1); | ||
| FLASHINFER_CHECK(x.size(2) == nheads, "x.size(2) must equal nheads"); | ||
| FLASHINFER_CHECK(x.size(3) == dim, "x.size(3) must equal dim"); | ||
| CHECK_LAST_DIM_CONTIGUOUS(x); | ||
| FLASHINFER_CHECK(x.stride(2) == dim, "x.stride(2) must equal dim, got ", x.stride(2), | ||
| " expected ", dim); | ||
| } |
There was a problem hiding this comment.
The removal of the FLASHINFER_CHECK(state_cache_size >= batch, ...) validation in the multi-token prediction (MTP) path is a critical security regression. Without this check, if state_batch_indices is not provided and the input batch size exceeds the state_cache_size, the kernel will perform out-of-bounds memory access on the state tensor. This could lead to data leakage or corruption. This check is still present in the single-token path (run_selective_state_update_stp) and is essential for preventing OOB access. Please restore this validation check in run_selective_state_update_mtp for the non-varlen case.
| FLASHINFER_CHECK(x.size(2) == dim, "x.size(2) must equal dim"); | ||
| CHECK_LAST_DIM_CONTIGUOUS(x); | ||
| FLASHINFER_CHECK(x.stride(1) == dim, "x.stride(1) must equal dim"); | ||
| batch = cu_seqlens.value().size(0) - 1; |
There was a problem hiding this comment.
The calculation of the batch size from cu_seqlens is vulnerable to an integer underflow if an empty cu_seqlens tensor is provided. Specifically, batch = cu_seqlens.value().size(0) - 1 will result in -1. When this value is assigned to params.batch (a uint32_t), it becomes 0xFFFFFFFF. This extremely large value is used as the grid dimension for the kernel launch and subsequently used within the kernel to index into cu_seqlens and other tensors, leading to out-of-bounds memory access. Please add a check to ensure cu_seqlens has at least one element: FLASHINFER_CHECK(cs.size(0) >= 1, "cu_seqlens must have at least one element");.
| auto const state_ptr_offset = state_batch * params.state_stride_batch + head * DIM * DSTATE; | ||
| state += state_ptr_offset; |
There was a problem hiding this comment.
The kernel retrieves indices from state_batch_indices, dst_state_batch_indices, and intermediate_state_indices and uses them to calculate offsets for accessing the state and intermediate_states tensors without any bounds checking. These indices are not validated to be within the valid range [0, state_cache_size). An attacker providing malicious indices could read from or write to arbitrary memory locations within the GPU's address space, potentially leading to sensitive data leakage from other users' sessions or memory corruption. Please implement bounds checks against params.state_cache_size before using these indices.
| inline void validate_state_batch_indices(Optional<TensorView> const& state_batch_indices, | ||
| int64_t batch) { | ||
| if (!state_batch_indices.has_value()) return; | ||
| CHECK_DIM(1, (*state_batch_indices)); | ||
| CHECK_CONTIGUOUS((*state_batch_indices)); | ||
| FLASHINFER_CHECK(state_batch_indices.value().size(0) == batch, | ||
| "state_batch_indices.shape must be (", batch, ")"); | ||
| auto const& sbi = state_batch_indices.value(); | ||
| FLASHINFER_CHECK(sbi.dim() == 1 || sbi.dim() == 2, "state_batch_indices must be 1D or 2D, got ", | ||
| sbi.dim(), "D"); | ||
| FLASHINFER_CHECK(sbi.size(0) >= batch, "state_batch_indices.size(0) must be >= batch (", batch, | ||
| ")"); | ||
| } |
There was a problem hiding this comment.
The validate_state_batch_indices function only validates the first dimension of the indices tensor. However, in multi-token mode, the kernel also accesses the second dimension using the step index, which can go up to cache_steps - 1. If the provided tensor is 2D and its second dimension is smaller than cache_steps, an out-of-bounds read will occur in the kernel. Please update the validation logic to check sbi.size(1) >= cache_steps when the tensor is 2D and used in the MTP path.
| if (num_accepted_tokens.has_value()) { | ||
| auto const& nat = num_accepted_tokens.value(); | ||
| CHECK_CUDA(nat); | ||
| CHECK_DIM(1, nat); | ||
| CHECK_CONTIGUOUS(nat); | ||
| FLASHINFER_CHECK(nat.dtype().code == kDLInt && nat.dtype().bits == 32, | ||
| "num_accepted_tokens must be int32"); | ||
| p.num_accepted_tokens = const_cast<void*>(nat.data_ptr()); | ||
| } |
There was a problem hiding this comment.
Missing validation for the size of the num_accepted_tokens tensor. The kernel accesses this tensor using seq_idx, which ranges from 0 to batch - 1. If num_accepted_tokens.size(0) is less than batch, an out-of-bounds read will occur. Please add a validation check: FLASHINFER_CHECK(nat.size(0) >= batch, "num_accepted_tokens.size(0) must be >= batch");.
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/mamba/test_selective_state_update_varlen.py (1)
80-162: Recommend adding architecture checks for consistency with other mamba tests.While the varlen implementation provides fallback kernels for SM80+, other mamba tests (test_selective_state_update_stp.py, test_selective_state_update_mtp.py) explicitly use
get_compute_capability()to document architecture support. Adding similar checks here would improve consistency and clarity. This is optional since tests will run on any CUDA GPU, but aligns with the testing pattern used elsewhere in the module.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/mamba/test_selective_state_update_varlen.py` around lines 80 - 162, Add the same CUDA architecture guard used in other mamba tests: call get_compute_capability() and skip the test (or class) when the compute capability is below the minimum you support for the varlen path (mirror checks in test_selective_state_update_stp.py/test_selective_state_update_mtp.py); place this check at the start of TestSelectiveStateUpdateDstIndices or inside test_dst_different_from_src so the test is skipped on unsupported GPUs, referencing the test name TestSelectiveStateUpdateDstIndices and test_dst_different_from_src to locate where to insert the guard.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/mamba/test_selective_state_update_varlen.py`:
- Around line 80-162: Add the same CUDA architecture guard used in other mamba
tests: call get_compute_capability() and skip the test (or class) when the
compute capability is below the minimum you support for the varlen path (mirror
checks in
test_selective_state_update_stp.py/test_selective_state_update_mtp.py); place
this check at the start of TestSelectiveStateUpdateDstIndices or inside
test_dst_different_from_src so the test is skipped on unsupported GPUs,
referencing the test name TestSelectiveStateUpdateDstIndices and
test_dst_different_from_src to locate where to insert the guard.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 9cc54fc5-eea3-4d4e-8877-ad7540693e79
📒 Files selected for processing (8)
csrc/flashinfer_mamba_binding.cucsrc/selective_state_update.cuflashinfer/mamba/selective_state_update.pyinclude/flashinfer/mamba/kernel_selective_state_update_mtp.cuhinclude/flashinfer/mamba/kernel_selective_state_update_stp.cuhinclude/flashinfer/mamba/selective_state_update.cuhtests/mamba/test_selective_state_update_varlen.pytests/mamba/triton_reference/selective_state_update_varlen.py
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
benchmarks/routines/mamba.py (1)
302-379:⚠️ Potential issue | 🟠 MajorPotential index out of bounds when allocating
src_indicesanddst_indices.The
ssm_state_cache_sizeis set tomax(384, batch_size * 10), but varlen mode requires2 * n_seqs * max_seqlenunique indices for non-overlappingsrc_indicesanddst_indices. When2 * batch_size * cache_steps > ssm_state_cache_size, line 367-369 will cause an index error.Example:
batch_size=10, cache_steps=50requires 1000 indices but only 384 are available.🐛 Proposed fix
## Prepare input tensors - ssm_state_cache_size = max(384, batch_size * 10) + if is_varlen: + # Varlen needs non-overlapping src and dst indices + ssm_state_cache_size = max(384, 2 * batch_size * cache_steps) + else: + ssm_state_cache_size = max(384, batch_size * 10)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/routines/mamba.py` around lines 302 - 379, ssm_state_cache_size can be too small for varlen because src_indices and dst_indices slice 2 * n_seqs * max_seqlen entries from perm; update the allocation to ensure capacity by computing required = 2 * n_seqs * max_seqlen (or 2 * batch_size * max_seqlen) and set ssm_state_cache_size = max(384, batch_size * 10, required) before creating state_cache and perm, or alternatively guard the perm sampling by generating torch.randperm(required, device=device) when in varlen mode so that src_indices and dst_indices cannot index out of bounds (references: ssm_state_cache_size, src_indices, dst_indices, perm, n_seqs, max_seqlen).
🧹 Nitpick comments (2)
csrc/selective_state_update.cu (2)
280-292: Minor inconsistency in CUDA validation betweenstate_batch_indicesanddst_state_batch_indices.
CHECK_CUDAis called fordst_state_batch_indices(line 288) but not forstate_batch_indices(lines 280-285). For consistency, consider addingCHECK_CUDAvalidation tovalidate_state_batch_indiceshelper function so both tensors are validated uniformly.♻️ Suggested improvement
inline void validate_state_batch_indices(Optional<TensorView> const& state_batch_indices, int64_t batch, int64_t max_seqlen = 1) { if (!state_batch_indices.has_value()) return; auto const& sbi = state_batch_indices.value(); + CHECK_CUDA(sbi); FLASHINFER_CHECK(sbi.dim() == 1 || sbi.dim() == 2, "state_batch_indices must be 1D or 2D, got ", sbi.dim(), "D");Then remove
CHECK_CUDA(dsbi)from line 288 and line 556.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/selective_state_update.cu` around lines 280 - 292, state_batch_indices is not validated with CHECK_CUDA while dst_state_batch_indices is; move the CHECK_CUDA check into the existing validate_state_batch_indices helper so both tensors are validated uniformly (call CHECK_CUDA on the tensor inside validate_state_batch_indices), then remove the redundant CHECK_CUDA(dsbi) calls that remain (e.g., the current CHECK_CUDA before assigning p.dst_state_batch_indices and any other standalone CHECK_CUDA(dsbi) usage such as the one later in the file).
668-683: Consider updating the error message to mention varlen mode.The error message states "3 dimensions (single-token) or 4 dimensions (multi-token)" but doesn't mention that 3D with
cu_seqlensis also valid for varlen multi-token mode. This could confuse users who provide 3D input but forgetcu_seqlens.📝 Suggested improvement
} else { FLASHINFER_CHECK(false, - "x must have 3 dimensions (single-token) or 4 dimensions (multi-token), got ", + "x must have 3 dimensions (single-token, or varlen multi-token with cu_seqlens) " + "or 4 dimensions (multi-token), got ", x.dim()); }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/selective_state_update.cu` around lines 668 - 683, The error message in the selection branch is misleading for varlen mode; update the FLASHINFER_CHECK call so its message mentions that 3D input is valid either for single-token or for varlen multi-token when cu_seqlens is provided (i.e., clarify "3 dimensions (single-token) or 4 dimensions (multi-token) or 3 dimensions with cu_seqlens (varlen/multi-token)"). Edit the FLASHINFER_CHECK invocation near the x.dim() checks (referencing x.dim(), has_cu_seqlens / cu_seqlens) to include that wording so users supplying 3D + cu_seqlens won't be confused.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@benchmarks/routines/mamba.py`:
- Around line 302-379: ssm_state_cache_size can be too small for varlen because
src_indices and dst_indices slice 2 * n_seqs * max_seqlen entries from perm;
update the allocation to ensure capacity by computing required = 2 * n_seqs *
max_seqlen (or 2 * batch_size * max_seqlen) and set ssm_state_cache_size =
max(384, batch_size * 10, required) before creating state_cache and perm, or
alternatively guard the perm sampling by generating torch.randperm(required,
device=device) when in varlen mode so that src_indices and dst_indices cannot
index out of bounds (references: ssm_state_cache_size, src_indices, dst_indices,
perm, n_seqs, max_seqlen).
---
Nitpick comments:
In `@csrc/selective_state_update.cu`:
- Around line 280-292: state_batch_indices is not validated with CHECK_CUDA
while dst_state_batch_indices is; move the CHECK_CUDA check into the existing
validate_state_batch_indices helper so both tensors are validated uniformly
(call CHECK_CUDA on the tensor inside validate_state_batch_indices), then remove
the redundant CHECK_CUDA(dsbi) calls that remain (e.g., the current CHECK_CUDA
before assigning p.dst_state_batch_indices and any other standalone
CHECK_CUDA(dsbi) usage such as the one later in the file).
- Around line 668-683: The error message in the selection branch is misleading
for varlen mode; update the FLASHINFER_CHECK call so its message mentions that
3D input is valid either for single-token or for varlen multi-token when
cu_seqlens is provided (i.e., clarify "3 dimensions (single-token) or 4
dimensions (multi-token) or 3 dimensions with cu_seqlens (varlen/multi-token)").
Edit the FLASHINFER_CHECK invocation near the x.dim() checks (referencing
x.dim(), has_cu_seqlens / cu_seqlens) to include that wording so users supplying
3D + cu_seqlens won't be confused.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 0087c120-5ace-4992-b8f0-422e323dc90c
📒 Files selected for processing (3)
.gitignorebenchmarks/routines/mamba.pycsrc/selective_state_update.cu
|
/bot run |
|
[SUCCESS] Pipeline #45680534: 9/20 passed |
1bb2fc4 to
17328df
Compare
There was a problem hiding this comment.
Actionable comments posted: 5
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/routines/mamba.py`:
- Around line 296-300: When is_varlen is true the code needs 2 * n_seqs *
max_seqlen cache slots but state_cache is still sized independently; before
materializing the varlen src/dst indices (the perm[...] slices and subsequent
reshape) ensure state_cache is grown/resized to at least 2 * n_seqs * max_seqlen
(use n_seqs = batch_size and max_seqlen = cache_steps) so the second perm slice
and reshape won't be too short; apply the same change to the analogous block
around the perm/reshape at lines 360-370 so both varlen paths expand state_cache
before slicing.
In `@csrc/selective_state_update.cu`:
- Around line 77-87: The code only checks shapes for state_batch_indices (sbi)
but never verifies it's a CUDA tensor before later packing its raw pointer; add
a GPU-device check immediately after extracting sbi and shape checks (same place
where you validate sizes) to ensure sbi.is_cuda() (or the project macro
equivalent) and error out if not CUDA, mirroring the validation done for
dst_state_batch_indices so a host tensor is never dereferenced from device code.
In `@flashinfer/mamba/selective_state_update.py`:
- Around line 278-283: Reject variable-length inputs whose longest sequence
exceeds cache_steps by validating cu_seqlens before setting ntokens_mtp: when
is_varlen is True, compute the maximum span from cu_seqlens (e.g.,
max(cu_seqlens[i+1]-cu_seqlens[i]) or equivalent) and if it is greater than
cache_steps raise an error (or return) instead of assigning ntokens_mtp =
cache_steps; update the early branch around is_varlen/ntokens_mtp in
selective_state_update.py to perform this guard so the kernel never silently
truncates tails.
In `@include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh`:
- Around line 410-432: The dst-slot write path that uses dst_state_batch_indices
currently ignores params.update_state and still mutates
params.state/params.state_scale; modify the block guarded by "if
(has_dst_indices) { ... }" (the code that computes dst_idx, uses dst_state_ptr
and writes into params.state and the dst_scale write using params.state_scale
and sram.state_scale) to first check params.update_state (e.g., if
(params.update_state) before performing any writes) so that when update_state is
false the dst writes are skipped as well.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: c3b081c9-8956-4bda-b11e-48955319b1e7
📒 Files selected for processing (10)
.gitignorebenchmarks/routines/mamba.pycsrc/flashinfer_mamba_binding.cucsrc/selective_state_update.cuflashinfer/mamba/selective_state_update.pyinclude/flashinfer/mamba/kernel_selective_state_update_mtp.cuhinclude/flashinfer/mamba/kernel_selective_state_update_stp.cuhinclude/flashinfer/mamba/selective_state_update.cuhtests/mamba/test_selective_state_update_varlen.pytests/mamba/triton_reference/selective_state_update_varlen.py
🚧 Files skipped from review as they are similar to previous changes (1)
- include/flashinfer/mamba/selective_state_update.cuh
| if is_varlen: | ||
| n_seqs = batch_size | ||
| max_seqlen = cache_steps | ||
| total_tokens = n_seqs * max_seqlen | ||
|
|
There was a problem hiding this comment.
Grow the cache before materializing varlen src/dst indices.
This path needs 2 * n_seqs * max_seqlen distinct slots, but the benchmark still sizes state_cache independently of max_seqlen. Once cache_steps gets large enough, the second perm[...] slice is too short and the reshape fails before timing starts.
🛠️ Proposed fix
- ssm_state_cache_size = max(384, batch_size * 10)
+ ssm_state_cache_size = max(384, batch_size * 10)
+ if is_varlen:
+ ssm_state_cache_size = max(
+ ssm_state_cache_size, 2 * n_seqs * max_seqlen
+ )Also applies to: 360-370
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@benchmarks/routines/mamba.py` around lines 296 - 300, When is_varlen is true
the code needs 2 * n_seqs * max_seqlen cache slots but state_cache is still
sized independently; before materializing the varlen src/dst indices (the
perm[...] slices and subsequent reshape) ensure state_cache is grown/resized to
at least 2 * n_seqs * max_seqlen (use n_seqs = batch_size and max_seqlen =
cache_steps) so the second perm slice and reshape won't be too short; apply the
same change to the analogous block around the perm/reshape at lines 360-370 so
both varlen paths expand state_cache before slicing.
| int64_t batch, int64_t max_seqlen = 1) { | ||
| if (!state_batch_indices.has_value()) return; | ||
| CHECK_DIM(1, (*state_batch_indices)); | ||
| CHECK_CONTIGUOUS((*state_batch_indices)); | ||
| FLASHINFER_CHECK(state_batch_indices.value().size(0) == batch, | ||
| "state_batch_indices.shape must be (", batch, ")"); | ||
| auto const& sbi = state_batch_indices.value(); | ||
| FLASHINFER_CHECK(sbi.dim() == 1 || sbi.dim() == 2, "state_batch_indices must be 1D or 2D, got ", | ||
| sbi.dim(), "D"); | ||
| FLASHINFER_CHECK(sbi.size(0) >= batch, "state_batch_indices.size(0) must be >= batch (", batch, | ||
| ")"); | ||
| if (sbi.dim() == 2) { | ||
| FLASHINFER_CHECK(sbi.size(1) >= max_seqlen, | ||
| "state_batch_indices.size(1) must be >= max_seqlen (", max_seqlen, ")"); | ||
| } |
There was a problem hiding this comment.
Validate state_batch_indices on CUDA before packing its raw pointer.
After widening this helper to 1D/2D, state_batch_indices only gets shape checks here. Unlike dst_state_batch_indices, it never hits a later CHECK_CUDA, so a host tensor can still be dereferenced from device code.
🛠️ Proposed fix
inline void validate_state_batch_indices(Optional<TensorView> const& state_batch_indices,
int64_t batch, int64_t max_seqlen = 1) {
if (!state_batch_indices.has_value()) return;
auto const& sbi = state_batch_indices.value();
+ CHECK_CUDA(sbi);
FLASHINFER_CHECK(sbi.dim() == 1 || sbi.dim() == 2, "state_batch_indices must be 1D or 2D, got ",
sbi.dim(), "D");🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@csrc/selective_state_update.cu` around lines 77 - 87, The code only checks
shapes for state_batch_indices (sbi) but never verifies it's a CUDA tensor
before later packing its raw pointer; add a GPU-device check immediately after
extracting sbi and shape checks (same place where you validate sizes) to ensure
sbi.is_cuda() (or the project macro equivalent) and error out if not CUDA,
mirroring the validation done for dst_state_batch_indices so a host tensor is
never dereferenced from device code.
| if is_varlen: | ||
| ntokens_mtp = cache_steps | ||
| elif x.dim() == 4: | ||
| ntokens_mtp = x.size(1) | ||
| else: | ||
| ntokens_mtp = 1 |
There was a problem hiding this comment.
Reject varlen inputs whose longest sequence exceeds cache_steps.
ntokens_mtp is specialized directly from cache_steps. If any cu_seqlens span is longer, the kernel only processes the prefix and leaves the tail tokens unwritten.
🛡️ Proposed guard
if is_varlen:
+ max_seqlen = int((cu_seqlens[1:] - cu_seqlens[:-1]).max().item())
+ if cache_steps < max_seqlen:
+ raise ValueError(
+ f"cache_steps ({cache_steps}) must be >= max sequence length ({max_seqlen})"
+ )
ntokens_mtp = cache_steps📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if is_varlen: | |
| ntokens_mtp = cache_steps | |
| elif x.dim() == 4: | |
| ntokens_mtp = x.size(1) | |
| else: | |
| ntokens_mtp = 1 | |
| if is_varlen: | |
| max_seqlen = int((cu_seqlens[1:] - cu_seqlens[:-1]).max().item()) | |
| if cache_steps < max_seqlen: | |
| raise ValueError( | |
| f"cache_steps ({cache_steps}) must be >= max sequence length ({max_seqlen})" | |
| ) | |
| ntokens_mtp = cache_steps | |
| elif x.dim() == 4: | |
| ntokens_mtp = x.size(1) | |
| else: | |
| ntokens_mtp = 1 |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/mamba/selective_state_update.py` around lines 278 - 283, Reject
variable-length inputs whose longest sequence exceeds cache_steps by validating
cu_seqlens before setting ntokens_mtp: when is_varlen is True, compute the
maximum span from cu_seqlens (e.g., max(cu_seqlens[i+1]-cu_seqlens[i]) or
equivalent) and if it is greater than cache_steps raise an error (or return)
instead of assigning ntokens_mtp = cache_steps; update the early branch around
is_varlen/ntokens_mtp in selective_state_update.py to perform this guard so the
kernel never silently truncates tails.
| if (state_batch != params.pad_slot_id) { | ||
| if (has_dst_indices) { | ||
| auto dst_idx = static_cast<int64_t>( | ||
| dst_state_batch_indices[seq_idx * params.dst_state_batch_indices_stride_batch + | ||
| step * params.dst_state_batch_indices_stride_T]); | ||
| if (dst_idx != params.pad_slot_id) { | ||
| auto* dst_state_ptr = reinterpret_cast<state_t*>(params.state); | ||
| for (int i = lane * load_state_t::count; i < DSTATE; | ||
| i += warpSize * load_state_t::count) { | ||
| auto* src = reinterpret_cast<load_state_t*>(&sram.state[dd][i]); | ||
| *reinterpret_cast<load_state_t*>( | ||
| &dst_state_ptr[dst_idx * params.state_stride_batch + head * DIM * DSTATE + | ||
| d * DSTATE + i]) = *src; | ||
| } | ||
| if constexpr (scaleState) { | ||
| if (lane == 0) { | ||
| auto* dst_scale = reinterpret_cast<state_scale_t*>(params.state_scale); | ||
| dst_scale[dst_idx * params.state_scale_stride_batch + head * DIM + d] = | ||
| sram.state_scale[dd]; | ||
| } | ||
| } | ||
| } | ||
| } else if (has_intermediate) { |
There was a problem hiding this comment.
Gate dst-slot writes on params.update_state.
disable_state_update=True currently suppresses only the final source-slot write. The new per-token dst_state_batch_indices path still stores into params.state, so verification runs mutate the cache anyway.
🛠️ Proposed fix
- if (has_dst_indices) {
+ if (params.update_state && has_dst_indices) {
auto dst_idx = static_cast<int64_t>(
dst_state_batch_indices[seq_idx * params.dst_state_batch_indices_stride_batch +
step * params.dst_state_batch_indices_stride_T]);📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if (state_batch != params.pad_slot_id) { | |
| if (has_dst_indices) { | |
| auto dst_idx = static_cast<int64_t>( | |
| dst_state_batch_indices[seq_idx * params.dst_state_batch_indices_stride_batch + | |
| step * params.dst_state_batch_indices_stride_T]); | |
| if (dst_idx != params.pad_slot_id) { | |
| auto* dst_state_ptr = reinterpret_cast<state_t*>(params.state); | |
| for (int i = lane * load_state_t::count; i < DSTATE; | |
| i += warpSize * load_state_t::count) { | |
| auto* src = reinterpret_cast<load_state_t*>(&sram.state[dd][i]); | |
| *reinterpret_cast<load_state_t*>( | |
| &dst_state_ptr[dst_idx * params.state_stride_batch + head * DIM * DSTATE + | |
| d * DSTATE + i]) = *src; | |
| } | |
| if constexpr (scaleState) { | |
| if (lane == 0) { | |
| auto* dst_scale = reinterpret_cast<state_scale_t*>(params.state_scale); | |
| dst_scale[dst_idx * params.state_scale_stride_batch + head * DIM + d] = | |
| sram.state_scale[dd]; | |
| } | |
| } | |
| } | |
| } else if (has_intermediate) { | |
| if (state_batch != params.pad_slot_id) { | |
| if (params.update_state && has_dst_indices) { | |
| auto dst_idx = static_cast<int64_t>( | |
| dst_state_batch_indices[seq_idx * params.dst_state_batch_indices_stride_batch + | |
| step * params.dst_state_batch_indices_stride_T]); | |
| if (dst_idx != params.pad_slot_id) { | |
| auto* dst_state_ptr = reinterpret_cast<state_t*>(params.state); | |
| for (int i = lane * load_state_t::count; i < DSTATE; | |
| i += warpSize * load_state_t::count) { | |
| auto* src = reinterpret_cast<load_state_t*>(&sram.state[dd][i]); | |
| *reinterpret_cast<load_state_t*>( | |
| &dst_state_ptr[dst_idx * params.state_stride_batch + head * DIM * DSTATE + | |
| d * DSTATE + i]) = *src; | |
| } | |
| if constexpr (scaleState) { | |
| if (lane == 0) { | |
| auto* dst_scale = reinterpret_cast<state_scale_t*>(params.state_scale); | |
| dst_scale[dst_idx * params.state_scale_stride_batch + head * DIM + d] = | |
| sram.state_scale[dd]; | |
| } | |
| } | |
| } | |
| } else if (has_intermediate) { |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh` around lines
410 - 432, The dst-slot write path that uses dst_state_batch_indices currently
ignores params.update_state and still mutates params.state/params.state_scale;
modify the block guarded by "if (has_dst_indices) { ... }" (the code that
computes dst_idx, uses dst_state_ptr and writes into params.state and the
dst_scale write using params.state_scale and sram.state_scale) to first check
params.update_state (e.g., if (params.update_state) before performing any
writes) so that when update_state is false the dst writes are skipped as well.
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (4)
include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh (1)
411-433:⚠️ Potential issue | 🟠 MajorGate dst-slot writes with
params.update_state.Line 412’s dst write branch still mutates
params.statewhendisable_state_update=True(params.update_state == false).🛠️ Suggested fix
- if (has_dst_indices) { + if (params.update_state && has_dst_indices) { auto dst_idx = static_cast<int64_t>( dst_state_batch_indices[seq_idx * params.dst_state_batch_indices_stride_batch + step * params.dst_state_batch_indices_stride_T]);🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh` around lines 411 - 433, The dst-slot write path currently always mutates params.state and params.state_scale; guard those writes with the update flag by checking params.update_state before performing the memory stores—i.e., inside the has_dst_indices branch around the loop that writes into dst_state_ptr and around the scaleState block that writes into dst_scale, skip the writes when params.update_state is false so no mutation occurs when disable_state_update=True; reference the symbols dst_state_ptr, params.state, params.state_scale, params.update_state, and sram.state/sram.state_scale to locate where to add the conditional.csrc/selective_state_update.cu (2)
76-87:⚠️ Potential issue | 🟠 MajorRequire CUDA memory for
state_batch_indicesbefore pointer packing.
state_batch_indicesis shape-validated but never device-validated beforedata_ptr()is consumed by CUDA kernels.🛠️ Suggested fix
inline void validate_state_batch_indices(Optional<TensorView> const& state_batch_indices, int64_t batch, int64_t max_seqlen = 1) { if (!state_batch_indices.has_value()) return; auto const& sbi = state_batch_indices.value(); + CHECK_CUDA(sbi); FLASHINFER_CHECK(sbi.dim() == 1 || sbi.dim() == 2, "state_batch_indices must be 1D or 2D, got ", sbi.dim(), "D");Also applies to: 280-285, 548-553
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/selective_state_update.cu` around lines 76 - 87, The shape checks in validate_state_batch_indices validate dims/sizes but do not ensure state_batch_indices is resident on the CUDA device before its pointer is consumed by kernels; add a device check (e.g., assert or FLASHINFER_CHECK that the TensorView sbi is on CUDA: sbi.is_cuda() or sbi.device().is_cuda()) right after retrieving sbi and before any code that will call data_ptr() and be passed to CUDA kernels; apply the same CUDA-device validation to the other similar validation sites that handle state_batch_indices before pointer packing.
358-377:⚠️ Potential issue | 🔴 CriticalVarlen path should validate flattened leading dimensions across tensors.
In varlen mode,
dt/B/C/z/outare not checked againstx.size(0). A shorter tensor can be indexed out-of-bounds viabos + step.🛡️ Suggested checks
if (is_varlen) { CHECK_DIM(3, x); // x: {total_tokens, nheads, dim} + int64_t const total_tokens = x.size(0); FLASHINFER_CHECK(x.size(1) == nheads, "x.size(1) must equal nheads"); FLASHINFER_CHECK(x.size(2) == dim, "x.size(2) must equal dim"); @@ if (is_varlen) { CHECK_DIM(3, dt); // dt: {total_tokens, nheads, dim} + FLASHINFER_CHECK(dt.size(0) == x.size(0), "dt.size(0) must equal x.size(0) in varlen mode"); @@ if (is_varlen) { CHECK_DIM(3, B); // B: {total_tokens, ngroups, dstate} + FLASHINFER_CHECK(B.size(0) == x.size(0), "B.size(0) must equal x.size(0) in varlen mode"); @@ if (is_varlen) { CHECK_DIM(3, C); // C: {total_tokens, ngroups, dstate} + FLASHINFER_CHECK(C.size(0) == x.size(0), "C.size(0) must equal x.size(0) in varlen mode"); @@ if (is_varlen) { CHECK_DIM(3, z_tensor); // z: {total_tokens, nheads, dim} + FLASHINFER_CHECK(z_tensor.size(0) == x.size(0), + "z.size(0) must equal x.size(0) in varlen mode"); @@ if (is_varlen) { CHECK_DIM(3, output); // out: {total_tokens, nheads, dim} + FLASHINFER_CHECK(output.size(0) == x.size(0), + "out.size(0) must equal x.size(0) in varlen mode");Also applies to: 386-472
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/selective_state_update.cu` around lines 358 - 377, In the is_varlen branch add explicit validation that all tensors indexed by flattened token positions (dt, B, C, z, out) have their leading flattened dimension equal to x.size(0) (the total_tokens computed from x) so indexing with bos + step cannot go out-of-bounds; use cu_seqlens.value().size(0)-1 (or the same total_tokens variable) to compare against dt.size(0), B.size(0), C.size(0), z.size(0), out.size(0) and emit FLASHINFER_CHECK errors when they differ; also ensure any use of cache_steps/bos/step is guarded by these checks so bos + step < x.size(0).flashinfer/mamba/selective_state_update.py (1)
286-291:⚠️ Potential issue | 🟠 MajorAdd a varlen guard:
cache_stepsmust cover the longest sequence.Without this check, varlen sequences longer than
cache_stepsare truncated by the specialized kernel token budget.🛡️ Suggested guard
if is_varlen: + max_seqlen = int((cu_seqlens[1:] - cu_seqlens[:-1]).max().item()) + if cache_steps < max_seqlen: + raise ValueError( + f"cache_steps ({cache_steps}) must be >= max sequence length ({max_seqlen})" + ) ntokens_mtp = cache_steps🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/mamba/selective_state_update.py` around lines 286 - 291, When is_varlen is True the code assumes cache_steps covers the full token length but doesn't verify it; add a guard in selective_state_update to compute the longest sequence length from the input (e.g. derive max_seq_len from x.size(1) or the provided lengths tensor) and assert or raise a clear ValueError if cache_steps < max_seq_len so the specialized kernel token budget won't truncate sequences; update the branch that sets ntokens_mtp (the is_varlen branch) to perform this check and fail fast with an informative message referencing cache_steps and max_seq_len.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/selective_state_update.cu`:
- Around line 561-569: Reject/ignore cu_seqlens when input tensor x is
non-varlen (x.dim() == 4): in the block that assigns p.cu_seqlens from
cu_seqlens (the code referencing cu_seqlens, cs, and p.cu_seqlens), add a guard
that checks x.dim() and if x.dim() == 4 then either DCHECK/FLASHINFER_CHECK that
cu_seqlens is not provided or simply do not set p.cu_seqlens (leave it null) and
log/raise an error; apply the same guard and behavior in the analogous block
around the later assignment at lines 669-674 so the kernel will not switch to
varlen addressing when x is 4D.
- Around line 570-579: When num_accepted_tokens is provided, ensure
state_batch_indices is a 2D CUDA tensor (not 1D) so the kernel can read
state_batch_indices[seq_idx, init_token_idx]; add checks after
FLASHINFER_CHECK(state_batch_indices.has_value()) to validate
state_batch_indices.dim()==2, CHECK_CUDA(state_batch_indices.value()),
CHECK_CONTIGUOUS(state_batch_indices.value()),
FLASHINFER_CHECK(state_batch_indices.value().size(0) >= batch, ...) and
FLASHINFER_CHECK(state_batch_indices.value().size(1) > 0, ...), then set
p.state_batch_indices =
const_cast<void*>(state_batch_indices.value().data_ptr()) alongside
p.num_accepted_tokens to ensure nonzero stride_T and correct indexing.
---
Duplicate comments:
In `@csrc/selective_state_update.cu`:
- Around line 76-87: The shape checks in validate_state_batch_indices validate
dims/sizes but do not ensure state_batch_indices is resident on the CUDA device
before its pointer is consumed by kernels; add a device check (e.g., assert or
FLASHINFER_CHECK that the TensorView sbi is on CUDA: sbi.is_cuda() or
sbi.device().is_cuda()) right after retrieving sbi and before any code that will
call data_ptr() and be passed to CUDA kernels; apply the same CUDA-device
validation to the other similar validation sites that handle state_batch_indices
before pointer packing.
- Around line 358-377: In the is_varlen branch add explicit validation that all
tensors indexed by flattened token positions (dt, B, C, z, out) have their
leading flattened dimension equal to x.size(0) (the total_tokens computed from
x) so indexing with bos + step cannot go out-of-bounds; use
cu_seqlens.value().size(0)-1 (or the same total_tokens variable) to compare
against dt.size(0), B.size(0), C.size(0), z.size(0), out.size(0) and emit
FLASHINFER_CHECK errors when they differ; also ensure any use of
cache_steps/bos/step is guarded by these checks so bos + step < x.size(0).
In `@flashinfer/mamba/selective_state_update.py`:
- Around line 286-291: When is_varlen is True the code assumes cache_steps
covers the full token length but doesn't verify it; add a guard in
selective_state_update to compute the longest sequence length from the input
(e.g. derive max_seq_len from x.size(1) or the provided lengths tensor) and
assert or raise a clear ValueError if cache_steps < max_seq_len so the
specialized kernel token budget won't truncate sequences; update the branch that
sets ntokens_mtp (the is_varlen branch) to perform this check and fail fast with
an informative message referencing cache_steps and max_seq_len.
In `@include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh`:
- Around line 411-433: The dst-slot write path currently always mutates
params.state and params.state_scale; guard those writes with the update flag by
checking params.update_state before performing the memory stores—i.e., inside
the has_dst_indices branch around the loop that writes into dst_state_ptr and
around the scaleState block that writes into dst_scale, skip the writes when
params.update_state is false so no mutation occurs when
disable_state_update=True; reference the symbols dst_state_ptr, params.state,
params.state_scale, params.update_state, and sram.state/sram.state_scale to
locate where to add the conditional.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 3f3148d4-1983-4797-a149-e6055374c8c2
📒 Files selected for processing (8)
csrc/flashinfer_mamba_binding.cucsrc/selective_state_update.cucsrc/selective_state_update_customize_config.jinjaflashinfer/jit/mamba/selective_state_update.pyflashinfer/mamba/selective_state_update.pyinclude/flashinfer/mamba/kernel_selective_state_update_mtp.cuhinclude/flashinfer/mamba/selective_state_update.cuhtests/mamba/test_selective_state_update_varlen.py
🚧 Files skipped from review as they are similar to previous changes (2)
- include/flashinfer/mamba/selective_state_update.cuh
- tests/mamba/test_selective_state_update_varlen.py
| if (cu_seqlens.has_value()) { | ||
| auto const& cs = cu_seqlens.value(); | ||
| CHECK_CUDA(cs); | ||
| CHECK_DIM(1, cs); | ||
| CHECK_CONTIGUOUS(cs); | ||
| FLASHINFER_CHECK(cs.size(0) == batch + 1, "cu_seqlens.size(0) must equal n_sequences + 1 (", | ||
| batch + 1, ")"); | ||
| p.cu_seqlens = const_cast<void*>(cs.data_ptr()); | ||
| } |
There was a problem hiding this comment.
Disallow cu_seqlens for non-varlen (x.dim()==4) inputs.
cu_seqlens is packed whenever present, and the kernel switches to varlen addressing based on that pointer even if x is 4D.
🧭 Suggested fix
bool const is_varlen = (x.dim() == 3 && cu_seqlens.has_value());
+ FLASHINFER_CHECK(!(cu_seqlens.has_value() && x.dim() != 3),
+ "cu_seqlens is only supported when x is 3D varlen layout");
@@
- if (cu_seqlens.has_value()) {
+ if (is_varlen) {
auto const& cs = cu_seqlens.value();
CHECK_CUDA(cs);
CHECK_DIM(1, cs);Also applies to: 669-674
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@csrc/selective_state_update.cu` around lines 561 - 569, Reject/ignore
cu_seqlens when input tensor x is non-varlen (x.dim() == 4): in the block that
assigns p.cu_seqlens from cu_seqlens (the code referencing cu_seqlens, cs, and
p.cu_seqlens), add a guard that checks x.dim() and if x.dim() == 4 then either
DCHECK/FLASHINFER_CHECK that cu_seqlens is not provided or simply do not set
p.cu_seqlens (leave it null) and log/raise an error; apply the same guard and
behavior in the analogous block around the later assignment at lines 669-674 so
the kernel will not switch to varlen addressing when x is 4D.
30c2c3e to
fd5eed6
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
flashinfer/mamba/selective_state_update.py (1)
274-282:⚠️ Potential issue | 🟠 MajorValidate that all index tensors share one dtype.
The generated module specializes a single
stateIndex_t, but this picks the first non-Nonedtype and forwards the others unchanged. Ifstate_batch_indices,dst_state_batch_indices, and/orintermediate_state_indicesdiffer, at least one tensor will be reinterpreted with the wrong element width in the CUDA path.🛠️ Proposed fix
- # Determine stateIndex dtype from index tensors, default to int32 - stateIndex_dtype = torch.int32 - if state_batch_indices is not None: - stateIndex_dtype = state_batch_indices.dtype - elif dst_state_batch_indices is not None: - stateIndex_dtype = dst_state_batch_indices.dtype - elif intermediate_state_indices is not None: - stateIndex_dtype = intermediate_state_indices.dtype + # Determine stateIndex dtype from index tensors, default to int32. + # All index tensors in one launch must share the same dtype because the + # generated module only specializes a single stateIndex_t. + index_dtypes = { + tensor.dtype + for tensor in ( + state_batch_indices, + dst_state_batch_indices, + intermediate_state_indices, + ) + if tensor is not None + } + if len(index_dtypes) > 1: + raise ValueError( + "state_batch_indices, dst_state_batch_indices, and " + "intermediate_state_indices must share the same dtype" + ) + stateIndex_dtype = next(iter(index_dtypes), torch.int32)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/mamba/selective_state_update.py` around lines 274 - 282, The current logic in selective_state_update.py picks the first non-None dtype into stateIndex_dtype without ensuring the other index tensors match; update the block that sets stateIndex_dtype to validate that all non-None tensors among state_batch_indices, dst_state_batch_indices, and intermediate_state_indices share the same dtype (compare their .dtype to the chosen stateIndex_dtype) and raise a clear ValueError if any mismatch is found, so the CUDA path won't reinterpret tensors with the wrong element width; alternatively, if you prefer automatic fixes, cast any mismatched tensors to the chosen stateIndex_dtype before continuing, but be explicit about which approach you take in the error/logic.include/flashinfer/mamba/kernel_selective_state_update_stp.cuh (1)
670-727:⚠️ Potential issue | 🔴 CriticalReject padded destination slots before enabling SM90 writeback.
The vertical/horizontal SM90 paths only gate writeback on the source slot. If
dst_state_batch_indicescontainspad_slot_id, the producers still issue TMA writes to batch-1, and the vertical scaled-state path also stores decode scales through that padded destination.🛠️ Proposed fix
- auto const write_state = read_state && params.update_state; + auto const write_state = + read_state && params.update_state && dst_state_batch != params.pad_slot_id; @@ - if (params.update_state && state_batch != params.pad_slot_id) { + if (params.update_state && state_batch != params.pad_slot_id && + dst_state_batch != params.pad_slot_id) { if (d < DIM) { state_scale[dst_state_batch * params.state_scale_stride_batch + head * DIM + d] = sram.state_scale[d]; }Also applies to: 787-791, 1058-1107
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/mamba/kernel_selective_state_update_stp.cuh` around lines 670 - 727, The code currently gates writeback only on the source slot (read_state) causing producers to issue SM90/TMA writes when dst_state_batch equals pad_slot_id; compute dst_state_batch (from params.dst_state_batch_indices) early and reject padded destination slots by changing the write enable to also require dst_state_batch != params.pad_slot_id (e.g., set write_state = read_state && params.update_state && dst_state_batch != params.pad_slot_id), and use that write_state when instantiating/calling producer_func_vertical/producer paths and before any scaled-state stores so no SM90/TMA or state-scale writes occur for padded destination slots (apply same guard to the horizontal/vertical SM90 paths and the other occurrences noted around the dst-related blocks).
♻️ Duplicate comments (7)
flashinfer/mamba/selective_state_update.py (1)
286-291:⚠️ Potential issue | 🟠 MajorReject varlen sequences longer than
cache_steps.
ntokens_mtpis specialized directly fromcache_steps, while the MTP kernel iterates exactly that many steps. If anycu_seqlensspan is longer, the tail tokens are silently skipped and their outputs/state never get written.🛠️ Proposed fix
- if is_varlen: - ntokens_mtp = cache_steps + if is_varlen: + max_seqlen = int((cu_seqlens[1:] - cu_seqlens[:-1]).max().item()) + if cache_steps < max_seqlen: + raise ValueError( + f"cache_steps ({cache_steps}) must be >= max sequence length ({max_seqlen})" + ) + ntokens_mtp = cache_steps elif x.dim() == 4: ntokens_mtp = x.size(1)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/mamba/selective_state_update.py` around lines 286 - 291, The varlen branch sets ntokens_mtp = cache_steps unconditionally which lets the MTP kernel iterate cache_steps and silently drop any varlen spans longer than cache_steps; before assigning ntokens_mtp when is_varlen is true, check cu_seqlens (the input cumulative sequence lengths for varlen batches) for any span length > cache_steps and raise an explicit error (ValueError) if found; otherwise keep ntokens_mtp = cache_steps. Reference the is_varlen branch, ntokens_mtp, cache_steps, and cu_seqlens to locate where to add this validation.benchmarks/routines/mamba.py (1)
296-307:⚠️ Potential issue | 🟠 MajorGrow the cache before materializing varlen src/dst indices.
Varlen mode needs
2 * n_seqs * max_seqlendistinct slots, butssm_state_cache_sizeis still independent ofmax_seqlen. Largercache_stepsmakes the secondperm[...]slice too short and the reshape fails before benchmarking starts.🛠️ Proposed fix
## Prepare input tensors ssm_state_cache_size = max(384, batch_size * 10) + if is_varlen: + ssm_state_cache_size = max( + ssm_state_cache_size, 2 * n_seqs * max_seqlen + ) # State cache: (total_entries, nheads, dim, dstate) - contiguous state_cache = torch.randn(Also applies to: 360-369
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/routines/mamba.py` around lines 296 - 307, When is_varlen is true, ssm_state_cache_size must be grown to accommodate the varlen indices: ensure ssm_state_cache_size is set to at least max(current_min, 2 * n_seqs * max_seqlen) before creating state_cache and before any perm[...] slicing; update the calculation that sets ssm_state_cache_size (used to allocate state_cache) to use max(384, batch_size * 10, 2 * n_seqs * max_seqlen) so the subsequent perm slices and reshape succeed, and apply the same change in the later block that also computes ssm_state_cache_size (the second occurrence around the other perm/reshape usage).include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh (1)
411-432:⚠️ Potential issue | 🟠 MajorHonor
update_stateon the dst-slot path.
disable_state_update=Truecurrently suppresses only the final source-slot write. Whendst_state_batch_indicesis present, this block still writes intoparams.state/params.state_scale, so verification runs mutate the cache anyway.🛠️ Proposed fix
- if (has_dst_indices) { + if (params.update_state && has_dst_indices) { auto dst_idx = static_cast<int64_t>( dst_state_batch_indices[seq_idx * params.dst_state_batch_indices_stride_batch + step * params.dst_state_batch_indices_stride_T]); if (dst_idx != params.pad_slot_id) { auto* dst_state_ptr = reinterpret_cast<state_t*>(params.state); @@ - } else if (has_intermediate) { + } else if (has_intermediate) {🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh` around lines 411 - 432, The dst-slot write path still updates params.state and params.state_scale even when state updates should be disabled; wrap the writes inside the dst-state block (the loop writing into params.state and the scaleState branch that writes into params.state_scale) with a guard that respects the update flag (e.g., if constexpr (!disable_state_update) or the existing update_state template/flag), so that when disable_state_update is true no writes occur to dst_state_batch_indices/dst_idx -> params.state or params.state_scale; reference dst_state_batch_indices, dst_idx, params.state, params.state_scale, scaleState, sram.state, load_state_t and the existing dst-slot write loop to locate the changes.csrc/selective_state_update.cu (4)
570-579:⚠️ Potential issue | 🟠 MajorRequire 2D
state_batch_indiceswhennum_accepted_tokensis provided.The kernel path uses per-token accepted offsets; with 1D
state_batch_indices,state_batch_indices_stride_Tbecomes 0 (Line 552), so accepted-token indexing is ignored.🛠️ Suggested fix
if (num_accepted_tokens.has_value()) { @@ FLASHINFER_CHECK(state_batch_indices.has_value(), "state_batch_indices is required when num_accepted_tokens is provided"); + FLASHINFER_CHECK(state_batch_indices.value().dim() == 2, + "state_batch_indices must be 2D when num_accepted_tokens is provided"); p.num_accepted_tokens = const_cast<void*>(nat.data_ptr()); }
561-569:⚠️ Potential issue | 🔴 CriticalReject
cu_seqlensfor 4D input and only pack it in varlen mode.
cu_seqlensis currently packed whenever present (Line 561), while dispatcher still routesx.dim()==4into MTP (Line 669). This can incorrectly enable varlen addressing for non-varlen layout.🛠️ Suggested fix
bool const has_cu_seqlens = cu_seqlens.has_value(); + FLASHINFER_CHECK(!(has_cu_seqlens && x.dim() != 3), + "cu_seqlens is only supported when x is 3D varlen layout"); @@ - if (cu_seqlens.has_value()) { + if (is_varlen) { auto const& cs = cu_seqlens.value(); CHECK_CUDA(cs); CHECK_DIM(1, cs);Also applies to: 664-670
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/selective_state_update.cu` around lines 561 - 569, The code currently packs cu_seqlens into p.cu_seqlens whenever present, which allows varlen addressing for inputs with x.dim()==4 (MTP); change the logic so cu_seqlens is only accepted and assigned when the input is in varlen mode (i.e., not 4D). Concretely, in the block handling cu_seqlens (the code that reads cs and sets p.cu_seqlens) add a guard that rejects or ignores cu_seqlens if x.dim() == 4 (or check the varlen flag used by the dispatcher) and only const_cast and assign to p.cu_seqlens when varlen is true; apply the same change to the other symmetric packing site that currently always assigns p.cu_seqlens.
358-472:⚠️ Potential issue | 🔴 CriticalVarlen mode is missing first-dimension token-count consistency checks.
In varlen branches,
dt,B,C, optionalz, and optionaloutare not checked againstx.size(0)on their first dimension. This can allow undersized tensors and out-of-bounds access when indexed by token offsets.🛠️ Suggested fix
if (is_varlen) { CHECK_DIM(3, x); // x: {total_tokens, nheads, dim} + auto const total_tokens = x.size(0); FLASHINFER_CHECK(x.size(1) == nheads, "x.size(1) must equal nheads"); FLASHINFER_CHECK(x.size(2) == dim, "x.size(2) must equal dim"); @@ if (is_varlen) { CHECK_DIM(3, dt); // dt: {total_tokens, nheads, dim} + FLASHINFER_CHECK(dt.size(0) == x.size(0), "dt.size(0) must equal total_tokens"); @@ if (is_varlen) { CHECK_DIM(3, B); // B: {total_tokens, ngroups, dstate} + FLASHINFER_CHECK(B.size(0) == x.size(0), "B.size(0) must equal total_tokens"); @@ if (is_varlen) { CHECK_DIM(3, C); // C: {total_tokens, ngroups, dstate} + FLASHINFER_CHECK(C.size(0) == x.size(0), "C.size(0) must equal total_tokens"); @@ if (is_varlen) { CHECK_DIM(3, z_tensor); // z: {total_tokens, nheads, dim} + FLASHINFER_CHECK(z_tensor.size(0) == x.size(0), "z.size(0) must equal total_tokens"); @@ if (is_varlen) { CHECK_DIM(3, output); // out: {total_tokens, nheads, dim} + FLASHINFER_CHECK(output.size(0) == x.size(0), "out.size(0) must equal total_tokens");🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/selective_state_update.cu` around lines 358 - 472, The varlen branches fail to validate that dt, B, C, z (z_tensor) and out (output) have their first dimension equal to x.size(0) (total_tokens), which risks OOB when indexing by token offsets; update the checks inside each is_varlen block to assert dt.size(0) == x.size(0), B.size(0) == x.size(0), C.size(0) == x.size(0), and if present z_tensor.size(0) == x.size(0) and output.size(0) == x.size(0) (use the same FLASHINFER_CHECK style as other checks), referencing the existing symbols dt, B, C, z_tensor, output and x.size(0)/total_tokens to locate where to add these assertions.
76-87:⚠️ Potential issue | 🔴 CriticalAdd CUDA-device validation for
state_batch_indicesbefore pointer packing.
state_batch_indicesis used to populate raw kernel pointers (Line 282 and Line 550), butvalidate_state_batch_indicesnever enforces CUDA residency. A host tensor here can be dereferenced from device code.🛠️ Suggested fix
inline void validate_state_batch_indices(Optional<TensorView> const& state_batch_indices, int64_t batch, int64_t max_seqlen = 1) { if (!state_batch_indices.has_value()) return; auto const& sbi = state_batch_indices.value(); + CHECK_CUDA(sbi); FLASHINFER_CHECK(sbi.dim() == 1 || sbi.dim() == 2, "state_batch_indices must be 1D or 2D, got ", sbi.dim(), "D");🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/selective_state_update.cu` around lines 76 - 87, validate_state_batch_indices currently only checks shape/size but not device residency, so a host tensor can be used when packing raw kernel pointers (e.g., where state_batch_indices is dereferenced to build device pointers for kernels). Update validate_state_batch_indices to assert the tensor is on CUDA (check sbi.is_cuda() or sbi.device().is_cuda()) before returning; if not, raise a FLASHINFER_CHECK/appropriate error explaining it must be a CUDA tensor. Keep the existing shape/size checks (sbi.dim(), sbi.size(...)) and apply this device check early (before any pointer-packing sites that consume state_batch_indices).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/selective_state_update.cu`:
- Around line 346-367: When is_varlen is true, validate cu_seqlens length before
computing batch: check that cu_seqlens.has_value() and cu_seqlens->size(0) >= 2
(so batch = cu_seqlens.value().size(0) - 1 is non-negative) and emit a
FLASHINFER_CHECK with a clear message if not; update the same guard in the other
varlen block that computes batch (the later block around ntokens/offset
handling) to avoid deriving batch = -1 from an empty cu_seqlens. Ensure you
perform this check prior to assigning batch and before any subsequent uses of
batch or indexing into cu_seqlens.
In `@include/flashinfer/mamba/kernel_selective_state_update_stp.cuh`:
- Around line 121-136: The code computes
dst_state_batch/dst_state/dst_state_scale from dst_sbi but then only uses the
source-slot write guards, so if dst_state_batch_indices contains a padded slot
(pad_slot_id) the kernel can write through a negative/invalid cache index; fix
by computing a boolean dst_valid (e.g., dst_sbi != nullptr && dst_state_batch !=
pad_slot_id) and use that same dst_valid wherever writes to dst_state or
dst_state_scale occur (mirror the existing source-slot guard logic), and ensure
any arithmetic that constructs dst_state and dst_state_scale is only used when
dst_valid is true to avoid creating/using invalid pointers (apply same change
around the other occurrences mentioned: the blocks around lines with
dst_state/dst_state_scale at the other offsets).
---
Outside diff comments:
In `@flashinfer/mamba/selective_state_update.py`:
- Around line 274-282: The current logic in selective_state_update.py picks the
first non-None dtype into stateIndex_dtype without ensuring the other index
tensors match; update the block that sets stateIndex_dtype to validate that all
non-None tensors among state_batch_indices, dst_state_batch_indices, and
intermediate_state_indices share the same dtype (compare their .dtype to the
chosen stateIndex_dtype) and raise a clear ValueError if any mismatch is found,
so the CUDA path won't reinterpret tensors with the wrong element width;
alternatively, if you prefer automatic fixes, cast any mismatched tensors to the
chosen stateIndex_dtype before continuing, but be explicit about which approach
you take in the error/logic.
In `@include/flashinfer/mamba/kernel_selective_state_update_stp.cuh`:
- Around line 670-727: The code currently gates writeback only on the source
slot (read_state) causing producers to issue SM90/TMA writes when
dst_state_batch equals pad_slot_id; compute dst_state_batch (from
params.dst_state_batch_indices) early and reject padded destination slots by
changing the write enable to also require dst_state_batch != params.pad_slot_id
(e.g., set write_state = read_state && params.update_state && dst_state_batch !=
params.pad_slot_id), and use that write_state when instantiating/calling
producer_func_vertical/producer paths and before any scaled-state stores so no
SM90/TMA or state-scale writes occur for padded destination slots (apply same
guard to the horizontal/vertical SM90 paths and the other occurrences noted
around the dst-related blocks).
---
Duplicate comments:
In `@benchmarks/routines/mamba.py`:
- Around line 296-307: When is_varlen is true, ssm_state_cache_size must be
grown to accommodate the varlen indices: ensure ssm_state_cache_size is set to
at least max(current_min, 2 * n_seqs * max_seqlen) before creating state_cache
and before any perm[...] slicing; update the calculation that sets
ssm_state_cache_size (used to allocate state_cache) to use max(384, batch_size *
10, 2 * n_seqs * max_seqlen) so the subsequent perm slices and reshape succeed,
and apply the same change in the later block that also computes
ssm_state_cache_size (the second occurrence around the other perm/reshape
usage).
In `@csrc/selective_state_update.cu`:
- Around line 561-569: The code currently packs cu_seqlens into p.cu_seqlens
whenever present, which allows varlen addressing for inputs with x.dim()==4
(MTP); change the logic so cu_seqlens is only accepted and assigned when the
input is in varlen mode (i.e., not 4D). Concretely, in the block handling
cu_seqlens (the code that reads cs and sets p.cu_seqlens) add a guard that
rejects or ignores cu_seqlens if x.dim() == 4 (or check the varlen flag used by
the dispatcher) and only const_cast and assign to p.cu_seqlens when varlen is
true; apply the same change to the other symmetric packing site that currently
always assigns p.cu_seqlens.
- Around line 358-472: The varlen branches fail to validate that dt, B, C, z
(z_tensor) and out (output) have their first dimension equal to x.size(0)
(total_tokens), which risks OOB when indexing by token offsets; update the
checks inside each is_varlen block to assert dt.size(0) == x.size(0), B.size(0)
== x.size(0), C.size(0) == x.size(0), and if present z_tensor.size(0) ==
x.size(0) and output.size(0) == x.size(0) (use the same FLASHINFER_CHECK style
as other checks), referencing the existing symbols dt, B, C, z_tensor, output
and x.size(0)/total_tokens to locate where to add these assertions.
- Around line 76-87: validate_state_batch_indices currently only checks
shape/size but not device residency, so a host tensor can be used when packing
raw kernel pointers (e.g., where state_batch_indices is dereferenced to build
device pointers for kernels). Update validate_state_batch_indices to assert the
tensor is on CUDA (check sbi.is_cuda() or sbi.device().is_cuda()) before
returning; if not, raise a FLASHINFER_CHECK/appropriate error explaining it must
be a CUDA tensor. Keep the existing shape/size checks (sbi.dim(), sbi.size(...))
and apply this device check early (before any pointer-packing sites that consume
state_batch_indices).
In `@flashinfer/mamba/selective_state_update.py`:
- Around line 286-291: The varlen branch sets ntokens_mtp = cache_steps
unconditionally which lets the MTP kernel iterate cache_steps and silently drop
any varlen spans longer than cache_steps; before assigning ntokens_mtp when
is_varlen is true, check cu_seqlens (the input cumulative sequence lengths for
varlen batches) for any span length > cache_steps and raise an explicit error
(ValueError) if found; otherwise keep ntokens_mtp = cache_steps. Reference the
is_varlen branch, ntokens_mtp, cache_steps, and cu_seqlens to locate where to
add this validation.
In `@include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh`:
- Around line 411-432: The dst-slot write path still updates params.state and
params.state_scale even when state updates should be disabled; wrap the writes
inside the dst-state block (the loop writing into params.state and the
scaleState branch that writes into params.state_scale) with a guard that
respects the update flag (e.g., if constexpr (!disable_state_update) or the
existing update_state template/flag), so that when disable_state_update is true
no writes occur to dst_state_batch_indices/dst_idx -> params.state or
params.state_scale; reference dst_state_batch_indices, dst_idx, params.state,
params.state_scale, scaleState, sram.state, load_state_t and the existing
dst-slot write loop to locate the changes.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: ec58034c-444e-40c0-b640-dabfc9d0299a
📒 Files selected for processing (12)
.gitignorebenchmarks/routines/mamba.pycsrc/flashinfer_mamba_binding.cucsrc/selective_state_update.cucsrc/selective_state_update_customize_config.jinjaflashinfer/jit/mamba/selective_state_update.pyflashinfer/mamba/selective_state_update.pyinclude/flashinfer/mamba/kernel_selective_state_update_mtp.cuhinclude/flashinfer/mamba/kernel_selective_state_update_stp.cuhinclude/flashinfer/mamba/selective_state_update.cuhtests/mamba/test_selective_state_update_varlen.pytests/mamba/triton_reference/selective_state_update_varlen.py
🚧 Files skipped from review as they are similar to previous changes (3)
- include/flashinfer/mamba/selective_state_update.cuh
- csrc/selective_state_update_customize_config.jinja
- .gitignore
| bool const is_varlen = (x.dim() == 3 && cu_seqlens.has_value()); | ||
| // Extract dimensions from input tensors | ||
| auto const batch = x.size(0); | ||
| auto const ntokens_mtp = x.size(1); | ||
| int64_t batch; | ||
| int64_t ntokens_mtp; | ||
|
|
||
| auto const state_cache_size = state.size(0); | ||
| auto const nheads = state.size(1); | ||
| auto const dim = state.size(2); | ||
| auto const dstate = state.size(3); | ||
| auto const ngroups = B.size(2); | ||
|
|
||
| FLASHINFER_CHECK(state_cache_size >= batch, "state.size(0) must be >= x.size(0)"); | ||
| FLASHINFER_CHECK(nheads % ngroups == 0, "nheads must be divisible by ngroups"); | ||
|
|
||
| // Check x shape and strides | ||
| CHECK_CUDA(x); | ||
| CHECK_DIM(4, x); | ||
| FLASHINFER_CHECK(x.size(2) == nheads, "x.size(2) must equal nheads"); | ||
| FLASHINFER_CHECK(x.size(3) == dim, "x.size(3) must equal dim"); | ||
| CHECK_LAST_DIM_CONTIGUOUS(x); | ||
| FLASHINFER_CHECK(x.stride(2) == dim, "x.stride(2) must equal dim, got ", x.stride(2), | ||
| " expected ", dim); | ||
| if (is_varlen) { | ||
| CHECK_DIM(3, x); // x: {total_tokens, nheads, dim} | ||
| FLASHINFER_CHECK(x.size(1) == nheads, "x.size(1) must equal nheads"); | ||
| FLASHINFER_CHECK(x.size(2) == dim, "x.size(2) must equal dim"); | ||
| CHECK_LAST_DIM_CONTIGUOUS(x); | ||
| FLASHINFER_CHECK(x.stride(1) == dim, "x.stride(1) must equal dim"); | ||
| batch = cu_seqlens.value().size(0) - 1; | ||
| FLASHINFER_CHECK(cache_steps >= 1, | ||
| "cache_steps must be >= 1 in varlen mode (specifies max_seqlen)"); | ||
| ntokens_mtp = cache_steps; |
There was a problem hiding this comment.
Validate cu_seqlens length before deriving batch.
batch is computed as cu_seqlens.size(0) - 1 on Line 364 before enforcing a minimum length. If cu_seqlens is empty, batch becomes -1 and downstream checks can pass incorrectly.
🛠️ Suggested fix
if (is_varlen) {
+ auto const& cs = cu_seqlens.value();
+ CHECK_CUDA(cs);
+ CHECK_DIM(1, cs);
+ CHECK_CONTIGUOUS(cs);
+ FLASHINFER_CHECK(cs.size(0) >= 1, "cu_seqlens must have at least one element");
CHECK_DIM(3, x); // x: {total_tokens, nheads, dim}
@@
- batch = cu_seqlens.value().size(0) - 1;
+ batch = cs.size(0) - 1;
@@
- if (cu_seqlens.has_value()) {
- auto const& cs = cu_seqlens.value();
- CHECK_CUDA(cs);
- CHECK_DIM(1, cs);
- CHECK_CONTIGUOUS(cs);
+ if (is_varlen) {
+ auto const& cs = cu_seqlens.value();
FLASHINFER_CHECK(cs.size(0) == batch + 1, "cu_seqlens.size(0) must equal n_sequences + 1 (",
batch + 1, ")");Also applies to: 561-567
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@csrc/selective_state_update.cu` around lines 346 - 367, When is_varlen is
true, validate cu_seqlens length before computing batch: check that
cu_seqlens.has_value() and cu_seqlens->size(0) >= 2 (so batch =
cu_seqlens.value().size(0) - 1 is non-negative) and emit a FLASHINFER_CHECK with
a clear message if not; update the same guard in the other varlen block that
computes batch (the later block around ntokens/offset handling) to avoid
deriving batch = -1 from an empty cu_seqlens. Ensure you perform this check
prior to assigning batch and before any subsequent uses of batch or indexing
into cu_seqlens.
| auto const* __restrict__ dst_sbi = | ||
| reinterpret_cast<stateIndex_t const*>(params.dst_state_batch_indices); | ||
| auto const dst_state_batch = | ||
| dst_sbi ? static_cast<int64_t>(dst_sbi[batch * params.dst_state_batch_indices_stride_batch]) | ||
| : state_batch; | ||
| auto const state_ptr_offset = state_batch * params.state_stride_batch + head * DIM * DSTATE; | ||
| state += state_ptr_offset; | ||
| auto* __restrict__ dst_state = reinterpret_cast<state_t*>(params.state) + | ||
| dst_state_batch * params.state_stride_batch + head * DIM * DSTATE; | ||
| if constexpr (scaleState) { | ||
| state_scale += state_batch * params.state_scale_stride_batch + head * DIM; | ||
| } | ||
| [[maybe_unused]] auto* __restrict__ dst_state_scale = | ||
| scaleState ? reinterpret_cast<state_scale_t*>(params.state_scale) + | ||
| dst_state_batch * params.state_scale_stride_batch + head * DIM | ||
| : nullptr; |
There was a problem hiding this comment.
Guard simple-kernel writeback when dst_state_batch_indices is padded.
dst_state / dst_state_scale are derived from dst_state_batch, but every write guard only checks the source slot. If a caller uses pad_slot_id in dst_state_batch_indices, this path writes through a negative cache index.
🛠️ Proposed fix
auto const dst_state_batch =
dst_sbi ? static_cast<int64_t>(dst_sbi[batch * params.dst_state_batch_indices_stride_batch])
: state_batch;
+ auto const dst_writable = dst_state_batch != params.pad_slot_id;
auto const state_ptr_offset = state_batch * params.state_stride_batch + head * DIM * DSTATE;
state += state_ptr_offset;
- auto* __restrict__ dst_state = reinterpret_cast<state_t*>(params.state) +
- dst_state_batch * params.state_stride_batch + head * DIM * DSTATE;
+ auto* __restrict__ dst_state =
+ dst_writable
+ ? reinterpret_cast<state_t*>(params.state) +
+ dst_state_batch * params.state_stride_batch + head * DIM * DSTATE
+ : nullptr;
@@
- if (!scaleState && params.update_state && state_batch != params.pad_slot_id) {
+ if (!scaleState && params.update_state && state_batch != params.pad_slot_id &&
+ dst_writable) {
*reinterpret_cast<load_state_t*>(&dst_state[d * DSTATE + i]) = rState;
}
@@
- if (params.update_state && state_batch != params.pad_slot_id) {
+ if (params.update_state && state_batch != params.pad_slot_id && dst_writable) {
@@
- if (params.update_state && state_batch != params.pad_slot_id) {
+ if (params.update_state && state_batch != params.pad_slot_id && dst_writable) {
for (int l = lane; l < rowsPerWarp; l += warpSize) {Also applies to: 244-246, 254-270, 297-303
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@include/flashinfer/mamba/kernel_selective_state_update_stp.cuh` around lines
121 - 136, The code computes dst_state_batch/dst_state/dst_state_scale from
dst_sbi but then only uses the source-slot write guards, so if
dst_state_batch_indices contains a padded slot (pad_slot_id) the kernel can
write through a negative/invalid cache index; fix by computing a boolean
dst_valid (e.g., dst_sbi != nullptr && dst_state_batch != pad_slot_id) and use
that same dst_valid wherever writes to dst_state or dst_state_scale occur
(mirror the existing source-slot guard logic), and ensure any arithmetic that
constructs dst_state and dst_state_scale is only used when dst_valid is true to
avoid creating/using invalid pointers (apply same change around the other
occurrences mentioned: the blocks around lines with dst_state/dst_state_scale at
the other offsets).
|
/bot run |
|
[FAILED] Pipeline #46539983: 13/20 passed |
…state support to selective_state_update
fd5eed6 to
b6e179b
Compare
|
/bot run |
|
[SUCCESS] Pipeline #46618531: 14/20 passed |
|
tests clean |
📌 Description
vLLM uses a different scheme for speculative decoding and prefix caching, when compared with SGLang and TRT-LLM, namly:
dst_state_batch_indices- telling the kernel where in the state tensor to store the newly computed statecu_seqlens- allowing for a varying number of tokens per sequence when speculative decoding is enablednum_accepted_tokens- used to decide from which index in the state tensor to read the initial cached state per sequence in speculative decodingThis PR adds support for all of these, while keeping support for previous variants, and without hurting performance.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
API
Tests
Chores