diff --git a/onnxruntime/core/providers/webgpu/math/gemm.cc b/onnxruntime/core/providers/webgpu/math/gemm.cc index 4fb512001381a..b722430049877 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm.cc @@ -3,6 +3,7 @@ #include "core/providers/webgpu/math/gemm.h" #include "core/providers/webgpu/math/gemm_packed.h" +#include "core/providers/webgpu/vendor/intel/math/gemm.h" #include @@ -147,6 +148,10 @@ Status Gemm::ComputeInternal(ComputeContext& context) const { return context.RunProgram(program); } + if (intel::CanApplyGemmIntel(context, M, N, K, transA_, transB_)) { + return intel::ApplyGemmIntel(A, B, C, transA_, transB_, alpha_, beta_, context); + } + return ApplyGemmPacked(A, B, C, transA_, transB_, alpha_, beta_, context); } diff --git a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc index 1a0ad7a843ec4..023a671420d89 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc @@ -31,7 +31,7 @@ Status GemmProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& a = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); const auto& b = shader.AddInput("b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); - MatMulReadFnSource(shader, a, b, nullptr, transA_, transB_, is_vec4_); + MatMulReadFnSource(shader, a, b, nullptr, transA_, transB_); } if (is_vec4_) { ORT_RETURN_IF_ERROR(MakeMatMulPackedVec4Source(shader, elements_per_thread, WorkgroupSizeX(), WorkgroupSizeY(), data_type, nullptr, transA_, transB_, alpha_, need_handle_matmul_, output_components_, /*tile_inner*/ 32, need_split_k, split_dim_inner_)); @@ -45,7 +45,7 @@ Status GemmProgram::GenerateShaderCode(ShaderHelper& shader) const { } const ProgramVariableDataType output_var_type = this->Outputs()[0].var_type; - MatMulWriteFnSource(shader, output, c, /* is_gemm = */ true, c_components_, output_components_, c_is_scalar_, /*activation_snippet*/ "", /*is_channels_last*/ false, need_split_k, output_var_type); + MatMulWriteFnSource(shader, output, c, /* is_gemm = */ true, c_components_, c_is_scalar_, /*activation_snippet*/ "", /*is_channels_last*/ false, need_split_k, output_var_type); return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc index ba7e9290f8455..0228fb25d1d26 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc @@ -49,7 +49,7 @@ void HandleMaybeBiasForMatMul(ShaderHelper& shader, shader.AdditionalImplementation() << " value = value + output_value_t(" << (is_channels_last ? bias->GetByOffset("colIn") : bias->GetByOffset("row")) << ");\n"; } shader.AdditionalImplementation() << " " << activation_snippet << "\n" - << output.SetByIndices("coords", "value") << "\n"; + << " " << output.SetByIndices("coords", "value") << "\n"; } void HandleMatMulWithSplitK( @@ -127,60 +127,61 @@ void MatMulReadFnSource(ShaderHelper& shader, const ShaderVariableHelper& b, const ShaderIndicesHelper* batch_dims, bool transA, - bool transB, - bool is_vec4) { - int components = is_vec4 ? 4 : 1; + bool transB) { + const int a_components = a.NumComponents(); const std::string data_type = "output_element_t"; - const std::string type_string = MakeScalarOrVectorType(components, data_type); + std::string type_string = MakeScalarOrVectorType(a_components, data_type); shader.AdditionalImplementation() << "fn mm_readA(batch: i32, row: i32, colIn: i32 " << (batch_dims ? ", batch_indices: batch_dims_indices_t" : "") - << ") -> " << type_string << " {\n " - << " var value = " << type_string << "(0);\n" - << " let col = colIn * " << components << ";\n"; + << ") -> " << type_string << " {\n" + << " var value = " << type_string << "(0);\n" + << " let col = colIn * " << a_components << ";\n"; if (transA) { - shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_a_outer)) {\n"; + shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_a_outer)) {\n"; } else { - shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_inner)) {\n"; + shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_inner)) {\n"; } - shader.AdditionalImplementation() << " var a_indices: a_indices_t;\n"; + shader.AdditionalImplementation() << " var a_indices: a_indices_t;\n"; if (batch_dims) { - shader.AdditionalImplementation() << ConvertOutputBatchIndicesToInputBatchIndices("a", a, a.Rank() - 2, batch_dims ? batch_dims->Rank() : 0, " batch_indices ") << "\n"; + shader.AdditionalImplementation() << ConvertOutputBatchIndicesToInputBatchIndices("a", a, a.Rank() - 2, batch_dims ? batch_dims->Rank() : 0, " batch_indices "); } - shader.AdditionalImplementation() << a.IndicesSet("a_indices", a.Rank() - 2, "u32(row)") << "\n" - << a.IndicesSet("a_indices", a.Rank() - 1, "u32(colIn)") << "\n" - << " value = " << a.GetByIndices("a_indices") << ";\n" - << " }\n" - << " return value;\n" + shader.AdditionalImplementation() << " " << a.IndicesSet("a_indices", a.Rank() - 2, "u32(row)") << "\n" + << " " << a.IndicesSet("a_indices", a.Rank() - 1, "u32(colIn)") << "\n" + << " value = " << a.GetByIndices("a_indices") << ";\n" + << " }\n" + << " return value;\n" << "}\n\n"; // Add the mm_readB function + const int b_components = b.NumComponents(); + type_string = MakeScalarOrVectorType(b_components, data_type); shader.AdditionalImplementation() << "fn mm_readB(batch: i32, row: i32, colIn: i32 " << (batch_dims ? ", batch_indices: batch_dims_indices_t" : "") - << ") -> " << type_string << " {\n " - << " var value = " << type_string << "(0);\n" - << " let col = colIn * " << components << ";\n"; + << ") -> " << type_string << " {\n" + << " var value = " << type_string << "(0);\n" + << " let col = colIn * " << b_components << ";\n"; if (transB) { - shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_b_outer) && col < i32(uniforms.dim_inner)) {\n"; + shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_b_outer) && col < i32(uniforms.dim_inner)) {\n"; } else { - shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_b_outer)) {\n"; + shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_b_outer)) {\n"; } - shader.AdditionalImplementation() << " var b_indices: b_indices_t;\n" + shader.AdditionalImplementation() << " var b_indices: b_indices_t;\n" << ConvertOutputBatchIndicesToInputBatchIndices("b", b, b.Rank() - 2, batch_dims ? batch_dims->Rank() : 0, "batch_indices") - << b.IndicesSet("b_indices", b.Rank() - 2, "u32(row)") << "\n" - << b.IndicesSet("b_indices", b.Rank() - 1, "u32(colIn)") << "\n" - << " value = " << b.GetByIndices("b_indices") << ";\n" - << " }\n" - << " return value;\n" + << " " << b.IndicesSet("b_indices", b.Rank() - 2, "u32(row)") << "\n" + << " " << b.IndicesSet("b_indices", b.Rank() - 1, "u32(colIn)") << "\n" + << " value = " << b.GetByIndices("b_indices") << ";\n" + << " }\n" + << " return value;\n" << "}\n\n"; } @@ -189,19 +190,19 @@ void MatMulWriteFnSource(ShaderHelper& shader, const ShaderVariableHelper* bias, bool is_gemm, int c_components, - int output_components, bool c_is_scalar, std::string activation_snippet, bool is_channels_last, bool use_split_k, ProgramVariableDataType output_variable_type) { + const int output_components = output.NumComponents(); shader.AdditionalImplementation() - << "fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: output_value_t) { \n"; + << "fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: output_value_t) {\n"; shader.AdditionalImplementation() << " let col = colIn * " << output_components << ";\n"; - shader.AdditionalImplementation() << "if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_b_outer)) { \n" - << " var value = valueIn; \n"; + shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_b_outer)) {\n" + << " var value = valueIn;\n"; if (use_split_k) { // Set output when MatMul is performed with Split-K. diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.h b/onnxruntime/core/providers/webgpu/math/gemm_utils.h index e001544f9e50d..49c3fbb8640a5 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.h +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.h @@ -13,15 +13,13 @@ void MatMulReadFnSource(ShaderHelper& shader, const ShaderVariableHelper& b, const ShaderIndicesHelper* batch_dims, bool transA, - bool transB, - bool is_vec4); + bool transB); void MatMulWriteFnSource(ShaderHelper& shader, const ShaderVariableHelper& output, const ShaderVariableHelper* bias, bool is_gemm, int c_components, - int output_components, bool c_is_scalar, std::string activation_snippet = "", bool is_channels_last = false, diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index ba32365bf9d88..b9afbc9bfecab 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -8,6 +8,7 @@ #include "core/providers/webgpu/webgpu_supported_types.h" #include "core/providers/webgpu/nn/fuse_utils.h" #include "core/providers/webgpu/data_transfer.h" +#include "core/providers/webgpu/vendor/intel/math/matmul.h" #include "core/providers/webgpu/webgpu_utils.h" namespace onnxruntime { @@ -163,6 +164,10 @@ Status MatMul::ComputeInternal(ComputeContext& context) const { inputs.push_back(bias); } + if (intel::CanApplyMatMulIntel(context, helper.M(), helper.N(), helper.K())) { + return intel::ApplyMatMulIntel(context, Activation(), inputs, output_tensor); + } + return ComputeMatMul(&context, Activation(), inputs, output_tensor, false); } diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc index e97e0fd6f1058..fb137f4755ed9 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc @@ -33,8 +33,8 @@ Status MatMulProgram::GenerateShaderCode(ShaderHelper& shader) const { std::string apply_activation = GetActivationSnippet(activation_, "output_value_t", "output_element_t"); ProgramVariableDataType output_var_type = this->Outputs()[0].var_type; // declare the read and write functions - MatMulReadFnSource(shader, a, b, &batch_dims, /*transA = */ false, /*transB = */ false, is_vec4_); - MatMulWriteFnSource(shader, output, bias, /* is_gemm = */ false, 1, is_vec4_ ? 4 : 1, false, apply_activation, is_channels_last_, need_split_k, output_var_type); + MatMulReadFnSource(shader, a, b, &batch_dims, /*transA = */ false, /*transB = */ false); + MatMulWriteFnSource(shader, output, bias, /* is_gemm = */ false, 1, false, apply_activation, is_channels_last_, need_split_k, output_var_type); std::string data_type = "a_element_t"; // generate the main function if (is_vec4_) { @@ -65,7 +65,7 @@ Status MatMulFillBiasOrZeroBeforeSplitKProgram::GenerateShaderCode(ShaderHelper& // `use_split_k` is true only when we do the actual MatMul with Split-K. const uint32_t bias_components = output_components_; MatMulWriteFnSource( - shader, output, bias, is_gemm_, bias_components, output_components_, bias_is_scalar_, + shader, output, bias, is_gemm_, bias_components, bias_is_scalar_, /*activation_snippet*/ "", /*is_channels_last*/ true, /*use_split_k*/ false); shader.MainFunctionBody() << " let output_components = " << output_components_ << ";\n"; diff --git a/onnxruntime/core/providers/webgpu/vendor/intel/math/gemm.cc b/onnxruntime/core/providers/webgpu/vendor/intel/math/gemm.cc new file mode 100644 index 0000000000000..699487b4c2270 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/vendor/intel/math/gemm.cc @@ -0,0 +1,121 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/vendor/intel/math/gemm.h" +#include "core/providers/webgpu/vendor/intel/math/gemm_subgroup.h" +#include "core/providers/webgpu/math/gemm_utils.h" + +namespace onnxruntime { +namespace webgpu { +namespace intel { + +Status GemmSubgroupProgram::GenerateShaderCode(ShaderHelper& shader) const { + const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | + ShaderUsage::UseValueTypeAlias | + ShaderUsage::UseElementTypeAlias); + + if (need_handle_matmul_) { + const auto& a = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | + ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& b = shader.AddInput("b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | + ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + + MatMulReadFnSource(shader, a, b, nullptr, transA_, transB_); + } + + ORT_RETURN_IF_ERROR(MakeMatMulSubgroupSource(shader, elements_per_thread_, nullptr, is_vec4_, transA_, transB_, + alpha_, need_handle_matmul_)); + const ShaderVariableHelper* c = nullptr; + if (need_handle_bias_) { + c = &shader.AddInput("c", ShaderUsage::UseUniform); + } + MatMulWriteFnSource(shader, output, c, true, c_components_, c_is_scalar_); + + return Status::OK(); +} + +bool CanApplyGemmIntel(const ComputeContext& context, int64_t M, int64_t N, int64_t K, bool transA, bool transB) { + return CanApplySubgroup(context, M, N, K, transA, transB); +} + +Status ApplyGemmIntel(const Tensor* a, + const Tensor* b, + const Tensor* c, + bool transA, + bool transB, + float alpha, + float beta, + ComputeContext& context) { + const auto& a_shape = a->Shape(); + const auto& b_shape = b->Shape(); + + uint32_t M = onnxruntime::narrow(transA ? a_shape[1] : a_shape[0]); + uint32_t K = onnxruntime::narrow(transA ? a_shape[0] : a_shape[1]); + uint32_t N = onnxruntime::narrow(transB ? b_shape[0] : b_shape[1]); + + std::vector output_dims{M, N}; + auto* y = context.Output(0, output_dims); + int64_t output_size = y->Shape().Size(); + + if (output_size == 0) { + return Status::OK(); + } + + // WebGPU doesn't support binding a zero-sized buffer, so we need to check if A or B is empty. + bool need_handle_matmul = a_shape.Size() > 0 && b_shape.Size() > 0; + bool need_handle_bias = c && beta; + + const bool is_vec4 = b_shape[1] % 4 == 0; + // Components for A, B + int a_components = 1; + int b_components = is_vec4 ? 4 : 1; + // Components for Y + int output_components = (is_vec4 && N % 4 == 0) ? 4 : 1; + // Components for C. + int c_components = 1; + + bool c_is_scalar = false; + if (need_handle_bias) { + const auto& c_shape = c->Shape(); + int64_t c_last_dim = c_shape[c_shape.NumDimensions() - 1]; + // `C` in GEMM might be broadcast to the output, and broadcasting requires the components to be consistent. + // So we use vec4 for C when its last dimension is N, and the output is also a vec4. + c_components = (c_last_dim == N && output_components == 4) ? 4 : 1; + c_is_scalar = c_shape.Size() == 1; + } + + InlinedVector elements_per_thread = InlinedVector({4, intel::ElementsPerThreadY(is_vec4, M), 1}); + const uint32_t dispatch_x = narrow((N + kSubgroupLogicalWorkGroupSizeX * elements_per_thread[0] - 1) / + (kSubgroupLogicalWorkGroupSizeX * elements_per_thread[0])); + const uint32_t dispatch_y = narrow((M + kSubgroupLogicalWorkGroupSizeY * elements_per_thread[1] - 1) / + (kSubgroupLogicalWorkGroupSizeY * elements_per_thread[1])); + + GemmSubgroupProgram program{transA, transB, alpha, need_handle_bias, need_handle_matmul, c_components, c_is_scalar, + is_vec4, elements_per_thread}; + + if (need_handle_matmul) { + program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_components}, + {b, ProgramTensorMetadataDependency::TypeAndRank, b_components}}); + } + + if (need_handle_bias) { + program.AddInput({c, ProgramTensorMetadataDependency::TypeAndRank, c_components}); + } + + program.CacheHint(alpha, transA, transB, c_is_scalar, absl::StrJoin(elements_per_thread, "-")) + .AddOutputs({{y, ProgramTensorMetadataDependency::TypeAndRank, output_components}}) + .SetDispatchGroupSize(dispatch_x, dispatch_y, 1) + .SetWorkgroupSize(kSubgroupLogicalWorkGroupSizeX * kSubgroupLogicalWorkGroupSizeY, 1, 1) + .AddUniformVariables({{alpha}, + {beta}, + {M}, /* dim_a_outer */ + {N}, /* dim_b_outer */ + {K}} /*dim_inner */ + ); + + return context.RunProgram(program); +} + +} // namespace intel +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/vendor/intel/math/gemm.h b/onnxruntime/core/providers/webgpu/vendor/intel/math/gemm.h new file mode 100644 index 0000000000000..1e6ac6a7e7514 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/vendor/intel/math/gemm.h @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { +namespace intel { + +class GemmSubgroupProgram final : public Program { + public: + GemmSubgroupProgram(bool transA, bool transB, float alpha, bool need_handle_bias, bool need_handle_matmul, + int c_components, bool c_is_scalar, bool is_vec4, + const gsl::span& elements_per_thread) + : Program{"GemmSubgroup"}, + transA_{transA}, + transB_{transB}, + alpha_{alpha}, + need_handle_bias_{need_handle_bias}, + need_handle_matmul_{need_handle_matmul}, + c_components_(c_components), + c_is_scalar_(c_is_scalar), + is_vec4_(is_vec4), + elements_per_thread_(elements_per_thread.begin(), elements_per_thread.end()) {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"alpha", ProgramUniformVariableDataType::Float32}, + {"beta", ProgramUniformVariableDataType::Float32}, + {"dim_a_outer", ProgramUniformVariableDataType::Uint32}, + {"dim_b_outer", ProgramUniformVariableDataType::Uint32}, + {"dim_inner", ProgramUniformVariableDataType::Uint32}); + + private: + bool transA_; + bool transB_; + float alpha_; + bool need_handle_bias_; + bool need_handle_matmul_; + int c_components_; + bool c_is_scalar_ = false; + bool is_vec4_ = false; + const InlinedVector elements_per_thread_; +}; + +bool CanApplyGemmIntel(const ComputeContext& context, int64_t M, int64_t N, int64_t K, bool transA, bool transB); + +Status ApplyGemmIntel(const Tensor* a, + const Tensor* b, + const Tensor* c, + bool transA, + bool transB, + float alpha, + float beta, + ComputeContext& context); + +} // namespace intel +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/vendor/intel/math/gemm_subgroup.cc b/onnxruntime/core/providers/webgpu/vendor/intel/math/gemm_subgroup.cc new file mode 100644 index 0000000000000..a6baf8dfb0239 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/vendor/intel/math/gemm_subgroup.cc @@ -0,0 +1,183 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/webgpu_utils.h" +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/string_macros.h" +#include "core/providers/webgpu/vendor/intel/math/gemm_subgroup.h" + +namespace onnxruntime { +namespace webgpu { +namespace intel { + +namespace { + +std::string LoadAStr(const ShaderIndicesHelper* batch_dims, int64_t elements_per_thread_y) { + SS(load_a_ss, 128); + for (int64_t i = 0; i < elements_per_thread_y; i++) { + load_a_ss << " a_val_" << i << " = " << std::string("mm_readA(batch, globalRowStart + ") + << i << std::string(", aCol") + (batch_dims ? ", batchIndices" : "") + ");\n"; + } + return SS_GET(load_a_ss); +} + +// Load one tile of B into local memory. +std::string LoadBStr(const ShaderIndicesHelper* batch_dims, int64_t tile_b_outer, bool is_vec4) { + SS(load_b_ss, 256); + load_b_ss << " let loadRowsPerThread = " << kSubgroupLogicalWorkGroupSizeX / kSubgroupLogicalWorkGroupSizeY << ";\n" + << " for (var innerRow = 0; innerRow < loadRowsPerThread; innerRow++) {\n" + << " let inputRow = loadRowsPerThread * localRow + innerRow;\n" + << " let inputCol = tileCol;\n"; + if (is_vec4) { + load_b_ss << " mm_Bsub[inputRow][inputCol] = mm_readB(batch, kStart + inputRow, globalColStart" + << (batch_dims ? ", batchIndices" : "") << ");\n"; + } else { + for (int j = 0; j < tile_b_outer; j += kSubgroupLogicalWorkGroupSizeX) { + load_b_ss << " mm_Bsub[inputRow][inputCol + " << j << "] = mm_readB(batch, kStart + inputRow, globalColStart + " + << j << (batch_dims ? ", batchIndices" : "") << ");\n"; + } + } + load_b_ss << " }\n" + << " workgroupBarrier();\n"; + + return SS_GET(load_b_ss); +} + +std::string LoadBCacheStr(bool is_vec4, uint32_t offset) { + SS(b_cache_ss, 256); + if (is_vec4) { + b_cache_ss << "BCache = mm_Bsub[" << offset << "][tileCol];\n"; + } else { + b_cache_ss << "BCache = vec4(mm_Bsub[" << offset << "][tileCol], " + << "mm_Bsub[" << offset << "][tileCol + " << kSubgroupLogicalWorkGroupSizeX << "], " + << "mm_Bsub[" << offset << "][tileCol + " << 2 * kSubgroupLogicalWorkGroupSizeX << "], " + << "mm_Bsub[" << offset << "][tileCol + " << 3 * kSubgroupLogicalWorkGroupSizeX << "]);\n"; + } + return SS_GET(b_cache_ss); +} + +std::string CalculateAccStr(const ShaderIndicesHelper* batch_dims, int64_t elements_per_thread_y, bool is_vec4) { + SS(cal_acc_ss, 1024); + + // key: simd size; value: the offset row of mm_Bsub. + std::map> simd_map = { + {32, {0}}, + {16, {0, 16}}, + {8, {0, 8, 16, 24}}}; + for (const auto& [simd, offsets] : simd_map) { + cal_acc_ss << " if (sg_size == " << simd << ") {\n"; + for (uint32_t offset : offsets) { + cal_acc_ss << LoadAStr(batch_dims, elements_per_thread_y) + << " aCol += " << simd << ";\n"; + for (uint32_t sg_idx = 0; sg_idx < simd; sg_idx++) { + cal_acc_ss << " " << LoadBCacheStr(is_vec4, sg_idx + offset); + for (uint32_t i = 0; i < elements_per_thread_y; i++) { + cal_acc_ss << " acc_" << i << " += subgroupBroadcast(a_val_" << i << ", " << sg_idx << ") * BCache;\n"; + } + } + } + cal_acc_ss << " }\n"; + } + + return SS_GET(cal_acc_ss); +} + +} // namespace + +bool CanApplySubgroup(const ComputeContext& context, int64_t M, int64_t N, int64_t K, bool transA, bool transB) { + if (context.AdapterInfo().vendor == std::string_view{"intel"}) { + bool use_subgroup = context.HasFeature(wgpu::FeatureName::Subgroups) && + M >= 64 && N >= 512 && K >= 32 && !transA && !transB; + return use_subgroup; + } + + return false; +} + +int64_t ElementsPerThreadY(bool is_vec4, uint32_t M) { + return is_vec4 ? (M <= 8 ? 1 : (M <= 16 ? 2 : (M <= 32 ? 4 : 8))) : 4; +} + +Status MakeMatMulSubgroupSource(ShaderHelper& shader, + const InlinedVector& elements_per_thread, + const ShaderIndicesHelper* batch_dims, + bool is_vec4, + bool transpose_a, + bool transpose_b, + float alpha, + bool need_handle_matmul) { + ORT_UNUSED_PARAMETER(transpose_a); + ORT_UNUSED_PARAMETER(transpose_b); + + // elements per thread + const auto elements_per_thread_x = elements_per_thread[0]; + const auto elements_per_thread_y = elements_per_thread[1]; + + const auto tile_a_outer = kSubgroupLogicalWorkGroupSizeY * elements_per_thread_y; + const auto tile_b_outer = kSubgroupLogicalWorkGroupSizeX * elements_per_thread_x; + + shader.AdditionalImplementation() + << "var mm_Bsub: array, 32>;\n"; + + shader.MainFunctionBody() + << " let workgroupIdXStride = (uniforms.dim_b_outer - 1) / " << tile_b_outer << " + 1;\n" + << " let workgroupIdYStride = (uniforms.dim_a_outer - 1) / " << tile_a_outer << " + 1;\n" + << " let batch = i32(workgroup_idx / (workgroupIdXStride * workgroupIdYStride));\n" + << " let workgroupIdXY = workgroup_idx % (workgroupIdXStride * workgroupIdYStride);\n" + << " let workgroupIdX = workgroupIdXY % workgroupIdXStride;\n" + << " let workgroupIdY = workgroupIdXY / workgroupIdXStride;\n" + << " let tileRow = i32(local_id.x / " << kSubgroupLogicalWorkGroupSizeX << ") * " << elements_per_thread_y << ";\n" + << " let tileCol = i32(local_id.x % " << kSubgroupLogicalWorkGroupSizeX << ");\n" + << " let localRow = i32(local_id.x / " << kSubgroupLogicalWorkGroupSizeX << ");\n" + << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "") + << " let globalRowStart = i32(workgroupIdY) * " << tile_a_outer << " + tileRow;\n" + << " let globalColStart = i32(workgroupIdX) * " << (is_vec4 ? tile_b_outer / elements_per_thread_x : tile_b_outer) << " + tileCol;\n" + << " let numTiles = (uniforms.dim_inner - 1) / 32 + 1;\n" + << " var kStart = 0;\n" + << " var aCol = 0;\n" + << " var BCache: vec4;\n"; + + for (uint32_t i = 0; i < elements_per_thread_y; i++) { + shader.MainFunctionBody() << " var acc_" << i << " = vec4(0);\n" + << " var a_val_" << i << " = a_value_t(0);\n"; + } + + if (need_handle_matmul) { + shader.MainFunctionBody() << " for (var t = 0; t < i32(numTiles); t++) {\n" + << LoadBStr(batch_dims, tile_b_outer, is_vec4) + << " aCol = kStart + tileCol % i32(sg_size);\n" + << CalculateAccStr(batch_dims, elements_per_thread_y, is_vec4) + << " kStart = kStart + 32;\n" + << " workgroupBarrier();\n" + << " }\n"; // main for loop + + // Calculate alpha * acc + if (alpha != 1.0f) { + for (uint32_t i = 0; i < elements_per_thread_y; i++) { + shader.MainFunctionBody() << " acc_" << i << " *= output_element_t(uniforms.alpha);\n"; + } + } + } + + // Write the results to the output buffer + if (is_vec4) { + for (uint32_t i = 0; i < elements_per_thread_y; i++) { + shader.MainFunctionBody() << " mm_write(batch, globalRowStart + " << i + << ", globalColStart, acc_" << i << ");\n"; + } + } else { + for (uint32_t i = 0; i < elements_per_thread_y; i++) { + for (uint32_t j = 0; j < elements_per_thread_x; j++) { + shader.MainFunctionBody() << " " + << "mm_write(batch, globalRowStart + " << i << ", globalColStart + " + << j * kSubgroupLogicalWorkGroupSizeX << ", acc_" << i << "[" << j << "]);\n"; + } + } + } + + return Status::OK(); +} + +} // namespace intel +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/vendor/intel/math/gemm_subgroup.h b/onnxruntime/core/providers/webgpu/vendor/intel/math/gemm_subgroup.h new file mode 100644 index 0000000000000..89dca023d3e1b --- /dev/null +++ b/onnxruntime/core/providers/webgpu/vendor/intel/math/gemm_subgroup.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/shader_helper.h" + +namespace onnxruntime { +namespace webgpu { +namespace intel { + +const uint32_t kSubgroupLogicalWorkGroupSizeX = 32; +const uint32_t kSubgroupLogicalWorkGroupSizeY = 8; +const uint32_t kSubgroupLogicalWorkGroupSizeZ = 1; + +bool CanApplySubgroup(const ComputeContext& context, int64_t M, int64_t N, int64_t K, bool transA = false, bool transB = false); + +int64_t ElementsPerThreadY(bool is_vec4, uint32_t M); + +Status MakeMatMulSubgroupSource(ShaderHelper& shader, + const InlinedVector& elements_per_thread, + const ShaderIndicesHelper* batch_dims, + bool is_vec4, + bool transpose_a = false, + bool transpose_b = false, + float alpha = 1.0f, + bool need_handle_matmul = true); + +} // namespace intel +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/vendor/intel/math/matmul.cc b/onnxruntime/core/providers/webgpu/vendor/intel/math/matmul.cc new file mode 100644 index 0000000000000..20874522daa20 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/vendor/intel/math/matmul.cc @@ -0,0 +1,135 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cpu/math/matmul_helper.h" +#include "core/providers/webgpu/webgpu_utils.h" +#include "core/providers/webgpu/math/matmul_utils.h" +#include "core/providers/webgpu/vendor/intel/math/gemm_subgroup.h" +#include "core/providers/webgpu/math/gemm_utils.h" +#include "core/providers/webgpu/vendor/intel/math/matmul.h" + +namespace onnxruntime { +namespace webgpu { +namespace intel { + +Status MatMulSubgroupProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& a = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | + ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& b = shader.AddInput("b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | + ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | + ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& batch_dims = shader.AddIndices("batch_dims", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + + const ShaderVariableHelper* bias = nullptr; + if (has_bias_) { + bias = &shader.AddInput("bias", ShaderUsage::UseUniform); + } + std::string apply_activation = GetActivationSnippet(activation_, "output_value_t", "output_element_t"); + // declare the read and write functions + MatMulReadFnSource(shader, a, b, &batch_dims, /*transA = */ false, /*transB = */ false); + MatMulWriteFnSource(shader, output, bias, /* is_gemm = */ false, 1, + false, apply_activation, /*is_channels_last = */ false); + // generate the main function + ORT_RETURN_IF_ERROR(MakeMatMulSubgroupSource(shader, elements_per_thread_, &batch_dims, is_vec4_)); + return Status::OK(); +} + +bool CanApplyMatMulIntel(const ComputeContext& context, int64_t M, int64_t N, int64_t K) { + return CanApplySubgroup(context, M, N, K); +} + +Status ApplyMatMulIntel(ComputeContext& context, + const Activation& activation, + std::vector& inputs, + Tensor* output) { + const auto* a = inputs[0]; + const auto* b = inputs[1]; + bool has_bias = inputs.size() > 2; + TensorShape a_shape = a->Shape(); + TensorShape b_shape = b->Shape(); + + MatMulComputeHelper helper; + ORT_THROW_IF_ERROR(helper.Compute(a_shape, b_shape)); + int64_t batchA = a_shape.SizeToDimension(a_shape.NumDimensions() - 2); + int64_t batchB = b_shape.SizeToDimension(b_shape.NumDimensions() - 2); + + TensorShape output_shape = helper.OutputShape(); + + const int64_t dim_output_outer = output_shape[output_shape.NumDimensions() - 2]; + // check if A is batch of vector (bach is not 1, M is 1) and B is a matrix (batch is 1) + if (batchA != 1 && dim_output_outer == 1 && batchB == 1) { + // optimization for batched vector matrix multiplication + // dimensions of A: [1,`batchA`,K] + TensorShapeVector dims_a = {1, batchA, helper.K()}; + // dimensions of B: [1,K,N] + TensorShapeVector dims_b = {1, helper.K(), helper.N()}; + + a_shape = TensorShape(dims_a); + b_shape = TensorShape(dims_b); + output_shape = {1, batchA, helper.N()}; + } + + // helpful dimension variables + TensorShape outer_dims_a = a_shape.NumDimensions() > 2 + ? a_shape.Slice(0, a_shape.NumDimensions() - 2) + : TensorShape({}); + + TensorShape outer_dims_b = b_shape.NumDimensions() > 2 + ? b_shape.Slice(0, b_shape.NumDimensions() - 2) + : TensorShape({}); + + TensorShape outer_dims = output_shape.NumDimensions() > 2 + ? output_shape.Slice(0, output_shape.NumDimensions() - 2) + : TensorShape({}); + + const int64_t batch_size = outer_dims.Size(); + + // Get dimensions for matrix multiplication from TensorShape + const uint32_t dim_a_outer = narrow(a_shape[a_shape.NumDimensions() - 2]); // left matrix second dimension + const uint32_t dim_inner = narrow(a_shape[a_shape.NumDimensions() - 1]); // left matrix first dimension + const uint32_t dim_b_outer = narrow(b_shape[b_shape.NumDimensions() - 1]); // right matrix first dimension + + // Always access A with 1-component when using subgroup. + const bool is_vec4 = dim_b_outer % 4 == 0; + InlinedVector elements_per_thread = InlinedVector({4, intel::ElementsPerThreadY(is_vec4, dim_a_outer), 1}); + + const uint32_t dispatch_x = narrow((dim_b_outer + kSubgroupLogicalWorkGroupSizeX * elements_per_thread[0] - 1) / + (kSubgroupLogicalWorkGroupSizeX * elements_per_thread[0])); + const uint32_t dispatch_y = narrow((dim_a_outer + kSubgroupLogicalWorkGroupSizeY * elements_per_thread[1] - 1) / + (kSubgroupLogicalWorkGroupSizeY * elements_per_thread[1])); + const uint32_t dispatch_z = narrow((static_cast(batch_size) + + kSubgroupLogicalWorkGroupSizeZ * elements_per_thread[2] - 1) / + (kSubgroupLogicalWorkGroupSizeZ * elements_per_thread[2])); + + const int components = is_vec4 ? 4 : 1; + const int a_components = 1; + const int b_components = components; + const TensorShape a_shape_temp = CreateMatMulIntermediateShape(outer_dims_a, dim_a_outer, dim_inner, a_components); + const TensorShape b_shape_temp = CreateMatMulIntermediateShape(outer_dims_b, dim_inner, dim_b_outer, b_components); + const TensorShape output_shape_temp = TensorShape({batch_size, dim_a_outer, dim_b_outer / components}); + + MatMulSubgroupProgram program{activation, has_bias, is_vec4, elements_per_thread}; + program + .CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-")) + .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, a_components}, + {b, ProgramTensorMetadataDependency::TypeAndRank, b_shape_temp, b_components}}) + .AddOutputs({{output, ProgramTensorMetadataDependency::Rank, output_shape_temp, components}}) + .AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}}) + .AddIndices(outer_dims) + .SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z) + .SetWorkgroupSize(kSubgroupLogicalWorkGroupSizeX * kSubgroupLogicalWorkGroupSizeY, 1, 1); + + if (has_bias) { + auto bias_components = 1; + const auto* bias = inputs[2]; + TensorShape reduced_bias_shape = ReduceShapeByComponents(bias->Shape(), bias_components); + program.AddInput({bias, ProgramTensorMetadataDependency::Rank, reduced_bias_shape, bias_components}); + } + + return context.RunProgram(program); +} + +} // namespace intel +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/vendor/intel/math/matmul.h b/onnxruntime/core/providers/webgpu/vendor/intel/math/matmul.h new file mode 100644 index 0000000000000..2a8333e3e912b --- /dev/null +++ b/onnxruntime/core/providers/webgpu/vendor/intel/math/matmul.h @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/nn/fuse_utils.h" + +namespace onnxruntime { +namespace webgpu { +namespace intel { + +class MatMulSubgroupProgram final : public Program { + public: + MatMulSubgroupProgram(const Activation& activation, + bool bias, + bool is_vec4, + const gsl::span& elements_per_thread) + : Program{"MatMulSubgroup"}, + activation_(activation), + has_bias_{bias}, + is_vec4_{is_vec4}, + elements_per_thread_(elements_per_thread.begin(), elements_per_thread.end()) {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"dim_a_outer", ProgramUniformVariableDataType::Uint32}, + {"dim_b_outer", ProgramUniformVariableDataType::Uint32}, + {"dim_inner", ProgramUniformVariableDataType::Uint32}); + + private: + const Activation activation_; + const bool has_bias_; + const bool is_vec4_; + const InlinedVector elements_per_thread_; +}; + +bool CanApplyMatMulIntel(const ComputeContext& context, int64_t M, int64_t N, int64_t K); + +Status ApplyMatMulIntel(ComputeContext& context, + const Activation& activation, + std::vector& inputs, + Tensor* output); + +} // namespace intel +} // namespace webgpu +} // namespace onnxruntime