-
Notifications
You must be signed in to change notification settings - Fork 577
feat: add xqa backend and completes NHD/HND coverage for trtllm-gen/xqa backend #2001
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Qidi Sang <[email protected]>
Summary of ChangesHello @qsang-nv, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a new XQA (eXtended Query Attention) backend to optimize batch decoding operations within the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
Warning Rate limit exceeded@yzh119 has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 11 minutes and 14 seconds before requesting another review. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. 📒 Files selected for processing (1)
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughThreads runtime PDL control and stride-based paged‑KV handling through Python APIs, C++ bindings, and CUDA kernels; adds kv_layout (HND/NHD) with normalization paths; extends tensor maps, kernel/launcher signatures to accept enable_pdl and kv_stride parameters; updates tests to cover layouts and PDL permutations. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Py as Python API
participant Bind as C++ Binding
participant Host as Launch/Host
participant CU as CUDA Kernel
Py->>Bind: xqa_batch_decode_with_kv_cache(q, k_cache, v_cache, kv_layout, enable_pdl)
Bind->>Bind: resolve enable_pdl if None (device_support_pdl)
alt kv_layout == "HND"
Bind->>Bind: transpose K/V (HND → NHD) for xqa path
end
Bind->>Bind: extract kv_stride_page/token/head (in elements)
Bind->>Host: call launch*(..., enable_pdl, kv_stride_page, kv_stride_token, kv_stride_head, stream)
Host->>Host: makeLaunchConfig(..., enable_pdl)
Host->>Host: makeTensorMapForPagedKVCache(..., stride_page, stride_token, stride_head)
Host->>CU: kernel(..., stride_page_in_heads, stride_token_in_heads, stride_head_in_heads)
CU->>CU: stride-based addressing selects page/head/token
CU-->>Host: outputs
Host-->>Bind: return tensor
Bind-->>Py: result
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @qsang-nv, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates a new XQA (eXtended Query Attention) backend into the FlashInfer library. The primary goal is to enhance decoding performance by leveraging specialized Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new xqa backend for decoding, which is enabled for specific GPU architectures. The changes are a good addition, but I've identified a few issues. There's a bug in the compute capability check that restricts the new backend from running on all supported GPUs. I also found a critical data type mismatch for the semaphore buffer being passed to the new xqa functions, which could lead to runtime errors. Additionally, there are some hardcoded values in scale calculations that should be replaced with function arguments to improve code maintainability.
flashinfer/decode.py
Outdated
| xqa( | ||
| query, | ||
| k_cache.reshape(-1, head_dim), | ||
| v_cache.reshape(-1, head_dim), | ||
| block_tables, | ||
| seq_lens, | ||
| out, | ||
| workspace_0, | ||
| workspace_1, | ||
| num_kv_heads, | ||
| page_size, | ||
| sinks=sinks, | ||
| q_scale=q_scale_value, | ||
| kv_scale=torch.tensor( | ||
| [kv_scale_value], dtype=torch.float32, device=query.device | ||
| ), | ||
| sliding_win_size=window_left + 1 if window_left >= 0 else 0, | ||
| sm_count=sm_count, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The xqa function expects the semaphores argument to be a torch.uint32 tensor, but workspace_1 has a torch.uint8 dtype because it's a chunk of workspace_buffer. This type mismatch will cause issues in the xqa kernel. You should reinterpret the tensor view to the correct dtype before passing it.
| xqa( | |
| query, | |
| k_cache.reshape(-1, head_dim), | |
| v_cache.reshape(-1, head_dim), | |
| block_tables, | |
| seq_lens, | |
| out, | |
| workspace_0, | |
| workspace_1, | |
| num_kv_heads, | |
| page_size, | |
| sinks=sinks, | |
| q_scale=q_scale_value, | |
| kv_scale=torch.tensor( | |
| [kv_scale_value], dtype=torch.float32, device=query.device | |
| ), | |
| sliding_win_size=window_left + 1 if window_left >= 0 else 0, | |
| sm_count=sm_count, | |
| ) | |
| xqa( | |
| query, | |
| k_cache.reshape(-1, head_dim), | |
| v_cache.reshape(-1, head_dim), | |
| block_tables, | |
| seq_lens, | |
| out, | |
| workspace_0, | |
| workspace_1.view(torch.uint32), | |
| num_kv_heads, | |
| page_size, | |
| sinks=sinks, | |
| q_scale=q_scale_value, | |
| kv_scale=torch.tensor( | |
| [kv_scale_value], dtype=torch.float32, device=query.device | |
| ), | |
| sliding_win_size=window_left + 1 if window_left >= 0 else 0, | |
| sm_count=sm_count, | |
| ) |
flashinfer/decode.py
Outdated
| xqa_mla( | ||
| query, | ||
| kv_cache, | ||
| kv_cache, | ||
| block_tables, | ||
| seq_lens, | ||
| out, | ||
| workspace_0, | ||
| workspace_1, | ||
| block_size, | ||
| q_scale=q_scale_value, | ||
| kv_scale=torch.tensor( | ||
| [kv_scale_value], dtype=torch.float32, device=query.device | ||
| ), | ||
| sm_count=sm_count, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The xqa_mla function expects the semaphores argument to be a torch.uint32 tensor, but workspace_1 has a torch.uint8 dtype because it's a chunk of workspace_buffer. This type mismatch will cause issues in the xqa_mla kernel. You should reinterpret the tensor view to the correct dtype before passing it.
| xqa_mla( | |
| query, | |
| kv_cache, | |
| kv_cache, | |
| block_tables, | |
| seq_lens, | |
| out, | |
| workspace_0, | |
| workspace_1, | |
| block_size, | |
| q_scale=q_scale_value, | |
| kv_scale=torch.tensor( | |
| [kv_scale_value], dtype=torch.float32, device=query.device | |
| ), | |
| sm_count=sm_count, | |
| ) | |
| xqa_mla( | |
| query, | |
| kv_cache, | |
| kv_cache, | |
| block_tables, | |
| seq_lens, | |
| out, | |
| workspace_0, | |
| workspace_1.view(torch.uint32), | |
| block_size, | |
| q_scale=q_scale_value, | |
| kv_scale=torch.tensor( | |
| [kv_scale_value], dtype=torch.float32, device=query.device | |
| ), | |
| sm_count=sm_count, | |
| ) |
flashinfer/decode.py
Outdated
| ) | ||
| # To decide if using xqa to decode | ||
| if ( | ||
| get_compute_capability(torch.device(device="cuda"))[0] in [9, 12] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The compute capability check for using the xqa backend is missing support for SM100 (compute capability 10) and SM110 (compute capability 11). The xqa kernel supports SM90, SM100, SM110, and SM120. The condition should be updated to include 10 and 11 to enable the backend on all supported architectures.
| get_compute_capability(torch.device(device="cuda"))[0] in [9, 12] | |
| get_compute_capability(torch.device(device="cuda"))[0] in [9, 10, 11, 12] |
flashinfer/decode.py
Outdated
| if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None: | ||
| kv_scale_value = bmm2_scale_tensor.item() | ||
| q_scale_value = ( | ||
| bmm1_scale_log2_tensor.item() | ||
| / kv_scale_value | ||
| * ((128 + 64) ** 0.5) | ||
| * math.log2(math.e) | ||
| ) | ||
| else: | ||
| kv_scale_value = bmm2_scale if bmm2_scale is not None else 1.0 | ||
| q_scale_value = ( | ||
| bmm1_scale / kv_scale_value * ((128 + 64) ** 0.5) | ||
| if bmm1_scale is not None | ||
| else 1.0 | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The scale calculation for q_scale_value uses a hardcoded value (128 + 64). This corresponds to qk_nope_head_dim + qk_rope_head_dim. Using the function arguments qk_nope_head_dim and qk_rope_head_dim will improve code readability and maintainability.
| if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None: | |
| kv_scale_value = bmm2_scale_tensor.item() | |
| q_scale_value = ( | |
| bmm1_scale_log2_tensor.item() | |
| / kv_scale_value | |
| * ((128 + 64) ** 0.5) | |
| * math.log2(math.e) | |
| ) | |
| else: | |
| kv_scale_value = bmm2_scale if bmm2_scale is not None else 1.0 | |
| q_scale_value = ( | |
| bmm1_scale / kv_scale_value * ((128 + 64) ** 0.5) | |
| if bmm1_scale is not None | |
| else 1.0 | |
| ) | |
| if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None: | |
| kv_scale_value = bmm2_scale_tensor.item() | |
| q_scale_value = ( | |
| bmm1_scale_log2_tensor.item() | |
| / kv_scale_value | |
| * ((qk_nope_head_dim + qk_rope_head_dim) ** 0.5) | |
| * math.log2(math.e) | |
| ) | |
| else: | |
| kv_scale_value = bmm2_scale if bmm2_scale is not None else 1.0 | |
| q_scale_value = ( | |
| bmm1_scale / kv_scale_value * ((qk_nope_head_dim + qk_rope_head_dim) ** 0.5) | |
| if bmm1_scale is not None | |
| else 1.0 | |
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new XQA backend for decoding, which is a valuable addition. The implementation is mostly correct, but I've identified a few issues that need to be addressed. These include a missing compute capability in a conditional check, a critical data type mismatch for semaphore buffers that could lead to runtime errors, and the use of hardcoded values where function parameters should be used for better maintainability. Please see the detailed comments for specific suggestions.
flashinfer/decode.py
Outdated
| xqa( | ||
| query, | ||
| k_cache.reshape(-1, head_dim), | ||
| v_cache.reshape(-1, head_dim), | ||
| block_tables, | ||
| seq_lens, | ||
| out, | ||
| workspace_0, | ||
| workspace_1, | ||
| num_kv_heads, | ||
| page_size, | ||
| sinks=sinks, | ||
| q_scale=q_scale_value, | ||
| kv_scale=torch.tensor( | ||
| [kv_scale_value], dtype=torch.float32, device=query.device | ||
| ), | ||
| sliding_win_size=window_left + 1 if window_left >= 0 else 0, | ||
| sm_count=sm_count, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The semaphores parameter of the xqa function expects a torch.uint32 tensor, but it's being passed workspace_1, which is a chunk of a torch.uint8 tensor. This type mismatch will likely cause a runtime error or incorrect behavior. You should view the tensor as torch.uint32 before passing it to the function.
| xqa( | |
| query, | |
| k_cache.reshape(-1, head_dim), | |
| v_cache.reshape(-1, head_dim), | |
| block_tables, | |
| seq_lens, | |
| out, | |
| workspace_0, | |
| workspace_1, | |
| num_kv_heads, | |
| page_size, | |
| sinks=sinks, | |
| q_scale=q_scale_value, | |
| kv_scale=torch.tensor( | |
| [kv_scale_value], dtype=torch.float32, device=query.device | |
| ), | |
| sliding_win_size=window_left + 1 if window_left >= 0 else 0, | |
| sm_count=sm_count, | |
| ) | |
| xqa( | |
| query, | |
| k_cache.reshape(-1, head_dim), | |
| v_cache.reshape(-1, head_dim), | |
| block_tables, | |
| seq_lens, | |
| out, | |
| workspace_0, | |
| workspace_1.view(torch.uint32), | |
| num_kv_heads, | |
| page_size, | |
| sinks=sinks, | |
| q_scale=q_scale_value, | |
| kv_scale=torch.tensor( | |
| [kv_scale_value], dtype=torch.float32, device=query.device | |
| ), | |
| sliding_win_size=window_left + 1 if window_left >= 0 else 0, | |
| sm_count=sm_count, | |
| ) |
flashinfer/decode.py
Outdated
| xqa_mla( | ||
| query, | ||
| kv_cache, | ||
| kv_cache, | ||
| block_tables, | ||
| seq_lens, | ||
| out, | ||
| workspace_0, | ||
| workspace_1, | ||
| block_size, | ||
| q_scale=q_scale_value, | ||
| kv_scale=torch.tensor( | ||
| [kv_scale_value], dtype=torch.float32, device=query.device | ||
| ), | ||
| sm_count=sm_count, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The semaphores parameter of the xqa_mla function expects a torch.uint32 tensor, but it's being passed workspace_1, which is a chunk of a torch.uint8 tensor. This type mismatch will likely cause a runtime error or incorrect behavior. You should view the tensor as torch.uint32 before passing it to the function.
| xqa_mla( | |
| query, | |
| kv_cache, | |
| kv_cache, | |
| block_tables, | |
| seq_lens, | |
| out, | |
| workspace_0, | |
| workspace_1, | |
| block_size, | |
| q_scale=q_scale_value, | |
| kv_scale=torch.tensor( | |
| [kv_scale_value], dtype=torch.float32, device=query.device | |
| ), | |
| sm_count=sm_count, | |
| ) | |
| xqa_mla( | |
| query, | |
| kv_cache, | |
| kv_cache, | |
| block_tables, | |
| seq_lens, | |
| out, | |
| workspace_0, | |
| workspace_1.view(torch.uint32), | |
| block_size, | |
| q_scale=q_scale_value, | |
| kv_scale=torch.tensor( | |
| [kv_scale_value], dtype=torch.float32, device=query.device | |
| ), | |
| sm_count=sm_count, | |
| ) |
flashinfer/decode.py
Outdated
| if ( | ||
| get_compute_capability(torch.device(device="cuda"))[0] in [9, 12] | ||
| and out_dtype != "nvfp4" | ||
| and query.dtype in [torch.float16, torch.bfloat16] | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check for compute capability is missing support for SM100 (compute capability 10). The xqa implementation in flashinfer/xqa.py indicates support for SM90, SM100, and SM120, which correspond to compute capabilities 9, 10, and 12. This condition should be updated to include 10 to enable the XQA path on SM100 GPUs.
| if ( | |
| get_compute_capability(torch.device(device="cuda"))[0] in [9, 12] | |
| and out_dtype != "nvfp4" | |
| and query.dtype in [torch.float16, torch.bfloat16] | |
| ): | |
| if ( | |
| get_compute_capability(torch.device(device="cuda"))[0] in [9, 10, 12] | |
| and out_dtype != "nvfp4" | |
| and query.dtype in [torch.float16, torch.bfloat16] | |
| ): |
flashinfer/decode.py
Outdated
| if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None: | ||
| kv_scale_value = bmm2_scale_tensor.item() | ||
| q_scale_value = ( | ||
| bmm1_scale_log2_tensor.item() | ||
| / kv_scale_value | ||
| * ((128 + 64) ** 0.5) | ||
| * math.log2(math.e) | ||
| ) | ||
| else: | ||
| kv_scale_value = bmm2_scale if bmm2_scale is not None else 1.0 | ||
| q_scale_value = ( | ||
| bmm1_scale / kv_scale_value * ((128 + 64) ** 0.5) | ||
| if bmm1_scale is not None | ||
| else 1.0 | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The head dimension is hardcoded as (128 + 64). This makes the code less maintainable. You should use the function parameters qk_nope_head_dim and qk_rope_head_dim instead, which are available in this function's scope.
| if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None: | |
| kv_scale_value = bmm2_scale_tensor.item() | |
| q_scale_value = ( | |
| bmm1_scale_log2_tensor.item() | |
| / kv_scale_value | |
| * ((128 + 64) ** 0.5) | |
| * math.log2(math.e) | |
| ) | |
| else: | |
| kv_scale_value = bmm2_scale if bmm2_scale is not None else 1.0 | |
| q_scale_value = ( | |
| bmm1_scale / kv_scale_value * ((128 + 64) ** 0.5) | |
| if bmm1_scale is not None | |
| else 1.0 | |
| ) | |
| if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None: | |
| kv_scale_value = bmm2_scale_tensor.item() | |
| q_scale_value = ( | |
| bmm1_scale_log2_tensor.item() | |
| / kv_scale_value | |
| * ((qk_nope_head_dim + qk_rope_head_dim) ** 0.5) | |
| * math.log2(math.e) | |
| ) | |
| else: | |
| kv_scale_value = bmm2_scale if bmm2_scale is not None else 1.0 | |
| q_scale_value = ( | |
| bmm1_scale / kv_scale_value * ((qk_nope_head_dim + qk_rope_head_dim) ** 0.5) | |
| if bmm1_scale is not None | |
| else 1.0 | |
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer/decode.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/decode.py (2)
flashinfer/utils.py (1)
get_compute_capability(251-254)include/flashinfer/trtllm/common.h (1)
device(83-90)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (1)
flashinfer/decode.py (1)
24-24: LGTM - Import statement is correct.The xqa and xqa_mla imports are properly used in the conditional decode paths added below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @qsang-nv, Currently test_trtllm_gen_attention.py and test_trtllm_gen_mla.py check for SM 100f and skips tests otherwise.
The current changes to decode.py will only impact SM90 and 120f cases, hence there are no unit tests being added for the current change. Can you add them?
Signed-off-by: Qidi Sang <[email protected]>
Signed-off-by: Qidi Sang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (2)
flashinfer/decode.py (2)
2518-2521: Use the query tensor’s device when checking compute capability.This condition still probes
cuda:0unconditionally. On multi-GPU systems, a request running on another device will pick the wrong capability and skip the XQA-MLA fast path. The earlier review already flagged this; please switch toquery.device.- get_compute_capability(torch.device(device="cuda"))[0] == 12 + get_compute_capability(query.device)[0] == 12
2345-2368: Fix semaphore buffer dtype passed to XQA.
xqaexpects its semaphore buffer to betorch.uint32, butworkspace_1remainstorch.int8aftertorch.chunk. Reinterpreting without converting will trip kernel-side type checks or yield garbage. Cast the second chunk to the correct view before passing it along.- workspace_0, workspace_1 = torch.chunk(workspace_buffer, 2, dim=0) + workspace_0, workspace_1 = torch.chunk(workspace_buffer, 2, dim=0) + semaphores = workspace_1.view(torch.uint32) @@ - workspace_1, + semaphores,
🧹 Nitpick comments (1)
flashinfer/xqa.py (1)
276-276: Inconsistent default value for cached function parameter.
get_xqa_module_mlahasenable_pdl: bool = Truewith a default value, whileget_xqa_moduleat line 40 hasenable_pdl: boolwithout a default. For functions decorated with@functools.cache, it's better to be consistent—either both should have defaults or neither should. Cached functions typically avoid defaults to ensure all parameters are explicit in the cache key.Consider removing the default value for consistency:
- enable_pdl: bool = True, + enable_pdl: bool,
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
flashinfer/aot.py(2 hunks)flashinfer/decode.py(4 hunks)flashinfer/jit/xqa.py(6 hunks)flashinfer/xqa.py(10 hunks)tests/attention/test_xqa.py(5 hunks)tests/attention/test_xqa_batch_decode.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
flashinfer/jit/xqa.py (1)
flashinfer/jit/core.py (1)
gen_jit_spec(287-353)
flashinfer/xqa.py (2)
flashinfer/jit/core.py (1)
build_and_load(272-284)flashinfer/utils.py (2)
register_custom_op(272-281)register_custom_op(291-310)
tests/attention/test_xqa_batch_decode.py (3)
tests/test_helpers/sink_attention_reference.py (1)
sink_attention_unified(39-402)flashinfer/utils.py (1)
get_compute_capability(251-254)flashinfer/decode.py (4)
BatchDecodeWithPagedKVCacheWrapper(582-1411)use_tensor_cores(780-781)use_tensor_cores(1577-1578)xqa_batch_decode_with_kv_cache(2254-2377)
flashinfer/decode.py (2)
flashinfer/xqa.py (4)
xqa(56-93)xqa(124-265)xqa_mla(292-321)xqa_mla(348-457)flashinfer/utils.py (3)
device_support_pdl(568-572)get_device_sm_count(595-596)get_compute_capability(251-254)
🪛 Ruff (0.14.2)
tests/attention/test_xqa_batch_decode.py
277-277: Unpacked variable in_kv_lens is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
flashinfer/decode.py
2260-2260: Unused function argument: max_seq_len
(ARG001)
2266-2266: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (1)
flashinfer/xqa.py (1)
53-53: LGTM: Registration names correctly includeenable_pdl.Including
enable_pdlin both the custom op and fake op registration names (lines 53, 96, 289, 324) is correct and essential. This ensures that PDL-enabled and PDL-disabled kernels are registered as distinct operations, preventing cache collisions and ensuring the correct module is selected at runtime.Also applies to: 96-96, 289-289, 324-324
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (2)
flashinfer/decode.py (2)
2361-2379: Critical: Thread enable_pdl flag to the XQA kernel.The
enable_pdlflag is computed at line 2320 but never passed to thexqacall. This causes the kernel to always use its default value (True), ignoring the caller's intent and device capabilities. Note thatxqa_mlaat line 2559 correctly passes this parameter.Apply this fix:
sliding_win_size=window_left + 1 if window_left >= 0 else 0, sm_count=sm_count, + enable_pdl=enable_pdl, )
2521-2526: Fix device handling in compute capability check.Line 2522 uses
torch.device(device="cuda")which defaults tocuda:0, causing incorrect capability detection on multi-GPU systems when the query tensor resides on a different GPU.Apply this fix:
if ( - get_compute_capability(torch.device(device="cuda"))[0] == 12 + get_compute_capability(query.device)[0] == 12 and query.dtype == torch.float8_e4m3fn
🧹 Nitpick comments (1)
flashinfer/decode.py (1)
2264-2264: Consider whether max_seq_len should be used.Static analysis flags that
max_seq_lenis unused in this function. If the parameter is genuinely not needed by the xqa kernel (which infers max_seq_len fromblock_tablesandpage_size), consider removing it from the signature to avoid confusion. If it's intended for future use or validation, add a comment explaining why it's currently unused.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer/decode.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/decode.py (2)
flashinfer/xqa.py (4)
xqa(56-93)xqa(124-265)xqa_mla(292-321)xqa_mla(348-457)flashinfer/utils.py (3)
device_support_pdl(568-572)get_device_sm_count(595-596)get_compute_capability(251-254)
🪛 Ruff (0.14.2)
flashinfer/decode.py
2264-2264: Unused function argument: max_seq_len
(ARG001)
2270-2270: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
Signed-off-by: Qidi Sang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
flashinfer/decode.py (1)
2349-2379: Pass the right buffers intoxqaand honor the caller’s output/PDL choicesWe’re still hitting the same blockers that earlier reviews called out:
outdefaults toNone, butxqarequires a real tensor – callers that rely on the optional output will crash.- The semaphore chunk stays
uint8; the kernel expects auint32view, so dtype checks fail.enable_pdlnever reachesxqa, forcing PDL on even when the caller disables it or the device doesn’t support it.Please fix all three together.
workspace_0, workspace_1 = torch.chunk(workspace_buffer, 2, dim=0) kv_scale_value = bmm2_scale q_scale_value = bmm1_scale / kv_scale_value * (head_dim**0.5) + semaphores = workspace_1.view(torch.uint32) + if out is None: + out = torch.empty_like(query) + k_cache_new = k_cache.reshape(-1, head_dim).contiguous() v_cache_new = v_cache.reshape(-1, head_dim).contiguous() query_new = query.unsqueeze(1).contiguous() seq_lens_new = seq_lens.unsqueeze(1).contiguous() sinks_new = ( @@ workspace_0, - workspace_1, + semaphores, num_kv_heads, page_size, sinks=sinks_new, q_scale=q_scale_value, kv_scale=torch.tensor( [kv_scale_value], dtype=torch.float32, device=query.device ), sliding_win_size=window_left + 1 if window_left >= 0 else 0, sm_count=sm_count, + enable_pdl=enable_pdl, )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
csrc/xqa/mha_sm90.cu(1 hunks)flashinfer/decode.py(3 hunks)tests/attention/test_xqa_batch_decode.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/attention/test_xqa_batch_decode.py (3)
tests/test_helpers/sink_attention_reference.py (1)
sink_attention_unified(39-402)flashinfer/utils.py (1)
get_compute_capability(251-254)flashinfer/decode.py (1)
xqa_batch_decode_with_kv_cache(2258-2381)
flashinfer/decode.py (2)
flashinfer/xqa.py (2)
xqa(56-93)xqa(124-265)flashinfer/utils.py (2)
device_support_pdl(568-572)get_device_sm_count(595-596)
🪛 Ruff (0.14.2)
tests/attention/test_xqa_batch_decode.py
277-277: Unpacked variable in_kv_lens is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
flashinfer/decode.py
2264-2264: Unused function argument: max_seq_len
(ARG001)
2270-2270: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (1)
csrc/xqa/mha_sm90.cu (1)
1969-1974: Clamping behavior verified—refactor is safe.The "WithDup" naming is a deliberate design pattern used across similar load functions in the codebase (
loadShmColWiseVecWithDup,loadShmRowWiseVecWithDup). The clamping behavior is intentional: when thread indices exceedbound, they clamp to the last valid index and duplicate its value.The semantics are correct:
bound = headGrpSize - 1(passed at line 2743) represents the last valid index (0-indexed, inclusive)- Accessing
gmemVec[(headGrpSize - 1) * GmmaAccCoreMat::cols + j]is valid per theattentionSinksVecinitialization (lines 1228–1231)- The refactor from
reinterpret_castto explicit element-wise copying improves type safety without changing behavior
| # Create separate K and V caches with NHD layout | ||
| max_seq_len = torch.max(seq_lens).item() | ||
| num_tokens = max_seq_len * batch_size | ||
| num_pages = (num_tokens + page_size - 1) // page_size | ||
| ref_kv_dtype_torch = DTYPE_MAP[ref_kv_dtype] | ||
| if kv_dtype != "fp8": | ||
| assert kv_dtype == ref_kv_dtype, ( | ||
| "kv_dtype and ref_kv_dtype must be the same for non-fp8 kv_cache" | ||
| ) | ||
|
|
||
| # NHD layout: [num_pages, page_size, num_kv_heads, head_dim] | ||
| k_cache = torch.randn( | ||
| num_pages, | ||
| page_size, | ||
| num_kv_heads, | ||
| head_dim, | ||
| dtype=ref_kv_dtype_torch, | ||
| device=GPU_DEVICE, | ||
| ) | ||
| v_cache = torch.randn( | ||
| num_pages, | ||
| page_size, | ||
| num_kv_heads, | ||
| head_dim, | ||
| dtype=ref_kv_dtype_torch, | ||
| device=GPU_DEVICE, | ||
| ) | ||
|
|
||
| # Convert K and V separately to fp8 if needed | ||
| if kv_dtype == "fp8": | ||
| k_cache, k_scale = to_float8(k_cache / 4.0) | ||
| v_cache, v_scale = to_float8(v_cache / 4.0) | ||
| # use high precision and fake-quantization for reference to avoid precision/functional issue | ||
| ref_kv_cache = torch.stack( | ||
| [ | ||
| k_cache.to(ref_kv_dtype_torch) * k_scale, | ||
| v_cache.to(ref_kv_dtype_torch) * v_scale, | ||
| ], | ||
| dim=1, | ||
| ) | ||
| else: | ||
| k_scale = v_scale = 1.0 | ||
| ref_kv_cache = torch.stack([k_cache, v_cache], dim=1) | ||
| # Combine K and V into interleaved format for the API | ||
| kv_cache = torch.stack([k_cache, v_cache], dim=1) | ||
|
|
||
| return kv_cache, k_scale, v_scale, ref_kv_cache |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure the synthetic KV cache actually covers every page ID we emit
num_pages is computed from max_seq_len * batch_size, which can be smaller than the sum of ceil(seq_len_i / page_size) across the batch. Example: page_size=16, two requests with seq_lens=[17, 17] ⇒ max_seq_len * batch_size = 34 so we allocate 3 pages, but the page tables legitimately reference 4 distinct pages. That produces out-of-range page IDs (crashing both the reference path and the kernel under test). Please size the cache by the total pages needed.
- max_seq_len = torch.max(seq_lens).item()
- num_tokens = max_seq_len * batch_size
- num_pages = (num_tokens + page_size - 1) // page_size
+ page_per_seq = (seq_lens + page_size - 1) // page_size
+ num_pages = int(page_per_seq.sum().item())
+ max_seq_len = int(seq_lens.max().item())🤖 Prompt for AI Agents
In tests/attention/test_xqa_batch_decode.py around lines 69 to 115, the
synthetic KV cache is undersized because num_pages is computed from max_seq_len
* batch_size rather than the actual total pages referenced by all sequences;
compute num_pages as the sum over batch of ceil(seq_lens[i] / page_size) (i.e.,
total_pages = sum((seq_lens + page_size - 1)//page_size)) and use that value to
allocate k_cache, v_cache, and any dependent quantities so every page ID emitted
by the page tables is covered.
Dismissing to remove the "Request Change" tag after concerns are addressed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
csrc/xqa/mha.cu (1)
1406-1410: Add runtime enable_pdl guard to PDL intrinsics and plumb parameter through kernel_mhaThe kernel_mha function lacks an enable_pdl parameter. Both launchMHA and launchMHAFlashInfer possess this flag but fail to pass it to kernel_mha, causing unconditional PDL intrinsic execution at lines 1406-1410 when ENABLE_PDL is defined, regardless of runtime enable_pdl value.
Required changes:
- Add
bool enable_pdlparameter to kernel_mha signature (mha.cu:2455)- Guard PDL intrinsics with runtime flag at lines 1406-1410:
#if ENABLE_PDL + if (enable_pdl) { preExit(); acqBulk(); + } #endif
- Pass enable_pdl from launchMHA to kernel_mha (lines 2577, 2627)
- Pass enable_pdl from launchMHAFlashInfer to kernel_mha (lines 2706, 2730)
csrc/xqa/mha_sm90.cu (1)
1254-1260: Add runtime guard for PDL intrinsics or remove compile-time conditional compilationThe review comment is accurate. PDL intrinsics (
preExit()andacqBulk()) in bothmha_sm90.cu(lines 1255, 1258, 1286, 1307, 1343, 1412) andmha.cu(lines 1407–1408) are currently gated only by compile-timeENABLE_PDLmacros. Theenable_pdlruntime parameter exists at the host level (passed tomakeLaunchConfig()) but is not passed as a kernel parameter, so kernels cannot access it to conditionally execute these intrinsics.This creates a runtime/compile-time divergence: launching with
enable_pdl=falsehas no effect if the code was compiled withENABLE_PDLdefined.To resolve:
- Add
bool enable_pdlas a kernel parameter to all affected kernels (kernel_mha_implin mha.cu,kernel_mhain both files)- Thread it through the launch paths
- Wrap intrinsic calls:
if (enable_pdl) { preExit(); acqBulk(); }flashinfer/xqa.py (1)
99-118: Missingenable_pdlparameter in fake op signature.The
_fake_xqafunction is missing theenable_pdl: boolparameter that was added to the corresponding custom op at line 73. Fake ops must match the signature of their real counterparts for proper torch.compile integration.Apply this diff:
def _fake_xqa( run_sm90_fp8_mha: bool, sm_count: int, num_kv_heads: int, sliding_win_size: int, q_scale: float, output: torch.Tensor, q: torch.Tensor, sinks: Optional[torch.Tensor], k_cache: torch.Tensor, v_cache: torch.Tensor, page_table: torch.Tensor, max_seq_len: int, seq_lens: torch.Tensor, batch_size: int, kv_scale: torch.Tensor, semaphores: torch.Tensor, workspace_buffer: torch.Tensor, + enable_pdl: bool, ) -> None: pass
🧹 Nitpick comments (3)
csrc/xqa/mha_sm90.cu (2)
3039-3077: Host launcher: enable_pdl passed to launch config (good)Passing enable_pdl into makeLaunchConfig is correct. Once the kernel accepts enable_pdl, also pass it as a kernel argument.
3179-3215: FlashInfer launcher: enable_pdl passed (good)Same note: forward enable_pdl into the kernel parameter when added.
csrc/xqa/mha.cu (1)
2593-2594: Host launch: enable_pdl wired to launch config (good)Once the kernel accepts enable_pdl, also append it to the cudaLaunchKernelEx argument list and the kernel signature.
Also applies to: 2700-2701
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
csrc/flashinfer_xqa_binding.cu(2 hunks)csrc/xqa/defines.h(1 hunks)csrc/xqa/mha.cu(5 hunks)csrc/xqa/mha.h(6 hunks)csrc/xqa/mha_sm90.cu(5 hunks)csrc/xqa/mla_sm120.cu(4 hunks)csrc/xqa/xqa_wrapper.cu(4 hunks)flashinfer/xqa.py(14 hunks)
🧰 Additional context used
🧬 Code graph analysis (5)
csrc/xqa/mla_sm120.cu (1)
csrc/xqa/hostUtils.h (1)
makeLaunchConfig(4-12)
csrc/flashinfer_xqa_binding.cu (1)
csrc/xqa/mha_sm90.cu (2)
scratch(506-513)scratch(506-506)
flashinfer/xqa.py (1)
flashinfer/utils.py (1)
device_support_pdl(568-572)
csrc/xqa/mha.cu (1)
csrc/xqa/hostUtils.h (1)
makeLaunchConfig(4-12)
csrc/xqa/mha_sm90.cu (1)
csrc/xqa/hostUtils.h (1)
makeLaunchConfig(4-12)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (12)
csrc/xqa/xqa_wrapper.cu (2)
31-48: MLA wrapper: enable_pdl correctly threadedParameter is appended and forwarded to launchMLAFlashInfer before stream. Looks good.
Please confirm all FFI callers updated to pass the new trailing bool to xqa_wrapper_mla.
67-95: MHA wrapper: enable_pdl correctly threadedParameter appended and propagated to the MHA launcher selection. No issues spotted.
Please confirm Python/TVM bindings and tests now pass enable_pdl to xqa_wrapper.
csrc/flashinfer_xqa_binding.cu (2)
20-31: Binding prototypes match wrapper signature changeSignatures include trailing enable_pdl; exported names unchanged. Looks good.
Ensure downstream TVM packed-call sites are regenerated to include the new bool (last arg).
34-53: Binding for xqa_wrapper: OKPrototype and export aligned with wrapper changes.
csrc/xqa/mha.cu (1)
92-102: Add SM 1210 path: OKIncluding CUDA_ARCH==1210 alongside 860/890/1200 is consistent with existing pattern.
Please confirm performance tuning (preferedKHeadPartBytes=64, cacheVTileSeqLen=32) is valid for 12.1 as well.
csrc/xqa/mla_sm120.cu (2)
1727-1727: LGTM! Clean transition from compile-time to runtime configuration.The function signature now accepts
enable_pdlas a runtime parameter, and it's correctly propagated tomakeLaunchConfig. This replaces the previous compile-timeENABLE_PDL != 0check with dynamic configuration.Also applies to: 1765-1765
1864-1864: LGTM! Consistent with the launchMLA changes.The same pattern is applied to
launchMLAFlashInfer, threadingenable_pdlthrough to the launch configuration.Also applies to: 1888-1888
flashinfer/xqa.py (4)
29-29: LGTM! Appropriate import for hardware capability detection.The
device_support_pdlimport is used to automatically detect whether the device supports PDL based on compute capability (SM90+).
73-73: LGTM! Parameter correctly threaded through custom op.The
enable_pdlparameter is properly added to the custom op signature and propagated to the underlying wrapper.Also applies to: 93-93
411-413: Same terminology verification needed for xqa_mla.The same PDL acronym expansion issue applies here. Ensure consistency with the verified terminology.
141-141: LGTM! Parameter correctly integrated with smart defaulting.The
enable_pdlparameter is properly added as optional, defaults to hardware capability detection viadevice_support_pdl, and is correctly propagated through the call chain.Also applies to: 215-215, 271-271
csrc/xqa/mha.h (1)
131-131: LGTM! Consistent parameter addition across all launch functions.All six launch function declarations correctly add the
bool enable_pdlparameter in a consistent position (before thecudaStream_t streamparameter). This aligns with the implementation changes in the corresponding .cu files.Also applies to: 150-150, 192-192, 211-212, 234-234, 252-252
| #ifndef ENABLE_PDL | ||
| #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 | ||
| #if __CUDA_ARCH__ == 900 | ||
| #define ENABLE_PDL 2 | ||
| #else | ||
| #define ENABLE_PDL 1 | ||
| #endif | ||
| #else | ||
| /* default for host or older architectures */ | ||
| #define ENABLE_PDL 0 | ||
| #endif | ||
| #endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Runtime-vs-compile-time PDL mismatch
ENABLE_PDL defaults to 1/2 for SM≥90 at compile time, but kernels still use this macro while the host now passes a runtime enable_pdl. Host compilation sees ENABLE_PDL==0 (no CUDA_ARCH), so kernels may execute preExit/acqBulk even when enable_pdl=false at launch. This is inconsistent and can lead to invalid usage of programmatic stream serialization.
Thread enable_pdl into the kernels and guard PDL intrinsics with the runtime flag (while keeping the arch guards). See follow-up diffs in kernel files below.
🤖 Prompt for AI Agents
In csrc/xqa/defines.h around lines 131-142, the current ENABLE_PDL macro
selection based solely on __CUDA_ARCH__ causes kernels to unconditionally
compile PDL intrinsics even though the host uses a runtime enable_pdl flag;
update the headers and usages so that kernels still respect the arch guards but
also check a passed-in runtime boolean (e.g., enable_pdl) before invoking PDL
intrinsics: keep the existing __CUDA_ARCH__ checks to determine PDL availability
at compile time, expose a runtime enable_pdl parameter into kernels (thread it
through kernel arguments or capture in device lambdas), and wrap all calls to
preExit/acqBulk (and other PDL intrinsics) with a combined condition that
requires both compile-time availability and the runtime flag so that when
enable_pdl is false at launch no PDL intrinsics execute on device.
| uint32_t const clampedIdx = mha::min(i * nbThrdsPerInstNBase + idx, bound); | ||
| uint32_t const baseOffset = clampedIdx * GmmaAccCoreMat::cols; | ||
| #pragma unroll | ||
| for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { | ||
| ret[i][j] = gmemVec[baseOffset + j]; | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out‑of‑bounds read in loadGmemColWiseVecWithDup for attention sinks
gmemVec points to a buffer of size headGrpSize (see finalizeAndWriteOut_sync passing attentionSinksVec[0]), but this code multiplies the index by GmmaAccCoreMat::cols and reads baseOffset+j, which can exceed headGrpSize. We should load a single sink value per head and duplicate it across columns, without advancing memory by cols.
Apply this fix:
- uint32_t const clampedIdx = mha::min(i * nbThrdsPerInstNBase + idx, bound);
- uint32_t const baseOffset = clampedIdx * GmmaAccCoreMat::cols;
-#pragma unroll
- for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) {
- ret[i][j] = gmemVec[baseOffset + j];
- }
+ uint32_t const clampedIdx = mha::min(i * nbThrdsPerInstNBase + idx, bound);
+#pragma unroll
+ for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) {
+ // Duplicate the same head sink across the 2 columns
+ ret[i][j] = gmemVec[clampedIdx];
+ }📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| uint32_t const clampedIdx = mha::min(i * nbThrdsPerInstNBase + idx, bound); | |
| uint32_t const baseOffset = clampedIdx * GmmaAccCoreMat::cols; | |
| #pragma unroll | |
| for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { | |
| ret[i][j] = gmemVec[baseOffset + j]; | |
| } | |
| } | |
| uint32_t const clampedIdx = mha::min(i * nbThrdsPerInstNBase + idx, bound); | |
| #pragma unroll | |
| for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { | |
| // Duplicate the same head sink across the 2 columns | |
| ret[i][j] = gmemVec[clampedIdx]; | |
| } |
| kv_scale: Optional[torch.Tensor] = None, | ||
| sliding_win_size: int = 0, | ||
| sm_count: Optional[int] = None, | ||
| enable_pdl: Optional[bool] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to use the decorator for the support check?
@backend_requirement ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes using backend_requirement would be preferred. There are lots of functions with this flag, we can clean them up in a standalone PR.
|
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (6)
flashinfer/decode.py (6)
2152-2163: Critical: Squeeze singleton dimension when unpacking kv_cache.When
kv_cache.shape[1] == 1at line 2156, bothk_cacheandv_cacheare assigned to the same tensor without squeezing the singleton dimension. This preserves shape[num_pages, 1, ...]instead of the expected[num_pages, ...]. Downstream operations at lines 2200-2205 that transpose dimensions will operate on the wrong axes.Apply this fix:
if isinstance(kv_cache, tuple): k_cache, v_cache = kv_cache else: if kv_cache.shape[1] == 1: - k_cache, v_cache = kv_cache, kv_cache + k_cache = kv_cache.squeeze(1) + v_cache = k_cache else: assert kv_cache.shape[1] == 2, ( "When kv_cache is a single tensor, the second dimension must be 1 or 2" ) # NOTE(Zihao): unbind transforms [num_pages, 2, ...] to ([num_pages, ...], [num_pages, ...]) # it doesn't change underlying storage k_cache, v_cache = kv_cache.unbind(dim=1)
2165-2168: Add architecture validation for xqa backend.The auto backend selection defaults to
xqafor any non-SM100 architecture, but xqa requires SM90+ (as enforced in xqa.py). This could cause runtime failures on older GPUs like SM80.Consider adding an architecture check:
if backend == "auto": - backend = ( - "trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa" - ) + major = get_compute_capability(query.device)[0] + if major == 10: + backend = "trtllm-gen" + elif major >= 9: + backend = "xqa" + else: + raise GPUArchitectureError( + f"No compatible backend for SM{major}0. Requires SM90+ or SM100+." + )
2334-2343: Critical: Fix type hints, allocate output, and remove unused parameter.Multiple issues in the function signature and parameter handling:
- Line 2341:
enable_pdl: bool = Noneviolates PEP 484. Must beOptional[bool] = None.- Line 2334:
max_seq_lenparameter is declared but never used in the function body.- Output allocation missing: When
out=None, the function passesNonedirectly toxqaat line 2450, but xqa expects a non-Optionaltorch.Tensor. This will cause a runtime error.Apply this fix:
def xqa_batch_decode_with_kv_cache( query: torch.Tensor, kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], workspace_buffer: torch.Tensor, block_tables: torch.Tensor, seq_lens: torch.Tensor, - max_seq_len: int, bmm1_scale: float, bmm2_scale: float, window_left: int = -1, out: Optional[torch.Tensor] = None, sinks: Optional[torch.Tensor] = None, kv_layout: str = "NHD", - enable_pdl: bool = None, + enable_pdl: Optional[bool] = None, q_len_per_req: Optional[int] = 1, ) -> torch.Tensor: ... enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl assert q_len_per_req == 1, "xqa not support speculative decoding yet" ... + # Allocate output if not provided + if out is None: + out = torch.empty_like(query) + xqa( query_new,
2434-2434: Critical: workspace_1 dtype mismatch with xqa expectations.Line 2434 uses
torch.chunkto splitworkspace_buffer(uint8) into two halves, but at line 2452,workspace_1is passed as thesemaphoresparameter toxqa, which expects atorch.uint32tensor. This type mismatch will cause incorrect behavior or runtime errors.You cannot safely reinterpret uint8 as uint32 with
.view(torch.uint32)because the buffer may not be properly aligned or sized. Allocate a proper uint32 semaphores tensor:- workspace_0, workspace_1 = torch.chunk(workspace_buffer, 2, dim=0) + # Reserve first half for workspace_0 (uint8), allocate separate uint32 semaphores + workspace_0_size = workspace_buffer.numel() // 2 + workspace_0 = workspace_buffer[:workspace_0_size] + + # Allocate semaphores as uint32 (estimate size based on xqa requirements) + semaphores = torch.zeros( + workspace_buffer.numel() // 8, dtype=torch.uint32, device=query.device + ) + kv_scale_value = bmm2_scale q_scale_value = bmm1_scale / kv_scale_value * (head_dim**0.5) query_new = query.unsqueeze(1).contiguous() seq_lens_new = seq_lens.unsqueeze(1).contiguous() sinks_new = ( sinks.reshape(num_kv_heads, -1).contiguous() if sinks is not None else None ) xqa( query_new, k_cache, v_cache, block_tables, seq_lens_new, out, workspace_0, - workspace_1, + semaphores, num_kv_heads, page_size, sinks=sinks_new, q_scale=q_scale_value, kv_scale=torch.tensor( [kv_scale_value], dtype=torch.float32, device=query.device ), sliding_win_size=window_left + 1 if window_left >= 0 else 0, kv_layout=kv_layout, sm_count=sm_count, )
2444-2463: Critical: Pass enable_pdl parameter to xqa.The
enable_pdlflag is computed at line 2396 but never passed to thexqacall. The kernel will always use its default value (True), ignoring the caller's choice and device support check.Add the parameter to the xqa call:
xqa( query_new, k_cache, v_cache, block_tables, seq_lens_new, out, workspace_0, workspace_1, num_kv_heads, page_size, sinks=sinks_new, q_scale=q_scale_value, kv_scale=torch.tensor( [kv_scale_value], dtype=torch.float32, device=query.device ), sliding_win_size=window_left + 1 if window_left >= 0 else 0, kv_layout=kv_layout, sm_count=sm_count, + enable_pdl=enable_pdl, )
2400-2411: Critical: Squeeze singleton dimension before assigning kv_cache.When
kv_cache.shape[1] == 1, line 2404 assigns bothk_cacheandv_cacheto the same tensor without squeezing the singleton dimension. This meansk_cachehas shape[num_pages, 1, ...]instead of the expected[num_pages, ...]. Downstream shape extraction at lines 2423-2432 will read the wrong dimensions.Squeeze the singleton dimension first:
if isinstance(kv_cache, tuple): k_cache, v_cache = kv_cache else: if kv_cache.shape[1] == 1: - k_cache, v_cache = kv_cache, kv_cache + k_cache = kv_cache.squeeze(1) + v_cache = k_cache else: assert kv_cache.shape[1] == 2, ( "When kv_cache is a single tensor, the second dimension must be 1 or 2" ) # NOTE(Zihao): unbind transforms [num_pages, 2, ...] to ([num_pages, ...], [num_pages, ...]) # it doesn't change underlying storage k_cache, v_cache = kv_cache.unbind(dim=1)
🧹 Nitpick comments (1)
tests/attention/test_trtllm_gen_attention.py (1)
231-237: Remove unnecessary no-op rearrange operations.In the NHD layout branch,
einops.rearrange(k_page, "p h d -> p h d")and the corresponding v_page rearrange are no-ops since the source and target patterns are identical. These can be removed for clarity and slight performance improvement.Apply this diff:
else: # NHD layout # NHD layout: [page_size, num_kv_heads, head_dim] if j == pages_i - 1: k_page = k_page[:last_len_i, :, :] v_page = v_page[:last_len_i, :, :] - k_list.append(einops.rearrange(k_page, "p h d -> p h d")) - v_list.append(einops.rearrange(v_page, "p h d -> p h d")) + k_list.append(k_page) + v_list.append(v_page)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
flashinfer/decode.py(9 hunks)tests/attention/test_trtllm_gen_attention.py(23 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/decode.py (3)
flashinfer/xqa.py (2)
xqa(55-94)xqa(125-283)flashinfer/utils.py (7)
get_compute_capability(251-254)FP4Tensor(599-669)get_device_sm_count(595-596)ceil_div(575-586)round_up(589-591)check_shape_dtype_device(518-536)device_support_pdl(568-572)csrc/trtllm_fmha_kernel_launcher.cu (2)
trtllm_paged_attention_decode(197-265)trtllm_paged_attention_decode(197-204)
🪛 Ruff (0.14.2)
flashinfer/decode.py
2173-2173: Avoid specifying long messages outside the exception class
(TRY003)
2175-2175: Avoid specifying long messages outside the exception class
(TRY003)
2218-2218: Consider (*query.shape[:-1], ceil_div(query.shape[-1], 2)) instead of concatenation
Replace with (*query.shape[:-1], ceil_div(query.shape[-1], 2))
(RUF005)
2241-2241: Avoid specifying long messages outside the exception class
(TRY003)
2264-2268: Avoid specifying long messages outside the exception class
(TRY003)
2279-2279: Avoid specifying long messages outside the exception class
(TRY003)
2282-2282: Avoid specifying long messages outside the exception class
(TRY003)
2324-2324: Avoid specifying long messages outside the exception class
(TRY003)
2334-2334: Unused function argument: max_seq_len
(ARG001)
2341-2341: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (8)
flashinfer/decode.py (1)
2200-2205: Verify layout conversion is applied after squeeze.The HND→NHD transpose at lines 2204-2205 assumes the input has shape
[..., H, N, D]. This depends on the kv_cache unpacking at lines 2152-2163 correctly removing any singleton dimensions. If the squeeze issue at line 2156 is not fixed, the transpose will operate on the wrong axes.Once the squeeze issue is resolved, this code will be correct.
tests/attention/test_trtllm_gen_attention.py (7)
93-149: LGTM! Layout-aware KV cache generation is correct.The function properly creates KV caches with the correct dimension ordering for both HND and NHD layouts.
342-342: LGTM! Test parameterization properly extended.The addition of
kv_layoutandbackendparameters ensures comprehensive coverage across different layout configurations and backend implementations.Also applies to: 571-571, 897-898
645-662: LGTM! Backend-specific GPU checks are appropriate.The GPU architecture requirements and feature limitations for each backend are correctly validated with clear skip messages.
801-804: LGTM! Backend-specific assertions are correctly applied.The workspace buffer zero-check and wrapper testing are appropriately limited to the trtllm-gen backend where these requirements apply.
Also applies to: 842-845
819-821: LGTM! Tolerance adjustment for xqa+fp8 is reasonable.The relaxed tolerances account for the precision characteristics of the xqa backend when using fp8 KV cache.
415-415: LGTM! Layout parameter properly threaded through test chain.The
kv_layoutparameter is consistently passed through all helper functions and API calls, ensuring layout coherence across cache creation, flattening, and attention operations.Also applies to: 470-470, 506-506, 690-690, 763-763, 796-796
976-976: Verify if these HND-only restrictions are still accurate.Multiple comments state "trtllm-gen only support HND", but the PR objectives indicate that trtllm-gen now supports both HND and NHD layouts. These specialized tests (batch_size=1, head_dim=256, long sequences) restrict
kv_layoutto["HND"]only.If trtllm-gen indeed supports both layouts now, consider:
- Updating these comments to clarify that the HND restriction is specific to these test scenarios (not a general limitation)
- Or extending these tests to include NHD if the backend fully supports it
Run the following to check if there are any backend-specific layout restrictions documented elsewhere:
Also applies to: 1031-1031, 1097-1097
|
[FAILED] Pipeline #37716350: 1/17 passed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/prefill.py (1)
3398-3401: Validate kv_layout and enforce seq_lens dtype.Add early kv_layout validation (consistent with other entry points) and ensure seq_lens is int32 as expected by the kernel.
if enable_pdl is None: enable_pdl = device_support_pdl(query.device) + _check_kv_layout(kv_layout) + if seq_lens.dtype != torch.int32: + seq_lens = seq_lens.to(torch.int32)
♻️ Duplicate comments (4)
csrc/trtllm_fmha_kernel_launcher.cu (1)
297-297: Same layout naming issue as the decode path.This comment has the same naming inconsistency as line 235. Apply the same clarification here for consistency.
flashinfer/prefill.py (1)
3338-3339: Default kv_layout differs from rest of prefill APIs; add validation.Several APIs here default to "NHD", but this function defaults to "HND" and doesn't validate kv_layout. Please align defaults or document the deliberate difference and add _check_kv_layout(kv_layout).
flashinfer/decode.py (2)
2441-2442: Major: seq_lens must be uint32 with beam axis.Cast to uint32 to match the xqa wrapper’s contract.
- seq_lens_new = seq_lens.unsqueeze(1).contiguous() + seq_lens_new = seq_lens.to(torch.uint32).unsqueeze(1).contiguous()
2403-2411: Critical: squeeze single‑channel KV tensor before inferring shape.When kv_cache has shape [P, 1, ...], shape extraction misreads page_size/num_heads. Squeeze first.
- if kv_cache.shape[1] == 1: - k_cache, v_cache = kv_cache, kv_cache + if kv_cache.shape[1] == 1: + # [num_pages, 1, ...] -> [num_pages, ...] + k_cache = kv_cache.squeeze(1) + v_cache = k_cache
🧹 Nitpick comments (5)
csrc/trtllm_fmha_kernel_launcher.cu (1)
235-237: Clarify the layout naming convention.The comment states "NHD layout: [..., H, N, D]", but in standard tensor notation, NHD typically refers to (N, H, D) where N is the sequence/token dimension, H is heads, and D is the dimension. The layout shown as [..., H, N, D] would conventionally be called HND layout (heads first).
This naming inconsistency could confuse developers familiar with standard conventions and lead to integration errors when interfacing with other systems or documentation.
Consider either:
- Using standard naming (change "NHD" to "HND" in the comment), or
- Adding a clarification that explains this codebase's specific naming convention
- // Assume NHD layout: [..., H, N, D] + // Assume HND layout: [..., H, N, D]Or:
- // Assume NHD layout: [..., H, N, D] + // Assume layout: [..., H, N, D] (referred to as NHD in this codebase)flashinfer/prefill.py (1)
3338-3341: Type the sinks parameter correctly.Kernels expect a single optional tensor for sinks, not a list. Update the annotation to avoid misuse downstream.
- sinks: Optional[List[torch.Tensor]] = None, + sinks: Optional[torch.Tensor] = None,flashinfer/decode.py (3)
2062-2080: API polish: sinks typing, kv_layout default consistency, Optional usage.
- sinks should be Optional[torch.Tensor], not list.
- kv_layout defaults to "HND" here while most decode/prefill APIs default to "NHD". Align or document.
- Keep Optional types explicit.
- sinks: Optional[List[torch.Tensor]] = None, - kv_layout: str = "HND", - enable_pdl: Optional[bool] = None, + sinks: Optional[torch.Tensor] = None, + kv_layout: str = "NHD", + enable_pdl: Optional[bool] = None,Also add
_check_kv_layout(kv_layout)early in the function.
2327-2343: Fix type hints and add kv_layout validation in xqa path.
- enable_pdl should be Optional[bool]
- q_len_per_req should be int (you assert == 1)
- Validate kv_layout upfront
-def xqa_batch_decode_with_kv_cache( +def xqa_batch_decode_with_kv_cache( query: torch.Tensor, kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], workspace_buffer: torch.Tensor, block_tables: torch.Tensor, seq_lens: torch.Tensor, max_seq_len: int, bmm1_scale: float, bmm2_scale: float, window_left: int = -1, out: Optional[torch.Tensor] = None, sinks: Optional[torch.Tensor] = None, kv_layout: str = "NHD", - enable_pdl: bool = None, - q_len_per_req: Optional[int] = 1, + enable_pdl: Optional[bool] = None, + q_len_per_req: int = 1, ) -> torch.Tensor: + _check_kv_layout(kv_layout)
1873-1890: Type hint nit: enable_pdl should be Optional.Signature shows
enable_pdl: bool = None. Make itOptional[bool]for correctness.- enable_pdl: bool = None, + enable_pdl: Optional[bool] = None,
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
csrc/trtllm_fmha_kernel_launcher.cu(2 hunks)flashinfer/decode.py(10 hunks)flashinfer/prefill.py(6 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/decode.py (3)
flashinfer/xqa.py (2)
xqa(55-94)xqa(125-283)flashinfer/utils.py (7)
get_compute_capability(251-254)FP4Tensor(599-669)get_device_sm_count(595-596)ceil_div(575-586)round_up(589-591)check_shape_dtype_device(518-536)device_support_pdl(568-572)csrc/trtllm_fmha_kernel_launcher.cu (2)
trtllm_paged_attention_decode(197-265)trtllm_paged_attention_decode(197-204)
🪛 Ruff (0.14.2)
flashinfer/decode.py
2173-2173: Avoid specifying long messages outside the exception class
(TRY003)
2175-2175: Avoid specifying long messages outside the exception class
(TRY003)
2218-2218: Consider (*query.shape[:-1], ceil_div(query.shape[-1], 2)) instead of concatenation
Replace with (*query.shape[:-1], ceil_div(query.shape[-1], 2))
(RUF005)
2241-2241: Avoid specifying long messages outside the exception class
(TRY003)
2264-2268: Avoid specifying long messages outside the exception class
(TRY003)
2279-2279: Avoid specifying long messages outside the exception class
(TRY003)
2282-2282: Avoid specifying long messages outside the exception class
(TRY003)
2324-2324: Avoid specifying long messages outside the exception class
(TRY003)
2334-2334: Unused function argument: max_seq_len
(ARG001)
2341-2341: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (3)
flashinfer/prefill.py (1)
2091-2096: Good: normalize layout for trtllm-gen.The NHD→HND transpose here avoids kernel-side conditionals and matches trtllm-gen expectations. Nice and minimal.
flashinfer/decode.py (2)
1239-1243: Good: normalize layout for trtllm-gen.The NHD→HND transpose here mirrors prefill and keeps decode kernels layout-stable.
2201-2206: Good: normalize layout for trtllm-gen path.NHD→HND transpose keeps the paged cache consistent with trtllm-gen decode assumption.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (5)
flashinfer/decode.py (5)
2139-2144: Doc nit: correct architecture name.The docstring incorrectly states "sm_90 and sm_91 (ampere architecture)". SM90/SM91 are Hopper architecture; Ampere corresponds to SM80/SM86.
Apply this fix:
backend : str = "auto" The implementation backend, could be ``auto``/``xqa`` or ``trtllm-gen``. Defaults to ``auto``. When set to ``auto``, the backend will be chosen based on the device architecture and kernel availability. For sm_100 and sm_103 (blackwell architecture), ``auto`` will choose ``trtllm-gen`` backend. - For sm_90 and sm_91 (ampere architecture), ``auto`` will choose ``xqa`` backend. + For sm_90 and sm_91 (Hopper architecture), ``auto`` will choose ``xqa`` backend.
2341-2342: Fix type hints: use explicit Optional and remove Optional from non-optional parameter.
enable_pdl: bool = Noneshould beenable_pdl: Optional[bool] = Noneper PEP 484q_len_per_req: Optional[int] = 1should beq_len_per_req: int = 1(the value is always an int, never None)Apply this diff:
kv_layout: str = "NHD", - enable_pdl: bool = None, - q_len_per_req: Optional[int] = 1, + enable_pdl: Optional[bool] = None, + q_len_per_req: int = 1, ) -> torch.Tensor:As per coding guidelines.
2434-2436: Allocate uint32 semaphores buffer correctly.The code slices
workspace_u8(uint8 dtype) for semaphores, but the xqa kernel expects a uint32 tensor. This dtype mismatch will cause incorrect behavior or runtime errors.Apply this fix to properly allocate semaphores:
workspace_u8 = workspace_buffer.view(torch.uint8) - semaphore = workspace_u8[: 8 * 1024 * 1024] # reserve 8MB for semaphore - scratch = workspace_u8[8 * 1024 * 1024 :] + # Reserve 8MB for semaphores as uint32 (8MB / 4 bytes = 2M entries) + sem_bytes = 8 * 1024 * 1024 + semaphores = workspace_u8[:sem_bytes].view(torch.uint32) + scratch = workspace_u8[sem_bytes:]Based on learnings.
2165-2168: Guard unsupported GPU architectures in auto backend selection.The current logic sends any non-SM10 device to xqa, but xqa requires SM90+ (see xqa.py:227). Devices with compute capability < 9.0 will fail at runtime. Additionally, SM11 and SM12 should be explicitly handled.
Apply this fix:
if backend == "auto": - backend = ( - "trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa" - ) + major = get_compute_capability(query.device)[0] + if major == 10: + backend = "trtllm-gen" + elif major in [9, 11, 12]: + backend = "xqa" + else: + raise GPUArchitectureError( + f"No compatible decode backend for SM{major}0. " + f"Requires SM90+, SM100+, or SM110+." + )
2446-2465: Critical: Allocate output tensor and pass enable_pdl parameter.Two critical issues:
Output allocation:
outis Optional at line 2338 but xqa requires a non-Optional tensor. Whenout=None, passing None to xqa will cause a runtime error.Missing enable_pdl: The
enable_pdlflag is computed at line 2396 but never passed to the xqa call, so the kernel always uses its default behavior instead of respecting the caller's choice.Apply this fix:
+ # Allocate output if not provided + if out is None: + out = torch.empty_like(query) + xqa( query_new, k_cache, v_cache, block_tables, seq_lens_new, out, scratch, semaphores, num_kv_heads, page_size, sinks=sinks_new, q_scale=q_scale_value, kv_scale=torch.tensor( [kv_scale_value], dtype=torch.float32, device=query.device ), sliding_win_size=window_left + 1 if window_left >= 0 else 0, kv_layout=kv_layout, sm_count=sm_count, + enable_pdl=enable_pdl, )Based on learnings.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer/decode.py(9 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/decode.py (3)
flashinfer/xqa.py (2)
xqa(55-94)xqa(125-283)flashinfer/utils.py (7)
get_compute_capability(251-254)FP4Tensor(599-669)get_device_sm_count(595-596)ceil_div(575-586)round_up(589-591)check_shape_dtype_device(518-536)device_support_pdl(568-572)csrc/trtllm_fmha_kernel_launcher.cu (2)
trtllm_paged_attention_decode(197-265)trtllm_paged_attention_decode(197-204)
🪛 Ruff (0.14.2)
flashinfer/decode.py
2173-2173: Avoid specifying long messages outside the exception class
(TRY003)
2175-2175: Avoid specifying long messages outside the exception class
(TRY003)
2218-2218: Consider (*query.shape[:-1], ceil_div(query.shape[-1], 2)) instead of concatenation
Replace with (*query.shape[:-1], ceil_div(query.shape[-1], 2))
(RUF005)
2241-2241: Avoid specifying long messages outside the exception class
(TRY003)
2264-2268: Avoid specifying long messages outside the exception class
(TRY003)
2279-2279: Avoid specifying long messages outside the exception class
(TRY003)
2282-2282: Avoid specifying long messages outside the exception class
(TRY003)
2324-2324: Avoid specifying long messages outside the exception class
(TRY003)
2334-2334: Unused function argument: max_seq_len
(ARG001)
2341-2341: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (5)
flashinfer/decode.py (5)
24-24: LGTM: Import of xqa module.The import is correctly placed and necessary for the new xqa backend integration.
1238-1242: LGTM: Layout conversion for trtllm-gen backend.The NHD→HND conversion via
transpose(-3, -2)correctly swaps the sequence and head dimensions as needed by the trtllm-gen backend. The conditional ensures it only applies when necessary.
2170-2199: LGTM: XQA backend routing with proper guards.The xqa backend path correctly validates unsupported features (nvfp4 output, scale factors) and allocates output tensors as needed before routing to
xqa_batch_decode_with_kv_cache. The parameter forwarding is complete.
2200-2324: LGTM: TRTLLM-gen backend with layout conversion and nvfp4 support.The trtllm-gen path correctly:
- Converts NHD→HND layout when necessary
- Handles nvfp4 output with proper shape calculations and scale factor management
- Falls back to standard output types with appropriate validation
- Passes all parameters including enable_pdl to the kernel
2438-2444: LGTM: Scale computation and input preparation.The scale calculations correctly compute
q_scale_value = bmm1_scale / kv_scale_value * (head_dim**0.5)and the input tensors are properly reshaped withunsqueeze(1)and made contiguous for the xqa kernel.
|
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/attention/test_xqa.py (1)
236-275: Consider extracting a helper for cache creation to reduce duplication.The cache tensor creation logic is duplicated between NHD and HND branches. While functional, extracting this into a helper function could improve maintainability.
Example refactor:
def create_cache_tensors(layout, total_num_pages, tokens_per_page, nb_k_heads, valid_elems_per_head, dtype, device): if layout == "NHD": shape = (total_num_pages, tokens_per_page, nb_k_heads, valid_elems_per_head) else: # HND shape = (total_num_pages, nb_k_heads, tokens_per_page, valid_elems_per_head) cache_k = torch.zeros(*shape, dtype=dtype, device=device) cache_v = torch.zeros(*shape, dtype=dtype, device=device) return cache_k, cache_v cache_k_heads, cache_v_heads = create_cache_tensors( kv_layout, total_num_pages, tokens_per_page, nb_k_heads, valid_elems_per_head, input_type, "cuda" )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
csrc/xqa/mla_sm120.cu(10 hunks)tests/attention/test_xqa.py(12 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
csrc/xqa/mla_sm120.cu (2)
csrc/xqa/hostUtils.h (1)
makeLaunchConfig(4-12)csrc/xqa/tensorMap.cpp (2)
makeTensorMapForPagedKVCache(75-108)makeTensorMapForPagedKVCache(75-79)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (6)
tests/attention/test_xqa.py (3)
43-60: LGTM! Clean implementation of dual KV cache layout support.The
kv_layoutparameter addition toCacheSeqcorrectly handles both NHD and HND layouts with appropriate tensor indexing. The default value maintains backward compatibility.
301-324: LGTM! Correct handling of both cache layouts.The
cache_head_athelper correctly implements layout-aware cache indexing for both NHD and HND formats. While it duplicates some logic fromCacheSeq.__getitem__, the direct cache access is more efficient for the padding zeroing use case.
446-616: LGTM! MLA test correctly uses NHD layout.The
enable_pdlparameter addition and cache initialization properly reflect that MLA only supports NHD layout. TheCacheSeqconstruction correctly relies on the defaultkv_layout="NHD", which is appropriate for MLA.csrc/xqa/mla_sm120.cu (3)
84-102: LGTM! Good addition of bounds checking for paged KV cache.The
nbPagesparameter addition toKVTilePartLoaderenables proper validation of page indices (used at line 144 for bound checks). The simplifiedbaseOffsetcalculation usingidxReq * cacheList.maxNbPagesPerSeqis cleaner and more maintainable.
1651-1723: LGTM! Clean integration of layout-flexible KV cache support.The additions to
launchMLAproperly thread through:
- Separate K/V cache pool pointers for explicit memory management
enable_pdlfor programmatic dependent launch control- Stride parameters (
kv_stride_page,kv_stride_token,kv_stride_head) enabling both NHD and HND layoutsThe stride parameters are correctly propagated to
makeTensorMapForPagedKVCachefor both K and V tensor maps, enabling the flexible cache layouts tested in the Python layer.
1773-1831: LGTM! Consistent layout support in FlashInfer variant.
launchMLAFlashInfermirrors thelaunchMLAchanges with the same KV cache pool and stride parameter additions. The consistent API design between both launch functions maintains a clean interface.
|
[SUCCESS] Pipeline #37735466: 13/17 passed |
<!-- .github/pull_request_template.md --> ## 📌 Description In #2001 , XQA decode kernels became available through `trtllm_batch_decode_with_kv_cache` on SM90 and SM120. Current PR adds the ability to benchmark through the microbenchmark. Example microbenchmark command and outputs before and after: ``` ### Before current PR: ## SM90 (H200) $ python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 trtllm-gen-native cudnn --page_size 32 --batch_size 1 --s_qo 1 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --q_dtype bfloat16 --kv_dtype bfloat16 --refcheck --use_cupti [WARNING] trtllm-gen-native for routine BatchDecodeWithPagedKVCacheWrapper is not supported on compute capability 9.0. Skipping. [PERF] fa2 :: median time 0.035 ms; std 0.002 ms; achieved tflops 7.721 TFLOPs/sec; achieved tb_per_sec 0.966 TB/sec [PERF] cudnn :: median time 0.020 ms; std 0.000 ms; achieved tflops 13.519 TFLOPs/sec; achieved tb_per_sec 1.692 TB/sec ## SM120 (RTX 5090) $ python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 trtllm-gen-native cudnn --page_size 32 --batch_size 1 --s_qo 1 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --q_dtype bfloat16 --kv_dtype bfloat16 --refcheck --use_cupti [WARNING] trtllm-gen-native for routine BatchDecodeWithPagedKVCacheWrapper is not supported on compute capability 12.0. Skipping. [PERF] fa2 :: median time 0.033 ms; std 0.001 ms; achieved tflops 8.204 TFLOPs/sec; achieved tb_per_sec 1.027 TB/sec [PERF] cudnn :: median time 0.030 ms; std 0.000 ms; achieved tflops 8.943 TFLOPs/sec; achieved tb_per_sec 1.119 TB/sec ### After current PR: ## SM90 (H200) $ python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 trtllm-gen-native cudnn --page_size 32 --batch_size 1 --s_qo 1 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --q_dtype bfloat16 --kv_dtype bfloat16 --refcheck --use_cupti [PERF] fa2 :: median time 0.035 ms; std 0.002 ms; achieved tflops 7.721 TFLOPs/sec; achieved tb_per_sec 0.966 TB/sec [PERF] trtllm-gen-nati:: median time 0.019 ms; std 0.002 ms; achieved tflops 13.820 TFLOPs/sec; achieved tb_per_sec 1.729 TB/sec [PERF] cudnn :: median time 0.020 ms; std 0.000 ms; achieved tflops 13.574 TFLOPs/sec; achieved tb_per_sec 1.698 TB/sec ## SM120 (RTX 5090) $ python3 flashinfer_benchmark.py --routine BatchDecodeWithPagedKVCacheWrapper --backends fa2 trtllm-gen-native cudnn --page_size 32 --batch_size 1 --s_qo 1 --s_kv 8192 --num_qo_heads 64 --num_kv_heads 8 --head_dim_qk 128 --head_dim_vo 128 --q_dtype bfloat16 --kv_dtype bfloat16 --refcheck --use_cupti [PERF] fa2 :: median time 0.033 ms; std 0.001 ms; achieved tflops 8.121 TFLOPs/sec; achieved tb_per_sec 1.016 TB/sec [PERF] trtllm-gen-nati:: median time 0.034 ms; std 0.001 ms; achieved tflops 7.903 TFLOPs/sec; achieved tb_per_sec 0.989 TB/sec [PERF] cudnn :: median time 0.030 ms; std 0.001 ms; achieved tflops 9.020 TFLOPs/sec; achieved tb_per_sec 1.129 TB/sec ``` <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Chores** * Standardized backend identifier to "trtllm-native" and expanded its support across benchmark routines and utilities. * Argument parsing now canonicalizes deprecated backend aliases and emits a deprecation warning when encountered. * **Documentation** * README and tool-facing messages updated to use the canonical backend name and include contextual notes about the change. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
📌 Description
Expose xqa backend to trtllm attention interface, and improve layout coverage of trtllm-gen and xqa backends.
Now both trtllm-gen/xqa supports NHD/HND kv-cache layout.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Improvements
Tests