Mamba SSU: better automatic kernel selection + algorithm selection optionally exposed to the user.#2591
Conversation
Move the test input generation helper from test_selective_state_update.py to a new test_utils.py module for reuse across tests. The refactored function adds support for multi-token mode, intermediate state buffers, and configurable state cache strides.
struct - Add helper functions for tensor validation and dtype checks - Move output tensor to Optional and update checks accordingly - Add state_stride_batch and update_state fields to SelectiveStateUpdateParams - Refactor kernel param usage for clarity and consistency
Extract dispatchDimDstate and dispatchRatio helpers to simplify kernel dispatch code and reduce duplication.
- Add kernel and dispatcher support for int32/int64 state_batch_indices - Update tests to cover int32 indices - Fix test_utils to use int64 slot_idx by default Support int32 and int64 state_batch_indices in selective_state_update - Remove int32 type check to allow both int32 and int64 index types - Add stateIndex_t template parameter to kernels for index type dispatch - Extract kernel implementations to new selective_state_update_stp.cuh - Remove unused TMA helper functions from create_tensor_map.cuh - Add comprehensive MTP (multi-token prediction) test suite
checks - Add common.cuh with kernel dispatch helpers and alignment checks - Split and rename kernel_selective_state_update_stp.cuh, add kernel_selective_state_update_mtp.cuh - Refactor Python selective_state_update to clarify dimension handling - Add test for dtype mismatch between state_batch_indices and intermediate_state_indices - Update test_utils to generate int64 intermediate_slot_idx by default - Remove redundant input type check in validate_intermediate_state_indices
Always define state_batch_idx (either from state_batch_indices or pid_b) to mirror the CUDA kernel's state_batch variable. This allows the intermediate state caching logic to use a simple check of `state_batch_idx != pad_slot_id` without requiring an extra HAS_STATE_BATCH_INDICES guard, matching the CUDA kernel behavior. addresses: flashinfer-ai#2444 (comment)
- Add test_chunk_scan_combined.py comparing CUTLASS CuTe DSL Blackwell implementation against Triton reference - Move selective_state_update_triton.py into triton_reference/ package - Add Triton reference implementations for Mamba2 SSD kernels: - ssd_combined.py (main entry point) - ssd_chunk_scan.py, ssd_chunk_state.py, ssd_state_passing.py - ssd_bmm.py, softplus.py (utilities)
# Conflicts: # tests/mamba/selective_state_update_triton.py # tests/mamba/test_selective_state_update_mtp.py # tests/mamba/test_selective_state_update_stp.py
- Move dtype dispatch and instantiation to codegen via Jinja templates - Generate config and instantiation files per dtype combination - Update Python JIT logic to build/load kernels for specific dtypes - Remove C++ dtype dispatch helpers from selective_state_update.cu - Update kernel launcher comment for clarity on consumer warps
Support explicit algorithm choice (auto/simple/vertical/horizontal) for selective_state_update and MTP kernels. Update kernel signatures, Python bindings, and JIT module generation to include algorithm and compile-time shape parameters (dim, dstate, ntokens_mtp). Refactor dispatch logic for SM90/SM100 architectures.
… .cu files The config.inc defines DIM, DSTATE, NTOKENS_MTP as constexpr globals that the header's function templates rely on. With the previous order (header first, config second), NVCC's lenient two-phase lookup masked the issue, but a fresh JIT compilation after cache clearing would fail with 'identifier DIM/DSTATE is undefined' errors. clang-format is disabled for these includes because it reorders them alphabetically, which breaks compilation. AI-assisted
Assign each of the 4 consumer warps a single tensor to load (x, B, z, C) instead of warps 0 and 1 each loading two tensors sequentially. This maximizes memory-level parallelism during the load phase. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Replace cartesian-product fixture parametrization with explicit rows: one base case plus one row per parameter deviation. Cuts the test count from ~200+ (MTP) and ~144+ (STP) down to ~26 and ~15 respectively. AI-assisted Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Parametrize tests to run with all supported algorithms - Update test logic to pass algorithm argument through - Improve test output messages to include algorithm name - Add utility to detect available algorithms based on GPU arch
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/mamba/selective_state_update.py (1)
163-163:⚠️ Potential issue | 🟡 MinorMissing runtime guard for MTP + vertical/horizontal algorithm.
The docstring says "MTP mode only supports 'auto' or 'simple'" but there is no enforcement — passing
algorithm="vertical"withcache_steps >= 1silently falls through to the C++ kernel, which may fail or produce wrong results.🛡️ Proposed guard
is_mtp = cache_steps >= 1 + +if is_mtp and algorithm not in ("auto", "simple"): + raise ValueError( + f"MTP mode only supports 'auto' or 'simple' algorithm, got '{algorithm}'" + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/mamba/selective_state_update.py` at line 163, The code sets is_mtp = cache_steps >= 1 but does not enforce the docstring constraint that MTP only supports algorithm "auto" or "simple"; add a runtime guard where is_mtp is computed (using cache_steps and the function parameter algorithm) that checks if is_mtp is True and algorithm not in ("auto", "simple") and raises a clear ValueError (or TypeError) explaining that MTP mode only supports "auto" or "simple"; reference the existing symbols is_mtp, cache_steps, and algorithm so the check is colocated with the current is_mtp logic (e.g., right after is_mtp = cache_steps >= 1).
🧹 Nitpick comments (1)
flashinfer/mamba/selective_state_update.py (1)
220-229: Consider moving the error message to reduce line length (Ruff TRY003).The hardcoded algorithm mapping (auto=0, simple=1, vertical=2, horizontal=3) correctly matches the
SSUAlgorithmenum ininclude/flashinfer/mamba/selective_state_update.cuh, so no mismatch risk exists.However, line 229's inline ValueError message triggers Ruff TRY003. Consider extracting it:
♻️ Optional: silence TRY003
+_VALID_ALGORITHMS = ("auto", "simple", "vertical", "horizontal") + if algorithm == "auto": algorithm_int = 0 elif algorithm == "simple": algorithm_int = 1 elif algorithm == "vertical": algorithm_int = 2 elif algorithm == "horizontal": algorithm_int = 3 else: - raise ValueError(f"Unknown algorithm: {algorithm}") + raise ValueError(f"Unknown algorithm '{algorithm}'. Valid options: {_VALID_ALGORITHMS}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/mamba/selective_state_update.py` around lines 220 - 229, The inline ValueError message in the algorithm mapping block makes the line too long (Ruff TRY003); extract the message into a local variable before raising to shorten the raise line: compute the error string (e.g., err = f"Unknown algorithm: {algorithm}") and then raise ValueError(err) in the else branch that sets algorithm_int for inputs "auto"/"simple"/"vertical"/"horizontal" (references: variable algorithm, algorithm_int, and the ValueError raise).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@flashinfer/mamba/selective_state_update.py`:
- Line 163: The code sets is_mtp = cache_steps >= 1 but does not enforce the
docstring constraint that MTP only supports algorithm "auto" or "simple"; add a
runtime guard where is_mtp is computed (using cache_steps and the function
parameter algorithm) that checks if is_mtp is True and algorithm not in ("auto",
"simple") and raises a clear ValueError (or TypeError) explaining that MTP mode
only supports "auto" or "simple"; reference the existing symbols is_mtp,
cache_steps, and algorithm so the check is colocated with the current is_mtp
logic (e.g., right after is_mtp = cache_steps >= 1).
---
Nitpick comments:
In `@flashinfer/mamba/selective_state_update.py`:
- Around line 220-229: The inline ValueError message in the algorithm mapping
block makes the line too long (Ruff TRY003); extract the message into a local
variable before raising to shorten the raise line: compute the error string
(e.g., err = f"Unknown algorithm: {algorithm}") and then raise ValueError(err)
in the else branch that sets algorithm_int for inputs
"auto"/"simple"/"vertical"/"horizontal" (references: variable algorithm,
algorithm_int, and the ValueError raise).
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
benchmarks/routines/mamba.py (1)
18-21:⚠️ Potential issue | 🟡 MinorStale comment still references the old file path.
Lines 19-20 say
tests/mamba/selective_state_update_triton.py, but the canonical location is nowtests/mamba/triton_reference/selective_state_update.py.🔧 Proposed fix
-# Imported from tests/mamba/selective_state_update_triton.py to avoid code -# duplication. See that file for the canonical Triton kernel source. +# Imported from tests/mamba/triton_reference/selective_state_update.py to avoid code +# duplication. See that file for the canonical Triton kernel source.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/routines/mamba.py` around lines 18 - 21, Update the stale comment that references the old canonical path: find the string 'tests/mamba/selective_state_update_triton.py' in the header comment of benchmarks/routines/mamba.py (the Triton reference implementation note) and replace it with the new canonical path 'tests/mamba/triton_reference/selective_state_update.py' so the comment points to the correct file.tests/mamba/test_selective_state_update_mtp.py (1)
87-104:⚠️ Potential issue | 🟡 MinorMTP
run_kernelis missing thealgorithmparameter — algorithm selection is never tested for the MTP path.All STP subclasses parametrize over
_get_algorithms(), but the entire MTP suite always falls through to the default ("auto"). Since the public API exposes thealgorithmparameter, this creates a coverage gap.🔧 Suggested fix
- def run_kernel(self, inputs, out=None, disable_state_update=False): + def run_kernel(self, inputs, out=None, disable_state_update=False, algorithm="auto"): """Run the flashinfer kernel and return output.""" return flashinfer.mamba.selective_state_update( inputs["state_cache"], inputs["x"], inputs["dt"], inputs["A"], inputs["B"], inputs["C"], D=inputs["D"], z=inputs.get("z"), dt_bias=inputs["dt_bias"], dt_softplus=True, state_batch_indices=inputs["slot_idx"], pad_slot_id=-1, out=out, disable_state_update=disable_state_update, + algorithm=algorithm, )Then add
@pytest.mark.parametrize("algorithm", _get_algorithms())totest_output_correctnessmethods (importing_get_algorithmsfrom the STP module), mirroring the STP approach.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/mamba/test_selective_state_update_mtp.py` around lines 87 - 104, The MTP test helper run_kernel calls flashinfer.mamba.selective_state_update without passing the public algorithm parameter, so algorithm selection isn't exercised; update run_kernel to accept an algorithm argument and forward it into selective_state_update (referencing run_kernel and flashinfer.mamba.selective_state_update), and then add pytest parametrization to the MTP test methods by importing _get_algorithms() from the STP module and decorating test_output_correctness with `@pytest.mark.parametrize`("algorithm", _get_algorithms()) so the MTP path mirrors the STP coverage (also ensure test_output_correctness references the new algorithm parameter when invoking run_kernel).
🧹 Nitpick comments (3)
flashinfer/aot.py (1)
549-597: AOT pre-build adds up to 1,512 SSU modules — verify build-time/artifact-size budget.The cartesian product of
_ssu_dtype_combos(12) ×_ssu_dims(3) ×_ssu_dstates(3) ×_ssu_ntokens(7) produces 756JitSpecobjects in the base loop, and an identical additional 756 in thehas_sm90 or has_sm100branch — totalling up to 1,512 SSU AOT compilation units.While Jinja templates reduce per-module JIT latency at inference time, these all get compiled up-front during
compile_and_package_modules. This meaningfully increases:
- AOT build wall-time (especially in CI pipelines with bounded parallelism)
- Pre-built package artifact size on disk
Consider whether the full grid is needed or whether a subset (e.g., fewer
ntokensvalues or pruning uncommondstates/dimscombos) is sufficient for AOT coverage, with remaining cases served by JIT.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/aot.py` around lines 549 - 597, The current AOT grid generates up to 1,512 SSU modules by taking the cartesian product of _ssu_dtype_combos, _ssu_dims, _ssu_dstates, and _ssu_ntokens and appending a duplicate set for SM90/100 via gen_selective_state_update_module and gen_selective_state_update_sm90_module; restrict this by making the grid configurable (e.g., an AOT_SSUS_PRESET or MAX_SSU_MODULES env/option) and replace the full product with a filtered/subsampled list when packaging (for example: reduce _ssu_ntokens, drop uncommon dim/dstate combos, or randomly/sample-select a subset) so compile_and_package_modules only receives the chosen subset; ensure the selection logic references the same symbols (_ssu_dtype_combos, _ssu_dims, _ssu_dstates, _ssu_ntokens, gen_selective_state_update_module, gen_selective_state_update_sm90_module) and document the new config flag.tests/mamba/test_selective_state_update_mtp.py (2)
402-409: Mutable class attributes requireClassVarannotation (Ruff RUF012).
_INTERMEDIATE_PARAMS,_NGROUPS_PARAMS, and_LARGE_BATCH_PARAMSare mutable lists defined as class attributes. Ruff RUF012 flags this. The simplest fix is to move them to module level, consistent with how_BASE_PARAMSis defined.🔧 Proposed fix (for all three occurrences)
Move
_INTERMEDIATE_PARAMS,_NGROUPS_PARAMS, and_LARGE_BATCH_PARAMSout of their respective class bodies to module level (alongside_BASE_PARAMS), then reference them from the class:+# Module level +_INTERMEDIATE_PARAMS = [...] + class TestSelectiveStateUpdateMTPWithIntermediateStates(TestSelectiveStateUpdateMTP): - _INTERMEDIATE_PARAMS = [...]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/mamba/test_selective_state_update_mtp.py` around lines 402 - 409, _MOVE mutable class attributes to module level: take the lists named _INTERMEDIATE_PARAMS, _NGROUPS_PARAMS, and _LARGE_BATCH_PARAMS out of their test class bodies and define them alongside _BASE_PARAMS at module scope, then update the test classes to reference these module-level names (keep names unchanged). This removes mutable class attributes and satisfies Ruff RUF012 while preserving existing usage in functions/tests that access _INTERMEDIATE_PARAMS, _NGROUPS_PARAMS, and _LARGE_BATCH_PARAMS.
106-168:assert_outputs_match,assert_states_match, and_print_mismatch_detailsare copy-pasted fromtest_selective_state_update_stp.py.These three methods are byte-for-byte identical in both test files. Extracting them into a shared mixin or a base class in
tests/mamba/utils.pywould eliminate the duplication and make future tolerance changes apply consistently.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/mamba/test_selective_state_update_mtp.py` around lines 106 - 168, The three duplicated methods assert_outputs_match, assert_states_match, and _print_mismatch_details should be moved into a shared test utility (e.g., create a TestOutputStateAssertions mixin or base class in the test utilities module) and the current test class should import/derive from that utility instead of keeping copy-pasted implementations; extract the implementations exactly, place them as methods on the new mixin (keeping ATOL/RTOL usage), update the test class to use the mixin (remove the local methods), and run tests to ensure no API/behavior changes.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/aot.py`:
- Around line 588-599: gen_gdn_prefill_sm90_module currently hardcodes
sm90a_nvcc_flags which produces SM90a-only PTX and fails on SM100-only systems;
update gen_gdn_prefill_sm90_module to use the same dynamic NVCC flag resolution
as gen_selective_state_update_sm90_module (i.e., call
get_nvcc_flags_list(supported_major_versions=[9,10,11,12]) or equivalent) so it
emits compatible flags for SM100, or alternatively add a new
gen_gdn_prefill_sm100_module that uses SM100-specific flags and select between
the two at build-time; modify the call site (where gen_gdn_prefill_sm90_module
is appended) to pick the dynamic-version or the new gen_gdn_prefill_sm100_module
accordingly.
In `@tests/mamba/test_selective_state_update_mtp.py`:
- Line 281: The local variable returned from self.make_reference_output is
unpacked into y_ref and state_ref but state_ref is never used; rename state_ref
to _state_ref (e.g., y_ref, _state_ref = self.make_reference_output(inputs)) to
indicate intentional discard and satisfy Ruff RUF059; update any nearby similar
unused unpackings in the same test function if present.
- Line 431: The tuple assignment on the call to make_reference_output unpacks
y_ref, state_ref, intermediate_states_ref but one or more of these unpacked
variables are unused; either unpack only the needed values (e.g., replace with
just y_ref = self.make_reference_output(inputs) if the function returns a single
needed value) or replace unused names with underscores (e.g., _, state_ref, _ or
y_ref, _, _ ) so there are no unused variables; update the call site to use the
chosen names and ensure any subsequent code references the correct symbol
(make_reference_output, y_ref, state_ref, intermediate_states_ref).
In `@tests/mamba/triton_reference/selective_state_update.py`:
- Line 14: Remove the unnecessary noqa directive on the import of softplus:
delete the trailing " # noqa: F401" from the line that reads "from .softplus
import softplus" because softplus is actually used (see kernel references to
softplus), so the F401 suppression is a no-op; simply leave the plain import
"from .softplus import softplus".
- Around line 20-24: The current dynamic import assumes
_ilu.spec_from_file_location returns a valid spec and that _sp_spec.loader
exists; guard against both returning None by checking _sp_spec and
_sp_spec.loader after calling _ilu.spec_from_file_location(_softplus_path) and
before calling _ilu.module_from_spec and _sp_spec.loader.exec_module. If either
is None, raise a clear import error (or fallback) with contextual details (file
path and which value was None) so that softplus assignment (softplus =
_sp_mod.softplus) only runs when the module was actually loaded successfully.
---
Outside diff comments:
In `@benchmarks/routines/mamba.py`:
- Around line 18-21: Update the stale comment that references the old canonical
path: find the string 'tests/mamba/selective_state_update_triton.py' in the
header comment of benchmarks/routines/mamba.py (the Triton reference
implementation note) and replace it with the new canonical path
'tests/mamba/triton_reference/selective_state_update.py' so the comment points
to the correct file.
In `@tests/mamba/test_selective_state_update_mtp.py`:
- Around line 87-104: The MTP test helper run_kernel calls
flashinfer.mamba.selective_state_update without passing the public algorithm
parameter, so algorithm selection isn't exercised; update run_kernel to accept
an algorithm argument and forward it into selective_state_update (referencing
run_kernel and flashinfer.mamba.selective_state_update), and then add pytest
parametrization to the MTP test methods by importing _get_algorithms() from the
STP module and decorating test_output_correctness with
`@pytest.mark.parametrize`("algorithm", _get_algorithms()) so the MTP path mirrors
the STP coverage (also ensure test_output_correctness references the new
algorithm parameter when invoking run_kernel).
---
Duplicate comments:
In `@csrc/selective_state_update.cu`:
- Around line 277-278: The code performs static_cast<SSUAlgorithm>(algorithm)
without validating that the incoming integer 'algorithm' maps to a defined
SSUAlgorithm value; update the call site before invoking
invokeSelectiveStateUpdate by checking that 'algorithm' is within the
SSUAlgorithm enum's valid range (or use a helper/isValid function), and on
invalid input return an error (or throw/log) rather than performing the cast;
keep the subsequent call to invokeSelectiveStateUpdate<input_t, weight_t,
matrixA_t, state_t, stateIndex_t>(p, algo, stream) unchanged once 'algo' is
guaranteed valid.
---
Nitpick comments:
In `@flashinfer/aot.py`:
- Around line 549-597: The current AOT grid generates up to 1,512 SSU modules by
taking the cartesian product of _ssu_dtype_combos, _ssu_dims, _ssu_dstates, and
_ssu_ntokens and appending a duplicate set for SM90/100 via
gen_selective_state_update_module and gen_selective_state_update_sm90_module;
restrict this by making the grid configurable (e.g., an AOT_SSUS_PRESET or
MAX_SSU_MODULES env/option) and replace the full product with a
filtered/subsampled list when packaging (for example: reduce _ssu_ntokens, drop
uncommon dim/dstate combos, or randomly/sample-select a subset) so
compile_and_package_modules only receives the chosen subset; ensure the
selection logic references the same symbols (_ssu_dtype_combos, _ssu_dims,
_ssu_dstates, _ssu_ntokens, gen_selective_state_update_module,
gen_selective_state_update_sm90_module) and document the new config flag.
In `@tests/mamba/test_selective_state_update_mtp.py`:
- Around line 402-409: _MOVE mutable class attributes to module level: take the
lists named _INTERMEDIATE_PARAMS, _NGROUPS_PARAMS, and _LARGE_BATCH_PARAMS out
of their test class bodies and define them alongside _BASE_PARAMS at module
scope, then update the test classes to reference these module-level names (keep
names unchanged). This removes mutable class attributes and satisfies Ruff
RUF012 while preserving existing usage in functions/tests that access
_INTERMEDIATE_PARAMS, _NGROUPS_PARAMS, and _LARGE_BATCH_PARAMS.
- Around line 106-168: The three duplicated methods assert_outputs_match,
assert_states_match, and _print_mismatch_details should be moved into a shared
test utility (e.g., create a TestOutputStateAssertions mixin or base class in
the test utilities module) and the current test class should import/derive from
that utility instead of keeping copy-pasted implementations; extract the
implementations exactly, place them as methods on the new mixin (keeping
ATOL/RTOL usage), update the test class to use the mixin (remove the local
methods), and run tests to ensure no API/behavior changes.
Simplify import logic by adding tests/mamba to sys.path and using standard package imports for the Triton reference kernel and softplus.
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/mamba/test_selective_state_update_mtp.py (1)
402-409: Move class-level param lists to module level to silence Ruff RUF012.
_INTERMEDIATE_PARAMS,_NGROUPS_PARAMS, and_LARGE_BATCH_PARAMSare mutable list literals assigned directly as class attributes, which Ruff RUF012 flags as unsafe._BASE_PARAMSalready follows the correct module-level pattern — making these consistent with it also removes the static-analysis noise.♻️ Suggested refactor (shown for
_INTERMEDIATE_PARAMS; apply the same pattern to the other two)+# fmt: off +_INTERMEDIATE_PARAMS = [ + # (batch, nheads, dim, dstate, cache_steps, state_dtype, weight_dtype, use_out_tensor) + ( 64, 64, 64, 128, 4, torch.bfloat16, torch.float32, True ), # base + ( 64, 64, 64, 64, 4, torch.bfloat16, torch.float32, True ), # dstate=64 + ( 64, 64, 64, 128, 2, torch.bfloat16, torch.float32, True ), # cache_steps=2 + ( 64, 64, 64, 128, 8, torch.bfloat16, torch.float32, True ), # cache_steps=8 + ( 64, 64, 64, 128, 4, torch.bfloat16, torch.float32, False), # use_out_tensor=False +] +# fmt: on class TestSelectiveStateUpdateMTPWithIntermediateStates(TestSelectiveStateUpdateMTP): ... - # fmt: off - _INTERMEDIATE_PARAMS = [ - ... - ] - # fmt: on `@pytest.mark.parametrize`( "batch,nheads,dim,dstate,cache_steps,state_dtype,weight_dtype,use_out_tensor", _INTERMEDIATE_PARAMS, )Also applies to: 602-608, 665-669
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/mamba/test_selective_state_update_mtp.py` around lines 402 - 409, Ruff flags mutable list literals assigned as class attributes (_INTERMEDIATE_PARAMS, _NGROUPS_PARAMS, _LARGE_BATCH_PARAMS); move these parameter lists to module-level constants (mirroring _BASE_PARAMS) and have the test class reference those module-level names (or convert them to immutable tuples) so the class no longer defines mutable literals directly; apply the same change for the other two param groups referenced in the file.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/mamba/test_selective_state_update_mtp.py`:
- Around line 402-409: Ruff flags mutable list literals assigned as class
attributes (_INTERMEDIATE_PARAMS, _NGROUPS_PARAMS, _LARGE_BATCH_PARAMS); move
these parameter lists to module-level constants (mirroring _BASE_PARAMS) and
have the test class reference those module-level names (or convert them to
immutable tuples) so the class no longer defines mutable literals directly;
apply the same change for the other two param groups referenced in the file.
|
[FAILED] Pipeline #44350474: 9/20 passed |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
flashinfer/aot.py (1)
549-600: SSU AOT registration is correct; note the significantly expanded build matrix.The
if has_sm90 or has_sm100:guard correctly enablesgen_selective_state_update_sm90_modulefor SM100-only builds — that function usesget_nvcc_flags_list(supported_major_versions=[9, 10, 11, 12])and handles SM100 PTX generation properly.gen_gdn_prefill_sm90_module()at line 600 remains correctly guarded by its ownif has_sm90:check (unchanged behavior).The build matrix is now 12 dtype-combos × 3 dims × 3 dstates × 7 ntokens = 756 base + up to 756 SM90 = 1,512 new SSU modules for AOT packages. While the Jinja templates reduce per-kernel compilation time, the absolute build time for pre-built packages will grow substantially. Package builders should account for this.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/aot.py` around lines 549 - 600, SSU AOT registration is correct as-is; no code changes required — keep the gen_selective_state_update_module and gen_selective_state_update_sm90_module registration loops, the has_sm90 or has_sm100 guard, the gen_trtllm_utils_module append, and the separate has_sm90 guard for gen_gdn_prefill_sm90_module; simply note the expanded build matrix (gen_selective_state_update_module, gen_selective_state_update_sm90_module, gen_trtllm_utils_module, gen_gdn_prefill_sm90_module) and ensure package builders are aware of the increased compile/time cost when producing pre-built AOT packages.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@flashinfer/aot.py`:
- Around line 549-600: SSU AOT registration is correct as-is; no code changes
required — keep the gen_selective_state_update_module and
gen_selective_state_update_sm90_module registration loops, the has_sm90 or
has_sm100 guard, the gen_trtllm_utils_module append, and the separate has_sm90
guard for gen_gdn_prefill_sm90_module; simply note the expanded build matrix
(gen_selective_state_update_module, gen_selective_state_update_sm90_module,
gen_trtllm_utils_module, gen_gdn_prefill_sm90_module) and ensure package
builders are aware of the increased compile/time cost when producing pre-built
AOT packages.
only used by major frameworks.
📌 Description
This PR does several things:
selective_state_updatefunction (jit is fast now)Background
This PR changes changes the behavior of the function. Now, an optional string
algorithmcan be passed to the kernel. The default value 'auto' allows the user not to think about the internals of the function. Optionally, the user can specify the kernel that they want. This adjustment allowed me to make use of the recent mamba benchmarks. The sweep is shown below:One can see, that the new benchmark now correctly shows the speed difference between the reference Triton and the current implementation as opposed to my previous PR. Clearly, I previously messed up the measurements at small batch sizes.
Kernel Selection
This PR improves the kernel selection in the following ways:
simplealgorithm with several CTAs per tile.horizontalalgorithm only for bf16/fp16 states, else fall back to thevertical.🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Please check how I handled jinja templates as it's my first time using those. Also, please check whether I accidentally deleted any important tests.
Summary by CodeRabbit
New Features
Bug Fixes
Chores
Tests