Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions onnxruntime/core/providers/webgpu/llm/rotary_embedding.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/webgpu/webgpu_supported_types.h"
#include "core/providers/webgpu/llm/rotary_embedding.h"
#include "contrib_ops/webgpu/bert/rotary_embedding.h"
#include "core/providers/webgpu/generator/range.h"

namespace onnxruntime {
namespace webgpu {

ONNX_OPERATOR_KERNEL_EX(
RotaryEmbedding,
kOnnxDomain,
23,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", WebGpuSupportedFloatTypes())
.TypeConstraint("M", DataTypeImpl::GetTensorType<int64_t>()),
RotaryEmbedding);

RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : WebGpuKernel(info) {
rotary_embedding_dim_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("rotary_embedding_dim", 0));
num_heads_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("num_heads", 0));
interleaved_ = (info.GetAttrOrDefault<int64_t>("interleaved", 0) == 1);
}

Status RotaryEmbedding::ComputeInternal(ComputeContext& context) const {
// ONNX inputs: X(0), cos_cache(1), sin_cache(2), position_ids(3, optional)
const auto* input = context.Input<Tensor>(0);
const auto* cos_cache = context.Input<Tensor>(1);
const auto* sin_cache = context.Input<Tensor>(2);
const auto* position_ids = context.Input<Tensor>(3); // optional

const auto input_shape = input->Shape();
auto* output = context.Output(0, input_shape);

const auto batch_size = onnxruntime::narrow<uint32_t>(input_shape[0]);
const auto batch_stride = onnxruntime::narrow<uint32_t>(input_shape.SizeFromDimension(1));
const auto sequence_length = onnxruntime::narrow<uint32_t>(input_shape[input_shape.NumDimensions() - 2]);
const auto hidden_size = batch_stride / sequence_length;
const auto half_rotary_embedding_dim = onnxruntime::narrow<uint32_t>(cos_cache->Shape()[cos_cache->Shape().NumDimensions() - 1]);

// Compute head_size: when rotary_embedding_dim is not set, head_size = rotary_dim (= 2 * half).
// When rotary_embedding_dim is set, derive head_size from the 4D input shape or num_heads attribute.
uint32_t head_size;
if (rotary_embedding_dim_ == 0) {
head_size = half_rotary_embedding_dim * 2;
} else if (input_shape.NumDimensions() == 4) {
// 4D input: [batch, num_heads, seq, head_size]
head_size = onnxruntime::narrow<uint32_t>(input_shape[3]);
} else {
ORT_ENFORCE(num_heads_ > 0,
"Attribute 'num_heads' must be provided when 'rotary_embedding_dim' is specified "
"and input is not rank-4 (batch, num_heads, sequence, head).");
head_size = hidden_size / num_heads_;
}

const TensorShape global_shape({batch_size,
sequence_length,
hidden_size / head_size,
head_size - half_rotary_embedding_dim});

const auto rank = global_shape.NumDimensions();
std::vector<uint32_t> global_dims(rank);
std::vector<uint32_t> global_strides(rank);
for (size_t j = 0; j < rank; ++j) {
global_dims[j] = onnxruntime::narrow<uint32_t>(global_shape[j]);
global_strides[j] = onnxruntime::narrow<uint32_t>(global_shape.SizeFromDimension(j + 1));
}

const auto output_size = onnxruntime::narrow<const uint32_t>(global_shape.Size());
const auto input_output_strides =
input_shape.NumDimensions() == 3
? std::vector<uint32_t>({batch_stride, hidden_size, head_size, 1})
: (input_shape.NumDimensions() == 4
? std::vector<uint32_t>({batch_stride, head_size, sequence_length * head_size, 1})
: std::vector<uint32_t>({}));

Check warning on line 78 in onnxruntime/core/providers/webgpu/llm/rotary_embedding.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/llm/rotary_embedding.cc:78: Add #include <vector> for vector<> [build/include_what_you_use] [4]

// The contrib RotaryEmbeddingProgram expects inputs in order:
// input(0), position_ids(1), cos_cache(2), sin_cache(3)
// The ONNX op has: X(0), cos_cache(1), sin_cache(2), position_ids(3, optional)

if (position_ids != nullptr) {
// position_ids provided: cos/sin cache is 2D (max_pos, D/2)
contrib::webgpu::RotaryEmbeddingProgram program{interleaved_};
Comment thread
guschmue marked this conversation as resolved.
Comment thread
guschmue marked this conversation as resolved.
program
.CacheHint(interleaved_)
.AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank},
{position_ids, ProgramTensorMetadataDependency::Rank},
{cos_cache, ProgramTensorMetadataDependency::Rank},
{sin_cache, ProgramTensorMetadataDependency::Rank}})
.AddOutput({output, ProgramTensorMetadataDependency::None})
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddUniformVariables({{1.0f},
{gsl::make_span(global_dims)},
{gsl::make_span(global_strides)},
{gsl::make_span(input_output_strides)}})
.AddIndices(TensorShape{1, 1});
return context.RunProgram(program);
}

// position_ids NOT provided: cos/sin cache is 3D (B, S, D/2)
// Reshape to 2D (B*S, D/2) and generate sequential position_ids.
const auto total_seq = batch_size * sequence_length;
const TensorShape cache_2d_shape({static_cast<int64_t>(total_seq),
static_cast<int64_t>(half_rotary_embedding_dim)});

// Generate position_ids [0, 1, ..., B*S-1] reshaped as (B, S) on GPU using RangeProgram
const TensorShape pos_ids_shape({static_cast<int64_t>(batch_size),
static_cast<int64_t>(sequence_length)});
Tensor pos_ids_tensor = context.CreateGPUTensor(DataTypeImpl::GetType<int64_t>(), pos_ids_shape);
{
RangeProgram range_program{ONNX_NAMESPACE::TensorProto_DataType_INT64};
int32_t start_i32 = 0;
int32_t delta_i32 = 1;
range_program
.AddOutput({&pos_ids_tensor, ProgramTensorMetadataDependency::Type})
.SetDispatchGroupSize((total_seq + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddUniformVariables({
total_seq,
std::bit_cast<uint32_t>(start_i32),
std::bit_cast<uint32_t>(delta_i32),
});
ORT_RETURN_IF_ERROR(context.RunProgram(range_program));
}

contrib::webgpu::RotaryEmbeddingProgram program{interleaved_};
program
.CacheHint(interleaved_)
.AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank},
{&pos_ids_tensor, ProgramTensorMetadataDependency::Rank},
{cos_cache, ProgramTensorMetadataDependency::Rank, cache_2d_shape, 1},
{sin_cache, ProgramTensorMetadataDependency::Rank, cache_2d_shape, 1}})
.AddOutput({output, ProgramTensorMetadataDependency::None})
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddUniformVariables({{1.0f},
{gsl::make_span(global_dims)},
{gsl::make_span(global_strides)},
{gsl::make_span(input_output_strides)}})
.AddIndices(TensorShape{1, 1});
return context.RunProgram(program);
}

} // namespace webgpu
} // namespace onnxruntime
23 changes: 23 additions & 0 deletions onnxruntime/core/providers/webgpu/llm/rotary_embedding.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/webgpu_kernel.h"

namespace onnxruntime {
namespace webgpu {

class RotaryEmbedding final : public WebGpuKernel {
public:
RotaryEmbedding(const OpKernelInfo& info);
Status ComputeInternal(ComputeContext& context) const override;

private:
int num_heads_;
int rotary_embedding_dim_;
bool interleaved_;
};

} // namespace webgpu
} // namespace onnxruntime
121 changes: 121 additions & 0 deletions onnxruntime/core/providers/webgpu/nn/rms_norm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "core/providers/webgpu/webgpu_utils.h"
#include "core/providers/webgpu/nn/rms_norm.h"
#include "core/providers/webgpu/nn/layer_norm.h"

namespace onnxruntime {
namespace webgpu {

static size_t NormalizeAxis(int64_t axis, size_t tensor_rank) {
int64_t rank = static_cast<int64_t>(tensor_rank);
if (axis < -rank && axis >= rank) {
ORT_THROW("invalid axis: ", axis);
}
return onnxruntime::narrow<size_t>(axis < 0 ? axis + rank : axis);
}

static TensorShape GetOverrideShape(const TensorShape& shape, int components) {
TensorShape override_shape{shape.Size() / components};
return override_shape;
}

Status RMSNorm::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
const auto* x = context.Input(0);
const auto* scale = context.Input(1);

const auto x_shape = x->Shape();

const size_t axis = NormalizeAxis(axis_, x_shape.NumDimensions());
const uint32_t norm_count = onnxruntime::narrow<uint32_t>(x_shape.SizeToDimension(axis));
const int64_t norm_size = x_shape.SizeFromDimension(axis);
const int components = GetMaxComponents(norm_size);
const uint32_t norm_size_vectorized = onnxruntime::narrow<uint32_t>((norm_size + components - 1) / components);

const auto& scale_shape = scale->Shape();
const auto scale_size = scale_shape.Size();
if (scale_shape.NumDimensions() > x_shape.NumDimensions()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Scale and (optional) bias must match X.shape[axis:] or be NumPy-broadcastable to it."
" Scale/Bias rank cannot exceed Input rank.");
}
if (scale_size != norm_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Size of X.shape()[axis:] == ", norm_size,
". Size of scale must match this. Got scale size of ",
scale_size);
}

// RMSNormalization outputs: Y (index 0), InvStdDev (index 1, optional)
auto* y = context.Output(0, x_shape);

TensorShapeVector inv_std_dev_dim;
for (size_t i = 0; i < x_shape.NumDimensions(); ++i) {
if (i < axis) {
inv_std_dev_dim.push_back(x_shape[i]);
} else {
inv_std_dev_dim.push_back(1);
}
}
TensorShape inv_std_dev_shape(inv_std_dev_dim);
auto* inv_std_dev = context.Output(1, inv_std_dev_shape);

if (x_shape.Size() == 0) {
return Status::OK();
}

// Check if we should use split norm dimension optimization
const bool split_norm_dim = norm_size % 512 == 0 && norm_count == 1;

// Reuse LayerNormProgram with simplified=true, has_bias=false, no mean output
LayerNormProgram program{/*has_bias=*/false, /*simplified=*/true, /*has_mean_output=*/false,
/*has_inv_std_dev_output=*/inv_std_dev != nullptr, split_norm_dim};

program.CacheHint(components, /*simplified=*/true, split_norm_dim)
.AddInputs({{x, ProgramTensorMetadataDependency::Type, GetOverrideShape(x->Shape(), components), components}})
.AddInputs(
{{scale, ProgramTensorMetadataDependency::Type, GetOverrideShape(scale->Shape(), components), components}})
.AddOutputs({{y, ProgramTensorMetadataDependency::None, GetOverrideShape(y->Shape(), components), components}})
.AddUniformVariables({
{static_cast<uint32_t>(components)},
})
.AddUniformVariables({
{static_cast<uint32_t>(norm_count)},
})
.AddUniformVariables({
{static_cast<uint32_t>(norm_size)},
})
.AddUniformVariables({
{static_cast<uint32_t>(norm_size_vectorized)},
})
.AddUniformVariables({
{static_cast<float>(epsilon_)},
});

if (split_norm_dim) {
const uint32_t workgroup_size_x = 128;
const uint32_t dispatch_size_x = onnxruntime::narrow<uint32_t>(norm_size / (workgroup_size_x * components));
program.SetDispatchGroupSize(dispatch_size_x, 1, 1)
.SetWorkgroupSize(workgroup_size_x);
} else {
program.SetDispatchGroupSize(norm_count);
}

if (inv_std_dev != nullptr) {
program.AddOutputs({{inv_std_dev, ProgramTensorMetadataDependency::None}});
}

return context.RunProgram(program);
}

ONNX_OPERATOR_KERNEL_EX(RMSNormalization, kOnnxDomain, 23, kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", WebGpuSupportedFloatTypes())
.TypeConstraint("V", WebGpuSupportedFloatTypes()),
RMSNorm);

} // namespace webgpu
} // namespace onnxruntime
26 changes: 26 additions & 0 deletions onnxruntime/core/providers/webgpu/nn/rms_norm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/webgpu_kernel.h"

namespace onnxruntime {
namespace webgpu {

class RMSNorm final : public WebGpuKernel {
public:
RMSNorm(const OpKernelInfo& info) : WebGpuKernel(info) {
info.GetAttrOrDefault<int64_t>("axis", &axis_, -1);
info.GetAttrOrDefault<float>("epsilon", &epsilon_, 1e-05f);
}

Status ComputeInternal(ComputeContext& context) const override;

private:
int64_t axis_;
float epsilon_;
};

} // namespace webgpu
} // namespace onnxruntime
26 changes: 25 additions & 1 deletion onnxruntime/core/providers/webgpu/tensor/reshape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,31 @@ namespace webgpu {
ONNX_OPERATOR_KERNEL_EX(
Reshape,
kOnnxDomain,
21,
25,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", WebGpuSupportedNumberTypes())
.TypeConstraint("shape", DataTypeImpl::GetTensorType<int64_t>())
.Alias(0, 0)
.InputMemoryType(OrtMemTypeCPU, 1),
Reshape);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Reshape,
kOnnxDomain,
23, 24,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", WebGpuSupportedNumberTypes())
.TypeConstraint("shape", DataTypeImpl::GetTensorType<int64_t>())
.Alias(0, 0)
.InputMemoryType(OrtMemTypeCPU, 1),
Reshape);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Reshape,
kOnnxDomain,
21, 22,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", WebGpuSupportedNumberTypes())
Expand Down
11 changes: 10 additions & 1 deletion onnxruntime/core/providers/webgpu/tensor/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,19 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
.TypeConstraint("T", WebGpuSupportedNumberTypes()),
Transpose);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Transpose,
kOnnxDomain,
23, 23,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", WebGpuSupportedNumberTypes()),
Transpose);

ONNX_OPERATOR_KERNEL_EX(
Transpose,
kOnnxDomain,
23,
24,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", WebGpuSupportedNumberTypes()),
Expand Down
Loading
Loading