-
Notifications
You must be signed in to change notification settings - Fork 3.3k
To implement attention-sink #12504
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
To implement attention-sink #12504
Conversation
Summary of ChangesHello @haotang1995, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces an implementation of the 'attention sink' technique to improve the efficiency of large language model reasoning in sglang, particularly for long input sequences. The core changes involve modifying the attention metadata handling to use a fixed-window page table for sequences longer than 8192 tokens. Although the implementation aims to replicate successful results from other frameworks, the current changes lead to unexpected model behavior, specifically repetitive output, which the author is seeking help to resolve. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces an implementation of attention sink to support efficient long-context reasoning. The core logic change is in python/sglang/srt/layers/attention/flashattention_backend.py to manipulate the page_table and sequence lengths passed to the attention kernel. While the approach is generally correct, there's a subtle bug in how the sequence length is handled for the attention sink case, which likely causes the incorrect model behavior described. I've provided a fix and some suggestions to improve code clarity and robustness.
| if forward_batch.seq_lens_cpu.max().item() > 8192: | ||
| metadata.page_table = torch.cat([ | ||
| forward_batch.req_to_token_pool.req_to_token[forward_batch.req_pool_indices, : 1024], | ||
| forward_batch.req_to_token_pool.req_to_token[forward_batch.req_pool_indices, forward_batch.seq_lens_cpu.max().item() - 7168: forward_batch.seq_lens_cpu.max().item()] | ||
| ], dim=1) | ||
| else: | ||
| metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ | ||
| forward_batch.req_pool_indices, : metadata.max_seq_len_k | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for handling attention sink is mostly correct, but there's a subtle issue. When forward_batch.seq_lens_cpu.max().item() > 8192, you correctly construct a page_table of size 8192. However, metadata.cache_seqlens_int32 is derived from seqlens_in_batch which is clamped from the original sequence length. While this also results in 8192, it's less direct and couples this logic to the clamp operation on line 447.
A more direct and robust way to fix the issue and improve clarity is to explicitly set metadata.cache_seqlens_int32 to the correct size of the attention window when the sink is active. This also resolves the issue from clamping seqlens_in_batch which could have unintended side effects.
I'd also recommend using named constants for the magic numbers 8192, 1024, and 7168 to improve readability and maintainability.
ATTENTION_SINK_SIZE = 1024
ATTENTION_WINDOW_SIZE = 7168
MAX_ATTENTION_SEQ_LEN = ATTENTION_SINK_SIZE + ATTENTION_WINDOW_SIZE
if forward_batch.seq_lens_cpu.max().item() > MAX_ATTENTION_SEQ_LEN:
metadata.page_table = torch.cat([
forward_batch.req_to_token_pool.req_to_token[forward_batch.req_pool_indices, :ATTENTION_SINK_SIZE],
forward_batch.req_to_token_pool.req_to_token[forward_batch.req_pool_indices, forward_batch.seq_lens_cpu.max().item() - ATTENTION_WINDOW_SIZE: forward_batch.seq_lens_cpu.max().item()]
], dim=1)
# Explicitly set the sequence length for the attention kernel to the attention window size.
seqlens_in_batch = seqlens_in_batch.new_full(seqlens_in_batch.shape, MAX_ATTENTION_SEQ_LEN)
else:
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)| self.forward_metadata_spec_decode_expand = metadata_expand | ||
| else: | ||
| # Normal Decode | ||
| seqlens_in_batch = torch.clamp(seqlens_in_batch, max=8192) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation clamps seqlens_in_batch to a maximum of 8192. However, for the attention sink mechanism to work correctly with paged attention, the cache_seqlens passed to the attention kernel should reflect the size of the constructed page_table, which is 1024 + 7168 = 8192. The actual sequence length can be much larger, and clamping it here might lead to incorrect behavior in other parts of the system that might rely on the original sequence length. A better approach is to set the sequence length for the attention kernel directly to the size of the attention window when the sink is active, without modifying the main seqlens_in_batch variable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a good call out, for now lets make it work for page size = 1 and ignore this msg
|
sorry i was very busy recently, just did brief look and this pr seems very easy to understand I will do some research on attn sink and get back to you |
| if forward_batch.seq_lens_cpu.max().item() > 8192: | ||
| metadata.page_table = torch.cat([ | ||
| forward_batch.req_to_token_pool.req_to_token[forward_batch.req_pool_indices, : 1024], | ||
| forward_batch.req_to_token_pool.req_to_token[forward_batch.req_pool_indices, forward_batch.seq_lens_cpu.max().item() - 7168: forward_batch.seq_lens_cpu.max().item()] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems bit suspicious, we may need to keep the last 7168 tokens for each sequence.
Maybe you can start with for loop
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, I will test for-loop as well then. But the fact that the code does not work even with batch-size=1 makes it feel like there are other problems other than this. Thanks
Motivation
To implement attention sink with customized window sizes (first-window-size=1024 and last-window-size=7168) for efficient reasoning. When implemented with huggingface transformers, it saves memories and compute without hurting the downstream tasks' performance (such as pass@1 on AIME25 for math reasoning) for various LLMs such as Qwen3. I would like to implement this in sglang as well but failed as discussed below. I could not figure out why and it would be great if someone can help check it out. Any suggestions and advice are welcome.
Modifications
Following the blog of FA3 implementation in sglang, I change the logic of
init_forward_metadataso that the maximum seqlens become 8192 and the page-tables become the concatenation of the first-window and last-window.Accuracy Tests
The model behaves differently from expected (as implemented in huggingface transformers). It will not generate reasonable sentences after 8K tokens and instead repeat single words like
\n\nSo \n\nSo \n\nSo \n\nSo \n\nSo \n\nSo \n\nSo \n\nSo \n\nSo \n\nSo.I attached a script
test_attention_sink.shfor testing. To launch servers, I runTo submit a request, I run