diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 4eb8f5b9a17dd..b4e3991344089 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -172,7 +172,7 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context * @return Status indicating whether the operation was successful or if an error occurred. * * @note Special optimizations are considered: - * - Subgroup matrix multiplication for eligible Apple/Intel GPUs. + * - Subgroup matrix multiplication for GPUs with supported configs. * - DP4A-based multiplication on FP32-only GPUs for specific dimensions and conditions. * - A wide tile program is used when block size, component count, and other criteria are met. * - Otherwise, a default matmul program is used. @@ -227,8 +227,8 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, #if !defined(__wasm__) int32_t subgroup_matrix_config_index = -1; - // apple|intel - Experimental dawn support for subgroup matrix matmul. - if ((M >= kMinMForTileOptimization && !has_weight_idx_indirect) && (context.AdapterInfo().vendor == std::string_view{"apple"} || context.AdapterInfo().vendor == std::string_view{"intel"}) && + // Experimental dawn support for subgroup matrix matmul (vendor-agnostic). + if ((M >= kMinMForTileOptimization && !has_weight_idx_indirect) && CanApplySubgroupMatrixMatMulNBits(context, accuracy_level, block_size, batch_count, N, K, static_cast(nbits), y->DataType() == DataTypeImpl::GetType(), subgroup_matrix_config_index)) { return ApplySubgroupMatrixMatMulNBits(a, b, scales, zero_points, bias, M, N, K, static_cast(nbits), zero_blocks_per_col, subgroup_matrix_config_index, context, y, weight_index, weight_index_indirect); } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc index 15aa74c6dcd57..cdc0f1ded3e45 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc @@ -2,7 +2,6 @@ // Licensed under the MIT License. #if !defined(__wasm__) -#include #include "contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h" #include "contrib_ops/webgpu/quantization/matmul_nbits_common.h" @@ -45,25 +44,50 @@ static_assert(ValidateComponentTypeName<4>({wgpu::SubgroupMatrixComponentType::F wgpu::SubgroupMatrixComponentType::I32}), "The elements' sequence of ComponentTypeName array do not match wgpu::SubgroupMatrixComponentType"); -// std::tuple -static const std::tuple - intel_supported_subgroup_matrix_configs[] = { - {"xe-2lpg", wgpu::BackendType::Vulkan, wgpu::SubgroupMatrixComponentType::F16, wgpu::SubgroupMatrixComponentType::F16, 8, 16, 16, 16, 32}, - {"xe-3lpg", wgpu::BackendType::Vulkan, wgpu::SubgroupMatrixComponentType::F16, wgpu::SubgroupMatrixComponentType::F16, 8, 16, 16, 16, 32}}; +// Vendor-agnostic subgroup matrix config: {componentType, resultComponentType, M, N, K, subgroupMinSize, subgroupMaxSize, needsPrepack} +// Any GPU reporting a matching config from wgpu::AdapterPropertiesSubgroupMatrixConfigs is supported. +struct SupportedSubgroupMatrixConfig { + wgpu::SubgroupMatrixComponentType componentType; + wgpu::SubgroupMatrixComponentType resultComponentType; + uint32_t M; + uint32_t N; + uint32_t K; + uint32_t subgroupMinSize; + uint32_t subgroupMaxSize; + bool needsPrepack; // Whether input A needs layout optimization for subgroupMatrixLoad +}; + +static const SupportedSubgroupMatrixConfig supported_subgroup_matrix_configs[] = { + // 16x16x16 config with 128x128 tiles (NVIDIA Blackwell, subgroup size 32) + {wgpu::SubgroupMatrixComponentType::F16, wgpu::SubgroupMatrixComponentType::F16, 16, 16, 16, 32, 32, true}, + // 8x16x16 config (Intel Xe2/Xe3, subgroup size 16-32) + {wgpu::SubgroupMatrixComponentType::F16, wgpu::SubgroupMatrixComponentType::F16, 8, 16, 16, 16, 32, true}, + // 8x8x8 config (Apple M-series, etc.) + {wgpu::SubgroupMatrixComponentType::F16, wgpu::SubgroupMatrixComponentType::F16, 8, 8, 8, 32, 32, false}, + {wgpu::SubgroupMatrixComponentType::F32, wgpu::SubgroupMatrixComponentType::F32, 8, 8, 8, 32, 32, false}, +}; -bool IsSubgroupMatrixConfigSupportedOnIntel(onnxruntime::webgpu::ComputeContext& context, int32_t& config_index) { +bool IsSubgroupMatrixConfigSupported(onnxruntime::webgpu::ComputeContext& context, bool is_fp16, int32_t& config_index) { const wgpu::AdapterInfo& adapter_info = context.AdapterInfo(); const wgpu::AdapterPropertiesSubgroupMatrixConfigs& subgroup_matrix_configs = context.SubgroupMatrixConfigs(); int32_t index = 0; - for (auto& supported_config : intel_supported_subgroup_matrix_configs) { + for (const auto& supported_config : supported_subgroup_matrix_configs) { + // F16 configs require FP16 output; skip them when output is F32. + // F32 configs require FP32 output; skip them when output is FP16. + if ((supported_config.componentType == wgpu::SubgroupMatrixComponentType::F16 && !is_fp16) || + (supported_config.componentType == wgpu::SubgroupMatrixComponentType::F32 && is_fp16)) { + index++; + continue; + } for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) { - auto& subgroup_matrix_config = subgroup_matrix_configs.configs[i]; - auto&& config = std::make_tuple(adapter_info.architecture, adapter_info.backendType, - subgroup_matrix_config.componentType, subgroup_matrix_config.resultComponentType, - subgroup_matrix_config.M, subgroup_matrix_config.N, subgroup_matrix_config.K, - adapter_info.subgroupMinSize, adapter_info.subgroupMaxSize); - if (config == supported_config) { + const auto& device_config = subgroup_matrix_configs.configs[i]; + if (device_config.componentType == supported_config.componentType && + device_config.resultComponentType == supported_config.resultComponentType && + device_config.M == supported_config.M && + device_config.N == supported_config.N && + device_config.K == supported_config.K && + adapter_info.subgroupMinSize == supported_config.subgroupMinSize && + adapter_info.subgroupMaxSize == supported_config.subgroupMaxSize) { config_index = index; return true; } @@ -117,31 +141,53 @@ Status PrepackProgram::GenerateShaderCode(ShaderHelper& shader) const { WGSL_TEMPLATE_PARAMETER(sg_mat_m, m_)); } -Status GenerateShaderCodeOnIntel(ShaderHelper& shader, +Status GenerateShaderCode16x16x16(ShaderHelper& shader, + const ShaderVariableHelper& b, + const ShaderVariableHelper& scales_b, + const ShaderVariableHelper& output, + uint32_t nbits, int32_t config_index, bool has_zero_points, bool has_bias, bool has_weight_idx, bool has_weight_idx_indirect) { + const auto& config = supported_subgroup_matrix_configs[config_index]; + // Use 128x128 tile shader for 16x16x16 config (index 0) + return WGSL_TEMPLATE_APPLY(shader, "quantization/subgroup_matrix_matmul_nbits_16x16x16_128.wgsl.template", + WGSL_TEMPLATE_PARAMETER(has_bias, has_bias), + WGSL_TEMPLATE_PARAMETER(has_weight_idx, has_weight_idx), + WGSL_TEMPLATE_PARAMETER(has_weight_idx_indirect, has_weight_idx_indirect), + WGSL_TEMPLATE_PARAMETER(has_zero_points, has_zero_points), + WGSL_TEMPLATE_PARAMETER(n_bits, nbits), + WGSL_TEMPLATE_PARAMETER(output_type_i32, false), + WGSL_TEMPLATE_PARAMETER(sg_mat_k, config.K), + WGSL_TEMPLATE_PARAMETER(sg_mat_m, config.M), + WGSL_TEMPLATE_PARAMETER(sg_mat_n, config.N), + WGSL_TEMPLATE_VARIABLE(input_b, b), + WGSL_TEMPLATE_VARIABLE(output, output), + WGSL_TEMPLATE_VARIABLE(scales_b, scales_b)); +} + +Status GenerateShaderCode8x16x16(ShaderHelper& shader, const ShaderVariableHelper& b, const ShaderVariableHelper& scales_b, const ShaderVariableHelper& output, uint32_t nbits, int32_t config_index, bool has_zero_points, bool has_bias, bool has_weight_idx, bool has_weight_idx_indirect) { - auto& config = intel_supported_subgroup_matrix_configs[config_index]; - return WGSL_TEMPLATE_APPLY(shader, "quantization/subgroup_matrix_matmul_nbits_intel.wgsl.template", + const auto& config = supported_subgroup_matrix_configs[config_index]; + return WGSL_TEMPLATE_APPLY(shader, "quantization/subgroup_matrix_matmul_nbits_8x16x16.wgsl.template", WGSL_TEMPLATE_PARAMETER(has_bias, has_bias), WGSL_TEMPLATE_PARAMETER(has_weight_idx, has_weight_idx), WGSL_TEMPLATE_PARAMETER(has_weight_idx_indirect, has_weight_idx_indirect), WGSL_TEMPLATE_PARAMETER(has_zero_points, has_zero_points), WGSL_TEMPLATE_PARAMETER(n_bits, nbits), WGSL_TEMPLATE_PARAMETER(output_type_i32, false), - WGSL_TEMPLATE_PARAMETER(sg_mat_k, std::get<6>(config)), - WGSL_TEMPLATE_PARAMETER(sg_mat_m, std::get<4>(config)), - WGSL_TEMPLATE_PARAMETER(sg_mat_n, std::get<5>(config)), + WGSL_TEMPLATE_PARAMETER(sg_mat_k, config.K), + WGSL_TEMPLATE_PARAMETER(sg_mat_m, config.M), + WGSL_TEMPLATE_PARAMETER(sg_mat_n, config.N), WGSL_TEMPLATE_VARIABLE(input_b, b), WGSL_TEMPLATE_VARIABLE(output, output), WGSL_TEMPLATE_VARIABLE(scales_b, scales_b)); } -Status GenerateShaderCodeOnApple(ShaderHelper& shader, const ShaderVariableHelper& a, const ShaderVariableHelper& b, - const ShaderVariableHelper& scales_b, - const ShaderVariableHelper& output, uint32_t nbits, bool has_zero_points, bool has_bias, bool has_weight_idx, bool has_weight_idx_indirect) { - return WGSL_TEMPLATE_APPLY(shader, "quantization/subgroup_matrix_matmul_nbits_apple.wgsl.template", +Status GenerateShaderCode8x8x8(ShaderHelper& shader, const ShaderVariableHelper& a, const ShaderVariableHelper& b, + const ShaderVariableHelper& scales_b, + const ShaderVariableHelper& output, uint32_t nbits, bool has_zero_points, bool has_bias, bool has_weight_idx, bool has_weight_idx_indirect) { + return WGSL_TEMPLATE_APPLY(shader, "quantization/subgroup_matrix_matmul_nbits_8x8x8.wgsl.template", WGSL_TEMPLATE_PARAMETER(has_bias, has_bias), WGSL_TEMPLATE_PARAMETER(has_weight_idx, has_weight_idx), WGSL_TEMPLATE_PARAMETER(has_weight_idx_indirect, has_weight_idx_indirect), @@ -169,13 +215,16 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader } const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); - if (!vendor_.compare("apple")) { - return GenerateShaderCodeOnApple(shader, a, b, scales_b, output, nbits_, has_zero_points_, has_bias_, has_weight_idx_, has_weight_idx_indirect_); - } else if (!vendor_.compare("intel")) { - return GenerateShaderCodeOnIntel(shader, b, scales_b, output, nbits_, config_index_, has_zero_points_, has_bias_, has_weight_idx_, has_weight_idx_indirect_); + const auto& config = supported_subgroup_matrix_configs[config_index_]; + if (config.M == 8 && config.N == 8 && config.K == 8) { + return GenerateShaderCode8x8x8(shader, a, b, scales_b, output, nbits_, has_zero_points_, has_bias_, has_weight_idx_, has_weight_idx_indirect_); + } else if (config.M == 8 && config.N == 16 && config.K == 16) { + return GenerateShaderCode8x16x16(shader, b, scales_b, output, nbits_, config_index_, has_zero_points_, has_bias_, has_weight_idx_, has_weight_idx_indirect_); + } else if (config.M == 16 && config.N == 16 && config.K == 16) { + return GenerateShaderCode16x16x16(shader, b, scales_b, output, nbits_, config_index_, has_zero_points_, has_bias_, has_weight_idx_, has_weight_idx_indirect_); } else { return Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::NOT_IMPLEMENTED, - "onnxruntime does not support subgroup matrix on this vendor."); + "Unsupported subgroup matrix config dimensions."); } } @@ -191,25 +240,42 @@ Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Te Tensor* y, const uint32_t weight_index, const Tensor* weight_index_indirect) { + // Determine tile sizes first (needed for prepack padding). + const auto& config = supported_subgroup_matrix_configs[config_index]; + uint32_t tile_size_a = 32; + uint32_t tile_size_b = 64; + uint32_t work_group_size = 128; + if (config.M == 8 && config.N == 16 && config.K == 16) { + // 8x16x16 config: 8 subgroups, 256 threads, 64x64 tiles + tile_size_a = 64; + work_group_size = 256; + } else if (config.M == 16 && config.N == 16 && config.K == 16) { + // 16x16x16 config: 4 subgroups, 128 threads, 128x128 tiles + tile_size_a = 128; + tile_size_b = 128; + work_group_size = 128; + } + // If applicable, layout optimization of input matrix A(MxK) can be used for SubgroupMatrixLoad. Tensor a_prepack; - if (context.AdapterInfo().vendor == std::string_view{"intel"}) { - const auto& config = intel_supported_subgroup_matrix_configs[config_index]; - const auto m = std::get<4>(config); - const auto k = std::get<6>(config); + if (config.needsPrepack) { + const auto m = config.M; + const auto k = config.K; // Optimize the layout of input matrix A(MxK) for SubgroupMatrixLoad. PrepackProgram prepack_program{m, k}; constexpr uint32_t kSubgroupSize = 32; prepack_program.SetWorkgroupSize(kSubgroupSize); - const auto dispatch_group_size_x = (M + m - 1) / m; + // Pad M to workgroup tile size so all subgroups read valid prepacked data. + const uint32_t padded_M = ((M + tile_size_a - 1) / tile_size_a) * tile_size_a; + const auto dispatch_group_size_x = padded_M / m; ORT_ENFORCE(K % k == 0, "K must be a multiple of ", k); const auto dispatch_group_size_y = K / k; // Each workgroup will process one subgroup matrix of size m x k. prepack_program.SetDispatchGroupSize(dispatch_group_size_x, dispatch_group_size_y, 1); - TensorShape a_prepack_shape{dispatch_group_size_x * m, K}; + TensorShape a_prepack_shape{padded_M, K}; a_prepack = context.CreateGPUTensor(a->DataType(), a_prepack_shape); prepack_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, 1}}) .AddOutputs({{&a_prepack, ProgramTensorMetadataDependency::Rank, a_prepack.Shape(), 1}}) @@ -219,23 +285,16 @@ Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Te a = &a_prepack; } - uint32_t tile_size_a = 32; - uint32_t work_group_size = 128; - constexpr uint32_t kTileSizeB = 64; constexpr uint32_t kU32Components = 4; TensorShape y_shape{1, M, N}; const bool has_zero_points = zero_points != nullptr; const bool has_bias = bias != nullptr; const bool has_weight_idx_indirect = weight_index_indirect != nullptr; const bool has_weight_idx = weight_index > 0 || has_weight_idx_indirect; - SubgroupMatrixMatMulNBitsProgram mul_program{nbits, config_index, context.AdapterInfo().vendor, has_zero_points, has_bias, has_weight_idx, has_weight_idx_indirect}; - if (context.AdapterInfo().vendor == std::string_view{"intel"}) { - tile_size_a = 64; - work_group_size = 256; - } + SubgroupMatrixMatMulNBitsProgram mul_program{nbits, config_index, has_zero_points, has_bias, has_weight_idx, has_weight_idx_indirect}; mul_program.SetWorkgroupSize(work_group_size); mul_program.SetDispatchGroupSize( - (N + kTileSizeB - 1) / kTileSizeB, + (N + tile_size_b - 1) / tile_size_b, (M + tile_size_a - 1) / tile_size_a, 1); mul_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, 1}, {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast(nbits == 4 ? kU32Components : 2 * kU32Components)}, @@ -271,15 +330,16 @@ bool CanApplySubgroupMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& cont bool has_subgroup_matrix = context.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); if (has_subgroup_matrix) { - if (context.AdapterInfo().vendor == std::string_view{"apple"}) { - // For now SubgroupMatrixMatMulNBits is only supported for accuracy level 4, because with Fp16 there are - // some precision issues with subgroupMatrixMultiplyAccumulate. It is possible to support higher accuracy - // by setting compute_precision to Fp32, but that will be slower. For 1K token prefill FP16 Phi 3.5 is around 5s, - // FP32 is around 7s. - has_subgroup_matrix = accuracy_level == 4; - } else if (context.AdapterInfo().vendor == std::string_view{"intel"}) { - // Intel subgroup matrix config is f16-only. - has_subgroup_matrix = is_fp16 && IsSubgroupMatrixConfigSupportedOnIntel(context, config_index); + // Check if the adapter reports a subgroup matrix config we support. + has_subgroup_matrix = IsSubgroupMatrixConfigSupported(context, is_fp16, config_index); + if (has_subgroup_matrix) { + if (context.AdapterInfo().vendor == std::string_view{"apple"}) { + // For now SubgroupMatrixMatMulNBits is only supported for accuracy level 4, because with Fp16 there are + // some precision issues with subgroupMatrixMultiplyAccumulate. It is possible to support higher accuracy + // by setting compute_precision to Fp32, but that will be slower. For 1K token prefill FP16 Phi 3.5 is around 5s, + // FP32 is around 7s. + has_subgroup_matrix = accuracy_level == 4; + } } } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h index 4ee800cafeb0d..810bda950b169 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h @@ -5,9 +5,6 @@ #if !defined(__wasm__) -#include - -#include "core/providers/webgpu/program.h" #include "core/providers/webgpu/compute_context.h" #include "core/providers/webgpu/program.h" #include "core/providers/webgpu/shader_helper.h" @@ -21,11 +18,10 @@ using namespace onnxruntime::webgpu; class SubgroupMatrixMatMulNBitsProgram final : public Program { public: - SubgroupMatrixMatMulNBitsProgram(uint32_t nbits, int32_t config_index, const wgpu::StringView& vendor, bool has_zero_points, bool has_bias, bool has_weight_idx, bool has_weight_idx_indirect) + SubgroupMatrixMatMulNBitsProgram(uint32_t nbits, int32_t config_index, bool has_zero_points, bool has_bias, bool has_weight_idx, bool has_weight_idx_indirect) : Program{"SubgroupMatrixMatMulNBits"}, nbits_(nbits), config_index_(config_index), - vendor_(vendor), has_zero_points_(has_zero_points), has_bias_(has_bias), has_weight_idx_{has_weight_idx}, @@ -41,7 +37,6 @@ class SubgroupMatrixMatMulNBitsProgram final : public Program buf_b: array; // 128 * 40 = 5120 elements = 10240 bytes + +// Small staging buffer for edge-tile output (one 16x16 tile per subgroup) +var coopmat_stage: array; // 16*16*4 = 1024 elements = 2048 bytes +// Total shared memory: 10240 + 2048 = 12288 bytes + +#if n_bits == 4 +// Dequantize 4-bit quantized B weights into buf_b with padded stride. +// First half: rows 0..63, Second half: rows 64..127 +// 128 threads, 2 threads per N row, each handles 16 K elements. +fn dequant_b_to_buf(n_base: u32, k_idx: u32, n_idx_0: u32, k_chunk: u32) { +#if has_weight_idx +#if has_weight_idx_indirect + let actual_weight_idx = weight_index_indirect[uniforms.weight_idx]; +#else + let actual_weight_idx = uniforms.weight_idx; +#endif + let weight_offset = actual_weight_idx * uniforms.K * uniforms.N; + let scale_offset = actual_weight_idx * uniforms.N * (uniforms.K / kQuantizationBlockSize); +#else + const weight_offset : u32 = 0; + const scale_offset : u32 = 0; +#endif + let k_offset = k_chunk * 16u; + + // First half: rows 0..63 + { + let global_n = n_base + n_idx_0; + let buf_base = n_idx_0 * SHMEM_STRIDE + k_offset; + if (global_n < uniforms.N) { + let packed_idx = u32((global_n * uniforms.K + k_idx + k_offset) / 8) + weight_offset / 8; + let scale_idx = (global_n * uniforms.K + k_idx + k_offset) / kQuantizationBlockSize + scale_offset; + let scale = f16(scales_b.getByOffset(scale_idx)); + let zero = mm_read_zero( + global_n, (k_idx + k_offset) / kQuantizationBlockSize, uniforms.N, uniforms.zero_blocks_per_col); + for (var step: u32 = 0; step < 2u; step++) { + let packed_weights = input_b.getByOffset(packed_idx + step); + let w_lo = (vec4(unpack4xU8(packed_weights & 0x0F0F0F0Fu)) - vec4(zero)) * scale; + let w_hi = (vec4(unpack4xU8((packed_weights >> 4u) & 0x0F0F0F0Fu)) - vec4(zero)) * scale; + let b = buf_base + step * 8u; + buf_b[b] = w_lo[0]; buf_b[b + 1] = w_hi[0]; + buf_b[b + 2] = w_lo[1]; buf_b[b + 3] = w_hi[1]; + buf_b[b + 4] = w_lo[2]; buf_b[b + 5] = w_hi[2]; + buf_b[b + 6] = w_lo[3]; buf_b[b + 7] = w_hi[3]; + } + } else { + for (var i: u32 = 0; i < 16u; i++) { + buf_b[buf_base + i] = f16(0.0); + } + } + } + // Second half: rows 64..127 + { + let n_idx_1 = n_idx_0 + 64u; + let global_n = n_base + n_idx_1; + let buf_base = n_idx_1 * SHMEM_STRIDE + k_offset; + if (global_n < uniforms.N) { + let packed_idx = u32((global_n * uniforms.K + k_idx + k_offset) / 8) + weight_offset / 8; + let scale_idx = (global_n * uniforms.K + k_idx + k_offset) / kQuantizationBlockSize + scale_offset; + let scale = f16(scales_b.getByOffset(scale_idx)); + let zero = mm_read_zero( + global_n, (k_idx + k_offset) / kQuantizationBlockSize, uniforms.N, uniforms.zero_blocks_per_col); + for (var step: u32 = 0; step < 2u; step++) { + let packed_weights = input_b.getByOffset(packed_idx + step); + let w_lo = (vec4(unpack4xU8(packed_weights & 0x0F0F0F0Fu)) - vec4(zero)) * scale; + let w_hi = (vec4(unpack4xU8((packed_weights >> 4u) & 0x0F0F0F0Fu)) - vec4(zero)) * scale; + let b = buf_base + step * 8u; + buf_b[b] = w_lo[0]; buf_b[b + 1] = w_hi[0]; + buf_b[b + 2] = w_lo[1]; buf_b[b + 3] = w_hi[1]; + buf_b[b + 4] = w_lo[2]; buf_b[b + 5] = w_hi[2]; + buf_b[b + 6] = w_lo[3]; buf_b[b + 7] = w_hi[3]; + } + } else { + for (var i: u32 = 0; i < 16u; i++) { + buf_b[buf_base + i] = f16(0.0); + } + } + } +} +#endif + +#if n_bits == 8 +// Dequantize 8-bit quantized B weights into buf_b with padded stride. +fn dequant_b_to_buf(n_base: u32, k_idx: u32, n_idx_0: u32, k_chunk: u32) { +#if has_weight_idx +#if has_weight_idx_indirect + let actual_weight_idx = weight_index_indirect[uniforms.weight_idx]; +#else + let actual_weight_idx = uniforms.weight_idx; +#endif + let weight_offset = actual_weight_idx * uniforms.K * uniforms.N; + let scale_offset = actual_weight_idx * uniforms.N * (uniforms.K / kQuantizationBlockSize); +#else + const weight_offset : u32 = 0; + const scale_offset : u32 = 0; +#endif + let k_offset = k_chunk * 16u; + + // First half: rows 0..63 + { + let global_n = n_base + n_idx_0; + let buf_base = n_idx_0 * SHMEM_STRIDE + k_offset; + if (global_n < uniforms.N) { + let packed_idx = u32((global_n * uniforms.K + k_idx + k_offset) / 8) + weight_offset / 8; + let scale_idx = (global_n * uniforms.K + k_idx + k_offset) / kQuantizationBlockSize + scale_offset; + let scale = f16(scales_b.getByOffset(scale_idx)); + let zero = mm_read_zero( + global_n, (k_idx + k_offset) / kQuantizationBlockSize, uniforms.N, uniforms.zero_blocks_per_col); + for (var step: u32 = 0; step < 2u; step++) { + let packed_weights = input_b.getByOffset(packed_idx + step); + let w_lo = (vec4(unpack4xU8(packed_weights[0])) - vec4(zero)) * scale; + let w_hi = (vec4(unpack4xU8(packed_weights[1])) - vec4(zero)) * scale; + let b = buf_base + step * 8u; + buf_b[b] = w_lo[0]; buf_b[b + 1] = w_lo[1]; + buf_b[b + 2] = w_lo[2]; buf_b[b + 3] = w_lo[3]; + buf_b[b + 4] = w_hi[0]; buf_b[b + 5] = w_hi[1]; + buf_b[b + 6] = w_hi[2]; buf_b[b + 7] = w_hi[3]; + } + } else { + for (var i: u32 = 0; i < 16u; i++) { + buf_b[buf_base + i] = f16(0.0); + } + } + } + // Second half: rows 64..127 + { + let n_idx_1 = n_idx_0 + 64u; + let global_n = n_base + n_idx_1; + let buf_base = n_idx_1 * SHMEM_STRIDE + k_offset; + if (global_n < uniforms.N) { + let packed_idx = u32((global_n * uniforms.K + k_idx + k_offset) / 8) + weight_offset / 8; + let scale_idx = (global_n * uniforms.K + k_idx + k_offset) / kQuantizationBlockSize + scale_offset; + let scale = f16(scales_b.getByOffset(scale_idx)); + let zero = mm_read_zero( + global_n, (k_idx + k_offset) / kQuantizationBlockSize, uniforms.N, uniforms.zero_blocks_per_col); + for (var step: u32 = 0; step < 2u; step++) { + let packed_weights = input_b.getByOffset(packed_idx + step); + let w_lo = (vec4(unpack4xU8(packed_weights[0])) - vec4(zero)) * scale; + let w_hi = (vec4(unpack4xU8(packed_weights[1])) - vec4(zero)) * scale; + let b = buf_base + step * 8u; + buf_b[b] = w_lo[0]; buf_b[b + 1] = w_lo[1]; + buf_b[b + 2] = w_lo[2]; buf_b[b + 3] = w_lo[3]; + buf_b[b + 4] = w_hi[0]; buf_b[b + 5] = w_hi[1]; + buf_b[b + 6] = w_hi[2]; buf_b[b + 7] = w_hi[3]; + } + } else { + for (var i: u32 = 0; i < 16u; i++) { + buf_b[buf_base + i] = f16(0.0); + } + } + } +} +#endif + +$MAIN { + let global_base_m = workgroup_id.y * BM; + let global_base_n = workgroup_id.x * BN; + + let sg_idx = u32(local_idx / sg_size); + let sg_tid = local_idx % sg_size; + let warp_r = sg_idx % 2u; // 0 or 1 (row in 2x2 grid) + let warp_c = sg_idx / 2u; // 0 or 1 (col in 2x2 grid) + + // Each subgroup's A starts at a different M offset in the prepacked layout + // Subgroup sg_idx handles rows: warp_r * WM .. warp_r * WM + WM (within the BM tile) + // But in the prepacked layout, the subgroups are numbered 0..7 (kSgMatCountM = BM/kSgMatM = 8) + // Subgroups 0,1 handle rows 0..63 (warp_r=0), subgroups 2,3 handle rows 64..127 (warp_r=1) + // Within each 64-row band, we have 4 coopmat rows of 16 + + // For prepacked A: sg_mat index = (workgroup_y * kSgMatCountM + first_coopmat_row_in_warp) * sg_mat_count_k + let sg_mat_count_k = uniforms.K / kSgMatK; + + // 16 accumulator tiles (4x4 grid of 16x16) + var c00: subgroup_matrix_result; + var c01: subgroup_matrix_result; + var c02: subgroup_matrix_result; + var c03: subgroup_matrix_result; + var c10: subgroup_matrix_result; + var c11: subgroup_matrix_result; + var c12: subgroup_matrix_result; + var c13: subgroup_matrix_result; + var c20: subgroup_matrix_result; + var c21: subgroup_matrix_result; + var c22: subgroup_matrix_result; + var c23: subgroup_matrix_result; + var c30: subgroup_matrix_result; + var c31: subgroup_matrix_result; + var c32: subgroup_matrix_result; + var c33: subgroup_matrix_result; + + // Precompute A offsets for the 4 coopmat rows this subgroup handles + // In prepacked layout, each subgroup matrix tile is kSgMatSizeLeft contiguous elements + // The 4 coopmat rows are at (warp_r * 4 + 0..3) in the M dimension + var sg_mat_offset_a0 = (workgroup_id.y * kSgMatCountM + warp_r * 4u + 0u) * sg_mat_count_k * kSgMatSizeLeft; + var sg_mat_offset_a1 = (workgroup_id.y * kSgMatCountM + warp_r * 4u + 1u) * sg_mat_count_k * kSgMatSizeLeft; + var sg_mat_offset_a2 = (workgroup_id.y * kSgMatCountM + warp_r * 4u + 2u) * sg_mat_count_k * kSgMatSizeLeft; + var sg_mat_offset_a3 = (workgroup_id.y * kSgMatCountM + warp_r * 4u + 3u) * sg_mat_count_k * kSgMatSizeLeft; + + // Main K-loop + for (var k_block: u32 = 0; k_block < uniforms.K; k_block += BK) { + // Load B (Q4_K/Q8 weights) to buf_b with padded stride + dequant_b_to_buf(global_base_n, k_block, local_idx / 2u, local_idx % 2u); + workgroupBarrier(); + + // Compute: load B once per k-step, iterate 4 A rows + for (var k: u32 = 0; k < BK; k += kSgMatK) { + let b_base = (warp_c * WN) * SHMEM_STRIDE + k; + + // Load 4 B tiles from shared memory (column-major, padded stride) + var b0: subgroup_matrix_right = + subgroupMatrixLoad>(&buf_b, b_base, true, SHMEM_STRIDE); + var b1: subgroup_matrix_right = + subgroupMatrixLoad>(&buf_b, b_base + 16u * SHMEM_STRIDE, true, SHMEM_STRIDE); + var b2: subgroup_matrix_right = + subgroupMatrixLoad>(&buf_b, b_base + 32u * SHMEM_STRIDE, true, SHMEM_STRIDE); + var b3: subgroup_matrix_right = + subgroupMatrixLoad>(&buf_b, b_base + 48u * SHMEM_STRIDE, true, SHMEM_STRIDE); + + // A row 0: load from global prepacked memory + var a0: subgroup_matrix_left = + subgroupMatrixLoad>( + &input_a, sg_mat_offset_a0, false, kSgMatK); + sg_mat_offset_a0 += kSgMatSizeLeft; + c00 = subgroupMatrixMultiplyAccumulate(a0, b0, c00); + c01 = subgroupMatrixMultiplyAccumulate(a0, b1, c01); + c02 = subgroupMatrixMultiplyAccumulate(a0, b2, c02); + c03 = subgroupMatrixMultiplyAccumulate(a0, b3, c03); + + // A row 1 + var a1: subgroup_matrix_left = + subgroupMatrixLoad>( + &input_a, sg_mat_offset_a1, false, kSgMatK); + sg_mat_offset_a1 += kSgMatSizeLeft; + c10 = subgroupMatrixMultiplyAccumulate(a1, b0, c10); + c11 = subgroupMatrixMultiplyAccumulate(a1, b1, c11); + c12 = subgroupMatrixMultiplyAccumulate(a1, b2, c12); + c13 = subgroupMatrixMultiplyAccumulate(a1, b3, c13); + + // A row 2 + var a2: subgroup_matrix_left = + subgroupMatrixLoad>( + &input_a, sg_mat_offset_a2, false, kSgMatK); + sg_mat_offset_a2 += kSgMatSizeLeft; + c20 = subgroupMatrixMultiplyAccumulate(a2, b0, c20); + c21 = subgroupMatrixMultiplyAccumulate(a2, b1, c21); + c22 = subgroupMatrixMultiplyAccumulate(a2, b2, c22); + c23 = subgroupMatrixMultiplyAccumulate(a2, b3, c23); + + // A row 3 + var a3: subgroup_matrix_left = + subgroupMatrixLoad>( + &input_a, sg_mat_offset_a3, false, kSgMatK); + sg_mat_offset_a3 += kSgMatSizeLeft; + c30 = subgroupMatrixMultiplyAccumulate(a3, b0, c30); + c31 = subgroupMatrixMultiplyAccumulate(a3, b1, c31); + c32 = subgroupMatrixMultiplyAccumulate(a3, b2, c32); + c33 = subgroupMatrixMultiplyAccumulate(a3, b3, c33); + } + workgroupBarrier(); + } + + // Write out results + let dr = global_base_m + warp_r * WM; + let dc = global_base_n + warp_c * WN; + let stride = uniforms.N; + + // Workgroup-uniform bounds check (based on workgroup_id, not local_idx) + let tile_m_in = global_base_m + BM <= uniforms.M; + let tile_n_in = global_base_n + BN <= uniforms.N; + let tile_in_bounds = tile_m_in && tile_n_in; + +#if has_bias + // Bias path: store all 16 tiles via coopmat_stage, add bias, write to output + let stage_base = sg_idx * kSgMatM * kSgMatN; +#if has_weight_idx_indirect + let bias_offset = weight_index_indirect[uniforms.weight_idx] * uniforms.N; +#elif has_weight_idx + let bias_offset = uniforms.weight_idx * uniforms.N; +#else + const bias_offset: u32 = 0; +#endif + + // Helper macro: store one coopmat tile via stage, add bias, write to output + // Process all 16 tiles (4 rows x 4 cols) + // No workgroupBarrier needed: subgroupMatrixStore and scalar reads are within the same + // subgroup (lockstep execution), and each subgroup uses its own non-overlapping stage_base. + + // Row 0 + subgroupMatrixStore(&coopmat_stage, stage_base, c00, false, kSgMatN); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + i / kSgMatN; + let c = dc + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + let val = output_element_t(coopmat_stage[stage_base + i]) + bias[bias_offset + c]; + output.setByOffset(r * stride + c, val); + } + } + + subgroupMatrixStore(&coopmat_stage, stage_base, c01, false, kSgMatN); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + i / kSgMatN; + let c = dc + 16u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + let val = output_element_t(coopmat_stage[stage_base + i]) + bias[bias_offset + c]; + output.setByOffset(r * stride + c, val); + } + } + + subgroupMatrixStore(&coopmat_stage, stage_base, c02, false, kSgMatN); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + i / kSgMatN; + let c = dc + 32u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + let val = output_element_t(coopmat_stage[stage_base + i]) + bias[bias_offset + c]; + output.setByOffset(r * stride + c, val); + } + } + + subgroupMatrixStore(&coopmat_stage, stage_base, c03, false, kSgMatN); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + i / kSgMatN; + let c = dc + 48u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + let val = output_element_t(coopmat_stage[stage_base + i]) + bias[bias_offset + c]; + output.setByOffset(r * stride + c, val); + } + } + + // Row 1 + subgroupMatrixStore(&coopmat_stage, stage_base, c10, false, kSgMatN); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 16u + i / kSgMatN; + let c = dc + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + let val = output_element_t(coopmat_stage[stage_base + i]) + bias[bias_offset + c]; + output.setByOffset(r * stride + c, val); + } + } + + subgroupMatrixStore(&coopmat_stage, stage_base, c11, false, kSgMatN); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 16u + i / kSgMatN; + let c = dc + 16u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + let val = output_element_t(coopmat_stage[stage_base + i]) + bias[bias_offset + c]; + output.setByOffset(r * stride + c, val); + } + } + + subgroupMatrixStore(&coopmat_stage, stage_base, c12, false, kSgMatN); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 16u + i / kSgMatN; + let c = dc + 32u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + let val = output_element_t(coopmat_stage[stage_base + i]) + bias[bias_offset + c]; + output.setByOffset(r * stride + c, val); + } + } + + subgroupMatrixStore(&coopmat_stage, stage_base, c13, false, kSgMatN); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 16u + i / kSgMatN; + let c = dc + 48u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + let val = output_element_t(coopmat_stage[stage_base + i]) + bias[bias_offset + c]; + output.setByOffset(r * stride + c, val); + } + } + + // Row 2 + subgroupMatrixStore(&coopmat_stage, stage_base, c20, false, kSgMatN); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 32u + i / kSgMatN; + let c = dc + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + let val = output_element_t(coopmat_stage[stage_base + i]) + bias[bias_offset + c]; + output.setByOffset(r * stride + c, val); + } + } + + subgroupMatrixStore(&coopmat_stage, stage_base, c21, false, kSgMatN); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 32u + i / kSgMatN; + let c = dc + 16u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + let val = output_element_t(coopmat_stage[stage_base + i]) + bias[bias_offset + c]; + output.setByOffset(r * stride + c, val); + } + } + + subgroupMatrixStore(&coopmat_stage, stage_base, c22, false, kSgMatN); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 32u + i / kSgMatN; + let c = dc + 32u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + let val = output_element_t(coopmat_stage[stage_base + i]) + bias[bias_offset + c]; + output.setByOffset(r * stride + c, val); + } + } + + subgroupMatrixStore(&coopmat_stage, stage_base, c23, false, kSgMatN); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 32u + i / kSgMatN; + let c = dc + 48u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + let val = output_element_t(coopmat_stage[stage_base + i]) + bias[bias_offset + c]; + output.setByOffset(r * stride + c, val); + } + } + + // Row 3 + subgroupMatrixStore(&coopmat_stage, stage_base, c30, false, kSgMatN); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 48u + i / kSgMatN; + let c = dc + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + let val = output_element_t(coopmat_stage[stage_base + i]) + bias[bias_offset + c]; + output.setByOffset(r * stride + c, val); + } + } + + subgroupMatrixStore(&coopmat_stage, stage_base, c31, false, kSgMatN); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 48u + i / kSgMatN; + let c = dc + 16u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + let val = output_element_t(coopmat_stage[stage_base + i]) + bias[bias_offset + c]; + output.setByOffset(r * stride + c, val); + } + } + + subgroupMatrixStore(&coopmat_stage, stage_base, c32, false, kSgMatN); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 48u + i / kSgMatN; + let c = dc + 32u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + let val = output_element_t(coopmat_stage[stage_base + i]) + bias[bias_offset + c]; + output.setByOffset(r * stride + c, val); + } + } + + subgroupMatrixStore(&coopmat_stage, stage_base, c33, false, kSgMatN); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 48u + i / kSgMatN; + let c = dc + 48u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + let val = output_element_t(coopmat_stage[stage_base + i]) + bias[bias_offset + c]; + output.setByOffset(r * stride + c, val); + } + } +#else + // Non-bias path + if (tile_in_bounds) { + // Fast path: entire 128x128 tile is in bounds - direct store to global + let base = dr * stride + dc; + subgroupMatrixStore(&output, base, c00, false, stride); + subgroupMatrixStore(&output, base + 16u, c01, false, stride); + subgroupMatrixStore(&output, base + 32u, c02, false, stride); + subgroupMatrixStore(&output, base + 48u, c03, false, stride); + + subgroupMatrixStore(&output, base + 16u * stride, c10, false, stride); + subgroupMatrixStore(&output, base + 16u * stride + 16u, c11, false, stride); + subgroupMatrixStore(&output, base + 16u * stride + 32u, c12, false, stride); + subgroupMatrixStore(&output, base + 16u * stride + 48u, c13, false, stride); + + subgroupMatrixStore(&output, base + 32u * stride, c20, false, stride); + subgroupMatrixStore(&output, base + 32u * stride + 16u, c21, false, stride); + subgroupMatrixStore(&output, base + 32u * stride + 32u, c22, false, stride); + subgroupMatrixStore(&output, base + 32u * stride + 48u, c23, false, stride); + + subgroupMatrixStore(&output, base + 48u * stride, c30, false, stride); + subgroupMatrixStore(&output, base + 48u * stride + 16u, c31, false, stride); + subgroupMatrixStore(&output, base + 48u * stride + 32u, c32, false, stride); + subgroupMatrixStore(&output, base + 48u * stride + 48u, c33, false, stride); + } else { + // Edge tile: store via coopmat_stage with bounds checking + // No workgroupBarrier needed: subgroupMatrixStore and scalar reads are within the same + // subgroup (lockstep execution), and each subgroup uses its own non-overlapping stage_base. + let stage_base = sg_idx * kSgMatM * kSgMatN; + + // Row 0 + subgroupMatrixStore(&coopmat_stage, stage_base, c00, false, kSgMatN); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + i / kSgMatN; + let c = dc + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); + } + } + + subgroupMatrixStore(&coopmat_stage, stage_base, c01, false, kSgMatN); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + i / kSgMatN; + let c = dc + 16u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); + } + } + + subgroupMatrixStore(&coopmat_stage, stage_base, c02, false, kSgMatN); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + i / kSgMatN; + let c = dc + 32u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); + } + } + + subgroupMatrixStore(&coopmat_stage, stage_base, c03, false, kSgMatN); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + i / kSgMatN; + let c = dc + 48u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); + } + } + + // Row 1 + subgroupMatrixStore(&coopmat_stage, stage_base, c10, false, kSgMatN); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 16u + i / kSgMatN; + let c = dc + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); + } + } + + subgroupMatrixStore(&coopmat_stage, stage_base, c11, false, kSgMatN); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 16u + i / kSgMatN; + let c = dc + 16u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); + } + } + + subgroupMatrixStore(&coopmat_stage, stage_base, c12, false, kSgMatN); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 16u + i / kSgMatN; + let c = dc + 32u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); + } + } + + subgroupMatrixStore(&coopmat_stage, stage_base, c13, false, kSgMatN); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 16u + i / kSgMatN; + let c = dc + 48u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); + } + } + + // Row 2 + subgroupMatrixStore(&coopmat_stage, stage_base, c20, false, kSgMatN); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 32u + i / kSgMatN; + let c = dc + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); + } + } + + subgroupMatrixStore(&coopmat_stage, stage_base, c21, false, kSgMatN); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 32u + i / kSgMatN; + let c = dc + 16u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); + } + } + + subgroupMatrixStore(&coopmat_stage, stage_base, c22, false, kSgMatN); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 32u + i / kSgMatN; + let c = dc + 32u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); + } + } + + subgroupMatrixStore(&coopmat_stage, stage_base, c23, false, kSgMatN); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 32u + i / kSgMatN; + let c = dc + 48u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); + } + } + + // Row 3 + subgroupMatrixStore(&coopmat_stage, stage_base, c30, false, kSgMatN); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 48u + i / kSgMatN; + let c = dc + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); + } + } + + subgroupMatrixStore(&coopmat_stage, stage_base, c31, false, kSgMatN); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 48u + i / kSgMatN; + let c = dc + 16u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); + } + } + + subgroupMatrixStore(&coopmat_stage, stage_base, c32, false, kSgMatN); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 48u + i / kSgMatN; + let c = dc + 32u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); + } + } + + subgroupMatrixStore(&coopmat_stage, stage_base, c33, false, kSgMatN); + for (var i: u32 = sg_tid; i < kSgMatM * kSgMatN; i += sg_size) { + let r = dr + 48u + i / kSgMatN; + let c = dc + 48u + i % kSgMatN; + if (r < uniforms.M && c < uniforms.N) { + output.setByOffset(r * stride + c, output_element_t(coopmat_stage[stage_base + i])); + } + } + } +#endif +} // MAIN diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_intel.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_8x16x16.wgsl.template similarity index 100% rename from onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_intel.wgsl.template rename to onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_8x16x16.wgsl.template diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_apple.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_8x8x8.wgsl.template similarity index 100% rename from onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_apple.wgsl.template rename to onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_8x8x8.wgsl.template diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_prepack.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_prepack.wgsl.template index 3b8f9884da927..0172dc223007b 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_prepack.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits_prepack.wgsl.template @@ -1,9 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -// Intel SubgroupMatrix prepack kernel +// SubgroupMatrix prepack kernel // Rearranges input matrix A(MxK) so that each subgroup matrix (sg_mat_m x sg_mat_k) // has its elements laid out contiguously in memory for subgroupMatrixLoad. +// +// The prepack buffer is padded to the workgroup tile size (which may exceed M). +// Fully OOB blocks (row_base >= M) are skipped entirely, and partial blocks +// only copy in-bounds rows. Padding corresponding to rows >= M may therefore +// remain uninitialized, so downstream shaders must not rely on it unless they +// handle edge tiles explicitly. #param sg_mat_k #param sg_mat_m @@ -14,11 +20,24 @@ const kSgMatK: u32 = u32(sg_mat_k); $MAIN { let M = uniforms.M; let K = uniforms.K; - let in_offset = workgroup_id.x * kSgMatM * K + workgroup_id.y * kSgMatK; + let row_base = workgroup_id.x * kSgMatM; + let in_offset = row_base * K + workgroup_id.y * kSgMatK; let out_offset = (workgroup_id.x * K / kSgMatK + workgroup_id.y) * kSgMatM * kSgMatK; - // Syntax: subgroupMatrixLoad src_ptr, src_offset, is_col_major, src_stride - var mat: subgroup_matrix_left = - subgroupMatrixLoad>(&input_a, in_offset, false, uniforms.K); - subgroupMatrixStore(&output_a, out_offset, mat, false, kSgMatK); + if (row_base + kSgMatM <= M) { + // All rows in this block are within bounds - use fast subgroupMatrixLoad. + var mat: subgroup_matrix_left = + subgroupMatrixLoad>(&input_a, in_offset, false, K); + subgroupMatrixStore(&output_a, out_offset, mat, false, kSgMatK); + } else if (row_base < M) { + // Partial block: some rows are OOB. Use scalar copy for in-bounds rows only. + for (var r: u32 = local_idx; r < kSgMatM * kSgMatK; r += workgroup_size_x) { + let row = r / kSgMatK; + let col = r % kSgMatK; + if (row_base + row < M) { + output_a[out_offset + r] = input_a[in_offset + row * K + col]; + } + } + } + // Fully OOB blocks (row_base >= M): skip entirely. } // MAIN diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index b463aa3a6c363..435995268c437 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -465,6 +465,8 @@ TEST(MatMulNBits, Float16_4b_Accuracy0) { TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); } TEST(MatMulNBits, Float16_4b_Accuracy4) { @@ -495,6 +497,8 @@ TEST(MatMulNBits, Float16_4b_Accuracy4) { TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); + TestMatMulNBitsTyped(); // See PR #27412 for details on the following test case, // which is added to cover a specific failure case in the past.