-
Notifications
You must be signed in to change notification settings - Fork 3.9k
webgpu support for qwen3.5 #27996
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
webgpu support for qwen3.5 #27996
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,158 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #include "contrib_ops/webgpu/bert/causal_conv_with_state.h" | ||
|
|
||
| #include "core/providers/webgpu/shader_helper.h" | ||
| #include "core/providers/webgpu/webgpu_supported_types.h" | ||
| #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" | ||
|
|
||
| using namespace onnxruntime::webgpu; | ||
|
Check warning on line 10 in onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc
|
||
|
|
||
| namespace onnxruntime { | ||
| namespace contrib { | ||
| namespace webgpu { | ||
|
|
||
| CausalConvActivation ParseCausalConvActivation(const std::string& activation_str) { | ||
| if (activation_str == "silu" || activation_str == "swish") { | ||
| return CausalConvActivation::Silu; | ||
| } else if (activation_str == "none" || activation_str.empty()) { | ||
| return CausalConvActivation::None; | ||
| } | ||
| return CausalConvActivation::Invalid; | ||
| } | ||
|
|
||
| // ============================================================================= | ||
| // CausalConvWithState Implementation | ||
| // ============================================================================= | ||
|
|
||
| ONNX_OPERATOR_KERNEL_EX( | ||
| CausalConvWithState, | ||
| kMSDomain, | ||
| 1, | ||
| kWebGpuExecutionProvider, | ||
| (*KernelDefBuilder::Create()) | ||
| .TypeConstraint("T", WebGpuSupportedFloatTypes()), | ||
| CausalConvWithState); | ||
|
|
||
| CausalConvWithState::CausalConvWithState(const OpKernelInfo& info) | ||
| : WebGpuKernel(info) { | ||
| std::string activation_str = info.GetAttrOrDefault<std::string>("activation", "none"); | ||
|
Check warning on line 40 in onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc
|
||
| activation_ = ParseCausalConvActivation(activation_str); | ||
| ORT_ENFORCE(info.GetAttr<int64_t>("ndim", &ndim_).IsOK(), "Attribute 'ndim' is required"); | ||
| } | ||
|
|
||
| Status CausalConvWithStateProgram::GenerateShaderCode(ShaderHelper& shader) const { | ||
| shader.AddInput("input", ShaderUsage::UseElementTypeAlias); | ||
| shader.AddInput("weight", ShaderUsage::UseUniform); | ||
|
|
||
| if (has_bias_) { | ||
| shader.AddInput("bias", ShaderUsage::UseUniform); | ||
| } | ||
| if (has_conv_state_) { | ||
| shader.AddInput("conv_state", ShaderUsage::UseUniform); | ||
| } | ||
|
|
||
| shader.AddOutput("output", ShaderUsage::UseUniform); | ||
| shader.AddOutput("present_state", ShaderUsage::UseUniform); | ||
|
|
||
| return WGSL_TEMPLATE_APPLY(shader, "bert/causal_conv_with_state.wgsl.template", | ||
| WGSL_TEMPLATE_PARAMETER(has_bias, has_bias_), | ||
| WGSL_TEMPLATE_PARAMETER(has_conv_state, has_conv_state_), | ||
| WGSL_TEMPLATE_PARAMETER(use_silu, activation_ == CausalConvActivation::Silu)); | ||
| } | ||
|
|
||
| Status CausalConvWithState::ComputeInternal(ComputeContext& context) const { | ||
| const Tensor* input = context.Input(0); // (B, D, L) | ||
| const Tensor* weight = context.Input(1); // (D, 1, K) | ||
| const Tensor* bias = context.Input(2); // optional (D,) | ||
| const Tensor* conv_state = context.Input(3); // optional (B, D, K-1) — past_state | ||
|
|
||
| ORT_RETURN_IF(activation_ == CausalConvActivation::Invalid, "Invalid activation type"); | ||
| ORT_RETURN_IF(ndim_ != 1, "Only 1D convolution is supported"); | ||
| const auto& input_shape = input->Shape(); | ||
| const auto& weight_shape = weight->Shape(); | ||
|
|
||
| ORT_RETURN_IF(input_shape.NumDimensions() != 3, | ||
| "Input must be 3D (batch_size, channels, length)"); | ||
| ORT_RETURN_IF(weight_shape.NumDimensions() != 3, | ||
| "Weight must be 3D (channels, 1, kernel_size)"); | ||
|
|
||
| const int64_t batch_size = input_shape[0]; | ||
| const int64_t channels = input_shape[1]; | ||
| const int64_t input_length = input_shape[2]; | ||
| const int64_t kernel_size = weight_shape[2]; | ||
| const int64_t state_length = kernel_size - 1; | ||
|
|
||
|
guschmue marked this conversation as resolved.
|
||
| ORT_RETURN_IF(weight_shape[0] != channels, "Weight first dim must match input channels"); | ||
| ORT_RETURN_IF(weight_shape[1] != 1, "Weight second dim must be 1 for depthwise convolution"); | ||
|
|
||
| if (bias != nullptr) { | ||
| ORT_RETURN_IF(bias->Shape().NumDimensions() != 1, "Bias must be 1D"); | ||
| ORT_RETURN_IF(bias->Shape()[0] != channels, "Bias size must match channels"); | ||
| } | ||
|
|
||
| if (conv_state != nullptr) { | ||
| ORT_RETURN_IF(conv_state->Shape().NumDimensions() != 3, | ||
| "conv_state must be 3D (batch_size, channels, kernel_size - 1)"); | ||
| ORT_RETURN_IF(conv_state->Shape()[0] != batch_size, | ||
| "conv_state batch_size must match input"); | ||
| ORT_RETURN_IF(conv_state->Shape()[1] != channels, | ||
| "conv_state channels must match input"); | ||
| ORT_RETURN_IF(conv_state->Shape()[2] != state_length, | ||
| "conv_state last dim must be kernel_size - 1"); | ||
| } | ||
|
|
||
| const bool has_bias = (bias != nullptr); | ||
| const bool has_conv_state = (conv_state != nullptr); | ||
|
|
||
| // Allocate outputs | ||
| // Output 0: (B, D, L) | ||
| Tensor* output = context.Output(0, input_shape); | ||
|
guschmue marked this conversation as resolved.
|
||
|
|
||
| // Output 1: present_state (B, D, K-1) | ||
| std::vector<int64_t> state_dims{batch_size, channels, state_length}; | ||
|
Check warning on line 114 in onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc
|
||
| Tensor* present_state = context.Output(1, TensorShape(state_dims)); | ||
|
|
||
| if (input_shape.Size() == 0) { | ||
| if (has_conv_state) { | ||
| ORT_RETURN_IF_ERROR(context.CopyTensor(*conv_state, *present_state)); | ||
| } else { | ||
| context.FillZero(*present_state); | ||
| return Status::OK(); | ||
| } | ||
| } | ||
|
|
||
| // Create and run the shader program | ||
| CausalConvWithStateProgram program{activation_, has_bias, has_conv_state}; | ||
|
|
||
| uint32_t output_size = static_cast<uint32_t>(batch_size * channels * input_length); | ||
|
|
||
| program.CacheHint(has_bias, has_conv_state, kernel_size, static_cast<int>(activation_)); | ||
|
|
||
| program.AddInput({input, ProgramTensorMetadataDependency::Type}) | ||
| .AddInput({weight, ProgramTensorMetadataDependency::None}); | ||
|
|
||
| if (has_bias) { | ||
| program.AddInput({bias, ProgramTensorMetadataDependency::None}); | ||
| } | ||
| if (has_conv_state) { | ||
| program.AddInput({conv_state, ProgramTensorMetadataDependency::None}); | ||
| } | ||
|
|
||
| program.AddOutput({output, ProgramTensorMetadataDependency::None}) | ||
| .AddOutput({present_state, ProgramTensorMetadataDependency::None}) | ||
| .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) | ||
| .AddUniformVariable({static_cast<uint32_t>(batch_size)}) | ||
|
guschmue marked this conversation as resolved.
|
||
| .AddUniformVariable({static_cast<uint32_t>(channels)}) | ||
| .AddUniformVariable({static_cast<uint32_t>(input_length)}) | ||
| .AddUniformVariable({static_cast<uint32_t>(kernel_size)}) | ||
| .AddUniformVariable({static_cast<uint32_t>(state_length)}) | ||
| .AddUniformVariable({output_size}); | ||
|
|
||
| return context.RunProgram(program); | ||
| } | ||
|
|
||
| } // namespace webgpu | ||
| } // namespace contrib | ||
| } // namespace onnxruntime | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <string> | ||
|
|
||
| #include "core/providers/webgpu/program.h" | ||
| #include "core/providers/webgpu/webgpu_kernel.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace contrib { | ||
| namespace webgpu { | ||
|
|
||
| using namespace onnxruntime::webgpu; | ||
|
Check warning on line 15 in onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.h
|
||
| using onnxruntime::webgpu::ComputeContext; | ||
|
|
||
| // Activation mode for CausalConvWithState | ||
| enum class CausalConvActivation { | ||
| Invalid, | ||
| None, | ||
| Silu | ||
| }; | ||
|
|
||
| CausalConvActivation ParseCausalConvActivation(const std::string& activation_str); | ||
|
|
||
| // Program for CausalConvWithState | ||
| class CausalConvWithStateProgram final : public Program<CausalConvWithStateProgram> { | ||
| public: | ||
| CausalConvWithStateProgram(CausalConvActivation activation, bool has_bias, bool has_conv_state) | ||
| : Program{"CausalConvWithState"}, | ||
| activation_(activation), | ||
| has_bias_(has_bias), | ||
| has_conv_state_(has_conv_state) {} | ||
|
|
||
| Status GenerateShaderCode(ShaderHelper& sh) const override; | ||
|
|
||
| WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( | ||
| {"batch_size", ProgramUniformVariableDataType::Uint32}, | ||
| {"channels", ProgramUniformVariableDataType::Uint32}, | ||
| {"input_length", ProgramUniformVariableDataType::Uint32}, | ||
| {"kernel_size", ProgramUniformVariableDataType::Uint32}, | ||
| {"state_length", ProgramUniformVariableDataType::Uint32}, | ||
| {"output_size", ProgramUniformVariableDataType::Uint32}); | ||
|
|
||
| private: | ||
| CausalConvActivation activation_; | ||
| bool has_bias_; | ||
| bool has_conv_state_; | ||
| }; | ||
|
|
||
| // Kernel for CausalConvWithState | ||
| class CausalConvWithState final : public WebGpuKernel { | ||
| public: | ||
| CausalConvWithState(const OpKernelInfo& info); | ||
| Status ComputeInternal(ComputeContext& context) const override; | ||
|
|
||
| private: | ||
| CausalConvActivation activation_; | ||
| int64_t ndim_; | ||
| }; | ||
|
|
||
| } // namespace webgpu | ||
| } // namespace contrib | ||
| } // namespace onnxruntime | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,118 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #param has_bias | ||
| #param has_conv_state | ||
| #param use_silu | ||
|
|
||
| #use guardAgainstOutOfBoundsWorkgroupSizes | ||
|
|
||
| #if use_silu | ||
| fn silu(x: input_element_t) -> input_element_t { | ||
| return x / (1.0 + exp(-x)); | ||
| } | ||
| #endif | ||
|
|
||
| $MAIN { | ||
| guardAgainstOutOfBoundsWorkgroupSizes(uniforms.output_size); | ||
|
|
||
| let batch_size = uniforms.batch_size; | ||
| let channels = uniforms.channels; | ||
| let input_length = uniforms.input_length; | ||
| let kernel_size = uniforms.kernel_size; | ||
| let state_length = uniforms.state_length; // = kernel_size - 1 | ||
|
|
||
| let pos = global_idx % input_length; | ||
| let bc_idx = global_idx / input_length; | ||
| let batch_idx = bc_idx / channels; | ||
| let channel_idx = bc_idx % channels; | ||
|
|
||
| // Perform depthwise causal convolution for this (batch, channel, pos). | ||
| // The convolution window looks back kernel_size-1 positions. | ||
| // With conv_state providing the history before position 0, the | ||
| // "virtual" input is: [conv_state[0..state_length-1], input[0..L-1]] | ||
| // | ||
| // For output position pos: | ||
| // output[pos] = sum_{j=0}^{kernel_size-1} weight[j] * virtual_input[pos + j] | ||
| // where virtual_input is state_length positions of conv_state | ||
| // followed by input_length positions of input. | ||
|
|
||
| var acc: input_element_t = 0.0; | ||
|
|
||
| // Weight layout: (D, 1, K) -> channel_idx * kernel_size + j | ||
| let weight_base = channel_idx * kernel_size; | ||
|
|
||
| for (var j: u32 = 0; j < kernel_size; j = j + 1) { | ||
| // virtual_pos is the position in the concatenated [conv_state, input] | ||
| let virtual_pos = pos + j; | ||
|
|
||
| var val: input_element_t = 0.0; | ||
|
|
||
| #if has_conv_state | ||
| if (virtual_pos < state_length) { | ||
| // Read from conv_state: (B, D, state_length) | ||
| let state_idx = (batch_idx * channels + channel_idx) * state_length + virtual_pos; | ||
| val = conv_state[state_idx]; | ||
| } else { | ||
| // Read from input: (B, D, L) | ||
| let input_pos = virtual_pos - state_length; | ||
| let input_idx = (batch_idx * channels + channel_idx) * input_length + input_pos; | ||
| val = input[input_idx]; | ||
| } | ||
| #else | ||
| // No conv_state: pad with zeros for positions before the input | ||
| if (virtual_pos >= state_length) { | ||
| let input_pos = virtual_pos - state_length; | ||
| let input_idx = (batch_idx * channels + channel_idx) * input_length + input_pos; | ||
| val = input[input_idx]; | ||
| } | ||
| #endif | ||
|
|
||
| let w = weight[weight_base + j]; | ||
| acc = acc + val * w; | ||
| } | ||
|
|
||
| #if has_bias | ||
| acc = acc + bias[channel_idx]; | ||
| #endif | ||
|
|
||
| #if use_silu | ||
| acc = silu(acc); | ||
| #endif | ||
|
|
||
| // Write output: (B, D, L) | ||
| let out_idx = (batch_idx * channels + channel_idx) * input_length + pos; | ||
| output[out_idx] = acc; | ||
|
|
||
| // Write present_state: the last (kernel_size - 1) elements from the | ||
| // virtual input [conv_state, input]. We only write present_state once | ||
| // per (batch, channel), using the thread at pos == 0. | ||
| if (pos == 0u) { | ||
| for (var s: u32 = 0; s < state_length; s = s + 1) { | ||
| var state_val: input_element_t = 0.0; | ||
| // total_len = state_length + input_length | ||
| // We want virtual_input[total_len - state_length + s] = virtual_input[input_length + s] | ||
| let vp = input_length + s; | ||
|
|
||
| #if has_conv_state | ||
| if (vp < state_length) { | ||
| let si = (batch_idx * channels + channel_idx) * state_length + vp; | ||
| state_val = conv_state[si]; | ||
| } else { | ||
| let ip = vp - state_length; | ||
| let ii = (batch_idx * channels + channel_idx) * input_length + ip; | ||
| state_val = input[ii]; | ||
| } | ||
| #else | ||
| if (vp >= state_length) { | ||
| let ip = vp - state_length; | ||
| let ii = (batch_idx * channels + channel_idx) * input_length + ip; | ||
| state_val = input[ii]; | ||
| } | ||
| #endif | ||
|
|
||
| let ps_idx = (batch_idx * channels + channel_idx) * state_length + s; | ||
| present_state[ps_idx] = state_val; | ||
| } | ||
| } | ||
| } // MAIN |
Uh oh!
There was an error while loading. Please reload this page.