Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -4613,7 +4613,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Whether to use sparse mixer</dd>
</dl>

#### Inputs (7 - 14)
#### Inputs (7 - 15)

<dl>
<dt><tt>input</tt> : T</dt>
Expand Down Expand Up @@ -4644,6 +4644,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>2D tensor with shape (num_experts, hidden_size / pack_size), or 3D tensor with shape (num_experts, hidden_size, inter_size / block_size / pack_size) when block_size is provided.</dd>
<dt><tt>fc3_zero_points</tt> (optional) : T1</dt>
<dd>2D optional tensor with shape (num_experts, inter_size / pack_size), or 3D optional tensor with shape (num_experts, inter_size, hidden_size / block_size / pack_size) when block_size is provided.</dd>
<dt><tt>router_weights</tt> (optional) : T</dt>
<dd>2D optional tensor with shape (num_tokens, num_experts). When provided, router_probs is used only for Top-K expert selection, and router_weights is used for aggregating expert outputs (the values at the selected expert indices are gathered and used as mixing weights). This enables DeepSeek-style noaux_tc routing where different tensors are used for selection and aggregation. When not provided, router_probs is used for both selection and aggregation (backward compatible).</dd>
</dl>

#### Outputs
Expand Down
4 changes: 2 additions & 2 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ Do not modify directly.*
|QLinearSigmoid|*in* X:**T**<br> *in* X_scale:**tensor(float)**<br> *in* X_zero_point:**T**<br> *in* Y_scale:**tensor(float)**<br> *in* Y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearSoftmax|*in* X:**T**<br> *in* X_scale:**tensor(float)**<br> *in* x_zero_point:**T**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearWhere|*in* condition:**B**<br> *in* X:**T**<br> *in* x_scale:**TF**<br> *in* x_zero_point:**T**<br> *in* Y:**T**<br> *in* y_scale:**TF**<br> *in* y_zero_point:**T**<br> *in* z_scale:**TF**<br> *in* z_zero_point:**T**<br> *out* Z:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QMoE|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T1**<br> *in* fc1_scales:**T2**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_weights:**T1**<br> *in* fc2_scales:**T2**<br> *in* fc2_experts_bias:**T**<br> *in* fc3_experts_weights:**T1**<br> *in* fc3_scales:**T2**<br> *in* fc3_experts_bias:**T**<br> *in* fc1_zero_points:**T1**<br> *in* fc2_zero_points:**T1**<br> *in* fc3_zero_points:**T1**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)<br/> **T1** = tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|QMoE|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T1**<br> *in* fc1_scales:**T2**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_weights:**T1**<br> *in* fc2_scales:**T2**<br> *in* fc2_experts_bias:**T**<br> *in* fc3_experts_weights:**T1**<br> *in* fc3_scales:**T2**<br> *in* fc3_experts_bias:**T**<br> *in* fc1_zero_points:**T1**<br> *in* fc2_zero_points:**T1**<br> *in* fc3_zero_points:**T1**<br> *in* router_weights:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)<br/> **T1** = tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)|
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Range|*in* start:**T**<br> *in* limit:**T**<br> *in* delta:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
Expand Down Expand Up @@ -1023,7 +1023,7 @@ Do not modify directly.*
|PackedMultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* token_offset:**M**<br> *in* cumulative_sequence_length:**M**<br> *in* attention_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|PagedAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* key_cache:**T**<br> *in* value_cache:**T**<br> *in* cumulative_sequence_length:**S**<br> *in* past_seqlens:**S**<br> *in* block_table:**S**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**<br> *out* key_cache_out:**T**<br> *out* value_cache_out:**T**|1+|**S** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)|
|QAttention|*in* input:**T1**<br> *in* weight:**T2**<br> *in* bias:**T3**<br> *in* input_scale:**T3**<br> *in* weight_scale:**T3**<br> *in* mask_index:**T4**<br> *in* input_zero_point:**T1**<br> *in* weight_zero_point:**T2**<br> *in* past:**T3**<br> *out* output:**T3**<br> *out* present:**T3**|1+|**T1** = tensor(int8)<br/> **T2** = tensor(int8)<br/> **T3** = tensor(float), tensor(float16)<br/> **T4** = tensor(int32)|
|QMoE|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T1**<br> *in* fc1_scales:**T2**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_weights:**T1**<br> *in* fc2_scales:**T2**<br> *in* fc2_experts_bias:**T**<br> *in* fc3_experts_weights:**T1**<br> *in* fc3_scales:**T2**<br> *in* fc3_experts_bias:**T**<br> *in* fc1_zero_points:**T1**<br> *in* fc2_zero_points:**T1**<br> *in* fc3_zero_points:**T1**<br> *out* output:**T**|1+|**T** = tensor(bfloat16), tensor(float16)<br/> **T1** = tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(float16)|
|QMoE|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T1**<br> *in* fc1_scales:**T2**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_weights:**T1**<br> *in* fc2_scales:**T2**<br> *in* fc2_experts_bias:**T**<br> *in* fc3_experts_weights:**T1**<br> *in* fc3_scales:**T2**<br> *in* fc3_experts_bias:**T**<br> *in* fc1_zero_points:**T1**<br> *in* fc2_zero_points:**T1**<br> *in* fc3_zero_points:**T1**<br> *in* router_weights:**T**<br> *out* output:**T**|1+|**T** = tensor(bfloat16), tensor(float16)<br/> **T1** = tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(float16)|
|QOrderedAttention|*in* input:**Q**<br> *in* scale_input:**S**<br> *in* scale_Q_gemm:**S**<br> *in* scale_K_gemm:**S**<br> *in* scale_V_gemm:**S**<br> *in* Q_weight:**Q**<br> *in* K_weight:**Q**<br> *in* V_weight:**Q**<br> *in* scale_Q_weight:**S**<br> *in* scale_K_weight:**S**<br> *in* scale_V_weight:**S**<br> *in* Q_bias:**S**<br> *in* K_bias:**S**<br> *in* V_bias:**S**<br> *in* scale_QKT_gemm:**S**<br> *in* scale_QKT_softmax:**S**<br> *in* scale_values_gemm:**S**<br> *in* mask_index:**G**<br> *in* past:**Q**<br> *in* attention_bias:**S**<br> *out* output:**Q**|1+|**G** = tensor(int32)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
|QOrderedGelu|*in* X:**Q**<br> *in* scale_X:**S**<br> *in* scale_Y:**S**<br> *out* Y:**Q**|1+|**Q** = tensor(int8)<br/> **S** = tensor(float)|
|QOrderedLayerNormalization|*in* X:**Q**<br> *in* scale_X:**S**<br> *in* scale:**F**<br> *in* B:**F**<br> *in* scale_Y:**S**<br> *out* Y:**Q**|1+|**F** = tensor(float), tensor(float16)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
Expand Down
87 changes: 73 additions & 14 deletions onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,7 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
const auto* fc1_zero_points = context->Input<Tensor>(11);
const auto* fc2_zero_points = context->Input<Tensor>(12);
const auto* fc3_zero_points = context->Input<Tensor>(13);
const auto* router_weights = context->Input<Tensor>(14);

const TensorShape* fc1_shape_ptr = packed_fc1_ ? &fc1_shape_ : (fc1_experts_weights ? &fc1_experts_weights->Shape() : nullptr);
const TensorShape* fc2_shape_ptr = packed_fc2_ ? &fc2_shape_ : (fc2_experts_weights ? &fc2_experts_weights->Shape() : nullptr);
Expand Down Expand Up @@ -708,6 +709,28 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
router_logits_float = reinterpret_cast<const float*>(router_probs->Data<T>());
}

// Handle optional router_weights input for separate selection/aggregation tensors
const bool has_router_weights = (router_weights != nullptr);
IAllocatorUniquePtr<float> router_weights_float_buffer;
const float* router_weights_float = nullptr;
if (has_router_weights) {
const auto& rw_shape = router_weights->Shape();
if (rw_shape.NumDimensions() != 2 || rw_shape[0] != num_tokens || rw_shape[1] != num_experts) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'router_weights' is expected to have shape (",
num_tokens, ", ", num_experts, "), got ", rw_shape);
}
if constexpr (std::is_same_v<T, MLFloat16>) {
router_weights_float_buffer = IAllocator::MakeUniquePtr<float>(allocator, static_cast<size_t>(num_tokens * num_experts));
router_weights_float = router_weights_float_buffer.get();
MlasConvertHalfToFloatBuffer(reinterpret_cast<const MLFloat16*>(router_weights->Data<T>()),
const_cast<float*>(router_weights_float),
static_cast<size_t>(num_tokens * num_experts));
} else {
router_weights_float = reinterpret_cast<const float*>(router_weights->Data<T>());
}
}

auto route_expert_ptr = IAllocator::MakeUniquePtr<int>(allocator, static_cast<size_t>(num_tokens * k_));
int* route_expert = route_expert_ptr.get();
auto route_scale_ptr = IAllocator::MakeUniquePtr<float>(allocator, static_cast<size_t>(num_tokens * k_));
Expand Down Expand Up @@ -743,22 +766,58 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
std::partial_sort(sorted_logits.begin(), sorted_logits.begin() + static_cast<std::ptrdiff_t>(k_),
sorted_logits.end(), std::greater<>());

float max_logit = sorted_logits[0].first;
if (has_router_weights) {
// When router_weights is provided, use it for aggregation weights instead of softmax of router_probs.
// Gather weights from router_weights at the selected expert indices.
// Note: top_k_exp is reused here as a scratch buffer for the gathered weights.
const float* weights_row = router_weights_float + i * num_experts;
if (normalize_routing_weights_) {
float weight_sum = 0.0f;
for (size_t j = 0; j < narrow<size_t>(k_); ++j) {
int64_t expert_idx = sorted_logits[j].second;
top_k_exp[j] = weights_row[expert_idx];
weight_sum += top_k_exp[j];
}
const float inv_weight_sum = (weight_sum == 0.0f) ? 0.0f : (1.0f / weight_sum);
for (size_t j = 0; j < narrow<size_t>(k_); ++j) {
int64_t expert_idx = sorted_logits[j].second;
int64_t route_idx = i * k_ + narrow<int64_t>(j);
route_expert[route_idx] = narrow<int>(expert_idx);
route_scale[route_idx] = top_k_exp[j] * inv_weight_sum;
if (route_scale[route_idx] > 1e-8f) {
local_expert_token_map[static_cast<size_t>(expert_idx)].push_back(route_idx);
}
}
} else {
for (size_t j = 0; j < narrow<size_t>(k_); ++j) {
int64_t expert_idx = sorted_logits[j].second;
int64_t route_idx = i * k_ + narrow<int64_t>(j);
route_expert[route_idx] = narrow<int>(expert_idx);
route_scale[route_idx] = weights_row[expert_idx];
if (route_scale[route_idx] > 1e-8f) {
local_expert_token_map[static_cast<size_t>(expert_idx)].push_back(route_idx);
}
}
}
} else {
// Default path: compute softmax weights from router_probs for aggregation.
float max_logit = sorted_logits[0].first;

float sum_exp = 0.0f;
for (size_t j = 0; j < narrow<size_t>(k_); ++j) {
top_k_exp[j] = std::exp(sorted_logits[j].first - max_logit);
sum_exp += top_k_exp[j];
}
float sum_exp = 0.0f;
for (size_t j = 0; j < narrow<size_t>(k_); ++j) {
top_k_exp[j] = std::exp(sorted_logits[j].first - max_logit);
sum_exp += top_k_exp[j];
}

const float inv_sum = (sum_exp == 0.0f) ? 0.0f : (1.0f / sum_exp);
for (size_t j = 0; j < narrow<size_t>(k_); ++j) {
int64_t expert_idx = sorted_logits[j].second;
int64_t route_idx = i * k_ + narrow<int64_t>(j);
route_expert[route_idx] = narrow<int>(expert_idx);
route_scale[route_idx] = top_k_exp[j] * inv_sum;
if (route_scale[route_idx] > 1e-8f) { // Use small threshold to avoid zero weights
local_expert_token_map[static_cast<size_t>(expert_idx)].push_back(route_idx);
const float inv_sum = (sum_exp == 0.0f) ? 0.0f : (1.0f / sum_exp);
for (size_t j = 0; j < narrow<size_t>(k_); ++j) {
int64_t expert_idx = sorted_logits[j].second;
int64_t route_idx = i * k_ + narrow<int64_t>(j);
route_expert[route_idx] = narrow<int>(expert_idx);
route_scale[route_idx] = top_k_exp[j] * inv_sum;
if (route_scale[route_idx] > 1e-8f) { // Use small threshold to avoid zero weights
local_expert_token_map[static_cast<size_t>(expert_idx)].push_back(route_idx);
}
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,11 @@ Status QMoE<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* fc1_zero_points = context->Input<Tensor>(11);
const Tensor* fc2_zero_points = context->Input<Tensor>(12);
const Tensor* fc3_zero_points = context->Input<Tensor>(13);
const Tensor* router_weights = context->Input<Tensor>(14);
ORT_ENFORCE(fc1_zero_points == nullptr && fc2_zero_points == nullptr && fc3_zero_points == nullptr,
"Zero points are not yet implemented on CUDA for QMoE.");
ORT_ENFORCE(router_weights == nullptr,
"Separate router_weights is not yet implemented on CUDA for QMoE.");

MoEParameters moe_params;
ORT_RETURN_IF_ERROR(::onnxruntime::contrib::moe_helper::CheckInputs<Tensor>(
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/contrib_ops/webgpu/moe/qmoe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ Status QMoE::ComputeInternal(ComputeContext& context) const {
const Tensor* fc1_zero_points = context.Input<Tensor>(11);
const Tensor* fc2_zero_points = context.Input<Tensor>(12);
const Tensor* fc3_zero_points = context.Input<Tensor>(13);
const Tensor* router_weights = context.Input<Tensor>(14);

MoEParameters moe_params;

Expand All @@ -192,6 +193,11 @@ Status QMoE::ComputeInternal(ComputeContext& context) const {
"zero_points for QMoE are not yet supported on WebGPU.");
}

if (router_weights) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"Separate router_weights is not yet implemented on WebGPU for QMoE.");
}

ORT_RETURN_IF_ERROR(::onnxruntime::contrib::moe_helper::CheckInputs<Tensor>(
moe_params, hidden_state, router_logits,
fc1_experts_weights, fc1_experts_bias_optional, fc1_scales, fc1_zero_points,
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1569,6 +1569,16 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"3D optional tensor with shape (num_experts, inter_size, hidden_size / block_size / pack_size) when block_size is provided.",
"T1",
OpSchema::Optional)
.Input(14,
"router_weights",
"2D optional tensor with shape (num_tokens, num_experts). "
"When provided, router_probs is used only for Top-K expert selection, and router_weights is used "
"for aggregating expert outputs (the values at the selected expert indices are gathered and used as "
"mixing weights). This enables DeepSeek-style noaux_tc routing where different tensors are used for "
"selection and aggregation. When not provided, router_probs is used for both selection and aggregation "
"(backward compatible).",
"T",
OpSchema::Optional)
.Output(0,
"output",
"output tensor with same shape of input",
Expand Down
Loading
Loading