Skip to content
136 changes: 105 additions & 31 deletions onnxruntime/contrib_ops/cpu/moe/moe_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,58 +35,100 @@ struct MoEParameters {
};
namespace moe_helper {

// Helper to check shape dimensions
#define ASSERT_SHAPE_DIMENSION(shape_ptr, dim, name) \
if (shape_ptr != nullptr) { \
if (shape_ptr->NumDimensions() != dim) { \
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input '", name, \
"' is expected to have ", dim, " dimensions, got ", \
shape_ptr->NumDimensions()); \
} \
}

#define ASSERT_SHAPE_3D(shape_ptr, name) ASSERT_SHAPE_DIMENSION(shape_ptr, 3, name)

#define CHECK_SHAPE(shape_ptr, name, ...) \
if (shape_ptr != nullptr) { \
const TensorShape& expected_shape = make_shape(__VA_ARGS__); \
if (*shape_ptr != expected_shape) { \
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input '", name, \
"' is expected to have shape ", expected_shape, \
", got ", *shape_ptr); \
} \
}

template <typename Tensor>
Status CheckInputs(MoEParameters& parameters,
const Tensor* input, // required
const Tensor* router_probs, // required
const Tensor* fc1_experts_weights, // required
const Tensor* fc1_experts_bias, // optional
const Tensor* fc1_experts_scales, // required for qMoE; NULL for MOE
const Tensor* fc1_zero_points, // optional, for qMoE
const Tensor* fc2_experts_weights, // required
const Tensor* fc2_experts_bias, // optional
const Tensor* fc2_experts_scales, // required for qMoE; NULL for MOE
const Tensor* fc2_zero_points, // optional, for qMoE
const Tensor* fc3_experts_weights, // optional
const Tensor* fc3_experts_bias, // optional
const Tensor* fc3_experts_scales, // required for qMoE; NULL for MOE
const Tensor* fc3_zero_points, // optional, for qMoE
const int64_t pack_size, // number of weights packed together (like 2 for uint4 packed to uint8)
const Tensor* input, // required
const Tensor* router_probs, // required
const TensorShape* fc1_experts_weights_shape, // required
const Tensor* fc1_experts_bias, // optional
const Tensor* fc1_experts_scales, // required for qMoE; NULL for MOE
const Tensor* fc1_zero_points, // optional, for qMoE
const TensorShape* fc2_experts_weights_shape, // required
const Tensor* fc2_experts_bias, // optional
const Tensor* fc2_experts_scales, // required for qMoE; NULL for MOE
const Tensor* fc2_zero_points, // optional, for qMoE
const TensorShape* fc3_experts_weights_shape, // optional
const Tensor* fc3_experts_bias, // optional
const Tensor* fc3_experts_scales, // required for qMoE; NULL for MOE
const Tensor* fc3_zero_points, // optional, for qMoE
const int64_t pack_size, // number of weights packed together (like 2 for uint4 packed to uint8)
const bool is_fused_swiglu,
const int64_t block_size = 0) { // block size for block-wise quantization
// Check dimensions of input to avoid input_dims index out of range. CHECK_TENSOR_SHAPE will verify each tensor later.
// Required inputs
if (input == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'input' is required.");
}
ASSERT_TENSOR_2D_OR_3D(input);
ASSERT_TENSOR_3D(fc1_experts_weights);
ASSERT_TENSOR_3D(fc2_experts_weights);

if (router_probs == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'router_probs' is required.");
}
ASSERT_TENSOR_2D(router_probs);

if (fc1_experts_weights_shape == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'fc1_experts_weights' is required.");
}
ASSERT_SHAPE_3D(fc1_experts_weights_shape, "fc1_experts_weights");

if (fc2_experts_weights_shape == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'fc2_experts_weights' is required.");
}
ASSERT_SHAPE_3D(fc2_experts_weights_shape, "fc2_experts_weights");

const auto& input_dims = input->Shape().GetDims();
const auto& router_probs_dims = router_probs->Shape().GetDims();
const auto& fc1_experts_weights_dims = fc1_experts_weights->Shape().GetDims();
const auto& fc2_experts_weights_dims = fc2_experts_weights->Shape().GetDims();

int64_t num_rows = input_dims.size() == 2 ? input_dims[0] : input_dims[0] * input_dims[1];
int64_t hidden_size = input_dims[input_dims.size() - 1];
int64_t local_num_experts = fc1_experts_weights_dims[0];
int64_t num_experts = router_probs_dims[1];
int64_t inter_size = (fc2_experts_weights_dims[1] * fc2_experts_weights_dims[2] * pack_size) / hidden_size;

const bool legacy_shape = (hidden_size != inter_size && fc2_experts_weights_dims[1] == inter_size) ||
(hidden_size == inter_size && is_fused_swiglu && fc1_experts_weights_dims[1] == hidden_size);
int64_t local_num_experts = fc1_experts_weights_shape->GetDims()[0];

int64_t inter_size = (fc2_experts_weights_shape->GetDims()[1] *
fc2_experts_weights_shape->GetDims()[2] * pack_size) /
hidden_size;

bool legacy_shape = false;
const auto& fc2_experts_weights_dims = fc2_experts_weights_shape->GetDims();
const auto& fc1_experts_weights_dims = fc1_experts_weights_shape->GetDims();
legacy_shape = (hidden_size != inter_size && fc2_experts_weights_dims[1] == inter_size) ||
(hidden_size == inter_size && is_fused_swiglu && fc1_experts_weights_dims[1] == hidden_size);

// Fused swiglu doubles the output dimension of FC1 since it fused two GEMMs into one.
const int64_t fc1_inter_size = is_fused_swiglu ? (inter_size + inter_size) : inter_size;
const int64_t zp_pack_size = pack_size; // Zero points packing (1 for 8-bit, 2 for 4-bit)

if (legacy_shape) {
// legacy shape does not match column major memory layout. This is for backward compatibility.
CHECK_TENSOR_SHAPE(fc1_experts_weights, num_experts, hidden_size, fc1_inter_size / pack_size);
CHECK_TENSOR_SHAPE(fc2_experts_weights, num_experts, inter_size, hidden_size / pack_size);
CHECK_TENSOR_SHAPE(fc3_experts_weights, num_experts, hidden_size, inter_size / pack_size);
CHECK_SHAPE(fc1_experts_weights_shape, "fc1_experts_weights", num_experts, hidden_size, fc1_inter_size / pack_size);
CHECK_SHAPE(fc2_experts_weights_shape, "fc2_experts_weights", num_experts, inter_size, hidden_size / pack_size);
CHECK_SHAPE(fc3_experts_weights_shape, "fc3_experts_weights", num_experts, hidden_size, inter_size / pack_size);
} else {
CHECK_TENSOR_SHAPE(fc1_experts_weights, num_experts, fc1_inter_size, hidden_size / pack_size);
CHECK_TENSOR_SHAPE(fc2_experts_weights, num_experts, hidden_size, inter_size / pack_size);
CHECK_TENSOR_SHAPE(fc3_experts_weights, num_experts, inter_size, hidden_size / pack_size);
CHECK_SHAPE(fc1_experts_weights_shape, "fc1_experts_weights", num_experts, fc1_inter_size, hidden_size / pack_size);
CHECK_SHAPE(fc2_experts_weights_shape, "fc2_experts_weights", num_experts, hidden_size, inter_size / pack_size);
CHECK_SHAPE(fc3_experts_weights_shape, "fc3_experts_weights", num_experts, inter_size, hidden_size / pack_size);
}

CHECK_TENSOR_SHAPE(router_probs, num_rows, num_experts);
Expand Down Expand Up @@ -168,9 +210,11 @@ Status CheckInputs(MoEParameters& parameters,
}
}

if (fc3_experts_weights == nullptr) {
if (fc3_experts_weights_shape == nullptr) {
// If fc3 weights are not provided, ensure no other fc3 parameters are provided
ORT_ENFORCE(fc3_experts_bias == nullptr && fc3_experts_scales == nullptr && fc3_zero_points == nullptr);
} else {
// If fc3 weights are provided, ensure scales logic is consistent
ORT_ENFORCE(fc1_experts_scales == nullptr || fc3_experts_scales != nullptr); // MOE no scale, or qMOE need scales
}

Expand Down Expand Up @@ -200,6 +244,36 @@ Status CheckInputs(MoEParameters& parameters,
return Status::OK();
}

template <typename Tensor>
Status CheckInputs(MoEParameters& parameters,
const Tensor* input, // required
const Tensor* router_probs, // required
const Tensor* fc1_experts_weights, // required
const Tensor* fc1_experts_bias, // optional
const Tensor* fc1_experts_scales, // required for qMoE; NULL for MOE
const Tensor* fc1_zero_points, // optional, for qMoE
const Tensor* fc2_experts_weights, // required
const Tensor* fc2_experts_bias, // optional
const Tensor* fc2_experts_scales, // required for qMoE; NULL for MOE
const Tensor* fc2_zero_points, // optional, for qMoE
const Tensor* fc3_experts_weights, // optional
const Tensor* fc3_experts_bias, // optional
const Tensor* fc3_experts_scales, // required for qMoE; NULL for MOE
const Tensor* fc3_zero_points, // optional, for qMoE
const int64_t pack_size, // number of weights packed together (like 2 for uint4 packed to uint8)
const bool is_fused_swiglu,
const int64_t block_size = 0) { // block size for block-wise quantization

const TensorShape* fc1_shape = (fc1_experts_weights != nullptr) ? &fc1_experts_weights->Shape() : nullptr;
const TensorShape* fc2_shape = (fc2_experts_weights != nullptr) ? &fc2_experts_weights->Shape() : nullptr;
const TensorShape* fc3_shape = (fc3_experts_weights != nullptr) ? &fc3_experts_weights->Shape() : nullptr;

return CheckInputs(parameters, input, router_probs, fc1_shape, fc1_experts_bias, fc1_experts_scales, fc1_zero_points,
fc2_shape, fc2_experts_bias, fc2_experts_scales, fc2_zero_points,
fc3_shape, fc3_experts_bias, fc3_experts_scales, fc3_zero_points,
pack_size, is_fused_swiglu, block_size);
}

} // namespace moe_helper
} // namespace contrib
} // namespace onnxruntime
Loading
Loading