Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions onnxruntime/core/providers/webgpu/math/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "core/providers/webgpu/math/gemm.h"
#include "core/providers/webgpu/math/gemm_packed.h"
#include "core/providers/webgpu/vendor/intel/math/gemm.h"

#include <vector>

Expand Down Expand Up @@ -147,6 +148,10 @@ Status Gemm::ComputeInternal(ComputeContext& context) const {
return context.RunProgram(program);
}

if (intel::CanApplyGemmIntel(context, M, N, K, transA_, transB_)) {
return intel::ApplyGemmIntel(A, B, C, transA_, transB_, alpha_, beta_, context);
}

return ApplyGemmPacked(A, B, C, transA_, transB_, alpha_, beta_, context);
}

Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/webgpu/math/gemm_packed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Status GemmProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& a = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
const auto& b = shader.AddInput("b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);

MatMulReadFnSource(shader, a, b, nullptr, transA_, transB_, is_vec4_);
MatMulReadFnSource(shader, a, b, nullptr, transA_, transB_);
}
if (is_vec4_) {
ORT_RETURN_IF_ERROR(MakeMatMulPackedVec4Source(shader, elements_per_thread, WorkgroupSizeX(), WorkgroupSizeY(), data_type, nullptr, transA_, transB_, alpha_, need_handle_matmul_, output_components_, /*tile_inner*/ 32, need_split_k, split_dim_inner_));
Expand All @@ -45,7 +45,7 @@ Status GemmProgram::GenerateShaderCode(ShaderHelper& shader) const {
}

const ProgramVariableDataType output_var_type = this->Outputs()[0].var_type;
MatMulWriteFnSource(shader, output, c, /* is_gemm = */ true, c_components_, output_components_, c_is_scalar_, /*activation_snippet*/ "", /*is_channels_last*/ false, need_split_k, output_var_type);
MatMulWriteFnSource(shader, output, c, /* is_gemm = */ true, c_components_, c_is_scalar_, /*activation_snippet*/ "", /*is_channels_last*/ false, need_split_k, output_var_type);

return Status::OK();
}
Expand Down
65 changes: 33 additions & 32 deletions onnxruntime/core/providers/webgpu/math/gemm_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ void HandleMaybeBiasForMatMul(ShaderHelper& shader,
shader.AdditionalImplementation() << " value = value + output_value_t(" << (is_channels_last ? bias->GetByOffset("colIn") : bias->GetByOffset("row")) << ");\n";
}
shader.AdditionalImplementation() << " " << activation_snippet << "\n"
<< output.SetByIndices("coords", "value") << "\n";
<< " " << output.SetByIndices("coords", "value") << "\n";
}

void HandleMatMulWithSplitK(
Expand Down Expand Up @@ -127,60 +127,61 @@ void MatMulReadFnSource(ShaderHelper& shader,
const ShaderVariableHelper& b,
const ShaderIndicesHelper* batch_dims,
bool transA,
bool transB,
bool is_vec4) {
int components = is_vec4 ? 4 : 1;
bool transB) {
const int a_components = a.NumComponents();
const std::string data_type = "output_element_t";
const std::string type_string = MakeScalarOrVectorType(components, data_type);
std::string type_string = MakeScalarOrVectorType(a_components, data_type);

shader.AdditionalImplementation()
<< "fn mm_readA(batch: i32, row: i32, colIn: i32 "
<< (batch_dims
? ", batch_indices: batch_dims_indices_t"
: "")
<< ") -> " << type_string << " {\n "
<< " var value = " << type_string << "(0);\n"
<< " let col = colIn * " << components << ";\n";
<< ") -> " << type_string << " {\n"
<< " var value = " << type_string << "(0);\n"
<< " let col = colIn * " << a_components << ";\n";
if (transA) {
shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_a_outer)) {\n";
shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_a_outer)) {\n";
} else {
shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_inner)) {\n";
shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_inner)) {\n";
}
shader.AdditionalImplementation() << " var a_indices: a_indices_t;\n";
shader.AdditionalImplementation() << " var a_indices: a_indices_t;\n";

if (batch_dims) {
shader.AdditionalImplementation() << ConvertOutputBatchIndicesToInputBatchIndices("a", a, a.Rank() - 2, batch_dims ? batch_dims->Rank() : 0, " batch_indices ") << "\n";
shader.AdditionalImplementation() << ConvertOutputBatchIndicesToInputBatchIndices("a", a, a.Rank() - 2, batch_dims ? batch_dims->Rank() : 0, " batch_indices ");
}
shader.AdditionalImplementation() << a.IndicesSet("a_indices", a.Rank() - 2, "u32(row)") << "\n"
<< a.IndicesSet("a_indices", a.Rank() - 1, "u32(colIn)") << "\n"
<< " value = " << a.GetByIndices("a_indices") << ";\n"
<< " }\n"
<< " return value;\n"
shader.AdditionalImplementation() << " " << a.IndicesSet("a_indices", a.Rank() - 2, "u32(row)") << "\n"
<< " " << a.IndicesSet("a_indices", a.Rank() - 1, "u32(colIn)") << "\n"
<< " value = " << a.GetByIndices("a_indices") << ";\n"
<< " }\n"
<< " return value;\n"
<< "}\n\n";

// Add the mm_readB function
const int b_components = b.NumComponents();
type_string = MakeScalarOrVectorType(b_components, data_type);
shader.AdditionalImplementation()
<< "fn mm_readB(batch: i32, row: i32, colIn: i32 "
<< (batch_dims
? ", batch_indices: batch_dims_indices_t"
: "")
<< ") -> " << type_string << " {\n "
<< " var value = " << type_string << "(0);\n"
<< " let col = colIn * " << components << ";\n";
<< ") -> " << type_string << " {\n"
<< " var value = " << type_string << "(0);\n"
<< " let col = colIn * " << b_components << ";\n";

if (transB) {
shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_b_outer) && col < i32(uniforms.dim_inner)) {\n";
shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_b_outer) && col < i32(uniforms.dim_inner)) {\n";
} else {
shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_b_outer)) {\n";
shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_b_outer)) {\n";
}

shader.AdditionalImplementation() << " var b_indices: b_indices_t;\n"
shader.AdditionalImplementation() << " var b_indices: b_indices_t;\n"
<< ConvertOutputBatchIndicesToInputBatchIndices("b", b, b.Rank() - 2, batch_dims ? batch_dims->Rank() : 0, "batch_indices")
<< b.IndicesSet("b_indices", b.Rank() - 2, "u32(row)") << "\n"
<< b.IndicesSet("b_indices", b.Rank() - 1, "u32(colIn)") << "\n"
<< " value = " << b.GetByIndices("b_indices") << ";\n"
<< " }\n"
<< " return value;\n"
<< " " << b.IndicesSet("b_indices", b.Rank() - 2, "u32(row)") << "\n"
<< " " << b.IndicesSet("b_indices", b.Rank() - 1, "u32(colIn)") << "\n"
<< " value = " << b.GetByIndices("b_indices") << ";\n"
<< " }\n"
<< " return value;\n"
<< "}\n\n";
}

Expand All @@ -189,19 +190,19 @@ void MatMulWriteFnSource(ShaderHelper& shader,
const ShaderVariableHelper* bias,
bool is_gemm,
int c_components,
int output_components,
bool c_is_scalar,
std::string activation_snippet,
bool is_channels_last,
bool use_split_k,
ProgramVariableDataType output_variable_type) {
const int output_components = output.NumComponents();
shader.AdditionalImplementation()
<< "fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: output_value_t) { \n";
<< "fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: output_value_t) {\n";

shader.AdditionalImplementation() << " let col = colIn * " << output_components << ";\n";

shader.AdditionalImplementation() << "if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_b_outer)) { \n"
<< " var value = valueIn; \n";
shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_b_outer)) {\n"
<< " var value = valueIn;\n";

if (use_split_k) {
// Set output when MatMul is performed with Split-K.
Expand Down
4 changes: 1 addition & 3 deletions onnxruntime/core/providers/webgpu/math/gemm_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@ void MatMulReadFnSource(ShaderHelper& shader,
const ShaderVariableHelper& b,
const ShaderIndicesHelper* batch_dims,
bool transA,
bool transB,
bool is_vec4);
bool transB);

void MatMulWriteFnSource(ShaderHelper& shader,
const ShaderVariableHelper& output,
const ShaderVariableHelper* bias,
bool is_gemm,
int c_components,
int output_components,
bool c_is_scalar,
std::string activation_snippet = "",
bool is_channels_last = false,
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/providers/webgpu/math/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "core/providers/webgpu/nn/fuse_utils.h"
#include "core/providers/webgpu/data_transfer.h"
#include "core/providers/webgpu/vendor/intel/math/matmul.h"
#include "core/providers/webgpu/webgpu_utils.h"

namespace onnxruntime {
Expand Down Expand Up @@ -163,6 +164,10 @@ Status MatMul::ComputeInternal(ComputeContext& context) const {
inputs.push_back(bias);
}

if (intel::CanApplyMatMulIntel(context, helper.M(), helper.N(), helper.K())) {
return intel::ApplyMatMulIntel(context, Activation(), inputs, output_tensor);
}

return ComputeMatMul(&context, Activation(), inputs, output_tensor, false);
}

Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/providers/webgpu/math/matmul_packed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ Status MatMulProgram::GenerateShaderCode(ShaderHelper& shader) const {
std::string apply_activation = GetActivationSnippet(activation_, "output_value_t", "output_element_t");
ProgramVariableDataType output_var_type = this->Outputs()[0].var_type;
// declare the read and write functions
MatMulReadFnSource(shader, a, b, &batch_dims, /*transA = */ false, /*transB = */ false, is_vec4_);
MatMulWriteFnSource(shader, output, bias, /* is_gemm = */ false, 1, is_vec4_ ? 4 : 1, false, apply_activation, is_channels_last_, need_split_k, output_var_type);
MatMulReadFnSource(shader, a, b, &batch_dims, /*transA = */ false, /*transB = */ false);
MatMulWriteFnSource(shader, output, bias, /* is_gemm = */ false, 1, false, apply_activation, is_channels_last_, need_split_k, output_var_type);
std::string data_type = "a_element_t";
// generate the main function
if (is_vec4_) {
Expand Down Expand Up @@ -65,7 +65,7 @@ Status MatMulFillBiasOrZeroBeforeSplitKProgram::GenerateShaderCode(ShaderHelper&
// `use_split_k` is true only when we do the actual MatMul with Split-K.
const uint32_t bias_components = output_components_;
MatMulWriteFnSource(
shader, output, bias, is_gemm_, bias_components, output_components_, bias_is_scalar_,
shader, output, bias, is_gemm_, bias_components, bias_is_scalar_,
/*activation_snippet*/ "", /*is_channels_last*/ true, /*use_split_k*/ false);

shader.MainFunctionBody() << " let output_components = " << output_components_ << ";\n";
Expand Down
121 changes: 121 additions & 0 deletions onnxruntime/core/providers/webgpu/vendor/intel/math/gemm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/webgpu/vendor/intel/math/gemm.h"
#include "core/providers/webgpu/vendor/intel/math/gemm_subgroup.h"
#include "core/providers/webgpu/math/gemm_utils.h"

namespace onnxruntime {
namespace webgpu {
namespace intel {

Status GemmSubgroupProgram::GenerateShaderCode(ShaderHelper& shader) const {
const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform |
ShaderUsage::UseValueTypeAlias |
ShaderUsage::UseElementTypeAlias);

if (need_handle_matmul_) {
const auto& a = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias |
ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
const auto& b = shader.AddInput("b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias |
ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);

MatMulReadFnSource(shader, a, b, nullptr, transA_, transB_);
}

ORT_RETURN_IF_ERROR(MakeMatMulSubgroupSource(shader, elements_per_thread_, nullptr, is_vec4_, transA_, transB_,
alpha_, need_handle_matmul_));
const ShaderVariableHelper* c = nullptr;
if (need_handle_bias_) {
c = &shader.AddInput("c", ShaderUsage::UseUniform);
}
MatMulWriteFnSource(shader, output, c, true, c_components_, c_is_scalar_);

return Status::OK();
}

bool CanApplyGemmIntel(const ComputeContext& context, int64_t M, int64_t N, int64_t K, bool transA, bool transB) {
return CanApplySubgroup(context, M, N, K, transA, transB);
}

Status ApplyGemmIntel(const Tensor* a,
const Tensor* b,
const Tensor* c,
bool transA,
bool transB,
float alpha,
float beta,
ComputeContext& context) {
const auto& a_shape = a->Shape();
const auto& b_shape = b->Shape();

uint32_t M = onnxruntime::narrow<uint32_t>(transA ? a_shape[1] : a_shape[0]);
uint32_t K = onnxruntime::narrow<uint32_t>(transA ? a_shape[0] : a_shape[1]);
uint32_t N = onnxruntime::narrow<uint32_t>(transB ? b_shape[0] : b_shape[1]);

std::vector<int64_t> output_dims{M, N};

Check warning on line 56 in onnxruntime/core/providers/webgpu/vendor/intel/math/gemm.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/vendor/intel/math/gemm.cc:56: Add #include <vector> for vector<> [build/include_what_you_use] [4]
auto* y = context.Output(0, output_dims);
int64_t output_size = y->Shape().Size();

if (output_size == 0) {
return Status::OK();
}

// WebGPU doesn't support binding a zero-sized buffer, so we need to check if A or B is empty.
bool need_handle_matmul = a_shape.Size() > 0 && b_shape.Size() > 0;
bool need_handle_bias = c && beta;

const bool is_vec4 = b_shape[1] % 4 == 0;
// Components for A, B
int a_components = 1;
int b_components = is_vec4 ? 4 : 1;
// Components for Y
int output_components = (is_vec4 && N % 4 == 0) ? 4 : 1;
// Components for C.
int c_components = 1;

bool c_is_scalar = false;
if (need_handle_bias) {
const auto& c_shape = c->Shape();
int64_t c_last_dim = c_shape[c_shape.NumDimensions() - 1];
// `C` in GEMM might be broadcast to the output, and broadcasting requires the components to be consistent.
// So we use vec4 for C when its last dimension is N, and the output is also a vec4.
c_components = (c_last_dim == N && output_components == 4) ? 4 : 1;
c_is_scalar = c_shape.Size() == 1;
}

InlinedVector<int64_t> elements_per_thread = InlinedVector<int64_t>({4, intel::ElementsPerThreadY(is_vec4, M), 1});
const uint32_t dispatch_x = narrow<uint32_t>((N + kSubgroupLogicalWorkGroupSizeX * elements_per_thread[0] - 1) /
(kSubgroupLogicalWorkGroupSizeX * elements_per_thread[0]));
const uint32_t dispatch_y = narrow<uint32_t>((M + kSubgroupLogicalWorkGroupSizeY * elements_per_thread[1] - 1) /
(kSubgroupLogicalWorkGroupSizeY * elements_per_thread[1]));

GemmSubgroupProgram program{transA, transB, alpha, need_handle_bias, need_handle_matmul, c_components, c_is_scalar,
is_vec4, elements_per_thread};

if (need_handle_matmul) {
program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_components},
{b, ProgramTensorMetadataDependency::TypeAndRank, b_components}});
}

if (need_handle_bias) {
program.AddInput({c, ProgramTensorMetadataDependency::TypeAndRank, c_components});
}

program.CacheHint(alpha, transA, transB, c_is_scalar, absl::StrJoin(elements_per_thread, "-"))
.AddOutputs({{y, ProgramTensorMetadataDependency::TypeAndRank, output_components}})
.SetDispatchGroupSize(dispatch_x, dispatch_y, 1)
.SetWorkgroupSize(kSubgroupLogicalWorkGroupSizeX * kSubgroupLogicalWorkGroupSizeY, 1, 1)
.AddUniformVariables({{alpha},
{beta},
{M}, /* dim_a_outer */
{N}, /* dim_b_outer */
{K}} /*dim_inner */
);

return context.RunProgram(program);
}

} // namespace intel
} // namespace webgpu
} // namespace onnxruntime
Loading
Loading