Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions onnxruntime/core/providers/webgpu/math/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "core/providers/webgpu/math/gemm.h"
#include "core/providers/webgpu/math/gemm_packed.h"
#include "core/providers/webgpu/vendor/intel/math/gemm_intel.h"

#include <vector>

Expand Down Expand Up @@ -147,6 +148,10 @@ Status Gemm::ComputeInternal(ComputeContext& context) const {
return context.RunProgram(program);
}

if (CanApplyGemmIntel(context, M, N, K, transA_, transB_)) {
return ApplyGemmIntel(A, B, C, transA_, transB_, alpha_, beta_, context);
}

return ApplyGemmPacked(A, B, C, transA_, transB_, alpha_, beta_, context);
}

Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/providers/webgpu/math/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "core/providers/webgpu/nn/fuse_utils.h"
#include "core/providers/webgpu/data_transfer.h"
#include "core/providers/webgpu/vendor/intel/math/matmul_intel.h"

namespace onnxruntime {
namespace webgpu {
Expand Down Expand Up @@ -161,6 +162,11 @@ Status MatMul::ComputeInternal(ComputeContext& context) const {
const auto* bias = context.Input(2);
inputs.push_back(bias);
}

if (CanApplyMatMulIntel(context, helper.M(), helper.N(), helper.K())) {
return ApplyMatMulIntel(context, Activation(), inputs, output_tensor);
}

auto program = CreateMatMulProgram(Activation(), inputs, output_tensor, false);

return context.RunProgram(program);
Expand Down
133 changes: 133 additions & 0 deletions onnxruntime/core/providers/webgpu/vendor/intel/math/gemm_intel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/webgpu/vendor/intel/math/gemm_intel.h"
#include "core/providers/webgpu/vendor/intel/math/gemm_subgroup.h"
#include "core/providers/webgpu/vendor/intel/math/gemm_utils_intel.h"

namespace onnxruntime {
namespace webgpu {

Status GemmSubgroupProgram::GenerateShaderCode(ShaderHelper& shader) const {
const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform |
ShaderUsage::UseValueTypeAlias |
ShaderUsage::UseElementTypeAlias);
const std::string data_type = "output_element_t";

Check warning on line 15 in onnxruntime/core/providers/webgpu/vendor/intel/math/gemm_intel.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/vendor/intel/math/gemm_intel.cc:15: Add #include <string> for string [build/include_what_you_use] [4]

if (need_handle_matmul_) {
const auto& a = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias |
ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
const auto& b = shader.AddInput("b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias |
ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);

MatMulReadFnSourceIntel(shader, a, b, nullptr, transA_, transB_, is_vec4_, true);
}
if (is_vec4_) {
ORT_RETURN_IF_ERROR(MakeMatMulSubgroupVec4Source(shader, elements_per_thread_, data_type, nullptr, transA_,
transB_, alpha_, need_handle_matmul_));
} else {
ORT_RETURN_IF_ERROR(MakeMatMulSubgroupSource(shader, elements_per_thread_, data_type, nullptr, transA_, transB_,
alpha_, need_handle_matmul_));
}
MatMulWriteFnSourceIntel(shader, output, need_handle_bias_, true, c_components_, output_components_, c_is_scalar_);

return Status::OK();
}

bool CanApplyGemmIntel(const ComputeContext& context, int64_t M, int64_t N, int64_t K, bool transA, bool transB) {
if (context.AdapterInfo().vendor == std::string_view{"intel"}) {
bool use_subgroup = (context.AdapterInfo().architecture == std::string_view{"xe-2lpg"} ||
context.AdapterInfo().architecture == std::string_view{"xe-2hpg"}) &&
M > 16 && N > 768 && K >= 32 && !transA && !transB;
return use_subgroup;
}

return false;
}

Status ApplyGemmIntel(const Tensor* a,
const Tensor* b,
const Tensor* c,
bool transA,
bool transB,
float alpha,
float beta,
ComputeContext& context) {
const auto& a_shape = a->Shape();
const auto& b_shape = b->Shape();

uint32_t M = onnxruntime::narrow<uint32_t>(transA ? a_shape[1] : a_shape[0]);
uint32_t K = onnxruntime::narrow<uint32_t>(transA ? a_shape[0] : a_shape[1]);
uint32_t N = onnxruntime::narrow<uint32_t>(transB ? b_shape[0] : b_shape[1]);

std::vector<int64_t> output_dims{M, N};

Check warning on line 63 in onnxruntime/core/providers/webgpu/vendor/intel/math/gemm_intel.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/vendor/intel/math/gemm_intel.cc:63: Add #include <vector> for vector<> [build/include_what_you_use] [4]
auto* y = context.Output(0, output_dims);
int64_t output_size = y->Shape().Size();

if (output_size == 0) {
return Status::OK();
}

// WebGPU doesn't support binding a zero-sized buffer, so we need to check if A or B is empty.
bool need_handle_matmul = a_shape.Size() > 0 && b_shape.Size() > 0;
bool need_handle_bias = c && beta;

bool use_subgroup = M > 16 && N > 768 && K >= 32 && !transA && !transB;
if (use_subgroup) {
const bool is_vec4 = b_shape[1] % 4 == 0;
// Components for A, B
int a_components = 1;
int b_components = is_vec4 ? 4 : 1;
// Components for Y
int output_components = (is_vec4 && N % 4 == 0) ? 4 : 1;
// Components for C.
int c_components = 1;

bool c_is_scalar = false;
if (need_handle_bias) {
const auto& c_shape = c->Shape();
int64_t c_last_dim = c_shape[c_shape.NumDimensions() - 1];
// `C` in GEMM might be broadcast to the output, and broadcasting requires the components to be consistent.
// So we use vec4 for C when its last dimension is N, and the output is also a vec4.
c_components = (c_last_dim == N && output_components == 4) ? 4 : 1;
c_is_scalar = c_shape.Size() == 1;
}

const int64_t elements_per_thread_y = is_vec4 ? (M <= 8 ? 1 : (M <= 16 ? 2 : (M <= 32 ? 4 : 8))) : 4;
InlinedVector<int64_t> elements_per_thread = InlinedVector<int64_t>({4, elements_per_thread_y, 1});
const uint32_t dispatch_x = narrow<uint32_t>((N + kSubgroupLogicalWorkGroupSizeX * elements_per_thread[0] - 1) /
(kSubgroupLogicalWorkGroupSizeX * elements_per_thread[0]));
const uint32_t dispatch_y = narrow<uint32_t>((M + kSubgroupLogicalWorkGroupSizeY * elements_per_thread[1] - 1) /
(kSubgroupLogicalWorkGroupSizeY * elements_per_thread[1]));

GemmSubgroupProgram program{transA, transB, alpha, need_handle_bias, need_handle_matmul, c_components, c_is_scalar,
output_components, is_vec4, elements_per_thread};

if (need_handle_matmul) {
program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_components},
{b, ProgramTensorMetadataDependency::TypeAndRank, b_components}});
}

if (need_handle_bias) {
program.AddInput({c, ProgramTensorMetadataDependency::TypeAndRank, c_components});
}

program.CacheHint(alpha, transA, transB, c_is_scalar, is_vec4, absl::StrJoin(elements_per_thread, "-"))
.AddOutputs({{y, ProgramTensorMetadataDependency::TypeAndRank, output_components}})
.SetDispatchGroupSize(dispatch_x, dispatch_y, 1)
.SetWorkgroupSize(kSubgroupLogicalWorkGroupSizeX * kSubgroupLogicalWorkGroupSizeY, 1, 1)
.AddUniformVariables({{alpha},
{beta},
{M}, /* dim_a_outer */
{N}, /* dim_b_outer */
{K}} /*dim_inner */
);

return context.RunProgram(program);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, "Should not reach here");
}
}

} // namespace webgpu
} // namespace onnxruntime
64 changes: 64 additions & 0 deletions onnxruntime/core/providers/webgpu/vendor/intel/math/gemm_intel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/webgpu_kernel.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/program.h"

namespace onnxruntime {
namespace webgpu {

class GemmSubgroupProgram final : public Program<GemmSubgroupProgram> {
public:
GemmSubgroupProgram(bool transA, bool transB, float alpha, bool need_handle_bias, bool need_handle_matmul,
int c_components, bool c_is_scalar, int output_components, bool is_vec4,
const gsl::span<int64_t>& elements_per_thread)
: Program{"GemmSubgroup"},
transA_{transA},
transB_{transB},
alpha_{alpha},
need_handle_bias_{need_handle_bias},
need_handle_matmul_{need_handle_matmul},
c_components_(c_components),
c_is_scalar_(c_is_scalar),
output_components_(output_components),
is_vec4_(is_vec4),
elements_per_thread_(elements_per_thread.begin(), elements_per_thread.end()) {}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"alpha", ProgramUniformVariableDataType::Float32},
{"beta", ProgramUniformVariableDataType::Float32},
{"dim_a_outer", ProgramUniformVariableDataType::Uint32},
{"dim_b_outer", ProgramUniformVariableDataType::Uint32},
{"dim_inner", ProgramUniformVariableDataType::Uint32});

private:
bool transA_;
bool transB_;
float alpha_;
bool need_handle_bias_;
bool need_handle_matmul_;
int c_components_;
bool c_is_scalar_ = false;
int output_components_;
bool is_vec4_ = false;
const InlinedVector<int64_t> elements_per_thread_;
};

bool CanApplyGemmIntel(const ComputeContext& context, int64_t M, int64_t N, int64_t K, bool transA, bool transB);

Status ApplyGemmIntel(const Tensor* a,
const Tensor* b,
const Tensor* c,
bool transA,
bool transB,
float alpha,
float beta,
ComputeContext& context);

} // namespace webgpu
} // namespace onnxruntime
Loading
Loading