Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/webgpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ Status PrepareQKV(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtte
std::vector<const Tensor*> matmul_inputs = {input, weights, bias};

// Call MatMul: packed_qkv = input * weights + bias
ORT_RETURN_IF_ERROR(onnxruntime::webgpu::ComputeMatMul(&context, Activation(), matmul_inputs, &packed_qkv, true));
ORT_RETURN_IF_ERROR(onnxruntime::webgpu::ComputeMatMul(&context, Activation(), matmul_inputs, &packed_qkv));

// Output Q, K, V in BSD format
return SplitPackedQKV(context, parameters, &packed_qkv, q, k, v, parameters.hidden_size_);
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/webgpu/math/gemm_packed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ Status ApplyGemmPacked(const Tensor* a,
// Currently we require the components for Y must also be a multiple of 4 when Split-K is used.
const bool output_is_vec4 = output_components == 4;
// The parameter `is_channel_last` is not used for GEMM.
const bool need_split_k = split_k_config.UseSplitK(is_vec4 && output_is_vec4, ActivationKind::None, /*batch_size*/ 1, /*is_gemm*/ true, /*is_channels_last*/ true, M, N, K);
const bool need_split_k = split_k_config.UseSplitK(is_vec4 && output_is_vec4, ActivationKind::None, /*batch_size*/ 1, need_handle_bias, /*is_gemm*/ true, /*is_channels_last*/ true, M, N, K);
if (need_split_k) {
const Tensor* bias = nullptr;
uint32_t output_components_in_fill_bias_program = 4;
Expand Down
9 changes: 6 additions & 3 deletions onnxruntime/core/providers/webgpu/math/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ Status MatMul::ComputeInternal(ComputeContext& context) const {
return intel::ApplyMatMulIntel(context, Activation(), inputs, output_tensor);
}

return ComputeMatMul(&context, Activation(), inputs, output_tensor, false);
return ComputeMatMul(&context, Activation(), inputs, output_tensor);
}

Status ComputeMatMul(ComputeContext* context,
Expand Down Expand Up @@ -248,11 +248,14 @@ Status ComputeMatMul(ComputeContext* context,
// Current Split-K implementation relies on atomic operations, which are not deterministic.
if (!context->KernelContext().GetUseDeterministicCompute()) {
const SplitKConfig& split_k_config = context->GetSplitKConfig();
const bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, /*is_gemm*/ false, is_channels_last, dim_a_outer, dim_b_outer, dim_inner);
const bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, has_bias, /*is_gemm*/ false, is_channels_last, dim_a_outer, dim_b_outer, dim_inner);
if (need_split_k) {
ORT_ENFORCE(batch_size == 1, "Split-K MatMul only supports batch_size == 1.");
ORT_ENFORCE(is_vec4, "Split-K MatMul only supports bias in vec4 format.");
Comment thread
Jiawei-Shao marked this conversation as resolved.
Outdated
ORT_ENFORCE(is_channels_last, "Split-K MatMul only supports channels-last format.");

if (has_bias) {
ORT_ENFORCE(is_channels_last, "Split-K MatMul only supports channels-last format.");
Comment thread
Jiawei-Shao marked this conversation as resolved.
}

// Initialize `output_tensor` with 0 or bias before MatMulProgram with Split-K enabled.
const auto fill_bias_program = CreateMatMulFillBiasOrZeroBeforeSplitKProgram(bias, output_tensor, /*is_gemm*/ false, /*beta*/ 1.0f, /*bias_components*/ 4, output_shape_temp);
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/webgpu/math/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
namespace onnxruntime {
namespace webgpu {

Status ComputeMatMul(ComputeContext* context, const Activation& activation, std::vector<const Tensor*>& inputs, Tensor* output, bool is_channels_last,
Status ComputeMatMul(ComputeContext* context, const Activation& activation, std::vector<const Tensor*>& inputs, Tensor* output, bool is_channels_last = true,
const TensorShape& input_a_reshape = TensorShape(),
const TensorShape& input_b_reshape = TensorShape());

Expand Down
12 changes: 9 additions & 3 deletions onnxruntime/core/providers/webgpu/webgpu_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ bool SplitKConfig::UseSplitK(
bool is_vec4,
ActivationKind activation_kind,
uint64_t batch_size,
bool has_bias,
bool is_gemm,
bool is_channels_last,
uint32_t dim_a_outer,
Expand All @@ -89,9 +90,14 @@ bool SplitKConfig::UseSplitK(
use_split_k &= activation_kind == ActivationKind::None;
use_split_k &= is_vec4;
use_split_k &= batch_size == 1;
// Now `is_channels_last` is only supported because we only generate vec4 shaders in
// `MatMulFillBiasOrZeroBeforeSplitKProgram` when `is_gemm` is false.
use_split_k &= (is_channels_last || is_gemm);

// Now we only need `is_channels_last` in `Conv|MatMul` with `bias`. We don't need to care about
// it in other places (`GEMM`, `MatMul` and `Conv|MatMul` without `bias`).
// When `is_channels_last` has valid value `is_channels_last` is required to be true because
// we only generate `vec4` shaders in `MatMulFillBiasOrZeroBeforeSplitKProgram`.
Comment thread
Jiawei-Shao marked this conversation as resolved.
Outdated
if (has_bias && !is_gemm) {
Comment thread
Jiawei-Shao marked this conversation as resolved.
Outdated
use_split_k &= is_channels_last;
}

// Split-K works best when `dim_inner` is relatively large compared with `dim_a_outer` and
// `dim_b_outer`. Currently we use the factor between `(dim_a_outer * dim_b_outer)` and
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/webgpu/webgpu_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ class SplitKConfig {
explicit SplitKConfig(const wgpu::AdapterInfo& adapter_info);

bool UseSplitK(
bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, bool is_gemm,
bool is_channels_last, uint32_t dim_a_outer,
bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, bool has_bias,
bool is_gemm, bool is_channels_last, uint32_t dim_a_outer,
uint32_t dim_b_outer, uint32_t dim_inner) const;

uint32_t GetSplitDimInner() const;
Expand Down
Loading