-
Notifications
You must be signed in to change notification settings - Fork 3.9k
webgpu support for LinearAttention and CausalConvWithState #27896
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
96997f7
23e744d
b72a150
d8f43fd
1baf195
c53ceaf
c8de2f4
0948d56
231c92c
e69b579
56fe4ac
8e09ff3
95400be
c598455
6c0c736
28a3ee3
3f80587
4111b95
f7711c4
52cee10
a1e9827
5c80e30
81251ff
9ebcd61
079a33b
085b274
6474e3e
15e4010
f9ff1a5
72bfe8f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,155 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #include "contrib_ops/webgpu/bert/causal_conv_with_state.h" | ||
|
|
||
| #include "core/providers/webgpu/shader_helper.h" | ||
| #include "core/providers/webgpu/webgpu_supported_types.h" | ||
| #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" | ||
|
|
||
| using namespace onnxruntime::webgpu; | ||
|
Check warning on line 10 in onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc
|
||
|
|
||
| namespace onnxruntime { | ||
| namespace contrib { | ||
| namespace webgpu { | ||
|
|
||
| CausalConvActivation ParseCausalConvActivation(const std::string& activation_str) { | ||
| if (activation_str == "silu" || activation_str == "swish") { | ||
| return CausalConvActivation::Silu; | ||
| } else if (activation_str == "none" || activation_str.empty()) { | ||
| return CausalConvActivation::None; | ||
| } | ||
| ORT_THROW("Unknown activation for CausalConvWithState: ", activation_str); | ||
| } | ||
|
|
||
| // ============================================================================= | ||
| // CausalConvWithState Implementation | ||
| // ============================================================================= | ||
|
|
||
| ONNX_OPERATOR_KERNEL_EX( | ||
| CausalConvWithState, | ||
| kMSDomain, | ||
| 1, | ||
| kWebGpuExecutionProvider, | ||
| (*KernelDefBuilder::Create()) | ||
| .TypeConstraint("T", WebGpuSupportedFloatTypes()), | ||
| CausalConvWithState); | ||
|
|
||
| CausalConvWithState::CausalConvWithState(const OpKernelInfo& info) | ||
| : WebGpuKernel(info) { | ||
| std::string activation_str = info.GetAttrOrDefault<std::string>("activation", "none"); | ||
|
Check warning on line 40 in onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc
|
||
| activation_ = ParseCausalConvActivation(activation_str); | ||
| } | ||
|
|
||
| Status CausalConvWithStateProgram::GenerateShaderCode(ShaderHelper& shader) const { | ||
| shader.AddInput("input", ShaderUsage::UseElementTypeAlias); | ||
| shader.AddInput("weight", ShaderUsage::UseUniform); | ||
|
|
||
| if (has_bias_) { | ||
| shader.AddInput("bias", ShaderUsage::UseUniform); | ||
| } | ||
| if (has_conv_state_) { | ||
| shader.AddInput("conv_state", ShaderUsage::UseUniform); | ||
| } | ||
|
|
||
| shader.AddOutput("output", ShaderUsage::UseUniform); | ||
| shader.AddOutput("present_state", ShaderUsage::UseUniform); | ||
|
|
||
| return WGSL_TEMPLATE_APPLY(shader, "bert/causal_conv_with_state.wgsl.template", | ||
| WGSL_TEMPLATE_PARAMETER(has_bias, has_bias_), | ||
| WGSL_TEMPLATE_PARAMETER(has_conv_state, has_conv_state_), | ||
| WGSL_TEMPLATE_PARAMETER(use_silu, activation_ == CausalConvActivation::Silu)); | ||
| } | ||
|
|
||
| Status CausalConvWithState::ComputeInternal(ComputeContext& context) const { | ||
| const Tensor* input = context.Input(0); // (B, D, L) | ||
| const Tensor* weight = context.Input(1); // (D, 1, K) | ||
| const Tensor* bias = context.Input(2); // optional (D,) | ||
| const Tensor* conv_state = context.Input(3); // optional (B, D, K-1) — past_state | ||
|
|
||
| ORT_RETURN_IF(input == nullptr, "Input tensor must not be null"); | ||
| ORT_RETURN_IF(weight == nullptr, "Weight tensor must not be null"); | ||
|
|
||
| const auto& input_shape = input->Shape(); | ||
| const auto& weight_shape = weight->Shape(); | ||
|
|
||
| ORT_RETURN_IF(input_shape.NumDimensions() != 3, | ||
| "Input must be 3D (batch_size, channels, length)"); | ||
| ORT_RETURN_IF(weight_shape.NumDimensions() != 3, | ||
| "Weight must be 3D (channels, 1, kernel_size)"); | ||
|
|
||
| const int batch_size = static_cast<int>(input_shape[0]); | ||
| const int channels = static_cast<int>(input_shape[1]); | ||
| const int input_length = static_cast<int>(input_shape[2]); | ||
| const int kernel_size = static_cast<int>(weight_shape[2]); | ||
| const int state_length = kernel_size - 1; | ||
|
|
||
| ORT_RETURN_IF(static_cast<int>(weight_shape[0]) != channels, | ||
| "Weight first dim must match input channels"); | ||
| ORT_RETURN_IF(static_cast<int>(weight_shape[1]) != 1, | ||
| "Weight second dim must be 1 for depthwise convolution"); | ||
|
|
||
| if (bias != nullptr) { | ||
| ORT_RETURN_IF(bias->Shape().NumDimensions() != 1, | ||
| "Bias must be 1D"); | ||
| ORT_RETURN_IF(static_cast<int>(bias->Shape()[0]) != channels, | ||
| "Bias size must match channels"); | ||
| } | ||
|
|
||
| if (conv_state != nullptr) { | ||
| ORT_RETURN_IF(conv_state->Shape().NumDimensions() != 3, | ||
| "conv_state must be 3D (batch_size, channels, kernel_size - 1)"); | ||
| ORT_RETURN_IF(static_cast<int>(conv_state->Shape()[0]) != batch_size, | ||
| "conv_state batch_size must match input"); | ||
| ORT_RETURN_IF(static_cast<int>(conv_state->Shape()[1]) != channels, | ||
| "conv_state channels must match input"); | ||
| ORT_RETURN_IF(static_cast<int>(conv_state->Shape()[2]) != state_length, | ||
| "conv_state last dim must be kernel_size - 1"); | ||
| } | ||
|
|
||
| const bool has_bias = (bias != nullptr); | ||
| const bool has_conv_state = (conv_state != nullptr); | ||
|
|
||
| // Allocate outputs | ||
| // Output 0: (B, D, L) | ||
| Tensor* output = context.Output(0, input_shape); | ||
|
|
||
| // Output 1: present_state (B, D, K-1) | ||
| std::vector<int64_t> state_dims{batch_size, channels, state_length}; | ||
|
Check warning on line 118 in onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc
|
||
| Tensor* present_state = context.Output(1, TensorShape(state_dims)); | ||
|
|
||
| if (input_length == 0) { | ||
| return Status::OK(); | ||
| } | ||
|
Comment on lines
+113
to
+123
|
||
|
|
||
| // Create and run the shader program | ||
| CausalConvWithStateProgram program{activation_, has_bias, has_conv_state, kernel_size}; | ||
|
|
||
| uint32_t output_size = static_cast<uint32_t>(batch_size * channels * input_length); | ||
|
|
||
| program.AddInput({input, ProgramTensorMetadataDependency::Type}) | ||
| .AddInput({weight, ProgramTensorMetadataDependency::None}); | ||
|
|
||
| if (has_bias) { | ||
| program.AddInput({bias, ProgramTensorMetadataDependency::None}); | ||
| } | ||
| if (has_conv_state) { | ||
| program.AddInput({conv_state, ProgramTensorMetadataDependency::None}); | ||
| } | ||
|
|
||
| program.AddOutput({output, ProgramTensorMetadataDependency::None}) | ||
| .AddOutput({present_state, ProgramTensorMetadataDependency::None}) | ||
| .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) | ||
| .AddUniformVariable({static_cast<uint32_t>(batch_size)}) | ||
| .AddUniformVariable({static_cast<uint32_t>(channels)}) | ||
| .AddUniformVariable({static_cast<uint32_t>(input_length)}) | ||
| .AddUniformVariable({static_cast<uint32_t>(kernel_size)}) | ||
| .AddUniformVariable({static_cast<uint32_t>(state_length)}) | ||
| .AddUniformVariable({output_size}); | ||
|
|
||
| return context.RunProgram(program); | ||
| } | ||
|
|
||
| } // namespace webgpu | ||
| } // namespace contrib | ||
| } // namespace onnxruntime | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <string> | ||
|
|
||
| #include "core/providers/webgpu/program.h" | ||
| #include "core/providers/webgpu/webgpu_kernel.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace contrib { | ||
| namespace webgpu { | ||
|
|
||
| using namespace onnxruntime::webgpu; | ||
|
Check warning on line 15 in onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.h
|
||
| using onnxruntime::webgpu::ComputeContext; | ||
|
|
||
| // Activation mode for CausalConvWithState | ||
| enum class CausalConvActivation { | ||
| None, | ||
| Silu, | ||
| }; | ||
|
|
||
| CausalConvActivation ParseCausalConvActivation(const std::string& activation_str); | ||
|
|
||
| // Program for CausalConvWithState | ||
| class CausalConvWithStateProgram final : public Program<CausalConvWithStateProgram> { | ||
| public: | ||
| CausalConvWithStateProgram(CausalConvActivation activation, bool has_bias, bool has_conv_state, | ||
| int kernel_size) | ||
| : Program{"CausalConvWithState"}, | ||
| activation_(activation), | ||
| has_bias_(has_bias), | ||
| has_conv_state_(has_conv_state), | ||
| kernel_size_(kernel_size) {} | ||
|
|
||
| Status GenerateShaderCode(ShaderHelper& sh) const override; | ||
|
|
||
| WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( | ||
| {"batch_size", ProgramUniformVariableDataType::Uint32}, | ||
| {"channels", ProgramUniformVariableDataType::Uint32}, | ||
| {"input_length", ProgramUniformVariableDataType::Uint32}, | ||
| {"kernel_size", ProgramUniformVariableDataType::Uint32}, | ||
| {"state_length", ProgramUniformVariableDataType::Uint32}, | ||
| {"output_size", ProgramUniformVariableDataType::Uint32}); | ||
|
|
||
| private: | ||
| CausalConvActivation activation_; | ||
| bool has_bias_; | ||
| bool has_conv_state_; | ||
| [[maybe_unused]] int kernel_size_; | ||
| }; | ||
|
|
||
| // Kernel for CausalConvWithState | ||
| class CausalConvWithState final : public WebGpuKernel { | ||
| public: | ||
| CausalConvWithState(const OpKernelInfo& info); | ||
| Status ComputeInternal(ComputeContext& context) const override; | ||
|
|
||
| private: | ||
| CausalConvActivation activation_; | ||
| }; | ||
|
|
||
| } // namespace webgpu | ||
| } // namespace contrib | ||
| } // namespace onnxruntime | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,118 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #param has_bias | ||
| #param has_conv_state | ||
| #param use_silu | ||
|
|
||
| #use guardAgainstOutOfBoundsWorkgroupSizes | ||
|
|
||
| #if use_silu | ||
| fn silu(x: input_element_t) -> input_element_t { | ||
| return x / (1.0 + exp(-x)); | ||
| } | ||
| #endif | ||
|
|
||
| $MAIN { | ||
| guardAgainstOutOfBoundsWorkgroupSizes(uniforms.output_size); | ||
|
|
||
| let batch_size = uniforms.batch_size; | ||
| let channels = uniforms.channels; | ||
| let input_length = uniforms.input_length; | ||
| let kernel_size = uniforms.kernel_size; | ||
| let state_length = uniforms.state_length; // = kernel_size - 1 | ||
|
|
||
| let pos = global_idx % input_length; | ||
| let bc_idx = global_idx / input_length; | ||
| let batch_idx = bc_idx / channels; | ||
| let channel_idx = bc_idx % channels; | ||
|
|
||
| // Perform depthwise causal convolution for this (batch, channel, pos). | ||
| // The convolution window looks back kernel_size-1 positions. | ||
| // With conv_state providing the history before position 0, the | ||
| // "virtual" input is: [conv_state[0..state_length-1], input[0..L-1]] | ||
| // | ||
| // For output position pos: | ||
| // output[pos] = sum_{j=0}^{kernel_size-1} weight[j] * virtual_input[pos + j] | ||
| // where virtual_input is state_length positions of conv_state | ||
| // followed by input_length positions of input. | ||
|
|
||
| var acc: input_element_t = 0.0; | ||
|
|
||
| // Weight layout: (D, 1, K) -> channel_idx * kernel_size + j | ||
| let weight_base = channel_idx * kernel_size; | ||
|
|
||
| for (var j: u32 = 0; j < kernel_size; j = j + 1) { | ||
| // virtual_pos is the position in the concatenated [conv_state, input] | ||
| let virtual_pos = pos + j; | ||
|
|
||
| var val: input_element_t = 0.0; | ||
|
|
||
| #if has_conv_state | ||
| if (virtual_pos < state_length) { | ||
| // Read from conv_state: (B, D, state_length) | ||
| let state_idx = (batch_idx * channels + channel_idx) * state_length + virtual_pos; | ||
| val = conv_state[state_idx]; | ||
| } else { | ||
| // Read from input: (B, D, L) | ||
| let input_pos = virtual_pos - state_length; | ||
| let input_idx = (batch_idx * channels + channel_idx) * input_length + input_pos; | ||
| val = input[input_idx]; | ||
| } | ||
| #else | ||
| // No conv_state: pad with zeros for positions before the input | ||
| if (virtual_pos >= state_length) { | ||
| let input_pos = virtual_pos - state_length; | ||
| let input_idx = (batch_idx * channels + channel_idx) * input_length + input_pos; | ||
| val = input[input_idx]; | ||
| } | ||
| #endif | ||
|
|
||
| let w = weight[weight_base + j]; | ||
| acc = acc + val * w; | ||
| } | ||
|
|
||
| #if has_bias | ||
| acc = acc + bias[channel_idx]; | ||
| #endif | ||
|
|
||
| #if use_silu | ||
| acc = silu(acc); | ||
| #endif | ||
|
|
||
| // Write output: (B, D, L) | ||
| let out_idx = (batch_idx * channels + channel_idx) * input_length + pos; | ||
| output[out_idx] = acc; | ||
|
|
||
| // Write present_state: the last (kernel_size - 1) elements from the | ||
| // virtual input [conv_state, input]. We only write present_state once | ||
| // per (batch, channel), using the thread at pos == 0. | ||
| if (pos == 0u) { | ||
| for (var s: u32 = 0; s < state_length; s = s + 1) { | ||
| var state_val: input_element_t = 0.0; | ||
| // total_len = state_length + input_length | ||
| // We want virtual_input[total_len - state_length + s] = virtual_input[input_length + s] | ||
| let vp = input_length + s; | ||
|
|
||
| #if has_conv_state | ||
| if (vp < state_length) { | ||
| let si = (batch_idx * channels + channel_idx) * state_length + vp; | ||
| state_val = conv_state[si]; | ||
| } else { | ||
| let ip = vp - state_length; | ||
| let ii = (batch_idx * channels + channel_idx) * input_length + ip; | ||
| state_val = input[ii]; | ||
| } | ||
| #else | ||
| if (vp >= state_length) { | ||
| let ip = vp - state_length; | ||
| let ii = (batch_idx * channels + channel_idx) * input_length + ip; | ||
| state_val = input[ii]; | ||
| } | ||
| #endif | ||
|
|
||
| let ps_idx = (batch_idx * channels + channel_idx) * state_length + s; | ||
| present_state[ps_idx] = state_val; | ||
| } | ||
| } | ||
| } // MAIN |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The op schema defines an
ndimattribute (default 1) but the WebGPU kernel ignores it entirely (constructor only readsactivation). Please read/validatendimand return a clear NOT_IMPLEMENTED/INVALID_ARGUMENT error forndim != 1(or implement higher dimensions), otherwise models exportingndim=2/3will silently run with incorrect semantics.