Skip to content

webgpu support for LinearAttention and CausalConvWithState #27896

Closed
guschmue wants to merge 30 commits intomainfrom
gs/wgpu-lattn
Closed

webgpu support for LinearAttention and CausalConvWithState #27896
guschmue wants to merge 30 commits intomainfrom
gs/wgpu-lattn

Conversation

@guschmue
Copy link
Copy Markdown
Contributor

@guschmue guschmue commented Mar 29, 2026

moved to here: #27996

@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label Mar 29, 2026
@guschmue guschmue changed the title Gs/wgpu lattn webgpu support for LinearAttention and CausalConvWithState Mar 29, 2026
Comment thread onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc Fixed
Comment thread onnxruntime/test/contrib_ops/causal_conv_with_state_op_test.cc Fixed
Comment thread onnxruntime/test/contrib_ops/linear_attention_op_test.cc Fixed
Copy link
Copy Markdown
Contributor

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

kv_num_heads_ = static_cast<int>(info.GetAttr<int64_t>("kv_num_heads"));
}


Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

Comment on lines +351 to +353
const Tensor* past_state = context.Input(3); // optional
const Tensor* decay = context.Input(4); // optional
const Tensor* beta = context.Input(5); // optional
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const Tensor* past_state = context.Input(3); // optional
const Tensor* decay = context.Input(4); // optional
const Tensor* beta = context.Input(5); // optional
const Tensor* past_state = context.Input(3); // optional
const Tensor* decay = context.Input(4); // optional
const Tensor* beta = context.Input(5); // optional

auto& query_shape = getInputShape(ctx, 0);
auto& value_shape = getInputShape(ctx, 2);
TensorShapeProto state_shape;
*state_shape.add_dim() = query_shape.dim(0); // B
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
*state_shape.add_dim() = query_shape.dim(0); // B
*state_shape.add_dim() = query_shape.dim(0); // B

}
}));


Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

Comment on lines +15 to +22
#define WEBGPU_CONCAT_VERSIONED_KERNEL(start, end) \
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
Concat, \
kOnnxDomain, \
start, \
end, \
kWebGpuExecutionProvider, \
(*KernelDefBuilder::Create()) \
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#define WEBGPU_CONCAT_VERSIONED_KERNEL(start, end) \
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
Concat, \
kOnnxDomain, \
start, \
end, \
kWebGpuExecutionProvider, \
(*KernelDefBuilder::Create()) \
#define WEBGPU_CONCAT_VERSIONED_KERNEL(start, end) \
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
Concat, \
kOnnxDomain, \
start, \
end, \
kWebGpuExecutionProvider, \
(*KernelDefBuilder::Create()) \


using namespace onnxruntime::test;


Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

const std::vector<float>* decay,
const std::vector<float>* beta,
std::vector<float>& output,
std::vector<float>& final_state) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std::vector<float>& final_state) {
std::vector<float>& final_state) {
int bht = batch_size * num_heads * seq_length;
bool decay_broadcast_dk = (decay != nullptr && static_cast<int>(decay->size()) == bht);

Comment on lines +36 to +39
int bht = batch_size * num_heads * seq_length;
bool decay_broadcast_dk = (decay != nullptr && static_cast<int>(decay->size()) == bht);

// State: (B, H, dk, dv)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
int bht = batch_size * num_heads * seq_length;
bool decay_broadcast_dk = (decay != nullptr && static_cast<int>(decay->size()) == bht);
// State: (B, H, dk, dv)
// State: (B, H, dk, dv)


// Convert data from 4D (B,H,T,D) layout to 3D packed (B,T,H*D) layout
std::vector<float> PackBHTD_to_BTHD(const std::vector<float>& data_4d,
int B, int H, int T, int D) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
int B, int H, int T, int D) {
int B, int H, int T, int D) {


// Convert decay/beta from (B,H,T) layout to (B,T,H) layout
std::vector<float> TransposeBHT_to_BTH(const std::vector<float>& data,
int B, int H, int T) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
int B, int H, int T) {
int B, int H, int T) {

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds WebGPU execution provider coverage for new/updated LLM building blocks (notably LinearAttention and CausalConvWithState) and wires up several ONNX-domain LLM ops needed for Qwen3.5-style graphs, alongside new reference-based tests.

Changes:

  • Add WebGPU contrib kernels for LinearAttention and CausalConvWithState, plus schema registration in MS opset.
  • Add WebGPU kernels for ONNX-domain Attention, RotaryEmbedding, and RMSNormalization, and update WebGPU kernel registrations for newer Reshape/Transpose opsets.
  • Add extensive correctness tests (reference implementations) for LinearAttention and CausalConvWithState, plus an int64 Concat test.

Reviewed changes

Copilot reviewed 20 out of 21 changed files in this pull request and generated 12 comments.

Show a summary per file
File Description
onnxruntime/test/providers/cpu/tensor/concat_op_test.cc Adds an int64 Concat test case.
onnxruntime/test/contrib_ops/linear_attention_op_test.cc New reference-based test suite for LinearAttention across update rules and shapes.
onnxruntime/test/contrib_ops/causal_conv_with_state_op_test.cc New reference-based test suite for CausalConvWithState (fp32/fp16, state continuity, etc.).
onnxruntime/core/providers/webgpu/webgpu_supported_types.h Adds an additional supported-type list including int64/uint64 (currently unused).
onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc Registers new kernels and updates versioned kernel declarations/registrations for opset changes.
onnxruntime/core/providers/webgpu/tensor/transpose.cc Updates WebGPU Transpose kernel registration for opset 23/24 split.
onnxruntime/core/providers/webgpu/tensor/reshape.cc Updates WebGPU Reshape kernel registration for opset 21–25.
onnxruntime/core/providers/webgpu/tensor/concat.cc Formatting-only macro alignment / namespace close fix.
onnxruntime/core/providers/webgpu/nn/rms_norm.h Declares WebGPU RMSNorm kernel wrapper.
onnxruntime/core/providers/webgpu/nn/rms_norm.cc Implements WebGPU RMSNormalization via LayerNormProgram in simplified mode.
onnxruntime/core/providers/webgpu/llm/rotary_embedding.h Declares ONNX-domain WebGPU RotaryEmbedding kernel wrapper.
onnxruntime/core/providers/webgpu/llm/rotary_embedding.cc Implements ONNX-domain RotaryEmbedding using existing contrib shader programs.
onnxruntime/core/providers/webgpu/llm/attention.h Declares ONNX-domain WebGPU Attention kernel wrapper.
onnxruntime/core/providers/webgpu/llm/attention.cc Implements ONNX-domain Attention on top of existing contrib WebGPU attention kernels.
onnxruntime/core/graph/contrib_ops/ms_opset.h Registers new MS-domain schemas in opset v1 list.
onnxruntime/core/graph/contrib_ops/bert_defs.cc Adds MS-domain schemas + shape inference for LinearAttention and CausalConvWithState.
onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc Registers WebGPU contrib kernels for LinearAttention and CausalConvWithState.
onnxruntime/contrib_ops/webgpu/bert/linear_attention.h Declares LinearAttentionProgram and WebGPU kernel wrapper.
onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc Implements LinearAttention WGSL generation + kernel host-side validation/dispatch.
onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.h Declares CausalConvWithStateProgram and WebGPU kernel wrapper.
onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc Implements CausalConvWithState WGSL generation + kernel host-side validation/dispatch.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +55 to +64
TensorShapeVector inv_std_dev_dim;
for (size_t i = 0; i < x_shape.NumDimensions(); ++i) {
if (i < axis) {
inv_std_dev_dim.push_back(x_shape[i]);
} else {
inv_std_dev_dim.push_back(1);
}
}
TensorShape inv_std_dev_shape(inv_std_dev_dim);
auto* inv_std_dev = context.Output(1, inv_std_dev_shape);
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inv_std_dev output is optional for RMSNormalization, but this code unconditionally calls context.Output(1, ...). If the node only has 1 output, this will be out-of-range and can crash/fail. Please guard with if (context.OutputCount() > 1) before requesting/producing output 1.

Suggested change
TensorShapeVector inv_std_dev_dim;
for (size_t i = 0; i < x_shape.NumDimensions(); ++i) {
if (i < axis) {
inv_std_dev_dim.push_back(x_shape[i]);
} else {
inv_std_dev_dim.push_back(1);
}
}
TensorShape inv_std_dev_shape(inv_std_dev_dim);
auto* inv_std_dev = context.Output(1, inv_std_dev_shape);
Tensor* inv_std_dev = nullptr;
if (context.OutputCount() > 1) {
TensorShapeVector inv_std_dev_dim;
for (size_t i = 0; i < x_shape.NumDimensions(); ++i) {
if (i < axis) {
inv_std_dev_dim.push_back(x_shape[i]);
} else {
inv_std_dev_dim.push_back(1);
}
}
TensorShape inv_std_dev_shape(inv_std_dev_dim);
inv_std_dev = context.Output(1, inv_std_dev_shape);
}

Copilot uses AI. Check for mistakes.
update_rule_ == LinearAttentionUpdateRule::GatedDelta);
ORT_RETURN_IF(needs_decay && decay == nullptr, "decay input required for gated/gated_delta update rules");
ORT_RETURN_IF(needs_beta && beta == nullptr, "beta input required for delta/gated_delta update rules");

Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation doesn't validate decay/beta shapes beyond nullptr checks, but the shader assumes specific packed layouts. In particular, the schema allows beta to be (B, T, 1) but the shader reads it as (B, T, H), which would cause out-of-bounds reads. Please either implement broadcasting for the (B,T,1) case (and validate dimensions), or explicitly reject unsupported shapes with a clear error.

Suggested change
// Validate decay/beta shapes. The shader expects (B, T, H) where H == num_heads.
if (needs_decay && decay != nullptr) {
const auto& decay_shape = decay->Shape();
ORT_RETURN_IF(decay_shape.NumDimensions() != 3,
"decay must have shape (batch_size, seq_length, num_heads); ",
"broadcasted form (B, T, 1) is not currently supported");
ORT_RETURN_IF(static_cast<int64_t>(batch_size) != decay_shape[0] ||
static_cast<int64_t>(seq_length) != decay_shape[1],
"decay shape mismatch: expected batch_size=", batch_size,
" and seq_length=", seq_length,
" but got (", decay_shape[0], ", ", decay_shape[1], ", ", decay_shape[2], ")");
ORT_RETURN_IF(static_cast<int64_t>(num_heads) != decay_shape[2],
"decay last dimension must equal num_heads (", num_heads,
"); broadcasted form (B, T, 1) is not currently supported, got ",
decay_shape[2]);
}
if (needs_beta && beta != nullptr) {
const auto& beta_shape = beta->Shape();
ORT_RETURN_IF(beta_shape.NumDimensions() != 3,
"beta must have shape (batch_size, seq_length, num_heads); ",
"broadcasted form (B, T, 1) is not currently supported");
ORT_RETURN_IF(static_cast<int64_t>(batch_size) != beta_shape[0] ||
static_cast<int64_t>(seq_length) != beta_shape[1],
"beta shape mismatch: expected batch_size=", batch_size,
" and seq_length=", seq_length,
" but got (", beta_shape[0], ", ", beta_shape[1], ", ", beta_shape[2], ")");
ORT_RETURN_IF(static_cast<int64_t>(num_heads) != beta_shape[2],
"beta last dimension must equal num_heads (", num_heads,
"); broadcasted form (B, T, 1) is not currently supported, got ",
beta_shape[2]);
}

Copilot uses AI. Check for mistakes.
Comment on lines +260 to +270
// Allocate outputs
// Output 0: (B, D, L)
Tensor* output = context.Output(0, input_shape);

// Output 1: present_state (B, D, K-1)
std::vector<int64_t> state_dims{batch_size, channels, state_length};
Tensor* present_state = context.Output(1, TensorShape(state_dims));

if (input_length == 0) {
return Status::OK();
}
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When input_length == 0, the code returns early after allocating present_state but never writes/initializes it. present_state should still be well-defined (typically equal to past_state if provided, otherwise zeros) even for an empty input. Please handle the zero-length case by populating present_state appropriately before returning.

Copilot uses AI. Check for mistakes.
Comment on lines +14 to +45
1. Input parsing: Handles both 3D (B, S, hidden) and 4D (B, N, S, H) input formats per the ONNX spec
2. MHA vs GQA: Detects whether q_num_heads == kv_num_heads (MHA) or q_num_heads > kv_num_heads (GQA) and configures WebgpuAttentionParameters accordingly
3. Flash attention: Used when available (no output_qk needed, subgroups feature present, no bias)
4. 3D→BNSH conversion: For 3D inputs, uses TransferBSDToBNSH to convert to the BNSH format expected by the attention kernels
5. 4D output: Computes in BSD layout (as the shader outputs), then transposes back to BNSH for 4D output format
6. Attention mask: Reshapes 2D/3D masks to 4D for the shader's broadcasting logic; boolean masks return NOT_SUPPORTED

Remaining failures fall into known limitation categories:
Boolean masks (2) — not yet supported on WebGPU
SoftCap (2) — not yet wired through to the shader
GQA output (3) — output stride mismatch for GQA with different kv_num_heads
QK matmul output (5) — the output_qk output needs additional work
Present without past (2) — present key/value output without past input needs handling
is_causal (1) — causal masking interaction

[ PASSED ] 24 tests.
[ FAILED ] 15 tests, listed below:
[ FAILED ] AttentionTest.Attention4DAttnMaskBoolAllFalse
[ FAILED ] AttentionTest.Attention4DAttnMaskBoolAllFalseDecodeWithPast
[ FAILED ] AttentionTest.Attention4DSoftCap
[ FAILED ] AttentionTest.Attention4DSoftCapFloat16
[ FAILED ] AttentionTest.Attention4DAttnMaskBool
[ FAILED ] AttentionTest.Attention4DAttnIsCausal
[ FAILED ] AttentionTest.Attention3DGqaAttn
[ FAILED ] AttentionTest.Attention3DGqaSelfAttnCausal
[ FAILED ] AttentionTest.Attention4DGqaAttnMask
[ FAILED ] AttentionTest.Attention4DWithPastAndPresentQkMatmul
[ FAILED ] AttentionTest.Attention3DWithPastAndPresentQkMatmul
[ FAILED ] AttentionTest.Attention4DWithMask3DPastAndPresentQkMatmul
[ FAILED ] AttentionTest.Attention4DWithMask3DPastAndPresentQkMatmulCausal
[ FAILED ] AttentionTest.TestAttention4DWithPastAndPresentQkMatmulBias4DMaskCausal
[ FAILED ] AttentionTest.AttentionNoPastWithPresentOutput
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This large block comment includes a snapshot of passing/failing test names and counts. It will go stale quickly and makes the production kernel harder to maintain. Consider moving this information to the PR description or a tracking issue (and keep only a brief comment describing current limitations in code).

Suggested change
1. Input parsing: Handles both 3D (B, S, hidden) and 4D (B, N, S, H) input formats per the ONNX spec
2. MHA vs GQA: Detects whether q_num_heads == kv_num_heads (MHA) or q_num_heads > kv_num_heads (GQA) and configures WebgpuAttentionParameters accordingly
3. Flash attention: Used when available (no output_qk needed, subgroups feature present, no bias)
4. 3D→BNSH conversion: For 3D inputs, uses TransferBSDToBNSH to convert to the BNSH format expected by the attention kernels
5. 4D output: Computes in BSD layout (as the shader outputs), then transposes back to BNSH for 4D output format
6. Attention mask: Reshapes 2D/3D masks to 4D for the shader's broadcasting logic; boolean masks return NOT_SUPPORTED
Remaining failures fall into known limitation categories:
Boolean masks (2) — not yet supported on WebGPU
SoftCap (2) — not yet wired through to the shader
GQA output (3) — output stride mismatch for GQA with different kv_num_heads
QK matmul output (5) — the output_qk output needs additional work
Present without past (2) — present key/value output without past input needs handling
is_causal (1) — causal masking interaction
[ PASSED ] 24 tests.
[ FAILED ] 15 tests, listed below:
[ FAILED ] AttentionTest.Attention4DAttnMaskBoolAllFalse
[ FAILED ] AttentionTest.Attention4DAttnMaskBoolAllFalseDecodeWithPast
[ FAILED ] AttentionTest.Attention4DSoftCap
[ FAILED ] AttentionTest.Attention4DSoftCapFloat16
[ FAILED ] AttentionTest.Attention4DAttnMaskBool
[ FAILED ] AttentionTest.Attention4DAttnIsCausal
[ FAILED ] AttentionTest.Attention3DGqaAttn
[ FAILED ] AttentionTest.Attention3DGqaSelfAttnCausal
[ FAILED ] AttentionTest.Attention4DGqaAttnMask
[ FAILED ] AttentionTest.Attention4DWithPastAndPresentQkMatmul
[ FAILED ] AttentionTest.Attention3DWithPastAndPresentQkMatmul
[ FAILED ] AttentionTest.Attention4DWithMask3DPastAndPresentQkMatmul
[ FAILED ] AttentionTest.Attention4DWithMask3DPastAndPresentQkMatmulCausal
[ FAILED ] AttentionTest.TestAttention4DWithPastAndPresentQkMatmulBias4DMaskCausal
[ FAILED ] AttentionTest.AttentionNoPastWithPresentOutput
1. Input parsing: Handles both 3D (B, S, hidden) and 4D (B, N, S, H) input formats per the ONNX spec.
2. MHA vs GQA: Detects whether q_num_heads == kv_num_heads (MHA) or q_num_heads > kv_num_heads (GQA) and configures
WebgpuAttentionParameters accordingly.
3. Flash attention: Used when available (no output_qk needed, subgroups feature present, no bias).
4. 3D→BNSH conversion: For 3D inputs, uses TransferBSDToBNSH to convert to the BNSH format expected by the attention kernels.
5. 4D output: Computes in BSD layout (as the shader outputs), then transposes back to BNSH for 4D output format.
6. Attention mask: Reshapes 2D/3D masks to 4D for the shader's broadcasting logic; boolean masks currently return
NOT_SUPPORTED.
Known current limitations (see associated PR or tracking issue for detailed test coverage status):
- Boolean attention masks on WebGPU are not yet supported.
- SoftCap is not yet wired through to the shader.
- GQA output with differing q_num_heads and kv_num_heads has an output stride/layout mismatch.
- The optional QK matmul (output_qk) output path requires additional work.
- Present-only key/value outputs (present without past) are not fully handled.
- Some is_causal configurations require additional handling of causal masking interactions.

Copilot uses AI. Check for mistakes.
Comment on lines +43 to +46
TEST(ConcatOpTest, Concat1D_int64) {
// webgpu ep will fail for 0x1122334455667788
const int64_t val = 0x11223344;
OpTester test("Concat");
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This int64 Concat test uses a 32-bit-sized constant (0x11223344), so it doesn't actually exercise 64-bit value handling. If WebGPU currently fails for larger int64 values, it's better to keep a true 64-bit test value (e.g. > 2^32) and explicitly exclude the WebGPU EP for this test until fixed, rather than weakening the test coverage.

Copilot uses AI. Check for mistakes.
Comment on lines +2431 to +2442
*state_shape.add_dim() = query_shape.dim(0); // B
state_shape.add_dim()->set_dim_value(kv_num_heads); // H_kv
// d_k = query.dim(2) / q_num_heads
if (query_shape.dim(2).has_dim_value()) {
state_shape.add_dim()->set_dim_value(query_shape.dim(2).dim_value() / q_num_heads);
} else {
state_shape.add_dim();
}
// d_v = value.dim(2) / kv_num_heads
if (value_shape.dim(2).has_dim_value()) {
state_shape.add_dim()->set_dim_value(value_shape.dim(2).dim_value() / kv_num_heads);
} else {
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, present_state shape inference computes d_k = query.dim(2) / q_num_heads and d_v = value.dim(2) / kv_num_heads without verifying divisibility. Please guard these divisions (mod == 0) to avoid emitting incorrect concrete dimension values in the inferred shape.

Copilot uses AI. Check for mistakes.

static size_t NormalizeAxis(int64_t axis, size_t tensor_rank) {
int64_t rank = static_cast<int64_t>(tensor_rank);
if (axis < -rank && axis >= rank) {
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NormalizeAxis range check is incorrect: axis < -rank && axis >= rank can never be true, so invalid axis values won't be rejected and may lead to overflow in the normalization/casts. This should be axis < -rank || axis >= rank (same logic as elsewhere in the codebase).

Suggested change
if (axis < -rank && axis >= rank) {
if (axis < -rank || axis >= rank) {

Copilot uses AI. Check for mistakes.
Comment on lines +348 to +374
const Tensor* query = context.Input(0);
const Tensor* key = context.Input(1);
const Tensor* value = context.Input(2);
const Tensor* past_state = context.Input(3); // optional
const Tensor* decay = context.Input(4); // optional
const Tensor* beta = context.Input(5); // optional

// Validate 3D packed inputs
const auto& q_shape = query->Shape();
ORT_RETURN_IF(q_shape.NumDimensions() != 3, "query must be 3D (B, T, H_q*d_k)");

const int batch_size = static_cast<int>(q_shape[0]);
const int seq_length = static_cast<int>(q_shape[1]);
const int q_packed_dim = static_cast<int>(q_shape[2]);
const int num_heads = kv_num_heads_;

ORT_RETURN_IF(q_num_heads_ != kv_num_heads_,
"GQA (q_num_heads != kv_num_heads) is not yet supported");

const int head_dim_k = q_packed_dim / q_num_heads_;
ORT_RETURN_IF(q_packed_dim != head_dim_k * q_num_heads_,
"query packed dim must be divisible by q_num_heads");

const int v_packed_dim = static_cast<int>(value->Shape()[2]);
const int head_dim_v = v_packed_dim / kv_num_heads_;
ORT_RETURN_IF(v_packed_dim != head_dim_v * kv_num_heads_,
"value packed dim must be divisible by kv_num_heads");
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ComputeInternal derives head_dim_k from query but never validates that key has the expected 3D packed shape (B, T, H*dk) or that its last dimension matches q_num_heads_*head_dim_k. The shader indexes key using packed_dk based on the query shape, so a mismatched key shape can lead to out-of-bounds reads and incorrect results. Add explicit checks for key rank and dimensions (and batch/seq match) before launching the program.

Copilot uses AI. Check for mistakes.
Comment on lines +408 to +415
// Workgroup size = head_dim_k (one thread per dk row)
// Ensure it's a power of 2 for tree reduction (round up)
uint32_t workgroup_size = 1;
while (workgroup_size < static_cast<uint32_t>(head_dim_k)) {
workgroup_size *= 2;
}
// Cap at GPU limits
workgroup_size = std::min(workgroup_size, static_cast<uint32_t>(256));
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Workgroup size is capped to 256 even when head_dim_k is larger. The shader maps local_idx to a dk row and reductions assume full dk coverage, so capping below head_dim_k silently drops rows and produces incorrect results. Either enforce/validate head_dim_k <= max_workgroup_size_x (and return NOT_IMPLEMENTED/INVALID_ARGUMENT when exceeded) or redesign the algorithm to handle dk > workgroup_size via tiling/multiple workgroups.

Suggested change
// Workgroup size = head_dim_k (one thread per dk row)
// Ensure it's a power of 2 for tree reduction (round up)
uint32_t workgroup_size = 1;
while (workgroup_size < static_cast<uint32_t>(head_dim_k)) {
workgroup_size *= 2;
}
// Cap at GPU limits
workgroup_size = std::min(workgroup_size, static_cast<uint32_t>(256));
// Validate that head_dim_k does not exceed the maximum supported workgroup size.
// The shader maps one thread to each dk row and relies on full dk coverage.
const uint32_t kMaxWorkgroupSizeX = 256;
if (static_cast<uint32_t>(head_dim_k) > kMaxWorkgroupSizeX) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"WebGPU LinearAttention currently requires head_dim_k <= ",
kMaxWorkgroupSizeX,
"; got head_dim_k = ",
head_dim_k,
". Consider reducing head_dim_k or updating the kernel implementation.");
}
// Workgroup size = head_dim_k (one thread per dk row)
// Ensure it's a power of 2 for tree reduction (round up)
uint32_t workgroup_size = 1;
while (workgroup_size < static_cast<uint32_t>(head_dim_k)) {
workgroup_size *= 2;
}
// Cap at GPU limits (head_dim_k is already validated to be <= kMaxWorkgroupSizeX)
workgroup_size = std::min(workgroup_size, kMaxWorkgroupSizeX);

Copilot uses AI. Check for mistakes.
CausalConvWithState);

CausalConvWithState::CausalConvWithState(const OpKernelInfo& info)
: WebGpuKernel(info) {
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The op schema defines an ndim attribute (default 1) but the WebGPU kernel ignores it entirely (constructor only reads activation). Please read/validate ndim and return a clear NOT_IMPLEMENTED/INVALID_ARGUMENT error for ndim != 1 (or implement higher dimensions), otherwise models exporting ndim=2/3 will silently run with incorrect semantics.

Suggested change
: WebGpuKernel(info) {
: WebGpuKernel(info) {
// Validate supported dimensionality.
const int64_t ndim = info.GetAttrOrDefault<int64_t>("ndim", 1);
if (ndim != 1) {
ORT_THROW("CausalConvWithState WebGPU kernel only supports ndim=1, but got ndim=", ndim);
}

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment on lines +141 to +148
const std::vector<float>& query, // (B, q_num_heads, T, dk)
const std::vector<float>& key, // (B, n_k_heads, T, dk)
const std::vector<float>& value, // (B, kv_num_heads, T, dv)
const std::vector<float>* initial_state, // (B, kv_num_heads, dk, dv)
const std::vector<float>* decay, // (B, kv_num_heads, T[, dk])
const std::vector<float>* beta, // (B, kv_num_heads, T)
std::vector<float>& output, // (B, kv_num_heads, T, dv)
std::vector<float>& final_state) { // (B, kv_num_heads, dk, dv)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const std::vector<float>& query, // (B, q_num_heads, T, dk)
const std::vector<float>& key, // (B, n_k_heads, T, dk)
const std::vector<float>& value, // (B, kv_num_heads, T, dv)
const std::vector<float>* initial_state, // (B, kv_num_heads, dk, dv)
const std::vector<float>* decay, // (B, kv_num_heads, T[, dk])
const std::vector<float>* beta, // (B, kv_num_heads, T)
std::vector<float>& output, // (B, kv_num_heads, T, dv)
std::vector<float>& final_state) { // (B, kv_num_heads, dk, dv)
const std::vector<float>& query, // (B, q_num_heads, T, dk)
const std::vector<float>& key, // (B, n_k_heads, T, dk)
const std::vector<float>& value, // (B, kv_num_heads, T, dv)
const std::vector<float>* initial_state, // (B, kv_num_heads, dk, dv)
const std::vector<float>* decay, // (B, kv_num_heads, T[, dk])
const std::vector<float>* beta, // (B, kv_num_heads, T)
std::vector<float>& output, // (B, kv_num_heads, T, dv)
std::vector<float>& final_state) { // (B, kv_num_heads, dk, dv)

@guschmue guschmue closed this Apr 7, 2026
@guschmue guschmue deleted the gs/wgpu-lattn branch April 27, 2026 22:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:WebGPU ort-web webgpu provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants