Skip to content

Conversation

@haotang1995
Copy link

Motivation

To implement attention sink with customized window sizes (first-window-size=1024 and last-window-size=7168) for efficient reasoning. When implemented with huggingface transformers, it saves memories and compute without hurting the downstream tasks' performance (such as pass@1 on AIME25 for math reasoning) for various LLMs such as Qwen3. I would like to implement this in sglang as well but failed as discussed below. I could not figure out why and it would be great if someone can help check it out. Any suggestions and advice are welcome.

Modifications

Following the blog of FA3 implementation in sglang, I change the logic of init_forward_metadata so that the maximum seqlens become 8192 and the page-tables become the concatenation of the first-window and last-window.

Accuracy Tests

The model behaves differently from expected (as implemented in huggingface transformers). It will not generate reasonable sentences after 8K tokens and instead repeat single words like \n\nSo \n\nSo \n\nSo \n\nSo \n\nSo \n\nSo \n\nSo \n\nSo \n\nSo \n\nSo.

I attached a script test_attention_sink.sh for testing. To launch servers, I run

python -m sglang.launch_server --model Qwen/Qwen3-8B --attention-backend fa3 --disable-cuda-graph

To submit a request, I run

curl http://localhost:30000/v1/chat/completions   -H "Content-Type: application/json"   -d '{
    "messages": [
        {"role": "user", "content": "There are $n$ values of $x$ in the interval $0<x<2\\pi$ where $f(x)=\\sin(7\\pi\\cdot\\sin(5x))=0$. For $t$ of these $n$ values of $x$, the graph of $y=f(x)$ is tangent to the $x$-axis. Find $n+t$. Put your answer in \\boxed{}. For multiple-choice questions, only put the choice letter in the box."}
    ],
    "max_tokens": 16384
}'

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @haotang1995, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces an implementation of the 'attention sink' technique to improve the efficiency of large language model reasoning in sglang, particularly for long input sequences. The core changes involve modifying the attention metadata handling to use a fixed-window page table for sequences longer than 8192 tokens. Although the implementation aims to replicate successful results from other frameworks, the current changes lead to unexpected model behavior, specifically repetitive output, which the author is seeking help to resolve.

Highlights

  • Attention Sink Implementation Attempt: This pull request attempts to integrate the 'attention sink' mechanism into the sglang framework, aiming for more efficient reasoning with large language models by managing attention over long sequences. The author followed a blog post on FA3 implementation as a guide.
  • Modified init_forward_metadata for Decode Mode: The init_forward_metadata function in flashattention_backend.py has been updated to support custom window sizes (1024 for the first window and 7168 for the last window) for the page table when the sequence length exceeds 8192 tokens in decode mode. This includes clamping seqlens_in_batch and max_seq_len_k to 8192.
  • Batch Size Restriction: A new assertion assert batch_size == 1 has been added for decode mode when the sequence length exceeds 8192, indicating that the attention sink logic, as currently implemented, is only supported for single-request decoding.
  • Test Script for Reproduction: A new shell script, test_attention_sink.sh, has been added to provide a reproducible example for testing the attention sink implementation and demonstrating the observed incorrect model behavior (repetition of single words) after 8K tokens.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an implementation of attention sink to support efficient long-context reasoning. The core logic change is in python/sglang/srt/layers/attention/flashattention_backend.py to manipulate the page_table and sequence lengths passed to the attention kernel. While the approach is generally correct, there's a subtle bug in how the sequence length is handled for the attention sink case, which likely causes the incorrect model behavior described. I've provided a fix and some suggestions to improve code clarity and robustness.

Comment on lines 457 to 465
if forward_batch.seq_lens_cpu.max().item() > 8192:
metadata.page_table = torch.cat([
forward_batch.req_to_token_pool.req_to_token[forward_batch.req_pool_indices, : 1024],
forward_batch.req_to_token_pool.req_to_token[forward_batch.req_pool_indices, forward_batch.seq_lens_cpu.max().item() - 7168: forward_batch.seq_lens_cpu.max().item()]
], dim=1)
else:
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic for handling attention sink is mostly correct, but there's a subtle issue. When forward_batch.seq_lens_cpu.max().item() > 8192, you correctly construct a page_table of size 8192. However, metadata.cache_seqlens_int32 is derived from seqlens_in_batch which is clamped from the original sequence length. While this also results in 8192, it's less direct and couples this logic to the clamp operation on line 447.

A more direct and robust way to fix the issue and improve clarity is to explicitly set metadata.cache_seqlens_int32 to the correct size of the attention window when the sink is active. This also resolves the issue from clamping seqlens_in_batch which could have unintended side effects.

I'd also recommend using named constants for the magic numbers 8192, 1024, and 7168 to improve readability and maintainability.

                ATTENTION_SINK_SIZE = 1024
                ATTENTION_WINDOW_SIZE = 7168
                MAX_ATTENTION_SEQ_LEN = ATTENTION_SINK_SIZE + ATTENTION_WINDOW_SIZE

                if forward_batch.seq_lens_cpu.max().item() > MAX_ATTENTION_SEQ_LEN:
                    metadata.page_table = torch.cat([
                        forward_batch.req_to_token_pool.req_to_token[forward_batch.req_pool_indices, :ATTENTION_SINK_SIZE],
                        forward_batch.req_to_token_pool.req_to_token[forward_batch.req_pool_indices, forward_batch.seq_lens_cpu.max().item() - ATTENTION_WINDOW_SIZE: forward_batch.seq_lens_cpu.max().item()]
                    ], dim=1)
                    # Explicitly set the sequence length for the attention kernel to the attention window size.
                    seqlens_in_batch = seqlens_in_batch.new_full(seqlens_in_batch.shape, MAX_ATTENTION_SEQ_LEN)
                else:
                    metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
                        forward_batch.req_pool_indices, : metadata.max_seq_len_k
                    ]
                metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)

self.forward_metadata_spec_decode_expand = metadata_expand
else:
# Normal Decode
seqlens_in_batch = torch.clamp(seqlens_in_batch, max=8192)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation clamps seqlens_in_batch to a maximum of 8192. However, for the attention sink mechanism to work correctly with paged attention, the cache_seqlens passed to the attention kernel should reflect the size of the constructed page_table, which is 1024 + 7168 = 8192. The actual sequence length can be much larger, and clamping it here might lead to incorrect behavior in other parts of the system that might rely on the original sequence length. A better approach is to set the sequence length for the attention kernel directly to the size of the attention window when the sink is active, without modifying the main seqlens_in_batch variable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a good call out, for now lets make it work for page size = 1 and ignore this msg

@hebiao064
Copy link
Collaborator

sorry i was very busy recently, just did brief look and this pr seems very easy to understand

I will do some research on attn sink and get back to you

if forward_batch.seq_lens_cpu.max().item() > 8192:
metadata.page_table = torch.cat([
forward_batch.req_to_token_pool.req_to_token[forward_batch.req_pool_indices, : 1024],
forward_batch.req_to_token_pool.req_to_token[forward_batch.req_pool_indices, forward_batch.seq_lens_cpu.max().item() - 7168: forward_batch.seq_lens_cpu.max().item()]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems bit suspicious, we may need to keep the last 7168 tokens for each sequence.

Maybe you can start with for loop

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, I will test for-loop as well then. But the fact that the code does not work even with batch-size=1 makes it feel like there are other problems other than this. Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants