Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions onnxruntime/contrib_ops/webgpu/bert/bias_add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_utils.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "contrib_ops/webgpu/bert/bias_add.h"
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
Expand Down Expand Up @@ -34,15 +35,6 @@ Status BiasAddProgram::GenerateShaderCode(ShaderHelper& shader) const {
return Status::OK();
}

static int64_t GetMaxComponents(int64_t size) {
if (size % 4 == 0) {
return 4;
} else if (size % 2 == 0) {
return 2;
}
return 1;
}

Status BiasAdd::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
const auto* input = context.Input(0);
const auto* bias = context.Input(1);
Expand Down
23 changes: 1 addition & 22 deletions onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_utils.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
#include "contrib_ops/webgpu/bert/skip_layer_norm.h"
Expand All @@ -10,28 +11,6 @@ namespace onnxruntime {
namespace contrib {
namespace webgpu {

static uint32_t GetMaxComponents(int size) {
if (size % 4 == 0) {
return 4;
} else if (size % 2 == 0) {
return 2;
}
return 1;
}

static std::string SumVector(std::string x, int components) {
switch (components) {
case 1:
return x;
case 2:
return "(" + x + ".x + " + x + ".y" + ")";
case 4:
return "(" + x + ".x + " + x + ".y + " + x + ".w + " + x + ".z" + ")";
default:
ORT_THROW("Unsupported number of components: ", components);
}
}

Status SkipLayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
shader.AddInput("skip", ShaderUsage::UseUniform);
Expand Down
12 changes: 1 addition & 11 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,13 @@
#include "core/providers/cpu/math/matmul_helper.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "core/providers/webgpu/webgpu_utils.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

namespace {
// Put it to a common place?
uint32_t GetMaxComponents(uint32_t size) {
// we cannot use vec3 type since it has alignment of 16 bytes
if (size % 4 == 0) {
return 4;
} else if (size % 2 == 0) {
return 2;
}

return 1;
}

std::string QuantizedDataType(int components) {
switch (components) {
Expand Down
23 changes: 1 addition & 22 deletions onnxruntime/core/providers/webgpu/math/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "core/providers/webgpu/shader_variable.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "core/providers/webgpu/webgpu_utils.h"
namespace onnxruntime {
namespace webgpu {

Expand Down Expand Up @@ -56,28 +57,6 @@ static std::string MaxVector(const std::string& name, int components) {
}
}

static std::string SumVector(const std::string& x, int components) {
switch (components) {
case 1:
return x;
case 2:
return "(" + x + ".x + " + x + ".y" + ")";
case 4:
return "(" + x + ".x + " + x + ".y + " + x + ".w + " + x + ".z" + ")";
default:
ORT_THROW("Unsupported number of components: ", components);
}
}

static int GetMaxComponents(int64_t size) {
if (size % 4 == 0) {
return 4;
} else if (size % 2 == 0) {
return 2;
}
return 1;
}

Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
// Add input and output variables
const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
Expand Down
23 changes: 1 addition & 22 deletions onnxruntime/core/providers/webgpu/nn/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,12 @@

#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "core/providers/webgpu/webgpu_utils.h"
#include "core/providers/webgpu/nn/layer_norm.h"

namespace onnxruntime {
namespace webgpu {

static int GetMaxComponents(int64_t size) {
if (size % 4 == 0) {
return 4;
} else if (size % 2 == 0) {
return 2;
}
return 1;
}

static size_t NormalizeAxis(int64_t axis, size_t tensor_rank) {
int64_t rank = static_cast<int64_t>(tensor_rank);
if (axis < -rank && axis >= rank) {
Expand All @@ -26,19 +18,6 @@ static size_t NormalizeAxis(int64_t axis, size_t tensor_rank) {
return onnxruntime::narrow<size_t>(axis < 0 ? axis + rank : axis);
}

static std::string SumVector(std::string x, int components) {
switch (components) {
case 1:
return x;
case 2:
return "(" + x + ".x + " + x + ".y" + ")";
case 4:
return "(" + x + ".x + " + x + ".y + " + x + ".w + " + x + ".z" + ")";
default:
ORT_THROW("Unsupported number of components: ", components);
}
}

Status LayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
shader.AddInput("scale", ShaderUsage::UseUniform);
Expand Down
17 changes: 15 additions & 2 deletions onnxruntime/core/providers/webgpu/webgpu_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
namespace onnxruntime {
namespace webgpu {

inline int64_t GetMaxComponents(int64_t size) {
inline int GetMaxComponents(int64_t size) {
if (size % 4 == 0) {
return 4;
} else if (size % 2 == 0) {
Expand All @@ -16,5 +16,18 @@
return 1;
}

inline std::string SumVector(std::string x, int components) {

Check warning on line 19 in onnxruntime/core/providers/webgpu/webgpu_utils.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/webgpu_utils.h:19: Add #include <string> for string [build/include_what_you_use] [4]
switch (components) {
case 1:
return x;
case 2:
return "(" + x + ".x + " + x + ".y" + ")";
case 4:
return "(" + x + ".x + " + x + ".y + " + x + ".z + " + x + ".w" + ")";
default:
ORT_THROW("Unsupported number of components: ", components);
}
}

} // namespace webgpu
} // namespace onnxruntime
} // namespace onnxruntime
Loading