webgpu support for LinearAttention and CausalConvWithState #27896
webgpu support for LinearAttention and CausalConvWithState #27896
Conversation
| kv_num_heads_ = static_cast<int>(info.GetAttr<int64_t>("kv_num_heads")); | ||
| } | ||
|
|
||
|
|
There was a problem hiding this comment.
| const Tensor* past_state = context.Input(3); // optional | ||
| const Tensor* decay = context.Input(4); // optional | ||
| const Tensor* beta = context.Input(5); // optional |
There was a problem hiding this comment.
| const Tensor* past_state = context.Input(3); // optional | |
| const Tensor* decay = context.Input(4); // optional | |
| const Tensor* beta = context.Input(5); // optional | |
| const Tensor* past_state = context.Input(3); // optional | |
| const Tensor* decay = context.Input(4); // optional | |
| const Tensor* beta = context.Input(5); // optional |
| auto& query_shape = getInputShape(ctx, 0); | ||
| auto& value_shape = getInputShape(ctx, 2); | ||
| TensorShapeProto state_shape; | ||
| *state_shape.add_dim() = query_shape.dim(0); // B |
There was a problem hiding this comment.
| *state_shape.add_dim() = query_shape.dim(0); // B | |
| *state_shape.add_dim() = query_shape.dim(0); // B |
| } | ||
| })); | ||
|
|
||
|
|
There was a problem hiding this comment.
| #define WEBGPU_CONCAT_VERSIONED_KERNEL(start, end) \ | ||
| ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ | ||
| Concat, \ | ||
| kOnnxDomain, \ | ||
| start, \ | ||
| end, \ | ||
| kWebGpuExecutionProvider, \ | ||
| (*KernelDefBuilder::Create()) \ |
There was a problem hiding this comment.
| #define WEBGPU_CONCAT_VERSIONED_KERNEL(start, end) \ | |
| ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ | |
| Concat, \ | |
| kOnnxDomain, \ | |
| start, \ | |
| end, \ | |
| kWebGpuExecutionProvider, \ | |
| (*KernelDefBuilder::Create()) \ | |
| #define WEBGPU_CONCAT_VERSIONED_KERNEL(start, end) \ | |
| ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ | |
| Concat, \ | |
| kOnnxDomain, \ | |
| start, \ | |
| end, \ | |
| kWebGpuExecutionProvider, \ | |
| (*KernelDefBuilder::Create()) \ |
|
|
||
| using namespace onnxruntime::test; | ||
|
|
||
|
|
There was a problem hiding this comment.
| const std::vector<float>* decay, | ||
| const std::vector<float>* beta, | ||
| std::vector<float>& output, | ||
| std::vector<float>& final_state) { |
There was a problem hiding this comment.
| std::vector<float>& final_state) { | |
| std::vector<float>& final_state) { | |
| int bht = batch_size * num_heads * seq_length; | |
| bool decay_broadcast_dk = (decay != nullptr && static_cast<int>(decay->size()) == bht); |
| int bht = batch_size * num_heads * seq_length; | ||
| bool decay_broadcast_dk = (decay != nullptr && static_cast<int>(decay->size()) == bht); | ||
|
|
||
| // State: (B, H, dk, dv) |
There was a problem hiding this comment.
| int bht = batch_size * num_heads * seq_length; | |
| bool decay_broadcast_dk = (decay != nullptr && static_cast<int>(decay->size()) == bht); | |
| // State: (B, H, dk, dv) | |
| // State: (B, H, dk, dv) |
|
|
||
| // Convert data from 4D (B,H,T,D) layout to 3D packed (B,T,H*D) layout | ||
| std::vector<float> PackBHTD_to_BTHD(const std::vector<float>& data_4d, | ||
| int B, int H, int T, int D) { |
There was a problem hiding this comment.
| int B, int H, int T, int D) { | |
| int B, int H, int T, int D) { |
|
|
||
| // Convert decay/beta from (B,H,T) layout to (B,T,H) layout | ||
| std::vector<float> TransposeBHT_to_BTH(const std::vector<float>& data, | ||
| int B, int H, int T) { |
There was a problem hiding this comment.
| int B, int H, int T) { | |
| int B, int H, int T) { |
There was a problem hiding this comment.
Pull request overview
Adds WebGPU execution provider coverage for new/updated LLM building blocks (notably LinearAttention and CausalConvWithState) and wires up several ONNX-domain LLM ops needed for Qwen3.5-style graphs, alongside new reference-based tests.
Changes:
- Add WebGPU contrib kernels for
LinearAttentionandCausalConvWithState, plus schema registration in MS opset. - Add WebGPU kernels for ONNX-domain
Attention,RotaryEmbedding, andRMSNormalization, and update WebGPU kernel registrations for newerReshape/Transposeopsets. - Add extensive correctness tests (reference implementations) for
LinearAttentionandCausalConvWithState, plus anint64Concat test.
Reviewed changes
Copilot reviewed 20 out of 21 changed files in this pull request and generated 12 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/test/providers/cpu/tensor/concat_op_test.cc | Adds an int64 Concat test case. |
| onnxruntime/test/contrib_ops/linear_attention_op_test.cc | New reference-based test suite for LinearAttention across update rules and shapes. |
| onnxruntime/test/contrib_ops/causal_conv_with_state_op_test.cc | New reference-based test suite for CausalConvWithState (fp32/fp16, state continuity, etc.). |
| onnxruntime/core/providers/webgpu/webgpu_supported_types.h | Adds an additional supported-type list including int64/uint64 (currently unused). |
| onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc | Registers new kernels and updates versioned kernel declarations/registrations for opset changes. |
| onnxruntime/core/providers/webgpu/tensor/transpose.cc | Updates WebGPU Transpose kernel registration for opset 23/24 split. |
| onnxruntime/core/providers/webgpu/tensor/reshape.cc | Updates WebGPU Reshape kernel registration for opset 21–25. |
| onnxruntime/core/providers/webgpu/tensor/concat.cc | Formatting-only macro alignment / namespace close fix. |
| onnxruntime/core/providers/webgpu/nn/rms_norm.h | Declares WebGPU RMSNorm kernel wrapper. |
| onnxruntime/core/providers/webgpu/nn/rms_norm.cc | Implements WebGPU RMSNormalization via LayerNormProgram in simplified mode. |
| onnxruntime/core/providers/webgpu/llm/rotary_embedding.h | Declares ONNX-domain WebGPU RotaryEmbedding kernel wrapper. |
| onnxruntime/core/providers/webgpu/llm/rotary_embedding.cc | Implements ONNX-domain RotaryEmbedding using existing contrib shader programs. |
| onnxruntime/core/providers/webgpu/llm/attention.h | Declares ONNX-domain WebGPU Attention kernel wrapper. |
| onnxruntime/core/providers/webgpu/llm/attention.cc | Implements ONNX-domain Attention on top of existing contrib WebGPU attention kernels. |
| onnxruntime/core/graph/contrib_ops/ms_opset.h | Registers new MS-domain schemas in opset v1 list. |
| onnxruntime/core/graph/contrib_ops/bert_defs.cc | Adds MS-domain schemas + shape inference for LinearAttention and CausalConvWithState. |
| onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc | Registers WebGPU contrib kernels for LinearAttention and CausalConvWithState. |
| onnxruntime/contrib_ops/webgpu/bert/linear_attention.h | Declares LinearAttentionProgram and WebGPU kernel wrapper. |
| onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc | Implements LinearAttention WGSL generation + kernel host-side validation/dispatch. |
| onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.h | Declares CausalConvWithStateProgram and WebGPU kernel wrapper. |
| onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc | Implements CausalConvWithState WGSL generation + kernel host-side validation/dispatch. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| TensorShapeVector inv_std_dev_dim; | ||
| for (size_t i = 0; i < x_shape.NumDimensions(); ++i) { | ||
| if (i < axis) { | ||
| inv_std_dev_dim.push_back(x_shape[i]); | ||
| } else { | ||
| inv_std_dev_dim.push_back(1); | ||
| } | ||
| } | ||
| TensorShape inv_std_dev_shape(inv_std_dev_dim); | ||
| auto* inv_std_dev = context.Output(1, inv_std_dev_shape); |
There was a problem hiding this comment.
inv_std_dev output is optional for RMSNormalization, but this code unconditionally calls context.Output(1, ...). If the node only has 1 output, this will be out-of-range and can crash/fail. Please guard with if (context.OutputCount() > 1) before requesting/producing output 1.
| TensorShapeVector inv_std_dev_dim; | |
| for (size_t i = 0; i < x_shape.NumDimensions(); ++i) { | |
| if (i < axis) { | |
| inv_std_dev_dim.push_back(x_shape[i]); | |
| } else { | |
| inv_std_dev_dim.push_back(1); | |
| } | |
| } | |
| TensorShape inv_std_dev_shape(inv_std_dev_dim); | |
| auto* inv_std_dev = context.Output(1, inv_std_dev_shape); | |
| Tensor* inv_std_dev = nullptr; | |
| if (context.OutputCount() > 1) { | |
| TensorShapeVector inv_std_dev_dim; | |
| for (size_t i = 0; i < x_shape.NumDimensions(); ++i) { | |
| if (i < axis) { | |
| inv_std_dev_dim.push_back(x_shape[i]); | |
| } else { | |
| inv_std_dev_dim.push_back(1); | |
| } | |
| } | |
| TensorShape inv_std_dev_shape(inv_std_dev_dim); | |
| inv_std_dev = context.Output(1, inv_std_dev_shape); | |
| } |
| update_rule_ == LinearAttentionUpdateRule::GatedDelta); | ||
| ORT_RETURN_IF(needs_decay && decay == nullptr, "decay input required for gated/gated_delta update rules"); | ||
| ORT_RETURN_IF(needs_beta && beta == nullptr, "beta input required for delta/gated_delta update rules"); | ||
|
|
There was a problem hiding this comment.
The implementation doesn't validate decay/beta shapes beyond nullptr checks, but the shader assumes specific packed layouts. In particular, the schema allows beta to be (B, T, 1) but the shader reads it as (B, T, H), which would cause out-of-bounds reads. Please either implement broadcasting for the (B,T,1) case (and validate dimensions), or explicitly reject unsupported shapes with a clear error.
| // Validate decay/beta shapes. The shader expects (B, T, H) where H == num_heads. | |
| if (needs_decay && decay != nullptr) { | |
| const auto& decay_shape = decay->Shape(); | |
| ORT_RETURN_IF(decay_shape.NumDimensions() != 3, | |
| "decay must have shape (batch_size, seq_length, num_heads); ", | |
| "broadcasted form (B, T, 1) is not currently supported"); | |
| ORT_RETURN_IF(static_cast<int64_t>(batch_size) != decay_shape[0] || | |
| static_cast<int64_t>(seq_length) != decay_shape[1], | |
| "decay shape mismatch: expected batch_size=", batch_size, | |
| " and seq_length=", seq_length, | |
| " but got (", decay_shape[0], ", ", decay_shape[1], ", ", decay_shape[2], ")"); | |
| ORT_RETURN_IF(static_cast<int64_t>(num_heads) != decay_shape[2], | |
| "decay last dimension must equal num_heads (", num_heads, | |
| "); broadcasted form (B, T, 1) is not currently supported, got ", | |
| decay_shape[2]); | |
| } | |
| if (needs_beta && beta != nullptr) { | |
| const auto& beta_shape = beta->Shape(); | |
| ORT_RETURN_IF(beta_shape.NumDimensions() != 3, | |
| "beta must have shape (batch_size, seq_length, num_heads); ", | |
| "broadcasted form (B, T, 1) is not currently supported"); | |
| ORT_RETURN_IF(static_cast<int64_t>(batch_size) != beta_shape[0] || | |
| static_cast<int64_t>(seq_length) != beta_shape[1], | |
| "beta shape mismatch: expected batch_size=", batch_size, | |
| " and seq_length=", seq_length, | |
| " but got (", beta_shape[0], ", ", beta_shape[1], ", ", beta_shape[2], ")"); | |
| ORT_RETURN_IF(static_cast<int64_t>(num_heads) != beta_shape[2], | |
| "beta last dimension must equal num_heads (", num_heads, | |
| "); broadcasted form (B, T, 1) is not currently supported, got ", | |
| beta_shape[2]); | |
| } |
| // Allocate outputs | ||
| // Output 0: (B, D, L) | ||
| Tensor* output = context.Output(0, input_shape); | ||
|
|
||
| // Output 1: present_state (B, D, K-1) | ||
| std::vector<int64_t> state_dims{batch_size, channels, state_length}; | ||
| Tensor* present_state = context.Output(1, TensorShape(state_dims)); | ||
|
|
||
| if (input_length == 0) { | ||
| return Status::OK(); | ||
| } |
There was a problem hiding this comment.
When input_length == 0, the code returns early after allocating present_state but never writes/initializes it. present_state should still be well-defined (typically equal to past_state if provided, otherwise zeros) even for an empty input. Please handle the zero-length case by populating present_state appropriately before returning.
| 1. Input parsing: Handles both 3D (B, S, hidden) and 4D (B, N, S, H) input formats per the ONNX spec | ||
| 2. MHA vs GQA: Detects whether q_num_heads == kv_num_heads (MHA) or q_num_heads > kv_num_heads (GQA) and configures WebgpuAttentionParameters accordingly | ||
| 3. Flash attention: Used when available (no output_qk needed, subgroups feature present, no bias) | ||
| 4. 3D→BNSH conversion: For 3D inputs, uses TransferBSDToBNSH to convert to the BNSH format expected by the attention kernels | ||
| 5. 4D output: Computes in BSD layout (as the shader outputs), then transposes back to BNSH for 4D output format | ||
| 6. Attention mask: Reshapes 2D/3D masks to 4D for the shader's broadcasting logic; boolean masks return NOT_SUPPORTED | ||
|
|
||
| Remaining failures fall into known limitation categories: | ||
| Boolean masks (2) — not yet supported on WebGPU | ||
| SoftCap (2) — not yet wired through to the shader | ||
| GQA output (3) — output stride mismatch for GQA with different kv_num_heads | ||
| QK matmul output (5) — the output_qk output needs additional work | ||
| Present without past (2) — present key/value output without past input needs handling | ||
| is_causal (1) — causal masking interaction | ||
|
|
||
| [ PASSED ] 24 tests. | ||
| [ FAILED ] 15 tests, listed below: | ||
| [ FAILED ] AttentionTest.Attention4DAttnMaskBoolAllFalse | ||
| [ FAILED ] AttentionTest.Attention4DAttnMaskBoolAllFalseDecodeWithPast | ||
| [ FAILED ] AttentionTest.Attention4DSoftCap | ||
| [ FAILED ] AttentionTest.Attention4DSoftCapFloat16 | ||
| [ FAILED ] AttentionTest.Attention4DAttnMaskBool | ||
| [ FAILED ] AttentionTest.Attention4DAttnIsCausal | ||
| [ FAILED ] AttentionTest.Attention3DGqaAttn | ||
| [ FAILED ] AttentionTest.Attention3DGqaSelfAttnCausal | ||
| [ FAILED ] AttentionTest.Attention4DGqaAttnMask | ||
| [ FAILED ] AttentionTest.Attention4DWithPastAndPresentQkMatmul | ||
| [ FAILED ] AttentionTest.Attention3DWithPastAndPresentQkMatmul | ||
| [ FAILED ] AttentionTest.Attention4DWithMask3DPastAndPresentQkMatmul | ||
| [ FAILED ] AttentionTest.Attention4DWithMask3DPastAndPresentQkMatmulCausal | ||
| [ FAILED ] AttentionTest.TestAttention4DWithPastAndPresentQkMatmulBias4DMaskCausal | ||
| [ FAILED ] AttentionTest.AttentionNoPastWithPresentOutput |
There was a problem hiding this comment.
This large block comment includes a snapshot of passing/failing test names and counts. It will go stale quickly and makes the production kernel harder to maintain. Consider moving this information to the PR description or a tracking issue (and keep only a brief comment describing current limitations in code).
| 1. Input parsing: Handles both 3D (B, S, hidden) and 4D (B, N, S, H) input formats per the ONNX spec | |
| 2. MHA vs GQA: Detects whether q_num_heads == kv_num_heads (MHA) or q_num_heads > kv_num_heads (GQA) and configures WebgpuAttentionParameters accordingly | |
| 3. Flash attention: Used when available (no output_qk needed, subgroups feature present, no bias) | |
| 4. 3D→BNSH conversion: For 3D inputs, uses TransferBSDToBNSH to convert to the BNSH format expected by the attention kernels | |
| 5. 4D output: Computes in BSD layout (as the shader outputs), then transposes back to BNSH for 4D output format | |
| 6. Attention mask: Reshapes 2D/3D masks to 4D for the shader's broadcasting logic; boolean masks return NOT_SUPPORTED | |
| Remaining failures fall into known limitation categories: | |
| Boolean masks (2) — not yet supported on WebGPU | |
| SoftCap (2) — not yet wired through to the shader | |
| GQA output (3) — output stride mismatch for GQA with different kv_num_heads | |
| QK matmul output (5) — the output_qk output needs additional work | |
| Present without past (2) — present key/value output without past input needs handling | |
| is_causal (1) — causal masking interaction | |
| [ PASSED ] 24 tests. | |
| [ FAILED ] 15 tests, listed below: | |
| [ FAILED ] AttentionTest.Attention4DAttnMaskBoolAllFalse | |
| [ FAILED ] AttentionTest.Attention4DAttnMaskBoolAllFalseDecodeWithPast | |
| [ FAILED ] AttentionTest.Attention4DSoftCap | |
| [ FAILED ] AttentionTest.Attention4DSoftCapFloat16 | |
| [ FAILED ] AttentionTest.Attention4DAttnMaskBool | |
| [ FAILED ] AttentionTest.Attention4DAttnIsCausal | |
| [ FAILED ] AttentionTest.Attention3DGqaAttn | |
| [ FAILED ] AttentionTest.Attention3DGqaSelfAttnCausal | |
| [ FAILED ] AttentionTest.Attention4DGqaAttnMask | |
| [ FAILED ] AttentionTest.Attention4DWithPastAndPresentQkMatmul | |
| [ FAILED ] AttentionTest.Attention3DWithPastAndPresentQkMatmul | |
| [ FAILED ] AttentionTest.Attention4DWithMask3DPastAndPresentQkMatmul | |
| [ FAILED ] AttentionTest.Attention4DWithMask3DPastAndPresentQkMatmulCausal | |
| [ FAILED ] AttentionTest.TestAttention4DWithPastAndPresentQkMatmulBias4DMaskCausal | |
| [ FAILED ] AttentionTest.AttentionNoPastWithPresentOutput | |
| 1. Input parsing: Handles both 3D (B, S, hidden) and 4D (B, N, S, H) input formats per the ONNX spec. | |
| 2. MHA vs GQA: Detects whether q_num_heads == kv_num_heads (MHA) or q_num_heads > kv_num_heads (GQA) and configures | |
| WebgpuAttentionParameters accordingly. | |
| 3. Flash attention: Used when available (no output_qk needed, subgroups feature present, no bias). | |
| 4. 3D→BNSH conversion: For 3D inputs, uses TransferBSDToBNSH to convert to the BNSH format expected by the attention kernels. | |
| 5. 4D output: Computes in BSD layout (as the shader outputs), then transposes back to BNSH for 4D output format. | |
| 6. Attention mask: Reshapes 2D/3D masks to 4D for the shader's broadcasting logic; boolean masks currently return | |
| NOT_SUPPORTED. | |
| Known current limitations (see associated PR or tracking issue for detailed test coverage status): | |
| - Boolean attention masks on WebGPU are not yet supported. | |
| - SoftCap is not yet wired through to the shader. | |
| - GQA output with differing q_num_heads and kv_num_heads has an output stride/layout mismatch. | |
| - The optional QK matmul (output_qk) output path requires additional work. | |
| - Present-only key/value outputs (present without past) are not fully handled. | |
| - Some is_causal configurations require additional handling of causal masking interactions. |
| TEST(ConcatOpTest, Concat1D_int64) { | ||
| // webgpu ep will fail for 0x1122334455667788 | ||
| const int64_t val = 0x11223344; | ||
| OpTester test("Concat"); |
There was a problem hiding this comment.
This int64 Concat test uses a 32-bit-sized constant (0x11223344), so it doesn't actually exercise 64-bit value handling. If WebGPU currently fails for larger int64 values, it's better to keep a true 64-bit test value (e.g. > 2^32) and explicitly exclude the WebGPU EP for this test until fixed, rather than weakening the test coverage.
| *state_shape.add_dim() = query_shape.dim(0); // B | ||
| state_shape.add_dim()->set_dim_value(kv_num_heads); // H_kv | ||
| // d_k = query.dim(2) / q_num_heads | ||
| if (query_shape.dim(2).has_dim_value()) { | ||
| state_shape.add_dim()->set_dim_value(query_shape.dim(2).dim_value() / q_num_heads); | ||
| } else { | ||
| state_shape.add_dim(); | ||
| } | ||
| // d_v = value.dim(2) / kv_num_heads | ||
| if (value_shape.dim(2).has_dim_value()) { | ||
| state_shape.add_dim()->set_dim_value(value_shape.dim(2).dim_value() / kv_num_heads); | ||
| } else { |
There was a problem hiding this comment.
Similarly, present_state shape inference computes d_k = query.dim(2) / q_num_heads and d_v = value.dim(2) / kv_num_heads without verifying divisibility. Please guard these divisions (mod == 0) to avoid emitting incorrect concrete dimension values in the inferred shape.
|
|
||
| static size_t NormalizeAxis(int64_t axis, size_t tensor_rank) { | ||
| int64_t rank = static_cast<int64_t>(tensor_rank); | ||
| if (axis < -rank && axis >= rank) { |
There was a problem hiding this comment.
NormalizeAxis range check is incorrect: axis < -rank && axis >= rank can never be true, so invalid axis values won't be rejected and may lead to overflow in the normalization/casts. This should be axis < -rank || axis >= rank (same logic as elsewhere in the codebase).
| if (axis < -rank && axis >= rank) { | |
| if (axis < -rank || axis >= rank) { |
| const Tensor* query = context.Input(0); | ||
| const Tensor* key = context.Input(1); | ||
| const Tensor* value = context.Input(2); | ||
| const Tensor* past_state = context.Input(3); // optional | ||
| const Tensor* decay = context.Input(4); // optional | ||
| const Tensor* beta = context.Input(5); // optional | ||
|
|
||
| // Validate 3D packed inputs | ||
| const auto& q_shape = query->Shape(); | ||
| ORT_RETURN_IF(q_shape.NumDimensions() != 3, "query must be 3D (B, T, H_q*d_k)"); | ||
|
|
||
| const int batch_size = static_cast<int>(q_shape[0]); | ||
| const int seq_length = static_cast<int>(q_shape[1]); | ||
| const int q_packed_dim = static_cast<int>(q_shape[2]); | ||
| const int num_heads = kv_num_heads_; | ||
|
|
||
| ORT_RETURN_IF(q_num_heads_ != kv_num_heads_, | ||
| "GQA (q_num_heads != kv_num_heads) is not yet supported"); | ||
|
|
||
| const int head_dim_k = q_packed_dim / q_num_heads_; | ||
| ORT_RETURN_IF(q_packed_dim != head_dim_k * q_num_heads_, | ||
| "query packed dim must be divisible by q_num_heads"); | ||
|
|
||
| const int v_packed_dim = static_cast<int>(value->Shape()[2]); | ||
| const int head_dim_v = v_packed_dim / kv_num_heads_; | ||
| ORT_RETURN_IF(v_packed_dim != head_dim_v * kv_num_heads_, | ||
| "value packed dim must be divisible by kv_num_heads"); |
There was a problem hiding this comment.
ComputeInternal derives head_dim_k from query but never validates that key has the expected 3D packed shape (B, T, H*dk) or that its last dimension matches q_num_heads_*head_dim_k. The shader indexes key using packed_dk based on the query shape, so a mismatched key shape can lead to out-of-bounds reads and incorrect results. Add explicit checks for key rank and dimensions (and batch/seq match) before launching the program.
| // Workgroup size = head_dim_k (one thread per dk row) | ||
| // Ensure it's a power of 2 for tree reduction (round up) | ||
| uint32_t workgroup_size = 1; | ||
| while (workgroup_size < static_cast<uint32_t>(head_dim_k)) { | ||
| workgroup_size *= 2; | ||
| } | ||
| // Cap at GPU limits | ||
| workgroup_size = std::min(workgroup_size, static_cast<uint32_t>(256)); |
There was a problem hiding this comment.
Workgroup size is capped to 256 even when head_dim_k is larger. The shader maps local_idx to a dk row and reductions assume full dk coverage, so capping below head_dim_k silently drops rows and produces incorrect results. Either enforce/validate head_dim_k <= max_workgroup_size_x (and return NOT_IMPLEMENTED/INVALID_ARGUMENT when exceeded) or redesign the algorithm to handle dk > workgroup_size via tiling/multiple workgroups.
| // Workgroup size = head_dim_k (one thread per dk row) | |
| // Ensure it's a power of 2 for tree reduction (round up) | |
| uint32_t workgroup_size = 1; | |
| while (workgroup_size < static_cast<uint32_t>(head_dim_k)) { | |
| workgroup_size *= 2; | |
| } | |
| // Cap at GPU limits | |
| workgroup_size = std::min(workgroup_size, static_cast<uint32_t>(256)); | |
| // Validate that head_dim_k does not exceed the maximum supported workgroup size. | |
| // The shader maps one thread to each dk row and relies on full dk coverage. | |
| const uint32_t kMaxWorkgroupSizeX = 256; | |
| if (static_cast<uint32_t>(head_dim_k) > kMaxWorkgroupSizeX) { | |
| return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, | |
| "WebGPU LinearAttention currently requires head_dim_k <= ", | |
| kMaxWorkgroupSizeX, | |
| "; got head_dim_k = ", | |
| head_dim_k, | |
| ". Consider reducing head_dim_k or updating the kernel implementation."); | |
| } | |
| // Workgroup size = head_dim_k (one thread per dk row) | |
| // Ensure it's a power of 2 for tree reduction (round up) | |
| uint32_t workgroup_size = 1; | |
| while (workgroup_size < static_cast<uint32_t>(head_dim_k)) { | |
| workgroup_size *= 2; | |
| } | |
| // Cap at GPU limits (head_dim_k is already validated to be <= kMaxWorkgroupSizeX) | |
| workgroup_size = std::min(workgroup_size, kMaxWorkgroupSizeX); |
| CausalConvWithState); | ||
|
|
||
| CausalConvWithState::CausalConvWithState(const OpKernelInfo& info) | ||
| : WebGpuKernel(info) { |
There was a problem hiding this comment.
The op schema defines an ndim attribute (default 1) but the WebGPU kernel ignores it entirely (constructor only reads activation). Please read/validate ndim and return a clear NOT_IMPLEMENTED/INVALID_ARGUMENT error for ndim != 1 (or implement higher dimensions), otherwise models exporting ndim=2/3 will silently run with incorrect semantics.
| : WebGpuKernel(info) { | |
| : WebGpuKernel(info) { | |
| // Validate supported dimensionality. | |
| const int64_t ndim = info.GetAttrOrDefault<int64_t>("ndim", 1); | |
| if (ndim != 1) { | |
| ORT_THROW("CausalConvWithState WebGPU kernel only supports ndim=1, but got ndim=", ndim); | |
| } |
| const std::vector<float>& query, // (B, q_num_heads, T, dk) | ||
| const std::vector<float>& key, // (B, n_k_heads, T, dk) | ||
| const std::vector<float>& value, // (B, kv_num_heads, T, dv) | ||
| const std::vector<float>* initial_state, // (B, kv_num_heads, dk, dv) | ||
| const std::vector<float>* decay, // (B, kv_num_heads, T[, dk]) | ||
| const std::vector<float>* beta, // (B, kv_num_heads, T) | ||
| std::vector<float>& output, // (B, kv_num_heads, T, dv) | ||
| std::vector<float>& final_state) { // (B, kv_num_heads, dk, dv) |
There was a problem hiding this comment.
| const std::vector<float>& query, // (B, q_num_heads, T, dk) | |
| const std::vector<float>& key, // (B, n_k_heads, T, dk) | |
| const std::vector<float>& value, // (B, kv_num_heads, T, dv) | |
| const std::vector<float>* initial_state, // (B, kv_num_heads, dk, dv) | |
| const std::vector<float>* decay, // (B, kv_num_heads, T[, dk]) | |
| const std::vector<float>* beta, // (B, kv_num_heads, T) | |
| std::vector<float>& output, // (B, kv_num_heads, T, dv) | |
| std::vector<float>& final_state) { // (B, kv_num_heads, dk, dv) | |
| const std::vector<float>& query, // (B, q_num_heads, T, dk) | |
| const std::vector<float>& key, // (B, n_k_heads, T, dk) | |
| const std::vector<float>& value, // (B, kv_num_heads, T, dv) | |
| const std::vector<float>* initial_state, // (B, kv_num_heads, dk, dv) | |
| const std::vector<float>* decay, // (B, kv_num_heads, T[, dk]) | |
| const std::vector<float>* beta, // (B, kv_num_heads, T) | |
| std::vector<float>& output, // (B, kv_num_heads, T, dv) | |
| std::vector<float>& final_state) { // (B, kv_num_heads, dk, dv) |
moved to here: #27996