Skip to content

Conversation

@fs-eire
Copy link
Contributor

@fs-eire fs-eire commented Jul 4, 2025

Description

support smooth softmax for non-FA GQA implementation

This change depends on:

Work items:

  • support smooth softmax
  • support bias
  • support head sink (per-head smooth softmax)

The following will not be included in this PR:

  • support for FlashAttention
  • support sliding window

@fs-eire fs-eire marked this pull request as draft July 4, 2025 02:01
@fs-eire fs-eire changed the title [webgpu] support smooth softmax for non-FA GQA implementation [WIP][webgpu] support smooth softmax for non-FA GQA implementation Jul 4, 2025
@fs-eire fs-eire force-pushed the fs-eire/webgpu-smooth-softmax branch 5 times, most recently from c201d49 to 5845d0e Compare July 5, 2025 19:11
@fs-eire fs-eire changed the title [WIP][webgpu] support smooth softmax for non-FA GQA implementation [webgpu] support smooth softmax for non-FA GQA implementation Jul 5, 2025
@fs-eire fs-eire marked this pull request as ready for review July 5, 2025 19:11
@fs-eire fs-eire force-pushed the fs-eire/webgpu-smooth-softmax branch from 5845d0e to 3a7b54f Compare July 5, 2025 19:13
@fs-eire fs-eire requested a review from Copilot July 5, 2025 19:13
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Adds support for smooth softmax, attention bias, and a per-head “sink” value in the non-FlashAttention path of GroupQueryAttention.

  • Introduce three new optional inputs (position_ids, attention_bias, head_sink) in both WebGPU and CPU GQA kernels
  • Wire attention_bias and head_sink through ComputeInternal, ApplyAttention, and the softmax shader
  • Update program constructors and shader generation to respect use_smooth_softmax, has_seqlen_k, and has_head_sink flags

Reviewed Changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc Added new inputs and custom‐input checks; passed attention_bias & head_sink into ComputeInternal and ApplyAttention
onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc Noted TODO for smooth softmax/head_sink support in FlashAttention
onnxruntime/contrib_ops/webgpu/bert/attention_common.h Extended ApplyAttention signature to include head_sink
onnxruntime/contrib_ops/webgpu/bert/attention.h Modified program constructors to take use_smooth_softmax, has_seqlen_k, has_head_sink flags
onnxruntime/contrib_ops/webgpu/bert/attention.cc Enhanced InPlaceSoftmaxProgram generation to handle the new flags in the shader
onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h Added validation logic for the new head_sink tensor
onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc Loaded head_sink input and invoked validation in the CPU kernel
Comments suppressed due to low confidence (3)

onnxruntime/contrib_ops/webgpu/bert/attention.cc:226

  • The new smooth_softmax and head_sink branches in the shader code introduce distinct behavior; please add unit or integration tests to cover both code paths in WebGPU and CPU implementations.
Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {

onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc:155

  • The operator schema and user‐facing documentation need updating to expose the new optional inputs (position_ids, attention_bias, and head_sink) and explain their semantics.
  const Tensor* position_ids = context.Input<Tensor>(9);  // TODO: support sliding window

onnxruntime/contrib_ops/webgpu/bert/attention.h:72

  • [nitpick] The term head_sink may be unfamiliar to new readers; consider renaming it to something more descriptive or adding a comment in the header to clarify its intended effect on the softmax.
  InPlaceSoftmaxProgram(int work_group_size, int components, bool use_smooth_softmax, bool has_seqlen_k, bool has_head_sink)

@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label Jul 7, 2025
@fs-eire fs-eire merged commit 6d28e2d into main Jul 7, 2025
91 checks passed
@fs-eire fs-eire deleted the fs-eire/webgpu-smooth-softmax branch July 7, 2025 23:59
daijh pushed a commit to daijh/onnxruntime that referenced this pull request Jul 10, 2025
…oft#25285)

### Description



support smooth softmax for non-FA GQA implementation


This change depends on:
- microsoft#25269



Work items:

- [x] support smooth softmax
- [x] support bias
- [x] support head sink (per-head smooth softmax)

The following will not be included in this PR:
- support for FlashAttention
- support sliding window
qti-yuduo pushed a commit to CodeLinaro/onnxruntime that referenced this pull request Aug 8, 2025
…oft#25285)

### Description



support smooth softmax for non-FA GQA implementation


This change depends on:
- microsoft#25269



Work items:

- [x] support smooth softmax
- [x] support bias
- [x] support head sink (per-head smooth softmax)

The following will not be included in this PR:
- support for FlashAttention
- support sliding window
sanketkaleoss pushed a commit to sanketkaleoss/onnxruntime that referenced this pull request Aug 11, 2025
…oft#25285)

### Description



support smooth softmax for non-FA GQA implementation


This change depends on:
- microsoft#25269



Work items:

- [x] support smooth softmax
- [x] support bias
- [x] support head sink (per-head smooth softmax)

The following will not be included in this PR:
- support for FlashAttention
- support sliding window
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:WebGPU ort-web webgpu provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants