[Kernels][FI] Skip trtllm attention when num_kv_heads=1#30842
[Kernels][FI] Skip trtllm attention when num_kv_heads=1#30842vllm-bot merged 2 commits intovllm-project:mainfrom
Conversation
Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
There was a problem hiding this comment.
Code Review
This pull request addresses a runtime error that occurs when using TRTLLM attention with num_kv_heads=1. The fix involves adding checks to disable this configuration and fall back to FlashInfer's native attention, which is the correct approach. A corresponding test case has been added to ensure this behavior is enforced. The changes are logical and well-implemented. I have one suggestion to improve the clarity and maintainability of the logic in can_use_trtllm_attention.
| # num_kv_heads=1 is not supported due to TMA descriptor building limitations. | ||
| # When num_kv_heads=1, the KV cache strides become degenerate (stride_heads == | ||
| # stride_batch), which causes CUDA's cuTensorMapEncodeTiled to fail because | ||
| # TMA descriptors cannot handle degenerate 4D tensors with singleton dimensions. | ||
| # See: https://fburl.com/352mrydz | ||
| if has_trtllm and num_kv_heads == 1: | ||
| logger.warning_once( | ||
| "TRTLLM attention does not support num_kv_heads=1. " | ||
| "This configuration causes TMA descriptor building to fail due to " | ||
| "degenerate tensor strides. Falling back to FlashInfer attention." | ||
| ) | ||
| return has_trtllm and (num_qo_heads % num_kv_heads == 0) and (num_kv_heads != 1) |
There was a problem hiding this comment.
The logic to handle num_kv_heads=1 is correct, but its implementation could be simplified for better readability and maintainability. The current structure separates the warning log from the return logic, making it slightly convoluted. By using an early return for the num_kv_heads == 1 case, we can make the function's control flow more direct and easier to follow.
# num_kv_heads=1 is not supported due to TMA descriptor building limitations.
# When num_kv_heads=1, the KV cache strides become degenerate (stride_heads ==
# stride_batch), which causes CUDA's cuTensorMapEncodeTiled to fail because
# TMA descriptors cannot handle degenerate 4D tensors with singleton dimensions.
# See: https://fburl.com/352mrydz
if num_kv_heads == 1:
if has_trtllm:
logger.warning_once(
"TRTLLM attention does not support num_kv_heads=1. "
"This configuration causes TMA descriptor building to fail due to "
"degenerate tensor strides. Falling back to FlashInfer attention."
)
return False
return has_trtllm and (num_qo_heads % num_kv_heads == 0)|
Or wondering if we can support num_kv_heads=1 in FlashInfer trtllm kernel, cc: @yzh119 |
|
we run into the problem on a smaller debug model run. probably normal sized model wouldn't run into these issues. cc: @pavanimajety @mgoin if the change makes sense to you. |
|
@yeqcharlotte Could you file a FlashInfer GitHub issue https://github.com/flashinfer-ai/flashinfer with repro steps so that we can investigate this issue? Our expectation is that the trtllm attention kernel should support num_kv_heads=1 (namely, MQA). We have tested various MQA tests and it worked for us, so we want to fix this. |
|
@yeqcharlotte Did you see any log for the failed TMADescriptor? From this line in Flashinfer - |
@pavanimajety included it in the summary section: |
…#30842) Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
|
Looks like a misconfiguration of tma descriptor, the stride 1969767282 looks suspicious to me. |
|
@yeqcharlotte This PR broke GPT-OSS TP8 fuctionally (see #30919 ). Is it possible to revert this PR for now and apply this patch in your dev branch locally, until FlashInfer team finds out the root cause so that we can apply a more narrow check? The check cc @mgoin |
…#30842) Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com> Signed-off-by: Ubuntu <mjtaheri68@gmail.com>
…-project#30842)" This reverts commit a100152. This PR causes GPT-OSS-120B TP8 has functional issue(NotImplementedError: FlashInfer backend currently does not support attention sinks). Signed-off-by: shyeh25 <206795756+shyeh25@users.noreply.github.com>
…-project#30842)" This reverts commit a100152. This PR causes GPT-OSS-120B TP8 has functional issue(NotImplementedError: FlashInfer backend currently does not support attention sinks). Signed-off-by: shyeh25 <206795756+shyeh25@users.noreply.github.com>
…#30842) Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com> Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
…#30842) Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
Purpose
We got the following error when running a small model on blackwell
Detailed errors
We confirmed that the dummy run failed when num_kv_heads=1 which causes stride of 0. in those case we can fall back to flashinfer's native attention instead of trtllm. Existing test cases ignore this any way.
Test Plan
Test Result
Our internal run succeed on this. Remaining will depend on CI.
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.