Support NVFP4 KV cache decode on SM120#2520
Conversation
Summary of ChangesHello @Tom-Zheng, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates NVIDIA FP4 (NVFP4) support for Key-Value (KV) cache decoding on SM120 GPUs. This feature aims to significantly reduce the memory footprint of KV caches, enabling larger models or longer context windows with improved performance. The changes involve extensive modifications to CUDA kernels for efficient 4-bit data handling and corresponding updates to the Python API to expose this new capability to users. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
📝 WalkthroughWalkthroughAdds NVFP4 (4-bit) KV-cache support end-to-end: type-system converters, kernel unpack/scale paths, dual-grain GMEM/SMEM handling, new SF cache pointers threaded through C++ wrappers and Python APIs, and tests exercising NVFP4 KV paths with runtime SM120 checks. Changes
Sequence Diagram(s)sequenceDiagram
participant Host as Host (Python)
participant Wrapper as XQA Wrapper (C++)
participant Kernel as MHA Kernel (CUDA)
participant GMEM as Global Memory
participant SMEM as Shared Memory
Host->>Host: Quantize KV to NVFP4 and compute k_sf_cache / v_sf_cache
Host->>Wrapper: xqa_wrapper(k_cache, v_cache, k_sf_cache, v_sf_cache, ...)
Wrapper->>Wrapper: Optional cast to GMemCacheHeadSf*
Wrapper->>Kernel: Launch kernel with KV + SF cache pointers
Kernel->>GMEM: Load packed 4-bit K/V and SF arrays
Kernel->>SMEM: Copy K/V and SF into SMEM buffers (grain-aware)
Kernel->>Kernel: ldmatrix_4x_unpack_4b() -> convertKCacheWordToF16() -> applyF16ScalingFactor()
Kernel->>Kernel: Compute attention using unpacked FP16/BF16
Kernel->>GMEM: Write outputs
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces support for NVFP4 KV cache on the SM120 architecture. The changes are extensive, modifying CUDA C++ kernels, helper utilities, and Python wrappers to handle the new 4-bit data type and its associated scaling factors. The refactoring to use ElemTypeConverter is a good design choice for managing different data types. However, I've identified a critical issue in the CUDA code involving an out-of-scope variable that would likely prevent compilation, and a logic error in a test case that could lead to incorrect test validation. My review includes suggestions to fix these issues.
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tests/attention/test_xqa_batch_decode.py (1)
107-175:⚠️ Potential issue | 🟡 MinorUse v_global_scale when dequantizing V cache.
The NVFP4 reference path currently uses k_global_scale for V, which will skew the reference when K/V scales differ.
🛠️ Suggested fix
- v_cache_dq = nvfp4_to_float(v_cache, v_scale, k_global_scale) + v_cache_dq = nvfp4_to_float(v_cache, v_scale, v_global_scale)csrc/xqa/mha.cu (1)
2885-2893:⚠️ Potential issue | 🟡 Minor
launchMHAis dead code and never invoked—stride calculation issue confirmed but latent.The stride calculation difference is real:
launchMHAdivides byvalidElemsPerHead(lines 2890–2893), whilelaunchMHAFlashInferdivides bycontainer_elems_per_head = validElemsPerHead / CacheElemConverter::ElemsPerContainer(lines 2974–2979). For 4-bit KV cache (ElemsPerContainer == 2), this is a 2× difference.However,
launchMHAis never called from anywhere in the codebase—neither fromxqa_wrapper.cu(which always dispatches tolaunchMHAFlashInfer) nor from any other entry point. This is dead code. The bug is real but unreachable.If
launchMHAis intentionally kept for backward compatibility or future use, apply the suggested fix to align stride calculations withlaunchMHAFlashInfer.Suggested fix — align launchMHA stride calculation with launchMHAFlashInfer
// Convert stride from elements to Heads - uint32_t const stride_page_in_heads = static_cast<uint32_t>(kv_stride_page / validElemsPerHead); - uint32_t const stride_token_in_heads = static_cast<uint32_t>(kv_stride_token / validElemsPerHead); - uint32_t const stride_head_in_heads = static_cast<uint32_t>(kv_stride_head / validElemsPerHead); + uint32_t const container_elems_per_head = + validElemsPerHead / CacheElemConverter::ElemsPerContainer; + uint32_t const stride_page_in_heads = + static_cast<uint32_t>(kv_stride_page / container_elems_per_head); + uint32_t const stride_token_in_heads = + static_cast<uint32_t>(kv_stride_token / container_elems_per_head); + uint32_t const stride_head_in_heads = + static_cast<uint32_t>(kv_stride_head / container_elems_per_head);
🤖 Fix all issues with AI agents
In `@csrc/xqa/mha.cu`:
- Around line 1169-1185: The vSfPrefetch register buffer is populated but never
used; either remove the prefetch block or make the main consumer read from
vSfPrefetch. To fix: if removing, delete the Array2D<uint32_t,...> vSfPrefetch
declaration and the entire prefetch loop (the `#if` ENABLE_4BIT_KV_CACHE block
that fills vSfPrefetch); otherwise change the consumer loop that currently calls
vSf(rowIdx, colIdx) to index into vSfPrefetch using the same sfRowIdx
calculation (use sfRowIdxInSlice / sfRowIdx mapping produced in the prefetch
loop and nbSfPrefetchBuffers for column indexing) so the registers filled by
vSfPrefetch are consumed instead of reloading vSf. Ensure you keep
ENABLE_4BIT_KV_CACHE gating consistent and remove any now-unused variables if
you choose removal.
- Around line 2113-2121: In the loadPages lambda remove the accidental reference
to idxNextSMemVBuf and the unused dstSf: do not call getSmemVSfTile or declare
dstSf inside loadPages (this symbol is defined only in loadVTilePart), and
ensure the ENABLE_4BIT_KV_CACHE branch only declares/uses dst (the
smem.vCachePages reference) — delete the unused dstSf/idxNextSMemVBuf lines so
the lambda compiles when BEAM_WIDTH > 1; keep loadPages calling
loadPagesForBeamSearchAsync with dst unchanged.
In `@flashinfer/xqa.py`:
- Around line 306-311: The check that calls
get_compute_capability(torch.device(device="cuda")) queries the default CUDA
device and can be wrong for tensors on other GPUs; update the capability checks
in the NVFP4 path to call get_compute_capability with the actual tensor device
(e.g. use q.device or k_cache.device) instead of torch.device(device="cuda") so
the assertions around k_cache.dtype, k_sf_cache, and v_sf_cache validate the
correct GPU for the tensor in functions/blocks referencing
get_compute_capability and the NVFP4 KV path.
In `@tests/attention/test_xqa_batch_decode.py`:
- Around line 582-765: The test function test_xqa_batch_decode_nvfp4_kv binds an
unused variable in_kv_lens from generate_seq_lens_decode; remove the unused
binding by changing the assignment to only capture needed values (e.g., q_lens,
seq_lens = generate_seq_lens_decode(...)) or replace in_kv_lens with an
underscore (_) so the unused return value is ignored; update the call site of
generate_seq_lens_decode in test_xqa_batch_decode_nvfp4_kv and ensure no other
references to in_kv_lens remain.
🧹 Nitpick comments (1)
csrc/xqa/mha.h (1)
84-87: Inconsistent use of plain division vs.exactDivforPaddedCacheHeadSf.Line 64 uses
exactDiv(validElemsPerHead, CacheElemConverter::QuantVectorSize)forGMemCacheHeadSf, but line 86 uses plain/forPaddedCacheHeadSf. UsingexactDivhere would provide a compile-time safety net against future misconfigurations whereheadElemsmight not be exactly divisible byQuantVectorSize.Suggested fix
`#if` ENABLE_4BIT_KV_CACHE using PaddedCacheHeadSf = - Vec<CacheElemConverter::ScalingFactorType, headElems / CacheElemConverter::QuantVectorSize>; + Vec<CacheElemConverter::ScalingFactorType, exactDiv(headElems, CacheElemConverter::QuantVectorSize)>; `#endif`
| if k_cache.dtype == torch.uint8: | ||
| assert get_compute_capability(torch.device(device="cuda"))[0] in [12], ( | ||
| "XQA NVFP4 KV is only supported on SM120 GPUs" | ||
| ) | ||
| assert k_sf_cache is not None, "K SF cache is required when NVFP4 KV is used" | ||
| assert v_sf_cache is not None, "V SF cache is required when NVFP4 KV is used" |
There was a problem hiding this comment.
get_compute_capability uses default CUDA device instead of q.device.
Lines 307 and 313 call get_compute_capability(torch.device(device="cuda")) which queries the default CUDA device. If q resides on a non-default GPU in a multi-GPU setup, this could yield an incorrect capability check. The same pattern exists at line 300, so this is pre-existing, but worth noting since it gates the new NVFP4 path.
Suggested fix
- if k_cache.dtype == torch.uint8:
- assert get_compute_capability(torch.device(device="cuda"))[0] in [12], (
+ if k_cache.dtype == torch.uint8:
+ assert get_compute_capability(q.device)[0] in [12], (
"XQA NVFP4 KV is only supported on SM120 GPUs"
)
assert k_sf_cache is not None, "K SF cache is required when NVFP4 KV is used"
assert v_sf_cache is not None, "V SF cache is required when NVFP4 KV is used"
- if get_compute_capability(torch.device(device="cuda"))[0] not in [9, 10, 12]:
+ if get_compute_capability(q.device)[0] not in [9, 10, 12]:
raise RuntimeError("XQA is only supported on SM90, SM100, SM120 GPUs")🤖 Prompt for AI Agents
In `@flashinfer/xqa.py` around lines 306 - 311, The check that calls
get_compute_capability(torch.device(device="cuda")) queries the default CUDA
device and can be wrong for tensors on other GPUs; update the capability checks
in the NVFP4 path to call get_compute_capability with the actual tensor device
(e.g. use q.device or k_cache.device) instead of torch.device(device="cuda") so
the assertions around k_cache.dtype, k_sf_cache, and v_sf_cache validate the
correct GPU for the tensor in functions/blocks referencing
get_compute_capability and the NVFP4 KV path.
|
I have verified the functions through SGLang Qwen model. The PR for SGLang sgl-project/sglang#18314 |
Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>
c8d4769 to
8618527
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
flashinfer/xqa.py (1)
154-160:⚠️ Potential issue | 🟠 MajorKeep
k_sf_cache/v_sf_cacheoptional in public API to avoid breakage.Adding them as required positional parameters breaks existing
xqa(...)callers that do not use NVFP4 KV.💡 Proposed fix
def xqa( q: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, - k_sf_cache: Optional[torch.Tensor], - v_sf_cache: Optional[torch.Tensor], + k_sf_cache: Optional[torch.Tensor] = None, + v_sf_cache: Optional[torch.Tensor] = None, page_table: torch.Tensor, seq_lens: torch.Tensor, output: torch.Tensor,🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/xqa.py` around lines 154 - 160, The xqa function signature currently requires k_sf_cache and v_sf_cache as positional parameters which breaks callers that don't use NVFP4 KV; make these parameters optional by giving them default value None in the xqa(...) signature (keep their types as Optional[torch.Tensor]) and update the function body to handle None (e.g., skip NVFP4-specific logic when k_sf_cache or v_sf_cache is None); ensure the symbols to change are the xqa(...) function header and any internal branches that reference k_sf_cache / v_sf_cache so behavior remains unchanged for callers that pass them and is safe for callers that omit them.csrc/xqa/mha.cu (1)
2884-2892:⚠️ Potential issue | 🟠 MajorUse container-based stride conversion in
launchMHAfor 4-bit KV.
launchMHAFlashInferalready switched to container-aware stride normalization, butlaunchMHAstill divides byvalidElemsPerHead. For containerized 4-bit cache this can misaddress KV pages/tokens/heads.💡 Proposed fix
- uint32_t const stride_page_in_heads = static_cast<uint32_t>(kv_stride_page / validElemsPerHead); - uint32_t const stride_token_in_heads = static_cast<uint32_t>(kv_stride_token / validElemsPerHead); - uint32_t const stride_head_in_heads = static_cast<uint32_t>(kv_stride_head / validElemsPerHead); + uint32_t const container_elems_per_head = + validElemsPerHead / CacheElemConverter::ElemsPerContainer; + uint32_t const stride_page_in_heads = + static_cast<uint32_t>(kv_stride_page / container_elems_per_head); + uint32_t const stride_token_in_heads = + static_cast<uint32_t>(kv_stride_token / container_elems_per_head); + uint32_t const stride_head_in_heads = + static_cast<uint32_t>(kv_stride_head / container_elems_per_head);🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/xqa/mha.cu` around lines 2884 - 2892, In launchMHA the stride conversion still divides kv_stride_page/token/head by validElemsPerHead (producing stride_page_in_heads, stride_token_in_heads, stride_head_in_heads) which breaks addressing for containerized 4-bit KV; change the conversion to use the KVCacheList container-aware element size (the same approach used in launchMHAFlashInfer) — read the container's valid-elements-per-head or element-size helper from KVCacheList (constructed as cacheList) and use that value when normalizing kv_stride_page/token/head to head units so page/token/head strides are computed correctly for 4-bit containers.
♻️ Duplicate comments (4)
flashinfer/xqa.py (1)
306-314:⚠️ Potential issue | 🟠 MajorCompute capability checks should use tensor device, not default CUDA device.
Line 307 and Line 313 query
torch.device("cuda"), which can check the wrong GPU in multi-device runs.💡 Proposed fix
- if ( - k_cache.dtype == torch.float8_e4m3fn - and get_compute_capability(torch.device(device="cuda"))[0] == 9 - ): + device_cc_major = get_compute_capability(q.device)[0] + if k_cache.dtype == torch.float8_e4m3fn and device_cc_major == 9: run_sm90_fp8_mha = True else: run_sm90_fp8_mha = False if k_cache.dtype == torch.uint8: - assert get_compute_capability(torch.device(device="cuda"))[0] in [12], ( + assert device_cc_major in [12], ( "XQA NVFP4 KV is only supported on SM120 GPUs" ) assert k_sf_cache is not None, "K SF cache is required when NVFP4 KV is used" assert v_sf_cache is not None, "V SF cache is required when NVFP4 KV is used" - if get_compute_capability(torch.device(device="cuda"))[0] not in [9, 10, 12]: + if device_cc_major not in [9, 10, 12]: raise RuntimeError("XQA is only supported on SM90, SM100, SM120/SM121 GPUs")#!/bin/bash set -euo pipefail rg -n 'get_compute_capability\(torch\.device\(device="cuda"\)\)' flashinfer/xqa.py🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/xqa.py` around lines 306 - 314, The compute-capability checks call get_compute_capability(torch.device(device="cuda")) which can target the wrong GPU in multi-device runs; update those calls to use the actual tensor device (e.g., k_cache.device) instead. Specifically, replace the get_compute_capability(torch.device(device="cuda")) usages inside the k_cache.dtype == torch.uint8 branch and the subsequent supported-SM check with get_compute_capability(k_cache.device) (or the appropriate tensor like q_cache.device if q is the primary tensor) so the checks use the same CUDA device as the tensors being validated.csrc/xqa/mha.cu (2)
2114-2116:⚠️ Potential issue | 🔴 CriticalUndefined symbol in
loadPagescauses compilation failure.
idxNextSMemVBufis not in scope in this lambda, sogetSmemVSfTile(idxNextSMemVBuf)cannot compile whenBEAM_WIDTH > 1and 4-bit KV is enabled.💡 Proposed fix
`#if` BEAM_WIDTH == 1 ... `#else` auto& dst = smem.vCachePages[grpLoadV ? warpGrpIdx : warpIdx.x]; -#if ENABLE_4BIT_KV_CACHE - auto& dstSf = getSmemVSfTile(idxNextSMemVBuf); -#endif loadPagesForBeamSearchAsync<grpLoadV ? gemm1WarpsPerGrp : 1U>( grpLoadV ? warpIdxInGrp : 0U, dst, cacheList, false, idxReq, idxPageBeg, nbPages); `#endif`🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/xqa/mha.cu` around lines 2114 - 2116, The lambda in loadPages references idxNextSMemVBuf which is out of scope when BEAM_WIDTH>1 with ENABLE_4BIT_KV_CACHE enabled; update the lambda to use a visible index or capture/pass idxNextSMemVBuf into the lambda and then call getSmemVSfTile with that captured variable (or compute the correct buffer index visible in loadPages), e.g., ensure the lambda captures idxNextSMemVBuf or replace it with the in-scope idx (used elsewhere in loadPages) before calling getSmemVSfTile to initialize dstSf.
1168-1184:⚠️ Potential issue | 🟠 MajorRemove or consume
vSfPrefetch; it is currently dead code.The prefetch buffer is populated but not used in subsequent computation, adding avoidable overhead.
💡 Proposed fix (remove dead prefetch block)
-#if ENABLE_4BIT_KV_CACHE - // Prefetch buffer for SF - constexpr uint32_t nbSfPrefetchBuffers = exactDiv(warpTile.x, 16 * sizeof(uint32_t)); - Array2D<uint32_t, exactDiv(cacheVTileSeqLen, 4), nbSfPrefetchBuffers> vSfPrefetch; -#pragma unroll - for (uint32_t i = 0; i < vSfPrefetch.rows; i += 4) { -#pragma unroll - for (uint32_t j = 0; j < nbSfPrefetchBuffers; j++) { - uint32_t const sfRowIdxInSlice = (laneId() % 4) * 4; - uint32_t const sfRowIdx = (i / 4) * 16 + i % 4 + sfRowIdxInSlice; - vSfPrefetch(i, j) = reinterpret_cast<const uint32_t&>(vSf.template at(sfRowIdx, j)); - } - } -#endif🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/xqa/mha.cu` around lines 1168 - 1184, Remove the dead prefetch buffer: delete the nbSfPrefetchBuffers declaration and the Array2D vSfPrefetch declaration plus the nested loops that populate it (the entire `#if` ENABLE_4BIT_KV_CACHE prefetch block that references nbSfPrefetchBuffers, vSfPrefetch, warpTile, cacheVTileSeqLen, vSf, and laneId()). After removal, ensure there are no remaining references to vSfPrefetch or nbSfPrefetchBuffers elsewhere; if a prefetch is actually needed, instead replace this block with a true consumer of the populated data where vSf is used.tests/attention/test_xqa_batch_decode.py (1)
632-632:⚠️ Potential issue | 🟡 MinorDrop the unused
in_kv_lensbinding.Line 632 binds
in_kv_lensbut it is not used in this test.🧹 Proposed cleanup
- q_lens, in_kv_lens, seq_lens = generate_seq_lens_decode( + q_lens, _in_kv_lens, seq_lens = generate_seq_lens_decode( batch_size, q_len_per_req, max_in_kv_len )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_xqa_batch_decode.py` at line 632, The test binds an unused variable from generate_seq_lens_decode; remove the unnecessary in_kv_lens binding and only capture the used values (e.g., q_lens and seq_lens) from the call to generate_seq_lens_decode so the test no longer declares an unused variable; update the assignment that currently reads "q_lens, in_kv_lens, seq_lens = generate_seq_lens_decode(...)" to only bind the needed symbols.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/decode.py`:
- Around line 2536-2548: Ensure NVFP4 scale-factor requirements are enforced and
fix the single-tensor branch error message: if the runtime/backend requires
NVFP4 KV scale-factor tensors, do not accept kv_cache_sf == None and raise a
clear ValueError; when kv_cache_sf is a tuple, validate it has exactly two
elements before unpacking into k_cache_sf and v_cache_sf; when kv_cache_sf is a
single tensor assert kv_cache_sf.shape[1] == 2 (change the message to state
"must be 2" rather than "1 or 2") and then unbind as before to assign
k_cache_sf, v_cache_sf.
In `@tests/attention/test_xqa_batch_decode.py`:
- Around line 164-173: The V cache is being dequantized with the K global scale
causing incorrect reference KV values; in the block handling kv_dtype == "nvfp4"
(variables and helpers: k_cache, v_cache, to_nvfp4, nvfp4_to_float, k_scale,
v_scale, k_global_scale, v_global_scale, ref_kv_cache) change the dequantization
call for v_cache to use v_global_scale instead of k_global_scale so v_cache_dq =
nvfp4_to_float(v_cache, v_scale, v_global_scale) before stacking into
ref_kv_cache.
---
Outside diff comments:
In `@csrc/xqa/mha.cu`:
- Around line 2884-2892: In launchMHA the stride conversion still divides
kv_stride_page/token/head by validElemsPerHead (producing stride_page_in_heads,
stride_token_in_heads, stride_head_in_heads) which breaks addressing for
containerized 4-bit KV; change the conversion to use the KVCacheList
container-aware element size (the same approach used in launchMHAFlashInfer) —
read the container's valid-elements-per-head or element-size helper from
KVCacheList (constructed as cacheList) and use that value when normalizing
kv_stride_page/token/head to head units so page/token/head strides are computed
correctly for 4-bit containers.
In `@flashinfer/xqa.py`:
- Around line 154-160: The xqa function signature currently requires k_sf_cache
and v_sf_cache as positional parameters which breaks callers that don't use
NVFP4 KV; make these parameters optional by giving them default value None in
the xqa(...) signature (keep their types as Optional[torch.Tensor]) and update
the function body to handle None (e.g., skip NVFP4-specific logic when
k_sf_cache or v_sf_cache is None); ensure the symbols to change are the xqa(...)
function header and any internal branches that reference k_sf_cache / v_sf_cache
so behavior remains unchanged for callers that pass them and is safe for callers
that omit them.
---
Duplicate comments:
In `@csrc/xqa/mha.cu`:
- Around line 2114-2116: The lambda in loadPages references idxNextSMemVBuf
which is out of scope when BEAM_WIDTH>1 with ENABLE_4BIT_KV_CACHE enabled;
update the lambda to use a visible index or capture/pass idxNextSMemVBuf into
the lambda and then call getSmemVSfTile with that captured variable (or compute
the correct buffer index visible in loadPages), e.g., ensure the lambda captures
idxNextSMemVBuf or replace it with the in-scope idx (used elsewhere in
loadPages) before calling getSmemVSfTile to initialize dstSf.
- Around line 1168-1184: Remove the dead prefetch buffer: delete the
nbSfPrefetchBuffers declaration and the Array2D vSfPrefetch declaration plus the
nested loops that populate it (the entire `#if` ENABLE_4BIT_KV_CACHE prefetch
block that references nbSfPrefetchBuffers, vSfPrefetch, warpTile,
cacheVTileSeqLen, vSf, and laneId()). After removal, ensure there are no
remaining references to vSfPrefetch or nbSfPrefetchBuffers elsewhere; if a
prefetch is actually needed, instead replace this block with a true consumer of
the populated data where vSf is used.
In `@flashinfer/xqa.py`:
- Around line 306-314: The compute-capability checks call
get_compute_capability(torch.device(device="cuda")) which can target the wrong
GPU in multi-device runs; update those calls to use the actual tensor device
(e.g., k_cache.device) instead. Specifically, replace the
get_compute_capability(torch.device(device="cuda")) usages inside the
k_cache.dtype == torch.uint8 branch and the subsequent supported-SM check with
get_compute_capability(k_cache.device) (or the appropriate tensor like
q_cache.device if q is the primary tensor) so the checks use the same CUDA
device as the tensors being validated.
In `@tests/attention/test_xqa_batch_decode.py`:
- Line 632: The test binds an unused variable from generate_seq_lens_decode;
remove the unnecessary in_kv_lens binding and only capture the used values
(e.g., q_lens and seq_lens) from the call to generate_seq_lens_decode so the
test no longer declares an unused variable; update the assignment that currently
reads "q_lens, in_kv_lens, seq_lens = generate_seq_lens_decode(...)" to only
bind the needed symbols.
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (12)
csrc/flashinfer_xqa_binding.cucsrc/xqa/defines.hcsrc/xqa/mha.cucsrc/xqa/mha.hcsrc/xqa/mhaUtils.cuhcsrc/xqa/utils.cuhcsrc/xqa/xqa_wrapper.cuflashinfer/decode.pyflashinfer/fp4_quantization.pyflashinfer/jit/xqa.pyflashinfer/xqa.pytests/attention/test_xqa_batch_decode.py
🚧 Files skipped from review as they are similar to previous changes (2)
- csrc/xqa/xqa_wrapper.cu
- csrc/xqa/defines.h
|
/bot run |
|
@flashinfer-bot is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww |
|
/bot run |
yzh119
left a comment
There was a problem hiding this comment.
LGTM overall, left some comments.
flashinfer/xqa.py
Outdated
| Data type should match query tensor or be torch.float8_e4m3fn, in which case xqa will run fp8 calculation. | ||
| Should be the same data type as k_cache. | ||
| k_sf_cache: Optional[torch.Tensor] | ||
| Optional scale factor cache tensor for the K cache. Use when NVFP4 KV is used. |
There was a problem hiding this comment.
Please also annotate the expected shape of scaling factors.
| Should be the same data type as k_cache. | ||
| k_sf_cache: Optional[torch.Tensor] | ||
| Optional scale factor cache tensor for the K cache. Use when NVFP4 KV is used. | ||
| v_sf_cache: Optional[torch.Tensor] |
|
[FAILED] Pipeline #45302193: 6/20 passed |
|
@yzh119 all comments are resolved and would you pls take a look and trigger CI? |
|
/bot run |
|
[FAILED] Pipeline #45502683: 6/20 passed |
There was a problem hiding this comment.
♻️ Duplicate comments (1)
flashinfer/xqa.py (1)
308-313:⚠️ Potential issue | 🟡 Minor
get_compute_capabilityuses default CUDA device instead of tensor device.The capability checks use
torch.device(device="cuda")which queries the default CUDA device. In multi-GPU setups where tensors reside on non-default GPUs, this could yield incorrect results.Suggested fix
if k_cache.dtype == torch.uint8: - assert get_compute_capability(torch.device(device="cuda"))[0] in [12], ( + assert get_compute_capability(k_cache.device)[0] in [12], ( "XQA NVFP4 KV is only supported on SM120 GPUs" ) - if get_compute_capability(torch.device(device="cuda"))[0] not in [9, 10, 12]: + if get_compute_capability(q.device)[0] not in [9, 10, 12]: raise RuntimeError("XQA is only supported on SM90, SM100, SM120/SM121 GPUs")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/xqa.py` around lines 308 - 313, The compute-capability check incorrectly queries the default CUDA device; change the get_compute_capability call to use the actual tensor/device where k_cache lives (e.g., pass torch.device(k_cache.device) or the existing local variable device that points to the tensor's device) instead of torch.device(device="cuda") so the SM check (in the block guarding k_cache.dtype == torch.uint8) uses the correct GPU for multi-GPU setups; update the get_compute_capability(...) call accordingly and keep the surrounding assertions (k_sf_cache, v_sf_cache) intact.
🧹 Nitpick comments (1)
tests/attention/test_xqa_batch_decode.py (1)
630-634: Prefix unusedin_kv_lenswith underscore.Static analysis flagged
in_kv_lensas unused. Prefix with underscore to indicate intentional discard.Suggested fix
- q_lens, in_kv_lens, seq_lens = generate_seq_lens_decode( + q_lens, _in_kv_lens, seq_lens = generate_seq_lens_decode( batch_size, q_len_per_req, max_in_kv_len )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_xqa_batch_decode.py` around lines 630 - 634, The variable in_kv_lens returned from generate_seq_lens_decode is unused and should be marked as intentionally discarded; change the assignment to unpack it into a prefixed-underscore name (e.g. _in_kv_lens) while keeping q_lens and seq_lens as-is so the call to generate_seq_lens_decode and uses of q_lens/seq_lens remain unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@flashinfer/xqa.py`:
- Around line 308-313: The compute-capability check incorrectly queries the
default CUDA device; change the get_compute_capability call to use the actual
tensor/device where k_cache lives (e.g., pass torch.device(k_cache.device) or
the existing local variable device that points to the tensor's device) instead
of torch.device(device="cuda") so the SM check (in the block guarding
k_cache.dtype == torch.uint8) uses the correct GPU for multi-GPU setups; update
the get_compute_capability(...) call accordingly and keep the surrounding
assertions (k_sf_cache, v_sf_cache) intact.
---
Nitpick comments:
In `@tests/attention/test_xqa_batch_decode.py`:
- Around line 630-634: The variable in_kv_lens returned from
generate_seq_lens_decode is unused and should be marked as intentionally
discarded; change the assignment to unpack it into a prefixed-underscore name
(e.g. _in_kv_lens) while keeping q_lens and seq_lens as-is so the call to
generate_seq_lens_decode and uses of q_lens/seq_lens remain unchanged.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 439dcace-8e6a-48ee-8fd5-7f0a747ab41f
📒 Files selected for processing (4)
csrc/xqa/mha.cuflashinfer/xqa.pytests/attention/test_xqa.pytests/attention/test_xqa_batch_decode.py
|
/bot run |
|
[SUCCESS] Pipeline #45680141: 10/20 passed |
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> Supports NVFP4 KV cache. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * xqa and decode APIs now accept optional NVFP4 (4-bit) KV-cache scaling tensors (k_sf_cache/v_sf_cache / kv_cache_sf); torch.uint8 KV-cache dtype supported with runtime SM 12.x guard; docstrings updated. * **Tests** * Added/extended tests exercising NVFP4 KV-cache quantization, scaling factors, uint8 KV-cache paths, and decode integration. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com>
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> Supports NVFP4 KV cache. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * xqa and decode APIs now accept optional NVFP4 (4-bit) KV-cache scaling tensors (k_sf_cache/v_sf_cache / kv_cache_sf); torch.uint8 KV-cache dtype supported with runtime SM 12.x guard; docstrings updated. * **Tests** * Added/extended tests exercising NVFP4 KV-cache quantization, scaling factors, uint8 KV-cache paths, and decode integration. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Tian Zheng <29906817+Tom-Zheng@users.noreply.github.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
<!-- .github/pull_request_template.md --> ## 📌 Description fix api breaking changes for 0.6.7 release ## 🔍 Related Issues (Gated-by PRs) https://github.com/flashinfer-ai/flashinfer/issues?q=state%3Aopen%20label%3Av0.6.7 <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes **API changes review** API changes since v0.6.6 PR #2520 + commit e35c19e (fixed to be compatible) Function: xqa() Change: Added k_sf_cache=None, v_sf_cache=None as keyword-only params (after *). Backward-compatible. PR #2618 (has PR #2730 to fix it) Function: gated_delta_rule_mtp() Change: disable_state_update: bool = True → Optional[bool] = None. Still defaults to True at runtime but emits a deprecation warning; will flip to False in 0.7.0. PR #2775 (expected — cute DSL MoE cleanup) Function: blockscaled_contiguous_grouped_gemm_nvfp4() Change: Entire @flashinfer_api decorated function deleted. Function: blockscaled_contiguous_grouped_gemm_swiglu_fusion_nvfp4() Change: Entire @flashinfer_api decorated function deleted. Function: blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4() Change: @flashinfer_api decorator removed; added enable_pdl: bool = True param. Function: blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4() Change: @flashinfer_api decorator removed; added enable_pdl: bool = True param. Function: CuteDslMoEWrapper.__init__() Change: Added enable_pdl: bool = True param. Backward-compatible. Function: cute_dsl_fused_moe_nvfp4() Change: Added enable_pdl: bool = True param. Backward-compatible. PR #2428 Function: rmsnorm_quant() Change: scale: float → scale: Union[float, torch.Tensor]; return type torch.Tensor → None. Function: fused_add_rmsnorm_quant() Change: scale: float → scale: Union[float, torch.Tensor]. Quantization functions (relocated, not removed) All quantization APIs (fp4_quantize, block_scale_interleave, e2m1_and_ufp8sf_scale_to_float, shuffle_matrix_a, shuffle_matrix_sf_a, nvfp4_quantize, nvfp4_batched_quantize, scaled_fp4_grouped_quantize, mxfp4_quantize, mxfp4_dequantize, mxfp4_dequantize_host, mxfp8_quantize, mxfp8_dequantize_host) were moved from flashinfer/fp4_quantization.py and flashinfer/fp8_quantization.py to flashinfer/quantization/. Signatures, @flashinfer_api decorators, and __init__.py exports are preserved. No breakage. ```diff $ git diff v0.6.6 | grep -A20 "@flashinfer_api" @flashinfer_api @@ -1215,6 +1227,9 @@ class BatchDecodeWithPagedKVCacheWrapper: sinks: Optional[torch.Tensor] = None, q_len_per_req: Optional[int] = 1, skip_softmax_threshold_scale_factor: Optional[float] = None, + kv_block_scales: Optional[ + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + ] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Compute batch decode attention between query and paged kv cache. @@ -1273,6 +1288,15 @@ class BatchDecodeWithPagedKVCacheWrapper: enable_pdl = device_support_pdl(q.device) k_cache, v_cache = _unpack_paged_kv_cache(paged_kv_cache, self._kv_layout) + # Unpack kv_block_scales + key_block_scales = None + value_block_scales = None + if kv_block_scales is not None: + if isinstance(kv_block_scales, tuple): + key_block_scales, value_block_scales = kv_block_scales -- -@flashinfer_api -def fp4_quantize( - input: torch.Tensor, - global_scale: Optional[torch.Tensor] = None, - sf_vec_size: int = 16, - sf_use_ue8m0: bool = False, - is_sf_swizzled_layout: bool = True, - is_sf_8x4_layout: bool = False, - enable_pdl: Optional[bool] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Quantize input tensor to FP4 format. - - This function implements FP4 quantization that converts input tensors to a compressed FP4 format - with associated scale factors. It supports various input data types and scale factor layouts. - - Args: - input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized. - global_scale (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32. - sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. - sf_use_ue8m0 (bool, optional): Whether to use UE8M0 format for scale factors. Defaults to False. - is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. -- -@flashinfer_api -def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor: - """Swizzle block scale tensor for FP4 format. - - This function swizzles the block scale tensor to optimize memory access patterns - for FP4 operations. The output needs to be padded in the m dimension to be a multiple of 128. - - Args: - unswizzled_sf (torch.Tensor): Input tensor with dtype uint8 or bfloat16. - - Returns: - torch.Tensor: Swizzled tensor with the same shape as input. - - Raises: - AssertionError: If input dtype is not uint8 or bfloat16. - """ - # TODO(shuw): check input dtype is uint8 - assert ( - unswizzled_sf.dtype == torch.uint8 or unswizzled_sf.dtype == torch.bfloat16 - ), f"Input dtype must be uint8 or bfloat16, got {unswizzled_sf.dtype}" - -- -@flashinfer_api -def e2m1_and_ufp8sf_scale_to_float( - e2m1_tensor: torch.Tensor, - ufp8_scale_tensor: torch.Tensor, - global_scale_tensor: Optional[torch.Tensor] = None, - sf_vec_size: int = 16, - ufp8_type: int = 1, - is_sf_swizzled_layout: bool = True, -) -> torch.Tensor: - """Convert E2M1 format tensor and UFP8 scale factors to float tensor. - - This function performs dequantization by converting a packed FP4 tensor in E2M1 format - back to float values using the associated UFP8 scale factors and global scale. - - Args: - e2m1_tensor (torch.Tensor): Packed FP4 tensor in E2M1 format of shape [M, K/2] with dtype uint8. - ufp8_scale_tensor (torch.Tensor): Scale factors tensor in UFP8 format with dtype uint8. - global_scale_tensor (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32. - sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. - ufp8_type (int, optional): UFP8 scale factor type (0 for UE8M0, 1 for E4M3). Defaults to 1. - is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True. -- -@flashinfer_api -def shuffle_matrix_a(input_tensor: torch.Tensor, epilogue_tile_m: int) -> torch.Tensor: - """ - PyTorch equivalent of trtllm-gen `shuffleMatrixA` - """ - row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m) - - return input_tensor[row_indices.to(input_tensor.device)] - - -@flashinfer_api -def shuffle_matrix_sf_a( - input_tensor: torch.Tensor, - epilogue_tile_m: int, - num_elts_per_sf: int = 16, -): - """ - Cuda implementation of trtllm-gen `shuffleMatrixSfA` but with a caveat. - `shuffleMatrixSfA` expects the input to be in 128x4 layout and then - apply the same shuffling in `shuffleMatrixA` and writes out in 128x4 - layout. - This function expects the input to be in linear layout. It's done this - way because the scaling factors in the NVFP4 checkpoints are quantized - and are in linear layout. - This function doesn't add padding. - """ - - row_indices = get_shuffle_matrix_sf_a_row_indices(input_tensor, epilogue_tile_m) - - w_shuffled = input_tensor[row_indices.to(input_tensor.device)] - -- -@flashinfer_api -def nvfp4_quantize( - a, - a_global_sf, - sfLayout=SfLayout.layout_128x4, - do_shuffle=False, - sf_vec_size=16, - enable_pdl=None, -): - """ - Quantize input tensor to NVFP4 format. - - Parameters: - a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16. - a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32. - sfLayout (SfLayout, optional): Scale factor layout. Defaults to SfLayout.layout_128x4. - do_shuffle (bool, optional): Whether to shuffle the scale factors. Defaults to False. Only TRTLLM backend needs to shuffle the tensor B scale factors. - sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. - enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). - If None, automatically detects based on device capability. Defaults to None. - -- -@flashinfer_api -def mxfp4_quantize(a): - """ - Quantize input tensor to MXFP4 format. - - Parameters: - a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2) - - Scale factors tensor with shape determined by layout and sf_vec_size (uint8) - """ - a_global_sf = (448 * 6) / a.float().abs().nan_to_num().max() - a_fp4, a_sf = fp4_quantize(a.cuda(), a_global_sf.cuda(), 32, True, True) - return a_fp4, a_sf - - -@flashinfer_api -def mxfp4_dequantize(a_fp4, a_sf): - """ - Dequantize input tensor from MXFP4 format. - - Parameters: - a_fp4 (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2) - a_sf (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8) - - Returns: - torch.Tensor: Dequantized tensor of shape [M, K] with dtype float. - """ - return e2m1_and_ufp8sf_scale_to_float( - a_fp4.cpu().view(torch.uint8), - a_sf.cpu().view(torch.uint8).reshape(-1), - torch.tensor([1.0], device=a_fp4.device), - 32, - 0, - True, - ) - -- -@flashinfer_api -def mxfp4_dequantize_host( - weight: torch.Tensor, - scale: torch.Tensor, - group_size: int = 32, -) -> torch.Tensor: - """ - Dequantize input tensor from MXFP4 format on host. - - Parameters: - weight (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2) - scale (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8) - group_size (int, optional): Group size for dequantization. Defaults to 32. - - Returns: - torch.Tensor: Dequantized tensor of shape [M, K] with dtype float. - """ - # NOTE(Zihao): the cpu op should be decouplied from cuda ops because it's device independent, should refactor this in the future - major, minor = get_compute_capability( - torch.device("cuda:0") - ) # use any cuda device to get a compute capability -- -@flashinfer_api -def nvfp4_batched_quantize( - a, - a_global_sf, - sf_vec_size=16, -): - """ - Quantize batched input tensor to NVFP4 format. - - Parameters: - a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16. - a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32. - sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2 - - Scale factors tensor with shape determined by layout and sf_vec_size - """ - major, minor = get_compute_capability(a.device) - device_arch = f"{major * 10 + minor}" -- -@flashinfer_api -def scaled_fp4_grouped_quantize( - a, - mask, - a_global_sf, -): - """ - quantize batched input tensor to NVFP4 format with mask. - Parameters: - a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16. - a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32. - mask (torch.Tensor): Mask tensor to apply before quantization. - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2 - - Scale factors tensor with shape determined by layout and sf_vec_size - """ - major, minor = get_compute_capability(a.device) - device_arch = f"{major * 10 + minor}" - a_fp4, a_sf = get_fp4_quantization_module( - device_arch -- -@flashinfer_api -def mxfp8_quantize( - input: torch.Tensor, - is_sf_swizzled_layout: bool = True, - alignment: int = 32, - enable_pdl: Optional[bool] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Quantize input tensor to MxFP8 format. - - This function implements MxFP8 quantization that converts input tensors to a compressed MxFP8 format - with associated scale factors. It supports various input data types and scale factor layouts. - - Args: - input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized. - is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. - alignment (int, optional): sfVecSize. Defaults to 32. - enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). - If None, automatically detects based on device capability. Defaults to None. - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - Quantized tensor of shape [M, K] with dtype FLOAT8_E4M3 -- -@flashinfer_api -def mxfp8_dequantize_host( - input: torch.Tensor, - scale_tensor: torch.Tensor, - is_sf_swizzled_layout: bool = True, -) -> torch.Tensor: - """Dequantize input tensor from MxFP8 format. - - This function performs dequantization by converting a packed FP8 tensor in MxFP8 format - back to float values using the associated scale factors. - - Args: - input (torch.Tensor): Packed FP8 tensor in MxFP8 format of shape [M, K] with dtype FLOAT8_E4M3. - scale_tensor (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size. - is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True. - - Returns: - torch.Tensor: Dequantized float tensor of shape [M, K] with dtype float32. - - """ - -- -@flashinfer_api def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4( a: torch.Tensor, b: torch.Tensor, @@ -323,6 +324,7 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4( vectorized_f32: bool = True, raster_along_m: bool = False, sm_count: Optional[int] = None, + enable_pdl: bool = True, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Blockscaled Contiguous Gather Grouped GEMM with SwiGLU Fusion for MoE workloads. @@ -423,7 +425,7 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4( major, minor = get_compute_capability(a.device) if major != 10: raise ValueError( - f"Blockscaled contiguous gather grouped GEMM with SwiGLU requires SM100 family (Blackwell: SM100, SM103, SM110). " + f"Blockscaled contiguous gather grouped GEMM with SwiGLU requires SM100 family (Blackwell: SM100, SM103). " f"Got SM{major}{minor}." ) -- -@flashinfer_api -def blockscaled_contiguous_grouped_gemm_nvfp4( - a: torch.Tensor, - b: torch.Tensor, - a_scale: torch.Tensor, - b_scale: torch.Tensor, - alpha: torch.Tensor, - tile_idx_to_group_idx: torch.Tensor, - num_non_exiting_tiles: torch.Tensor, - out: Optional[torch.Tensor] = None, - *, - ab_dtype: str = "float4_e2m1fn", - sf_dtype: str = "float8_e4m3fn", - c_dtype: str = "bfloat16", - sf_vec_size: int = 16, - mma_tiler_mn: Tuple[int, int] = (128, 128), - cluster_shape_mn: Tuple[int, int] = (1, 1), - sm_count: Optional[int] = None, -) -> torch.Tensor: - """Blockscaled Contiguous Grouped GEMM for MoE workloads with NVFP4 quantization. - -- -@flashinfer_api def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4( a: torch.Tensor, b: torch.Tensor, @@ -272,6 +279,7 @@ def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4( cluster_shape_mn: Tuple[int, int] = (2, 1), raster_along_m: bool = False, sm_count: Optional[int] = None, + enable_pdl: bool = True, ) -> torch.Tensor: """Blockscaled Contiguous Grouped GEMM with Finalize Fusion for MoE workloads. @@ -298,7 +306,11 @@ def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4( expanded_idx = token_idx * topk + topk_idx. Invalid rows have -1. token_final_scales: Router scaling factors, shape (seq_len, topk), float32/bf16/fp16 out: Optional output tensor, shape (seq_len, n). Created if None. - This tensor is used for atomic accumulation, so it should be zero-initialized. + This tensor is used for atomic accumulation. If `out` is + provided, it must already be zero-initialized by the caller. + If `out` is None, this function allocates a zero-initialized + output tensor. Passing a non-zeroed `out` buffer will silently -- -@flashinfer_api -def blockscaled_contiguous_grouped_gemm_swiglu_fusion_nvfp4( - a: torch.Tensor, - b: torch.Tensor, - a_scale: torch.Tensor, - b_scale: torch.Tensor, - alpha: torch.Tensor, - tile_idx_to_group_idx: torch.Tensor, - num_non_exiting_tiles: torch.Tensor, - out: Optional[torch.Tensor] = None, - out_scale: Optional[torch.Tensor] = None, - global_scale: Optional[torch.Tensor] = None, - *, - ab_dtype: str = "float4_e2m1fn", - sf_dtype: str = "float8_e4m3fn", - c_dtype: str = "bfloat16", - sf_vec_size: int = 16, - mma_tiler_mn: Tuple[int, int] = (256, 128), - cluster_shape_mn: Tuple[int, int] = (2, 1), - vectorized_f32: bool = True, - sm_count: Optional[int] = None, -- @flashinfer_api def __init__( self, @@ -347,6 +355,7 @@ class CuteDslMoEWrapper: sf_vec_size: int = 16, output_dtype: torch.dtype = torch.bfloat16, device: str = "cuda", + enable_pdl: bool = True, ): """Initialize the MoE wrapper. @@ -363,6 +372,7 @@ class CuteDslMoEWrapper: sf_vec_size: Scale factor vector size. Default: 16. output_dtype: Output data type. Default: torch.bfloat16. device: Device for buffer allocation. Default: "cuda". + enable_pdl: Enable Programmatic Dependent Launch. Default: True. """ self.num_experts = num_experts self.top_k = top_k @@ -376,6 +386,7 @@ class CuteDslMoEWrapper: self.sf_vec_size = sf_vec_size -- @flashinfer_api @@ -550,9 +570,10 @@ class CuteDslMoEWrapper: f"num_tokens ({num_tokens}) exceeds max_num_tokens ({self.max_num_tokens})" ) - # Allocate output buffer if not using pre-allocated one + # Slice the pre-allocated buffer to the active batch so that + # _moe_core_impl only zeros num_tokens rows, not max_num_tokens. if self.use_cuda_graph: - moe_output = self._moe_output + moe_output = self._moe_output[:num_tokens] else: moe_output = torch.empty( (num_tokens, self.hidden_size), @@ -627,6 +648,7 @@ def _cute_dsl_fused_moe_nvfp4_impl( use_fused_finalize: bool = True, moe_output: Optional[torch.Tensor] = None, aux_stream: Optional[torch.cuda.Stream] = None, + enable_pdl: bool = True, ) -> torch.Tensor: """Internal implementation called by auto-tuner for functional API.""" -- @flashinfer_api def cute_dsl_fused_moe_nvfp4( x: torch.Tensor, @@ -678,9 +702,12 @@ def cute_dsl_fused_moe_nvfp4( use_fused_finalize: bool = True, moe_output: Optional[torch.Tensor] = None, aux_stream: Optional[torch.cuda.Stream] = None, + enable_pdl: bool = True, ) -> torch.Tensor: """Run fused MoE computation using CuteDSL NVFP4 kernels. + Supported architectures: SM100, SM103. + This is the simple functional API. For CUDA graph support, use `CuteDslMoEWrapper` instead. @@ -736,6 +763,7 @@ def cute_dsl_fused_moe_nvfp4( local_expert_offset=local_expert_offset, use_fused_finalize=use_fused_finalize, output_dtype=output_dtype, + enable_pdl=enable_pdl, -- @flashinfer_api def gated_delta_rule_decode_pretranspose( q: torch.Tensor, @@ -1002,8 +174,9 @@ def gated_delta_rule_decode_pretranspose( - State layout is v-major (K-last): [B, HV, V, K]. When state is bfloat16 and T in 1..4 with K=V=128, the gdn_decode_klast_bf16_state kernel is used (supports both the direct ``state`` path and the pool+indices path). - - pool+indices (``initial_state``/``initial_state_indices``) only supported - via the bf16 fast path; float32 state raises an error. + - pool+indices (``initial_state``/``initial_state_indices``) supported on + both the bf16 fast path (T in 1..4, K=V=128) and the float32 legacy path + (T=1). The float32 path also supports negative indices for padding. - Legacy path (float32 state, T=1): K and V must be multiples of 4. """ # Validate input shapes @@ -1069,13 +242,17 @@ def gated_delta_rule_decode_pretranspose( return_state = initial_state if use_pool else state return output, return_state - # Legacy path: T=1 only, float32 state (no pool+indices support) - assert not use_pool, ( -- @flashinfer_api def gated_delta_rule_mtp( q: torch.Tensor, @@ -2427,7 +489,7 @@ def gated_delta_rule_mtp( scale: Optional[float] = None, output: Optional[torch.Tensor] = None, intermediate_states_buffer: Optional[torch.Tensor] = None, - disable_state_update: bool = True, + disable_state_update: Optional[bool] = None, use_qk_l2norm: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -2463,8 +525,15 @@ def gated_delta_rule_mtp( intermediate_states_buffer (Optional[torch.Tensor]): Buffer for caching intermediate states, shape ``[pool_size, T, HV, V, K]``. If None, intermediate states are not cached. - disable_state_update (bool): - If True, the initial state is not updated. Default: ``True``. + disable_state_update (Optional[bool]): + If True, the initial state is not updated. Currently defaults to ``True``. + Please pass this argument explicitly — the default will change to ``False`` -- @flashinfer_api @@ -60,16 +120,14 @@ def rmsnorm( output: torch.Tensor Normalized tensor, 2D shape (batch_size, hidden_size) or 3D shape (batch_size, num_heads, hidden_size). """ - if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) if out is None: out = torch.empty_like(input) - _rmsnorm(out, input, weight, eps, enable_pdl) + _rmsnorm_impl(out, input, weight, eps, enable_pdl) return out @register_custom_op("flashinfer::rmsnorm", mutates_args=("out",)) -def _rmsnorm( +def _rmsnorm_impl( out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, @@ -78,11 +136,21 @@ def _rmsnorm( -- @flashinfer_api def fmha_v2_prefill_deepseek( query: torch.Tensor, @@ -3865,18 +4029,11 @@ def fmha_v2_prefill_deepseek( If return_lse is False, the output will be a single tensor. """ if not is_sm12x_supported(query.device): - major, minor = get_compute_capability(query.device) - if major == 12: - min_cuda = "13.0" if minor >= 1 else "12.8" - raise ValueError( - f"fmha_v2_prefill_deepseek requires CUDA >= {min_cuda} " - f"for SM12{minor}x GPUs." - ) raise ValueError("fmha_v2_prefill_deepseek is only supported on SM12x GPUs.") assert query.shape[3] == 192 and key.shape[3] == 192 and value.shape[3] == 128, ( "currently only support deepseek r1 192 query and 128 value" ) - module = get_trtllm_fmha_v2_module() + module = get_trtllm_fmha_v2_sm120_module() is_e4m3 = query.dtype == torch.float8_e4m3fn -- +@flashinfer_api +def trtllm_fmha_v2_prefill( + qkv: Union[ + torch.Tensor, + Tuple[torch.Tensor, torch.Tensor], + Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + ], + input_layout: str, + workspace_buffer: torch.Tensor, + seq_lens: torch.Tensor, + max_q_len: int, + max_kv_len: int, + bmm1_scale: float, + bmm2_scale: float, + batch_size: int, + cum_seq_lens_q: torch.Tensor, + cum_seq_lens_kv: torch.Tensor, + block_tables: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + out_dtype: Optional[Union[torch.dtype, str]] = None, + sinks: Optional[List[torch.Tensor]] = None, -- +@flashinfer_api +def fp4_quantize( + input: torch.Tensor, + global_scale: Optional[torch.Tensor] = None, + sf_vec_size: int = 16, + sf_use_ue8m0: bool = False, + is_sf_swizzled_layout: bool = True, + is_sf_8x4_layout: bool = False, + enable_pdl: Optional[bool] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Quantize input tensor to FP4 format. + + This function implements FP4 quantization that converts input tensors to a compressed FP4 format + with associated scale factors. It supports various input data types and scale factor layouts. + + Args: + input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized. + global_scale (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32. + sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. + sf_use_ue8m0 (bool, optional): Whether to use UE8M0 format for scale factors. Defaults to False. + is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. -- +@flashinfer_api +def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor: + """Swizzle block scale tensor for FP4 format. + + This function swizzles the block scale tensor to optimize memory access patterns + for FP4 operations. The output needs to be padded in the m dimension to be a multiple of 128. + + Args: + unswizzled_sf (torch.Tensor): Input tensor with dtype uint8 or bfloat16. + + Returns: + torch.Tensor: Swizzled tensor with the same shape as input. + + Raises: + AssertionError: If input dtype is not uint8 or bfloat16. + """ + # TODO(shuw): check input dtype is uint8 + assert ( + unswizzled_sf.dtype == torch.uint8 or unswizzled_sf.dtype == torch.bfloat16 + ), f"Input dtype must be uint8 or bfloat16, got {unswizzled_sf.dtype}" + -- +@flashinfer_api +def e2m1_and_ufp8sf_scale_to_float( + e2m1_tensor: torch.Tensor, + ufp8_scale_tensor: torch.Tensor, + global_scale_tensor: Optional[torch.Tensor] = None, + sf_vec_size: int = 16, + ufp8_type: int = 1, + is_sf_swizzled_layout: bool = True, +) -> torch.Tensor: + """Convert E2M1 format tensor and UFP8 scale factors to float tensor. + + This function performs dequantization by converting a packed FP4 tensor in E2M1 format + back to float values using the associated UFP8 scale factors and global scale. + + Args: + e2m1_tensor (torch.Tensor): Packed FP4 tensor in E2M1 format of shape [M, K/2] with dtype uint8. + ufp8_scale_tensor (torch.Tensor): Scale factors tensor in UFP8 format with dtype uint8. + global_scale_tensor (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32. + sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. + ufp8_type (int, optional): UFP8 scale factor type (0 for UE8M0, 1 for E4M3). Defaults to 1. + is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True. -- +@flashinfer_api +def shuffle_matrix_a(input_tensor: torch.Tensor, epilogue_tile_m: int) -> torch.Tensor: + """ + PyTorch equivalent of trtllm-gen `shuffleMatrixA` + """ + row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m) + + return input_tensor[row_indices.to(input_tensor.device)] + + +@flashinfer_api +def shuffle_matrix_sf_a( + input_tensor: torch.Tensor, + epilogue_tile_m: int, + num_elts_per_sf: int = 16, +): + """ + Cuda implementation of trtllm-gen `shuffleMatrixSfA` but with a caveat. + `shuffleMatrixSfA` expects the input to be in 128x4 layout and then + apply the same shuffling in `shuffleMatrixA` and writes out in 128x4 + layout. + This function expects the input to be in linear layout. It's done this + way because the scaling factors in the NVFP4 checkpoints are quantized + and are in linear layout. + This function doesn't add padding. + """ + + row_indices = get_shuffle_matrix_sf_a_row_indices(input_tensor, epilogue_tile_m) + + w_shuffled = input_tensor[row_indices.to(input_tensor.device)] + -- +@flashinfer_api +def nvfp4_quantize( + a, + a_global_sf, + sfLayout=SfLayout.layout_128x4, + do_shuffle=False, + sf_vec_size=16, + enable_pdl=None, +): + """ + Quantize input tensor to NVFP4 format. + + Parameters: + a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16. + a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32. + sfLayout (SfLayout, optional): Scale factor layout. Defaults to SfLayout.layout_128x4. + do_shuffle (bool, optional): Whether to shuffle the scale factors. Defaults to False. Only TRTLLM backend needs to shuffle the tensor B scale factors. + sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. + enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). + If None, automatically detects based on device capability. Defaults to None. + -- +@flashinfer_api +def mxfp4_quantize( + a: torch.Tensor, + backend: str = "cuda", + enable_pdl: Optional[bool] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to MXFP4 format. + + Parameters: + a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16. + backend (str, optional): Backend to use for quantization. + - "cuda": Use CUDA kernel (default, stable) + - "cute-dsl": Use CuTe-DSL kernel (requires SM100+, **experimental**) + enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic + Dependent Launch). Only used when backend="cute-dsl". + If None, automatically detects based on device capability. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2) -- +@flashinfer_api +def mxfp4_dequantize(a_fp4, a_sf): + """ + Dequantize input tensor from MXFP4 format. + + Parameters: + a_fp4 (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2) + a_sf (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8) + + Returns: + torch.Tensor: Dequantized tensor of shape [M, K] with dtype float. + """ + return e2m1_and_ufp8sf_scale_to_float( + a_fp4.cpu().view(torch.uint8), + a_sf.cpu().view(torch.uint8).reshape(-1), + torch.tensor([1.0], device=a_fp4.device), + 32, + 0, + True, + ) + -- +@flashinfer_api +def mxfp4_dequantize_host( + weight: torch.Tensor, + scale: torch.Tensor, + group_size: int = 32, +) -> torch.Tensor: + """ + Dequantize input tensor from MXFP4 format on host. + + Parameters: + weight (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2) + scale (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8) + group_size (int, optional): Group size for dequantization. Defaults to 32. + + Returns: + torch.Tensor: Dequantized tensor of shape [M, K] with dtype float. + """ + # NOTE(Zihao): the cpu op should be decouplied from cuda ops because it's device independent, should refactor this in the future + major, minor = get_compute_capability( + torch.device("cuda:0") + ) # use any cuda device to get a compute capability -- +@flashinfer_api +def nvfp4_batched_quantize( + a, + a_global_sf, + sf_vec_size=16, +): + """ + Quantize batched input tensor to NVFP4 format. + + Parameters: + a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16. + a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32. + sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2 + - Scale factors tensor with shape determined by layout and sf_vec_size + """ + major, minor = get_compute_capability(a.device) + device_arch = f"{major * 10 + minor}" -- +@flashinfer_api +def nvfp4_quantize_paged_kv_cache( + k_cache: torch.Tensor, + v_cache: torch.Tensor, + kv_layout: str = "HND", + k_global_sf: Optional[torch.Tensor] = None, + v_global_sf: Optional[torch.Tensor] = None, +) -> Tuple[ + Tuple[torch.Tensor, torch.Tensor], + Tuple[torch.Tensor, torch.Tensor], + float, + float, +]: + """Quantize paged KV cache to NVFP4 format for trtllm-gen MHA. + + Quantizes BF16/FP16 K/V caches to NVFP4 with two-level scaling + (global FP32 + per-block FP8), and swizzles scale factors + for the SM100 trtllm-gen MHA kernel layout. + + Args: + k_cache: Key cache tensor. -- +@flashinfer_api +def scaled_fp4_grouped_quantize( + a, + mask, + a_global_sf, +): + """ + quantize batched input tensor to NVFP4 format with mask. + Parameters: + a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16. + a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32. + mask (torch.Tensor): Mask tensor to apply before quantization. + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2 + - Scale factors tensor with shape determined by layout and sf_vec_size + """ + major, minor = get_compute_capability(a.device) + device_arch = f"{major * 10 + minor}" + a_fp4, a_sf = get_fp4_quantization_module( + device_arch -- +@flashinfer_api +def nvfp4_kv_dequantize( + fp4_data: torch.Tensor, + block_scales: torch.Tensor, + global_scale: torch.Tensor, + output_dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + """GPU dequantization of NVFP4 KV cache data with linear block scale layout. + + Requires SM80+. + + Args: + fp4_data (torch.Tensor): Packed FP4 data of shape ``[M, K/2]`` with dtype uint8. + block_scales (torch.Tensor): Per-block FP8 E4M3 scales of shape ``[M, K/16]`` + with dtype uint8. + global_scale (torch.Tensor): Global scale factor of shape ``[1]`` with dtype float32, + on the same CUDA device as fp4_data. + output_dtype (torch.dtype): Output dtype, either ``torch.bfloat16`` or ``torch.float16``. + + Returns: + torch.Tensor: Dequantized tensor of shape ``[M, K]`` with the specified output dtype. -- +@flashinfer_api +def nvfp4_kv_quantize( + input: torch.Tensor, + global_scale: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """GPU quantization to NVFP4 KV cache format with linear block scale layout. + + Requires SM100+ (Blackwell) for the cvt.rn.satfinite.e2m1x2.f32 PTX instruction. + + Args: + input (torch.Tensor): Input tensor of shape [M, K] with dtype bf16 or fp16. + K must be divisible by 16. + global_scale (torch.Tensor): Global scale factor of shape ``[1]`` with dtype float32, + on the same CUDA device as input. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - fp4_output: Packed FP4 data of shape ``[M, K/2]`` with dtype uint8. + - block_scales: Per-block FP8 E4M3 scales of shape ``[M, K/16]`` with dtype uint8. + """ + M, K = input.shape -- +@flashinfer_api +def mxfp8_quantize( + input: torch.Tensor, + is_sf_swizzled_layout: bool = True, + alignment: int = 32, + enable_pdl: Optional[bool] = None, + backend: Literal["cuda", "cute-dsl"] = "cuda", + sf_swizzle_layout: Optional[SfLayout] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Quantize input tensor to MxFP8 format. + + This function implements MxFP8 quantization that converts input tensors to a compressed MxFP8 format + with associated scale factors. It supports various input data types and scale factor layouts. + + Args: + input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized. + is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. + alignment (int, optional): sfVecSize. Defaults to 32. + enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). + If None, automatically detects based on device capability (SM >= 9.0). Defaults to None. + backend (Literal["cuda", "cute-dsl"], optional): Backend to use for quantization. Options are: -- +@flashinfer_api +def mxfp8_dequantize_host( + input: torch.Tensor, + scale_tensor: torch.Tensor, + is_sf_swizzled_layout: bool = True, + sf_swizzle_layout: Optional[SfLayout] = None, +) -> torch.Tensor: + """Dequantize input tensor from MxFP8 format. + + This function performs dequantization by converting a packed FP8 tensor in MxFP8 format + back to float values using the associated scale factors. + + Args: + input (torch.Tensor): Packed FP8 tensor in MxFP8 format of shape [M, K] with dtype FLOAT8_E4M3. + scale_tensor (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size. + is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. + sf_swizzle_layout (Optional[SfLayout], optional): Swizzle layout for scale factors. + If provided,it overrides is_sf_swizzled_layout. Defaults to None. + Available options are 1. SfLayout.layout_128x4; 2. SfLayout.layout_linear. + + Returns: -- +@flashinfer_api +def mxfp4_quantize_cute_dsl( + input: torch.Tensor, + enable_pdl: bool | None = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to MXFP4 format using CuTe-DSL kernel. + + This is a GPU implementation matching FlashInfer's mxfp4_quantize() behavior: + - Global scale computed as (448 * 6) / max(|input|) + - UE8M0 scale factors + - E2M1 output format (4-bit, 2 values per byte) + - Swizzled (128x4) scale factor layout + + The kernel is compiled once per (K, dtype, pdl) combination and handles + varying M (batch size) at runtime without recompilation. + + Args: + input: Input tensor of shape [M, K] with dtype fp16/bf16 + enable_pdl: Whether to enable PDL (Programmatic Dependent Launch). + If None, automatically detects based on device capability (SM >= 9.0). -- +@flashinfer_api +def mxfp8_quantize_cute_dsl( + input: torch.Tensor, + is_sf_swizzled_layout: bool = True, + alignment: int = 32, + enable_pdl: bool | None = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to MXFP8 format using CuTe-DSL kernel. + + This is a GPU implementation with dual-path optimization: + - LINEAR layout: SF-block based iteration (fast) + - SWIZZLED layout: Row-based iteration with padding fast path (optimized) + + The kernel is compiled once per (K, dtype, pdl) combination and handles + varying M (batch size) at runtime without recompilation. + + Args: + input: Input tensor of shape [M, K] with dtype fp16/bf16 + is_sf_swizzled_layout: Whether to use 128x4 swizzled layout (True) or linear (False) + alignment: Alignment for K dimension (default 32, must be multiple of SF_VEC_SIZE) ``` <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Enhancements** * Normalization now accepts scale as either a float or tensor; passing a float emits a deprecation warning and is auto-converted for compatibility. * Attention/decoding API: cache-scale parameters are now optional keyword-only arguments with sensible defaults, simplifying common call patterns. * **Tests** * Tests updated to match the adjusted attention/decoding call signature. * **Chores** * Release version bumped to 0.6.7. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
📌 Description
Supports NVFP4 KV cache.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Tests