Conversation
Move the test input generation helper from test_selective_state_update.py to a new test_utils.py module for reuse across tests. The refactored function adds support for multi-token mode, intermediate state buffers, and configurable state cache strides.
struct - Add helper functions for tensor validation and dtype checks - Move output tensor to Optional and update checks accordingly - Add state_stride_batch and update_state fields to SelectiveStateUpdateParams - Refactor kernel param usage for clarity and consistency
Extract dispatchDimDstate and dispatchRatio helpers to simplify kernel dispatch code and reduce duplication.
- Add kernel and dispatcher support for int32/int64 state_batch_indices - Update tests to cover int32 indices - Fix test_utils to use int64 slot_idx by default Support int32 and int64 state_batch_indices in selective_state_update - Remove int32 type check to allow both int32 and int64 index types - Add stateIndex_t template parameter to kernels for index type dispatch - Extract kernel implementations to new selective_state_update_stp.cuh - Remove unused TMA helper functions from create_tensor_map.cuh - Add comprehensive MTP (multi-token prediction) test suite
checks - Add common.cuh with kernel dispatch helpers and alignment checks - Split and rename kernel_selective_state_update_stp.cuh, add kernel_selective_state_update_mtp.cuh - Refactor Python selective_state_update to clarify dimension handling - Add test for dtype mismatch between state_batch_indices and intermediate_state_indices - Update test_utils to generate int64 intermediate_slot_idx by default - Remove redundant input type check in validate_intermediate_state_indices
📝 WalkthroughWalkthroughThis PR adds multi‑token prediction (MTP) support and richer validation/dispatch for selective_state_update: new STP/MTP CUDA kernels, dtype-driven compile-time dispatch, expanded parameter structs, Python/C++ bindings for intermediate-state caching and disable_state_update, and extensive tests/utilities. Changes
Sequence Diagram(s)sequenceDiagram
participant PyAPI as Python API
participant CPP as C++ Binding
participant Dispatch as Dispatcher (csrc)
participant KernelMgr as Kernel Launcher (STP/MTP)
participant CUDA as CUDA Device
PyAPI->>CPP: selective_state_update(..., cache_steps, out, intermediate_states_buffer, ...)
CPP->>Dispatch: call C++ selective_state_update with all params
Dispatch->>Dispatch: validate tensors & dtypes
Dispatch->>Dispatch: choose STP or MTP (cache_steps)
alt STP (single-token)
Dispatch->>KernelMgr: invokeSelectiveStateUpdate(params)
KernelMgr->>CUDA: launch STP kernel variant (simple/SM90/SM100)
else MTP (multi-token)
Dispatch->>KernelMgr: invokeSelectiveStateUpdateMTP(params)
KernelMgr->>CUDA: launch MTP kernel (per-token loop, optional intermediate state caching)
end
CUDA-->>Dispatch: write outputs, state, intermediate buffers
Dispatch-->>CPP: return results
CPP-->>PyAPI: return output tensor
sequenceDiagram
participant PyFront as Public Python
participant ShapeProc as Shape/Dim Normalization
participant TorchLib as Torch meta/fake
participant Backend as C++/CUDA
PyFront->>ShapeProc: selective_state_update(..., cache_steps, ...)
ShapeProc->>ShapeProc: derive is_mtp, expand/squeeze dims, normalize A/D/dt_bias
ShapeProc->>TorchLib: call _selective_state_update (meta/fake path)
TorchLib->>Backend: compiled binding -> C++ implementation
Backend->>Backend: dtype dispatch & validation -> select kernel launcher
Backend->>CUDA: launch kernels (STP/MTP)
CUDA-->>Backend: results
Backend-->>PyFront: final tensor
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @ishovkun, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Mamba selective state update operation by introducing multi-token prediction capabilities, which are vital for efficient speculative decoding. The changes involve a comprehensive refactoring of the CUDA kernels and Python bindings to support processing multiple tokens concurrently. It also includes robust handling for various memory layouts, data types, and indexing schemes for state and intermediate caches, ensuring greater flexibility and stability. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces multi-token prediction (MTP) for Mamba, a significant feature enhancement. The changes are extensive and well-structured, including robust validation, modern C++ dispatching mechanisms, and architecture-specific optimizations for CUDA kernels. The test suite is comprehensive, covering single-token, multi-token, and various edge cases, which provides confidence in the correctness of the implementation. I have one suggestion to improve type safety in the C++ code by using const for read-only data pointers.
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In `@csrc/selective_state_update.cu`:
- Around line 161-174: The allowed_dtype_combos list currently only allows
bfloat16_code at tuple position 1 (the input_code), which blocks float16 inputs;
update allowed_dtype_combos to also include the same combinations with
float16_code in that second position so float16 inputs are accepted.
Specifically, for each existing tuple where the second element is bfloat16_code,
add a corresponding tuple with that element replaced by float16_code, preserving
the other elements (including both int32_code and int64_code variants) so the
dtype permutations match the upstream Mamba selective scan flexibility.
In `@include/flashinfer/mamba/kernel_selective_state_update_stp.cuh`:
- Around line 169-170: The simple kernel currently always writes back state with
the line that stores rState into state[d * DSTATE + i] when state_batch !=
params.pad_slot_id; change that write to also check params.update_state so it
only writes when state updates are enabled (i.e. guard the store with
params.update_state && state_batch != params.pad_slot_id), matching the
producer-consumer kernels' behavior (see params.update_state usage) and
preserving the pad_slot_id check; ensure you apply this to the same cast/store
using load_state_t and the state/DSTATE indexing so semantics remain identical
aside from honoring update_state.
In `@tests/mamba/selective_state_update_triton.py`:
- Around line 263-277: Remove the extra HAS_STATE_BATCH_INDICES guard so caching
mirrors the CUDA kernel: when CACHE_INTERMEDIATE_STATES is true and
state_batch_idx != pad_slot_id, always compute cache_ptr_base (using
intermediate_states_buffer, cache_idx, cache_steps, nheads, dim, dstate,
current_step_idx, pid_h) and cache_ptrs (using offs_m, offs_n) and call
tl.store(state.to(cache_ptrs.dtype.element_ty), mask=mask). Delete the enclosing
"if HAS_STATE_BATCH_INDICES:" condition around that caching block so the logic
only checks CACHE_INTERMEDIATE_STATES and state_batch_idx != pad_slot_id.
🧹 Nitpick comments (16)
include/flashinfer/mamba/create_tensor_map.cuh (2)
64-72: Consider validatingtileShapes[0]for boxDim limit consistency.The validation for
tileShapes[ii](ii > 0) checks against 256, buttileShapes[0]is assigned directly without limit checking. While the first dimension typically has different constraints (up to 256 bytes for the box extent), consider adding a similar bounds check for consistency, or add a comment explaining why the first dimension doesn't need the same validation.Additionally, the error handling pattern mixes
std::cerrwithFLASHINFER_CHECK(false). Consider using the message directly inFLASHINFER_CHECKfor consistency:boxDim[0] = tileShapes[0]; for (size_t ii = 1; ii < tileShapes.size(); ++ii) { - if (tileShapes[ii] > 256) { - std::cerr << "buildNdTmaDescriptor: boxDim too large " << tileShapes[ii] << std::endl; - FLASHINFER_CHECK(false); - } else { - boxDim[ii] = tileShapes[ii]; - } + FLASHINFER_CHECK(tileShapes[ii] <= 256, "buildNdTmaDescriptor: boxDim too large ", tileShapes[ii]); + boxDim[ii] = tileShapes[ii]; }
86-121: Detailed error reporting is helpful for debugging TMA issues.The comprehensive error reporting with shapes, strides, tile dimensions, and swizzle type will be valuable for diagnosing TMA descriptor creation failures. However,
errorStringfromcuGetErrorStringis retrieved but never used in the error message.🔧 Proposed fix to include the error string
if (result != CUDA_SUCCESS) { char const* errorString; cuGetErrorString(result, &errorString); std::stringstream ss; - ss << "Error: Failed to initialize the TMA descriptor " << result << std::endl; + ss << "Error: Failed to initialize the TMA descriptor (code=" << result + << ", " << (errorString ? errorString : "unknown error") << ")" << std::endl;include/flashinfer/mamba/common.cuh (2)
29-29: Redefinition ofwarpSizeshadows CUDA built-in.The constant
warpSizeshadows the CUDA built-inwarpSizevariable. While functionally equivalent on current hardware, this could cause confusion or issues if the built-in is referenced elsewhere. Consider renaming tokWarpSizeor similar to avoid shadowing.-constexpr unsigned warpSize = 32; +constexpr unsigned kWarpSize = 32;And update references accordingly (lines 59, 64).
43-50: Minor: Remove trailing semicolon after function body.The semicolon after the closing brace of
make_zeros()is unnecessary.return ret; -}; +}flashinfer/mamba/selective_state_update.py (1)
136-138: Clarify MTP mode semantics in comment.The condition
cache_steps >= 1means that evencache_steps=1triggers MTP mode with 4D tensors. The comment says "more than 1 token" but the condition includescache_steps=1. Consider clarifying:- # Determine if we're in multi-token mode (more than 1 token) - is_mtp = cache_steps >= 1 + # Determine if we're in multi-token mode (cache_steps provided) + # Note: cache_steps >= 1 triggers 4D tensor handling even for single token + is_mtp = cache_steps >= 1tests/mamba/test_selective_state_update_stp.py (1)
284-286: Prefix unused variable with underscore.Static analysis correctly identifies that
state_refis unpacked but unused in this test method.def test_output_correctness(self, inputs, reference_output, use_out_tensor): """Test that kernel output matches reference but state is not updated.""" - y_ref, state_ref = reference_output + y_ref, _state_ref = reference_outputtests/mamba/test_selective_state_update_mtp.py (2)
291-293: Prefix unused variable with underscore.def test_output_correctness(self, inputs, reference_output, use_out_tensor): """Test that kernel output matches reference but state is not updated.""" - y_ref, state_ref = reference_output + y_ref, _state_ref = reference_output
443-445: Prefix unused variable with underscore.def test_output_correctness(self, inputs, reference_output, use_out_tensor): """Test that kernel output matches and intermediate states are cached correctly.""" - y_ref, state_ref, intermediate_states_ref = reference_output + y_ref, _state_ref, intermediate_states_ref = reference_outputinclude/flashinfer/mamba/kernel_selective_state_update_mtp.cuh (3)
82-120: Warp assignment is tightly coupled tonumWarps=4.The loading logic hardcodes warp indices 0-3 for loading x, B, z, and C respectively. If
numWarpstemplate parameter is changed from 4, this code will break or leave some data unloaded. Consider adding astatic_assertto enforce this assumption.🛠️ Proposed fix to enforce the assumption
+ static_assert(numWarps == 4, "Loading logic assumes exactly 4 warps"); + if (warp == 0) { // Load x: gmem -> smem
152-159: Add comment explaining the packed element calculation strategy.This is a performance-critical hot path with non-trivial logic for computing
packedSramLdInputElements. A brief comment explaining why this optimization reduces LSU load would help future reviewers understand the design choice.As per coding guidelines: "For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers."
237-248: Potential performance concern: intermediate state writes inside the inner loop.Writing to
intermediate_stateson every MTP step (line 237-248) inside the dimension loop could be a performance bottleneck. Each iteration writes the fullDSTATEelements per dimension row. Consider whether buffering and batching these writes is feasible for better memory throughput.csrc/selective_state_update.cu (1)
312-320: Uninitializedout_stride_batchwhen output is not provided.When
outis not provided,p.out_stride_batchis set to 0 (line 315-316), butp.outputis also nullptr (line 329). The kernel should handle this case, but it would be cleaner to explicitly document that the kernel checks for nullptr output before using the stride.include/flashinfer/mamba/selective_state_update.cuh (1)
72-73: Unusual include placement after namespace close.Including kernel headers after the namespace closure is unconventional. While it works because those headers likely open the same namespace, consider moving these includes to the top of the file for consistency with standard C++ practices.
include/flashinfer/mamba/kernel_selective_state_update_stp.cuh (1)
16-17: Remove commented-out include guards.These commented-out lines appear to be leftover from refactoring. They should be removed to keep the code clean.
🧹 Proposed fix
-// `#ifndef` FLASHINFER_MAMBA_SELECTIVE_STATE_UPDATE_CUH_ -// `#define` FLASHINFER_MAMBA_SELECTIVE_STATE_UPDATE_CUH_ - `#include` <cooperative_groups.h>tests/mamba/selective_state_update_triton.py (2)
87-88: Unusedbatchparameter in kernel signature.The
batchparameter on line 87 is never used inside the kernel (the batch index is obtained viatl.program_id(axis=1)). Consider removing it from the signature to avoid confusion.🧹 Proposed fix
# Matrix dimensions - batch, T, nheads,And update the kernel call accordingly.
286-286: Remove unusednoqadirective.The
# noqa: SIM113comment is flagged as unused by ruff. The manual loop counter increment is intentional for the pointer arithmetic pattern, but this specific rule isn't enabled.🧹 Proposed fix
- current_step_idx += 1 # noqa: SIM113 + current_step_idx += 1
|
/bot run |
|
@flashinfer-bot run |
Always define state_batch_idx (either from state_batch_indices or pid_b) to mirror the CUDA kernel's state_batch variable. This allows the intermediate state caching logic to use a simple check of `state_batch_idx != pad_slot_id` without requiring an extra HAS_STATE_BATCH_INDICES guard, matching the CUDA kernel behavior. addresses: flashinfer-ai#2444 (comment)
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/mamba/selective_state_update_triton.py (1)
68-90:⚠️ Potential issue | 🟡 MinorRename unused
batchparameter to_batchto silence Ruff ARG001 warning.The parameter is not used within the kernel function; the kernel operates using
pid_b = tl.program_id(axis=1)instead. Renaming it explicitly indicates it's intentionally unused.Minimal fix
- batch, + _batch,
🤖 Fix all issues with AI agents
In `@tests/mamba/selective_state_update_triton.py`:
- Line 284: The trailing noqa directive on the increment line is unnecessary;
remove the inline comment "# noqa: SIM113" from the statement that updates
current_step_idx (the line containing "current_step_idx += 1") so the code
increments the variable without the unused linter suppression.
- Around line 189-276: The cache write can index out-of-bounds when
intermediate_state_indices contains pad_slot_id (e.g., -1); before computing
cache_ptr_base and calling tl.store you must gate the write on a valid cache
index — e.g., check CACHE_INTERMEDIATE_STATES and that cache_idx != pad_slot_id
and cache_idx >= 0 (and optionally state_batch_idx != pad_slot_id) — so modify
the block that computes cache_ptr_base and calls tl.store (references:
intermediate_state_indices_ptr, cache_idx, pad_slot_id,
CACHE_INTERMEDIATE_STATES, state_batch_idx, intermediate_states_buffer,
cache_ptrs, current_step_idx, tl.store) to skip stores for invalid/padded cache
indices.
- Line 403: Replace the lambda assigned to grid with a named function to satisfy
Ruff E731: define a function (e.g., def grid(META):) that takes the META
argument and returns (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads),
keeping the same captured variables (dim, batch, nheads) and use that function
name wherever grid was used.
🧹 Nitpick comments (2)
include/flashinfer/mamba/kernel_selective_state_update_stp.cuh (2)
641-733: HoistdAcomputation out of inner loops.
A_valueanddt_valueare invariant per thread; recomputing__expfper element adds avoidable work.♻️ Suggested refactor
@@ - // `#pragma` unroll 1 + auto const dA = __expf(A_value * dt_value); + // `#pragma` unroll 1 @@ - auto const dA = __expf(A_value * dt_value); auto const dB = B_value * dt_value; @@ - auto const dA = __expf(A_value * dt_value); auto const dB = B_value * dt_value;
910-915: Document thenumWarps = 4heuristic for the simple path.A short rationale for this tuning choice (and alternatives) will help future perf work.
As per coding guidelines, for performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers.📝 Suggested comment
- constexpr int numWarps = 4; + constexpr int numWarps = 4; // tuned for occupancy vs. register pressure; consider retuning for new GPUs
| cache_idx = -1 | ||
| if CACHE_INTERMEDIATE_STATES: | ||
| if HAS_INTERMEDIATE_STATE_INDICES: | ||
| intermediate_state_idx = tl.load(intermediate_state_indices_ptr + pid_b).to( | ||
| tl.int64 | ||
| ) | ||
| cache_idx = intermediate_state_idx | ||
| elif HAS_STATE_BATCH_INDICES: | ||
| cache_idx = state_batch_idx | ||
| else: | ||
| cache_idx = pid_b | ||
|
|
||
| mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) | ||
| if HAS_STATE_BATCH_INDICES: | ||
| mask &= state_batch_idx != pad_slot_id | ||
| tl.store(state_ptrs, state, mask=mask) | ||
| out = tl.sum(state * C[None, :], axis=1) | ||
| if HAS_D: | ||
| out += x * D | ||
| if HAS_Z: | ||
| out *= z * tl.sigmoid(z) | ||
| tl.store(out_ptrs, out, mask=offs_m < dim) | ||
| current_step_idx = 0 | ||
| for _ in range(T): | ||
| if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: | ||
| if current_step_idx != 0 and cache_idx >= 0: | ||
| parent_ptr = ( | ||
| retrieve_parent_token_ptr | ||
| + pid_b * stride_retrieve_parent_token_batch | ||
| + current_step_idx * stride_retrieve_parent_token_T | ||
| ) | ||
| parent_step_idx = tl.load(parent_ptr).to(tl.int32) | ||
|
|
||
| if parent_step_idx >= 0 and parent_step_idx < T: | ||
| step_offset = parent_step_idx * nheads * dim * dstate | ||
| cache_ptr = ( | ||
| intermediate_states_buffer | ||
| + cache_idx * cache_steps * nheads * dim * dstate | ||
| + step_offset | ||
| + pid_h * dim * dstate | ||
| + offs_m[:, None] * dstate | ||
| + offs_n[None, :] | ||
| ) | ||
| state = tl.load(cache_ptr, mask=mask, other=0.0).to(tl.float32) | ||
|
|
||
| x_ptrs = x_ptr + offs_m * stride_x_dim | ||
| dt_ptrs = dt_ptr + offs_m * stride_dt_dim | ||
| B_ptrs = B_ptr + offs_n * stride_B_dstate | ||
| C_ptrs = C_ptr + offs_n * stride_C_dstate | ||
| if HAS_Z: | ||
| z_ptrs = z_ptr + offs_m * stride_z_dim | ||
| out_ptrs = out_ptr + offs_m * stride_out_dim | ||
|
|
||
| x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) | ||
| if not TIE_HDIM: | ||
| dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) | ||
| if HAS_DT_BIAS: | ||
| dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) | ||
| if DT_SOFTPLUS: | ||
| dt = softplus(dt) | ||
| A = tl.load( | ||
| A_ptrs, | ||
| mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), | ||
| other=0.0, | ||
| ).to(tl.float32) | ||
| dA = tl.exp(A * dt[:, None]) | ||
| else: | ||
| dt = tl.load(dt_ptr).to(tl.float32) | ||
| if HAS_DT_BIAS: | ||
| dt += tl.load(dt_bias_ptr).to(tl.float32) | ||
| if DT_SOFTPLUS: | ||
| dt = softplus(dt) | ||
| A = tl.load(A_ptr).to(tl.float32) | ||
| dA = tl.exp(A * dt) # scalar, not a matrix | ||
|
|
||
| B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) | ||
| C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) | ||
| if HAS_D: | ||
| D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) | ||
| if HAS_Z: | ||
| z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) | ||
|
|
||
| dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt | ||
| state = state * dA + dB * x[:, None] | ||
|
|
||
| if CACHE_INTERMEDIATE_STATES: | ||
| if state_batch_idx != pad_slot_id: | ||
| cache_ptr_base = ( | ||
| intermediate_states_buffer | ||
| + cache_idx * cache_steps * nheads * dim * dstate | ||
| + current_step_idx * nheads * dim * dstate | ||
| + pid_h * dim * dstate | ||
| ) | ||
| cache_ptrs = cache_ptr_base + ( | ||
| offs_m[:, None] * dstate + offs_n[None, :] | ||
| ) | ||
| tl.store(cache_ptrs, state.to(cache_ptrs.dtype.element_ty), mask=mask) | ||
|
|
There was a problem hiding this comment.
Guard cache writes when intermediate_state_indices may contain pad/invalid values.
If intermediate_state_indices uses pad_slot_id (e.g., -1) for padded rows, the current write path can compute a negative base and write out-of-bounds. Consider gating on cache_idx != pad_slot_id (or assert on the host side) before storing.
🛡️ Suggested guard
- if state_batch_idx != pad_slot_id:
+ if state_batch_idx != pad_slot_id and cache_idx != pad_slot_id:🤖 Prompt for AI Agents
In `@tests/mamba/selective_state_update_triton.py` around lines 189 - 276, The
cache write can index out-of-bounds when intermediate_state_indices contains
pad_slot_id (e.g., -1); before computing cache_ptr_base and calling tl.store you
must gate the write on a valid cache index — e.g., check
CACHE_INTERMEDIATE_STATES and that cache_idx != pad_slot_id and cache_idx >= 0
(and optionally state_batch_idx != pad_slot_id) — so modify the block that
computes cache_ptr_base and calls tl.store (references:
intermediate_state_indices_ptr, cache_idx, pad_slot_id,
CACHE_INTERMEDIATE_STATES, state_batch_idx, intermediate_states_buffer,
cache_ptrs, current_step_idx, tl.store) to skip stores for invalid/padded cache
indices.
|
[FAILED] Pipeline #42916798: 3/20 passed |
|
@flashinfer-bot run |
<!-- .github/pull_request_template.md --> ## 📌 Description This contribution implements the following changes: - Multi-token prediction for mamba - Handling state and intermediate state cache that are non-contiguous in batch dimension - Handling int32 and int64 cache indices - More checks for dtypes and consistency of dtypes - Unified template dispatch functions (that hopefully will be replaced with jinja templates in the future). - NO Eagle3 yet The new kernel yields 2.78x throughput compared to SGLang's Triton implementation: <!-- Link any related issues here --> <img width="3000" height="1500" alt="mtp_b200_02" src="https://github.com/user-attachments/assets/ff084385-e812-426f-b499-10d1fdf36692" /> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes Sorry that it's so big. Please, see if I'm missing any important tests. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Multi-token prediction support (cache_steps) and new public parameters: disable_state_update, intermediate_states_buffer, intermediate_state_indices, and optional out tensor; updated docstrings/shapes. * **Validation & Errors** * Centralized, stricter input validation and clearer human-readable error messages for unsupported configurations. * **Tests** * Added extensive single-token and multi-token test suites and test utilities; removed an older test file. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
📌 Description
This contribution implements the following changes:
The new kernel yields 2.78x throughput compared to SGLang's Triton implementation:
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Sorry that it's so big. Please, see if I'm missing any important tests.
Summary by CodeRabbit
New Features
Validation & Errors
Tests
✏️ Tip: You can customize this high-level summary in your review settings.