Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion cmake/onnxruntime_providers_webgpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,6 @@
endif()
endif()

target_compile_features(onnxruntime_providers_webgpu PRIVATE cxx_std_20)
add_dependencies(onnxruntime_providers_webgpu onnx ${onnxruntime_EXTERNAL_DEPENDENCIES})

if (onnxruntime_WGSL_TEMPLATE)
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/webgpu/bert/bias_add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Status BiasAddProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< " let value = " << input.GetByOffset("global_idx")
<< " + " << bias.GetByOffset("global_idx % uniforms.channels")
<< " + " << residual.GetByOffset("global_idx") << ";\n"
<< " " + output.SetByOffset("global_idx", "value");
<< " " << output.SetByOffset("global_idx", "value");

return Status::OK();
}
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ Status RotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< " let j = i + select(half_rotary_emb_dim, 1, " << interleaved_str << ");\n"
<< " let re = " << input.GetByOffset("i") << " * " << cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << " - " << input.GetByOffset("j") << " * " << sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << ";\n"
<< " " << output.SetByOffset("i", "re") << "\n"
<< " let im = " << input.GetByOffset("i") << " * " << sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << " + " << input.GetByOffset("j") + " * " << cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << ";\n"
<< " let im = " << input.GetByOffset("i") << " * " << sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << " + " << input.GetByOffset("j") << " * " << cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << ";\n"
<< " " << output.SetByOffset("j", "im") << "\n"
<< " } else { \n"
" let k = dot(bsnh, uniforms.input_output_stride) + half_rotary_emb_dim;\n"
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/core/providers/webgpu/math/einsum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "core/providers/webgpu/math/einsum.h"

#include <algorithm>
#include <cctype>
#include <regex>
#include <set>
#include <vector>
Expand All @@ -24,7 +25,7 @@ static const std::regex lhs_pattern("(([a-zA-Z]|\\.\\.\\.)*,)*([a-zA-Z]|\\.\\.\\
// Helper function to remove all whitespaces in a given string.
std::string RemoveAllWhitespace(const std::string& str) {
std::string result = str;
result.erase(std::remove_if(result.begin(), result.end(), ::isspace), result.end());
std::erase_if(result, [](unsigned char c) { return std::isspace(c); });
return result;
Comment thread
fs-eire marked this conversation as resolved.
}

Expand Down Expand Up @@ -318,7 +319,7 @@ Status EinsumProgram::GenerateShaderCode(ShaderHelper& shader) const {
symbol));

// Check if we've already processed this symbol to avoid duplicate loop generation
if (uniform_symbol_set.find(symbol) == uniform_symbol_set.end()) {
if (!uniform_symbol_set.contains(symbol)) {
// Add symbol to tracked set to prevent duplicate processing
uniform_symbol_set.insert(symbol);

Expand Down
12 changes: 4 additions & 8 deletions onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#include <utility>
#include <cstring>
#include <limits>

#include "core/providers/webgpu/math/unary_elementwise_ops.h"
Expand Down Expand Up @@ -194,10 +195,6 @@ class Clip final : public UnaryElementwise {
"Clip",
std::is_same_v<T, MLFloat16> ? ClipF16Impl : ClipImpl,
"", ShaderUsage::UseElementTypeAlias} {}
#if defined(__GNUC__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif

Status ConfigureProgram(const ComputeContext& context, UnaryElementwiseProgram& program) const override {
const auto* clip_min_tensor = context.Input<Tensor>(1);
Expand All @@ -209,7 +206,9 @@ class Clip final : public UnaryElementwise {
: std::numeric_limits<T>::max()};
if constexpr (std::is_same_v<T, MLFloat16>) {
// F16: stores span<f16, 2> as a single float
float encoded_value = *reinterpret_cast<const float*>(attr);
float encoded_value;
static_assert(sizeof(encoded_value) == 2 * sizeof(MLFloat16));
std::memcpy(&encoded_value, attr, sizeof(encoded_value));
program.AddUniformVariable({encoded_value});
Comment thread
fs-eire marked this conversation as resolved.
} else {
static_assert(sizeof(T) == sizeof(float), "T must be f32, i32 or u32");
Expand All @@ -218,9 +217,6 @@ class Clip final : public UnaryElementwise {
}
return Status::OK();
}
#if defined(__GNUC__)
#pragma GCC diagnostic pop
#endif

// uniforms.attr is a f32 value. It is encoded as a float for 2 f16 values.
// bitcast<vec2<f16>>(uniforms.attr)[0] is clip_min, bitcast<vec2<f16>>(uniforms.attr)[1] is clip_max
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/webgpu/nn/pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ Status Pool<PoolType, is_nhwc>::ComputeInternal(ComputeContext& context) const {
Tensor* Y = context.Output(0, output_shape);

std::vector<uint32_t> kernel_strides(kernel_shape.size());
ORT_ENFORCE(kernel_shape.size() > 0, "kernel_shape must have at least one element.");
ORT_ENFORCE(!kernel_shape.empty(), "kernel_shape must have at least one element.");
// Calculate the kernel element strides for each dimension in reverse order. For example:
// kernel_shape = [3, 2], kernel_strides = [2, 1]
// kernel_shape = [2, 3, 2], kernel_strides = [6, 2, 1]
Expand Down
206 changes: 53 additions & 153 deletions onnxruntime/core/providers/webgpu/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#pragma once

#include <array>
#include <string>
#include <vector>
#include <iosfwd>
Expand Down Expand Up @@ -164,28 +165,19 @@ enum class ProgramTensorMetadataDependency : int {
};
OStringStream& operator<<(OStringStream& os, ProgramTensorMetadataDependency);

#if defined(__GNUC__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif

inline ProgramTensorMetadataDependency operator|(ProgramTensorMetadataDependency a, ProgramTensorMetadataDependency b) {
return (ProgramTensorMetadataDependency)((int&)a | (int&)b);
return static_cast<ProgramTensorMetadataDependency>(static_cast<int>(a) | static_cast<int>(b));
}
inline ProgramTensorMetadataDependency operator&(ProgramTensorMetadataDependency a, ProgramTensorMetadataDependency b) {
return (ProgramTensorMetadataDependency)((int&)a & (int&)b);
return static_cast<ProgramTensorMetadataDependency>(static_cast<int>(a) & static_cast<int>(b));
}
inline ProgramTensorMetadataDependency& operator|=(ProgramTensorMetadataDependency& a, ProgramTensorMetadataDependency b) {
return (ProgramTensorMetadataDependency&)((int&)a |= (int&)b);
return a = a | b;
}
inline ProgramTensorMetadataDependency& operator&=(ProgramTensorMetadataDependency& a, ProgramTensorMetadataDependency b) {
return (ProgramTensorMetadataDependency&)((int&)a &= (int&)b);
return a = a & b;
}

#if defined(__GNUC__)
#pragma GCC diagnostic pop
#endif

constexpr SafeInt<uint32_t> WORKGROUP_SIZE = 64;

// data type of variable
Expand Down Expand Up @@ -417,133 +409,55 @@ class ProgramWrapper : public ProgramBase {
ProgramWrapper(Args&&... args) : ProgramBase{std::forward<Args>(args)...} {}
};

#if defined(ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK)
#error "macro ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK is already defined"
#endif

#define ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(identifier, element_type) \
private: \
template <typename U> \
static auto test_has_##identifier(int) -> decltype(U::identifier, std::true_type{}); /* checks if member exists */ \
template <typename...> \
static auto test_has_##identifier(...) -> std::false_type; \
\
template <typename U, /* The following type check uses SFINAE */ \
typename = std::enable_if_t< /* to ensure the specific member: */ \
is_const_std_array<decltype(U::identifier)>::value && /* - is a const std::array */ \
std::is_const_v<decltype(U::identifier)> && /* - has "const" modifier */ \
!std::is_member_pointer_v<decltype(&U::identifier)>>> /* - is static */ \
static auto test_has_##identifier##_with_correct_type(int) -> std::true_type; \
template <typename...> \
static auto test_has_##identifier##_with_correct_type(...) -> std::false_type; \
\
public: \
static constexpr bool has_##identifier = decltype(test_has_##identifier<T>(0))::value; \
static constexpr bool has_##identifier##_with_correct_type = decltype(test_has_##identifier##_with_correct_type<T>(0))::value

// the following template class checks whether the type is a const std::array
template <typename T>
struct is_const_std_array : std::false_type {};
template <typename T, size_t N>
struct is_const_std_array<const std::array<T, N>> : std::true_type {};

// the following template class checks whether certain static members exist in the derived class (SFINAE)
template <typename T>
class DerivedProgramClassTypeCheck {
ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(constants, ProgramConstant);
ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(overridable_constants, ProgramOverridableConstantDefinition);
ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(uniform_variables, ProgramUniformVariableDefinition);
};

// compile-time tests for the type check
//
// TODO: move this to test folder
namespace test {
// The following variable templates check whether certain static members exist in the derived class.
// Uses std::void_t with decltype(T::member) for SFINAE-based detection of named static data members.

template <typename T, typename = void>
inline constexpr bool has_member_constants = false;
template <typename T>
class TestTypeCheck {
ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(a, int);
};

struct TestClass_Empty {};
static_assert(!TestTypeCheck<TestClass_Empty>::has_a);
static_assert(!TestTypeCheck<TestClass_Empty>::has_a_with_correct_type);

struct TestClass_NotArray_0 {
int b;
};
static_assert(!TestTypeCheck<TestClass_NotArray_0>::has_a);
static_assert(!TestTypeCheck<TestClass_NotArray_0>::has_a_with_correct_type);
inline constexpr bool has_member_constants<T, std::void_t<decltype(T::constants)>> = true;

struct TestClass_NotArray_1 {
int a;
};
static_assert(TestTypeCheck<TestClass_NotArray_1>::has_a);
static_assert(!TestTypeCheck<TestClass_NotArray_1>::has_a_with_correct_type);

struct TestClass_NotArray_2 {
const int a;
};
static_assert(TestTypeCheck<TestClass_NotArray_2>::has_a);
static_assert(!TestTypeCheck<TestClass_NotArray_2>::has_a_with_correct_type);

struct TestClass_NotStdArray_0 {
const int a[2];
};
static_assert(TestTypeCheck<TestClass_NotStdArray_0>::has_a);
static_assert(!TestTypeCheck<TestClass_NotStdArray_0>::has_a_with_correct_type);

struct TestClass_NotStdArray_1 {
static constexpr int a[] = {0};
};
static_assert(TestTypeCheck<TestClass_NotStdArray_1>::has_a);
static_assert(!TestTypeCheck<TestClass_NotStdArray_1>::has_a_with_correct_type);

struct TestClass_NotStdArray_2 {
static int a[];
};
static_assert(TestTypeCheck<TestClass_NotStdArray_2>::has_a);
static_assert(!TestTypeCheck<TestClass_NotStdArray_2>::has_a_with_correct_type);

struct TestClass_NotStdArray_3 {
static const int a[];
};
static_assert(TestTypeCheck<TestClass_NotStdArray_3>::has_a);
static_assert(!TestTypeCheck<TestClass_NotStdArray_3>::has_a_with_correct_type);
template <typename T, typename = void>
inline constexpr bool has_member_overridable_constants = false;
template <typename T>
inline constexpr bool has_member_overridable_constants<T, std::void_t<decltype(T::overridable_constants)>> = true;

struct TestClass_StdArray_0 {
std::array<int, 1> a = {1};
};
static_assert(TestTypeCheck<TestClass_StdArray_0>::has_a);
static_assert(!TestTypeCheck<TestClass_StdArray_0>::has_a_with_correct_type);
template <typename T, typename = void>
inline constexpr bool has_member_uniform_variables = false;
template <typename T>
inline constexpr bool has_member_uniform_variables<T, std::void_t<decltype(T::uniform_variables)>> = true;

struct TestClass_StdArray_1 {
static constexpr std::array<int, 2> a = {1, 2};
};
static_assert(TestTypeCheck<TestClass_StdArray_1>::has_a);
static_assert(TestTypeCheck<TestClass_StdArray_1>::has_a_with_correct_type);
// C++20 concepts for checking whether the member has the correct type (static const std::array).

struct TestClass_StdArray_2 {
static const std::array<int, 3> a;
template <typename T>
concept has_constants_correct_type = requires {
T::constants;
requires is_const_std_array<decltype(T::constants)>::value;
requires std::is_const_v<decltype(T::constants)>;
requires !std::is_member_pointer_v<decltype(&T::constants)>;
};
Comment thread
fs-eire marked this conversation as resolved.
static_assert(TestTypeCheck<TestClass_StdArray_2>::has_a);
static_assert(TestTypeCheck<TestClass_StdArray_2>::has_a_with_correct_type);

struct TestClass_StdArray_3 {
static constexpr const std::array<int, 4> a = {1, 2, 3, 4};
template <typename T>
concept has_overridable_constants_correct_type = requires {
T::overridable_constants;
requires is_const_std_array<decltype(T::overridable_constants)>::value;
requires std::is_const_v<decltype(T::overridable_constants)>;
requires !std::is_member_pointer_v<decltype(&T::overridable_constants)>;
};
static_assert(TestTypeCheck<TestClass_StdArray_3>::has_a);
static_assert(TestTypeCheck<TestClass_StdArray_3>::has_a_with_correct_type);

struct TestClass_StdArray_4 {
static std::array<int, 5> a;
template <typename T>
concept has_uniform_variables_correct_type = requires {
T::uniform_variables;
requires is_const_std_array<decltype(T::uniform_variables)>::value;
requires std::is_const_v<decltype(T::uniform_variables)>;
requires !std::is_member_pointer_v<decltype(&T::uniform_variables)>;
};
static_assert(TestTypeCheck<TestClass_StdArray_4>::has_a);
static_assert(!TestTypeCheck<TestClass_StdArray_4>::has_a_with_correct_type);

} // namespace test

#undef ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK

} // namespace details

Expand All @@ -555,40 +469,40 @@ class Program : public details::ProgramWrapper {

static ProgramMetadata GetMetadata() {
ProgramMetadata metadata;
if constexpr (details::DerivedProgramClassTypeCheck<T>::has_constants) {
constexpr const ProgramConstant* ptr = T::constants.data();
constexpr size_t len = T::constants.size();

static_assert(details::DerivedProgramClassTypeCheck<T>::has_constants_with_correct_type,
if constexpr (details::has_member_constants<T>) {
static_assert(details::has_constants_correct_type<T>,
"Derived class of \"Program\" has member \"constants\" but its type is incorrect. "
"Please use macro WEBGPU_PROGRAM_DEFINE_CONSTANTS() or WEBGPU_PROGRAM_EXTEND_CONSTANTS() to declare constants.");
Comment thread
fs-eire marked this conversation as resolved.

constexpr const ProgramConstant* ptr = T::constants.data();
constexpr size_t len = T::constants.size();

metadata.constants = {ptr, len};
} else {
metadata.constants = {};
}

if constexpr (details::DerivedProgramClassTypeCheck<T>::has_overridable_constants) {
constexpr const ProgramOverridableConstantDefinition* ptr = T::overridable_constants.data();
constexpr size_t len = T::overridable_constants.size();

static_assert(details::DerivedProgramClassTypeCheck<T>::has_overridable_constants_with_correct_type,
if constexpr (details::has_member_overridable_constants<T>) {
static_assert(details::has_overridable_constants_correct_type<T>,
"Derived class of \"Program\" has member \"overridable_constants\" but its type is incorrect. "
"Please use macro WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS() or WEBGPU_PROGRAM_EXTEND_OVERRIDABLE_CONSTANTS() to declare overridable constants.");

constexpr const ProgramOverridableConstantDefinition* ptr = T::overridable_constants.data();
constexpr size_t len = T::overridable_constants.size();

metadata.overridable_constants = {ptr, len};
} else {
metadata.overridable_constants = {};
}

if constexpr (details::DerivedProgramClassTypeCheck<T>::has_uniform_variables) {
constexpr const ProgramUniformVariableDefinition* ptr = T::uniform_variables.data();
constexpr size_t len = T::uniform_variables.size();

static_assert(details::DerivedProgramClassTypeCheck<T>::has_uniform_variables_with_correct_type,
if constexpr (details::has_member_uniform_variables<T>) {
static_assert(details::has_uniform_variables_correct_type<T>,
"Derived class of \"Program\" has member \"uniform_variables\" but its type is incorrect. "
"Please use macro WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES() or WEBGPU_PROGRAM_EXTEND_UNIFORM_VARIABLES() to declare uniform variables.");

constexpr const ProgramUniformVariableDefinition* ptr = T::uniform_variables.data();
constexpr size_t len = T::uniform_variables.size();

metadata.uniform_variables = {ptr, len};
} else {
metadata.uniform_variables = {};
Expand All @@ -599,20 +513,6 @@ class Program : public details::ProgramWrapper {
};

namespace details {
// helper function to convert a C-style array to std::array
//
// This is basically the same as std::to_array in C++20.
//
template <typename T, size_t N, size_t... Idx>
constexpr auto _to_std_array_impl(T (&arr)[N], std::index_sequence<Idx...>) -> std::array<std::remove_cv_t<T>, N> {
return {{arr[Idx]...}};
}

template <typename T, size_t N>
constexpr auto _to_std_array(T (&arr)[N]) -> std::array<std::remove_cv_t<T>, N> {
return _to_std_array_impl(arr, std::make_index_sequence<N>{});
}

// helper function to concatenate a std::array and a C-style array to a std::array
//
template <typename T, size_t L, size_t... IdxL, size_t R, size_t... IdxR>
Expand All @@ -632,7 +532,7 @@ constexpr std::array<std::remove_cv_t<T>, L + R> _concat2(const std::array<T, L>
#define WEBGPU_PROGRAM_DEFINE_(identifier, T, ...) \
static constexpr const T identifier##_own[] = {__VA_ARGS__}; \
static constexpr const auto identifier = \
onnxruntime::webgpu::details::_to_std_array(identifier##_own)
std::to_array(identifier##_own)

#define WEBGPU_PROGRAM_EXTEND_(identifier, T, BASE, ...) \
static constexpr const T identifier##_own[] = {__VA_ARGS__}; \
Expand Down
Loading
Loading