Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions onnxruntime/core/providers/webgpu/shader_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -491,16 +491,29 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector<int>& sha
ss << ",";
}

auto alignment = (data_type == ProgramUniformVariableDataType::Float16 && length > 4) ? "@align(16) " : "";
ss << "\n " << alignment << name << ": ";
// The actual variable type for the uniform variable depends on the data type (T) and length (N).
//
// For T in [i32, u32, f32]:
// - If N == 1, the type is simply i32, u32, or f32.
// - If 2 < N <= 4, the type is vecN<i32>, vecN<u32>, or vecN<f32> where N is the length.
// - If N > 4, the type is array<vec4<T>, ceil(N / 4)>.
//
// For T is f16:
// - If N == 1 or N == 2, the type is u32.
// - If 2 < N <= 8, the type is vecX<u32> where X is ceil(N / 2).
// - If N > 8, the type is array<vec4<u32>, X> where X is ceil(N / 8).
//
// Note: Using f16 type in uniforms is not generally supported on all devices. We use a u32 variable to represent
// 2 f16 values.

if (data_type == ProgramUniformVariableDataType::Float16) {
data_type = ProgramUniformVariableDataType::Uint32; // f16 is represented as u32
length = (length + 1) / 2; // each u32 can hold 2 f16 values
}
ss << "\n " << name << ": ";
if (length > 4) {
if (data_type == ProgramUniformVariableDataType::Float16) {
size_t array_size = (length + 7) / 8;
ss << "array<mat2x4<" << data_type << ">, " << array_size << ">";
} else {
size_t array_size = (length + 3) / 4;
ss << "array<vec4<" << data_type << ">, " << array_size << ">";
}
size_t array_size = (length + 3) / 4;
ss << "array<vec4<" << data_type << ">, " << array_size << ">";
} else if (length > 1) {
ss << "vec" << length << "<" << data_type << ">";
} else {
Expand Down
34 changes: 25 additions & 9 deletions onnxruntime/core/providers/webgpu/shader_variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,34 @@ template <typename TIdx,
typename TRank,
typename = std::enable_if_t<std::is_same_v<TRank, int> || std::is_same_v<TRank, size_t>>>
std::string GetElementAt(std::string_view var, const TIdx& idx, TRank rank, bool is_f16 = false) {
// "std::string::rfind(str, 0) == 0" is equivalent to "std::string::starts_with(str)" before C++20.
if (var.rfind("uniforms.", 0) == 0) {
if (rank > 4) {
if constexpr (std::is_integral_v<TIdx>) {
if (is_f16) {
return MakeStringWithClassicLocale(var, "[", idx / 8, "][", (idx % 8) / 4, "][", (idx % 8) % 4, "]");
if (var.starts_with("uniforms.")) {
if (is_f16) {
if (rank > 8) {
// array<vec4<u32>, N>
if constexpr (std::is_integral_v<TIdx>) {
return MakeStringWithClassicLocale("bitcast<vec2<f16>>(", var, "[", idx / 8, "][", (idx % 8) / 2, "])[", (idx % 8) % 2, "]");
} else {
return MakeStringWithClassicLocale(var, "[", idx / 4, "][", idx % 4, "]");
return MakeStringWithClassicLocale("bitcast<vec2<f16>>(", var, "[(", idx, ") / 8][((", idx, ") % 8) / 2])[((", idx, ") % 8) % 2]");
}
} else if (rank > 2) {
// vecN<u32>
if constexpr (std::is_integral_v<TIdx>) {
return MakeStringWithClassicLocale("bitcast<vec2<f16>>(", var, "[", idx / 2, "])[", idx % 2, "]");
} else {
return MakeStringWithClassicLocale("bitcast<vec2<f16>>(", var, "[(", idx, ") / 2])[(", idx, ") % 2]");
}
} else {
if (is_f16) {
return MakeStringWithClassicLocale(var, "[(", idx, ") / 8][(", idx, ") % 8 / 4][(", idx, ") % 8 % 4]");
// u32
if constexpr (std::is_integral_v<TIdx>) {
return MakeStringWithClassicLocale("bitcast<vec2<f16>>(", var, ")[", idx % 2, "]");
} else {
return MakeStringWithClassicLocale("bitcast<vec2<f16>>(", var, ")[(", idx, ") % 2]");
}
}
} else {
if (rank > 4) {
if constexpr (std::is_integral_v<TIdx>) {
return MakeStringWithClassicLocale(var, "[", idx / 4, "][", idx % 4, "]");
} else {
return MakeStringWithClassicLocale(var, "[(", idx, ") / 4][(", idx, ") % 4]");
}
Expand Down
65 changes: 48 additions & 17 deletions onnxruntime/core/providers/webgpu/webgpu_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -373,26 +373,57 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) {
continue;
}

bool is_f16 = uniform.data_type == ProgramUniformVariableDataType::Float16;

size_t element_size = ProgramUniformVariableDataTypeSize[static_cast<int>(uniform.data_type)];
// Calculate the size and alignment of the uniform variable.
//
// https://www.w3.org/TR/WGSL/#alignof
size_t base_alignment = is_f16
? (length > 4 ? 16 : length > 2 ? 8
: length * element_size)
: (length > 2 ? 16 : length * element_size);
size_t struct_size = is_f16 && length <= 4 ? length * element_size : 16;

current_offset = (current_offset + base_alignment - 1) / base_alignment * base_alignment;
//
// For f16:
// - length > 8 : array<vec4<u32>, N> (align 16) (size 16 * N, N = ceil(length / 8))
// - length == 7 or 8: vec4<u32> (align 16) (size 16)
// - length == 5 or 6: vec3<u32> (align 16) (size 12)
// - length == 3 or 4: vec2<u32> (align 8) (size 8)
// - length == 1 or 2: u32 (align 4) (size 4)
//
// For other types (i32, u32, f32):
// - length > 4 : array<vec4<T>, N> (align 16) (size 16 * N, N = ceil(length / 4))
// - length == 4 : vec4<T> (align 16) (size 16)
// - length == 3 : vec3<T> (align 16) (size 12)
// - length == 2 : vec2<T> (align 8) (size 8)
// - length == 1 : T (align 4) (size 4)
//

const bool is_f16 = uniform.data_type == ProgramUniformVariableDataType::Float16;

size_t variable_alignment = 4; // default alignment for scalar types
size_t variable_size = 4; // default size for scalar types

if (is_f16) {
if (length > 6) {
variable_alignment = 16;
variable_size = 16 * ((length + 7) / 8);
} else if (length > 4) {
variable_alignment = 16;
variable_size = 12;
} else if (length > 2) {
variable_alignment = 8;
variable_size = 8;
}
} else {
if (length > 3) {
variable_alignment = 16;
variable_size = 16 * ((length + 3) / 4);
} else if (length > 2) {
variable_alignment = 16;
variable_size = 12;
} else if (length > 1) {
variable_alignment = 8;
variable_size = 8;
}
}
current_offset = (current_offset + variable_alignment - 1) / variable_alignment * variable_alignment;
uniform_and_offsets.emplace_back(uniform, current_offset);

// For non-float16 type, when length > 4, the uniform variable is of type array<vec4<i32|u32|f32>,N>, where
// N = ceil(data.length / 4) and SizeOf(vec4<i32|u32|f32>) = 16. The total byte length is N * SizeOf(vec4<i32|u32|f32>).
// For float16 type, when length > 4, the uniform variable is of type array<mat2x4<f16>,N>, where
// N = ceil(data.length / 8) and SizeOf(mat2x4<f16>) = 16. The total byte length is N * SizeOf(mat2x4<f16>).
size_t element_per_struct = is_f16 ? 8 : 4;
current_offset +=
length > 4 ? (length + element_per_struct - 1) / element_per_struct * struct_size : length * element_size;
current_offset += variable_size;
}

// Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set
Expand Down
Loading