Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions onnxruntime/core/providers/webgpu/nn/pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,8 @@ Status PoolProgram::GenerateShaderCode(ShaderHelper& shader) const {

constexpr const size_t kStringInitialSize = 128;
if (is_max_pool_) {
std::string f16_min = "f16(-65504)";

SS(f32_min_ss, kStringInitialSize);
f32_min_ss << "f32(" << std::numeric_limits<float>::lowest() << ")";
std::string f32_min = SS_GET(f32_min_ss);

SS(var_decl_ss, kStringInitialSize);
var_decl_ss << " var value = " << (is_float16_ ? f16_min : f32_min) << ";\n";
var_decl_ss << " var value = " << (is_float16_ ? "-65504.0h" : "-3.4028234663852886e+38f") << ";\n";
var_decl_code = SS_GET(var_decl_ss);

sampling_code = " value = max(value, x_val);\n";
Expand Down
24 changes: 11 additions & 13 deletions onnxruntime/core/providers/webgpu/program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,18 @@ ProgramUniformVariableValue::ProgramUniformVariableValue(ProgramUniformVariableD
memcpy(data.data(), ptr, length * element_byte_size);
}

std::ostream& operator<<(std::ostream& os, ProgramUniformVariableDataType type) {
os << ProgramUniformVariableDataTypeName[std::underlying_type<decltype(type)>::type(type)];
return os;
}
#define DEFINE_ENUM_STREAM_OP(StreamType, EnumType, EnumNameArray) \
StreamType& operator<<(StreamType& os, EnumType type) { \
os << EnumNameArray[std::underlying_type<decltype(type)>::type(type)]; \
return os; \
}

std::ostream& operator<<(std::ostream& os, ProgramConstantDataType type) {
os << ProgramConstantDataTypeName[std::underlying_type<decltype(type)>::type(type)];
return os;
}
DEFINE_ENUM_STREAM_OP(std::ostream, ProgramUniformVariableDataType, ProgramUniformVariableDataTypeName)
DEFINE_ENUM_STREAM_OP(OStringStream, ProgramUniformVariableDataType, ProgramUniformVariableDataTypeName)
DEFINE_ENUM_STREAM_OP(std::ostream, ProgramConstantDataType, ProgramConstantDataTypeName)
DEFINE_ENUM_STREAM_OP(OStringStream, ProgramConstantDataType, ProgramConstantDataTypeName)

std::ostream& operator<<(std::ostream& os, ProgramTensorMetadataDependency dep) {
OStringStream& operator<<(OStringStream& os, ProgramTensorMetadataDependency dep) {
bool first = true;
if ((dep & ProgramTensorMetadataDependency::Type) == ProgramTensorMetadataDependency::Type) {
os << "Type";
Expand Down Expand Up @@ -109,10 +110,7 @@ constexpr std::string_view ProgramVariableDataTypeName[] = {
"i4x8", // Int4x8
};

std::ostream& operator<<(std::ostream& os, ProgramVariableDataType type) {
os << ProgramVariableDataTypeName[std::underlying_type<decltype(type)>::type(type)];
return os;
}
DEFINE_ENUM_STREAM_OP(OStringStream, ProgramVariableDataType, ProgramVariableDataTypeName)
#endif

int NumberOfComponents(ProgramVariableDataType type) {
Expand Down
8 changes: 6 additions & 2 deletions onnxruntime/core/providers/webgpu/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include "core/common/safeint.h"
#include "core/framework/tensor.h"

#include "core/providers/webgpu/string_utils.h"

namespace onnxruntime {
namespace webgpu {
class ShaderHelper;
Expand All @@ -37,6 +39,7 @@ enum class ProgramUniformVariableDataType {
Int32,
};
std::ostream& operator<<(std::ostream& os, ProgramUniformVariableDataType);
OStringStream& operator<<(OStringStream& os, ProgramUniformVariableDataType);

constexpr size_t ProgramUniformVariableDataTypeSize[] = {sizeof(float), sizeof(uint16_t), sizeof(uint32_t), sizeof(int32_t)};

Expand Down Expand Up @@ -80,6 +83,7 @@ enum class ProgramConstantDataType {
Bool
};
std::ostream& operator<<(std::ostream& os, ProgramConstantDataType);
OStringStream& operator<<(OStringStream& os, ProgramConstantDataType);

constexpr std::string_view ProgramConstantDataTypeName[] = {"f32", "f16", "u32", "i32", "bool"};

Expand Down Expand Up @@ -158,7 +162,7 @@ enum class ProgramTensorMetadataDependency : int {
TypeAndRank = Type | Rank,
TypeAndShape = Type | Shape,
};
std::ostream& operator<<(std::ostream& os, ProgramTensorMetadataDependency);
OStringStream& operator<<(OStringStream& os, ProgramTensorMetadataDependency);

#if defined(__GNUC__)
#pragma GCC diagnostic push
Expand Down Expand Up @@ -216,7 +220,7 @@ enum class ProgramVariableDataType {
// if you add a new type here, you also need to update ProgramVariableDataTypeName
};
#ifndef NDEBUG
std::ostream& operator<<(std::ostream& os, ProgramVariableDataType);
OStringStream& operator<<(OStringStream& os, ProgramVariableDataType);
#endif

int NumberOfComponents(ProgramVariableDataType type);
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/webgpu/program_cache_key.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace webgpu {

namespace {
// append the info of an input or output to the cachekey
void AppendTensorInfo(std::ostream& ss,
void AppendTensorInfo(OStringStream& ss,
const TensorShape& tensor_shape,
ProgramVariableDataType var_type,
ProgramTensorMetadataDependency dependency,
Expand Down
12 changes: 5 additions & 7 deletions onnxruntime/core/providers/webgpu/shader_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
dispatch_group_size_z_{dispatch_group_size_z},
program_{program},
program_metadata_{program_metadata},
additional_implementation_ss_{&additional_implementation_},
body_ss_{&body_} {}
additional_implementation_ss_{kStringInitialSizeShaderSourceCodeAdditionalImplementation},
body_ss_{kStringInitialSizeShaderSourceCodeMain} {}

Status ShaderHelper::Init() {
// dispatch group size is normalized so no need to validate it here
Expand All @@ -59,8 +59,6 @@
// init body string stream
bool is_1d_dispatch = dispatch_group_size_y_ == 1 && dispatch_group_size_z_ == 1;
bool use_indirect_dispatch = program_.IndirectDispatchTensor() != nullptr;
body_.reserve(4096);
additional_implementation_.reserve(1024);

// append header for main function so it is ready for user to append main function body
body_ss_ << "@compute @workgroup_size(workgroup_size_x, workgroup_size_y, workgroup_size_z)\n"
Expand Down Expand Up @@ -384,7 +382,7 @@
return Status::OK();
}

Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector<int>& shape_uniform_ranks) const {
Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector<int>& shape_uniform_ranks) {

Check warning on line 385 in onnxruntime/core/providers/webgpu/shader_helper.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/shader_helper.cc:385: Add #include <vector> for vector<> [build/include_what_you_use] [4]
SS(ss, kStringInitialSizeShaderSourceCode);

//
Expand Down Expand Up @@ -633,12 +631,12 @@
//
// Additional Implementation
//
ss << additional_implementation_;
ss << SS_GET(additional_implementation_ss_);

//
// Main Function Body
//
ss << body_;
ss << SS_GET(body_ss_);
ss << "\n"
"}\n";

Expand Down
6 changes: 2 additions & 4 deletions onnxruntime/core/providers/webgpu/shader_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@

private:
template <typename ConstantType> // ConstantType is one of {ProgramConstant, ProgramOverridableConstantValue, ProgramOverridableConstantDefinition}
void WriteConstantValue(std::ostream& ss, const ConstantType& constant) const {
void WriteConstantValue(OStringStream& ss, const ConstantType& constant) const {
switch (constant.type) {
case ProgramConstantDataType::Float16:
ss << constant.f16.ToFloat();
Expand Down Expand Up @@ -156,7 +156,7 @@
// \param code The generated full WGSL source code.
// \param shape_uniform_ranks The ranks for variables that need a uniform for the shape.
//
Status GenerateSourceCode(std::string& code, std::vector<int>& shape_uniform_ranks) const;
Status GenerateSourceCode(std::string& code, std::vector<int>& shape_uniform_ranks);

Check warning on line 159 in onnxruntime/core/providers/webgpu/shader_helper.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/shader_helper.h:159: Add #include <string> for string [build/include_what_you_use] [4]
friend class ProgramManager;

const WebGpuContext& webgpu_context_;
Expand All @@ -175,9 +175,7 @@
std::vector<std::unique_ptr<ShaderVariableHelper>> input_vars_;
std::vector<std::unique_ptr<ShaderVariableHelper>> output_vars_;
std::vector<std::unique_ptr<ShaderIndicesHelper>> indices_vars_;
std::string additional_implementation_;
OStringStream additional_implementation_ss_;
std::string body_;
OStringStream body_ss_;
};

Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/webgpu/shader_variable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ ShaderVariableHelper::ShaderVariableHelper(std::string_view name, ProgramVariabl
ORT_ENFORCE(num_components_ > 0, "Invalid number of components for variable ", name_);
}

void ShaderIndicesHelper::Impl(std::ostream& ss) const {
void ShaderIndicesHelper::Impl(OStringStream& ss) const {
// Start generating code

const std::string shape = (usage_ & ShaderUsage::UseUniform) ? "uniforms." + name_ + "_shape" : name_ + "_shape";
Expand Down Expand Up @@ -249,7 +249,7 @@ void ShaderIndicesHelper::Impl(std::ostream& ss) const {
}
}

void ShaderVariableHelper::Impl(std::ostream& ss) const {
void ShaderVariableHelper::Impl(OStringStream& ss) const {
ShaderIndicesHelper::Impl(ss);

// Implementation of "fn set_{name}"
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/webgpu/shader_variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class ShaderIndicesHelper {
protected:
ORT_DISALLOW_COPY_AND_ASSIGNMENT(ShaderIndicesHelper);

void Impl(std::ostream& ss) const;
void Impl(OStringStream& ss) const;

std::string_view IndicesType() const;

Expand Down Expand Up @@ -197,7 +197,7 @@ class ShaderVariableHelper : public ShaderIndicesHelper {
private:
ORT_DISALLOW_COPY_AND_ASSIGNMENT(ShaderVariableHelper);

void Impl(std::ostream& ss) const;
void Impl(OStringStream& ss) const;

std::string GetByOffsetImpl(std::string_view offset) const;
std::string SetByOffsetImpl(std::string_view offset, std::string_view value) const;
Expand Down
7 changes: 2 additions & 5 deletions onnxruntime/core/providers/webgpu/string_macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,10 @@
#include "core/providers/webgpu/string_utils.h"

// macro "SS" - declare an ostream variable and its string buffer
#define SS(ss, reserve_size) \
std::string ss##_str; \
ss##_str.reserve(reserve_size); \
::onnxruntime::webgpu::OStringStream ss(&ss##_str)
#define SS(ss, reserve_size) ::onnxruntime::webgpu::OStringStream ss(reserve_size)

// macro "SS_GET" - get the string from the ostream
#define SS_GET(ss) ss##_str
#define SS_GET(ss) (std::move(ss).str())

// macro "SS_APPEND" - use function call style to append to the ostream
#define SS_APPEND(ss, ...) ::onnxruntime::webgpu::detail::OStringStreamAppend(ss, __VA_ARGS__)
95 changes: 83 additions & 12 deletions onnxruntime/core/providers/webgpu/string_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@

#include "core/common/make_string.h"

#include <array>

Check warning on line 8 in onnxruntime/core/providers/webgpu/string_utils.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Found C++ system header after other header. Should be: string_utils.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/core/providers/webgpu/string_utils.h:8: Found C++ system header after other header. Should be: string_utils.h, c system, c++ system, other. [build/include_order] [4]
#include <charconv>

Check warning on line 9 in onnxruntime/core/providers/webgpu/string_utils.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Found C++ system header after other header. Should be: string_utils.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/core/providers/webgpu/string_utils.h:9: Found C++ system header after other header. Should be: string_utils.h, c system, c++ system, other. [build/include_order] [4]

#ifdef _MSC_VER
#pragma warning(push)
// C4702: unreachable code
#pragma warning(disable : 4702)
#endif // _MSC_VER

#include <absl/strings/internal/ostringstream.h>

#ifdef _MSC_VER
#pragma warning(pop)
#endif // _MSC_VER
Expand All @@ -22,32 +23,102 @@

constexpr const size_t kStringInitialSizeSetByOffsetImpl = 128;
constexpr const size_t kStringInitialSizeGetByOffsetImpl = 128;
constexpr const size_t kStringInitialSizeShaderSourceCode = 2048;
#ifndef NDEBUG
constexpr const size_t kStringInitialSizeShaderSourceCode = 4096;
constexpr const size_t kStringInitialSizeShaderSourceCodeAdditionalImplementation = 1024;
constexpr const size_t kStringInitialSizeShaderSourceCodeMain = 3068;
constexpr const size_t kStringInitialSizeCacheKey = 512;
#else
constexpr const size_t kStringInitialSizeCacheKey = 256;
#endif

using OStringStream = absl::strings_internal::OStringStream;
namespace detail {

// A simpler and faster ostringstream implementation than absl::strings_internal::OStringStream
//
// This FastOStringStream class is intended to be used in very performance critical paths. It does
// not inherit from std::ostream so that it can avoid the following overheads:
// - locale handling and formatting
// - state management (e.g. error handling, badbit, EOF, I/O sync)
// - unnecessary heap allocations
// - virtual function calls
//
// This class is majorly used for generating shader source code and program cache keys.
//
class FastOStringStream {
public:
explicit FastOStringStream(size_t reserve_size) {
str_.reserve(reserve_size);
}

std::string str() && {
return std::move(str_);

Check warning on line 51 in onnxruntime/core/providers/webgpu/string_utils.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/string_utils.h:51: Add #include <utility> for move [build/include_what_you_use] [4]
}

// String types
FastOStringStream& operator<<(const char* s) {
str_.append(s);
return *this;
}

FastOStringStream& operator<<(const std::string& s) {
str_.append(s);
return *this;
}

FastOStringStream& operator<<(std::string_view s) {
str_.append(s);
return *this;
}

// Character
FastOStringStream& operator<<(char c) {
str_.push_back(c);
return *this;
}

// Integer types
template <typename T>
std::enable_if_t<std::is_integral_v<T> && !std::is_same_v<T, char>, FastOStringStream&>
operator<<(T value) {
std::array<char, 32> buffer;
auto [ptr, ec] = std::to_chars(buffer.data(), buffer.data() + buffer.size(), value);
str_.append(buffer.data(), ptr - buffer.data());
return *this;
}

// Floating point types
template <typename T>
std::enable_if_t<std::is_floating_point_v<T>, FastOStringStream&>
operator<<(T value) {
std::array<char, 64> buffer;
auto [ptr, ec] = std::to_chars(buffer.data(), buffer.data() + buffer.size(), value);
str_.append(buffer.data(), ptr - buffer.data());
return *this;
}

private:
std::string str_;

Check warning on line 97 in onnxruntime/core/providers/webgpu/string_utils.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/string_utils.h:97: Add #include <string> for string [build/include_what_you_use] [4]
};

} // namespace detail

using OStringStream = detail::FastOStringStream;

namespace detail {
inline void OStringStreamAppendImpl(std::ostream& /*ss*/) noexcept {

inline void OStringStreamAppendImpl(OStringStream& /*ss*/) noexcept {
}

template <typename T>
inline void OStringStreamAppendImpl(std::ostream& ss, const T& t) noexcept {
inline void OStringStreamAppendImpl(OStringStream& ss, const T& t) noexcept {
ss << t;
}

template <typename T, typename... Args>
inline void OStringStreamAppendImpl(std::ostream& ss, const T& t, const Args&... args) noexcept {
inline void OStringStreamAppendImpl(OStringStream& ss, const T& t, const Args&... args) noexcept {
OStringStreamAppendImpl(ss, t);
OStringStreamAppendImpl(ss, args...);
}

template <typename... Args>
inline void OStringStreamAppend(std::ostream& ss, const Args&... args) {
inline void OStringStreamAppend(OStringStream& ss, const Args&... args) {
return OStringStreamAppendImpl(ss, ::onnxruntime::detail::if_char_array_make_ptr_t<Args const&>(args)...);
}

Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/webgpu/tensor/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ WEBGPU_CONCAT_VERSIONED_KERNEL(4, 10)
WEBGPU_CONCAT_VERSIONED_KERNEL(11, 12)
WEBGPU_CONCAT_KERNEL(13)

void AppendCalculateInputIndexFunction(std::ostream& os, size_t input_count) {
void AppendCalculateInputIndexFunction(OStringStream& os, size_t input_count) {
os << "fn calculate_input_index(global_idx: u32) -> u32 {\n"
<< " for (var i = 1u; i < " << input_count << "; i = i + 1u) {\n"
<< " if (global_idx < " << GetElementAt("uniforms.offsets", "i", input_count) << ") {\n"
Expand All @@ -49,7 +49,7 @@ void AppendCalculateInputIndexFunction(std::ostream& os, size_t input_count) {
<< "}\n";
}

void AppendAssignOutputDataFunction(std::ostream& os, gsl::span<const ShaderVariableHelper*> inputs, const ShaderVariableHelper& output, size_t axis, size_t input_count) {
void AppendAssignOutputDataFunction(OStringStream& os, gsl::span<const ShaderVariableHelper*> inputs, const ShaderVariableHelper& output, size_t axis, size_t input_count) {
os << "fn assign_output_data(global_idx: u32, input_index: u32) {\n";
for (size_t i = 0; i < inputs.size(); ++i) {
if (i == 0) {
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/webgpu/tensor/depth_to_space.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ WEBGPU_DEPTH_TO_SPACE_KERNEL(13, kOnnxDomain, false)
WEBGPU_DEPTH_TO_SPACE_VERSIONED_KERNEL(11, 12, kMSInternalNHWCDomain, true)
WEBGPU_DEPTH_TO_SPACE_KERNEL(13, kMSInternalNHWCDomain, true)

void AppendPermFunction(std::ostream& os, const ShaderVariableHelper& input, const int64_t* perm) {
void AppendPermFunction(OStringStream& os, const ShaderVariableHelper& input, const int64_t* perm) {
os << "fn perm(i: input_indices_t) -> input_indices_t {\n"
<< " var a: input_indices_t;\n";
for (int idx = 0; idx < input.Rank(); ++idx) {
Expand Down
Loading
Loading