Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 0 additions & 19 deletions onnxruntime/contrib_ops/webgpu/moe/final_mix_1token.wgsl.template

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// Fused FinalMix for 1-token MoE: processes all k experts in one dispatch.
// Each workgroup handles one of the k expert results.
Comment thread
qjia7 marked this conversation as resolved.
Outdated
// in: fc2_outputs [k, hidden_size] — concatenated fc2 results for all k experts
// in: router_values [1, num_experts] — softmax weights per expert
// in: indirect_experts [k] — which expert index each row corresponds to
// out: output [1, hidden_size] — accumulated weighted output
// uniform: hidden_size, k

$MAIN {
let out_idx = workgroup_idx * workgroup_size_x + local_idx;
if (out_idx >= uniforms.hidden_size) {
return;
}
var acc = output_element_t(0);
for (var i = 0u; i < uniforms.k; i++) {
let expert_idx = indirect_experts[i];
let router_value = router_values[expert_idx];
acc += router_value * fc2_outputs[i * uniforms.hidden_size + out_idx];
}
output[out_idx] = acc;
} // MAIN
137 changes: 72 additions & 65 deletions onnxruntime/contrib_ops/webgpu/moe/qmoe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,44 +124,41 @@ class SwigLuProgram final : public Program<SwigLuProgram> {
private:
};

class QMoEFinalMixProgram final : public Program<QMoEFinalMixProgram> {
class FusedFinalMix1TokenProgram final : public Program<FusedFinalMix1TokenProgram> {
public:
QMoEFinalMixProgram() : Program<QMoEFinalMixProgram>{"QMoEFinalMix"} {}
FusedFinalMix1TokenProgram() : Program<FusedFinalMix1TokenProgram>{"QmoeFusedFinalMix1Token"} {}

Status GenerateShaderCode(ShaderHelper& shader) const override {
shader.AddInput("fc2_outputs", ShaderUsage::UseElementTypeAlias);
shader.AddInput("router_values", ShaderUsage::UseElementTypeAlias);
shader.AddInput("expert_tokens", ShaderUsage::UseElementTypeAlias);
shader.AddInput("indirect_experts", ShaderUsage::UseElementTypeAlias);
shader.AddOutput("output", ShaderUsage::UseElementTypeAlias);

return WGSL_TEMPLATE_APPLY(shader, "moe/final_mix.wgsl.template");
return WGSL_TEMPLATE_APPLY(shader, "moe/fused_final_mix_1token.wgsl.template");
}

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"hidden_size", ProgramUniformVariableDataType::Uint32},
{"num_experts", ProgramUniformVariableDataType::Uint32},
{"expert_idx", ProgramUniformVariableDataType::Uint32},
{"token_offset", ProgramUniformVariableDataType::Uint32});

private:
{"k", ProgramUniformVariableDataType::Uint32});
};

class QMoEFinalMix1TokenProgram final : public Program<QMoEFinalMix1TokenProgram> {
class QMoEFinalMixProgram final : public Program<QMoEFinalMixProgram> {
public:
QMoEFinalMix1TokenProgram() : Program<QMoEFinalMix1TokenProgram>{"QMoEFinalMix1TokenProgram"} {}
QMoEFinalMixProgram() : Program<QMoEFinalMixProgram>{"QMoEFinalMix"} {}

Status GenerateShaderCode(ShaderHelper& shader) const override {
shader.AddInput("fc2_outputs", ShaderUsage::UseElementTypeAlias);
shader.AddInput("router_values", ShaderUsage::UseElementTypeAlias);
shader.AddInput("indirect_experts", ShaderUsage::UseElementTypeAlias);
shader.AddInput("expert_tokens", ShaderUsage::UseElementTypeAlias);
shader.AddOutput("output", ShaderUsage::UseElementTypeAlias);

return WGSL_TEMPLATE_APPLY(shader, "moe/final_mix_1token.wgsl.template");
return WGSL_TEMPLATE_APPLY(shader, "moe/final_mix.wgsl.template");
}

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"hidden_size", ProgramUniformVariableDataType::Uint32},
{"expert_idx", ProgramUniformVariableDataType::Uint32});
{"num_experts", ProgramUniformVariableDataType::Uint32},
{"expert_idx", ProgramUniformVariableDataType::Uint32},
{"token_offset", ProgramUniformVariableDataType::Uint32});

private:
};
Expand Down Expand Up @@ -246,73 +243,83 @@ Status QMoE::ComputeInternal(ComputeContext& context) const {
ORT_RETURN_IF_ERROR(context.RunProgram(zero));

if (moe_params.num_rows == 1) {
// Optimized code path for 1 token to avoid gpu -> cpu copy
// Fused MoE path for 1 token: instead of looping k times with separate dispatches,
// run a single batched MatMulNBits with M=k where each row uses a different expert's
// weights via weight_index_indirect. A's single row is broadcast to all k rows.
// This reduces dispatches from 1 + k*4 = 17 to 5 (gate + fc1 + swiglu + fc2 + mix).

const int num_tokens = 1;
const uint32_t k = static_cast<uint32_t>(k_);
const uint32_t num_tokens = 1;
TensorShape gate_value_shape({num_tokens, num_experts});
TensorShape indirect_experts_shape({k_});
TensorShape indirect_experts_shape({k});

Tensor router_values = context.CreateGPUTensor(dtype, gate_value_shape);
Tensor indirect_experts = context.CreateGPUTensor(dtype_uint32, indirect_experts_shape);

// Step 1: Gate — select top-k experts
Gate1TokenProgram gate{k_, is_fp16};
gate
.AddInputs({{router_logits, ProgramTensorMetadataDependency::Type}})
.AddOutput({&router_values, ProgramTensorMetadataDependency::None})
.AddOutput({&indirect_experts, ProgramTensorMetadataDependency::None})
.SetWorkgroupSize(num_experts)
.SetDispatchGroupSize(static_cast<uint32_t>(num_tokens))
.AddUniformVariables({static_cast<uint32_t>(num_tokens), num_experts})
.SetDispatchGroupSize(num_tokens)
.AddUniformVariables({num_tokens, num_experts})
.CacheHint(k_, is_fp16 ? "fp16" : "fp32");

ORT_RETURN_IF_ERROR(context.RunProgram(gate));

for (uint32_t expert_idx = 0; expert_idx < static_cast<uint32_t>(k_); expert_idx++) {
TensorShape fc1_output_shape({num_tokens, fc1_output_size});
Tensor fc1_outputs = context.CreateGPUTensor(dtype, fc1_output_shape);
TensorShape fc1_activated_shape({num_tokens, moe_params.inter_size});
Tensor fc1_activated = context.CreateGPUTensor(dtype, fc1_activated_shape);
TensorShape fc2_output_shape({num_tokens, N_fc2});
Tensor fc2_outputs = context.CreateGPUTensor(dtype, fc2_output_shape);

status = ApplyMatMulNBits(hidden_state, fc1_experts_weights, fc1_scales, nullptr, fc1_experts_bias_optional,
K_fc1, N_fc1, block_size_fc1, accuracy_level, expert_weight_bits_, context,
&fc1_outputs, expert_idx, &indirect_experts);
ORT_RETURN_IF_ERROR(status);

if (is_swiglu) {
SwigLuProgram swiglu;
swiglu
.AddInputs({{&fc1_outputs, ProgramTensorMetadataDependency::Type, 2}})
.AddOutput({&fc1_activated, ProgramTensorMetadataDependency::None})
.SetWorkgroupSize(128)
.SetDispatchGroupSize(((num_tokens * static_cast<uint32_t>(moe_params.inter_size)) + 127) / 128)
.AddUniformVariables({static_cast<uint32_t>(num_tokens),
static_cast<uint32_t>(moe_params.inter_size),
activation_alpha_,
activation_beta_,
swiglu_limit_});
ORT_RETURN_IF_ERROR(context.RunProgram(swiglu));
} else {
ORT_THROW("only swiglu is supported for WebGPU.");
}
// Step 2: Batched fc1 MatMulNBits with M=k, per-row expert selection.
// A is (1, hidden_size) but dispatched with override_M=k; shader broadcasts A row 0.
TensorShape fc1_output_shape({static_cast<int64_t>(k), fc1_output_size});
Tensor fc1_outputs = context.CreateGPUTensor(dtype, fc1_output_shape);
status = ApplyMatMulNBits(hidden_state, fc1_experts_weights, fc1_scales, nullptr, fc1_experts_bias_optional,
K_fc1, N_fc1, block_size_fc1, accuracy_level, expert_weight_bits_, context,
&fc1_outputs, 0, &indirect_experts, /*override_M=*/k);
ORT_RETURN_IF_ERROR(status);

// Step 3: SwiGLU on all k rows at once
TensorShape fc1_activated_shape({static_cast<int64_t>(k), moe_params.inter_size});
Tensor fc1_activated = context.CreateGPUTensor(dtype, fc1_activated_shape);
if (is_swiglu) {
SwigLuProgram swiglu;
swiglu
.AddInputs({{&fc1_outputs, ProgramTensorMetadataDependency::Type, 2}})
.AddOutput({&fc1_activated, ProgramTensorMetadataDependency::None})
.SetWorkgroupSize(128)
.SetDispatchGroupSize(((k * static_cast<uint32_t>(moe_params.inter_size)) + 127) / 128)
.AddUniformVariables({k,
static_cast<uint32_t>(moe_params.inter_size),
activation_alpha_,
activation_beta_,
swiglu_limit_});
ORT_RETURN_IF_ERROR(context.RunProgram(swiglu));
} else {
ORT_THROW("only swiglu is supported for WebGPU.");
}

status = ApplyMatMulNBits(&fc1_activated, fc2_experts_weights, fc2_scales, nullptr, fc2_experts_bias_optional,
K_fc2, N_fc2, block_size_fc2, accuracy_level, expert_weight_bits_, context,
&fc2_outputs, expert_idx, &indirect_experts);
ORT_RETURN_IF_ERROR(status);
// Step 4: Batched fc2 MatMulNBits with M=k, per-row expert selection
// fc1_activated already has k rows (one per expert), no override_M needed.
TensorShape fc2_output_shape({static_cast<int64_t>(k), N_fc2});
Tensor fc2_outputs = context.CreateGPUTensor(dtype, fc2_output_shape);
status = ApplyMatMulNBits(&fc1_activated, fc2_experts_weights, fc2_scales, nullptr, fc2_experts_bias_optional,
K_fc2, N_fc2, block_size_fc2, accuracy_level, expert_weight_bits_, context,
&fc2_outputs, 0, &indirect_experts, /*override_M=*/0);
ORT_RETURN_IF_ERROR(status);

// Step 5: Fused FinalMix — accumulate all k expert results weighted by router_values
// Dispatch across hidden_size (not k) to avoid race: each thread accumulates all k experts.
const uint32_t mix_wg_size = 256;
FusedFinalMix1TokenProgram final_mix;
final_mix
Comment thread
qjia7 marked this conversation as resolved.
.AddInputs({{&fc2_outputs, ProgramTensorMetadataDependency::Type}})
.AddInputs({{&router_values, ProgramTensorMetadataDependency::Type}})
.AddInputs({{&indirect_experts, ProgramTensorMetadataDependency::Type}})
.AddOutput({output_tensor, ProgramTensorMetadataDependency::None})
.SetWorkgroupSize(mix_wg_size)
.SetDispatchGroupSize((hidden_size + mix_wg_size - 1) / mix_wg_size)
.AddUniformVariables({hidden_size, k});
ORT_RETURN_IF_ERROR(context.RunProgram(final_mix));

QMoEFinalMix1TokenProgram final_mix;
final_mix
.AddInputs({{&fc2_outputs, ProgramTensorMetadataDependency::Type}})
.AddInputs({{&router_values, ProgramTensorMetadataDependency::Type}})
.AddInputs({{&indirect_experts, ProgramTensorMetadataDependency::Type}})
.AddOutput({output_tensor, ProgramTensorMetadataDependency::None})
.SetDispatchGroupSize(1)
.AddUniformVariables({hidden_size, expert_idx});

ORT_RETURN_IF_ERROR(context.RunProgram(final_mix));
}
return Status::OK();
}

Expand Down
15 changes: 9 additions & 6 deletions onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co
ORT_ENFORCE(tile_size_ % sub_tile_count == 0, "tile_size_ must be divisible by sub_tile_count");

return WGSL_TEMPLATE_APPLY(shader, "quantization/dp4a_matmul_small_m.wgsl.template",
WGSL_TEMPLATE_PARAMETER(broadcast_a_row, broadcast_a_row_),
WGSL_TEMPLATE_PARAMETER(has_bias, has_bias_),
WGSL_TEMPLATE_PARAMETER(has_weight_idx, has_weight_idx_),
WGSL_TEMPLATE_PARAMETER(has_weight_idx_indirect, has_weight_idx_indirect_),
Expand All @@ -93,6 +94,7 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
const Tensor* zero_points, const Tensor* bias,
uint32_t batch_count,
uint32_t M,
uint32_t dispatch_M,
uint32_t N,
uint32_t K,
uint32_t block_size,
Expand Down Expand Up @@ -124,25 +126,26 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor

const bool has_zero_points = zero_points != nullptr;
const bool has_bias = bias != nullptr;
const bool has_weight_idx = weight_index != 0;
const bool has_weight_idx_indirect = weight_index_indirect != nullptr;
const bool has_weight_idx = weight_index != 0 || has_weight_idx_indirect;
const bool single_scale_weights = (block_size == K * N);
if (M < min_M_for_tile_optimization) {
if (has_weight_idx_indirect || M < min_M_for_tile_optimization) {
uint32_t tile_size_k_vec = 32;
uint32_t tile_size_n = 4;

const uint32_t b_components = (nbits == 2 ? kVec2Components : kVec4Components);
DP4AMatMulNBitsSmallMProgram mul_program{tile_size_k_vec, tile_size_n, nbits, has_zero_points, has_bias, has_weight_idx, has_weight_idx_indirect, single_scale_weights};
const bool broadcast_a = dispatch_M > M;
DP4AMatMulNBitsSmallMProgram mul_program{tile_size_k_vec, tile_size_n, nbits, has_zero_points, has_bias, has_weight_idx, has_weight_idx_indirect, single_scale_weights, broadcast_a};
uint32_t num_N_tile = (N + tile_size_n - 1) / tile_size_n;
mul_program.SetWorkgroupSize(128);
mul_program.SetDispatchGroupSize(batch_count * M * num_N_tile);
mul_program.SetDispatchGroupSize(batch_count * dispatch_M * num_N_tile);
mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(kVec4Components)},
{&a_scale, ProgramTensorMetadataDependency::TypeAndRank, 1},
{b, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(b_components * kU32Components)},
{scales, ProgramTensorMetadataDependency::TypeAndRank, 1}})
.AddUniformVariables({batch_count, M, N, K, K / 16, K / 32, block_size, num_N_tile, zero_blocks_per_col, weight_index})
.AddUniformVariables({batch_count, dispatch_M, N, K, K / 16, K / 32, block_size, num_N_tile, zero_blocks_per_col, weight_index})
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, 1})
.CacheHint(nbits, tile_size_k_vec, tile_size_n, has_zero_points, single_scale_weights, has_bias, has_weight_idx, has_weight_idx_indirect);
.CacheHint(nbits, tile_size_k_vec, tile_size_n, has_zero_points, single_scale_weights, has_bias, has_weight_idx, has_weight_idx_indirect, broadcast_a);
if (has_zero_points) {
mul_program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4});
}
Expand Down
22 changes: 13 additions & 9 deletions onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,17 @@ class DP4AMatMulNBitsSmallMProgram final : public Program<DP4AMatMulNBitsSmallMP
public:
DP4AMatMulNBitsSmallMProgram(uint32_t tile_size_k_vec, uint32_t tile_size, uint32_t nbits,
bool has_zero_points, bool has_bias,
bool has_weight_idx, bool has_weight_idx_indirect, bool single_scale_weights) : Program{"DP4AMatMulNBitsSmallMProgram"},
tile_size_k_vec_(tile_size_k_vec),
tile_size_(tile_size),
nbits_(nbits),
has_bias_(has_bias),
has_zero_points_(has_zero_points),
has_weight_idx_(has_weight_idx),
has_weight_idx_indirect_(has_weight_idx_indirect),
single_scale_weights_(single_scale_weights) {}
bool has_weight_idx, bool has_weight_idx_indirect, bool single_scale_weights,
bool broadcast_a_row = false) : Program{"DP4AMatMulNBitsSmallMProgram"},
tile_size_k_vec_(tile_size_k_vec),
tile_size_(tile_size),
nbits_(nbits),
has_bias_(has_bias),
has_zero_points_(has_zero_points),
has_weight_idx_(has_weight_idx),
has_weight_idx_indirect_(has_weight_idx_indirect),
single_scale_weights_(single_scale_weights),
broadcast_a_row_(broadcast_a_row) {}
Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"batch_count", ProgramUniformVariableDataType::Uint32},
Expand All @@ -89,12 +91,14 @@ class DP4AMatMulNBitsSmallMProgram final : public Program<DP4AMatMulNBitsSmallMP
bool has_weight_idx_;
bool has_weight_idx_indirect_;
bool single_scale_weights_;
bool broadcast_a_row_;
};

Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales,
const Tensor* zero_points, const Tensor* bias,
uint32_t batch_count,
uint32_t M,
uint32_t dispatch_M,
uint32_t N,
uint32_t K,
uint32_t block_size,
Expand Down
Loading
Loading