Skip to content

webgpu: support head_sink in flash attention#27410

Merged
guschmue merged 5 commits intomainfrom
gs/wgpu-fa-head-sink
Feb 25, 2026
Merged

webgpu: support head_sink in flash attention#27410
guschmue merged 5 commits intomainfrom
gs/wgpu-fa-head-sink

Conversation

@guschmue
Copy link
Copy Markdown
Contributor

This enables flash attention for gpt-oss

@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label Feb 20, 2026
@guschmue guschmue marked this pull request as ready for review February 20, 2026 18:26
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR enables flash attention support for head_sink in WebGPU, specifically to support gpt-oss models. The changes remove the restriction that prevented flash attention from being used when head_sink is present, and thread the head_sink parameter through the entire flash attention call chain.

Changes:

  • Removed the head_sink nullptr check that was blocking flash attention usage with head_sink
  • Added head_sink parameter throughout the flash attention implementation with backward-compatible default value
  • Updated WGSL shader templates to handle head_sink logic for proper attention computation
  • Fixed numerical precision issue with explicit f32 casting in exponential calculations

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated no comments.

Show a summary per file
File Description
onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc Removed head_sink nullptr check blocking flash attention; added head_sink parameter to ApplyFlashAttention calls
onnxruntime/contrib_ops/webgpu/bert/flash_attention.h Added head_sink parameter to ApplyFlashAttention signature and FlashAttentionDecodeSplitVxProgram; added num_heads uniform variable
onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc Threaded head_sink parameter through flash attention implementation; added shader input and uniform handling for head_sink
onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template Added head_sink support with conditional initialization of previous_max; fixed numerical precision with explicit f32 casting
onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template Added head_sink support in global max/sum calculation using head_idx derived from batch_head_idx

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

fs-eire
fs-eire previously approved these changes Feb 25, 2026
@guschmue guschmue enabled auto-merge (squash) February 25, 2026 00:16
@guschmue guschmue merged commit bb3866c into main Feb 25, 2026
94 of 95 checks passed
@guschmue guschmue deleted the gs/wgpu-fa-head-sink branch February 25, 2026 08:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:WebGPU ort-web webgpu provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants