Skip to content

Enable fp16/bf16/f32 support for selective_state_update (mamba)#2366

Merged
yzh119 merged 51 commits intoflashinfer-ai:mainfrom
ishovkun:main
Jan 19, 2026
Merged

Enable fp16/bf16/f32 support for selective_state_update (mamba)#2366
yzh119 merged 51 commits intoflashinfer-ai:mainfrom
ishovkun:main

Conversation

@ishovkun
Copy link
Copy Markdown
Contributor

@ishovkun ishovkun commented Jan 16, 2026

📌 Description

This PR brings support for various dtypes for the state parameter of selective_state_update (mamba).
Specifically, this supports the following state types:

  • fp32
  • bf16
  • fp16.

The data types of other parameters are still fixed. Additionally, there are performance improvements over the current version, see current version:

runtime_vs_batch_size_NVIDIA_B200

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Refactor

    • Broadened numeric-type support and dispatch for selective state updates; reworked kernels, shared buffering, alignment and vectorized loads to improve warp utilization and DIM/DSTATE-driven execution; added a staged producer–consumer path and clearer error messaging for unsupported combinations.
  • Tests

    • Added explicit state-type test coverage, introduced a state-type test parameter, and switched to stricter comparisons for more precise failure diagnostics.

✏️ Tip: You can customize this high-level summary in your review settings.

Igor Shovkun and others added 30 commits December 22, 2025 10:55
make input tensors always have batch and nheads as dimensions.

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This ensures pointers and strides meet vectorization requirements on
pre-Hopper architectures (SM < 90) to prevent illegal memory accesses.
Fixes:
flashinfer-ai#2301 (comment)
ngroups = 8
delta_softplus = True
input_dtype = torch.bfloat16
weight_dtype = torch.float32
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean we discard the unittest for weight_dtype == torch.bfloat16?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Supported combos include:\n"
(state=bfloat16, input=bfloat16, weight=bfloat16, matrixA=float32) < the old one.
(state=bfloat16, input=bfloat16, weight=float32, matrixA=float32)
(state=float16, input=bfloat16, weight=float32, matrixA=float32)
(state=float32, input=bfloat16, weight=float32, matrixA=float32)

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Jan 16, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !246 has been created, and the CI pipeline #41921386 is currently running. I'll report back once the pipeline job completes.

List all supported dtype combinations instead of just one example.
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
include/flashinfer/mamba/selective_state_update.cuh (1)

526-541: Add alignment checks for vectorized x/z loads.
Line 168 and Line 180 use PackedAligned loads for x/z, but dispatch only checks state, B, and C. If x or z pointers/strides ever lose float4 alignment, the reinterpret_cast loads are undefined. Consider validating x (and z when non-null) similarly, and mirror the checks in the SM90 path.

🛠️ Proposed checks (pre‑Hopper block; mirror in SM90 block)
       constexpr size_t state_vec_size = stateLoadSize * sizeof(state_t);
       constexpr size_t vec_size = sizeof(float4);  // PackedAligned<input_t> uses float4 size

+      FLASHINFER_CHECK(reinterpret_cast<uintptr_t>(params.x) % vec_size == 0,
+                       "x pointer must be aligned to ", vec_size, " bytes");
+      FLASHINFER_CHECK((params.x_stride_batch * sizeof(input_t)) % vec_size == 0,
+                       "x batch stride must be aligned to ", vec_size, " bytes");
+      if (params.z) {
+        FLASHINFER_CHECK(reinterpret_cast<uintptr_t>(params.z) % vec_size == 0,
+                         "z pointer must be aligned to ", vec_size, " bytes");
+        FLASHINFER_CHECK((params.z_stride_batch * sizeof(input_t)) % vec_size == 0,
+                         "z batch stride must be aligned to ", vec_size, " bytes");
+      }

Also applies to: 588-595

🤖 Fix all issues with AI agents
In `@include/flashinfer/mamba/selective_state_update.cuh`:
- Around line 231-241: The current store loop computes d = warp * rowsPerWarp +
lane and only checks d < DIM, which allows lanes >= rowsPerWarp to write into
the next warp when DIM equals warp*rowsPerWarp (e.g., DIM=64, rowsPerWarp=16),
causing overlapping/racy writes; in the kernel around the sram/out/z handling
(variables: warp, rowsPerWarp, lane, DIM, sram.out, sram.z) add an additional
guard that requires lane < rowsPerWarp before performing the per-row work and
the convertAndStore to output (use the same output index expression:
output[batch * params.out_stride_batch + head * DIM + d]) so only the lanes that
own rows perform the stores.
🧹 Nitpick comments (1)
include/flashinfer/mamba/selective_state_update.cuh (1)

278-333: Add a brief rationale for the uint32_t packed path.
The uint32_t packing and the sizeof(state_t) == sizeof(input_t) split are key perf choices; a short comment on bank-width/throughput rationale (and any alternatives) would help future reviewers. As per coding guidelines, please add a short explanatory note.

The output write loop was only using one lane per warp, leaving most
threads idle. Change to iterate with lane stride so all 32 lanes
participate in writing output values.
Support fp16/fp32 state with bf16 weights in addition to existing
fp32 weight combinations. This enables more flexible memory/precision
tradeoffs for Mamba models.
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
include/flashinfer/mamba/selective_state_update.cuh (1)

528-546: Add alignment checks for x/z vectorized loads.

Both kernels use PackedAligned<input_t> for x/z, but only B/C/state are validated. Misaligned x or z (or their strides) can cause undefined vector loads. Consider validating these pointers/strides (or falling back to scalar loads).

🛠️ Proposed fix (checks in both dispatch blocks)
       constexpr size_t state_vec_size = stateLoadSize * sizeof(state_t);
       constexpr size_t vec_size = sizeof(float4);  // PackedAligned<input_t> uses float4 size

+      FLASHINFER_CHECK(reinterpret_cast<uintptr_t>(params.x) % vec_size == 0,
+                       "x pointer must be aligned to ", vec_size, " bytes");
+      FLASHINFER_CHECK((params.x_stride_batch * sizeof(input_t)) % vec_size == 0,
+                       "x batch stride must be aligned to ", vec_size, " bytes");
+      if (params.z) {
+        FLASHINFER_CHECK(reinterpret_cast<uintptr_t>(params.z) % vec_size == 0,
+                         "z pointer must be aligned to ", vec_size, " bytes");
+        FLASHINFER_CHECK((params.z_stride_batch * sizeof(input_t)) % vec_size == 0,
+                         "z batch stride must be aligned to ", vec_size, " bytes");
+      }
       FLASHINFER_CHECK(reinterpret_cast<uintptr_t>(params.state) % state_vec_size == 0,
                        "state pointer must be aligned to ", state_vec_size, " bytes");
       FLASHINFER_CHECK(reinterpret_cast<uintptr_t>(params.B) % vec_size == 0,
                        "B pointer must be aligned to ", vec_size, " bytes");
       FLASHINFER_CHECK(reinterpret_cast<uintptr_t>(params.C) % vec_size == 0,
                        "C pointer must be aligned to ", vec_size, " bytes");
@@
       // B and C use PackedAligned<input_t> which uses float4 size
       constexpr size_t vec_size = sizeof(float4);

+      FLASHINFER_CHECK(reinterpret_cast<uintptr_t>(params.x) % vec_size == 0,
+                       "x pointer must be aligned to ", vec_size, " bytes");
+      FLASHINFER_CHECK((params.x_stride_batch * sizeof(input_t)) % vec_size == 0,
+                       "x batch stride must be aligned to ", vec_size, " bytes");
+      if (params.z) {
+        FLASHINFER_CHECK(reinterpret_cast<uintptr_t>(params.z) % vec_size == 0,
+                         "z pointer must be aligned to ", vec_size, " bytes");
+        FLASHINFER_CHECK((params.z_stride_batch * sizeof(input_t)) % vec_size == 0,
+                         "z batch stride must be aligned to ", vec_size, " bytes");
+      }
       FLASHINFER_CHECK(reinterpret_cast<uintptr_t>(params.B) % vec_size == 0,
                        "B pointer must be aligned to ", vec_size, " bytes");
       FLASHINFER_CHECK(reinterpret_cast<uintptr_t>(params.C) % vec_size == 0,
                        "C pointer must be aligned to ", vec_size, " bytes");

Also applies to: 584-598

🧹 Nitpick comments (1)
include/flashinfer/mamba/selective_state_update.cuh (1)

280-336: Document the 32‑bit bank‑width packing path.

This is a performance-sensitive optimization; a short comment explaining why the 32‑bit packing is chosen (and possible alternatives) will help future reviewers.

📝 Suggested comment
-      constexpr auto bankSize = sizeof(uint32_t);
+      // Use 32-bit packing to align with shared-memory bank width for 16-bit values.
+      // Alternative: PackedAligned vector loads if we retune tiles later.
+      constexpr auto bankSize = sizeof(uint32_t);
As per coding guidelines, ...

The selective state update kernel uses vectorized loads for x and z
tensors but was missing runtime alignment validation. Also refactors
alignment checks to use sizeof(load_*_t) directly instead of hardcoded
float4 size for consistency.
@ishovkun
Copy link
Copy Markdown
Contributor Author

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)

🧹 Nitpick comments (1)

Resolved

@ishovkun ishovkun closed this Jan 17, 2026
@ishovkun ishovkun reopened this Jan 17, 2026
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #41921386: 9/20 passed

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Jan 18, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !246 has been updated with latest changes, and the CI pipeline #41982617 is currently running. I'll report back once the pipeline job completes.

@yzh119
Copy link
Copy Markdown
Collaborator

yzh119 commented Jan 18, 2026

Hi @ishovkun all unittests failed for cuda < 13 with errors such as:

FAILED tests/mamba/test_selective_state_update.py::test_selective_state_update[state_dtype2-64-64-8-1] - RuntimeError: Ninja build failed. Ninja output:
  ninja: Entering directory `/root/.cache/flashinfer/0.6.1/100a/cached_ops/mamba_selective_state_update_sm90'
  [1/2] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /root/.cache/flashinfer/0.6.1/100a/cached_ops/mamba_selective_state_update_sm90/selective_state_update.cuda.o.d
  -DPy_LIMITED_API=0x03090000 -D_GLIBCXX_USE_CXX11_ABI=1 -isystem /opt/conda/envs/py312/include/python3.12 -isystem /usr/local/cuda/include -isystem /usr/local/cuda/include/cccl -isystem
  /opt/conda/envs/py312/lib/python3.12/site-packages/tvm_ffi/include -isystem /opt/conda/envs/py312/lib/python3.12/site-packages/tvm_ffi/include -isystem /workspace/flashinfer/include -isystem
  /workspace/flashinfer/csrc -isystem /workspace/flashinfer/3rdparty/cutlass/include -isystem /workspace/flashinfer/3rdparty/cutlass/tools/util/include -isystem /workspace/flashinfer/3rdparty/spdlog/include
  --compiler-options=-fPIC --expt-relaxed-constexpr -static-global-template-stub=false -DFLASHINFER_ENABLE_FP8_E8M0 -DFLASHINFER_ENABLE_FP4_E2M1 -std=c++17 --threads=1 -use_fast_math -DFLASHINFER_ENABLE_F16
  -DFLASHINFER_ENABLE_BF16 -DFLASHINFER_ENABLE_FP8_E4M3 -DFLASHINFER_ENABLE_FP8_E5M2 -DNDEBUG -O3 -gencode=arch=compute_100a,code=sm_100a -DFLASHINFER_ENABLE_FP8_E8M0 -DFLASHINFER_ENABLE_FP4_E2M1
  -DFLASHINFER_MAMBA_ENABLE_SM90 -c /workspace/flashinfer/csrc/selective_state_update.cu -o /root/.cache/flashinfer/0.6.1/100a/cached_ops/mamba_selective_state_update_sm90/selective_state_update.cuda.o
  FAILED: [code=255] /root/.cache/flashinfer/0.6.1/100a/cached_ops/mamba_selective_state_update_sm90/selective_state_update.cuda.o
  /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /root/.cache/flashinfer/0.6.1/100a/cached_ops/mamba_selective_state_update_sm90/selective_state_update.cuda.o.d
  -DPy_LIMITED_API=0x03090000 -D_GLIBCXX_USE_CXX11_ABI=1 -isystem /opt/conda/envs/py312/include/python3.12 -isystem /usr/local/cuda/include -isystem /usr/local/cuda/include/cccl -isystem
  /opt/conda/envs/py312/lib/python3.12/site-packages/tvm_ffi/include -isystem /opt/conda/envs/py312/lib/python3.12/site-packages/tvm_ffi/include -isystem /workspace/flashinfer/include -isystem
  /workspace/flashinfer/csrc -isystem /workspace/flashinfer/3rdparty/cutlass/include -isystem /workspace/flashinfer/3rdparty/cutlass/tools/util/include -isystem /workspace/flashinfer/3rdparty/spdlog/include
  --compiler-options=-fPIC --expt-relaxed-constexpr -static-global-template-stub=false -DFLASHINFER_ENABLE_FP8_E8M0 -DFLASHINFER_ENABLE_FP4_E2M1 -std=c++17 --threads=1 -use_fast_math -DFLASHINFER_ENABLE_F16
  -DFLASHINFER_ENABLE_BF16 -DFLASHINFER_ENABLE_FP8_E4M3 -DFLASHINFER_ENABLE_FP8_E5M2 -DNDEBUG -O3 -gencode=arch=compute_100a,code=sm_100a -DFLASHINFER_ENABLE_FP8_E8M0 -DFLASHINFER_ENABLE_FP4_E2M1
  -DFLASHINFER_MAMBA_ENABLE_SM90 -c /workspace/flashinfer/csrc/selective_state_update.cu -o /root/.cache/flashinfer/0.6.1/100a/cached_ops/mamba_selective_state_update_sm90/selective_state_update.cuda.o
  ptxas error   : Entry function
  '_ZN10flashinfer5mamba56selective_state_update_kernel_producer_consumer_verticalI13__nv_bfloat16fffLi128ELi256ELi4ELi16ELi3EEEvNS0_26SelectiveStateUpdateParamsE14CUtensorMap_st' uses too much shared data
  (0xc880 bytes, 0xc000 max)
  ptxas error   : Entry function
  '_ZN10flashinfer5mamba56selective_state_update_kernel_producer_consumer_verticalI13__nv_bfloat16fffLi64ELi256ELi4ELi16ELi3EEEvNS0_26SelectiveStateUpdateParamsE14CUtensorMap_st' uses too much shared data
  (0xc680 bytes, 0xc000 max)

I added a commit to your code 65ee5c2 to use dynamic shared memory, does it look good to you?

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@include/flashinfer/mamba/selective_state_update.cuh`:
- Around line 640-645: The call to cudaFuncSetAttribute(scan_func,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) does not check for CUDA
errors; wrap this call with the project's error-check macro
(FLASHINFER_CUDA_CALL) so failures are reported immediately. Locate the block
that defines sram_t and smem_size (using SharedStorage<input_t,...>), and
replace the direct cudaFuncSetAttribute invocation with a
FLASHINFER_CUDA_CALL(...) around that call, passing the same arguments
(scan_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) to ensure
consistent error handling.
🧹 Nitpick comments (2)
include/flashinfer/mamba/selective_state_update.cuh (2)

65-92: Add compile-time guards for vector-load sizing assumptions.
getVectorLoadSizeForFullUtilization can return 0 if future DSTATE < warpSize, and there’s no compile‑time guard against invalid vector sizes. A static_assert would prevent accidental PackedAligned<...,0> or out-of-range loads if new DSTATE values are introduced.

💡 Suggested guard
template <typename T, int DSTATE>
inline constexpr auto getVectorLoadSizeForFullUtilization() -> unsigned {
  static_assert(sizeof(float4) >= sizeof(T));
  constexpr unsigned maxHardwareLoadSize = sizeof(float4) / sizeof(T);
  constexpr unsigned warpSize = 32;
  constexpr unsigned maxLogicalLoadSize = (unsigned)DSTATE / warpSize;
+ static_assert(maxLogicalLoadSize > 0, "DSTATE must be >= warpSize");
  return maxHardwareLoadSize < maxLogicalLoadSize ? maxHardwareLoadSize : maxLogicalLoadSize;
}

280-313: Add a short rationale for the 32‑bit packed shared‑memory loads.
The uint32_t reinterpret and bank-size math are performance‑critical; a one‑liner explaining the bank‑width optimization will help future maintainers.

💬 Suggested comment
-      constexpr auto bankSize = sizeof(uint32_t);
+      // Use 32-bit packed loads to match shared-memory bank width and reduce bank conflicts.
+      constexpr auto bankSize = sizeof(uint32_t);

As per coding guidelines, add brief rationale for performance‑critical choices.

@ishovkun
Copy link
Copy Markdown
Contributor Author

Hi @ishovkun all unittests failed for cuda < 13 with errors such as:

FAILED tests/mamba/test_selective_state_update.py::test_selective_state_update[state_dtype2-64-64-8-1] - RuntimeError: Ninja build failed. Ninja output:
  ninja: Entering directory `/root/.cache/flashinfer/0.6.1/100a/cached_ops/mamba_selective_state_update_sm90'
  [1/2] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /root/.cache/flashinfer/0.6.1/100a/cached_ops/mamba_selective_state_update_sm90/selective_state_update.cuda.o.d
  -DPy_LIMITED_API=0x03090000 -D_GLIBCXX_USE_CXX11_ABI=1 -isystem /opt/conda/envs/py312/include/python3.12 -isystem /usr/local/cuda/include -isystem /usr/local/cuda/include/cccl -isystem
  /opt/conda/envs/py312/lib/python3.12/site-packages/tvm_ffi/include -isystem /opt/conda/envs/py312/lib/python3.12/site-packages/tvm_ffi/include -isystem /workspace/flashinfer/include -isystem
  /workspace/flashinfer/csrc -isystem /workspace/flashinfer/3rdparty/cutlass/include -isystem /workspace/flashinfer/3rdparty/cutlass/tools/util/include -isystem /workspace/flashinfer/3rdparty/spdlog/include
  --compiler-options=-fPIC --expt-relaxed-constexpr -static-global-template-stub=false -DFLASHINFER_ENABLE_FP8_E8M0 -DFLASHINFER_ENABLE_FP4_E2M1 -std=c++17 --threads=1 -use_fast_math -DFLASHINFER_ENABLE_F16
  -DFLASHINFER_ENABLE_BF16 -DFLASHINFER_ENABLE_FP8_E4M3 -DFLASHINFER_ENABLE_FP8_E5M2 -DNDEBUG -O3 -gencode=arch=compute_100a,code=sm_100a -DFLASHINFER_ENABLE_FP8_E8M0 -DFLASHINFER_ENABLE_FP4_E2M1
  -DFLASHINFER_MAMBA_ENABLE_SM90 -c /workspace/flashinfer/csrc/selective_state_update.cu -o /root/.cache/flashinfer/0.6.1/100a/cached_ops/mamba_selective_state_update_sm90/selective_state_update.cuda.o
  FAILED: [code=255] /root/.cache/flashinfer/0.6.1/100a/cached_ops/mamba_selective_state_update_sm90/selective_state_update.cuda.o
  /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /root/.cache/flashinfer/0.6.1/100a/cached_ops/mamba_selective_state_update_sm90/selective_state_update.cuda.o.d
  -DPy_LIMITED_API=0x03090000 -D_GLIBCXX_USE_CXX11_ABI=1 -isystem /opt/conda/envs/py312/include/python3.12 -isystem /usr/local/cuda/include -isystem /usr/local/cuda/include/cccl -isystem
  /opt/conda/envs/py312/lib/python3.12/site-packages/tvm_ffi/include -isystem /opt/conda/envs/py312/lib/python3.12/site-packages/tvm_ffi/include -isystem /workspace/flashinfer/include -isystem
  /workspace/flashinfer/csrc -isystem /workspace/flashinfer/3rdparty/cutlass/include -isystem /workspace/flashinfer/3rdparty/cutlass/tools/util/include -isystem /workspace/flashinfer/3rdparty/spdlog/include
  --compiler-options=-fPIC --expt-relaxed-constexpr -static-global-template-stub=false -DFLASHINFER_ENABLE_FP8_E8M0 -DFLASHINFER_ENABLE_FP4_E2M1 -std=c++17 --threads=1 -use_fast_math -DFLASHINFER_ENABLE_F16
  -DFLASHINFER_ENABLE_BF16 -DFLASHINFER_ENABLE_FP8_E4M3 -DFLASHINFER_ENABLE_FP8_E5M2 -DNDEBUG -O3 -gencode=arch=compute_100a,code=sm_100a -DFLASHINFER_ENABLE_FP8_E8M0 -DFLASHINFER_ENABLE_FP4_E2M1
  -DFLASHINFER_MAMBA_ENABLE_SM90 -c /workspace/flashinfer/csrc/selective_state_update.cu -o /root/.cache/flashinfer/0.6.1/100a/cached_ops/mamba_selective_state_update_sm90/selective_state_update.cuda.o
  ptxas error   : Entry function
  '_ZN10flashinfer5mamba56selective_state_update_kernel_producer_consumer_verticalI13__nv_bfloat16fffLi128ELi256ELi4ELi16ELi3EEEvNS0_26SelectiveStateUpdateParamsE14CUtensorMap_st' uses too much shared data
  (0xc880 bytes, 0xc000 max)
  ptxas error   : Entry function
  '_ZN10flashinfer5mamba56selective_state_update_kernel_producer_consumer_verticalI13__nv_bfloat16fffLi64ELi256ELi4ELi16ELi3EEEvNS0_26SelectiveStateUpdateParamsE14CUtensorMap_st' uses too much shared data
  (0xc680 bytes, 0xc000 max)

I added a commit to your code 65ee5c2 to use dynamic shared memory, does it look good to you?

It looks fantastic! I think that the code rabbit is right and we need to add the error handling to the request of dynamic shmem, though.

@yzh119 yzh119 merged commit c80d679 into flashinfer-ai:main Jan 19, 2026
22 checks passed
yzh119 added a commit that referenced this pull request Jan 20, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

The commit
1938c5c
we added to #2366 cause hanging issue when running the unittests because
the `FLASHINFER_CUDA_CALL` macro expects the caller environment returns
a `cudaError_t` which contradicts with the structure of the lambda
functions inside invokeSelectiveStateUpdate of mamba kernel. This PR
fixes the issue by adding another macro `FLASHINFER_CUDA_CHECK`.

## 🔍 Related Issues

#2366 

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Refactor**
* Improved CUDA error handling mechanisms to provide clearer error
messages and enhanced diagnostics during kernel configuration and
execution.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants