Skip to content
54 changes: 40 additions & 14 deletions onnxruntime/contrib_ops/cpu/moe/moe_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,37 +56,63 @@ Status CheckInputs(MoEParameters& parameters,
const int64_t block_size = 0) { // block size for block-wise quantization
// Check dimensions of input to avoid input_dims index out of range. CHECK_TENSOR_SHAPE will verify each tensor later.
ASSERT_TENSOR_2D_OR_3D(input);
ASSERT_TENSOR_3D(fc1_experts_weights);
ASSERT_TENSOR_3D(fc2_experts_weights);
if (fc1_experts_weights) ASSERT_TENSOR_3D(fc1_experts_weights);
if (fc2_experts_weights) ASSERT_TENSOR_3D(fc2_experts_weights);
ASSERT_TENSOR_2D(router_probs);

const auto& input_dims = input->Shape().GetDims();
const auto& router_probs_dims = router_probs->Shape().GetDims();
const auto& fc1_experts_weights_dims = fc1_experts_weights->Shape().GetDims();
const auto& fc2_experts_weights_dims = fc2_experts_weights->Shape().GetDims();

int64_t num_rows = input_dims.size() == 2 ? input_dims[0] : input_dims[0] * input_dims[1];
int64_t hidden_size = input_dims[input_dims.size() - 1];
int64_t local_num_experts = fc1_experts_weights_dims[0];
int64_t num_experts = router_probs_dims[1];
int64_t inter_size = (fc2_experts_weights_dims[1] * fc2_experts_weights_dims[2] * pack_size) / hidden_size;

const bool legacy_shape = (hidden_size != inter_size && fc2_experts_weights_dims[1] == inter_size) ||
(hidden_size == inter_size && is_fused_swiglu && fc1_experts_weights_dims[1] == hidden_size);
int64_t local_num_experts;
if (fc1_experts_weights != nullptr) {
local_num_experts = fc1_experts_weights->Shape().GetDims()[0];
} else if (fc1_experts_scales != nullptr) {
local_num_experts = fc1_experts_scales->Shape().GetDims()[0];
} else {
// Fallback for non-quantized MoE without weights (should not happen in current code paths)
// or if only bias is provided?
local_num_experts = num_experts;
Comment thread
tianleiwu marked this conversation as resolved.
Outdated
}

int64_t inter_size;
if (fc2_experts_weights != nullptr) {
const auto& dims = fc2_experts_weights->Shape().GetDims();
inter_size = (dims[1] * dims[2] * pack_size) / hidden_size;
} else if (fc3_experts_scales != nullptr) {
inter_size = fc3_experts_scales->Shape().GetDims()[1];
} else if (fc1_experts_scales != nullptr) {
int64_t fc1_inter_size = fc1_experts_scales->Shape().GetDims()[1];
inter_size = is_fused_swiglu ? fc1_inter_size / 2 : fc1_inter_size;
} else {
// Should not happen for valid QMoE calls
inter_size = 0;
Comment thread
tianleiwu marked this conversation as resolved.
Outdated
}

bool legacy_shape = false;
if (fc2_experts_weights != nullptr && fc1_experts_weights != nullptr) {
const auto& fc2_experts_weights_dims = fc2_experts_weights->Shape().GetDims();
const auto& fc1_experts_weights_dims = fc1_experts_weights->Shape().GetDims();
legacy_shape = (hidden_size != inter_size && fc2_experts_weights_dims[1] == inter_size) ||
(hidden_size == inter_size && is_fused_swiglu && fc1_experts_weights_dims[1] == hidden_size);
}

// Fused swiglu doubles the output dimension of FC1 since it fused two GEMMs into one.
const int64_t fc1_inter_size = is_fused_swiglu ? (inter_size + inter_size) : inter_size;
const int64_t zp_pack_size = pack_size; // Zero points packing (1 for 8-bit, 2 for 4-bit)

if (legacy_shape) {
// legacy shape does not match column major memory layout. This is for backward compatibility.
CHECK_TENSOR_SHAPE(fc1_experts_weights, num_experts, hidden_size, fc1_inter_size / pack_size);
CHECK_TENSOR_SHAPE(fc2_experts_weights, num_experts, inter_size, hidden_size / pack_size);
CHECK_TENSOR_SHAPE(fc3_experts_weights, num_experts, hidden_size, inter_size / pack_size);
if (fc1_experts_weights) CHECK_TENSOR_SHAPE(fc1_experts_weights, num_experts, hidden_size, fc1_inter_size / pack_size);
if (fc2_experts_weights) CHECK_TENSOR_SHAPE(fc2_experts_weights, num_experts, inter_size, hidden_size / pack_size);
if (fc3_experts_weights) CHECK_TENSOR_SHAPE(fc3_experts_weights, num_experts, hidden_size, inter_size / pack_size);
} else {
CHECK_TENSOR_SHAPE(fc1_experts_weights, num_experts, fc1_inter_size, hidden_size / pack_size);
CHECK_TENSOR_SHAPE(fc2_experts_weights, num_experts, hidden_size, inter_size / pack_size);
CHECK_TENSOR_SHAPE(fc3_experts_weights, num_experts, inter_size, hidden_size / pack_size);
if (fc1_experts_weights) CHECK_TENSOR_SHAPE(fc1_experts_weights, num_experts, fc1_inter_size, hidden_size / pack_size);
if (fc2_experts_weights) CHECK_TENSOR_SHAPE(fc2_experts_weights, num_experts, hidden_size, inter_size / pack_size);
if (fc3_experts_weights) CHECK_TENSOR_SHAPE(fc3_experts_weights, num_experts, inter_size, hidden_size / pack_size);
}

CHECK_TENSOR_SHAPE(router_probs, num_rows, num_experts);
Expand Down
Loading
Loading