diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc new file mode 100644 index 0000000000000..9b447d5fdb59a --- /dev/null +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -0,0 +1,228 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/math/matmul.h" +#include "core/common/inlined_containers.h" +#include "core/providers/cpu/tensor/utils.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +#include "core/providers/webgpu/data_transfer.h" +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + MatMul, + kOnnxDomain, + 1, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()), + MatMul); + +ONNX_OPERATOR_KERNEL_EX( + MatMul, + kOnnxDomain, + 13, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()), + MatMul); + +static std::string CalcResult(int64_t components, int64_t a_components, int64_t output_number) { + std::ostringstream oss; + oss << "var a_data: a_value_t;\n"; + for (int i = 0; i < a_components; ++i) { + oss << "let b_data" << i << " = b[(b_offset + (k + " << i << ") * uniforms.N + col) / " << components << "];\n"; + } + for (int i = 0; i < output_number; ++i) { + oss << "a_data = a[(a_offset + (row + " << i << ") * uniforms.K + k) / " << a_components << "];\n"; + + for (int j = 0; j < a_components; j++) { + oss << "values[" << i << "] = fma(b_value_t(a_data" << (a_components == 1 ? "" : "[" + std::to_string(j) + "]") << "), b_data" << j << ", values[" << i << "]);\n"; + } + } + return oss.str(); +} + +Status MatMulNaiveProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& a = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | + ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& b = shader.AddInput("b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | + ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + + std::string process_bias; + if (has_bias_) { + shader.AddInput("bias", ShaderUsage::UseUniform); + process_bias = "value += output_value_t(bias[row + i]);"; + } + + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | + ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + const auto& batch_dims = shader.AddIndices("batch_dims"); + + int a_components = a.NumComponents(); + int components = b.NumComponents(); // components of N + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << "let col = (global_idx % (uniforms.N / " << components << ")) * " << components << ";\n" + << "var index1 = global_idx / (uniforms.N / " << components << ");\n" + << "let stride1 = uniforms.M / " << output_number_ << ";\n" + << "let row = (index1 % stride1) * " << output_number_ << ";\n" + << "let batch = index1 / stride1;\n"; + if (output_rank_ != 2) { + shader.MainFunctionBody() << "let batch_indices = " << batch_dims.OffsetToIndices("batch") << ";\n"; + } + shader.MainFunctionBody() << "var a_indices: a_indices_t;\n" + << ConvertOutputBatchIndicesToInputBatchIndices("a", a, a.Rank() - 2, batch_dims.Rank(), "batch_indices") + << a.IndicesSet("a_indices", a.Rank() - 2, 0) << "\n" + << a.IndicesSet("a_indices", a.Rank() - 1, 0) << "\n" + << "let a_offset = " << a.IndicesToOffset("a_indices") << "*" << a_components << ";\n" + << "var b_indices: b_indices_t;\n" + << ConvertOutputBatchIndicesToInputBatchIndices("b", b, b.Rank() - 2, batch_dims.Rank(), "batch_indices") + << b.IndicesSet("b_indices", b.Rank() - 2, 0) << "\n" + << b.IndicesSet("b_indices", b.Rank() - 1, 0) << "\n" + << "let b_offset = " << b.IndicesToOffset("b_indices") << " * " << components << ";\n" + << "var values: array;\n" + << "for (var k: u32 = 0u; k < uniforms.K; k = k + " << a_components << ") {\n" + << CalcResult(components, a_components, output_number_) << "\n" + << "}\n" + << "for (var i = 0u; i < " << output_number_ << "u; i++) {\n" + << " var value = values[i];\n" + << process_bias << "\n" + << " let cur_indices = output_indices_t(batch, row + i, col/ " << components << ");\n" + << " let offset = " << output.IndicesToOffset("cur_indices") << ";\n" + << output.SetByOffset("offset", "value") + << "}\n"; + + return Status::OK(); +} + +Status MatMul::ComputeInternal(ComputeContext& context) const { + // calculate output shape + MatMulComputeHelper helper; + const auto* a = context.Input(0); + const auto* b = context.Input(1); + + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b->Shape())); + auto* output_tensor = context.Output(0, helper.OutputShape()); + bool has_bias = context.InputCount() > 2; + + if (helper.N() < 8 && helper.K() < 8) { // call MatMulNaiveProgram + + const uint32_t m = narrow(helper.M()); // left matrix first dimension + const uint32_t n = narrow(helper.N()); // right matrix second dimension + const uint32_t k = narrow(helper.K()); // right matrix first dimension + + const auto components = GetMaxComponents(n); + const auto a_components = GetMaxComponents(k); + + const auto output_number = GetMaxComponents(m); + uint32_t output_size = narrow(helper.OutputShape().Size() / components / output_number); + + const size_t output_rank = helper.OutputShape().NumDimensions(); + TensorShape outer_dims = output_rank > 2 ? helper.OutputShape().Slice(0, output_rank - 2) : TensorShape({}); + const int64_t batch_size = outer_dims.Size(); + + const int64_t a_rows = a->Shape().NumDimensions() > 1 ? a->Shape()[a->Shape().NumDimensions() - 2] : 1; + TensorShape output_shape_shader({batch_size, a_rows, helper.N() / components}); + + MatMulNaiveProgram program{output_rank, output_number, has_bias}; + + program + .CacheHint(std::to_string(components), std::to_string(a_components), std::to_string(output_number)) + .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_components}, + {b, ProgramTensorMetadataDependency::TypeAndRank, components}}); + + if (has_bias) { + const auto* bias = context.Input(2); + program.AddInput({bias, ProgramTensorMetadataDependency::Rank, 1}); + } + program + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::None, output_shape_shader, components}}) + .SetDispatchGroupSize((output_size + 63) / 64) // Integer ceiling division + .AddIndices(outer_dims) + .AddUniformVariables({{output_size}, {m}, {n}, {k}}); + + return context.RunProgram(program); + } + + int64_t batchA = a->Shape().SizeToDimension(a->Shape().NumDimensions() - 2); + int64_t batchB = b->Shape().SizeToDimension(b->Shape().NumDimensions() - 2); + + TensorShape a_shape = a->Shape(); + TensorShape b_shape = b->Shape(); + TensorShape output_shape = helper.OutputShape(); + + const int64_t dim_output_outer = output_shape[output_shape.NumDimensions() - 2]; + // check if A is batch of vector (bach is not 1, M is 1) and B is a matrix (batch is 1) + if (batchA != 1 && dim_output_outer == 1 && batchB == 1) { + // optimization for batched vector matrix multiplication + // dimensions of A: [1,`batchA`,K] + TensorShapeVector dims_a = {1, batchA, helper.K()}; + // dimensions of B: [1,K,N] + TensorShapeVector dims_b = {1, helper.K(), helper.N()}; + + a_shape = TensorShape(dims_a); + b_shape = TensorShape(dims_b); + output_shape = {1, batchA, helper.N()}; + } + + // helpful dimension variables + TensorShape outer_dims_a = a_shape.NumDimensions() > 2 + ? a_shape.Slice(0, a_shape.NumDimensions() - 2) + : TensorShape({}); + + TensorShape outer_dims_b = b_shape.NumDimensions() > 2 + ? b_shape.Slice(0, b_shape.NumDimensions() - 2) + : TensorShape({}); + + TensorShape outer_dims = output_shape.NumDimensions() > 2 + ? output_shape.Slice(0, output_shape.NumDimensions() - 2) + : TensorShape({}); + + const int64_t batch_size = outer_dims.Size(); + + // Get dimensions for matrix multiplication from TensorShape + const int32_t dim_a_outer = narrow(a_shape[a_shape.NumDimensions() - 2]); // left matrix second dimension + const int32_t dim_inner = narrow(a_shape[a_shape.NumDimensions() - 1]); // left matrix first dimension + const int32_t dim_b_outer = narrow(b_shape[b_shape.NumDimensions() - 1]); // right matrix first dimension + + const bool is_vec4 = dim_inner % 4 == 0 && dim_b_outer % 4 == 0; + + InlinedVector elements_per_thread = dim_a_outer <= 8 + ? InlinedVector({4, 1, 1}) + : InlinedVector({4, 4, 1}); + + const uint32_t dispatch_x = narrow((dim_b_outer + MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0] - 1) / + (MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0])); + const uint32_t dispatch_y = narrow((dim_a_outer + MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1] - 1) / + (MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1])); + const uint32_t dispatch_z = narrow((static_cast(batch_size) + MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2] - 1) / + (MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2])); + + const int components = is_vec4 ? 4 : 1; + const TensorShape a_shape_temp = CreateMatMulIntermediateShape(outer_dims_a, dim_a_outer, dim_inner, components); + const TensorShape b_shape_temp = CreateMatMulIntermediateShape(outer_dims_b, dim_inner, dim_b_outer, components); + const TensorShape output_shape_temp = TensorShape({batch_size, dim_a_outer, dim_b_outer / components}); + + MatMulProgram program{has_bias, is_vec4, elements_per_thread}; + program + .CacheHint(absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4)) + .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, components}, + {b, ProgramTensorMetadataDependency::TypeAndRank, b_shape_temp, components}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::Rank, output_shape_temp, components}}) + .AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}}) + .AddIndices(outer_dims) + .SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z) + .SetWorkgroupSize(MATMUL_PACKED_WORKGROUP_SIZE_X, MATMUL_PACKED_WORKGROUP_SIZE_Y, MATMUL_PACKED_WORKGROUP_SIZE_Z); + + if (has_bias) { + const auto* bias = context.Input(2); + program.AddInput({bias, ProgramTensorMetadataDependency::Rank, 1}); + } + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/matmul.h b/onnxruntime/core/providers/webgpu/math/matmul.h new file mode 100644 index 0000000000000..789e824383189 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/math/matmul.h @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/cpu/math/matmul_helper.h" +#include "core/providers/webgpu/math/matmul_utils.h" +#include "core/providers/webgpu/math/matmul_packed.h" +#include "core/providers/webgpu/webgpu_utils.h" + +namespace onnxruntime { +namespace webgpu { + +class MatMul final : public WebGpuKernel { + public: + MatMul(const OpKernelInfo& info) : WebGpuKernel{info} {} + + Status ComputeInternal(ComputeContext& context) const override; + + constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_X = 8; + constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_Y = 8; + constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_Z = 1; +}; + +class MatMulNaiveProgram final : public Program { + public: + MatMulNaiveProgram(const size_t output_rank, int64_t output_number, bool has_bias) + : Program{"MatMulNaive"}, output_rank_(output_rank), output_number_(output_number), has_bias_{has_bias} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}, + {"M", ProgramUniformVariableDataType::Uint32}, + {"N", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}); + + private: + const size_t output_rank_; + const int64_t output_number_; + const bool has_bias_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc new file mode 100644 index 0000000000000..2e5cff923f442 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc @@ -0,0 +1,303 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/math/matmul_packed.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/webgpu_utils.h" + +namespace onnxruntime { +namespace webgpu { + +void MatMulProgram::MatMulReadWriteFnSource(ShaderHelper& shader, + const ShaderVariableHelper& a, + const ShaderVariableHelper& b, + const ShaderVariableHelper& output, + const ShaderIndicesHelper& batch_dims) const { + int components = is_vec4_ ? 4 : 1; + const std::string data_type = "a_element_t"; + const std::string type_string = MakeScalarOrVectorType(components, data_type); + + // Add the mm_readA function + shader.AdditionalImplementation() + << "fn mm_readA(batch: i32, row: i32, colIn: i32, batch_indices: batch_dims_indices_t) -> " << type_string << " {\n" + << " var value = " << type_string << "(0.0);\n" + << " let col = colIn * " << components << ";\n" + << " if(row < uniforms.dim_a_outer && col < uniforms.dim_inner) {\n" + << " var a_indices: a_indices_t;\n" + << ConvertOutputBatchIndicesToInputBatchIndices("a", a, a.Rank() - 2, batch_dims.Rank(), "batch_indices") + << a.IndicesSet("a_indices", a.Rank() - 2, "u32(row)") << "\n" + << a.IndicesSet("a_indices", a.Rank() - 1, "u32(colIn)") << "\n" + << " value = " << a.GetByIndices("a_indices") << ";\n" + << " }\n" + << " return value;\n" + << "}\n\n"; + + // Add the mm_readB function + shader.AdditionalImplementation() + << "fn mm_readB(batch: i32, row: i32, colIn: i32, batch_indices: batch_dims_indices_t) -> " << type_string << " {\n" + << " var value = " << type_string << "(0.0);\n" + << " let col = colIn * " << components << ";\n" + << " if(row < uniforms.dim_inner && col < uniforms.dim_b_outer) {\n" + << " var b_indices: b_indices_t;\n" + << ConvertOutputBatchIndicesToInputBatchIndices("b", b, b.Rank() - 2, batch_dims.Rank(), "batch_indices") + << b.IndicesSet("b_indices", b.Rank() - 2, "u32(row)") << "\n" + << b.IndicesSet("b_indices", b.Rank() - 1, "u32(colIn)") << "\n" + << " value = " << b.GetByIndices("b_indices") << ";\n" + << " }\n" + << " return value;\n" + << "}\n\n"; + + // Add the mm_write function + shader.AdditionalImplementation() + << "fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: " << type_string << ") {\n" + << " let col = colIn * " << components << ";\n" + << " if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer) {\n" + << " var value = valueIn;\n" + << " let coords = vec3(batch, row, colIn);\n"; + + if (has_bias_) { + shader.AdditionalImplementation() << " value = value + " << type_string << "(bias[row]);\n"; + } + + shader.AdditionalImplementation() + << output.SetByIndices("vec3(coords)", "value") << "\n" + << " }\n" + << "}\n\n"; +} + +Status MatMulProgram::MakeMatMulPackedVec4Source(ShaderHelper& shader, + const ShaderIndicesHelper& batch_dims, + const InlinedVector& elements_per_thread, + uint32_t workgroup_size_x, + uint32_t workgroup_size_y) { + // elements per thread + const auto elements_per_thread_x = elements_per_thread[0]; + const auto elements_per_thread_y = elements_per_thread[1]; + const decltype(elements_per_thread_x) tile_inner = 32; + + const auto tile_a_outer = workgroup_size_y * elements_per_thread_y; + const auto tile_b_outer = workgroup_size_x * elements_per_thread_x; + const auto tile_a_width = tile_inner; + + const auto tile_a_height = tile_a_outer; + const auto inner_elements_size = tile_a_width / workgroup_size_x; + const auto row_per_thread_b = tile_inner / workgroup_size_y; + + const std::string data_type = "a_element_t"; + + if (!((inner_elements_size == 3 || inner_elements_size == 4) && + tile_a_width % workgroup_size_x == 0 && + tile_inner % workgroup_size_y == 0 && + elements_per_thread_x == 4)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid matrix multiplication configuration inner_elements_size: ", inner_elements_size, + " must be 3 or 4. tile_a_width: ", tile_a_width, " must be divisible by WorkgroupSizeX: ", + workgroup_size_x, ". tile_inner: ", tile_inner, " must be divisible by WorkgroupSizeY: ", + workgroup_size_y, ". elements_per_thread_x: ", elements_per_thread_x, " must be 4."); + } + + shader.AdditionalImplementation() + << "var mm_Asub: array, " << tile_a_width / inner_elements_size << ">, " << tile_a_height << ">;\n" + << "var mm_Bsub: array, " << tile_b_outer / elements_per_thread_x << ">, " << tile_inner << ">;\n" + << "const rowPerThread = " << elements_per_thread_y << ";\n" + << "const colPerThread = " << elements_per_thread_x << ";\n" + << "const innerElementSize = " << inner_elements_size << ";\n" + << "const tileInner = " << tile_inner << ";\n"; + + shader.MainFunctionBody() + << " let localRow = i32(local_id.y);\n" + << " let tileRow = localRow * rowPerThread;\n" + << " let tileCol = i32(local_id.x);\n" + << " let globalRow = i32(global_id.y) * rowPerThread;\n" + << " let globalCol = i32(global_id.x);\n" + << " let batch = i32(global_id.z);\n" + << " let batchIndices = " << batch_dims.OffsetToIndices("u32(batch)") << ";\n" + << " let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n" + << " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n" + << " var kStart = 0;\n" + << " var acc: array, rowPerThread>;\n"; + + // Loop over shared dimension. + shader.MainFunctionBody() + << " let tileRowB = localRow * " << row_per_thread_b << ";\n" + << " for (var t = 0; t < num_tiles; t = t + 1) {\n"; + + // Load one tile of A into local memory. + shader.MainFunctionBody() + << " for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {\n" + << " let inputRow = tileRow + innerRow;\n" + << " let inputCol = tileCol;\n" + << " mm_Asub[inputRow][inputCol] = mm_readA(batch, globalRow + innerRow, kStart / innerElementSize + inputCol, batchIndices);\n" + << " }\n"; + + // Load one tile of B into local memory. + shader.MainFunctionBody() + << " for (var innerRow = 0; innerRow < " << row_per_thread_b << "; innerRow = innerRow + 1) {\n" + << " let inputRow = tileRowB + innerRow;\n" + << " let inputCol = tileCol;\n" + << " mm_Bsub[inputRow][inputCol] = mm_readB(batch, kStart + inputRow, globalCol, batchIndices);\n" + << " }\n" + << " kStart = kStart + tileInner;\n" + << " workgroupBarrier();\n"; + + // Compute acc values for a single thread. + shader.MainFunctionBody() + << " for (var k = 0; k < tileInner / innerElementSize; k = k + 1) {\n" + << " let BCached0 = mm_Bsub[k * innerElementSize][tileCol];\n" + << " let BCached1 = mm_Bsub[k * innerElementSize + 1][tileCol];\n" + << " let BCached2 = mm_Bsub[k * innerElementSize + 2][tileCol];\n"; + + if (inner_elements_size != 3) { + shader.MainFunctionBody() << " let BCached3 = mm_Bsub[k * innerElementSize + 3][tileCol];\n"; + } + + shader.MainFunctionBody() + << " for (var i = 0; i < rowPerThread; i = i + 1) {\n" + << " let ACached = mm_Asub[tileRow + i][k];\n" + << " acc[i] = BCached0 * ACached.x + acc[i];\n" + << " acc[i] = BCached1 * ACached.y + acc[i];\n" + << " acc[i] = BCached2 * ACached.z + acc[i];\n" + << " " << (inner_elements_size == 3 ? "" : "acc[i] = BCached3 * ACached.w + acc[i];") << "\n" + << " }\n"; + + shader.MainFunctionBody() << " workgroupBarrier();\n" + << " }\n"; // main for loop + + // Write the results to the output buffer + shader.MainFunctionBody() + << " for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {\n" + << " mm_write(batch, globalRow + innerRow, globalCol, acc[innerRow]);\n" + << " }\n" + << "}\n"; + + return Status::OK(); +} + +Status MatMulProgram::MakeMatMulPackedSource(ShaderHelper& shader, const ShaderIndicesHelper& batch_dims, + const InlinedVector& elements_per_thread, + uint32_t workgroup_size_x, + uint32_t workgroup_size_y) { + const auto elements_per_thread_x = elements_per_thread[0]; + const auto elements_per_thread_y = elements_per_thread[1]; + const decltype(elements_per_thread_x) tile_inner = 32; + + const auto tile_a_outer = workgroup_size_y * elements_per_thread_y; + const auto tile_b_outer = workgroup_size_x * elements_per_thread_x; + const auto tile_a_width = tile_inner; + const auto tile_a_height = tile_a_outer; + + if (!(tile_a_height % workgroup_size_y == 0 && tile_a_width % workgroup_size_x == 0 && tile_inner % workgroup_size_y == 0)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "tile_a_height: ", tile_a_height, " must be divisible by WorkgroupSizeY: ", workgroup_size_y, + ", tile_a_width: ", tile_a_width, " must be divisible by WorkgroupSizeX: ", workgroup_size_x, + ", tile_inner: ", tile_inner, " must be divisible by WorkgroupSizeY: ", workgroup_size_y); + } + + const std::string data_type = "a_element_t"; + + const auto row_per_thread_a = tile_a_height / workgroup_size_y; + const auto col_per_thread_a = tile_a_width / workgroup_size_x; + const auto row_per_thread_b = tile_inner / workgroup_size_y; + + shader.AdditionalImplementation() + << "var mm_Asub: array, " << tile_a_height << ">;\n" + << "var mm_Bsub: array, " << tile_inner << ">;\n" + << "const rowPerThread = " << elements_per_thread_y << ";\n" + << "const colPerThread = " << elements_per_thread_x << ";\n" + << "const tileInner = " << tile_inner << ";\n"; + + shader.MainFunctionBody() << " let batch = i32(global_id.z);\n" + << " let batchIndices = " << batch_dims.OffsetToIndices("u32(batch)") << ";\n" + << " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n" + << " var kStart = 0;\n" + << " var acc: array, rowPerThread>;\n"; + + shader.MainFunctionBody() + << "let tileRow = i32(local_id.y) * rowPerThread;\n" + << "let tileCol = i32(local_id.x) * colPerThread;\n" + << "let globalRow = i32(global_id.y) * rowPerThread;\n" + << "let globalCol = i32(global_id.x) * colPerThread;\n" + << "let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n" + << "let tileRowA = i32(local_id.y) * " << row_per_thread_a << ";\n" + << "let tileColA = i32(local_id.x) * " << col_per_thread_a << ";\n" + << "let tileRowB = i32(local_id.y) * " << row_per_thread_b << ";\n"; + + // Loop over shared dimension. + shader.MainFunctionBody() + << "for (var t = 0; t < num_tiles; t = t + 1) {\n"; + + // Load one tile of A into local memory. + shader.MainFunctionBody() + << " for (var innerRow = 0; innerRow < " << row_per_thread_a << "; innerRow = innerRow + 1) {\n" + << " for (var innerCol = 0; innerCol < " << col_per_thread_a << "; innerCol = innerCol + 1) {\n" + << " let inputRow = tileRowA + innerRow;\n" + << " let inputCol = tileColA + innerCol;\n" + << " mm_Asub[inputRow][inputCol] = mm_readA(batch, globalRowStart + inputRow, kStart + inputCol, batchIndices);\n" + << " }\n" + << " }\n"; + + // Load one tile of B into local memory. + shader.MainFunctionBody() + << " for (var innerRow = 0; innerRow < " << row_per_thread_b << "; innerRow = innerRow + 1) {\n" + << " for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) {\n" + << " let inputRow = tileRowB + innerRow;\n" + << " let inputCol = tileCol + innerCol;\n" + << " mm_Bsub[inputRow][inputCol] = mm_readB(batch, kStart + inputRow, globalCol + innerCol, batchIndices);\n" + << " }\n" + << " }\n" + << " kStart = kStart + tileInner;\n" + << " workgroupBarrier();\n"; + + // Compute acc values for a single thread. + shader.MainFunctionBody() + << "var BCached: array<" << data_type << ", colPerThread>;\n" + << " for (var k = 0; k < tileInner; k = k + 1) {\n" + << " for (var inner = 0; inner < colPerThread; inner = inner + 1) {\n" + << " BCached[inner] = mm_Bsub[k][tileCol + inner];\n" + << " }\n" + << " for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {\n" + << " let ACached = mm_Asub[tileRow + innerRow][k];\n" + << " for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) {\n" + << " acc[innerRow][innerCol] = acc[innerRow][innerCol] + ACached * BCached[innerCol];\n" + << " }\n" + << " }\n" + << " }\n" + << " workgroupBarrier();\n" + << "}\n"; + + // Write the results to the output buffer + shader.MainFunctionBody() + << "for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {\n" + << " for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) {\n" + << " mm_write(batch, globalRow + innerRow, globalCol + innerCol, acc[innerRow][innerCol]);\n" + << " }\n" + << "}\n"; + + return Status::OK(); +} + +Status MatMulProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& a = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& b = shader.AddInput("b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + const auto& batch_dims = shader.AddIndices("batch_dims", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + + if (has_bias_) { + shader.AddInput("bias", ShaderUsage::UseUniform); + } + + // declare the read and write functions + MatMulReadWriteFnSource(shader, a, b, output, batch_dims); + + // generate the main function + if (is_vec4_) { + ORT_RETURN_IF_ERROR(MakeMatMulPackedVec4Source(shader, batch_dims, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY())); + } else { + ORT_RETURN_IF_ERROR(MakeMatMulPackedSource(shader, batch_dims, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY())); + } + return Status::OK(); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.h b/onnxruntime/core/providers/webgpu/math/matmul_packed.h new file mode 100644 index 0000000000000..ea76468944066 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.h @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/math/matmul_utils.h" + +namespace onnxruntime { +namespace webgpu { +class MatMulProgram final : public Program { + public: + MatMulProgram(bool bias, bool is_vec4, const gsl::span& elements_per_thread) : Program{"MatMul"}, + has_bias_{bias}, + is_vec4_{is_vec4}, + elements_per_thread_(elements_per_thread.begin(), elements_per_thread.end()) {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"dim_a_outer", ProgramUniformVariableDataType::Int32}, + {"dim_b_outer", ProgramUniformVariableDataType::Int32}, + {"dim_inner", ProgramUniformVariableDataType::Int32}); + + static Status MakeMatMulPackedVec4Source(ShaderHelper& shader, + const ShaderIndicesHelper& batch_dims, + const InlinedVector& elements_per_thread, + uint32_t workgroup_size_x, + uint32_t workgroup_size_y); + static Status MakeMatMulPackedSource(ShaderHelper& shader, + const ShaderIndicesHelper& batch_dims, + const InlinedVector& elements_per_thread, + uint32_t workgroup_size_x, + uint32_t workgroup_size_y); + + private: + const bool has_bias_; + const bool is_vec4_; + const InlinedVector elements_per_thread_; + + void MatMulReadWriteFnSource(ShaderHelper& shader, const ShaderVariableHelper& a, const ShaderVariableHelper& b, const ShaderVariableHelper& output, const ShaderIndicesHelper& batch_dims) const; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/matmul_utils.h b/onnxruntime/core/providers/webgpu/math/matmul_utils.h new file mode 100644 index 0000000000000..bcd9c1b24a9bf --- /dev/null +++ b/onnxruntime/core/providers/webgpu/math/matmul_utils.h @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/inlined_containers.h" +#include "core/providers/cpu/tensor/utils.h" +#include "core/providers/webgpu/shader_helper.h" + +namespace onnxruntime { +namespace webgpu { + +// Helper that creates a new TensorShape for the intermediate result of MatMul +// The new shape is created by appending the two dimensions dim1 and dim2 / components to the original shape +inline TensorShape CreateMatMulIntermediateShape(const TensorShape& shape, const int64_t dim1, const int64_t dim2, const int components) { + TensorShapeVector shape_vec = shape.AsShapeVector(); + shape_vec.push_back(dim1); + shape_vec.push_back(dim2 / components); + return TensorShape(shape_vec); +} + +// Helper that convert output batch indices to input batch indices using only the rank and +// the shape information in uniform +inline std::string ConvertOutputBatchIndicesToInputBatchIndices(const std::string& name, const ShaderVariableHelper& input, int input_batch_rank, int output_batch_rank, const std::string& batch_indices) { + std::ostringstream oss; + const std::string input_shape = "uniforms." + name + "_shape"; + const std::string input_indices = name + "_indices"; + int extending_input_rank = output_batch_rank - input_batch_rank; + for (int i = 0; i < input_batch_rank; ++i) { + oss << "if (" << GetElementAt(input_shape, i, input.Rank()) << " != 1) {\n" + << input.IndicesSet(input_indices, i, GetElementAt(batch_indices, i + extending_input_rank, output_batch_rank)) << "\n" + << "} else {\n" + << input.IndicesSet(input_indices, i, 0) << "\n" + << "}\n"; + } + return oss.str(); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 874607988773b..15166df54e40c 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -621,8 +621,8 @@ std::unique_ptr RegisterKernels() { // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.h b/onnxruntime/core/providers/webgpu/webgpu_utils.h index eb25a9bd5386e..5f6f18f34b7f5 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.h +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.h @@ -29,5 +29,20 @@ inline std::string SumVector(std::string x, int components) { } } +inline std::string MakeScalarOrVectorType(int components, std::string_view data_type) { + switch (components) { + case 1: + return std::string{data_type}; + case 2: + return MakeStringWithClassicLocale("vec2<", data_type, ">"); + case 3: + return MakeStringWithClassicLocale("vec3<", data_type, ">"); + case 4: + return MakeStringWithClassicLocale("vec4<", data_type, ">"); + default: + ORT_THROW("Unsupported number of components: ", components); + } +} + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/math/matmul_test.cc b/onnxruntime/test/providers/cpu/math/matmul_test.cc index dd8cbed15e5ef..504e645738344 100644 --- a/onnxruntime/test/providers/cpu/math/matmul_test.cc +++ b/onnxruntime/test/providers/cpu/math/matmul_test.cc @@ -158,6 +158,112 @@ std::vector> GenerateTestCases() { // clang-format on })}); +#ifdef USE_WEBGPU + test_cases.push_back( + {"test 3D tensors with M = 1", + {6, 1, 8}, + {1, 8, 3}, + {6, 1, 3}, + real_expected_vals({ + // clang-format off + 420, 448, 476, + 1092, 1184, 1276, + 1764, 1920, 2076, + 2436, 2656, 2876, + 3108, 3392, 3676, + 3780, 4128, 4476, + // clang-format on + })}); + + test_cases.push_back( + {"test 4D tensors with M = 1", + {2, 3, 1, 8}, + {1, 1, 8, 3}, + {2, 3, 1, 3}, + real_expected_vals({420, 448, 476, 1092, 1184, 1276, 1764, 1920, 2076, 2436, 2656, 2876, 3108, 3392, 3676, 3780, 4128, 4476})}); + + test_cases.push_back( + {"test 4D tensors", + {2, 3, 4, 3}, + {2, 3, 3, 5}, + {2, 3, 4, 5}, + real_expected_vals({ + // clang-format off + 25, 28, 31, 34, 37, 70, 82, 94, 106, 118, 115, 136, 157, 178, 199, 160, 190, 220, + 250, 280, 790, 829, 868, 907, 946, 970, 1018, 1066, 1114, 1162, 1150, 1207, 1264, + 1321, 1378, 1330, 1396, 1462, 1528, 1594, 2635, 2710, 2785, 2860, 2935, 2950, 3034, + 3118, 3202, 3286, 3265, 3358, 3451, 3544, 3637, 3580, 3682, 3784, 3886, 3988, 5560, + 5671, 5782, 5893, 6004, 6010, 6130, 6250, 6370, 6490, 6460, 6589, 6718, 6847, 6976, + 6910, 7048, 7186, 7324, 7462, 9565, 9712, 9859, 10006, 10153, 10150, 10306, 10462, + 10618, 10774, 10735, 10900, 11065, 11230, 11395, 11320, 11494, 11668, 11842, 12016, + 14650, 14833, 15016, 15199, 15382, 15370, 15562, 15754, 15946, 16138, 16090, 16291, + 16492, 16693, 16894, 16810, 17020, 17230, 17440, 17650 + // clang-format on + })}); + + // Test case: multiplies 2D broadcasted to 4D tensors + test_cases.push_back( + {"test 2D broadcasted to 4D tensors", + {2, 4}, + {4, 3, 2, 4, 2}, + {4, 3, 2, 2, 2}, + real_expected_vals({ + // clang-format off + 28, 34, 76, 98, 76, 82, 252, 274, 124, 130, 428, 450, 172, 178, 604, 626, + 220, 226, 780, 802, 268, 274, 956, 978, 316, 322, 1132, 1154, 364, 370, + 1308, 1330, 412, 418, 1484, 1506, 460, 466, 1660, 1682, 508, 514, 1836, + 1858, 556, 562, 2012, 2034, 604, 610, 2188, 2210, 652, 658, 2364, 2386, + 700, 706, 2540, 2562, 748, 754, 2716, 2738, 796, 802, 2892, 2914, 844, + 850, 3068, 3090, 892, 898, 3244, 3266, 940, 946, 3420, 3442, 988, 994, + 3596, 3618, 1036, 1042, 3772, 3794, 1084, 1090, 3948, 3970, 1132, 1138, + 4124, 4146 + // clang-format on + })}); + + // Test case: multiplies 4D broadcasted to 5D tensors + test_cases.push_back( + {"test 4D broadcasted to 5D tensors", + {3, 1, 2, 4}, + {4, 3, 2, 4, 2}, + {4, 3, 2, 2, 2}, + real_expected_vals({ + // clang-format off + 28, 34, 76, 98, 76, 82, 252, 274, 732, 770, 1036, 1090, 1036, 1074, 1468, + 1522, 2460, 2530, 3020, 3106, 3020, 3090, 3708, 3794, 316, 322, 1132, + 1154, 364, 370, 1308, 1330, 2556, 2594, 3628, 3682, 2860, 2898, 4060, + 4114, 5820, 5890, 7148, 7234, 6380, 6450, 7836, 7922, 604, 610, 2188, + 2210, 652, 658, 2364, 2386, 4380, 4418, 6220, 6274, 4684, 4722, 6652, + 6706, 9180, 9250, 11276, 11362, 9740, 9810, 11964, 12050, 892, 898, 3244, + 3266, 940, 946, 3420, 3442, 6204, 6242, 8812, 8866, 6508, 6546, 9244, + 9298, 12540, 12610, 15404, 15490, 13100, 13170, 16092, 16178 + + // clang-format on + })}); + + // Test case: same ranks different broadcast small 1 + test_cases.push_back( + {"test same ranks different broadcast small 1", + {2, 1, 2, 2}, + {1, 2, 2, 1}, + {2, 2, 2, 1}, + real_expected_vals({1, 3, 3, 13, 5, 7, 23, 33})}); + + // Test case: same ranks different broadcast larger 0 + test_cases.push_back( + {"test same ranks different broadcast larger 0", + {1, 2, 2, 8}, + {2, 1, 8, 1}, + {2, 2, 2, 1}, + real_expected_vals({140, 364, 588, 812, 364, 1100, 1836, 2572})}); + + // Test case: same ranks different broadcast larger 1 + test_cases.push_back( + {"test same ranks different broadcast larger 1", + {2, 1, 2, 8}, + {1, 2, 8, 1}, + {2, 2, 2, 1}, + real_expected_vals({140, 364, 364, 1100, 588, 812, 1836, 2572})}); +#endif return test_cases; } @@ -189,6 +295,17 @@ void RunMatMulTest(int32_t opset_version, bool is_a_constant, bool is_b_constant excluded_providers.insert(kNnapiExecutionProvider); } + // TODO:: Change MatMulNaive Shader to support these test cases webgpu + std::unordered_set webgpu_excluded_test_cases{ + "test left 1D", + "test right 1D", + "test 2D empty input"}; + + // if test in webgpu_excluded_test_cases, add webgpu to excluded_providers + if (webgpu_excluded_test_cases.find(t.name) != webgpu_excluded_test_cases.end()) { + excluded_providers.insert(kWebGpuExecutionProvider); + } + test.ConfigExcludeEps(excluded_providers) .Config(run_with_tunable_op) .RunWithConfig(); @@ -234,10 +351,18 @@ TEST(MathOpTest, MatMulDoubleType) { } TEST(MathOpTest, MatMulInt32Type) { + // Webgpu does not support int32 matmul + if (DefaultWebGpuExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because of the following error: Webgpu does not support int32 matmul"; + } RunMatMulTest(9); } TEST(MathOpTest, MatMulUint32Type) { + // Webgpu does not support uint32 matmul + if (DefaultWebGpuExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because of the following error: Webgpu does not support uint32 matmul"; + } RunMatMulTest(9); } @@ -263,16 +388,22 @@ void RunMatMulZeroKTest() { // No special case is implemented. test.ConfigExcludeEps({kCoreMLExecutionProvider, kNnapiExecutionProvider, kDmlExecutionProvider, kDnnlExecutionProvider, kQnnExecutionProvider, - kOpenVINOExecutionProvider}) + kOpenVINOExecutionProvider, kWebGpuExecutionProvider}) .Config(run_with_tunable_op) .RunWithConfig(); } TEST(MathOpTest, MatMulZeroKFloatType) { + if (DefaultWebGpuExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because of the following error: Webgpu does not support zero-sized tensor"; + } RunMatMulZeroKTest(); } TEST(MathOpTest, MatMulZeroKInt32Type) { + if (DefaultWebGpuExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because of the following error: Webgpu does not support zero-sized tensor"; + } RunMatMulZeroKTest(); }