diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 122cb5396f5c1..755bd0c60452f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -656,7 +656,7 @@ Status PrepareQKV(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtte std::vector matmul_inputs = {input, weights, bias}; // Call MatMul: packed_qkv = input * weights + bias - ORT_RETURN_IF_ERROR(onnxruntime::webgpu::ComputeMatMul(&context, Activation(), matmul_inputs, &packed_qkv, true)); + ORT_RETURN_IF_ERROR(onnxruntime::webgpu::ComputeMatMul(&context, Activation(), matmul_inputs, &packed_qkv)); // Output Q, K, V in BSD format return SplitPackedQKV(context, parameters, &packed_qkv, q, k, v, parameters.hidden_size_); diff --git a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc index 45c462209e30b..79a4f1f73902b 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc @@ -113,8 +113,8 @@ Status ApplyGemmPacked(const Tensor* a, const SplitKConfig& split_k_config = context.GetSplitKConfig(); // Currently we require the components for Y must also be a multiple of 4 when Split-K is used. const bool output_is_vec4 = output_components == 4; - // The parameter `is_channel_last` is not used for GEMM. - const bool need_split_k = split_k_config.UseSplitK(is_vec4 && output_is_vec4, ActivationKind::None, /*batch_size*/ 1, /*is_gemm*/ true, /*is_channels_last*/ true, M, N, K); + // We need to use `true` as `is_channels_last` to meet the requirement in `UseSplitK`. + const bool need_split_k = split_k_config.UseSplitK(is_vec4 && output_is_vec4, ActivationKind::None, /*batch_size*/ 1, M, N, K); if (need_split_k) { const Tensor* bias = nullptr; uint32_t output_components_in_fill_bias_program = 4; diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index 7bbc84d2c3c87..af488f2c23a30 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -168,7 +168,7 @@ Status MatMul::ComputeInternal(ComputeContext& context) const { return intel::ApplyMatMulIntel(context, Activation(), inputs, output_tensor); } - return ComputeMatMul(&context, Activation(), inputs, output_tensor, false); + return ComputeMatMul(&context, Activation(), inputs, output_tensor); } Status ComputeMatMul(ComputeContext* context, @@ -248,11 +248,14 @@ Status ComputeMatMul(ComputeContext* context, // Current Split-K implementation relies on atomic operations, which are not deterministic. if (!context->KernelContext().GetUseDeterministicCompute()) { const SplitKConfig& split_k_config = context->GetSplitKConfig(); - const bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, /*is_gemm*/ false, is_channels_last, dim_a_outer, dim_b_outer, dim_inner); + const bool need_split_k = split_k_config.UseSplitK(is_vec4, activation.activation_kind_, batch_size, dim_a_outer, dim_b_outer, dim_inner, is_channels_last); if (need_split_k) { ORT_ENFORCE(batch_size == 1, "Split-K MatMul only supports batch_size == 1."); - ORT_ENFORCE(is_vec4, "Split-K MatMul only supports bias in vec4 format."); - ORT_ENFORCE(is_channels_last, "Split-K MatMul only supports channels-last format."); + ORT_ENFORCE(is_vec4, "Split-K MatMul requires vec4 packing."); + + if (has_bias) { + ORT_ENFORCE(is_channels_last, "Split-K MatMul only supports channels-last format."); + } // Initialize `output_tensor` with 0 or bias before MatMulProgram with Split-K enabled. const auto fill_bias_program = CreateMatMulFillBiasOrZeroBeforeSplitKProgram(bias, output_tensor, /*is_gemm*/ false, /*beta*/ 1.0f, /*bias_components*/ 4, output_shape_temp); diff --git a/onnxruntime/core/providers/webgpu/math/matmul.h b/onnxruntime/core/providers/webgpu/math/matmul.h index 3135c1f8a2457..d15e36ffa3d85 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.h +++ b/onnxruntime/core/providers/webgpu/math/matmul.h @@ -14,7 +14,7 @@ namespace onnxruntime { namespace webgpu { -Status ComputeMatMul(ComputeContext* context, const Activation& activation, std::vector& inputs, Tensor* output, bool is_channels_last, +Status ComputeMatMul(ComputeContext* context, const Activation& activation, std::vector& inputs, Tensor* output, bool is_channels_last = true, const TensorShape& input_a_reshape = TensorShape(), const TensorShape& input_b_reshape = TensorShape()); diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.cc b/onnxruntime/core/providers/webgpu/webgpu_utils.cc index e48c8c833c7e3..5127801ca8451 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.cc @@ -74,11 +74,10 @@ bool SplitKConfig::UseSplitK( bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, - bool is_gemm, - bool is_channels_last, uint32_t dim_a_outer, uint32_t dim_b_outer, - uint32_t dim_inner) const { + uint32_t dim_inner, + bool is_channels_last) const { if (!enable_split_k_) { return false; } @@ -89,9 +88,13 @@ bool SplitKConfig::UseSplitK( use_split_k &= activation_kind == ActivationKind::None; use_split_k &= is_vec4; use_split_k &= batch_size == 1; - // Now `is_channels_last` is only supported because we only generate vec4 shaders in - // `MatMulFillBiasOrZeroBeforeSplitKProgram` when `is_gemm` is false. - use_split_k &= (is_channels_last || is_gemm); + + // `is_channels_last` should only affect Split-K gating when bias is applied in the non-GEMM + // MatMul/Conv|MatMul path. For GEMM and for MatMul or Conv|MatMul without bias, we need to + // use `true` as `is_channels_last` to make `UseSplitK` ignore `is_channels_last`. + // When `is_channels_last` has a valid value here, it is required to be true because we only + // generate `vec4` shaders in `MatMulFillBiasOrZeroBeforeSplitKProgram`. + use_split_k &= is_channels_last; // Split-K works best when `dim_inner` is relatively large compared with `dim_a_outer` and // `dim_b_outer`. Currently we use the factor between `(dim_a_outer * dim_b_outer)` and diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.h b/onnxruntime/core/providers/webgpu/webgpu_utils.h index 6c72fd07938d5..cbceaf2be120d 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.h +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.h @@ -106,9 +106,8 @@ class SplitKConfig { explicit SplitKConfig(const wgpu::AdapterInfo& adapter_info); bool UseSplitK( - bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, bool is_gemm, - bool is_channels_last, uint32_t dim_a_outer, - uint32_t dim_b_outer, uint32_t dim_inner) const; + bool is_vec4, ActivationKind activation_kind, uint64_t batch_size, + uint32_t dim_a_outer, uint32_t dim_b_outer, uint32_t dim_inner, bool is_channels_last = true) const; uint32_t GetSplitDimInner() const;