Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions tests/kernels/attention/test_flashinfer_trtllm_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,3 +455,38 @@ def test_flashinfer_trtllm_prefill_with_baseline(
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol),
f"{torch.max(torch.abs(output - output_trtllm))}",
)


def test_trtllm_attention_rejects_num_kv_heads_1() -> None:
"""Test that TRTLLM attention correctly rejects num_kv_heads=1.

When num_kv_heads=1 (MQA), the KV cache strides become degenerate
(stride_heads == stride_batch), which causes CUDA's cuTensorMapEncodeTiled
to fail because TMA descriptors cannot handle degenerate 4D tensors with
singleton dimensions.

This test verifies that can_use_trtllm_attention returns False for
num_kv_heads=1 configurations.
"""
from vllm.utils.flashinfer import can_use_trtllm_attention

# num_kv_heads=1 should be rejected
assert not can_use_trtllm_attention(num_qo_heads=64, num_kv_heads=1), (
"can_use_trtllm_attention should return False for num_kv_heads=1"
)
assert not can_use_trtllm_attention(num_qo_heads=32, num_kv_heads=1), (
"can_use_trtllm_attention should return False for num_kv_heads=1"
)

# num_kv_heads > 1 should be accepted (if platform supports it)
# Note: This may return False on non-Blackwell platforms, which is fine
result_kv8 = can_use_trtllm_attention(num_qo_heads=64, num_kv_heads=8)
result_kv1 = can_use_trtllm_attention(num_qo_heads=64, num_kv_heads=1)

# Even if platform doesn't support TRTLLM, num_kv_heads=1 should never
# return True when num_kv_heads > 1 returns True
if result_kv8:
assert not result_kv1, (
"If TRTLLM is supported for num_kv_heads=8, "
"it must be rejected for num_kv_heads=1"
)
22 changes: 21 additions & 1 deletion vllm/utils/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,18 @@ def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
if force_use_trtllm_attention() is False:
return False
has_trtllm = supports_trtllm_attention()
return has_trtllm and (num_qo_heads % num_kv_heads == 0)
# num_kv_heads=1 is not supported due to TMA descriptor building limitations.
# When num_kv_heads=1, the KV cache strides become degenerate (stride_heads ==
# stride_batch), which causes CUDA's cuTensorMapEncodeTiled to fail because
# TMA descriptors cannot handle degenerate 4D tensors with singleton dimensions.
# See: https://fburl.com/352mrydz
if has_trtllm and num_kv_heads == 1:
logger.warning_once(
"TRTLLM attention does not support num_kv_heads=1. "
"This configuration causes TMA descriptor building to fail due to "
"degenerate tensor strides. Falling back to FlashInfer attention."
)
return has_trtllm and (num_qo_heads % num_kv_heads == 0) and (num_kv_heads != 1)
Comment on lines +308 to +319
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic to handle num_kv_heads=1 is correct, but its implementation could be simplified for better readability and maintainability. The current structure separates the warning log from the return logic, making it slightly convoluted. By using an early return for the num_kv_heads == 1 case, we can make the function's control flow more direct and easier to follow.

    # num_kv_heads=1 is not supported due to TMA descriptor building limitations.
    # When num_kv_heads=1, the KV cache strides become degenerate (stride_heads ==
    # stride_batch), which causes CUDA's cuTensorMapEncodeTiled to fail because
    # TMA descriptors cannot handle degenerate 4D tensors with singleton dimensions.
    # See: https://fburl.com/352mrydz
    if num_kv_heads == 1:
        if has_trtllm:
            logger.warning_once(
                "TRTLLM attention does not support num_kv_heads=1. "
                "This configuration causes TMA descriptor building to fail due to "
                "degenerate tensor strides. Falling back to FlashInfer attention."
            )
        return False

    return has_trtllm and (num_qo_heads % num_kv_heads == 0)



def use_trtllm_attention(
Expand Down Expand Up @@ -355,6 +366,15 @@ def use_trtllm_attention(
)
return False

# num_kv_heads=1 is not supported
if num_kv_heads == 1:
if force_use_trtllm:
logger.warning_once(
"TRTLLM attention does not support num_kv_heads=1, "
"but --attention-config.use_trtllm_attention is set to 1"
)
return False

if has_spec and not is_prefill:
# Speculative decoding requires TRTLLM attention for decodes
logger.info_once("Using TRTLLM attention (enabled for speculative decoding).")
Expand Down