Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -4613,7 +4613,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Whether to use sparse mixer</dd>
</dl>

#### Inputs (7 - 14)
#### Inputs (7 - 15)

<dl>
<dt><tt>input</tt> : T</dt>
Expand Down Expand Up @@ -4644,6 +4644,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>2D tensor with shape (num_experts, hidden_size / pack_size), or 3D tensor with shape (num_experts, hidden_size, inter_size / block_size / pack_size) when block_size is provided.</dd>
<dt><tt>fc3_zero_points</tt> (optional) : T1</dt>
<dd>2D optional tensor with shape (num_experts, inter_size / pack_size), or 3D optional tensor with shape (num_experts, inter_size, hidden_size / block_size / pack_size) when block_size is provided.</dd>
<dt><tt>router_weights</tt> (optional) : T</dt>
<dd>2D optional tensor with shape (num_tokens, num_experts). When provided, router_probs is used only for Top-K expert selection, and router_weights is used for aggregating expert outputs (the values at the selected expert indices are gathered and used as mixing weights). This enables DeepSeek-style noaux_tc routing where different tensors are used for selection and aggregation. When not provided, router_probs is used for both selection and aggregation (backward compatible).</dd>
</dl>

#### Outputs
Expand Down
4 changes: 2 additions & 2 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ Do not modify directly.*
|QLinearSigmoid|*in* X:**T**<br> *in* X_scale:**tensor(float)**<br> *in* X_zero_point:**T**<br> *in* Y_scale:**tensor(float)**<br> *in* Y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearSoftmax|*in* X:**T**<br> *in* X_scale:**tensor(float)**<br> *in* x_zero_point:**T**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearWhere|*in* condition:**B**<br> *in* X:**T**<br> *in* x_scale:**TF**<br> *in* x_zero_point:**T**<br> *in* Y:**T**<br> *in* y_scale:**TF**<br> *in* y_zero_point:**T**<br> *in* z_scale:**TF**<br> *in* z_zero_point:**T**<br> *out* Z:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QMoE|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T1**<br> *in* fc1_scales:**T2**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_weights:**T1**<br> *in* fc2_scales:**T2**<br> *in* fc2_experts_bias:**T**<br> *in* fc3_experts_weights:**T1**<br> *in* fc3_scales:**T2**<br> *in* fc3_experts_bias:**T**<br> *in* fc1_zero_points:**T1**<br> *in* fc2_zero_points:**T1**<br> *in* fc3_zero_points:**T1**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)<br/> **T1** = tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|QMoE|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T1**<br> *in* fc1_scales:**T2**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_weights:**T1**<br> *in* fc2_scales:**T2**<br> *in* fc2_experts_bias:**T**<br> *in* fc3_experts_weights:**T1**<br> *in* fc3_scales:**T2**<br> *in* fc3_experts_bias:**T**<br> *in* fc1_zero_points:**T1**<br> *in* fc2_zero_points:**T1**<br> *in* fc3_zero_points:**T1**<br> *in* router_weights:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)<br/> **T1** = tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)|
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)|
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Range|*in* start:**T**<br> *in* limit:**T**<br> *in* delta:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
Expand Down Expand Up @@ -1023,7 +1023,7 @@ Do not modify directly.*
|PackedMultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* token_offset:**M**<br> *in* cumulative_sequence_length:**M**<br> *in* attention_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|PagedAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* key_cache:**T**<br> *in* value_cache:**T**<br> *in* cumulative_sequence_length:**S**<br> *in* past_seqlens:**S**<br> *in* block_table:**S**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**<br> *out* key_cache_out:**T**<br> *out* value_cache_out:**T**|1+|**S** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)|
|QAttention|*in* input:**T1**<br> *in* weight:**T2**<br> *in* bias:**T3**<br> *in* input_scale:**T3**<br> *in* weight_scale:**T3**<br> *in* mask_index:**T4**<br> *in* input_zero_point:**T1**<br> *in* weight_zero_point:**T2**<br> *in* past:**T3**<br> *out* output:**T3**<br> *out* present:**T3**|1+|**T1** = tensor(int8)<br/> **T2** = tensor(int8)<br/> **T3** = tensor(float), tensor(float16)<br/> **T4** = tensor(int32)|
|QMoE|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T1**<br> *in* fc1_scales:**T2**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_weights:**T1**<br> *in* fc2_scales:**T2**<br> *in* fc2_experts_bias:**T**<br> *in* fc3_experts_weights:**T1**<br> *in* fc3_scales:**T2**<br> *in* fc3_experts_bias:**T**<br> *in* fc1_zero_points:**T1**<br> *in* fc2_zero_points:**T1**<br> *in* fc3_zero_points:**T1**<br> *out* output:**T**|1+|**T** = tensor(bfloat16), tensor(float16)<br/> **T1** = tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(float16)|
|QMoE|*in* input:**T**<br> *in* router_probs:**T**<br> *in* fc1_experts_weights:**T1**<br> *in* fc1_scales:**T2**<br> *in* fc1_experts_bias:**T**<br> *in* fc2_experts_weights:**T1**<br> *in* fc2_scales:**T2**<br> *in* fc2_experts_bias:**T**<br> *in* fc3_experts_weights:**T1**<br> *in* fc3_scales:**T2**<br> *in* fc3_experts_bias:**T**<br> *in* fc1_zero_points:**T1**<br> *in* fc2_zero_points:**T1**<br> *in* fc3_zero_points:**T1**<br> *in* router_weights:**T**<br> *out* output:**T**|1+|**T** = tensor(bfloat16), tensor(float16)<br/> **T1** = tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(float16)|
|QOrderedAttention|*in* input:**Q**<br> *in* scale_input:**S**<br> *in* scale_Q_gemm:**S**<br> *in* scale_K_gemm:**S**<br> *in* scale_V_gemm:**S**<br> *in* Q_weight:**Q**<br> *in* K_weight:**Q**<br> *in* V_weight:**Q**<br> *in* scale_Q_weight:**S**<br> *in* scale_K_weight:**S**<br> *in* scale_V_weight:**S**<br> *in* Q_bias:**S**<br> *in* K_bias:**S**<br> *in* V_bias:**S**<br> *in* scale_QKT_gemm:**S**<br> *in* scale_QKT_softmax:**S**<br> *in* scale_values_gemm:**S**<br> *in* mask_index:**G**<br> *in* past:**Q**<br> *in* attention_bias:**S**<br> *out* output:**Q**|1+|**G** = tensor(int32)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
|QOrderedGelu|*in* X:**Q**<br> *in* scale_X:**S**<br> *in* scale_Y:**S**<br> *out* Y:**Q**|1+|**Q** = tensor(int8)<br/> **S** = tensor(float)|
|QOrderedLayerNormalization|*in* X:**Q**<br> *in* scale_X:**S**<br> *in* scale:**F**<br> *in* B:**F**<br> *in* scale_Y:**S**<br> *out* Y:**Q**|1+|**F** = tensor(float), tensor(float16)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
Expand Down
87 changes: 73 additions & 14 deletions onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,7 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
const auto* fc1_zero_points = context->Input<Tensor>(11);
const auto* fc2_zero_points = context->Input<Tensor>(12);
const auto* fc3_zero_points = context->Input<Tensor>(13);
const auto* router_weights = context->Input<Tensor>(14);

const TensorShape* fc1_shape_ptr = packed_fc1_ ? &fc1_shape_ : (fc1_experts_weights ? &fc1_experts_weights->Shape() : nullptr);
const TensorShape* fc2_shape_ptr = packed_fc2_ ? &fc2_shape_ : (fc2_experts_weights ? &fc2_experts_weights->Shape() : nullptr);
Expand Down Expand Up @@ -708,6 +709,28 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
router_logits_float = reinterpret_cast<const float*>(router_probs->Data<T>());
}

// Handle optional router_weights input for separate selection/aggregation tensors
const bool has_router_weights = (router_weights != nullptr);
IAllocatorUniquePtr<float> router_weights_float_buffer;
const float* router_weights_float = nullptr;
if (has_router_weights) {
const auto& rw_shape = router_weights->Shape();
if (rw_shape.NumDimensions() != 2 || rw_shape[0] != num_tokens || rw_shape[1] != num_experts) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'router_weights' is expected to have shape (",
num_tokens, ", ", num_experts, "), got ", rw_shape);
}
if constexpr (std::is_same_v<T, MLFloat16>) {
router_weights_float_buffer = IAllocator::MakeUniquePtr<float>(allocator, static_cast<size_t>(num_tokens * num_experts));
router_weights_float = router_weights_float_buffer.get();
MlasConvertHalfToFloatBuffer(reinterpret_cast<const MLFloat16*>(router_weights->Data<T>()),
const_cast<float*>(router_weights_float),
static_cast<size_t>(num_tokens * num_experts));
} else {
router_weights_float = reinterpret_cast<const float*>(router_weights->Data<T>());
}
}

auto route_expert_ptr = IAllocator::MakeUniquePtr<int>(allocator, static_cast<size_t>(num_tokens * k_));
int* route_expert = route_expert_ptr.get();
auto route_scale_ptr = IAllocator::MakeUniquePtr<float>(allocator, static_cast<size_t>(num_tokens * k_));
Expand Down Expand Up @@ -743,22 +766,58 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
std::partial_sort(sorted_logits.begin(), sorted_logits.begin() + static_cast<std::ptrdiff_t>(k_),
sorted_logits.end(), std::greater<>());

float max_logit = sorted_logits[0].first;
if (has_router_weights) {
// When router_weights is provided, use it for aggregation weights instead of softmax of router_probs.
// Gather weights from router_weights at the selected expert indices.
// Note: top_k_exp is reused here as a scratch buffer for the gathered weights.
const float* weights_row = router_weights_float + i * num_experts;
if (normalize_routing_weights_) {
float weight_sum = 0.0f;
for (size_t j = 0; j < narrow<size_t>(k_); ++j) {
int64_t expert_idx = sorted_logits[j].second;
top_k_exp[j] = weights_row[expert_idx];
weight_sum += top_k_exp[j];
}
const float inv_weight_sum = (weight_sum == 0.0f) ? 0.0f : (1.0f / weight_sum);
for (size_t j = 0; j < narrow<size_t>(k_); ++j) {
int64_t expert_idx = sorted_logits[j].second;
int64_t route_idx = i * k_ + narrow<int64_t>(j);
route_expert[route_idx] = narrow<int>(expert_idx);
route_scale[route_idx] = top_k_exp[j] * inv_weight_sum;
if (route_scale[route_idx] > 1e-8f) {
local_expert_token_map[static_cast<size_t>(expert_idx)].push_back(route_idx);
}
}
} else {
for (size_t j = 0; j < narrow<size_t>(k_); ++j) {
int64_t expert_idx = sorted_logits[j].second;
int64_t route_idx = i * k_ + narrow<int64_t>(j);
route_expert[route_idx] = narrow<int>(expert_idx);
route_scale[route_idx] = weights_row[expert_idx];
if (route_scale[route_idx] > 1e-8f) {
local_expert_token_map[static_cast<size_t>(expert_idx)].push_back(route_idx);
}
}
}
} else {
// Default path: compute softmax weights from router_probs for aggregation.
float max_logit = sorted_logits[0].first;

float sum_exp = 0.0f;
for (size_t j = 0; j < narrow<size_t>(k_); ++j) {
top_k_exp[j] = std::exp(sorted_logits[j].first - max_logit);
sum_exp += top_k_exp[j];
}
float sum_exp = 0.0f;
for (size_t j = 0; j < narrow<size_t>(k_); ++j) {
top_k_exp[j] = std::exp(sorted_logits[j].first - max_logit);
sum_exp += top_k_exp[j];
}

const float inv_sum = (sum_exp == 0.0f) ? 0.0f : (1.0f / sum_exp);
for (size_t j = 0; j < narrow<size_t>(k_); ++j) {
int64_t expert_idx = sorted_logits[j].second;
int64_t route_idx = i * k_ + narrow<int64_t>(j);
route_expert[route_idx] = narrow<int>(expert_idx);
route_scale[route_idx] = top_k_exp[j] * inv_sum;
if (route_scale[route_idx] > 1e-8f) { // Use small threshold to avoid zero weights
local_expert_token_map[static_cast<size_t>(expert_idx)].push_back(route_idx);
const float inv_sum = (sum_exp == 0.0f) ? 0.0f : (1.0f / sum_exp);
for (size_t j = 0; j < narrow<size_t>(k_); ++j) {
int64_t expert_idx = sorted_logits[j].second;
int64_t route_idx = i * k_ + narrow<int64_t>(j);
route_expert[route_idx] = narrow<int>(expert_idx);
route_scale[route_idx] = top_k_exp[j] * inv_sum;
if (route_scale[route_idx] > 1e-8f) { // Use small threshold to avoid zero weights
local_expert_token_map[static_cast<size_t>(expert_idx)].push_back(route_idx);
}
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,11 @@ Status QMoE<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* fc1_zero_points = context->Input<Tensor>(11);
const Tensor* fc2_zero_points = context->Input<Tensor>(12);
const Tensor* fc3_zero_points = context->Input<Tensor>(13);
const Tensor* router_weights = context->Input<Tensor>(14);
ORT_ENFORCE(fc1_zero_points == nullptr && fc2_zero_points == nullptr && fc3_zero_points == nullptr,
"Zero points are not yet implemented on CUDA for QMoE.");
ORT_ENFORCE(router_weights == nullptr,
"Separate router_weights is not yet implemented on CUDA for QMoE.");

MoEParameters moe_params;
ORT_RETURN_IF_ERROR(::onnxruntime::contrib::moe_helper::CheckInputs<Tensor>(
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/contrib_ops/webgpu/moe/qmoe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ Status QMoE::ComputeInternal(ComputeContext& context) const {
const Tensor* fc1_zero_points = context.Input<Tensor>(11);
const Tensor* fc2_zero_points = context.Input<Tensor>(12);
const Tensor* fc3_zero_points = context.Input<Tensor>(13);
const Tensor* router_weights = context.Input<Tensor>(14);

MoEParameters moe_params;

Expand All @@ -192,6 +193,11 @@ Status QMoE::ComputeInternal(ComputeContext& context) const {
"zero_points for QMoE are not yet supported on WebGPU.");
}

if (router_weights) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"Separate router_weights is not yet implemented on WebGPU for QMoE.");
}

ORT_RETURN_IF_ERROR(::onnxruntime::contrib::moe_helper::CheckInputs<Tensor>(
moe_params, hidden_state, router_logits,
fc1_experts_weights, fc1_experts_bias_optional, fc1_scales, fc1_zero_points,
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1569,6 +1569,16 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"3D optional tensor with shape (num_experts, inter_size, hidden_size / block_size / pack_size) when block_size is provided.",
"T1",
OpSchema::Optional)
.Input(14,
"router_weights",
"2D optional tensor with shape (num_tokens, num_experts). "
"When provided, router_probs is used only for Top-K expert selection, and router_weights is used "
"for aggregating expert outputs (the values at the selected expert indices are gathered and used as "
"mixing weights). This enables DeepSeek-style noaux_tc routing where different tensors are used for "
"selection and aggregation. When not provided, router_probs is used for both selection and aggregation "
"(backward compatible).",
"T",
OpSchema::Optional)
.Output(0,
"output",
"output tensor with same shape of input",
Expand Down
Loading
Loading