From acabc516aab6b9298b6a493f70d71721a8778d30 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Fri, 13 Feb 2026 18:06:05 -0800 Subject: [PATCH 1/6] add weight_index_indirect to nbitmm --- .../quantization/dp4a_matmul.wgsl.template | 38 +++++++++++++++---- .../webgpu/quantization/dp4a_matmul_nbits.cc | 26 ++++++++++--- .../webgpu/quantization/dp4a_matmul_nbits.h | 11 ++++-- .../dp4a_matmul_small_m.wgsl.template | 21 ++++++++-- .../webgpu/quantization/matmul_nbits.cc | 30 +++++++++++---- .../webgpu/quantization/matmul_nbits.h | 13 ++++--- .../quantization/matmul_nbits.wgsl.template | 15 ++++++-- .../matmul_nbits_wide_tile.wgsl.template | 29 ++++++++++++-- .../subgroup_matrix_matmul_nbits.cc | 19 +++++++--- .../subgroup_matrix_matmul_nbits.h | 9 +++-- ...up_matrix_matmul_nbits_apple.wgsl.template | 23 +++++++++-- 11 files changed, 184 insertions(+), 50 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template index 6d22e6743707b..cbe471baf504f 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template @@ -7,6 +7,7 @@ #param has_zero_points #param is_qualcomm #param has_weight_idx +#param has_weight_idx_indirect #use .getByOffset .setByOffset @@ -78,9 +79,15 @@ fn loadSHMA(batch:u32, a_global_base:u32, kidx_v:u32, row: u32, col: u32) return; } #if has_weight_idx - let b_weight_offset = uniforms.weight_idx * uniforms.N * uniforms.K16; +#if has_weight_idx_indirect + let actual_weight_idx = weight_index_indirect[uniforms.weight_idx]; +#else + let actual_weight_idx = uniforms.weight_idx; +#endif + let b_weight_offset = actual_weight_idx * uniforms.N * uniforms.K16; let b_value = b.getByOffset(b_weight_offset + b_global * uniforms.K16 + kidx_v + col); #else + const actual_weight_idx : u32 = 0; let b_value = b.getByOffset(b_global * uniforms.K16+kidx_v + col); #endif let block_idx = kidx_v/(block_size/16); @@ -89,7 +96,7 @@ fn loadSHMA(batch:u32, a_global_base:u32, kidx_v:u32, row: u32, col: u32) if (col == 0) { // kidx_v - each kidx_v covers 16 values of k - let b_scale_offset = uniforms.weight_idx * uniforms.N * (uniforms.K/block_size); + let b_scale_offset = actual_weight_idx * uniforms.N * (uniforms.K/block_size); scale_B[row] = scales_b.getByOffset(b_scale_offset + b_global*(uniforms.K/block_size) + block_idx); } } @@ -105,9 +112,15 @@ fn loadSHMA(batch:u32, a_global_base:u32, kidx_v:u32, row: u32, col: u32) } #if has_weight_idx - let b_weight_offset = uniforms.weight_idx * uniforms.N * uniforms.K16; +#if has_weight_idx_indirect + let actual_weight_idx = weight_index_indirect[uniforms.weight_idx]; +#else + let actual_weight_idx = uniforms.weight_idx; +#endif + let b_weight_offset = actual_weight_idx * uniforms.N * uniforms.K16; let b_value = b.getByOffset(b_weight_offset + b_global * uniforms.K16 + kidx_v + col); #else + const actual_weight_idx : u32 = 0; const b_weight_offset : u32 = 0; let b_value = b.getByOffset(b_global * uniforms.K16 + kidx_v + col); #endif @@ -116,7 +129,7 @@ fn loadSHMA(batch:u32, a_global_base:u32, kidx_v:u32, row: u32, col: u32) { // kidx_v - each kidx_v covers 16 values of k let block_idx = kidx_v/(block_size/16); - let b_scale_offset = uniforms.weight_idx * uniforms.N * (uniforms.K/block_size); + let b_scale_offset = actual_weight_idx * uniforms.N * (uniforms.K/block_size); scale_B[row] = scales_b.getByOffset(b_scale_offset + b_global*(uniforms.K/block_size) + block_idx); #if has_zero_points zeroes[row] = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col); @@ -134,15 +147,21 @@ fn loadSHMA(batch:u32, a_global_base:u32, kidx_v:u32, row: u32, col: u32) return; } #if has_weight_idx - let b_weight_offset = uniforms.weight_idx * uniforms.N * uniforms.K16; +#if has_weight_idx_indirect + let actual_weight_idx = weight_index_indirect[uniforms.weight_idx]; +#else + let actual_weight_idx = uniforms.weight_idx; +#endif + let b_weight_offset = actual_weight_idx * uniforms.N * uniforms.K16; let b_value = b.getByOffset(b_weight_offset + b_global * uniforms.K16 + kidx_v + col); #else + const actual_weight_idx : u32 = 0; const b_weight_offset : u32 = 0; let b_value = b.getByOffset(b_global * uniforms.K16 + kidx_v + col); #endif tile_B[col][row] = DequantizedFrom2BitsTo8Bits(b_value); let block_idx = kidx_v/(block_size/16); - let b_scale_offset = uniforms.weight_idx * uniforms.N * (uniforms.K/block_size); + let b_scale_offset = actual_weight_idx * uniforms.N * (uniforms.K/block_size); scale_B[row] = scales_b.getByOffset(b_scale_offset + b_global*(uniforms.K/block_size) + block_idx); } #endif @@ -387,7 +406,12 @@ $MAIN { let output_idx = (batch * uniforms.M * uniforms.N + a_global * uniforms.N + b_global)/4; #if has_bias #if has_weight_idx - let b_bias_offset = uniforms.weight_idx * uniforms.N; +#if has_weight_idx_indirect + let actual_weight_idx_bias = weight_index_indirect[uniforms.weight_idx]; +#else + let actual_weight_idx_bias = uniforms.weight_idx; +#endif + let b_bias_offset = actual_weight_idx_bias * uniforms.N; #else const b_bias_offset : u32 = 0; #endif diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc index 45a0379fa26a5..900974a0b4a11 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc @@ -30,11 +30,15 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { if (has_bias_) { shader.AddInput("bias", ShaderUsage::UseUniform); } + if (has_weight_idx_indirect_) { + shader.AddInput("weight_index_indirect", ShaderUsage::UseUniform); + } const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); return WGSL_TEMPLATE_APPLY(shader, "quantization/dp4a_matmul.wgsl.template", WGSL_TEMPLATE_PARAMETER(block_size, block_size_), WGSL_TEMPLATE_PARAMETER(has_bias, has_bias_), WGSL_TEMPLATE_PARAMETER(has_weight_idx, has_weight_idx_), + WGSL_TEMPLATE_PARAMETER(has_weight_idx_indirect, has_weight_idx_indirect_), WGSL_TEMPLATE_PARAMETER(has_zero_points, has_zero_points_), WGSL_TEMPLATE_PARAMETER(is_qualcomm, is_qualcomm_), WGSL_TEMPLATE_PARAMETER(n_bits, nbits_), @@ -58,6 +62,9 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co if (has_bias_) { shader.AddInput("bias", ShaderUsage::UseUniform); } + if (has_weight_idx_indirect_) { + shader.AddInput("weight_index_indirect", ShaderUsage::UseUniform); + } const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); ORT_ENFORCE(WorkgroupSizeX() % tile_size_k_vec_ == 0 && tile_size_k_vec_ % 4 == 0, "tile_size_k_vec_ must evenly divide workgroup size X and be divisible by 4"); @@ -67,6 +74,7 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co return WGSL_TEMPLATE_APPLY(shader, "quantization/dp4a_matmul_small_m.wgsl.template", WGSL_TEMPLATE_PARAMETER(has_bias, has_bias_), WGSL_TEMPLATE_PARAMETER(has_weight_idx, has_weight_idx_), + WGSL_TEMPLATE_PARAMETER(has_weight_idx_indirect, has_weight_idx_indirect_), WGSL_TEMPLATE_PARAMETER(has_zero_points, has_zero_points_), WGSL_TEMPLATE_PARAMETER(n_bits, nbits_), WGSL_TEMPLATE_PARAMETER(output_type_i32, true), @@ -93,7 +101,8 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor uint32_t nbits, onnxruntime::webgpu::ComputeContext& context, Tensor* y, - const uint32_t weight_index) { + const uint32_t weight_index, + const Tensor* weight_index_indirect) { constexpr uint32_t kVec4Components = 4; constexpr uint32_t kVec2Components = 2; constexpr uint32_t kU32Components = 4; @@ -116,6 +125,7 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor const bool has_zero_points = zero_points != nullptr; const bool has_bias = bias != nullptr; const bool has_weight_idx = weight_index != 0; + const bool has_weight_idx_indirect = weight_index_indirect != nullptr; const bool single_scale_weights = (block_size == K * N); if (M < min_M_for_tile_optimization) { uint32_t tile_size_k_vec = 16; @@ -126,7 +136,7 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor tile_size_n = 4; } const uint32_t b_components = (nbits == 2 ? kVec2Components : kVec4Components); - DP4AMatMulNBitsSmallMProgram mul_program{tile_size_k_vec, tile_size_n, nbits, has_zero_points, has_bias, has_weight_idx, single_scale_weights}; + DP4AMatMulNBitsSmallMProgram mul_program{tile_size_k_vec, tile_size_n, nbits, has_zero_points, has_bias, has_weight_idx, has_weight_idx_indirect, single_scale_weights}; uint32_t num_N_tile = (N + tile_size_n - 1) / tile_size_n; mul_program.SetWorkgroupSize(128); mul_program.SetDispatchGroupSize(batch_count * M * num_N_tile); @@ -136,13 +146,16 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor {scales, ProgramTensorMetadataDependency::TypeAndRank, 1}}) .AddUniformVariables({batch_count, M, N, K, K / 16, K / 32, block_size, num_N_tile, zero_blocks_per_col, weight_index}) .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, 1}) - .CacheHint(nbits, tile_size_k_vec, tile_size_n, has_zero_points, single_scale_weights, has_bias, has_weight_idx); + .CacheHint(nbits, tile_size_k_vec, tile_size_n, has_zero_points, single_scale_weights, has_bias, has_weight_idx, has_weight_idx_indirect); if (has_zero_points) { mul_program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); } if (has_bias) { mul_program.AddInput({bias, ProgramTensorMetadataDependency::None}); } + if (has_weight_idx_indirect) { + mul_program.AddInput({weight_index_indirect, ProgramTensorMetadataDependency::None}); + } return context.RunProgram(mul_program); } @@ -151,7 +164,7 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor uint32_t num_M_tile = (M + kTileSize - 1) / kTileSize; uint32_t num_N_tile = (N + kTileSize - 1) / kTileSize; bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"}; - DP4AMatMulNBitsProgram mul_program{block_size, nbits, has_zero_points, has_bias, has_weight_idx, is_qualcomm}; + DP4AMatMulNBitsProgram mul_program{block_size, nbits, has_zero_points, has_bias, has_weight_idx, has_weight_idx_indirect, is_qualcomm}; mul_program.SetWorkgroupSize(256); mul_program.SetDispatchGroupSize(batch_count * num_M_tile * num_N_tile); mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast(kVec4Components)}, @@ -169,13 +182,16 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor {zero_blocks_per_col}, {weight_index}}) .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, static_cast(kVec4Components)}) - .CacheHint("Block" + std::to_string(block_size), nbits, has_zero_points, is_qualcomm, has_bias, has_weight_idx); + .CacheHint("Block" + std::to_string(block_size), nbits, has_zero_points, is_qualcomm, has_bias, has_weight_idx, has_weight_idx_indirect); if (has_zero_points) { mul_program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); } if (has_bias) { mul_program.AddInput({bias, ProgramTensorMetadataDependency::None}); } + if (has_weight_idx_indirect) { + mul_program.AddInput({weight_index_indirect, ProgramTensorMetadataDependency::None}); + } return context.RunProgram(mul_program); } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h index 143a234864e2e..297cf3c3e9042 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h @@ -23,12 +23,13 @@ class DP4AMatMulNBitsProgram final : public Program { public: DP4AMatMulNBitsProgram(uint32_t block_size, uint32_t nbits, bool has_zero_points, bool has_bias, - bool has_weight_idx, bool is_qualcomm) : Program{"DP4AMatMulNBits"}, + bool has_weight_idx, bool has_weight_idx_indirect, bool is_qualcomm) : Program{"DP4AMatMulNBits"}, block_size_(block_size), nbits_(nbits), has_bias_(has_bias), has_zero_points_(has_zero_points), has_weight_idx_(has_weight_idx), + has_weight_idx_indirect_(has_weight_idx_indirect), is_qualcomm_(is_qualcomm) {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( @@ -49,6 +50,7 @@ class DP4AMatMulNBitsProgram final : public Program { bool has_bias_; bool has_zero_points_; bool has_weight_idx_; + bool has_weight_idx_indirect_; bool is_qualcomm_; }; @@ -56,13 +58,14 @@ class DP4AMatMulNBitsSmallMProgram final : public ProgramShape(), b_shape, false, true)); const bool has_bias = bias != nullptr; const bool has_weight_idx = weight_index > 0; + const bool has_weight_idx_indirect = weight_index_indirect != nullptr; const bool has_zero_points = zero_points != nullptr; if (has_zero_points) { ORT_ENFORCE(zero_points->DataType() == DataTypeImpl::GetType(), "Currently, only uint8 is supported for zero points, but got ", zero_points->DataType()); @@ -216,7 +226,7 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, // apple|intel - Experimental dawn support for subgroup matrix matmul. if (M >= kMinMForTileOptimization && (context.AdapterInfo().vendor == std::string_view{"apple"} || context.AdapterInfo().vendor == std::string_view{"intel"}) && CanApplySubgroupMatrixMatMulNBits(context, accuracy_level, block_size, batch_count, N, K, subgroup_matrix_config_index)) { - return ApplySubgroupMatrixMatMulNBits(a, b, scales, zero_points, bias, M, N, K, static_cast(nbits), zero_blocks_per_col, subgroup_matrix_config_index, context, y, weight_index); + return ApplySubgroupMatrixMatMulNBits(a, b, scales, zero_points, bias, M, N, K, static_cast(nbits), zero_blocks_per_col, subgroup_matrix_config_index, context, y, weight_index, weight_index_indirect); } #endif @@ -225,7 +235,7 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, if ((M >= kMinMForTileOptimization || y->DataType() == DataTypeImpl::GetType() || context.AdapterInfo().vendor == std::string_view{"qualcomm"}) && !(has_zero_points && nbits == 2) && CanApplyDP4AMatrixMatMulNBits(context, accuracy_level, block_size, N, K, components_a)) { - return ApplyDP4AMatrixMatMulNBits(a, b, scales, zero_points, bias, batch_count, M, N, K, block_size, zero_blocks_per_col, kMinMForTileOptimization, static_cast(nbits), context, y, weight_index); + return ApplyDP4AMatrixMatMulNBits(a, b, scales, zero_points, bias, batch_count, M, N, K, block_size, zero_blocks_per_col, kMinMForTileOptimization, static_cast(nbits), context, y, weight_index, weight_index_indirect); } // WideTileProgram @@ -246,7 +256,7 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, const uint32_t num_N_tile = CeilDiv(N, tile_n); const uint32_t num_M_tile = CeilDiv(M, tile_m); - MatMulNBitsWideTileProgram program{has_zero_points, has_bias, has_weight_idx, tile_m, tile_n, static_cast(nbits)}; + MatMulNBitsWideTileProgram program{has_zero_points, has_bias, has_weight_idx, has_weight_idx_indirect, tile_m, tile_n, static_cast(nbits)}; program.SetWorkgroupSize(workgroup_size); program.SetDispatchGroupSize(num_N_tile, num_M_tile, batch_count); @@ -271,6 +281,9 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, if (has_bias) { program.AddInput({bias, ProgramTensorMetadataDependency::None}); } + if (has_weight_idx_indirect) { + program.AddInput({weight_index_indirect, ProgramTensorMetadataDependency::None}); + } program.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, onnxruntime::narrow(components)}); @@ -284,7 +297,7 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, {num_N_tile}, {num_M_tile}, {weight_index}}); - program.CacheHint(nbits, has_zero_points, has_bias, has_weight_idx); + program.CacheHint(nbits, has_zero_points, has_bias, has_weight_idx, has_weight_idx_indirect); return context.RunProgram(program); } @@ -295,7 +308,7 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, uint32_t components_b_with_u32 = components_b * kU32Components; uint32_t num_N_tile = (N + tile_size - 1) / tile_size; uint32_t K_of_b = (n_blocks_per_col * blob_size) / components_b_with_u32; - MatMulNBitsProgram program{tile_size, static_cast(nbits), has_zero_points, has_bias, has_weight_idx, single_scale_weights}; + MatMulNBitsProgram program{tile_size, static_cast(nbits), has_zero_points, has_bias, has_weight_idx, has_weight_idx_indirect, single_scale_weights}; program.SetWorkgroupSize(workgroup_size); program.SetDispatchGroupSize((N + tile_size - 1) / tile_size, M, batch_count); program @@ -314,13 +327,16 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, {num_N_tile}, {batch_count}, {weight_index}}) - .CacheHint(nbits, has_zero_points, single_scale_weights, has_bias, has_weight_idx); + .CacheHint(nbits, has_zero_points, single_scale_weights, has_bias, has_weight_idx, has_weight_idx_indirect); if (has_zero_points) { program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); } if (has_bias) { program.AddInput({bias, ProgramTensorMetadataDependency::None}); } + if (has_weight_idx_indirect) { + program.AddInput({weight_index_indirect, ProgramTensorMetadataDependency::None}); + } return context.RunProgram(program); } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h index ccd1ef6f1355c..70ddf2f818627 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -14,8 +14,8 @@ using namespace onnxruntime::webgpu; class MatMulNBitsWideTileProgram final : public Program { public: - MatMulNBitsWideTileProgram(bool has_zero_points, bool has_bias, bool has_weight_idx, uint32_t tile_m, uint32_t tile_n, uint32_t nbits) - : Program{"MatMulNBitsWideTile"}, has_zero_points_{has_zero_points}, has_bias_{has_bias}, has_weight_idx_{has_weight_idx}, tile_m_(tile_m), tile_n_(tile_n), nbits_(nbits) {} + MatMulNBitsWideTileProgram(bool has_zero_points, bool has_bias, bool has_weight_idx, bool has_weight_idx_indirect, uint32_t tile_m, uint32_t tile_n, uint32_t nbits) + : Program{"MatMulNBitsWideTile"}, has_zero_points_{has_zero_points}, has_bias_{has_bias}, has_weight_idx_{has_weight_idx}, has_weight_idx_indirect_{has_weight_idx_indirect}, tile_m_(tile_m), tile_n_(tile_n), nbits_(nbits) {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"Batch", ProgramUniformVariableDataType::Uint32}, @@ -33,6 +33,7 @@ class MatMulNBitsWideTileProgram final : public Program { public: - MatMulNBitsProgram(uint32_t tile_size, uint32_t nbits, bool has_zero_points, bool has_bias, bool has_weight_idx, bool single_scale_weights) - : Program{"MatMulNBits"}, tile_size_(tile_size), nbits_(nbits), has_zero_points_(has_zero_points), has_bias_(has_bias), has_weight_idx_{has_weight_idx}, single_scale_weights_(single_scale_weights) {} + MatMulNBitsProgram(uint32_t tile_size, uint32_t nbits, bool has_zero_points, bool has_bias, bool has_weight_idx, bool has_weight_idx_indirect, bool single_scale_weights) + : Program{"MatMulNBits"}, tile_size_(tile_size), nbits_(nbits), has_zero_points_(has_zero_points), has_bias_(has_bias), has_weight_idx_{has_weight_idx}, has_weight_idx_indirect_{has_weight_idx_indirect}, single_scale_weights_(single_scale_weights) {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( {"M", ProgramUniformVariableDataType::Uint32}, @@ -62,6 +63,7 @@ class MatMulNBitsProgram final : public Program { bool has_zero_points_; bool has_bias_; bool has_weight_idx_; + bool has_weight_idx_indirect_; bool single_scale_weights_; }; @@ -89,7 +91,8 @@ class MatMulNBits final : public WebGpuKernel { Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, const Tensor* zero_points, const Tensor* bias, int64_t K_op, int64_t N_op, int64_t block_size_op, int64_t accuracy_level, int64_t bits_op, - onnxruntime::webgpu::ComputeContext& context, Tensor* y, const uint32_t weight_index = 0); + onnxruntime::webgpu::ComputeContext& context, Tensor* y, const uint32_t weight_index = 0, + const Tensor* weight_index_indirect = nullptr); } // namespace webgpu } // namespace contrib diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.wgsl.template index 6a66d2eb402e5..15325181974ad 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.wgsl.template @@ -13,6 +13,7 @@ #param tile_size #param has_bias #param has_weight_idx +#param has_weight_idx_indirect #use .getByOffset .setByOffset @@ -35,18 +36,24 @@ fn loadSHMA(batch: u32, a_global: u32, kidx: u32, col: u32) $MAIN { let batch = workgroup_idx / (uniforms.M * uniforms.num_N_tile); #if has_weight_idx - let b_base_offset = uniforms.weight_idx * uniforms.K_of_b * uniforms.N; +#if has_weight_idx_indirect + let actual_weight_idx = weight_index_indirect[uniforms.weight_idx]; +#else + let actual_weight_idx = uniforms.weight_idx; +#endif + let b_base_offset = actual_weight_idx * uniforms.K_of_b * uniforms.N; #if single_scale_weights - let b_scale_offset = uniforms.weight_idx; + let b_scale_offset = actual_weight_idx; #else - let b_scale_offset = uniforms.weight_idx * uniforms.N * uniforms.blocks_per_col; + let b_scale_offset = actual_weight_idx * uniforms.N * uniforms.blocks_per_col; #endif #else const b_base_offset : u32 = 0; const b_scale_offset : u32 = 0; + const actual_weight_idx : u32 = 0; #endif #if has_bias - let b_bias_offset = uniforms.weight_idx * uniforms.N; + let b_bias_offset = actual_weight_idx * uniforms.N; #endif let a_global = (workgroup_idx / uniforms.num_N_tile) % uniforms.M; let b_global_base = (workgroup_idx % uniforms.num_N_tile) * tile_size; diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_wide_tile.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_wide_tile.wgsl.template index b95d4bd49c6d8..c8ddc9f2590d1 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_wide_tile.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_wide_tile.wgsl.template @@ -4,6 +4,7 @@ #param has_zero_points #param has_bias #param has_weight_idx +#param has_weight_idx_indirect #param nbits #param tile_m #param tile_n @@ -71,7 +72,12 @@ fn load_scale(row : u32, block_idx : u32) -> output_element_t { if (row < uniforms.N && block_idx < uniforms.n_blocks_per_col) { let offset = row * uniforms.n_blocks_per_col + block_idx; #if has_weight_idx - let b_scale_offset = uniforms.weight_idx * uniforms.N * uniforms.n_blocks_per_col; +#if has_weight_idx_indirect + let actual_weight_idx = weight_index_indirect[uniforms.weight_idx]; +#else + let actual_weight_idx = uniforms.weight_idx; +#endif + let b_scale_offset = actual_weight_idx * uniforms.N * uniforms.n_blocks_per_col; return scales.getByOffset(offset + b_scale_offset); #else return scales.getByOffset(offset); @@ -91,7 +97,12 @@ fn write_output(batch : u32, row : u32, col : u32, value : output_element_t) { fn load_b(row : u32, block_idx : u32) -> vec4 { if (row < uniforms.N && block_idx < uniforms.K_of_b) { #if has_weight_idx - let b_offset = uniforms.weight_idx * uniforms.K_of_b * uniforms.N; +#if has_weight_idx_indirect + let actual_weight_idx = weight_index_indirect[uniforms.weight_idx]; +#else + let actual_weight_idx = uniforms.weight_idx; +#endif + let b_offset = actual_weight_idx * uniforms.K_of_b * uniforms.N; let offset = row * uniforms.K_of_b + block_idx + b_offset; #else let offset = row * uniforms.K_of_b + block_idx; @@ -125,7 +136,12 @@ fn dequantize(packed_data : u32, fn load_b(row : u32, block_idx : u32) -> array, 4> { if (row < uniforms.N) { #if has_weight_idx - let b_offset = uniforms.weight_idx * uniforms.K_of_b * uniforms.N; +#if has_weight_idx_indirect + let actual_weight_idx = weight_index_indirect[uniforms.weight_idx]; +#else + let actual_weight_idx = uniforms.weight_idx; +#endif + let b_offset = actual_weight_idx * uniforms.K_of_b * uniforms.N; let offset = 2 * block_idx + b_offset; #else let offset = 2 * block_idx; @@ -209,7 +225,12 @@ $MAIN { // Write the results. #if has_bias #if has_weight_idx - let b_bias_offset = uniforms.weight_idx * uniforms.N; + #if has_weight_idx_indirect + let actual_weight_idx_bias = weight_index_indirect[uniforms.weight_idx]; + #else + let actual_weight_idx_bias = uniforms.weight_idx; + #endif + let b_bias_offset = actual_weight_idx_bias * uniforms.N; let bias_value = bias[b_bias_offset + col + local_idx]; #else let bias_value = bias[col + local_idx]; diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc index 50aa4de4749bb..e43682ea498a2 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc @@ -280,10 +280,11 @@ Status GenerateShaderCodeOnIntel(ShaderHelper& shader, const ShaderVariableHelpe Status GenerateShaderCodeOnApple(ShaderHelper& shader, const ShaderVariableHelper& a, const ShaderVariableHelper& b, const ShaderVariableHelper& scales_b, - const ShaderVariableHelper& output, uint32_t nbits, bool has_zero_points, bool has_bias, bool has_weight_idx) { + const ShaderVariableHelper& output, uint32_t nbits, bool has_zero_points, bool has_bias, bool has_weight_idx, bool has_weight_idx_indirect) { return WGSL_TEMPLATE_APPLY(shader, "quantization/subgroup_matrix_matmul_nbits_apple.wgsl.template", WGSL_TEMPLATE_PARAMETER(has_bias, has_bias), WGSL_TEMPLATE_PARAMETER(has_weight_idx, has_weight_idx), + WGSL_TEMPLATE_PARAMETER(has_weight_idx_indirect, has_weight_idx_indirect), WGSL_TEMPLATE_PARAMETER(has_zero_points, has_zero_points), WGSL_TEMPLATE_PARAMETER(n_bits, nbits), WGSL_TEMPLATE_PARAMETER(output_type_i32, false), @@ -303,11 +304,14 @@ Status SubgroupMatrixMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader if (has_bias_) { shader.AddInput("bias", ShaderUsage::UseUniform); } + if (has_weight_idx_indirect_) { + shader.AddInput("weight_index_indirect", ShaderUsage::UseUniform); + } const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); // TODO: add support for bias to the shader for Intel. In the meantime, use the shader for Metal if (!vendor_.compare("apple") || has_bias_) { - return GenerateShaderCodeOnApple(shader, a, b, scales_b, output, nbits_, has_zero_points_, has_bias_, has_weight_idx_); + return GenerateShaderCodeOnApple(shader, a, b, scales_b, output, nbits_, has_zero_points_, has_bias_, has_weight_idx_, has_weight_idx_indirect_); } else if (!vendor_.compare("intel")) { return GenerateShaderCodeOnIntel(shader, b, scales_b, nbits_, config_index_, has_zero_points_); } else { @@ -326,7 +330,8 @@ Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Te int32_t config_index, onnxruntime::webgpu::ComputeContext& context, Tensor* y, - const uint32_t weight_index) { + const uint32_t weight_index, + const Tensor* weight_index_indirect) { // If applicable, layout optimization of input matrix A(MxK) can be used for SubgroupMatrixLoad. Tensor a_prepack; if (context.AdapterInfo().vendor == std::string_view{"intel"}) { @@ -364,7 +369,8 @@ Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Te const bool has_zero_points = zero_points != nullptr; const bool has_bias = bias != nullptr; const bool has_weight_idx = weight_index > 0; - SubgroupMatrixMatMulNBitsProgram mul_program{nbits, config_index, context.AdapterInfo().vendor, has_zero_points, has_bias, has_weight_idx}; + const bool has_weight_idx_indirect = weight_index_indirect != nullptr; + SubgroupMatrixMatMulNBitsProgram mul_program{nbits, config_index, context.AdapterInfo().vendor, has_zero_points, has_bias, has_weight_idx, has_weight_idx_indirect}; if (context.AdapterInfo().vendor == std::string_view{"intel"}) { tile_size_a = 64; work_group_size = 256; @@ -378,13 +384,16 @@ Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Te {scales, ProgramTensorMetadataDependency::TypeAndRank, 1}}) .AddUniformVariables({{M}, {N}, {K}, {zero_blocks_per_col}, {weight_index}}) .AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, y_shape, 1}) - .CacheHint(nbits, has_zero_points, has_bias, has_weight_idx); + .CacheHint(nbits, has_zero_points, has_bias, has_weight_idx, has_weight_idx_indirect); if (has_zero_points) { mul_program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4}); } if (bias) { mul_program.AddInput({bias, ProgramTensorMetadataDependency::None}); } + if (has_weight_idx_indirect) { + mul_program.AddInput({weight_index_indirect, ProgramTensorMetadataDependency::None}); + } return context.RunProgram(mul_program); } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h index cb9bd8a599f54..9836d2ad04d44 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.h @@ -21,14 +21,15 @@ using namespace onnxruntime::webgpu; class SubgroupMatrixMatMulNBitsProgram final : public Program { public: - SubgroupMatrixMatMulNBitsProgram(uint32_t nbits, int32_t config_index, const wgpu::StringView& vendor, bool has_zero_points, bool has_bias, bool has_weight_idx) + SubgroupMatrixMatMulNBitsProgram(uint32_t nbits, int32_t config_index, const wgpu::StringView& vendor, bool has_zero_points, bool has_bias, bool has_weight_idx, bool has_weight_idx_indirect) : Program{"SubgroupMatrixMatMulNBits"}, nbits_(nbits), config_index_(config_index), vendor_(vendor), has_zero_points_(has_zero_points), has_bias_(has_bias), - has_weight_idx_{has_weight_idx} {}; + has_weight_idx_{has_weight_idx}, + has_weight_idx_indirect_{has_weight_idx_indirect} {}; Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( {"M", ProgramUniformVariableDataType::Uint32}, @@ -44,6 +45,7 @@ class SubgroupMatrixMatMulNBitsProgram final : public Program 0 && row < u32(row_limit)) { let col2 = col + 1; #if has_bias +#if has_weight_idx_indirect + let col_base = offset % uniforms.N + weight_index_indirect[uniforms.weight_idx] * uniforms.N; +#else let col_base = offset % uniforms.N + uniforms.weight_idx * uniforms.N; +#endif output.setByOffset(offset + row * uniforms.N + col, output_element_t(scratch[src_slot][0][row * 8 + col]) + bias[col_base + col]); output.setByOffset(offset + row * uniforms.N + col + 8, output_element_t(scratch[src_slot][1][row * 8 + col]) + bias[col_base + col + 8]); From 41701c8dc63a80ecb21413bad8c64f5e7166cd27 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Wed, 18 Feb 2026 14:53:19 -0800 Subject: [PATCH 2/6] optimized qmoe code pass for 1 token --- .../webgpu/moe/final_mix.wgsl.template | 4 +- .../webgpu/moe/final_mix_1token.wgsl.template | 19 +++++ .../contrib_ops/webgpu/moe/gate.wgsl.template | 9 ++- .../webgpu/moe/gate_1token.wgsl.template | 79 +++++++++++++++++++ .../webgpu/quantization/matmul_nbits.cc | 2 +- 5 files changed, 106 insertions(+), 7 deletions(-) create mode 100644 onnxruntime/contrib_ops/webgpu/moe/final_mix_1token.wgsl.template create mode 100644 onnxruntime/contrib_ops/webgpu/moe/gate_1token.wgsl.template diff --git a/onnxruntime/contrib_ops/webgpu/moe/final_mix.wgsl.template b/onnxruntime/contrib_ops/webgpu/moe/final_mix.wgsl.template index 80887b845f915..231470a39603e 100644 --- a/onnxruntime/contrib_ops/webgpu/moe/final_mix.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/moe/final_mix.wgsl.template @@ -5,13 +5,13 @@ // in: router_values [num_tokens, num_experts] // in: expert_tokens [used_by], mapping token idx to original token index // out: output -// uniform: used_by, hidden_size, num_experts, expert_idx +// uniform: hidden_size, expert_idx, token_offset $MAIN { let token_idx = expert_tokens[workgroup_idx]; let step = uniforms.hidden_size / workgroup_size_x; let wg_offset = local_idx * step; - // token_idx is the offset into hidden state while fc2_outputs is for the chunk and + // token_idx is the offset into hidden state while fc2_outputs is for the chunk so // we need to substract uniforms.token_offset let router_value_offset = (token_idx - uniforms.token_offset) * uniforms.num_experts + uniforms.expert_idx; let router_value = router_values[router_value_offset]; diff --git a/onnxruntime/contrib_ops/webgpu/moe/final_mix_1token.wgsl.template b/onnxruntime/contrib_ops/webgpu/moe/final_mix_1token.wgsl.template new file mode 100644 index 0000000000000..7965b7c8bf0c5 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/moe/final_mix_1token.wgsl.template @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// in: fc2_outputs [used_by, inter_size] +// in: router_values [num_tokens, num_experts] +// in: indirect_experts +// out: output +// uniform: hidden_size, expert_idx + +$MAIN { + let expert_idx = indirect_experts[uniforms.expert_idx]; + let steps = uniforms.hidden_size / workgroup_size_x; + let router_value = router_values[expert_idx]; + let offset = local_idx * steps; + for (var i = 0u; i < steps; i++) { + let weight = fc2_outputs[offset + i]; + output[offset + i] += router_value * weight; + } +} diff --git a/onnxruntime/contrib_ops/webgpu/moe/gate.wgsl.template b/onnxruntime/contrib_ops/webgpu/moe/gate.wgsl.template index 6e0d4c7299793..034a7d5305e09 100644 --- a/onnxruntime/contrib_ops/webgpu/moe/gate.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/moe/gate.wgsl.template @@ -5,7 +5,8 @@ // MOE gate shader // // called with expert as local_idx and token_idx as workgroup_idx -// in: router_values [num_tokens, num_experts], per expert float we multiply final results with +// in: router_logits [num_tokens, num_experts], per expert float we multiply final results with +// out: topk_values [num_tokens, num_experts], number of tokens assigned to each expert // out: gate_counts [num_experts], number of tokens assigned to each expert // out: gate_hidden [num_experts, num_tokens], token_idx assigned to each expert // uniform: rows(num_tokens), cols(num_experts), token_offset @@ -21,7 +22,7 @@ const MAX_FLOAT: f16 = 65504.0; const MAX_FLOAT: f32 = 3.4028234663852886e+38; #endif -var shared_vals: array; +var shared_vals: array; var shared_idxs: array; $MAIN { @@ -32,14 +33,14 @@ $MAIN { let cols = uniforms.cols; let output_base = row * cols; - var max_val: hidden_state_element_t = -MAX_FLOAT; + var max_val: router_logits_element_t = -MAX_FLOAT; var max_idx: u32 = 0u; if (global_idx < cols) { atomicStore(&tokencount_for_expert[global_idx], 0u); } if (local_idx < cols) { - max_val = hidden_state[(row + uniforms.token_offset) * cols + local_idx]; + max_val = router_logits[(row + uniforms.token_offset) * cols + local_idx]; max_idx = local_idx; } shared_vals[local_idx] = max_val; diff --git a/onnxruntime/contrib_ops/webgpu/moe/gate_1token.wgsl.template b/onnxruntime/contrib_ops/webgpu/moe/gate_1token.wgsl.template new file mode 100644 index 0000000000000..200445cc3a89c --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/moe/gate_1token.wgsl.template @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// +// MOE 1 token gate shader +// +// called with expert as local_idx +// input: router_logits +// output: topk_values +// output: indirect_experts + +#param is_fp16 +#param k + +const K: u32 = k; +#if is_fp16 +const MAX_FLOAT: f16 = 65504.0; +#else +const MAX_FLOAT: f32 = 3.4028234663852886e+38; +#endif + +var shared_vals: array; +var shared_idxs: array; + +$MAIN { + let row = workgroup_idx; + if (row >= uniforms.rows) { + return; + } + let cols = uniforms.cols; + let output_base = row * cols; + + var max_val: router_logits_element_t = -MAX_FLOAT; + var max_idx: u32 = 0u; + + if (local_idx < cols) { + max_val = router_logits[row * cols + local_idx]; + max_idx = local_idx; + } + shared_vals[local_idx] = max_val; + shared_idxs[local_idx] = max_idx; + topk_values[output_base + local_idx] = topk_values_value_t(0); + workgroupBarrier(); + + // K is small, use a simple bubble sort + for (var i = 0u; i < workgroup_size_x - 1u; i++) { + for (var j = 0u; j < workgroup_size_x - 1u - i; j++) { + if (local_idx == j && local_idx < cols && (local_idx + 1u) < cols) { + // Compare adjacent elements and swap if needed (descending order) + if (shared_vals[local_idx] < shared_vals[local_idx + 1u]) { + let temp_val = shared_vals[local_idx]; + let temp_idx = shared_idxs[local_idx]; + shared_vals[local_idx] = shared_vals[local_idx + 1u]; + shared_idxs[local_idx] = shared_idxs[local_idx + 1u]; + shared_vals[local_idx + 1u] = temp_val; + shared_idxs[local_idx + 1u] = temp_idx; + } + } + workgroupBarrier(); + } + } + if (local_idx < K) { + // found the top K experts for token, write to output + let expert_idx = shared_idxs[local_idx]; + let expert_base = expert_idx * uniforms.rows; + } + if (local_idx == 0u) { + // softmax + var sum : f32 = 0.0; + for (var i = 0u; i < K; i++) { + sum += exp(f32(shared_vals[i])); + } + for (var i = 0u; i < K; i++) { + let expert_idx = shared_idxs[i]; + topk_values[output_base + expert_idx] = topk_values_value_t(exp(f32(shared_vals[i])) / sum); + indirect_experts[i] = expert_idx; + } + } +} // MAIN diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 4162ce29800d0..956c2b0c9380c 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -191,8 +191,8 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); const bool has_bias = bias != nullptr; - const bool has_weight_idx = weight_index > 0; const bool has_weight_idx_indirect = weight_index_indirect != nullptr; + const bool has_weight_idx = weight_index > 0 || has_weight_idx_indirect; const bool has_zero_points = zero_points != nullptr; if (has_zero_points) { ORT_ENFORCE(zero_points->DataType() == DataTypeImpl::GetType(), "Currently, only uint8 is supported for zero points, but got ", zero_points->DataType()); From 7e4b875b68f2e2b79d607cfd825df8e984af6d79 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Wed, 18 Feb 2026 14:56:24 -0800 Subject: [PATCH 3/6] missing file --- .../webgpu/quantization/dp4a_matmul_nbits.h | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h index 297cf3c3e9042..e27dd3bc39d5a 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h @@ -24,13 +24,13 @@ class DP4AMatMulNBitsProgram final : public Program { DP4AMatMulNBitsProgram(uint32_t block_size, uint32_t nbits, bool has_zero_points, bool has_bias, bool has_weight_idx, bool has_weight_idx_indirect, bool is_qualcomm) : Program{"DP4AMatMulNBits"}, - block_size_(block_size), - nbits_(nbits), - has_bias_(has_bias), - has_zero_points_(has_zero_points), - has_weight_idx_(has_weight_idx), - has_weight_idx_indirect_(has_weight_idx_indirect), - is_qualcomm_(is_qualcomm) {} + block_size_(block_size), + nbits_(nbits), + has_bias_(has_bias), + has_zero_points_(has_zero_points), + has_weight_idx_(has_weight_idx), + has_weight_idx_indirect_(has_weight_idx_indirect), + is_qualcomm_(is_qualcomm) {} Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( {"batch_count", ProgramUniformVariableDataType::Uint32}, @@ -59,14 +59,14 @@ class DP4AMatMulNBitsSmallMProgram final : public Program Date: Wed, 18 Feb 2026 15:01:55 -0800 Subject: [PATCH 4/6] add missing file --- onnxruntime/contrib_ops/webgpu/moe/qmoe.cc | 127 +++++++++++++++++++-- 1 file changed, 119 insertions(+), 8 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/moe/qmoe.cc b/onnxruntime/contrib_ops/webgpu/moe/qmoe.cc index c67cf8e37be69..39a1d1230ddf1 100755 --- a/onnxruntime/contrib_ops/webgpu/moe/qmoe.cc +++ b/onnxruntime/contrib_ops/webgpu/moe/qmoe.cc @@ -22,7 +22,7 @@ class GateProgram final : public Program { GateProgram(int k, bool is_fp16) : Program{"QmoeGate"}, k_{k}, is_fp16_{is_fp16} {}; Status GenerateShaderCode(ShaderHelper& shader) const override { - shader.AddInput("hidden_state", ShaderUsage::UseElementTypeAlias); + shader.AddInput("router_logits", ShaderUsage::UseElementTypeAlias); shader.AddOutput("topk_values"); shader.AddOutput("hiddenstate_for_expert"); shader.AddOutput("tokencount_for_expert"); @@ -42,6 +42,29 @@ class GateProgram final : public Program { bool is_fp16_; }; +class Gate1TokenProgram final : public Program { + public: + Gate1TokenProgram(int k, bool is_fp16) : Program{"QmoeGate1Token"}, k_{k}, is_fp16_{is_fp16} {}; + + Status GenerateShaderCode(ShaderHelper& shader) const override { + shader.AddInput("router_logits", ShaderUsage::UseElementTypeAlias); + shader.AddOutput("topk_values"); + shader.AddOutput("indirect_experts"); + + return WGSL_TEMPLATE_APPLY(shader, "moe/gate_1token.wgsl.template", + WGSL_TEMPLATE_PARAMETER(is_fp16, is_fp16_), + WGSL_TEMPLATE_PARAMETER(k, k_)); + }; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"rows", ProgramUniformVariableDataType::Uint32}, + {"cols", ProgramUniformVariableDataType::Uint32}); + + private: + int k_; + bool is_fp16_; +}; + class HiddenStateGatherProgram final : public Program { public: HiddenStateGatherProgram() : Program{"QmoeHiddenStateGather"} {}; @@ -115,7 +138,6 @@ class QMoEFinalMixProgram final : public Program { } WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( - {"used_by", ProgramUniformVariableDataType::Uint32}, {"hidden_size", ProgramUniformVariableDataType::Uint32}, {"num_experts", ProgramUniformVariableDataType::Uint32}, {"expert_idx", ProgramUniformVariableDataType::Uint32}, @@ -124,6 +146,26 @@ class QMoEFinalMixProgram final : public Program { private: }; +class QMoEFinalMix1TokenProgram final : public Program { + public: + QMoEFinalMix1TokenProgram() : Program{"QMoEFinalMix1TokenProgram"} {} + + Status GenerateShaderCode(ShaderHelper& shader) const override { + shader.AddInput("fc2_outputs", ShaderUsage::UseElementTypeAlias); + shader.AddInput("router_values", ShaderUsage::UseElementTypeAlias); + shader.AddInput("indirect_experts", ShaderUsage::UseElementTypeAlias); + shader.AddOutput("output", ShaderUsage::UseElementTypeAlias); + + return WGSL_TEMPLATE_APPLY(shader, "moe/final_mix_1token.wgsl.template"); + } + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"hidden_size", ProgramUniformVariableDataType::Uint32}, + {"expert_idx", ProgramUniformVariableDataType::Uint32}); + + private: +}; + Status QMoE::ComputeInternal(ComputeContext& context) const { const Tensor* hidden_state = context.Input(0); const Tensor* router_logits = context.Input(1); @@ -168,7 +210,7 @@ Status QMoE::ComputeInternal(ComputeContext& context) const { } // process tokens in chunks of max_tokens to put some cap on memory usage - const int max_tokens = 512; + const int max_tokens = 2 * 1024; const uint32_t num_experts = static_cast(moe_params.num_experts); const uint32_t hidden_size = static_cast(moe_params.hidden_size); @@ -197,6 +239,78 @@ Status QMoE::ComputeInternal(ComputeContext& context) const { .AddUniformVariables({static_cast(total_output_size)}); ORT_RETURN_IF_ERROR(context.RunProgram(zero)); + if (moe_params.num_rows == 1) { + // Optimized code path for 1 token to avoid gpu -> cpu copy + + const int num_tokens = 1; + TensorShape gate_value_shape({num_tokens, num_experts}); + TensorShape indirect_experts_shape({k_}); + + Tensor router_values = context.CreateGPUTensor(dtype, gate_value_shape); + Tensor indirect_experts = context.CreateGPUTensor(dtype_uint32, indirect_experts_shape); + + Gate1TokenProgram gate{k_, is_fp16}; + gate + .AddInputs({{router_logits, ProgramTensorMetadataDependency::Type}}) + .AddOutput({&router_values, ProgramTensorMetadataDependency::None}) + .AddOutput({&indirect_experts, ProgramTensorMetadataDependency::None}) + .SetWorkgroupSize(num_experts) + .SetDispatchGroupSize(static_cast(num_tokens)) + .AddUniformVariables({static_cast(num_tokens), num_experts}) + .CacheHint(k_, is_fp16 ? "fp16" : "fp32"); + + ORT_RETURN_IF_ERROR(context.RunProgram(gate)); + + for (uint32_t expert_idx = 0; expert_idx < static_cast(k_); expert_idx++) { + TensorShape fc1_output_shape({num_tokens, fc1_output_size}); + Tensor fc1_outputs = context.CreateGPUTensor(dtype, fc1_output_shape); + TensorShape fc1_activated_shape({num_tokens, moe_params.inter_size}); + Tensor fc1_activated = context.CreateGPUTensor(dtype, fc1_activated_shape); + TensorShape fc2_output_shape({num_tokens, N_fc2}); + Tensor fc2_outputs = context.CreateGPUTensor(dtype, fc2_output_shape); + + status = ApplyMatMulNBits(hidden_state, fc1_experts_weights, fc1_scales, nullptr, fc1_experts_bias_optional, + K_fc1, N_fc1, block_size_fc1, accuracy_level, expert_weight_bits_, context, + &fc1_outputs, expert_idx, &indirect_experts); + ORT_RETURN_IF_ERROR(status); + + if (is_swiglu) { + SwigLuProgram swiglu; + swiglu + .AddInputs({{&fc1_outputs, ProgramTensorMetadataDependency::Type, 2}}) + .AddOutput({&fc1_activated, ProgramTensorMetadataDependency::None}) + .SetWorkgroupSize(128) + .SetDispatchGroupSize(((num_tokens * static_cast(moe_params.inter_size)) + 127) / 128) + .AddUniformVariables({static_cast(num_tokens), + static_cast(moe_params.inter_size), + activation_alpha_, + activation_beta_, + swiglu_limit_}); + ORT_RETURN_IF_ERROR(context.RunProgram(swiglu)); + } else { + ORT_THROW("only swiglu is supported for WebGPU."); + } + + status = ApplyMatMulNBits(&fc1_activated, fc2_experts_weights, fc2_scales, nullptr, fc2_experts_bias_optional, + K_fc2, N_fc2, block_size_fc2, accuracy_level, expert_weight_bits_, context, + &fc2_outputs, expert_idx, &indirect_experts); + ORT_RETURN_IF_ERROR(status); + + QMoEFinalMix1TokenProgram final_mix; + final_mix + .AddInputs({{&fc2_outputs, ProgramTensorMetadataDependency::Type}}) + .AddInputs({{&router_values, ProgramTensorMetadataDependency::Type}}) + .AddInputs({{&indirect_experts, ProgramTensorMetadataDependency::Type}}) + .AddOutput({output_tensor, ProgramTensorMetadataDependency::None}) + .SetDispatchGroupSize(1) + .AddUniformVariables({hidden_size, expert_idx}); + + ORT_RETURN_IF_ERROR(context.RunProgram(final_mix)); + } + return Status::OK(); + } + + // path for num_tokens > 1 // process tokens in chunks of max_tokens to put some cap on memory usage for (int token_offset = 0; token_offset < moe_params.num_rows; token_offset += max_tokens) { // @@ -226,9 +340,7 @@ Status QMoE::ComputeInternal(ComputeContext& context) const { .AddOutput({&gate_counts, ProgramTensorMetadataDependency::None, ProgramOutput::Atomic}) .SetWorkgroupSize(num_experts) .SetDispatchGroupSize(static_cast(num_tokens)) - .AddUniformVariables({static_cast(num_tokens), - num_experts, - static_cast(token_offset)}) + .AddUniformVariables({static_cast(num_tokens), num_experts, static_cast(token_offset)}) .CacheHint(k_, is_fp16 ? "fp16" : "fp32"); ORT_RETURN_IF_ERROR(context.RunProgram(gate)); @@ -318,8 +430,7 @@ Status QMoE::ComputeInternal(ComputeContext& context) const { .AddInputs({{&expert_tokens, ProgramTensorMetadataDependency::Type}}) .AddOutput({output_tensor, ProgramTensorMetadataDependency::None}) .SetDispatchGroupSize(used_by) - .AddUniformVariables({used_by, - hidden_size, + .AddUniformVariables({hidden_size, num_experts, expert_idx, static_cast(token_offset)}); From 89c648a5615ef8ba10a092f4796082e8f24cb5fe Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Thu, 19 Feb 2026 08:21:02 -0800 Subject: [PATCH 5/6] Update onnxruntime/contrib_ops/webgpu/moe/gate_1token.wgsl.template Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxruntime/contrib_ops/webgpu/moe/gate_1token.wgsl.template | 5 ----- 1 file changed, 5 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/moe/gate_1token.wgsl.template b/onnxruntime/contrib_ops/webgpu/moe/gate_1token.wgsl.template index 200445cc3a89c..afcb93ade702b 100644 --- a/onnxruntime/contrib_ops/webgpu/moe/gate_1token.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/moe/gate_1token.wgsl.template @@ -59,11 +59,6 @@ $MAIN { workgroupBarrier(); } } - if (local_idx < K) { - // found the top K experts for token, write to output - let expert_idx = shared_idxs[local_idx]; - let expert_base = expert_idx * uniforms.rows; - } if (local_idx == 0u) { // softmax var sum : f32 = 0.0; From 957e65fca698d5a13bdc8eb260bb0128d155c74a Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Thu, 19 Feb 2026 08:22:51 -0800 Subject: [PATCH 6/6] Update onnxruntime/contrib_ops/webgpu/moe/final_mix.wgsl.template Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxruntime/contrib_ops/webgpu/moe/final_mix.wgsl.template | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/moe/final_mix.wgsl.template b/onnxruntime/contrib_ops/webgpu/moe/final_mix.wgsl.template index 231470a39603e..c8ac409ca2932 100644 --- a/onnxruntime/contrib_ops/webgpu/moe/final_mix.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/moe/final_mix.wgsl.template @@ -5,7 +5,7 @@ // in: router_values [num_tokens, num_experts] // in: expert_tokens [used_by], mapping token idx to original token index // out: output -// uniform: hidden_size, expert_idx, token_offset +// uniform: hidden_size, num_experts, expert_idx, token_offset $MAIN { let token_idx = expert_tokens[workgroup_idx];