Enable fp16/bf16/f32 support for selective_state_update (mamba)#2366
Enable fp16/bf16/f32 support for selective_state_update (mamba)#2366yzh119 merged 51 commits intoflashinfer-ai:mainfrom
Conversation
make input tensors always have batch and nheads as dimensions. Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This ensures pointers and strides meet vectorization requirements on pre-Hopper architectures (SM < 90) to prevent illegal memory accesses. Fixes: flashinfer-ai#2301 (comment)
| ngroups = 8 | ||
| delta_softplus = True | ||
| input_dtype = torch.bfloat16 | ||
| weight_dtype = torch.float32 |
There was a problem hiding this comment.
Does this mean we discard the unittest for weight_dtype == torch.bfloat16?
There was a problem hiding this comment.
Supported combos include:\n"
(state=bfloat16, input=bfloat16, weight=bfloat16, matrixA=float32) < the old one.
(state=bfloat16, input=bfloat16, weight=float32, matrixA=float32)
(state=float16, input=bfloat16, weight=float32, matrixA=float32)
(state=float32, input=bfloat16, weight=float32, matrixA=float32)
|
/bot run |
List all supported dtype combinations instead of just one example.
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
include/flashinfer/mamba/selective_state_update.cuh (1)
526-541: Add alignment checks for vectorized x/z loads.
Line 168 and Line 180 use PackedAligned loads forx/z, but dispatch only checksstate,B, andC. Ifxorzpointers/strides ever lose float4 alignment, the reinterpret_cast loads are undefined. Consider validatingx(andzwhen non-null) similarly, and mirror the checks in the SM90 path.🛠️ Proposed checks (pre‑Hopper block; mirror in SM90 block)
constexpr size_t state_vec_size = stateLoadSize * sizeof(state_t); constexpr size_t vec_size = sizeof(float4); // PackedAligned<input_t> uses float4 size + FLASHINFER_CHECK(reinterpret_cast<uintptr_t>(params.x) % vec_size == 0, + "x pointer must be aligned to ", vec_size, " bytes"); + FLASHINFER_CHECK((params.x_stride_batch * sizeof(input_t)) % vec_size == 0, + "x batch stride must be aligned to ", vec_size, " bytes"); + if (params.z) { + FLASHINFER_CHECK(reinterpret_cast<uintptr_t>(params.z) % vec_size == 0, + "z pointer must be aligned to ", vec_size, " bytes"); + FLASHINFER_CHECK((params.z_stride_batch * sizeof(input_t)) % vec_size == 0, + "z batch stride must be aligned to ", vec_size, " bytes"); + }Also applies to: 588-595
🤖 Fix all issues with AI agents
In `@include/flashinfer/mamba/selective_state_update.cuh`:
- Around line 231-241: The current store loop computes d = warp * rowsPerWarp +
lane and only checks d < DIM, which allows lanes >= rowsPerWarp to write into
the next warp when DIM equals warp*rowsPerWarp (e.g., DIM=64, rowsPerWarp=16),
causing overlapping/racy writes; in the kernel around the sram/out/z handling
(variables: warp, rowsPerWarp, lane, DIM, sram.out, sram.z) add an additional
guard that requires lane < rowsPerWarp before performing the per-row work and
the convertAndStore to output (use the same output index expression:
output[batch * params.out_stride_batch + head * DIM + d]) so only the lanes that
own rows perform the stores.
🧹 Nitpick comments (1)
include/flashinfer/mamba/selective_state_update.cuh (1)
278-333: Add a brief rationale for the uint32_t packed path.
Theuint32_tpacking and thesizeof(state_t) == sizeof(input_t)split are key perf choices; a short comment on bank-width/throughput rationale (and any alternatives) would help future reviewers. As per coding guidelines, please add a short explanatory note.
The output write loop was only using one lane per warp, leaving most threads idle. Change to iterate with lane stride so all 32 lanes participate in writing output values.
Support fp16/fp32 state with bf16 weights in addition to existing fp32 weight combinations. This enables more flexible memory/precision tradeoffs for Mamba models.
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
include/flashinfer/mamba/selective_state_update.cuh (1)
528-546: Add alignment checks for x/z vectorized loads.Both kernels use
PackedAligned<input_t>forx/z, but onlyB/C/stateare validated. Misalignedxorz(or their strides) can cause undefined vector loads. Consider validating these pointers/strides (or falling back to scalar loads).🛠️ Proposed fix (checks in both dispatch blocks)
constexpr size_t state_vec_size = stateLoadSize * sizeof(state_t); constexpr size_t vec_size = sizeof(float4); // PackedAligned<input_t> uses float4 size + FLASHINFER_CHECK(reinterpret_cast<uintptr_t>(params.x) % vec_size == 0, + "x pointer must be aligned to ", vec_size, " bytes"); + FLASHINFER_CHECK((params.x_stride_batch * sizeof(input_t)) % vec_size == 0, + "x batch stride must be aligned to ", vec_size, " bytes"); + if (params.z) { + FLASHINFER_CHECK(reinterpret_cast<uintptr_t>(params.z) % vec_size == 0, + "z pointer must be aligned to ", vec_size, " bytes"); + FLASHINFER_CHECK((params.z_stride_batch * sizeof(input_t)) % vec_size == 0, + "z batch stride must be aligned to ", vec_size, " bytes"); + } FLASHINFER_CHECK(reinterpret_cast<uintptr_t>(params.state) % state_vec_size == 0, "state pointer must be aligned to ", state_vec_size, " bytes"); FLASHINFER_CHECK(reinterpret_cast<uintptr_t>(params.B) % vec_size == 0, "B pointer must be aligned to ", vec_size, " bytes"); FLASHINFER_CHECK(reinterpret_cast<uintptr_t>(params.C) % vec_size == 0, "C pointer must be aligned to ", vec_size, " bytes"); @@ // B and C use PackedAligned<input_t> which uses float4 size constexpr size_t vec_size = sizeof(float4); + FLASHINFER_CHECK(reinterpret_cast<uintptr_t>(params.x) % vec_size == 0, + "x pointer must be aligned to ", vec_size, " bytes"); + FLASHINFER_CHECK((params.x_stride_batch * sizeof(input_t)) % vec_size == 0, + "x batch stride must be aligned to ", vec_size, " bytes"); + if (params.z) { + FLASHINFER_CHECK(reinterpret_cast<uintptr_t>(params.z) % vec_size == 0, + "z pointer must be aligned to ", vec_size, " bytes"); + FLASHINFER_CHECK((params.z_stride_batch * sizeof(input_t)) % vec_size == 0, + "z batch stride must be aligned to ", vec_size, " bytes"); + } FLASHINFER_CHECK(reinterpret_cast<uintptr_t>(params.B) % vec_size == 0, "B pointer must be aligned to ", vec_size, " bytes"); FLASHINFER_CHECK(reinterpret_cast<uintptr_t>(params.C) % vec_size == 0, "C pointer must be aligned to ", vec_size, " bytes");Also applies to: 584-598
🧹 Nitpick comments (1)
include/flashinfer/mamba/selective_state_update.cuh (1)
280-336: Document the 32‑bit bank‑width packing path.This is a performance-sensitive optimization; a short comment explaining why the 32‑bit packing is chosen (and possible alternatives) will help future reviewers.
As per coding guidelines, ...📝 Suggested comment
- constexpr auto bankSize = sizeof(uint32_t); + // Use 32-bit packing to align with shared-memory bank width for 16-bit values. + // Alternative: PackedAligned vector loads if we retune tiles later. + constexpr auto bankSize = sizeof(uint32_t);
The selective state update kernel uses vectorized loads for x and z tensors but was missing runtime alignment validation. Also refactors alignment checks to use sizeof(load_*_t) directly instead of hardcoded float4 size for consistency.
Resolved |
|
[FAILED] Pipeline #41921386: 9/20 passed |
|
/bot run |
|
Hi @ishovkun all unittests failed for cuda < 13 with errors such as: I added a commit to your code 65ee5c2 to use dynamic shared memory, does it look good to you? |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@include/flashinfer/mamba/selective_state_update.cuh`:
- Around line 640-645: The call to cudaFuncSetAttribute(scan_func,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) does not check for CUDA
errors; wrap this call with the project's error-check macro
(FLASHINFER_CUDA_CALL) so failures are reported immediately. Locate the block
that defines sram_t and smem_size (using SharedStorage<input_t,...>), and
replace the direct cudaFuncSetAttribute invocation with a
FLASHINFER_CUDA_CALL(...) around that call, passing the same arguments
(scan_func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) to ensure
consistent error handling.
🧹 Nitpick comments (2)
include/flashinfer/mamba/selective_state_update.cuh (2)
65-92: Add compile-time guards for vector-load sizing assumptions.
getVectorLoadSizeForFullUtilizationcan return 0 if future DSTATE < warpSize, and there’s no compile‑time guard against invalid vector sizes. A static_assert would prevent accidentalPackedAligned<...,0>or out-of-range loads if new DSTATE values are introduced.💡 Suggested guard
template <typename T, int DSTATE> inline constexpr auto getVectorLoadSizeForFullUtilization() -> unsigned { static_assert(sizeof(float4) >= sizeof(T)); constexpr unsigned maxHardwareLoadSize = sizeof(float4) / sizeof(T); constexpr unsigned warpSize = 32; constexpr unsigned maxLogicalLoadSize = (unsigned)DSTATE / warpSize; + static_assert(maxLogicalLoadSize > 0, "DSTATE must be >= warpSize"); return maxHardwareLoadSize < maxLogicalLoadSize ? maxHardwareLoadSize : maxLogicalLoadSize; }
280-313: Add a short rationale for the 32‑bit packed shared‑memory loads.
Theuint32_treinterpret and bank-size math are performance‑critical; a one‑liner explaining the bank‑width optimization will help future maintainers.💬 Suggested comment
- constexpr auto bankSize = sizeof(uint32_t); + // Use 32-bit packed loads to match shared-memory bank width and reduce bank conflicts. + constexpr auto bankSize = sizeof(uint32_t);As per coding guidelines, add brief rationale for performance‑critical choices.
It looks fantastic! I think that the code rabbit is right and we need to add the error handling to the request of dynamic shmem, though. |
<!-- .github/pull_request_template.md --> ## 📌 Description The commit 1938c5c we added to #2366 cause hanging issue when running the unittests because the `FLASHINFER_CUDA_CALL` macro expects the caller environment returns a `cudaError_t` which contradicts with the structure of the lambda functions inside invokeSelectiveStateUpdate of mamba kernel. This PR fixes the issue by adding another macro `FLASHINFER_CUDA_CHECK`. ## 🔍 Related Issues #2366 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Improved CUDA error handling mechanisms to provide clearer error messages and enhanced diagnostics during kernel configuration and execution. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
📌 Description
This PR brings support for various dtypes for the state parameter of
selective_state_update(mamba).Specifically, this supports the following state types:
The data types of other parameters are still fixed. Additionally, there are performance improvements over the current version, see current version:
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Refactor
Tests
✏️ Tip: You can customize this high-level summary in your review settings.