From 10d490c212692b5af9dd0ea0d9372baf135ef05a Mon Sep 17 00:00:00 2001 From: "wenqin.yang" Date: Fri, 16 Jan 2026 14:49:21 +0800 Subject: [PATCH] support vec1 for im2col --- .../core/providers/webgpu/nn/im2col_matmul.cc | 17 ++++---- .../core/providers/webgpu/nn/im2col_matmul.h | 3 ++ .../webgpu/nn/im2col_matmul.wgsl.template | 39 +++++++++++++------ 3 files changed, 38 insertions(+), 21 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc b/onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc index cfea39e1464d3..b2158f549c7c1 100644 --- a/onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc +++ b/onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc @@ -71,12 +71,14 @@ Status Im2ColMatMulProgram::GenerateShaderCode(ShaderHelper& shader) const { ORT_ENFORCE(tile_m_ == 16 || tile_m_ == 32, "tile_m must be 16 or 32."); ORT_ENFORCE(tile_n_ == 64, "tile_n must be 64."); + ORT_ENFORCE(vec_size_ == 1 || vec_size_ == 4, "vec_size must be 4 or 1."); return WGSL_TEMPLATE_APPLY(shader, "nn/im2col_matmul.wgsl.template", WGSL_TEMPLATE_PARAMETER(has_bias, has_bias_), WGSL_TEMPLATE_PARAMETER(tile_m, tile_m_), WGSL_TEMPLATE_PARAMETER(tile_n, tile_n_), WGSL_TEMPLATE_PARAMETER(use_subgroup, use_subgroup_), + WGSL_TEMPLATE_PARAMETER(vec_size, vec_size_), WGSL_TEMPLATE_VARIABLE(output, output), WGSL_TEMPLATE_VARIABLE(src, src), WGSL_TEMPLATE_VARIABLE(weight, weight)); @@ -145,7 +147,8 @@ Status ApplyIm2ColMatMulProgram(ComputeContext& context, // Ensure the subgroup size must be greater than or equal to `tile_m` to safely enable `use_subgroup`. // If the status of this condition is uncertain, the feature must be disabled. const bool use_subgroup = false; - Im2ColMatMulProgram im2col_mm_program{has_bias, tile_m, tile_n, use_subgroup}; + const uint32_t vec_size = channel_input % 4 == 0 ? 4 : 1; + Im2ColMatMulProgram im2col_mm_program{has_bias, tile_m, tile_n, vec_size, use_subgroup}; im2col_mm_program.SetWorkgroupSize(workgroup_size); const uint32_t M_tiles = CeilDiv(im2col_m, tile_m); @@ -154,10 +157,10 @@ Status ApplyIm2ColMatMulProgram(ComputeContext& context, im2col_mm_program.AddInput({src, ProgramTensorMetadataDependency::TypeAndRank, - 4}); + static_cast(vec_size)}); im2col_mm_program.AddInput({&ohwi_weight, ProgramTensorMetadataDependency::TypeAndRank, - 4}); + static_cast(vec_size)}); if (has_bias) { im2col_mm_program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank}); @@ -181,7 +184,7 @@ Status ApplyIm2ColMatMulProgram(ComputeContext& context, {dilations}, {pads}, {strides}}); - im2col_mm_program.CacheHint(has_bias, tile_m, tile_n, use_subgroup); + im2col_mm_program.CacheHint(has_bias, tile_m, tile_n, vec_size, use_subgroup); return context.RunProgram(im2col_mm_program); } @@ -212,12 +215,6 @@ bool CanApplyIm2ColMatMulProgram(ComputeContextBase& context, return false; } - // TODO: Support channel input vec1 - const uint32_t channel_input = onnxruntime::narrow(weight_shape[1]); - if (channel_input % 4 != 0) { - return false; - } - return true; } diff --git a/onnxruntime/core/providers/webgpu/nn/im2col_matmul.h b/onnxruntime/core/providers/webgpu/nn/im2col_matmul.h index ed24100879520..f881f1cb006ed 100644 --- a/onnxruntime/core/providers/webgpu/nn/im2col_matmul.h +++ b/onnxruntime/core/providers/webgpu/nn/im2col_matmul.h @@ -39,10 +39,12 @@ class Im2ColMatMulProgram final : public Program { Im2ColMatMulProgram(bool has_bias, uint32_t tile_m, uint32_t tile_n, + uint32_t vec_size, bool use_subgroup) : Program("Im2ColMatMul"), has_bias_(has_bias), tile_m_(tile_m), tile_n_(tile_n), + vec_size_(vec_size), use_subgroup_(use_subgroup) {} Status GenerateShaderCode(ShaderHelper& shader) const override; @@ -71,6 +73,7 @@ class Im2ColMatMulProgram final : public Program { uint32_t tile_m_; uint32_t tile_n_; + uint32_t vec_size_; bool use_subgroup_; }; diff --git a/onnxruntime/core/providers/webgpu/nn/im2col_matmul.wgsl.template b/onnxruntime/core/providers/webgpu/nn/im2col_matmul.wgsl.template index 2f64525469561..ee937dd9a62e9 100644 --- a/onnxruntime/core/providers/webgpu/nn/im2col_matmul.wgsl.template +++ b/onnxruntime/core/providers/webgpu/nn/im2col_matmul.wgsl.template @@ -5,25 +5,26 @@ #param tile_m #param tile_n #param use_subgroup +#param vec_size #use .getByOffset .setByOffset // im2col access for src: [N, H_i, W_i, C_i / 4] (vec4-packed NHWC) // Conceptual Matrix Shape: N * (H_o * W_o) x (K_h * K_w * C_i / 4) fn load_src(batch : u32, m : u32, k_packed_idx : u32) -> src_value_t { - if (batch >= uniforms.batch || m >= uniforms.im2col_m || k_packed_idx * 4 >= uniforms.im2col_k) { + if (batch >= uniforms.batch || m >= uniforms.im2col_m || k_packed_idx * vec_size >= uniforms.im2col_k) { return src_value_t(); } - let channel_i_v4 = uniforms.channel_i / 4; + let channel_i_vec = uniforms.channel_i / vec_size; // 1. Decompose M index (H_o * W_o) into (h_idx, w_idx) let h_idx = m / uniforms.output_w; // Output H index (H_o) let w_idx = m % uniforms.output_w; // Output W index (W_o) // 2. Decompose K index into (k_h, k_w, c_i_v4_idx) - let c_i_v4_idx = k_packed_idx % channel_i_v4; - let k_h_w_idx = k_packed_idx / channel_i_v4; + let c_i_v4_idx = k_packed_idx % channel_i_vec; + let k_h_w_idx = k_packed_idx / channel_i_vec; let k_h = k_h_w_idx / uniforms.kernel_w; // Kernel Row let k_w = k_h_w_idx % uniforms.kernel_w; // Kernel Column @@ -33,7 +34,7 @@ fn load_src(batch : u32, m : u32, k_packed_idx : u32) -> src_value_t { // 4. Calculate the coordinate in the original input tensor let src_h_coord : i32 = i32(src_h_coord_padded) - i32(uniforms.pads.x); - let src_w_coord : i32 = i32(src_w_coord_padded) - i32(uniforms.pads.z); + let src_w_coord : i32 = i32(src_w_coord_padded) - i32(uniforms.pads.y); // 5. Check for padding/out-of-bounds if (src_h_coord < 0 || src_h_coord >= i32(uniforms.src_h) || @@ -42,17 +43,17 @@ fn load_src(batch : u32, m : u32, k_packed_idx : u32) -> src_value_t { } // 6. Calculate final NHWC/vec4 index - let src_idx = batch * uniforms.src_h * uniforms.src_w * channel_i_v4 + - u32(src_h_coord) * uniforms.src_w * channel_i_v4 + - u32(src_w_coord) * channel_i_v4 + + let src_idx = batch * uniforms.src_h * uniforms.src_w * channel_i_vec + + u32(src_h_coord) * uniforms.src_w * channel_i_vec + + u32(src_w_coord) * channel_i_vec + c_i_v4_idx; return src.getByOffset(src_idx); } // weight shape: [Co, K_h, K_w, C_i / 4] (vec4-packed CoHWCi) fn load_weight(n : u32, k_packed_idx : u32) -> weight_value_t { - if (n < uniforms.im2col_n && k_packed_idx < uniforms.im2col_k / 4) { - let weight_idx = n * uniforms.im2col_k / 4 + + if (n < uniforms.im2col_n && k_packed_idx < uniforms.im2col_k / vec_size) { + let weight_idx = n * uniforms.im2col_k / vec_size + k_packed_idx; return weight.getByOffset(weight_idx); } @@ -80,7 +81,7 @@ fn write_output(batch : u32, m : u32, n : u32, value : output_element_t) { const TILE_M_SIZE : u32 = tile_m; const TILE_N_SIZE : u32 = tile_n; -const TILE_K_VEC_SIZE : u32 = 4; +const TILE_K_VEC_SIZE : u32 = 16 / vec_size; var src_tile : array, TILE_K_VEC_SIZE>; var weight_tile : array, TILE_K_VEC_SIZE>; @@ -92,20 +93,32 @@ $MAIN { var results : array; for (var k_idx = 0u; k_idx < uniforms.K_tiles; k_idx++) { +#if vec_size != 4 + for (var src_m = 0u; src_m < TILE_M_SIZE; src_m += 4u) { + let load_src_m = src_m + local_idx / 16; + let load_src_k = local_idx % 16; +#else for (var src_m = 0u; src_m < TILE_M_SIZE; src_m += 16u) { // Loads a 16x4 vec of src into the workgroup memory. let load_src_m = src_m + local_idx / 4; let load_src_k = local_idx % 4; +#endif src_tile[load_src_k][load_src_m] = load_src(batch, m_global_base + load_src_m, k_idx * TILE_K_VEC_SIZE + load_src_k); } +#if vec_size != 4 + for (var weight_n = 0u; weight_n < TILE_N_SIZE; weight_n += 4u) { + let load_weight_n = weight_n + local_idx / 16; + let load_weight_k = local_idx % 16; +#else for (var weight_n = 0u; weight_n < TILE_N_SIZE; weight_n += 16u) { // Loads a 16x4 vec of weight into the workgroup memory. let load_weight_n = weight_n + local_idx / 4; let load_weight_k = local_idx % 4; +#endif weight_tile[load_weight_k][load_weight_n] = load_weight(n_global_base + load_weight_n, k_idx * TILE_K_VEC_SIZE + load_weight_k); @@ -121,7 +134,11 @@ $MAIN { } #else for (var m_idx = 0u; m_idx < TILE_M_SIZE; m_idx++) { +#if vec_size != 4 + results[m_idx] += output_element_t(weight_data * src_tile[inner_k_idx][m_idx]); +#else results[m_idx] += output_element_t(dot(weight_data, src_tile[inner_k_idx][m_idx])); +#endif } #endif }