-
-
Notifications
You must be signed in to change notification settings - Fork 15k
[Bugfix] Disable TRTLLM attention when KV transfer is enabled #33192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1236562
2061e21
0c0a874
1e9d4d7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -573,6 +573,20 @@ def __init__( | |
| # try to use fp8 q if kv cache is fp8, and will fall back to model dtype | ||
| # if TRTLLM attention kernel is not used when building attn metadata | ||
| can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads) | ||
|
|
||
| # TRTLLM attention requires strictly contiguous KV cache tensors. | ||
| # When KV transfer (P/D disaggregation) is enabled, the KV cache may be | ||
| # permuted into non-contiguous views, which causes assertion failures. | ||
| self._kv_transfer_enabled = vllm_config.kv_transfer_config is not None | ||
| if can_use_trtllm and self._kv_transfer_enabled: | ||
| logger.info_once( | ||
| "TRTLLM attention is disabled because KV transfer " | ||
| "(P/D disaggregation) is enabled. TRTLLM attention requires " | ||
| "strictly contiguous KV cache tensors which may not be " | ||
| "guaranteed with KV transfer." | ||
| ) | ||
| can_use_trtllm = False | ||
|
Comment on lines
+581
to
+588
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This block of code disables TRTLLM attention when KV transfer is enabled. It's crucial to ensure that this disabling mechanism is robust and doesn't inadvertently affect other scenarios where TRTLLM attention could be beneficial. Consider adding a more specific check to ensure TRTLLM is only disabled when the non-contiguous KV cache issue is present, perhaps by checking the specific KV connector being used. |
||
|
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| if ( | ||
| can_use_trtllm | ||
| and not vllm_config.attention_config.disable_flashinfer_q_quantization | ||
|
|
@@ -822,6 +836,9 @@ def build( | |
| has_sinks=self.has_sinks, | ||
| has_spec=uses_spec_reorder, | ||
| ) | ||
| # KV transfer requires non-contiguous KV cache views, incompatible with TRTLLM | ||
| if self._kv_transfer_enabled: | ||
| prefill_use_trtllm = False | ||
|
Comment on lines
+839
to
+841
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This check disables TRTLLM attention for prefill when KV transfer is enabled. It's important to verify that this disabling mechanism is correctly applied and doesn't inadvertently affect other scenarios where TRTLLM attention could be beneficial. Ensure that this condition is specific to the KV transfer configuration that causes the non-contiguous KV cache issue. |
||
| decode_use_trtllm = ( | ||
| self.use_trtllm_decode_attention and self.dcp_world_size <= 1 | ||
| ) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should probably use
get_kv_cache_layoutto check layout is != NHD which would ensure logical<>physical layout match, but also disable heterogeneous TP deployments.@ZhanqiuHu Can you verify this is the case, so that if one starts an instance on b200 with
VLLM_KV_CACHE_LAYOUT=NHDit doesnt crash?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @NickLucche, I tested using
get_kv_cache_layout() != "NHD"instead of checkingkv_transfer_config.Although, on B200, even with
VLLM_KV_CACHE_LAYOUT=NHD, the layout is still forced to HND (with_KV_CACHE_LAYOUT_OVERRIDE), and it triggers the check to disable TRTLLM:Notes:
get_kv_cache_layout()in utils.py - override takes priority over env varselector.pysetting override - callsset_kv_cache_layout()to set_KV_CACHE_LAYOUT_OVERRIDEFlashInfer.get_required_kv_cache_layout()- returns HND on B200There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay so FI is forcing HND. Let's go with your original approach to stay on the safe side