diff --git a/onnxruntime/core/providers/webgpu/llm/rotary_embedding.cc b/onnxruntime/core/providers/webgpu/llm/rotary_embedding.cc new file mode 100644 index 0000000000000..ee46c76f1ea54 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/llm/rotary_embedding.cc @@ -0,0 +1,146 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/llm/rotary_embedding.h" +#include "contrib_ops/webgpu/bert/rotary_embedding.h" +#include "core/providers/webgpu/generator/range.h" + +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + RotaryEmbedding, + kOnnxDomain, + 23, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()) + .TypeConstraint("M", DataTypeImpl::GetTensorType()), + RotaryEmbedding); + +RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : WebGpuKernel(info) { + rotary_embedding_dim_ = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0)); + num_heads_ = static_cast(info.GetAttrOrDefault("num_heads", 0)); + interleaved_ = (info.GetAttrOrDefault("interleaved", 0) == 1); +} + +Status RotaryEmbedding::ComputeInternal(ComputeContext& context) const { + // ONNX inputs: X(0), cos_cache(1), sin_cache(2), position_ids(3, optional) + const auto* input = context.Input(0); + const auto* cos_cache = context.Input(1); + const auto* sin_cache = context.Input(2); + const auto* position_ids = context.Input(3); // optional + + const auto input_shape = input->Shape(); + auto* output = context.Output(0, input_shape); + + const auto batch_size = onnxruntime::narrow(input_shape[0]); + const auto batch_stride = onnxruntime::narrow(input_shape.SizeFromDimension(1)); + const auto sequence_length = onnxruntime::narrow(input_shape[input_shape.NumDimensions() - 2]); + const auto hidden_size = batch_stride / sequence_length; + const auto half_rotary_embedding_dim = onnxruntime::narrow(cos_cache->Shape()[cos_cache->Shape().NumDimensions() - 1]); + + // Compute head_size: when rotary_embedding_dim is not set, head_size = rotary_dim (= 2 * half). + // When rotary_embedding_dim is set, derive head_size from the 4D input shape or num_heads attribute. + uint32_t head_size; + if (rotary_embedding_dim_ == 0) { + head_size = half_rotary_embedding_dim * 2; + } else if (input_shape.NumDimensions() == 4) { + // 4D input: [batch, num_heads, seq, head_size] + head_size = onnxruntime::narrow(input_shape[3]); + } else { + ORT_ENFORCE(num_heads_ > 0, + "Attribute 'num_heads' must be provided when 'rotary_embedding_dim' is specified " + "and input is not rank-4 (batch, num_heads, sequence, head)."); + head_size = hidden_size / num_heads_; + } + + const TensorShape global_shape({batch_size, + sequence_length, + hidden_size / head_size, + head_size - half_rotary_embedding_dim}); + + const auto rank = global_shape.NumDimensions(); + std::vector global_dims(rank); + std::vector global_strides(rank); + for (size_t j = 0; j < rank; ++j) { + global_dims[j] = onnxruntime::narrow(global_shape[j]); + global_strides[j] = onnxruntime::narrow(global_shape.SizeFromDimension(j + 1)); + } + + const auto output_size = onnxruntime::narrow(global_shape.Size()); + const auto input_output_strides = + input_shape.NumDimensions() == 3 + ? std::vector({batch_stride, hidden_size, head_size, 1}) + : (input_shape.NumDimensions() == 4 + ? std::vector({batch_stride, head_size, sequence_length * head_size, 1}) + : std::vector({})); + + // The contrib RotaryEmbeddingProgram expects inputs in order: + // input(0), position_ids(1), cos_cache(2), sin_cache(3) + // The ONNX op has: X(0), cos_cache(1), sin_cache(2), position_ids(3, optional) + + if (position_ids != nullptr) { + // position_ids provided: cos/sin cache is 2D (max_pos, D/2) + contrib::webgpu::RotaryEmbeddingProgram program{interleaved_}; + program + .CacheHint(interleaved_) + .AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank}, + {position_ids, ProgramTensorMetadataDependency::Rank}, + {cos_cache, ProgramTensorMetadataDependency::Rank}, + {sin_cache, ProgramTensorMetadataDependency::Rank}}) + .AddOutput({output, ProgramTensorMetadataDependency::None}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({{1.0f}, + {gsl::make_span(global_dims)}, + {gsl::make_span(global_strides)}, + {gsl::make_span(input_output_strides)}}) + .AddIndices(TensorShape{1, 1}); + return context.RunProgram(program); + } + + // position_ids NOT provided: cos/sin cache is 3D (B, S, D/2) + // Reshape to 2D (B*S, D/2) and generate sequential position_ids. + const auto total_seq = batch_size * sequence_length; + const TensorShape cache_2d_shape({static_cast(total_seq), + static_cast(half_rotary_embedding_dim)}); + + // Generate position_ids [0, 1, ..., B*S-1] reshaped as (B, S) on GPU using RangeProgram + const TensorShape pos_ids_shape({static_cast(batch_size), + static_cast(sequence_length)}); + Tensor pos_ids_tensor = context.CreateGPUTensor(DataTypeImpl::GetType(), pos_ids_shape); + { + RangeProgram range_program{ONNX_NAMESPACE::TensorProto_DataType_INT64}; + int32_t start_i32 = 0; + int32_t delta_i32 = 1; + range_program + .AddOutput({&pos_ids_tensor, ProgramTensorMetadataDependency::Type}) + .SetDispatchGroupSize((total_seq + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ + total_seq, + std::bit_cast(start_i32), + std::bit_cast(delta_i32), + }); + ORT_RETURN_IF_ERROR(context.RunProgram(range_program)); + } + + contrib::webgpu::RotaryEmbeddingProgram program{interleaved_}; + program + .CacheHint(interleaved_) + .AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank}, + {&pos_ids_tensor, ProgramTensorMetadataDependency::Rank}, + {cos_cache, ProgramTensorMetadataDependency::Rank, cache_2d_shape, 1}, + {sin_cache, ProgramTensorMetadataDependency::Rank, cache_2d_shape, 1}}) + .AddOutput({output, ProgramTensorMetadataDependency::None}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({{1.0f}, + {gsl::make_span(global_dims)}, + {gsl::make_span(global_strides)}, + {gsl::make_span(input_output_strides)}}) + .AddIndices(TensorShape{1, 1}); + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/llm/rotary_embedding.h b/onnxruntime/core/providers/webgpu/llm/rotary_embedding.h new file mode 100644 index 0000000000000..6a3f60e8b75e3 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/llm/rotary_embedding.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace webgpu { + +class RotaryEmbedding final : public WebGpuKernel { + public: + RotaryEmbedding(const OpKernelInfo& info); + Status ComputeInternal(ComputeContext& context) const override; + + private: + int num_heads_; + int rotary_embedding_dim_; + bool interleaved_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/rms_norm.cc b/onnxruntime/core/providers/webgpu/nn/rms_norm.cc new file mode 100644 index 0000000000000..250b1153beb8b --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/rms_norm.cc @@ -0,0 +1,121 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/webgpu_utils.h" +#include "core/providers/webgpu/nn/rms_norm.h" +#include "core/providers/webgpu/nn/layer_norm.h" + +namespace onnxruntime { +namespace webgpu { + +static size_t NormalizeAxis(int64_t axis, size_t tensor_rank) { + int64_t rank = static_cast(tensor_rank); + if (axis < -rank && axis >= rank) { + ORT_THROW("invalid axis: ", axis); + } + return onnxruntime::narrow(axis < 0 ? axis + rank : axis); +} + +static TensorShape GetOverrideShape(const TensorShape& shape, int components) { + TensorShape override_shape{shape.Size() / components}; + return override_shape; +} + +Status RMSNorm::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const auto* x = context.Input(0); + const auto* scale = context.Input(1); + + const auto x_shape = x->Shape(); + + const size_t axis = NormalizeAxis(axis_, x_shape.NumDimensions()); + const uint32_t norm_count = onnxruntime::narrow(x_shape.SizeToDimension(axis)); + const int64_t norm_size = x_shape.SizeFromDimension(axis); + const int components = GetMaxComponents(norm_size); + const uint32_t norm_size_vectorized = onnxruntime::narrow((norm_size + components - 1) / components); + + const auto& scale_shape = scale->Shape(); + const auto scale_size = scale_shape.Size(); + if (scale_shape.NumDimensions() > x_shape.NumDimensions()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Scale and (optional) bias must match X.shape[axis:] or be NumPy-broadcastable to it." + " Scale/Bias rank cannot exceed Input rank."); + } + if (scale_size != norm_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Size of X.shape()[axis:] == ", norm_size, + ". Size of scale must match this. Got scale size of ", + scale_size); + } + + // RMSNormalization outputs: Y (index 0), InvStdDev (index 1, optional) + auto* y = context.Output(0, x_shape); + + TensorShapeVector inv_std_dev_dim; + for (size_t i = 0; i < x_shape.NumDimensions(); ++i) { + if (i < axis) { + inv_std_dev_dim.push_back(x_shape[i]); + } else { + inv_std_dev_dim.push_back(1); + } + } + TensorShape inv_std_dev_shape(inv_std_dev_dim); + auto* inv_std_dev = context.Output(1, inv_std_dev_shape); + + if (x_shape.Size() == 0) { + return Status::OK(); + } + + // Check if we should use split norm dimension optimization + const bool split_norm_dim = norm_size % 512 == 0 && norm_count == 1; + + // Reuse LayerNormProgram with simplified=true, has_bias=false, no mean output + LayerNormProgram program{/*has_bias=*/false, /*simplified=*/true, /*has_mean_output=*/false, + /*has_inv_std_dev_output=*/inv_std_dev != nullptr, split_norm_dim}; + + program.CacheHint(components, /*simplified=*/true, split_norm_dim) + .AddInputs({{x, ProgramTensorMetadataDependency::Type, GetOverrideShape(x->Shape(), components), components}}) + .AddInputs( + {{scale, ProgramTensorMetadataDependency::Type, GetOverrideShape(scale->Shape(), components), components}}) + .AddOutputs({{y, ProgramTensorMetadataDependency::None, GetOverrideShape(y->Shape(), components), components}}) + .AddUniformVariables({ + {static_cast(components)}, + }) + .AddUniformVariables({ + {static_cast(norm_count)}, + }) + .AddUniformVariables({ + {static_cast(norm_size)}, + }) + .AddUniformVariables({ + {static_cast(norm_size_vectorized)}, + }) + .AddUniformVariables({ + {static_cast(epsilon_)}, + }); + + if (split_norm_dim) { + const uint32_t workgroup_size_x = 128; + const uint32_t dispatch_size_x = onnxruntime::narrow(norm_size / (workgroup_size_x * components)); + program.SetDispatchGroupSize(dispatch_size_x, 1, 1) + .SetWorkgroupSize(workgroup_size_x); + } else { + program.SetDispatchGroupSize(norm_count); + } + + if (inv_std_dev != nullptr) { + program.AddOutputs({{inv_std_dev, ProgramTensorMetadataDependency::None}}); + } + + return context.RunProgram(program); +} + +ONNX_OPERATOR_KERNEL_EX(RMSNormalization, kOnnxDomain, 23, kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()) + .TypeConstraint("V", WebGpuSupportedFloatTypes()), + RMSNorm); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/rms_norm.h b/onnxruntime/core/providers/webgpu/nn/rms_norm.h new file mode 100644 index 0000000000000..47da51f6df4a1 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/rms_norm.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace webgpu { + +class RMSNorm final : public WebGpuKernel { + public: + RMSNorm(const OpKernelInfo& info) : WebGpuKernel(info) { + info.GetAttrOrDefault("axis", &axis_, -1); + info.GetAttrOrDefault("epsilon", &epsilon_, 1e-05f); + } + + Status ComputeInternal(ComputeContext& context) const override; + + private: + int64_t axis_; + float epsilon_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/reshape.cc b/onnxruntime/core/providers/webgpu/tensor/reshape.cc index 9ede015a0c99c..26546d59220fa 100644 --- a/onnxruntime/core/providers/webgpu/tensor/reshape.cc +++ b/onnxruntime/core/providers/webgpu/tensor/reshape.cc @@ -11,7 +11,31 @@ namespace webgpu { ONNX_OPERATOR_KERNEL_EX( Reshape, kOnnxDomain, - 21, + 25, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("shape", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Reshape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Reshape, + kOnnxDomain, + 23, 24, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("shape", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Reshape); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Reshape, + kOnnxDomain, + 21, 22, kWebGpuExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T", WebGpuSupportedNumberTypes()) diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc index 5cc09501ab378..4f45305666c32 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -63,10 +63,19 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( .TypeConstraint("T", WebGpuSupportedNumberTypes()), Transpose); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Transpose, + kOnnxDomain, + 23, 23, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Transpose); + ONNX_OPERATOR_KERNEL_EX( Transpose, kOnnxDomain, - 23, + 24, kWebGpuExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T", WebGpuSupportedNumberTypes()), diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 3a802a996d95c..84aa4d137ad82 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -248,7 +248,9 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxD class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Reshape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, 18, Reshape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, Reshape); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, Reshape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, 22, Reshape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, 24, Reshape); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 25, Reshape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, Identity); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Identity); @@ -281,7 +283,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, Transpose); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 20, Transpose); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, 22, Transpose); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, Transpose); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, 23, Transpose); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 24, Transpose); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, DepthToSpace); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, DepthToSpace); @@ -391,8 +394,13 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxD class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Tile); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 17, LayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, RMSNormalization); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, RotaryEmbedding); + class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 21, LpNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 22, LpNormalization); + class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 5, InstanceNormalization); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 1, 5, InstanceNormalization); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 21, InstanceNormalization); @@ -557,7 +565,9 @@ std::unique_ptr RegisterKernels(bool enable_graph_capture, bool BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -634,7 +644,8 @@ std::unique_ptr RegisterKernels(bool enable_graph_capture, bool BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -743,6 +754,9 @@ std::unique_ptr RegisterKernels(bool enable_graph_capture, bool BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo,