Improved simple mamba SSU kernel #2962
Improved simple mamba SSU kernel #2962ishovkun wants to merge 143 commits intoflashinfer-ai:mainfrom
simple mamba SSU kernel #2962Conversation
Move the test input generation helper from test_selective_state_update.py to a new test_utils.py module for reuse across tests. The refactored function adds support for multi-token mode, intermediate state buffers, and configurable state cache strides.
struct - Add helper functions for tensor validation and dtype checks - Move output tensor to Optional and update checks accordingly - Add state_stride_batch and update_state fields to SelectiveStateUpdateParams - Refactor kernel param usage for clarity and consistency
Extract dispatchDimDstate and dispatchRatio helpers to simplify kernel dispatch code and reduce duplication.
- Add kernel and dispatcher support for int32/int64 state_batch_indices - Update tests to cover int32 indices - Fix test_utils to use int64 slot_idx by default Support int32 and int64 state_batch_indices in selective_state_update - Remove int32 type check to allow both int32 and int64 index types - Add stateIndex_t template parameter to kernels for index type dispatch - Extract kernel implementations to new selective_state_update_stp.cuh - Remove unused TMA helper functions from create_tensor_map.cuh - Add comprehensive MTP (multi-token prediction) test suite
checks - Add common.cuh with kernel dispatch helpers and alignment checks - Split and rename kernel_selective_state_update_stp.cuh, add kernel_selective_state_update_mtp.cuh - Refactor Python selective_state_update to clarify dimension handling - Add test for dtype mismatch between state_batch_indices and intermediate_state_indices - Update test_utils to generate int64 intermediate_slot_idx by default - Remove redundant input type check in validate_intermediate_state_indices
Always define state_batch_idx (either from state_batch_indices or pid_b) to mirror the CUDA kernel's state_batch variable. This allows the intermediate state caching logic to use a simple check of `state_batch_idx != pad_slot_id` without requiring an extra HAS_STATE_BATCH_INDICES guard, matching the CUDA kernel behavior. addresses: flashinfer-ai#2444 (comment)
- Add test_chunk_scan_combined.py comparing CUTLASS CuTe DSL Blackwell implementation against Triton reference - Move selective_state_update_triton.py into triton_reference/ package - Add Triton reference implementations for Mamba2 SSD kernels: - ssd_combined.py (main entry point) - ssd_chunk_scan.py, ssd_chunk_state.py, ssd_state_passing.py - ssd_bmm.py, softplus.py (utilities)
# Conflicts: # tests/mamba/selective_state_update_triton.py # tests/mamba/test_selective_state_update_mtp.py # tests/mamba/test_selective_state_update_stp.py
- Move dtype dispatch and instantiation to codegen via Jinja templates - Generate config and instantiation files per dtype combination - Update Python JIT logic to build/load kernels for specific dtypes - Remove C++ dtype dispatch helpers from selective_state_update.cu - Update kernel launcher comment for clarity on consumer warps
Support explicit algorithm choice (auto/simple/vertical/horizontal) for selective_state_update and MTP kernels. Update kernel signatures, Python bindings, and JIT module generation to include algorithm and compile-time shape parameters (dim, dstate, ntokens_mtp). Refactor dispatch logic for SM90/SM100 architectures.
… .cu files The config.inc defines DIM, DSTATE, NTOKENS_MTP as constexpr globals that the header's function templates rely on. With the previous order (header first, config second), NVCC's lenient two-phase lookup masked the issue, but a fresh JIT compilation after cache clearing would fail with 'identifier DIM/DSTATE is undefined' errors. clang-format is disabled for these includes because it reorders them alphabetically, which breaks compilation. AI-assisted
Assign each of the 4 consumer warps a single tensor to load (x, B, z, C) instead of warps 0 and 1 each loading two tensors sequentially. This maximizes memory-level parallelism during the load phase. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Replace cartesian-product fixture parametrization with explicit rows: one base case plus one row per parameter deviation. Cuts the test count from ~200+ (MTP) and ~144+ (STP) down to ~26 and ~15 respectively. AI-assisted Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Parametrize tests to run with all supported algorithms - Update test logic to pass algorithm argument through - Improve test output messages to include algorithm name - Add utility to detect available algorithms based on GPU arch
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughReplaces the monolithic MTP SSU with three modular MTP kernels (simple, horizontal, vertical), adds SM100-specialized JIT/module support and dispatch, introduces MTP utilities and new benchmarks, updates conversion/alignment helpers and STP/MTP dispatch/validation, and expands tests to cover algorithms, dstate, and pad/varlen cases. Changes
Sequence Diagram(s)sequenceDiagram
participant Host as Host/Python
participant JIT as JIT/ModuleGen
participant Dispatch as invokeSelectiveStateUpdateMTP
participant Kernel as Kernel (simple/vertical/horizontal)
participant GPU as GPU Memory
Host->>JIT: Ensure module built/loaded (select SM major)
Host->>Dispatch: Call selective_state_update(params, algorithm)
Dispatch->>Dispatch: Validate inputs, map algorithm (async_horizontal→simple), choose auto
Dispatch->>Dispatch: Build TMA descriptors & alignment checks
Dispatch->>Kernel: Launch chosen kernel with params & descriptors
Kernel->>GPU: Issue cp.async / TMA loads, compute recurrence, write intermediate/final state & outputs
Kernel-->>Dispatch: Kernel completes
Dispatch-->>Host: Return status/result
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces support for NVIDIA Blackwell (SM100) architectures in the Mamba selective state update implementation, adding specialized vertical and horizontal MTP kernels that leverage TMA and f32x2 SIMD instructions. The update also includes a refactored SM80+ simple MTP kernel using cp.async, improved alignment logic, and new benchmarking scripts for performance and Speed-of-Light analysis. Feedback was provided regarding an opportunity to improve the maintainability of the CTAS_PER_HEAD dispatch logic by replacing the series of if constexpr blocks with a more structured mechanism.
|
|
||
| // Dispatch to the largest instantiated CTAS_PER_HEAD <= ctas_per_head. | ||
| // Use if constexpr to avoid compiling invalid template instantiations. | ||
| if constexpr (DIM / 4 >= kRowsPerPass) { |
There was a problem hiding this comment.
There was a problem hiding this comment.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/mamba/test_selective_state_update_varlen.py (1)
258-271:⚠️ Potential issue | 🟡 MinorKeep auto-dispatch coverage in the varlen suite.
Restricting these
cu_seqlenscases to"simple"means they no longer exercise the public auto path, so a regression in thecu_seqlens -> simplefallback would slip through. Please keep"auto"here as well, and mirror that in the other varlen parametrizations below.💡 Suggested parametrization
- `@pytest.mark.parametrize`("algorithm", ["simple"]) + `@pytest.mark.parametrize`("algorithm", ["auto", "simple"])Based on learnings: In flashinfer-ai/flashinfer MTP paths, the vertical and horizontal selective_state_update kernels do not support varlen (
cu_seqlens), and the invoker falls back toSSUAlgorithm::kSimplein auto mode whenparams.cu_seqlensis set.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/mamba/test_selective_state_update_varlen.py` around lines 258 - 271, The test currently parametrizes only "simple" for the algorithm in test_varlen_uniform which prevents exercising the public auto-dispatch path; update the parametrization to include "auto" (i.e., ["simple", "auto"]) for test_varlen_uniform and mirror that change in the other varlen-related parametrizations in this file so the auto fallback (cu_seqlens -> simple) is covered; locate the algorithm param in the test_varlen_uniform function and the other varlen test parameter blocks and add "auto" to the algorithm list.
🧹 Nitpick comments (4)
include/flashinfer/mamba/ssu_mtp_common.cuh (1)
21-23: Make this header self-contained.This file uses
uint32_t/int64_t,fabsf/fmaxf, andstd::numeric_limitsbut only includes<cuda/barrier>plusconversion.cuh. Please include the standard headers directly so this doesn't depend on include order.♻️ Suggested include set
`#include` <cuda/barrier> +#include <cmath> +#include <cstdint> +#include <limits> + `#include` "conversion.cuh"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/mamba/ssu_mtp_common.cuh` around lines 21 - 23, The header ssu_mtp_common.cuh is not self-contained: it uses uint32_t and int64_t, fabsf/fmaxf, and std::numeric_limits but only includes <cuda/barrier> and "conversion.cuh"; add direct standard includes (e.g., <cstdint>, <cmath>, and <limits>) at the top of ssu_mtp_common.cuh so symbols used by functions/types in this file (uint32_t, int64_t, fabsf, fmaxf, std::numeric_limits) are defined regardless of include order.include/flashinfer/mamba/common.cuh (1)
41-46: Guard zero-widthPackedAlignedinstantiations.
getVectorLoadSizeForFullUtilization()can bottom out at0for smallDSTATEvalues, soPackedAligned<T, 0>currently fails wherever the compiler happens to complain first. A localstatic_asserthere would make that failure explicit.♻️ Suggested guard
template <typename T, int N = sizeof(float4) / sizeof(T)> struct alignas(largestPow2Divisor(N * sizeof(T))) PackedAligned { + static_assert(N > 0, "PackedAligned requires at least one element"); T val[N]; static constexpr int count = N;🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/mamba/common.cuh` around lines 41 - 46, Add a compile-time guard to prevent zero-width instantiations of the PackedAligned template: inside the template struct PackedAligned (templated on T and N) add a static_assert that N > 0 with a clear message (e.g. "PackedAligned instantiated with N == 0; ensure getVectorLoadSizeForFullUtilization() returns >0") so that attempts to instantiate PackedAligned<T,0> fail with an explicit error rather than a confusing template/ABI failure; refer to PackedAligned, its template parameter N, and largestPow2Divisor/getVectorLoadSizeForFullUtilization when making this change.tests/mamba/test_selective_state_update_mtp.py (2)
56-66: Please putautoback into the MTP algorithm matrix.This PR changes the dispatcher, but the fixture now exercises only explicit kernels. Without an
autocase, the batch-based routing and fallback logic can regress without any MTP test failing.Proposed test matrix tweak
`@pytest.fixture`( autouse=True, params=[ + "auto", "simple", pytest.param("vertical", marks=_requires_sm100), pytest.param("horizontal", marks=_requires_sm100), ], )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/mamba/test_selective_state_update_mtp.py` around lines 56 - 66, The test fixture _algorithm currently only parametrizes "simple", "vertical", and "horizontal" causing MTP to miss the 'auto' routing path; modify the params list in the _algorithm pytest.fixture to include "auto" (e.g., add "auto" into the params array alongside "simple", "vertical", and "horizontal") so the test matrix exercises the batch-based routing/fallback logic that relies on the automatic kernel selection.
629-635: The new MTPngroupssweep still misses three dispatch ratios.With
nheads=64, this matrix only coversHEADS_PER_GROUPvalues 64, 32, 16, and 8. The runtime also instantiates 4, 2, and 1, so half of the vertical/horizontal specializations are still untested.Proposed coverage extension
_NGROUPS_PARAMS = ( # (batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype, use_out_tensor, ngroups) ( 64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True, 1), ( 64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True, 2), ( 64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True, 4), ( 64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True, 8), + ( 64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True, 16), + ( 64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True, 32), + ( 64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True, 64), )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/mamba/test_selective_state_update_mtp.py` around lines 629 - 635, The _NGROUPS_PARAMS sweep only includes ngroups = 1,2,4,8 (HEADS_PER_GROUP = 64,32,16,8) and misses the cases where the runtime instantiates ngroups = 16,32,64 (HEADS_PER_GROUP = 4,2,1); update the _NGROUPS_PARAMS tuple to also include entries for ngroups=16, ngroups=32, and ngroups=64 using the same parameter pattern (batch=64, nheads=64, dim=64, dstate=128, cache_steps=4, state_dtype=torch.bfloat16, weight_dtype=torch.float32, use_out_tensor=True) so the test covers those dispatch ratios.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/bench_ssu_sweep_mtp.py`:
- Around line 351-359: The parsing currently allows bf16-philox-* and
f32-philox-* even though the MTP path only supports stochastic rounding for fp16
state; update parse_dtype_spec to reject any philox-specs whose base is not
"f16" by either restricting the regex to only match "f16-philox-(\d+)" or by
checking m.group(1) and raising a ValueError if base != "f16"; also update the
error message (and the branch that checks _dtype_name_to_torch) to clearly state
that only "f16" or "f16-philox-<rounds>" are supported for MTP/stochastic
rounding so callers get a clear failure instead of a later runtime error.
In `@benchmarks/bench_ssu_sweep_sol.py`:
- Around line 265-269: The --ncu branch directly calls kernel_fn(**kwargs) so
unsupported kernels raise RuntimeError and abort; wrap the kernel_fn call (and
the subsequent torch.cuda.synchronize()) in a try/except that catches
RuntimeError, prints or logs an "unsupported kernel" message (matching the timed
path behavior), and returns 0.0 on failure, otherwise proceed to synchronize,
print the "Single invocation done (ncu mode)" message and return success; update
the ncu branch around kernel_fn and torch.cuda.synchronize() accordingly.
In `@include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh`:
- Around line 151-163: The initial check uses kHorizontalDimAlignment
(NUM_COMPUTE_WARPS_PER_GROUP * ROWS_PER_WARP) allowing DIM values that are
multiples of 16 but the kernel template requires DIM % TMA_STATE_ROWS == 0;
update the validation in the horizontal launcher (the FLASHINFER_CHECK that
currently tests DIM % kHorizontalDimAlignment == 0) to instead verify DIM %
TMA_STATE_ROWS == 0 (or add an additional check against TMA_STATE_ROWS) so
mismatched DIM (e.g. 48, 80) are rejected before kernel launch; reference
kHorizontalDimAlignment, TMA_STATE_ROWS, and the existing FLASHINFER_CHECK lines
to locate where to change the condition.
In `@include/flashinfer/mamba/kernel_selective_state_update_mtp_horizontal.cuh`:
- Around line 108-145: The IS_PAD branch currently skips the initial
cde::cp_async_bulk_tensor_4d_global_to_shared calls for sram.B, sram.C and
sram.x but the compute/epilogue path still reads sram.B/C/x, causing use of
uninitialized shared memory; fix by ensuring the TMA transactions for B/C/x are
still issued when IS_PAD (or alternatively skip the compute/epilogue for padded
rows). Concretely, in the lane==0 / !IS_PAD block replace the conditional with
logic that issues cde::cp_async_bulk_tensor_4d_global_to_shared for sram.B,
sram.C and the x loop even when IS_PAD is true (but you can supply zeroed host
buffers if needed), and make sure the first barrier_arrive_tx call for
sram.bar_state_in_full[slot] uses bytes = bytesBCX + bytesChunk for the (h==0 &&
tl==0) case so the producer byte count matches the B/C/X transfers;
alternatively, gate the compute/epilogue reads of sram.B/sram.C/sram.x to skip
them entirely when IS_PAD.
In `@include/flashinfer/mamba/kernel_selective_state_update_mtp_simple.cuh`:
- Around line 229-259: The unified dst-slot logic changed semantics: ensure
writes honor params.update_state and preserve final-state writes to params.state
when callers expect to cache intermediates. In the dst_state_batch_indices
branch (dst_state_batch_indices) skip writes when params.update_state is false
(set sram.state_dst_slots[step]=SKIP) so per-step indices don't cause writes if
update_state is disabled; in the params.intermediate_states branch, special-case
the last step (step == seq_len-1) to route the destination to state_batch when
params.update_state is true (i.e., set sram.state_dst_slots[step]=state_batch)
instead of the icache slot, otherwise use icache_idx * params.cache_steps +
step; apply the same fix at the other occurrences noted (lines ~312-322 and
~436-474) referencing sram.state_dst_slots, intermediate_state_indices,
dst_state_batch_indices, params.update_state, params.intermediate_states,
state_batch, and SKIP.
In `@include/flashinfer/mamba/kernel_selective_state_update_mtp_vertical.cuh`:
- Around line 192-198: The code silently drops tail state columns when DSTATE
isn’t divisible by warpSize because stateValuesPerThread = DSTATE / warpSize
uses integer division; add an explicit compile-time guard to prevent this misuse
by inserting a static_assert that DSTATE % warpSize == 0 (with a helpful
message) near the current declarations (around
stateValuesPerThread/packed_state_t) so callers must satisfy the warp-aligned
state width or refactor later to handle remainders.
---
Outside diff comments:
In `@tests/mamba/test_selective_state_update_varlen.py`:
- Around line 258-271: The test currently parametrizes only "simple" for the
algorithm in test_varlen_uniform which prevents exercising the public
auto-dispatch path; update the parametrization to include "auto" (i.e.,
["simple", "auto"]) for test_varlen_uniform and mirror that change in the other
varlen-related parametrizations in this file so the auto fallback (cu_seqlens ->
simple) is covered; locate the algorithm param in the test_varlen_uniform
function and the other varlen test parameter blocks and add "auto" to the
algorithm list.
---
Nitpick comments:
In `@include/flashinfer/mamba/common.cuh`:
- Around line 41-46: Add a compile-time guard to prevent zero-width
instantiations of the PackedAligned template: inside the template struct
PackedAligned (templated on T and N) add a static_assert that N > 0 with a clear
message (e.g. "PackedAligned instantiated with N == 0; ensure
getVectorLoadSizeForFullUtilization() returns >0") so that attempts to
instantiate PackedAligned<T,0> fail with an explicit error rather than a
confusing template/ABI failure; refer to PackedAligned, its template parameter
N, and largestPow2Divisor/getVectorLoadSizeForFullUtilization when making this
change.
In `@include/flashinfer/mamba/ssu_mtp_common.cuh`:
- Around line 21-23: The header ssu_mtp_common.cuh is not self-contained: it
uses uint32_t and int64_t, fabsf/fmaxf, and std::numeric_limits but only
includes <cuda/barrier> and "conversion.cuh"; add direct standard includes
(e.g., <cstdint>, <cmath>, and <limits>) at the top of ssu_mtp_common.cuh so
symbols used by functions/types in this file (uint32_t, int64_t, fabsf, fmaxf,
std::numeric_limits) are defined regardless of include order.
In `@tests/mamba/test_selective_state_update_mtp.py`:
- Around line 56-66: The test fixture _algorithm currently only parametrizes
"simple", "vertical", and "horizontal" causing MTP to miss the 'auto' routing
path; modify the params list in the _algorithm pytest.fixture to include "auto"
(e.g., add "auto" into the params array alongside "simple", "vertical", and
"horizontal") so the test matrix exercises the batch-based routing/fallback
logic that relies on the automatic kernel selection.
- Around line 629-635: The _NGROUPS_PARAMS sweep only includes ngroups = 1,2,4,8
(HEADS_PER_GROUP = 64,32,16,8) and misses the cases where the runtime
instantiates ngroups = 16,32,64 (HEADS_PER_GROUP = 4,2,1); update the
_NGROUPS_PARAMS tuple to also include entries for ngroups=16, ngroups=32, and
ngroups=64 using the same parameter pattern (batch=64, nheads=64, dim=64,
dstate=128, cache_steps=4, state_dtype=torch.bfloat16,
weight_dtype=torch.float32, use_out_tensor=True) so the test covers those
dispatch ratios.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: b38ee3bd-b368-46ee-a850-939f416f74bf
📒 Files selected for processing (20)
benchmarks/bench_ssu_sweep_mtp.pybenchmarks/bench_ssu_sweep_sol.pyflashinfer/jit/core.pyflashinfer/jit/mamba/__init__.pyflashinfer/jit/mamba/selective_state_update.pyflashinfer/mamba/selective_state_update.pyinclude/flashinfer/mamba/common.cuhinclude/flashinfer/mamba/conversion.cuhinclude/flashinfer/mamba/create_tensor_map.cuhinclude/flashinfer/mamba/invoke_selective_state_update_mtp.cuhinclude/flashinfer/mamba/kernel_selective_state_update_mtp.cuhinclude/flashinfer/mamba/kernel_selective_state_update_mtp_horizontal.cuhinclude/flashinfer/mamba/kernel_selective_state_update_mtp_simple.cuhinclude/flashinfer/mamba/kernel_selective_state_update_mtp_vertical.cuhinclude/flashinfer/mamba/kernel_selective_state_update_stp.cuhinclude/flashinfer/mamba/selective_state_update.cuhinclude/flashinfer/mamba/ssu_mtp_common.cuhtests/mamba/test_selective_state_update_mtp.pytests/mamba/test_selective_state_update_stp.pytests/mamba/test_selective_state_update_varlen.py
💤 Files with no reviewable changes (1)
- include/flashinfer/mamba/kernel_selective_state_update_mtp.cuh
| m = re.match(r"^(bf16|f16|f32)-philox-(\d+)$", spec) | ||
| if m: | ||
| base, rounds = m.group(1), int(m.group(2)) | ||
| return spec, _dtype_name_to_torch[base], rounds | ||
| if spec not in _dtype_name_to_torch: | ||
| raise ValueError( | ||
| f"Unknown dtype spec '{spec}'. " | ||
| "Expected bf16, f16, f32, or <dtype>-philox-<rounds>" | ||
| ) |
There was a problem hiding this comment.
Reject non-f16-philox-* specs here.
parse_dtype_spec() currently accepts bf16-philox-* and f32-philox-*, but the MTP path only supports stochastic rounding for fp16 state. Those CLI values will fail instead of being benchmarked.
🔧 Proposed fix
- m = re.match(r"^(bf16|f16|f32)-philox-(\d+)$", spec)
+ m = re.match(r"^(f16)-philox-(\d+)$", spec)- "Expected bf16, f16, f32, or <dtype>-philox-<rounds>"
+ "Expected bf16, f16, f32, or f16-philox-<rounds>"📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| m = re.match(r"^(bf16|f16|f32)-philox-(\d+)$", spec) | |
| if m: | |
| base, rounds = m.group(1), int(m.group(2)) | |
| return spec, _dtype_name_to_torch[base], rounds | |
| if spec not in _dtype_name_to_torch: | |
| raise ValueError( | |
| f"Unknown dtype spec '{spec}'. " | |
| "Expected bf16, f16, f32, or <dtype>-philox-<rounds>" | |
| ) | |
| m = re.match(r"^(f16)-philox-(\d+)$", spec) | |
| if m: | |
| base, rounds = m.group(1), int(m.group(2)) | |
| return spec, _dtype_name_to_torch[base], rounds | |
| if spec not in _dtype_name_to_torch: | |
| raise ValueError( | |
| f"Unknown dtype spec '{spec}'. " | |
| "Expected bf16, f16, f32, or f16-philox-<rounds>" | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@benchmarks/bench_ssu_sweep_mtp.py` around lines 351 - 359, The parsing
currently allows bf16-philox-* and f32-philox-* even though the MTP path only
supports stochastic rounding for fp16 state; update parse_dtype_spec to reject
any philox-specs whose base is not "f16" by either restricting the regex to only
match "f16-philox-(\d+)" or by checking m.group(1) and raising a ValueError if
base != "f16"; also update the error message (and the branch that checks
_dtype_name_to_torch) to clearly state that only "f16" or "f16-philox-<rounds>"
are supported for MTP/stochastic rounding so callers get a clear failure instead
of a later runtime error.
There was a problem hiding this comment.
This is a benchmarking script not a prod file. WONTFIX
| if ncu: | ||
| kernel_fn(**kwargs) | ||
| torch.cuda.synchronize() | ||
| print(" Single invocation done (ncu mode)") | ||
| return 0.0 |
There was a problem hiding this comment.
--ncu mode currently aborts on unsupported kernels.
The timed path catches RuntimeError and keeps the sweep going, but the single-invocation branch calls the kernel directly. On non-SM100 hardware, the first vertical/horizontal launch will terminate the script instead of being reported as unsupported.
Suggested fix
if ncu:
- kernel_fn(**kwargs)
- torch.cuda.synchronize()
+ try:
+ kernel_fn(**kwargs)
+ torch.cuda.synchronize()
+ except RuntimeError as e:
+ print(f" Kernel failed: {e}")
+ return float("inf")
print(" Single invocation done (ncu mode)")
return 0.0🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@benchmarks/bench_ssu_sweep_sol.py` around lines 265 - 269, The --ncu branch
directly calls kernel_fn(**kwargs) so unsupported kernels raise RuntimeError and
abort; wrap the kernel_fn call (and the subsequent torch.cuda.synchronize()) in
a try/except that catches RuntimeError, prints or logs an "unsupported kernel"
message (matching the timed path behavior), and returns 0.0 on failure,
otherwise proceed to synchronize, print the "Single invocation done (ncu mode)"
message and return success; update the ncu branch around kernel_fn and
torch.cuda.synchronize() accordingly.
There was a problem hiding this comment.
leave me alone this is a benchmarking script
include/flashinfer/mamba/kernel_selective_state_update_mtp_horizontal.cuh
Show resolved
Hide resolved
There was a problem hiding this comment.
🧹 Nitpick comments (2)
include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh (2)
6-8: Consider removing unused<iostream>include.The
<iostream>header is included but there's no visible usage of stream objects (std::cout,std::cerr, etc.) in this file. IfFLASHINFER_CHECKinternally handles stream output, that should be encapsulated in its own header.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh` around lines 6 - 8, Remove the unused `#include` <iostream> from invoke_selective_state_update_mtp.cuh to avoid unnecessary dependency; keep <algorithm> and <type_traits>, and if the FLASHINFER_CHECK macro (or any other macro/function used in this file) requires stream facilities, include its defining header (or <iosfwd> if only forward declarations are needed) instead of <iostream> so that the file only depends on the headers that actually provide the symbols it uses.
265-268: Usestd::maxfor consistency.Line 267 uses unqualified
maxwhile also usingstd::clamp. For consistency and clarity in host code, preferstd::max.♻️ Suggested fix
- int const ctas_per_head = std::clamp(target_ctas / max(total_tiles, 1), 1, kMaxCtas); + int const ctas_per_head = std::clamp(target_ctas / std::max(total_tiles, 1), 1, kMaxCtas);🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh` around lines 265 - 268, The expression computing ctas_per_head uses unqualified max; replace the call to max(target_ctas / max(total_tiles, 1), 1, kMaxCtas) by using std::max for the inner max (i.e., std::max(total_tiles, 1) or std::max(target_ctas / std::max(total_tiles, 1), 1) as appropriate) so it matches the std::clamp usage and avoids ADL surprises; update the expression computing ctas_per_head (symbols: kMaxCtas, target_ctas, ctas_per_head, total_tiles) and ensure <algorithm> is included where this translation unit is compiled.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh`:
- Around line 6-8: Remove the unused `#include` <iostream> from
invoke_selective_state_update_mtp.cuh to avoid unnecessary dependency; keep
<algorithm> and <type_traits>, and if the FLASHINFER_CHECK macro (or any other
macro/function used in this file) requires stream facilities, include its
defining header (or <iosfwd> if only forward declarations are needed) instead of
<iostream> so that the file only depends on the headers that actually provide
the symbols it uses.
- Around line 265-268: The expression computing ctas_per_head uses unqualified
max; replace the call to max(target_ctas / max(total_tiles, 1), 1, kMaxCtas) by
using std::max for the inner max (i.e., std::max(total_tiles, 1) or
std::max(target_ctas / std::max(total_tiles, 1), 1) as appropriate) so it
matches the std::clamp usage and avoids ADL surprises; update the expression
computing ctas_per_head (symbols: kMaxCtas, target_ctas, ctas_per_head,
total_tiles) and ensure <algorithm> is included where this translation unit is
compiled.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: cdce1ba3-0bce-449f-a005-e00098bdc6e4
📒 Files selected for processing (1)
include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh
The alignment requirement depends on `TMA_STATE_ROWS` (2 * ROWS_PER_PASS), not on `NUM_COMPUTE_WARPS_PER_GROUP * ROWS_PER_WARP`. Move the check after `TMA_STATE_ROWS` is defined so the correct value is used and reported in the error message.
There was a problem hiding this comment.
🧹 Nitpick comments (1)
include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh (1)
263-263: Usestd::maxinstead of unqualifiedmaxin host code.The unqualified
maxcall relies on CUDA headers or ADL to resolve. Since<algorithm>is already included, explicitly usingstd::maxis more portable and avoids potential ambiguity with CUDA's device-onlymaxfunction.🔧 Suggested fix
- int const ctas_per_head = std::clamp(target_ctas / max(total_tiles, 1), 1, kMaxCtas); + int const ctas_per_head = std::clamp(target_ctas / std::max(total_tiles, 1), 1, kMaxCtas);🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh` at line 263, The expression computing ctas_per_head calls an unqualified max (in the clamp call) which can resolve to CUDA device-only overloads; change the unqualified max to std::max in the ctas_per_head initialization (the line assigning int const ctas_per_head) so the host-side std::max from <algorithm> is used, i.e., replace max(total_tiles, 1) with std::max(total_tiles, 1).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh`:
- Line 263: The expression computing ctas_per_head calls an unqualified max (in
the clamp call) which can resolve to CUDA device-only overloads; change the
unqualified max to std::max in the ctas_per_head initialization (the line
assigning int const ctas_per_head) so the host-side std::max from <algorithm> is
used, i.e., replace max(total_tiles, 1) with std::max(total_tiles, 1).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 703e6145-15e1-43bc-9e94-1cd8f71b5e67
📒 Files selected for processing (1)
include/flashinfer/mamba/invoke_selective_state_update_mtp.cuh
Pad slots previously skipped all TMA loads including B/C/x, causing invalid output for those batch entries. Fix all three kernel variants (simple, vertical, horizontal) to always load B/C/x while still skipping state loads for pad slots (using zero state in registers instead). Add TestSelectiveStateUpdateMTPPadSlots test class to verify that pad slots produce correct output with ~25% of batch entries marked as pad.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@include/flashinfer/mamba/kernel_selective_state_update_mtp_vertical.cuh`:
- Around line 24-38: The header kernel_selective_state_update_mtp_vertical.cuh
is missing a header guard; add a single-line `#pragma` once near the top of the
file (immediately after the license/comment block and before the first `#include`)
so the declarations in this header (e.g., types and kernels declared alongside
includes like cooperative_groups and ssu_mtp_common.cuh) are not redefined on
multiple inclusion.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 906509aa-307a-4cad-89c5-5a70bd16085c
📒 Files selected for processing (4)
include/flashinfer/mamba/kernel_selective_state_update_mtp_horizontal.cuhinclude/flashinfer/mamba/kernel_selective_state_update_mtp_simple.cuhinclude/flashinfer/mamba/kernel_selective_state_update_mtp_vertical.cuhtests/mamba/test_selective_state_update_mtp.py
|
/bot run |
|
[FAILED] Pipeline #47565080: 7/20 passed |
📌 Description
This PR upgrades the SSU MTP "simple" kernel with cp.async state prefetching, vectorized loads, and a consolidated state write path, delivering drastic performance improvements — particularly in the latency spectrum. The
async_horizontalkernel was a temporary development vehicle used during implementation and benchmarking; once the optimizations were validated, the simple kernel was replaced with the improved version and the temporary kernel was removed.Summary
This PR upgrades the SSU MTP "simple" kernel with cp.async state prefetching, vectorized loads, and a consolidated state write path, delivering drastic performance improvements — particularly in the latency spectrum. The
async_horizontalkernel was a temporary development vehicle used during implementation and benchmarking; once the optimizations were validated, the simple kernel was replaced with the improved version and the temporary kernel was removed.Key changes
Async state prefetch (cp.async → double-buffered smem)
state_inwithcp.asyncinto a double-buffered shared memory staging area (state_in[STATE_STAGES]).cp_async_state_cooperativehelper function.Vectorized loads
PackedAlignedfor vectorized loads of B, C, and x tensors in the load path, improving memory access efficiency when padding is not active.State write path consolidation
state_dst_slots[]during the load phase, eliminating redundant index recomputation per pass/dd.dst_slot != SKIPpath.OOB handling cleanup
__syncthreads__barrier.Latency hiding
dst_slotprefetch earlier to hide LDS latency.mul_f32x2for state decode scale.Varlen + scaled-state support
cu_seqlensor scaled (quantized) state.BANK_CYCLE_ELEMSscheme with a simplerDSTATE_PAD(128-byte aligned) wide tile.Validation
intermediate_states_bufferanddst_state_batch_indicescannot both be provided.Benchmarking
bench_ssu_sweep_sol.py— SOL (speed-of-light) benchmark script for SSU MTP mode.🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Improvements
Tests