-
Notifications
You must be signed in to change notification settings - Fork 3.9k
+rotemb, +rmsnorm, reshape->opset-25, transpose->opset-24 #27752
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
146 changes: 146 additions & 0 deletions
146
onnxruntime/core/providers/webgpu/llm/rotary_embedding.cc
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,146 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #include "core/providers/webgpu/webgpu_supported_types.h" | ||
| #include "core/providers/webgpu/llm/rotary_embedding.h" | ||
| #include "contrib_ops/webgpu/bert/rotary_embedding.h" | ||
| #include "core/providers/webgpu/generator/range.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace webgpu { | ||
|
|
||
| ONNX_OPERATOR_KERNEL_EX( | ||
| RotaryEmbedding, | ||
| kOnnxDomain, | ||
| 23, | ||
| kWebGpuExecutionProvider, | ||
| (*KernelDefBuilder::Create()) | ||
| .TypeConstraint("T", WebGpuSupportedFloatTypes()) | ||
| .TypeConstraint("M", DataTypeImpl::GetTensorType<int64_t>()), | ||
| RotaryEmbedding); | ||
|
|
||
| RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : WebGpuKernel(info) { | ||
| rotary_embedding_dim_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("rotary_embedding_dim", 0)); | ||
| num_heads_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("num_heads", 0)); | ||
| interleaved_ = (info.GetAttrOrDefault<int64_t>("interleaved", 0) == 1); | ||
| } | ||
|
|
||
| Status RotaryEmbedding::ComputeInternal(ComputeContext& context) const { | ||
| // ONNX inputs: X(0), cos_cache(1), sin_cache(2), position_ids(3, optional) | ||
| const auto* input = context.Input<Tensor>(0); | ||
| const auto* cos_cache = context.Input<Tensor>(1); | ||
| const auto* sin_cache = context.Input<Tensor>(2); | ||
| const auto* position_ids = context.Input<Tensor>(3); // optional | ||
|
|
||
| const auto input_shape = input->Shape(); | ||
| auto* output = context.Output(0, input_shape); | ||
|
|
||
| const auto batch_size = onnxruntime::narrow<uint32_t>(input_shape[0]); | ||
| const auto batch_stride = onnxruntime::narrow<uint32_t>(input_shape.SizeFromDimension(1)); | ||
| const auto sequence_length = onnxruntime::narrow<uint32_t>(input_shape[input_shape.NumDimensions() - 2]); | ||
| const auto hidden_size = batch_stride / sequence_length; | ||
| const auto half_rotary_embedding_dim = onnxruntime::narrow<uint32_t>(cos_cache->Shape()[cos_cache->Shape().NumDimensions() - 1]); | ||
|
|
||
| // Compute head_size: when rotary_embedding_dim is not set, head_size = rotary_dim (= 2 * half). | ||
| // When rotary_embedding_dim is set, derive head_size from the 4D input shape or num_heads attribute. | ||
| uint32_t head_size; | ||
| if (rotary_embedding_dim_ == 0) { | ||
| head_size = half_rotary_embedding_dim * 2; | ||
| } else if (input_shape.NumDimensions() == 4) { | ||
| // 4D input: [batch, num_heads, seq, head_size] | ||
| head_size = onnxruntime::narrow<uint32_t>(input_shape[3]); | ||
| } else { | ||
| ORT_ENFORCE(num_heads_ > 0, | ||
| "Attribute 'num_heads' must be provided when 'rotary_embedding_dim' is specified " | ||
| "and input is not rank-4 (batch, num_heads, sequence, head)."); | ||
| head_size = hidden_size / num_heads_; | ||
| } | ||
|
|
||
| const TensorShape global_shape({batch_size, | ||
| sequence_length, | ||
| hidden_size / head_size, | ||
| head_size - half_rotary_embedding_dim}); | ||
|
|
||
| const auto rank = global_shape.NumDimensions(); | ||
| std::vector<uint32_t> global_dims(rank); | ||
| std::vector<uint32_t> global_strides(rank); | ||
| for (size_t j = 0; j < rank; ++j) { | ||
| global_dims[j] = onnxruntime::narrow<uint32_t>(global_shape[j]); | ||
| global_strides[j] = onnxruntime::narrow<uint32_t>(global_shape.SizeFromDimension(j + 1)); | ||
| } | ||
|
|
||
| const auto output_size = onnxruntime::narrow<const uint32_t>(global_shape.Size()); | ||
| const auto input_output_strides = | ||
| input_shape.NumDimensions() == 3 | ||
| ? std::vector<uint32_t>({batch_stride, hidden_size, head_size, 1}) | ||
| : (input_shape.NumDimensions() == 4 | ||
| ? std::vector<uint32_t>({batch_stride, head_size, sequence_length * head_size, 1}) | ||
| : std::vector<uint32_t>({})); | ||
|
Check warning on line 78 in onnxruntime/core/providers/webgpu/llm/rotary_embedding.cc
|
||
|
|
||
| // The contrib RotaryEmbeddingProgram expects inputs in order: | ||
| // input(0), position_ids(1), cos_cache(2), sin_cache(3) | ||
| // The ONNX op has: X(0), cos_cache(1), sin_cache(2), position_ids(3, optional) | ||
|
|
||
| if (position_ids != nullptr) { | ||
| // position_ids provided: cos/sin cache is 2D (max_pos, D/2) | ||
| contrib::webgpu::RotaryEmbeddingProgram program{interleaved_}; | ||
|
guschmue marked this conversation as resolved.
|
||
| program | ||
| .CacheHint(interleaved_) | ||
| .AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank}, | ||
| {position_ids, ProgramTensorMetadataDependency::Rank}, | ||
| {cos_cache, ProgramTensorMetadataDependency::Rank}, | ||
| {sin_cache, ProgramTensorMetadataDependency::Rank}}) | ||
| .AddOutput({output, ProgramTensorMetadataDependency::None}) | ||
| .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) | ||
| .AddUniformVariables({{1.0f}, | ||
| {gsl::make_span(global_dims)}, | ||
| {gsl::make_span(global_strides)}, | ||
| {gsl::make_span(input_output_strides)}}) | ||
| .AddIndices(TensorShape{1, 1}); | ||
| return context.RunProgram(program); | ||
| } | ||
|
|
||
| // position_ids NOT provided: cos/sin cache is 3D (B, S, D/2) | ||
| // Reshape to 2D (B*S, D/2) and generate sequential position_ids. | ||
| const auto total_seq = batch_size * sequence_length; | ||
| const TensorShape cache_2d_shape({static_cast<int64_t>(total_seq), | ||
| static_cast<int64_t>(half_rotary_embedding_dim)}); | ||
|
|
||
| // Generate position_ids [0, 1, ..., B*S-1] reshaped as (B, S) on GPU using RangeProgram | ||
| const TensorShape pos_ids_shape({static_cast<int64_t>(batch_size), | ||
| static_cast<int64_t>(sequence_length)}); | ||
| Tensor pos_ids_tensor = context.CreateGPUTensor(DataTypeImpl::GetType<int64_t>(), pos_ids_shape); | ||
| { | ||
| RangeProgram range_program{ONNX_NAMESPACE::TensorProto_DataType_INT64}; | ||
| int32_t start_i32 = 0; | ||
| int32_t delta_i32 = 1; | ||
| range_program | ||
| .AddOutput({&pos_ids_tensor, ProgramTensorMetadataDependency::Type}) | ||
| .SetDispatchGroupSize((total_seq + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) | ||
| .AddUniformVariables({ | ||
| total_seq, | ||
| std::bit_cast<uint32_t>(start_i32), | ||
| std::bit_cast<uint32_t>(delta_i32), | ||
| }); | ||
| ORT_RETURN_IF_ERROR(context.RunProgram(range_program)); | ||
| } | ||
|
|
||
| contrib::webgpu::RotaryEmbeddingProgram program{interleaved_}; | ||
| program | ||
| .CacheHint(interleaved_) | ||
| .AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank}, | ||
| {&pos_ids_tensor, ProgramTensorMetadataDependency::Rank}, | ||
| {cos_cache, ProgramTensorMetadataDependency::Rank, cache_2d_shape, 1}, | ||
| {sin_cache, ProgramTensorMetadataDependency::Rank, cache_2d_shape, 1}}) | ||
| .AddOutput({output, ProgramTensorMetadataDependency::None}) | ||
| .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) | ||
| .AddUniformVariables({{1.0f}, | ||
| {gsl::make_span(global_dims)}, | ||
| {gsl::make_span(global_strides)}, | ||
| {gsl::make_span(input_output_strides)}}) | ||
| .AddIndices(TensorShape{1, 1}); | ||
| return context.RunProgram(program); | ||
| } | ||
|
|
||
| } // namespace webgpu | ||
| } // namespace onnxruntime | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #pragma once | ||
|
|
||
| #include "core/providers/webgpu/webgpu_kernel.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace webgpu { | ||
|
|
||
| class RotaryEmbedding final : public WebGpuKernel { | ||
| public: | ||
| RotaryEmbedding(const OpKernelInfo& info); | ||
| Status ComputeInternal(ComputeContext& context) const override; | ||
|
|
||
| private: | ||
| int num_heads_; | ||
| int rotary_embedding_dim_; | ||
| bool interleaved_; | ||
| }; | ||
|
|
||
| } // namespace webgpu | ||
| } // namespace onnxruntime |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,121 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #include "core/providers/webgpu/shader_helper.h" | ||
| #include "core/providers/webgpu/webgpu_supported_types.h" | ||
| #include "core/providers/webgpu/webgpu_utils.h" | ||
| #include "core/providers/webgpu/nn/rms_norm.h" | ||
| #include "core/providers/webgpu/nn/layer_norm.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace webgpu { | ||
|
|
||
| static size_t NormalizeAxis(int64_t axis, size_t tensor_rank) { | ||
| int64_t rank = static_cast<int64_t>(tensor_rank); | ||
| if (axis < -rank && axis >= rank) { | ||
| ORT_THROW("invalid axis: ", axis); | ||
| } | ||
| return onnxruntime::narrow<size_t>(axis < 0 ? axis + rank : axis); | ||
| } | ||
|
|
||
| static TensorShape GetOverrideShape(const TensorShape& shape, int components) { | ||
| TensorShape override_shape{shape.Size() / components}; | ||
| return override_shape; | ||
| } | ||
|
|
||
| Status RMSNorm::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { | ||
| const auto* x = context.Input(0); | ||
| const auto* scale = context.Input(1); | ||
|
|
||
| const auto x_shape = x->Shape(); | ||
|
|
||
| const size_t axis = NormalizeAxis(axis_, x_shape.NumDimensions()); | ||
| const uint32_t norm_count = onnxruntime::narrow<uint32_t>(x_shape.SizeToDimension(axis)); | ||
| const int64_t norm_size = x_shape.SizeFromDimension(axis); | ||
| const int components = GetMaxComponents(norm_size); | ||
| const uint32_t norm_size_vectorized = onnxruntime::narrow<uint32_t>((norm_size + components - 1) / components); | ||
|
|
||
| const auto& scale_shape = scale->Shape(); | ||
| const auto scale_size = scale_shape.Size(); | ||
| if (scale_shape.NumDimensions() > x_shape.NumDimensions()) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, | ||
| "Scale and (optional) bias must match X.shape[axis:] or be NumPy-broadcastable to it." | ||
| " Scale/Bias rank cannot exceed Input rank."); | ||
| } | ||
| if (scale_size != norm_size) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, | ||
| "Size of X.shape()[axis:] == ", norm_size, | ||
| ". Size of scale must match this. Got scale size of ", | ||
| scale_size); | ||
| } | ||
|
|
||
| // RMSNormalization outputs: Y (index 0), InvStdDev (index 1, optional) | ||
| auto* y = context.Output(0, x_shape); | ||
|
|
||
| TensorShapeVector inv_std_dev_dim; | ||
| for (size_t i = 0; i < x_shape.NumDimensions(); ++i) { | ||
| if (i < axis) { | ||
| inv_std_dev_dim.push_back(x_shape[i]); | ||
| } else { | ||
| inv_std_dev_dim.push_back(1); | ||
| } | ||
| } | ||
| TensorShape inv_std_dev_shape(inv_std_dev_dim); | ||
| auto* inv_std_dev = context.Output(1, inv_std_dev_shape); | ||
|
|
||
| if (x_shape.Size() == 0) { | ||
| return Status::OK(); | ||
| } | ||
|
|
||
| // Check if we should use split norm dimension optimization | ||
| const bool split_norm_dim = norm_size % 512 == 0 && norm_count == 1; | ||
|
|
||
| // Reuse LayerNormProgram with simplified=true, has_bias=false, no mean output | ||
| LayerNormProgram program{/*has_bias=*/false, /*simplified=*/true, /*has_mean_output=*/false, | ||
| /*has_inv_std_dev_output=*/inv_std_dev != nullptr, split_norm_dim}; | ||
|
|
||
| program.CacheHint(components, /*simplified=*/true, split_norm_dim) | ||
| .AddInputs({{x, ProgramTensorMetadataDependency::Type, GetOverrideShape(x->Shape(), components), components}}) | ||
| .AddInputs( | ||
| {{scale, ProgramTensorMetadataDependency::Type, GetOverrideShape(scale->Shape(), components), components}}) | ||
| .AddOutputs({{y, ProgramTensorMetadataDependency::None, GetOverrideShape(y->Shape(), components), components}}) | ||
| .AddUniformVariables({ | ||
| {static_cast<uint32_t>(components)}, | ||
| }) | ||
| .AddUniformVariables({ | ||
| {static_cast<uint32_t>(norm_count)}, | ||
| }) | ||
| .AddUniformVariables({ | ||
| {static_cast<uint32_t>(norm_size)}, | ||
| }) | ||
| .AddUniformVariables({ | ||
| {static_cast<uint32_t>(norm_size_vectorized)}, | ||
| }) | ||
| .AddUniformVariables({ | ||
| {static_cast<float>(epsilon_)}, | ||
| }); | ||
|
|
||
| if (split_norm_dim) { | ||
| const uint32_t workgroup_size_x = 128; | ||
| const uint32_t dispatch_size_x = onnxruntime::narrow<uint32_t>(norm_size / (workgroup_size_x * components)); | ||
| program.SetDispatchGroupSize(dispatch_size_x, 1, 1) | ||
| .SetWorkgroupSize(workgroup_size_x); | ||
| } else { | ||
| program.SetDispatchGroupSize(norm_count); | ||
| } | ||
|
|
||
| if (inv_std_dev != nullptr) { | ||
| program.AddOutputs({{inv_std_dev, ProgramTensorMetadataDependency::None}}); | ||
| } | ||
|
|
||
| return context.RunProgram(program); | ||
| } | ||
|
|
||
| ONNX_OPERATOR_KERNEL_EX(RMSNormalization, kOnnxDomain, 23, kWebGpuExecutionProvider, | ||
| (*KernelDefBuilder::Create()) | ||
| .TypeConstraint("T", WebGpuSupportedFloatTypes()) | ||
| .TypeConstraint("V", WebGpuSupportedFloatTypes()), | ||
| RMSNorm); | ||
|
|
||
| } // namespace webgpu | ||
| } // namespace onnxruntime |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #pragma once | ||
|
|
||
| #include "core/providers/webgpu/webgpu_kernel.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace webgpu { | ||
|
|
||
| class RMSNorm final : public WebGpuKernel { | ||
| public: | ||
| RMSNorm(const OpKernelInfo& info) : WebGpuKernel(info) { | ||
| info.GetAttrOrDefault<int64_t>("axis", &axis_, -1); | ||
| info.GetAttrOrDefault<float>("epsilon", &epsilon_, 1e-05f); | ||
| } | ||
|
|
||
| Status ComputeInternal(ComputeContext& context) const override; | ||
|
|
||
| private: | ||
| int64_t axis_; | ||
| float epsilon_; | ||
| }; | ||
|
|
||
| } // namespace webgpu | ||
| } // namespace onnxruntime |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.