-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[webgpu] support smooth softmax for non-FA GQA implementation #25285
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
c201d49 to
5845d0e
Compare
5845d0e to
3a7b54f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Adds support for smooth softmax, attention bias, and a per-head “sink” value in the non-FlashAttention path of GroupQueryAttention.
- Introduce three new optional inputs (
position_ids,attention_bias,head_sink) in both WebGPU and CPU GQA kernels - Wire
attention_biasandhead_sinkthroughComputeInternal,ApplyAttention, and the softmax shader - Update program constructors and shader generation to respect
use_smooth_softmax,has_seqlen_k, andhas_head_sinkflags
Reviewed Changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc | Added new inputs and custom‐input checks; passed attention_bias & head_sink into ComputeInternal and ApplyAttention |
| onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc | Noted TODO for smooth softmax/head_sink support in FlashAttention |
| onnxruntime/contrib_ops/webgpu/bert/attention_common.h | Extended ApplyAttention signature to include head_sink |
| onnxruntime/contrib_ops/webgpu/bert/attention.h | Modified program constructors to take use_smooth_softmax, has_seqlen_k, has_head_sink flags |
| onnxruntime/contrib_ops/webgpu/bert/attention.cc | Enhanced InPlaceSoftmaxProgram generation to handle the new flags in the shader |
| onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h | Added validation logic for the new head_sink tensor |
| onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc | Loaded head_sink input and invoked validation in the CPU kernel |
Comments suppressed due to low confidence (3)
onnxruntime/contrib_ops/webgpu/bert/attention.cc:226
- The new smooth_softmax and head_sink branches in the shader code introduce distinct behavior; please add unit or integration tests to cover both code paths in WebGPU and CPU implementations.
Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc:155
- The operator schema and user‐facing documentation need updating to expose the new optional inputs (
position_ids,attention_bias, andhead_sink) and explain their semantics.
const Tensor* position_ids = context.Input<Tensor>(9); // TODO: support sliding window
onnxruntime/contrib_ops/webgpu/bert/attention.h:72
- [nitpick] The term
head_sinkmay be unfamiliar to new readers; consider renaming it to something more descriptive or adding a comment in the header to clarify its intended effect on the softmax.
InPlaceSoftmaxProgram(int work_group_size, int components, bool use_smooth_softmax, bool has_seqlen_k, bool has_head_sink)
…oft#25285) ### Description support smooth softmax for non-FA GQA implementation This change depends on: - microsoft#25269 Work items: - [x] support smooth softmax - [x] support bias - [x] support head sink (per-head smooth softmax) The following will not be included in this PR: - support for FlashAttention - support sliding window
…oft#25285) ### Description support smooth softmax for non-FA GQA implementation This change depends on: - microsoft#25269 Work items: - [x] support smooth softmax - [x] support bias - [x] support head sink (per-head smooth softmax) The following will not be included in this PR: - support for FlashAttention - support sliding window
…oft#25285) ### Description support smooth softmax for non-FA GQA implementation This change depends on: - microsoft#25269 Work items: - [x] support smooth softmax - [x] support bias - [x] support head sink (per-head smooth softmax) The following will not be included in this PR: - support for FlashAttention - support sliding window
Description
support smooth softmax for non-FA GQA implementation
This change depends on:
Work items:
The following will not be included in this PR: