feat: openai oss attention sink support with trtllm-gen backend #8825#8834
feat: openai oss attention sink support with trtllm-gen backend #8825#8834zhyncs merged 24 commits intosgl-project:mainfrom
Conversation
Co-authored-by: averyhuang <averyh@nvidia.com>
There was a problem hiding this comment.
Summary of Changes
Hello @yyihuang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces the capability to utilize attention sinks within the TensorRT-LLM (TRTLLM) generation backend. This enhancement allows for more fine-grained control over the attention mechanism, potentially improving model performance or stability, especially for long contexts, by integrating a new 'sink' parameter into the attention computation.
Highlights
- Attention Sink Integration: Implemented support for attention sink functionality within the
trtllm_mhabackend by extracting ansk(sink) parameter from keyword arguments and passing it to the underlying TensorRT-LLM attention kernel calls during both decode and extend phases. - API Extension for Attention Parameters: Modified the
forward_decodeandforward_extendmethods intrtllm_mha_backend.pyto accept arbitrary keyword arguments (**kwargs), enabling the flexible passing of new attention-related parameters like the attention sink.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Code Review
This pull request adds support for attention sink in the trtllm_mha_backend. The changes look good and correctly pass the attention_sink parameter to the underlying FlashInfer kernels for both decode and extend paths.
I have a few suggestions to improve code clarity and maintainability:
- Replace the magic string "sk" with a more descriptive key for retrieving the attention sink value.
- Add a missing return type hint to the
forward_extendfunction.
These are minor changes that will make the code easier to understand and maintain.
| else 1.0 | ||
| ) | ||
| # sink: additional value per head in the denominator of the softmax. | ||
| attention_sink = kwargs.get("sk", None) |
There was a problem hiding this comment.
Using the magic string "sk" to retrieve the attention sink value from kwargs can be hard to maintain and prone to errors. It's not immediately clear what "sk" stands for without prior knowledge.
To improve readability and maintainability, I suggest using a more descriptive key, like "attention_sink". It would be even better to define this key as a constant in a shared location and use it both here and at the call sites.
attention_sink = kwargs.get("attention_sink", None)| forward_batch: ForwardBatch, | ||
| save_kv_cache=True, | ||
| **kwargs, | ||
| ): |
| # TODO: bmm1_scale and bmm2_scale might require modification | ||
| # TODO: Change once quantization is supported | ||
| # sink: additional value per head in the denominator of the softmax. | ||
| attention_sink = kwargs.get("sk", None) |
…roject#8825 (sgl-project#8834) Co-authored-by: averyhuang <averyh@nvidia.com>
…roject#8825 (sgl-project#8834) Co-authored-by: averyhuang <averyh@nvidia.com>
Motivation
Add attention sinks. #8833
Modifications
Accuracy Test
throughput
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1000 --parallel 1000Output
lm_eval
lm_eval --model local-chat-completions --model_args model=gpt-oss,base_url=http://127.0.0.1:30000/v1/chat/completions,num_concurrent=128,timeout=999999,max_gen_toks=2048 --tasks gsm8k --batch_size 128 --apply_chat_template --num_fewshot 8TRTLLM_MHA:
python3 -m sglang.launch_server --model-path lmsys/gpt-oss-20b-bf16 --trust-remote-code --attention-backend trtllm_mha --enable-triton-kernel-moe --mem-fraction-static 0.7 --tp-size 8 --disable-cuda-graph --disable-hybrid-swa-memoryTRITON
python3 -m sglang.launch_server --model-path lmsys/gpt-oss-20b-bf16 --trust-remote-code --attention-backend triton --enable-triton-kernel-moe --mem-fraction-static 0.7 --tp 8lm_eval
Benchmark & Profiling
Checklist