Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 228 additions & 0 deletions onnxruntime/core/providers/webgpu/math/matmul.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/webgpu/math/matmul.h"
#include "core/common/inlined_containers.h"
#include "core/providers/cpu/tensor/utils.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"

#include "core/providers/webgpu/data_transfer.h"
namespace onnxruntime {
namespace webgpu {

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
MatMul,
kOnnxDomain,
1, 12,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", WebGpuSupportedFloatTypes()),
MatMul);

ONNX_OPERATOR_KERNEL_EX(
MatMul,
kOnnxDomain,
13,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", WebGpuSupportedFloatTypes()),
MatMul);

static std::string CalcResult(int64_t components, int64_t a_components, int64_t output_number) {
std::ostringstream oss;
oss << "var a_data: a_value_t;\n";
for (int i = 0; i < a_components; ++i) {
oss << "let b_data" << i << " = b[(b_offset + (k + " << i << ") * uniforms.N + col) / " << components << "];\n";
}
for (int i = 0; i < output_number; ++i) {
oss << "a_data = a[(a_offset + (row + " << i << ") * uniforms.K + k) / " << a_components << "];\n";

for (int j = 0; j < a_components; j++) {
oss << "values[" << i << "] = fma(b_value_t(a_data" << (a_components == 1 ? "" : "[" + std::to_string(j) + "]") << "), b_data" << j << ", values[" << i << "]);\n";
}
}
return oss.str();
}

Status MatMulNaiveProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& a = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias |
ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
const auto& b = shader.AddInput("b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias |
ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);

std::string process_bias;

Check warning on line 54 in onnxruntime/core/providers/webgpu/math/matmul.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/math/matmul.cc:54: Add #include <string> for string [build/include_what_you_use] [4]
if (has_bias_) {
shader.AddInput("bias", ShaderUsage::UseUniform);
process_bias = "value += output_value_t(bias[row + i]);";
}

const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform |
ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
const auto& batch_dims = shader.AddIndices("batch_dims");

int a_components = a.NumComponents();
int components = b.NumComponents(); // components of N

shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")
<< "let col = (global_idx % (uniforms.N / " << components << ")) * " << components << ";\n"
<< "var index1 = global_idx / (uniforms.N / " << components << ");\n"
<< "let stride1 = uniforms.M / " << output_number_ << ";\n"
<< "let row = (index1 % stride1) * " << output_number_ << ";\n"
<< "let batch = index1 / stride1;\n";
if (output_rank_ != 2) {
shader.MainFunctionBody() << "let batch_indices = " << batch_dims.OffsetToIndices("batch") << ";\n";
}
shader.MainFunctionBody() << "var a_indices: a_indices_t;\n"
<< ConvertOutputBatchIndicesToInputBatchIndices("a", a, a.Rank() - 2, batch_dims.Rank(), "batch_indices")
<< a.IndicesSet("a_indices", a.Rank() - 2, 0) << "\n"
<< a.IndicesSet("a_indices", a.Rank() - 1, 0) << "\n"
<< "let a_offset = " << a.IndicesToOffset("a_indices") << "*" << a_components << ";\n"
<< "var b_indices: b_indices_t;\n"
<< ConvertOutputBatchIndicesToInputBatchIndices("b", b, b.Rank() - 2, batch_dims.Rank(), "batch_indices")
<< b.IndicesSet("b_indices", b.Rank() - 2, 0) << "\n"
<< b.IndicesSet("b_indices", b.Rank() - 1, 0) << "\n"
<< "let b_offset = " << b.IndicesToOffset("b_indices") << " * " << components << ";\n"
<< "var values: array<output_value_t, " << output_number_ << ">;\n"
<< "for (var k: u32 = 0u; k < uniforms.K; k = k + " << a_components << ") {\n"
<< CalcResult(components, a_components, output_number_) << "\n"
<< "}\n"
<< "for (var i = 0u; i < " << output_number_ << "u; i++) {\n"
<< " var value = values[i];\n"
<< process_bias << "\n"
<< " let cur_indices = output_indices_t(batch, row + i, col/ " << components << ");\n"
<< " let offset = " << output.IndicesToOffset("cur_indices") << ";\n"
<< output.SetByOffset("offset", "value")
<< "}\n";

return Status::OK();
}

Status MatMul::ComputeInternal(ComputeContext& context) const {
// calculate output shape
MatMulComputeHelper helper;
const auto* a = context.Input(0);
const auto* b = context.Input(1);

ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b->Shape()));
auto* output_tensor = context.Output(0, helper.OutputShape());
bool has_bias = context.InputCount() > 2;

if (helper.N() < 8 && helper.K() < 8) { // call MatMulNaiveProgram

const uint32_t m = narrow<uint32_t>(helper.M()); // left matrix first dimension
const uint32_t n = narrow<uint32_t>(helper.N()); // right matrix second dimension
const uint32_t k = narrow<uint32_t>(helper.K()); // right matrix first dimension

const auto components = GetMaxComponents(n);
const auto a_components = GetMaxComponents(k);

const auto output_number = GetMaxComponents(m);
uint32_t output_size = narrow<uint32_t>(helper.OutputShape().Size() / components / output_number);

const size_t output_rank = helper.OutputShape().NumDimensions();
TensorShape outer_dims = output_rank > 2 ? helper.OutputShape().Slice(0, output_rank - 2) : TensorShape({});
const int64_t batch_size = outer_dims.Size();

const int64_t a_rows = a->Shape().NumDimensions() > 1 ? a->Shape()[a->Shape().NumDimensions() - 2] : 1;
TensorShape output_shape_shader({batch_size, a_rows, helper.N() / components});

MatMulNaiveProgram program{output_rank, output_number, has_bias};

program
.CacheHint(std::to_string(components), std::to_string(a_components), std::to_string(output_number))
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_components},
{b, ProgramTensorMetadataDependency::TypeAndRank, components}});

if (has_bias) {
const auto* bias = context.Input(2);
program.AddInput({bias, ProgramTensorMetadataDependency::Rank, 1});
}
program
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::None, output_shape_shader, components}})
.SetDispatchGroupSize((output_size + 63) / 64) // Integer ceiling division
.AddIndices(outer_dims)
.AddUniformVariables({{output_size}, {m}, {n}, {k}});

return context.RunProgram(program);
}

int64_t batchA = a->Shape().SizeToDimension(a->Shape().NumDimensions() - 2);
int64_t batchB = b->Shape().SizeToDimension(b->Shape().NumDimensions() - 2);

TensorShape a_shape = a->Shape();
TensorShape b_shape = b->Shape();
TensorShape output_shape = helper.OutputShape();

const int64_t dim_output_outer = output_shape[output_shape.NumDimensions() - 2];
// check if A is batch of vector (bach is not 1, M is 1) and B is a matrix (batch is 1)
if (batchA != 1 && dim_output_outer == 1 && batchB == 1) {
// optimization for batched vector matrix multiplication
// dimensions of A: [1,`batchA`,K]
TensorShapeVector dims_a = {1, batchA, helper.K()};
// dimensions of B: [1,K,N]
TensorShapeVector dims_b = {1, helper.K(), helper.N()};

a_shape = TensorShape(dims_a);
b_shape = TensorShape(dims_b);
output_shape = {1, batchA, helper.N()};
}

// helpful dimension variables
TensorShape outer_dims_a = a_shape.NumDimensions() > 2
? a_shape.Slice(0, a_shape.NumDimensions() - 2)
: TensorShape({});

TensorShape outer_dims_b = b_shape.NumDimensions() > 2
? b_shape.Slice(0, b_shape.NumDimensions() - 2)
: TensorShape({});

TensorShape outer_dims = output_shape.NumDimensions() > 2
? output_shape.Slice(0, output_shape.NumDimensions() - 2)
: TensorShape({});

const int64_t batch_size = outer_dims.Size();

// Get dimensions for matrix multiplication from TensorShape
const int32_t dim_a_outer = narrow<int32_t>(a_shape[a_shape.NumDimensions() - 2]); // left matrix second dimension
const int32_t dim_inner = narrow<int32_t>(a_shape[a_shape.NumDimensions() - 1]); // left matrix first dimension
const int32_t dim_b_outer = narrow<int32_t>(b_shape[b_shape.NumDimensions() - 1]); // right matrix first dimension

const bool is_vec4 = dim_inner % 4 == 0 && dim_b_outer % 4 == 0;

InlinedVector<int64_t> elements_per_thread = dim_a_outer <= 8
? InlinedVector<int64_t>({4, 1, 1})
: InlinedVector<int64_t>({4, 4, 1});

const uint32_t dispatch_x = narrow<uint32_t>((dim_b_outer + MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0] - 1) /
(MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0]));
const uint32_t dispatch_y = narrow<uint32_t>((dim_a_outer + MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1] - 1) /
(MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1]));
const uint32_t dispatch_z = narrow<uint32_t>((static_cast<uint32_t>(batch_size) + MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2] - 1) /
(MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2]));

const int components = is_vec4 ? 4 : 1;
const TensorShape a_shape_temp = CreateMatMulIntermediateShape(outer_dims_a, dim_a_outer, dim_inner, components);
const TensorShape b_shape_temp = CreateMatMulIntermediateShape(outer_dims_b, dim_inner, dim_b_outer, components);
const TensorShape output_shape_temp = TensorShape({batch_size, dim_a_outer, dim_b_outer / components});

MatMulProgram program{has_bias, is_vec4, elements_per_thread};
program
.CacheHint(absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4))
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, components},
{b, ProgramTensorMetadataDependency::TypeAndRank, b_shape_temp, components}})
.AddOutputs({{output_tensor, ProgramTensorMetadataDependency::Rank, output_shape_temp, components}})
.AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}})
.AddIndices(outer_dims)
.SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z)
.SetWorkgroupSize(MATMUL_PACKED_WORKGROUP_SIZE_X, MATMUL_PACKED_WORKGROUP_SIZE_Y, MATMUL_PACKED_WORKGROUP_SIZE_Z);

if (has_bias) {
const auto* bias = context.Input(2);
program.AddInput({bias, ProgramTensorMetadataDependency::Rank, 1});
}
return context.RunProgram(program);
}

} // namespace webgpu
} // namespace onnxruntime
47 changes: 47 additions & 0 deletions onnxruntime/core/providers/webgpu/math/matmul.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/webgpu_kernel.h"
#include "core/providers/webgpu/program.h"
#include "core/providers/cpu/math/matmul_helper.h"
#include "core/providers/webgpu/math/matmul_utils.h"
#include "core/providers/webgpu/math/matmul_packed.h"
#include "core/providers/webgpu/webgpu_utils.h"

namespace onnxruntime {
namespace webgpu {

class MatMul final : public WebGpuKernel {
public:
MatMul(const OpKernelInfo& info) : WebGpuKernel{info} {}

Status ComputeInternal(ComputeContext& context) const override;

constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_X = 8;
constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_Y = 8;
constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_Z = 1;
};

class MatMulNaiveProgram final : public Program<MatMulNaiveProgram> {
public:
MatMulNaiveProgram(const size_t output_rank, int64_t output_number, bool has_bias)
: Program{"MatMulNaive"}, output_rank_(output_rank), output_number_(output_number), has_bias_{has_bias} {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32},
{"M", ProgramUniformVariableDataType::Uint32},
{"N", ProgramUniformVariableDataType::Uint32},
{"K", ProgramUniformVariableDataType::Uint32});

private:
const size_t output_rank_;
const int64_t output_number_;
const bool has_bias_;
};

} // namespace webgpu
} // namespace onnxruntime
Loading
Loading