From 8fba36163c5d2f8aa3dde8ec999731e7aff19e49 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Fri, 17 Apr 2026 09:43:13 +0800 Subject: [PATCH 1/8] webgpu: Refactor SubgroupMatrixMatMulNBits to vendor-agnostic config + add NVIDIA 16x16x16 Refactor subgroup matrix MatMulNBits support from vendor-specific (Apple/Intel) to a vendor-agnostic config-based approach. Any GPU reporting a matching subgroup matrix config from Dawn is now automatically supported. Key changes: - Replace vendor-specific config table with SupportedSubgroupMatrixConfig struct containing {componentType, resultComponentType, M, N, K, subgroupMinSize, subgroupMaxSize, needsPrepack}. No architecture or backendType required. - Remove vendor_ member from SubgroupMatrixMatMulNBitsProgram. Shader selection is now driven by config dimensions (8x8x8, 8x16x16, 16x16x16). - Remove vendor gate in matmul_nbits.cc call site. - Rename shader templates: _apple -> _8x8x8, _intel -> _8x16x16. - Add new 16x16x16 shader template for NVIDIA Blackwell (RTX 5080). - 4 subgroups x 32 lanes = 128 threads per workgroup - 64x64 tile with 16x16 subgroup matrices - Bounds-checked output via scratch buffer for partial M tiles - Fix prepack shader OOB reads: add scalar fallback with zero-fill for partial blocks where M is not a multiple of kSgMatM. - Prioritize larger configs (16x16x16 > 8x16x16 > 8x8x8) when multiple match. Verified on NVIDIA RTX 5080 (Blackwell, Vulkan backend): - Correctness: model-qa.py with phi4-graph-prune produces identical output to D3D12 baseline - Prefill (phi4, l=1024): - D3D12 DP4A baseline: 3,006 tps - Vulkan DP4A baseline: 6,155 tps - Vulkan tensor core (this change): 6,759 tps (+10% vs Vulkan DP4A, +125% vs D3D12) - NVIDIA reports ChromiumExperimentalSubgroupMatrix with F16/F16 16x16x16 config --- .../webgpu/quantization/matmul_nbits.cc | 6 +- .../subgroup_matrix_matmul_nbits.cc | 137 ++++++---- .../subgroup_matrix_matmul_nbits.h | 6 +- ...matrix_matmul_nbits_16x16x16.wgsl.template | 252 ++++++++++++++++++ ...matrix_matmul_nbits_8x16x16.wgsl.template} | 0 ...p_matrix_matmul_nbits_8x8x8.wgsl.template} | 0 ..._matrix_matmul_nbits_prepack.wgsl.template | 27 +- 7 files changed, 369 insertions(+), 59 deletions(-) create mode 100644 onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_16x16x16.wgsl.template rename onnxruntime/contrib_ops/webgpu/quantization/{subgroup_matrix_matmul_nbits_intel.wgsl.template => subgroup_matrix_matmul_nbits_8x16x16.wgsl.template} (100%) rename onnxruntime/contrib_ops/webgpu/quantization/{subgroup_matrix_matmul_nbits_apple.wgsl.template => subgroup_matrix_matmul_nbits_8x8x8.wgsl.template} (100%) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 4eb8f5b9a17dd..b4e3991344089 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -172,7 +172,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context * @return Status indicating whether the operation was successful or if an error occurred. * * @note Special optimizations are considered: - * - Subgroup matrix multiplication for eligible Apple/Intel GPUs. + * - Subgroup matrix multiplication for GPUs with supported configs. * - DP4A-based multiplication on FP32-only GPUs for specific dimensions and conditions. * - A wide tile program is used when block size, component count, and other criteria are met. * - Otherwise, a default matmul program is used. @@ -227,8 +227,8 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, #if !defined(__wasm__) int32_t subgroup_matrix_config_index = -1; - // apple|intel - Experimental dawn support for subgroup matrix matmul. - if ((M >= kMinMForTileOptimization && !has_weight_idx_indirect) && (context.AdapterInfo().vendor == std::string_view{"apple"} || context.AdapterInfo().vendor == std::string_view{"intel"}) && + // Experimental dawn support for subgroup matrix matmul (vendor-agnostic). + if ((M >= kMinMForTileOptimization && !has_weight_idx_indirect) && CanApplySubgroupMatrixMatMulNBits(context, accuracy_level, block_size, batch_count, N, K, static_cast(nbits), y->DataType() == DataTypeImpl::GetType(), subgroup_matrix_config_index)) { return ApplySubgroupMatrixMatMulNBits(a, b, scales, zero_points, bias, M, N, K, static_cast(nbits), zero_blocks_per_col, subgroup_matrix_config_index, context, y, weight_index, weight_index_indirect); } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc index 15aa74c6dcd57..a381574756301 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc @@ -2,7 +2,6 @@ // Licensed under the MIT License. #if !defined(__wasm__) -#include #include "contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h" #include "contrib_ops/webgpu/quantization/matmul_nbits_common.h" @@ -45,25 +44,45 @@ static_assert(ValidateComponentTypeName<4>({wgpu::SubgroupMatrixComponentType::F wgpu::SubgroupMatrixComponentType::I32}), "The elements' sequence of ComponentTypeName array do not match wgpu::SubgroupMatrixComponentType"); -// std::tuple -static const std::tuple - intel_supported_subgroup_matrix_configs[] = { - {"xe-2lpg", wgpu::BackendType::Vulkan, wgpu::SubgroupMatrixComponentType::F16, wgpu::SubgroupMatrixComponentType::F16, 8, 16, 16, 16, 32}, - {"xe-3lpg", wgpu::BackendType::Vulkan, wgpu::SubgroupMatrixComponentType::F16, wgpu::SubgroupMatrixComponentType::F16, 8, 16, 16, 16, 32}}; +// Vendor-agnostic subgroup matrix config: {componentType, resultComponentType, M, N, K, subgroupMinSize, subgroupMaxSize, needsPrepack} +// Any GPU reporting a matching config from wgpu::AdapterPropertiesSubgroupMatrixConfigs is supported. +struct SupportedSubgroupMatrixConfig { + wgpu::SubgroupMatrixComponentType componentType; + wgpu::SubgroupMatrixComponentType resultComponentType; + uint32_t M; + uint32_t N; + uint32_t K; + uint32_t subgroupMinSize; + uint32_t subgroupMaxSize; + bool needsPrepack; // Whether input A needs layout optimization for subgroupMatrixLoad +}; + +static const SupportedSubgroupMatrixConfig supported_subgroup_matrix_configs[] = { + // 16x16x16 config (NVIDIA Blackwell, subgroup size 32) + {wgpu::SubgroupMatrixComponentType::F16, wgpu::SubgroupMatrixComponentType::F16, 16, 16, 16, 32, 32, true}, + // 8x16x16 configs + // 8x16x16 config (Intel Xe2/Xe3, subgroup size 16-32) + {wgpu::SubgroupMatrixComponentType::F16, wgpu::SubgroupMatrixComponentType::F16, 8, 16, 16, 16, 32, true}, + // 8x16x16 config (subgroup size 32) + {wgpu::SubgroupMatrixComponentType::F16, wgpu::SubgroupMatrixComponentType::F16, 8, 16, 16, 32, 32, true}, + // 8x8x8 config (Apple M-series, etc.) + {wgpu::SubgroupMatrixComponentType::F16, wgpu::SubgroupMatrixComponentType::F16, 8, 8, 8, 32, 32, false}, +}; -bool IsSubgroupMatrixConfigSupportedOnIntel(onnxruntime::webgpu::ComputeContext& context, int32_t& config_index) { +bool IsSubgroupMatrixConfigSupported(onnxruntime::webgpu::ComputeContext& context, int32_t& config_index) { const wgpu::AdapterInfo& adapter_info = context.AdapterInfo(); const wgpu::AdapterPropertiesSubgroupMatrixConfigs& subgroup_matrix_configs = context.SubgroupMatrixConfigs(); int32_t index = 0; - for (auto& supported_config : intel_supported_subgroup_matrix_configs) { + for (const auto& supported_config : supported_subgroup_matrix_configs) { for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) { - auto& subgroup_matrix_config = subgroup_matrix_configs.configs[i]; - auto&& config = std::make_tuple(adapter_info.architecture, adapter_info.backendType, - subgroup_matrix_config.componentType, subgroup_matrix_config.resultComponentType, - subgroup_matrix_config.M, subgroup_matrix_config.N, subgroup_matrix_config.K, - adapter_info.subgroupMinSize, adapter_info.subgroupMaxSize); - if (config == supported_config) { + const auto& device_config = subgroup_matrix_configs.configs[i]; + if (device_config.componentType == supported_config.componentType && + device_config.resultComponentType == supported_config.resultComponentType && + device_config.M == supported_config.M && + device_config.N == supported_config.N && + device_config.K == supported_config.K && + adapter_info.subgroupMinSize >= supported_config.subgroupMinSize && + adapter_info.subgroupMaxSize >= supported_config.subgroupMaxSize) { config_index = index; return true; } @@ -117,31 +136,52 @@ Status PrepackProgram::GenerateShaderCode(ShaderHelper& shader) const { WGSL_TEMPLATE_PARAMETER(sg_mat_m, m_)); } -Status GenerateShaderCodeOnIntel(ShaderHelper& shader, +Status GenerateShaderCode16x16x16(ShaderHelper& shader, + const ShaderVariableHelper& b, + const ShaderVariableHelper& scales_b, + const ShaderVariableHelper& output, + uint32_t nbits, int32_t config_index, bool has_zero_points, bool has_bias, bool has_weight_idx, bool has_weight_idx_indirect) { + const auto& config = supported_subgroup_matrix_configs[config_index]; + return WGSL_TEMPLATE_APPLY(shader, "quantization/subgroup_matrix_matmul_nbits_16x16x16.wgsl.template", + WGSL_TEMPLATE_PARAMETER(has_bias, has_bias), + WGSL_TEMPLATE_PARAMETER(has_weight_idx, has_weight_idx), + WGSL_TEMPLATE_PARAMETER(has_weight_idx_indirect, has_weight_idx_indirect), + WGSL_TEMPLATE_PARAMETER(has_zero_points, has_zero_points), + WGSL_TEMPLATE_PARAMETER(n_bits, nbits), + WGSL_TEMPLATE_PARAMETER(output_type_i32, false), + WGSL_TEMPLATE_PARAMETER(sg_mat_k, config.K), + WGSL_TEMPLATE_PARAMETER(sg_mat_m, config.M), + WGSL_TEMPLATE_PARAMETER(sg_mat_n, config.N), + WGSL_TEMPLATE_VARIABLE(input_b, b), + WGSL_TEMPLATE_VARIABLE(output, output), + WGSL_TEMPLATE_VARIABLE(scales_b, scales_b)); +} + +Status GenerateShaderCode8x16x16(ShaderHelper& shader, const ShaderVariableHelper& b, const ShaderVariableHelper& scales_b, const ShaderVariableHelper& output, uint32_t nbits, int32_t config_index, bool has_zero_points, bool has_bias, bool has_weight_idx, bool has_weight_idx_indirect) { - auto& config = intel_supported_subgroup_matrix_configs[config_index]; - return WGSL_TEMPLATE_APPLY(shader, "quantization/subgroup_matrix_matmul_nbits_intel.wgsl.template", + const auto& config = supported_subgroup_matrix_configs[config_index]; + return WGSL_TEMPLATE_APPLY(shader, "quantization/subgroup_matrix_matmul_nbits_8x16x16.wgsl.template", WGSL_TEMPLATE_PARAMETER(has_bias, has_bias), WGSL_TEMPLATE_PARAMETER(has_weight_idx, has_weight_idx), WGSL_TEMPLATE_PARAMETER(has_weight_idx_indirect, has_weight_idx_indirect), WGSL_TEMPLATE_PARAMETER(has_zero_points, has_zero_points), WGSL_TEMPLATE_PARAMETER(n_bits, nbits), WGSL_TEMPLATE_PARAMETER(output_type_i32, false), - WGSL_TEMPLATE_PARAMETER(sg_mat_k, std::get<6>(config)), - WGSL_TEMPLATE_PARAMETER(sg_mat_m, std::get<4>(config)), - WGSL_TEMPLATE_PARAMETER(sg_mat_n, std::get<5>(config)), + WGSL_TEMPLATE_PARAMETER(sg_mat_k, config.K), + WGSL_TEMPLATE_PARAMETER(sg_mat_m, config.M), + WGSL_TEMPLATE_PARAMETER(sg_mat_n, config.N), WGSL_TEMPLATE_VARIABLE(input_b, b), WGSL_TEMPLATE_VARIABLE(output, output), WGSL_TEMPLATE_VARIABLE(scales_b, scales_b)); } -Status GenerateShaderCodeOnApple(ShaderHelper& shader, const ShaderVariableHelper& a, const ShaderVariableHelper& b, - const ShaderVariableHelper& scales_b, - const ShaderVariableHelper& output, uint32_t nbits, bool has_zero_points, bool has_bias, bool has_weight_idx, bool has_weight_idx_indirect) { - return WGSL_TEMPLATE_APPLY(shader, "quantization/subgroup_matrix_matmul_nbits_apple.wgsl.template", +Status GenerateShaderCode8x8x8(ShaderHelper& shader, const ShaderVariableHelper& a, const ShaderVariableHelper& b, + const ShaderVariableHelper& scales_b, + const ShaderVariableHelper& output, uint32_t nbits, bool has_zero_points, bool has_bias, bool has_weight_idx, bool has_weight_idx_indirect) { + return WGSL_TEMPLATE_APPLY(shader, "quantization/subgroup_matrix_matmul_nbits_8x8x8.wgsl.template", WGSL_TEMPLATE_PARAMETER(has_bias, has_bias), WGSL_TEMPLATE_PARAMETER(has_weight_idx, has_weight_idx), WGSL_TEMPLATE_PARAMETER(has_weight_idx_indirect, has_weight_idx_indirect), @@ -169,13 +209,16 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader } const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); - if (!vendor_.compare("apple")) { - return GenerateShaderCodeOnApple(shader, a, b, scales_b, output, nbits_, has_zero_points_, has_bias_, has_weight_idx_, has_weight_idx_indirect_); - } else if (!vendor_.compare("intel")) { - return GenerateShaderCodeOnIntel(shader, b, scales_b, output, nbits_, config_index_, has_zero_points_, has_bias_, has_weight_idx_, has_weight_idx_indirect_); + const auto& config = supported_subgroup_matrix_configs[config_index_]; + if (config.M == 8 && config.N == 8 && config.K == 8) { + return GenerateShaderCode8x8x8(shader, a, b, scales_b, output, nbits_, has_zero_points_, has_bias_, has_weight_idx_, has_weight_idx_indirect_); + } else if (config.M == 8 && config.N == 16 && config.K == 16) { + return GenerateShaderCode8x16x16(shader, b, scales_b, output, nbits_, config_index_, has_zero_points_, has_bias_, has_weight_idx_, has_weight_idx_indirect_); + } else if (config.M == 16 && config.N == 16 && config.K == 16) { + return GenerateShaderCode16x16x16(shader, b, scales_b, output, nbits_, config_index_, has_zero_points_, has_bias_, has_weight_idx_, has_weight_idx_indirect_); } else { return Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::NOT_IMPLEMENTED, - "onnxruntime does not support subgroup matrix on this vendor."); + "Unsupported subgroup matrix config dimensions."); } } @@ -193,10 +236,10 @@ Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Te const Tensor* weight_index_indirect) { // If applicable, layout optimization of input matrix A(MxK) can be used for SubgroupMatrixLoad. Tensor a_prepack; - if (context.AdapterInfo().vendor == std::string_view{"intel"}) { - const auto& config = intel_supported_subgroup_matrix_configs[config_index]; - const auto m = std::get<4>(config); - const auto k = std::get<6>(config); + const auto& config = supported_subgroup_matrix_configs[config_index]; + if (config.needsPrepack) { + const auto m = config.M; + const auto k = config.K; // Optimize the layout of input matrix A(MxK) for SubgroupMatrixLoad. PrepackProgram prepack_program{m, k}; @@ -228,10 +271,15 @@ Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Te const bool has_bias = bias != nullptr; const bool has_weight_idx_indirect = weight_index_indirect != nullptr; const bool has_weight_idx = weight_index > 0 || has_weight_idx_indirect; - SubgroupMatrixMatMulNBitsProgram mul_program{nbits, config_index, context.AdapterInfo().vendor, has_zero_points, has_bias, has_weight_idx, has_weight_idx_indirect}; - if (context.AdapterInfo().vendor == std::string_view{"intel"}) { + SubgroupMatrixMatMulNBitsProgram mul_program{nbits, config_index, has_zero_points, has_bias, has_weight_idx, has_weight_idx_indirect}; + if (config.M == 8 && config.N == 16 && config.K == 16) { + // 8x16x16 config: 8 subgroups, 256 threads, 64x64 tiles tile_size_a = 64; work_group_size = 256; + } else if (config.M == 16 && config.N == 16 && config.K == 16) { + // 16x16x16 config: 4 subgroups, 128 threads, 64x64 tiles + tile_size_a = 64; + work_group_size = 128; } mul_program.SetWorkgroupSize(work_group_size); mul_program.SetDispatchGroupSize( @@ -271,15 +319,14 @@ bool CanApplySubgroupMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& cont bool has_subgroup_matrix = context.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); if (has_subgroup_matrix) { - if (context.AdapterInfo().vendor == std::string_view{"apple"}) { - // For now SubgroupMatrixMatMulNBits is only supported for accuracy level 4, because with Fp16 there are - // some precision issues with subgroupMatrixMultiplyAccumulate. It is possible to support higher accuracy - // by setting compute_precision to Fp32, but that will be slower. For 1K token prefill FP16 Phi 3.5 is around 5s, - // FP32 is around 7s. - has_subgroup_matrix = accuracy_level == 4; - } else if (context.AdapterInfo().vendor == std::string_view{"intel"}) { - // Intel subgroup matrix config is f16-only. - has_subgroup_matrix = is_fp16 && IsSubgroupMatrixConfigSupportedOnIntel(context, config_index); + // Check if the adapter reports a subgroup matrix config we support. + has_subgroup_matrix = IsSubgroupMatrixConfigSupported(context, config_index); + if (has_subgroup_matrix) { + const auto& config = supported_subgroup_matrix_configs[config_index]; + // F16 component type requires FP16 output and accuracy level 4 (FP16 precision). + if (config.componentType == wgpu::SubgroupMatrixComponentType::F16) { + has_subgroup_matrix = is_fp16 && accuracy_level == 4; + } } } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h index 4ee800cafeb0d..1b0c1d336545c 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h @@ -5,8 +5,6 @@ #if !defined(__wasm__) -#include - #include "core/providers/webgpu/program.h" #include "core/providers/webgpu/compute_context.h" #include "core/providers/webgpu/program.h" @@ -21,11 +19,10 @@ using namespace onnxruntime::webgpu; class SubgroupMatrixMatMulNBitsProgram final : public Program { public: - SubgroupMatrixMatMulNBitsProgram(uint32_t nbits, int32_t config_index, const wgpu::StringView& vendor, bool has_zero_points, bool has_bias, bool has_weight_idx, bool has_weight_idx_indirect) + SubgroupMatrixMatMulNBitsProgram(uint32_t nbits, int32_t config_index, bool has_zero_points, bool has_bias, bool has_weight_idx, bool has_weight_idx_indirect) : Program{"SubgroupMatrixMatMulNBits"}, nbits_(nbits), config_index_(config_index), - vendor_(vendor), has_zero_points_(has_zero_points), has_bias_(has_bias), has_weight_idx_{has_weight_idx}, @@ -41,7 +38,6 @@ class SubgroupMatrixMatMulNBitsProgram final : public Program tile_b: array; + +// Scratch space for matmul results [kTileM, kTileN], row-major +// Used for bounds-checked output writes (both bias and non-bias paths). +var scratch: array; + +#if n_bits == 4 +// Dequantize 4-bit quantized B weights and store into tile_b [kTileN, kTileK] workgroup memory. +// 128 threads load kTileN(64) x kTileK(32). 2 threads per N, 16 K elements per thread. +// k_chunk_idx: which 16-element chunk of kTileK this thread handles (0..1). +fn dequant_b_to_tile(n_base: u32, k_idx: u32, n_idx: u32, k_chunk_idx: u32) { + let global_n = n_base + n_idx; + if (global_n >= uniforms.N) { + return; + } + let k_offset = k_chunk_idx * 16; +#if has_weight_idx +#if has_weight_idx_indirect + let actual_weight_idx = weight_index_indirect[uniforms.weight_idx]; +#else + let actual_weight_idx = uniforms.weight_idx; +#endif + let weight_offset = actual_weight_idx * uniforms.K * uniforms.N; + let scale_offset = actual_weight_idx * uniforms.N * (uniforms.K / kQuantizationBlockSize); +#else + const weight_offset : u32 = 0; + const scale_offset : u32 = 0; +#endif + let packed_idx = u32((global_n * uniforms.K + k_idx + k_offset) / 8) + weight_offset / 8; + let scale_idx = (global_n * uniforms.K + k_idx + k_offset) / kQuantizationBlockSize + scale_offset; + let scale = f16(scales_b.getByOffset(scale_idx)); + let zero = mm_read_zero( + global_n, (k_idx + k_offset) / kQuantizationBlockSize, uniforms.N, uniforms.zero_blocks_per_col); + // Each thread loads 2 packed u32 words = 16 elements + for (var step: u32 = 0; step < 2; step++) { + let packed_weights = input_b.getByOffset(packed_idx + step); + let weights_lo = + (vec4(unpack4xU8(packed_weights & 0x0F0F0F0Fu)) - vec4(zero)) * scale; + let weights_hi = + (vec4(unpack4xU8((packed_weights >> 4) & 0x0F0F0F0Fu)) - vec4(zero)) * scale; + let tile_base = n_idx * kTileK + k_offset + step * 8; + tile_b[tile_base] = weights_lo[0]; + tile_b[tile_base + 1] = weights_hi[0]; + tile_b[tile_base + 2] = weights_lo[1]; + tile_b[tile_base + 3] = weights_hi[1]; + tile_b[tile_base + 4] = weights_lo[2]; + tile_b[tile_base + 5] = weights_hi[2]; + tile_b[tile_base + 6] = weights_lo[3]; + tile_b[tile_base + 7] = weights_hi[3]; + } +} +#endif + +#if n_bits == 8 +// Dequantize 8-bit quantized B weights and store into tile_b [kTileN, kTileK] workgroup memory. +// 128 threads load kTileN(64) x kTileK(32). 2 threads per N, 16 K elements per thread. +// k_chunk_idx: which 16-element chunk of kTileK this thread handles (0..1). +fn dequant_b_to_tile(n_base: u32, k_idx: u32, n_idx: u32, k_chunk_idx: u32) { + let global_n = n_base + n_idx; + if (global_n >= uniforms.N) { + return; + } + let k_offset = k_chunk_idx * 16; +#if has_weight_idx +#if has_weight_idx_indirect + let actual_weight_idx = weight_index_indirect[uniforms.weight_idx]; +#else + let actual_weight_idx = uniforms.weight_idx; +#endif + let weight_offset = actual_weight_idx * uniforms.K * uniforms.N; + let scale_offset = actual_weight_idx * uniforms.N * (uniforms.K / kQuantizationBlockSize); +#else + const weight_offset : u32 = 0; + const scale_offset : u32 = 0; +#endif + let packed_idx = u32((global_n * uniforms.K + k_idx + k_offset) / 8) + weight_offset / 8; + let scale_idx = (global_n * uniforms.K + k_idx + k_offset) / kQuantizationBlockSize + scale_offset; + let scale = f16(scales_b.getByOffset(scale_idx)); + let zero = mm_read_zero( + global_n, (k_idx + k_offset) / kQuantizationBlockSize, uniforms.N, uniforms.zero_blocks_per_col); + // Each thread loads 2 packed vec2 words = 16 elements + for (var step: u32 = 0; step < 2; step++) { + let packed_weights = input_b.getByOffset(packed_idx + step); + let weights_lo = (vec4(unpack4xU8(packed_weights[0])) - vec4(zero)) * scale; + let weights_hi = (vec4(unpack4xU8(packed_weights[1])) - vec4(zero)) * scale; + let tile_base = n_idx * kTileK + k_offset + step * 8; + tile_b[tile_base] = weights_lo[0]; + tile_b[tile_base + 1] = weights_lo[1]; + tile_b[tile_base + 2] = weights_lo[2]; + tile_b[tile_base + 3] = weights_lo[3]; + tile_b[tile_base + 4] = weights_hi[0]; + tile_b[tile_base + 5] = weights_hi[1]; + tile_b[tile_base + 6] = weights_hi[2]; + tile_b[tile_base + 7] = weights_hi[3]; + } +} +#endif + +$MAIN { + let global_base_a = workgroup_id.y * kTileM; + let global_base_b = workgroup_id.x * kTileN; + + let sg_idx = u32(local_idx / sg_size); + let sg_mat_count_k = uniforms.K / kSgMatK; + let sg_mat_idx = (workgroup_id.y * kSgMatCountM + sg_idx) * sg_mat_count_k; + + var sg_mat_offset_a = sg_mat_idx * kSgMatSizeLeft; + + var sg_mat_c0: subgroup_matrix_result; + var sg_mat_c1: subgroup_matrix_result; + var sg_mat_c2: subgroup_matrix_result; + var sg_mat_c3: subgroup_matrix_result; + for (var k_idx: u32 = 0; k_idx < uniforms.K; k_idx += kTileK) { + // Load Phase: 128 threads, 2 threads per N row, each handles 16 K elements + dequant_b_to_tile(global_base_b, k_idx, local_idx / 2, local_idx % 2); + workgroupBarrier(); + + for (var sg_mat_k_idx: u32 = 0; sg_mat_k_idx < kTileK; sg_mat_k_idx += kSgMatK) + { + // Load A from global memory (prepacked layout). + // Syntax: subgroupMatrixLoad src_ptr, src_offset, is_col_major, src_stride + var sg_mat_a0: subgroup_matrix_left = + subgroupMatrixLoad>( + &input_a, sg_mat_offset_a, false, kSgMatK); + sg_mat_offset_a += kSgMatSizeLeft; + + // Load B from shared local memory. + // tile_b [kTileN, kTileK] is stored as column major. + var sg_mat_b0: subgroup_matrix_right = + subgroupMatrixLoad>( + &tile_b, sg_mat_k_idx, true, kTileK); + var sg_mat_b1: subgroup_matrix_right = + subgroupMatrixLoad>( + &tile_b, sg_mat_k_idx + kSgMatStrideN, true, kTileK); + var sg_mat_b2: subgroup_matrix_right = + subgroupMatrixLoad>( + &tile_b, sg_mat_k_idx + 2 * kSgMatStrideN, true, kTileK); + var sg_mat_b3: subgroup_matrix_right = + subgroupMatrixLoad>( + &tile_b, sg_mat_k_idx + 3 * kSgMatStrideN, true, kTileK); + + // Compute Phase + // Syntax: subgroupMatrixMultiplyAccumulate left, right, accumulate -> accumulate + sg_mat_c0 = subgroupMatrixMultiplyAccumulate(sg_mat_a0, sg_mat_b0, sg_mat_c0); + sg_mat_c1 = subgroupMatrixMultiplyAccumulate(sg_mat_a0, sg_mat_b1, sg_mat_c1); + sg_mat_c2 = subgroupMatrixMultiplyAccumulate(sg_mat_a0, sg_mat_b2, sg_mat_c2); + sg_mat_c3 = subgroupMatrixMultiplyAccumulate(sg_mat_a0, sg_mat_b3, sg_mat_c3); + } + workgroupBarrier(); + } + + // Write out +#if has_bias + // Store results to scratch workgroup memory, then add bias and write to output. + // scratch layout: [kTileM, kTileN] row-major + let scratch_m_base = sg_idx * kSgMatM; + subgroupMatrixStore(&scratch, scratch_m_base * kTileN, sg_mat_c0, false, kTileN); + subgroupMatrixStore(&scratch, scratch_m_base * kTileN + kSgMatN, sg_mat_c1, false, kTileN); + subgroupMatrixStore(&scratch, scratch_m_base * kTileN + 2 * kSgMatN, sg_mat_c2, false, kTileN); + subgroupMatrixStore(&scratch, scratch_m_base * kTileN + 3 * kSgMatN, sg_mat_c3, false, kTileN); + workgroupBarrier(); + + // 128 threads write 64x64 = 4096 elements. Each thread handles 32 elements. + // Thread mapping: m = local_idx / 2, n_base = (local_idx % 2) * 32 + let out_m = local_idx / 2; + let out_n_base = (local_idx % 2) * 32; + let global_m = global_base_a + out_m; + if (global_m < uniforms.M) { + let global_n_base = global_base_b + out_n_base; + let scratch_base = out_m * kTileN + out_n_base; + let out_base = global_m * uniforms.N + global_n_base; +#if has_weight_idx_indirect + let bias_offset = weight_index_indirect[uniforms.weight_idx] * uniforms.N; +#elif has_weight_idx + let bias_offset = uniforms.weight_idx * uniforms.N; +#else + const bias_offset: u32 = 0; +#endif + for (var i: u32 = 0; i < 32; i++) { + if (global_n_base + i < uniforms.N) { + let val = output_element_t(scratch[scratch_base + i]) + + bias[bias_offset + global_n_base + i]; + output.setByOffset(out_base + i, val); + } + } + } +#else + // Non-bias path: store results directly to output with bounds checking. + // Use scratch to handle M bounds - subgroupMatrixStore to scratch first, + // then scalar copy only valid rows. + let scratch_m_base = sg_idx * kSgMatM; + subgroupMatrixStore(&scratch, scratch_m_base * kTileN, sg_mat_c0, false, kTileN); + subgroupMatrixStore(&scratch, scratch_m_base * kTileN + kSgMatN, sg_mat_c1, false, kTileN); + subgroupMatrixStore(&scratch, scratch_m_base * kTileN + 2 * kSgMatN, sg_mat_c2, false, kTileN); + subgroupMatrixStore(&scratch, scratch_m_base * kTileN + 3 * kSgMatN, sg_mat_c3, false, kTileN); + workgroupBarrier(); + + let out_m = local_idx / 2; + let out_n_base = (local_idx % 2) * 32; + let global_m = global_base_a + out_m; + if (global_m < uniforms.M) { + let global_n_base = global_base_b + out_n_base; + let scratch_base = out_m * kTileN + out_n_base; + let out_base = global_m * uniforms.N + global_n_base; + for (var i: u32 = 0; i < 32; i++) { + if (global_n_base + i < uniforms.N) { + output.setByOffset(out_base + i, output_element_t(scratch[scratch_base + i])); + } + } + } +#endif +} // MAIN diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_intel.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_8x16x16.wgsl.template similarity index 100% rename from onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_intel.wgsl.template rename to onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_8x16x16.wgsl.template diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_apple.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_8x8x8.wgsl.template similarity index 100% rename from onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_apple.wgsl.template rename to onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_8x8x8.wgsl.template diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_prepack.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_prepack.wgsl.template index 3b8f9884da927..e3e65b667ff6e 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_prepack.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_prepack.wgsl.template @@ -1,9 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -// Intel SubgroupMatrix prepack kernel +// SubgroupMatrix prepack kernel // Rearranges input matrix A(MxK) so that each subgroup matrix (sg_mat_m x sg_mat_k) // has its elements laid out contiguously in memory for subgroupMatrixLoad. +// OOB rows (beyond M) are zero-filled to avoid undefined behavior. #param sg_mat_k #param sg_mat_m @@ -14,11 +15,25 @@ const kSgMatK: u32 = u32(sg_mat_k); $MAIN { let M = uniforms.M; let K = uniforms.K; - let in_offset = workgroup_id.x * kSgMatM * K + workgroup_id.y * kSgMatK; + let row_base = workgroup_id.x * kSgMatM; + let in_offset = row_base * K + workgroup_id.y * kSgMatK; let out_offset = (workgroup_id.x * K / kSgMatK + workgroup_id.y) * kSgMatM * kSgMatK; - // Syntax: subgroupMatrixLoad src_ptr, src_offset, is_col_major, src_stride - var mat: subgroup_matrix_left = - subgroupMatrixLoad>(&input_a, in_offset, false, uniforms.K); - subgroupMatrixStore(&output_a, out_offset, mat, false, kSgMatK); + if (row_base + kSgMatM <= M) { + // All rows in this block are within bounds - use fast subgroupMatrixLoad. + var mat: subgroup_matrix_left = + subgroupMatrixLoad>(&input_a, in_offset, false, K); + subgroupMatrixStore(&output_a, out_offset, mat, false, kSgMatK); + } else { + // Partial block: some rows are OOB. Use scalar copy with zero-fill. + for (var r: u32 = local_idx; r < kSgMatM * kSgMatK; r += workgroup_size_x) { + let row = r / kSgMatK; + let col = r % kSgMatK; + if (row_base + row < M) { + output_a[out_offset + r] = input_a[in_offset + row * K + col]; + } else { + output_a[out_offset + r] = f16(0.0); + } + } + } } // MAIN From b201c8d5a9cb49a7182f0b35a0f2f8409b9dde14 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Fri, 17 Apr 2026 13:58:01 +0800 Subject: [PATCH 2/8] webgpu: Fix 16x16x16 shader output - fast path for full blocks, safe barrier placement - Use fast subgroupMatrixStore directly to output for full M blocks (sg_m_base + kSgMatM <= M), avoiding scratch overhead for the common case. - Use scratch + scalar write only for partial M blocks at the boundary. - Move workgroupBarrier outside the if/else to avoid divergent barrier (WGSL disallows workgroupBarrier in non-uniform control flow). - Make scratch array unconditional (needed for both bias and non-bias paths). This fixes the Invalid ShaderModule crash that occurred when the barrier was inside a branch that different subgroups could take different sides of. --- ...matrix_matmul_nbits_16x16x16.wgsl.template | 49 ++++++++++++------- 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_16x16x16.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_16x16x16.wgsl.template index 1bf6769e7d5b4..7c69b71c29580 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_16x16x16.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_16x16x16.wgsl.template @@ -225,26 +225,39 @@ $MAIN { } } #else - // Non-bias path: store results directly to output with bounds checking. - // Use scratch to handle M bounds - subgroupMatrixStore to scratch first, - // then scalar copy only valid rows. - let scratch_m_base = sg_idx * kSgMatM; - subgroupMatrixStore(&scratch, scratch_m_base * kTileN, sg_mat_c0, false, kTileN); - subgroupMatrixStore(&scratch, scratch_m_base * kTileN + kSgMatN, sg_mat_c1, false, kTileN); - subgroupMatrixStore(&scratch, scratch_m_base * kTileN + 2 * kSgMatN, sg_mat_c2, false, kTileN); - subgroupMatrixStore(&scratch, scratch_m_base * kTileN + 3 * kSgMatN, sg_mat_c3, false, kTileN); + // Non-bias path: write results to output. + let sg_m_base = global_base_a + sg_idx * kSgMatM; + let full_block = sg_m_base + kSgMatM <= uniforms.M; + + if (full_block) { + // All rows in this block are within bounds - use fast subgroupMatrixStore to output. + let sg_mat_offset_c = sg_m_base * uniforms.N + global_base_b; + subgroupMatrixStore(&output, sg_mat_offset_c, sg_mat_c0, false, uniforms.N); + subgroupMatrixStore(&output, sg_mat_offset_c + kSgMatN, sg_mat_c1, false, uniforms.N); + subgroupMatrixStore(&output, sg_mat_offset_c + 2 * kSgMatN, sg_mat_c2, false, uniforms.N); + subgroupMatrixStore(&output, sg_mat_offset_c + 3 * kSgMatN, sg_mat_c3, false, uniforms.N); + } else { + // Partial block: store to scratch for bounds-checked scalar write. + let scratch_m_base = sg_idx * kSgMatM; + subgroupMatrixStore(&scratch, scratch_m_base * kTileN, sg_mat_c0, false, kTileN); + subgroupMatrixStore(&scratch, scratch_m_base * kTileN + kSgMatN, sg_mat_c1, false, kTileN); + subgroupMatrixStore(&scratch, scratch_m_base * kTileN + 2 * kSgMatN, sg_mat_c2, false, kTileN); + subgroupMatrixStore(&scratch, scratch_m_base * kTileN + 3 * kSgMatN, sg_mat_c3, false, kTileN); + } workgroupBarrier(); - let out_m = local_idx / 2; - let out_n_base = (local_idx % 2) * 32; - let global_m = global_base_a + out_m; - if (global_m < uniforms.M) { - let global_n_base = global_base_b + out_n_base; - let scratch_base = out_m * kTileN + out_n_base; - let out_base = global_m * uniforms.N + global_n_base; - for (var i: u32 = 0; i < 32; i++) { - if (global_n_base + i < uniforms.N) { - output.setByOffset(out_base + i, output_element_t(scratch[scratch_base + i])); + if (!full_block) { + let out_m = local_idx / 2; + let out_n_base = (local_idx % 2) * 32; + let global_m = global_base_a + out_m; + if (global_m < uniforms.M) { + let global_n_base = global_base_b + out_n_base; + let scratch_base = out_m * kTileN + out_n_base; + let out_base = global_m * uniforms.N + global_n_base; + for (var i: u32 = 0; i < 32; i++) { + if (global_n_base + i < uniforms.N) { + output.setByOffset(out_base + i, output_element_t(scratch[scratch_base + i])); + } } } } From bec0fdbd62d50601cd1ffe33a08e2c8cd737d92e Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Mon, 20 Apr 2026 21:00:32 +0800 Subject: [PATCH 3/8] webgpu: Add 128x128 tile shader for 16x16x16 subgroup matrix config Use larger 128x128 tiles (vs 64x64) for NVIDIA Blackwell 16x16x16 config to improve prefill throughput. Key changes: - New WGSL template with 2x2 subgroup grid, each handling 64x64 subtile - Load A directly from prepacked global memory (no shared memory) - Dequant B to shared memory with padded stride (SHMEM_STRIDE=40) - Update dispatch to ceil(N/128) x ceil(M/128) --- .../subgroup_matrix_matmul_nbits.cc | 14 +- ...ix_matmul_nbits_16x16x16_128.wgsl.template | 738 ++++++++++++++++++ 2 files changed, 746 insertions(+), 6 deletions(-) create mode 100644 onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_16x16x16_128.wgsl.template diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc index a381574756301..a1151ce753110 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc @@ -58,7 +58,7 @@ struct SupportedSubgroupMatrixConfig { }; static const SupportedSubgroupMatrixConfig supported_subgroup_matrix_configs[] = { - // 16x16x16 config (NVIDIA Blackwell, subgroup size 32) + // 16x16x16 config with 128x128 tiles (NVIDIA Blackwell, subgroup size 32) - highest priority {wgpu::SubgroupMatrixComponentType::F16, wgpu::SubgroupMatrixComponentType::F16, 16, 16, 16, 32, 32, true}, // 8x16x16 configs // 8x16x16 config (Intel Xe2/Xe3, subgroup size 16-32) @@ -142,7 +142,8 @@ Status GenerateShaderCode16x16x16(ShaderHelper& shader, const ShaderVariableHelper& output, uint32_t nbits, int32_t config_index, bool has_zero_points, bool has_bias, bool has_weight_idx, bool has_weight_idx_indirect) { const auto& config = supported_subgroup_matrix_configs[config_index]; - return WGSL_TEMPLATE_APPLY(shader, "quantization/subgroup_matrix_matmul_nbits_16x16x16.wgsl.template", + // Use 128x128 tile shader for 16x16x16 config (index 0) + return WGSL_TEMPLATE_APPLY(shader, "quantization/subgroup_matrix_matmul_nbits_16x16x16_128.wgsl.template", WGSL_TEMPLATE_PARAMETER(has_bias, has_bias), WGSL_TEMPLATE_PARAMETER(has_weight_idx, has_weight_idx), WGSL_TEMPLATE_PARAMETER(has_weight_idx_indirect, has_weight_idx_indirect), @@ -263,8 +264,8 @@ Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Te } uint32_t tile_size_a = 32; + uint32_t tile_size_b = 64; uint32_t work_group_size = 128; - constexpr uint32_t kTileSizeB = 64; constexpr uint32_t kU32Components = 4; TensorShape y_shape{1, M, N}; const bool has_zero_points = zero_points != nullptr; @@ -277,13 +278,14 @@ Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Te tile_size_a = 64; work_group_size = 256; } else if (config.M == 16 && config.N == 16 && config.K == 16) { - // 16x16x16 config: 4 subgroups, 128 threads, 64x64 tiles - tile_size_a = 64; + // 16x16x16 config: 4 subgroups, 128 threads, 128x128 tiles + tile_size_a = 128; + tile_size_b = 128; work_group_size = 128; } mul_program.SetWorkgroupSize(work_group_size); mul_program.SetDispatchGroupSize( - (N + kTileSizeB - 1) / kTileSizeB, + (N + tile_size_b - 1) / tile_size_b, (M + tile_size_a - 1) / tile_size_a, 1); mul_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, 1}, {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(nbits == 4 ? kU32Components : 2 * kU32Components)}, diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_16x16x16_128.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_16x16x16_128.wgsl.template new file mode 100644 index 0000000000000..0c49b880e2576 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_16x16x16_128.wgsl.template @@ -0,0 +1,738 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// SubgroupMatrix matmul_nbits kernel for 16x16x16 config with 128x128 tiles +// Uses subgroupMatrixLoad/Store/MultiplyAccumulate for hardware tensor core acceleration +// +// A: prepacked layout (from PrepackProgram) - elements arranged for subgroupMatrixLoad +// B: 4-bit or 8-bit quantized, dequantized into SLM (buf_b) with padded stride +// +// Workgroup: 128 threads = 4 subgroups x 32 lanes, arranged 2x2 +// Tile: BM(128) x BN(128) +// Each subgroup handles WM(64) x WN(64) = 4x4 = 16 coopmat tiles of 16x16 +// K dimension iterated in BK(32) blocks +// +// Key optimizations vs 64x64 shader: +// - 4x larger tile (128x128 vs 64x64) reduces global memory traffic +// - A loaded directly from global memory (prepacked), 4 rows per subgroup per k-step +// - B dequantized to shared memory with padded stride (SHMEM_STRIDE=40) for bank conflict avoidance +// - B tiles loaded once per k-step, reused across 4 A rows + +#param n_bits +#param has_zero_points +#param has_bias +#param has_weight_idx +#param has_weight_idx_indirect +#param sg_mat_k +#param sg_mat_m +#param sg_mat_n + +#use .getByOffset .setByOffset + +#include "quantization/matmul_nbits_zero_pt.wgsl.template" + +const kQuantizationBlockSize: u32 = 32; + +// Tile parameters matching CM1 (cooperative matrix) layout +const BM: u32 = 128; // Tile size along M +const BN: u32 = 128; // Tile size along N +const BK: u32 = 32; // Tile size along K +const kSgMatM: u32 = u32(sg_mat_m); // 16 +const kSgMatN: u32 = u32(sg_mat_n); // 16 +const kSgMatK: u32 = u32(sg_mat_k); // 16 +const WM: u32 = 64; // Per-subgroup M (BM/2) +const WN: u32 = 64; // Per-subgroup N (BN/2) +const kSgMatSizeLeft: u32 = kSgMatM * kSgMatK; // Elements per left (A) subgroup matrix +const kSgMatCountM: u32 = BM / kSgMatM; // 8 subgroup matrices per tile along M + +// Padded shared memory stride to avoid bank conflicts during cooperative matrix loads +const SHMEM_STRIDE: u32 = 40; // BK + 8 padding (matches Vulkan CM1's BK/2+4 in vec2 units) + +// Shared memory for dequantized B weights [BN, BK] with padding +var buf_b: array; // 128 * 40 = 5120 elements = 10240 bytes + +// Small staging buffer for edge-tile output (one 16x16 tile per subgroup) +var coopmat_stage: array; // 16*16*4 = 1024 elements = 2048 bytes +// Total shared memory: 10240 + 2048 = 12288 bytes + +#if n_bits == 4 +// Dequantize 4-bit quantized B weights into buf_b with padded stride. +// First half: rows 0..63, Second half: rows 64..127 +// 128 threads, 2 threads per N row, each handles 16 K elements. +fn dequant_b_to_buf(n_base: u32, k_idx: u32, n_idx_0: u32, k_chunk: u32) { +#if has_weight_idx +#if has_weight_idx_indirect + let actual_weight_idx = weight_index_indirect[uniforms.weight_idx]; +#else + let actual_weight_idx = uniforms.weight_idx; +#endif + let weight_offset = actual_weight_idx * uniforms.K * uniforms.N; + let scale_offset = actual_weight_idx * uniforms.N * (uniforms.K / kQuantizationBlockSize); +#else + const weight_offset : u32 = 0; + const scale_offset : u32 = 0; +#endif + let k_offset = k_chunk * 16u; + + // First half: rows 0..63 + { + let global_n = n_base + n_idx_0; + let buf_base = n_idx_0 * SHMEM_STRIDE + k_offset; + if (global_n < uniforms.N) { + let packed_idx = u32((global_n * uniforms.K + k_idx + k_offset) / 8) + weight_offset / 8; + let scale_idx = (global_n * uniforms.K + k_idx + k_offset) / kQuantizationBlockSize + scale_offset; + let scale = f16(scales_b.getByOffset(scale_idx)); + let zero = mm_read_zero( + global_n, (k_idx + k_offset) / kQuantizationBlockSize, uniforms.N, uniforms.zero_blocks_per_col); + for (var step: u32 = 0; step < 2u; step++) { + let packed_weights = input_b.getByOffset(packed_idx + step); + let w_lo = (vec4(unpack4xU8(packed_weights & 0x0F0F0F0Fu)) - vec4(zero)) * scale; + let w_hi = (vec4(unpack4xU8((packed_weights >> 4u) & 0x0F0F0F0Fu)) - vec4(zero)) * scale; + let b = buf_base + step * 8u; + buf_b[b] = w_lo[0]; buf_b[b + 1] = w_hi[0]; + buf_b[b + 2] = w_lo[1]; buf_b[b + 3] = w_hi[1]; + buf_b[b + 4] = w_lo[2]; buf_b[b + 5] = w_hi[2]; + buf_b[b + 6] = w_lo[3]; buf_b[b + 7] = w_hi[3]; + } + } else { + for (var i: u32 = 0; i < 16u; i++) { + buf_b[buf_base + i] = f16(0.0); + } + } + } + // Second half: rows 64..127 + { + let n_idx_1 = n_idx_0 + 64u; + let global_n = n_base + n_idx_1; + let buf_base = n_idx_1 * SHMEM_STRIDE + k_offset; + if (global_n < uniforms.N) { + let packed_idx = u32((global_n * uniforms.K + k_idx + k_offset) / 8) + weight_offset / 8; + let scale_idx = (global_n * uniforms.K + k_idx + k_offset) / kQuantizationBlockSize + scale_offset; + let scale = f16(scales_b.getByOffset(scale_idx)); + let zero = mm_read_zero( + global_n, (k_idx + k_offset) / kQuantizationBlockSize, uniforms.N, uniforms.zero_blocks_per_col); + for (var step: u32 = 0; step < 2u; step++) { + let packed_weights = input_b.getByOffset(packed_idx + step); + let w_lo = (vec4(unpack4xU8(packed_weights & 0x0F0F0F0Fu)) - vec4(zero)) * scale; + let w_hi = (vec4(unpack4xU8((packed_weights >> 4u) & 0x0F0F0F0Fu)) - vec4(zero)) * scale; + let b = buf_base + step * 8u; + buf_b[b] = w_lo[0]; buf_b[b + 1] = w_hi[0]; + buf_b[b + 2] = w_lo[1]; buf_b[b + 3] = w_hi[1]; + buf_b[b + 4] = w_lo[2]; buf_b[b + 5] = w_hi[2]; + buf_b[b + 6] = w_lo[3]; buf_b[b + 7] = w_hi[3]; + } + } else { + for (var i: u32 = 0; i < 16u; i++) { + buf_b[buf_base + i] = f16(0.0); + } + } + } +} +#endif + +#if n_bits == 8 +// Dequantize 8-bit quantized B weights into buf_b with padded stride. +fn dequant_b_to_buf(n_base: u32, k_idx: u32, n_idx_0: u32, k_chunk: u32) { +#if has_weight_idx +#if has_weight_idx_indirect + let actual_weight_idx = weight_index_indirect[uniforms.weight_idx]; +#else + let actual_weight_idx = uniforms.weight_idx; +#endif + let weight_offset = actual_weight_idx * uniforms.K * uniforms.N; + let scale_offset = actual_weight_idx * uniforms.N * (uniforms.K / kQuantizationBlockSize); +#else + const weight_offset : u32 = 0; + const scale_offset : u32 = 0; +#endif + let k_offset = k_chunk * 16u; + + // First half: rows 0..63 + { + let global_n = n_base + n_idx_0; + let buf_base = n_idx_0 * SHMEM_STRIDE + k_offset; + if (global_n < uniforms.N) { + let packed_idx = u32((global_n * uniforms.K + k_idx + k_offset) / 8) + weight_offset / 8; + let scale_idx = (global_n * uniforms.K + k_idx + k_offset) / kQuantizationBlockSize + scale_offset; + let scale = f16(scales_b.getByOffset(scale_idx)); + let zero = mm_read_zero( + global_n, (k_idx + k_offset) / kQuantizationBlockSize, uniforms.N, uniforms.zero_blocks_per_col); + for (var step: u32 = 0; step < 2u; step++) { + let packed_weights = input_b.getByOffset(packed_idx + step); + let w_lo = (vec4(unpack4xU8(packed_weights[0])) - vec4(zero)) * scale; + let w_hi = (vec4(unpack4xU8(packed_weights[1])) - vec4(zero)) * scale; + let b = buf_base + step * 8u; + buf_b[b] = w_lo[0]; buf_b[b + 1] = w_lo[1]; + buf_b[b + 2] = w_lo[2]; buf_b[b + 3] = w_lo[3]; + buf_b[b + 4] = w_hi[0]; buf_b[b + 5] = w_hi[1]; + buf_b[b + 6] = w_hi[2]; buf_b[b + 7] = w_hi[3]; + } + } else { + for (var i: u32 = 0; i < 16u; i++) { + buf_b[buf_base + i] = f16(0.0); + } + } + } + // Second half: rows 64..127 + { + let n_idx_1 = n_idx_0 + 64u; + let global_n = n_base + n_idx_1; + let buf_base = n_idx_1 * SHMEM_STRIDE + k_offset; + if (global_n < uniforms.N) { + let packed_idx = u32((global_n * uniforms.K + k_idx + k_offset) / 8) + weight_offset / 8; + let scale_idx = (global_n * uniforms.K + k_idx + k_offset) / kQuantizationBlockSize + scale_offset; + let scale = f16(scales_b.getByOffset(scale_idx)); + let zero = mm_read_zero( + global_n, (k_idx + k_offset) / kQuantizationBlockSize, uniforms.N, uniforms.zero_blocks_per_col); + for (var step: u32 = 0; step < 2u; step++) { + let packed_weights = input_b.getByOffset(packed_idx + step); + let w_lo = (vec4(unpack4xU8(packed_weights[0])) - vec4(zero)) * scale; + let w_hi = (vec4(unpack4xU8(packed_weights[1])) - vec4(zero)) * scale; + let b = buf_base + step * 8u; + buf_b[b] = w_lo[0]; buf_b[b + 1] = w_lo[1]; + buf_b[b + 2] = w_lo[2]; buf_b[b + 3] = w_lo[3]; + buf_b[b + 4] = w_hi[0]; buf_b[b + 5] = w_hi[1]; + buf_b[b + 6] = w_hi[2]; buf_b[b + 7] = w_hi[3]; + } + } else { + for (var i: u32 = 0; i < 16u; i++) { + buf_b[buf_base + i] = f16(0.0); + } + } + } +} +#endif + +$MAIN { + let global_base_m = workgroup_id.y * BM; + let global_base_n = workgroup_id.x * BN; + + let sg_idx = u32(local_idx / sg_size); + let sg_tid = local_idx % sg_size; + let warp_r = sg_idx % 2u; // 0 or 1 (row in 2x2 grid) + let warp_c = sg_idx / 2u; // 0 or 1 (col in 2x2 grid) + + // Each subgroup's A starts at a different M offset in the prepacked layout + // Subgroup sg_idx handles rows: warp_r * WM .. warp_r * WM + WM (within the BM tile) + // But in the prepacked layout, the subgroups are numbered 0..7 (kSgMatCountM = BM/kSgMatM = 8) + // Subgroups 0,1 handle rows 0..63 (warp_r=0), subgroups 2,3 handle rows 64..127 (warp_r=1) + // Within each 64-row band, we have 4 coopmat rows of 16 + + // For prepacked A: sg_mat index = (workgroup_y * kSgMatCountM + first_coopmat_row_in_warp) * sg_mat_count_k + let sg_mat_count_k = uniforms.K / kSgMatK; + + // 16 accumulator tiles (4x4 grid of 16x16) + var c00: subgroup_matrix_result; + var c01: subgroup_matrix_result; + var c02: subgroup_matrix_result; + var c03: subgroup_matrix_result; + var c10: subgroup_matrix_result; + var c11: subgroup_matrix_result; + var c12: subgroup_matrix_result; + var c13: subgroup_matrix_result; + var c20: subgroup_matrix_result; + var c21: subgroup_matrix_result; + var c22: subgroup_matrix_result; + var c23: subgroup_matrix_result; + var c30: subgroup_matrix_result; + var c31: subgroup_matrix_result; + var c32: subgroup_matrix_result; + var c33: subgroup_matrix_result; + + // Precompute A offsets for the 4 coopmat rows this subgroup handles + // In prepacked layout, each subgroup matrix tile is kSgMatSizeLeft contiguous elements + // The 4 coopmat rows are at (warp_r * 4 + 0..3) in the M dimension + var sg_mat_offset_a0 = (workgroup_id.y * kSgMatCountM + warp_r * 4u + 0u) * sg_mat_count_k * kSgMatSizeLeft; + var sg_mat_offset_a1 = (workgroup_id.y * kSgMatCountM + warp_r * 4u + 1u) * sg_mat_count_k * kSgMatSizeLeft; + var sg_mat_offset_a2 = (workgroup_id.y * kSgMatCountM + warp_r * 4u + 2u) * sg_mat_count_k * kSgMatSizeLeft; + var sg_mat_offset_a3 = (workgroup_id.y * kSgMatCountM + warp_r * 4u + 3u) * sg_mat_count_k * kSgMatSizeLeft; + + // Main K-loop + for (var k_block: u32 = 0; k_block < uniforms.K; k_block += BK) { + // Load B (Q4_K/Q8 weights) to buf_b with padded stride + dequant_b_to_buf(global_base_n, k_block, local_idx / 2u, local_idx % 2u); + workgroupBarrier(); + + // Compute: load B once per k-step, iterate 4 A rows + for (var k: u32 = 0; k < BK; k += kSgMatK) { + let b_base = (warp_c * WN) * SHMEM_STRIDE + k; + + // Load 4 B tiles from shared memory (column-major, padded stride) + var b0: subgroup_matrix_right = + subgroupMatrixLoad>(&buf_b, b_base, true, SHMEM_STRIDE); + var b1: subgroup_matrix_right = + subgroupMatrixLoad>(&buf_b, b_base + 16u * SHMEM_STRIDE, true, SHMEM_STRIDE); + var b2: subgroup_matrix_right = + subgroupMatrixLoad>(&buf_b, b_base + 32u * SHMEM_STRIDE, true, SHMEM_STRIDE); + var b3: subgroup_matrix_right = + subgroupMatrixLoad>(&buf_b, b_base + 48u * SHMEM_STRIDE, true, SHMEM_STRIDE); + + // A row 0: load from global prepacked memory + var a0: subgroup_matrix_left = + subgroupMatrixLoad>( + &input_a, sg_mat_offset_a0, false, kSgMatK); + sg_mat_offset_a0 += kSgMatSizeLeft; + c00 = subgroupMatrixMultiplyAccumulate(a0, b0, c00); + c01 = subgroupMatrixMultiplyAccumulate(a0, b1, c01); + c02 = subgroupMatrixMultiplyAccumulate(a0, b2, c02); + c03 = subgroupMatrixMultiplyAccumulate(a0, b3, c03); + + // A row 1 + var a1: subgroup_matrix_left = + subgroupMatrixLoad>( + &input_a, sg_mat_offset_a1, false, kSgMatK); + sg_mat_offset_a1 += kSgMatSizeLeft; + c10 = subgroupMatrixMultiplyAccumulate(a1, b0, c10); + c11 = subgroupMatrixMultiplyAccumulate(a1, b1, c11); + c12 = subgroupMatrixMultiplyAccumulate(a1, b2, c12); + c13 = subgroupMatrixMultiplyAccumulate(a1, b3, c13); + + // A row 2 + var a2: subgroup_matrix_left = + subgroupMatrixLoad>( + &input_a, sg_mat_offset_a2, false, kSgMatK); + sg_mat_offset_a2 += kSgMatSizeLeft; + c20 = subgroupMatrixMultiplyAccumulate(a2, b0, c20); + c21 = subgroupMatrixMultiplyAccumulate(a2, b1, c21); + c22 = subgroupMatrixMultiplyAccumulate(a2, b2, c22); + c23 = subgroupMatrixMultiplyAccumulate(a2, b3, c23); + + // A row 3 + var a3: subgroup_matrix_left = + subgroupMatrixLoad>( + &input_a, sg_mat_offset_a3, false, kSgMatK); + sg_mat_offset_a3 += kSgMatSizeLeft; + c30 = subgroupMatrixMultiplyAccumulate(a3, b0, c30); + c31 = subgroupMatrixMultiplyAccumulate(a3, b1, c31); + c32 = subgroupMatrixMultiplyAccumulate(a3, b2, c32); + c33 = subgroupMatrixMultiplyAccumulate(a3, b3, c33); + } + workgroupBarrier(); + } + + // Write out results + let dr = global_base_m + warp_r * WM; + let dc = global_base_n + warp_c * WN; + let stride = uniforms.N; + + // Workgroup-uniform bounds check (based on workgroup_id, not local_idx) + let tile_m_in = global_base_m + BM <= uniforms.M; + let tile_n_in = global_base_n + BN <= uniforms.N; + let tile_in_bounds = tile_m_in && tile_n_in; + +#if has_bias + // Bias path: store all 16 tiles via coopmat_stage, add bias, write to output + let stage_base = sg_idx * kSgMatM * kSgMatN; +#if has_weight_idx_indirect + let bias_offset = weight_index_indirect[uniforms.weight_idx] * uniforms.N; +#elif has_weight_idx + let bias_offset = uniforms.weight_idx * uniforms.N; +#else + const bias_offset: u32 = 0; +#endif + + // Helper macro: store one coopmat tile via stage, add bias, write to output + // Process all 16 tiles (4 rows x 4 cols) + // Row 0 + subgroupMatrixStore(&coopmat_stage, stage_base, c00, false, kSgMatN); + workgroupBarrier(); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + i / kSgMatN; + let c = dc + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + let val = output_element_t(coopmat_stage[stage_base + i]) + bias[bias_offset + c]; + output.setByOffset(r * stride + c, val); + } + } + workgroupBarrier(); + + subgroupMatrixStore(&coopmat_stage, stage_base, c01, false, kSgMatN); + workgroupBarrier(); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + i / kSgMatN; + let c = dc + 16u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + let val = output_element_t(coopmat_stage[stage_base + i]) + bias[bias_offset + c]; + output.setByOffset(r * stride + c, val); + } + } + workgroupBarrier(); + + subgroupMatrixStore(&coopmat_stage, stage_base, c02, false, kSgMatN); + workgroupBarrier(); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + i / kSgMatN; + let c = dc + 32u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + let val = output_element_t(coopmat_stage[stage_base + i]) + bias[bias_offset + c]; + output.setByOffset(r * stride + c, val); + } + } + workgroupBarrier(); + + subgroupMatrixStore(&coopmat_stage, stage_base, c03, false, kSgMatN); + workgroupBarrier(); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + i / kSgMatN; + let c = dc + 48u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + let val = output_element_t(coopmat_stage[stage_base + i]) + bias[bias_offset + c]; + output.setByOffset(r * stride + c, val); + } + } + workgroupBarrier(); + + // Row 1 + subgroupMatrixStore(&coopmat_stage, stage_base, c10, false, kSgMatN); + workgroupBarrier(); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 16u + i / kSgMatN; + let c = dc + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + let val = output_element_t(coopmat_stage[stage_base + i]) + bias[bias_offset + c]; + output.setByOffset(r * stride + c, val); + } + } + workgroupBarrier(); + + subgroupMatrixStore(&coopmat_stage, stage_base, c11, false, kSgMatN); + workgroupBarrier(); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 16u + i / kSgMatN; + let c = dc + 16u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + let val = output_element_t(coopmat_stage[stage_base + i]) + bias[bias_offset + c]; + output.setByOffset(r * stride + c, val); + } + } + workgroupBarrier(); + + subgroupMatrixStore(&coopmat_stage, stage_base, c12, false, kSgMatN); + workgroupBarrier(); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 16u + i / kSgMatN; + let c = dc + 32u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + let val = output_element_t(coopmat_stage[stage_base + i]) + bias[bias_offset + c]; + output.setByOffset(r * stride + c, val); + } + } + workgroupBarrier(); + + subgroupMatrixStore(&coopmat_stage, stage_base, c13, false, kSgMatN); + workgroupBarrier(); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 16u + i / kSgMatN; + let c = dc + 48u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + let val = output_element_t(coopmat_stage[stage_base + i]) + bias[bias_offset + c]; + output.setByOffset(r * stride + c, val); + } + } + workgroupBarrier(); + + // Row 2 + subgroupMatrixStore(&coopmat_stage, stage_base, c20, false, kSgMatN); + workgroupBarrier(); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 32u + i / kSgMatN; + let c = dc + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + let val = output_element_t(coopmat_stage[stage_base + i]) + bias[bias_offset + c]; + output.setByOffset(r * stride + c, val); + } + } + workgroupBarrier(); + + subgroupMatrixStore(&coopmat_stage, stage_base, c21, false, kSgMatN); + workgroupBarrier(); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 32u + i / kSgMatN; + let c = dc + 16u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + let val = output_element_t(coopmat_stage[stage_base + i]) + bias[bias_offset + c]; + output.setByOffset(r * stride + c, val); + } + } + workgroupBarrier(); + + subgroupMatrixStore(&coopmat_stage, stage_base, c22, false, kSgMatN); + workgroupBarrier(); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 32u + i / kSgMatN; + let c = dc + 32u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + let val = output_element_t(coopmat_stage[stage_base + i]) + bias[bias_offset + c]; + output.setByOffset(r * stride + c, val); + } + } + workgroupBarrier(); + + subgroupMatrixStore(&coopmat_stage, stage_base, c23, false, kSgMatN); + workgroupBarrier(); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 32u + i / kSgMatN; + let c = dc + 48u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + let val = output_element_t(coopmat_stage[stage_base + i]) + bias[bias_offset + c]; + output.setByOffset(r * stride + c, val); + } + } + workgroupBarrier(); + + // Row 3 + subgroupMatrixStore(&coopmat_stage, stage_base, c30, false, kSgMatN); + workgroupBarrier(); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 48u + i / kSgMatN; + let c = dc + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + let val = output_element_t(coopmat_stage[stage_base + i]) + bias[bias_offset + c]; + output.setByOffset(r * stride + c, val); + } + } + workgroupBarrier(); + + subgroupMatrixStore(&coopmat_stage, stage_base, c31, false, kSgMatN); + workgroupBarrier(); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 48u + i / kSgMatN; + let c = dc + 16u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + let val = output_element_t(coopmat_stage[stage_base + i]) + bias[bias_offset + c]; + output.setByOffset(r * stride + c, val); + } + } + workgroupBarrier(); + + subgroupMatrixStore(&coopmat_stage, stage_base, c32, false, kSgMatN); + workgroupBarrier(); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 48u + i / kSgMatN; + let c = dc + 32u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + let val = output_element_t(coopmat_stage[stage_base + i]) + bias[bias_offset + c]; + output.setByOffset(r * stride + c, val); + } + } + workgroupBarrier(); + + subgroupMatrixStore(&coopmat_stage, stage_base, c33, false, kSgMatN); + workgroupBarrier(); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 48u + i / kSgMatN; + let c = dc + 48u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + let val = output_element_t(coopmat_stage[stage_base + i]) + bias[bias_offset + c]; + output.setByOffset(r * stride + c, val); + } + } +#else + // Non-bias path + if (tile_in_bounds) { + // Fast path: entire 128x128 tile is in bounds - direct store to global + let base = dr * stride + dc; + subgroupMatrixStore(&output, base, c00, false, stride); + subgroupMatrixStore(&output, base + 16u, c01, false, stride); + subgroupMatrixStore(&output, base + 32u, c02, false, stride); + subgroupMatrixStore(&output, base + 48u, c03, false, stride); + + subgroupMatrixStore(&output, base + 16u * stride, c10, false, stride); + subgroupMatrixStore(&output, base + 16u * stride + 16u, c11, false, stride); + subgroupMatrixStore(&output, base + 16u * stride + 32u, c12, false, stride); + subgroupMatrixStore(&output, base + 16u * stride + 48u, c13, false, stride); + + subgroupMatrixStore(&output, base + 32u * stride, c20, false, stride); + subgroupMatrixStore(&output, base + 32u * stride + 16u, c21, false, stride); + subgroupMatrixStore(&output, base + 32u * stride + 32u, c22, false, stride); + subgroupMatrixStore(&output, base + 32u * stride + 48u, c23, false, stride); + + subgroupMatrixStore(&output, base + 48u * stride, c30, false, stride); + subgroupMatrixStore(&output, base + 48u * stride + 16u, c31, false, stride); + subgroupMatrixStore(&output, base + 48u * stride + 32u, c32, false, stride); + subgroupMatrixStore(&output, base + 48u * stride + 48u, c33, false, stride); + } else { + // Edge tile: store via coopmat_stage with bounds checking + let stage_base = sg_idx * kSgMatM * kSgMatN; + + // Row 0 + subgroupMatrixStore(&coopmat_stage, stage_base, c00, false, kSgMatN); + workgroupBarrier(); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + i / kSgMatN; + let c = dc + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); + } + } + workgroupBarrier(); + + subgroupMatrixStore(&coopmat_stage, stage_base, c01, false, kSgMatN); + workgroupBarrier(); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + i / kSgMatN; + let c = dc + 16u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); + } + } + workgroupBarrier(); + + subgroupMatrixStore(&coopmat_stage, stage_base, c02, false, kSgMatN); + workgroupBarrier(); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + i / kSgMatN; + let c = dc + 32u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); + } + } + workgroupBarrier(); + + subgroupMatrixStore(&coopmat_stage, stage_base, c03, false, kSgMatN); + workgroupBarrier(); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + i / kSgMatN; + let c = dc + 48u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); + } + } + workgroupBarrier(); + + // Row 1 + subgroupMatrixStore(&coopmat_stage, stage_base, c10, false, kSgMatN); + workgroupBarrier(); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 16u + i / kSgMatN; + let c = dc + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); + } + } + workgroupBarrier(); + + subgroupMatrixStore(&coopmat_stage, stage_base, c11, false, kSgMatN); + workgroupBarrier(); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 16u + i / kSgMatN; + let c = dc + 16u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); + } + } + workgroupBarrier(); + + subgroupMatrixStore(&coopmat_stage, stage_base, c12, false, kSgMatN); + workgroupBarrier(); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 16u + i / kSgMatN; + let c = dc + 32u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); + } + } + workgroupBarrier(); + + subgroupMatrixStore(&coopmat_stage, stage_base, c13, false, kSgMatN); + workgroupBarrier(); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 16u + i / kSgMatN; + let c = dc + 48u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); + } + } + workgroupBarrier(); + + // Row 2 + subgroupMatrixStore(&coopmat_stage, stage_base, c20, false, kSgMatN); + workgroupBarrier(); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 32u + i / kSgMatN; + let c = dc + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); + } + } + workgroupBarrier(); + + subgroupMatrixStore(&coopmat_stage, stage_base, c21, false, kSgMatN); + workgroupBarrier(); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 32u + i / kSgMatN; + let c = dc + 16u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); + } + } + workgroupBarrier(); + + subgroupMatrixStore(&coopmat_stage, stage_base, c22, false, kSgMatN); + workgroupBarrier(); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 32u + i / kSgMatN; + let c = dc + 32u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); + } + } + workgroupBarrier(); + + subgroupMatrixStore(&coopmat_stage, stage_base, c23, false, kSgMatN); + workgroupBarrier(); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 32u + i / kSgMatN; + let c = dc + 48u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); + } + } + workgroupBarrier(); + + // Row 3 + subgroupMatrixStore(&coopmat_stage, stage_base, c30, false, kSgMatN); + workgroupBarrier(); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 48u + i / kSgMatN; + let c = dc + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); + } + } + workgroupBarrier(); + + subgroupMatrixStore(&coopmat_stage, stage_base, c31, false, kSgMatN); + workgroupBarrier(); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 48u + i / kSgMatN; + let c = dc + 16u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); + } + } + workgroupBarrier(); + + subgroupMatrixStore(&coopmat_stage, stage_base, c32, false, kSgMatN); + workgroupBarrier(); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 48u + i / kSgMatN; + let c = dc + 32u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); + } + } + workgroupBarrier(); + + subgroupMatrixStore(&coopmat_stage, stage_base, c33, false, kSgMatN); + workgroupBarrier(); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 48u + i / kSgMatN; + let c = dc + 48u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); + } + } + } +#endif +} // MAIN From 12fde715673c52d35ee140450285d1cd976bc572 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Wed, 22 Apr 2026 17:37:05 +0800 Subject: [PATCH 4/8] webgpu: Improve subgroup matrix config matching and add F32 8x8x8 support - Use exact subgroup size matching (==) instead of range (>=) - Add F32 8x8x8 config for Apple parity - Pass is_fp16 into IsSubgroupMatrixConfigSupported to correctly skip F16 configs when output is F32 - Simplify accuracy_level check to apply uniformly - Fix missing closing brace in CanApplySubgroupMatrixMatMulNBits --- .../subgroup_matrix_matmul_nbits.cc | 29 +- ...matrix_matmul_nbits_16x16x16.wgsl.template | 265 ------------------ 2 files changed, 17 insertions(+), 277 deletions(-) delete mode 100644 onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_16x16x16.wgsl.template diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc index a1151ce753110..af02a614f1eed 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc @@ -58,22 +58,25 @@ struct SupportedSubgroupMatrixConfig { }; static const SupportedSubgroupMatrixConfig supported_subgroup_matrix_configs[] = { - // 16x16x16 config with 128x128 tiles (NVIDIA Blackwell, subgroup size 32) - highest priority + // 16x16x16 config with 128x128 tiles (NVIDIA Blackwell, subgroup size 32) {wgpu::SubgroupMatrixComponentType::F16, wgpu::SubgroupMatrixComponentType::F16, 16, 16, 16, 32, 32, true}, - // 8x16x16 configs // 8x16x16 config (Intel Xe2/Xe3, subgroup size 16-32) {wgpu::SubgroupMatrixComponentType::F16, wgpu::SubgroupMatrixComponentType::F16, 8, 16, 16, 16, 32, true}, - // 8x16x16 config (subgroup size 32) - {wgpu::SubgroupMatrixComponentType::F16, wgpu::SubgroupMatrixComponentType::F16, 8, 16, 16, 32, 32, true}, // 8x8x8 config (Apple M-series, etc.) {wgpu::SubgroupMatrixComponentType::F16, wgpu::SubgroupMatrixComponentType::F16, 8, 8, 8, 32, 32, false}, + {wgpu::SubgroupMatrixComponentType::F32, wgpu::SubgroupMatrixComponentType::F32, 8, 8, 8, 32, 32, false}, }; -bool IsSubgroupMatrixConfigSupported(onnxruntime::webgpu::ComputeContext& context, int32_t& config_index) { +bool IsSubgroupMatrixConfigSupported(onnxruntime::webgpu::ComputeContext& context, bool is_fp16, int32_t& config_index) { const wgpu::AdapterInfo& adapter_info = context.AdapterInfo(); const wgpu::AdapterPropertiesSubgroupMatrixConfigs& subgroup_matrix_configs = context.SubgroupMatrixConfigs(); int32_t index = 0; for (const auto& supported_config : supported_subgroup_matrix_configs) { + // F16 configs require FP16 output; skip them when output is F32. + if (supported_config.componentType == wgpu::SubgroupMatrixComponentType::F16 && !is_fp16) { + index++; + continue; + } for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) { const auto& device_config = subgroup_matrix_configs.configs[i]; if (device_config.componentType == supported_config.componentType && @@ -81,8 +84,8 @@ bool IsSubgroupMatrixConfigSupported(onnxruntime::webgpu::ComputeContext& contex device_config.M == supported_config.M && device_config.N == supported_config.N && device_config.K == supported_config.K && - adapter_info.subgroupMinSize >= supported_config.subgroupMinSize && - adapter_info.subgroupMaxSize >= supported_config.subgroupMaxSize) { + adapter_info.subgroupMinSize == supported_config.subgroupMinSize && + adapter_info.subgroupMaxSize == supported_config.subgroupMaxSize) { config_index = index; return true; } @@ -322,12 +325,14 @@ bool CanApplySubgroupMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& cont bool has_subgroup_matrix = context.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); if (has_subgroup_matrix) { // Check if the adapter reports a subgroup matrix config we support. - has_subgroup_matrix = IsSubgroupMatrixConfigSupported(context, config_index); + has_subgroup_matrix = IsSubgroupMatrixConfigSupported(context, is_fp16, config_index); if (has_subgroup_matrix) { - const auto& config = supported_subgroup_matrix_configs[config_index]; - // F16 component type requires FP16 output and accuracy level 4 (FP16 precision). - if (config.componentType == wgpu::SubgroupMatrixComponentType::F16) { - has_subgroup_matrix = is_fp16 && accuracy_level == 4; + if (context.AdapterInfo().vendor == std::string_view{"apple"}) { + // For now SubgroupMatrixMatMulNBits is only supported for accuracy level 4, because with Fp16 there are + // some precision issues with subgroupMatrixMultiplyAccumulate. It is possible to support higher accuracy + // by setting compute_precision to Fp32, but that will be slower. For 1K token prefill FP16 Phi 3.5 is around 5s, + // FP32 is around 7s. + has_subgroup_matrix = accuracy_level == 4; } } } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_16x16x16.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_16x16x16.wgsl.template deleted file mode 100644 index 7c69b71c29580..0000000000000 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_16x16x16.wgsl.template +++ /dev/null @@ -1,265 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// SubgroupMatrix matmul_nbits kernel for 16x16x16 config (NVIDIA Blackwell) -// Uses subgroupMatrixLoad/Store/MultiplyAccumulate for hardware tensor core acceleration -// -// A: prepacked layout (from PrepackProgram) - elements arranged for subgroupMatrixLoad -// B: 4-bit or 8-bit quantized, dequantized into SLM (tile_b) -// -// Workgroup: 128 threads = 4 subgroups x 32 lanes -// Tile: kTileM(64) x kTileN(64) -// Each subgroup handles kSgMatM(16) M rows x 64 N columns (4 x kSgMatN N blocks) -// K dimension iterated in kTileK(32) blocks - -#param n_bits -#param has_zero_points -#param has_bias -#param has_weight_idx -#param has_weight_idx_indirect -#param sg_mat_k -#param sg_mat_m -#param sg_mat_n - -#use .getByOffset .setByOffset - -#include "quantization/matmul_nbits_zero_pt.wgsl.template" - -const kQuantizationBlockSize: u32 = 32; - -const kTileM: u32 = 64; // Tile size along M (output rows per workgroup) -const kTileN: u32 = 64; // Tile size along N (output columns per workgroup) -const kTileK: u32 = 32; // Tile size along K (reduction dimension per iteration) -const kSgMatM: u32 = u32(sg_mat_m); // Subgroup matrix M dimension (rows) -const kSgMatN: u32 = u32(sg_mat_n); // Subgroup matrix N dimension (columns) -const kSgMatK: u32 = u32(sg_mat_k); // Subgroup matrix K dimension (reduction) -const kSgMatSizeLeft: u32 = kSgMatM * kSgMatK; // Elements per left (A) subgroup matrix -const kSgMatStrideN: u32 = kSgMatN * kTileK; // Stride between N blocks in tile_b -const kSgMatCountM: u32 = kTileM / kSgMatM; // Number of subgroup matrices per tile along M - -// Shared local memory for dequantized B weights [kTileN, kTileK], column-major -var tile_b: array; - -// Scratch space for matmul results [kTileM, kTileN], row-major -// Used for bounds-checked output writes (both bias and non-bias paths). -var scratch: array; - -#if n_bits == 4 -// Dequantize 4-bit quantized B weights and store into tile_b [kTileN, kTileK] workgroup memory. -// 128 threads load kTileN(64) x kTileK(32). 2 threads per N, 16 K elements per thread. -// k_chunk_idx: which 16-element chunk of kTileK this thread handles (0..1). -fn dequant_b_to_tile(n_base: u32, k_idx: u32, n_idx: u32, k_chunk_idx: u32) { - let global_n = n_base + n_idx; - if (global_n >= uniforms.N) { - return; - } - let k_offset = k_chunk_idx * 16; -#if has_weight_idx -#if has_weight_idx_indirect - let actual_weight_idx = weight_index_indirect[uniforms.weight_idx]; -#else - let actual_weight_idx = uniforms.weight_idx; -#endif - let weight_offset = actual_weight_idx * uniforms.K * uniforms.N; - let scale_offset = actual_weight_idx * uniforms.N * (uniforms.K / kQuantizationBlockSize); -#else - const weight_offset : u32 = 0; - const scale_offset : u32 = 0; -#endif - let packed_idx = u32((global_n * uniforms.K + k_idx + k_offset) / 8) + weight_offset / 8; - let scale_idx = (global_n * uniforms.K + k_idx + k_offset) / kQuantizationBlockSize + scale_offset; - let scale = f16(scales_b.getByOffset(scale_idx)); - let zero = mm_read_zero( - global_n, (k_idx + k_offset) / kQuantizationBlockSize, uniforms.N, uniforms.zero_blocks_per_col); - // Each thread loads 2 packed u32 words = 16 elements - for (var step: u32 = 0; step < 2; step++) { - let packed_weights = input_b.getByOffset(packed_idx + step); - let weights_lo = - (vec4(unpack4xU8(packed_weights & 0x0F0F0F0Fu)) - vec4(zero)) * scale; - let weights_hi = - (vec4(unpack4xU8((packed_weights >> 4) & 0x0F0F0F0Fu)) - vec4(zero)) * scale; - let tile_base = n_idx * kTileK + k_offset + step * 8; - tile_b[tile_base] = weights_lo[0]; - tile_b[tile_base + 1] = weights_hi[0]; - tile_b[tile_base + 2] = weights_lo[1]; - tile_b[tile_base + 3] = weights_hi[1]; - tile_b[tile_base + 4] = weights_lo[2]; - tile_b[tile_base + 5] = weights_hi[2]; - tile_b[tile_base + 6] = weights_lo[3]; - tile_b[tile_base + 7] = weights_hi[3]; - } -} -#endif - -#if n_bits == 8 -// Dequantize 8-bit quantized B weights and store into tile_b [kTileN, kTileK] workgroup memory. -// 128 threads load kTileN(64) x kTileK(32). 2 threads per N, 16 K elements per thread. -// k_chunk_idx: which 16-element chunk of kTileK this thread handles (0..1). -fn dequant_b_to_tile(n_base: u32, k_idx: u32, n_idx: u32, k_chunk_idx: u32) { - let global_n = n_base + n_idx; - if (global_n >= uniforms.N) { - return; - } - let k_offset = k_chunk_idx * 16; -#if has_weight_idx -#if has_weight_idx_indirect - let actual_weight_idx = weight_index_indirect[uniforms.weight_idx]; -#else - let actual_weight_idx = uniforms.weight_idx; -#endif - let weight_offset = actual_weight_idx * uniforms.K * uniforms.N; - let scale_offset = actual_weight_idx * uniforms.N * (uniforms.K / kQuantizationBlockSize); -#else - const weight_offset : u32 = 0; - const scale_offset : u32 = 0; -#endif - let packed_idx = u32((global_n * uniforms.K + k_idx + k_offset) / 8) + weight_offset / 8; - let scale_idx = (global_n * uniforms.K + k_idx + k_offset) / kQuantizationBlockSize + scale_offset; - let scale = f16(scales_b.getByOffset(scale_idx)); - let zero = mm_read_zero( - global_n, (k_idx + k_offset) / kQuantizationBlockSize, uniforms.N, uniforms.zero_blocks_per_col); - // Each thread loads 2 packed vec2 words = 16 elements - for (var step: u32 = 0; step < 2; step++) { - let packed_weights = input_b.getByOffset(packed_idx + step); - let weights_lo = (vec4(unpack4xU8(packed_weights[0])) - vec4(zero)) * scale; - let weights_hi = (vec4(unpack4xU8(packed_weights[1])) - vec4(zero)) * scale; - let tile_base = n_idx * kTileK + k_offset + step * 8; - tile_b[tile_base] = weights_lo[0]; - tile_b[tile_base + 1] = weights_lo[1]; - tile_b[tile_base + 2] = weights_lo[2]; - tile_b[tile_base + 3] = weights_lo[3]; - tile_b[tile_base + 4] = weights_hi[0]; - tile_b[tile_base + 5] = weights_hi[1]; - tile_b[tile_base + 6] = weights_hi[2]; - tile_b[tile_base + 7] = weights_hi[3]; - } -} -#endif - -$MAIN { - let global_base_a = workgroup_id.y * kTileM; - let global_base_b = workgroup_id.x * kTileN; - - let sg_idx = u32(local_idx / sg_size); - let sg_mat_count_k = uniforms.K / kSgMatK; - let sg_mat_idx = (workgroup_id.y * kSgMatCountM + sg_idx) * sg_mat_count_k; - - var sg_mat_offset_a = sg_mat_idx * kSgMatSizeLeft; - - var sg_mat_c0: subgroup_matrix_result; - var sg_mat_c1: subgroup_matrix_result; - var sg_mat_c2: subgroup_matrix_result; - var sg_mat_c3: subgroup_matrix_result; - for (var k_idx: u32 = 0; k_idx < uniforms.K; k_idx += kTileK) { - // Load Phase: 128 threads, 2 threads per N row, each handles 16 K elements - dequant_b_to_tile(global_base_b, k_idx, local_idx / 2, local_idx % 2); - workgroupBarrier(); - - for (var sg_mat_k_idx: u32 = 0; sg_mat_k_idx < kTileK; sg_mat_k_idx += kSgMatK) - { - // Load A from global memory (prepacked layout). - // Syntax: subgroupMatrixLoad src_ptr, src_offset, is_col_major, src_stride - var sg_mat_a0: subgroup_matrix_left = - subgroupMatrixLoad>( - &input_a, sg_mat_offset_a, false, kSgMatK); - sg_mat_offset_a += kSgMatSizeLeft; - - // Load B from shared local memory. - // tile_b [kTileN, kTileK] is stored as column major. - var sg_mat_b0: subgroup_matrix_right = - subgroupMatrixLoad>( - &tile_b, sg_mat_k_idx, true, kTileK); - var sg_mat_b1: subgroup_matrix_right = - subgroupMatrixLoad>( - &tile_b, sg_mat_k_idx + kSgMatStrideN, true, kTileK); - var sg_mat_b2: subgroup_matrix_right = - subgroupMatrixLoad>( - &tile_b, sg_mat_k_idx + 2 * kSgMatStrideN, true, kTileK); - var sg_mat_b3: subgroup_matrix_right = - subgroupMatrixLoad>( - &tile_b, sg_mat_k_idx + 3 * kSgMatStrideN, true, kTileK); - - // Compute Phase - // Syntax: subgroupMatrixMultiplyAccumulate left, right, accumulate -> accumulate - sg_mat_c0 = subgroupMatrixMultiplyAccumulate(sg_mat_a0, sg_mat_b0, sg_mat_c0); - sg_mat_c1 = subgroupMatrixMultiplyAccumulate(sg_mat_a0, sg_mat_b1, sg_mat_c1); - sg_mat_c2 = subgroupMatrixMultiplyAccumulate(sg_mat_a0, sg_mat_b2, sg_mat_c2); - sg_mat_c3 = subgroupMatrixMultiplyAccumulate(sg_mat_a0, sg_mat_b3, sg_mat_c3); - } - workgroupBarrier(); - } - - // Write out -#if has_bias - // Store results to scratch workgroup memory, then add bias and write to output. - // scratch layout: [kTileM, kTileN] row-major - let scratch_m_base = sg_idx * kSgMatM; - subgroupMatrixStore(&scratch, scratch_m_base * kTileN, sg_mat_c0, false, kTileN); - subgroupMatrixStore(&scratch, scratch_m_base * kTileN + kSgMatN, sg_mat_c1, false, kTileN); - subgroupMatrixStore(&scratch, scratch_m_base * kTileN + 2 * kSgMatN, sg_mat_c2, false, kTileN); - subgroupMatrixStore(&scratch, scratch_m_base * kTileN + 3 * kSgMatN, sg_mat_c3, false, kTileN); - workgroupBarrier(); - - // 128 threads write 64x64 = 4096 elements. Each thread handles 32 elements. - // Thread mapping: m = local_idx / 2, n_base = (local_idx % 2) * 32 - let out_m = local_idx / 2; - let out_n_base = (local_idx % 2) * 32; - let global_m = global_base_a + out_m; - if (global_m < uniforms.M) { - let global_n_base = global_base_b + out_n_base; - let scratch_base = out_m * kTileN + out_n_base; - let out_base = global_m * uniforms.N + global_n_base; -#if has_weight_idx_indirect - let bias_offset = weight_index_indirect[uniforms.weight_idx] * uniforms.N; -#elif has_weight_idx - let bias_offset = uniforms.weight_idx * uniforms.N; -#else - const bias_offset: u32 = 0; -#endif - for (var i: u32 = 0; i < 32; i++) { - if (global_n_base + i < uniforms.N) { - let val = output_element_t(scratch[scratch_base + i]) - + bias[bias_offset + global_n_base + i]; - output.setByOffset(out_base + i, val); - } - } - } -#else - // Non-bias path: write results to output. - let sg_m_base = global_base_a + sg_idx * kSgMatM; - let full_block = sg_m_base + kSgMatM <= uniforms.M; - - if (full_block) { - // All rows in this block are within bounds - use fast subgroupMatrixStore to output. - let sg_mat_offset_c = sg_m_base * uniforms.N + global_base_b; - subgroupMatrixStore(&output, sg_mat_offset_c, sg_mat_c0, false, uniforms.N); - subgroupMatrixStore(&output, sg_mat_offset_c + kSgMatN, sg_mat_c1, false, uniforms.N); - subgroupMatrixStore(&output, sg_mat_offset_c + 2 * kSgMatN, sg_mat_c2, false, uniforms.N); - subgroupMatrixStore(&output, sg_mat_offset_c + 3 * kSgMatN, sg_mat_c3, false, uniforms.N); - } else { - // Partial block: store to scratch for bounds-checked scalar write. - let scratch_m_base = sg_idx * kSgMatM; - subgroupMatrixStore(&scratch, scratch_m_base * kTileN, sg_mat_c0, false, kTileN); - subgroupMatrixStore(&scratch, scratch_m_base * kTileN + kSgMatN, sg_mat_c1, false, kTileN); - subgroupMatrixStore(&scratch, scratch_m_base * kTileN + 2 * kSgMatN, sg_mat_c2, false, kTileN); - subgroupMatrixStore(&scratch, scratch_m_base * kTileN + 3 * kSgMatN, sg_mat_c3, false, kTileN); - } - workgroupBarrier(); - - if (!full_block) { - let out_m = local_idx / 2; - let out_n_base = (local_idx % 2) * 32; - let global_m = global_base_a + out_m; - if (global_m < uniforms.M) { - let global_n_base = global_base_b + out_n_base; - let scratch_base = out_m * kTileN + out_n_base; - let out_base = global_m * uniforms.N + global_n_base; - for (var i: u32 = 0; i < 32; i++) { - if (global_n_base + i < uniforms.N) { - output.setByOffset(out_base + i, output_element_t(scratch[scratch_base + i])); - } - } - } - } -#endif -} // MAIN From 95818212e1f01cbd8cd86261bbbc881befb8a267 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Thu, 23 Apr 2026 11:05:13 +0800 Subject: [PATCH 5/8] webgpu: Fix prepack buffer OOB in subgroup matrix matmul MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The prepack buffer was sized to ceil(M/sg_mat_m)*sg_mat_m rows, but the matmul shader dispatches workgroups covering ceil(M/tile_size_a)*tile_size_a rows. When M < tile_size_a (e.g. M=32 with 128x128 tiles), subgroups in the matmul shader would read past the end of the prepack buffer, causing a device-lost error. Fix: move tile size computation before prepack allocation and pad the prepack buffer to the workgroup tile size. Also remove unnecessary zero-fill in the prepack shader for OOB rows — the matmul shader already bounds-checks before storing output, so fully OOB prepack blocks can skip entirely. --- .../subgroup_matrix_matmul_nbits.cc | 36 ++++++++++--------- ..._matrix_matmul_nbits_prepack.wgsl.template | 13 ++++--- 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc index af02a614f1eed..415d21fffb1fe 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc @@ -238,9 +238,24 @@ Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Te Tensor* y, const uint32_t weight_index, const Tensor* weight_index_indirect) { + // Determine tile sizes first (needed for prepack padding). + const auto& config = supported_subgroup_matrix_configs[config_index]; + uint32_t tile_size_a = 32; + uint32_t tile_size_b = 64; + uint32_t work_group_size = 128; + if (config.M == 8 && config.N == 16 && config.K == 16) { + // 8x16x16 config: 8 subgroups, 256 threads, 64x64 tiles + tile_size_a = 64; + work_group_size = 256; + } else if (config.M == 16 && config.N == 16 && config.K == 16) { + // 16x16x16 config: 4 subgroups, 128 threads, 128x128 tiles + tile_size_a = 128; + tile_size_b = 128; + work_group_size = 128; + } + // If applicable, layout optimization of input matrix A(MxK) can be used for SubgroupMatrixLoad. Tensor a_prepack; - const auto& config = supported_subgroup_matrix_configs[config_index]; if (config.needsPrepack) { const auto m = config.M; const auto k = config.K; @@ -250,13 +265,15 @@ Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Te constexpr uint32_t kSubgroupSize = 32; prepack_program.SetWorkgroupSize(kSubgroupSize); - const auto dispatch_group_size_x = (M + m - 1) / m; + // Pad M to workgroup tile size so all subgroups read valid prepacked data. + const uint32_t padded_M = ((M + tile_size_a - 1) / tile_size_a) * tile_size_a; + const auto dispatch_group_size_x = padded_M / m; ORT_ENFORCE(K % k == 0, "K must be a multiple of ", k); const auto dispatch_group_size_y = K / k; // Each workgroup will process one subgroup matrix of size m x k. prepack_program.SetDispatchGroupSize(dispatch_group_size_x, dispatch_group_size_y, 1); - TensorShape a_prepack_shape{dispatch_group_size_x * m, K}; + TensorShape a_prepack_shape{padded_M, K}; a_prepack = context.CreateGPUTensor(a->DataType(), a_prepack_shape); prepack_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, 1}}) .AddOutputs({{&a_prepack, ProgramTensorMetadataDependency::Rank, a_prepack.Shape(), 1}}) @@ -266,9 +283,6 @@ Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Te a = &a_prepack; } - uint32_t tile_size_a = 32; - uint32_t tile_size_b = 64; - uint32_t work_group_size = 128; constexpr uint32_t kU32Components = 4; TensorShape y_shape{1, M, N}; const bool has_zero_points = zero_points != nullptr; @@ -276,16 +290,6 @@ Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Te const bool has_weight_idx_indirect = weight_index_indirect != nullptr; const bool has_weight_idx = weight_index > 0 || has_weight_idx_indirect; SubgroupMatrixMatMulNBitsProgram mul_program{nbits, config_index, has_zero_points, has_bias, has_weight_idx, has_weight_idx_indirect}; - if (config.M == 8 && config.N == 16 && config.K == 16) { - // 8x16x16 config: 8 subgroups, 256 threads, 64x64 tiles - tile_size_a = 64; - work_group_size = 256; - } else if (config.M == 16 && config.N == 16 && config.K == 16) { - // 16x16x16 config: 4 subgroups, 128 threads, 128x128 tiles - tile_size_a = 128; - tile_size_b = 128; - work_group_size = 128; - } mul_program.SetWorkgroupSize(work_group_size); mul_program.SetDispatchGroupSize( (N + tile_size_b - 1) / tile_size_b, diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_prepack.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_prepack.wgsl.template index e3e65b667ff6e..6c4d5e0376f37 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_prepack.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_prepack.wgsl.template @@ -4,7 +4,11 @@ // SubgroupMatrix prepack kernel // Rearranges input matrix A(MxK) so that each subgroup matrix (sg_mat_m x sg_mat_k) // has its elements laid out contiguously in memory for subgroupMatrixLoad. -// OOB rows (beyond M) are zero-filled to avoid undefined behavior. +// +// The prepack buffer is padded to the workgroup tile size (which may exceed M). +// Fully OOB blocks (row_base >= M) are skipped entirely, and partial blocks +// only copy in-bounds rows. The padding region is left uninitialized since the +// matmul shader bounds-checks (r < M) before storing any output. #param sg_mat_k #param sg_mat_m @@ -24,16 +28,15 @@ $MAIN { var mat: subgroup_matrix_left = subgroupMatrixLoad>(&input_a, in_offset, false, K); subgroupMatrixStore(&output_a, out_offset, mat, false, kSgMatK); - } else { - // Partial block: some rows are OOB. Use scalar copy with zero-fill. + } else if (row_base < M) { + // Partial block: some rows are OOB. Use scalar copy for in-bounds rows only. for (var r: u32 = local_idx; r < kSgMatM * kSgMatK; r += workgroup_size_x) { let row = r / kSgMatK; let col = r % kSgMatK; if (row_base + row < M) { output_a[out_offset + r] = input_a[in_offset + row * K + col]; - } else { - output_a[out_offset + r] = f16(0.0); } } } + // Fully OOB blocks (row_base >= M): skip entirely. } // MAIN From a3f7051d5f52e3557878d67267086f91b2c8fa9b Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Thu, 23 Apr 2026 14:08:22 +0800 Subject: [PATCH 6/8] Add test coverage for SubgroupMatrixMatMulNBits 16x16x16 path Add M=100, N=256, K=128, block_size=32 test cases to Float16_4b_Accuracy0 and Float16_4b_Accuracy4. These dimensions meet SubgroupMatrix constraints (block_size=32, N%64==0, K%32==0) and M=100 is non-128-aligned, exercising the prepack buffer padding fix from commit 95818212e1. --- onnxruntime/test/contrib_ops/matmul_4bits_test.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index b463aa3a6c363..ff9ab9e3552d9 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -465,6 +465,7 @@ TEST(MatMulNBits, Float16_4b_Accuracy0) { TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); } TEST(MatMulNBits, Float16_4b_Accuracy4) { @@ -495,6 +496,7 @@ TEST(MatMulNBits, Float16_4b_Accuracy4) { TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); // See PR #27412 for details on the following test case, // which is added to cover a specific failure case in the past. From d372b3336673b3261bd6e816adf3fead081dff02 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Thu, 23 Apr 2026 17:27:39 +0800 Subject: [PATCH 7/8] Remove unnecessary workgroupBarrier in 16x16x16 edge-tile output, add N=192 test Remove 63 workgroupBarrier() calls from the bias and non-bias edge-tile output paths in the 16x16x16_128 shader. These barriers are unnecessary because subgroupMatrixStore and subsequent scalar reads from coopmat_stage execute within the same subgroup (lockstep), and each subgroup writes to its own non-overlapping stage_base region. Also add N=192 (partial N tile, not divisible by 128) test cases and remove a duplicate #include in subgroup_matrix_matmul_nbits.h. --- .../subgroup_matrix_matmul_nbits.h | 1 - ...ix_matmul_nbits_16x16x16_128.wgsl.template | 67 ++----------------- .../test/contrib_ops/matmul_4bits_test.cc | 2 + 3 files changed, 7 insertions(+), 63 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h index 1b0c1d336545c..810bda950b169 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h @@ -5,7 +5,6 @@ #if !defined(__wasm__) -#include "core/providers/webgpu/program.h" #include "core/providers/webgpu/compute_context.h" #include "core/providers/webgpu/program.h" #include "core/providers/webgpu/shader_helper.h" diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_16x16x16_128.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_16x16x16_128.wgsl.template index 0c49b880e2576..839139950da18 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_16x16x16_128.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_16x16x16_128.wgsl.template @@ -333,9 +333,11 @@ $MAIN { // Helper macro: store one coopmat tile via stage, add bias, write to output // Process all 16 tiles (4 rows x 4 cols) + // No workgroupBarrier needed: subgroupMatrixStore and scalar reads are within the same + // subgroup (lockstep execution), and each subgroup uses its own non-overlapping stage_base. + // Row 0 subgroupMatrixStore(&coopmat_stage, stage_base, c00, false, kSgMatN); - workgroupBarrier(); for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { let r = dr + i / kSgMatN; let c = dc + i % kSgMatN; @@ -344,10 +346,8 @@ $MAIN { output.setByOffset(r * stride + c, val); } } - workgroupBarrier(); subgroupMatrixStore(&coopmat_stage, stage_base, c01, false, kSgMatN); - workgroupBarrier(); for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { let r = dr + i / kSgMatN; let c = dc + 16u + i % kSgMatN; @@ -356,10 +356,8 @@ $MAIN { output.setByOffset(r * stride + c, val); } } - workgroupBarrier(); subgroupMatrixStore(&coopmat_stage, stage_base, c02, false, kSgMatN); - workgroupBarrier(); for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { let r = dr + i / kSgMatN; let c = dc + 32u + i % kSgMatN; @@ -368,10 +366,8 @@ $MAIN { output.setByOffset(r * stride + c, val); } } - workgroupBarrier(); subgroupMatrixStore(&coopmat_stage, stage_base, c03, false, kSgMatN); - workgroupBarrier(); for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { let r = dr + i / kSgMatN; let c = dc + 48u + i % kSgMatN; @@ -380,11 +376,9 @@ $MAIN { output.setByOffset(r * stride + c, val); } } - workgroupBarrier(); // Row 1 subgroupMatrixStore(&coopmat_stage, stage_base, c10, false, kSgMatN); - workgroupBarrier(); for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { let r = dr + 16u + i / kSgMatN; let c = dc + i % kSgMatN; @@ -393,10 +387,8 @@ $MAIN { output.setByOffset(r * stride + c, val); } } - workgroupBarrier(); subgroupMatrixStore(&coopmat_stage, stage_base, c11, false, kSgMatN); - workgroupBarrier(); for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { let r = dr + 16u + i / kSgMatN; let c = dc + 16u + i % kSgMatN; @@ -405,10 +397,8 @@ $MAIN { output.setByOffset(r * stride + c, val); } } - workgroupBarrier(); subgroupMatrixStore(&coopmat_stage, stage_base, c12, false, kSgMatN); - workgroupBarrier(); for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { let r = dr + 16u + i / kSgMatN; let c = dc + 32u + i % kSgMatN; @@ -417,10 +407,8 @@ $MAIN { output.setByOffset(r * stride + c, val); } } - workgroupBarrier(); subgroupMatrixStore(&coopmat_stage, stage_base, c13, false, kSgMatN); - workgroupBarrier(); for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { let r = dr + 16u + i / kSgMatN; let c = dc + 48u + i % kSgMatN; @@ -429,11 +417,9 @@ $MAIN { output.setByOffset(r * stride + c, val); } } - workgroupBarrier(); // Row 2 subgroupMatrixStore(&coopmat_stage, stage_base, c20, false, kSgMatN); - workgroupBarrier(); for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { let r = dr + 32u + i / kSgMatN; let c = dc + i % kSgMatN; @@ -442,10 +428,8 @@ $MAIN { output.setByOffset(r * stride + c, val); } } - workgroupBarrier(); subgroupMatrixStore(&coopmat_stage, stage_base, c21, false, kSgMatN); - workgroupBarrier(); for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { let r = dr + 32u + i / kSgMatN; let c = dc + 16u + i % kSgMatN; @@ -454,10 +438,8 @@ $MAIN { output.setByOffset(r * stride + c, val); } } - workgroupBarrier(); subgroupMatrixStore(&coopmat_stage, stage_base, c22, false, kSgMatN); - workgroupBarrier(); for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { let r = dr + 32u + i / kSgMatN; let c = dc + 32u + i % kSgMatN; @@ -466,10 +448,8 @@ $MAIN { output.setByOffset(r * stride + c, val); } } - workgroupBarrier(); subgroupMatrixStore(&coopmat_stage, stage_base, c23, false, kSgMatN); - workgroupBarrier(); for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { let r = dr + 32u + i / kSgMatN; let c = dc + 48u + i % kSgMatN; @@ -478,11 +458,9 @@ $MAIN { output.setByOffset(r * stride + c, val); } } - workgroupBarrier(); // Row 3 subgroupMatrixStore(&coopmat_stage, stage_base, c30, false, kSgMatN); - workgroupBarrier(); for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { let r = dr + 48u + i / kSgMatN; let c = dc + i % kSgMatN; @@ -491,10 +469,8 @@ $MAIN { output.setByOffset(r * stride + c, val); } } - workgroupBarrier(); subgroupMatrixStore(&coopmat_stage, stage_base, c31, false, kSgMatN); - workgroupBarrier(); for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { let r = dr + 48u + i / kSgMatN; let c = dc + 16u + i % kSgMatN; @@ -503,10 +479,8 @@ $MAIN { output.setByOffset(r * stride + c, val); } } - workgroupBarrier(); subgroupMatrixStore(&coopmat_stage, stage_base, c32, false, kSgMatN); - workgroupBarrier(); for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { let r = dr + 48u + i / kSgMatN; let c = dc + 32u + i % kSgMatN; @@ -515,10 +489,8 @@ $MAIN { output.setByOffset(r * stride + c, val); } } - workgroupBarrier(); subgroupMatrixStore(&coopmat_stage, stage_base, c33, false, kSgMatN); - workgroupBarrier(); for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { let r = dr + 48u + i / kSgMatN; let c = dc + 48u + i % kSgMatN; @@ -553,11 +525,12 @@ $MAIN { subgroupMatrixStore(&output, base + 48u * stride + 48u, c33, false, stride); } else { // Edge tile: store via coopmat_stage with bounds checking + // No workgroupBarrier needed: subgroupMatrixStore and scalar reads are within the same + // subgroup (lockstep execution), and each subgroup uses its own non-overlapping stage_base. let stage_base = sg_idx * kSgMatM * kSgMatN; // Row 0 subgroupMatrixStore(&coopmat_stage, stage_base, c00, false, kSgMatN); - workgroupBarrier(); for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { let r = dr + i / kSgMatN; let c = dc + i % kSgMatN; @@ -565,10 +538,8 @@ $MAIN { output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); } } - workgroupBarrier(); subgroupMatrixStore(&coopmat_stage, stage_base, c01, false, kSgMatN); - workgroupBarrier(); for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { let r = dr + i / kSgMatN; let c = dc + 16u + i % kSgMatN; @@ -576,10 +547,8 @@ $MAIN { output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); } } - workgroupBarrier(); subgroupMatrixStore(&coopmat_stage, stage_base, c02, false, kSgMatN); - workgroupBarrier(); for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { let r = dr + i / kSgMatN; let c = dc + 32u + i % kSgMatN; @@ -587,10 +556,8 @@ $MAIN { output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); } } - workgroupBarrier(); subgroupMatrixStore(&coopmat_stage, stage_base, c03, false, kSgMatN); - workgroupBarrier(); for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { let r = dr + i / kSgMatN; let c = dc + 48u + i % kSgMatN; @@ -598,11 +565,9 @@ $MAIN { output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); } } - workgroupBarrier(); // Row 1 subgroupMatrixStore(&coopmat_stage, stage_base, c10, false, kSgMatN); - workgroupBarrier(); for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { let r = dr + 16u + i / kSgMatN; let c = dc + i % kSgMatN; @@ -610,10 +575,8 @@ $MAIN { output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); } } - workgroupBarrier(); subgroupMatrixStore(&coopmat_stage, stage_base, c11, false, kSgMatN); - workgroupBarrier(); for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { let r = dr + 16u + i / kSgMatN; let c = dc + 16u + i % kSgMatN; @@ -621,10 +584,8 @@ $MAIN { output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); } } - workgroupBarrier(); subgroupMatrixStore(&coopmat_stage, stage_base, c12, false, kSgMatN); - workgroupBarrier(); for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { let r = dr + 16u + i / kSgMatN; let c = dc + 32u + i % kSgMatN; @@ -632,10 +593,8 @@ $MAIN { output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); } } - workgroupBarrier(); subgroupMatrixStore(&coopmat_stage, stage_base, c13, false, kSgMatN); - workgroupBarrier(); for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { let r = dr + 16u + i / kSgMatN; let c = dc + 48u + i % kSgMatN; @@ -643,11 +602,9 @@ $MAIN { output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); } } - workgroupBarrier(); // Row 2 subgroupMatrixStore(&coopmat_stage, stage_base, c20, false, kSgMatN); - workgroupBarrier(); for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { let r = dr + 32u + i / kSgMatN; let c = dc + i % kSgMatN; @@ -655,10 +612,8 @@ $MAIN { output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); } } - workgroupBarrier(); subgroupMatrixStore(&coopmat_stage, stage_base, c21, false, kSgMatN); - workgroupBarrier(); for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { let r = dr + 32u + i / kSgMatN; let c = dc + 16u + i % kSgMatN; @@ -666,10 +621,8 @@ $MAIN { output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); } } - workgroupBarrier(); subgroupMatrixStore(&coopmat_stage, stage_base, c22, false, kSgMatN); - workgroupBarrier(); for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { let r = dr + 32u + i / kSgMatN; let c = dc + 32u + i % kSgMatN; @@ -677,10 +630,8 @@ $MAIN { output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); } } - workgroupBarrier(); subgroupMatrixStore(&coopmat_stage, stage_base, c23, false, kSgMatN); - workgroupBarrier(); for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { let r = dr + 32u + i / kSgMatN; let c = dc + 48u + i % kSgMatN; @@ -688,11 +639,9 @@ $MAIN { output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); } } - workgroupBarrier(); // Row 3 subgroupMatrixStore(&coopmat_stage, stage_base, c30, false, kSgMatN); - workgroupBarrier(); for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { let r = dr + 48u + i / kSgMatN; let c = dc + i % kSgMatN; @@ -700,10 +649,8 @@ $MAIN { output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); } } - workgroupBarrier(); subgroupMatrixStore(&coopmat_stage, stage_base, c31, false, kSgMatN); - workgroupBarrier(); for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { let r = dr + 48u + i / kSgMatN; let c = dc + 16u + i % kSgMatN; @@ -711,10 +658,8 @@ $MAIN { output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); } } - workgroupBarrier(); subgroupMatrixStore(&coopmat_stage, stage_base, c32, false, kSgMatN); - workgroupBarrier(); for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { let r = dr + 48u + i / kSgMatN; let c = dc + 32u + i % kSgMatN; @@ -722,10 +667,8 @@ $MAIN { output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); } } - workgroupBarrier(); subgroupMatrixStore(&coopmat_stage, stage_base, c33, false, kSgMatN); - workgroupBarrier(); for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { let r = dr + 48u + i / kSgMatN; let c = dc + 48u + i % kSgMatN; diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index ff9ab9e3552d9..435995268c437 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -466,6 +466,7 @@ TEST(MatMulNBits, Float16_4b_Accuracy0) { TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); } TEST(MatMulNBits, Float16_4b_Accuracy4) { @@ -497,6 +498,7 @@ TEST(MatMulNBits, Float16_4b_Accuracy4) { TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); // See PR #27412 for details on the following test case, // which is added to cover a specific failure case in the past. From f061947b7a72060b42f50932cd135c8b57e3ef14 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Thu, 23 Apr 2026 17:38:50 +0800 Subject: [PATCH 8/8] Address PR #28109 review comments - Skip F32 subgroup matrix configs when output is FP16 (symmetric with the existing F16-skip-when-F32 filter) - Fix misleading prepack comment: not all matmul shaders bounds-check edge tiles, so don't promise they do --- .../webgpu/quantization/subgroup_matrix_matmul_nbits.cc | 4 +++- .../subgroup_matrix_matmul_nbits_prepack.wgsl.template | 5 +++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc index 415d21fffb1fe..cdc0f1ded3e45 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc @@ -73,7 +73,9 @@ bool IsSubgroupMatrixConfigSupported(onnxruntime::webgpu::ComputeContext& contex int32_t index = 0; for (const auto& supported_config : supported_subgroup_matrix_configs) { // F16 configs require FP16 output; skip them when output is F32. - if (supported_config.componentType == wgpu::SubgroupMatrixComponentType::F16 && !is_fp16) { + // F32 configs require FP32 output; skip them when output is FP16. + if ((supported_config.componentType == wgpu::SubgroupMatrixComponentType::F16 && !is_fp16) || + (supported_config.componentType == wgpu::SubgroupMatrixComponentType::F32 && is_fp16)) { index++; continue; } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_prepack.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_prepack.wgsl.template index 6c4d5e0376f37..0172dc223007b 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_prepack.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_prepack.wgsl.template @@ -7,8 +7,9 @@ // // The prepack buffer is padded to the workgroup tile size (which may exceed M). // Fully OOB blocks (row_base >= M) are skipped entirely, and partial blocks -// only copy in-bounds rows. The padding region is left uninitialized since the -// matmul shader bounds-checks (r < M) before storing any output. +// only copy in-bounds rows. Padding corresponding to rows >= M may therefore +// remain uninitialized, so downstream shaders must not rely on it unless they +// handle edge tiles explicitly. #param sg_mat_k #param sg_mat_m