Skip to content

MTP for mamba #2444

Merged
yzh119 merged 15 commits intoflashinfer-ai:mainfrom
ishovkun:main
Feb 3, 2026
Merged

MTP for mamba #2444
yzh119 merged 15 commits intoflashinfer-ai:mainfrom
ishovkun:main

Conversation

@ishovkun
Copy link
Copy Markdown
Contributor

@ishovkun ishovkun commented Jan 30, 2026

📌 Description

This contribution implements the following changes:

  • Multi-token prediction for mamba
  • Handling state and intermediate state cache that are non-contiguous in batch dimension
  • Handling int32 and int64 cache indices
  • More checks for dtypes and consistency of dtypes
  • Unified template dispatch functions (that hopefully will be replaced with jinja templates in the future).
  • NO Eagle3 yet

The new kernel yields 2.78x throughput compared to SGLang's Triton implementation:

mtp_b200_02

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Sorry that it's so big. Please, see if I'm missing any important tests.

Summary by CodeRabbit

  • New Features

    • Multi-token prediction support (cache_steps) and new public parameters: disable_state_update, intermediate_states_buffer, intermediate_state_indices, and optional out tensor; updated docstrings/shapes.
  • Validation & Errors

    • Centralized, stricter input validation and clearer human-readable error messages for unsupported configurations.
  • Tests

    • Added extensive single-token and multi-token test suites and test utilities; removed an older test file.

✏️ Tip: You can customize this high-level summary in your review settings.

Move the test input generation helper from
test_selective_state_update.py
to a new test_utils.py module for reuse across tests. The refactored
function adds support for multi-token mode, intermediate state buffers,
and configurable state cache strides.
struct

- Add helper functions for tensor validation and dtype checks
- Move output tensor to Optional and update checks accordingly
- Add state_stride_batch and update_state fields to
  SelectiveStateUpdateParams
- Refactor kernel param usage for clarity and consistency
Extract dispatchDimDstate and dispatchRatio helpers to simplify
kernel dispatch code and reduce duplication.
- Add kernel and dispatcher support for int32/int64 state_batch_indices
- Update tests to cover int32 indices
- Fix test_utils to use int64 slot_idx by default
  Support int32 and int64 state_batch_indices in selective_state_update

- Remove int32 type check to allow both int32 and int64 index types
- Add stateIndex_t template parameter to kernels for index type dispatch
- Extract kernel implementations to new selective_state_update_stp.cuh
- Remove unused TMA helper functions from create_tensor_map.cuh
- Add comprehensive MTP (multi-token prediction) test suite
checks

- Add common.cuh with kernel dispatch helpers and alignment checks
- Split and rename kernel_selective_state_update_stp.cuh, add
  kernel_selective_state_update_mtp.cuh
- Refactor Python selective_state_update to clarify dimension handling
- Add test for dtype mismatch between state_batch_indices and
  intermediate_state_indices
- Update test_utils to generate int64 intermediate_slot_idx by default
- Remove redundant input type check in
  validate_intermediate_state_indices
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Jan 30, 2026

📝 Walkthrough

Walkthrough

This PR adds multi‑token prediction (MTP) support and richer validation/dispatch for selective_state_update: new STP/MTP CUDA kernels, dtype-driven compile-time dispatch, expanded parameter structs, Python/C++ bindings for intermediate-state caching and disable_state_update, and extensive tests/utilities.

Changes

Cohort / File(s) Summary
C++ binding
csrc/flashinfer_mamba_binding.cu
Updated selective_state_update signature: moved output, added disable_state_update, intermediate_states_buffer, intermediate_state_indices, cache_steps.
Dispatch & validation
csrc/selective_state_update.cu
Centralized tensor validation, dtype-code mapping and compile-time dispatcher; STP vs MTP dispatch paths; support for optional intermediate buffers and detailed error messages.
Python API & bindings
flashinfer/mamba/selective_state_update.py
Public API extended with out, disable_state_update, intermediate_states_buffer, intermediate_state_indices, cache_steps; MTP shape augmentation and propagation to internal bindings (including meta/fake paths).
Core headers & params
include/flashinfer/mamba/selective_state_update.cuh
Refactored params structs (SelectiveStateUpdateParams, SelectiveStateMTPParams), replaced inline kernels with external includes, added AllowedDims/AllowedDstates/AllowedNtokens.
Common CUDA utilities
include/flashinfer/mamba/common.cuh
New packed load helpers, warp reductions, softplus helpers, alignment checks, and dimension/dstate/token dispatch helpers (dispatchDimDstate*, dispatchRatio).
TMA / tensor map builder
include/flashinfer/mamba/create_tensor_map.cuh
New buildNdDescriptor with runtime validation and exception-based diagnostics for cuTensorMap descriptor creation.
STP kernels & launcher
include/flashinfer/mamba/kernel_selective_state_update_stp.cuh
New single‑token kernel variants: simple, SM90 vertical, SM100 horizontal producer/consumer paths and unified invokeSelectiveStateUpdate launcher.
MTP kernels & launcher
include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh
New multi‑token kernel and invokeSelectiveStateUpdateMTP launcher with per-token accumulation and optional intermediate-state caching.
Triton reference
tests/mamba/selective_state_update_triton.py
Expanded Triton reference to support multi‑token loops, intermediate caching, retrieve_parent_token, and new kernel args.
Tests & utils
tests/mamba/test_utils.py, tests/mamba/test_selective_state_update_stp.py, tests/mamba/test_selective_state_update_mtp.py
Added test utilities and extensive STP/MTP test suites covering out, disable_state_update, intermediate buffers, dtype/index edge cases.
Removed tests
tests/mamba/test_selective_state_update.py
Deleted prior monolithic test file (replaced by new, more granular tests).

Sequence Diagram(s)

sequenceDiagram
    participant PyAPI as Python API
    participant CPP as C++ Binding
    participant Dispatch as Dispatcher (csrc)
    participant KernelMgr as Kernel Launcher (STP/MTP)
    participant CUDA as CUDA Device

    PyAPI->>CPP: selective_state_update(..., cache_steps, out, intermediate_states_buffer, ...)
    CPP->>Dispatch: call C++ selective_state_update with all params
    Dispatch->>Dispatch: validate tensors & dtypes
    Dispatch->>Dispatch: choose STP or MTP (cache_steps)
    alt STP (single-token)
        Dispatch->>KernelMgr: invokeSelectiveStateUpdate(params)
        KernelMgr->>CUDA: launch STP kernel variant (simple/SM90/SM100)
    else MTP (multi-token)
        Dispatch->>KernelMgr: invokeSelectiveStateUpdateMTP(params)
        KernelMgr->>CUDA: launch MTP kernel (per-token loop, optional intermediate state caching)
    end
    CUDA-->>Dispatch: write outputs, state, intermediate buffers
    Dispatch-->>CPP: return results
    CPP-->>PyAPI: return output tensor
Loading
sequenceDiagram
    participant PyFront as Public Python
    participant ShapeProc as Shape/Dim Normalization
    participant TorchLib as Torch meta/fake
    participant Backend as C++/CUDA

    PyFront->>ShapeProc: selective_state_update(..., cache_steps, ...)
    ShapeProc->>ShapeProc: derive is_mtp, expand/squeeze dims, normalize A/D/dt_bias
    ShapeProc->>TorchLib: call _selective_state_update (meta/fake path)
    TorchLib->>Backend: compiled binding -> C++ implementation
    Backend->>Backend: dtype dispatch & validation -> select kernel launcher
    Backend->>CUDA: launch kernels (STP/MTP)
    CUDA-->>Backend: results
    Backend-->>PyFront: final tensor
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

  • PR #2387: Adds SM100 horizontal producer-consumer kernel and updates SM100 vs SM90 dispatch logic for selective_state_update.
  • PR #2392: Introduces optional user-provided out parameter and threads it through bindings.
  • PR #2301: Broad refactor of selective_state_update kernels, dispatch machinery, and Python/C++ bindings overlapping MTP/caching changes.

Suggested labels

v0.6.2

Suggested reviewers

  • djmmoss
  • cyx-6
  • bkryu
  • nvmbreughe
  • jimmyzho
  • jiahanc
  • yzh119
  • kahyunnam

Poem

🐰 Hopping through kernels, I cache and stride,

tokens march forward, one hop at a time.
Dispatch lines hum, dtypes all in tune,
STP and MTP dancing under the moon.
Hooray — states updated, the mamba’s moonshine!

🚥 Pre-merge checks | ✅ 1 | ❌ 2
❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 29.79% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The PR title 'MTP for mamba' is vague and overly generic, using non-descriptive terms that do not clearly convey what MTP is or the significance of the changes without prior context. Expand the title to be more specific, such as 'Add multi-token prediction support for mamba backend' to clearly identify the feature being implemented.
✅ Passed checks (1 passed)
Check name Status Explanation
Description check ✅ Passed The PR description addresses the required template sections: it provides a clear description of changes, mentions related improvements, and confirms that pre-commit checks and tests have been completed. However, the description could be more structured to match the template exactly.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @ishovkun, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the Mamba selective state update operation by introducing multi-token prediction capabilities, which are vital for efficient speculative decoding. The changes involve a comprehensive refactoring of the CUDA kernels and Python bindings to support processing multiple tokens concurrently. It also includes robust handling for various memory layouts, data types, and indexing schemes for state and intermediate caches, ensuring greater flexibility and stability.

Highlights

  • Multi-Token Prediction (MTP) for Mamba: Introduced support for processing multiple tokens simultaneously in the Mamba selective state update operation, crucial for speculative decoding.
  • Non-Contiguous State and Intermediate State Cache Handling: Implemented mechanisms to correctly handle state and intermediate state caches that are non-contiguous in the batch dimension, improving memory flexibility.
  • Support for int32 and int64 Cache Indices: Added robust support for both 32-bit and 64-bit integer types for cache indices, enhancing compatibility and addressing potential overflow issues for very large caches.
  • Enhanced Dtype Consistency Checks: Integrated more comprehensive checks to ensure data type consistency across various input tensors, preventing common runtime errors.
  • Unified Template Dispatch Functions: Refactored the kernel dispatch logic using unified template functions, improving code organization and maintainability for different data type combinations.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces multi-token prediction (MTP) for Mamba, a significant feature enhancement. The changes are extensive and well-structured, including robust validation, modern C++ dispatching mechanisms, and architecture-specific optimizations for CUDA kernels. The test suite is comprehensive, covering single-token, multi-token, and various edge cases, which provides confidence in the correctness of the implementation. I have one suggestion to improve type safety in the C++ code by using const for read-only data pointers.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In `@csrc/selective_state_update.cu`:
- Around line 161-174: The allowed_dtype_combos list currently only allows
bfloat16_code at tuple position 1 (the input_code), which blocks float16 inputs;
update allowed_dtype_combos to also include the same combinations with
float16_code in that second position so float16 inputs are accepted.
Specifically, for each existing tuple where the second element is bfloat16_code,
add a corresponding tuple with that element replaced by float16_code, preserving
the other elements (including both int32_code and int64_code variants) so the
dtype permutations match the upstream Mamba selective scan flexibility.

In `@include/flashinfer/mamba/kernel_selective_state_update_stp.cuh`:
- Around line 169-170: The simple kernel currently always writes back state with
the line that stores rState into state[d * DSTATE + i] when state_batch !=
params.pad_slot_id; change that write to also check params.update_state so it
only writes when state updates are enabled (i.e. guard the store with
params.update_state && state_batch != params.pad_slot_id), matching the
producer-consumer kernels' behavior (see params.update_state usage) and
preserving the pad_slot_id check; ensure you apply this to the same cast/store
using load_state_t and the state/DSTATE indexing so semantics remain identical
aside from honoring update_state.

In `@tests/mamba/selective_state_update_triton.py`:
- Around line 263-277: Remove the extra HAS_STATE_BATCH_INDICES guard so caching
mirrors the CUDA kernel: when CACHE_INTERMEDIATE_STATES is true and
state_batch_idx != pad_slot_id, always compute cache_ptr_base (using
intermediate_states_buffer, cache_idx, cache_steps, nheads, dim, dstate,
current_step_idx, pid_h) and cache_ptrs (using offs_m, offs_n) and call
tl.store(state.to(cache_ptrs.dtype.element_ty), mask=mask). Delete the enclosing
"if HAS_STATE_BATCH_INDICES:" condition around that caching block so the logic
only checks CACHE_INTERMEDIATE_STATES and state_batch_idx != pad_slot_id.
🧹 Nitpick comments (16)
include/flashinfer/mamba/create_tensor_map.cuh (2)

64-72: Consider validating tileShapes[0] for boxDim limit consistency.

The validation for tileShapes[ii] (ii > 0) checks against 256, but tileShapes[0] is assigned directly without limit checking. While the first dimension typically has different constraints (up to 256 bytes for the box extent), consider adding a similar bounds check for consistency, or add a comment explaining why the first dimension doesn't need the same validation.

Additionally, the error handling pattern mixes std::cerr with FLASHINFER_CHECK(false). Consider using the message directly in FLASHINFER_CHECK for consistency:

   boxDim[0] = tileShapes[0];
   for (size_t ii = 1; ii < tileShapes.size(); ++ii) {
-    if (tileShapes[ii] > 256) {
-      std::cerr << "buildNdTmaDescriptor: boxDim too large " << tileShapes[ii] << std::endl;
-      FLASHINFER_CHECK(false);
-    } else {
-      boxDim[ii] = tileShapes[ii];
-    }
+    FLASHINFER_CHECK(tileShapes[ii] <= 256, "buildNdTmaDescriptor: boxDim too large ", tileShapes[ii]);
+    boxDim[ii] = tileShapes[ii];
   }

86-121: Detailed error reporting is helpful for debugging TMA issues.

The comprehensive error reporting with shapes, strides, tile dimensions, and swizzle type will be valuable for diagnosing TMA descriptor creation failures. However, errorString from cuGetErrorString is retrieved but never used in the error message.

🔧 Proposed fix to include the error string
   if (result != CUDA_SUCCESS) {
     char const* errorString;
     cuGetErrorString(result, &errorString);
     std::stringstream ss;
-    ss << "Error: Failed to initialize the TMA descriptor " << result << std::endl;
+    ss << "Error: Failed to initialize the TMA descriptor (code=" << result 
+       << ", " << (errorString ? errorString : "unknown error") << ")" << std::endl;
include/flashinfer/mamba/common.cuh (2)

29-29: Redefinition of warpSize shadows CUDA built-in.

The constant warpSize shadows the CUDA built-in warpSize variable. While functionally equivalent on current hardware, this could cause confusion or issues if the built-in is referenced elsewhere. Consider renaming to kWarpSize or similar to avoid shadowing.

-constexpr unsigned warpSize = 32;
+constexpr unsigned kWarpSize = 32;

And update references accordingly (lines 59, 64).


43-50: Minor: Remove trailing semicolon after function body.

The semicolon after the closing brace of make_zeros() is unnecessary.

   return ret;
-};
+}
flashinfer/mamba/selective_state_update.py (1)

136-138: Clarify MTP mode semantics in comment.

The condition cache_steps >= 1 means that even cache_steps=1 triggers MTP mode with 4D tensors. The comment says "more than 1 token" but the condition includes cache_steps=1. Consider clarifying:

-    # Determine if we're in multi-token mode (more than 1 token)
-    is_mtp = cache_steps >= 1
+    # Determine if we're in multi-token mode (cache_steps provided)
+    # Note: cache_steps >= 1 triggers 4D tensor handling even for single token
+    is_mtp = cache_steps >= 1
tests/mamba/test_selective_state_update_stp.py (1)

284-286: Prefix unused variable with underscore.

Static analysis correctly identifies that state_ref is unpacked but unused in this test method.

     def test_output_correctness(self, inputs, reference_output, use_out_tensor):
         """Test that kernel output matches reference but state is not updated."""
-        y_ref, state_ref = reference_output
+        y_ref, _state_ref = reference_output
tests/mamba/test_selective_state_update_mtp.py (2)

291-293: Prefix unused variable with underscore.

     def test_output_correctness(self, inputs, reference_output, use_out_tensor):
         """Test that kernel output matches reference but state is not updated."""
-        y_ref, state_ref = reference_output
+        y_ref, _state_ref = reference_output

443-445: Prefix unused variable with underscore.

     def test_output_correctness(self, inputs, reference_output, use_out_tensor):
         """Test that kernel output matches and intermediate states are cached correctly."""
-        y_ref, state_ref, intermediate_states_ref = reference_output
+        y_ref, _state_ref, intermediate_states_ref = reference_output
include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh (3)

82-120: Warp assignment is tightly coupled to numWarps=4.

The loading logic hardcodes warp indices 0-3 for loading x, B, z, and C respectively. If numWarps template parameter is changed from 4, this code will break or leave some data unloaded. Consider adding a static_assert to enforce this assumption.

🛠️ Proposed fix to enforce the assumption
+  static_assert(numWarps == 4, "Loading logic assumes exactly 4 warps");
+
   if (warp == 0) {  // Load x: gmem -> smem

152-159: Add comment explaining the packed element calculation strategy.

This is a performance-critical hot path with non-trivial logic for computing packedSramLdInputElements. A brief comment explaining why this optimization reduces LSU load would help future reviewers understand the design choice.

As per coding guidelines: "For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers."


237-248: Potential performance concern: intermediate state writes inside the inner loop.

Writing to intermediate_states on every MTP step (line 237-248) inside the dimension loop could be a performance bottleneck. Each iteration writes the full DSTATE elements per dimension row. Consider whether buffering and batching these writes is feasible for better memory throughput.

csrc/selective_state_update.cu (1)

312-320: Uninitialized out_stride_batch when output is not provided.

When out is not provided, p.out_stride_batch is set to 0 (line 315-316), but p.output is also nullptr (line 329). The kernel should handle this case, but it would be cleaner to explicitly document that the kernel checks for nullptr output before using the stride.

include/flashinfer/mamba/selective_state_update.cuh (1)

72-73: Unusual include placement after namespace close.

Including kernel headers after the namespace closure is unconventional. While it works because those headers likely open the same namespace, consider moving these includes to the top of the file for consistency with standard C++ practices.

include/flashinfer/mamba/kernel_selective_state_update_stp.cuh (1)

16-17: Remove commented-out include guards.

These commented-out lines appear to be leftover from refactoring. They should be removed to keep the code clean.

🧹 Proposed fix
-// `#ifndef` FLASHINFER_MAMBA_SELECTIVE_STATE_UPDATE_CUH_
-// `#define` FLASHINFER_MAMBA_SELECTIVE_STATE_UPDATE_CUH_
-
 `#include` <cooperative_groups.h>
tests/mamba/selective_state_update_triton.py (2)

87-88: Unused batch parameter in kernel signature.

The batch parameter on line 87 is never used inside the kernel (the batch index is obtained via tl.program_id(axis=1)). Consider removing it from the signature to avoid confusion.

🧹 Proposed fix
     # Matrix dimensions
-    batch,
     T,
     nheads,

And update the kernel call accordingly.


286-286: Remove unused noqa directive.

The # noqa: SIM113 comment is flagged as unused by ruff. The manual loop counter increment is intentional for the pointer arithmetic pattern, but this specific rule isn't enabled.

🧹 Proposed fix
-        current_step_idx += 1  # noqa: SIM113
+        current_step_idx += 1

@bkryu
Copy link
Copy Markdown
Collaborator

bkryu commented Jan 30, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !278 has been created, and the CI pipeline #42916798 is currently running. I'll report back once the pipeline job completes.

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Jan 30, 2026

@flashinfer-bot run

Always define state_batch_idx (either from state_batch_indices or pid_b)
to mirror the CUDA kernel's state_batch variable. This allows the
intermediate state caching logic to use a simple check of
`state_batch_idx != pad_slot_id` without requiring an extra
HAS_STATE_BATCH_INDICES guard, matching the CUDA kernel behavior.

addresses:
flashinfer-ai#2444 (comment)
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/mamba/selective_state_update_triton.py (1)

68-90: ⚠️ Potential issue | 🟡 Minor

Rename unused batch parameter to _batch to silence Ruff ARG001 warning.

The parameter is not used within the kernel function; the kernel operates using pid_b = tl.program_id(axis=1) instead. Renaming it explicitly indicates it's intentionally unused.

Minimal fix
-    batch,
+    _batch,
🤖 Fix all issues with AI agents
In `@tests/mamba/selective_state_update_triton.py`:
- Line 284: The trailing noqa directive on the increment line is unnecessary;
remove the inline comment "# noqa: SIM113" from the statement that updates
current_step_idx (the line containing "current_step_idx += 1") so the code
increments the variable without the unused linter suppression.
- Around line 189-276: The cache write can index out-of-bounds when
intermediate_state_indices contains pad_slot_id (e.g., -1); before computing
cache_ptr_base and calling tl.store you must gate the write on a valid cache
index — e.g., check CACHE_INTERMEDIATE_STATES and that cache_idx != pad_slot_id
and cache_idx >= 0 (and optionally state_batch_idx != pad_slot_id) — so modify
the block that computes cache_ptr_base and calls tl.store (references:
intermediate_state_indices_ptr, cache_idx, pad_slot_id,
CACHE_INTERMEDIATE_STATES, state_batch_idx, intermediate_states_buffer,
cache_ptrs, current_step_idx, tl.store) to skip stores for invalid/padded cache
indices.
- Line 403: Replace the lambda assigned to grid with a named function to satisfy
Ruff E731: define a function (e.g., def grid(META):) that takes the META
argument and returns (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads),
keeping the same captured variables (dim, batch, nheads) and use that function
name wherever grid was used.
🧹 Nitpick comments (2)
include/flashinfer/mamba/kernel_selective_state_update_stp.cuh (2)

641-733: Hoist dA computation out of inner loops.

A_value and dt_value are invariant per thread; recomputing __expf per element adds avoidable work.

♻️ Suggested refactor
@@
-  // `#pragma` unroll 1
+  auto const dA = __expf(A_value * dt_value);
+  // `#pragma` unroll 1
@@
-          auto const dA = __expf(A_value * dt_value);
           auto const dB = B_value * dt_value;
@@
-          auto const dA = __expf(A_value * dt_value);
           auto const dB = B_value * dt_value;

910-915: Document the numWarps = 4 heuristic for the simple path.

A short rationale for this tuning choice (and alternatives) will help future perf work.

📝 Suggested comment
-      constexpr int numWarps = 4;
+      constexpr int numWarps = 4;  // tuned for occupancy vs. register pressure; consider retuning for new GPUs
As per coding guidelines, for performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers.

Comment on lines +189 to +276
cache_idx = -1
if CACHE_INTERMEDIATE_STATES:
if HAS_INTERMEDIATE_STATE_INDICES:
intermediate_state_idx = tl.load(intermediate_state_indices_ptr + pid_b).to(
tl.int64
)
cache_idx = intermediate_state_idx
elif HAS_STATE_BATCH_INDICES:
cache_idx = state_batch_idx
else:
cache_idx = pid_b

mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
if HAS_STATE_BATCH_INDICES:
mask &= state_batch_idx != pad_slot_id
tl.store(state_ptrs, state, mask=mask)
out = tl.sum(state * C[None, :], axis=1)
if HAS_D:
out += x * D
if HAS_Z:
out *= z * tl.sigmoid(z)
tl.store(out_ptrs, out, mask=offs_m < dim)
current_step_idx = 0
for _ in range(T):
if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK:
if current_step_idx != 0 and cache_idx >= 0:
parent_ptr = (
retrieve_parent_token_ptr
+ pid_b * stride_retrieve_parent_token_batch
+ current_step_idx * stride_retrieve_parent_token_T
)
parent_step_idx = tl.load(parent_ptr).to(tl.int32)

if parent_step_idx >= 0 and parent_step_idx < T:
step_offset = parent_step_idx * nheads * dim * dstate
cache_ptr = (
intermediate_states_buffer
+ cache_idx * cache_steps * nheads * dim * dstate
+ step_offset
+ pid_h * dim * dstate
+ offs_m[:, None] * dstate
+ offs_n[None, :]
)
state = tl.load(cache_ptr, mask=mask, other=0.0).to(tl.float32)

x_ptrs = x_ptr + offs_m * stride_x_dim
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
B_ptrs = B_ptr + offs_n * stride_B_dstate
C_ptrs = C_ptr + offs_n * stride_C_dstate
if HAS_Z:
z_ptrs = z_ptr + offs_m * stride_z_dim
out_ptrs = out_ptr + offs_m * stride_out_dim

x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if not TIE_HDIM:
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if HAS_DT_BIAS:
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if DT_SOFTPLUS:
dt = softplus(dt)
A = tl.load(
A_ptrs,
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
other=0.0,
).to(tl.float32)
dA = tl.exp(A * dt[:, None])
else:
dt = tl.load(dt_ptr).to(tl.float32)
if HAS_DT_BIAS:
dt += tl.load(dt_bias_ptr).to(tl.float32)
if DT_SOFTPLUS:
dt = softplus(dt)
A = tl.load(A_ptr).to(tl.float32)
dA = tl.exp(A * dt) # scalar, not a matrix

B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
if HAS_D:
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if HAS_Z:
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)

dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt
state = state * dA + dB * x[:, None]

if CACHE_INTERMEDIATE_STATES:
if state_batch_idx != pad_slot_id:
cache_ptr_base = (
intermediate_states_buffer
+ cache_idx * cache_steps * nheads * dim * dstate
+ current_step_idx * nheads * dim * dstate
+ pid_h * dim * dstate
)
cache_ptrs = cache_ptr_base + (
offs_m[:, None] * dstate + offs_n[None, :]
)
tl.store(cache_ptrs, state.to(cache_ptrs.dtype.element_ty), mask=mask)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard cache writes when intermediate_state_indices may contain pad/invalid values.

If intermediate_state_indices uses pad_slot_id (e.g., -1) for padded rows, the current write path can compute a negative base and write out-of-bounds. Consider gating on cache_idx != pad_slot_id (or assert on the host side) before storing.

🛡️ Suggested guard
-        if state_batch_idx != pad_slot_id:
+        if state_batch_idx != pad_slot_id and cache_idx != pad_slot_id:
🤖 Prompt for AI Agents
In `@tests/mamba/selective_state_update_triton.py` around lines 189 - 276, The
cache write can index out-of-bounds when intermediate_state_indices contains
pad_slot_id (e.g., -1); before computing cache_ptr_base and calling tl.store you
must gate the write on a valid cache index — e.g., check
CACHE_INTERMEDIATE_STATES and that cache_idx != pad_slot_id and cache_idx >= 0
(and optionally state_batch_idx != pad_slot_id) — so modify the block that
computes cache_ptr_base and calls tl.store (references:
intermediate_state_indices_ptr, cache_idx, pad_slot_id,
CACHE_INTERMEDIATE_STATES, state_batch_idx, intermediate_states_buffer,
cache_ptrs, current_step_idx, tl.store) to skip stores for invalid/padded cache
indices.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #42916798: 3/20 passed

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Feb 1, 2026

@flashinfer-bot run

@yzh119 yzh119 merged commit b7404d0 into flashinfer-ai:main Feb 3, 2026
49 checks passed
raayandhar pushed a commit to raayandhar/flashinfer that referenced this pull request Feb 5, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

This contribution implements the following changes:

- Multi-token prediction for mamba
- Handling state and intermediate state cache that are non-contiguous in
batch dimension
- Handling int32 and int64 cache indices
- More checks for dtypes and consistency of dtypes
- Unified template dispatch functions (that hopefully will be replaced
with jinja templates in the future).
- NO Eagle3 yet

The new kernel yields 2.78x throughput compared to SGLang's Triton
implementation:


<!-- Link any related issues here -->
<img width="3000" height="1500" alt="mtp_b200_02"
src="https://github.com/user-attachments/assets/ff084385-e812-426f-b499-10d1fdf36692"
/>

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

Sorry that it's so big. Please, see if I'm missing any important tests.


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Multi-token prediction support (cache_steps) and new public
parameters: disable_state_update, intermediate_states_buffer,
intermediate_state_indices, and optional out tensor; updated
docstrings/shapes.

* **Validation & Errors**
* Centralized, stricter input validation and clearer human-readable
error messages for unsupported configurations.

* **Tests**
* Added extensive single-token and multi-token test suites and test
utilities; removed an older test file.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants