Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/webgpu/moe/final_mix.wgsl.template
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
// in: router_values [num_tokens, num_experts]
// in: expert_tokens [used_by], mapping token idx to original token index
// out: output
// uniform: used_by, hidden_size, num_experts, expert_idx
// uniform: hidden_size, num_experts, expert_idx, token_offset

$MAIN {
let token_idx = expert_tokens[workgroup_idx];
let step = uniforms.hidden_size / workgroup_size_x;
let wg_offset = local_idx * step;
// token_idx is the offset into hidden state while fc2_outputs is for the chunk and
// token_idx is the offset into hidden state while fc2_outputs is for the chunk so
// we need to substract uniforms.token_offset

Check warning on line 15 in onnxruntime/contrib_ops/webgpu/moe/final_mix.wgsl.template

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "substract" is a misspelling of "subtract" Raw Output: ./onnxruntime/contrib_ops/webgpu/moe/final_mix.wgsl.template:15:18: "substract" is a misspelling of "subtract"
let router_value_offset = (token_idx - uniforms.token_offset) * uniforms.num_experts + uniforms.expert_idx;
let router_value = router_values[router_value_offset];
let fc2_outputs_offset = workgroup_idx * uniforms.hidden_size + wg_offset;
Expand Down
19 changes: 19 additions & 0 deletions onnxruntime/contrib_ops/webgpu/moe/final_mix_1token.wgsl.template
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// in: fc2_outputs [used_by, inter_size]
Comment thread
guschmue marked this conversation as resolved.
// in: router_values [num_tokens, num_experts]
// in: indirect_experts
// out: output
// uniform: hidden_size, expert_idx

$MAIN {
let expert_idx = indirect_experts[uniforms.expert_idx];
let steps = uniforms.hidden_size / workgroup_size_x;
let router_value = router_values[expert_idx];
let offset = local_idx * steps;
for (var i = 0u; i < steps; i++) {
let weight = fc2_outputs[offset + i];
output[offset + i] += router_value * weight;
}
}
9 changes: 5 additions & 4 deletions onnxruntime/contrib_ops/webgpu/moe/gate.wgsl.template
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
// MOE gate shader
//
// called with expert as local_idx and token_idx as workgroup_idx
// in: router_values [num_tokens, num_experts], per expert float we multiply final results with
// in: router_logits [num_tokens, num_experts], per expert float we multiply final results with
// out: topk_values [num_tokens, num_experts], number of tokens assigned to each expert
// out: gate_counts [num_experts], number of tokens assigned to each expert
// out: gate_hidden [num_experts, num_tokens], token_idx assigned to each expert
// uniform: rows(num_tokens), cols(num_experts), token_offset
Expand All @@ -21,7 +22,7 @@ const MAX_FLOAT: f16 = 65504.0;
const MAX_FLOAT: f32 = 3.4028234663852886e+38;
#endif

var<workgroup> shared_vals: array<hidden_state_element_t, workgroup_size_x>;
var<workgroup> shared_vals: array<router_logits_element_t, workgroup_size_x>;
var<workgroup> shared_idxs: array<u32, workgroup_size_x>;

$MAIN {
Expand All @@ -32,14 +33,14 @@ $MAIN {
let cols = uniforms.cols;
let output_base = row * cols;

var max_val: hidden_state_element_t = -MAX_FLOAT;
var max_val: router_logits_element_t = -MAX_FLOAT;
var max_idx: u32 = 0u;

if (global_idx < cols) {
atomicStore(&tokencount_for_expert[global_idx], 0u);
}
if (local_idx < cols) {
max_val = hidden_state[(row + uniforms.token_offset) * cols + local_idx];
max_val = router_logits[(row + uniforms.token_offset) * cols + local_idx];
max_idx = local_idx;
}
shared_vals[local_idx] = max_val;
Expand Down
74 changes: 74 additions & 0 deletions onnxruntime/contrib_ops/webgpu/moe/gate_1token.wgsl.template
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

//
// MOE 1 token gate shader
//
// called with expert as local_idx
// input: router_logits
// output: topk_values
// output: indirect_experts

#param is_fp16
#param k

const K: u32 = k;
#if is_fp16
const MAX_FLOAT: f16 = 65504.0;
#else
const MAX_FLOAT: f32 = 3.4028234663852886e+38;
#endif

var<workgroup> shared_vals: array<router_logits_element_t, workgroup_size_x>;
var<workgroup> shared_idxs: array<u32, workgroup_size_x>;

$MAIN {
let row = workgroup_idx;
if (row >= uniforms.rows) {
return;
}
let cols = uniforms.cols;
let output_base = row * cols;

var max_val: router_logits_element_t = -MAX_FLOAT;
var max_idx: u32 = 0u;

if (local_idx < cols) {
max_val = router_logits[row * cols + local_idx];
max_idx = local_idx;
}
shared_vals[local_idx] = max_val;
shared_idxs[local_idx] = max_idx;
topk_values[output_base + local_idx] = topk_values_value_t(0);
workgroupBarrier();

// K is small, use a simple bubble sort
for (var i = 0u; i < workgroup_size_x - 1u; i++) {
for (var j = 0u; j < workgroup_size_x - 1u - i; j++) {
if (local_idx == j && local_idx < cols && (local_idx + 1u) < cols) {
// Compare adjacent elements and swap if needed (descending order)
if (shared_vals[local_idx] < shared_vals[local_idx + 1u]) {
let temp_val = shared_vals[local_idx];
let temp_idx = shared_idxs[local_idx];
shared_vals[local_idx] = shared_vals[local_idx + 1u];
shared_idxs[local_idx] = shared_idxs[local_idx + 1u];
shared_vals[local_idx + 1u] = temp_val;
shared_idxs[local_idx + 1u] = temp_idx;
}
}
workgroupBarrier();
}
}
if (local_idx == 0u) {
// softmax
var sum : f32 = 0.0;
for (var i = 0u; i < K; i++) {
sum += exp(f32(shared_vals[i]));
}
for (var i = 0u; i < K; i++) {
let expert_idx = shared_idxs[i];
topk_values[output_base + expert_idx] = topk_values_value_t(exp(f32(shared_vals[i])) / sum);
indirect_experts[i] = expert_idx;
}
}
} // MAIN
127 changes: 119 additions & 8 deletions onnxruntime/contrib_ops/webgpu/moe/qmoe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class GateProgram final : public Program<GateProgram> {
GateProgram(int k, bool is_fp16) : Program<GateProgram>{"QmoeGate"}, k_{k}, is_fp16_{is_fp16} {};

Status GenerateShaderCode(ShaderHelper& shader) const override {
shader.AddInput("hidden_state", ShaderUsage::UseElementTypeAlias);
shader.AddInput("router_logits", ShaderUsage::UseElementTypeAlias);
shader.AddOutput("topk_values");
shader.AddOutput("hiddenstate_for_expert");
shader.AddOutput("tokencount_for_expert");
Expand All @@ -42,6 +42,29 @@ class GateProgram final : public Program<GateProgram> {
bool is_fp16_;
};

class Gate1TokenProgram final : public Program<Gate1TokenProgram> {
public:
Gate1TokenProgram(int k, bool is_fp16) : Program<Gate1TokenProgram>{"QmoeGate1Token"}, k_{k}, is_fp16_{is_fp16} {};

Status GenerateShaderCode(ShaderHelper& shader) const override {
shader.AddInput("router_logits", ShaderUsage::UseElementTypeAlias);
shader.AddOutput("topk_values");
shader.AddOutput("indirect_experts");

return WGSL_TEMPLATE_APPLY(shader, "moe/gate_1token.wgsl.template",
WGSL_TEMPLATE_PARAMETER(is_fp16, is_fp16_),
WGSL_TEMPLATE_PARAMETER(k, k_));
};

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"rows", ProgramUniformVariableDataType::Uint32},
{"cols", ProgramUniformVariableDataType::Uint32});

private:
int k_;
bool is_fp16_;
};

class HiddenStateGatherProgram final : public Program<HiddenStateGatherProgram> {
public:
HiddenStateGatherProgram() : Program<HiddenStateGatherProgram>{"QmoeHiddenStateGather"} {};
Expand Down Expand Up @@ -115,7 +138,6 @@ class QMoEFinalMixProgram final : public Program<QMoEFinalMixProgram> {
}

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"used_by", ProgramUniformVariableDataType::Uint32},
{"hidden_size", ProgramUniformVariableDataType::Uint32},
{"num_experts", ProgramUniformVariableDataType::Uint32},
{"expert_idx", ProgramUniformVariableDataType::Uint32},
Expand All @@ -124,6 +146,26 @@ class QMoEFinalMixProgram final : public Program<QMoEFinalMixProgram> {
private:
};

class QMoEFinalMix1TokenProgram final : public Program<QMoEFinalMix1TokenProgram> {
public:
QMoEFinalMix1TokenProgram() : Program<QMoEFinalMix1TokenProgram>{"QMoEFinalMix1TokenProgram"} {}

Status GenerateShaderCode(ShaderHelper& shader) const override {
shader.AddInput("fc2_outputs", ShaderUsage::UseElementTypeAlias);
shader.AddInput("router_values", ShaderUsage::UseElementTypeAlias);
shader.AddInput("indirect_experts", ShaderUsage::UseElementTypeAlias);
shader.AddOutput("output", ShaderUsage::UseElementTypeAlias);

return WGSL_TEMPLATE_APPLY(shader, "moe/final_mix_1token.wgsl.template");
}

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"hidden_size", ProgramUniformVariableDataType::Uint32},
{"expert_idx", ProgramUniformVariableDataType::Uint32});

private:
};

Status QMoE::ComputeInternal(ComputeContext& context) const {
const Tensor* hidden_state = context.Input<Tensor>(0);
const Tensor* router_logits = context.Input<Tensor>(1);
Expand Down Expand Up @@ -168,7 +210,7 @@ Status QMoE::ComputeInternal(ComputeContext& context) const {
}

// process tokens in chunks of max_tokens to put some cap on memory usage
const int max_tokens = 512;
const int max_tokens = 2 * 1024;

const uint32_t num_experts = static_cast<uint32_t>(moe_params.num_experts);
const uint32_t hidden_size = static_cast<uint32_t>(moe_params.hidden_size);
Expand Down Expand Up @@ -197,6 +239,78 @@ Status QMoE::ComputeInternal(ComputeContext& context) const {
.AddUniformVariables({static_cast<uint32_t>(total_output_size)});
ORT_RETURN_IF_ERROR(context.RunProgram(zero));

if (moe_params.num_rows == 1) {
// Optimized code path for 1 token to avoid gpu -> cpu copy

const int num_tokens = 1;
TensorShape gate_value_shape({num_tokens, num_experts});
TensorShape indirect_experts_shape({k_});

Tensor router_values = context.CreateGPUTensor(dtype, gate_value_shape);
Tensor indirect_experts = context.CreateGPUTensor(dtype_uint32, indirect_experts_shape);

Gate1TokenProgram gate{k_, is_fp16};
gate
.AddInputs({{router_logits, ProgramTensorMetadataDependency::Type}})
.AddOutput({&router_values, ProgramTensorMetadataDependency::None})
.AddOutput({&indirect_experts, ProgramTensorMetadataDependency::None})
.SetWorkgroupSize(num_experts)
.SetDispatchGroupSize(static_cast<uint32_t>(num_tokens))
.AddUniformVariables({static_cast<uint32_t>(num_tokens), num_experts})
.CacheHint(k_, is_fp16 ? "fp16" : "fp32");

ORT_RETURN_IF_ERROR(context.RunProgram(gate));

for (uint32_t expert_idx = 0; expert_idx < static_cast<uint32_t>(k_); expert_idx++) {
TensorShape fc1_output_shape({num_tokens, fc1_output_size});
Tensor fc1_outputs = context.CreateGPUTensor(dtype, fc1_output_shape);
TensorShape fc1_activated_shape({num_tokens, moe_params.inter_size});
Tensor fc1_activated = context.CreateGPUTensor(dtype, fc1_activated_shape);
TensorShape fc2_output_shape({num_tokens, N_fc2});
Tensor fc2_outputs = context.CreateGPUTensor(dtype, fc2_output_shape);

status = ApplyMatMulNBits(hidden_state, fc1_experts_weights, fc1_scales, nullptr, fc1_experts_bias_optional,
K_fc1, N_fc1, block_size_fc1, accuracy_level, expert_weight_bits_, context,
&fc1_outputs, expert_idx, &indirect_experts);
ORT_RETURN_IF_ERROR(status);

if (is_swiglu) {
SwigLuProgram swiglu;
swiglu
.AddInputs({{&fc1_outputs, ProgramTensorMetadataDependency::Type, 2}})
.AddOutput({&fc1_activated, ProgramTensorMetadataDependency::None})
.SetWorkgroupSize(128)
.SetDispatchGroupSize(((num_tokens * static_cast<uint32_t>(moe_params.inter_size)) + 127) / 128)
.AddUniformVariables({static_cast<uint32_t>(num_tokens),
static_cast<uint32_t>(moe_params.inter_size),
activation_alpha_,
activation_beta_,
swiglu_limit_});
ORT_RETURN_IF_ERROR(context.RunProgram(swiglu));
} else {
ORT_THROW("only swiglu is supported for WebGPU.");
}

status = ApplyMatMulNBits(&fc1_activated, fc2_experts_weights, fc2_scales, nullptr, fc2_experts_bias_optional,
K_fc2, N_fc2, block_size_fc2, accuracy_level, expert_weight_bits_, context,
&fc2_outputs, expert_idx, &indirect_experts);
ORT_RETURN_IF_ERROR(status);

QMoEFinalMix1TokenProgram final_mix;
final_mix
.AddInputs({{&fc2_outputs, ProgramTensorMetadataDependency::Type}})
.AddInputs({{&router_values, ProgramTensorMetadataDependency::Type}})
.AddInputs({{&indirect_experts, ProgramTensorMetadataDependency::Type}})
.AddOutput({output_tensor, ProgramTensorMetadataDependency::None})
.SetDispatchGroupSize(1)
.AddUniformVariables({hidden_size, expert_idx});
Comment thread
guschmue marked this conversation as resolved.

ORT_RETURN_IF_ERROR(context.RunProgram(final_mix));
}
return Status::OK();
Comment thread
guschmue marked this conversation as resolved.
}

// path for num_tokens > 1
// process tokens in chunks of max_tokens to put some cap on memory usage
for (int token_offset = 0; token_offset < moe_params.num_rows; token_offset += max_tokens) {
//
Expand Down Expand Up @@ -226,9 +340,7 @@ Status QMoE::ComputeInternal(ComputeContext& context) const {
.AddOutput({&gate_counts, ProgramTensorMetadataDependency::None, ProgramOutput::Atomic})
.SetWorkgroupSize(num_experts)
.SetDispatchGroupSize(static_cast<uint32_t>(num_tokens))
.AddUniformVariables({static_cast<uint32_t>(num_tokens),
num_experts,
static_cast<uint32_t>(token_offset)})
.AddUniformVariables({static_cast<uint32_t>(num_tokens), num_experts, static_cast<uint32_t>(token_offset)})
.CacheHint(k_, is_fp16 ? "fp16" : "fp32");

ORT_RETURN_IF_ERROR(context.RunProgram(gate));
Expand Down Expand Up @@ -318,8 +430,7 @@ Status QMoE::ComputeInternal(ComputeContext& context) const {
.AddInputs({{&expert_tokens, ProgramTensorMetadataDependency::Type}})
.AddOutput({output_tensor, ProgramTensorMetadataDependency::None})
.SetDispatchGroupSize(used_by)
.AddUniformVariables({used_by,
hidden_size,
.AddUniformVariables({hidden_size,
num_experts,
expert_idx,
static_cast<uint32_t>(token_offset)});
Expand Down
Loading
Loading