diff --git a/onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc b/onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc new file mode 100644 index 0000000000000..229b3777e62d1 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc @@ -0,0 +1,158 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/webgpu/bert/causal_conv_with_state.h" + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +using namespace onnxruntime::webgpu; + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +CausalConvActivation ParseCausalConvActivation(const std::string& activation_str) { + if (activation_str == "silu" || activation_str == "swish") { + return CausalConvActivation::Silu; + } else if (activation_str == "none" || activation_str.empty()) { + return CausalConvActivation::None; + } + return CausalConvActivation::Invalid; +} + +// ============================================================================= +// CausalConvWithState Implementation +// ============================================================================= + +ONNX_OPERATOR_KERNEL_EX( + CausalConvWithState, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()), + CausalConvWithState); + +CausalConvWithState::CausalConvWithState(const OpKernelInfo& info) + : WebGpuKernel(info) { + std::string activation_str = info.GetAttrOrDefault("activation", "none"); + activation_ = ParseCausalConvActivation(activation_str); + ORT_ENFORCE(info.GetAttr("ndim", &ndim_).IsOK(), "Attribute 'ndim' is required"); +} + +Status CausalConvWithStateProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("input", ShaderUsage::UseElementTypeAlias); + shader.AddInput("weight", ShaderUsage::UseUniform); + + if (has_bias_) { + shader.AddInput("bias", ShaderUsage::UseUniform); + } + if (has_conv_state_) { + shader.AddInput("conv_state", ShaderUsage::UseUniform); + } + + shader.AddOutput("output", ShaderUsage::UseUniform); + shader.AddOutput("present_state", ShaderUsage::UseUniform); + + return WGSL_TEMPLATE_APPLY(shader, "bert/causal_conv_with_state.wgsl.template", + WGSL_TEMPLATE_PARAMETER(has_bias, has_bias_), + WGSL_TEMPLATE_PARAMETER(has_conv_state, has_conv_state_), + WGSL_TEMPLATE_PARAMETER(use_silu, activation_ == CausalConvActivation::Silu)); +} + +Status CausalConvWithState::ComputeInternal(ComputeContext& context) const { + const Tensor* input = context.Input(0); // (B, D, L) + const Tensor* weight = context.Input(1); // (D, 1, K) + const Tensor* bias = context.Input(2); // optional (D,) + const Tensor* conv_state = context.Input(3); // optional (B, D, K-1) — past_state + + ORT_RETURN_IF(activation_ == CausalConvActivation::Invalid, "Invalid activation type"); + ORT_RETURN_IF(ndim_ != 1, "Only 1D convolution is supported"); + const auto& input_shape = input->Shape(); + const auto& weight_shape = weight->Shape(); + + ORT_RETURN_IF(input_shape.NumDimensions() != 3, + "Input must be 3D (batch_size, channels, length)"); + ORT_RETURN_IF(weight_shape.NumDimensions() != 3, + "Weight must be 3D (channels, 1, kernel_size)"); + + const int64_t batch_size = input_shape[0]; + const int64_t channels = input_shape[1]; + const int64_t input_length = input_shape[2]; + const int64_t kernel_size = weight_shape[2]; + const int64_t state_length = kernel_size - 1; + + ORT_RETURN_IF(weight_shape[0] != channels, "Weight first dim must match input channels"); + ORT_RETURN_IF(weight_shape[1] != 1, "Weight second dim must be 1 for depthwise convolution"); + + if (bias != nullptr) { + ORT_RETURN_IF(bias->Shape().NumDimensions() != 1, "Bias must be 1D"); + ORT_RETURN_IF(bias->Shape()[0] != channels, "Bias size must match channels"); + } + + if (conv_state != nullptr) { + ORT_RETURN_IF(conv_state->Shape().NumDimensions() != 3, + "conv_state must be 3D (batch_size, channels, kernel_size - 1)"); + ORT_RETURN_IF(conv_state->Shape()[0] != batch_size, + "conv_state batch_size must match input"); + ORT_RETURN_IF(conv_state->Shape()[1] != channels, + "conv_state channels must match input"); + ORT_RETURN_IF(conv_state->Shape()[2] != state_length, + "conv_state last dim must be kernel_size - 1"); + } + + const bool has_bias = (bias != nullptr); + const bool has_conv_state = (conv_state != nullptr); + + // Allocate outputs + // Output 0: (B, D, L) + Tensor* output = context.Output(0, input_shape); + + // Output 1: present_state (B, D, K-1) + std::vector state_dims{batch_size, channels, state_length}; + Tensor* present_state = context.Output(1, TensorShape(state_dims)); + + if (input_shape.Size() == 0) { + if (has_conv_state) { + ORT_RETURN_IF_ERROR(context.CopyTensor(*conv_state, *present_state)); + } else { + context.FillZero(*present_state); + return Status::OK(); + } + } + + // Create and run the shader program + CausalConvWithStateProgram program{activation_, has_bias, has_conv_state}; + + uint32_t output_size = static_cast(batch_size * channels * input_length); + + program.CacheHint(has_bias, has_conv_state, kernel_size, static_cast(activation_)); + + program.AddInput({input, ProgramTensorMetadataDependency::Type}) + .AddInput({weight, ProgramTensorMetadataDependency::None}); + + if (has_bias) { + program.AddInput({bias, ProgramTensorMetadataDependency::None}); + } + if (has_conv_state) { + program.AddInput({conv_state, ProgramTensorMetadataDependency::None}); + } + + program.AddOutput({output, ProgramTensorMetadataDependency::None}) + .AddOutput({present_state, ProgramTensorMetadataDependency::None}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariable({static_cast(batch_size)}) + .AddUniformVariable({static_cast(channels)}) + .AddUniformVariable({static_cast(input_length)}) + .AddUniformVariable({static_cast(kernel_size)}) + .AddUniformVariable({static_cast(state_length)}) + .AddUniformVariable({output_size}); + + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.h b/onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.h new file mode 100644 index 0000000000000..a87412bdc9070 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.h @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + +// Activation mode for CausalConvWithState +enum class CausalConvActivation { + Invalid, + None, + Silu +}; + +CausalConvActivation ParseCausalConvActivation(const std::string& activation_str); + +// Program for CausalConvWithState +class CausalConvWithStateProgram final : public Program { + public: + CausalConvWithStateProgram(CausalConvActivation activation, bool has_bias, bool has_conv_state) + : Program{"CausalConvWithState"}, + activation_(activation), + has_bias_(has_bias), + has_conv_state_(has_conv_state) {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"batch_size", ProgramUniformVariableDataType::Uint32}, + {"channels", ProgramUniformVariableDataType::Uint32}, + {"input_length", ProgramUniformVariableDataType::Uint32}, + {"kernel_size", ProgramUniformVariableDataType::Uint32}, + {"state_length", ProgramUniformVariableDataType::Uint32}, + {"output_size", ProgramUniformVariableDataType::Uint32}); + + private: + CausalConvActivation activation_; + bool has_bias_; + bool has_conv_state_; +}; + +// Kernel for CausalConvWithState +class CausalConvWithState final : public WebGpuKernel { + public: + CausalConvWithState(const OpKernelInfo& info); + Status ComputeInternal(ComputeContext& context) const override; + + private: + CausalConvActivation activation_; + int64_t ndim_; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.wgsl.template new file mode 100644 index 0000000000000..e109f167d27b1 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.wgsl.template @@ -0,0 +1,118 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#param has_bias +#param has_conv_state +#param use_silu + +#use guardAgainstOutOfBoundsWorkgroupSizes + +#if use_silu +fn silu(x: input_element_t) -> input_element_t { + return x / (1.0 + exp(-x)); +} +#endif + +$MAIN { + guardAgainstOutOfBoundsWorkgroupSizes(uniforms.output_size); + + let batch_size = uniforms.batch_size; + let channels = uniforms.channels; + let input_length = uniforms.input_length; + let kernel_size = uniforms.kernel_size; + let state_length = uniforms.state_length; // = kernel_size - 1 + + let pos = global_idx % input_length; + let bc_idx = global_idx / input_length; + let batch_idx = bc_idx / channels; + let channel_idx = bc_idx % channels; + + // Perform depthwise causal convolution for this (batch, channel, pos). + // The convolution window looks back kernel_size-1 positions. + // With conv_state providing the history before position 0, the + // "virtual" input is: [conv_state[0..state_length-1], input[0..L-1]] + // + // For output position pos: + // output[pos] = sum_{j=0}^{kernel_size-1} weight[j] * virtual_input[pos + j] + // where virtual_input is state_length positions of conv_state + // followed by input_length positions of input. + + var acc: input_element_t = 0.0; + + // Weight layout: (D, 1, K) -> channel_idx * kernel_size + j + let weight_base = channel_idx * kernel_size; + + for (var j: u32 = 0; j < kernel_size; j = j + 1) { + // virtual_pos is the position in the concatenated [conv_state, input] + let virtual_pos = pos + j; + + var val: input_element_t = 0.0; + +#if has_conv_state + if (virtual_pos < state_length) { + // Read from conv_state: (B, D, state_length) + let state_idx = (batch_idx * channels + channel_idx) * state_length + virtual_pos; + val = conv_state[state_idx]; + } else { + // Read from input: (B, D, L) + let input_pos = virtual_pos - state_length; + let input_idx = (batch_idx * channels + channel_idx) * input_length + input_pos; + val = input[input_idx]; + } +#else + // No conv_state: pad with zeros for positions before the input + if (virtual_pos >= state_length) { + let input_pos = virtual_pos - state_length; + let input_idx = (batch_idx * channels + channel_idx) * input_length + input_pos; + val = input[input_idx]; + } +#endif + + let w = weight[weight_base + j]; + acc = acc + val * w; + } + +#if has_bias + acc = acc + bias[channel_idx]; +#endif + +#if use_silu + acc = silu(acc); +#endif + + // Write output: (B, D, L) + let out_idx = (batch_idx * channels + channel_idx) * input_length + pos; + output[out_idx] = acc; + + // Write present_state: the last (kernel_size - 1) elements from the + // virtual input [conv_state, input]. We only write present_state once + // per (batch, channel), using the thread at pos == 0. + if (pos == 0u) { + for (var s: u32 = 0; s < state_length; s = s + 1) { + var state_val: input_element_t = 0.0; + // total_len = state_length + input_length + // We want virtual_input[total_len - state_length + s] = virtual_input[input_length + s] + let vp = input_length + s; + +#if has_conv_state + if (vp < state_length) { + let si = (batch_idx * channels + channel_idx) * state_length + vp; + state_val = conv_state[si]; + } else { + let ip = vp - state_length; + let ii = (batch_idx * channels + channel_idx) * input_length + ip; + state_val = input[ii]; + } +#else + if (vp >= state_length) { + let ip = vp - state_length; + let ii = (batch_idx * channels + channel_idx) * input_length + ip; + state_val = input[ii]; + } +#endif + + let ps_idx = (batch_idx * channels + channel_idx) * state_length + s; + present_state[ps_idx] = state_val; + } + } +} // MAIN diff --git a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc new file mode 100644 index 0000000000000..aed71065a8354 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.cc @@ -0,0 +1,284 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/webgpu/bert/linear_attention.h" + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +using namespace onnxruntime::webgpu; + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +LinearAttentionUpdateRule ParseUpdateRule(const std::string& rule_str) { + if (rule_str == "linear") { + return LinearAttentionUpdateRule::Linear; + } else if (rule_str == "gated") { + return LinearAttentionUpdateRule::Gated; + } else if (rule_str == "delta") { + return LinearAttentionUpdateRule::Delta; + } else if (rule_str == "gated_delta") { + return LinearAttentionUpdateRule::GatedDelta; + } + return LinearAttentionUpdateRule::Invalid; +} + +// ============================================================================= +// LinearAttention Shader Implementation +// ============================================================================= +// +// Design overview: +// - Each workgroup handles one (batch, head, dv_tile) combination +// - Workgroup size = head_dim_k (dk): one thread per state row +// - Each thread maintains TILE_V columns of its state row in private memory +// - Tokens are processed sequentially; matrix ops are parallelized across threads +// - Reductions across dk (for S^T @ k and S^T @ q) use shared memory +// + +Status LinearAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { + const bool use_vec4 = (components_ == 4); + + // Map update rule to integer for template conditionals + int update_rule_int = 0; + switch (update_rule_) { + case LinearAttentionUpdateRule::Linear: + update_rule_int = 0; + break; + case LinearAttentionUpdateRule::Gated: + update_rule_int = 1; + break; + case LinearAttentionUpdateRule::Delta: + update_rule_int = 2; + break; + case LinearAttentionUpdateRule::GatedDelta: + update_rule_int = 3; + break; + case LinearAttentionUpdateRule::Invalid: + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Invalid update rule"); + } + + // Add inputs + shader.AddInput("query", ShaderUsage::UseUniform); + shader.AddInput("key", ShaderUsage::UseUniform); + shader.AddInput("value", ShaderUsage::UseUniform); + if (has_initial_state_) { + shader.AddInput("initial_state", ShaderUsage::UseUniform); + } + if (has_decay_) { + shader.AddInput("decay", ShaderUsage::UseUniform); + } + if (has_beta_) { + shader.AddInput("beta", ShaderUsage::UseUniform); + } + + // Add outputs - UseValueTypeAlias for vec4 writes, UseElementTypeAlias for scalar writes + shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddOutput("present_state", ShaderUsage::UseUniform); + + return WGSL_TEMPLATE_APPLY(shader, "bert/linear_attention.wgsl.template", + WGSL_TEMPLATE_PARAMETER(decay_broadcast_dk, decay_broadcast_dk_), + WGSL_TEMPLATE_PARAMETER(has_initial_state, has_initial_state_), + WGSL_TEMPLATE_PARAMETER(tile_v, tile_v_), + WGSL_TEMPLATE_PARAMETER(update_rule, update_rule_int), + WGSL_TEMPLATE_PARAMETER(use_vec4, use_vec4)); +} + +// ============================================================================= +// LinearAttention Kernel Registration and Computation +// ============================================================================= + +ONNX_OPERATOR_KERNEL_EX( + LinearAttention, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()), + LinearAttention); + +LinearAttention::LinearAttention(const OpKernelInfo& info) + : WebGpuKernel(info) { + std::string update_rule_str = info.GetAttr("update_rule"); + update_rule_ = ParseUpdateRule(update_rule_str); + scale_ = info.GetAttrOrDefault("scale", 0.0f); + q_num_heads_ = static_cast(info.GetAttr("q_num_heads")); + kv_num_heads_ = static_cast(info.GetAttr("kv_num_heads")); +} + +/* + 3D packed inputs: + query: (B, T, H_q * d_k) — packed query + key: (B, T, H_kv * d_k) — packed key + value: (B, T, H_kv * d_v) — packed value + past_state: (B, H_kv, d_k, d_v) — recurrent state (4D) + decay: (B, T, H_kv * d_k) or (B, T, H_kv) — decay gate (3D) + beta: (B, T, H_kv) or (B, T, 1) — update rate (3D) + + Outputs: + output: (B, T, H_q * d_v) — packed attention output + present_state: (B, H_kv, d_k, d_v) — updated recurrent state (4D) +*/ +Status LinearAttention::ComputeInternal(ComputeContext& context) const { + const Tensor* query = context.Input(0); + const Tensor* key = context.Input(1); + const Tensor* value = context.Input(2); + const Tensor* past_state = context.Input(3); // optional + const Tensor* decay = context.Input(4); // optional + const Tensor* beta = context.Input(5); // optional + + // Validate 3D packed inputs + const auto& q_shape = query->Shape(); + ORT_RETURN_IF(q_shape.NumDimensions() != 3, "query must be 3D (B, T, H_q*d_k)"); + const auto& k_shape = key->Shape(); + ORT_RETURN_IF(k_shape.NumDimensions() != 3, "key must be 3D (B, T, H_k*d_k)"); + const auto& v_shape = value->Shape(); + ORT_RETURN_IF(v_shape.NumDimensions() != 3, "value must be 3D (B, T, H_v*d_v)"); + + const int64_t batch_size = q_shape[0]; + const int64_t seq_length = q_shape[1]; + ORT_RETURN_IF(k_shape[0] != batch_size || k_shape[1] != seq_length, + "key batch/sequence dimensions must match query"); + ORT_RETURN_IF(v_shape[0] != batch_size || v_shape[1] != seq_length, + "value batch/sequence dimensions must match query"); + + const int64_t q_packed_dim = q_shape[2]; + ORT_RETURN_IF(q_num_heads_ <= 0 || q_packed_dim % q_num_heads_ != 0, + "query packed dim must be divisible by q_num_heads"); + const int64_t head_dim_k = q_packed_dim / q_num_heads_; + const int64_t k_packed_dim = k_shape[2]; + ORT_RETURN_IF(k_packed_dim % head_dim_k != 0, + "key packed dim must be divisible by query head dimension"); + const int64_t n_k_heads = k_packed_dim / head_dim_k; + const int64_t v_packed_dim = v_shape[2]; + const int64_t head_dim_v = v_packed_dim / kv_num_heads_; + ORT_RETURN_IF(v_packed_dim != head_dim_v * kv_num_heads_, + "value packed dim must be divisible by kv_num_heads"); + + // ==== GQA head mapping ==== + // Standard GQA: q_num_heads >= kv_num_heads, multiple Q heads per KV group. + // Inverse GQA: q_num_heads < kv_num_heads (e.g., Qwen3.5 9B: n_q=16, n_kv=32). + // Also n_k_heads may differ from both (K has its own head count). + int64_t heads_per_group; // Q heads per KV group (0 if inverse GQA) + if (q_num_heads_ >= kv_num_heads_) { + ORT_RETURN_IF_NOT(q_num_heads_ % kv_num_heads_ == 0, + "q_num_heads must be divisible by kv_num_heads"); + heads_per_group = q_num_heads_ / kv_num_heads_; + } else { + ORT_RETURN_IF_NOT(kv_num_heads_ % q_num_heads_ == 0, + "kv_num_heads must be divisible by q_num_heads (inverse GQA)"); + heads_per_group = 0; // signals inverse GQA + } + + // K-to-KV head mapping: when n_k < kv_num_heads, multiple KV heads share one K head + ORT_RETURN_IF_NOT(kv_num_heads_ % n_k_heads == 0, + "kv_num_heads must be divisible by n_k_heads"); + int64_t kv_per_k_head = kv_num_heads_ / n_k_heads; + + // Validate update rule has required inputs + bool needs_decay = (update_rule_ == LinearAttentionUpdateRule::Gated || + update_rule_ == LinearAttentionUpdateRule::GatedDelta); + bool needs_beta = (update_rule_ == LinearAttentionUpdateRule::Delta || + update_rule_ == LinearAttentionUpdateRule::GatedDelta); + ORT_RETURN_IF(needs_decay && decay == nullptr, "decay input required for gated/gated_delta update rules"); + ORT_RETURN_IF(needs_beta && beta == nullptr, "beta input required for delta/gated_delta update rules"); + + // Compute scale: 0.0 means derive from d_k + float scale = scale_; + if (scale == 0.0f) { + scale = 1.0f / std::sqrt(static_cast(head_dim_k)); + } + + // Allocate outputs — output is 3D packed, state is 4D + // Output uses kv_num_heads (matches schema inference: output_dim == V_dim). + // For inverse GQA (q < kv): each KV head writes its own output slot. + // For standard/MHA (q >= kv): q == kv with this schema, so equivalent. + TensorShapeVector output_shape({batch_size, seq_length, kv_num_heads_ * head_dim_v}); + Tensor* output = context.Output(0, output_shape); + + TensorShapeVector state_shape({batch_size, kv_num_heads_, head_dim_k, head_dim_v}); + Tensor* present_state = context.Output(1, state_shape); + + // Vectorization: when head_dim_v is divisible by 4, use vec4 to pack 4 dv values + // per element. This replaces scalar TILE_V loops with native vec4 SIMD operations, + // reduces shared memory access overhead, and enables coalesced memory reads/writes. + const int components = (head_dim_v % 4 == 0 && head_dim_v >= 4) ? 4 : 1; + int tile_v = (components == 4) ? 1 : 4; + if (components == 1 && head_dim_v <= 4) { + tile_v = onnxruntime::narrow(head_dim_v); + } + const int head_dim_v_vectorized = onnxruntime::narrow(head_dim_v) / components; + + constexpr uint32_t kMaxSupportedWorkgroupSize = 256; + ORT_RETURN_IF_NOT(head_dim_k <= static_cast(kMaxSupportedWorkgroupSize), + "LinearAttention WebGPU kernel requires head_dim_k <= ", + kMaxSupportedWorkgroupSize, + ", got ", + head_dim_k); + uint32_t workgroup_size = 1; + while (workgroup_size < static_cast(head_dim_k)) { + workgroup_size *= 2; + } + // Cap at GPU limits + workgroup_size = std::min(workgroup_size, kMaxSupportedWorkgroupSize); + + const int num_dv_tiles = (head_dim_v_vectorized + tile_v - 1) / tile_v; + const uint32_t num_workgroups = onnxruntime::narrow(batch_size * kv_num_heads_ * num_dv_tiles); + + bool has_initial_state = past_state != nullptr; + bool has_decay = decay != nullptr; + bool has_beta = beta != nullptr; + + // Detect whether decay is (B,T,H_kv) or (B,T,H_kv*dk) + bool decay_broadcast_dk = false; + if (has_decay) { + const auto& decay_shape = decay->Shape(); + // (B, T, H_kv) = 3D with last dim == num_heads + int decay_last_dim = static_cast(decay_shape[decay_shape.NumDimensions() - 1]); + if (decay_last_dim == kv_num_heads_) { + decay_broadcast_dk = true; + } + } + + LinearAttentionProgram program{update_rule_, has_initial_state, has_decay, has_beta, decay_broadcast_dk, tile_v, components}; + + program.AddInputs({{query, ProgramTensorMetadataDependency::TypeAndRank}, + {key, ProgramTensorMetadataDependency::TypeAndRank}, + {value, ProgramTensorMetadataDependency::TypeAndRank, components}}); + if (has_initial_state) { + program.AddInput({past_state, ProgramTensorMetadataDependency::TypeAndRank, components}); + } + if (has_decay) { + program.AddInput({decay, ProgramTensorMetadataDependency::TypeAndRank}); + } + if (has_beta) { + program.AddInput({beta, ProgramTensorMetadataDependency::TypeAndRank}); + } + + program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank, components}, + {present_state, ProgramTensorMetadataDependency::TypeAndRank, components}}); + + program.SetDispatchGroupSize(num_workgroups) + .SetWorkgroupSize(workgroup_size) + .CacheHint(std::to_string(static_cast(update_rule_)), + has_initial_state, has_decay, has_beta, decay_broadcast_dk, tile_v, components) + .AddUniformVariables({{static_cast(batch_size)}, + {static_cast(kv_num_heads_)}, + {static_cast(seq_length)}, + {static_cast(head_dim_k)}, + {static_cast(head_dim_v_vectorized)}, + {scale}, + {static_cast(num_dv_tiles)}, + {static_cast(heads_per_group)}, + {static_cast(kv_per_k_head)}, + {static_cast(q_num_heads_)}, + {static_cast(n_k_heads)}}); + + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h new file mode 100644 index 0000000000000..566b1668d3914 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.h @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + +// Update rule enumeration +enum class LinearAttentionUpdateRule { + Invalid, + Linear, // S_t = S_{t-1} + k ⊗ v + Gated, // S_t = exp(g) * S_{t-1} + k ⊗ v + Delta, // S_t = S_{t-1} + β * k ⊗ (v - S^T k) + GatedDelta // S_t = exp(g) * S_{t-1} + β * k ⊗ (v - exp(g) * S^T k) +}; + +LinearAttentionUpdateRule ParseUpdateRule(const std::string& rule_str); + +// WebGPU program for the fused linear attention kernel. +// Each workgroup processes one (batch, head, dv_tile) combination. +// Threads within a workgroup (one per dk row) cooperate on reductions. +class LinearAttentionProgram final : public Program { + public: + LinearAttentionProgram(LinearAttentionUpdateRule update_rule, bool has_initial_state, + bool has_decay, bool has_beta, bool decay_broadcast_dk, int tile_v, int components) + : Program{"LinearAttention"}, + update_rule_(update_rule), + has_initial_state_(has_initial_state), + has_decay_(has_decay), + has_beta_(has_beta), + decay_broadcast_dk_(decay_broadcast_dk), + tile_v_(tile_v), + components_(components) {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"batch_size", ProgramUniformVariableDataType::Uint32}, + {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"seq_length", ProgramUniformVariableDataType::Uint32}, + {"head_dim_k", ProgramUniformVariableDataType::Uint32}, + {"head_dim_v", ProgramUniformVariableDataType::Uint32}, + {"scale", ProgramUniformVariableDataType::Float32}, + {"num_dv_tiles", ProgramUniformVariableDataType::Uint32}, + {"heads_per_group", ProgramUniformVariableDataType::Uint32}, + {"kv_per_k_head", ProgramUniformVariableDataType::Uint32}, + {"q_num_heads", ProgramUniformVariableDataType::Uint32}, + {"n_k_heads", ProgramUniformVariableDataType::Uint32}); + + private: + LinearAttentionUpdateRule update_rule_; + bool has_initial_state_; + bool has_decay_; + bool has_beta_; + bool decay_broadcast_dk_; + int tile_v_; + int components_; +}; + +// Kernel for LinearAttention +class LinearAttention : public WebGpuKernel { + public: + LinearAttention(const OpKernelInfo& info); + Status ComputeInternal(ComputeContext& context) const override; + + private: + LinearAttentionUpdateRule update_rule_; + float scale_; + int q_num_heads_; + int kv_num_heads_; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/linear_attention.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.wgsl.template new file mode 100644 index 0000000000000..4b4c08165c8b9 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/linear_attention.wgsl.template @@ -0,0 +1,294 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// +// LinearAttention shader +// +// Design overview: +// - Each workgroup handles one (batch, head, dv_tile) combination +// - Workgroup size = head_dim_k (dk): one thread per state row +// - Each thread maintains TILE_V columns of its state row in private memory +// - Tokens are processed sequentially; matrix ops are parallelized across threads +// - Reductions across dk (for S^T @ k and S^T @ q) use shared memory +// - For delta/gated_delta: the S^T@k (retrieval) and S^T@q (output) reductions +// are fused into a single barrier tree, nearly halving barrier count per token. +// Key identity: output = scale * (S_old^T @ q + delta * (k^T @ q)) +// + +#param use_vec4 +#param has_initial_state +#param decay_broadcast_dk +#param tile_v + +// Update rule constants +#define UPDATE_LINEAR 0 +#define UPDATE_GATED 1 +#define UPDATE_DELTA 2 +#define UPDATE_GATED_DELTA 3 +#param update_rule + +// Type aliases: vtype is the element type for state and reductions. +// otype is the output storage type. +#if use_vec4 +alias vtype = vec4; +alias otype = output_value_t; +#else +alias vtype = f32; +alias otype = output_element_t; +#endif + +const TILE_V: u32 = tile_v; + +// Shared memory for parallel reduction across dk threads. +#if update_rule == UPDATE_DELTA || update_rule == UPDATE_GATED_DELTA +// Fused reduction: retrieved (S^T@k), pre_output (S^T@q), and kq_dot (k^T@q) +// are reduced in a single barrier tree. +var red_retrieved: array; +var red_preout: array; +var red_kq: array; +var broadcast_buf: array; +#else +// Output-only reduction for linear/gated. +var reduction_buf: array; +#endif + +$MAIN { + // Identify which (batch, head, dv_tile) this workgroup handles + let bh = workgroup_idx / uniforms.num_dv_tiles; + let dv_tile_idx = workgroup_idx % uniforms.num_dv_tiles; + let batch_idx = bh / uniforms.num_heads; + let head_idx = bh % uniforms.num_heads; + let dk_idx = local_idx; // thread index = row in state matrix + let dv_start = dv_tile_idx * TILE_V; + + // Precompute packed strides for 3D packed inputs (B, T, H*D) + // Q: (B, T, q_num_heads * dk), K: (B, T, n_k_heads * dk), + // V/output: (B, T, num_heads * dv) [schema: output_dim == V_dim] + let packed_dk_q = uniforms.q_num_heads * uniforms.head_dim_k; + let packed_dk_k = uniforms.n_k_heads * uniforms.head_dim_k; + let packed_dk_kv = uniforms.num_heads * uniforms.head_dim_k; + let packed_dv = uniforms.num_heads * uniforms.head_dim_v; + + // Initialize state tile in private memory + var state: array; + for (var j = 0u; j < TILE_V; j++) { + state[j] = vtype(0.0); + } + + // Load initial state if provided +#if has_initial_state + if (dk_idx < uniforms.head_dim_k) { + let state_base = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.head_dim_k + dk_idx) * uniforms.head_dim_v + dv_start; + for (var j = 0u; j < TILE_V; j++) { + if (dv_start + j < uniforms.head_dim_v) { + state[j] = vtype(initial_state[state_base + j]); + } + } + } +#endif + + // Process each token sequentially + for (var t = 0u; t < uniforms.seq_length; t++) { + let bt_offset = batch_idx * uniforms.seq_length + t; + var k_val: f32 = 0.0; + if (dk_idx < uniforms.head_dim_k) { + let k_head_idx = head_idx / uniforms.kv_per_k_head; + let k_idx = bt_offset * packed_dk_k + k_head_idx * uniforms.head_dim_k + dk_idx; + k_val = f32(key[k_idx]); + } + + // Step 1: Apply decay (for gated and gated_delta modes) +#if update_rule == UPDATE_GATED || update_rule == UPDATE_GATED_DELTA + // Apply exponential decay: S *= exp(decay) +#if decay_broadcast_dk + let exp_g = exp(f32(decay[bt_offset * uniforms.num_heads + head_idx])); +#else + var exp_g: f32 = 1.0; + if (dk_idx < uniforms.head_dim_k) { + exp_g = exp(f32(decay[bt_offset * packed_dk_kv + head_idx * uniforms.head_dim_k + dk_idx])); + } +#endif + for (var j = 0u; j < TILE_V; j++) { + state[j] *= exp_g; + } +#endif + +#if update_rule == UPDATE_DELTA || update_rule == UPDATE_GATED_DELTA + // Determine Q head and output head for this KV head. + // Standard GQA/MHA (heads_per_group > 0): Q indexed by q_head = kv_head * hpg. + // Inverse GQA (heads_per_group == 0): multiple KV heads share one Q head; + // output indexed by KV head (each KV head has its own output slot). + var q_head_0: u32; + var out_head_0: u32; + if (uniforms.heads_per_group > 0u) { + q_head_0 = head_idx * uniforms.heads_per_group; + out_head_0 = q_head_0; + } else { + q_head_0 = head_idx * uniforms.q_num_heads / uniforms.num_heads; + out_head_0 = head_idx; + } + var q0_val: f32 = 0.0; + if (dk_idx < uniforms.head_dim_k) { + q0_val = f32(query[bt_offset * packed_dk_q + q_head_0 * uniforms.head_dim_k + dk_idx]); + } + + // Fused reduction: compute retrieved = S^T@k, pre_output = S^T@q_0, + // and kq_dot = k^T@q_0 in a single barrier tree. + // Then: output_0 = scale * (pre_output + delta * kq_dot) + for (var j = 0u; j < TILE_V; j++) { + red_retrieved[j * workgroup_size_x + dk_idx] = state[j] * k_val; + red_preout[j * workgroup_size_x + dk_idx] = state[j] * q0_val; + } + red_kq[dk_idx] = k_val * q0_val; + workgroupBarrier(); + + for (var stride = workgroup_size_x >> 1u; stride > 0u; stride = stride >> 1u) { + if (dk_idx < stride) { + for (var j = 0u; j < TILE_V; j++) { + red_retrieved[j * workgroup_size_x + dk_idx] += red_retrieved[j * workgroup_size_x + dk_idx + stride]; + red_preout[j * workgroup_size_x + dk_idx] += red_preout[j * workgroup_size_x + dk_idx + stride]; + } + red_kq[dk_idx] += red_kq[dk_idx + stride]; + } + workgroupBarrier(); + } + + // Thread 0: compute delta, broadcast it, and write output for out_head_0. + let v_base = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_start; + let beta_base = bt_offset * uniforms.num_heads + head_idx; + if (dk_idx == 0u) { + let beta_val = f32(beta[beta_base]); + let kq_dot = red_kq[0]; + for (var j = 0u; j < TILE_V; j++) { + if (dv_start + j < uniforms.head_dim_v) { + let retrieved = red_retrieved[j * workgroup_size_x]; + let pre_out = red_preout[j * workgroup_size_x]; + let v_val = vtype(value[v_base + j]); + let delta_j = beta_val * (v_val - retrieved); + broadcast_buf[j] = delta_j; + output[bt_offset * packed_dv + out_head_0 * uniforms.head_dim_v + dv_start + j] = otype((pre_out + delta_j * kq_dot) * uniforms.scale); + } else { + broadcast_buf[j] = vtype(0.0); + } + } + } + workgroupBarrier(); + + // All threads: update state with delta (S_new = S_old + k * delta) + for (var j = 0u; j < TILE_V; j++) { + state[j] += k_val * broadcast_buf[j]; + } + workgroupBarrier(); + + // Standard GQA: additional Q heads — output_g = scale * S_new^T @ q_g + // (For inverse GQA, heads_per_group == 0 so this loop is skipped.) + for (var qg = 1u; qg < uniforms.heads_per_group; qg++) { + let q_head_g = head_idx * uniforms.heads_per_group + qg; + var qg_val: f32 = 0.0; + if (dk_idx < uniforms.head_dim_k) { + qg_val = f32(query[bt_offset * packed_dk_q + q_head_g * uniforms.head_dim_k + dk_idx]); + } + for (var j = 0u; j < TILE_V; j++) { + red_preout[j * workgroup_size_x + dk_idx] = state[j] * qg_val; + } + workgroupBarrier(); + for (var stride = workgroup_size_x >> 1u; stride > 0u; stride = stride >> 1u) { + if (dk_idx < stride) { + for (var j = 0u; j < TILE_V; j++) { + red_preout[j * workgroup_size_x + dk_idx] += red_preout[j * workgroup_size_x + dk_idx + stride]; + } + } + workgroupBarrier(); + } + if (dk_idx == 0u) { + for (var j = 0u; j < TILE_V; j++) { + if (dv_start + j < uniforms.head_dim_v) { + output[bt_offset * packed_dv + q_head_g * uniforms.head_dim_v + dv_start + j] = otype(red_preout[j * workgroup_size_x] * uniforms.scale); + } + } + } + workgroupBarrier(); + } + +#else + // Linear/gated: S += k ⊗ v + let v_base = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_start; + for (var j = 0u; j < TILE_V; j++) { + if (dv_start + j < uniforms.head_dim_v) { + state[j] += k_val * vtype(value[v_base + j]); + } + } + + // Output = scale * S^T @ q + if (uniforms.heads_per_group > 0u) { + // Standard GQA / MHA: one output per Q head in group + for (var qg = 0u; qg < uniforms.heads_per_group; qg++) { + let q_head_idx = head_idx * uniforms.heads_per_group + qg; + var q_val_g: f32 = 0.0; + if (dk_idx < uniforms.head_dim_k) { + q_val_g = f32(query[bt_offset * packed_dk_q + q_head_idx * uniforms.head_dim_k + dk_idx]); + } + for (var j = 0u; j < TILE_V; j++) { + reduction_buf[j * workgroup_size_x + dk_idx] = state[j] * q_val_g; + } + workgroupBarrier(); + for (var stride = workgroup_size_x >> 1u; stride > 0u; stride = stride >> 1u) { + if (dk_idx < stride) { + for (var j = 0u; j < TILE_V; j++) { + reduction_buf[j * workgroup_size_x + dk_idx] += reduction_buf[j * workgroup_size_x + dk_idx + stride]; + } + } + workgroupBarrier(); + } + if (dk_idx == 0u) { + let out_base = bt_offset * packed_dv + q_head_idx * uniforms.head_dim_v + dv_start; + for (var j = 0u; j < TILE_V; j++) { + if (dv_start + j < uniforms.head_dim_v) { + output[out_base + j] = otype(reduction_buf[j * workgroup_size_x] * uniforms.scale); + } + } + } + workgroupBarrier(); + } + } else { + // Inverse GQA: one output per KV head, using shared Q + let q_head_inv = head_idx * uniforms.q_num_heads / uniforms.num_heads; + var q_val_inv: f32 = 0.0; + if (dk_idx < uniforms.head_dim_k) { + q_val_inv = f32(query[bt_offset * packed_dk_q + q_head_inv * uniforms.head_dim_k + dk_idx]); + } + for (var j = 0u; j < TILE_V; j++) { + reduction_buf[j * workgroup_size_x + dk_idx] = state[j] * q_val_inv; + } + workgroupBarrier(); + for (var stride = workgroup_size_x >> 1u; stride > 0u; stride = stride >> 1u) { + if (dk_idx < stride) { + for (var j = 0u; j < TILE_V; j++) { + reduction_buf[j * workgroup_size_x + dk_idx] += reduction_buf[j * workgroup_size_x + dk_idx + stride]; + } + } + workgroupBarrier(); + } + if (dk_idx == 0u) { + let out_base = bt_offset * packed_dv + head_idx * uniforms.head_dim_v + dv_start; + for (var j = 0u; j < TILE_V; j++) { + if (dv_start + j < uniforms.head_dim_v) { + output[out_base + j] = otype(reduction_buf[j * workgroup_size_x] * uniforms.scale); + } + } + } + workgroupBarrier(); + } +#endif + } // end token loop + + // Write present_state + if (dk_idx < uniforms.head_dim_k) { + let state_base = ((batch_idx * uniforms.num_heads + head_idx) * uniforms.head_dim_k + dk_idx) * uniforms.head_dim_v + dv_start; + for (var j = 0u; j < TILE_V; j++) { + if (dv_start + j < uniforms.head_dim_v) { + present_state[state_base + j] = otype(state[j]); + } + } + } +} // MAIN diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index 357eebee714d5..2fe5f4d533b7d 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -2,7 +2,9 @@ // Licensed under the MIT License. #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#include "contrib_ops/webgpu/bert/causal_conv_with_state.h" #include "contrib_ops/webgpu/bert/group_query_attention.h" +#include "contrib_ops/webgpu/bert/linear_attention.h" #include "core/framework/op_kernel.h" @@ -12,6 +14,7 @@ namespace webgpu { class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Attention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasAdd); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, CausalConvWithState); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasSplitGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FastGelu); @@ -19,6 +22,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Fu class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GatherBlockQuantized); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GroupQueryAttention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, LinearAttention); // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 16, LayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBits); @@ -43,12 +47,14 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry, bool enable BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index 38848e98509ba..632e04a36c7bf 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -212,6 +212,16 @@ class ComputeContext final : public ComputeContextBase { return op_kernel_.Info().GetDataTransferManager().CopyTensor(src, dst); } + // + // Fill a GPU tensor with zeros. + // + inline void FillZero(Tensor& dst) { + webgpu_context_.EndComputePass(); + auto& command_encoder = webgpu_context_.GetCommandEncoder(); + WGPUBuffer buffer = reinterpret_cast(dst.MutableDataRaw()); + command_encoder.ClearBuffer(buffer, 0, dst.SizeInBytes()); + } + private: OpKernelContext& kernel_context_; }; diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.cc b/onnxruntime/core/providers/webgpu/tensor/expand.cc index 0d39b1ec9d35e..e565362a380c1 100644 --- a/onnxruntime/core/providers/webgpu/tensor/expand.cc +++ b/onnxruntime/core/providers/webgpu/tensor/expand.cc @@ -21,7 +21,7 @@ Status ExpandProgram::GenerateShaderCode(ShaderHelper& shader) const { // The last dims of input shape and output shape are all divisible by 4. shader.MainFunctionBody() << " let output_indices = " << output_indices.OffsetToIndices("global_idx * 4") << ";\n" << " let input_offset = " << input_indices.BroadcastedIndicesToOffset("output_indices", output_indices) << ";\n" - << output.SetByOffset("global_idx", input.GetByOffset("input_offset")); + << output.SetByOffset("global_idx", input.GetByOffset("input_offset / 4")); } else if (output_last_dim_divisible_by_4_) { // The last dim of output shape is divisible by 4, and the last dim of input shape is 1. shader.MainFunctionBody() << " let output_indices = " << output_indices.OffsetToIndices("global_idx * 4") << ";\n"