Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion cmake/onnxruntime_providers_webgpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,6 @@
endif()
endif()

target_compile_features(onnxruntime_providers_webgpu PRIVATE cxx_std_20)
add_dependencies(onnxruntime_providers_webgpu onnx ${onnxruntime_EXTERNAL_DEPENDENCIES})

if (onnxruntime_WGSL_TEMPLATE)
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/webgpu/bert/bias_add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Status BiasAddProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< " let value = " << input.GetByOffset("global_idx")
<< " + " << bias.GetByOffset("global_idx % uniforms.channels")
<< " + " << residual.GetByOffset("global_idx") << ";\n"
<< " " + output.SetByOffset("global_idx", "value");
<< " " << output.SetByOffset("global_idx", "value");

return Status::OK();
}
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ Status RotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< " let j = i + select(half_rotary_emb_dim, 1, " << interleaved_str << ");\n"
<< " let re = " << input.GetByOffset("i") << " * " << cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << " - " << input.GetByOffset("j") << " * " << sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << ";\n"
<< " " << output.SetByOffset("i", "re") << "\n"
<< " let im = " << input.GetByOffset("i") << " * " << sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << " + " << input.GetByOffset("j") + " * " << cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << ";\n"
<< " let im = " << input.GetByOffset("i") << " * " << sin_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << " + " << input.GetByOffset("j") << " * " << cos_cache.GetByIndices("vec2<u32>(position_id, bsnh[3])") << ";\n"
<< " " << output.SetByOffset("j", "im") << "\n"
<< " } else { \n"
" let k = dot(bsnh, uniforms.input_output_stride) + half_rotary_emb_dim;\n"
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/core/providers/webgpu/math/einsum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "core/providers/webgpu/math/einsum.h"

#include <algorithm>
#include <cctype>
#include <regex>
#include <set>
#include <vector>
Expand All @@ -24,7 +25,7 @@ static const std::regex lhs_pattern("(([a-zA-Z]|\\.\\.\\.)*,)*([a-zA-Z]|\\.\\.\\
// Helper function to remove all whitespaces in a given string.
std::string RemoveAllWhitespace(const std::string& str) {
std::string result = str;
result.erase(std::remove_if(result.begin(), result.end(), ::isspace), result.end());
std::erase_if(result, [](unsigned char c) { return std::isspace(c); });
return result;
Comment thread
fs-eire marked this conversation as resolved.
}

Expand Down Expand Up @@ -318,7 +319,7 @@ Status EinsumProgram::GenerateShaderCode(ShaderHelper& shader) const {
symbol));

// Check if we've already processed this symbol to avoid duplicate loop generation
if (uniform_symbol_set.find(symbol) == uniform_symbol_set.end()) {
if (!uniform_symbol_set.contains(symbol)) {
// Add symbol to tracked set to prevent duplicate processing
uniform_symbol_set.insert(symbol);

Expand Down
12 changes: 4 additions & 8 deletions onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#include <utility>
#include <cstring>
#include <limits>

#include "core/providers/webgpu/math/unary_elementwise_ops.h"
Expand Down Expand Up @@ -194,10 +195,6 @@ class Clip final : public UnaryElementwise {
"Clip",
std::is_same_v<T, MLFloat16> ? ClipF16Impl : ClipImpl,
"", ShaderUsage::UseElementTypeAlias} {}
#if defined(__GNUC__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif

Status ConfigureProgram(const ComputeContext& context, UnaryElementwiseProgram& program) const override {
const auto* clip_min_tensor = context.Input<Tensor>(1);
Expand All @@ -209,7 +206,9 @@ class Clip final : public UnaryElementwise {
: std::numeric_limits<T>::max()};
if constexpr (std::is_same_v<T, MLFloat16>) {
// F16: stores span<f16, 2> as a single float
float encoded_value = *reinterpret_cast<const float*>(attr);
float encoded_value;
static_assert(sizeof(encoded_value) == 2 * sizeof(MLFloat16));
std::memcpy(&encoded_value, attr, sizeof(encoded_value));
program.AddUniformVariable({encoded_value});
Comment thread
fs-eire marked this conversation as resolved.
} else {
static_assert(sizeof(T) == sizeof(float), "T must be f32, i32 or u32");
Expand All @@ -218,9 +217,6 @@ class Clip final : public UnaryElementwise {
}
return Status::OK();
}
#if defined(__GNUC__)
#pragma GCC diagnostic pop
#endif

// uniforms.attr is a f32 value. It is encoded as a float for 2 f16 values.
// bitcast<vec2<f16>>(uniforms.attr)[0] is clip_min, bitcast<vec2<f16>>(uniforms.attr)[1] is clip_max
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/webgpu/nn/pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ Status Pool<PoolType, is_nhwc>::ComputeInternal(ComputeContext& context) const {
Tensor* Y = context.Output(0, output_shape);

std::vector<uint32_t> kernel_strides(kernel_shape.size());
ORT_ENFORCE(kernel_shape.size() > 0, "kernel_shape must have at least one element.");
ORT_ENFORCE(!kernel_shape.empty(), "kernel_shape must have at least one element.");
// Calculate the kernel element strides for each dimension in reverse order. For example:
// kernel_shape = [3, 2], kernel_strides = [2, 1]
// kernel_shape = [2, 3, 2], kernel_strides = [6, 2, 1]
Expand Down
206 changes: 53 additions & 153 deletions onnxruntime/core/providers/webgpu/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#pragma once

#include <array>
#include <string>
#include <vector>
#include <iosfwd>
Expand Down Expand Up @@ -164,28 +165,19 @@ enum class ProgramTensorMetadataDependency : int {
};
OStringStream& operator<<(OStringStream& os, ProgramTensorMetadataDependency);

#if defined(__GNUC__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif

inline ProgramTensorMetadataDependency operator|(ProgramTensorMetadataDependency a, ProgramTensorMetadataDependency b) {
return (ProgramTensorMetadataDependency)((int&)a | (int&)b);
return static_cast<ProgramTensorMetadataDependency>(static_cast<int>(a) | static_cast<int>(b));
}
inline ProgramTensorMetadataDependency operator&(ProgramTensorMetadataDependency a, ProgramTensorMetadataDependency b) {
return (ProgramTensorMetadataDependency)((int&)a & (int&)b);
return static_cast<ProgramTensorMetadataDependency>(static_cast<int>(a) & static_cast<int>(b));
}
inline ProgramTensorMetadataDependency& operator|=(ProgramTensorMetadataDependency& a, ProgramTensorMetadataDependency b) {
return (ProgramTensorMetadataDependency&)((int&)a |= (int&)b);
return a = a | b;
}
inline ProgramTensorMetadataDependency& operator&=(ProgramTensorMetadataDependency& a, ProgramTensorMetadataDependency b) {
return (ProgramTensorMetadataDependency&)((int&)a &= (int&)b);
return a = a & b;
}

#if defined(__GNUC__)
#pragma GCC diagnostic pop
#endif

constexpr SafeInt<uint32_t> WORKGROUP_SIZE = 64;

// data type of variable
Expand Down Expand Up @@ -417,133 +409,55 @@ class ProgramWrapper : public ProgramBase {
ProgramWrapper(Args&&... args) : ProgramBase{std::forward<Args>(args)...} {}
};

#if defined(ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK)
#error "macro ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK is already defined"
#endif

#define ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(identifier, element_type) \
private: \
template <typename U> \
static auto test_has_##identifier(int) -> decltype(U::identifier, std::true_type{}); /* checks if member exists */ \
template <typename...> \
static auto test_has_##identifier(...) -> std::false_type; \
\
template <typename U, /* The following type check uses SFINAE */ \
typename = std::enable_if_t< /* to ensure the specific member: */ \
is_const_std_array<decltype(U::identifier)>::value && /* - is a const std::array */ \
std::is_const_v<decltype(U::identifier)> && /* - has "const" modifier */ \
!std::is_member_pointer_v<decltype(&U::identifier)>>> /* - is static */ \
static auto test_has_##identifier##_with_correct_type(int) -> std::true_type; \
template <typename...> \
static auto test_has_##identifier##_with_correct_type(...) -> std::false_type; \
\
public: \
static constexpr bool has_##identifier = decltype(test_has_##identifier<T>(0))::value; \
static constexpr bool has_##identifier##_with_correct_type = decltype(test_has_##identifier##_with_correct_type<T>(0))::value

// the following template class checks whether the type is a const std::array
template <typename T>
struct is_const_std_array : std::false_type {};
template <typename T, size_t N>
struct is_const_std_array<const std::array<T, N>> : std::true_type {};

// the following template class checks whether certain static members exist in the derived class (SFINAE)
template <typename T>
class DerivedProgramClassTypeCheck {
ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(constants, ProgramConstant);
ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(overridable_constants, ProgramOverridableConstantDefinition);
ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(uniform_variables, ProgramUniformVariableDefinition);
};

// compile-time tests for the type check
//
// TODO: move this to test folder
namespace test {
// The following variable templates check whether certain static members exist in the derived class.
// Uses std::void_t with decltype(T::member) for SFINAE-based detection of named static data members.

template <typename T, typename = void>
inline constexpr bool has_member_constants = false;
template <typename T>
class TestTypeCheck {
ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(a, int);
};

struct TestClass_Empty {};
static_assert(!TestTypeCheck<TestClass_Empty>::has_a);
static_assert(!TestTypeCheck<TestClass_Empty>::has_a_with_correct_type);

struct TestClass_NotArray_0 {
int b;
};
static_assert(!TestTypeCheck<TestClass_NotArray_0>::has_a);
static_assert(!TestTypeCheck<TestClass_NotArray_0>::has_a_with_correct_type);
inline constexpr bool has_member_constants<T, std::void_t<decltype(T::constants)>> = true;

struct TestClass_NotArray_1 {
int a;
};
static_assert(TestTypeCheck<TestClass_NotArray_1>::has_a);
static_assert(!TestTypeCheck<TestClass_NotArray_1>::has_a_with_correct_type);

struct TestClass_NotArray_2 {
const int a;
};
static_assert(TestTypeCheck<TestClass_NotArray_2>::has_a);
static_assert(!TestTypeCheck<TestClass_NotArray_2>::has_a_with_correct_type);

struct TestClass_NotStdArray_0 {
const int a[2];
};
static_assert(TestTypeCheck<TestClass_NotStdArray_0>::has_a);
static_assert(!TestTypeCheck<TestClass_NotStdArray_0>::has_a_with_correct_type);

struct TestClass_NotStdArray_1 {
static constexpr int a[] = {0};
};
static_assert(TestTypeCheck<TestClass_NotStdArray_1>::has_a);
static_assert(!TestTypeCheck<TestClass_NotStdArray_1>::has_a_with_correct_type);

struct TestClass_NotStdArray_2 {
static int a[];
};
static_assert(TestTypeCheck<TestClass_NotStdArray_2>::has_a);
static_assert(!TestTypeCheck<TestClass_NotStdArray_2>::has_a_with_correct_type);

struct TestClass_NotStdArray_3 {
static const int a[];
};
static_assert(TestTypeCheck<TestClass_NotStdArray_3>::has_a);
static_assert(!TestTypeCheck<TestClass_NotStdArray_3>::has_a_with_correct_type);
template <typename T, typename = void>
inline constexpr bool has_member_overridable_constants = false;
template <typename T>
inline constexpr bool has_member_overridable_constants<T, std::void_t<decltype(T::overridable_constants)>> = true;

struct TestClass_StdArray_0 {
std::array<int, 1> a = {1};
};
static_assert(TestTypeCheck<TestClass_StdArray_0>::has_a);
static_assert(!TestTypeCheck<TestClass_StdArray_0>::has_a_with_correct_type);
template <typename T, typename = void>
inline constexpr bool has_member_uniform_variables = false;
template <typename T>
inline constexpr bool has_member_uniform_variables<T, std::void_t<decltype(T::uniform_variables)>> = true;

struct TestClass_StdArray_1 {
static constexpr std::array<int, 2> a = {1, 2};
};
static_assert(TestTypeCheck<TestClass_StdArray_1>::has_a);
static_assert(TestTypeCheck<TestClass_StdArray_1>::has_a_with_correct_type);
// C++20 concepts for checking whether the member has the correct type (static const std::array).

struct TestClass_StdArray_2 {
static const std::array<int, 3> a;
template <typename T>
concept has_constants_correct_type = requires {
T::constants;
requires is_const_std_array<decltype(T::constants)>::value;
requires std::is_const_v<decltype(T::constants)>;
requires !std::is_member_pointer_v<decltype(&T::constants)>;
};
Comment thread
fs-eire marked this conversation as resolved.
static_assert(TestTypeCheck<TestClass_StdArray_2>::has_a);
static_assert(TestTypeCheck<TestClass_StdArray_2>::has_a_with_correct_type);

struct TestClass_StdArray_3 {
static constexpr const std::array<int, 4> a = {1, 2, 3, 4};
template <typename T>
concept has_overridable_constants_correct_type = requires {
T::overridable_constants;
requires is_const_std_array<decltype(T::overridable_constants)>::value;
requires std::is_const_v<decltype(T::overridable_constants)>;
requires !std::is_member_pointer_v<decltype(&T::overridable_constants)>;
};
static_assert(TestTypeCheck<TestClass_StdArray_3>::has_a);
static_assert(TestTypeCheck<TestClass_StdArray_3>::has_a_with_correct_type);

struct TestClass_StdArray_4 {
static std::array<int, 5> a;
template <typename T>
concept has_uniform_variables_correct_type = requires {
T::uniform_variables;
requires is_const_std_array<decltype(T::uniform_variables)>::value;
requires std::is_const_v<decltype(T::uniform_variables)>;
requires !std::is_member_pointer_v<decltype(&T::uniform_variables)>;
};
static_assert(TestTypeCheck<TestClass_StdArray_4>::has_a);
static_assert(!TestTypeCheck<TestClass_StdArray_4>::has_a_with_correct_type);

} // namespace test

#undef ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK

} // namespace details

Expand All @@ -555,40 +469,40 @@ class Program : public details::ProgramWrapper {

static ProgramMetadata GetMetadata() {
ProgramMetadata metadata;
if constexpr (details::DerivedProgramClassTypeCheck<T>::has_constants) {
constexpr const ProgramConstant* ptr = T::constants.data();
constexpr size_t len = T::constants.size();

static_assert(details::DerivedProgramClassTypeCheck<T>::has_constants_with_correct_type,
if constexpr (details::has_member_constants<T>) {
static_assert(details::has_constants_correct_type<T>,
"Derived class of \"Program\" has member \"constants\" but its type is incorrect. "
"Please use macro WEBGPU_PROGRAM_DEFINE_CONSTANTS() or WEBGPU_PROGRAM_EXTEND_CONSTANTS() to declare constants.");
Comment thread
fs-eire marked this conversation as resolved.

constexpr const ProgramConstant* ptr = T::constants.data();
constexpr size_t len = T::constants.size();

metadata.constants = {ptr, len};
} else {
metadata.constants = {};
}

if constexpr (details::DerivedProgramClassTypeCheck<T>::has_overridable_constants) {
constexpr const ProgramOverridableConstantDefinition* ptr = T::overridable_constants.data();
constexpr size_t len = T::overridable_constants.size();

static_assert(details::DerivedProgramClassTypeCheck<T>::has_overridable_constants_with_correct_type,
if constexpr (details::has_member_overridable_constants<T>) {
static_assert(details::has_overridable_constants_correct_type<T>,
"Derived class of \"Program\" has member \"overridable_constants\" but its type is incorrect. "
"Please use macro WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS() or WEBGPU_PROGRAM_EXTEND_OVERRIDABLE_CONSTANTS() to declare overridable constants.");

constexpr const ProgramOverridableConstantDefinition* ptr = T::overridable_constants.data();
constexpr size_t len = T::overridable_constants.size();

metadata.overridable_constants = {ptr, len};
} else {
metadata.overridable_constants = {};
}

if constexpr (details::DerivedProgramClassTypeCheck<T>::has_uniform_variables) {
constexpr const ProgramUniformVariableDefinition* ptr = T::uniform_variables.data();
constexpr size_t len = T::uniform_variables.size();

static_assert(details::DerivedProgramClassTypeCheck<T>::has_uniform_variables_with_correct_type,
if constexpr (details::has_member_uniform_variables<T>) {
static_assert(details::has_uniform_variables_correct_type<T>,
"Derived class of \"Program\" has member \"uniform_variables\" but its type is incorrect. "
"Please use macro WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES() or WEBGPU_PROGRAM_EXTEND_UNIFORM_VARIABLES() to declare uniform variables.");

constexpr const ProgramUniformVariableDefinition* ptr = T::uniform_variables.data();
constexpr size_t len = T::uniform_variables.size();

metadata.uniform_variables = {ptr, len};
} else {
metadata.uniform_variables = {};
Expand All @@ -599,20 +513,6 @@ class Program : public details::ProgramWrapper {
};

namespace details {
// helper function to convert a C-style array to std::array
//
// This is basically the same as std::to_array in C++20.
//
template <typename T, size_t N, size_t... Idx>
constexpr auto _to_std_array_impl(T (&arr)[N], std::index_sequence<Idx...>) -> std::array<std::remove_cv_t<T>, N> {
return {{arr[Idx]...}};
}

template <typename T, size_t N>
constexpr auto _to_std_array(T (&arr)[N]) -> std::array<std::remove_cv_t<T>, N> {
return _to_std_array_impl(arr, std::make_index_sequence<N>{});
}

// helper function to concatenate a std::array and a C-style array to a std::array
//
template <typename T, size_t L, size_t... IdxL, size_t R, size_t... IdxR>
Expand All @@ -632,7 +532,7 @@ constexpr std::array<std::remove_cv_t<T>, L + R> _concat2(const std::array<T, L>
#define WEBGPU_PROGRAM_DEFINE_(identifier, T, ...) \
static constexpr const T identifier##_own[] = {__VA_ARGS__}; \
static constexpr const auto identifier = \
onnxruntime::webgpu::details::_to_std_array(identifier##_own)
std::to_array(identifier##_own)

#define WEBGPU_PROGRAM_EXTEND_(identifier, T, BASE, ...) \
static constexpr const T identifier##_own[] = {__VA_ARGS__}; \
Expand Down
Loading
Loading