Skip to content

Add NVFP4 KV cache quantization support for SM100#2702

Merged
aleozlx merged 6 commits intoflashinfer-ai:mainfrom
sychen52:nvfp4-kv-cache-sm100
Mar 19, 2026
Merged

Add NVFP4 KV cache quantization support for SM100#2702
aleozlx merged 6 commits intoflashinfer-ai:mainfrom
sychen52:nvfp4-kv-cache-sm100

Conversation

@sychen52
Copy link
Contributor

@sychen52 sychen52 commented Mar 6, 2026

📌 Description

add nvfp4 kv cache prefill and decode kernels to flashinfer

🔍 Related Issues

#2458

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features
    • NVFP4 (4-bit) KV-cache quantization with per-block FP8 scales and end-to-end support in decode/generation paths.
    • New CLI option to control KV-cache dtype (auto/fp8/nvfp4) and threading of per-block KV scales through APIs.
  • Tests
    • Expanded tests to validate NVFP4 paths and adjusted tolerances for quantized KV scenarios.
  • Chores
    • Added runtime debug/print helpers and enhanced KV-dtype and scaling telemetry.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 6, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds NVFP4 (4-bit) KV-cache quantization end-to-end: new KVFP4QuantizeUtil, thread per-block KV scales through Python decode/benchmark/test code, and extend CUDA FMHA launcher/kernel interfaces to accept and use key/value block-scale pointers.

Changes

Cohort / File(s) Summary
KV FP4 utility & export
flashinfer/testing/kvfp4.py, flashinfer/testing/__init__.py
New KVFP4QuantizeUtil (batched_quantize/dequantize, quantize_paged_kv_cache) and re-export in testing.init.py.
Decode & runtime plumbing
flashinfer/decode.py, flashinfer/mla.py
Added optional kv_block_scales argument across batch/paged decode paths and threaded key/value block scales through paged_run, trtllm_gen, and MLA callsites.
Benchmarks & routines
benchmarks/bench_trtllm_fmha.py, benchmarks/routines/attention.py, benchmarks/routines/flashinfer_benchmark_utils.py
CLI --kv_cache_dtype added; nvfp4 path quantizes KV via KVFP4QuantizeUtil, prepares k/v scales and kv_block_scales, and forwards dtype/scales to wrapper and TB/s accounting.
CUDA launcher & kernel interfaces
csrc/trtllm_fmha_kernel_launcher.cu
Extended trtllm_paged_attention_launcher/decode/context signatures to accept k/v block-scale pointers; runner params gain kSfBasePtr/vSfBasePtr/kvSfScalePtr; FP4 KV validation and stride adjustments added.
FMHA headers & params
include/flashinfer/trtllm/fmha/kernelParams.h, include/flashinfer/trtllm/fmha/fmhaRunnerParams.h, include/flashinfer/trtllm/fmha/fmhaRunner.cuh, include/flashinfer/trtllm/fmha/fmhaKernels.cuh
Added debug-print helpers, expanded KV dtype acceptance (E2M1), new print() helpers, updated TMA/descriptor signatures (bitsPerElt, unpack4b, storeTransformedKvInTmem) and adjusted KV layout/stride logic.
Tests
tests/attention/test_trtllm_gen_attention.py
create_kv_cache now returns kv_block_scales; nvfp4 path added using KVFP4QuantizeUtil; callers, skips, tolerances, and mismatch allowances updated for nvfp4.
Misc / mapping
flashinfer/testing/__init__.py, benchmarks/...
Added dtype mapping "nvfp4" -> torch.uint8; IO accounting updated to handle single-tensor or (k,v) tuple KV caches; wrapper signatures extended to accept scales/block-scales.

Sequence Diagram(s)

sequenceDiagram
    participant User as User Code
    participant PyAPI as flashinfer.decode
    participant Quant as KVFP4QuantizeUtil
    participant Wrapper as Bench/Wrapper
    participant Launcher as CUDA Launcher
    participant Kernel as FMHA Kernel

    User->>PyAPI: batch_decode_with_paged_kv_cache(..., kv_cache?, kv_block_scales?)
    alt nvfp4 quantization path
        PyAPI->>Quant: quantize_paged_kv_cache(k_cache, v_cache)
        Quant-->>PyAPI: quantized_kv, kv_block_scales, k_scale, v_scale
        PyAPI->>Wrapper: run(..., kv_data_type=uint8, k_scale, v_scale, kv_block_scales)
    else non-nvfp4 path
        PyAPI->>Wrapper: run(..., kv_data_type=kv_cache.dtype, k_scale, v_scale)
    end
    Wrapper->>Launcher: trtllm_paged_attention_decode(..., k_block_scales_ptr, v_block_scales_ptr)
    Launcher->>Kernel: invoke kernel with scale pointers
    Kernel-->>User: attention outputs
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • yzh119
  • aleozlx
  • cyx-6
  • nvmbreughe
  • jiahanc
  • jimmyzho
  • bkryu
  • Anerudhan

Poem

🐰
I nibbled blocks of tiny scales,
Packed four-bit crumbs in winding trails,
From Python hop to CUDA lair,
I threaded scales with tender care—
Now attention hums through quantized vales.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 38.24% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly and concisely summarizes the main change: adding NVFP4 KV cache quantization support for SM100 GPUs. It is specific, relevant, and directly reflects the core functionality being introduced across all modified files.
Description check ✅ Passed The PR description provides a concise title and related issue link, meeting basic requirements, though it lacks detailed implementation context.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

You can make CodeRabbit's review stricter and more nitpicky using the `assertive` profile, if that's what you prefer.

Change the reviews.profile setting to assertive to make CodeRabbit's nitpick more issues in your PRs.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances FlashInfer's capabilities by introducing support for NVFP4 Key-Value cache, primarily targeting NVIDIA's SM100 (Hopper) architecture. This addition aims to improve memory efficiency and potentially performance for large language models by allowing KV caches to be stored in a compact 4-bit floating-point format. The changes span across the Python API, C++ kernels, and testing infrastructure, ensuring a robust and well-integrated implementation of this new data type for paged attention operations.

Highlights

  • NVFP4 KV Cache Support: Introduced support for NVFP4 (NVIDIA FP4) Key-Value cache, enabling more memory-efficient attention mechanisms, particularly for SM100 architectures.
  • Quantization Utility: Added a new KVFP4QuantizeUtil module for pure-PyTorch NVFP4 quantization and dequantization, including two-level scaling (global FP32 + per-block FP8).
  • TRT-LLM Integration: Integrated NVFP4 KV cache into the trtllm_batch_decode_with_kv_cache function, allowing the TRT-LLM backend to utilize this new data type.
  • Benchmarking and Testing: Updated benchmarks (bench_trtllm_fmha.py, routines/attention.py) and tests (test_trtllm_gen_attention.py) to include NVFP4 KV cache, ensuring performance and correctness.
  • C++ Kernel Enhancements: Modified C++ kernels (trtllm_fmha_kernel_launcher.cu, fmhaKernels.cuh, fmhaRunner.cuh, kernelParams.h) to handle NVFP4 data types, block scales, and associated memory access patterns.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • benchmarks/bench_trtllm_fmha.py
    • Imported flashinfer.decode and KVFP4QuantizeUtil for NVFP4 handling.
    • Modified KV cache initialization to support NVFP4 quantization and scale preparation.
    • Updated bench_trtllm_fmha_wrapper to pass k_scale_val, v_scale_val, and kv_block_scales to the attention wrapper.
    • Adjusted I/O calculation for NVFP4 KV cache, considering packed uint8 format.
    • Added --kv_cache_dtype argument to the benchmark script for specifying KV cache data type.
  • benchmarks/routines/attention.py
    • Imported flashinfer.decode and KVFP4QuantizeUtil.
    • Introduced is_nvfp4_kv flag to identify NVFP4 KV cache.
    • Modified head_dim_vo assignment to handle None values.
    • Added logic to filter unsupported backends for NVFP4 KV cache (only trtllm-native is supported).
    • Implemented NVFP4 KV cache quantization and scale preparation before running backend wrappers.
    • Passed kv_block_scales to trtllm_batch_decode_with_kv_cache calls.
    • Updated bandwidth calculation for NVFP4, using torch.uint8 as a proxy dtype for packed FP4 values.
  • benchmarks/routines/flashinfer_benchmark_utils.py
    • Added 'nvfp4' as a supported data type string, returning 'nvfp4' string for kv_dtype.
  • csrc/trtllm_fmha_kernel_launcher.cu
    • Modified trtllm_paged_attention_launcher signature to include k_block_scales_ptr and v_block_scales_ptr.
    • Updated runner_params to include kSfBasePtr and vSfBasePtr for block scales.
    • Adjusted trtllm_paged_attention_decode function signature to accept key_block_scales and value_block_scales.
    • Added logic to determine is_fp4_kv and stride_idx_factor for 4-bit KV cache.
    • Extracted k_block_scales_ptr and v_block_scales_ptr from optional TensorView arguments.
    • Passed new block scale pointers to the trtllm_paged_attention_launcher call.
  • flashinfer/decode.py
    • Added kv_block_scales parameter to BatchDecodeWithPagedKVCacheWrapper.run and trtllm_batch_decode_with_kv_cache.
    • Implemented logic to unpack kv_block_scales into key_block_scales and value_block_scales.
    • Adjusted output tensor out_head_dim calculation to use query's head dimension when NVFP4 KV cache is used.
    • Added key_block_scales and value_block_scales to the arguments passed to the C++ _paged_run function.
    • Introduced is_nvfp4_kvcache flag to identify NVFP4 KV cache based on dtype and presence of block scales.
    • Added logic to unpack kv_block_scales for NVFP4 KV cache, handling single tensor or tuple inputs.
    • Ensured k/v_block_scales are float8_e4m3fn dtype for NVFP4.
    • Transposed k_block_scales and v_block_scales if kv_layout is 'NHD' for NVFP4 KV cache.
  • flashinfer/jit/cubin_loader.py
    • Added an early return cubin statement in load_cubin.
    • Added a conditional return b'' if FLASHINFER_DISABLE_CUBIN_DOWNLOAD environment variable is set in get_cubin.
  • flashinfer/mla.py
    • Added None placeholders for key_block_scales and value_block_scales in trtllm_batch_decode_with_kv_cache_mla.
  • flashinfer/testing/init.py
    • Exported KVFP4QuantizeUtil from flashinfer.testing.kvfp4.
  • flashinfer/testing/kvfp4.py
    • Added new file kvfp4.py containing KVFP4QuantizeUtil class.
    • Implemented batched_quantize for NVFP4 quantization with two-level scaling (global FP32 + block FP8 E4M3).
    • Implemented batched_dequantize for NVFP4 dequantization.
    • Implemented quantize_paged_kv_cache to quantize K/V caches in HND layout and adjust scales for FP8 compute.
  • include/flashinfer/trtllm/fmha/fmhaKernels.cuh
    • Included <cfloat> header.
    • Added SAM_DEBUG macro and extensive debug print statements for kernel launch parameters and meta-information.
    • Added debug print statements for numCtasX tracing.
  • include/flashinfer/trtllm/fmha/fmhaRunner.cuh
    • Expanded supported mDtypeKv types to include DATA_TYPE_E2M1 (NVFP4).
  • include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
    • Added a print() method to TllmGenFmhaRunnerParams for debugging all member variables.
  • include/flashinfer/trtllm/fmha/kernelParams.h
    • Added SAM_DEBUG macro and a printKernelParams() method for debugging.
    • Modified getDevicePtrs to use bitsPerElt instead of bytesPerElt for pointer calculations.
    • Updated makeTmaShapeStrideKv to accept storeTransformedKvInTmem parameter.
    • Adjusted TMA descriptor shape calculation for DATA_TYPE_E2M1 based on storeTransformedKvInTmem.
    • Modified buildNdTmaDescriptor to accept unpack4b parameter for DATA_TYPE_E2M1.
    • Added logic to determine storeTransformedKvInTmem and swizzleKv for K/V TMA descriptors.
    • Updated buildNdTmaDescriptor calls for K and V TMA descriptors with new swizzleKv and unpack4b parameters.
    • Added debug print statements for TMA shapes, strides, and tile shapes.
    • Changed params.mSumOfSeqLensKv to a fixed value of 64 (likely for debugging purposes).
  • tests/attention/test_trtllm_gen_attention.py
    • Imported KVFP4QuantizeUtil.
    • Modified create_kv_cache to handle 'nvfp4' kv_dtype, quantizing BF16 data and returning kv_block_scales.
    • Updated _test_trtllm_batch_prefill to unpack kv_block_scales (though not used in prefill).
    • Added NVFP4-specific constraints to _test_trtllm_batch_decode (e.g., requires trtllm-gen backend, FP8 query/output, HND layout, no sink/speculative decoding).
    • Updated _test_trtllm_batch_decode to pass kv_block_scales to the attention function.
    • Adjusted rtol, atol, and allowed_mismatch_rate for NVFP4 KV cache due to increased quantization error.
    • Extended kv_dtype parameterization to include 'nvfp4' in test cases.
Activity
  • The pull request introduces NVFP4 KV cache support, with significant changes across core logic, benchmarks, and tests.
  • A new Python utility KVFP4QuantizeUtil was added for NVFP4 quantization and dequantization.
  • C++ kernels were updated to handle the new NVFP4 data type and associated scaling factors.
  • Benchmarking and testing infrastructure were adapted to validate the new NVFP4 functionality.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for NVFP4 KV cache, primarily targeting SM100 architectures. The changes span across benchmark files, the core Python API, C++ CUDA kernel launchers, and tests. A new testing utility for NVFP4 quantization is also added.

My review has identified a few critical issues that appear to be leftover debugging code. These changes disable CUBIN checksum validation and downloading, and hardcode a value for mSumOfSeqLensKv, which will likely break functionality. These need to be addressed before merging.

Otherwise, the implementation for NVFP4 support seems consistent and well-integrated into the existing architecture.

Note: Security Review did not run due to the size of the PR.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
flashinfer/jit/cubin_loader.py (1)

184-196: ⚠️ Potential issue | 🔴 Critical

Restore checksum validation in load_cubin.

Line 186 returns before FLASHINFER_CUBIN_CHECKSUM_DISABLED and the SHA-256 check run, so every cached cubin is now accepted even when it is stale or corrupted. That breaks this helper’s contract and can hand the wrong binary to the native loader.

Suggested fix
     try:
         with open(cubin_path, mode="rb") as f:
             cubin = f.read()
-            return cubin
             if os.getenv("FLASHINFER_CUBIN_CHECKSUM_DISABLED"):
                 return cubin
             m = hashlib.sha256()
             m.update(cubin)
             actual_sha = m.hexdigest()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/jit/cubin_loader.py` around lines 184 - 196, In load_cubin, remove
the premature "return cubin" that short-circuits checksum logic: read the file
into cubin, then if os.getenv("FLASHINFER_CUBIN_CHECKSUM_DISABLED") return
cubin; otherwise compute the SHA-256 (using hashlib.update on cubin ->
actual_sha) and compare to the provided sha256 argument; if they match return
cubin, else log the mismatch via logger.warning (including expected and actual)
and raise a ValueError (or otherwise fail) so a corrupted/stale cubin is not
silently accepted. Ensure you reference the existing symbols cubin_path, cubin,
sha256, FLASHINFER_CUBIN_CHECKSUM_DISABLED, and logger when making the change.
flashinfer/decode.py (2)

2299-2338: ⚠️ Potential issue | 🟠 Major

Reject invalid NVFP4 combinations before backend dispatch.

is_nvfp4_kvcache only becomes true when kv_block_scales is present, so a uint8 KV cache without scales currently falls through into the normal path. The same packed-KV call can also still resolve to xqa when backend="auto" on non-SM100 devices. Both cases end up sending NVFP4 data to a backend that cannot decode it.

Suggested fix
-    is_nvfp4_kvcache = (
-        k_cache.dtype == torch.uint8
-        and v_cache.dtype == torch.uint8
-        and kv_block_scales is not None
-    )
+    is_nvfp4_kvcache = (
+        k_cache.dtype == torch.uint8 and v_cache.dtype == torch.uint8
+    )
@@
+    if is_nvfp4_kvcache and kv_block_scales is None:
+        raise ValueError("kv_block_scales is required for NVFP4 KV cache")
+
     if backend == "auto":
         backend = (
             "trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa"
         )
+
+    if is_nvfp4_kvcache and backend != "trtllm-gen":
+        raise ValueError("NVFP4 KV cache is only supported by the trtllm-gen backend")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/decode.py` around lines 2299 - 2338, The code currently only sets
is_nvfp4_kvcache when kv_block_scales is present, allowing uint8 k_cache/v_cache
without scales or auto-resolving backend="auto" to xqa to proceed and send NVFP4
data to an incompatible backend; update decode.py to validate NVFP4-packed KV
cache before backend dispatch by (1) adding a check that if
k_cache.dtype==torch.uint8 and v_cache.dtype==torch.uint8 then kv_block_scales
must be non-None and raise a ValueError if missing (refer to variables
is_nvfp4_kvcache, k_cache, v_cache, kv_block_scales), and (2) ensure that NVFP4
output/packed-KV is rejected when backend resolves to "xqa" (whether backend was
explicitly "xqa" or chosen via backend == "auto" using
get_compute_capability(query.device)) so the existing xqa checks for
out_dtype/FP4Tensor also cover packed uint8 KV cache scenarios.

2025-2076: ⚠️ Potential issue | 🔴 Critical

Fix the fake op registration to match the real custom op exactly.

The _fake_paged_run function violates the torch.compile fake-tensor contract. It must be registered under the same op name and have an identical signature as paged_run:

  1. Change decorator from @register_fake_op(f"flashinfer::{uri}_paged_run") to @register_fake_op(f"flashinfer::{uri}_ragged_run")
  2. Add missing parameters: scale_q, scale_k, scale_v (after sm_scale), and workspace_size (after token_pos_in_items_len)

Without this fix, shape and dtype inference during torch.compile will fail for this decode path.

Suggested fix
-    `@register_fake_op`(f"flashinfer::{uri}_paged_run")
+    `@register_fake_op`(f"flashinfer::{uri}_ragged_run")
     def _fake_paged_run(
         float_workspace_buffer: torch.Tensor,
         int_workspace_buffer: torch.Tensor,
         plan_info_vec: List[int],
         q: torch.Tensor,
@@
         maybe_max_item_len_ptr: Optional[torch.Tensor],
         logits_soft_cap: float,
         sm_scale: float,
+        scale_q: Optional[torch.Tensor],
+        scale_k: Optional[torch.Tensor],
+        scale_v: Optional[torch.Tensor],
         rope_scale: float,
         rope_theta: float,
         token_pos_in_items_len: int,
+        workspace_size: int,
         paged_kv_cache: Optional[torch.Tensor] = None,
         num_qo_heads: Optional[int] = None,
         num_kv_heads: Optional[int] = None,
         block_tables: Optional[torch.Tensor] = None,

Also applies to: line 2109 area (the other fake op at lines 2108-2147 has the same issues).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/decode.py` around lines 2025 - 2076, The fake-op registration for
_fake_paged_run must exactly mirror the real custom op paged_run: update the
decorator from `@register_fake_op`(f"flashinfer::{uri}_paged_run") to
`@register_fake_op`(f"flashinfer::{uri}_ragged_run") and add the missing
parameters to the fake function signature—insert scale_q, scale_k, scale_v
immediately after sm_scale, and add workspace_size after
token_pos_in_items_len—to match paged_run; apply the same decorator and
signature fixes to the second fake op variant around the 2108–2147 area so both
fake registrations fully match paged_run.
🧹 Nitpick comments (1)
flashinfer/testing/kvfp4.py (1)

35-43: Module-level CUDA tensors created at import time.

E2M1_VALUES and E2M1_BOUNDS are created when the module is imported. If CUDA is available but not initialized, or if there are multiple CUDA devices, this could cause issues. Consider lazy initialization or making these tensors on-demand.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/testing/kvfp4.py` around lines 35 - 43, E2M1_VALUES and
E2M1_BOUNDS are created as CUDA tensors at module import, which can fail if CUDA
isn't initialized or when multiple devices exist; change them to be created
on-demand by replacing the module-level tensors with factory accessors (e.g.,
get_e2m1_values(device=None) and get_e2m1_bounds(device=None)) that take an
optional device, compute torch.tensor(...) inside the function and return
.to(device) (defaulting to CPU or a resolved _device) so no CUDA tensors are
allocated at import time; update all uses of E2M1_VALUES and E2M1_BOUNDS to call
the new getters.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/routines/flashinfer_benchmark_utils.py`:
- Around line 272-273: The function dtype_str_to_torch_dtype currently returns
the string "nvfp4" for dtype_str == "nvfp4", breaking the expected return type
of torch.dtype; change this to return torch.uint8 (the carrier dtype for nvfp4)
and add an inline comment in dtype_str_to_torch_dtype explaining that nvfp4 is
represented/carried as torch.uint8 so callers should interpret it specially, or
alternatively update the function docstring and return type annotation to
document the nvfp4 special-case if you prefer explicit typing; ensure all code
paths return a torch.dtype.

In `@flashinfer/decode.py`:
- Around line 1285-1292: kv_block_scales are left in NHD layout while
k_cache/v_cache are converted to HND for the trtllm-gen wrapper, causing
mismatched layouts for the NVFP4 kernel; when kv_block_scales is not None,
transpose/permute key_block_scales and value_block_scales the same way you
transform k_cache/v_cache (i.e., convert their NHD ordering to HND ordering) so
they match the packed KV pages, and apply the same change to the second
occurrence handling the same variables later
(key_block_scales/value_block_scales and kv_block_scales).

In `@flashinfer/jit/cubin_loader.py`:
- Around line 239-246: The get_cubin() function currently has an unconditional
return b"" that short-circuits the cache-miss path and prevents the
FLASHINFER_DISABLE_CUBIN_DOWNLOAD check and download logic from running; remove
the early "return b\"\"" so the function evaluates
os.getenv("FLASHINFER_DISABLE_CUBIN_DOWNLOAD") and, only if that env var is set,
logs the error about cubin_path and returns b""; otherwise proceed with the
existing download/cache retrieval logic in cubin_loader.py to restore cold-start
fetches used by callers like gemm/core.py, fused_moe.py, attention/modules.py,
and moe_utils.py.

In `@flashinfer/testing/kvfp4.py`:
- Around line 201-204: The division by 6.0 on k_blk_scales and v_blk_scales can
underflow when converted to torch.float8_e4m3fn; clamp the scaled block scales
to the FP8 minimum positive value before casting. Update the block where
k_blk_scales and v_blk_scales are adjusted (symbols: k_blk_scales, v_blk_scales,
float8_e4m3fn) to compute the divided values, clamp them with a conservative min
like 2**-9 (~0.00195) using .clamp(min=min_val), then cast to
torch.float8_e4m3fn; this preserves small-but-nonzero scales and avoids
underflow while keeping the intended FP8 adjustment.
- Around line 104-107: block_scales can be zero causing division by zero when
computing x_scaled = reshaped / (block_scales_fixed * global_scale); fix it by
ensuring block_scales_fixed never contains zeros (e.g., clamp_min or replace
zeros with a small epsilon like 1e-6) before the division so that
block_scales_fixed (derived from block_scales/block_max) is safe; update the
code that constructs block_scales_fixed (and any use of block_scales or
block_max) to apply the clamp/eps consistently so x_scaled and downstream
computations never divide by zero.

In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh`:
- Around line 486-489: The stray fflush(stdout) after the commented debug printf
in fmhaKernels.cuh should not execute unconditionally; either remove the
fflush(stdout) or wrap it in the same SAM_DEBUG guard used for other debug
prints (same pattern around the printf near the multiCtasKv debug block), so
update the code around the multiCtasKv debug section to put fflush(stdout)
inside the SAM_DEBUG macro (or delete it) to match the debug gating used
elsewhere.
- Around line 430-432: Remove or guard the leftover unconditional fflush(stdout)
in fmhaKernels.cuh: locate the orphaned fflush(stdout) call (near the commented
printf in the fmha kernel trace region) and either delete it or move it inside
the SAM_DEBUG conditional so it only runs when SAM_DEBUG is enabled; ensure no
other unconditional debug I/O remains in the same function/kernel after the
change.

In `@include/flashinfer/trtllm/fmha/kernelParams.h`:
- Around line 922-923: Replace the debug hardcoded assignment to
params.mSumOfSeqLensKv with the original dynamic value: restore the use of
options.mSumOfSeqLensKv (i.e., set params.mSumOfSeqLensKv =
options.mSumOfSeqLensKv) and remove the fixed "64" literal; if needed, add a
defensive check that options.mSumOfSeqLensKv is valid before assigning to
params.mSumOfSeqLensKv to avoid regressions.

In `@tests/attention/test_trtllm_gen_attention.py`:
- Around line 1064-1070: The test guard currently excludes NVFP4 KV paths so the
new wrapper code never gets exercised; remove the `kv_dtype != "nvfp4"`
condition (or adjust the conditional logic) so the branch runs when
`kv_block_scales` can be provided, and ensure `wrapper_trtllm_gen.run(...)`
forwards the `kv_block_scales` argument into
`BatchDecodeWithPagedKVCacheWrapper.run(...)` (the wrapper's NVFP4 behavior is
gated on `kv_block_scales` being non-None), referencing
`wrapper_trtllm_gen.run`, `BatchDecodeWithPagedKVCacheWrapper.run`, and the
`kv_block_scales` parameter to locate and fix the code.

---

Outside diff comments:
In `@flashinfer/decode.py`:
- Around line 2299-2338: The code currently only sets is_nvfp4_kvcache when
kv_block_scales is present, allowing uint8 k_cache/v_cache without scales or
auto-resolving backend="auto" to xqa to proceed and send NVFP4 data to an
incompatible backend; update decode.py to validate NVFP4-packed KV cache before
backend dispatch by (1) adding a check that if k_cache.dtype==torch.uint8 and
v_cache.dtype==torch.uint8 then kv_block_scales must be non-None and raise a
ValueError if missing (refer to variables is_nvfp4_kvcache, k_cache, v_cache,
kv_block_scales), and (2) ensure that NVFP4 output/packed-KV is rejected when
backend resolves to "xqa" (whether backend was explicitly "xqa" or chosen via
backend == "auto" using get_compute_capability(query.device)) so the existing
xqa checks for out_dtype/FP4Tensor also cover packed uint8 KV cache scenarios.
- Around line 2025-2076: The fake-op registration for _fake_paged_run must
exactly mirror the real custom op paged_run: update the decorator from
`@register_fake_op`(f"flashinfer::{uri}_paged_run") to
`@register_fake_op`(f"flashinfer::{uri}_ragged_run") and add the missing
parameters to the fake function signature—insert scale_q, scale_k, scale_v
immediately after sm_scale, and add workspace_size after
token_pos_in_items_len—to match paged_run; apply the same decorator and
signature fixes to the second fake op variant around the 2108–2147 area so both
fake registrations fully match paged_run.

In `@flashinfer/jit/cubin_loader.py`:
- Around line 184-196: In load_cubin, remove the premature "return cubin" that
short-circuits checksum logic: read the file into cubin, then if
os.getenv("FLASHINFER_CUBIN_CHECKSUM_DISABLED") return cubin; otherwise compute
the SHA-256 (using hashlib.update on cubin -> actual_sha) and compare to the
provided sha256 argument; if they match return cubin, else log the mismatch via
logger.warning (including expected and actual) and raise a ValueError (or
otherwise fail) so a corrupted/stale cubin is not silently accepted. Ensure you
reference the existing symbols cubin_path, cubin, sha256,
FLASHINFER_CUBIN_CHECKSUM_DISABLED, and logger when making the change.

---

Nitpick comments:
In `@flashinfer/testing/kvfp4.py`:
- Around line 35-43: E2M1_VALUES and E2M1_BOUNDS are created as CUDA tensors at
module import, which can fail if CUDA isn't initialized or when multiple devices
exist; change them to be created on-demand by replacing the module-level tensors
with factory accessors (e.g., get_e2m1_values(device=None) and
get_e2m1_bounds(device=None)) that take an optional device, compute
torch.tensor(...) inside the function and return .to(device) (defaulting to CPU
or a resolved _device) so no CUDA tensors are allocated at import time; update
all uses of E2M1_VALUES and E2M1_BOUNDS to call the new getters.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 2da26834-27b1-49c2-be23-cb0a19ace74e

📥 Commits

Reviewing files that changed from the base of the PR and between 124a2d3 and 4de4aa5.

📒 Files selected for processing (14)
  • benchmarks/bench_trtllm_fmha.py
  • benchmarks/routines/attention.py
  • benchmarks/routines/flashinfer_benchmark_utils.py
  • csrc/trtllm_fmha_kernel_launcher.cu
  • flashinfer/decode.py
  • flashinfer/jit/cubin_loader.py
  • flashinfer/mla.py
  • flashinfer/testing/__init__.py
  • flashinfer/testing/kvfp4.py
  • include/flashinfer/trtllm/fmha/fmhaKernels.cuh
  • include/flashinfer/trtllm/fmha/fmhaRunner.cuh
  • include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
  • include/flashinfer/trtllm/fmha/kernelParams.h
  • tests/attention/test_trtllm_gen_attention.py

@sychen52 sychen52 force-pushed the nvfp4-kv-cache-sm100 branch 3 times, most recently from d885d8a to e5f8507 Compare March 6, 2026 01:17
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/decode.py (1)

2328-2366: ⚠️ Potential issue | 🟠 Major

Do not route NVFP4 KV cache to XQA.

is_nvfp4_kvcache is recognized above, but only the trtllm-gen branch forwards k_block_scales / v_block_scales. With backend="auto" on SM120/121, or backend="xqa" explicitly, this silently ignores the new scale tensors and dispatches a packed uint8 KV cache to an unsupported backend.

Suggested fix
     if backend == "auto":
-        backend = (
-            "trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa"
-        )
+        cc_major = get_compute_capability(query.device)[0]
+        if is_nvfp4_kvcache:
+            if cc_major != 10:
+                raise ValueError(
+                    "NVFP4 KV cache is only supported by the trtllm-gen backend on SM100/SM103."
+                )
+            backend = "trtllm-gen"
+        else:
+            backend = "trtllm-gen" if cc_major == 10 else "xqa"

     if backend == "xqa":
+        if is_nvfp4_kvcache:
+            raise ValueError("xqa backend does not support NVFP4 KV cache.")
         # xqa backend doesn't support nvfp4 output
♻️ Duplicate comments (3)
include/flashinfer/trtllm/fmha/kernelParams.h (1)

921-923: ⚠️ Potential issue | 🔴 Critical

Restore the real KV length sum.

Line 923 hardcodes mSumOfSeqLensKv to 64. Any batch whose total KV length differs from that will build incorrect kernel params and break the generated launch configuration.

Suggested fix
     params.mSumOfSeqLensQ = options.mSumOfSeqLensQ;
-    // params.mSumOfSeqLensKv = options.mSumOfSeqLensKv;
-    params.mSumOfSeqLensKv = 64;
+    params.mSumOfSeqLensKv = options.mSumOfSeqLensKv;
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/trtllm/fmha/kernelParams.h` around lines 921 - 923,
Replace the hardcoded constant with the actual KV length sum: change the
assignment that sets params.mSumOfSeqLensKv to use options.mSumOfSeqLensKv
instead of the literal 64 so the kernel params reflect the true batch KV length
(look for params.mSumOfSeqLensKv and options.mSumOfSeqLensKv in the
kernelParams.h code).
flashinfer/decode.py (1)

1302-1306: ⚠️ Potential issue | 🟠 Major

Transpose the block-scale tensors with the KV pages.

On the NHD wrapper path, k_cache and v_cache are converted to HND at Lines 1305-1306, but key_block_scales and value_block_scales stay in NHD order. The direct trtllm_batch_decode_with_kv_cache path below already transposes them, so the wrapper currently feeds mismatched layouts into the NVFP4 kernel.

Suggested fix
         if self._backend == "trtllm-gen" and self._kv_layout == "NHD":
             # For NHD: [..., N, H, D] -> HND: [..., H, N, D]
             k_cache = k_cache.transpose(-3, -2)
             v_cache = v_cache.transpose(-3, -2)
+            if key_block_scales is not None:
+                key_block_scales = key_block_scales.transpose(-3, -2)
+            if value_block_scales is not None:
+                value_block_scales = value_block_scales.transpose(-3, -2)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/decode.py` around lines 1302 - 1306, The NHD-to-HND transpose only
moves k_cache and v_cache but leaves key_block_scales and value_block_scales in
NHD order, causing layout mismatch for the trtllm-gen NVFP4 kernel; update the
same conditional in decode.py (the branch that checks self._backend ==
"trtllm-gen" and self._kv_layout == "NHD") to also transpose key_block_scales
and value_block_scales with the same axes (use transpose(-3, -2) like
k_cache/v_cache) so their layout matches before calling
trtllm_batch_decode_with_kv_cache.
flashinfer/testing/kvfp4.py (1)

118-126: ⚠️ Potential issue | 🟠 Major

Avoid 0/0 in all-zero blocks.

When block_max is zero, Line 126 divides by block_scales_fixed * global_scale == 0 and the quantized block becomes NaN. Use a clamped copy only for the divisor path.

Suggested fix
-        block_scales_fixed = block_scales.unsqueeze(-1)
+        block_scales_fixed = block_scales.clamp_min(1e-6).unsqueeze(-1)
         x_scaled = reshaped / (block_scales_fixed * global_scale)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/testing/kvfp4.py` around lines 118 - 126, The division can produce
0/0 for all-zero blocks when computing x_scaled = reshaped / (block_scales_fixed
* global_scale); update the code that computes x_scaled to use a safe divisor:
compute safe_divisor = (block_scales_fixed * global_scale).clone() and clamp it
to a small positive eps (e.g. 1e-8) before dividing, then restore zeros for
all-zero blocks (detect via block_max == 0) by setting x_scaled[...] = 0 where
block_max.squeeze(-1) == 0; change references to
block_scales_fixed/global_scale/block_max and x_scaled accordingly so only the
divisor path uses the clamped copy.
🧹 Nitpick comments (1)
benchmarks/bench_trtllm_fmha.py (1)

123-139: Consider clarifying variable naming for scale computation.

The k_scale_val variable initially holds the query inverse scale (q_inv_scale) before being multiplied by the K global scale (k_gs). While the final computation is correct for bmm1_scale = q_scale * k_scale * sm_scale, the intermediate naming is confusing.

Consider renaming for clarity:

♻️ Suggested clarification
     if kv_cache_dtype == "nvfp4":
         # NVFP4 KV requires FP8 query
         if q.dtype != torch.float8_e4m3fn:
             q, q_inv_scale = to_float8(q)
-            k_scale_val = (
+            q_scale_val = (
                 q_inv_scale.item()
                 if isinstance(q_inv_scale, torch.Tensor)
                 else q_inv_scale
             )
         else:
-            k_scale_val = 1.0
+            q_scale_val = 1.0

         kv_cache, kv_block_scales, k_gs, v_gs = (
             KVFP4QuantizeUtil.quantize_paged_kv_cache(kv_cache[:, 0], kv_cache[:, 1])
         )
-        k_scale_val *= k_gs
+        # Combined scale for bmm1: q_scale * k_scale
+        k_scale_val = q_scale_val * k_gs
         v_scale_val = v_gs
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_trtllm_fmha.py` around lines 123 - 139, The scale variable
naming is confusing: k_scale_val is initially set from the query inverse scale
(q_inv_scale) and later multiplied by the K global scale (k_gs). Update the
variables in the nvfp4 branch to make intent explicit: capture the query inverse
scale into a clearly named variable (e.g., q_inv_scale_val or q_scale_val) when
calling to_float8(q), keep the K/V global scales returned from
KVFP4QuantizeUtil.quantize_paged_kv_cache as k_gs and v_gs, then compute the
final k_scale_val by multiplying the query-scale variable with k_gs and set
v_scale_val = v_gs; adjust references to q_inv_scale, k_scale_val and
v_scale_val accordingly so the intermediate meaning is clear.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/bench_trtllm_fmha.py`:
- Around line 210-213: The IO calculation misses element sizes when kv_cache is
a tuple: update the tuple branch that computes io so each tensor in kv_cache
multiplies its numel() by element_size() (e.g. replace kv_cache[0].numel() +
kv_cache[1].numel() with kv_cache[0].numel() * kv_cache[0].element_size() +
kv_cache[1].numel() * kv_cache[1].element_size()) so io = q.numel() *
q.element_size() + <those products>; keep the existing scalar-tensor branch
as-is.

In `@flashinfer/testing/kvfp4.py`:
- Around line 29-60: The E2M1 lookup tensors (E2M1_VALUES, E2M1_BOUNDS) are
pinned to a single device at import time which causes device-mismatch errors
when quantize/dequantize are called with tensors on a different GPU; change the
code so these tables are created or moved to the input tensor's device at
runtime inside the quantize and dequantize functions: (1) stop constructing
E2M1_VALUES and E2M1_BOUNDS with device=_device at module import (construct them
on CPU or keep their data as Python lists/constants); (2) inside the functions
that use them (the quantize and dequantize routines that perform comparisons
with E2M1_BOUNDS and indexing into E2M1_VALUES), call
.to(input_tensor.device).to(dtype=torch.float32) (or recreate the tensors on
input_tensor.device) before any comparisons or indexing so both operands share
device and dtype; (3) ensure you reference the same symbols E2M1_VALUES and
E2M1_BOUNDS so callers are unchanged.

---

Duplicate comments:
In `@flashinfer/decode.py`:
- Around line 1302-1306: The NHD-to-HND transpose only moves k_cache and v_cache
but leaves key_block_scales and value_block_scales in NHD order, causing layout
mismatch for the trtllm-gen NVFP4 kernel; update the same conditional in
decode.py (the branch that checks self._backend == "trtllm-gen" and
self._kv_layout == "NHD") to also transpose key_block_scales and
value_block_scales with the same axes (use transpose(-3, -2) like
k_cache/v_cache) so their layout matches before calling
trtllm_batch_decode_with_kv_cache.

In `@flashinfer/testing/kvfp4.py`:
- Around line 118-126: The division can produce 0/0 for all-zero blocks when
computing x_scaled = reshaped / (block_scales_fixed * global_scale); update the
code that computes x_scaled to use a safe divisor: compute safe_divisor =
(block_scales_fixed * global_scale).clone() and clamp it to a small positive eps
(e.g. 1e-8) before dividing, then restore zeros for all-zero blocks (detect via
block_max == 0) by setting x_scaled[...] = 0 where block_max.squeeze(-1) == 0;
change references to block_scales_fixed/global_scale/block_max and x_scaled
accordingly so only the divisor path uses the clamped copy.

In `@include/flashinfer/trtllm/fmha/kernelParams.h`:
- Around line 921-923: Replace the hardcoded constant with the actual KV length
sum: change the assignment that sets params.mSumOfSeqLensKv to use
options.mSumOfSeqLensKv instead of the literal 64 so the kernel params reflect
the true batch KV length (look for params.mSumOfSeqLensKv and
options.mSumOfSeqLensKv in the kernelParams.h code).

---

Nitpick comments:
In `@benchmarks/bench_trtllm_fmha.py`:
- Around line 123-139: The scale variable naming is confusing: k_scale_val is
initially set from the query inverse scale (q_inv_scale) and later multiplied by
the K global scale (k_gs). Update the variables in the nvfp4 branch to make
intent explicit: capture the query inverse scale into a clearly named variable
(e.g., q_inv_scale_val or q_scale_val) when calling to_float8(q), keep the K/V
global scales returned from KVFP4QuantizeUtil.quantize_paged_kv_cache as k_gs
and v_gs, then compute the final k_scale_val by multiplying the query-scale
variable with k_gs and set v_scale_val = v_gs; adjust references to q_inv_scale,
k_scale_val and v_scale_val accordingly so the intermediate meaning is clear.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 4d6d49a9-286b-46c4-97af-ab7e9772f862

📥 Commits

Reviewing files that changed from the base of the PR and between 4de4aa5 and 770f36b.

📒 Files selected for processing (10)
  • benchmarks/bench_trtllm_fmha.py
  • benchmarks/routines/attention.py
  • benchmarks/routines/flashinfer_benchmark_utils.py
  • csrc/trtllm_fmha_kernel_launcher.cu
  • flashinfer/decode.py
  • flashinfer/mla.py
  • flashinfer/testing/__init__.py
  • flashinfer/testing/kvfp4.py
  • include/flashinfer/trtllm/fmha/kernelParams.h
  • tests/attention/test_trtllm_gen_attention.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • benchmarks/routines/flashinfer_benchmark_utils.py

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
include/flashinfer/trtllm/fmha/kernelParams.h (1)

557-570: ⚠️ Potential issue | 🔴 Critical

Packed-QKV with mixed Q/KV dtypes computes wrong K/V offsets.

When isPackedQkv(mQkvLayout) is true and mDataTypeQ != mDataTypeKv, the offsets computed at lines 565–570 are incorrect. getDevicePtrs() receives only bitsPerElt based on mDataTypeKv (line 705), but the packed buffer stores Q first with its own element width. The K offset should account for Q's byte count, not KV's. There is no upstream validation preventing this combination—the code at line 743 explicitly handles mixed dtypes (transformsKv), including the E4M3/E2M1 case (lines 747–749). Either add a guard to reject packed QKV when datatypes differ, or pass separate element widths for Q and K/V to getDevicePtrs().

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/trtllm/fmha/kernelParams.h` around lines 557 - 570,
getDevicePtrs currently computes K/V offsets using a single bitsPerElt (passed
from mDataTypeKv) which is wrong when isPackedQkv(mQkvLayout) is true but
mDataTypeQ != mDataTypeKv; update the code so offsets use Q's element width for
advancing past Q and KV's element width for K/V: either add a guard in
getDevicePtrs/isPackedQkv to reject mixed datatypes (check mDataTypeQ vs
mDataTypeKv and error out) or change getDevicePtrs signature to accept two
element-width parameters (bitsPerEltQ, bitsPerEltKv) and use bitsPerEltQ when
computing the offset from qkvPtr to kPtr and then use bitsPerEltKv for the
offset to vPtr; modify call sites that pass bitsPerElt (e.g., where
getDevicePtrs is invoked) to supply both widths and update references to
mNumHeadsQ, mNumHeadsKv, mHeadDimQk, and qkvPtr accordingly.
flashinfer/decode.py (1)

2301-2331: ⚠️ Potential issue | 🟠 Major

Fail fast before packed NVFP4 KV can fall into XQA.

NVFP4 detection only turns on when kv_block_scales is present. A packed uint8 KV cache without scales is silently treated as ordinary UINT8 input, and backend="auto" can still select XQA on non-SM100 devices even though the XQA binding has no block-scale parameters (csrc/flashinfer_xqa_binding.cu:31-39). Require scales for packed UINT8 caches and reject non-trtllm-gen dispatch up front.

Suggested fix
-    is_nvfp4_kvcache = (
-        k_cache.dtype == torch.uint8
-        and v_cache.dtype == torch.uint8
-        and kv_block_scales is not None
-    )
+    is_packed_uint8_kvcache = k_cache.dtype == torch.uint8 and v_cache.dtype == torch.uint8
+    cc_major = get_compute_capability(query.device)[0]
+    if is_packed_uint8_kvcache:
+        if kv_block_scales is None:
+            raise ValueError("kv_block_scales must be provided for NVFP4 KV cache")
+        if cc_major != 10:
+            raise ValueError("NVFP4 KV cache is only supported on SM100/SM103")
+        if backend not in ("auto", "trtllm-gen"):
+            raise ValueError("NVFP4 KV cache is only supported by the trtllm-gen backend")
+    is_nvfp4_kvcache = is_packed_uint8_kvcache and kv_block_scales is not None
@@
-    if backend == "auto":
-        backend = (
-            "trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa"
-        )
+    if backend == "auto":
+        backend = "trtllm-gen" if cc_major == 10 else "xqa"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/decode.py` around lines 2301 - 2331, Detect packed NVFP4 KV by
checking k_cache.dtype == v_cache.dtype == torch.uint8 even when kv_block_scales
is None, and if such a packed UINT8 KV is found require scales or force/reject
non-trtllm-gen dispatch: if kv_block_scales is None and both k_cache and v_cache
are uint8, then if backend == "auto" set backend = "trtllm-gen", else if backend
!= "trtllm-gen" raise a clear ValueError that kv_block_scales are required for
packed NVFP4 KV or backend must be "trtllm-gen"; keep the existing
is_nvfp4_kvcache logic (which should remain gated on kv_block_scales) and only
treat k_block_scales/v_block_scales when kv_block_scales is provided.
♻️ Duplicate comments (3)
tests/attention/test_trtllm_gen_attention.py (1)

1064-1087: ⚠️ Potential issue | 🟠 Major

Keep wrapper coverage enabled for NVFP4 KV.

The direct call above forwards kv_block_scales, but this guard still skips every kv_dtype == "nvfp4" wrapper case and the wrapper invocation below still omits kv_block_scales. That leaves the new wrapper plumbing in flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper.run() unexercised.

Suggested fix
     if (
         o_dtype != "nvfp4"
-        and kv_dtype != "nvfp4"
         and backend == "trtllm-gen"
         and q_len_per_req
         is not None  # only test for the case all requests have the same q_len
-    ):  # wrapper api does not support fp4 output/kv yet.
+    ):  # wrapper api does not support fp4 output yet.
@@
         output_wrapper = wrapper_trtllm_gen.run(
             q_input,
             kv_cache,
             q_scale=q_scale,
             k_scale=k_scale,
             v_scale=v_scale / o_scale,
+            kv_block_scales=kv_block_scales,
             enable_pdl=enable_pdl,
             sinks=(sink if enable_sink else None),
             q_len_per_req=q_len_per_req,
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_trtllm_gen_attention.py` around lines 1064 - 1087, The
test currently skips wrapper coverage when kv_dtype == "nvfp4" and the
wrapper.run() call omits kv_block_scales; update the condition so NVFP4 KV cases
are allowed through (remove or relax the kv_dtype != "nvfp4" check) and call
wrapper_trtllm_gen.run(...) including the kv_block_scales argument (pass the
existing kv_block_scales variable) so BatchDecodeWithPagedKVCacheWrapper.run()
is exercised; ensure plan(...) is still called on wrapper_trtllm_gen and
q_len_per_req logic is preserved.
flashinfer/testing/kvfp4.py (1)

29-30: ⚠️ Potential issue | 🟠 Major

Move the E2M1 tables off the import-time device.

E2M1_VALUES and E2M1_BOUNDS are pinned to whichever CUDA device is current when this module is imported. batched_quantize() and batched_dequantize() later use them with tensor.device, so calls on cuda:1 (or CPU on a CUDA-enabled host) will fail with a device-mismatch error.

Suggested fix
-# Put constants directly on CUDA if available
-_device = "cuda" if torch.cuda.is_available() else "cpu"
 # E2M1 format: 1 sign bit + 2 exponent bits + 1 mantissa bit = 4 bits
 # 16 possible values: 0x0-0xF
 # Negative values: 0x8-0xF (sign bit = 1)
 # Positive values: 0x0-0x7 (sign bit = 0)
 E2M1_VALUES = torch.tensor(
@@
-    dtype=torch.float32,
-    device=_device,
+    dtype=torch.float32,
 )
 # Boundaries for rounding to nearest E2M1 value (only for positive values)
 E2M1_BOUNDS = torch.tensor(
-    [0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5], dtype=torch.float32, device=_device
+    [0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5], dtype=torch.float32
 )
@@
         b, m, n = tensor.shape
         device = tensor.device
+        e2m1_bounds = E2M1_BOUNDS.to(device=device)
@@
-        magnitude_bits = torch.sum(abs_vals.unsqueeze(-1) >= E2M1_BOUNDS, dim=-1).to(
+        magnitude_bits = torch.sum(abs_vals.unsqueeze(-1) >= e2m1_bounds, dim=-1).to(
             torch.uint8
         )
@@
         b, m, n_half = quant_tensor.shape
         n = n_half * 2
+        e2m1_values = E2M1_VALUES.to(device=quant_tensor.device)
@@
-        float_vals = E2M1_VALUES[fp4_vals.long()]
+        float_vals = e2m1_values[fp4_vals.long()]

Also applies to: 35-60

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/testing/kvfp4.py` around lines 29 - 30, E2M1_VALUES and
E2M1_BOUNDS are created at import using the current _device and so get pinned to
the import-time CUDA device; change them to be created on CPU (or unpinned) and
move to the call-site device inside batched_quantize and batched_dequantize:
keep the global symbols E2M1_VALUES and E2M1_BOUNDS but construct them on CPU
(remove use of _device for their creation) and in
batched_quantize()/batched_dequantize() call .to(tensor.device,
non_blocking=True) (or create device-specific cached copies) before using them
to avoid device-mismatch errors.
flashinfer/decode.py (1)

1285-1292: ⚠️ Potential issue | 🟠 Major

Normalize kv_block_scales exactly like the KV cache.

The wrapper path still assumes tensor-form scales always have two K/V planes, and it never transposes them on the NHD→HND trtllm-gen path. The direct API in this same file already handles both cases, so wrapper calls can either fail during unbind() or feed misaligned scale tensors into the kernel.

Suggested fix
         if kv_block_scales is not None:
             if isinstance(kv_block_scales, tuple):
                 key_block_scales, value_block_scales = kv_block_scales
             else:
-                key_block_scales, value_block_scales = kv_block_scales.unbind(dim=1)
+                if kv_block_scales.shape[1] == 1:
+                    key_block_scales, value_block_scales = kv_block_scales, kv_block_scales
+                else:
+                    assert kv_block_scales.shape[1] == 2, (
+                        "When kv_block_scales is a single tensor, the second dimension must be 1 or 2"
+                    )
+                    key_block_scales, value_block_scales = kv_block_scales.unbind(dim=1)
@@
         if self._backend == "trtllm-gen" and self._kv_layout == "NHD":
             # For NHD: [..., N, H, D] -> HND: [..., H, N, D]
             k_cache = k_cache.transpose(-3, -2)
             v_cache = v_cache.transpose(-3, -2)
+            if key_block_scales is not None:
+                key_block_scales = key_block_scales.transpose(-3, -2)
+            if value_block_scales is not None:
+                value_block_scales = value_block_scales.transpose(-3, -2)

Also applies to: 1302-1306

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/decode.py` around lines 1285 - 1292, The kv_block_scales handling
currently assumes tensor-form scales have two K/V planes and skips the NHD→HND
transpose used by the KV cache, which can cause unbind() failures or misaligned
scales; update the unpacking for kv_block_scales to mirror the KV cache
normalization used elsewhere in this file: accept a tuple as (key_block_scales,
value_block_scales) or a single tensor and split it robustly (not just
unbind(dim=1)), and when the tensor is in NHD layout perform the same transpose
to HND before splitting; apply the same fix at both occurrences that set
key_block_scales/value_block_scales (the block using kv_block_scales and the
later one at the other occurrence), reusing the direct-API normalization logic
rather than assuming shape.
🧹 Nitpick comments (3)
csrc/trtllm_fmha_kernel_launcher.cu (1)

370-375: Inconsistent layout comment.

The comment at line 370 says "Assume NHD layout" but the decode path at line 269 was updated to "Assume HND layout". Both compute strides identically, so one comment is incorrect. Please align the comments to reflect the actual layout.

📝 Proposed fix
-  // Assume NHD layout: [..., H, N, D]
+  // Assume HND layout: [..., H, N, D]
   int page_size = key_cache.size(-2);
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fmha_kernel_launcher.cu` around lines 370 - 375, The comment
above the key_cache stride calculations is inconsistent with the decode path;
change the layout comment to match the decode path's "Assume HND layout" (or
whichever layout is actually used) so it aligns with how strides are computed
for key_cache; update the comment near the declarations of page_size,
num_kv_heads, kv_stride_keys_values, kv_stride_heads, and kv_stride_batch to
read "Assume HND layout: [..., H, N, D]" (or update the decode path comment
instead if HND is wrong) so both locations consistently describe the same
layout.
benchmarks/bench_trtllm_fmha.py (2)

127-133: Redundant tensor check and unreachable branch.

to_float8 always returns a tensor for q_inv_scale, so the isinstance check at line 129 is always True, making lines 132-133 unreachable dead code. Simplify this logic:

♻️ Proposed fix
         if q.dtype != torch.float8_e4m3fn:
             q, q_inv_scale = to_float8(q)
-            k_scale_val = (
-                q_inv_scale.item()
-                if isinstance(q_inv_scale, torch.Tensor)
-                else q_inv_scale
-            )
+            k_scale_val = q_inv_scale.item()
         else:
             k_scale_val = 1.0
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_trtllm_fmha.py` around lines 127 - 133, The code sets
k_scale_val with an unnecessary isinstance check against torch.Tensor because
to_float8 always returns a tensor; remove the conditional and unreachable else
branch and simply extract the scalar via q_inv_scale.item() (or use
float(q_inv_scale)) so k_scale_val is assigned from q_inv_scale.item() directly;
update the assignment near the k_scale_val variable where q_inv_scale is
produced and remove the now-dead else branch.

140-141: Discarded FP8 scale value.

The to_float8 function returns a scale value that is discarded here. For accurate benchmarking with FP8, consider whether k_scale/v_scale should be set similarly to the nvfp4 path.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_trtllm_fmha.py` around lines 140 - 141, The FP8 scale
returned by to_float8 is being discarded; update the kv_cache handling so you
capture the returned scale (e.g., kv_cache, kv_scale = to_float8(kv_cache)) and
set the appropriate k_scale and v_scale variables (or assign both to kv_scale)
when kv_cache_dtype.startswith("fp8") and q_dtype != "fp8", matching how the
nvfp4 branch sets k_scale/v_scale so FP8 quantization scales are used
consistently in benchmarking.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/routines/attention.py`:
- Around line 600-606: The NVFP4 branch currently force-casts q to FP8
(torch.float8_e4m3fn) which mislabels reported q_dtype; instead either validate
and reject non-FP8 inputs or propagate an effective_q_dtype used by metrics: in
the is_nvfp4_kv branch (around checks using q and
KVFP4QuantizeUtil.quantize_paged_kv_cache) raise a clear error when q.dtype !=
torch.float8_e4m3fn, or set effective_q_dtype = torch.float8_e4m3fn and ensure
that downstream metric/charging code reads effective_q_dtype rather than the
original q.dtype so the benchmark records the actual query format.
- Around line 397-404: The current NVFP4 KV cache handling in the is_nvfp4_kv
block improperly force-adds "trtllm-native" after compute-capability filtering
(so functions like filter_backends_by_compute_capability / subsequent calls to
flashinfer.decode.trtllm_batch_decode_with_kv_cache may run on an unsupported
GPU or a different backend labeled as trtllm-native). Fix it by not appending
"trtllm-native" unconditionally; instead compute the intersection of allowed
backends and ["trtllm-native"], and if empty log/print a clear message and skip
NVFP4-specific benchmarking (or raise/continue) so only truly supported backends
remain in the backends list; update the block that references is_nvfp4_kv and
the backends list manipulation (the backends variable and the NVFP4 handling
code) to enforce capability-filtering rather than bypassing it.

In `@include/flashinfer/trtllm/fmha/kernelParams.h`:
- Around line 218-241: The pointer dump helper in kernelParams.h is missing
ptrSkipSoftmaxStats, so add a printf entry for ptrSkipSoftmaxStats (matching the
style of the other prints, e.g., printf("ptrSkipSoftmaxStats: %p\n",
(void*)ptrSkipSoftmaxStats);) in the same location and in definition order
alongside the other ptr* prints (refer to existing symbols like ptrSoftmaxStats,
ptrPartialStats, ptrSeqLensKv) so the dump is complete for skip-softmax
debugging.
- Line 260: The printf uses %ld for an int64_t field (mNumHiddenEltsO) which is
non-portable; include <cinttypes> and change the print to use the PRId64 macro
(e.g., printf("mNumHiddenEltsO: %" PRId64 "\n", mNumHiddenEltsO)) so the
formatting is correct across LP64/LLP64 targets — update the include and the
printf call in kernelParams.h where mNumHiddenEltsO is printed.

---

Outside diff comments:
In `@flashinfer/decode.py`:
- Around line 2301-2331: Detect packed NVFP4 KV by checking k_cache.dtype ==
v_cache.dtype == torch.uint8 even when kv_block_scales is None, and if such a
packed UINT8 KV is found require scales or force/reject non-trtllm-gen dispatch:
if kv_block_scales is None and both k_cache and v_cache are uint8, then if
backend == "auto" set backend = "trtllm-gen", else if backend != "trtllm-gen"
raise a clear ValueError that kv_block_scales are required for packed NVFP4 KV
or backend must be "trtllm-gen"; keep the existing is_nvfp4_kvcache logic (which
should remain gated on kv_block_scales) and only treat
k_block_scales/v_block_scales when kv_block_scales is provided.

In `@include/flashinfer/trtllm/fmha/kernelParams.h`:
- Around line 557-570: getDevicePtrs currently computes K/V offsets using a
single bitsPerElt (passed from mDataTypeKv) which is wrong when
isPackedQkv(mQkvLayout) is true but mDataTypeQ != mDataTypeKv; update the code
so offsets use Q's element width for advancing past Q and KV's element width for
K/V: either add a guard in getDevicePtrs/isPackedQkv to reject mixed datatypes
(check mDataTypeQ vs mDataTypeKv and error out) or change getDevicePtrs
signature to accept two element-width parameters (bitsPerEltQ, bitsPerEltKv) and
use bitsPerEltQ when computing the offset from qkvPtr to kPtr and then use
bitsPerEltKv for the offset to vPtr; modify call sites that pass bitsPerElt
(e.g., where getDevicePtrs is invoked) to supply both widths and update
references to mNumHeadsQ, mNumHeadsKv, mHeadDimQk, and qkvPtr accordingly.

---

Duplicate comments:
In `@flashinfer/decode.py`:
- Around line 1285-1292: The kv_block_scales handling currently assumes
tensor-form scales have two K/V planes and skips the NHD→HND transpose used by
the KV cache, which can cause unbind() failures or misaligned scales; update the
unpacking for kv_block_scales to mirror the KV cache normalization used
elsewhere in this file: accept a tuple as (key_block_scales, value_block_scales)
or a single tensor and split it robustly (not just unbind(dim=1)), and when the
tensor is in NHD layout perform the same transpose to HND before splitting;
apply the same fix at both occurrences that set
key_block_scales/value_block_scales (the block using kv_block_scales and the
later one at the other occurrence), reusing the direct-API normalization logic
rather than assuming shape.

In `@flashinfer/testing/kvfp4.py`:
- Around line 29-30: E2M1_VALUES and E2M1_BOUNDS are created at import using the
current _device and so get pinned to the import-time CUDA device; change them to
be created on CPU (or unpinned) and move to the call-site device inside
batched_quantize and batched_dequantize: keep the global symbols E2M1_VALUES and
E2M1_BOUNDS but construct them on CPU (remove use of _device for their creation)
and in batched_quantize()/batched_dequantize() call .to(tensor.device,
non_blocking=True) (or create device-specific cached copies) before using them
to avoid device-mismatch errors.

In `@tests/attention/test_trtllm_gen_attention.py`:
- Around line 1064-1087: The test currently skips wrapper coverage when kv_dtype
== "nvfp4" and the wrapper.run() call omits kv_block_scales; update the
condition so NVFP4 KV cases are allowed through (remove or relax the kv_dtype !=
"nvfp4" check) and call wrapper_trtllm_gen.run(...) including the
kv_block_scales argument (pass the existing kv_block_scales variable) so
BatchDecodeWithPagedKVCacheWrapper.run() is exercised; ensure plan(...) is still
called on wrapper_trtllm_gen and q_len_per_req logic is preserved.

---

Nitpick comments:
In `@benchmarks/bench_trtllm_fmha.py`:
- Around line 127-133: The code sets k_scale_val with an unnecessary isinstance
check against torch.Tensor because to_float8 always returns a tensor; remove the
conditional and unreachable else branch and simply extract the scalar via
q_inv_scale.item() (or use float(q_inv_scale)) so k_scale_val is assigned from
q_inv_scale.item() directly; update the assignment near the k_scale_val variable
where q_inv_scale is produced and remove the now-dead else branch.
- Around line 140-141: The FP8 scale returned by to_float8 is being discarded;
update the kv_cache handling so you capture the returned scale (e.g., kv_cache,
kv_scale = to_float8(kv_cache)) and set the appropriate k_scale and v_scale
variables (or assign both to kv_scale) when kv_cache_dtype.startswith("fp8") and
q_dtype != "fp8", matching how the nvfp4 branch sets k_scale/v_scale so FP8
quantization scales are used consistently in benchmarking.

In `@csrc/trtllm_fmha_kernel_launcher.cu`:
- Around line 370-375: The comment above the key_cache stride calculations is
inconsistent with the decode path; change the layout comment to match the decode
path's "Assume HND layout" (or whichever layout is actually used) so it aligns
with how strides are computed for key_cache; update the comment near the
declarations of page_size, num_kv_heads, kv_stride_keys_values, kv_stride_heads,
and kv_stride_batch to read "Assume HND layout: [..., H, N, D]" (or update the
decode path comment instead if HND is wrong) so both locations consistently
describe the same layout.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 00eff0e7-12ca-4018-8c34-c52eb1b24732

📥 Commits

Reviewing files that changed from the base of the PR and between 770f36b and e5f8507.

📒 Files selected for processing (10)
  • benchmarks/bench_trtllm_fmha.py
  • benchmarks/routines/attention.py
  • benchmarks/routines/flashinfer_benchmark_utils.py
  • csrc/trtllm_fmha_kernel_launcher.cu
  • flashinfer/decode.py
  • flashinfer/mla.py
  • flashinfer/testing/__init__.py
  • flashinfer/testing/kvfp4.py
  • include/flashinfer/trtllm/fmha/kernelParams.h
  • tests/attention/test_trtllm_gen_attention.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • flashinfer/mla.py

@sychen52 sychen52 force-pushed the nvfp4-kv-cache-sm100 branch from e5f8507 to 17f2353 Compare March 6, 2026 19:24
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (5)
benchmarks/routines/attention.py (2)

600-606: ⚠️ Potential issue | 🟠 Major

Don't silently benchmark FP8 queries under a non-FP8 request.

This branch force-casts q to FP8 for NVFP4, but the benchmark still carries the original requested q_dtype through perf accounting and result reporting. Either reject non-FP8 inputs here or thread an explicit effective query dtype through the metric path.

Suggested fix
     if is_nvfp4_kv:
         # NVFP4 KV requires FP8 query
-        if q.dtype != torch.float8_e4m3fn:
-            q = q.to(torch.float8_e4m3fn)
+        if q_dtype != torch.float8_e4m3fn:
+            print("[ERROR] NVFP4 KV cache requires --q_dtype fp8.")
+            return res
         kv_cache_nvfp4, kv_block_scales, k_scale, v_scale = (
             KVFP4QuantizeUtil.quantize_paged_kv_cache(kv_cache[:, 0], kv_cache[:, 1])
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/attention.py` around lines 600 - 606, The code branch
under is_nvfp4_kv force-casts q to torch.float8_e4m3fn but leaves the original
requested q dtype recorded in benchmarks; either validate and raise if q.dtype
is not FP8 or propagate the effective query dtype into the metric/reporting
path. Update the block around
is_nvfp4_kv/q/KVFP4QuantizeUtil.quantize_paged_kv_cache to (a) check q.dtype and
raise a clear error if non-FP8 inputs are not supported, or (b) set an explicit
effective_q_dtype variable (e.g. effective_q_dtype = torch.float8_e4m3fn) after
casting and ensure all perf accounting and result reporting use
effective_q_dtype instead of the original q.dtype so metrics reflect the actual
dtype used.

397-404: ⚠️ Potential issue | 🟠 Major

Don't re-add trtllm-native after capability filtering.

If compute-capability filtering removed trtllm-native, appending it back here can benchmark NVFP4 under an unsupported device/backend combination. The later call still uses backend="auto", so this can also run a different backend under the trtllm-native label.

Suggested fix
     # NVFP4 KV cache only works with trtllm-native (direct API with kv_block_scales)
     if is_nvfp4_kv:
-        unsupported = [b for b in backends if b != "trtllm-native"]
-        for b in unsupported:
-            print(f"[INFO] {b} backend does not support NVFP4 KV cache. Skipping.")
-            backends.remove(b)
-        if "trtllm-native" not in backends:
-            backends.append("trtllm-native")
+        backends = [b for b in backends if b == "trtllm-native"]
+        if not backends:
+            print("[ERROR] NVFP4 KV cache is not supported on this device/backend set.")
+            return res
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/attention.py` around lines 397 - 404, The current block
handling is_nvfp4_kv wrongly re-appends "trtllm-native" after filtering, which
can override compute-capability decisions; modify the logic in the is_nvfp4_kv
section so you do not re-add "trtllm-native" if it was removed by capability
filtering — either remove the line that appends "trtllm-native" or conditionally
append only if it was present in the original backends list and still supported;
update references to backends and the is_nvfp4_kv check accordingly to preserve
correct filtering and avoid mislabeling.
flashinfer/decode.py (1)

1285-1306: ⚠️ Potential issue | 🟠 Major

Transpose kv_block_scales with the KV cache on the NHD trtllm-gen path.

This branch converts k_cache and v_cache from NHD to HND, but key_block_scales and value_block_scales stay in NHD order. NVFP4 then feeds HND-packed KV pages with mismatched scale tensors.

Suggested fix
         # Convert NHD layout to HND for trtllm-gen backend
         if self._backend == "trtllm-gen" and self._kv_layout == "NHD":
             # For NHD: [..., N, H, D] -> HND: [..., H, N, D]
             k_cache = k_cache.transpose(-3, -2)
             v_cache = v_cache.transpose(-3, -2)
+            if key_block_scales is not None:
+                key_block_scales = key_block_scales.transpose(-3, -2)
+            if value_block_scales is not None:
+                value_block_scales = value_block_scales.transpose(-3, -2)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/decode.py` around lines 1285 - 1306, The NHD->HND conversion for
k_cache and v_cache in the trtllm-gen path doesn't apply the same transpose to
kv_block_scales, causing mismatched ordering; update the block under "if
self._backend == 'trtllm-gen' and self._kv_layout == 'NHD':" to also transpose
key_block_scales and value_block_scales (when present) using the same axes as
k_cache/v_cache (i.e., transpose(-3, -2)); handle both the tuple and unbound
tensor forms of kv_block_scales (respecting the earlier unpack logic where
kv_block_scales may be None, a tuple, or produced by .unbind(dim=1)) so the
scales are reordered to HND to match k_cache/v_cache.
benchmarks/bench_trtllm_fmha.py (2)

210-213: ⚠️ Potential issue | 🟡 Minor

Account for bytes, not elements, in the NVFP4 IO path.

The tuple branch adds raw numel() for the packed KV tensors and also skips kv_block_scales, so the printed GB/s is understated for NVFP4.

Suggested fix
     if isinstance(kv_cache, tuple):
-        io = q.numel() * q.element_size() + kv_cache[0].numel() + kv_cache[1].numel()
+        io = (
+            q.numel() * q.element_size()
+            + kv_cache[0].numel() * kv_cache[0].element_size()
+            + kv_cache[1].numel() * kv_cache[1].element_size()
+        )
+        if kv_block_scales is not None:
+            if isinstance(kv_block_scales, tuple):
+                io += sum(t.numel() * t.element_size() for t in kv_block_scales)
+            else:
+                io += kv_block_scales.numel() * kv_block_scales.element_size()
     else:
         io = q.numel() * q.element_size() + kv_cache.numel() * kv_cache.element_size()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_trtllm_fmha.py` around lines 210 - 213, The IO calculation
currently under the NVFP4 tuple branch uses raw element counts and omits
kv_block_scales, understating bytes; update the tuple branch that computes io
(using kv_cache and q) to sum byte sizes by multiplying each tensor's numel() by
its element_size() (e.g., kv_cache[0].numel()*kv_cache[0].element_size() and
kv_cache[1].numel()*kv_cache[1].element_size()) and also add the byte size of
any kv_block_scales tensor(s) (numel()*element_size()) so the NVFP4 GB/s report
uses bytes, not elements.

123-139: ⚠️ Potential issue | 🟠 Major

NVFP4 silently changes the benchmarked query dtype.

When kv_cache_dtype == "nvfp4", this path converts non-FP8 queries to FP8 before the run. That means the benchmark no longer measures the workload requested by q_dtype, and the CLI currently gives the caller no way to see or control that change.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_trtllm_fmha.py` around lines 123 - 139, The code currently
mutates the benchmark query `q` to FP8 when `kv_cache_dtype == "nvfp4"` (via
to_float8), which silently changes the measured dtype; instead, preserve the
original `q` and create a separate FP8 copy only for NVFP4-specific operations:
call to_float8(q) and store result in a new variable (e.g., q_fp8 and
q_fp8_inv_scale) and use q_fp8 when computing `k_scale_val`, `kv_cache`
quantization (KVFP4QuantizeUtil.quantize_paged_kv_cache) and any NVFP4 paths,
leaving `q` and reported `q_dtype` unchanged; also add a clear guard or CLI flag
to either: (a) error if the requested q_dtype is not FP8 and kv_cache_dtype is
nvfp4, or (b) emit a warning/info that an FP8 copy was used for NVFP4 operations
and document the flag to opt-in—implement the change around the block
referencing kv_cache_dtype, q, to_float8, and
KVFP4QuantizeUtil.quantize_paged_kv_cache.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@benchmarks/bench_trtllm_fmha.py`:
- Around line 210-213: The IO calculation currently under the NVFP4 tuple branch
uses raw element counts and omits kv_block_scales, understating bytes; update
the tuple branch that computes io (using kv_cache and q) to sum byte sizes by
multiplying each tensor's numel() by its element_size() (e.g.,
kv_cache[0].numel()*kv_cache[0].element_size() and
kv_cache[1].numel()*kv_cache[1].element_size()) and also add the byte size of
any kv_block_scales tensor(s) (numel()*element_size()) so the NVFP4 GB/s report
uses bytes, not elements.
- Around line 123-139: The code currently mutates the benchmark query `q` to FP8
when `kv_cache_dtype == "nvfp4"` (via to_float8), which silently changes the
measured dtype; instead, preserve the original `q` and create a separate FP8
copy only for NVFP4-specific operations: call to_float8(q) and store result in a
new variable (e.g., q_fp8 and q_fp8_inv_scale) and use q_fp8 when computing
`k_scale_val`, `kv_cache` quantization
(KVFP4QuantizeUtil.quantize_paged_kv_cache) and any NVFP4 paths, leaving `q` and
reported `q_dtype` unchanged; also add a clear guard or CLI flag to either: (a)
error if the requested q_dtype is not FP8 and kv_cache_dtype is nvfp4, or (b)
emit a warning/info that an FP8 copy was used for NVFP4 operations and document
the flag to opt-in—implement the change around the block referencing
kv_cache_dtype, q, to_float8, and KVFP4QuantizeUtil.quantize_paged_kv_cache.

In `@benchmarks/routines/attention.py`:
- Around line 600-606: The code branch under is_nvfp4_kv force-casts q to
torch.float8_e4m3fn but leaves the original requested q dtype recorded in
benchmarks; either validate and raise if q.dtype is not FP8 or propagate the
effective query dtype into the metric/reporting path. Update the block around
is_nvfp4_kv/q/KVFP4QuantizeUtil.quantize_paged_kv_cache to (a) check q.dtype and
raise a clear error if non-FP8 inputs are not supported, or (b) set an explicit
effective_q_dtype variable (e.g. effective_q_dtype = torch.float8_e4m3fn) after
casting and ensure all perf accounting and result reporting use
effective_q_dtype instead of the original q.dtype so metrics reflect the actual
dtype used.
- Around line 397-404: The current block handling is_nvfp4_kv wrongly re-appends
"trtllm-native" after filtering, which can override compute-capability
decisions; modify the logic in the is_nvfp4_kv section so you do not re-add
"trtllm-native" if it was removed by capability filtering — either remove the
line that appends "trtllm-native" or conditionally append only if it was present
in the original backends list and still supported; update references to backends
and the is_nvfp4_kv check accordingly to preserve correct filtering and avoid
mislabeling.

In `@flashinfer/decode.py`:
- Around line 1285-1306: The NHD->HND conversion for k_cache and v_cache in the
trtllm-gen path doesn't apply the same transpose to kv_block_scales, causing
mismatched ordering; update the block under "if self._backend == 'trtllm-gen'
and self._kv_layout == 'NHD':" to also transpose key_block_scales and
value_block_scales (when present) using the same axes as k_cache/v_cache (i.e.,
transpose(-3, -2)); handle both the tuple and unbound tensor forms of
kv_block_scales (respecting the earlier unpack logic where kv_block_scales may
be None, a tuple, or produced by .unbind(dim=1)) so the scales are reordered to
HND to match k_cache/v_cache.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ed7f2473-d203-439a-8f76-18f3eb207067

📥 Commits

Reviewing files that changed from the base of the PR and between e5f8507 and 17f2353.

📒 Files selected for processing (5)
  • benchmarks/bench_trtllm_fmha.py
  • benchmarks/routines/attention.py
  • benchmarks/routines/flashinfer_benchmark_utils.py
  • flashinfer/decode.py
  • flashinfer/mla.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • benchmarks/routines/flashinfer_benchmark_utils.py

@sychen52 sychen52 force-pushed the nvfp4-kv-cache-sm100 branch from 17f2353 to 77e4f6d Compare March 7, 2026 22:31
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
benchmarks/bench_trtllm_fmha.py (1)

210-213: ⚠️ Potential issue | 🟡 Minor

Fix tuple KV IO accounting to include element sizes.

Line 211 currently sums only element counts for tuple KV caches, so bandwidth is misreported.

Suggested fix
     if isinstance(kv_cache, tuple):
-        io = q.numel() * q.element_size() + kv_cache[0].numel() + kv_cache[1].numel()
+        io = (
+            q.numel() * q.element_size()
+            + kv_cache[0].numel() * kv_cache[0].element_size()
+            + kv_cache[1].numel() * kv_cache[1].element_size()
+        )
     else:
         io = q.numel() * q.element_size() + kv_cache.numel() * kv_cache.element_size()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_trtllm_fmha.py` around lines 210 - 213, The IO calculation
for tuple KV caches underestimates bandwidth because it adds only numel() for
each tensor; update the tuple branch in the io computation (where kv_cache is
checked and io is assigned) to multiply each tensor's numel() by its
element_size() (i.e., use kv_cache[0].numel() * kv_cache[0].element_size() +
kv_cache[1].numel() * kv_cache[1].element_size()) and keep the q contribution as
q.numel() * q.element_size().
include/flashinfer/trtllm/fmha/kernelParams.h (1)

261-261: ⚠️ Potential issue | 🟡 Minor

Use portable formatting for int64_t in debug print.

Line 261 uses %ld for mNumHiddenEltsO (int64_t), which is not portable across LP64/LLP64 targets. Prefer PRId64.

#!/bin/bash
# Verify current formatting usage around mNumHiddenEltsO in kernelParams.h
rg -n "mNumHiddenEltsO|%ld|PRId64" include/flashinfer/trtllm/fmha/kernelParams.h -C2
Suggested fix
+#include <cinttypes>
 ...
-    printf("mNumHiddenEltsO: %ld\n", mNumHiddenEltsO);
+    printf("mNumHiddenEltsO: %" PRId64 "\n", mNumHiddenEltsO);
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/trtllm/fmha/kernelParams.h` at line 261, Replace the
non-portable printf usage that prints mNumHiddenEltsO (an int64_t) with a
portable int64_t format: include <inttypes.h> if not already included and change
the printf to use "%" PRId64 and cast mNumHiddenEltsO to (int64_t) (e.g.
printf("mNumHiddenEltsO: %" PRId64 "\n", (int64_t)mNumHiddenEltsO);) so the
output is correct across LP64/LLP64 targets.
🧹 Nitpick comments (1)
tests/attention/test_trtllm_gen_attention.py (1)

1040-1050: Consider expanding documentation for NVFP4 tolerance thresholds.

The NVFP4 KV cache tolerances (rtol=0.3, atol=0.3, allowed_mismatch_rate=0.05) are intentional design choices that reflect FP4's inherent 4-bit quantization limitations. While the existing comment ("NVFP4 KV cache has significant quantization error") explains the relaxation, consider adding more specifics:

  • Why a 5% mismatch rate is appropriate for this quantization format
  • Whether these thresholds have been validated against expected FP4 precision characteristics
  • Any references to NVFP4 specification or prior analysis

This helps future maintainers understand whether the bounds can be tightened if FP4 implementation is refined.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_trtllm_gen_attention.py` around lines 1040 - 1050,
Expand the comment block around the NVFP4 tolerance assignments (where kv_dtype
is checked and rtol, atol, allowed_mismatch_rate are set) to document why the
NVFP4 values were chosen: state that FP4’s 4-bit quantization causes significant
rounding/representation error, justify the 0.3 rtol/atol and 0.05 mismatch rate
by referencing any validation/benchmarks or expected FP4 precision
characteristics, note whether these thresholds were empirically validated (and
how) and include a pointer to NVFP4 specification or analysis notes so future
maintainers can reassess tightening rtol/atol or allowed_mismatch_rate for
kv_dtype == "nvfp4".
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/decode.py`:
- Around line 1288-1292: The code assumes kv_block_scales can be unbound along
dim=1 into key_block_scales and value_block_scales but crashes when
kv_block_scales is a tensor with size(1)==1 (shared-KV); update the branch in
flashinfer/decode.py so that if kv_block_scales is not a tuple and
kv_block_scales.size(1)==1 you assign key_block_scales = value_block_scales =
kv_block_scales.squeeze(1) (or equivalent indexing), otherwise keep the current
kv_block_scales.unbind(dim=1) behavior; reference variables kv_block_scales,
key_block_scales, value_block_scales in the fix.

---

Duplicate comments:
In `@benchmarks/bench_trtllm_fmha.py`:
- Around line 210-213: The IO calculation for tuple KV caches underestimates
bandwidth because it adds only numel() for each tensor; update the tuple branch
in the io computation (where kv_cache is checked and io is assigned) to multiply
each tensor's numel() by its element_size() (i.e., use kv_cache[0].numel() *
kv_cache[0].element_size() + kv_cache[1].numel() * kv_cache[1].element_size())
and keep the q contribution as q.numel() * q.element_size().

In `@include/flashinfer/trtllm/fmha/kernelParams.h`:
- Line 261: Replace the non-portable printf usage that prints mNumHiddenEltsO
(an int64_t) with a portable int64_t format: include <inttypes.h> if not already
included and change the printf to use "%" PRId64 and cast mNumHiddenEltsO to
(int64_t) (e.g. printf("mNumHiddenEltsO: %" PRId64 "\n",
(int64_t)mNumHiddenEltsO);) so the output is correct across LP64/LLP64 targets.

---

Nitpick comments:
In `@tests/attention/test_trtllm_gen_attention.py`:
- Around line 1040-1050: Expand the comment block around the NVFP4 tolerance
assignments (where kv_dtype is checked and rtol, atol, allowed_mismatch_rate are
set) to document why the NVFP4 values were chosen: state that FP4’s 4-bit
quantization causes significant rounding/representation error, justify the 0.3
rtol/atol and 0.05 mismatch rate by referencing any validation/benchmarks or
expected FP4 precision characteristics, note whether these thresholds were
empirically validated (and how) and include a pointer to NVFP4 specification or
analysis notes so future maintainers can reassess tightening rtol/atol or
allowed_mismatch_rate for kv_dtype == "nvfp4".

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: b9e84230-fc54-4037-9f54-3d2d505543c7

📥 Commits

Reviewing files that changed from the base of the PR and between 17f2353 and 77e4f6d.

📒 Files selected for processing (11)
  • benchmarks/bench_trtllm_fmha.py
  • benchmarks/routines/attention.py
  • benchmarks/routines/flashinfer_benchmark_utils.py
  • csrc/trtllm_fmha_kernel_launcher.cu
  • flashinfer/decode.py
  • flashinfer/mla.py
  • flashinfer/testing/__init__.py
  • flashinfer/testing/kvfp4.py
  • include/flashinfer/trtllm/fmha/fmhaKernels.cuh
  • include/flashinfer/trtllm/fmha/kernelParams.h
  • tests/attention/test_trtllm_gen_attention.py
🚧 Files skipped from review as they are similar to previous changes (4)
  • include/flashinfer/trtllm/fmha/fmhaKernels.cuh
  • benchmarks/routines/attention.py
  • flashinfer/testing/init.py
  • flashinfer/mla.py

@samuellees
Copy link
Collaborator

cc @PerkzZheng @Tom-Zheng for viz and review~

@sychen52 sychen52 force-pushed the nvfp4-kv-cache-sm100 branch from 77e4f6d to ba3b178 Compare March 10, 2026 04:14
@samuellees
Copy link
Collaborator

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !417 has been updated with latest changes, and the CI pipeline #46162422 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #46162422: 9/20 passed

@sychen52 sychen52 force-pushed the nvfp4-kv-cache-sm100 branch from ba94862 to ad06df9 Compare March 15, 2026 16:21
@sychen52
Copy link
Contributor Author

/bot run

@flashinfer-bot
Copy link
Collaborator

@sychen52 is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww

@yzh119
Copy link
Collaborator

yzh119 commented Mar 15, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !417 has been updated with latest changes, and the CI pipeline #46194361 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[SUCCESS] Pipeline #46194361: 9/20 passed

@kahyunnam kahyunnam enabled auto-merge (squash) March 16, 2026 17:24
Copy link
Collaborator

@nv-yunzheq nv-yunzheq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approve

auto-merge was automatically disabled March 17, 2026 19:51

Head branch was pushed to by a user without write access

@sychen52 sychen52 force-pushed the nvfp4-kv-cache-sm100 branch from ad06df9 to ebac227 Compare March 17, 2026 19:51
PerkzZheng and others added 6 commits March 18, 2026 10:29
address comments

small fix

Add support for trtllm gen nvfp4 kv cache (none-interleave)

Disable checksum and download

Add unit test

Add unit test

fix
Signed-off-by: Shiyang Chen <shiychen@nvidia.com>

fix2

fix3

fix4

remove SAM_DEBUG

temp

temp

temp
cleanup debug prints in cpp

try to add none HND support

remove trtllm-native logic and small changes according to comments

temp

add support for prefill/context kernel as well.

add kv_block_scales size into io

update based on comments
@sychen52 sychen52 force-pushed the nvfp4-kv-cache-sm100 branch from ebac227 to de36999 Compare March 18, 2026 17:44
@aleozlx aleozlx enabled auto-merge (squash) March 18, 2026 17:55
@aleozlx aleozlx merged commit fc4e70f into flashinfer-ai:main Mar 19, 2026
28 of 29 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.