Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "contrib_ops/webgpu/bert/causal_conv_with_state.h"

#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"

using namespace onnxruntime::webgpu;

Check warning on line 10 in onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc:10: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]

namespace onnxruntime {
namespace contrib {
namespace webgpu {

CausalConvActivation ParseCausalConvActivation(const std::string& activation_str) {
if (activation_str == "silu" || activation_str == "swish") {
return CausalConvActivation::Silu;
} else if (activation_str == "none" || activation_str.empty()) {
return CausalConvActivation::None;
}
return CausalConvActivation::Invalid;
}
Comment thread
guschmue marked this conversation as resolved.

// =============================================================================
// CausalConvWithState Implementation
// =============================================================================

ONNX_OPERATOR_KERNEL_EX(
CausalConvWithState,
kMSDomain,
1,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", WebGpuSupportedFloatTypes()),
CausalConvWithState);

CausalConvWithState::CausalConvWithState(const OpKernelInfo& info)
: WebGpuKernel(info) {
std::string activation_str = info.GetAttrOrDefault<std::string>("activation", "none");

Check warning on line 40 in onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc:40: Add #include <string> for string [build/include_what_you_use] [4]
activation_ = ParseCausalConvActivation(activation_str);
ORT_ENFORCE(info.GetAttr<int64_t>("ndim", &ndim_).IsOK(), "Attribute 'ndim' is required");
}

Status CausalConvWithStateProgram::GenerateShaderCode(ShaderHelper& shader) const {
shader.AddInput("input", ShaderUsage::UseElementTypeAlias);
shader.AddInput("weight", ShaderUsage::UseUniform);

if (has_bias_) {
shader.AddInput("bias", ShaderUsage::UseUniform);
}
if (has_conv_state_) {
shader.AddInput("conv_state", ShaderUsage::UseUniform);
}

shader.AddOutput("output", ShaderUsage::UseUniform);
shader.AddOutput("present_state", ShaderUsage::UseUniform);

return WGSL_TEMPLATE_APPLY(shader, "bert/causal_conv_with_state.wgsl.template",
WGSL_TEMPLATE_PARAMETER(has_bias, has_bias_),
WGSL_TEMPLATE_PARAMETER(has_conv_state, has_conv_state_),
WGSL_TEMPLATE_PARAMETER(use_silu, activation_ == CausalConvActivation::Silu));
}

Status CausalConvWithState::ComputeInternal(ComputeContext& context) const {
const Tensor* input = context.Input(0); // (B, D, L)
const Tensor* weight = context.Input(1); // (D, 1, K)
const Tensor* bias = context.Input(2); // optional (D,)
const Tensor* conv_state = context.Input(3); // optional (B, D, K-1) — past_state

ORT_RETURN_IF(activation_ == CausalConvActivation::Invalid, "Invalid activation type");
ORT_RETURN_IF(ndim_ != 1, "Only 1D convolution is supported");
const auto& input_shape = input->Shape();
const auto& weight_shape = weight->Shape();

ORT_RETURN_IF(input_shape.NumDimensions() != 3,
"Input must be 3D (batch_size, channels, length)");
ORT_RETURN_IF(weight_shape.NumDimensions() != 3,
"Weight must be 3D (channels, 1, kernel_size)");

const int64_t batch_size = input_shape[0];
const int64_t channels = input_shape[1];
const int64_t input_length = input_shape[2];
const int64_t kernel_size = weight_shape[2];
const int64_t state_length = kernel_size - 1;

Comment thread
guschmue marked this conversation as resolved.
ORT_RETURN_IF(weight_shape[0] != channels, "Weight first dim must match input channels");
ORT_RETURN_IF(weight_shape[1] != 1, "Weight second dim must be 1 for depthwise convolution");

if (bias != nullptr) {
ORT_RETURN_IF(bias->Shape().NumDimensions() != 1, "Bias must be 1D");
ORT_RETURN_IF(bias->Shape()[0] != channels, "Bias size must match channels");
}

if (conv_state != nullptr) {
ORT_RETURN_IF(conv_state->Shape().NumDimensions() != 3,
"conv_state must be 3D (batch_size, channels, kernel_size - 1)");
ORT_RETURN_IF(conv_state->Shape()[0] != batch_size,
"conv_state batch_size must match input");
ORT_RETURN_IF(conv_state->Shape()[1] != channels,
"conv_state channels must match input");
ORT_RETURN_IF(conv_state->Shape()[2] != state_length,
"conv_state last dim must be kernel_size - 1");
}

const bool has_bias = (bias != nullptr);
const bool has_conv_state = (conv_state != nullptr);

// Allocate outputs
// Output 0: (B, D, L)
Tensor* output = context.Output(0, input_shape);
Comment thread
guschmue marked this conversation as resolved.

// Output 1: present_state (B, D, K-1)
std::vector<int64_t> state_dims{batch_size, channels, state_length};

Check warning on line 114 in onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.cc:114: Add #include <vector> for vector<> [build/include_what_you_use] [4]
Tensor* present_state = context.Output(1, TensorShape(state_dims));

if (input_shape.Size() == 0) {
if (has_conv_state) {
ORT_RETURN_IF_ERROR(context.CopyTensor(*conv_state, *present_state));
} else {
context.FillZero(*present_state);
return Status::OK();
}
}

// Create and run the shader program
CausalConvWithStateProgram program{activation_, has_bias, has_conv_state};

uint32_t output_size = static_cast<uint32_t>(batch_size * channels * input_length);

program.CacheHint(has_bias, has_conv_state, kernel_size, static_cast<int>(activation_));

program.AddInput({input, ProgramTensorMetadataDependency::Type})
.AddInput({weight, ProgramTensorMetadataDependency::None});

if (has_bias) {
program.AddInput({bias, ProgramTensorMetadataDependency::None});
}
if (has_conv_state) {
program.AddInput({conv_state, ProgramTensorMetadataDependency::None});
}

program.AddOutput({output, ProgramTensorMetadataDependency::None})
.AddOutput({present_state, ProgramTensorMetadataDependency::None})
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddUniformVariable({static_cast<uint32_t>(batch_size)})
Comment thread
guschmue marked this conversation as resolved.
.AddUniformVariable({static_cast<uint32_t>(channels)})
.AddUniformVariable({static_cast<uint32_t>(input_length)})
.AddUniformVariable({static_cast<uint32_t>(kernel_size)})
.AddUniformVariable({static_cast<uint32_t>(state_length)})
.AddUniformVariable({output_size});

return context.RunProgram(program);
}

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
65 changes: 65 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <string>

#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/webgpu_kernel.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

using namespace onnxruntime::webgpu;

Check warning on line 15 in onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/contrib_ops/webgpu/bert/causal_conv_with_state.h:15: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
using onnxruntime::webgpu::ComputeContext;

// Activation mode for CausalConvWithState
enum class CausalConvActivation {
Invalid,
None,
Silu
};

CausalConvActivation ParseCausalConvActivation(const std::string& activation_str);

// Program for CausalConvWithState
class CausalConvWithStateProgram final : public Program<CausalConvWithStateProgram> {
public:
CausalConvWithStateProgram(CausalConvActivation activation, bool has_bias, bool has_conv_state)
: Program{"CausalConvWithState"},
activation_(activation),
has_bias_(has_bias),
has_conv_state_(has_conv_state) {}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"batch_size", ProgramUniformVariableDataType::Uint32},
{"channels", ProgramUniformVariableDataType::Uint32},
{"input_length", ProgramUniformVariableDataType::Uint32},
{"kernel_size", ProgramUniformVariableDataType::Uint32},
{"state_length", ProgramUniformVariableDataType::Uint32},
{"output_size", ProgramUniformVariableDataType::Uint32});

private:
CausalConvActivation activation_;
bool has_bias_;
bool has_conv_state_;
};

// Kernel for CausalConvWithState
class CausalConvWithState final : public WebGpuKernel {
public:
CausalConvWithState(const OpKernelInfo& info);
Status ComputeInternal(ComputeContext& context) const override;

private:
CausalConvActivation activation_;
int64_t ndim_;
};

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#param has_bias
#param has_conv_state
#param use_silu

#use guardAgainstOutOfBoundsWorkgroupSizes

#if use_silu
fn silu(x: input_element_t) -> input_element_t {
return x / (1.0 + exp(-x));
}
#endif

$MAIN {
guardAgainstOutOfBoundsWorkgroupSizes(uniforms.output_size);

let batch_size = uniforms.batch_size;
let channels = uniforms.channels;
let input_length = uniforms.input_length;
let kernel_size = uniforms.kernel_size;
let state_length = uniforms.state_length; // = kernel_size - 1

let pos = global_idx % input_length;
let bc_idx = global_idx / input_length;
let batch_idx = bc_idx / channels;
let channel_idx = bc_idx % channels;

// Perform depthwise causal convolution for this (batch, channel, pos).
// The convolution window looks back kernel_size-1 positions.
// With conv_state providing the history before position 0, the
// "virtual" input is: [conv_state[0..state_length-1], input[0..L-1]]
//
// For output position pos:
// output[pos] = sum_{j=0}^{kernel_size-1} weight[j] * virtual_input[pos + j]
// where virtual_input is state_length positions of conv_state
// followed by input_length positions of input.

var acc: input_element_t = 0.0;

// Weight layout: (D, 1, K) -> channel_idx * kernel_size + j
let weight_base = channel_idx * kernel_size;

for (var j: u32 = 0; j < kernel_size; j = j + 1) {
// virtual_pos is the position in the concatenated [conv_state, input]
let virtual_pos = pos + j;

var val: input_element_t = 0.0;

#if has_conv_state
if (virtual_pos < state_length) {
// Read from conv_state: (B, D, state_length)
let state_idx = (batch_idx * channels + channel_idx) * state_length + virtual_pos;
val = conv_state[state_idx];
} else {
// Read from input: (B, D, L)
let input_pos = virtual_pos - state_length;
let input_idx = (batch_idx * channels + channel_idx) * input_length + input_pos;
val = input[input_idx];
}
#else
// No conv_state: pad with zeros for positions before the input
if (virtual_pos >= state_length) {
let input_pos = virtual_pos - state_length;
let input_idx = (batch_idx * channels + channel_idx) * input_length + input_pos;
val = input[input_idx];
}
#endif

let w = weight[weight_base + j];
acc = acc + val * w;
}

#if has_bias
acc = acc + bias[channel_idx];
#endif

#if use_silu
acc = silu(acc);
#endif

// Write output: (B, D, L)
let out_idx = (batch_idx * channels + channel_idx) * input_length + pos;
output[out_idx] = acc;

// Write present_state: the last (kernel_size - 1) elements from the
// virtual input [conv_state, input]. We only write present_state once
// per (batch, channel), using the thread at pos == 0.
if (pos == 0u) {
for (var s: u32 = 0; s < state_length; s = s + 1) {
var state_val: input_element_t = 0.0;
// total_len = state_length + input_length
// We want virtual_input[total_len - state_length + s] = virtual_input[input_length + s]
let vp = input_length + s;

#if has_conv_state
if (vp < state_length) {
let si = (batch_idx * channels + channel_idx) * state_length + vp;
state_val = conv_state[si];
} else {
let ip = vp - state_length;
let ii = (batch_idx * channels + channel_idx) * input_length + ip;
state_val = input[ii];
}
#else
if (vp >= state_length) {
let ip = vp - state_length;
let ii = (batch_idx * channels + channel_idx) * input_length + ip;
state_val = input[ii];
}
#endif

let ps_idx = (batch_idx * channels + channel_idx) * state_length + s;
present_state[ps_idx] = state_val;
}
}
} // MAIN
Loading
Loading