Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class MoEBaseCPU {
protected:
MoEBaseCPU(const OpKernelInfo& op_kernel_info) {
ORT_ENFORCE(op_kernel_info.GetAttr<int64_t>("k", &k_).IsOK());
ORT_ENFORCE(k_ > 0, "k must be positive, got: ", k_);

std::string activation_type_str;
ORT_ENFORCE(op_kernel_info.GetAttr<std::string>("activation_type", &activation_type_str).IsOK());
Expand Down
145 changes: 118 additions & 27 deletions onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,12 @@ bool CanUseMlasQ4Gemm(int64_t expert_weight_bits, int64_t block_size,
return false;
}

// Disable direct MLAS Q4 GEMM for block-wise quantization to avoid double conversion errors
// Use traditional dequantization path which has been fixed for correct scale indexing
if (block_size == 64) {
out_qtype = BlkQ4Sym64;
return false; // Force traditional path
} else if (block_size == 128) {
out_qtype = BlkQ4Sym128;
return false; // Force traditional path
} else if (block_size == 0) {
out_qtype = BlkQ4Sym;
} else {
Expand Down Expand Up @@ -202,8 +204,16 @@ void DequantizeBlockWithMlas(const uint8_t* quantized_data,

for (int64_t block_start = 0; block_start < cols; block_start += block_size) {
const int64_t block_end = std::min(block_start + block_size, cols);
const int64_t block_idx = std::min(block_start / block_size, blocks_per_row - 1);
const int64_t block_idx = block_start / block_size;
const int64_t scale_idx = r * blocks_per_row + block_idx;

// Validate scale index bounds
const int64_t max_scale_idx = rows * blocks_per_row;
if (scale_idx < 0 || scale_idx >= max_scale_idx) {
// Skip this block if scale index is invalid
continue;
}

const float scale = static_cast<float>(scales[scale_idx]);

int64_t c = block_start;
Expand Down Expand Up @@ -255,8 +265,16 @@ void DequantizeBlockWithMlas(const uint8_t* quantized_data,
if (block_size > 0) {
for (int64_t block_start = 0; block_start < cols; block_start += block_size) {
const int64_t block_end = std::min(block_start + block_size, cols);
const int64_t block_idx = std::min(block_start / block_size, blocks_per_row - 1);
const int64_t block_idx = block_start / block_size;
const int64_t scale_idx = r * blocks_per_row + block_idx;

// Validate scale index bounds for 8-bit case
const int64_t max_scale_idx = rows * blocks_per_row;
if (scale_idx < 0 || scale_idx >= max_scale_idx) {
// Skip this block if scale index is invalid
continue;
}

const float scale = static_cast<float>(scales[scale_idx]);

for (c = block_start; c + 4 <= block_end; c += 4) {
Expand Down Expand Up @@ -295,8 +313,16 @@ void DequantizeBlockWithMlas(const uint8_t* quantized_data,
if (block_size > 0) {
for (int64_t block_start = 0; block_start < cols; block_start += block_size) {
const int64_t block_end = std::min(block_start + block_size, cols);
const int64_t block_idx = std::min(block_start / block_size, blocks_per_row - 1);
const int64_t block_idx = block_start / block_size;
const int64_t scale_idx = r * blocks_per_row + block_idx;

// Validate scale index bounds for 4-bit case
const int64_t max_scale_idx = rows * blocks_per_row;
if (scale_idx < 0 || scale_idx >= max_scale_idx) {
// Skip this block if scale index is invalid
continue;
}

const float scale = static_cast<float>(scales[scale_idx]);

for (int64_t c = block_start; c < block_end; c += 2) {
Expand Down Expand Up @@ -419,8 +445,18 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {

const int max_threads = tp ? concurrency::ThreadPool::DegreeOfParallelism(tp) : 1;
const int64_t thread_divisor = std::max(1, max_threads * 4);
const int64_t min_work_per_thread = std::max(int64_t{32}, static_cast<int64_t>(num_tokens / thread_divisor));
const int optimal_routing_threads = (tp == nullptr || num_tokens < min_work_per_thread) ? 1 : std::min(static_cast<int>(num_tokens / std::max(int64_t{1}, min_work_per_thread)), max_threads);

// For decoding (small num_tokens), use more aggressive parallelization
// For prefill (large num_tokens), ensure sufficient work per thread
int optimal_routing_threads;
if (num_tokens <= 4) {
// Small token counts (decoding): use up to 4 threads for better latency
optimal_routing_threads = (tp == nullptr) ? 1 : std::min(4, max_threads);
} else {
// Larger token counts: ensure minimum work per thread to avoid overhead
const int64_t min_work_per_thread = std::max(int64_t{8}, static_cast<int64_t>(num_tokens / thread_divisor));
optimal_routing_threads = (tp == nullptr || num_tokens < min_work_per_thread) ? 1 : std::min(static_cast<int>(num_tokens / min_work_per_thread), max_threads);
}
const int num_routing_threads = std::max(1, optimal_routing_threads);

std::vector<std::vector<std::vector<int64_t>>> thread_local_expert_token_maps(num_routing_threads);
Expand Down Expand Up @@ -516,23 +552,46 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
max_tokens_per_expert = std::max(max_tokens_per_expert, tokens.size());
}

const auto align_size = [](size_t size) -> size_t {
return (size + 63) & ~63;
};

const size_t A1_size = align_size(static_cast<size_t>(max_tokens_per_expert) * static_cast<size_t>(hidden_size));
const size_t C1_size = align_size(static_cast<size_t>(max_tokens_per_expert) * static_cast<size_t>(fc1_out_features));
const size_t A2_size = align_size(static_cast<size_t>(max_tokens_per_expert) * static_cast<size_t>(inter_size));
const size_t C2_size = align_size(static_cast<size_t>(max_tokens_per_expert) * static_cast<size_t>(hidden_size));
const size_t B1_dequant_size = align_size(static_cast<size_t>(fc1_out_features) * static_cast<size_t>(hidden_size));
const size_t B2_dequant_size = align_size(static_cast<size_t>(hidden_size) * static_cast<size_t>(inter_size));
// Use consistent buffer sizes - no alignment needed for float arrays
const size_t A1_size = static_cast<size_t>(max_tokens_per_expert) * static_cast<size_t>(hidden_size);
const size_t C1_size = static_cast<size_t>(max_tokens_per_expert) * static_cast<size_t>(fc1_out_features);
const size_t A2_size = static_cast<size_t>(max_tokens_per_expert) * static_cast<size_t>(inter_size);
const size_t C2_size = static_cast<size_t>(max_tokens_per_expert) * static_cast<size_t>(hidden_size);
const size_t B1_dequant_size = static_cast<size_t>(fc1_out_features) * static_cast<size_t>(hidden_size);
const size_t B2_dequant_size = static_cast<size_t>(hidden_size) * static_cast<size_t>(inter_size);

const size_t workspace_elements_per_thread = A1_size + C1_size + A2_size + C2_size +
B1_dequant_size + B2_dequant_size;

auto workspace_ptr = IAllocator::MakeUniquePtr<float>(allocator, static_cast<size_t>(num_expert_threads) * workspace_elements_per_thread);
float* workspace = workspace_ptr.get();

// Only zero-initialize the dequantization buffers that need it, not the entire workspace
// A1, C1, A2, C2 don't need initialization since they're always fully overwritten
const size_t dequant_buffers_size = B1_dequant_size + B2_dequant_size;
const size_t workspace_data_size = A1_size + C1_size + A2_size + C2_size;

// Zero only the dequantization buffers for each thread
// Use parallel initialization for large buffers to improve performance
if (dequant_buffers_size > 0) {
const size_t total_dequant_size = static_cast<size_t>(num_expert_threads) * dequant_buffers_size;
const size_t parallel_threshold = 64 * 1024; // 64KB threshold

if (total_dequant_size > parallel_threshold && tp != nullptr && num_expert_threads > 1) {
// Parallel initialization for large buffers
concurrency::ThreadPool::TrySimpleParallelFor(tp, num_expert_threads, [&](std::ptrdiff_t t) {
float* thread_dequant_start = workspace + static_cast<size_t>(t) * workspace_elements_per_thread + workspace_data_size;
std::memset(thread_dequant_start, 0, dequant_buffers_size * sizeof(float));
});
} else {
// Sequential initialization for smaller buffers
for (int t = 0; t < num_expert_threads; ++t) {
float* thread_dequant_start = workspace + static_cast<size_t>(t) * workspace_elements_per_thread + workspace_data_size;
std::memset(thread_dequant_start, 0, dequant_buffers_size * sizeof(float));
}
}
}

auto bias_conversion_buffers_ptr = IAllocator::MakeUniquePtr<float>(allocator,
static_cast<size_t>(num_expert_threads) * (static_cast<size_t>(fc1_out_features) + static_cast<size_t>(hidden_size)));
float* bias_conversion_buffers = bias_conversion_buffers_ptr.get();
Expand Down Expand Up @@ -566,6 +625,12 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
}
}

// Adjust thread count based on total work to avoid thread overhead for small workloads
// These thresholds are based on empirical performance testing:
// - < 48 tokens: Single thread is most efficient due to low overhead
// - 48-191 tokens: Cap at 2 threads to balance parallelism vs overhead
// - 192-511 tokens: Cap at 4 threads for good CPU utilization
// - >= 512 tokens: Use full calculated thread count for maximum parallelism
if (total_work < 48) {
num_expert_threads = 1;
} else if (total_work < 192) {
Expand Down Expand Up @@ -601,6 +666,13 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {

const int64_t num_expert_tokens = static_cast<int64_t>(routes.size());

// Validate that the number of tokens doesn't exceed our allocation
if (static_cast<size_t>(num_expert_tokens) > max_tokens_per_expert) {
LOGS_DEFAULT(ERROR) << "Expert " << expert_idx << " has " << num_expert_tokens
<< " tokens but workspace allocated for max " << max_tokens_per_expert;
continue;
}

float* A1 = thread_workspace;
float* C1 = A1 + A1_size;
float* A2 = C1 + C1_size;
Expand All @@ -617,7 +689,8 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
const int64_t end_idx = std::min(start_idx + dynamic_block_size, num_expert_tokens);

for (int64_t i = start_idx; i < end_idx; ++i) {
const int64_t token_idx = routes[static_cast<size_t>(i)] / k_;
const int64_t route_idx = routes[static_cast<size_t>(i)];
const int64_t token_idx = route_idx / k_;
const float* src = input_float + token_idx * hidden_size;
float* dst = A1 + i * hidden_size;

Expand All @@ -626,7 +699,11 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
});
} else {
for (int64_t i = 0; i < num_expert_tokens; ++i) {
const int64_t token_idx = routes[static_cast<size_t>(i)] / k_;
const int64_t route_idx = routes[static_cast<size_t>(i)];
const int64_t token_idx = route_idx / k_;
if (token_idx >= num_tokens) {
continue; // Skip out-of-bounds token indices
}
const float* src = input_float + token_idx * hidden_size;
float* dst = A1 + i * hidden_size;

Expand Down Expand Up @@ -694,8 +771,12 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
if constexpr (std::is_same_v<T, MLFloat16>) {
MlasConvertHalfToFloatBuffer(reinterpret_cast<const MLFloat16*>(B1_bias), fc1_bias_float, static_cast<size_t>(fc1_out_features));
} else {
for (int64_t i = 0; i < fc1_out_features; ++i) {
fc1_bias_float[i] = static_cast<float>(B1_bias[i]);
if (ShouldUseMemcpy(fc1_out_features)) {
std::memcpy(fc1_bias_float, B1_bias, static_cast<size_t>(fc1_out_features) * sizeof(float));
} else {
for (int64_t i = 0; i < fc1_out_features; ++i) {
fc1_bias_float[i] = static_cast<float>(B1_bias[i]);
}
}
}
}
Expand All @@ -716,6 +797,10 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
}

// Traditional approach: dequantize + regular GEMM
// Use parallel dequantization when:
// 1. num_dequant_blocks > 1: Multiple blocks to parallelize across
// 2. fc1_out_features >= 32: Sufficient work per thread to justify overhead
// (32 features * hidden_size elements = substantial work per block)
if (num_dequant_blocks > 1 && fc1_out_features >= 32) {
concurrency::ThreadPool::TrySimpleParallelFor(tp, static_cast<int>(num_dequant_blocks), [&](std::ptrdiff_t block_idx) {
const int64_t start_row = block_idx * dequant_block_size;
Expand Down Expand Up @@ -780,7 +865,7 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {

fc1_gemm_done:

const int64_t activation_threshold = std::max(int64_t{4}, 256 / std::max(int64_t{1}, inter_size));
const int64_t activation_threshold = std::max(int64_t{4}, 256 / inter_size);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you choose the magic numbers (4, 256) here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

256 is chosen because it fits well in L1 cache and is better for CPU Cache efficiency. We get the number 4 based on the inter_size it is the minimum number of token required before considering parallel processing.
inter_size = 64 --> 256/64 = 4.

if (num_expert_tokens >= activation_threshold && tp != nullptr) {
const int64_t activation_block_size = std::max(int64_t{1}, std::min(int64_t{64}, activation_threshold));
const int64_t num_activation_blocks = (num_expert_tokens + activation_block_size - 1) / activation_block_size;
Expand Down Expand Up @@ -857,9 +942,7 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
if constexpr (std::is_same_v<T, MLFloat16>) {
MlasConvertHalfToFloatBuffer(reinterpret_cast<const MLFloat16*>(B2_bias), fc2_bias_float, static_cast<size_t>(hidden_size));
} else {
for (int64_t i = 0; i < hidden_size; ++i) {
fc2_bias_float[i] = static_cast<float>(B2_bias[i]);
}
std::memcpy(fc2_bias_float, B2_bias, static_cast<size_t>(hidden_size) * sizeof(float));
}
}

Expand All @@ -880,6 +963,10 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
}

// Traditional approach: dequantize + regular GEMM
// Use parallel dequantization when:
// 1. num_fc2_dequant_blocks > 1: Multiple blocks to parallelize across
// 2. hidden_size >= 32: Sufficient work per thread to justify overhead
// (32 features * inter_size elements = substantial work per block)
if (num_fc2_dequant_blocks > 1 && hidden_size >= 32) {
concurrency::ThreadPool::TrySimpleParallelFor(tp, static_cast<int>(num_fc2_dequant_blocks), [&](std::ptrdiff_t block_idx) {
const int64_t start_row = block_idx * fc2_dequant_block_size;
Expand Down Expand Up @@ -932,13 +1019,17 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
for (int64_t i = 0; i < num_expert_tokens; ++i) {
const int64_t route_idx = routes[static_cast<size_t>(i)];
const int64_t token_idx = route_idx / k_;
if (token_idx >= num_tokens || route_idx >= num_tokens * k_) {
continue; // Skip out-of-bounds indices
}
const float weight = route_scale[route_idx];

if (token_idx < 0 || token_idx >= num_tokens) continue;

const size_t buffer_offset = static_cast<size_t>(token_idx) * static_cast<size_t>(hidden_size);
if (buffer_offset + static_cast<size_t>(hidden_size) > output_buffer_size) continue;

// Simplified thread buffer validation
if (thread_id < 0 || thread_id >= num_expert_threads) continue;

float* dest = thread_local_outputs + static_cast<size_t>(thread_id) * output_buffer_size + buffer_offset;
const float* src = C2 + i * hidden_size;

Expand Down
Loading