Skip to content

Selective State Update kernel (mamba)#2301

Merged
yzh119 merged 41 commits intoflashinfer-ai:mainfrom
ishovkun:main
Jan 13, 2026
Merged

Selective State Update kernel (mamba)#2301
yzh119 merged 41 commits intoflashinfer-ai:mainfrom
ishovkun:main

Conversation

@ishovkun
Copy link
Copy Markdown
Contributor

@ishovkun ishovkun commented Jan 6, 2026

📌 Description

This PR implements selective_state_update kernels for the Mamba layer (e.g. in the Nemotron model). This implementation is necessary as the default Triton implementation in TRT-LLM is inefficient.

Specifically, the PR implements two kernels:

  • a simple pre-Hopper kernel with vectorized loads
  • a producer-consumer kernel for Hopper+ GPUs.

The speedup yield of the CUDA code is summarized in the image below.
Screenshot 2026-01-06 at 3 05 13 PM

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Public Mamba selective-state-update API: JIT-built CUDA backend with BF16 and SM90 variants, runtime backend selection/caching, optional dt-bias and z-scaling, state-batch indexing, dt-softplus, pad-slot handling, and multi-head batching with optimized device kernels.
  • Tests

    • End-to-end pytest suite validating correctness against a Triton reference across dtypes, batching, and devices.
  • Chores

    • Mamba package exposed at top level and integrated into AOT/JIT build flow.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Jan 6, 2026

📝 Walkthrough

Walkthrough

Adds a new selective_state_update feature: CUDA headers/kernels (simple + SM90 CUtensorMap path), a host launcher and TVM FFI binding, Python JIT glue (base + SM90), a Triton reference implementation, pytest validations, and AOT/JIT integration.

Changes

Cohort / File(s) Summary
CUDA headers & kernels
include/flashinfer/mamba/selective_state_update.cuh, include/flashinfer/mamba/create_tensor_map.cuh, include/flashinfer/mamba/conversion.cuh
New SelectiveStateUpdateParams, device helpers (softplus/exp/zero), BF16/FP16 conversion utilities, vectorized load traits, two kernel paths (simple and producer‑consumer vertical), SharedStorage, CUtensorMap/TMA utilities, alignment/capability checks, and host-side launcher templates.
CUDA launcher & FFI binding
csrc/selective_state_update.cu, csrc/flashinfer_mamba_binding.cu
New host selective_state_update implementation with input/stride validation, param packing, optional-input checks, dtype-key dispatch (BF16+floatA supported), device/stream setup, and export via TVM_FFI_DLL_EXPORT_TYPED_FUNC.
Python runtime & API
flashinfer/mamba/selective_state_update.py, flashinfer/mamba/__init__.py, flashinfer/__init__.py
JIT module loaders (base & SM90), runtime selection by device compute capability, public selective_state_update wrapper (shape normalization, output allocation), registered internal op and fake meta-op for torch.compile, and package exports.
JIT spec & build config
flashinfer/jit/mamba/selective_state_update.py, flashinfer/jit/mamba/__init__.py
JIT spec generators referencing selective_state_update.cu and flashinfer_mamba_binding.cu; SM90 variant adds nvcc flags via CompilationContext and defines -DFLASHINFER_MAMBA_ENABLE_SM90.
Triton reference & tests
tests/mamba/selective_state_update_triton.py, tests/mamba/test_selective_state_update.py, tests/mamba/__init__.py
New Triton kernel + Python wrapper as reference; pytest compares Triton vs CUDA outputs across shapes/dtypes and validates state updates; test module license header added.
AOT integration
flashinfer/aot.py
Imports selective_state_update JIT spec generators and appends the spec(s) to jit_specs in gen_all_modules (conditionally adds SM90 spec when available).

Sequence Diagram(s)

sequenceDiagram
    participant User as Python User
    participant API as flashinfer.mamba API
    participant JIT as JIT Compiler/Loader
    participant Binding as TVM FFI Binding
    participant Launcher as Host Launcher
    participant GPU as GPU Device

    User->>API: selective_state_update(state, x, dt, A, B, C, D, ...)
    API->>JIT: get_selective_state_update_module(device)
    alt cached
        JIT-->>API: module
    else build
        JIT->>JIT: gen_jit_spec() -> nvcc compile/link
        JIT-->>API: build_and_load()
    end
    API->>Binding: prepare tensors & call binding
    Binding->>Launcher: selective_state_update(...)
    Launcher->>Launcher: validate inputs, pack SelectiveStateUpdateParams
    Launcher->>Launcher: dtype-key dispatch → choose kernel specialization
    opt advanced_path_requires_TMA
        Launcher->>GPU: cuDriver create CUtensorMap
    end
    Launcher->>GPU: launch kernel (simple or producer-consumer)
    GPU-->>Launcher: kernel complete
    Launcher-->>Binding: return
    Binding-->>API: done
    API-->>User: output tensor
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Suggested reviewers

  • aleozlx
  • djmmoss
  • cyx-6
  • IwakuraRein
  • bkryu
  • nvmbreughe
  • kahyunnam
  • jimmyzho

Poem

I hop through kernels, nibble softplus light,
BF16 carrots crunch in CUDA night,
JIT threads hum, bindings stitch the way,
States update, outputs bloom — I leap and play,
—a rabbit celebrates: "selective_state_update!" 🐇✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 47.06% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Selective State Update kernel (mamba)' accurately describes the main feature added: a selective state update kernel implementation for Mamba layers.
Description check ✅ Passed The description covers the purpose (replacing inefficient Triton implementation), implementation details (two kernel variants), includes visual evidence of performance gains, and confirms pre-commit checks and testing completion. Minor gaps in issue linking, but overall comprehensive.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @ishovkun, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the performance of Mamba layer operations by integrating specialized CUDA kernels for the selective_state_update function. By offering optimized implementations tailored for different GPU architectures, including advanced techniques like Tensor Memory Accelerator (TMA) for Hopper+, it aims to provide substantial speedups over existing solutions. The changes also include robust Python bindings and JIT compilation capabilities, making these high-performance kernels easily accessible and verifiable through comprehensive unit tests.

Highlights

  • Mamba Selective State Update Kernels: Introduced highly optimized CUDA kernels for the selective_state_update operation within the Mamba layer.
  • Performance Optimization: Addresses the inefficiency of the default Triton implementation in TRT-LLM, aiming for significant speedups.
  • Architecture-Specific Kernels: Implements two distinct kernels: a simple vectorized load kernel for pre-Hopper GPUs and a producer-consumer kernel utilizing TMA for Hopper+ architectures.
  • Python Bindings and JIT: Provides Python API and JIT compilation support for seamless integration and dynamic compilation of the CUDA kernels.
  • Comprehensive Testing: Includes a Triton-based reference implementation and unit tests to ensure the correctness and reliability of the new kernels.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an efficient selective_state_update kernel for Mamba, with specialized versions for pre-Hopper and Hopper+ architectures. The implementation is well-structured, with a clear separation between the Python API, TVM-FFI bindings, and the CUDA kernels. The code includes extensive validation checks and a comprehensive test suite against a Triton reference implementation.

However, I've identified a few issues. The most critical is that the CUDA kernels have hardcoded dimension sizes (DSTATE=128, DIM=64), which severely limits their use cases. The Python wrapper also lacks the dimension unsqueezing logic present in the reference implementation, making it less user-friendly. There are also some minor issues like a documentation error, unreachable code, and an inconsistency in precision between the two kernel versions.

Addressing these points will significantly improve the robustness and usability of this new feature.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI Agents
In @include/flashinfer/mamba/selective_state_update.cuh:
- Around line 493-497: The kernel launch currently checks cudaGetLastError() and
prints it via printf but does not propagate it; change the code that calls
cudaGetLastError()/cudaGetErrorString() to capture the cudaError_t (err) and
propagate it out of the function (e.g., return err or throw a runtime_error)
instead of only printing, and consider calling cudaDeviceSynchronize() before
checking errors if you need to catch asynchronous failures; update the function
signature and all callers of this kernel-launching function accordingly (look
for cudaGetLastError, cudaGetErrorString, and the kernel launch site in
selective_state_update.cuh) so callers can handle the error.

In @tests/mamba/selective_state_update_triton.py:
- Around line 297-302: The tie_hdim expression can raise AttributeError when
dt_bias is None because it calls dt_bias.stride(-1); update the tie_hdim boolean
to first ensure dt_bias is not None (e.g., include "and dt_bias is not None and
dt_bias.stride(-1) == 0") or otherwise short-circuit the check so
dt_bias.stride(...) is only evaluated when dt_bias exists; modify the line
computing tie_hdim that references A.stride, dt.stride and dt_bias.stride to
include this None check.
🧹 Nitpick comments (8)
include/flashinfer/mamba/create_tensor_map.cuh (1)

18-24: Minor: Redundant error output in gpuAssert.

The function outputs the error message via both fprintf(stderr, ...) and std::cout, which is redundant and uses inconsistent streams (stderr vs stdout).

🔎 Suggested fix
 static inline void gpuAssert(cudaError_t code, const char* file, int line, bool abort = true) {
   if (code != cudaSuccess) {
     fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
-    std::cout << "GPU assert failed" << std::endl;
     if (abort) exit(code);
   }
 }
include/flashinfer/mamba/selective_state_update.cuh (3)

102-111: Missing VectorizedLoadTraits specializations for other data types.

Only __nv_bfloat16 is specialized. This will cause a compilation error if any other type combination is used with the simple kernel, since the primary template is empty.

This is fine if bfloat16-only support is intentional for now, but consider adding a static_assert to the primary template to provide a clear error message.

🔎 Suggested improvement
 template <typename input_t, typename weight_t, typename state_t>
-struct VectorizedLoadTraits {};
+struct VectorizedLoadTraits {
+  static_assert(sizeof(input_t) == 0,
+                "VectorizedLoadTraits not specialized for this type combination");
+};

164-166: Consider: Inconsistent exp implementation between kernels.

The simple kernel uses __expf(A_value * dt_value) (line 165), while the producer-consumer kernel uses fast_exp(A_value * dt_value) (line 419). The fast_exp uses PTX ex2.approx.f32 which may have slightly different precision characteristics than __expf.

If numerical consistency between GPU paths matters, consider using the same implementation in both kernels.


456-457: Hardcoded dimension constraints limit kernel flexibility.

The kernel only supports DSTATE=128 (line 457) and DIM=64 for Hopper (line 475). While these are validated with FLASHINFER_CHECK, consider adding a brief comment explaining why these specific values are chosen (e.g., shared memory constraints, warp alignment) for future maintainers. Based on coding guidelines, hot-path algorithmic choices should be documented.

Also applies to: 474-476

tests/mamba/test_selective_state_update.py (2)

10-12: Remove unused matrixA_dtype parameter.

The matrixA_dtype parameter is passed to create_test_inputs but never used in the function body. Either remove it from the signature or use it when creating tensor A if float32 is not always desired.

🔎 Proposed fix
 def create_test_inputs(
-    batch_size, nheads, dim, dstate, ngroups, input_dtype, weight_dtype, matrixA_dtype
+    batch_size, nheads, dim, dstate, ngroups, input_dtype, weight_dtype
 ):

And update the call site accordingly:

     inputs = create_test_inputs(
         batch,
         nheads,
         dim,
         dstate,
         ngroups,
         input_dtype,
         weight_dtype,
-        matrixA_dtype,
     )

68-87: Consider adding GPU capability check for test skip conditions.

Per coding guidelines, tests should use flashinfer.utils functions to skip on unsupported GPU architectures. The PR mentions a Hopper+ specific kernel path, so you may want to add appropriate skip conditions.

🔎 Example
import pytest
from flashinfer.utils import get_compute_capability

# At module level or as a fixture
@pytest.fixture(autouse=True)
def check_cuda():
    if not torch.cuda.is_available():
        pytest.skip("CUDA not available")

# Or as a decorator if specific SM version is required:
# @pytest.mark.skipif(get_compute_capability()[0] < 8, reason="Requires SM80+")
csrc/selective_state_update.cu (2)

119-120: Simplify the assertion for unsupported z parameter.

The current assertion syntax FLASHINFER_CHECK(!z.has_value() && "z is not supported yet") is unconventional. The string literal always evaluates to true, making the condition equivalent to FLASHINFER_CHECK(!z.has_value()). Consider using the standard two-argument form with an explicit message.

🔎 Proposed fix
-  // if(z.has_value())
-  FLASHINFER_CHECK(!z.has_value() && "z is not supported yet");
+  FLASHINFER_CHECK(!z.has_value(), "z is not supported yet");

156-158: Dead code: z pointer copy is unreachable.

Since line 120 throws an error if z.has_value(), this block is currently unreachable. Consider removing it until z support is implemented, or leave it with a comment indicating it's a placeholder for future support.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between edb37cd and 8e80802.

📒 Files selected for processing (13)
  • csrc/flashinfer_mamba_binding.cu
  • csrc/selective_state_update.cu
  • flashinfer/__init__.py
  • flashinfer/jit/mamba/__init__.py
  • flashinfer/jit/mamba/selective_state_update.py
  • flashinfer/mamba/__init__.py
  • flashinfer/mamba/selective_state_update.py
  • include/flashinfer/mamba/conversion.cuh
  • include/flashinfer/mamba/create_tensor_map.cuh
  • include/flashinfer/mamba/selective_state_update.cuh
  • tests/mamba/__init__.py
  • tests/mamba/selective_state_update_triton.py
  • tests/mamba/test_selective_state_update.py
🧰 Additional context used
📓 Path-based instructions (6)
tests/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

tests/**/*.py: Test implementations should use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing with mpirun on multi-GPU systems, use the pattern: mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes - tests/conftest.py provides auto-skipping for OOM tests as a safety net but should not be relied upon

Files:

  • tests/mamba/__init__.py
  • tests/mamba/selective_state_update_triton.py
  • tests/mamba/test_selective_state_update.py
flashinfer/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

flashinfer/**/*.py: Use @functools.cache decorator on Python API functions to implement module-level caching and avoid recompilation
Use @flashinfer_api decorator for debugging API calls, enable via FLASHINFER_LOGLEVEL environment variable (0=off, 1=basic, 3=detailed, 5=with stats)

Files:

  • flashinfer/jit/mamba/selective_state_update.py
  • flashinfer/__init__.py
  • flashinfer/mamba/selective_state_update.py
  • flashinfer/jit/mamba/__init__.py
  • flashinfer/mamba/__init__.py
flashinfer/jit/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

flashinfer/jit/**/*.py: JIT module generators in flashinfer/jit/ must follow the pattern: compute URI → create directory → (optional) render Jinja template → copy sources → return JitSpec
Use gen_jit_spec() function to return a properly configured JitSpec from module generators with appropriate sources and extra_cuda_cflags
Specify supported_major_versions in JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)

Files:

  • flashinfer/jit/mamba/selective_state_update.py
  • flashinfer/jit/mamba/__init__.py
flashinfer/__init__.py

📄 CodeRabbit inference engine (CLAUDE.md)

Export new operations in flashinfer/__init__.py to make them available as public API

Files:

  • flashinfer/__init__.py
csrc/**/*.cu

📄 CodeRabbit inference engine (CLAUDE.md)

Framework bindings and PyTorch tensor handling should be implemented in csrc/ via TVM-FFI, not in include/ headers

Files:

  • csrc/selective_state_update.cu
  • csrc/flashinfer_mamba_binding.cu
include/**/*.cuh

📄 CodeRabbit inference engine (CLAUDE.md)

include/**/*.cuh: Torch headers MUST NOT be included in files within the include/ directory - keep framework-agnostic CUDA kernels that accept raw pointers
Kernel code in include/flashinfer/ is automatically picked up by JIT compilation on changes - no pip reinstall needed
For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly

Files:

  • include/flashinfer/mamba/conversion.cuh
  • include/flashinfer/mamba/create_tensor_map.cuh
  • include/flashinfer/mamba/selective_state_update.cuh
🧠 Learnings (14)
📓 Common learnings
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/__init__.py : Export new operations in `flashinfer/__init__.py` to make them available as public API
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Kernel code in `include/flashinfer/` is automatically picked up by JIT compilation on changes - no pip reinstall needed
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/__init__.py : Export new operations in `flashinfer/__init__.py` to make them available as public API

Applied to files:

  • tests/mamba/__init__.py
  • flashinfer/__init__.py
  • flashinfer/mamba/selective_state_update.py
  • flashinfer/jit/mamba/__init__.py
  • flashinfer/mamba/__init__.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/**/*.py : Use `flashinfer_api` decorator for debugging API calls, enable via `FLASHINFER_LOGLEVEL` environment variable (0=off, 1=basic, 3=detailed, 5=with stats)

Applied to files:

  • tests/mamba/__init__.py
  • flashinfer/__init__.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/**/*.py : Use `functools.cache` decorator on Python API functions to implement module-level caching and avoid recompilation

Applied to files:

  • tests/mamba/__init__.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/aot.py : Register new operations in `flashinfer/aot.py` by calling the `gen_*_module()` function for AOT (Ahead-Of-Time) pre-compilation support

Applied to files:

  • tests/mamba/__init__.py
  • flashinfer/jit/mamba/selective_state_update.py
  • flashinfer/mamba/selective_state_update.py
  • flashinfer/jit/mamba/__init__.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Use `gen_jit_spec()` function to return a properly configured JitSpec from module generators with appropriate `sources` and `extra_cuda_cflags`

Applied to files:

  • flashinfer/jit/mamba/selective_state_update.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : JIT module generators in `flashinfer/jit/` must follow the pattern: compute URI → create directory → (optional) render Jinja template → copy sources → return JitSpec

Applied to files:

  • flashinfer/jit/mamba/selective_state_update.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Kernel code in `include/flashinfer/` is automatically picked up by JIT compilation on changes - no pip reinstall needed

Applied to files:

  • flashinfer/jit/mamba/selective_state_update.py
  • csrc/flashinfer_mamba_binding.cu
  • include/flashinfer/mamba/conversion.cuh
  • include/flashinfer/mamba/selective_state_update.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Specify `supported_major_versions` in JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)

Applied to files:

  • flashinfer/jit/mamba/selective_state_update.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*_jit_binding.cu : Create TVM-FFI bindings in files matching the pattern `csrc/*_jit_binding.cu` using the `TVM_FFI_DLL_EXPORT_TYPED_FUNC(name, func)` macro to expose C++ functions

Applied to files:

  • csrc/flashinfer_mamba_binding.cu
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers

Applied to files:

  • csrc/flashinfer_mamba_binding.cu
  • include/flashinfer/mamba/create_tensor_map.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly

Applied to files:

  • include/flashinfer/mamba/conversion.cuh
  • include/flashinfer/mamba/selective_state_update.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Torch headers MUST NOT be included in files within the `include/` directory - keep framework-agnostic CUDA kernels that accept raw pointers

Applied to files:

  • include/flashinfer/mamba/conversion.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers

Applied to files:

  • include/flashinfer/mamba/selective_state_update.cuh
🧬 Code graph analysis (7)
flashinfer/jit/mamba/selective_state_update.py (1)
flashinfer/jit/core.py (2)
  • JitSpec (216-315)
  • gen_jit_spec (318-384)
csrc/selective_state_update.cu (1)
csrc/tvm_ffi_utils.h (2)
  • get_stream (294-296)
  • encode_dlpack_dtype (30-32)
csrc/flashinfer_mamba_binding.cu (2)
csrc/selective_state_update.cu (2)
  • selective_state_update (30-204)
  • selective_state_update (30-33)
flashinfer/mamba/selective_state_update.py (1)
  • selective_state_update (38-104)
flashinfer/mamba/selective_state_update.py (5)
flashinfer/api_logging.py (1)
  • flashinfer_api (464-565)
flashinfer/jit/mamba/selective_state_update.py (1)
  • gen_selective_state_update_module (21-34)
flashinfer/jit/core.py (1)
  • build_and_load (303-315)
csrc/flashinfer_mamba_binding.cu (1)
  • selective_state_update (23-26)
csrc/selective_state_update.cu (2)
  • selective_state_update (30-204)
  • selective_state_update (30-33)
flashinfer/jit/mamba/__init__.py (3)
csrc/flashinfer_mamba_binding.cu (1)
  • selective_state_update (23-26)
csrc/selective_state_update.cu (2)
  • selective_state_update (30-204)
  • selective_state_update (30-33)
flashinfer/jit/mamba/selective_state_update.py (1)
  • gen_selective_state_update_module (21-34)
flashinfer/mamba/__init__.py (3)
csrc/flashinfer_mamba_binding.cu (1)
  • selective_state_update (23-26)
csrc/selective_state_update.cu (2)
  • selective_state_update (30-204)
  • selective_state_update (30-33)
flashinfer/mamba/selective_state_update.py (1)
  • selective_state_update (38-104)
tests/mamba/test_selective_state_update.py (3)
csrc/flashinfer_mamba_binding.cu (1)
  • selective_state_update (23-26)
flashinfer/mamba/selective_state_update.py (1)
  • selective_state_update (38-104)
flashinfer/logits_processor/types.py (1)
  • size (132-136)
🪛 Ruff (0.14.10)
tests/mamba/selective_state_update_triton.py

70-70: Unused function argument: batch

(ARG001)


71-71: Unused function argument: nheads

(ARG001)


284-284: Do not assign a lambda expression, use a def

Rewrite grid as a def

(E731)

tests/mamba/test_selective_state_update.py

11-11: Unused function argument: matrixA_dtype

(ARG001)

🔇 Additional comments (17)
tests/mamba/__init__.py (1)

1-15: LGTM!

Standard license header for the new test package.

include/flashinfer/mamba/conversion.cuh (1)

1-31: LGTM!

Clean, framework-agnostic type conversion utilities. The design correctly keeps Torch headers out of include/ as per coding guidelines.

include/flashinfer/mamba/create_tensor_map.cuh (1)

44-86: LGTM with minor note.

Well-structured TMA tensor map creation utility. The type mapping via if constexpr chain with static_assert fallback is idiomatic.

Note: The <iostream> include is only needed for the error paths. Consider using fprintf(stderr, ...) consistently for error output to avoid pulling in iostream overhead, though this is minor for a header used in kernel code.

flashinfer/__init__.py (1)

156-156: LGTM!

Correctly exports the new mamba module as a public API, following the project's coding guidelines.

flashinfer/mamba/__init__.py (1)

1-19: LGTM!

Clean package initialization that properly exports selective_state_update as the public API for the mamba module.

flashinfer/jit/mamba/selective_state_update.py (2)

22-25: No action needed. The flag naming in selective_state_update.py is correct and intentional.

The mamba kernel headers (conversion.cuh) use short-form macros ENABLE_BF16 and ENABLE_FP8. The -DENABLE_BF16 and -DENABLE_FP8 flags defined in selective_state_update.py correctly match these expectations. The longer FLASHINFER_ENABLE_* variants from gen_jit_spec defaults are for other modules and safely coexist with the mamba-specific flags when both are compiled in via extra_cuda_cflags.

Likely an incorrect or invalid review comment.


21-34: The review comment is based on a coding guideline that does not match the actual implementation. The gen_jit_spec() function and JitSpec class do not have a supported_major_versions parameter.

The correct pattern used elsewhere in the codebase (e.g., flashinfer/jit/mla.py, flashinfer/jit/comm.py) is:

nvcc_flags = current_compilation_context.get_nvcc_flags_list(
    supported_major_versions=[10, 11]
)
return gen_jit_spec(
    "module_name",
    sources,
    extra_cuda_cflags=nvcc_flags,
)

The current implementation in selective_state_update.py is valid as-is. If architecture-specific restrictions are needed for BF16/FP8 features, the module would need to follow the pattern above using current_compilation_context.get_nvcc_flags_list() rather than a non-existent supported_major_versions parameter.

Likely an incorrect or invalid review comment.

flashinfer/jit/mamba/__init__.py (1)

1-19: LGTM!

The module initialization correctly imports and exports gen_selective_state_update_module, following the established pattern for JIT module generators in the codebase.

csrc/flashinfer_mamba_binding.cu (1)

1-34: LGTM!

The TVM-FFI binding correctly declares the forward reference to the implementation in selective_state_update.cu and exports it using the standard macro. The signature matches the implementation, and the namespace organization follows project conventions. Based on learnings, this correctly implements framework bindings in csrc/ via TVM-FFI.

tests/mamba/test_selective_state_update.py (1)

100-209: LGTM!

The test correctly validates both output tensors and updated state tensors against the Triton reference implementation. The detailed mismatch analysis with element-wise comparison will be valuable for debugging numerical issues.

csrc/selective_state_update.cu (2)

30-105: LGTM!

The tensor validation is comprehensive, correctly checking shapes, strides, and broadcasting patterns for all inputs. The contiguity and stride requirements are well-documented with inline comments.


191-203: The throw std::runtime_error statement is not unreachable. Based on usage patterns in the codebase (e.g., csrc/tvm_ffi_utils.h), TVM_FFI_ICHECK(false) does not throw or abort execution—it logs the error message and allows execution to continue. The explicit throw is necessary to actually terminate execution with an exception when an unsupported dtype combination is encountered.

tests/mamba/selective_state_update_triton.py (2)

24-36: LGTM!

The version-aware softplus implementation correctly handles Triton 3.x API differences. Using tl.math.log1p(tl.exp(dt)) for earlier versions is a good numerical stability choice.


42-208: LGTM!

The Triton kernel implementation follows the reference Mamba implementation correctly. The heuristics for constexpr parameters, padding support via pad_slot_id, and optional tensor handling are well-structured.

flashinfer/mamba/selective_state_update.py (3)

27-34: LGTM!

The module caching correctly uses @functools.cache as per coding guidelines, with a clear docstring explaining both in-memory and disk caching behavior.


37-104: LGTM!

The public API correctly applies @flashinfer_api for debugging support. The docstring is comprehensive, documenting all parameters including optional tensors and their expected shapes. The implementation correctly creates the output tensor and delegates to the internal function.


107-159: LGTM!

The custom op registration correctly specifies mutates_args=("state", "output") to indicate in-place mutation, which is essential for proper torch.compile() graph capture. The fake op implementation correctly returns None for meta tensor propagation.

make input tensors always have batch and nheads as dimensions.

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In @flashinfer/mamba/selective_state_update.py:
- Line 57: Update the docstring in selective_state_update.py to say the first
dimension of the state tensor is state_cache_size (not batch); clarify that the
state shape is (state_cache_size, dim, dstate) or (state_cache_size, nheads,
dim, dstate) and note that state_cache_size may differ from batch when using
state_batch_indices, matching the C++ usage where state.size(0) is assigned to
state_cache_size.
- Line 45: The parameter D is annotated as torch.Tensor but the code checks it
for None; change the function signature to use Optional[torch.Tensor] for D and
add "from typing import Optional" to imports, leaving the existing None check at
line 100 in place so callers passing None are handled correctly; ensure any
usages in functions/methods that consume D handle the Optional type (e.g.,
conditional logic already present).
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8e80802 and 8568d67.

📒 Files selected for processing (1)
  • flashinfer/mamba/selective_state_update.py
🧰 Additional context used
📓 Path-based instructions (1)
flashinfer/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

flashinfer/**/*.py: Use @functools.cache decorator on Python API functions to implement module-level caching and avoid recompilation
Use @flashinfer_api decorator for debugging API calls, enable via FLASHINFER_LOGLEVEL environment variable (0=off, 1=basic, 3=detailed, 5=with stats)

Files:

  • flashinfer/mamba/selective_state_update.py
🧠 Learnings (3)
📓 Common learnings
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Specify `supported_major_versions` in JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*_jit_binding.cu : Create TVM-FFI bindings in files matching the pattern `csrc/*_jit_binding.cu` using the `TVM_FFI_DLL_EXPORT_TYPED_FUNC(name, func)` macro to expose C++ functions
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/aot.py : Register new operations in `flashinfer/aot.py` by calling the `gen_*_module()` function for AOT (Ahead-Of-Time) pre-compilation support

Applied to files:

  • flashinfer/mamba/selective_state_update.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/__init__.py : Export new operations in `flashinfer/__init__.py` to make them available as public API

Applied to files:

  • flashinfer/mamba/selective_state_update.py
🧬 Code graph analysis (1)
flashinfer/mamba/selective_state_update.py (4)
flashinfer/jit/mamba/selective_state_update.py (1)
  • gen_selective_state_update_module (21-34)
flashinfer/jit/core.py (1)
  • build_and_load (303-315)
csrc/flashinfer_mamba_binding.cu (1)
  • selective_state_update (23-26)
csrc/selective_state_update.cu (2)
  • selective_state_update (30-204)
  • selective_state_update (30-33)
🔇 Additional comments (6)
flashinfer/mamba/selective_state_update.py (6)

17-24: LGTM!

The imports are well-organized and include all necessary dependencies for JIT compilation, custom op registration, and API logging.


27-34: LGTM!

The function correctly applies @functools.cache for module-level caching as per coding guidelines. The JIT system handles both in-memory and disk-based caching appropriately.


88-105: LGTM!

The dimension handling logic correctly addresses the past review feedback by automatically unsqueezing tensors to match expected dimensions. This allows the API to accept both 2D/3D inputs (without nheads) and 3D/4D inputs (with nheads), improving usability. The C++ backend will validate the final shapes and provide clear error messages for incompatible dimensions.


106-122: LGTM!

The output tensor is efficiently allocated using torch.empty_like(x), and all parameters are correctly passed to the internal operation. The mutation of output is properly declared in the custom op registration.


125-158: LGTM!

The internal operation is correctly registered with mutates_args=("state", "output") to declare side effects. The implementation properly delegates to the JIT-compiled module for efficient execution.


161-178: LGTM!

The fake operation is correctly implemented for torch.compile() support. Since this operation mutates existing tensors without changing their shapes, the empty pass implementation is appropriate for meta tensor propagation.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
csrc/selective_state_update.cu (2)

114-115: Use consistent validation macro and remove incomplete comment.

Line 114 contains an incomplete comment (if(z.has_value())), and line 115 uses an inconsistent pattern. For consistency with the rest of this file, use the standard FLASHINFER_CHECK macro:

♻️ Proposed refactoring
-  // if(z.has_value())
-  FLASHINFER_CHECK(!z.has_value() && "z is not supported yet");
+  FLASHINFER_CHECK(!z.has_value(), "z is not supported yet");

151-153: Remove dead code.

Lines 151-153 are unreachable because line 115 already validates that z must not have a value (raises an error if it does). This conditional will never execute.

♻️ Proposed refactoring
   if (dt_bias) {
     p.dt_bias = dt_bias.value().data_ptr();
   }
-  if (z) {
-    p.z = z.value().data_ptr();
-  }
   p.A = A.data_ptr();
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8568d67 and 7bb48cc.

📒 Files selected for processing (1)
  • csrc/selective_state_update.cu
🧰 Additional context used
📓 Path-based instructions (1)
csrc/**/*.cu

📄 CodeRabbit inference engine (CLAUDE.md)

Framework bindings and PyTorch tensor handling should be implemented in csrc/ via TVM-FFI, not in include/ headers

Files:

  • csrc/selective_state_update.cu
🧠 Learnings (2)
📓 Common learnings
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Specify `supported_major_versions` in JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*_jit_binding.cu : Create TVM-FFI bindings in files matching the pattern `csrc/*_jit_binding.cu` using the `TVM_FFI_DLL_EXPORT_TYPED_FUNC(name, func)` macro to expose C++ functions
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • csrc/selective_state_update.cu
🧬 Code graph analysis (1)
csrc/selective_state_update.cu (1)
csrc/tvm_ffi_utils.h (2)
  • get_stream (294-296)
  • encode_dlpack_dtype (30-32)
🔇 Additional comments (4)
csrc/selective_state_update.cu (4)

16-23: LGTM!

The includes and namespace setup follow the expected patterns for a CUDA implementation file with TVM-FFI bindings.


25-36: LGTM!

The dimension extraction and divisibility check correctly establish the tensor geometry required for the selective state update operation.


159-161: LGTM!

The device guard and stream retrieval correctly establish the CUDA execution context for the kernel launch.


179-179: The dtype code constants bfloat16_code and float32_code are properly defined in csrc/tvm_ffi_utils.h (lines 53-54) and correctly included in selective_state_update.cu via #include "tvm_ffi_utils.h". No issues found.

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
flashinfer/mamba/selective_state_update.py (1)

46-47: Consider validating unsupported z parameter early for clearer error messages.

The C++ implementation explicitly rejects the z parameter with FLASHINFER_CHECK(!z.has_value() && "z is not supported yet"). While the error will be caught, users will receive a C++ exception rather than a clear Python-level error message. Consider adding early validation for better UX.

♻️ Proposed validation

Add this check before line 106:

    if z is not None:
        raise NotImplementedError("z parameter is not supported yet")

Alternatively, update the docstring to explicitly note that z is currently unsupported.

Also applies to: 116-116

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7bb48cc and 5812938.

📒 Files selected for processing (1)
  • flashinfer/mamba/selective_state_update.py
🧰 Additional context used
📓 Path-based instructions (1)
flashinfer/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

flashinfer/**/*.py: Use @functools.cache decorator on Python API functions to implement module-level caching and avoid recompilation
Use @flashinfer_api decorator for debugging API calls, enable via FLASHINFER_LOGLEVEL environment variable (0=off, 1=basic, 3=detailed, 5=with stats)

Files:

  • flashinfer/mamba/selective_state_update.py
🧠 Learnings (3)
📓 Common learnings
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Specify `supported_major_versions` in JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*_jit_binding.cu : Create TVM-FFI bindings in files matching the pattern `csrc/*_jit_binding.cu` using the `TVM_FFI_DLL_EXPORT_TYPED_FUNC(name, func)` macro to expose C++ functions
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/aot.py : Register new operations in `flashinfer/aot.py` by calling the `gen_*_module()` function for AOT (Ahead-Of-Time) pre-compilation support

Applied to files:

  • flashinfer/mamba/selective_state_update.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/__init__.py : Export new operations in `flashinfer/__init__.py` to make them available as public API

Applied to files:

  • flashinfer/mamba/selective_state_update.py
🧬 Code graph analysis (1)
flashinfer/mamba/selective_state_update.py (5)
flashinfer/api_logging.py (1)
  • flashinfer_api (464-565)
flashinfer/jit/mamba/selective_state_update.py (1)
  • gen_selective_state_update_module (21-34)
flashinfer/jit/core.py (1)
  • build_and_load (303-315)
csrc/flashinfer_mamba_binding.cu (1)
  • selective_state_update (23-26)
csrc/selective_state_update.cu (2)
  • selective_state_update (25-197)
  • selective_state_update (25-28)
🔇 Additional comments (4)
flashinfer/mamba/selective_state_update.py (4)

27-34: LGTM: Proper use of module-level caching.

The @functools.cache decorator correctly implements module-level caching as per coding guidelines, ensuring the JIT module is built and loaded only once.


88-105: LGTM: Dimension expansion logic correctly implemented.

The dimension expansion logic properly handles tensors with fewer dimensions, expanding them to match the C++ backend's expectations. This addresses the concern raised in previous reviews.


125-158: LGTM: Correct custom op registration for torch.compile() support.

The custom op is properly registered with torch.library, correctly specifying mutates_args=("state", "output") and delegating to the JIT-compiled module.


161-178: LGTM: Correct fake op implementation for meta tensor propagation.

The fake op correctly mirrors the custom op signature and uses a pass implementation, which is appropriate for torch.compile()'s shape and dtype inference without actual computation.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In @include/flashinfer/mamba/selective_state_update.cuh:
- Around line 63-68: Add a brief comment above the __device__ __forceinline__
float fast_exp(float x) function that documents the performance/precision
tradeoff: note that it uses PTX ex2.approx.f32 for a faster but lower-precision
exp2 approximation (multiplying x by log2_E) and that this is chosen because
this is a performance-critical hot path; mention potential alternatives (e.g.,
using standard expf, a higher-precision polynomial/fast-math approximation, or a
compiler intrinsic) and when to switch to them for correctness over speed.
- Around line 491-495: The kernel launch currently checks cudaGetLastError() and
prints the error with printf but does not propagate it; replace the printf
branch with a call to FLASHINFER_CHECK(err) (using the existing cudaError_t err
from cudaGetLastError()) so the error is surfaced to callers; remove or replace
the printf("Kernel launch failed: %s\n", cudaGetErrorString(err)) with
FLASHINFER_CHECK(err) to ensure failures are propagated rather than silently
continuing.
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5812938 and 9883eab.

📒 Files selected for processing (1)
  • include/flashinfer/mamba/selective_state_update.cuh
🧰 Additional context used
📓 Path-based instructions (1)
include/**/*.cuh

📄 CodeRabbit inference engine (CLAUDE.md)

include/**/*.cuh: Torch headers MUST NOT be included in files within the include/ directory - keep framework-agnostic CUDA kernels that accept raw pointers
Kernel code in include/flashinfer/ is automatically picked up by JIT compilation on changes - no pip reinstall needed
For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly

Files:

  • include/flashinfer/mamba/selective_state_update.cuh
🧠 Learnings (5)
📓 Common learnings
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Specify `supported_major_versions` in JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*_jit_binding.cu : Create TVM-FFI bindings in files matching the pattern `csrc/*_jit_binding.cu` using the `TVM_FFI_DLL_EXPORT_TYPED_FUNC(name, func)` macro to expose C++ functions
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers

Applied to files:

  • include/flashinfer/mamba/selective_state_update.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Kernel code in `include/flashinfer/` is automatically picked up by JIT compilation on changes - no pip reinstall needed

Applied to files:

  • include/flashinfer/mamba/selective_state_update.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly

Applied to files:

  • include/flashinfer/mamba/selective_state_update.cuh
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/mamba/selective_state_update.cuh
🔇 Additional comments (8)
include/flashinfer/mamba/selective_state_update.cuh (8)

1-35: LGTM!

The header follows coding guidelines by remaining framework-agnostic (no Torch dependencies) and includes appropriate CUDA cooperative groups, barriers, and flashinfer utilities.


36-56: LGTM!

The parameter struct is well-documented with tensor shape annotations for each pointer, and provides sensible defaults for optional parameters.


102-111: Note: Limited dtype support.

Currently only __nv_bfloat16 is specialized for vectorized loads. If other dtype combinations (fp16, fp32) are expected, they'll need additional specializations.


113-236: LGTM with known inconsistency.

The simple kernel correctly implements vectorized state updates with warp-level parallelism and optional z-gating. The use of fast_exp at line 164 is consistent with the commit message indicating it was modified to use fast_exp.

Note: A past review flagged the fast_exp vs __expf inconsistency between kernels, which remains unresolved.


238-252: LGTM!

The shared storage is correctly aligned for TMA (Tensor Memory Accelerator) and provides appropriate barrier synchronization for the multi-stage producer-consumer pattern.


254-450: LGTM!

The producer-consumer kernel correctly implements a multi-stage vertical tiling pattern with TMA for Hopper+ GPUs. The barrier synchronization and async copy patterns are properly structured.


484-488: LGTM!

The alignment check for TMA (line 484-485) and dimension divisibility check (line 488) are good defensive programming practices that catch invalid configurations before kernel launch.


454-473: Hardcoded dimensions severely limit kernel usability.

The launcher hardcodes DSTATE=128 (line 454) and DIM=64 (line 472) for the Hopper kernel path, causing failures for any other dimension values. This is a severe limitation that prevents the kernel from being used in general scenarios.

As noted in past reviews, consider making dimensions generic via template parameters or dispatch macros to support arbitrary dstate and dim values.

⛔ Skipped due to learnings
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In @include/flashinfer/mamba/selective_state_update.cuh:
- Line 409: Hoist the dA computation out of the innermost i-loop: compute dA =
__expf(A_value * dt_value) once after dt_value and A_value are set (before
entering the loops) instead of recomputing it inside the loop where the line
"auto const dA = __expf(A_value * dt_value);" currently resides; update any uses
inside the loop to reference this precomputed dA and remove the redundant
computation in the innermost loop.
🧹 Nitpick comments (1)
include/flashinfer/mamba/selective_state_update.cuh (1)

222-222: Consider simplifying negation syntax.

Using __expf(-z_value) is more idiomatic than __expf(0.f - z_value).

♻️ Suggested simplification
-      float sig_z = __fdividef(1.f, (1.f + __expf(0.f - z_value)));
+      float sig_z = __fdividef(1.f, (1.f + __expf(-z_value)));

Apply the same change at both locations (lines 222 and 435).

Also applies to: 435-435

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9883eab and 5c65e53.

📒 Files selected for processing (1)
  • include/flashinfer/mamba/selective_state_update.cuh
🧰 Additional context used
📓 Path-based instructions (1)
include/**/*.cuh

📄 CodeRabbit inference engine (CLAUDE.md)

include/**/*.cuh: Torch headers MUST NOT be included in files within the include/ directory - keep framework-agnostic CUDA kernels that accept raw pointers
Kernel code in include/flashinfer/ is automatically picked up by JIT compilation on changes - no pip reinstall needed
For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly

Files:

  • include/flashinfer/mamba/selective_state_update.cuh
🧠 Learnings (7)
📓 Common learnings
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Specify `supported_major_versions` in JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*_jit_binding.cu : Create TVM-FFI bindings in files matching the pattern `csrc/*_jit_binding.cu` using the `TVM_FFI_DLL_EXPORT_TYPED_FUNC(name, func)` macro to expose C++ functions
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Kernel code in `include/flashinfer/` is automatically picked up by JIT compilation on changes - no pip reinstall needed
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers

Applied to files:

  • include/flashinfer/mamba/selective_state_update.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Kernel code in `include/flashinfer/` is automatically picked up by JIT compilation on changes - no pip reinstall needed

Applied to files:

  • include/flashinfer/mamba/selective_state_update.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly

Applied to files:

  • include/flashinfer/mamba/selective_state_update.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Torch headers MUST NOT be included in files within the `include/` directory - keep framework-agnostic CUDA kernels that accept raw pointers

Applied to files:

  • include/flashinfer/mamba/selective_state_update.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers

Applied to files:

  • include/flashinfer/mamba/selective_state_update.cuh
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/mamba/selective_state_update.cuh

Removes an unnecessary `toFloat` conversion on `d_value` in the vertical
producer-consumer kernel, as `d_value` is already a float type. Fixes:
flashinfer-ai#2301 (comment)
Updates JIT flags to use generic compute_90 PTX for forward
compatibility with SM100. Adds runtime dispatch logic to select the
optimized SM90 kernel on devices with compute capability >= 9.0.
Addresses
flashinfer-ai#2301 (comment)
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In @flashinfer/jit/mamba/selective_state_update.py:
- Around line 36-52: The SM90-specific JIT spec in
gen_selective_state_update_sm90_module currently lacks an explicit
supported_major_versions restriction; update the gen_jit_spec(...) call in
gen_selective_state_update_sm90_module to include
supported_major_versions=[9,10] (Hopper and future SM100) so the module only
compiles for Hopper+ GPUs, and optionally change nvcc_flags construction to use
iterable unpacking (e.g., nvcc_flags = ["-gencode=...", "-DENABLE_BF16",
"-DFLASHINFER_MAMBA_ENABLE_SM90", *common_nvcc_flags]) for cleaner list
composition.
🧹 Nitpick comments (1)
include/flashinfer/mamba/selective_state_update.cuh (1)

497-497: Add error message to dimension divisibility check.

The assertion at line 497 should include a descriptive message to help users understand the constraint.

♻️ Proposed fix to add error message
-      FLASHINFER_CHECK(params.dim % rowsPerStage == 0);
+      FLASHINFER_CHECK(params.dim % rowsPerStage == 0,
+                       "dim must be divisible by ", rowsPerStage, " for SM90+ kernel");
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d2a147a and 59c9f85.

📒 Files selected for processing (3)
  • flashinfer/jit/mamba/selective_state_update.py
  • flashinfer/mamba/selective_state_update.py
  • include/flashinfer/mamba/selective_state_update.cuh
🧰 Additional context used
📓 Path-based instructions (3)
flashinfer/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

flashinfer/**/*.py: Use @functools.cache decorator on Python API functions to implement module-level caching and avoid recompilation
Use @flashinfer_api decorator for debugging API calls, enable via FLASHINFER_LOGLEVEL environment variable (0=off, 1=basic, 3=detailed, 5=with stats)

Files:

  • flashinfer/jit/mamba/selective_state_update.py
  • flashinfer/mamba/selective_state_update.py
flashinfer/jit/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

flashinfer/jit/**/*.py: JIT module generators in flashinfer/jit/ must follow the pattern: compute URI → create directory → (optional) render Jinja template → copy sources → return JitSpec
Use gen_jit_spec() function to return a properly configured JitSpec from module generators with appropriate sources and extra_cuda_cflags
Specify supported_major_versions in JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)

Files:

  • flashinfer/jit/mamba/selective_state_update.py
include/**/*.cuh

📄 CodeRabbit inference engine (CLAUDE.md)

include/**/*.cuh: Torch headers MUST NOT be included in files within the include/ directory - keep framework-agnostic CUDA kernels that accept raw pointers
Kernel code in include/flashinfer/ is automatically picked up by JIT compilation on changes - no pip reinstall needed
For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly

Files:

  • include/flashinfer/mamba/selective_state_update.cuh
🧠 Learnings (10)
📓 Common learnings
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Specify `supported_major_versions` in JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*_jit_binding.cu : Create TVM-FFI bindings in files matching the pattern `csrc/*_jit_binding.cu` using the `TVM_FFI_DLL_EXPORT_TYPED_FUNC(name, func)` macro to expose C++ functions
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Kernel code in `include/flashinfer/` is automatically picked up by JIT compilation on changes - no pip reinstall needed
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Use `gen_jit_spec()` function to return a properly configured JitSpec from module generators with appropriate `sources` and `extra_cuda_cflags`

Applied to files:

  • flashinfer/jit/mamba/selective_state_update.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : JIT module generators in `flashinfer/jit/` must follow the pattern: compute URI → create directory → (optional) render Jinja template → copy sources → return JitSpec

Applied to files:

  • flashinfer/jit/mamba/selective_state_update.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Kernel code in `include/flashinfer/` is automatically picked up by JIT compilation on changes - no pip reinstall needed

Applied to files:

  • flashinfer/jit/mamba/selective_state_update.py
  • include/flashinfer/mamba/selective_state_update.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Specify `supported_major_versions` in JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)

Applied to files:

  • flashinfer/jit/mamba/selective_state_update.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/aot.py : Register new operations in `flashinfer/aot.py` by calling the `gen_*_module()` function for AOT (Ahead-Of-Time) pre-compilation support

Applied to files:

  • flashinfer/jit/mamba/selective_state_update.py
  • flashinfer/mamba/selective_state_update.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/__init__.py : Export new operations in `flashinfer/__init__.py` to make them available as public API

Applied to files:

  • flashinfer/mamba/selective_state_update.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers

Applied to files:

  • include/flashinfer/mamba/selective_state_update.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly

Applied to files:

  • include/flashinfer/mamba/selective_state_update.cuh
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/mamba/selective_state_update.cuh
🧬 Code graph analysis (2)
flashinfer/jit/mamba/selective_state_update.py (1)
flashinfer/jit/core.py (2)
  • JitSpec (216-397)
  • gen_jit_spec (400-466)
flashinfer/mamba/selective_state_update.py (3)
flashinfer/utils.py (1)
  • get_compute_capability (258-261)
csrc/selective_state_update.cu (2)
  • selective_state_update (25-209)
  • selective_state_update (25-28)
csrc/flashinfer_mamba_binding.cu (1)
  • selective_state_update (23-26)
🪛 Ruff (0.14.10)
flashinfer/jit/mamba/selective_state_update.py

39-43: Consider iterable unpacking instead of concatenation

Replace with iterable unpacking

(RUF005)

🔇 Additional comments (12)
flashinfer/jit/mamba/selective_state_update.py (1)

21-33: LGTM! Base module configuration is correct.

The base module uses appropriate flags for pre-Hopper GPUs and doesn't need architecture restrictions.

include/flashinfer/mamba/selective_state_update.cuh (7)

36-101: LGTM! Parameter struct and helper utilities are well-designed.

The SelectiveStateUpdateParams struct efficiently packs all kernel parameters, and the helper functions follow standard CUDA patterns for warp-level operations and type conversions.


103-226: LGTM! Simple kernel implementation is correct.

The pre-Hopper kernel uses appropriate shared memory, vectorized loads, and warp-level synchronization. The handling of the pad_slot_id for skipping padded entries is correct.


228-296: LGTM! Producer-consumer synchronization setup is correct.

The multi-stage shared storage with CUDA barriers is properly configured. The 128-byte alignment for TMA operations and barrier initialization with correct arrival counts follow SM90+ best practices.


299-335: LGTM! Producer warp TMA pipeline is correctly implemented.

The multi-stage TMA pipeline with barrier synchronization follows the correct pattern for overlapping loads and stores. The handling of pad_slot_id and barrier arrival tracking is correct.


336-436: LGTM! Consumer warp computation is correctly synchronized.

The consumer warps properly coordinate with the producer through barriers, and the computation logic matches the simple kernel. The parallel loading of inputs across warps is an effective optimization.


440-488: LGTM! Pre-Hopper launcher path with comprehensive alignment checks.

The alignment validation for vectorized loads is thorough, checking pointers, strides, and sizes. The dispatch logic correctly handles supported dstate values.


490-544: LGTM! SM90+ launcher path with proper TMA setup.

The SM90+ path correctly creates the tensor map for TMA operations and dispatches based on dim and dstate. The 128-byte alignment requirement is properly enforced.

flashinfer/mamba/selective_state_update.py (4)

30-46: LGTM! Module loading with proper caching and device-aware selection.

The use of @functools.cache follows coding guidelines for module-level caching, and the compute capability check correctly selects between base and SM90 modules.


49-134: LGTM! Public API with proper dimension normalization.

The function correctly handles both 2D (no head dimension) and 3D (with head dimension) inputs by normalizing to 3D/4D shapes. The use of @flashinfer_api decorator enables debug logging per coding guidelines.


137-170: LGTM! Custom op registration for torch.compile support.

The registration correctly declares mutation of state and output tensors, and properly dispatches to the device-appropriate JIT module.


173-190: LGTM! Fake op implementation for meta tensor support.

The fake op correctly provides a no-op implementation for torch.compile meta tensor propagation, following PyTorch's custom op pattern.

Include the required divisor and kernel variant in the error message for
easier debugging when dim is not properly aligned. Address:
flashinfer-ai#2301 (review)
This enables architecture filtering via supported_major_versions and
removes reliance on hardcoded nvcc flags. Future architectures are now
supported automatically.

address
flashinfer-ai#2301 (comment)
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
include/flashinfer/mamba/selective_state_update.cuh (3)

16-35: Consider adding high-level documentation.

This header lacks documentation explaining what selective state update does, the mathematical formulation, and how the two kernel variants differ (simple vs. producer-consumer). High-level comments would help future reviewers understand the algorithm at a glance.

Based on coding guidelines, performance-critical hot paths benefit from comments explaining special algorithmic choices.


244-438: Consider adding more inline comments for the producer-consumer algorithm.

This kernel implements a sophisticated multi-stage producer-consumer pipeline with TMA. While the implementation appears correct, the complex synchronization logic (barriers, staging, TMA operations) would benefit from more inline comments explaining:

  • Why rowsPerStage = 4 * numConsumers and numStages = 3 were chosen
  • The producer's writeback-then-load pattern (lines 310-329)
  • The barrier arrival semantics and byte counting for TMA

Based on coding guidelines, performance-critical hot paths should document special algorithmic choices for future reviewers.


489-544: Note: Limited dim support in SM90+ path.

The SM90+ kernel only supports dim=64 or dim=128 (lines 533-542), while the pre-Hopper kernel supports any dim that fits in the grid. This is a conscious design trade-off for compile-time optimization, but it may limit flexibility. Consider documenting this constraint in the function/struct documentation, or adding support for additional dim values if needed.

The restriction is enforced with a clear error message, but users might benefit from knowing the supported values upfront.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 59c9f85 and 95fa14d.

📒 Files selected for processing (2)
  • flashinfer/jit/mamba/selective_state_update.py
  • include/flashinfer/mamba/selective_state_update.cuh
🚧 Files skipped from review as they are similar to previous changes (1)
  • flashinfer/jit/mamba/selective_state_update.py
🧰 Additional context used
📓 Path-based instructions (1)
include/**/*.cuh

📄 CodeRabbit inference engine (CLAUDE.md)

include/**/*.cuh: Torch headers MUST NOT be included in files within the include/ directory - keep framework-agnostic CUDA kernels that accept raw pointers
Kernel code in include/flashinfer/ is automatically picked up by JIT compilation on changes - no pip reinstall needed
For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly

Files:

  • include/flashinfer/mamba/selective_state_update.cuh
🧠 Learnings (6)
📓 Common learnings
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Specify `supported_major_versions` in JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*_jit_binding.cu : Create TVM-FFI bindings in files matching the pattern `csrc/*_jit_binding.cu` using the `TVM_FFI_DLL_EXPORT_TYPED_FUNC(name, func)` macro to expose C++ functions
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Kernel code in `include/flashinfer/` is automatically picked up by JIT compilation on changes - no pip reinstall needed
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers

Applied to files:

  • include/flashinfer/mamba/selective_state_update.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly

Applied to files:

  • include/flashinfer/mamba/selective_state_update.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Kernel code in `include/flashinfer/` is automatically picked up by JIT compilation on changes - no pip reinstall needed

Applied to files:

  • include/flashinfer/mamba/selective_state_update.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Torch headers MUST NOT be included in files within the `include/` directory - keep framework-agnostic CUDA kernels that accept raw pointers

Applied to files:

  • include/flashinfer/mamba/selective_state_update.cuh
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/mamba/selective_state_update.cuh
🔇 Additional comments (7)
include/flashinfer/mamba/selective_state_update.cuh (7)

36-56: LGTM!

The parameter struct is well-designed with clear shape documentation and framework-agnostic void pointers. Default initialization prevents uninitialized usage.


58-90: LGTM!

The utility functions are correctly implemented. The thresholded_softplus properly guards against overflow in softplus, and the warp reduction follows the standard pattern.


92-101: LGTM!

The vectorized load traits are correctly defined for bfloat16. The chunk_size calculation is accurate (8 bytes / 2 bytes = 4 elements, but stored as 2 in the code... wait, let me recalculate: sizeof(input) / sizeof(__nv_bfloat16) = sizeof(float2) / sizeof(__nv_bfloat16) = 8 / 2 = 4).

Note: Future dtype support will require additional trait specializations.


103-226: LGTM!

The simple kernel implementation is correct:

  • Proper bounds checking and pad slot handling
  • Correct state update math: new_state = state * exp(A*dt) + B*dt*x
  • Efficient vectorized loads and warp reduction
  • Correct SiLU gating when z is present

228-242: LGTM!

The shared storage is correctly designed for producer-consumer synchronization with proper 128-byte alignment for TMA operations.


451-464: LGTM!

The alignment checks are thorough and provide helpful error messages indicating the required alignment and kernel variant, exactly as mentioned in the PR objectives.


440-545: Verify the supported dim and dstate values meet current requirements.

The launcher restricts:

  • SM90+ path: dim ∈ {64, 128}, dstate ∈ {64, 128, 256}
  • Pre-Hopper path: dstate ∈ {64, 128, 256}, any dim

Confirm these cover all Nemotron/Mamba model configurations currently in use.

If additional configurations are needed in the future, the switch statements at lines 475-487 and 518-542 will need extension.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (4)
tests/mamba/test_selective_state_update.py (4)

1-7: Missing GPU architecture skip logic.

Per coding guidelines, test implementations should use flashinfer.utils functions to skip tests on unsupported GPU architectures. The selective_state_update kernel has SM90-optimized paths and requires specific GPU support.

Consider adding architecture checks:

import flashinfer
from flashinfer.utils import get_compute_capability

# At the start of each test or as a module-level skip
@pytest.fixture(autouse=True)
def skip_unsupported_gpu():
    cc = get_compute_capability()
    if cc < (8, 0):
        pytest.skip("selective_state_update requires SM80+")

85-93: Consider expanding parameter coverage.

The test only covers delta_softplus=True and ngroups=8. Consider adding:

  • delta_softplus=[True, False] to test both code paths
  • Additional ngroups values to verify head grouping logic

150-185: Consider using torch.testing.assert_allclose for consistency.

The test_selective_state_update_with_z function uses torch.testing.assert_allclose directly, which provides better error messages. Consider using the same approach here instead of manual torch.allclose + print statements for consistency across tests.

♻️ Suggested refactor
-    atol = 1e-3
-    rtol = 1e-2
-    outputs_match = torch.allclose(y_ref, y_test, atol=atol, rtol=rtol)
-
-    if outputs_match:
-        print(f"✓ Outputs match within tolerance (atol={atol}, rtol={rtol})")
-    else:
-        print(f"✗ Outputs do NOT match within tolerance (atol={atol}, rtol={rtol})")
-
-        # Detailed comparison using numpy testing
-        y_ref_np = y_ref.detach().cpu().float().numpy()
-        y_test_np = y_test.detach().cpu().float().numpy()
-        # ... detailed mismatch analysis ...
-
-    assert outputs_match
+    atol = 1e-3
+    rtol = 1e-2
+    torch.testing.assert_allclose(y_ref, y_test, atol=atol, rtol=rtol)

188-227: Same suggestion: use torch.testing.assert_allclose for state comparison.

Apply the same simplification for the state tensor comparison to match the pattern in test_selective_state_update_with_z.

♻️ Suggested refactor
-    # Check if states match within tolerance
-    states_match = torch.allclose(
-        state_ref[inputs["slot_idx"]],
-        state[inputs["slot_idx"]],
-        atol=atol,
-        rtol=rtol,
-    )
-
-    if states_match:
-        print(f"✓ States match within tolerance (atol={atol}, rtol={rtol})")
-    else:
-        # ... detailed mismatch analysis ...
-
-    assert states_match
+    torch.testing.assert_allclose(
+        state_ref[inputs["slot_idx"]],
+        state[inputs["slot_idx"]],
+        atol=atol,
+        rtol=rtol,
+    )
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between baa55a1 and 0d9c71b.

📒 Files selected for processing (1)
  • tests/mamba/test_selective_state_update.py
🧰 Additional context used
📓 Path-based instructions (1)
tests/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

tests/**/*.py: Test implementations should use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing with mpirun on multi-GPU systems, use the pattern: mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes - tests/conftest.py provides auto-skipping for OOM tests as a safety net but should not be relied upon

Files:

  • tests/mamba/test_selective_state_update.py
🧠 Learnings (1)
📓 Common learnings
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers
🧬 Code graph analysis (1)
tests/mamba/test_selective_state_update.py (4)
tests/mamba/selective_state_update_triton.py (1)
  • selective_state_update_triton (211-356)
flashinfer/mamba/selective_state_update.py (1)
  • selective_state_update (50-134)
csrc/selective_state_update.cu (2)
  • selective_state_update (25-209)
  • selective_state_update (25-28)
csrc/flashinfer_mamba_binding.cu (1)
  • selective_state_update (23-26)
🔇 Additional comments (2)
tests/mamba/test_selective_state_update.py (2)

10-82: LGTM!

The test input generator is well-structured with:

  • Reproducible seeding
  • Correct strided tensor creation matching kernel expectations
  • Proper shape and dtype handling

The use of torch.rand (uniform) for A instead of torch.randn (normal) appears intentional for numerical stability, aligning with the commit note about "choice of the A tensor."


230-294: LGTM!

Clean test implementation that:

  • Covers the z tensor code path specifically
  • Uses torch.testing.assert_allclose for concise assertions with informative error messages
  • Tests both output and state tensor correctness

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
include/flashinfer/mamba/conversion.cuh (1)

7-9: Unused FP8 include — consider removing or implementing.

cuda_fp8.h is conditionally included but no FP8 conversion functions (toFloat(__nv_fp8_e4m3), convertAndStore(__nv_fp8_e4m3*, float), etc.) are provided. This appears to be dead code.

Additionally, the macro name ENABLE_FP8 is inconsistent with FLASHINFER_ENABLE_BF16 — consider using FLASHINFER_ENABLE_FP8 for consistency.

♻️ Suggested fix: remove unused include or add FP8 functions

If FP8 support is not needed now:

-#ifdef ENABLE_FP8
-#include <cuda_fp8.h>
-#endif

Or rename for consistency:

-#ifdef ENABLE_FP8
+#ifdef FLASHINFER_ENABLE_FP8
 #include <cuda_fp8.h>
 #endif
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5cda288 and d99923c.

📒 Files selected for processing (1)
  • include/flashinfer/mamba/conversion.cuh
🧰 Additional context used
📓 Path-based instructions (1)
include/**/*.cuh

📄 CodeRabbit inference engine (CLAUDE.md)

include/**/*.cuh: Torch headers MUST NOT be included in files within the include/ directory - keep framework-agnostic CUDA kernels that accept raw pointers
Kernel code in include/flashinfer/ is automatically picked up by JIT compilation on changes - no pip reinstall needed
For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly

Files:

  • include/flashinfer/mamba/conversion.cuh
🧠 Learnings (3)
📓 Common learnings
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Kernel code in `include/flashinfer/` is automatically picked up by JIT compilation on changes - no pip reinstall needed

Applied to files:

  • include/flashinfer/mamba/conversion.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Torch headers MUST NOT be included in files within the `include/` directory - keep framework-agnostic CUDA kernels that accept raw pointers

Applied to files:

  • include/flashinfer/mamba/conversion.cuh
🔇 Additional comments (2)
include/flashinfer/mamba/conversion.cuh (2)

11-19: LGTM!

The toFloat overloads correctly use CUDA intrinsics (__half2float, __bfloat162float) and the BF16 variant is properly guarded. The identity passthrough for float is appropriate.


21-31: LGTM!

The convertAndStore overloads correctly use CUDA intrinsics for float-to-half and float-to-bfloat16 conversions. The BF16 variant is properly guarded with FLASHINFER_ENABLE_BF16.

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Jan 13, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !228 has been updated with latest changes, and the CI pipeline #41601989 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for this contribution!

@ishovkun
Copy link
Copy Markdown
Contributor Author

LGTM, thanks for this contribution!

I'll have a follow up soon with more perf for Blackwell!

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Jan 13, 2026

@ishovkun sounds great! Do you expect to land it in this PR, or in a follow up one?

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #41601989: 17/20 passed

@yzh119 yzh119 merged commit e6d3d82 into flashinfer-ai:main Jan 13, 2026
7 checks passed
yzh119 added a commit that referenced this pull request Jan 19, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

This PR brings support for various dtypes for the state parameter of
`selective_state_update` (mamba).
Specifically, this supports the following state types:
- fp32
- bf16
- fp16.

The data types of other parameters are still fixed. Additionally, there
are performance improvements over the current version, see [current
version](#2301):

<img width="3000" height="4500" alt="runtime_vs_batch_size_NVIDIA_B200"
src="https://github.com/user-attachments/assets/21eb7fe0-1655-4f6b-aa62-b88d1d8505e8"
/>

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Refactor**
* Broadened numeric-type support and dispatch for selective state
updates; reworked kernels, shared buffering, alignment and vectorized
loads to improve warp utilization and DIM/DSTATE-driven execution; added
a staged producer–consumer path and clearer error messaging for
unsupported combinations.

* **Tests**
* Added explicit state-type test coverage, introduced a state-type test
parameter, and switched to stricter comparisons for more precise failure
diagnostics.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Igor Shovkun <ishovkun@viking-dvt-151.nvidia.com>
Co-authored-by: Igor Shovkun <ishovkun@viking-cr-196.nvidia.com>
Co-authored-by: Igor Shovkun <ishovkun@viking-prod-258.nvidia.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: yzh119 <zihaoy@nvidia.com>
yzh119 pushed a commit that referenced this pull request Feb 22, 2026
…tionally exposed to the user. (#2591)

<!-- .github/pull_request_template.md -->

## 📌 Description

This PR does several things:
- Improves automatic kernel selection based on the arch, state_dtype,
and the batch size (see image below).
- Slightly improves performance at small batch sizes by launching
several CTAs per tile.
- Adds jinja templates for the `selective_state_update` function (jit is
fast now)
- Reduces the number of meaningless parameter combinations in the tests
(test are still fast)

## Background

This PR changes changes the behavior of the function. Now, an optional
string `algorithm` can be passed to the kernel. The default value 'auto'
allows the user not to think about the internals of the function.
Optionally, the user can specify the kernel that they want. This
adjustment allowed me to make use of the recent mamba benchmarks. The
sweep is shown below:

<!-- Link any related issues here -->
<img width="3600" height="3000"
alt="ssu_speedup_vs_batch_size_NVIDIA_B200"
src="https://github.com/user-attachments/assets/11325924-64e2-48e7-8fee-7244cdd7a893"
/>

One can see, that the new benchmark now correctly shows the speed
difference between the reference Triton and the current implementation
as opposed to my previous
[PR](#2301). Clearly, I
previously messed up the measurements at small batch sizes.

## Kernel Selection

This PR improves the kernel selection in the following ways:
- If the problem size is too small, use the `simple` algorithm with
several CTAs per tile.
- If on Blackwell, use the `horizontal` algorithm only for bf16/fp16
states, else fall back to the `vertical`.


## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

Please check how I handled jinja templates as it's my first time using
those. Also, please check whether I accidentally deleted any important
tests.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Runtime-selectable algorithm for selective state update: auto
(default), simple, vertical, horizontal.

* **Bug Fixes**
* Added runtime validation to ensure index/dtype consistency across
execution paths.

* **Chores**
* JIT/module generation reworked to produce specialized builds per
dtype/dimension and target architectures.
* Public API unified to select appropriate compiled module based on
device and data.

* **Tests**
* Expanded and parameterized tests covering algorithms, dtypes, tiling,
intermediate states, and large batches.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
murphymatt pushed a commit to fw-ai/flashinfer that referenced this pull request Feb 27, 2026
…tionally exposed to the user. (#2591)

<!-- .github/pull_request_template.md -->

## 📌 Description

This PR does several things:
- Improves automatic kernel selection based on the arch, state_dtype,
and the batch size (see image below).
- Slightly improves performance at small batch sizes by launching
several CTAs per tile.
- Adds jinja templates for the `selective_state_update` function (jit is
fast now)
- Reduces the number of meaningless parameter combinations in the tests
(test are still fast)

## Background

This PR changes changes the behavior of the function. Now, an optional
string `algorithm` can be passed to the kernel. The default value 'auto'
allows the user not to think about the internals of the function.
Optionally, the user can specify the kernel that they want. This
adjustment allowed me to make use of the recent mamba benchmarks. The
sweep is shown below:

<!-- Link any related issues here -->
<img width="3600" height="3000"
alt="ssu_speedup_vs_batch_size_NVIDIA_B200"
src="https://github.com/user-attachments/assets/11325924-64e2-48e7-8fee-7244cdd7a893"
/>

One can see, that the new benchmark now correctly shows the speed
difference between the reference Triton and the current implementation
as opposed to my previous
[PR](flashinfer-ai/flashinfer#2301). Clearly, I
previously messed up the measurements at small batch sizes.

## Kernel Selection

This PR improves the kernel selection in the following ways:
- If the problem size is too small, use the `simple` algorithm with
several CTAs per tile.
- If on Blackwell, use the `horizontal` algorithm only for bf16/fp16
states, else fall back to the `vertical`.


## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

Please check how I handled jinja templates as it's my first time using
those. Also, please check whether I accidentally deleted any important
tests.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Runtime-selectable algorithm for selective state update: auto
(default), simple, vertical, horizontal.

* **Bug Fixes**
* Added runtime validation to ensure index/dtype consistency across
execution paths.

* **Chores**
* JIT/module generation reworked to produce specialized builds per
dtype/dimension and target architectures.
* Public API unified to select appropriate compiled module based on
device and data.

* **Tests**
* Expanded and parameterized tests covering algorithms, dtypes, tiling,
intermediate states, and large batches.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants