Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
96997f7
1st cut at webgpu LinearAttention
guschmue Mar 11, 2026
23e744d
ut is now passing
guschmue Mar 11, 2026
b72a150
merge LinearAttentionRecurrentProgram and LinearAttentionRecurrentPro…
guschmue Mar 12, 2026
d8f43fd
webgpu support for rmsnorm
guschmue Mar 13, 2026
1baf195
add empty attention
guschmue Mar 13, 2026
c53ceaf
draft implementation of onnx Attention operator that maps to webgpu c…
guschmue Mar 13, 2026
c8de2f4
CausalConv1DWithState
guschmue Mar 13, 2026
0948d56
add int64_t to concat/webgpu
guschmue Mar 14, 2026
231c92c
webgpu support for onnx rotarry embeddings
guschmue Mar 15, 2026
e69b579
webgpu reshape to opset 25
guschmue Mar 15, 2026
56fe4ac
keep int64_t for concat on cpu
guschmue Mar 16, 2026
8e09ff3
webgpu LinearAttention
guschmue Mar 17, 2026
95400be
allow for decay [B,H,T]
guschmue Mar 18, 2026
c598455
Merge branch 'main' into gs/wgpu-lattn
guschmue Mar 18, 2026
6c0c736
guard for head-dim_k
guschmue Mar 24, 2026
28a3ee3
qwen3.5 shows now correct results
guschmue Mar 24, 2026
3f80587
remove chunk and group from signature
guschmue Mar 25, 2026
4111b95
update to latest signature proposal
guschmue Mar 28, 2026
f7711c4
opt: make use of vec4
guschmue Mar 29, 2026
52cee10
rename to causal_conv_with_state_op_test
guschmue Mar 29, 2026
a1e9827
ut looks for the registered ops
guschmue Mar 29, 2026
5c80e30
lintrunner -a
guschmue Mar 30, 2026
81251ff
move shader generation to .wgsl.template
guschmue Mar 31, 2026
9ebcd61
optimize number of barriers
guschmue Mar 31, 2026
079a33b
sync with #27842
guschmue Mar 31, 2026
085b274
update_rule need to stay std::string
guschmue Mar 31, 2026
6474e3e
add support for inverse GQA, needed for Qwen3.5-4/9B
guschmue Apr 2, 2026
15e4010
fix issue in Expand that shows with Qwen3.5 embeddings
guschmue Apr 5, 2026
f9ff1a5
Merge branch 'main' into gs/wgpu-lattn
guschmue Apr 5, 2026
72bfe8f
Merge branch 'main' into gs/wgpu-lattn
guschmue Apr 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "contrib_ops/webgpu/bert/causal_conv_with_state.h"

#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"

using namespace onnxruntime::webgpu;

Check warning on line 10 in onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc:10: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]

namespace onnxruntime {
namespace contrib {
namespace webgpu {

CausalConvActivation ParseCausalConvActivation(const std::string& activation_str) {
if (activation_str == "silu" || activation_str == "swish") {
return CausalConvActivation::Silu;
} else if (activation_str == "none" || activation_str.empty()) {
return CausalConvActivation::None;
}
ORT_THROW("Unknown activation for CausalConvWithState: ", activation_str);
}

// =============================================================================
// CausalConvWithState Implementation
// =============================================================================

ONNX_OPERATOR_KERNEL_EX(
CausalConvWithState,
kMSDomain,
1,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", WebGpuSupportedFloatTypes()),
CausalConvWithState);

CausalConvWithState::CausalConvWithState(const OpKernelInfo& info)
: WebGpuKernel(info) {
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The op schema defines an ndim attribute (default 1) but the WebGPU kernel ignores it entirely (constructor only reads activation). Please read/validate ndim and return a clear NOT_IMPLEMENTED/INVALID_ARGUMENT error for ndim != 1 (or implement higher dimensions), otherwise models exporting ndim=2/3 will silently run with incorrect semantics.

Suggested change
: WebGpuKernel(info) {
: WebGpuKernel(info) {
// Validate supported dimensionality.
const int64_t ndim = info.GetAttrOrDefault<int64_t>("ndim", 1);
if (ndim != 1) {
ORT_THROW("CausalConvWithState WebGPU kernel only supports ndim=1, but got ndim=", ndim);
}

Copilot uses AI. Check for mistakes.
std::string activation_str = info.GetAttrOrDefault<std::string>("activation", "none");

Check warning on line 40 in onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc:40: Add #include <string> for string [build/include_what_you_use] [4]
activation_ = ParseCausalConvActivation(activation_str);
}

Status CausalConvWithStateProgram::GenerateShaderCode(ShaderHelper& shader) const {
shader.AddInput("input", ShaderUsage::UseElementTypeAlias);
shader.AddInput("weight", ShaderUsage::UseUniform);

if (has_bias_) {
shader.AddInput("bias", ShaderUsage::UseUniform);
}
if (has_conv_state_) {
shader.AddInput("conv_state", ShaderUsage::UseUniform);
}

shader.AddOutput("output", ShaderUsage::UseUniform);
shader.AddOutput("present_state", ShaderUsage::UseUniform);

return WGSL_TEMPLATE_APPLY(shader, "bert/causal_conv_with_state.wgsl.template",
WGSL_TEMPLATE_PARAMETER(has_bias, has_bias_),
WGSL_TEMPLATE_PARAMETER(has_conv_state, has_conv_state_),
WGSL_TEMPLATE_PARAMETER(use_silu, activation_ == CausalConvActivation::Silu));
}

Status CausalConvWithState::ComputeInternal(ComputeContext& context) const {
const Tensor* input = context.Input(0); // (B, D, L)
const Tensor* weight = context.Input(1); // (D, 1, K)
const Tensor* bias = context.Input(2); // optional (D,)
const Tensor* conv_state = context.Input(3); // optional (B, D, K-1) — past_state

ORT_RETURN_IF(input == nullptr, "Input tensor must not be null");
ORT_RETURN_IF(weight == nullptr, "Weight tensor must not be null");

const auto& input_shape = input->Shape();
const auto& weight_shape = weight->Shape();

ORT_RETURN_IF(input_shape.NumDimensions() != 3,
"Input must be 3D (batch_size, channels, length)");
ORT_RETURN_IF(weight_shape.NumDimensions() != 3,
"Weight must be 3D (channels, 1, kernel_size)");

const int batch_size = static_cast<int>(input_shape[0]);
const int channels = static_cast<int>(input_shape[1]);
const int input_length = static_cast<int>(input_shape[2]);
const int kernel_size = static_cast<int>(weight_shape[2]);
const int state_length = kernel_size - 1;

ORT_RETURN_IF(static_cast<int>(weight_shape[0]) != channels,
"Weight first dim must match input channels");
ORT_RETURN_IF(static_cast<int>(weight_shape[1]) != 1,
"Weight second dim must be 1 for depthwise convolution");

if (bias != nullptr) {
ORT_RETURN_IF(bias->Shape().NumDimensions() != 1,
"Bias must be 1D");
ORT_RETURN_IF(static_cast<int>(bias->Shape()[0]) != channels,
"Bias size must match channels");
}

if (conv_state != nullptr) {
ORT_RETURN_IF(conv_state->Shape().NumDimensions() != 3,
"conv_state must be 3D (batch_size, channels, kernel_size - 1)");
ORT_RETURN_IF(static_cast<int>(conv_state->Shape()[0]) != batch_size,
"conv_state batch_size must match input");
ORT_RETURN_IF(static_cast<int>(conv_state->Shape()[1]) != channels,
"conv_state channels must match input");
ORT_RETURN_IF(static_cast<int>(conv_state->Shape()[2]) != state_length,
"conv_state last dim must be kernel_size - 1");
}

const bool has_bias = (bias != nullptr);
const bool has_conv_state = (conv_state != nullptr);

// Allocate outputs
// Output 0: (B, D, L)
Tensor* output = context.Output(0, input_shape);

// Output 1: present_state (B, D, K-1)
std::vector<int64_t> state_dims{batch_size, channels, state_length};

Check warning on line 118 in onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc:118: Add #include <vector> for vector<> [build/include_what_you_use] [4]
Tensor* present_state = context.Output(1, TensorShape(state_dims));

if (input_length == 0) {
return Status::OK();
}
Comment on lines +113 to +123
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When input_length == 0, the code returns early after allocating present_state but never writes/initializes it. present_state should still be well-defined (typically equal to past_state if provided, otherwise zeros) even for an empty input. Please handle the zero-length case by populating present_state appropriately before returning.

Copilot uses AI. Check for mistakes.

// Create and run the shader program
CausalConvWithStateProgram program{activation_, has_bias, has_conv_state, kernel_size};

uint32_t output_size = static_cast<uint32_t>(batch_size * channels * input_length);

program.AddInput({input, ProgramTensorMetadataDependency::Type})
.AddInput({weight, ProgramTensorMetadataDependency::None});

if (has_bias) {
program.AddInput({bias, ProgramTensorMetadataDependency::None});
}
if (has_conv_state) {
program.AddInput({conv_state, ProgramTensorMetadataDependency::None});
}

program.AddOutput({output, ProgramTensorMetadataDependency::None})
.AddOutput({present_state, ProgramTensorMetadataDependency::None})
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddUniformVariable({static_cast<uint32_t>(batch_size)})
.AddUniformVariable({static_cast<uint32_t>(channels)})
.AddUniformVariable({static_cast<uint32_t>(input_length)})
.AddUniformVariable({static_cast<uint32_t>(kernel_size)})
.AddUniformVariable({static_cast<uint32_t>(state_length)})
.AddUniformVariable({output_size});

return context.RunProgram(program);
}

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
66 changes: 66 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <string>

#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/webgpu_kernel.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

using namespace onnxruntime::webgpu;

Check warning on line 15 in onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.h:15: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
using onnxruntime::webgpu::ComputeContext;

// Activation mode for CausalConvWithState
enum class CausalConvActivation {
None,
Silu,
};

CausalConvActivation ParseCausalConvActivation(const std::string& activation_str);

// Program for CausalConvWithState
class CausalConvWithStateProgram final : public Program<CausalConvWithStateProgram> {
public:
CausalConvWithStateProgram(CausalConvActivation activation, bool has_bias, bool has_conv_state,
int kernel_size)
: Program{"CausalConvWithState"},
activation_(activation),
has_bias_(has_bias),
has_conv_state_(has_conv_state),
kernel_size_(kernel_size) {}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"batch_size", ProgramUniformVariableDataType::Uint32},
{"channels", ProgramUniformVariableDataType::Uint32},
{"input_length", ProgramUniformVariableDataType::Uint32},
{"kernel_size", ProgramUniformVariableDataType::Uint32},
{"state_length", ProgramUniformVariableDataType::Uint32},
{"output_size", ProgramUniformVariableDataType::Uint32});

private:
CausalConvActivation activation_;
bool has_bias_;
bool has_conv_state_;
[[maybe_unused]] int kernel_size_;
};

// Kernel for CausalConvWithState
class CausalConvWithState final : public WebGpuKernel {
public:
CausalConvWithState(const OpKernelInfo& info);
Status ComputeInternal(ComputeContext& context) const override;

private:
CausalConvActivation activation_;
};

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#param has_bias
#param has_conv_state
#param use_silu

#use guardAgainstOutOfBoundsWorkgroupSizes

#if use_silu
fn silu(x: input_element_t) -> input_element_t {
return x / (1.0 + exp(-x));
}
#endif

$MAIN {
guardAgainstOutOfBoundsWorkgroupSizes(uniforms.output_size);

let batch_size = uniforms.batch_size;
let channels = uniforms.channels;
let input_length = uniforms.input_length;
let kernel_size = uniforms.kernel_size;
let state_length = uniforms.state_length; // = kernel_size - 1

let pos = global_idx % input_length;
let bc_idx = global_idx / input_length;
let batch_idx = bc_idx / channels;
let channel_idx = bc_idx % channels;

// Perform depthwise causal convolution for this (batch, channel, pos).
// The convolution window looks back kernel_size-1 positions.
// With conv_state providing the history before position 0, the
// "virtual" input is: [conv_state[0..state_length-1], input[0..L-1]]
//
// For output position pos:
// output[pos] = sum_{j=0}^{kernel_size-1} weight[j] * virtual_input[pos + j]
// where virtual_input is state_length positions of conv_state
// followed by input_length positions of input.

var acc: input_element_t = 0.0;

// Weight layout: (D, 1, K) -> channel_idx * kernel_size + j
let weight_base = channel_idx * kernel_size;

for (var j: u32 = 0; j < kernel_size; j = j + 1) {
// virtual_pos is the position in the concatenated [conv_state, input]
let virtual_pos = pos + j;

var val: input_element_t = 0.0;

#if has_conv_state
if (virtual_pos < state_length) {
// Read from conv_state: (B, D, state_length)
let state_idx = (batch_idx * channels + channel_idx) * state_length + virtual_pos;
val = conv_state[state_idx];
} else {
// Read from input: (B, D, L)
let input_pos = virtual_pos - state_length;
let input_idx = (batch_idx * channels + channel_idx) * input_length + input_pos;
val = input[input_idx];
}
#else
// No conv_state: pad with zeros for positions before the input
if (virtual_pos >= state_length) {
let input_pos = virtual_pos - state_length;
let input_idx = (batch_idx * channels + channel_idx) * input_length + input_pos;
val = input[input_idx];
}
#endif

let w = weight[weight_base + j];
acc = acc + val * w;
}

#if has_bias
acc = acc + bias[channel_idx];
#endif

#if use_silu
acc = silu(acc);
#endif

// Write output: (B, D, L)
let out_idx = (batch_idx * channels + channel_idx) * input_length + pos;
output[out_idx] = acc;

// Write present_state: the last (kernel_size - 1) elements from the
// virtual input [conv_state, input]. We only write present_state once
// per (batch, channel), using the thread at pos == 0.
if (pos == 0u) {
for (var s: u32 = 0; s < state_length; s = s + 1) {
var state_val: input_element_t = 0.0;
// total_len = state_length + input_length
// We want virtual_input[total_len - state_length + s] = virtual_input[input_length + s]
let vp = input_length + s;

#if has_conv_state
if (vp < state_length) {
let si = (batch_idx * channels + channel_idx) * state_length + vp;
state_val = conv_state[si];
} else {
let ip = vp - state_length;
let ii = (batch_idx * channels + channel_idx) * input_length + ip;
state_val = input[ii];
}
#else
if (vp >= state_length) {
let ip = vp - state_length;
let ii = (batch_idx * channels + channel_idx) * input_length + ip;
state_val = input[ii];
}
#endif

let ps_idx = (batch_idx * channels + channel_idx) * state_length + s;
present_state[ps_idx] = state_val;
}
}
} // MAIN
Loading
Loading