[Bugfix] [Kernel] Triton attention kernels: mask out V blocks that fall outside sliding window#30887
Conversation
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
There was a problem hiding this comment.
Code Review
This pull request effectively addresses a critical bug in the 3D Triton attention kernel by correctly masking out V blocks that fall outside the sliding window. This prevents potential NaN corruption in the output, as described in the PR purpose. The implementation uses tl.where for conditional masking, which is an appropriate and efficient approach within Triton kernels. The logic for determining if a V block is within the sliding window appears correct and consistent with how attention scores (S) are handled. This fix significantly improves the correctness and stability of the 3D kernel.
|
While this does fix that curl request for me, I'm still getting a lot of repeated "!!!!" generations (token id 0) using the gsm8k lm_eval dataset on my A5500. So, I believe this helps, but does not fix all cases where this kind of things has been reported on Triton attention. If you want to reproduce what I'm seeing (which copies what was reported in #29539), install lm_eval, then spin up gpt-oss-20b and run the gsm8k dataset test: Serve gpt-oss-20b in vLLM with TRITON_ATTN Run lm_eval Grep for repeated I had 56 different samples with this infinite token id 0 repeated generation when testing this change. For reference, I get zero when testing the change from #30650 |
|
My output from multiple runs of this yesterday and today: |
|
Thanks @bbrowning - looking into it |
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
|
Oh right, the issue is that the fix also needs to be applied to the 2D kernel. I think @bbrowning mentioned on Slack that he had also observed this for 2D kernel actually. While in the 2D kernel we prune tiles that fall fully outside the sliding window, a tile can contain multiple KV blocks, some of which might actually fall outside the window, some of which might be inside. So pruning tiles is not sufficient, we also need to mask out the blocks within those tiles that fall outside the window. I ran your reproducer above and debugged this as follows: I hope this issue is now properly fixed. |
|
I can confirm this latest fix to both the 2D and 3D triton kernels removes all the infinite generations in both the manual curl test case and the gsm8k eval. Thank you for figuring this out! |
|
@Isotr0py Could you help review this since you've been working on these kernels recently? |
…ll outside sliding window (vllm-project#30887) Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
…ll outside sliding window (vllm-project#30887) Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: Ubuntu <mjtaheri68@gmail.com>
…ll outside sliding window (vllm-project#30887) Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
…ll outside sliding window (vllm-project#30887) Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
|
Thanks for the fix. Could you confirm whether this patch has been included in the latest ARM image? I pulled the latest image, but I still hit the same issue with gpt-oss-120b: a response that should be around ~300 tokens keeps generating until it reaches the max_tokens limit. This looks similar to what’s described in the link when I enable --enforce-eager, the issue disappears. |
|
It’s hard for me to reliably reproduce this issue because it only occurs when the system is under heavy load with many concurrent requests. On a server that hasn’t handled any prior requests, the output is normal. |
|
Which attention backend are you using? This fix relates specifically to Triton |
|
Sorry for the very late reply. I’m not very familiar with vLLM, but my understanding is that it was started with the default settings. I’m not sure whether Triton is being used. Code is here: |
Purpose
There is currently a bug in the Triton attention kernels where we don't correctly mask out V blocks that fall out of the sliding window. On main, we can be reading garbage blocks (that may even contain NaN values) which will corrupt the output. This PR resolves it
Potentially fix:
Test Plan
Server:
Client:
On main the above hangs after the first request.
Test Result
After, the PR, the test does not hang anymore.
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.