Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,20 @@ def __init__(
# try to use fp8 q if kv cache is fp8, and will fall back to model dtype
# if TRTLLM attention kernel is not used when building attn metadata
can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads)

# TRTLLM attention requires strictly contiguous KV cache tensors.
# When KV transfer (P/D disaggregation) is enabled, the KV cache may be
# permuted into non-contiguous views, which causes assertion failures.
self._kv_transfer_enabled = vllm_config.kv_transfer_config is not None
if can_use_trtllm and self._kv_transfer_enabled:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably use get_kv_cache_layout to check layout is != NHD which would ensure logical<>physical layout match, but also disable heterogeneous TP deployments.

@ZhanqiuHu Can you verify this is the case, so that if one starts an instance on b200 with
VLLM_KV_CACHE_LAYOUT=NHD it doesnt crash?

Copy link
Copy Markdown
Contributor Author

@ZhanqiuHu ZhanqiuHu Jan 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @NickLucche, I tested using get_kv_cache_layout() != "NHD" instead of checking kv_transfer_config.

Although, on B200, even with VLLM_KV_CACHE_LAYOUT=NHD, the layout is still forced to HND (with _KV_CACHE_LAYOUT_OVERRIDE), and it triggers the check to disable TRTLLM:

Screenshot 2026-01-29 at 12 46 05 PM

Notes:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay so FI is forcing HND. Let's go with your original approach to stay on the safe side

logger.info_once(
"TRTLLM attention is disabled because KV transfer "
"(P/D disaggregation) is enabled. TRTLLM attention requires "
"strictly contiguous KV cache tensors which may not be "
"guaranteed with KV transfer."
)
can_use_trtllm = False
Comment on lines +581 to +588
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block of code disables TRTLLM attention when KV transfer is enabled. It's crucial to ensure that this disabling mechanism is robust and doesn't inadvertently affect other scenarios where TRTLLM attention could be beneficial. Consider adding a more specific check to ensure TRTLLM is only disabled when the non-contiguous KV cache issue is present, perhaps by checking the specific KV connector being used.


Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_cudagraph_support ignores KV transfer disabling TRTLLM

Medium Severity

The get_cudagraph_support() classmethod returns AttentionCGSupport.UNIFORM_BATCH when TRTLLM is hardware-supported, but doesn't account for KV transfer being enabled. With the new fix disabling TRTLLM when kv_transfer_config is set, this method will incorrectly report UNIFORM_BATCH support when the actual runtime support is only UNIFORM_SINGLE_TOKEN_DECODE. This causes incorrect cudagraph mode selection in gpu_model_runner.py, as the system believes it has TRTLLM-backed uniform batch support when it doesn't.

Fix in Cursor Fix in Web

if (
can_use_trtllm
and not vllm_config.attention_config.disable_flashinfer_q_quantization
Expand Down Expand Up @@ -822,6 +836,9 @@ def build(
has_sinks=self.has_sinks,
has_spec=uses_spec_reorder,
)
# KV transfer requires non-contiguous KV cache views, incompatible with TRTLLM
if self._kv_transfer_enabled:
prefill_use_trtllm = False
Comment on lines +839 to +841
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This check disables TRTLLM attention for prefill when KV transfer is enabled. It's important to verify that this disabling mechanism is correctly applied and doesn't inadvertently affect other scenarios where TRTLLM attention could be beneficial. Ensure that this condition is specific to the KV transfer configuration that causes the non-contiguous KV cache issue.

decode_use_trtllm = (
self.use_trtllm_decode_attention and self.dcp_world_size <= 1
)
Expand Down