Add NVFP4 KV cache quantization support for SM100#2702
Add NVFP4 KV cache quantization support for SM100#2702aleozlx merged 6 commits intoflashinfer-ai:mainfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds NVFP4 (4-bit) KV-cache quantization end-to-end: new KVFP4QuantizeUtil, thread per-block KV scales through Python decode/benchmark/test code, and extend CUDA FMHA launcher/kernel interfaces to accept and use key/value block-scale pointers. Changes
Sequence Diagram(s)sequenceDiagram
participant User as User Code
participant PyAPI as flashinfer.decode
participant Quant as KVFP4QuantizeUtil
participant Wrapper as Bench/Wrapper
participant Launcher as CUDA Launcher
participant Kernel as FMHA Kernel
User->>PyAPI: batch_decode_with_paged_kv_cache(..., kv_cache?, kv_block_scales?)
alt nvfp4 quantization path
PyAPI->>Quant: quantize_paged_kv_cache(k_cache, v_cache)
Quant-->>PyAPI: quantized_kv, kv_block_scales, k_scale, v_scale
PyAPI->>Wrapper: run(..., kv_data_type=uint8, k_scale, v_scale, kv_block_scales)
else non-nvfp4 path
PyAPI->>Wrapper: run(..., kv_data_type=kv_cache.dtype, k_scale, v_scale)
end
Wrapper->>Launcher: trtllm_paged_attention_decode(..., k_block_scales_ptr, v_block_scales_ptr)
Launcher->>Kernel: invoke kernel with scale pointers
Kernel-->>User: attention outputs
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip You can make CodeRabbit's review stricter and more nitpicky using the `assertive` profile, if that's what you prefer.Change the |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances FlashInfer's capabilities by introducing support for NVFP4 Key-Value cache, primarily targeting NVIDIA's SM100 (Hopper) architecture. This addition aims to improve memory efficiency and potentially performance for large language models by allowing KV caches to be stored in a compact 4-bit floating-point format. The changes span across the Python API, C++ kernels, and testing infrastructure, ensuring a robust and well-integrated implementation of this new data type for paged attention operations. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for NVFP4 KV cache, primarily targeting SM100 architectures. The changes span across benchmark files, the core Python API, C++ CUDA kernel launchers, and tests. A new testing utility for NVFP4 quantization is also added.
My review has identified a few critical issues that appear to be leftover debugging code. These changes disable CUBIN checksum validation and downloading, and hardcode a value for mSumOfSeqLensKv, which will likely break functionality. These need to be addressed before merging.
Otherwise, the implementation for NVFP4 support seems consistent and well-integrated into the existing architecture.
Note: Security Review did not run due to the size of the PR.
There was a problem hiding this comment.
Actionable comments posted: 9
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
flashinfer/jit/cubin_loader.py (1)
184-196:⚠️ Potential issue | 🔴 CriticalRestore checksum validation in
load_cubin.Line 186 returns before
FLASHINFER_CUBIN_CHECKSUM_DISABLEDand the SHA-256 check run, so every cached cubin is now accepted even when it is stale or corrupted. That breaks this helper’s contract and can hand the wrong binary to the native loader.Suggested fix
try: with open(cubin_path, mode="rb") as f: cubin = f.read() - return cubin if os.getenv("FLASHINFER_CUBIN_CHECKSUM_DISABLED"): return cubin m = hashlib.sha256() m.update(cubin) actual_sha = m.hexdigest()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/jit/cubin_loader.py` around lines 184 - 196, In load_cubin, remove the premature "return cubin" that short-circuits checksum logic: read the file into cubin, then if os.getenv("FLASHINFER_CUBIN_CHECKSUM_DISABLED") return cubin; otherwise compute the SHA-256 (using hashlib.update on cubin -> actual_sha) and compare to the provided sha256 argument; if they match return cubin, else log the mismatch via logger.warning (including expected and actual) and raise a ValueError (or otherwise fail) so a corrupted/stale cubin is not silently accepted. Ensure you reference the existing symbols cubin_path, cubin, sha256, FLASHINFER_CUBIN_CHECKSUM_DISABLED, and logger when making the change.flashinfer/decode.py (2)
2299-2338:⚠️ Potential issue | 🟠 MajorReject invalid NVFP4 combinations before backend dispatch.
is_nvfp4_kvcacheonly becomes true whenkv_block_scalesis present, so auint8KV cache without scales currently falls through into the normal path. The same packed-KV call can also still resolve toxqawhenbackend="auto"on non-SM100 devices. Both cases end up sending NVFP4 data to a backend that cannot decode it.Suggested fix
- is_nvfp4_kvcache = ( - k_cache.dtype == torch.uint8 - and v_cache.dtype == torch.uint8 - and kv_block_scales is not None - ) + is_nvfp4_kvcache = ( + k_cache.dtype == torch.uint8 and v_cache.dtype == torch.uint8 + ) @@ + if is_nvfp4_kvcache and kv_block_scales is None: + raise ValueError("kv_block_scales is required for NVFP4 KV cache") + if backend == "auto": backend = ( "trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa" ) + + if is_nvfp4_kvcache and backend != "trtllm-gen": + raise ValueError("NVFP4 KV cache is only supported by the trtllm-gen backend")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/decode.py` around lines 2299 - 2338, The code currently only sets is_nvfp4_kvcache when kv_block_scales is present, allowing uint8 k_cache/v_cache without scales or auto-resolving backend="auto" to xqa to proceed and send NVFP4 data to an incompatible backend; update decode.py to validate NVFP4-packed KV cache before backend dispatch by (1) adding a check that if k_cache.dtype==torch.uint8 and v_cache.dtype==torch.uint8 then kv_block_scales must be non-None and raise a ValueError if missing (refer to variables is_nvfp4_kvcache, k_cache, v_cache, kv_block_scales), and (2) ensure that NVFP4 output/packed-KV is rejected when backend resolves to "xqa" (whether backend was explicitly "xqa" or chosen via backend == "auto" using get_compute_capability(query.device)) so the existing xqa checks for out_dtype/FP4Tensor also cover packed uint8 KV cache scenarios.
2025-2076:⚠️ Potential issue | 🔴 CriticalFix the fake op registration to match the real custom op exactly.
The
_fake_paged_runfunction violates the torch.compile fake-tensor contract. It must be registered under the same op name and have an identical signature aspaged_run:
- Change decorator from
@register_fake_op(f"flashinfer::{uri}_paged_run")to@register_fake_op(f"flashinfer::{uri}_ragged_run")- Add missing parameters:
scale_q,scale_k,scale_v(aftersm_scale), andworkspace_size(aftertoken_pos_in_items_len)Without this fix, shape and dtype inference during
torch.compilewill fail for this decode path.Suggested fix
- `@register_fake_op`(f"flashinfer::{uri}_paged_run") + `@register_fake_op`(f"flashinfer::{uri}_ragged_run") def _fake_paged_run( float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor, plan_info_vec: List[int], q: torch.Tensor, @@ maybe_max_item_len_ptr: Optional[torch.Tensor], logits_soft_cap: float, sm_scale: float, + scale_q: Optional[torch.Tensor], + scale_k: Optional[torch.Tensor], + scale_v: Optional[torch.Tensor], rope_scale: float, rope_theta: float, token_pos_in_items_len: int, + workspace_size: int, paged_kv_cache: Optional[torch.Tensor] = None, num_qo_heads: Optional[int] = None, num_kv_heads: Optional[int] = None, block_tables: Optional[torch.Tensor] = None,Also applies to: line 2109 area (the other fake op at lines 2108-2147 has the same issues).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/decode.py` around lines 2025 - 2076, The fake-op registration for _fake_paged_run must exactly mirror the real custom op paged_run: update the decorator from `@register_fake_op`(f"flashinfer::{uri}_paged_run") to `@register_fake_op`(f"flashinfer::{uri}_ragged_run") and add the missing parameters to the fake function signature—insert scale_q, scale_k, scale_v immediately after sm_scale, and add workspace_size after token_pos_in_items_len—to match paged_run; apply the same decorator and signature fixes to the second fake op variant around the 2108–2147 area so both fake registrations fully match paged_run.
🧹 Nitpick comments (1)
flashinfer/testing/kvfp4.py (1)
35-43: Module-level CUDA tensors created at import time.
E2M1_VALUESandE2M1_BOUNDSare created when the module is imported. If CUDA is available but not initialized, or if there are multiple CUDA devices, this could cause issues. Consider lazy initialization or making these tensors on-demand.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/testing/kvfp4.py` around lines 35 - 43, E2M1_VALUES and E2M1_BOUNDS are created as CUDA tensors at module import, which can fail if CUDA isn't initialized or when multiple devices exist; change them to be created on-demand by replacing the module-level tensors with factory accessors (e.g., get_e2m1_values(device=None) and get_e2m1_bounds(device=None)) that take an optional device, compute torch.tensor(...) inside the function and return .to(device) (defaulting to CPU or a resolved _device) so no CUDA tensors are allocated at import time; update all uses of E2M1_VALUES and E2M1_BOUNDS to call the new getters.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/routines/flashinfer_benchmark_utils.py`:
- Around line 272-273: The function dtype_str_to_torch_dtype currently returns
the string "nvfp4" for dtype_str == "nvfp4", breaking the expected return type
of torch.dtype; change this to return torch.uint8 (the carrier dtype for nvfp4)
and add an inline comment in dtype_str_to_torch_dtype explaining that nvfp4 is
represented/carried as torch.uint8 so callers should interpret it specially, or
alternatively update the function docstring and return type annotation to
document the nvfp4 special-case if you prefer explicit typing; ensure all code
paths return a torch.dtype.
In `@flashinfer/decode.py`:
- Around line 1285-1292: kv_block_scales are left in NHD layout while
k_cache/v_cache are converted to HND for the trtllm-gen wrapper, causing
mismatched layouts for the NVFP4 kernel; when kv_block_scales is not None,
transpose/permute key_block_scales and value_block_scales the same way you
transform k_cache/v_cache (i.e., convert their NHD ordering to HND ordering) so
they match the packed KV pages, and apply the same change to the second
occurrence handling the same variables later
(key_block_scales/value_block_scales and kv_block_scales).
In `@flashinfer/jit/cubin_loader.py`:
- Around line 239-246: The get_cubin() function currently has an unconditional
return b"" that short-circuits the cache-miss path and prevents the
FLASHINFER_DISABLE_CUBIN_DOWNLOAD check and download logic from running; remove
the early "return b\"\"" so the function evaluates
os.getenv("FLASHINFER_DISABLE_CUBIN_DOWNLOAD") and, only if that env var is set,
logs the error about cubin_path and returns b""; otherwise proceed with the
existing download/cache retrieval logic in cubin_loader.py to restore cold-start
fetches used by callers like gemm/core.py, fused_moe.py, attention/modules.py,
and moe_utils.py.
In `@flashinfer/testing/kvfp4.py`:
- Around line 201-204: The division by 6.0 on k_blk_scales and v_blk_scales can
underflow when converted to torch.float8_e4m3fn; clamp the scaled block scales
to the FP8 minimum positive value before casting. Update the block where
k_blk_scales and v_blk_scales are adjusted (symbols: k_blk_scales, v_blk_scales,
float8_e4m3fn) to compute the divided values, clamp them with a conservative min
like 2**-9 (~0.00195) using .clamp(min=min_val), then cast to
torch.float8_e4m3fn; this preserves small-but-nonzero scales and avoids
underflow while keeping the intended FP8 adjustment.
- Around line 104-107: block_scales can be zero causing division by zero when
computing x_scaled = reshaped / (block_scales_fixed * global_scale); fix it by
ensuring block_scales_fixed never contains zeros (e.g., clamp_min or replace
zeros with a small epsilon like 1e-6) before the division so that
block_scales_fixed (derived from block_scales/block_max) is safe; update the
code that constructs block_scales_fixed (and any use of block_scales or
block_max) to apply the clamp/eps consistently so x_scaled and downstream
computations never divide by zero.
In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh`:
- Around line 486-489: The stray fflush(stdout) after the commented debug printf
in fmhaKernels.cuh should not execute unconditionally; either remove the
fflush(stdout) or wrap it in the same SAM_DEBUG guard used for other debug
prints (same pattern around the printf near the multiCtasKv debug block), so
update the code around the multiCtasKv debug section to put fflush(stdout)
inside the SAM_DEBUG macro (or delete it) to match the debug gating used
elsewhere.
- Around line 430-432: Remove or guard the leftover unconditional fflush(stdout)
in fmhaKernels.cuh: locate the orphaned fflush(stdout) call (near the commented
printf in the fmha kernel trace region) and either delete it or move it inside
the SAM_DEBUG conditional so it only runs when SAM_DEBUG is enabled; ensure no
other unconditional debug I/O remains in the same function/kernel after the
change.
In `@include/flashinfer/trtllm/fmha/kernelParams.h`:
- Around line 922-923: Replace the debug hardcoded assignment to
params.mSumOfSeqLensKv with the original dynamic value: restore the use of
options.mSumOfSeqLensKv (i.e., set params.mSumOfSeqLensKv =
options.mSumOfSeqLensKv) and remove the fixed "64" literal; if needed, add a
defensive check that options.mSumOfSeqLensKv is valid before assigning to
params.mSumOfSeqLensKv to avoid regressions.
In `@tests/attention/test_trtllm_gen_attention.py`:
- Around line 1064-1070: The test guard currently excludes NVFP4 KV paths so the
new wrapper code never gets exercised; remove the `kv_dtype != "nvfp4"`
condition (or adjust the conditional logic) so the branch runs when
`kv_block_scales` can be provided, and ensure `wrapper_trtllm_gen.run(...)`
forwards the `kv_block_scales` argument into
`BatchDecodeWithPagedKVCacheWrapper.run(...)` (the wrapper's NVFP4 behavior is
gated on `kv_block_scales` being non-None), referencing
`wrapper_trtllm_gen.run`, `BatchDecodeWithPagedKVCacheWrapper.run`, and the
`kv_block_scales` parameter to locate and fix the code.
---
Outside diff comments:
In `@flashinfer/decode.py`:
- Around line 2299-2338: The code currently only sets is_nvfp4_kvcache when
kv_block_scales is present, allowing uint8 k_cache/v_cache without scales or
auto-resolving backend="auto" to xqa to proceed and send NVFP4 data to an
incompatible backend; update decode.py to validate NVFP4-packed KV cache before
backend dispatch by (1) adding a check that if k_cache.dtype==torch.uint8 and
v_cache.dtype==torch.uint8 then kv_block_scales must be non-None and raise a
ValueError if missing (refer to variables is_nvfp4_kvcache, k_cache, v_cache,
kv_block_scales), and (2) ensure that NVFP4 output/packed-KV is rejected when
backend resolves to "xqa" (whether backend was explicitly "xqa" or chosen via
backend == "auto" using get_compute_capability(query.device)) so the existing
xqa checks for out_dtype/FP4Tensor also cover packed uint8 KV cache scenarios.
- Around line 2025-2076: The fake-op registration for _fake_paged_run must
exactly mirror the real custom op paged_run: update the decorator from
`@register_fake_op`(f"flashinfer::{uri}_paged_run") to
`@register_fake_op`(f"flashinfer::{uri}_ragged_run") and add the missing
parameters to the fake function signature—insert scale_q, scale_k, scale_v
immediately after sm_scale, and add workspace_size after
token_pos_in_items_len—to match paged_run; apply the same decorator and
signature fixes to the second fake op variant around the 2108–2147 area so both
fake registrations fully match paged_run.
In `@flashinfer/jit/cubin_loader.py`:
- Around line 184-196: In load_cubin, remove the premature "return cubin" that
short-circuits checksum logic: read the file into cubin, then if
os.getenv("FLASHINFER_CUBIN_CHECKSUM_DISABLED") return cubin; otherwise compute
the SHA-256 (using hashlib.update on cubin -> actual_sha) and compare to the
provided sha256 argument; if they match return cubin, else log the mismatch via
logger.warning (including expected and actual) and raise a ValueError (or
otherwise fail) so a corrupted/stale cubin is not silently accepted. Ensure you
reference the existing symbols cubin_path, cubin, sha256,
FLASHINFER_CUBIN_CHECKSUM_DISABLED, and logger when making the change.
---
Nitpick comments:
In `@flashinfer/testing/kvfp4.py`:
- Around line 35-43: E2M1_VALUES and E2M1_BOUNDS are created as CUDA tensors at
module import, which can fail if CUDA isn't initialized or when multiple devices
exist; change them to be created on-demand by replacing the module-level tensors
with factory accessors (e.g., get_e2m1_values(device=None) and
get_e2m1_bounds(device=None)) that take an optional device, compute
torch.tensor(...) inside the function and return .to(device) (defaulting to CPU
or a resolved _device) so no CUDA tensors are allocated at import time; update
all uses of E2M1_VALUES and E2M1_BOUNDS to call the new getters.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 2da26834-27b1-49c2-be23-cb0a19ace74e
📒 Files selected for processing (14)
benchmarks/bench_trtllm_fmha.pybenchmarks/routines/attention.pybenchmarks/routines/flashinfer_benchmark_utils.pycsrc/trtllm_fmha_kernel_launcher.cuflashinfer/decode.pyflashinfer/jit/cubin_loader.pyflashinfer/mla.pyflashinfer/testing/__init__.pyflashinfer/testing/kvfp4.pyinclude/flashinfer/trtllm/fmha/fmhaKernels.cuhinclude/flashinfer/trtllm/fmha/fmhaRunner.cuhinclude/flashinfer/trtllm/fmha/fmhaRunnerParams.hinclude/flashinfer/trtllm/fmha/kernelParams.htests/attention/test_trtllm_gen_attention.py
d885d8a to
e5f8507
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/decode.py (1)
2328-2366:⚠️ Potential issue | 🟠 MajorDo not route NVFP4 KV cache to XQA.
is_nvfp4_kvcacheis recognized above, but only thetrtllm-genbranch forwardsk_block_scales/v_block_scales. Withbackend="auto"on SM120/121, orbackend="xqa"explicitly, this silently ignores the new scale tensors and dispatches a packeduint8KV cache to an unsupported backend.Suggested fix
if backend == "auto": - backend = ( - "trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa" - ) + cc_major = get_compute_capability(query.device)[0] + if is_nvfp4_kvcache: + if cc_major != 10: + raise ValueError( + "NVFP4 KV cache is only supported by the trtllm-gen backend on SM100/SM103." + ) + backend = "trtllm-gen" + else: + backend = "trtllm-gen" if cc_major == 10 else "xqa" if backend == "xqa": + if is_nvfp4_kvcache: + raise ValueError("xqa backend does not support NVFP4 KV cache.") # xqa backend doesn't support nvfp4 output
♻️ Duplicate comments (3)
include/flashinfer/trtllm/fmha/kernelParams.h (1)
921-923:⚠️ Potential issue | 🔴 CriticalRestore the real KV length sum.
Line 923 hardcodes
mSumOfSeqLensKvto64. Any batch whose total KV length differs from that will build incorrect kernel params and break the generated launch configuration.Suggested fix
params.mSumOfSeqLensQ = options.mSumOfSeqLensQ; - // params.mSumOfSeqLensKv = options.mSumOfSeqLensKv; - params.mSumOfSeqLensKv = 64; + params.mSumOfSeqLensKv = options.mSumOfSeqLensKv;🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/trtllm/fmha/kernelParams.h` around lines 921 - 923, Replace the hardcoded constant with the actual KV length sum: change the assignment that sets params.mSumOfSeqLensKv to use options.mSumOfSeqLensKv instead of the literal 64 so the kernel params reflect the true batch KV length (look for params.mSumOfSeqLensKv and options.mSumOfSeqLensKv in the kernelParams.h code).flashinfer/decode.py (1)
1302-1306:⚠️ Potential issue | 🟠 MajorTranspose the block-scale tensors with the KV pages.
On the NHD wrapper path,
k_cacheandv_cacheare converted to HND at Lines 1305-1306, butkey_block_scalesandvalue_block_scalesstay in NHD order. The directtrtllm_batch_decode_with_kv_cachepath below already transposes them, so the wrapper currently feeds mismatched layouts into the NVFP4 kernel.Suggested fix
if self._backend == "trtllm-gen" and self._kv_layout == "NHD": # For NHD: [..., N, H, D] -> HND: [..., H, N, D] k_cache = k_cache.transpose(-3, -2) v_cache = v_cache.transpose(-3, -2) + if key_block_scales is not None: + key_block_scales = key_block_scales.transpose(-3, -2) + if value_block_scales is not None: + value_block_scales = value_block_scales.transpose(-3, -2)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/decode.py` around lines 1302 - 1306, The NHD-to-HND transpose only moves k_cache and v_cache but leaves key_block_scales and value_block_scales in NHD order, causing layout mismatch for the trtllm-gen NVFP4 kernel; update the same conditional in decode.py (the branch that checks self._backend == "trtllm-gen" and self._kv_layout == "NHD") to also transpose key_block_scales and value_block_scales with the same axes (use transpose(-3, -2) like k_cache/v_cache) so their layout matches before calling trtllm_batch_decode_with_kv_cache.flashinfer/testing/kvfp4.py (1)
118-126:⚠️ Potential issue | 🟠 MajorAvoid
0/0in all-zero blocks.When
block_maxis zero, Line 126 divides byblock_scales_fixed * global_scale == 0and the quantized block becomes NaN. Use a clamped copy only for the divisor path.Suggested fix
- block_scales_fixed = block_scales.unsqueeze(-1) + block_scales_fixed = block_scales.clamp_min(1e-6).unsqueeze(-1) x_scaled = reshaped / (block_scales_fixed * global_scale)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/testing/kvfp4.py` around lines 118 - 126, The division can produce 0/0 for all-zero blocks when computing x_scaled = reshaped / (block_scales_fixed * global_scale); update the code that computes x_scaled to use a safe divisor: compute safe_divisor = (block_scales_fixed * global_scale).clone() and clamp it to a small positive eps (e.g. 1e-8) before dividing, then restore zeros for all-zero blocks (detect via block_max == 0) by setting x_scaled[...] = 0 where block_max.squeeze(-1) == 0; change references to block_scales_fixed/global_scale/block_max and x_scaled accordingly so only the divisor path uses the clamped copy.
🧹 Nitpick comments (1)
benchmarks/bench_trtllm_fmha.py (1)
123-139: Consider clarifying variable naming for scale computation.The
k_scale_valvariable initially holds the query inverse scale (q_inv_scale) before being multiplied by the K global scale (k_gs). While the final computation is correct forbmm1_scale = q_scale * k_scale * sm_scale, the intermediate naming is confusing.Consider renaming for clarity:
♻️ Suggested clarification
if kv_cache_dtype == "nvfp4": # NVFP4 KV requires FP8 query if q.dtype != torch.float8_e4m3fn: q, q_inv_scale = to_float8(q) - k_scale_val = ( + q_scale_val = ( q_inv_scale.item() if isinstance(q_inv_scale, torch.Tensor) else q_inv_scale ) else: - k_scale_val = 1.0 + q_scale_val = 1.0 kv_cache, kv_block_scales, k_gs, v_gs = ( KVFP4QuantizeUtil.quantize_paged_kv_cache(kv_cache[:, 0], kv_cache[:, 1]) ) - k_scale_val *= k_gs + # Combined scale for bmm1: q_scale * k_scale + k_scale_val = q_scale_val * k_gs v_scale_val = v_gs🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_trtllm_fmha.py` around lines 123 - 139, The scale variable naming is confusing: k_scale_val is initially set from the query inverse scale (q_inv_scale) and later multiplied by the K global scale (k_gs). Update the variables in the nvfp4 branch to make intent explicit: capture the query inverse scale into a clearly named variable (e.g., q_inv_scale_val or q_scale_val) when calling to_float8(q), keep the K/V global scales returned from KVFP4QuantizeUtil.quantize_paged_kv_cache as k_gs and v_gs, then compute the final k_scale_val by multiplying the query-scale variable with k_gs and set v_scale_val = v_gs; adjust references to q_inv_scale, k_scale_val and v_scale_val accordingly so the intermediate meaning is clear.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/bench_trtllm_fmha.py`:
- Around line 210-213: The IO calculation misses element sizes when kv_cache is
a tuple: update the tuple branch that computes io so each tensor in kv_cache
multiplies its numel() by element_size() (e.g. replace kv_cache[0].numel() +
kv_cache[1].numel() with kv_cache[0].numel() * kv_cache[0].element_size() +
kv_cache[1].numel() * kv_cache[1].element_size()) so io = q.numel() *
q.element_size() + <those products>; keep the existing scalar-tensor branch
as-is.
In `@flashinfer/testing/kvfp4.py`:
- Around line 29-60: The E2M1 lookup tensors (E2M1_VALUES, E2M1_BOUNDS) are
pinned to a single device at import time which causes device-mismatch errors
when quantize/dequantize are called with tensors on a different GPU; change the
code so these tables are created or moved to the input tensor's device at
runtime inside the quantize and dequantize functions: (1) stop constructing
E2M1_VALUES and E2M1_BOUNDS with device=_device at module import (construct them
on CPU or keep their data as Python lists/constants); (2) inside the functions
that use them (the quantize and dequantize routines that perform comparisons
with E2M1_BOUNDS and indexing into E2M1_VALUES), call
.to(input_tensor.device).to(dtype=torch.float32) (or recreate the tensors on
input_tensor.device) before any comparisons or indexing so both operands share
device and dtype; (3) ensure you reference the same symbols E2M1_VALUES and
E2M1_BOUNDS so callers are unchanged.
---
Duplicate comments:
In `@flashinfer/decode.py`:
- Around line 1302-1306: The NHD-to-HND transpose only moves k_cache and v_cache
but leaves key_block_scales and value_block_scales in NHD order, causing layout
mismatch for the trtllm-gen NVFP4 kernel; update the same conditional in
decode.py (the branch that checks self._backend == "trtllm-gen" and
self._kv_layout == "NHD") to also transpose key_block_scales and
value_block_scales with the same axes (use transpose(-3, -2) like
k_cache/v_cache) so their layout matches before calling
trtllm_batch_decode_with_kv_cache.
In `@flashinfer/testing/kvfp4.py`:
- Around line 118-126: The division can produce 0/0 for all-zero blocks when
computing x_scaled = reshaped / (block_scales_fixed * global_scale); update the
code that computes x_scaled to use a safe divisor: compute safe_divisor =
(block_scales_fixed * global_scale).clone() and clamp it to a small positive eps
(e.g. 1e-8) before dividing, then restore zeros for all-zero blocks (detect via
block_max == 0) by setting x_scaled[...] = 0 where block_max.squeeze(-1) == 0;
change references to block_scales_fixed/global_scale/block_max and x_scaled
accordingly so only the divisor path uses the clamped copy.
In `@include/flashinfer/trtllm/fmha/kernelParams.h`:
- Around line 921-923: Replace the hardcoded constant with the actual KV length
sum: change the assignment that sets params.mSumOfSeqLensKv to use
options.mSumOfSeqLensKv instead of the literal 64 so the kernel params reflect
the true batch KV length (look for params.mSumOfSeqLensKv and
options.mSumOfSeqLensKv in the kernelParams.h code).
---
Nitpick comments:
In `@benchmarks/bench_trtllm_fmha.py`:
- Around line 123-139: The scale variable naming is confusing: k_scale_val is
initially set from the query inverse scale (q_inv_scale) and later multiplied by
the K global scale (k_gs). Update the variables in the nvfp4 branch to make
intent explicit: capture the query inverse scale into a clearly named variable
(e.g., q_inv_scale_val or q_scale_val) when calling to_float8(q), keep the K/V
global scales returned from KVFP4QuantizeUtil.quantize_paged_kv_cache as k_gs
and v_gs, then compute the final k_scale_val by multiplying the query-scale
variable with k_gs and set v_scale_val = v_gs; adjust references to q_inv_scale,
k_scale_val and v_scale_val accordingly so the intermediate meaning is clear.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 4d6d49a9-286b-46c4-97af-ab7e9772f862
📒 Files selected for processing (10)
benchmarks/bench_trtllm_fmha.pybenchmarks/routines/attention.pybenchmarks/routines/flashinfer_benchmark_utils.pycsrc/trtllm_fmha_kernel_launcher.cuflashinfer/decode.pyflashinfer/mla.pyflashinfer/testing/__init__.pyflashinfer/testing/kvfp4.pyinclude/flashinfer/trtllm/fmha/kernelParams.htests/attention/test_trtllm_gen_attention.py
🚧 Files skipped from review as they are similar to previous changes (1)
- benchmarks/routines/flashinfer_benchmark_utils.py
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
include/flashinfer/trtllm/fmha/kernelParams.h (1)
557-570:⚠️ Potential issue | 🔴 CriticalPacked-QKV with mixed Q/KV dtypes computes wrong K/V offsets.
When
isPackedQkv(mQkvLayout)is true andmDataTypeQ != mDataTypeKv, the offsets computed at lines 565–570 are incorrect.getDevicePtrs()receives onlybitsPerEltbased onmDataTypeKv(line 705), but the packed buffer stores Q first with its own element width. The K offset should account for Q's byte count, not KV's. There is no upstream validation preventing this combination—the code at line 743 explicitly handles mixed dtypes (transformsKv), including the E4M3/E2M1 case (lines 747–749). Either add a guard to reject packed QKV when datatypes differ, or pass separate element widths for Q and K/V togetDevicePtrs().🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/trtllm/fmha/kernelParams.h` around lines 557 - 570, getDevicePtrs currently computes K/V offsets using a single bitsPerElt (passed from mDataTypeKv) which is wrong when isPackedQkv(mQkvLayout) is true but mDataTypeQ != mDataTypeKv; update the code so offsets use Q's element width for advancing past Q and KV's element width for K/V: either add a guard in getDevicePtrs/isPackedQkv to reject mixed datatypes (check mDataTypeQ vs mDataTypeKv and error out) or change getDevicePtrs signature to accept two element-width parameters (bitsPerEltQ, bitsPerEltKv) and use bitsPerEltQ when computing the offset from qkvPtr to kPtr and then use bitsPerEltKv for the offset to vPtr; modify call sites that pass bitsPerElt (e.g., where getDevicePtrs is invoked) to supply both widths and update references to mNumHeadsQ, mNumHeadsKv, mHeadDimQk, and qkvPtr accordingly.flashinfer/decode.py (1)
2301-2331:⚠️ Potential issue | 🟠 MajorFail fast before packed NVFP4 KV can fall into XQA.
NVFP4 detection only turns on when
kv_block_scalesis present. A packeduint8KV cache without scales is silently treated as ordinary UINT8 input, andbackend="auto"can still select XQA on non-SM100 devices even though the XQA binding has no block-scale parameters (csrc/flashinfer_xqa_binding.cu:31-39). Require scales for packed UINT8 caches and reject non-trtllm-gendispatch up front.Suggested fix
- is_nvfp4_kvcache = ( - k_cache.dtype == torch.uint8 - and v_cache.dtype == torch.uint8 - and kv_block_scales is not None - ) + is_packed_uint8_kvcache = k_cache.dtype == torch.uint8 and v_cache.dtype == torch.uint8 + cc_major = get_compute_capability(query.device)[0] + if is_packed_uint8_kvcache: + if kv_block_scales is None: + raise ValueError("kv_block_scales must be provided for NVFP4 KV cache") + if cc_major != 10: + raise ValueError("NVFP4 KV cache is only supported on SM100/SM103") + if backend not in ("auto", "trtllm-gen"): + raise ValueError("NVFP4 KV cache is only supported by the trtllm-gen backend") + is_nvfp4_kvcache = is_packed_uint8_kvcache and kv_block_scales is not None @@ - if backend == "auto": - backend = ( - "trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa" - ) + if backend == "auto": + backend = "trtllm-gen" if cc_major == 10 else "xqa"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/decode.py` around lines 2301 - 2331, Detect packed NVFP4 KV by checking k_cache.dtype == v_cache.dtype == torch.uint8 even when kv_block_scales is None, and if such a packed UINT8 KV is found require scales or force/reject non-trtllm-gen dispatch: if kv_block_scales is None and both k_cache and v_cache are uint8, then if backend == "auto" set backend = "trtllm-gen", else if backend != "trtllm-gen" raise a clear ValueError that kv_block_scales are required for packed NVFP4 KV or backend must be "trtllm-gen"; keep the existing is_nvfp4_kvcache logic (which should remain gated on kv_block_scales) and only treat k_block_scales/v_block_scales when kv_block_scales is provided.
♻️ Duplicate comments (3)
tests/attention/test_trtllm_gen_attention.py (1)
1064-1087:⚠️ Potential issue | 🟠 MajorKeep wrapper coverage enabled for NVFP4 KV.
The direct call above forwards
kv_block_scales, but this guard still skips everykv_dtype == "nvfp4"wrapper case and the wrapper invocation below still omitskv_block_scales. That leaves the new wrapper plumbing inflashinfer.decode.BatchDecodeWithPagedKVCacheWrapper.run()unexercised.Suggested fix
if ( o_dtype != "nvfp4" - and kv_dtype != "nvfp4" and backend == "trtllm-gen" and q_len_per_req is not None # only test for the case all requests have the same q_len - ): # wrapper api does not support fp4 output/kv yet. + ): # wrapper api does not support fp4 output yet. @@ output_wrapper = wrapper_trtllm_gen.run( q_input, kv_cache, q_scale=q_scale, k_scale=k_scale, v_scale=v_scale / o_scale, + kv_block_scales=kv_block_scales, enable_pdl=enable_pdl, sinks=(sink if enable_sink else None), q_len_per_req=q_len_per_req, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_trtllm_gen_attention.py` around lines 1064 - 1087, The test currently skips wrapper coverage when kv_dtype == "nvfp4" and the wrapper.run() call omits kv_block_scales; update the condition so NVFP4 KV cases are allowed through (remove or relax the kv_dtype != "nvfp4" check) and call wrapper_trtllm_gen.run(...) including the kv_block_scales argument (pass the existing kv_block_scales variable) so BatchDecodeWithPagedKVCacheWrapper.run() is exercised; ensure plan(...) is still called on wrapper_trtllm_gen and q_len_per_req logic is preserved.flashinfer/testing/kvfp4.py (1)
29-30:⚠️ Potential issue | 🟠 MajorMove the E2M1 tables off the import-time device.
E2M1_VALUESandE2M1_BOUNDSare pinned to whichever CUDA device is current when this module is imported.batched_quantize()andbatched_dequantize()later use them withtensor.device, so calls oncuda:1(or CPU on a CUDA-enabled host) will fail with a device-mismatch error.Suggested fix
-# Put constants directly on CUDA if available -_device = "cuda" if torch.cuda.is_available() else "cpu" # E2M1 format: 1 sign bit + 2 exponent bits + 1 mantissa bit = 4 bits # 16 possible values: 0x0-0xF # Negative values: 0x8-0xF (sign bit = 1) # Positive values: 0x0-0x7 (sign bit = 0) E2M1_VALUES = torch.tensor( @@ - dtype=torch.float32, - device=_device, + dtype=torch.float32, ) # Boundaries for rounding to nearest E2M1 value (only for positive values) E2M1_BOUNDS = torch.tensor( - [0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5], dtype=torch.float32, device=_device + [0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5], dtype=torch.float32 ) @@ b, m, n = tensor.shape device = tensor.device + e2m1_bounds = E2M1_BOUNDS.to(device=device) @@ - magnitude_bits = torch.sum(abs_vals.unsqueeze(-1) >= E2M1_BOUNDS, dim=-1).to( + magnitude_bits = torch.sum(abs_vals.unsqueeze(-1) >= e2m1_bounds, dim=-1).to( torch.uint8 ) @@ b, m, n_half = quant_tensor.shape n = n_half * 2 + e2m1_values = E2M1_VALUES.to(device=quant_tensor.device) @@ - float_vals = E2M1_VALUES[fp4_vals.long()] + float_vals = e2m1_values[fp4_vals.long()]Also applies to: 35-60
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/testing/kvfp4.py` around lines 29 - 30, E2M1_VALUES and E2M1_BOUNDS are created at import using the current _device and so get pinned to the import-time CUDA device; change them to be created on CPU (or unpinned) and move to the call-site device inside batched_quantize and batched_dequantize: keep the global symbols E2M1_VALUES and E2M1_BOUNDS but construct them on CPU (remove use of _device for their creation) and in batched_quantize()/batched_dequantize() call .to(tensor.device, non_blocking=True) (or create device-specific cached copies) before using them to avoid device-mismatch errors.flashinfer/decode.py (1)
1285-1292:⚠️ Potential issue | 🟠 MajorNormalize
kv_block_scalesexactly like the KV cache.The wrapper path still assumes tensor-form scales always have two K/V planes, and it never transposes them on the NHD→HND
trtllm-genpath. The direct API in this same file already handles both cases, so wrapper calls can either fail duringunbind()or feed misaligned scale tensors into the kernel.Suggested fix
if kv_block_scales is not None: if isinstance(kv_block_scales, tuple): key_block_scales, value_block_scales = kv_block_scales else: - key_block_scales, value_block_scales = kv_block_scales.unbind(dim=1) + if kv_block_scales.shape[1] == 1: + key_block_scales, value_block_scales = kv_block_scales, kv_block_scales + else: + assert kv_block_scales.shape[1] == 2, ( + "When kv_block_scales is a single tensor, the second dimension must be 1 or 2" + ) + key_block_scales, value_block_scales = kv_block_scales.unbind(dim=1) @@ if self._backend == "trtllm-gen" and self._kv_layout == "NHD": # For NHD: [..., N, H, D] -> HND: [..., H, N, D] k_cache = k_cache.transpose(-3, -2) v_cache = v_cache.transpose(-3, -2) + if key_block_scales is not None: + key_block_scales = key_block_scales.transpose(-3, -2) + if value_block_scales is not None: + value_block_scales = value_block_scales.transpose(-3, -2)Also applies to: 1302-1306
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/decode.py` around lines 1285 - 1292, The kv_block_scales handling currently assumes tensor-form scales have two K/V planes and skips the NHD→HND transpose used by the KV cache, which can cause unbind() failures or misaligned scales; update the unpacking for kv_block_scales to mirror the KV cache normalization used elsewhere in this file: accept a tuple as (key_block_scales, value_block_scales) or a single tensor and split it robustly (not just unbind(dim=1)), and when the tensor is in NHD layout perform the same transpose to HND before splitting; apply the same fix at both occurrences that set key_block_scales/value_block_scales (the block using kv_block_scales and the later one at the other occurrence), reusing the direct-API normalization logic rather than assuming shape.
🧹 Nitpick comments (3)
csrc/trtllm_fmha_kernel_launcher.cu (1)
370-375: Inconsistent layout comment.The comment at line 370 says "Assume NHD layout" but the decode path at line 269 was updated to "Assume HND layout". Both compute strides identically, so one comment is incorrect. Please align the comments to reflect the actual layout.
📝 Proposed fix
- // Assume NHD layout: [..., H, N, D] + // Assume HND layout: [..., H, N, D] int page_size = key_cache.size(-2);🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/trtllm_fmha_kernel_launcher.cu` around lines 370 - 375, The comment above the key_cache stride calculations is inconsistent with the decode path; change the layout comment to match the decode path's "Assume HND layout" (or whichever layout is actually used) so it aligns with how strides are computed for key_cache; update the comment near the declarations of page_size, num_kv_heads, kv_stride_keys_values, kv_stride_heads, and kv_stride_batch to read "Assume HND layout: [..., H, N, D]" (or update the decode path comment instead if HND is wrong) so both locations consistently describe the same layout.benchmarks/bench_trtllm_fmha.py (2)
127-133: Redundant tensor check and unreachable branch.
to_float8always returns a tensor forq_inv_scale, so theisinstancecheck at line 129 is alwaysTrue, making lines 132-133 unreachable dead code. Simplify this logic:♻️ Proposed fix
if q.dtype != torch.float8_e4m3fn: q, q_inv_scale = to_float8(q) - k_scale_val = ( - q_inv_scale.item() - if isinstance(q_inv_scale, torch.Tensor) - else q_inv_scale - ) + k_scale_val = q_inv_scale.item() else: k_scale_val = 1.0🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_trtllm_fmha.py` around lines 127 - 133, The code sets k_scale_val with an unnecessary isinstance check against torch.Tensor because to_float8 always returns a tensor; remove the conditional and unreachable else branch and simply extract the scalar via q_inv_scale.item() (or use float(q_inv_scale)) so k_scale_val is assigned from q_inv_scale.item() directly; update the assignment near the k_scale_val variable where q_inv_scale is produced and remove the now-dead else branch.
140-141: Discarded FP8 scale value.The
to_float8function returns a scale value that is discarded here. For accurate benchmarking with FP8, consider whetherk_scale/v_scaleshould be set similarly to the nvfp4 path.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_trtllm_fmha.py` around lines 140 - 141, The FP8 scale returned by to_float8 is being discarded; update the kv_cache handling so you capture the returned scale (e.g., kv_cache, kv_scale = to_float8(kv_cache)) and set the appropriate k_scale and v_scale variables (or assign both to kv_scale) when kv_cache_dtype.startswith("fp8") and q_dtype != "fp8", matching how the nvfp4 branch sets k_scale/v_scale so FP8 quantization scales are used consistently in benchmarking.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@benchmarks/routines/attention.py`:
- Around line 600-606: The NVFP4 branch currently force-casts q to FP8
(torch.float8_e4m3fn) which mislabels reported q_dtype; instead either validate
and reject non-FP8 inputs or propagate an effective_q_dtype used by metrics: in
the is_nvfp4_kv branch (around checks using q and
KVFP4QuantizeUtil.quantize_paged_kv_cache) raise a clear error when q.dtype !=
torch.float8_e4m3fn, or set effective_q_dtype = torch.float8_e4m3fn and ensure
that downstream metric/charging code reads effective_q_dtype rather than the
original q.dtype so the benchmark records the actual query format.
- Around line 397-404: The current NVFP4 KV cache handling in the is_nvfp4_kv
block improperly force-adds "trtllm-native" after compute-capability filtering
(so functions like filter_backends_by_compute_capability / subsequent calls to
flashinfer.decode.trtllm_batch_decode_with_kv_cache may run on an unsupported
GPU or a different backend labeled as trtllm-native). Fix it by not appending
"trtllm-native" unconditionally; instead compute the intersection of allowed
backends and ["trtllm-native"], and if empty log/print a clear message and skip
NVFP4-specific benchmarking (or raise/continue) so only truly supported backends
remain in the backends list; update the block that references is_nvfp4_kv and
the backends list manipulation (the backends variable and the NVFP4 handling
code) to enforce capability-filtering rather than bypassing it.
In `@include/flashinfer/trtllm/fmha/kernelParams.h`:
- Around line 218-241: The pointer dump helper in kernelParams.h is missing
ptrSkipSoftmaxStats, so add a printf entry for ptrSkipSoftmaxStats (matching the
style of the other prints, e.g., printf("ptrSkipSoftmaxStats: %p\n",
(void*)ptrSkipSoftmaxStats);) in the same location and in definition order
alongside the other ptr* prints (refer to existing symbols like ptrSoftmaxStats,
ptrPartialStats, ptrSeqLensKv) so the dump is complete for skip-softmax
debugging.
- Line 260: The printf uses %ld for an int64_t field (mNumHiddenEltsO) which is
non-portable; include <cinttypes> and change the print to use the PRId64 macro
(e.g., printf("mNumHiddenEltsO: %" PRId64 "\n", mNumHiddenEltsO)) so the
formatting is correct across LP64/LLP64 targets — update the include and the
printf call in kernelParams.h where mNumHiddenEltsO is printed.
---
Outside diff comments:
In `@flashinfer/decode.py`:
- Around line 2301-2331: Detect packed NVFP4 KV by checking k_cache.dtype ==
v_cache.dtype == torch.uint8 even when kv_block_scales is None, and if such a
packed UINT8 KV is found require scales or force/reject non-trtllm-gen dispatch:
if kv_block_scales is None and both k_cache and v_cache are uint8, then if
backend == "auto" set backend = "trtllm-gen", else if backend != "trtllm-gen"
raise a clear ValueError that kv_block_scales are required for packed NVFP4 KV
or backend must be "trtllm-gen"; keep the existing is_nvfp4_kvcache logic (which
should remain gated on kv_block_scales) and only treat
k_block_scales/v_block_scales when kv_block_scales is provided.
In `@include/flashinfer/trtllm/fmha/kernelParams.h`:
- Around line 557-570: getDevicePtrs currently computes K/V offsets using a
single bitsPerElt (passed from mDataTypeKv) which is wrong when
isPackedQkv(mQkvLayout) is true but mDataTypeQ != mDataTypeKv; update the code
so offsets use Q's element width for advancing past Q and KV's element width for
K/V: either add a guard in getDevicePtrs/isPackedQkv to reject mixed datatypes
(check mDataTypeQ vs mDataTypeKv and error out) or change getDevicePtrs
signature to accept two element-width parameters (bitsPerEltQ, bitsPerEltKv) and
use bitsPerEltQ when computing the offset from qkvPtr to kPtr and then use
bitsPerEltKv for the offset to vPtr; modify call sites that pass bitsPerElt
(e.g., where getDevicePtrs is invoked) to supply both widths and update
references to mNumHeadsQ, mNumHeadsKv, mHeadDimQk, and qkvPtr accordingly.
---
Duplicate comments:
In `@flashinfer/decode.py`:
- Around line 1285-1292: The kv_block_scales handling currently assumes
tensor-form scales have two K/V planes and skips the NHD→HND transpose used by
the KV cache, which can cause unbind() failures or misaligned scales; update the
unpacking for kv_block_scales to mirror the KV cache normalization used
elsewhere in this file: accept a tuple as (key_block_scales, value_block_scales)
or a single tensor and split it robustly (not just unbind(dim=1)), and when the
tensor is in NHD layout perform the same transpose to HND before splitting;
apply the same fix at both occurrences that set
key_block_scales/value_block_scales (the block using kv_block_scales and the
later one at the other occurrence), reusing the direct-API normalization logic
rather than assuming shape.
In `@flashinfer/testing/kvfp4.py`:
- Around line 29-30: E2M1_VALUES and E2M1_BOUNDS are created at import using the
current _device and so get pinned to the import-time CUDA device; change them to
be created on CPU (or unpinned) and move to the call-site device inside
batched_quantize and batched_dequantize: keep the global symbols E2M1_VALUES and
E2M1_BOUNDS but construct them on CPU (remove use of _device for their creation)
and in batched_quantize()/batched_dequantize() call .to(tensor.device,
non_blocking=True) (or create device-specific cached copies) before using them
to avoid device-mismatch errors.
In `@tests/attention/test_trtllm_gen_attention.py`:
- Around line 1064-1087: The test currently skips wrapper coverage when kv_dtype
== "nvfp4" and the wrapper.run() call omits kv_block_scales; update the
condition so NVFP4 KV cases are allowed through (remove or relax the kv_dtype !=
"nvfp4" check) and call wrapper_trtllm_gen.run(...) including the
kv_block_scales argument (pass the existing kv_block_scales variable) so
BatchDecodeWithPagedKVCacheWrapper.run() is exercised; ensure plan(...) is still
called on wrapper_trtllm_gen and q_len_per_req logic is preserved.
---
Nitpick comments:
In `@benchmarks/bench_trtllm_fmha.py`:
- Around line 127-133: The code sets k_scale_val with an unnecessary isinstance
check against torch.Tensor because to_float8 always returns a tensor; remove the
conditional and unreachable else branch and simply extract the scalar via
q_inv_scale.item() (or use float(q_inv_scale)) so k_scale_val is assigned from
q_inv_scale.item() directly; update the assignment near the k_scale_val variable
where q_inv_scale is produced and remove the now-dead else branch.
- Around line 140-141: The FP8 scale returned by to_float8 is being discarded;
update the kv_cache handling so you capture the returned scale (e.g., kv_cache,
kv_scale = to_float8(kv_cache)) and set the appropriate k_scale and v_scale
variables (or assign both to kv_scale) when kv_cache_dtype.startswith("fp8") and
q_dtype != "fp8", matching how the nvfp4 branch sets k_scale/v_scale so FP8
quantization scales are used consistently in benchmarking.
In `@csrc/trtllm_fmha_kernel_launcher.cu`:
- Around line 370-375: The comment above the key_cache stride calculations is
inconsistent with the decode path; change the layout comment to match the decode
path's "Assume HND layout" (or whichever layout is actually used) so it aligns
with how strides are computed for key_cache; update the comment near the
declarations of page_size, num_kv_heads, kv_stride_keys_values, kv_stride_heads,
and kv_stride_batch to read "Assume HND layout: [..., H, N, D]" (or update the
decode path comment instead if HND is wrong) so both locations consistently
describe the same layout.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 00eff0e7-12ca-4018-8c34-c52eb1b24732
📒 Files selected for processing (10)
benchmarks/bench_trtllm_fmha.pybenchmarks/routines/attention.pybenchmarks/routines/flashinfer_benchmark_utils.pycsrc/trtllm_fmha_kernel_launcher.cuflashinfer/decode.pyflashinfer/mla.pyflashinfer/testing/__init__.pyflashinfer/testing/kvfp4.pyinclude/flashinfer/trtllm/fmha/kernelParams.htests/attention/test_trtllm_gen_attention.py
🚧 Files skipped from review as they are similar to previous changes (1)
- flashinfer/mla.py
e5f8507 to
17f2353
Compare
There was a problem hiding this comment.
♻️ Duplicate comments (5)
benchmarks/routines/attention.py (2)
600-606:⚠️ Potential issue | 🟠 MajorDon't silently benchmark FP8 queries under a non-FP8 request.
This branch force-casts
qto FP8 for NVFP4, but the benchmark still carries the original requestedq_dtypethrough perf accounting and result reporting. Either reject non-FP8 inputs here or thread an explicit effective query dtype through the metric path.Suggested fix
if is_nvfp4_kv: # NVFP4 KV requires FP8 query - if q.dtype != torch.float8_e4m3fn: - q = q.to(torch.float8_e4m3fn) + if q_dtype != torch.float8_e4m3fn: + print("[ERROR] NVFP4 KV cache requires --q_dtype fp8.") + return res kv_cache_nvfp4, kv_block_scales, k_scale, v_scale = ( KVFP4QuantizeUtil.quantize_paged_kv_cache(kv_cache[:, 0], kv_cache[:, 1]) )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/routines/attention.py` around lines 600 - 606, The code branch under is_nvfp4_kv force-casts q to torch.float8_e4m3fn but leaves the original requested q dtype recorded in benchmarks; either validate and raise if q.dtype is not FP8 or propagate the effective query dtype into the metric/reporting path. Update the block around is_nvfp4_kv/q/KVFP4QuantizeUtil.quantize_paged_kv_cache to (a) check q.dtype and raise a clear error if non-FP8 inputs are not supported, or (b) set an explicit effective_q_dtype variable (e.g. effective_q_dtype = torch.float8_e4m3fn) after casting and ensure all perf accounting and result reporting use effective_q_dtype instead of the original q.dtype so metrics reflect the actual dtype used.
397-404:⚠️ Potential issue | 🟠 MajorDon't re-add
trtllm-nativeafter capability filtering.If compute-capability filtering removed
trtllm-native, appending it back here can benchmark NVFP4 under an unsupported device/backend combination. The later call still usesbackend="auto", so this can also run a different backend under thetrtllm-nativelabel.Suggested fix
# NVFP4 KV cache only works with trtllm-native (direct API with kv_block_scales) if is_nvfp4_kv: - unsupported = [b for b in backends if b != "trtllm-native"] - for b in unsupported: - print(f"[INFO] {b} backend does not support NVFP4 KV cache. Skipping.") - backends.remove(b) - if "trtllm-native" not in backends: - backends.append("trtllm-native") + backends = [b for b in backends if b == "trtllm-native"] + if not backends: + print("[ERROR] NVFP4 KV cache is not supported on this device/backend set.") + return res🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/routines/attention.py` around lines 397 - 404, The current block handling is_nvfp4_kv wrongly re-appends "trtllm-native" after filtering, which can override compute-capability decisions; modify the logic in the is_nvfp4_kv section so you do not re-add "trtllm-native" if it was removed by capability filtering — either remove the line that appends "trtllm-native" or conditionally append only if it was present in the original backends list and still supported; update references to backends and the is_nvfp4_kv check accordingly to preserve correct filtering and avoid mislabeling.flashinfer/decode.py (1)
1285-1306:⚠️ Potential issue | 🟠 MajorTranspose
kv_block_scaleswith the KV cache on the NHDtrtllm-genpath.This branch converts
k_cacheandv_cachefrom NHD to HND, butkey_block_scalesandvalue_block_scalesstay in NHD order. NVFP4 then feeds HND-packed KV pages with mismatched scale tensors.Suggested fix
# Convert NHD layout to HND for trtllm-gen backend if self._backend == "trtllm-gen" and self._kv_layout == "NHD": # For NHD: [..., N, H, D] -> HND: [..., H, N, D] k_cache = k_cache.transpose(-3, -2) v_cache = v_cache.transpose(-3, -2) + if key_block_scales is not None: + key_block_scales = key_block_scales.transpose(-3, -2) + if value_block_scales is not None: + value_block_scales = value_block_scales.transpose(-3, -2)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/decode.py` around lines 1285 - 1306, The NHD->HND conversion for k_cache and v_cache in the trtllm-gen path doesn't apply the same transpose to kv_block_scales, causing mismatched ordering; update the block under "if self._backend == 'trtllm-gen' and self._kv_layout == 'NHD':" to also transpose key_block_scales and value_block_scales (when present) using the same axes as k_cache/v_cache (i.e., transpose(-3, -2)); handle both the tuple and unbound tensor forms of kv_block_scales (respecting the earlier unpack logic where kv_block_scales may be None, a tuple, or produced by .unbind(dim=1)) so the scales are reordered to HND to match k_cache/v_cache.benchmarks/bench_trtllm_fmha.py (2)
210-213:⚠️ Potential issue | 🟡 MinorAccount for bytes, not elements, in the NVFP4 IO path.
The tuple branch adds raw
numel()for the packed KV tensors and also skipskv_block_scales, so the printed GB/s is understated for NVFP4.Suggested fix
if isinstance(kv_cache, tuple): - io = q.numel() * q.element_size() + kv_cache[0].numel() + kv_cache[1].numel() + io = ( + q.numel() * q.element_size() + + kv_cache[0].numel() * kv_cache[0].element_size() + + kv_cache[1].numel() * kv_cache[1].element_size() + ) + if kv_block_scales is not None: + if isinstance(kv_block_scales, tuple): + io += sum(t.numel() * t.element_size() for t in kv_block_scales) + else: + io += kv_block_scales.numel() * kv_block_scales.element_size() else: io = q.numel() * q.element_size() + kv_cache.numel() * kv_cache.element_size()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_trtllm_fmha.py` around lines 210 - 213, The IO calculation currently under the NVFP4 tuple branch uses raw element counts and omits kv_block_scales, understating bytes; update the tuple branch that computes io (using kv_cache and q) to sum byte sizes by multiplying each tensor's numel() by its element_size() (e.g., kv_cache[0].numel()*kv_cache[0].element_size() and kv_cache[1].numel()*kv_cache[1].element_size()) and also add the byte size of any kv_block_scales tensor(s) (numel()*element_size()) so the NVFP4 GB/s report uses bytes, not elements.
123-139:⚠️ Potential issue | 🟠 MajorNVFP4 silently changes the benchmarked query dtype.
When
kv_cache_dtype == "nvfp4", this path converts non-FP8 queries to FP8 before the run. That means the benchmark no longer measures the workload requested byq_dtype, and the CLI currently gives the caller no way to see or control that change.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_trtllm_fmha.py` around lines 123 - 139, The code currently mutates the benchmark query `q` to FP8 when `kv_cache_dtype == "nvfp4"` (via to_float8), which silently changes the measured dtype; instead, preserve the original `q` and create a separate FP8 copy only for NVFP4-specific operations: call to_float8(q) and store result in a new variable (e.g., q_fp8 and q_fp8_inv_scale) and use q_fp8 when computing `k_scale_val`, `kv_cache` quantization (KVFP4QuantizeUtil.quantize_paged_kv_cache) and any NVFP4 paths, leaving `q` and reported `q_dtype` unchanged; also add a clear guard or CLI flag to either: (a) error if the requested q_dtype is not FP8 and kv_cache_dtype is nvfp4, or (b) emit a warning/info that an FP8 copy was used for NVFP4 operations and document the flag to opt-in—implement the change around the block referencing kv_cache_dtype, q, to_float8, and KVFP4QuantizeUtil.quantize_paged_kv_cache.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@benchmarks/bench_trtllm_fmha.py`:
- Around line 210-213: The IO calculation currently under the NVFP4 tuple branch
uses raw element counts and omits kv_block_scales, understating bytes; update
the tuple branch that computes io (using kv_cache and q) to sum byte sizes by
multiplying each tensor's numel() by its element_size() (e.g.,
kv_cache[0].numel()*kv_cache[0].element_size() and
kv_cache[1].numel()*kv_cache[1].element_size()) and also add the byte size of
any kv_block_scales tensor(s) (numel()*element_size()) so the NVFP4 GB/s report
uses bytes, not elements.
- Around line 123-139: The code currently mutates the benchmark query `q` to FP8
when `kv_cache_dtype == "nvfp4"` (via to_float8), which silently changes the
measured dtype; instead, preserve the original `q` and create a separate FP8
copy only for NVFP4-specific operations: call to_float8(q) and store result in a
new variable (e.g., q_fp8 and q_fp8_inv_scale) and use q_fp8 when computing
`k_scale_val`, `kv_cache` quantization
(KVFP4QuantizeUtil.quantize_paged_kv_cache) and any NVFP4 paths, leaving `q` and
reported `q_dtype` unchanged; also add a clear guard or CLI flag to either: (a)
error if the requested q_dtype is not FP8 and kv_cache_dtype is nvfp4, or (b)
emit a warning/info that an FP8 copy was used for NVFP4 operations and document
the flag to opt-in—implement the change around the block referencing
kv_cache_dtype, q, to_float8, and KVFP4QuantizeUtil.quantize_paged_kv_cache.
In `@benchmarks/routines/attention.py`:
- Around line 600-606: The code branch under is_nvfp4_kv force-casts q to
torch.float8_e4m3fn but leaves the original requested q dtype recorded in
benchmarks; either validate and raise if q.dtype is not FP8 or propagate the
effective query dtype into the metric/reporting path. Update the block around
is_nvfp4_kv/q/KVFP4QuantizeUtil.quantize_paged_kv_cache to (a) check q.dtype and
raise a clear error if non-FP8 inputs are not supported, or (b) set an explicit
effective_q_dtype variable (e.g. effective_q_dtype = torch.float8_e4m3fn) after
casting and ensure all perf accounting and result reporting use
effective_q_dtype instead of the original q.dtype so metrics reflect the actual
dtype used.
- Around line 397-404: The current block handling is_nvfp4_kv wrongly re-appends
"trtllm-native" after filtering, which can override compute-capability
decisions; modify the logic in the is_nvfp4_kv section so you do not re-add
"trtllm-native" if it was removed by capability filtering — either remove the
line that appends "trtllm-native" or conditionally append only if it was present
in the original backends list and still supported; update references to backends
and the is_nvfp4_kv check accordingly to preserve correct filtering and avoid
mislabeling.
In `@flashinfer/decode.py`:
- Around line 1285-1306: The NHD->HND conversion for k_cache and v_cache in the
trtllm-gen path doesn't apply the same transpose to kv_block_scales, causing
mismatched ordering; update the block under "if self._backend == 'trtllm-gen'
and self._kv_layout == 'NHD':" to also transpose key_block_scales and
value_block_scales (when present) using the same axes as k_cache/v_cache (i.e.,
transpose(-3, -2)); handle both the tuple and unbound tensor forms of
kv_block_scales (respecting the earlier unpack logic where kv_block_scales may
be None, a tuple, or produced by .unbind(dim=1)) so the scales are reordered to
HND to match k_cache/v_cache.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: ed7f2473-d203-439a-8f76-18f3eb207067
📒 Files selected for processing (5)
benchmarks/bench_trtllm_fmha.pybenchmarks/routines/attention.pybenchmarks/routines/flashinfer_benchmark_utils.pyflashinfer/decode.pyflashinfer/mla.py
🚧 Files skipped from review as they are similar to previous changes (1)
- benchmarks/routines/flashinfer_benchmark_utils.py
17f2353 to
77e4f6d
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (2)
benchmarks/bench_trtllm_fmha.py (1)
210-213:⚠️ Potential issue | 🟡 MinorFix tuple KV IO accounting to include element sizes.
Line 211 currently sums only element counts for tuple KV caches, so bandwidth is misreported.
Suggested fix
if isinstance(kv_cache, tuple): - io = q.numel() * q.element_size() + kv_cache[0].numel() + kv_cache[1].numel() + io = ( + q.numel() * q.element_size() + + kv_cache[0].numel() * kv_cache[0].element_size() + + kv_cache[1].numel() * kv_cache[1].element_size() + ) else: io = q.numel() * q.element_size() + kv_cache.numel() * kv_cache.element_size()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_trtllm_fmha.py` around lines 210 - 213, The IO calculation for tuple KV caches underestimates bandwidth because it adds only numel() for each tensor; update the tuple branch in the io computation (where kv_cache is checked and io is assigned) to multiply each tensor's numel() by its element_size() (i.e., use kv_cache[0].numel() * kv_cache[0].element_size() + kv_cache[1].numel() * kv_cache[1].element_size()) and keep the q contribution as q.numel() * q.element_size().include/flashinfer/trtllm/fmha/kernelParams.h (1)
261-261:⚠️ Potential issue | 🟡 MinorUse portable formatting for
int64_tin debug print.Line 261 uses
%ldformNumHiddenEltsO(int64_t), which is not portable across LP64/LLP64 targets. PreferPRId64.#!/bin/bash # Verify current formatting usage around mNumHiddenEltsO in kernelParams.h rg -n "mNumHiddenEltsO|%ld|PRId64" include/flashinfer/trtllm/fmha/kernelParams.h -C2Suggested fix
+#include <cinttypes> ... - printf("mNumHiddenEltsO: %ld\n", mNumHiddenEltsO); + printf("mNumHiddenEltsO: %" PRId64 "\n", mNumHiddenEltsO);🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/trtllm/fmha/kernelParams.h` at line 261, Replace the non-portable printf usage that prints mNumHiddenEltsO (an int64_t) with a portable int64_t format: include <inttypes.h> if not already included and change the printf to use "%" PRId64 and cast mNumHiddenEltsO to (int64_t) (e.g. printf("mNumHiddenEltsO: %" PRId64 "\n", (int64_t)mNumHiddenEltsO);) so the output is correct across LP64/LLP64 targets.
🧹 Nitpick comments (1)
tests/attention/test_trtllm_gen_attention.py (1)
1040-1050: Consider expanding documentation for NVFP4 tolerance thresholds.The NVFP4 KV cache tolerances (rtol=0.3, atol=0.3, allowed_mismatch_rate=0.05) are intentional design choices that reflect FP4's inherent 4-bit quantization limitations. While the existing comment ("NVFP4 KV cache has significant quantization error") explains the relaxation, consider adding more specifics:
- Why a 5% mismatch rate is appropriate for this quantization format
- Whether these thresholds have been validated against expected FP4 precision characteristics
- Any references to NVFP4 specification or prior analysis
This helps future maintainers understand whether the bounds can be tightened if FP4 implementation is refined.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_trtllm_gen_attention.py` around lines 1040 - 1050, Expand the comment block around the NVFP4 tolerance assignments (where kv_dtype is checked and rtol, atol, allowed_mismatch_rate are set) to document why the NVFP4 values were chosen: state that FP4’s 4-bit quantization causes significant rounding/representation error, justify the 0.3 rtol/atol and 0.05 mismatch rate by referencing any validation/benchmarks or expected FP4 precision characteristics, note whether these thresholds were empirically validated (and how) and include a pointer to NVFP4 specification or analysis notes so future maintainers can reassess tightening rtol/atol or allowed_mismatch_rate for kv_dtype == "nvfp4".
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/decode.py`:
- Around line 1288-1292: The code assumes kv_block_scales can be unbound along
dim=1 into key_block_scales and value_block_scales but crashes when
kv_block_scales is a tensor with size(1)==1 (shared-KV); update the branch in
flashinfer/decode.py so that if kv_block_scales is not a tuple and
kv_block_scales.size(1)==1 you assign key_block_scales = value_block_scales =
kv_block_scales.squeeze(1) (or equivalent indexing), otherwise keep the current
kv_block_scales.unbind(dim=1) behavior; reference variables kv_block_scales,
key_block_scales, value_block_scales in the fix.
---
Duplicate comments:
In `@benchmarks/bench_trtllm_fmha.py`:
- Around line 210-213: The IO calculation for tuple KV caches underestimates
bandwidth because it adds only numel() for each tensor; update the tuple branch
in the io computation (where kv_cache is checked and io is assigned) to multiply
each tensor's numel() by its element_size() (i.e., use kv_cache[0].numel() *
kv_cache[0].element_size() + kv_cache[1].numel() * kv_cache[1].element_size())
and keep the q contribution as q.numel() * q.element_size().
In `@include/flashinfer/trtllm/fmha/kernelParams.h`:
- Line 261: Replace the non-portable printf usage that prints mNumHiddenEltsO
(an int64_t) with a portable int64_t format: include <inttypes.h> if not already
included and change the printf to use "%" PRId64 and cast mNumHiddenEltsO to
(int64_t) (e.g. printf("mNumHiddenEltsO: %" PRId64 "\n",
(int64_t)mNumHiddenEltsO);) so the output is correct across LP64/LLP64 targets.
---
Nitpick comments:
In `@tests/attention/test_trtllm_gen_attention.py`:
- Around line 1040-1050: Expand the comment block around the NVFP4 tolerance
assignments (where kv_dtype is checked and rtol, atol, allowed_mismatch_rate are
set) to document why the NVFP4 values were chosen: state that FP4’s 4-bit
quantization causes significant rounding/representation error, justify the 0.3
rtol/atol and 0.05 mismatch rate by referencing any validation/benchmarks or
expected FP4 precision characteristics, note whether these thresholds were
empirically validated (and how) and include a pointer to NVFP4 specification or
analysis notes so future maintainers can reassess tightening rtol/atol or
allowed_mismatch_rate for kv_dtype == "nvfp4".
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: b9e84230-fc54-4037-9f54-3d2d505543c7
📒 Files selected for processing (11)
benchmarks/bench_trtllm_fmha.pybenchmarks/routines/attention.pybenchmarks/routines/flashinfer_benchmark_utils.pycsrc/trtllm_fmha_kernel_launcher.cuflashinfer/decode.pyflashinfer/mla.pyflashinfer/testing/__init__.pyflashinfer/testing/kvfp4.pyinclude/flashinfer/trtllm/fmha/fmhaKernels.cuhinclude/flashinfer/trtllm/fmha/kernelParams.htests/attention/test_trtllm_gen_attention.py
🚧 Files skipped from review as they are similar to previous changes (4)
- include/flashinfer/trtllm/fmha/fmhaKernels.cuh
- benchmarks/routines/attention.py
- flashinfer/testing/init.py
- flashinfer/mla.py
|
cc @PerkzZheng @Tom-Zheng for viz and review~ |
77e4f6d to
ba3b178
Compare
|
/bot run |
|
[FAILED] Pipeline #46162422: 9/20 passed |
ba94862 to
ad06df9
Compare
|
/bot run |
|
@sychen52 is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww |
|
/bot run |
|
[SUCCESS] Pipeline #46194361: 9/20 passed |
Head branch was pushed to by a user without write access
ad06df9 to
ebac227
Compare
address comments small fix Add support for trtllm gen nvfp4 kv cache (none-interleave) Disable checksum and download Add unit test Add unit test fix
cleanup debug prints in cpp try to add none HND support remove trtllm-native logic and small changes according to comments temp add support for prefill/context kernel as well. add kv_block_scales size into io update based on comments
ebac227 to
de36999
Compare
📌 Description
add nvfp4 kv cache prefill and decode kernels to flashinfer
🔍 Related Issues
#2458
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit