-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[WebGPU EP] Implements CumSum Operator #24047
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
6be1e87
cum sum impl
prathikr 981e18f
Merge remote-tracking branch 'origin' into prathikrao/cumsum-webgpu-ep
prathikr 8737179
Merge remote-tracking branch 'origin' into prathikrao/cumsum-webgpu-ep
prathikr a703c82
address satya comments
prathikr 774309d
remove extra streams
prathikr fb79b18
do everything in MainFunctionBody
prathikr db71f46
int64_t to int
prathikr 507a41d
output_indices_t
prathikr 31ba47a
let to var
prathikr 1b75762
cast
prathikr 555684b
cast
prathikr 199900d
final cast
prathikr b7e122f
move cast
prathikr 2debcd6
another cast
prathikr cd1d914
format
prathikr 47abdf5
output_value_t
prathikr 3eda700
yulong comments
prathikr File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #include "core/providers/webgpu/math/cum_sum.h" | ||
| #include "core/providers/webgpu/shader_helper.h" | ||
| #include "core/providers/webgpu/webgpu_supported_types.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace webgpu { | ||
|
|
||
| ONNX_OPERATOR_VERSIONED_KERNEL_EX( | ||
| CumSum, | ||
| kOnnxDomain, | ||
| 11, 13, | ||
| kWebGpuExecutionProvider, | ||
| (*KernelDefBuilder::Create()) | ||
| .TypeConstraint("T", WebGpuSupportedFloatTypes()) | ||
| .TypeConstraint("T2", {DataTypeImpl::GetTensorType<int32_t>(), | ||
| DataTypeImpl::GetTensorType<int64_t>()}) | ||
| .InputMemoryType(OrtMemTypeCPU, 1), | ||
| CumSum); | ||
|
|
||
| ONNX_OPERATOR_KERNEL_EX( | ||
| CumSum, | ||
| kOnnxDomain, | ||
| 14, | ||
| kWebGpuExecutionProvider, | ||
| (*KernelDefBuilder::Create()) | ||
| .TypeConstraint("T", WebGpuSupportedFloatTypes()) | ||
| .TypeConstraint("T2", {DataTypeImpl::GetTensorType<int32_t>(), | ||
| DataTypeImpl::GetTensorType<int64_t>()}) | ||
| .InputMemoryType(OrtMemTypeCPU, 1), | ||
| CumSum); | ||
|
|
||
| Status CumSumProgram::GenerateShaderCode(ShaderHelper& shader) const { | ||
| const ShaderVariableHelper& input = shader.AddInput("input", ShaderUsage::UseUniform); | ||
| const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); | ||
|
|
||
| shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") | ||
| << "var input_indices = " << input.OffsetToIndices("global_idx") << ";\n" | ||
| << "var sum : output_value_t = 0;\n" | ||
| << "var first : i32 = 0;\n" | ||
| << "if (uniforms.reverse == 1) {\n" | ||
| << " first = i32(" + input.IndicesGet("input_indices", "uniforms.axis") + ");\n" | ||
| << " if (uniforms.exclusive == 1) { first += 1; }\n" | ||
| << "}\n\n" | ||
| << "var last : i32 = 0;\n" | ||
| << "if (uniforms.reverse == 1) {\n" | ||
| << " last = i32(" << GetElementAt("uniforms.input_shape", "uniforms.axis", input.Rank()) << ");\n" | ||
| << "} else {\n" | ||
| << " last = i32(" + input.IndicesGet("input_indices", "uniforms.axis") + ");\n" | ||
| << " if (uniforms.exclusive == 0) { last += 1; }\n" | ||
| << "}\n\n" | ||
| << "for (var i : i32 = first; i < last; i++) {\n" | ||
| << " " << input.IndicesSet("input_indices", "uniforms.axis", "u32(i)") << ";\n" | ||
| << " sum = sum + " << input.GetByIndices("input_indices") << ";\n" | ||
| << "}\n" | ||
| << output.SetByOffset("global_idx", "sum"); | ||
|
|
||
| return Status::OK(); | ||
| } | ||
|
|
||
| Status CumSum::ComputeInternal(ComputeContext& context) const { | ||
| const auto* input_tensor = context.Input(0); | ||
| const TensorShape& input_shape = input_tensor->Shape(); | ||
| int64_t input_rank = input_shape.NumDimensions(); | ||
|
|
||
| const auto* axis_tensor = context.Input(1); | ||
| const auto* axis_data = axis_tensor->Data<int>(); | ||
| int64_t axis = static_cast<int64_t>(axis_data[0]); | ||
|
|
||
| ORT_ENFORCE(-input_rank <= axis && axis < input_rank, "Axes attribute must be within range -input_rank <= axis < input_rank."); | ||
| // Handle negative axis | ||
| if (axis < 0) { | ||
| axis += input_rank; | ||
| } | ||
|
|
||
| auto* output_tensor = context.Output(0, input_shape); | ||
| int64_t output_size = output_tensor->Shape().Size(); | ||
|
|
||
| if (output_size == 0) { | ||
| return Status::OK(); | ||
| } | ||
|
|
||
| CumSumProgram program{}; | ||
| program | ||
| .AddInput({input_tensor}) | ||
| .AddOutput({output_tensor, ProgramTensorMetadataDependency::TypeAndRank}) | ||
| .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) | ||
| .AddUniformVariables({{static_cast<uint32_t>(output_size)}, | ||
| {static_cast<uint32_t>(axis)}, | ||
| {static_cast<uint32_t>(exclusive_)}, | ||
| {static_cast<uint32_t>(reverse_)}}); | ||
| return context.RunProgram(program); | ||
| } | ||
|
|
||
| } // namespace webgpu | ||
| } // namespace onnxruntime | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #pragma once | ||
|
|
||
| #include "core/providers/webgpu/webgpu_kernel.h" | ||
| #include "core/providers/webgpu/program.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace webgpu { | ||
|
|
||
| class CumSumProgram final : public Program<CumSumProgram> { | ||
| public: | ||
| CumSumProgram() : Program{"CumSum"} {} | ||
|
|
||
| Status GenerateShaderCode(ShaderHelper& sh) const override; | ||
|
|
||
| WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}, | ||
| {"axis", ProgramUniformVariableDataType::Uint32}, | ||
| {"exclusive", ProgramUniformVariableDataType::Uint32}, | ||
| {"reverse", ProgramUniformVariableDataType::Uint32}); | ||
| }; | ||
|
|
||
| class CumSum final : public WebGpuKernel { | ||
| public: | ||
| CumSum(const OpKernelInfo& info) : WebGpuKernel(info) { | ||
| exclusive_ = info.GetAttrOrDefault<int64_t>("exclusive", 0); | ||
| reverse_ = info.GetAttrOrDefault<int64_t>("reverse", 0); | ||
| } | ||
|
|
||
| Status ComputeInternal(ComputeContext& context) const override; | ||
|
|
||
| private: | ||
| int64_t exclusive_; | ||
| int64_t reverse_; | ||
| }; | ||
|
|
||
| } // namespace webgpu | ||
| } // namespace onnxruntime |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.