diff --git a/onnxruntime/contrib_ops/webgpu/bert/bias_add.cc b/onnxruntime/contrib_ops/webgpu/bert/bias_add.cc index 65c14e8cb0bdd..e822f8764b63f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/bias_add.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/bias_add.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_utils.h" #include "core/providers/webgpu/webgpu_supported_types.h" #include "contrib_ops/webgpu/bert/bias_add.h" #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" @@ -34,15 +35,6 @@ Status BiasAddProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } -static int64_t GetMaxComponents(int64_t size) { - if (size % 4 == 0) { - return 4; - } else if (size % 2 == 0) { - return 2; - } - return 1; -} - Status BiasAdd::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { const auto* input = context.Input(0); const auto* bias = context.Input(1); diff --git a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc index d5d4632c01e2a..61f701f7911a7 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_utils.h" #include "core/providers/webgpu/webgpu_supported_types.h" #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" #include "contrib_ops/webgpu/bert/skip_layer_norm.h" @@ -10,28 +11,6 @@ namespace onnxruntime { namespace contrib { namespace webgpu { -static uint32_t GetMaxComponents(int size) { - if (size % 4 == 0) { - return 4; - } else if (size % 2 == 0) { - return 2; - } - return 1; -} - -static std::string SumVector(std::string x, int components) { - switch (components) { - case 1: - return x; - case 2: - return "(" + x + ".x + " + x + ".y" + ")"; - case 4: - return "(" + x + ".x + " + x + ".y + " + x + ".w + " + x + ".z" + ")"; - default: - ORT_THROW("Unsupported number of components: ", components); - } -} - Status SkipLayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); shader.AddInput("skip", ShaderUsage::UseUniform); diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index b4e47b9186265..be105a0fd4374 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -10,23 +10,13 @@ #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/webgpu_utils.h" namespace onnxruntime { namespace contrib { namespace webgpu { namespace { -// Put it to a common place? -uint32_t GetMaxComponents(uint32_t size) { - // we cannot use vec3 type since it has alignment of 16 bytes - if (size % 4 == 0) { - return 4; - } else if (size % 2 == 0) { - return 2; - } - - return 1; -} std::string QuantizedDataType(int components) { switch (components) { diff --git a/onnxruntime/core/providers/webgpu/math/softmax.cc b/onnxruntime/core/providers/webgpu/math/softmax.cc index d06fc5a57eb8c..6a6cfe154b91c 100644 --- a/onnxruntime/core/providers/webgpu/math/softmax.cc +++ b/onnxruntime/core/providers/webgpu/math/softmax.cc @@ -11,6 +11,7 @@ #include "core/providers/webgpu/shader_variable.h" #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/webgpu_utils.h" namespace onnxruntime { namespace webgpu { @@ -56,28 +57,6 @@ static std::string MaxVector(const std::string& name, int components) { } } -static std::string SumVector(const std::string& x, int components) { - switch (components) { - case 1: - return x; - case 2: - return "(" + x + ".x + " + x + ".y" + ")"; - case 4: - return "(" + x + ".x + " + x + ".y + " + x + ".w + " + x + ".z" + ")"; - default: - ORT_THROW("Unsupported number of components: ", components); - } -} - -static int GetMaxComponents(int64_t size) { - if (size % 4 == 0) { - return 4; - } else if (size % 2 == 0) { - return 2; - } - return 1; -} - Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { // Add input and output variables const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); diff --git a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc index 28ad686909a47..cf2939555057a 100644 --- a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc +++ b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc @@ -4,20 +4,12 @@ #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/webgpu_utils.h" #include "core/providers/webgpu/nn/layer_norm.h" namespace onnxruntime { namespace webgpu { -static int GetMaxComponents(int64_t size) { - if (size % 4 == 0) { - return 4; - } else if (size % 2 == 0) { - return 2; - } - return 1; -} - static size_t NormalizeAxis(int64_t axis, size_t tensor_rank) { int64_t rank = static_cast(tensor_rank); if (axis < -rank && axis >= rank) { @@ -26,19 +18,6 @@ static size_t NormalizeAxis(int64_t axis, size_t tensor_rank) { return onnxruntime::narrow(axis < 0 ? axis + rank : axis); } -static std::string SumVector(std::string x, int components) { - switch (components) { - case 1: - return x; - case 2: - return "(" + x + ".x + " + x + ".y" + ")"; - case 4: - return "(" + x + ".x + " + x + ".y + " + x + ".w + " + x + ".z" + ")"; - default: - ORT_THROW("Unsupported number of components: ", components); - } -} - Status LayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); shader.AddInput("scale", ShaderUsage::UseUniform); diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.h b/onnxruntime/core/providers/webgpu/webgpu_utils.h index 4f9018646905d..eb25a9bd5386e 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.h +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.h @@ -7,7 +7,7 @@ namespace onnxruntime { namespace webgpu { -inline int64_t GetMaxComponents(int64_t size) { +inline int GetMaxComponents(int64_t size) { if (size % 4 == 0) { return 4; } else if (size % 2 == 0) { @@ -16,5 +16,18 @@ inline int64_t GetMaxComponents(int64_t size) { return 1; } +inline std::string SumVector(std::string x, int components) { + switch (components) { + case 1: + return x; + case 2: + return "(" + x + ".x + " + x + ".y" + ")"; + case 4: + return "(" + x + ".x + " + x + ".y + " + x + ".z + " + x + ".w" + ")"; + default: + ORT_THROW("Unsupported number of components: ", components); + } +} + } // namespace webgpu -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime