diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index a22d21d8d798b..bdeea726a2cf5 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -491,16 +491,29 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha ss << ","; } - auto alignment = (data_type == ProgramUniformVariableDataType::Float16 && length > 4) ? "@align(16) " : ""; - ss << "\n " << alignment << name << ": "; + // The actual variable type for the uniform variable depends on the data type (T) and length (N). + // + // For T in [i32, u32, f32]: + // - If N == 1, the type is simply i32, u32, or f32. + // - If 2 < N <= 4, the type is vecN, vecN, or vecN where N is the length. + // - If N > 4, the type is array, ceil(N / 4)>. + // + // For T is f16: + // - If N == 1 or N == 2, the type is u32. + // - If 2 < N <= 8, the type is vecX where X is ceil(N / 2). + // - If N > 8, the type is array, X> where X is ceil(N / 8). + // + // Note: Using f16 type in uniforms is not generally supported on all devices. We use a u32 variable to represent + // 2 f16 values. + + if (data_type == ProgramUniformVariableDataType::Float16) { + data_type = ProgramUniformVariableDataType::Uint32; // f16 is represented as u32 + length = (length + 1) / 2; // each u32 can hold 2 f16 values + } + ss << "\n " << name << ": "; if (length > 4) { - if (data_type == ProgramUniformVariableDataType::Float16) { - size_t array_size = (length + 7) / 8; - ss << "array, " << array_size << ">"; - } else { - size_t array_size = (length + 3) / 4; - ss << "array, " << array_size << ">"; - } + size_t array_size = (length + 3) / 4; + ss << "array, " << array_size << ">"; } else if (length > 1) { ss << "vec" << length << "<" << data_type << ">"; } else { diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index 2aba2a59d157f..78c98ab26f5b8 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -17,18 +17,34 @@ template || std::is_same_v>> std::string GetElementAt(std::string_view var, const TIdx& idx, TRank rank, bool is_f16 = false) { - // "std::string::rfind(str, 0) == 0" is equivalent to "std::string::starts_with(str)" before C++20. - if (var.rfind("uniforms.", 0) == 0) { - if (rank > 4) { - if constexpr (std::is_integral_v) { - if (is_f16) { - return MakeStringWithClassicLocale(var, "[", idx / 8, "][", (idx % 8) / 4, "][", (idx % 8) % 4, "]"); + if (var.starts_with("uniforms.")) { + if (is_f16) { + if (rank > 8) { + // array, N> + if constexpr (std::is_integral_v) { + return MakeStringWithClassicLocale("bitcast>(", var, "[", idx / 8, "][", (idx % 8) / 2, "])[", (idx % 8) % 2, "]"); } else { - return MakeStringWithClassicLocale(var, "[", idx / 4, "][", idx % 4, "]"); + return MakeStringWithClassicLocale("bitcast>(", var, "[(", idx, ") / 8][((", idx, ") % 8) / 2])[((", idx, ") % 8) % 2]"); + } + } else if (rank > 2) { + // vecN + if constexpr (std::is_integral_v) { + return MakeStringWithClassicLocale("bitcast>(", var, "[", idx / 2, "])[", idx % 2, "]"); + } else { + return MakeStringWithClassicLocale("bitcast>(", var, "[(", idx, ") / 2])[(", idx, ") % 2]"); } } else { - if (is_f16) { - return MakeStringWithClassicLocale(var, "[(", idx, ") / 8][(", idx, ") % 8 / 4][(", idx, ") % 8 % 4]"); + // u32 + if constexpr (std::is_integral_v) { + return MakeStringWithClassicLocale("bitcast>(", var, ")[", idx % 2, "]"); + } else { + return MakeStringWithClassicLocale("bitcast>(", var, ")[(", idx, ") % 2]"); + } + } + } else { + if (rank > 4) { + if constexpr (std::is_integral_v) { + return MakeStringWithClassicLocale(var, "[", idx / 4, "][", idx % 4, "]"); } else { return MakeStringWithClassicLocale(var, "[(", idx, ") / 4][(", idx, ") % 4]"); } diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 4bd79a627df22..a9557f7b9aa87 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -373,26 +373,57 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { continue; } - bool is_f16 = uniform.data_type == ProgramUniformVariableDataType::Float16; - - size_t element_size = ProgramUniformVariableDataTypeSize[static_cast(uniform.data_type)]; + // Calculate the size and alignment of the uniform variable. + // // https://www.w3.org/TR/WGSL/#alignof - size_t base_alignment = is_f16 - ? (length > 4 ? 16 : length > 2 ? 8 - : length * element_size) - : (length > 2 ? 16 : length * element_size); - size_t struct_size = is_f16 && length <= 4 ? length * element_size : 16; - - current_offset = (current_offset + base_alignment - 1) / base_alignment * base_alignment; + // + // For f16: + // - length > 8 : array, N> (align 16) (size 16 * N, N = ceil(length / 8)) + // - length == 7 or 8: vec4 (align 16) (size 16) + // - length == 5 or 6: vec3 (align 16) (size 12) + // - length == 3 or 4: vec2 (align 8) (size 8) + // - length == 1 or 2: u32 (align 4) (size 4) + // + // For other types (i32, u32, f32): + // - length > 4 : array, N> (align 16) (size 16 * N, N = ceil(length / 4)) + // - length == 4 : vec4 (align 16) (size 16) + // - length == 3 : vec3 (align 16) (size 12) + // - length == 2 : vec2 (align 8) (size 8) + // - length == 1 : T (align 4) (size 4) + // + + const bool is_f16 = uniform.data_type == ProgramUniformVariableDataType::Float16; + + size_t variable_alignment = 4; // default alignment for scalar types + size_t variable_size = 4; // default size for scalar types + + if (is_f16) { + if (length > 6) { + variable_alignment = 16; + variable_size = 16 * ((length + 7) / 8); + } else if (length > 4) { + variable_alignment = 16; + variable_size = 12; + } else if (length > 2) { + variable_alignment = 8; + variable_size = 8; + } + } else { + if (length > 3) { + variable_alignment = 16; + variable_size = 16 * ((length + 3) / 4); + } else if (length > 2) { + variable_alignment = 16; + variable_size = 12; + } else if (length > 1) { + variable_alignment = 8; + variable_size = 8; + } + } + current_offset = (current_offset + variable_alignment - 1) / variable_alignment * variable_alignment; uniform_and_offsets.emplace_back(uniform, current_offset); - // For non-float16 type, when length > 4, the uniform variable is of type array,N>, where - // N = ceil(data.length / 4) and SizeOf(vec4) = 16. The total byte length is N * SizeOf(vec4). - // For float16 type, when length > 4, the uniform variable is of type array,N>, where - // N = ceil(data.length / 8) and SizeOf(mat2x4) = 16. The total byte length is N * SizeOf(mat2x4). - size_t element_per_struct = is_f16 ? 8 : 4; - current_offset += - length > 4 ? (length + element_per_struct - 1) / element_per_struct * struct_size : length * element_size; + current_offset += variable_size; } // Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set