Skip to content

Improved simple mamba SSU kernel #2962

Open
ishovkun wants to merge 143 commits intoflashinfer-ai:mainfrom
ishovkun:ssu_mtp_horizontal_async
Open

Improved simple mamba SSU kernel #2962
ishovkun wants to merge 143 commits intoflashinfer-ai:mainfrom
ishovkun:ssu_mtp_horizontal_async

Conversation

@ishovkun
Copy link
Copy Markdown
Contributor

@ishovkun ishovkun commented Apr 2, 2026

📌 Description

This PR upgrades the SSU MTP "simple" kernel with cp.async state prefetching, vectorized loads, and a consolidated state write path, delivering drastic performance improvements — particularly in the latency spectrum. The async_horizontal kernel was a temporary development vehicle used during implementation and benchmarking; once the optimizations were validated, the simple kernel was replaced with the improved version and the temporary kernel was removed.

Summary

This PR upgrades the SSU MTP "simple" kernel with cp.async state prefetching, vectorized loads, and a consolidated state write path, delivering drastic performance improvements — particularly in the latency spectrum. The async_horizontal kernel was a temporary development vehicle used during implementation and benchmarking; once the optimizations were validated, the simple kernel was replaced with the improved version and the temporary kernel was removed.

Key changes

Async state prefetch (cp.async → double-buffered smem)

  • Replace direct global loads of state_in with cp.async into a double-buffered shared memory staging area (state_in[STATE_STAGES]).
  • First pass is prefetched during the load phase; subsequent passes are pipelined at the end of each pass loop iteration.
  • Extracted into a reusable cp_async_state_cooperative helper function.

Vectorized loads

  • Use PackedAligned for vectorized loads of B, C, and x tensors in the load path, improving memory access efficiency when padding is not active.

State write path consolidation

  • Precompute per-step state_dst_slots[] during the load phase, eliminating redundant index recomputation per pass/dd.
  • Replace three separate state-write branches (intermediate states, per-step dst indices, final state) with a single unified dst_slot != SKIP path.
  • Deduplicate encode-scale computation (was computed up to 3× per step).

OOB handling cleanup

  • Remove upfront shared memory zero-fill padding. Instead, zero OOB padding columns directly in registers at load time. This eliminates an extra __syncthreads__ barrier.

Latency hiding

  • Hoist A and D global loads before the barrier to overlap with smem wait.
  • Move dst_slot prefetch earlier to hide LDS latency.
  • Use mul_f32x2 for state decode scale.

Varlen + scaled-state support

  • Remove guards that blocked the async_horizontal path from running with cu_seqlens or scaled (quantized) state.
  • Refactor smem layout: replace sub-tile-major BANK_CYCLE_ELEMS scheme with a simpler DSTATE_PAD (128-byte aligned) wide tile.

Validation

  • Add mutual exclusion check: intermediate_states_buffer and dst_state_batch_indices cannot both be provided.

Benchmarking

  • Add bench_ssu_sweep_sol.py — SOL (speed-of-light) benchmark script for SSU MTP mode.
sol_vs_batch_size_mtp6_bf16_NVIDIA_B200

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • SM100-optimized selective-state-update path with new algorithms: vertical, horizontal, and async-horizontal.
    • Comprehensive benchmarks that measure runtime and "speed-of-light" (SOL) estimates with automatic plotting.
  • Improvements

    • Improved dtype/format conversion, alignment checks, and configurable tensor-map out-of-bounds fill.
    • Input validation for mutually exclusive options, optional stochastic rounding, and support for larger state widths (dstate=96).
  • Tests

    • Expanded tests covering SM100 capability, new algorithms, varlen/pad-slot cases, and varied ngroups.

ishovkun and others added 30 commits January 28, 2026 21:29
Move the test input generation helper from
test_selective_state_update.py
to a new test_utils.py module for reuse across tests. The refactored
function adds support for multi-token mode, intermediate state buffers,
and configurable state cache strides.
struct

- Add helper functions for tensor validation and dtype checks
- Move output tensor to Optional and update checks accordingly
- Add state_stride_batch and update_state fields to
  SelectiveStateUpdateParams
- Refactor kernel param usage for clarity and consistency
Extract dispatchDimDstate and dispatchRatio helpers to simplify
kernel dispatch code and reduce duplication.
- Add kernel and dispatcher support for int32/int64 state_batch_indices
- Update tests to cover int32 indices
- Fix test_utils to use int64 slot_idx by default
  Support int32 and int64 state_batch_indices in selective_state_update

- Remove int32 type check to allow both int32 and int64 index types
- Add stateIndex_t template parameter to kernels for index type dispatch
- Extract kernel implementations to new selective_state_update_stp.cuh
- Remove unused TMA helper functions from create_tensor_map.cuh
- Add comprehensive MTP (multi-token prediction) test suite
checks

- Add common.cuh with kernel dispatch helpers and alignment checks
- Split and rename kernel_selective_state_update_stp.cuh, add
  kernel_selective_state_update_mtp.cuh
- Refactor Python selective_state_update to clarify dimension handling
- Add test for dtype mismatch between state_batch_indices and
  intermediate_state_indices
- Update test_utils to generate int64 intermediate_slot_idx by default
- Remove redundant input type check in
  validate_intermediate_state_indices
Always define state_batch_idx (either from state_batch_indices or pid_b)
to mirror the CUDA kernel's state_batch variable. This allows the
intermediate state caching logic to use a simple check of
`state_batch_idx != pad_slot_id` without requiring an extra
HAS_STATE_BATCH_INDICES guard, matching the CUDA kernel behavior.

addresses:
flashinfer-ai#2444 (comment)
- Add test_chunk_scan_combined.py comparing CUTLASS CuTe DSL
  Blackwell implementation against Triton reference
- Move selective_state_update_triton.py into triton_reference/ package
- Add Triton reference implementations for Mamba2 SSD kernels:
  - ssd_combined.py (main entry point)
  - ssd_chunk_scan.py, ssd_chunk_state.py, ssd_state_passing.py
  - ssd_bmm.py, softplus.py (utilities)
# Conflicts:
#	tests/mamba/selective_state_update_triton.py
#	tests/mamba/test_selective_state_update_mtp.py
#	tests/mamba/test_selective_state_update_stp.py
- Move dtype dispatch and instantiation to codegen via Jinja templates
- Generate config and instantiation files per dtype combination
- Update Python JIT logic to build/load kernels for specific dtypes
- Remove C++ dtype dispatch helpers from selective_state_update.cu
- Update kernel launcher comment for clarity on consumer warps
Support explicit algorithm choice (auto/simple/vertical/horizontal)
for selective_state_update and MTP kernels. Update kernel signatures,
Python bindings, and JIT module generation to include algorithm and
compile-time shape parameters (dim, dstate, ntokens_mtp). Refactor
dispatch logic for SM90/SM100 architectures.
… .cu files

The config.inc defines DIM, DSTATE, NTOKENS_MTP as constexpr globals
that the header's function templates rely on. With the previous order
(header first, config second), NVCC's lenient two-phase lookup masked
the issue, but a fresh JIT compilation after cache clearing would fail
with 'identifier DIM/DSTATE is undefined' errors.

clang-format is disabled for these includes because it reorders them
alphabetically, which breaks compilation.

AI-assisted
Assign each of the 4 consumer warps a single tensor to load (x, B, z, C)
instead of warps 0 and 1 each loading two tensors sequentially. This
maximizes memory-level parallelism during the load phase.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Replace cartesian-product fixture parametrization with explicit rows:
one base case plus one row per parameter deviation. Cuts the test count
from ~200+ (MTP) and ~144+ (STP) down to ~26 and ~15 respectively.

AI-assisted

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Parametrize tests to run with all supported algorithms
- Update test logic to pass algorithm argument through
- Improve test output messages to include algorithm name
- Add utility to detect available algorithms based on GPU arch
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 2, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Replaces the monolithic MTP SSU with three modular MTP kernels (simple, horizontal, vertical), adds SM100-specialized JIT/module support and dispatch, introduces MTP utilities and new benchmarks, updates conversion/alignment helpers and STP/MTP dispatch/validation, and expands tests to cover algorithms, dstate, and pad/varlen cases.

Changes

Cohort / File(s) Summary
Benchmarks
benchmarks/bench_ssu_sweep_mtp.py, benchmarks/bench_ssu_sweep_sol.py
Added two CLI benchmark scripts: runtime sweep vs Triton and SOL (% speed-of-light) analysis, with philox/stochastic options, NCU mode, plotting and CSV/DataFrame outputs.
New MTP Kernels
include/flashinfer/mamba/kernel_selective_state_update_mtp_simple.cuh, include/flashinfer/mamba/kernel_selective_state_update_mtp_horizontal.cuh, include/flashinfer/mamba/kernel_selective_state_update_mtp_vertical.cuh
Added three MTP kernel implementations (simple, horizontal, vertical) with distinct shared-memory layouts, warp-role coordination, cp.async/TMA usage, and optional stochastic rounding / intermediate-state writes.
MTP Dispatch & Common Utilities
include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh, include/flashinfer/mamba/ssu_mtp_common.cuh
New host-side invokeSelectiveStateUpdateMTP dispatch, algorithm selection logic (including auto/coercions and async_horizontal→simple), alignment/validation, CTAS dispatch helper, parity barriers, xor-swizzle, SIMD and stochastic-rounding/block-scale helpers.
Removed Legacy MTP
include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh
Removed previous monolithic MTP kernel and its host wrapper; replaced by new modular kernels plus invoke header.
JIT / Module Gen
flashinfer/jit/core.py, flashinfer/jit/mamba/__init__.py, flashinfer/jit/mamba/selective_state_update.py
Exported new SM100 generator gen_selective_state_update_sm100_module, added SM100 NVCC defines in SM100 path, and removed forced -lineinfo from generic generator.
Python API / Dispatch
flashinfer/mamba/selective_state_update.py, include/flashinfer/mamba/selective_state_update.cuh, include/flashinfer/mamba/kernel_selective_state_update_stp.cuh
Dispatch now selects SM100 module for sm_major>=10; maps async_horizontal to simple; switched MTP include to new invoke header; refined alignment checks and expanded dispatchRatio specializations.
Conversion & Alignment Helpers
include/flashinfer/mamba/common.cuh, include/flashinfer/mamba/conversion.cuh, include/flashinfer/mamba/create_tensor_map.cuh
Added largestPow2Divisor and tightened PackedAligned alignment + static_assert; added toFloat2 overloads (half/bf16/int16/pointer forms); buildNdDescriptor gains optional oobFill parameter.
STP tweaks
include/flashinfer/mamba/kernel_selective_state_update_stp.cuh
Adjusted alignment validation to use alignof(load_state_t) and expanded dispatchRatio integer sequence.
Tests
tests/mamba/test_selective_state_update_mtp.py, tests/mamba/test_selective_state_update_stp.py, tests/mamba/test_selective_state_update_varlen.py
Expanded coverage: SM100 gating, algorithm parameterization over simple/vertical/horizontal, added dstate=96 and fp16 cases, new pad-slot tests, new ngroups parameterization, and updated varlen tests to accept algorithm.
Minor comment
flashinfer/jit/core.py
Clarified a comment to “useful for ncu source correlation” (no behavioral change).

Sequence Diagram(s)

sequenceDiagram
    participant Host as Host/Python
    participant JIT as JIT/ModuleGen
    participant Dispatch as invokeSelectiveStateUpdateMTP
    participant Kernel as Kernel (simple/vertical/horizontal)
    participant GPU as GPU Memory

    Host->>JIT: Ensure module built/loaded (select SM major)
    Host->>Dispatch: Call selective_state_update(params, algorithm)
    Dispatch->>Dispatch: Validate inputs, map algorithm (async_horizontal→simple), choose auto
    Dispatch->>Dispatch: Build TMA descriptors & alignment checks
    Dispatch->>Kernel: Launch chosen kernel with params & descriptors
    Kernel->>GPU: Issue cp.async / TMA loads, compute recurrence, write intermediate/final state & outputs
    Kernel-->>Dispatch: Kernel completes
    Dispatch-->>Host: Return status/result
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

  • #2301: Overlapping changes to selective_state_update Mamba kernels, JIT/module bindings, and tests — high code-level overlap.
  • #2387: SM100-specific selective_state_update JIT/module and kernel additions — related SM100 work.
  • #2645: MTP/STP refactor adding per-block state_scale and Philox stochastic rounding — directly related kernel and template changes.

Suggested labels

run-ci, ready

Suggested reviewers

  • jiahanc
  • kahyunnam
  • IwakuraRein
  • jimmyzho
  • cyx-6
  • bkryu
  • nvmbreughe
  • yzh119

Poem

🐰 Hopping through kernels, three in a row,
Simple, vertical, horizontal—watch them go.
Benchmarks blink and JIT paints the light,
SM100 hums through day and night.
A rabbit cheers: updates swift and bright!

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 79.17% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Improved simple mamba SSU kernel' is specific and highlights the main optimization target, accurately reflecting the core focus on improving the simple kernel with cp.async prefetching and vectorized loads.
Description check ✅ Passed The PR description is comprehensive and includes all required sections from the template: a detailed description of changes with context, related changes summary, pre-commit checklist completion status (all checked), and test completion status with detailed reviewer notes.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for NVIDIA Blackwell (SM100) architectures in the Mamba selective state update implementation, adding specialized vertical and horizontal MTP kernels that leverage TMA and f32x2 SIMD instructions. The update also includes a refactored SM80+ simple MTP kernel using cp.async, improved alignment logic, and new benchmarking scripts for performance and Speed-of-Light analysis. Feedback was provided regarding an opportunity to improve the maintainability of the CTAS_PER_HEAD dispatch logic by replacing the series of if constexpr blocks with a more structured mechanism.


// Dispatch to the largest instantiated CTAS_PER_HEAD <= ctas_per_head.
// Use if constexpr to avoid compiling invalid template instantiations.
if constexpr (DIM / 4 >= kRowsPerPass) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The dispatch logic for CTAS_PER_HEAD uses a series of if constexpr blocks that return early. While functional, this can be simplified using a constexpr array or a more structured dispatch mechanism to improve maintainability and readability as more configurations are added.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/mamba/test_selective_state_update_varlen.py (1)

258-271: ⚠️ Potential issue | 🟡 Minor

Keep auto-dispatch coverage in the varlen suite.

Restricting these cu_seqlens cases to "simple" means they no longer exercise the public auto path, so a regression in the cu_seqlens -> simple fallback would slip through. Please keep "auto" here as well, and mirror that in the other varlen parametrizations below.

💡 Suggested parametrization
-    `@pytest.mark.parametrize`("algorithm", ["simple"])
+    `@pytest.mark.parametrize`("algorithm", ["auto", "simple"])

Based on learnings: In flashinfer-ai/flashinfer MTP paths, the vertical and horizontal selective_state_update kernels do not support varlen (cu_seqlens), and the invoker falls back to SSUAlgorithm::kSimple in auto mode when params.cu_seqlens is set.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/mamba/test_selective_state_update_varlen.py` around lines 258 - 271,
The test currently parametrizes only "simple" for the algorithm in
test_varlen_uniform which prevents exercising the public auto-dispatch path;
update the parametrization to include "auto" (i.e., ["simple", "auto"]) for
test_varlen_uniform and mirror that change in the other varlen-related
parametrizations in this file so the auto fallback (cu_seqlens -> simple) is
covered; locate the algorithm param in the test_varlen_uniform function and the
other varlen test parameter blocks and add "auto" to the algorithm list.
🧹 Nitpick comments (4)
include/flashinfer/mamba/ssu_mtp_common.cuh (1)

21-23: Make this header self-contained.

This file uses uint32_t/int64_t, fabsf/fmaxf, and std::numeric_limits but only includes <cuda/barrier> plus conversion.cuh. Please include the standard headers directly so this doesn't depend on include order.

♻️ Suggested include set
 `#include` <cuda/barrier>
 
+#include <cmath>
+#include <cstdint>
+#include <limits>
+
 `#include` "conversion.cuh"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/mamba/ssu_mtp_common.cuh` around lines 21 - 23, The header
ssu_mtp_common.cuh is not self-contained: it uses uint32_t and int64_t,
fabsf/fmaxf, and std::numeric_limits but only includes <cuda/barrier> and
"conversion.cuh"; add direct standard includes (e.g., <cstdint>, <cmath>, and
<limits>) at the top of ssu_mtp_common.cuh so symbols used by functions/types in
this file (uint32_t, int64_t, fabsf, fmaxf, std::numeric_limits) are defined
regardless of include order.
include/flashinfer/mamba/common.cuh (1)

41-46: Guard zero-width PackedAligned instantiations.

getVectorLoadSizeForFullUtilization() can bottom out at 0 for small DSTATE values, so PackedAligned<T, 0> currently fails wherever the compiler happens to complain first. A local static_assert here would make that failure explicit.

♻️ Suggested guard
 template <typename T, int N = sizeof(float4) / sizeof(T)>
 struct alignas(largestPow2Divisor(N * sizeof(T))) PackedAligned {
+  static_assert(N > 0, "PackedAligned requires at least one element");
   T val[N];
 
   static constexpr int count = N;
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/mamba/common.cuh` around lines 41 - 46, Add a compile-time
guard to prevent zero-width instantiations of the PackedAligned template: inside
the template struct PackedAligned (templated on T and N) add a static_assert
that N > 0 with a clear message (e.g. "PackedAligned instantiated with N == 0;
ensure getVectorLoadSizeForFullUtilization() returns >0") so that attempts to
instantiate PackedAligned<T,0> fail with an explicit error rather than a
confusing template/ABI failure; refer to PackedAligned, its template parameter
N, and largestPow2Divisor/getVectorLoadSizeForFullUtilization when making this
change.
tests/mamba/test_selective_state_update_mtp.py (2)

56-66: Please put auto back into the MTP algorithm matrix.

This PR changes the dispatcher, but the fixture now exercises only explicit kernels. Without an auto case, the batch-based routing and fallback logic can regress without any MTP test failing.

Proposed test matrix tweak
     `@pytest.fixture`(
         autouse=True,
         params=[
+            "auto",
             "simple",
             pytest.param("vertical", marks=_requires_sm100),
             pytest.param("horizontal", marks=_requires_sm100),
         ],
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/mamba/test_selective_state_update_mtp.py` around lines 56 - 66, The
test fixture _algorithm currently only parametrizes "simple", "vertical", and
"horizontal" causing MTP to miss the 'auto' routing path; modify the params list
in the _algorithm pytest.fixture to include "auto" (e.g., add "auto" into the
params array alongside "simple", "vertical", and "horizontal") so the test
matrix exercises the batch-based routing/fallback logic that relies on the
automatic kernel selection.

629-635: The new MTP ngroups sweep still misses three dispatch ratios.

With nheads=64, this matrix only covers HEADS_PER_GROUP values 64, 32, 16, and 8. The runtime also instantiates 4, 2, and 1, so half of the vertical/horizontal specializations are still untested.

Proposed coverage extension
     _NGROUPS_PARAMS = (
         # (batch, nheads, dim, dstate, cache_steps, state_dtype,    weight_dtype,   use_out_tensor, ngroups)
         (  64,    64,     64,  128,    4,           torch.bfloat16, torch.float32,  True,           1),
         (  64,    64,     64,  128,    4,           torch.bfloat16, torch.float32,  True,           2),
         (  64,    64,     64,  128,    4,           torch.bfloat16, torch.float32,  True,           4),
         (  64,    64,     64,  128,    4,           torch.bfloat16, torch.float32,  True,           8),
+        (  64,    64,     64,  128,    4,           torch.bfloat16, torch.float32,  True,          16),
+        (  64,    64,     64,  128,    4,           torch.bfloat16, torch.float32,  True,          32),
+        (  64,    64,     64,  128,    4,           torch.bfloat16, torch.float32,  True,          64),
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/mamba/test_selective_state_update_mtp.py` around lines 629 - 635, The
_NGROUPS_PARAMS sweep only includes ngroups = 1,2,4,8 (HEADS_PER_GROUP =
64,32,16,8) and misses the cases where the runtime instantiates ngroups =
16,32,64 (HEADS_PER_GROUP = 4,2,1); update the _NGROUPS_PARAMS tuple to also
include entries for ngroups=16, ngroups=32, and ngroups=64 using the same
parameter pattern (batch=64, nheads=64, dim=64, dstate=128, cache_steps=4,
state_dtype=torch.bfloat16, weight_dtype=torch.float32, use_out_tensor=True) so
the test covers those dispatch ratios.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/bench_ssu_sweep_mtp.py`:
- Around line 351-359: The parsing currently allows bf16-philox-* and
f32-philox-* even though the MTP path only supports stochastic rounding for fp16
state; update parse_dtype_spec to reject any philox-specs whose base is not
"f16" by either restricting the regex to only match "f16-philox-(\d+)" or by
checking m.group(1) and raising a ValueError if base != "f16"; also update the
error message (and the branch that checks _dtype_name_to_torch) to clearly state
that only "f16" or "f16-philox-<rounds>" are supported for MTP/stochastic
rounding so callers get a clear failure instead of a later runtime error.

In `@benchmarks/bench_ssu_sweep_sol.py`:
- Around line 265-269: The --ncu branch directly calls kernel_fn(**kwargs) so
unsupported kernels raise RuntimeError and abort; wrap the kernel_fn call (and
the subsequent torch.cuda.synchronize()) in a try/except that catches
RuntimeError, prints or logs an "unsupported kernel" message (matching the timed
path behavior), and returns 0.0 on failure, otherwise proceed to synchronize,
print the "Single invocation done (ncu mode)" message and return success; update
the ncu branch around kernel_fn and torch.cuda.synchronize() accordingly.

In `@include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh`:
- Around line 151-163: The initial check uses kHorizontalDimAlignment
(NUM_COMPUTE_WARPS_PER_GROUP * ROWS_PER_WARP) allowing DIM values that are
multiples of 16 but the kernel template requires DIM % TMA_STATE_ROWS == 0;
update the validation in the horizontal launcher (the FLASHINFER_CHECK that
currently tests DIM % kHorizontalDimAlignment == 0) to instead verify DIM %
TMA_STATE_ROWS == 0 (or add an additional check against TMA_STATE_ROWS) so
mismatched DIM (e.g. 48, 80) are rejected before kernel launch; reference
kHorizontalDimAlignment, TMA_STATE_ROWS, and the existing FLASHINFER_CHECK lines
to locate where to change the condition.

In `@include/flashinfer/mamba/kernel_selective_state_update_mtp_horizontal.cuh`:
- Around line 108-145: The IS_PAD branch currently skips the initial
cde::cp_async_bulk_tensor_4d_global_to_shared calls for sram.B, sram.C and
sram.x but the compute/epilogue path still reads sram.B/C/x, causing use of
uninitialized shared memory; fix by ensuring the TMA transactions for B/C/x are
still issued when IS_PAD (or alternatively skip the compute/epilogue for padded
rows). Concretely, in the lane==0 / !IS_PAD block replace the conditional with
logic that issues cde::cp_async_bulk_tensor_4d_global_to_shared for sram.B,
sram.C and the x loop even when IS_PAD is true (but you can supply zeroed host
buffers if needed), and make sure the first barrier_arrive_tx call for
sram.bar_state_in_full[slot] uses bytes = bytesBCX + bytesChunk for the (h==0 &&
tl==0) case so the producer byte count matches the B/C/X transfers;
alternatively, gate the compute/epilogue reads of sram.B/sram.C/sram.x to skip
them entirely when IS_PAD.

In `@include/flashinfer/mamba/kernel_selective_state_update_mtp_simple.cuh`:
- Around line 229-259: The unified dst-slot logic changed semantics: ensure
writes honor params.update_state and preserve final-state writes to params.state
when callers expect to cache intermediates. In the dst_state_batch_indices
branch (dst_state_batch_indices) skip writes when params.update_state is false
(set sram.state_dst_slots[step]=SKIP) so per-step indices don't cause writes if
update_state is disabled; in the params.intermediate_states branch, special-case
the last step (step == seq_len-1) to route the destination to state_batch when
params.update_state is true (i.e., set sram.state_dst_slots[step]=state_batch)
instead of the icache slot, otherwise use icache_idx * params.cache_steps +
step; apply the same fix at the other occurrences noted (lines ~312-322 and
~436-474) referencing sram.state_dst_slots, intermediate_state_indices,
dst_state_batch_indices, params.update_state, params.intermediate_states,
state_batch, and SKIP.

In `@include/flashinfer/mamba/kernel_selective_state_update_mtp_vertical.cuh`:
- Around line 192-198: The code silently drops tail state columns when DSTATE
isn’t divisible by warpSize because stateValuesPerThread = DSTATE / warpSize
uses integer division; add an explicit compile-time guard to prevent this misuse
by inserting a static_assert that DSTATE % warpSize == 0 (with a helpful
message) near the current declarations (around
stateValuesPerThread/packed_state_t) so callers must satisfy the warp-aligned
state width or refactor later to handle remainders.

---

Outside diff comments:
In `@tests/mamba/test_selective_state_update_varlen.py`:
- Around line 258-271: The test currently parametrizes only "simple" for the
algorithm in test_varlen_uniform which prevents exercising the public
auto-dispatch path; update the parametrization to include "auto" (i.e.,
["simple", "auto"]) for test_varlen_uniform and mirror that change in the other
varlen-related parametrizations in this file so the auto fallback (cu_seqlens ->
simple) is covered; locate the algorithm param in the test_varlen_uniform
function and the other varlen test parameter blocks and add "auto" to the
algorithm list.

---

Nitpick comments:
In `@include/flashinfer/mamba/common.cuh`:
- Around line 41-46: Add a compile-time guard to prevent zero-width
instantiations of the PackedAligned template: inside the template struct
PackedAligned (templated on T and N) add a static_assert that N > 0 with a clear
message (e.g. "PackedAligned instantiated with N == 0; ensure
getVectorLoadSizeForFullUtilization() returns >0") so that attempts to
instantiate PackedAligned<T,0> fail with an explicit error rather than a
confusing template/ABI failure; refer to PackedAligned, its template parameter
N, and largestPow2Divisor/getVectorLoadSizeForFullUtilization when making this
change.

In `@include/flashinfer/mamba/ssu_mtp_common.cuh`:
- Around line 21-23: The header ssu_mtp_common.cuh is not self-contained: it
uses uint32_t and int64_t, fabsf/fmaxf, and std::numeric_limits but only
includes <cuda/barrier> and "conversion.cuh"; add direct standard includes
(e.g., <cstdint>, <cmath>, and <limits>) at the top of ssu_mtp_common.cuh so
symbols used by functions/types in this file (uint32_t, int64_t, fabsf, fmaxf,
std::numeric_limits) are defined regardless of include order.

In `@tests/mamba/test_selective_state_update_mtp.py`:
- Around line 56-66: The test fixture _algorithm currently only parametrizes
"simple", "vertical", and "horizontal" causing MTP to miss the 'auto' routing
path; modify the params list in the _algorithm pytest.fixture to include "auto"
(e.g., add "auto" into the params array alongside "simple", "vertical", and
"horizontal") so the test matrix exercises the batch-based routing/fallback
logic that relies on the automatic kernel selection.
- Around line 629-635: The _NGROUPS_PARAMS sweep only includes ngroups = 1,2,4,8
(HEADS_PER_GROUP = 64,32,16,8) and misses the cases where the runtime
instantiates ngroups = 16,32,64 (HEADS_PER_GROUP = 4,2,1); update the
_NGROUPS_PARAMS tuple to also include entries for ngroups=16, ngroups=32, and
ngroups=64 using the same parameter pattern (batch=64, nheads=64, dim=64,
dstate=128, cache_steps=4, state_dtype=torch.bfloat16,
weight_dtype=torch.float32, use_out_tensor=True) so the test covers those
dispatch ratios.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: b38ee3bd-b368-46ee-a850-939f416f74bf

📥 Commits

Reviewing files that changed from the base of the PR and between 637209a and 2fbb60a.

📒 Files selected for processing (20)
  • benchmarks/bench_ssu_sweep_mtp.py
  • benchmarks/bench_ssu_sweep_sol.py
  • flashinfer/jit/core.py
  • flashinfer/jit/mamba/__init__.py
  • flashinfer/jit/mamba/selective_state_update.py
  • flashinfer/mamba/selective_state_update.py
  • include/flashinfer/mamba/common.cuh
  • include/flashinfer/mamba/conversion.cuh
  • include/flashinfer/mamba/create_tensor_map.cuh
  • include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh
  • include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh
  • include/flashinfer/mamba/kernel_selective_state_update_mtp_horizontal.cuh
  • include/flashinfer/mamba/kernel_selective_state_update_mtp_simple.cuh
  • include/flashinfer/mamba/kernel_selective_state_update_mtp_vertical.cuh
  • include/flashinfer/mamba/kernel_selective_state_update_stp.cuh
  • include/flashinfer/mamba/selective_state_update.cuh
  • include/flashinfer/mamba/ssu_mtp_common.cuh
  • tests/mamba/test_selective_state_update_mtp.py
  • tests/mamba/test_selective_state_update_stp.py
  • tests/mamba/test_selective_state_update_varlen.py
💤 Files with no reviewable changes (1)
  • include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh

Comment on lines +351 to +359
m = re.match(r"^(bf16|f16|f32)-philox-(\d+)$", spec)
if m:
base, rounds = m.group(1), int(m.group(2))
return spec, _dtype_name_to_torch[base], rounds
if spec not in _dtype_name_to_torch:
raise ValueError(
f"Unknown dtype spec '{spec}'. "
"Expected bf16, f16, f32, or <dtype>-philox-<rounds>"
)
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Reject non-f16-philox-* specs here.

parse_dtype_spec() currently accepts bf16-philox-* and f32-philox-*, but the MTP path only supports stochastic rounding for fp16 state. Those CLI values will fail instead of being benchmarked.

🔧 Proposed fix
-    m = re.match(r"^(bf16|f16|f32)-philox-(\d+)$", spec)
+    m = re.match(r"^(f16)-philox-(\d+)$", spec)
-            "Expected bf16, f16, f32, or <dtype>-philox-<rounds>"
+            "Expected bf16, f16, f32, or f16-philox-<rounds>"
Based on learnings, the `static_assert` in `include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh` intentionally restricts stochastic rounding to fp16 state (`std::is_same_v`).
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
m = re.match(r"^(bf16|f16|f32)-philox-(\d+)$", spec)
if m:
base, rounds = m.group(1), int(m.group(2))
return spec, _dtype_name_to_torch[base], rounds
if spec not in _dtype_name_to_torch:
raise ValueError(
f"Unknown dtype spec '{spec}'. "
"Expected bf16, f16, f32, or <dtype>-philox-<rounds>"
)
m = re.match(r"^(f16)-philox-(\d+)$", spec)
if m:
base, rounds = m.group(1), int(m.group(2))
return spec, _dtype_name_to_torch[base], rounds
if spec not in _dtype_name_to_torch:
raise ValueError(
f"Unknown dtype spec '{spec}'. "
"Expected bf16, f16, f32, or f16-philox-<rounds>"
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_ssu_sweep_mtp.py` around lines 351 - 359, The parsing
currently allows bf16-philox-* and f32-philox-* even though the MTP path only
supports stochastic rounding for fp16 state; update parse_dtype_spec to reject
any philox-specs whose base is not "f16" by either restricting the regex to only
match "f16-philox-(\d+)" or by checking m.group(1) and raising a ValueError if
base != "f16"; also update the error message (and the branch that checks
_dtype_name_to_torch) to clearly state that only "f16" or "f16-philox-<rounds>"
are supported for MTP/stochastic rounding so callers get a clear failure instead
of a later runtime error.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a benchmarking script not a prod file. WONTFIX

Comment on lines +265 to +269
if ncu:
kernel_fn(**kwargs)
torch.cuda.synchronize()
print(" Single invocation done (ncu mode)")
return 0.0
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

--ncu mode currently aborts on unsupported kernels.

The timed path catches RuntimeError and keeps the sweep going, but the single-invocation branch calls the kernel directly. On non-SM100 hardware, the first vertical/horizontal launch will terminate the script instead of being reported as unsupported.

Suggested fix
     if ncu:
-        kernel_fn(**kwargs)
-        torch.cuda.synchronize()
+        try:
+            kernel_fn(**kwargs)
+            torch.cuda.synchronize()
+        except RuntimeError as e:
+            print(f"    Kernel failed: {e}")
+            return float("inf")
         print("    Single invocation done (ncu mode)")
         return 0.0
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_ssu_sweep_sol.py` around lines 265 - 269, The --ncu branch
directly calls kernel_fn(**kwargs) so unsupported kernels raise RuntimeError and
abort; wrap the kernel_fn call (and the subsequent torch.cuda.synchronize()) in
a try/except that catches RuntimeError, prints or logs an "unsupported kernel"
message (matching the timed path behavior), and returns 0.0 on failure,
otherwise proceed to synchronize, print the "Single invocation done (ncu mode)"
message and return success; update the ncu branch around kernel_fn and
torch.cuda.synchronize() accordingly.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leave me alone this is a benchmarking script

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh (2)

6-8: Consider removing unused <iostream> include.

The <iostream> header is included but there's no visible usage of stream objects (std::cout, std::cerr, etc.) in this file. If FLASHINFER_CHECK internally handles stream output, that should be encapsulated in its own header.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh` around lines
6 - 8, Remove the unused `#include` <iostream> from
invoke_selective_state_update_mtp.cuh to avoid unnecessary dependency; keep
<algorithm> and <type_traits>, and if the FLASHINFER_CHECK macro (or any other
macro/function used in this file) requires stream facilities, include its
defining header (or <iosfwd> if only forward declarations are needed) instead of
<iostream> so that the file only depends on the headers that actually provide
the symbols it uses.

265-268: Use std::max for consistency.

Line 267 uses unqualified max while also using std::clamp. For consistency and clarity in host code, prefer std::max.

♻️ Suggested fix
-    int const ctas_per_head = std::clamp(target_ctas / max(total_tiles, 1), 1, kMaxCtas);
+    int const ctas_per_head = std::clamp(target_ctas / std::max(total_tiles, 1), 1, kMaxCtas);
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh` around lines
265 - 268, The expression computing ctas_per_head uses unqualified max; replace
the call to max(target_ctas / max(total_tiles, 1), 1, kMaxCtas) by using
std::max for the inner max (i.e., std::max(total_tiles, 1) or
std::max(target_ctas / std::max(total_tiles, 1), 1) as appropriate) so it
matches the std::clamp usage and avoids ADL surprises; update the expression
computing ctas_per_head (symbols: kMaxCtas, target_ctas, ctas_per_head,
total_tiles) and ensure <algorithm> is included where this translation unit is
compiled.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh`:
- Around line 6-8: Remove the unused `#include` <iostream> from
invoke_selective_state_update_mtp.cuh to avoid unnecessary dependency; keep
<algorithm> and <type_traits>, and if the FLASHINFER_CHECK macro (or any other
macro/function used in this file) requires stream facilities, include its
defining header (or <iosfwd> if only forward declarations are needed) instead of
<iostream> so that the file only depends on the headers that actually provide
the symbols it uses.
- Around line 265-268: The expression computing ctas_per_head uses unqualified
max; replace the call to max(target_ctas / max(total_tiles, 1), 1, kMaxCtas) by
using std::max for the inner max (i.e., std::max(total_tiles, 1) or
std::max(target_ctas / std::max(total_tiles, 1), 1) as appropriate) so it
matches the std::clamp usage and avoids ADL surprises; update the expression
computing ctas_per_head (symbols: kMaxCtas, target_ctas, ctas_per_head,
total_tiles) and ensure <algorithm> is included where this translation unit is
compiled.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: cdce1ba3-0bce-449f-a005-e00098bdc6e4

📥 Commits

Reviewing files that changed from the base of the PR and between 2fbb60a and 5cf114b.

📒 Files selected for processing (1)
  • include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh

The alignment requirement depends on `TMA_STATE_ROWS` (2 *
ROWS_PER_PASS),
not on `NUM_COMPUTE_WARPS_PER_GROUP * ROWS_PER_WARP`. Move the check
after
`TMA_STATE_ROWS` is defined so the correct value is used and reported in
the error message.
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh (1)

263-263: Use std::max instead of unqualified max in host code.

The unqualified max call relies on CUDA headers or ADL to resolve. Since <algorithm> is already included, explicitly using std::max is more portable and avoids potential ambiguity with CUDA's device-only max function.

🔧 Suggested fix
-    int const ctas_per_head = std::clamp(target_ctas / max(total_tiles, 1), 1, kMaxCtas);
+    int const ctas_per_head = std::clamp(target_ctas / std::max(total_tiles, 1), 1, kMaxCtas);
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh` at line 263,
The expression computing ctas_per_head calls an unqualified max (in the clamp
call) which can resolve to CUDA device-only overloads; change the unqualified
max to std::max in the ctas_per_head initialization (the line assigning int
const ctas_per_head) so the host-side std::max from <algorithm> is used, i.e.,
replace max(total_tiles, 1) with std::max(total_tiles, 1).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh`:
- Line 263: The expression computing ctas_per_head calls an unqualified max (in
the clamp call) which can resolve to CUDA device-only overloads; change the
unqualified max to std::max in the ctas_per_head initialization (the line
assigning int const ctas_per_head) so the host-side std::max from <algorithm> is
used, i.e., replace max(total_tiles, 1) with std::max(total_tiles, 1).

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 703e6145-15e1-43bc-9e94-1cd8f71b5e67

📥 Commits

Reviewing files that changed from the base of the PR and between 5151ddb and 41c9e36.

📒 Files selected for processing (1)
  • include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh

Pad slots previously skipped all TMA loads including B/C/x, causing
invalid output for those batch entries. Fix all three kernel variants
(simple, vertical, horizontal) to always load B/C/x while still
skipping state loads for pad slots (using zero state in registers
instead).

Add TestSelectiveStateUpdateMTPPadSlots test class to verify that pad
slots produce correct output with ~25% of batch entries marked as pad.
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@include/flashinfer/mamba/kernel_selective_state_update_mtp_vertical.cuh`:
- Around line 24-38: The header kernel_selective_state_update_mtp_vertical.cuh
is missing a header guard; add a single-line `#pragma` once near the top of the
file (immediately after the license/comment block and before the first `#include`)
so the declarations in this header (e.g., types and kernels declared alongside
includes like cooperative_groups and ssu_mtp_common.cuh) are not redefined on
multiple inclusion.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 906509aa-307a-4cad-89c5-5a70bd16085c

📥 Commits

Reviewing files that changed from the base of the PR and between 41c9e36 and ba6fc7c.

📒 Files selected for processing (4)
  • include/flashinfer/mamba/kernel_selective_state_update_mtp_horizontal.cuh
  • include/flashinfer/mamba/kernel_selective_state_update_mtp_simple.cuh
  • include/flashinfer/mamba/kernel_selective_state_update_mtp_vertical.cuh
  • tests/mamba/test_selective_state_update_mtp.py

@bkryu bkryu added the run-ci label Apr 2, 2026
@bkryu
Copy link
Copy Markdown
Collaborator

bkryu commented Apr 2, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !497 has been created, and the CI pipeline #47565080 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #47565080: 7/20 passed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants