Selective State Update kernel (mamba)#2301
Conversation
📝 WalkthroughWalkthroughAdds a new selective_state_update feature: CUDA headers/kernels (simple + SM90 CUtensorMap path), a host launcher and TVM FFI binding, Python JIT glue (base + SM90), a Triton reference implementation, pytest validations, and AOT/JIT integration. Changes
Sequence Diagram(s)sequenceDiagram
participant User as Python User
participant API as flashinfer.mamba API
participant JIT as JIT Compiler/Loader
participant Binding as TVM FFI Binding
participant Launcher as Host Launcher
participant GPU as GPU Device
User->>API: selective_state_update(state, x, dt, A, B, C, D, ...)
API->>JIT: get_selective_state_update_module(device)
alt cached
JIT-->>API: module
else build
JIT->>JIT: gen_jit_spec() -> nvcc compile/link
JIT-->>API: build_and_load()
end
API->>Binding: prepare tensors & call binding
Binding->>Launcher: selective_state_update(...)
Launcher->>Launcher: validate inputs, pack SelectiveStateUpdateParams
Launcher->>Launcher: dtype-key dispatch → choose kernel specialization
opt advanced_path_requires_TMA
Launcher->>GPU: cuDriver create CUtensorMap
end
Launcher->>GPU: launch kernel (simple or producer-consumer)
GPU-->>Launcher: kernel complete
Launcher-->>Binding: return
Binding-->>API: done
API-->>User: output tensor
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @ishovkun, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the performance of Mamba layer operations by integrating specialized CUDA kernels for the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces an efficient selective_state_update kernel for Mamba, with specialized versions for pre-Hopper and Hopper+ architectures. The implementation is well-structured, with a clear separation between the Python API, TVM-FFI bindings, and the CUDA kernels. The code includes extensive validation checks and a comprehensive test suite against a Triton reference implementation.
However, I've identified a few issues. The most critical is that the CUDA kernels have hardcoded dimension sizes (DSTATE=128, DIM=64), which severely limits their use cases. The Python wrapper also lacks the dimension unsqueezing logic present in the reference implementation, making it less user-friendly. There are also some minor issues like a documentation error, unreachable code, and an inconsistency in precision between the two kernel versions.
Addressing these points will significantly improve the robustness and usability of this new feature.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI Agents
In @include/flashinfer/mamba/selective_state_update.cuh:
- Around line 493-497: The kernel launch currently checks cudaGetLastError() and
prints it via printf but does not propagate it; change the code that calls
cudaGetLastError()/cudaGetErrorString() to capture the cudaError_t (err) and
propagate it out of the function (e.g., return err or throw a runtime_error)
instead of only printing, and consider calling cudaDeviceSynchronize() before
checking errors if you need to catch asynchronous failures; update the function
signature and all callers of this kernel-launching function accordingly (look
for cudaGetLastError, cudaGetErrorString, and the kernel launch site in
selective_state_update.cuh) so callers can handle the error.
In @tests/mamba/selective_state_update_triton.py:
- Around line 297-302: The tie_hdim expression can raise AttributeError when
dt_bias is None because it calls dt_bias.stride(-1); update the tie_hdim boolean
to first ensure dt_bias is not None (e.g., include "and dt_bias is not None and
dt_bias.stride(-1) == 0") or otherwise short-circuit the check so
dt_bias.stride(...) is only evaluated when dt_bias exists; modify the line
computing tie_hdim that references A.stride, dt.stride and dt_bias.stride to
include this None check.
🧹 Nitpick comments (8)
include/flashinfer/mamba/create_tensor_map.cuh (1)
18-24: Minor: Redundant error output ingpuAssert.The function outputs the error message via both
fprintf(stderr, ...)andstd::cout, which is redundant and uses inconsistent streams (stderr vs stdout).🔎 Suggested fix
static inline void gpuAssert(cudaError_t code, const char* file, int line, bool abort = true) { if (code != cudaSuccess) { fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); - std::cout << "GPU assert failed" << std::endl; if (abort) exit(code); } }include/flashinfer/mamba/selective_state_update.cuh (3)
102-111: MissingVectorizedLoadTraitsspecializations for other data types.Only
__nv_bfloat16is specialized. This will cause a compilation error if any other type combination is used with the simple kernel, since the primary template is empty.This is fine if bfloat16-only support is intentional for now, but consider adding a
static_assertto the primary template to provide a clear error message.🔎 Suggested improvement
template <typename input_t, typename weight_t, typename state_t> -struct VectorizedLoadTraits {}; +struct VectorizedLoadTraits { + static_assert(sizeof(input_t) == 0, + "VectorizedLoadTraits not specialized for this type combination"); +};
164-166: Consider: Inconsistentexpimplementation between kernels.The simple kernel uses
__expf(A_value * dt_value)(line 165), while the producer-consumer kernel usesfast_exp(A_value * dt_value)(line 419). Thefast_expuses PTXex2.approx.f32which may have slightly different precision characteristics than__expf.If numerical consistency between GPU paths matters, consider using the same implementation in both kernels.
456-457: Hardcoded dimension constraints limit kernel flexibility.The kernel only supports
DSTATE=128(line 457) andDIM=64for Hopper (line 475). While these are validated withFLASHINFER_CHECK, consider adding a brief comment explaining why these specific values are chosen (e.g., shared memory constraints, warp alignment) for future maintainers. Based on coding guidelines, hot-path algorithmic choices should be documented.Also applies to: 474-476
tests/mamba/test_selective_state_update.py (2)
10-12: Remove unusedmatrixA_dtypeparameter.The
matrixA_dtypeparameter is passed tocreate_test_inputsbut never used in the function body. Either remove it from the signature or use it when creating tensorAif float32 is not always desired.🔎 Proposed fix
def create_test_inputs( - batch_size, nheads, dim, dstate, ngroups, input_dtype, weight_dtype, matrixA_dtype + batch_size, nheads, dim, dstate, ngroups, input_dtype, weight_dtype ):And update the call site accordingly:
inputs = create_test_inputs( batch, nheads, dim, dstate, ngroups, input_dtype, weight_dtype, - matrixA_dtype, )
68-87: Consider adding GPU capability check for test skip conditions.Per coding guidelines, tests should use
flashinfer.utilsfunctions to skip on unsupported GPU architectures. The PR mentions a Hopper+ specific kernel path, so you may want to add appropriate skip conditions.🔎 Example
import pytest from flashinfer.utils import get_compute_capability # At module level or as a fixture @pytest.fixture(autouse=True) def check_cuda(): if not torch.cuda.is_available(): pytest.skip("CUDA not available") # Or as a decorator if specific SM version is required: # @pytest.mark.skipif(get_compute_capability()[0] < 8, reason="Requires SM80+")csrc/selective_state_update.cu (2)
119-120: Simplify the assertion for unsupportedzparameter.The current assertion syntax
FLASHINFER_CHECK(!z.has_value() && "z is not supported yet")is unconventional. The string literal always evaluates to true, making the condition equivalent toFLASHINFER_CHECK(!z.has_value()). Consider using the standard two-argument form with an explicit message.🔎 Proposed fix
- // if(z.has_value()) - FLASHINFER_CHECK(!z.has_value() && "z is not supported yet"); + FLASHINFER_CHECK(!z.has_value(), "z is not supported yet");
156-158: Dead code:zpointer copy is unreachable.Since line 120 throws an error if
z.has_value(), this block is currently unreachable. Consider removing it untilzsupport is implemented, or leave it with a comment indicating it's a placeholder for future support.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (13)
csrc/flashinfer_mamba_binding.cucsrc/selective_state_update.cuflashinfer/__init__.pyflashinfer/jit/mamba/__init__.pyflashinfer/jit/mamba/selective_state_update.pyflashinfer/mamba/__init__.pyflashinfer/mamba/selective_state_update.pyinclude/flashinfer/mamba/conversion.cuhinclude/flashinfer/mamba/create_tensor_map.cuhinclude/flashinfer/mamba/selective_state_update.cuhtests/mamba/__init__.pytests/mamba/selective_state_update_triton.pytests/mamba/test_selective_state_update.py
🧰 Additional context used
📓 Path-based instructions (6)
tests/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
tests/**/*.py: Test implementations should useflashinfer.utilsfunctions (get_compute_capability,is_sm90a_supported,is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing withmpirunon multi-GPU systems, use the pattern:mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes -tests/conftest.pyprovides auto-skipping for OOM tests as a safety net but should not be relied upon
Files:
tests/mamba/__init__.pytests/mamba/selective_state_update_triton.pytests/mamba/test_selective_state_update.py
flashinfer/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
flashinfer/**/*.py: Use@functools.cachedecorator on Python API functions to implement module-level caching and avoid recompilation
Use@flashinfer_apidecorator for debugging API calls, enable viaFLASHINFER_LOGLEVELenvironment variable (0=off, 1=basic, 3=detailed, 5=with stats)
Files:
flashinfer/jit/mamba/selective_state_update.pyflashinfer/__init__.pyflashinfer/mamba/selective_state_update.pyflashinfer/jit/mamba/__init__.pyflashinfer/mamba/__init__.py
flashinfer/jit/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
flashinfer/jit/**/*.py: JIT module generators inflashinfer/jit/must follow the pattern: compute URI → create directory → (optional) render Jinja template → copy sources → return JitSpec
Usegen_jit_spec()function to return a properly configured JitSpec from module generators with appropriatesourcesandextra_cuda_cflags
Specifysupported_major_versionsin JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)
Files:
flashinfer/jit/mamba/selective_state_update.pyflashinfer/jit/mamba/__init__.py
flashinfer/__init__.py
📄 CodeRabbit inference engine (CLAUDE.md)
Export new operations in
flashinfer/__init__.pyto make them available as public API
Files:
flashinfer/__init__.py
csrc/**/*.cu
📄 CodeRabbit inference engine (CLAUDE.md)
Framework bindings and PyTorch tensor handling should be implemented in
csrc/via TVM-FFI, not ininclude/headers
Files:
csrc/selective_state_update.cucsrc/flashinfer_mamba_binding.cu
include/**/*.cuh
📄 CodeRabbit inference engine (CLAUDE.md)
include/**/*.cuh: Torch headers MUST NOT be included in files within theinclude/directory - keep framework-agnostic CUDA kernels that accept raw pointers
Kernel code ininclude/flashinfer/is automatically picked up by JIT compilation on changes - no pip reinstall needed
For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly
Files:
include/flashinfer/mamba/conversion.cuhinclude/flashinfer/mamba/create_tensor_map.cuhinclude/flashinfer/mamba/selective_state_update.cuh
🧠 Learnings (14)
📓 Common learnings
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/__init__.py : Export new operations in `flashinfer/__init__.py` to make them available as public API
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Kernel code in `include/flashinfer/` is automatically picked up by JIT compilation on changes - no pip reinstall needed
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/__init__.py : Export new operations in `flashinfer/__init__.py` to make them available as public API
Applied to files:
tests/mamba/__init__.pyflashinfer/__init__.pyflashinfer/mamba/selective_state_update.pyflashinfer/jit/mamba/__init__.pyflashinfer/mamba/__init__.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/**/*.py : Use `flashinfer_api` decorator for debugging API calls, enable via `FLASHINFER_LOGLEVEL` environment variable (0=off, 1=basic, 3=detailed, 5=with stats)
Applied to files:
tests/mamba/__init__.pyflashinfer/__init__.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/**/*.py : Use `functools.cache` decorator on Python API functions to implement module-level caching and avoid recompilation
Applied to files:
tests/mamba/__init__.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/aot.py : Register new operations in `flashinfer/aot.py` by calling the `gen_*_module()` function for AOT (Ahead-Of-Time) pre-compilation support
Applied to files:
tests/mamba/__init__.pyflashinfer/jit/mamba/selective_state_update.pyflashinfer/mamba/selective_state_update.pyflashinfer/jit/mamba/__init__.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Use `gen_jit_spec()` function to return a properly configured JitSpec from module generators with appropriate `sources` and `extra_cuda_cflags`
Applied to files:
flashinfer/jit/mamba/selective_state_update.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : JIT module generators in `flashinfer/jit/` must follow the pattern: compute URI → create directory → (optional) render Jinja template → copy sources → return JitSpec
Applied to files:
flashinfer/jit/mamba/selective_state_update.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Kernel code in `include/flashinfer/` is automatically picked up by JIT compilation on changes - no pip reinstall needed
Applied to files:
flashinfer/jit/mamba/selective_state_update.pycsrc/flashinfer_mamba_binding.cuinclude/flashinfer/mamba/conversion.cuhinclude/flashinfer/mamba/selective_state_update.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Specify `supported_major_versions` in JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)
Applied to files:
flashinfer/jit/mamba/selective_state_update.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*_jit_binding.cu : Create TVM-FFI bindings in files matching the pattern `csrc/*_jit_binding.cu` using the `TVM_FFI_DLL_EXPORT_TYPED_FUNC(name, func)` macro to expose C++ functions
Applied to files:
csrc/flashinfer_mamba_binding.cu
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers
Applied to files:
csrc/flashinfer_mamba_binding.cuinclude/flashinfer/mamba/create_tensor_map.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly
Applied to files:
include/flashinfer/mamba/conversion.cuhinclude/flashinfer/mamba/selective_state_update.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Torch headers MUST NOT be included in files within the `include/` directory - keep framework-agnostic CUDA kernels that accept raw pointers
Applied to files:
include/flashinfer/mamba/conversion.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Applied to files:
include/flashinfer/mamba/selective_state_update.cuh
🧬 Code graph analysis (7)
flashinfer/jit/mamba/selective_state_update.py (1)
flashinfer/jit/core.py (2)
JitSpec(216-315)gen_jit_spec(318-384)
csrc/selective_state_update.cu (1)
csrc/tvm_ffi_utils.h (2)
get_stream(294-296)encode_dlpack_dtype(30-32)
csrc/flashinfer_mamba_binding.cu (2)
csrc/selective_state_update.cu (2)
selective_state_update(30-204)selective_state_update(30-33)flashinfer/mamba/selective_state_update.py (1)
selective_state_update(38-104)
flashinfer/mamba/selective_state_update.py (5)
flashinfer/api_logging.py (1)
flashinfer_api(464-565)flashinfer/jit/mamba/selective_state_update.py (1)
gen_selective_state_update_module(21-34)flashinfer/jit/core.py (1)
build_and_load(303-315)csrc/flashinfer_mamba_binding.cu (1)
selective_state_update(23-26)csrc/selective_state_update.cu (2)
selective_state_update(30-204)selective_state_update(30-33)
flashinfer/jit/mamba/__init__.py (3)
csrc/flashinfer_mamba_binding.cu (1)
selective_state_update(23-26)csrc/selective_state_update.cu (2)
selective_state_update(30-204)selective_state_update(30-33)flashinfer/jit/mamba/selective_state_update.py (1)
gen_selective_state_update_module(21-34)
flashinfer/mamba/__init__.py (3)
csrc/flashinfer_mamba_binding.cu (1)
selective_state_update(23-26)csrc/selective_state_update.cu (2)
selective_state_update(30-204)selective_state_update(30-33)flashinfer/mamba/selective_state_update.py (1)
selective_state_update(38-104)
tests/mamba/test_selective_state_update.py (3)
csrc/flashinfer_mamba_binding.cu (1)
selective_state_update(23-26)flashinfer/mamba/selective_state_update.py (1)
selective_state_update(38-104)flashinfer/logits_processor/types.py (1)
size(132-136)
🪛 Ruff (0.14.10)
tests/mamba/selective_state_update_triton.py
70-70: Unused function argument: batch
(ARG001)
71-71: Unused function argument: nheads
(ARG001)
284-284: Do not assign a lambda expression, use a def
Rewrite grid as a def
(E731)
tests/mamba/test_selective_state_update.py
11-11: Unused function argument: matrixA_dtype
(ARG001)
🔇 Additional comments (17)
tests/mamba/__init__.py (1)
1-15: LGTM!Standard license header for the new test package.
include/flashinfer/mamba/conversion.cuh (1)
1-31: LGTM!Clean, framework-agnostic type conversion utilities. The design correctly keeps Torch headers out of
include/as per coding guidelines.include/flashinfer/mamba/create_tensor_map.cuh (1)
44-86: LGTM with minor note.Well-structured TMA tensor map creation utility. The type mapping via
if constexprchain withstatic_assertfallback is idiomatic.Note: The
<iostream>include is only needed for the error paths. Consider usingfprintf(stderr, ...)consistently for error output to avoid pulling in iostream overhead, though this is minor for a header used in kernel code.flashinfer/__init__.py (1)
156-156: LGTM!Correctly exports the new
mambamodule as a public API, following the project's coding guidelines.flashinfer/mamba/__init__.py (1)
1-19: LGTM!Clean package initialization that properly exports
selective_state_updateas the public API for the mamba module.flashinfer/jit/mamba/selective_state_update.py (2)
22-25: No action needed. The flag naming inselective_state_update.pyis correct and intentional.The mamba kernel headers (
conversion.cuh) use short-form macrosENABLE_BF16andENABLE_FP8. The-DENABLE_BF16and-DENABLE_FP8flags defined inselective_state_update.pycorrectly match these expectations. The longerFLASHINFER_ENABLE_*variants fromgen_jit_specdefaults are for other modules and safely coexist with the mamba-specific flags when both are compiled in viaextra_cuda_cflags.Likely an incorrect or invalid review comment.
21-34: The review comment is based on a coding guideline that does not match the actual implementation. Thegen_jit_spec()function andJitSpecclass do not have asupported_major_versionsparameter.The correct pattern used elsewhere in the codebase (e.g.,
flashinfer/jit/mla.py,flashinfer/jit/comm.py) is:nvcc_flags = current_compilation_context.get_nvcc_flags_list( supported_major_versions=[10, 11] ) return gen_jit_spec( "module_name", sources, extra_cuda_cflags=nvcc_flags, )The current implementation in
selective_state_update.pyis valid as-is. If architecture-specific restrictions are needed for BF16/FP8 features, the module would need to follow the pattern above usingcurrent_compilation_context.get_nvcc_flags_list()rather than a non-existentsupported_major_versionsparameter.Likely an incorrect or invalid review comment.
flashinfer/jit/mamba/__init__.py (1)
1-19: LGTM!The module initialization correctly imports and exports
gen_selective_state_update_module, following the established pattern for JIT module generators in the codebase.csrc/flashinfer_mamba_binding.cu (1)
1-34: LGTM!The TVM-FFI binding correctly declares the forward reference to the implementation in
selective_state_update.cuand exports it using the standard macro. The signature matches the implementation, and the namespace organization follows project conventions. Based on learnings, this correctly implements framework bindings incsrc/via TVM-FFI.tests/mamba/test_selective_state_update.py (1)
100-209: LGTM!The test correctly validates both output tensors and updated state tensors against the Triton reference implementation. The detailed mismatch analysis with element-wise comparison will be valuable for debugging numerical issues.
csrc/selective_state_update.cu (2)
30-105: LGTM!The tensor validation is comprehensive, correctly checking shapes, strides, and broadcasting patterns for all inputs. The contiguity and stride requirements are well-documented with inline comments.
191-203: Thethrow std::runtime_errorstatement is not unreachable. Based on usage patterns in the codebase (e.g.,csrc/tvm_ffi_utils.h),TVM_FFI_ICHECK(false)does not throw or abort execution—it logs the error message and allows execution to continue. The explicitthrowis necessary to actually terminate execution with an exception when an unsupported dtype combination is encountered.tests/mamba/selective_state_update_triton.py (2)
24-36: LGTM!The version-aware softplus implementation correctly handles Triton 3.x API differences. Using
tl.math.log1p(tl.exp(dt))for earlier versions is a good numerical stability choice.
42-208: LGTM!The Triton kernel implementation follows the reference Mamba implementation correctly. The heuristics for constexpr parameters, padding support via
pad_slot_id, and optional tensor handling are well-structured.flashinfer/mamba/selective_state_update.py (3)
27-34: LGTM!The module caching correctly uses
@functools.cacheas per coding guidelines, with a clear docstring explaining both in-memory and disk caching behavior.
37-104: LGTM!The public API correctly applies
@flashinfer_apifor debugging support. The docstring is comprehensive, documenting all parameters including optional tensors and their expected shapes. The implementation correctly creates the output tensor and delegates to the internal function.
107-159: LGTM!The custom op registration correctly specifies
mutates_args=("state", "output")to indicate in-place mutation, which is essential for proper torch.compile() graph capture. The fake op implementation correctly returnsNonefor meta tensor propagation.
make input tensors always have batch and nheads as dimensions. Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In @flashinfer/mamba/selective_state_update.py:
- Line 57: Update the docstring in selective_state_update.py to say the first
dimension of the state tensor is state_cache_size (not batch); clarify that the
state shape is (state_cache_size, dim, dstate) or (state_cache_size, nheads,
dim, dstate) and note that state_cache_size may differ from batch when using
state_batch_indices, matching the C++ usage where state.size(0) is assigned to
state_cache_size.
- Line 45: The parameter D is annotated as torch.Tensor but the code checks it
for None; change the function signature to use Optional[torch.Tensor] for D and
add "from typing import Optional" to imports, leaving the existing None check at
line 100 in place so callers passing None are handled correctly; ensure any
usages in functions/methods that consume D handle the Optional type (e.g.,
conditional logic already present).
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer/mamba/selective_state_update.py
🧰 Additional context used
📓 Path-based instructions (1)
flashinfer/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
flashinfer/**/*.py: Use@functools.cachedecorator on Python API functions to implement module-level caching and avoid recompilation
Use@flashinfer_apidecorator for debugging API calls, enable viaFLASHINFER_LOGLEVELenvironment variable (0=off, 1=basic, 3=detailed, 5=with stats)
Files:
flashinfer/mamba/selective_state_update.py
🧠 Learnings (3)
📓 Common learnings
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Specify `supported_major_versions` in JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*_jit_binding.cu : Create TVM-FFI bindings in files matching the pattern `csrc/*_jit_binding.cu` using the `TVM_FFI_DLL_EXPORT_TYPED_FUNC(name, func)` macro to expose C++ functions
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/aot.py : Register new operations in `flashinfer/aot.py` by calling the `gen_*_module()` function for AOT (Ahead-Of-Time) pre-compilation support
Applied to files:
flashinfer/mamba/selective_state_update.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/__init__.py : Export new operations in `flashinfer/__init__.py` to make them available as public API
Applied to files:
flashinfer/mamba/selective_state_update.py
🧬 Code graph analysis (1)
flashinfer/mamba/selective_state_update.py (4)
flashinfer/jit/mamba/selective_state_update.py (1)
gen_selective_state_update_module(21-34)flashinfer/jit/core.py (1)
build_and_load(303-315)csrc/flashinfer_mamba_binding.cu (1)
selective_state_update(23-26)csrc/selective_state_update.cu (2)
selective_state_update(30-204)selective_state_update(30-33)
🔇 Additional comments (6)
flashinfer/mamba/selective_state_update.py (6)
17-24: LGTM!The imports are well-organized and include all necessary dependencies for JIT compilation, custom op registration, and API logging.
27-34: LGTM!The function correctly applies
@functools.cachefor module-level caching as per coding guidelines. The JIT system handles both in-memory and disk-based caching appropriately.
88-105: LGTM!The dimension handling logic correctly addresses the past review feedback by automatically unsqueezing tensors to match expected dimensions. This allows the API to accept both 2D/3D inputs (without nheads) and 3D/4D inputs (with nheads), improving usability. The C++ backend will validate the final shapes and provide clear error messages for incompatible dimensions.
106-122: LGTM!The output tensor is efficiently allocated using
torch.empty_like(x), and all parameters are correctly passed to the internal operation. The mutation ofoutputis properly declared in the custom op registration.
125-158: LGTM!The internal operation is correctly registered with
mutates_args=("state", "output")to declare side effects. The implementation properly delegates to the JIT-compiled module for efficient execution.
161-178: LGTM!The fake operation is correctly implemented for
torch.compile()support. Since this operation mutates existing tensors without changing their shapes, the emptypassimplementation is appropriate for meta tensor propagation.
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (2)
csrc/selective_state_update.cu (2)
114-115: Use consistent validation macro and remove incomplete comment.Line 114 contains an incomplete comment (
if(z.has_value())), and line 115 uses an inconsistent pattern. For consistency with the rest of this file, use the standardFLASHINFER_CHECKmacro:♻️ Proposed refactoring
- // if(z.has_value()) - FLASHINFER_CHECK(!z.has_value() && "z is not supported yet"); + FLASHINFER_CHECK(!z.has_value(), "z is not supported yet");
151-153: Remove dead code.Lines 151-153 are unreachable because line 115 already validates that
zmust not have a value (raises an error if it does). This conditional will never execute.♻️ Proposed refactoring
if (dt_bias) { p.dt_bias = dt_bias.value().data_ptr(); } - if (z) { - p.z = z.value().data_ptr(); - } p.A = A.data_ptr();
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
csrc/selective_state_update.cu
🧰 Additional context used
📓 Path-based instructions (1)
csrc/**/*.cu
📄 CodeRabbit inference engine (CLAUDE.md)
Framework bindings and PyTorch tensor handling should be implemented in
csrc/via TVM-FFI, not ininclude/headers
Files:
csrc/selective_state_update.cu
🧠 Learnings (2)
📓 Common learnings
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Specify `supported_major_versions` in JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*_jit_binding.cu : Create TVM-FFI bindings in files matching the pattern `csrc/*_jit_binding.cu` using the `TVM_FFI_DLL_EXPORT_TYPED_FUNC(name, func)` macro to expose C++ functions
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
csrc/selective_state_update.cu
🧬 Code graph analysis (1)
csrc/selective_state_update.cu (1)
csrc/tvm_ffi_utils.h (2)
get_stream(294-296)encode_dlpack_dtype(30-32)
🔇 Additional comments (4)
csrc/selective_state_update.cu (4)
16-23: LGTM!The includes and namespace setup follow the expected patterns for a CUDA implementation file with TVM-FFI bindings.
25-36: LGTM!The dimension extraction and divisibility check correctly establish the tensor geometry required for the selective state update operation.
159-161: LGTM!The device guard and stream retrieval correctly establish the CUDA execution context for the kernel launch.
179-179: The dtype code constantsbfloat16_codeandfloat32_codeare properly defined incsrc/tvm_ffi_utils.h(lines 53-54) and correctly included inselective_state_update.cuvia#include "tvm_ffi_utils.h". No issues found.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
flashinfer/mamba/selective_state_update.py (1)
46-47: Consider validating unsupportedzparameter early for clearer error messages.The C++ implementation explicitly rejects the
zparameter withFLASHINFER_CHECK(!z.has_value() && "z is not supported yet"). While the error will be caught, users will receive a C++ exception rather than a clear Python-level error message. Consider adding early validation for better UX.♻️ Proposed validation
Add this check before line 106:
if z is not None: raise NotImplementedError("z parameter is not supported yet")Alternatively, update the docstring to explicitly note that
zis currently unsupported.Also applies to: 116-116
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer/mamba/selective_state_update.py
🧰 Additional context used
📓 Path-based instructions (1)
flashinfer/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
flashinfer/**/*.py: Use@functools.cachedecorator on Python API functions to implement module-level caching and avoid recompilation
Use@flashinfer_apidecorator for debugging API calls, enable viaFLASHINFER_LOGLEVELenvironment variable (0=off, 1=basic, 3=detailed, 5=with stats)
Files:
flashinfer/mamba/selective_state_update.py
🧠 Learnings (3)
📓 Common learnings
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Specify `supported_major_versions` in JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*_jit_binding.cu : Create TVM-FFI bindings in files matching the pattern `csrc/*_jit_binding.cu` using the `TVM_FFI_DLL_EXPORT_TYPED_FUNC(name, func)` macro to expose C++ functions
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/aot.py : Register new operations in `flashinfer/aot.py` by calling the `gen_*_module()` function for AOT (Ahead-Of-Time) pre-compilation support
Applied to files:
flashinfer/mamba/selective_state_update.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/__init__.py : Export new operations in `flashinfer/__init__.py` to make them available as public API
Applied to files:
flashinfer/mamba/selective_state_update.py
🧬 Code graph analysis (1)
flashinfer/mamba/selective_state_update.py (5)
flashinfer/api_logging.py (1)
flashinfer_api(464-565)flashinfer/jit/mamba/selective_state_update.py (1)
gen_selective_state_update_module(21-34)flashinfer/jit/core.py (1)
build_and_load(303-315)csrc/flashinfer_mamba_binding.cu (1)
selective_state_update(23-26)csrc/selective_state_update.cu (2)
selective_state_update(25-197)selective_state_update(25-28)
🔇 Additional comments (4)
flashinfer/mamba/selective_state_update.py (4)
27-34: LGTM: Proper use of module-level caching.The
@functools.cachedecorator correctly implements module-level caching as per coding guidelines, ensuring the JIT module is built and loaded only once.
88-105: LGTM: Dimension expansion logic correctly implemented.The dimension expansion logic properly handles tensors with fewer dimensions, expanding them to match the C++ backend's expectations. This addresses the concern raised in previous reviews.
125-158: LGTM: Correct custom op registration for torch.compile() support.The custom op is properly registered with torch.library, correctly specifying
mutates_args=("state", "output")and delegating to the JIT-compiled module.
161-178: LGTM: Correct fake op implementation for meta tensor propagation.The fake op correctly mirrors the custom op signature and uses a pass implementation, which is appropriate for torch.compile()'s shape and dtype inference without actual computation.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In @include/flashinfer/mamba/selective_state_update.cuh:
- Around line 63-68: Add a brief comment above the __device__ __forceinline__
float fast_exp(float x) function that documents the performance/precision
tradeoff: note that it uses PTX ex2.approx.f32 for a faster but lower-precision
exp2 approximation (multiplying x by log2_E) and that this is chosen because
this is a performance-critical hot path; mention potential alternatives (e.g.,
using standard expf, a higher-precision polynomial/fast-math approximation, or a
compiler intrinsic) and when to switch to them for correctness over speed.
- Around line 491-495: The kernel launch currently checks cudaGetLastError() and
prints the error with printf but does not propagate it; replace the printf
branch with a call to FLASHINFER_CHECK(err) (using the existing cudaError_t err
from cudaGetLastError()) so the error is surfaced to callers; remove or replace
the printf("Kernel launch failed: %s\n", cudaGetErrorString(err)) with
FLASHINFER_CHECK(err) to ensure failures are propagated rather than silently
continuing.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
include/flashinfer/mamba/selective_state_update.cuh
🧰 Additional context used
📓 Path-based instructions (1)
include/**/*.cuh
📄 CodeRabbit inference engine (CLAUDE.md)
include/**/*.cuh: Torch headers MUST NOT be included in files within theinclude/directory - keep framework-agnostic CUDA kernels that accept raw pointers
Kernel code ininclude/flashinfer/is automatically picked up by JIT compilation on changes - no pip reinstall needed
For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly
Files:
include/flashinfer/mamba/selective_state_update.cuh
🧠 Learnings (5)
📓 Common learnings
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Specify `supported_major_versions` in JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*_jit_binding.cu : Create TVM-FFI bindings in files matching the pattern `csrc/*_jit_binding.cu` using the `TVM_FFI_DLL_EXPORT_TYPED_FUNC(name, func)` macro to expose C++ functions
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Applied to files:
include/flashinfer/mamba/selective_state_update.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Kernel code in `include/flashinfer/` is automatically picked up by JIT compilation on changes - no pip reinstall needed
Applied to files:
include/flashinfer/mamba/selective_state_update.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly
Applied to files:
include/flashinfer/mamba/selective_state_update.cuh
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
include/flashinfer/mamba/selective_state_update.cuh
🔇 Additional comments (8)
include/flashinfer/mamba/selective_state_update.cuh (8)
1-35: LGTM!The header follows coding guidelines by remaining framework-agnostic (no Torch dependencies) and includes appropriate CUDA cooperative groups, barriers, and flashinfer utilities.
36-56: LGTM!The parameter struct is well-documented with tensor shape annotations for each pointer, and provides sensible defaults for optional parameters.
102-111: Note: Limited dtype support.Currently only
__nv_bfloat16is specialized for vectorized loads. If other dtype combinations (fp16, fp32) are expected, they'll need additional specializations.
113-236: LGTM with known inconsistency.The simple kernel correctly implements vectorized state updates with warp-level parallelism and optional z-gating. The use of
fast_expat line 164 is consistent with the commit message indicating it was modified to use fast_exp.Note: A past review flagged the
fast_expvs__expfinconsistency between kernels, which remains unresolved.
238-252: LGTM!The shared storage is correctly aligned for TMA (Tensor Memory Accelerator) and provides appropriate barrier synchronization for the multi-stage producer-consumer pattern.
254-450: LGTM!The producer-consumer kernel correctly implements a multi-stage vertical tiling pattern with TMA for Hopper+ GPUs. The barrier synchronization and async copy patterns are properly structured.
484-488: LGTM!The alignment check for TMA (line 484-485) and dimension divisibility check (line 488) are good defensive programming practices that catch invalid configurations before kernel launch.
454-473: Hardcoded dimensions severely limit kernel usability.The launcher hardcodes
DSTATE=128(line 454) andDIM=64(line 472) for the Hopper kernel path, causing failures for any other dimension values. This is a severe limitation that prevents the kernel from being used in general scenarios.As noted in past reviews, consider making dimensions generic via template parameters or dispatch macros to support arbitrary
dstateanddimvalues.⛔ Skipped due to learnings
Learnt from: raayandhar Repo: flashinfer-ai/flashinfer PR: 2070 File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160 Timestamp: 2025-11-12T03:35:17.583Z Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In @include/flashinfer/mamba/selective_state_update.cuh:
- Line 409: Hoist the dA computation out of the innermost i-loop: compute dA =
__expf(A_value * dt_value) once after dt_value and A_value are set (before
entering the loops) instead of recomputing it inside the loop where the line
"auto const dA = __expf(A_value * dt_value);" currently resides; update any uses
inside the loop to reference this precomputed dA and remove the redundant
computation in the innermost loop.
🧹 Nitpick comments (1)
include/flashinfer/mamba/selective_state_update.cuh (1)
222-222: Consider simplifying negation syntax.Using
__expf(-z_value)is more idiomatic than__expf(0.f - z_value).♻️ Suggested simplification
- float sig_z = __fdividef(1.f, (1.f + __expf(0.f - z_value))); + float sig_z = __fdividef(1.f, (1.f + __expf(-z_value)));Apply the same change at both locations (lines 222 and 435).
Also applies to: 435-435
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
include/flashinfer/mamba/selective_state_update.cuh
🧰 Additional context used
📓 Path-based instructions (1)
include/**/*.cuh
📄 CodeRabbit inference engine (CLAUDE.md)
include/**/*.cuh: Torch headers MUST NOT be included in files within theinclude/directory - keep framework-agnostic CUDA kernels that accept raw pointers
Kernel code ininclude/flashinfer/is automatically picked up by JIT compilation on changes - no pip reinstall needed
For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly
Files:
include/flashinfer/mamba/selective_state_update.cuh
🧠 Learnings (7)
📓 Common learnings
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Specify `supported_major_versions` in JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*_jit_binding.cu : Create TVM-FFI bindings in files matching the pattern `csrc/*_jit_binding.cu` using the `TVM_FFI_DLL_EXPORT_TYPED_FUNC(name, func)` macro to expose C++ functions
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Kernel code in `include/flashinfer/` is automatically picked up by JIT compilation on changes - no pip reinstall needed
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Applied to files:
include/flashinfer/mamba/selective_state_update.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Kernel code in `include/flashinfer/` is automatically picked up by JIT compilation on changes - no pip reinstall needed
Applied to files:
include/flashinfer/mamba/selective_state_update.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly
Applied to files:
include/flashinfer/mamba/selective_state_update.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Torch headers MUST NOT be included in files within the `include/` directory - keep framework-agnostic CUDA kernels that accept raw pointers
Applied to files:
include/flashinfer/mamba/selective_state_update.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers
Applied to files:
include/flashinfer/mamba/selective_state_update.cuh
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
include/flashinfer/mamba/selective_state_update.cuh
Removes an unnecessary `toFloat` conversion on `d_value` in the vertical producer-consumer kernel, as `d_value` is already a float type. Fixes: flashinfer-ai#2301 (comment)
Updates JIT flags to use generic compute_90 PTX for forward compatibility with SM100. Adds runtime dispatch logic to select the optimized SM90 kernel on devices with compute capability >= 9.0. Addresses flashinfer-ai#2301 (comment)
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In @flashinfer/jit/mamba/selective_state_update.py:
- Around line 36-52: The SM90-specific JIT spec in
gen_selective_state_update_sm90_module currently lacks an explicit
supported_major_versions restriction; update the gen_jit_spec(...) call in
gen_selective_state_update_sm90_module to include
supported_major_versions=[9,10] (Hopper and future SM100) so the module only
compiles for Hopper+ GPUs, and optionally change nvcc_flags construction to use
iterable unpacking (e.g., nvcc_flags = ["-gencode=...", "-DENABLE_BF16",
"-DFLASHINFER_MAMBA_ENABLE_SM90", *common_nvcc_flags]) for cleaner list
composition.
🧹 Nitpick comments (1)
include/flashinfer/mamba/selective_state_update.cuh (1)
497-497: Add error message to dimension divisibility check.The assertion at line 497 should include a descriptive message to help users understand the constraint.
♻️ Proposed fix to add error message
- FLASHINFER_CHECK(params.dim % rowsPerStage == 0); + FLASHINFER_CHECK(params.dim % rowsPerStage == 0, + "dim must be divisible by ", rowsPerStage, " for SM90+ kernel");
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
flashinfer/jit/mamba/selective_state_update.pyflashinfer/mamba/selective_state_update.pyinclude/flashinfer/mamba/selective_state_update.cuh
🧰 Additional context used
📓 Path-based instructions (3)
flashinfer/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
flashinfer/**/*.py: Use@functools.cachedecorator on Python API functions to implement module-level caching and avoid recompilation
Use@flashinfer_apidecorator for debugging API calls, enable viaFLASHINFER_LOGLEVELenvironment variable (0=off, 1=basic, 3=detailed, 5=with stats)
Files:
flashinfer/jit/mamba/selective_state_update.pyflashinfer/mamba/selective_state_update.py
flashinfer/jit/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
flashinfer/jit/**/*.py: JIT module generators inflashinfer/jit/must follow the pattern: compute URI → create directory → (optional) render Jinja template → copy sources → return JitSpec
Usegen_jit_spec()function to return a properly configured JitSpec from module generators with appropriatesourcesandextra_cuda_cflags
Specifysupported_major_versionsin JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)
Files:
flashinfer/jit/mamba/selective_state_update.py
include/**/*.cuh
📄 CodeRabbit inference engine (CLAUDE.md)
include/**/*.cuh: Torch headers MUST NOT be included in files within theinclude/directory - keep framework-agnostic CUDA kernels that accept raw pointers
Kernel code ininclude/flashinfer/is automatically picked up by JIT compilation on changes - no pip reinstall needed
For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly
Files:
include/flashinfer/mamba/selective_state_update.cuh
🧠 Learnings (10)
📓 Common learnings
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Specify `supported_major_versions` in JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*_jit_binding.cu : Create TVM-FFI bindings in files matching the pattern `csrc/*_jit_binding.cu` using the `TVM_FFI_DLL_EXPORT_TYPED_FUNC(name, func)` macro to expose C++ functions
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Kernel code in `include/flashinfer/` is automatically picked up by JIT compilation on changes - no pip reinstall needed
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Use `gen_jit_spec()` function to return a properly configured JitSpec from module generators with appropriate `sources` and `extra_cuda_cflags`
Applied to files:
flashinfer/jit/mamba/selective_state_update.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : JIT module generators in `flashinfer/jit/` must follow the pattern: compute URI → create directory → (optional) render Jinja template → copy sources → return JitSpec
Applied to files:
flashinfer/jit/mamba/selective_state_update.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Kernel code in `include/flashinfer/` is automatically picked up by JIT compilation on changes - no pip reinstall needed
Applied to files:
flashinfer/jit/mamba/selective_state_update.pyinclude/flashinfer/mamba/selective_state_update.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Specify `supported_major_versions` in JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)
Applied to files:
flashinfer/jit/mamba/selective_state_update.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/aot.py : Register new operations in `flashinfer/aot.py` by calling the `gen_*_module()` function for AOT (Ahead-Of-Time) pre-compilation support
Applied to files:
flashinfer/jit/mamba/selective_state_update.pyflashinfer/mamba/selective_state_update.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/__init__.py : Export new operations in `flashinfer/__init__.py` to make them available as public API
Applied to files:
flashinfer/mamba/selective_state_update.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Applied to files:
include/flashinfer/mamba/selective_state_update.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly
Applied to files:
include/flashinfer/mamba/selective_state_update.cuh
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
include/flashinfer/mamba/selective_state_update.cuh
🧬 Code graph analysis (2)
flashinfer/jit/mamba/selective_state_update.py (1)
flashinfer/jit/core.py (2)
JitSpec(216-397)gen_jit_spec(400-466)
flashinfer/mamba/selective_state_update.py (3)
flashinfer/utils.py (1)
get_compute_capability(258-261)csrc/selective_state_update.cu (2)
selective_state_update(25-209)selective_state_update(25-28)csrc/flashinfer_mamba_binding.cu (1)
selective_state_update(23-26)
🪛 Ruff (0.14.10)
flashinfer/jit/mamba/selective_state_update.py
39-43: Consider iterable unpacking instead of concatenation
Replace with iterable unpacking
(RUF005)
🔇 Additional comments (12)
flashinfer/jit/mamba/selective_state_update.py (1)
21-33: LGTM! Base module configuration is correct.The base module uses appropriate flags for pre-Hopper GPUs and doesn't need architecture restrictions.
include/flashinfer/mamba/selective_state_update.cuh (7)
36-101: LGTM! Parameter struct and helper utilities are well-designed.The
SelectiveStateUpdateParamsstruct efficiently packs all kernel parameters, and the helper functions follow standard CUDA patterns for warp-level operations and type conversions.
103-226: LGTM! Simple kernel implementation is correct.The pre-Hopper kernel uses appropriate shared memory, vectorized loads, and warp-level synchronization. The handling of the pad_slot_id for skipping padded entries is correct.
228-296: LGTM! Producer-consumer synchronization setup is correct.The multi-stage shared storage with CUDA barriers is properly configured. The 128-byte alignment for TMA operations and barrier initialization with correct arrival counts follow SM90+ best practices.
299-335: LGTM! Producer warp TMA pipeline is correctly implemented.The multi-stage TMA pipeline with barrier synchronization follows the correct pattern for overlapping loads and stores. The handling of pad_slot_id and barrier arrival tracking is correct.
336-436: LGTM! Consumer warp computation is correctly synchronized.The consumer warps properly coordinate with the producer through barriers, and the computation logic matches the simple kernel. The parallel loading of inputs across warps is an effective optimization.
440-488: LGTM! Pre-Hopper launcher path with comprehensive alignment checks.The alignment validation for vectorized loads is thorough, checking pointers, strides, and sizes. The dispatch logic correctly handles supported dstate values.
490-544: LGTM! SM90+ launcher path with proper TMA setup.The SM90+ path correctly creates the tensor map for TMA operations and dispatches based on dim and dstate. The 128-byte alignment requirement is properly enforced.
flashinfer/mamba/selective_state_update.py (4)
30-46: LGTM! Module loading with proper caching and device-aware selection.The use of
@functools.cachefollows coding guidelines for module-level caching, and the compute capability check correctly selects between base and SM90 modules.
49-134: LGTM! Public API with proper dimension normalization.The function correctly handles both 2D (no head dimension) and 3D (with head dimension) inputs by normalizing to 3D/4D shapes. The use of
@flashinfer_apidecorator enables debug logging per coding guidelines.
137-170: LGTM! Custom op registration for torch.compile support.The registration correctly declares mutation of
stateandoutputtensors, and properly dispatches to the device-appropriate JIT module.
173-190: LGTM! Fake op implementation for meta tensor support.The fake op correctly provides a no-op implementation for torch.compile meta tensor propagation, following PyTorch's custom op pattern.
Include the required divisor and kernel variant in the error message for easier debugging when dim is not properly aligned. Address: flashinfer-ai#2301 (review)
This enables architecture filtering via supported_major_versions and removes reliance on hardcoded nvcc flags. Future architectures are now supported automatically. address flashinfer-ai#2301 (comment)
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (3)
include/flashinfer/mamba/selective_state_update.cuh (3)
16-35: Consider adding high-level documentation.This header lacks documentation explaining what selective state update does, the mathematical formulation, and how the two kernel variants differ (simple vs. producer-consumer). High-level comments would help future reviewers understand the algorithm at a glance.
Based on coding guidelines, performance-critical hot paths benefit from comments explaining special algorithmic choices.
244-438: Consider adding more inline comments for the producer-consumer algorithm.This kernel implements a sophisticated multi-stage producer-consumer pipeline with TMA. While the implementation appears correct, the complex synchronization logic (barriers, staging, TMA operations) would benefit from more inline comments explaining:
- Why
rowsPerStage = 4 * numConsumersandnumStages = 3were chosen- The producer's writeback-then-load pattern (lines 310-329)
- The barrier arrival semantics and byte counting for TMA
Based on coding guidelines, performance-critical hot paths should document special algorithmic choices for future reviewers.
489-544: Note: Limited dim support in SM90+ path.The SM90+ kernel only supports
dim=64ordim=128(lines 533-542), while the pre-Hopper kernel supports any dim that fits in the grid. This is a conscious design trade-off for compile-time optimization, but it may limit flexibility. Consider documenting this constraint in the function/struct documentation, or adding support for additional dim values if needed.The restriction is enforced with a clear error message, but users might benefit from knowing the supported values upfront.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
flashinfer/jit/mamba/selective_state_update.pyinclude/flashinfer/mamba/selective_state_update.cuh
🚧 Files skipped from review as they are similar to previous changes (1)
- flashinfer/jit/mamba/selective_state_update.py
🧰 Additional context used
📓 Path-based instructions (1)
include/**/*.cuh
📄 CodeRabbit inference engine (CLAUDE.md)
include/**/*.cuh: Torch headers MUST NOT be included in files within theinclude/directory - keep framework-agnostic CUDA kernels that accept raw pointers
Kernel code ininclude/flashinfer/is automatically picked up by JIT compilation on changes - no pip reinstall needed
For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly
Files:
include/flashinfer/mamba/selective_state_update.cuh
🧠 Learnings (6)
📓 Common learnings
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Specify `supported_major_versions` in JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*_jit_binding.cu : Create TVM-FFI bindings in files matching the pattern `csrc/*_jit_binding.cu` using the `TVM_FFI_DLL_EXPORT_TYPED_FUNC(name, func)` macro to expose C++ functions
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Kernel code in `include/flashinfer/` is automatically picked up by JIT compilation on changes - no pip reinstall needed
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Applied to files:
include/flashinfer/mamba/selective_state_update.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly
Applied to files:
include/flashinfer/mamba/selective_state_update.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Kernel code in `include/flashinfer/` is automatically picked up by JIT compilation on changes - no pip reinstall needed
Applied to files:
include/flashinfer/mamba/selective_state_update.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Torch headers MUST NOT be included in files within the `include/` directory - keep framework-agnostic CUDA kernels that accept raw pointers
Applied to files:
include/flashinfer/mamba/selective_state_update.cuh
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
include/flashinfer/mamba/selective_state_update.cuh
🔇 Additional comments (7)
include/flashinfer/mamba/selective_state_update.cuh (7)
36-56: LGTM!The parameter struct is well-designed with clear shape documentation and framework-agnostic void pointers. Default initialization prevents uninitialized usage.
58-90: LGTM!The utility functions are correctly implemented. The
thresholded_softplusproperly guards against overflow insoftplus, and the warp reduction follows the standard pattern.
92-101: LGTM!The vectorized load traits are correctly defined for bfloat16. The chunk_size calculation is accurate (8 bytes / 2 bytes = 4 elements, but stored as 2 in the code... wait, let me recalculate:
sizeof(input) / sizeof(__nv_bfloat16)=sizeof(float2) / sizeof(__nv_bfloat16)=8 / 2= 4).Note: Future dtype support will require additional trait specializations.
103-226: LGTM!The simple kernel implementation is correct:
- Proper bounds checking and pad slot handling
- Correct state update math:
new_state = state * exp(A*dt) + B*dt*x- Efficient vectorized loads and warp reduction
- Correct SiLU gating when z is present
228-242: LGTM!The shared storage is correctly designed for producer-consumer synchronization with proper 128-byte alignment for TMA operations.
451-464: LGTM!The alignment checks are thorough and provide helpful error messages indicating the required alignment and kernel variant, exactly as mentioned in the PR objectives.
440-545: Verify the supported dim and dstate values meet current requirements.The launcher restricts:
- SM90+ path: dim ∈ {64, 128}, dstate ∈ {64, 128, 256}
- Pre-Hopper path: dstate ∈ {64, 128, 256}, any dim
Confirm these cover all Nemotron/Mamba model configurations currently in use.
If additional configurations are needed in the future, the switch statements at lines 475-487 and 518-542 will need extension.
latter can display the number of mismatch elements and the location of mismatch elements.
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (4)
tests/mamba/test_selective_state_update.py (4)
1-7: Missing GPU architecture skip logic.Per coding guidelines, test implementations should use
flashinfer.utilsfunctions to skip tests on unsupported GPU architectures. The selective_state_update kernel has SM90-optimized paths and requires specific GPU support.Consider adding architecture checks:
import flashinfer from flashinfer.utils import get_compute_capability # At the start of each test or as a module-level skip @pytest.fixture(autouse=True) def skip_unsupported_gpu(): cc = get_compute_capability() if cc < (8, 0): pytest.skip("selective_state_update requires SM80+")
85-93: Consider expanding parameter coverage.The test only covers
delta_softplus=Trueandngroups=8. Consider adding:
delta_softplus=[True, False]to test both code paths- Additional
ngroupsvalues to verify head grouping logic
150-185: Consider usingtorch.testing.assert_allclosefor consistency.The
test_selective_state_update_with_zfunction usestorch.testing.assert_allclosedirectly, which provides better error messages. Consider using the same approach here instead of manualtorch.allclose+ print statements for consistency across tests.♻️ Suggested refactor
- atol = 1e-3 - rtol = 1e-2 - outputs_match = torch.allclose(y_ref, y_test, atol=atol, rtol=rtol) - - if outputs_match: - print(f"✓ Outputs match within tolerance (atol={atol}, rtol={rtol})") - else: - print(f"✗ Outputs do NOT match within tolerance (atol={atol}, rtol={rtol})") - - # Detailed comparison using numpy testing - y_ref_np = y_ref.detach().cpu().float().numpy() - y_test_np = y_test.detach().cpu().float().numpy() - # ... detailed mismatch analysis ... - - assert outputs_match + atol = 1e-3 + rtol = 1e-2 + torch.testing.assert_allclose(y_ref, y_test, atol=atol, rtol=rtol)
188-227: Same suggestion: usetorch.testing.assert_allclosefor state comparison.Apply the same simplification for the state tensor comparison to match the pattern in
test_selective_state_update_with_z.♻️ Suggested refactor
- # Check if states match within tolerance - states_match = torch.allclose( - state_ref[inputs["slot_idx"]], - state[inputs["slot_idx"]], - atol=atol, - rtol=rtol, - ) - - if states_match: - print(f"✓ States match within tolerance (atol={atol}, rtol={rtol})") - else: - # ... detailed mismatch analysis ... - - assert states_match + torch.testing.assert_allclose( + state_ref[inputs["slot_idx"]], + state[inputs["slot_idx"]], + atol=atol, + rtol=rtol, + )
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tests/mamba/test_selective_state_update.py
🧰 Additional context used
📓 Path-based instructions (1)
tests/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
tests/**/*.py: Test implementations should useflashinfer.utilsfunctions (get_compute_capability,is_sm90a_supported,is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing withmpirunon multi-GPU systems, use the pattern:mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes -tests/conftest.pyprovides auto-skipping for OOM tests as a safety net but should not be relied upon
Files:
tests/mamba/test_selective_state_update.py
🧠 Learnings (1)
📓 Common learnings
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers
🧬 Code graph analysis (1)
tests/mamba/test_selective_state_update.py (4)
tests/mamba/selective_state_update_triton.py (1)
selective_state_update_triton(211-356)flashinfer/mamba/selective_state_update.py (1)
selective_state_update(50-134)csrc/selective_state_update.cu (2)
selective_state_update(25-209)selective_state_update(25-28)csrc/flashinfer_mamba_binding.cu (1)
selective_state_update(23-26)
🔇 Additional comments (2)
tests/mamba/test_selective_state_update.py (2)
10-82: LGTM!The test input generator is well-structured with:
- Reproducible seeding
- Correct strided tensor creation matching kernel expectations
- Proper shape and dtype handling
The use of
torch.rand(uniform) for A instead oftorch.randn(normal) appears intentional for numerical stability, aligning with the commit note about "choice of the A tensor."
230-294: LGTM!Clean test implementation that:
- Covers the
ztensor code path specifically- Uses
torch.testing.assert_allclosefor concise assertions with informative error messages- Tests both output and state tensor correctness
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
include/flashinfer/mamba/conversion.cuh (1)
7-9: Unused FP8 include — consider removing or implementing.
cuda_fp8.his conditionally included but no FP8 conversion functions (toFloat(__nv_fp8_e4m3),convertAndStore(__nv_fp8_e4m3*, float), etc.) are provided. This appears to be dead code.Additionally, the macro name
ENABLE_FP8is inconsistent withFLASHINFER_ENABLE_BF16— consider usingFLASHINFER_ENABLE_FP8for consistency.♻️ Suggested fix: remove unused include or add FP8 functions
If FP8 support is not needed now:
-#ifdef ENABLE_FP8 -#include <cuda_fp8.h> -#endifOr rename for consistency:
-#ifdef ENABLE_FP8 +#ifdef FLASHINFER_ENABLE_FP8 #include <cuda_fp8.h> #endif
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
include/flashinfer/mamba/conversion.cuh
🧰 Additional context used
📓 Path-based instructions (1)
include/**/*.cuh
📄 CodeRabbit inference engine (CLAUDE.md)
include/**/*.cuh: Torch headers MUST NOT be included in files within theinclude/directory - keep framework-agnostic CUDA kernels that accept raw pointers
Kernel code ininclude/flashinfer/is automatically picked up by JIT compilation on changes - no pip reinstall needed
For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly
Files:
include/flashinfer/mamba/conversion.cuh
🧠 Learnings (3)
📓 Common learnings
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Kernel code in `include/flashinfer/` is automatically picked up by JIT compilation on changes - no pip reinstall needed
Applied to files:
include/flashinfer/mamba/conversion.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Torch headers MUST NOT be included in files within the `include/` directory - keep framework-agnostic CUDA kernels that accept raw pointers
Applied to files:
include/flashinfer/mamba/conversion.cuh
🔇 Additional comments (2)
include/flashinfer/mamba/conversion.cuh (2)
11-19: LGTM!The
toFloatoverloads correctly use CUDA intrinsics (__half2float,__bfloat162float) and the BF16 variant is properly guarded. The identity passthrough forfloatis appropriate.
21-31: LGTM!The
convertAndStoreoverloads correctly use CUDA intrinsics for float-to-half and float-to-bfloat16 conversions. The BF16 variant is properly guarded withFLASHINFER_ENABLE_BF16.
|
/bot run |
yzh119
left a comment
There was a problem hiding this comment.
LGTM, thanks for this contribution!
I'll have a follow up soon with more perf for Blackwell! |
|
@ishovkun sounds great! Do you expect to land it in this PR, or in a follow up one? |
|
[FAILED] Pipeline #41601989: 17/20 passed |
<!-- .github/pull_request_template.md --> ## 📌 Description This PR brings support for various dtypes for the state parameter of `selective_state_update` (mamba). Specifically, this supports the following state types: - fp32 - bf16 - fp16. The data types of other parameters are still fixed. Additionally, there are performance improvements over the current version, see [current version](#2301): <img width="3000" height="4500" alt="runtime_vs_batch_size_NVIDIA_B200" src="https://github.com/user-attachments/assets/21eb7fe0-1655-4f6b-aa62-b88d1d8505e8" /> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Broadened numeric-type support and dispatch for selective state updates; reworked kernels, shared buffering, alignment and vectorized loads to improve warp utilization and DIM/DSTATE-driven execution; added a staged producer–consumer path and clearer error messaging for unsupported combinations. * **Tests** * Added explicit state-type test coverage, introduced a state-type test parameter, and switched to stricter comparisons for more precise failure diagnostics. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Igor Shovkun <ishovkun@viking-dvt-151.nvidia.com> Co-authored-by: Igor Shovkun <ishovkun@viking-cr-196.nvidia.com> Co-authored-by: Igor Shovkun <ishovkun@viking-prod-258.nvidia.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: yzh119 <zihaoy@nvidia.com>
…tionally exposed to the user. (#2591) <!-- .github/pull_request_template.md --> ## 📌 Description This PR does several things: - Improves automatic kernel selection based on the arch, state_dtype, and the batch size (see image below). - Slightly improves performance at small batch sizes by launching several CTAs per tile. - Adds jinja templates for the `selective_state_update` function (jit is fast now) - Reduces the number of meaningless parameter combinations in the tests (test are still fast) ## Background This PR changes changes the behavior of the function. Now, an optional string `algorithm` can be passed to the kernel. The default value 'auto' allows the user not to think about the internals of the function. Optionally, the user can specify the kernel that they want. This adjustment allowed me to make use of the recent mamba benchmarks. The sweep is shown below: <!-- Link any related issues here --> <img width="3600" height="3000" alt="ssu_speedup_vs_batch_size_NVIDIA_B200" src="https://github.com/user-attachments/assets/11325924-64e2-48e7-8fee-7244cdd7a893" /> One can see, that the new benchmark now correctly shows the speed difference between the reference Triton and the current implementation as opposed to my previous [PR](#2301). Clearly, I previously messed up the measurements at small batch sizes. ## Kernel Selection This PR improves the kernel selection in the following ways: - If the problem size is too small, use the `simple` algorithm with several CTAs per tile. - If on Blackwell, use the `horizontal` algorithm only for bf16/fp16 states, else fall back to the `vertical`. ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes Please check how I handled jinja templates as it's my first time using those. Also, please check whether I accidentally deleted any important tests. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Runtime-selectable algorithm for selective state update: auto (default), simple, vertical, horizontal. * **Bug Fixes** * Added runtime validation to ensure index/dtype consistency across execution paths. * **Chores** * JIT/module generation reworked to produce specialized builds per dtype/dimension and target architectures. * Public API unified to select appropriate compiled module based on device and data. * **Tests** * Expanded and parameterized tests covering algorithms, dtypes, tiling, intermediate states, and large batches. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
…tionally exposed to the user. (#2591) <!-- .github/pull_request_template.md --> ## 📌 Description This PR does several things: - Improves automatic kernel selection based on the arch, state_dtype, and the batch size (see image below). - Slightly improves performance at small batch sizes by launching several CTAs per tile. - Adds jinja templates for the `selective_state_update` function (jit is fast now) - Reduces the number of meaningless parameter combinations in the tests (test are still fast) ## Background This PR changes changes the behavior of the function. Now, an optional string `algorithm` can be passed to the kernel. The default value 'auto' allows the user not to think about the internals of the function. Optionally, the user can specify the kernel that they want. This adjustment allowed me to make use of the recent mamba benchmarks. The sweep is shown below: <!-- Link any related issues here --> <img width="3600" height="3000" alt="ssu_speedup_vs_batch_size_NVIDIA_B200" src="https://github.com/user-attachments/assets/11325924-64e2-48e7-8fee-7244cdd7a893" /> One can see, that the new benchmark now correctly shows the speed difference between the reference Triton and the current implementation as opposed to my previous [PR](flashinfer-ai/flashinfer#2301). Clearly, I previously messed up the measurements at small batch sizes. ## Kernel Selection This PR improves the kernel selection in the following ways: - If the problem size is too small, use the `simple` algorithm with several CTAs per tile. - If on Blackwell, use the `horizontal` algorithm only for bf16/fp16 states, else fall back to the `vertical`. ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes Please check how I handled jinja templates as it's my first time using those. Also, please check whether I accidentally deleted any important tests. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Runtime-selectable algorithm for selective state update: auto (default), simple, vertical, horizontal. * **Bug Fixes** * Added runtime validation to ensure index/dtype consistency across execution paths. * **Chores** * JIT/module generation reworked to produce specialized builds per dtype/dimension and target architectures. * Public API unified to select appropriate compiled module based on device and data. * **Tests** * Expanded and parameterized tests covering algorithms, dtypes, tiling, intermediate states, and large batches. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
📌 Description
This PR implements
selective_state_updatekernels for the Mamba layer (e.g. in the Nemotron model). This implementation is necessary as the default Triton implementation in TRT-LLM is inefficient.Specifically, the PR implements two kernels:
The speedup yield of the CUDA code is summarized in the image below.

🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Tests
Chores
✏️ Tip: You can customize this high-level summary in your review settings.