From ce56a3abdd6c405004c01f5011cbe0734817172c Mon Sep 17 00:00:00 2001 From: xieofxie Date: Thu, 31 Jul 2025 16:05:26 +0800 Subject: [PATCH 01/10] add session_id_ to LogEvaluationStart/Stop, LogSessionCreationStart (#25590) ### Description use session id to track them with LogSessionCreation if we call Run in different threads, we could differentiate them with thread id given Run is not async ### Motivation and Context --------- Co-authored-by: hualxie --- onnxruntime/core/platform/telemetry.cc | 9 ++++++--- onnxruntime/core/platform/telemetry.h | 6 +++--- onnxruntime/core/platform/windows/telemetry.cc | 13 ++++++++----- onnxruntime/core/platform/windows/telemetry.h | 6 +++--- onnxruntime/core/session/inference_session.cc | 6 +++--- 5 files changed, 23 insertions(+), 17 deletions(-) diff --git a/onnxruntime/core/platform/telemetry.cc b/onnxruntime/core/platform/telemetry.cc index 9cf89a04f031c..6cbbdd4e0a7ef 100644 --- a/onnxruntime/core/platform/telemetry.cc +++ b/onnxruntime/core/platform/telemetry.cc @@ -40,13 +40,16 @@ void Telemetry::SetLanguageProjection(uint32_t projection) const { void Telemetry::LogProcessInfo() const { } -void Telemetry::LogSessionCreationStart() const { +void Telemetry::LogSessionCreationStart(uint32_t session_id) const { + ORT_UNUSED_PARAMETER(session_id); } -void Telemetry::LogEvaluationStop() const { +void Telemetry::LogEvaluationStop(uint32_t session_id) const { + ORT_UNUSED_PARAMETER(session_id); } -void Telemetry::LogEvaluationStart() const { +void Telemetry::LogEvaluationStart(uint32_t session_id) const { + ORT_UNUSED_PARAMETER(session_id); } void Telemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, const std::string& model_producer_name, diff --git a/onnxruntime/core/platform/telemetry.h b/onnxruntime/core/platform/telemetry.h index cb7a6176e5aec..b60345e1b8a80 100644 --- a/onnxruntime/core/platform/telemetry.h +++ b/onnxruntime/core/platform/telemetry.h @@ -48,11 +48,11 @@ class Telemetry { virtual void LogProcessInfo() const; - virtual void LogSessionCreationStart() const; + virtual void LogSessionCreationStart(uint32_t session_id) const; - virtual void LogEvaluationStop() const; + virtual void LogEvaluationStop(uint32_t session_id) const; - virtual void LogEvaluationStart() const; + virtual void LogEvaluationStart(uint32_t session_id) const; virtual void LogSessionCreation(uint32_t session_id, int64_t ir_version, const std::string& model_producer_name, const std::string& model_producer_version, const std::string& model_domain, diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc index 44ef44a3f5aff..2e5d334856278 100644 --- a/onnxruntime/core/platform/windows/telemetry.cc +++ b/onnxruntime/core/platform/windows/telemetry.cc @@ -194,7 +194,7 @@ void WindowsTelemetry::LogProcessInfo() const { process_info_logged = true; } -void WindowsTelemetry::LogSessionCreationStart() const { +void WindowsTelemetry::LogSessionCreationStart(uint32_t session_id) const { if (global_register_count_ == 0 || enabled_ == false) return; @@ -203,23 +203,26 @@ void WindowsTelemetry::LogSessionCreationStart() const { TraceLoggingBool(true, "UTCReplace_AppSessionGuid"), TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage), TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), + TraceLoggingUInt32(session_id, "sessionId"), TraceLoggingLevel(WINEVENT_LEVEL_INFO)); } -void WindowsTelemetry::LogEvaluationStop() const { +void WindowsTelemetry::LogEvaluationStop(uint32_t session_id) const { if (global_register_count_ == 0 || enabled_ == false) return; TraceLoggingWrite(telemetry_provider_handle, - "EvaluationStop"); + "EvaluationStop", + TraceLoggingUInt32(session_id, "sessionId")); } -void WindowsTelemetry::LogEvaluationStart() const { +void WindowsTelemetry::LogEvaluationStart(uint32_t session_id) const { if (global_register_count_ == 0 || enabled_ == false) return; TraceLoggingWrite(telemetry_provider_handle, - "EvaluationStart"); + "EvaluationStart", + TraceLoggingUInt32(session_id, "sessionId")); } void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, const std::string& model_producer_name, diff --git a/onnxruntime/core/platform/windows/telemetry.h b/onnxruntime/core/platform/windows/telemetry.h index 7281063d50c2e..261d14a7fed8c 100644 --- a/onnxruntime/core/platform/windows/telemetry.h +++ b/onnxruntime/core/platform/windows/telemetry.h @@ -41,11 +41,11 @@ class WindowsTelemetry : public Telemetry { void LogProcessInfo() const override; - void LogSessionCreationStart() const override; + void LogSessionCreationStart(uint32_t session_id) const override; - void LogEvaluationStop() const override; + void LogEvaluationStop(uint32_t session_id) const override; - void LogEvaluationStart() const override; + void LogEvaluationStart(uint32_t session_id) const override; void LogSessionCreation(uint32_t session_id, int64_t ir_version, const std::string& model_producer_name, const std::string& model_producer_version, const std::string& model_domain, diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 112fd84c5ed45..e3291cdce62c5 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -2041,7 +2041,7 @@ common::Status InferenceSession::Initialize() { ORT_TRY { LOGS(*session_logger_, INFO) << "Initializing session."; const Env& env = Env::Default(); - env.GetTelemetryProvider().LogSessionCreationStart(); + env.GetTelemetryProvider().LogSessionCreationStart(session_id_); bool have_cpu_ep = false; @@ -2980,7 +2980,7 @@ Status InferenceSession::Run(const RunOptions& run_options, } // log evaluation start to trace logging provider - env.GetTelemetryProvider().LogEvaluationStart(); + env.GetTelemetryProvider().LogEvaluationStart(session_id_); ORT_RETURN_IF_ERROR_SESSIONID_(ValidateInputs(feed_names, feeds)); ORT_RETURN_IF_ERROR_SESSIONID_(ValidateOutputs(output_names, p_fetches)); @@ -3133,7 +3133,7 @@ Status InferenceSession::Run(const RunOptions& run_options, } // log evaluation stop to trace logging provider - env.GetTelemetryProvider().LogEvaluationStop(); + env.GetTelemetryProvider().LogEvaluationStop(session_id_); // send out profiling events (optional) if (session_profiler_.IsEnabled()) { From 04f0fff1daf7f0591babcb57d35afce27cc12f54 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 5 Aug 2025 08:07:49 -0700 Subject: [PATCH 02/10] [build] fix WebAssembly build on macOS/arm64 (#25653) ### Description fix WebAssembly build on macOS/arm64 by disable appending "-Donnxruntime_USE_KLEIDIAI=ON" to the cmake_args KleidiAI should not be enabled for WebAssembly build. --- tools/ci_build/build.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index d22c8587a82b5..1d2f889ff1494 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -888,7 +888,7 @@ def generate_build_tree( # * Leave disabled if "no_kleidiai" argument was specified. # * Enable if the target is Android and args.android_abi contains arm64* # * Enable for a Windows cross compile build if compile target is an Arm one. - # * Finally enable if platform.machine contains "arm64". This should cover the following cases: + # * Finally enable if platform.machine contains "arm64" and not a WebAssembly build. This should cover the following cases: # * Linux on Arm # * MacOs (case must be ignored) # * TODO Delegate responsibility for Onnxruntime_USE_KLEIDIAI = ON to CMake logic @@ -896,7 +896,7 @@ def generate_build_tree( if ( (args.android and "arm64" in args.android_abi.lower()) or (is_windows() and (args.arm64 or args.arm64ec or args.arm) and platform.architecture()[0] != "AMD64") - or ("arm64" in platform.machine().lower()) + or ("arm64" in platform.machine().lower() and not args.build_wasm) ): cmake_args += ["-Donnxruntime_USE_KLEIDIAI=ON"] From 55aa5247a382d2a562d6661b15e7ac6f16ae4229 Mon Sep 17 00:00:00 2001 From: Akshay Sonawane <111780983+apsonawane@users.noreply.github.com> Date: Wed, 10 Sep 2025 16:54:49 -0700 Subject: [PATCH 03/10] [CPU] MoE Kernel (#25958) CPU MoE Kernel ``` name: SwigluMoEBlock, quant_bits: 0, dtype: FP32, batch: 1, seq_len: 16, max_diff: 2.682209014892578e-07 .name: SwigluMoEBlock, quant_bits: 0, dtype: FP32, batch: 1, seq_len: 32, max_diff: 2.980232238769531e-07 .name: SwigluMoEBlock, quant_bits: 0, dtype: FP32, batch: 2, seq_len: 16, max_diff: 2.980232238769531e-07 .name: SwigluMoEBlock, quant_bits: 0, dtype: FP32, batch: 2, seq_len: 32, max_diff: 4.172325134277344e-07 .MoE CPU kernel time: 15.721677541732786 ms . ---------------------------------------------------------------------- Ran 5 tests in 30.217s ``` --- docs/OperatorKernels.md | 1 + .../contrib_ops/cpu/cpu_contrib_kernels.cc | 2 + .../contrib_ops/cpu/moe/moe_base_cpu.h | 2 +- onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc | 605 ++++++++++++++++++ onnxruntime/contrib_ops/cpu/moe/moe_cpu.h | 53 ++ onnxruntime/contrib_ops/cpu/moe/moe_utils.cc | 12 +- onnxruntime/test/contrib_ops/moe_test.cc | 91 +++ .../test/python/transformers/test_moe_cpu.py | 473 ++++++++++++++ .../test/python/transformers/test_qmoe_cpu.py | 498 ++++++++++---- 9 files changed, 1613 insertions(+), 124 deletions(-) create mode 100644 onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc create mode 100644 onnxruntime/contrib_ops/cpu/moe/moe_cpu.h create mode 100644 onnxruntime/test/python/transformers/test_moe_cpu.py diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index b8385a3352278..c7844b4120d97 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -549,6 +549,7 @@ Do not modify directly.* |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)
**T3** = tensor(float), tensor(float16), tensor(uint8)
**T4** = tensor(int32)| |MaxpoolWithMask|*in* X:**T**
*in* M:**tensor(int32)**
*out* Y:**T**|1+|**T** = tensor(float)| +|MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**QK**|1+|**T** = tensor(float)| |MurmurHash3|*in* X:**T1**
*out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(uint32)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 34410a5f42630..d959d11e3fd43 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -108,6 +108,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QGemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, QMoE); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QMoE); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MoE); // ******** End: Quantization ******************* // #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED @@ -275,6 +276,7 @@ Status RegisterQuantizationKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h b/onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h index eae96c186d471..84580b310f6b3 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h +++ b/onnxruntime/contrib_ops/cpu/moe/moe_base_cpu.h @@ -6,7 +6,7 @@ #include "core/common/common.h" #include "core/framework/tensor_shape.h" #include "core/framework/op_kernel.h" -#include "contrib_ops/cpu/moe/moe_helper.h" +#include "moe_helper.h" #include namespace onnxruntime { diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc new file mode 100644 index 0000000000000..73be099181aa9 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc @@ -0,0 +1,605 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/moe/moe_cpu.h" +#include "contrib_ops/cpu/moe/moe_utils.h" +#include "contrib_ops/cpu/moe/moe_helper.h" +#include "core/framework/op_kernel.h" +#include "core/providers/common.h" +#include "core/providers/cpu/math/gemm_helper.h" +#include "core/util/math_cpuonly.h" +#include "core/mlas/inc/mlas.h" +#include "core/framework/float16.h" +#include "core/framework/allocator.h" +#include "core/platform/threadpool.h" + +#include +#include +#include + +namespace onnxruntime { +namespace contrib { + +template +MoE::MoE(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info), MoEBaseCPU(op_kernel_info) { + if (activation_type_ == ActivationType::SwiGLU && swiglu_fusion_ != 1) { + ORT_THROW("CPU MoE only supports interleaved SwiGLU format. Please set swiglu_fusion=1."); + } +} + +template +Status MoE::Compute(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const Tensor* router_probs = context->Input(1); + const Tensor* fc1_experts_weights = context->Input(2); + const Tensor* fc1_experts_bias = context->Input(3); + const Tensor* fc2_experts_weights = context->Input(4); + const Tensor* fc2_experts_bias = context->Input(5); + const Tensor* fc3_experts_weights = context->Input(6); + const Tensor* fc3_experts_bias = context->Input(7); + + // FC3 not supported + if (fc3_experts_weights != nullptr || fc3_experts_bias != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "FC3 is not implemented for CPU MoE."); + } + + MoEParameters moe_params; + ORT_RETURN_IF_ERROR(moe_helper::CheckInputs( + moe_params, input, router_probs, + fc1_experts_weights, fc1_experts_bias, nullptr, + fc2_experts_weights, fc2_experts_bias, nullptr, + fc3_experts_weights, fc3_experts_bias, nullptr, + 1, + activation_type_ == ActivationType::SwiGLU)); + + Tensor* output = context->Output(0, input->Shape()); + + return ComputeMoE(context, input, router_probs, fc1_experts_weights, fc1_experts_bias, + fc2_experts_weights, fc2_experts_bias, output); +} + +template +Status MoE::ComputeMoE(const OpKernelContext* context, + const Tensor* input, + const Tensor* router_probs, + const Tensor* fc1_experts_weights, + const Tensor* fc1_experts_bias, + const Tensor* fc2_experts_weights, + const Tensor* fc2_experts_bias, + Tensor* output) const { + const auto& input_shape = input->Shape(); + const auto& router_shape = router_probs->Shape(); + const auto& fc2_shape = fc2_experts_weights->Shape(); + + const int64_t num_tokens = input_shape.Size() / input_shape[input_shape.NumDimensions() - 1]; + const int64_t hidden_size = input_shape[input_shape.NumDimensions() - 1]; + const int64_t num_experts = router_shape[1]; + const int64_t inter_size = (fc2_shape[1] * fc2_shape[2]) / hidden_size; + const bool is_swiglu = activation_type_ == ActivationType::SwiGLU; + const int64_t fc1_output_size = is_swiglu ? (inter_size * 2) : inter_size; + + const T* input_data = input->Data(); + const T* router_data = router_probs->Data(); + const T* fc1_weights_data = fc1_experts_weights->Data(); + const T* fc1_bias_data = fc1_experts_bias ? fc1_experts_bias->Data() : nullptr; + const T* fc2_weights_data = fc2_experts_weights->Data(); + const T* fc2_bias_data = fc2_experts_bias ? fc2_experts_bias->Data() : nullptr; + T* output_data = output->MutableData(); + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + + const T* input_data_to_use = input_data; + IAllocatorUniquePtr input_data_copy_ptr; + if (normalize_routing_weights_) { + input_data_copy_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * hidden_size)); + T* input_data_copy = input_data_copy_ptr.get(); + std::copy(input_data, input_data + (num_tokens * hidden_size), input_data_copy); + input_data_to_use = input_data_copy; + } + + std::fill_n(output_data, output->Shape().Size(), T{}); + + IAllocatorUniquePtr router_logits_float_buffer; + const float* router_logits_float = nullptr; + if constexpr (std::is_same_v) { + router_logits_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * num_experts)); + router_logits_float = router_logits_float_buffer.get(); + MlasConvertHalfToFloatBuffer(reinterpret_cast(router_data), const_cast(router_logits_float), static_cast(num_tokens * num_experts)); + } else { + router_logits_float = reinterpret_cast(router_data); + } + + auto route_expert_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * k_)); + int* route_expert = route_expert_ptr.get(); + auto route_scale_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * k_)); + float* route_scale = route_scale_ptr.get(); + + auto* tp = context->GetOperatorThreadPool(); + int num_routing_threads = 1; + if (tp != nullptr && num_tokens >= 1024) { + int max_threads = concurrency::ThreadPool::DegreeOfParallelism(tp); + num_routing_threads = std::min(static_cast(num_tokens / 512), max_threads); + num_routing_threads = std::max(1, num_routing_threads); + } + + std::vector>> thread_local_expert_token_maps(num_routing_threads); + for (auto& map : thread_local_expert_token_maps) { + map.resize(static_cast(num_experts)); + for (auto& expert_map : map) { + expert_map.reserve(static_cast(std::max(static_cast(1), num_tokens / num_experts / num_routing_threads * 2))); + } + } + + concurrency::ThreadPool::TrySimpleParallelFor(tp, num_routing_threads, [&](std::ptrdiff_t thread_id) { + auto work = concurrency::ThreadPool::PartitionWork(static_cast(thread_id), num_routing_threads, static_cast(num_tokens)); + auto& local_expert_token_map = thread_local_expert_token_maps[thread_id]; + + std::vector> sorted_logits(static_cast(num_experts)); + std::vector full_softmax(static_cast(num_experts)); + + for (int64_t i = work.start; i < work.end; ++i) { + const float* logits = router_logits_float + i * num_experts; + + float max_logit = logits[0]; + for (int64_t j = 1; j < num_experts; ++j) { + max_logit = std::max(max_logit, logits[j]); + } + + float sum_exp = 0.0f; + for (int64_t j = 0; j < num_experts; ++j) { + full_softmax[static_cast(j)] = std::exp(logits[j] - max_logit); + sum_exp += full_softmax[static_cast(j)]; + } + + const float inv_sum_exp = 1.0f / sum_exp; + for (int64_t j = 0; j < num_experts; ++j) { + full_softmax[static_cast(j)] *= inv_sum_exp; + sorted_logits[static_cast(j)] = {full_softmax[static_cast(j)], j}; + } + + std::partial_sort(sorted_logits.begin(), sorted_logits.begin() + static_cast(k_), sorted_logits.end(), std::greater<>()); + + if (normalize_routing_weights_) { + float top_k_sum = 0.0f; + for (int64_t j = 0; j < k_; ++j) { + top_k_sum += sorted_logits[static_cast(j)].first; + } + const float inv_top_k_sum = 1.0f / top_k_sum; + + for (int64_t j = 0; j < k_; ++j) { + int64_t expert_idx = sorted_logits[static_cast(j)].second; + int64_t route_idx = i * k_ + j; + float normalized_weight = sorted_logits[static_cast(j)].first * inv_top_k_sum; + + route_expert[route_idx] = static_cast(expert_idx); + route_scale[route_idx] = normalized_weight; + if (normalized_weight > 0.0f) { + local_expert_token_map[static_cast(expert_idx)].push_back(route_idx); + } + } + } else { + for (int64_t j = 0; j < k_; ++j) { + int64_t expert_idx = sorted_logits[static_cast(j)].second; + int64_t route_idx = i * k_ + j; + float weight = sorted_logits[static_cast(j)].first; + + route_expert[route_idx] = static_cast(expert_idx); + route_scale[route_idx] = weight; + if (weight > 0.0f) { + local_expert_token_map[static_cast(expert_idx)].push_back(route_idx); + } + } + } + } + }); + + std::vector> expert_token_map(static_cast(num_experts)); + + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + size_t total_tokens_for_expert = 0; + for (int t = 0; t < num_routing_threads; ++t) { + total_tokens_for_expert += thread_local_expert_token_maps[t][static_cast(expert_idx)].size(); + } + expert_token_map[static_cast(expert_idx)].reserve(total_tokens_for_expert); + } + + for (int t = 0; t < num_routing_threads; ++t) { + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + auto& local_tokens = thread_local_expert_token_maps[t][static_cast(expert_idx)]; + if (!local_tokens.empty()) { + auto& expert_map = expert_token_map[static_cast(expert_idx)]; + expert_map.insert(expert_map.end(), + std::make_move_iterator(local_tokens.begin()), + std::make_move_iterator(local_tokens.end())); + } + } + } + + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + auto& expert_map = expert_token_map[static_cast(expert_idx)]; + if (!expert_map.empty()) { + std::sort(expert_map.begin(), expert_map.end()); + } + } + + IAllocatorUniquePtr input_float_buffer; + const float* input_float; + if constexpr (std::is_same_v) { + input_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * hidden_size)); + input_float = input_float_buffer.get(); + MlasConvertHalfToFloatBuffer(reinterpret_cast(input_data_to_use), const_cast(input_float), static_cast(num_tokens * hidden_size)); + } else { + input_float = reinterpret_cast(input_data_to_use); + } + + int num_expert_threads = 1; + if (tp != nullptr) { + int total_active_experts = 0; + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + if (!expert_token_map[static_cast(expert_idx)].empty()) { + total_active_experts++; + } + } + + if (total_active_experts > 0) { + int max_threads = concurrency::ThreadPool::DegreeOfParallelism(tp); + num_expert_threads = std::min(total_active_experts, max_threads); + num_expert_threads = std::min(num_expert_threads, 8); + } + } + + // Calculate maximum possible tokens per expert for buffer sizing + int64_t max_tokens_per_expert = 0; + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + const auto& routes = expert_token_map[static_cast(expert_idx)]; + max_tokens_per_expert = std::max(max_tokens_per_expert, static_cast(routes.size())); + } + + // Thread-local buffer pool for expert processing + struct ThreadLocalBuffers { + IAllocatorUniquePtr A1_buffer; + IAllocatorUniquePtr batch_weights_buffer; + IAllocatorUniquePtr token_ids_buffer; + IAllocatorUniquePtr A1_t_buffer; + IAllocatorUniquePtr C2_buffer; + // Additional buffers for ProcessExpertBatch to avoid repeated allocations + IAllocatorUniquePtr fc1_output_buffer; + IAllocatorUniquePtr activation_output_buffer; + int64_t current_capacity = 0; + int64_t current_fc1_capacity = 0; + int64_t current_activation_capacity = 0; + + void EnsureCapacity(AllocatorPtr& allocator, int64_t required_tokens, int64_t hidden_size, + int64_t fc1_output_size, int64_t inter_size) { + if (required_tokens > current_capacity) { + // Use high watermark approach - allocate more than needed for future reuse + int64_t new_capacity = std::max(required_tokens * 2, current_capacity + 512); + + A1_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(new_capacity * hidden_size)); + batch_weights_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(new_capacity)); + token_ids_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(new_capacity)); + A1_t_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(new_capacity * hidden_size)); + C2_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(new_capacity * hidden_size)); + + current_capacity = new_capacity; + } + + // Ensure ProcessExpertBatch buffers have sufficient capacity + int64_t required_fc1_capacity = required_tokens * fc1_output_size; + int64_t required_activation_capacity = required_tokens * inter_size; + + if (required_fc1_capacity > current_fc1_capacity) { + int64_t new_fc1_capacity = std::max(required_fc1_capacity * 2, current_fc1_capacity + (512 * fc1_output_size)); + fc1_output_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(new_fc1_capacity)); + current_fc1_capacity = new_fc1_capacity; + } + + if (required_activation_capacity > current_activation_capacity) { + int64_t new_activation_capacity = std::max(required_activation_capacity * 2, current_activation_capacity + (512 * inter_size)); + activation_output_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(new_activation_capacity)); + current_activation_capacity = new_activation_capacity; + } + } + }; + + // Pre-allocate thread-local buffer pools + std::vector thread_buffers(num_expert_threads); + for (int i = 0; i < num_expert_threads; ++i) { + thread_buffers[i].EnsureCapacity(allocator, max_tokens_per_expert, hidden_size, fc1_output_size, inter_size); + } + + const size_t output_buffer_size = static_cast(output->Shape().Size()); + auto thread_local_outputs_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_expert_threads) * output_buffer_size); + float* thread_local_outputs = thread_local_outputs_ptr.get(); + + // Initialize thread-local outputs with vectorized operation + std::fill_n(thread_local_outputs, static_cast(num_expert_threads) * output_buffer_size, 0.0f); + + // Optimized expert processing with thread-local buffer reuse + concurrency::ThreadPool::TrySimpleParallelFor(tp, num_expert_threads, [&](std::ptrdiff_t thread_id_pd) { + int thread_id = static_cast(thread_id_pd); + auto work = concurrency::ThreadPool::PartitionWork(thread_id, num_expert_threads, static_cast(num_experts)); + + float* local_output = thread_local_outputs + static_cast(thread_id) * output_buffer_size; + ThreadLocalBuffers& buffers = thread_buffers[thread_id]; + + for (int64_t expert_idx = work.start; expert_idx < work.end; ++expert_idx) { + const auto& routes = expert_token_map[static_cast(expert_idx)]; + if (routes.empty()) continue; + + const int64_t num_expert_tokens = static_cast(routes.size()); + + // Ensure thread-local buffers have sufficient capacity + buffers.EnsureCapacity(allocator, num_expert_tokens, hidden_size, fc1_output_size, inter_size); + + // Use pre-allocated buffers from thread-local pool + float* A1 = buffers.A1_buffer.get(); + float* batch_weights = buffers.batch_weights_buffer.get(); + int64_t* token_ids = buffers.token_ids_buffer.get(); + T* A1_t = buffers.A1_t_buffer.get(); + T* C2 = buffers.C2_buffer.get(); + T* fc1_output = buffers.fc1_output_buffer.get(); + T* activation_output = buffers.activation_output_buffer.get(); + + // Optimized data gathering with better memory access patterns + for (int64_t r = 0; r < num_expert_tokens; ++r) { + int64_t route_idx = routes[static_cast(r)]; + int64_t token = route_idx / k_; + + token_ids[r] = token; + batch_weights[r] = route_scale[route_idx]; + + // Use SIMD-friendly copy for better performance + const float* src = input_float + token * hidden_size; + float* dst = A1 + static_cast(r) * static_cast(hidden_size); + std::copy(src, src + hidden_size, dst); + } + + const T* fc1_expert_weights = fc1_weights_data + expert_idx * fc1_output_size * hidden_size; + const T* fc1_expert_bias = fc1_bias_data ? fc1_bias_data + expert_idx * fc1_output_size : nullptr; + const T* fc2_expert_weights = fc2_weights_data + expert_idx * hidden_size * inter_size; + const T* fc2_expert_bias = fc2_bias_data ? fc2_bias_data + expert_idx * hidden_size : nullptr; + + // Convert input to T only when needed for computation + for (size_t i = 0; i < static_cast(num_expert_tokens * hidden_size); ++i) { + A1_t[i] = static_cast(A1[i]); + } + + ORT_IGNORE_RETURN_VALUE(ProcessExpertBatch(A1_t, token_ids, batch_weights, + num_expert_tokens, expert_idx, + fc1_expert_weights, fc1_expert_bias, + fc2_expert_weights, fc2_expert_bias, + C2, hidden_size, inter_size, + fc1_output, activation_output)); + + // Optimized output accumulation with vectorized operations + for (int64_t r = 0; r < num_expert_tokens; ++r) { + int64_t token = token_ids[r]; + const T* expert_output_t = C2 + static_cast(r) * static_cast(hidden_size); + float w = batch_weights[r]; + float* dest = local_output + static_cast(token) * static_cast(hidden_size); + + // Use explicit loop for better vectorization opportunities + for (int64_t j = 0; j < hidden_size; ++j) { + dest[j] += w * static_cast(expert_output_t[j]); + } + } + } + }); + + auto accumulate = [&](float* buffer) { + std::fill_n(buffer, output_buffer_size, 0.0f); + + for (size_t j = 0; j < output_buffer_size; ++j) { + double sum = 0.0; + double c = 0.0; + + for (int i = 0; i < num_expert_threads; ++i) { + const size_t thread_offset = static_cast(i) * output_buffer_size; + double y = static_cast(thread_local_outputs[thread_offset + j]) - c; + double t = sum + y; + c = (t - sum) - y; + sum = t; + } + buffer[j] = static_cast(sum); + } + }; + + if constexpr (std::is_same_v) { + auto final_output_float_ptr = IAllocator::MakeUniquePtr(allocator, output_buffer_size); + float* final_output_float = final_output_float_ptr.get(); + accumulate(final_output_float); + + MlasConvertFloatToHalfBuffer(final_output_float, + reinterpret_cast(output->MutableData()), + static_cast(output_buffer_size)); + } else { + auto final_output_float_ptr = IAllocator::MakeUniquePtr(allocator, output_buffer_size); + float* final_output_float = final_output_float_ptr.get(); + accumulate(final_output_float); + + float* out_ptr = reinterpret_cast(output->MutableData()); + memcpy(out_ptr, final_output_float, output_buffer_size * sizeof(float)); + } + return Status::OK(); +} +template +Status MoE::ProcessExpertBatch(const T* input_tokens, + const int64_t* token_expert_ids, + const float* token_weights, + int64_t batch_size, + int64_t expert_id, + const T* fc1_weights, + const T* fc1_bias, + const T* fc2_weights, + const T* fc2_bias, + T* output_buffer, + int64_t hidden_size, + int64_t inter_size, + T* fc1_output_buffer, + T* activation_output_buffer) const { + const bool is_swiglu = activation_type_ == ActivationType::SwiGLU; + const int64_t fc1_output_size = is_swiglu ? (inter_size * 2) : inter_size; + + constexpr int64_t stack_threshold = 1024; + const bool use_stack = (batch_size * fc1_output_size) <= stack_threshold; + + std::vector fc1_output_vec; + std::vector activation_output_vec; + T* fc1_output; + T* activation_output; + + if (use_stack) { + fc1_output_vec.resize(static_cast(batch_size * fc1_output_size)); + activation_output_vec.resize(static_cast(batch_size * inter_size)); + fc1_output = fc1_output_vec.data(); + activation_output = activation_output_vec.data(); + } else { + fc1_output_vec.resize(static_cast(batch_size * fc1_output_size)); + activation_output_vec.resize(static_cast(batch_size * inter_size)); + fc1_output = fc1_output_vec.data(); + activation_output = activation_output_vec.data(); + } + + ORT_RETURN_IF_ERROR(ComputeGEMM(input_tokens, fc1_weights, fc1_output, + batch_size, hidden_size, fc1_output_size, true)); + + if (fc1_bias) { + for (int64_t batch = 0; batch < batch_size; ++batch) { + T* batch_output = fc1_output + batch * fc1_output_size; + // Explicit loop for better vectorization + for (int64_t i = 0; i < fc1_output_size; ++i) { + batch_output[i] = static_cast(static_cast(batch_output[i]) + + static_cast(fc1_bias[i])); + } + } + } + + if (is_swiglu) { + for (int64_t batch = 0; batch < batch_size; ++batch) { + ApplySwiGLUVectorized(fc1_output + batch * fc1_output_size, + activation_output + batch * inter_size, + inter_size); + } + } else { + ApplyActivationVectorized(fc1_output, batch_size * fc1_output_size); + std::copy(fc1_output, fc1_output + (batch_size * fc1_output_size), activation_output); + } + + ORT_RETURN_IF_ERROR(ComputeGEMM(activation_output, fc2_weights, output_buffer, + batch_size, inter_size, hidden_size, true)); + + if (fc2_bias) { + for (int64_t batch = 0; batch < batch_size; ++batch) { + T* batch_output = output_buffer + batch * hidden_size; + for (int64_t i = 0; i < hidden_size; ++i) { + batch_output[i] = static_cast(static_cast(batch_output[i]) + + static_cast(fc2_bias[i])); + } + } + } + + return Status::OK(); +} + +template <> +Status MoE::ComputeGEMM(const float* A, const float* B, float* C, + int64_t M, int64_t K, int64_t N, bool transpose_B) const { + MLAS_SGEMM_DATA_PARAMS params; + params.A = A; + params.lda = static_cast(K); + params.alpha = 1.0f; + params.beta = 0.0f; + params.C = C; + params.ldc = static_cast(N); + params.B = B; + + if (transpose_B) { + params.ldb = static_cast(K); + MlasGemm(CblasNoTrans, CblasTrans, static_cast(M), static_cast(N), static_cast(K), params, nullptr); + } else { + params.ldb = static_cast(N); + MlasGemm(CblasNoTrans, CblasNoTrans, static_cast(M), static_cast(N), static_cast(K), params, nullptr); + } + + return Status::OK(); +} + +template <> +Status MoE::ComputeGEMM(const MLFloat16* A, const MLFloat16* B, MLFloat16* C, + int64_t M, int64_t K, int64_t N, bool transpose_B) const { + MLAS_HALF_GEMM_DATA_PARAMS params; + params.A = A; + params.lda = static_cast(K); + params.C = C; + params.ldc = static_cast(N); + params.AIsfp32 = false; + params.BIsfp32 = false; + params.B = B; + + if (transpose_B) { + params.ldb = static_cast(K); + } else { + params.ldb = static_cast(N); + } + + MlasHalfGemmBatch(static_cast(M), static_cast(N), static_cast(K), 1, ¶ms, nullptr); + return Status::OK(); +} + +template +void MoE::ApplyActivationVectorized(T* data, int64_t size) const { + for (int64_t i = 0; i < size; ++i) { + float val = static_cast(data[i]); + data[i] = static_cast(ApplyActivation(val, activation_type_)); + } +} + +template +void MoE::ApplySwiGLUVectorized(const T* input, T* output, int64_t size) const { + for (int64_t i = 0; i < size; ++i) { + float gate = static_cast(input[2 * i]); + float linear = static_cast(input[2 * i + 1]); + + gate = std::min(gate, swiglu_limit_); + linear = std::clamp(linear, -swiglu_limit_, swiglu_limit_); + + float sigmoid_arg = activation_alpha_ * gate; + float sigmoid_out; + if (sigmoid_arg > 0) { + float exp_neg = std::exp(-sigmoid_arg); + sigmoid_out = 1.0f / (1.0f + exp_neg); + } else { + float exp_pos = std::exp(sigmoid_arg); + sigmoid_out = exp_pos / (1.0f + exp_pos); + } + + float swish_out = gate * sigmoid_out; + output[i] = static_cast(swish_out * (linear + activation_beta_)); + } +} + +template <> +void MoE::ApplySwiGLUVectorized(const float* input, float* output, int64_t size) const { + ApplySwiGLUActivation(input, output, size, true, + activation_alpha_, activation_beta_, swiglu_limit_); +} + +template class MoE; +template class MoE; + +#define REGISTER_KERNEL_TYPED(type) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + MoE, kMSDomain, 1, type, kCpuExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .MayInplace(0, 0) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + MoE); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_cpu.h b/onnxruntime/contrib_ops/cpu/moe/moe_cpu.h new file mode 100644 index 0000000000000..60d8217015b5b --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/moe/moe_cpu.h @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "contrib_ops/cpu/moe/moe_base_cpu.h" + +namespace onnxruntime { +namespace contrib { + +template +class MoE final : public OpKernel, public MoEBaseCPU { + public: + explicit MoE(const OpKernelInfo& op_kernel_info); + Status Compute(OpKernelContext* context) const override; + + private: + Status ComputeMoE(const OpKernelContext* context, + const Tensor* input, + const Tensor* router_probs, + const Tensor* fc1_experts_weights, + const Tensor* fc1_experts_bias, + const Tensor* fc2_experts_weights, + const Tensor* fc2_experts_bias, + Tensor* output) const; + + Status ProcessExpertBatch(const T* input_tokens, + const int64_t* token_expert_ids, + const float* token_weights, + int64_t num_tokens, + int64_t expert_id, + const T* fc1_weights, + const T* fc1_bias, + const T* fc2_weights, + const T* fc2_bias, + T* output_buffer, + int64_t hidden_size, + int64_t inter_size, + T* fc1_output_buffer, + T* activation_output_buffer) const; + + Status ComputeGEMM(const T* A, const T* B, T* C, + int64_t M, int64_t K, int64_t N, + bool transpose_B = false) const; + + void ApplyActivationVectorized(T* data, int64_t size) const; + void ApplySwiGLUVectorized(const T* input, T* output, int64_t size) const; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_utils.cc b/onnxruntime/contrib_ops/cpu/moe/moe_utils.cc index 2c59210bfabd4..5a3c5d1dd0364 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_utils.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_utils.cc @@ -37,10 +37,18 @@ void ApplySwiGLUActivation(const float* input_data, float* output_data, int64_t gate_val = std::min(gate_val, clamp_limit); linear_val = std::clamp(linear_val, -clamp_limit, clamp_limit); + // Use numerically stable sigmoid computation (matches CUDA kernel behavior) float sigmoid_arg = activation_alpha * gate_val; - float sigmoid_out = 1.0f / (1.0f + std::exp(-sigmoid_arg)); - float swish_out = gate_val * sigmoid_out; + float sigmoid_out; + if (sigmoid_arg > 0) { + float exp_neg = std::exp(-sigmoid_arg); + sigmoid_out = 1.0f / (1.0f + exp_neg); + } else { + float exp_pos = std::exp(sigmoid_arg); + sigmoid_out = exp_pos / (1.0f + exp_pos); + } + float swish_out = gate_val * sigmoid_out; output_data[i] = swish_out * (linear_val + activation_beta); } } else { diff --git a/onnxruntime/test/contrib_ops/moe_test.cc b/onnxruntime/test/contrib_ops/moe_test.cc index 0690b8894eb7a..ab740ea38fb74 100644 --- a/onnxruntime/test/contrib_ops/moe_test.cc +++ b/onnxruntime/test/contrib_ops/moe_test.cc @@ -1690,6 +1690,97 @@ TEST(MoETest, QMoETest_CPU_SwiGLU_Int8) { #endif } +// Test for CPU MoE implementation +static void RunMoECpuTest(const std::vector& input, const std::vector& router_probs, + const std::vector& fc1_experts_weights, const std::vector& fc2_experts_weights, + const std::vector& fc3_experts_weights, const std::vector& fc1_experts_bias, + const std::vector& fc2_experts_bias, const std::vector& output_data, int num_rows, + int num_experts, int hidden_size, int inter_size, std::string activation_type, + int normalize_routing_weights = 1, int top_k = 1) { + OpTester tester("MoE", 1, onnxruntime::kMSDomain); + tester.AddAttribute("k", static_cast(top_k)); + tester.AddAttribute("activation_type", activation_type); + tester.AddAttribute("normalize_routing_weights", static_cast(normalize_routing_weights)); + + bool is_swiglu = (activation_type == "swiglu"); + + if (is_swiglu) { + tester.AddAttribute("swiglu_fusion", static_cast(1)); + tester.AddAttribute("activation_beta", 1.0f); + } + + std::vector input_dims = {num_rows, hidden_size}; + std::vector router_probs_dims = {num_rows, num_experts}; + + int64_t fc1_output_size = is_swiglu ? (2 * inter_size) : inter_size; + + std::vector fc1_experts_weights_dims = {num_experts, hidden_size, fc1_output_size}; + std::vector fc2_experts_weights_dims = {num_experts, inter_size, hidden_size}; + std::vector fc3_experts_weights_dims = fc1_experts_weights_dims; + std::vector fc1_experts_bias_dims = {num_experts, fc1_output_size}; + std::vector fc2_experts_bias_dims = {num_experts, hidden_size}; + std::vector output_dims = {num_rows, hidden_size}; + + tester.AddInput("input", input_dims, input); + tester.AddInput("router_probs", router_probs_dims, router_probs); + tester.AddInput("fc1_experts_weights", fc1_experts_weights_dims, fc1_experts_weights); + if (!fc1_experts_bias.empty()) { + tester.AddInput("fc1_experts_bias", fc1_experts_bias_dims, fc1_experts_bias); + } else { + tester.AddOptionalInputEdge(); + } + tester.AddInput("fc2_experts_weights", fc2_experts_weights_dims, fc2_experts_weights); + if (!fc2_experts_bias.empty()) { + tester.AddInput("fc2_experts_bias", fc2_experts_bias_dims, fc2_experts_bias); + } else { + tester.AddOptionalInputEdge(); + } + if (!fc3_experts_weights.empty()) { + tester.AddInput("fc3_experts_weights", fc3_experts_weights_dims, fc3_experts_weights); + } else { + tester.AddOptionalInputEdge(); + } + tester.AddOptionalInputEdge(); // fc3_experts_bias + + tester.AddOutput("output", output_dims, output_data); + tester.SetOutputTolerance(0.05f); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(MoETest, MoECpuTest_BasicSwiGLU) { + int num_rows = 2; + int num_experts = 2; + int hidden_size = 4; + int inter_size = 8; + + // Simple test data + const std::vector input = { + 1.0f, 2.0f, 3.0f, 4.0f, + 5.0f, 6.0f, 7.0f, 8.0f}; + + const std::vector router_probs = { + 0.8f, 0.2f, + 0.3f, 0.7f}; + + const std::vector fc1_experts_weights(num_experts * hidden_size * (2 * inter_size), 0.1f); + + const std::vector fc2_experts_weights(num_experts * inter_size * hidden_size, 0.1f); + + const std::vector fc3_experts_weights = {}; // No FC3 + const std::vector fc1_experts_bias = {}; // No bias + const std::vector fc2_experts_bias = {}; // No bias + + const std::vector output_data = { + 1.169694f, 1.169694f, 1.169694f, 1.169694f, + 6.970291f, 6.970291f, 6.970291f, 6.970291f}; + + RunMoECpuTest(input, router_probs, fc1_experts_weights, fc2_experts_weights, + fc3_experts_weights, fc1_experts_bias, fc2_experts_bias, output_data, + num_rows, num_experts, hidden_size, inter_size, "swiglu"); +} #endif } // namespace test diff --git a/onnxruntime/test/python/transformers/test_moe_cpu.py b/onnxruntime/test/python/transformers/test_moe_cpu.py new file mode 100644 index 0000000000000..d6cbcc64733d4 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_moe_cpu.py @@ -0,0 +1,473 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +# +# Regular MoE CPU kernel testing implementation - SwiGLU Interleaved Only +# +# This file tests the non-quantized MoE CPU implementation with SwiGLU +# activation in interleaved format and validates parity between +# PyTorch reference implementation and ONNX Runtime CPU kernel. +# +# Based on the CUDA test structure for consistency. +# -------------------------------------------------------------------------- + +import itertools +import time +import unittest + +import numpy +import torch +import torch.nn as nn +import torch.nn.functional as F +from onnx import TensorProto, helper +from parameterized import parameterized + +from onnxruntime import InferenceSession, SessionOptions + +# Device and provider settings for CPU +device = torch.device("cpu") +ort_provider = ["CPUExecutionProvider"] + +torch.manual_seed(42) +numpy.random.seed(42) + +onnx_to_torch_type_map = { + TensorProto.FLOAT16: torch.float16, + TensorProto.FLOAT: torch.float, + TensorProto.BFLOAT16: torch.bfloat16, + TensorProto.UINT8: torch.uint8, +} + +ort_to_numpy_type_map = { + TensorProto.FLOAT16: numpy.float16, + TensorProto.FLOAT: numpy.float32, + TensorProto.UINT8: numpy.uint8, +} + +ort_dtype_name_map = { + TensorProto.FLOAT16: "FP16", + TensorProto.FLOAT: "FP32", + TensorProto.BFLOAT16: "BF16", +} + + +class SwigluMoeConfig: + def __init__( + self, + hidden_size=2048, + intermediate_size=2048, + num_experts_per_token=2, + num_local_experts=8, + ): + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_experts_per_token = num_experts_per_token + self.num_local_experts = num_local_experts + + +def swiglu(x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0): + x_glu = x[..., ::2] + x_linear = x[..., 1::2] + + if limit is not None: + x_glu = x_glu.clamp(max=limit) + x_linear = x_linear.clamp(min=-limit, max=limit) + + y = x_glu * torch.sigmoid(alpha * x_glu) * (x_linear + 1) + return y + + +class SwigluMlp(nn.Module): + def __init__(self, config): + super().__init__() + self.intermediate_size = config.intermediate_size + self.hidden_dim = config.hidden_size + self.w1 = nn.Linear(self.hidden_dim, 2 * self.intermediate_size, bias=True) + self.w2 = nn.Linear(self.intermediate_size, self.hidden_dim, bias=True) + + def forward(self, x): + x1 = self.w1(x) + y = swiglu(x1) + y = self.w2(y) + return y + + +def make_onnx_intializer(name: str, tensor: torch.Tensor, shape, onnx_dtype): + torch_dtype = onnx_to_torch_type_map[onnx_dtype] + if torch_dtype == torch.bfloat16: + numpy_vals_uint16 = tensor.to(torch.bfloat16).cpu().view(torch.uint16).numpy() + initializer = helper.make_tensor( + name=name, + data_type=TensorProto.BFLOAT16, + dims=shape, + vals=numpy_vals_uint16.tobytes(), + raw=True, + ) + else: + initializer = helper.make_tensor( + name=name, + data_type=onnx_dtype, + dims=shape, + vals=tensor.flatten().detach().cpu().numpy().astype(numpy.uint8).tolist() + if onnx_dtype == TensorProto.UINT8 + else tensor.detach().to(torch_dtype).flatten().tolist(), + raw=False, + ) + return initializer + + +def create_swiglu_moe_onnx_graph( + num_tokens: int, + num_experts: int, + hidden_size: int, + inter_size: int, + topk: int, + onnx_dtype: int, + quant_bits: int, + fc1_experts_weights: torch.Tensor, + fc1_experts_bias: torch.Tensor, + fc2_experts_weights: torch.Tensor, + fc2_experts_bias: torch.Tensor, + fc1_experts_weight_scale: torch.Tensor = None, + fc2_experts_weight_scale: torch.Tensor = None, +): + use_quant = quant_bits > 0 + op_name = "QMoE" if use_quant else "MoE" + + inputs = ( + [ + "input", + "router_probs", + "fc1_experts_weights", + "fc1_experts_weight_scale", + "fc1_experts_bias", + "fc2_experts_weights", + "fc2_experts_weight_scale", + "fc2_experts_bias", + ] + if use_quant + else [ + "input", + "router_probs", + "fc1_experts_weights", + "fc1_experts_bias", + "fc2_experts_weights", + "fc2_experts_bias", + ] + ) + + nodes = [ + helper.make_node( + op_name, + inputs, + ["output"], + "MoE_0", + k=topk, + normalize_routing_weights=1, + activation_type="swiglu", + swiglu_fusion=1, + activation_alpha=1.702, + activation_beta=1.0, + domain="com.microsoft", + ), + ] + + if use_quant: + nodes[0].attribute.extend([helper.make_attribute("expert_weight_bits", quant_bits)]) + + components = 2 if quant_bits == 4 else 1 + fc1_weight_shape = [num_experts, 2 * inter_size, hidden_size // components] + fc1_bias_shape = [num_experts, 2 * inter_size] + fc1_experts_weight_scale_shape = [num_experts, 2 * inter_size] + + fc2_weight_shape = [num_experts, hidden_size, inter_size // components] + fc2_bias_shape = [num_experts, hidden_size] + fc2_experts_weight_scale_shape = [num_experts, hidden_size] + + weight_onnx_type = TensorProto.UINT8 if use_quant else onnx_dtype + + torch_dtype = onnx_to_torch_type_map[onnx_dtype] + weight_torch_dtype = onnx_to_torch_type_map[weight_onnx_type] + + initializers = [ + make_onnx_intializer( + "fc1_experts_weights", fc1_experts_weights.to(weight_torch_dtype), fc1_weight_shape, weight_onnx_type + ), + make_onnx_intializer("fc1_experts_bias", fc1_experts_bias.to(torch_dtype), fc1_bias_shape, onnx_dtype), + make_onnx_intializer( + "fc2_experts_weights", fc2_experts_weights.to(weight_torch_dtype), fc2_weight_shape, weight_onnx_type + ), + make_onnx_intializer("fc2_experts_bias", fc2_experts_bias.to(torch_dtype), fc2_bias_shape, onnx_dtype), + ] + + if use_quant: + initializers.extend( + [ + make_onnx_intializer( + "fc1_experts_weight_scale", + fc1_experts_weight_scale.to(torch_dtype), + fc1_experts_weight_scale_shape, + onnx_dtype, + ), + make_onnx_intializer( + "fc2_experts_weight_scale", + fc2_experts_weight_scale.to(torch_dtype), + fc2_experts_weight_scale_shape, + onnx_dtype, + ), + ] + ) + + graph_inputs = [ + helper.make_tensor_value_info("input", onnx_dtype, [num_tokens, hidden_size]), + ] + + graph_inputs.append( + helper.make_tensor_value_info( + "router_probs", + onnx_dtype, + [num_tokens, num_experts], + ) + ) + + graph_outputs = [ + helper.make_tensor_value_info("output", onnx_dtype, [num_tokens, hidden_size]), + ] + + graph = helper.make_graph( + nodes, + "MoE_Graph", + graph_inputs, + graph_outputs, + initializers, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +class SparseMoeBlockORTHelper(nn.Module): + def __init__(self, quant_bits=0, onnx_dtype=None): + super().__init__() + self.quant_bits = quant_bits + if onnx_dtype is None: + self.onnx_dtype = TensorProto.FLOAT16 if self.quant_bits > 0 else TensorProto.FLOAT + else: + self.onnx_dtype = onnx_dtype + self.np_type = numpy.float16 if self.onnx_dtype == TensorProto.FLOAT16 else numpy.float32 + + def create_ort_session(self, moe_onnx_graph): + sess_options = SessionOptions() + sess_options.log_severity_level = 2 + + try: + ort_session = InferenceSession(moe_onnx_graph, sess_options, providers=ort_provider) + except Exception as e: + print(f"Failed to create ONNX Runtime session with provider {ort_provider}: {e}") + print("Skipping ONNX Runtime execution for this test case.") + return None + + return ort_session + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + pass + + def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False) -> torch.Tensor: + if self.ort_sess is None: + return None + + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states_flat = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states_flat) + + torch_dtype = onnx_to_torch_type_map[self.onnx_dtype] + + tensors = { + "input": hidden_states_flat.clone().to(device=device, dtype=torch_dtype), + "router_probs": router_logits.clone().to(device=device, dtype=torch_dtype), + "output": torch.zeros_like(hidden_states_flat, device=device, dtype=torch_dtype), + } + + ort_inputs = { + "input": tensors["input"].detach().cpu().numpy(), + "router_probs": tensors["router_probs"].detach().cpu().numpy(), + } + + if enable_performance_test: + repeat = 1000 + s = time.time() + for _ in range(repeat): + self.ort_sess.run(None, ort_inputs) + e = time.time() + print(f"MoE CPU kernel time: {(e - s) / repeat * 1000} ms") + ort_outputs = self.ort_sess.run(None, ort_inputs) + else: + ort_outputs = self.ort_sess.run(None, ort_inputs) + + output_tensor = torch.from_numpy(ort_outputs[0]).to(device) + + return output_tensor.reshape(batch_size, sequence_length, hidden_dim) + + def parity_check(self): + hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim).to(device) + torch_output = self.forward(hidden_state) + ort_output = self.ort_forward(hidden_state) + + dtype_str = ort_dtype_name_map[self.onnx_dtype] + + ort_dtype_quant_bits_tolerance_map = { + "FP32:0": (5e-3, 1e-3), + } + + atol, rtol = ort_dtype_quant_bits_tolerance_map[f"{dtype_str}:{self.quant_bits}"] + if ort_output is not None: + print( + f"name: {self.__class__.__name__}, quant_bits: {self.quant_bits}, dtype: {dtype_str}," + f" batch: {self.batch_size}, seq_len: {self.sequence_length}," + f" max_diff: {(torch_output.cpu() - ort_output.cpu()).abs().max()}" + ) + torch.testing.assert_close( + ort_output.cpu().to(torch.float32), torch_output.cpu().to(torch.float32), rtol=rtol, atol=atol + ) + + def benchmark_ort(self): + hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim).to(device) + self.ort_forward(hidden_state, enable_performance_test=True) + + +class SwigluMoEBlock(SparseMoeBlockORTHelper): + def __init__( + self, config: SwigluMoeConfig, batch_size: int, sequence_length: int, quant_bits: int = 0, onnx_dtype=None + ): + super().__init__(quant_bits, onnx_dtype=onnx_dtype) + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_token + + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=True) + + self.experts = nn.ModuleList([SwigluMlp(config) for _ in range(self.num_experts)]) + + fc1_w_list, fc2_w_list = [], [] + fc1_b_list, fc2_b_list = [], [] + + for expert in self.experts: + w1_weight = expert.w1.weight.data.clone() + w2_weight = expert.w2.weight.data.clone() + w1_bias = expert.w1.bias.data.clone() + w2_bias = expert.w2.bias.data.clone() + + fc1_w_list.append(w1_weight) + fc2_w_list.append(w2_weight) + fc1_b_list.append(w1_bias) + fc2_b_list.append(w2_bias) + + fc1_experts_weights = torch.stack(fc1_w_list, dim=0) + fc2_experts_weights = torch.stack(fc2_w_list, dim=0) + fc1_experts_bias = torch.stack(fc1_b_list, dim=0) + fc2_experts_bias = torch.stack(fc2_b_list, dim=0) + + self.batch_size = batch_size + self.sequence_length = sequence_length + + self.moe_onnx_graph = create_swiglu_moe_onnx_graph( + num_tokens=self.batch_size * self.sequence_length, + num_experts=self.num_experts, + hidden_size=self.hidden_dim, + inter_size=self.ffn_dim, + topk=self.top_k, + onnx_dtype=self.onnx_dtype, + quant_bits=self.quant_bits, + fc1_experts_weights=fc1_experts_weights, + fc1_experts_bias=fc1_experts_bias, + fc2_experts_weights=fc2_experts_weights, + fc2_experts_bias=fc2_experts_bias, + ) + + self.ort_sess = self.create_ort_session(self.moe_onnx_graph) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states) + + # Compute full softmax over all experts (same as CUDA) + full_probs = F.softmax(router_logits, dim=-1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(full_probs, self.top_k, dim=-1) + + # For normalize_routing_weights=1: normalize by sum of top-k values (same as CUDA) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states + + +swiglu_test_cases = list( + itertools.product( + [1, 2], # batch_size + [16, 32], # sequence_length + [0], # quant_bits (CPU kernel only supports float32) + ) +) + +perf_test_cases = list( + itertools.product( + [1], # batch_size + [128], # sequence_length + [0], # quant_bits (CPU kernel only supports float32) + ) +) + + +class TestSwigluMoECPU(unittest.TestCase): + @parameterized.expand(swiglu_test_cases) + def test_swiglu_moe_parity(self, batch_size, sequence_length, quant_bits): + config = SwigluMoeConfig(hidden_size=64, intermediate_size=256, num_experts_per_token=2, num_local_experts=4) + moe = SwigluMoEBlock(config, batch_size, sequence_length, quant_bits) + moe.to(device) + moe.parity_check() + + +class TestSwigluMoECPUPerf(unittest.TestCase): + @parameterized.expand(perf_test_cases) + def test_swiglu_moe_perf(self, batch_size, sequence_length, quant_bits): + hidden_size = 1024 + intermediate_size = 2048 + num_experts_per_token = 4 + num_local_experts = 16 + config = SwigluMoeConfig( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_experts_per_token=num_experts_per_token, + num_local_experts=num_local_experts, + ) + moe = SwigluMoEBlock(config, batch_size, sequence_length, quant_bits) + moe.to(device) + moe.benchmark_ort() + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_qmoe_cpu.py b/onnxruntime/test/python/transformers/test_qmoe_cpu.py index efaaca29a01b6..403becbe0616a 100644 --- a/onnxruntime/test/python/transformers/test_qmoe_cpu.py +++ b/onnxruntime/test/python/transformers/test_qmoe_cpu.py @@ -254,7 +254,8 @@ def create_cpu_moe_onnx_graph( inter_size = intermediate_size topk = top_k - use_quant = True + # Only override use_quant for backward compatibility if not explicitly set + # use_quant = True # This line was causing issues for regular MoE tests if fc1_scales is None and use_quant: return None @@ -263,30 +264,52 @@ def create_cpu_moe_onnx_graph( if not has_onnx: return None - assert fc1_experts_weights.dtype == torch.uint8, "FC1 weights must be uint8 for QMoE" - assert fc2_experts_weights.dtype == torch.uint8, "FC2 weights must be uint8 for QMoE" - assert fc1_scales is not None, "FC1 scales must be provided for QMoE" - assert fc2_scales is not None, "FC2 scales must be provided for QMoE" - assert fc1_scales.dtype == torch.float16, "FC1 scales must be float16 for QMoE" - assert fc2_scales.dtype == torch.float16, "FC2 scales must be float16 for QMoE" + if use_quant: + # Assertions only apply to quantized MoE + assert fc1_experts_weights.dtype == torch.uint8, "FC1 weights must be uint8 for QMoE" + assert fc2_experts_weights.dtype == torch.uint8, "FC2 weights must be uint8 for QMoE" + assert fc1_scales is not None, "FC1 scales must be provided for QMoE" + assert fc2_scales is not None, "FC2 scales must be provided for QMoE" + assert fc1_scales.dtype == torch.float16, "FC1 scales must be float16 for QMoE" + assert fc2_scales.dtype == torch.float16, "FC2 scales must be float16 for QMoE" if not has_onnx: return None - op_name = "QMoE" - inputs = [ - "input", - "router_probs", - "fc1_experts_weights", - "fc1_scales", - "", - "fc2_experts_weights", - "fc2_scales", - "", - ] + # Set operator name and inputs based on quantization mode + if use_quant: + op_name = "QMoE" + inputs = [ + "input", + "router_probs", + "fc1_experts_weights", + "fc1_scales", + "", + "fc2_experts_weights", + "fc2_scales", + "", + ] + else: + # For regular (non-quantized) MoE, use different operator and input layout + op_name = "MoE" # Regular MoE operator + inputs = [ + "input", + "router_probs", + "fc1_experts_weights", + "fc1_experts_bias" if fc1_bias is not None else "", # fc1_bias as input 3 + "fc2_experts_weights", + "fc2_experts_bias" if fc2_bias is not None else "", # fc2_bias as input 5 + "", # fc3_experts_weights (not used) + "", # fc3_experts_bias (not used) + ] activation = "swiglu" if use_swiglu else "silu" + # Set normalization behavior based on operator type: + # - QMoE: Raw logits passed, needs normalization in C++ kernel + # - Regular MoE: Pre-computed probabilities passed, no additional normalization needed + normalize_routing = 1 if use_quant else 0 + nodes = [ helper.make_node( op_name, @@ -294,13 +317,14 @@ def create_cpu_moe_onnx_graph( ["output"], "MoE_0", k=topk, - normalize_routing_weights=1, # Use proper routing normalization to match PyTorch behavior + normalize_routing_weights=normalize_routing, activation_type=activation, # Add new attributes with backwards-compatible default values - swiglu_fusion=1 if (use_swiglu and swiglu_interleaved) else 0, # 1 = fused and interleaved + swiglu_fusion=1 if use_swiglu else 0, # 1 if using SwiGLU activation swiglu_limit=7.0, activation_alpha=1.702, activation_beta=1.0, + swiglu_interleaved=1 if swiglu_interleaved else 0, # Enable this attribute domain="com.microsoft", ), ] @@ -339,79 +363,106 @@ def create_cpu_moe_onnx_graph( fc1_scale_size = num_experts * (2 * inter_size if use_swiglu else inter_size) fc2_scale_size = num_experts * hidden_size - # Handle scale tensors - fc1_scales and fc2_scales are guaranteed to be not None due to earlier assertions - # Handle different possible scale tensor structures for fc1_scales - if len(fc1_scales.shape) == 4: - # 4D case: [num_experts, inter_size, hidden_size, 1] - extract first scale per expert per output - if use_swiglu: - fc1_scale_tensor = fc1_scales.to(torch_dtype)[:, : 2 * inter_size, 0, 0].flatten().detach().cpu().numpy() + # Handle scale tensors based on quantization mode + if use_quant: + # Handle different possible scale tensor structures for fc1_scales + if len(fc1_scales.shape) == 4: + # 4D case: [num_experts, inter_size, hidden_size, 1] - extract first scale per expert per output + if use_swiglu: + fc1_scale_tensor = ( + fc1_scales.to(torch_dtype)[:, : 2 * inter_size, 0, 0].flatten().detach().cpu().numpy() + ) + else: + fc1_scale_tensor = fc1_scales.to(torch_dtype)[:, :inter_size, 0, 0].flatten().detach().cpu().numpy() + elif len(fc1_scales.shape) == 2: + # 2D case: already flattened, just ensure correct size + fc1_scale_tensor = fc1_scales.to(torch_dtype).flatten().detach().cpu().numpy() + if use_swiglu and fc1_scale_tensor.size == num_experts * inter_size: + # For SwiGLU, duplicate the scales to cover both gate and value components + fc1_scale_tensor = numpy.tile(fc1_scale_tensor.reshape(num_experts, inter_size), (1, 2)).flatten() + elif fc1_scale_tensor.size > fc1_scale_size: + # Truncate to expected size + fc1_scale_tensor = fc1_scale_tensor[:fc1_scale_size] else: - fc1_scale_tensor = fc1_scales.to(torch_dtype)[:, :inter_size, 0, 0].flatten().detach().cpu().numpy() - elif len(fc1_scales.shape) == 2: - # 2D case: already flattened, just ensure correct size - fc1_scale_tensor = fc1_scales.to(torch_dtype).flatten().detach().cpu().numpy() - if use_swiglu and fc1_scale_tensor.size == num_experts * inter_size: - # For SwiGLU, duplicate the scales to cover both gate and value components - fc1_scale_tensor = numpy.tile(fc1_scale_tensor.reshape(num_experts, inter_size), (1, 2)).flatten() - elif fc1_scale_tensor.size > fc1_scale_size: - # Truncate to expected size - fc1_scale_tensor = fc1_scale_tensor[:fc1_scale_size] - else: - # Other cases: flatten and truncate/pad as needed - fc1_scale_tensor = fc1_scales.to(torch_dtype).flatten().detach().cpu().numpy() - if fc1_scale_tensor.size > fc1_scale_size: - fc1_scale_tensor = fc1_scale_tensor[:fc1_scale_size] - elif fc1_scale_tensor.size < fc1_scale_size: - # Pad with ones if too small - pad_size = fc1_scale_size - fc1_scale_tensor.size - fc1_scale_tensor = numpy.concatenate([fc1_scale_tensor, numpy.ones(pad_size, dtype=fc1_scale_tensor.dtype)]) - - # Process scale tensor for proper shape - fc1_scale_data_list = fc1_scale_tensor.tolist() - fc1_scale_data = fc1_scale_data_list - - # Handle different possible scale tensor structures for fc2_scales - if len(fc2_scales.shape) == 4: - # 4D case: [num_experts, hidden_size, inter_size, 1] - extract first scale per expert per output - fc2_scale_tensor = fc2_scales.to(torch_dtype)[:, :hidden_size, 0, 0].flatten().detach().cpu().numpy() - elif len(fc2_scales.shape) == 2: - # 2D case: already flattened, just ensure correct size - fc2_scale_tensor = fc2_scales.to(torch_dtype).flatten().detach().cpu().numpy() - if fc2_scale_tensor.size > fc2_scale_size: - # Truncate to expected size - fc2_scale_tensor = fc2_scale_tensor[:fc2_scale_size] + # Other cases: flatten and truncate/pad as needed + fc1_scale_tensor = fc1_scales.to(torch_dtype).flatten().detach().cpu().numpy() + if fc1_scale_tensor.size > fc1_scale_size: + fc1_scale_tensor = fc1_scale_tensor[:fc1_scale_size] + elif fc1_scale_tensor.size < fc1_scale_size: + # Pad with ones if too small + pad_size = fc1_scale_size - fc1_scale_tensor.size + fc1_scale_tensor = numpy.concatenate( + [fc1_scale_tensor, numpy.ones(pad_size, dtype=fc1_scale_tensor.dtype)] + ) + + # Process scale tensor for proper shape + fc1_scale_data_list = fc1_scale_tensor.tolist() + fc1_scale_data = fc1_scale_data_list + + # Handle different possible scale tensor structures for fc2_scales + if len(fc2_scales.shape) == 4: + # 4D case: [num_experts, hidden_size, inter_size, 1] - extract first scale per expert per output + fc2_scale_tensor = fc2_scales.to(torch_dtype)[:, :hidden_size, 0, 0].flatten().detach().cpu().numpy() + elif len(fc2_scales.shape) == 2: + # 2D case: already flattened, just ensure correct size + fc2_scale_tensor = fc2_scales.to(torch_dtype).flatten().detach().cpu().numpy() + if fc2_scale_tensor.size > fc2_scale_size: + # Truncate to expected size + fc2_scale_tensor = fc2_scale_tensor[:fc2_scale_size] + else: + # Other cases: flatten and truncate/pad as needed + fc2_scale_tensor = fc2_scales.to(torch_dtype).flatten().detach().cpu().numpy() + if fc2_scale_tensor.size > fc2_scale_size: + fc2_scale_tensor = fc2_scale_tensor[:fc2_scale_size] + elif fc2_scale_tensor.size < fc2_scale_size: + # Pad with ones if too small + pad_size = fc2_scale_size - fc2_scale_tensor.size + fc2_scale_tensor = numpy.concatenate( + [fc2_scale_tensor, numpy.ones(pad_size, dtype=fc2_scale_tensor.dtype)] + ) + + # Process scale tensor for proper shape + fc2_scale_data_list = fc2_scale_tensor.tolist() + fc2_scale_data = fc2_scale_data_list + + initializers.extend( + [ + helper.make_tensor( + "fc1_scales", + onnx_dtype, + fc1_scale_shape, + fc1_scale_data, + raw=False, + ), + helper.make_tensor( + "fc2_scales", + onnx_dtype, + fc2_scale_shape, + fc2_scale_data, + raw=False, + ), + ] + ) else: - # Other cases: flatten and truncate/pad as needed - fc2_scale_tensor = fc2_scales.to(torch_dtype).flatten().detach().cpu().numpy() - if fc2_scale_tensor.size > fc2_scale_size: - fc2_scale_tensor = fc2_scale_tensor[:fc2_scale_size] - elif fc2_scale_tensor.size < fc2_scale_size: - # Pad with ones if too small - pad_size = fc2_scale_size - fc2_scale_tensor.size - fc2_scale_tensor = numpy.concatenate([fc2_scale_tensor, numpy.ones(pad_size, dtype=fc2_scale_tensor.dtype)]) - - # Process scale tensor for proper shape - fc2_scale_data_list = fc2_scale_tensor.tolist() - fc2_scale_data = fc2_scale_data_list - - initializers.extend( - [ - helper.make_tensor( - "fc1_scales", - onnx_dtype, - fc1_scale_shape, - fc1_scale_data, - raw=False, - ), - helper.make_tensor( - "fc2_scales", - onnx_dtype, - fc2_scale_shape, - fc2_scale_data, - raw=False, - ), - ] - ) + # For non-quantized mode, add bias tensors if provided + if fc1_bias is not None: + initializers.append( + helper.make_tensor( + "fc1_experts_bias", + onnx_dtype, + list(fc1_bias.shape), + fc1_bias.flatten().detach().cpu().numpy().astype(ort_to_numpy_type_map[onnx_dtype]).tolist(), + ) + ) + if fc2_bias is not None: + initializers.append( + helper.make_tensor( + "fc2_experts_bias", + onnx_dtype, + list(fc2_bias.shape), + fc2_bias.flatten().detach().cpu().numpy().astype(ort_to_numpy_type_map[onnx_dtype]).tolist(), + ) + ) graph_inputs = [ helper.make_tensor_value_info("input", onnx_dtype, [sequence_length, hidden_size]), @@ -619,17 +670,54 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False) -> torch.Tensor: if self.ort_sess is None: + print(f"ERROR: ORT session is None for {self.__class__.__name__}") return None batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_flat = hidden_states.view(-1, hidden_dim) router_logits = self.gate(hidden_states_flat) + # Different routing logic for QMoE vs regular MoE: + # - QMoE expects raw logits (does its own softmax internally) + # - Regular MoE expects pre-computed routing probabilities + if hasattr(self, "quant_bits") and self.quant_bits > 0: + # QMoE: Pass raw logits directly (QMoE does softmax internally) + router_input = router_logits + # print("DEBUG: Using QMoE routing (raw logits)") + else: + # Regular MoE: Apply the same routing logic as PyTorch reference + # This converts raw logits to proper routing probabilities + routing_weights, selected_experts = masked_sampling_omp_inference( + router_logits, + top_k=self.top_k, + jitter_eps=self.router_jitter_noise, + training=False, + ) + + # IMPORTANT: The routing weights from masked_sampling_omp_inference sum to top_k, + # but ONNX Runtime expects normalized probabilities that sum to 1.0 + # Normalize the routing weights per token + routing_weights = routing_weights / routing_weights.sum(dim=1, keepdim=True) + + # Create proper router probabilities tensor that matches PyTorch routing + router_input = torch.zeros_like(router_logits) + for i in range(router_logits.shape[0]): # For each token + for j in range(self.top_k): # For each top-k expert + expert_idx = selected_experts[i, j] + router_input[i, expert_idx] = routing_weights[i, j] + + # print("DEBUG: Using regular MoE routing (processed probabilities)") + + # print(f"DEBUG: router_input stats: mean={router_input.mean():.6f}, std={router_input.std():.6f}") + # print( + # f"DEBUG: hidden_states_flat stats: mean={hidden_states_flat.mean():.6f}, std={hidden_states_flat.std():.6f}" + # ) + torch_dtype = onnx_to_torch_type_map[self.onnx_dtype] tensors = { "input": hidden_states_flat.clone().to(device=device, dtype=torch_dtype), - "router_probs": router_logits.clone().to(device=device, dtype=torch_dtype), + "router_probs": router_input.clone().to(device=device, dtype=torch_dtype), "output": torch.zeros_like(hidden_states_flat, device=device, dtype=torch_dtype), } @@ -656,10 +744,14 @@ def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False buffer_ptr=tensor.data_ptr(), ) + # print("DEBUG: About to run ORT inference...") + iobinding.synchronize_inputs() self.ort_sess.run_with_iobinding(iobinding) iobinding.synchronize_outputs() + # print("DEBUG: ORT inference completed successfully") + if enable_performance_test: repeat = 100 s = time.time() @@ -691,28 +783,10 @@ def recreate_onnx_model(self): w2_scale, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight, is_4_bit) if self.use_swiglu: - if self.swiglu_interleaved: - pass - else: - w3_scale, pre_qweight3, w3_qdq = quant_dequant(self.experts[i].w3.weight, is_4_bit) - - gate_weights = pre_qweight1 - value_weights = pre_qweight3 - gate_scales = w1_scale - value_scales = w3_scale - - pre_qweight1 = torch.cat([gate_weights, value_weights], dim=0) - w1_scale = torch.cat([gate_scales, value_scales], dim=0) - - if self.swiglu_interleaved: - self.experts[i].w1.weight = nn.Parameter(w1_qdq.contiguous().clone()) - - else: - intermediate_size = self.experts[i].w1.weight.shape[0] - gate_dequant = w1_qdq[:intermediate_size].contiguous().clone() - value_dequant = w1_qdq[intermediate_size:].contiguous().clone() - self.experts[i].w1.weight.data = gate_dequant - self.experts[i].w3.weight.data = value_dequant + # For SwiGLU, CPU kernel now always expects interleaved format + # SwigluMlp weights are already in interleaved format [gate_0, linear_0, gate_1, linear_1, ...] + # No conversion needed - both CPU and CUDA use interleaved format + self.experts[i].w1.weight = nn.Parameter(w1_qdq.contiguous().clone()) else: self.experts[i].w1.weight.data = w1_qdq.contiguous().clone() @@ -754,7 +828,7 @@ def recreate_onnx_model(self): use_swiglu=self.use_swiglu, use_quant=True, # Always use QMoE quant_bits=self.quant_bits, - swiglu_interleaved=self.swiglu_interleaved if hasattr(self, "swiglu_interleaved") else False, + swiglu_interleaved=True, # CPU kernel now always expects interleaved format ) except Exception: self.moe_onnx_graph = None @@ -1043,10 +1117,17 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class TestPhiQMoECPU(unittest.TestCase): @parameterized.expand(phi3_test_cases) def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): - torch.manual_seed(42) - numpy.random.seed(42) + # Create unique seed based on test parameters to ensure different inputs for each test + base_seed = 2000 # Different base seed from other tests + param_hash = hash((batch_size, sequence_length, quant_bits)) + unique_seed = base_seed + abs(param_hash) % 1000 - test_config = f"batch_size={batch_size}, sequence_length={sequence_length}, quant_bits={quant_bits}" + torch.manual_seed(unique_seed) + numpy.random.seed(unique_seed) + + test_config = ( + f"batch_size={batch_size}, sequence_length={sequence_length}, quant_bits={quant_bits}, seed={unique_seed}" + ) print(f"Running Phi3 QMoE test: {test_config}") config = PhiMoEConfig(hidden_size=128, intermediate_size=256, num_local_experts=4, num_experts_per_tok=2) @@ -1086,10 +1167,17 @@ def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): class TestSwigluQMoECPU(unittest.TestCase): @parameterized.expand(swiglu_test_cases) def test_swiglu_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): - torch.manual_seed(42) - numpy.random.seed(42) + # Create unique seed based on test parameters to ensure different inputs for each test + base_seed = 1000 # Different base seed from regular MoE tests + param_hash = hash((batch_size, sequence_length, quant_bits)) + unique_seed = base_seed + abs(param_hash) % 1000 + + torch.manual_seed(unique_seed) + numpy.random.seed(unique_seed) - test_config = f"batch_size={batch_size}, sequence_length={sequence_length}, quant_bits={quant_bits}" + test_config = ( + f"batch_size={batch_size}, sequence_length={sequence_length}, quant_bits={quant_bits}, seed={unique_seed}" + ) print(f"Running SwiGLU test: {test_config}") config = SwigluMoeConfig(hidden_size=128, intermediate_size=256, num_local_experts=4, num_experts_per_token=2) @@ -1114,5 +1202,173 @@ def test_swiglu_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): swiglu_moe.parity_check() +@unittest.skipIf(True, "Skipping QMoE CPU benchmark tests") +class TestQMoESwiGLUBenchmark(unittest.TestCase): + """Benchmark tests for QMoE SwiGLU performance measurement.""" + + def test_qmoe_swiglu_throughput_benchmark(self): + """Comprehensive throughput benchmark for QMoE SwiGLU across different configurations.""" + if disable_cpu_qmoe_tests: + self.skipTest("QMoE CPU tests disabled") + + print("\n=== QMoE SwiGLU Throughput Benchmark ===") + + # Test configurations: (name, hidden_size, intermediate_size, num_experts, top_k, quant_bits) + configs = [ + ("Medium-4bit", 2880, 2880, 32, 4, 4), + ("Medium-8bit", 2880, 2880, 32, 4, 8), + ] + + batch_size = 1 + sequence_length = 512 + num_runs = 30 + + results = [] + + for config_name, hidden_size, intermediate_size, num_experts, top_k, quant_bits in configs: + torch.manual_seed(42) + numpy.random.seed(42) + + print(f"\nTesting {config_name}:") + print(f" Hidden: {hidden_size}, Intermediate: {intermediate_size}") + print(f" Experts: {num_experts}, Top-K: {top_k}, Quant: {quant_bits}-bit") + + try: + # Create config and model + config = PhiMoEConfig( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_local_experts=num_experts, + num_experts_per_tok=top_k, + ) + + qmoe_swiglu = PhiMoESparseMoeBlock( + config, + batch_size=batch_size, + sequence_length=sequence_length, + quant_bits=quant_bits, + onnx_dtype=TensorProto.FLOAT, + ) + + # Create test input with fixed sequence length to match ONNX model + full_hidden_states = torch.randn(batch_size, sequence_length, hidden_size).to(torch.float32) + + # For TTFT simulation, we'll measure single forward pass time + # This represents the time to process one token in autoregressive generation + + # Initialize variables + torch_output = None + ort_output = None + + # Warm up with full context + for _ in range(3): + _ = qmoe_swiglu.forward(full_hidden_states) + + # Benchmark PyTorch TTFT (Time to First Token) + # Measure time for a single forward pass (represents token generation time) + torch.manual_seed(42) + + start_time = time.time() + for _ in range(num_runs): + torch_output = qmoe_swiglu.forward(full_hidden_states) + end_time = time.time() + torch_ttft_ms = (end_time - start_time) / num_runs * 1000 + + # Calculate tokens per second (throughput) + # For sequence generation, this represents the rate at which we can generate tokens + torch_tokens_per_sec = 1000.0 / torch_ttft_ms # 1 token / (time_ms / 1000) + + print(f" PyTorch TTFT: {torch_ttft_ms:.3f} ms (per token generation time)") + print(f" PyTorch Throughput: {torch_tokens_per_sec:.1f} tokens/sec") + + # Benchmark ONNX Runtime + ort_ttft_ms = 0 + ort_tokens_per_sec = 0 + speedup = 0 + throughput_ratio = 0 + max_diff = 0 + + model_updated = qmoe_swiglu.recreate_onnx_model() + if model_updated and qmoe_swiglu.ort_sess is not None: + # Warm up ORT with full context + for _ in range(3): + _ = qmoe_swiglu.ort_forward(full_hidden_states) + + torch.manual_seed(42) + + # Measure ONNX Runtime TTFT (Time to First Token) + start_time = time.time() + for _ in range(num_runs): + ort_output = qmoe_swiglu.ort_forward(full_hidden_states) + end_time = time.time() + ort_ttft_ms = (end_time - start_time) / num_runs * 1000 + + # Calculate tokens per second for ONNX Runtime + ort_tokens_per_sec = 1000.0 / ort_ttft_ms # 1 token / (time_ms / 1000) + + speedup = torch_ttft_ms / ort_ttft_ms if ort_ttft_ms > 0 else 0 + throughput_ratio = ort_tokens_per_sec / torch_tokens_per_sec if torch_tokens_per_sec > 0 else 0 + + print(f" ONNX RT TTFT: {ort_ttft_ms:.3f} ms (per token generation time)") + print(f" ONNX RT Throughput: {ort_tokens_per_sec:.1f} tokens/sec") + print(f" TTFT Speedup: {speedup:.2f}x") + print(f" Throughput Gain: {throughput_ratio:.2f}x") + else: + print(" ONNX RT: Not available") + + # Calculate max difference if both outputs available + if torch_output is not None and ort_output is not None: + max_diff = (torch_output.cpu() - ort_output.cpu()).abs().max().item() + print(f" Max diff: {max_diff:.6f}") + + results.append( + { + "config": config_name, + "torch_ttft_ms": torch_ttft_ms, + "torch_tokens_per_sec": torch_tokens_per_sec, + "ort_ttft_ms": ort_ttft_ms, + "ort_tokens_per_sec": ort_tokens_per_sec, + "speedup": speedup, + "throughput_ratio": throughput_ratio, + "max_diff": max_diff, + } + ) + + except Exception as e: + print(f" Error: {e}") + continue + + # Summary + print("\n=== Token Generation Time & Throughput Summary ===") + print( + f"{'Config':<15} {'PT Time':<10} {'PT tok/s':<10} {'ORT Time':<11} {'ORT tok/s':<11} {'Time Gain':<10} {'Throughput':<11} {'Max Diff':<10}" + ) + print("-" * 105) + for result in results: + config = result["config"] + torch_ttft = result["torch_ttft_ms"] + torch_tps = result["torch_tokens_per_sec"] + ort_ttft = result["ort_ttft_ms"] + ort_tps = result["ort_tokens_per_sec"] + speedup = result["speedup"] + throughput_ratio = result["throughput_ratio"] + max_diff = result["max_diff"] + + ort_ttft_str = f"{ort_ttft:.3f}" if ort_ttft > 0 else "N/A" + ort_tps_str = f"{ort_tps:.1f}" if ort_tps > 0 else "N/A" + speedup_str = f"{speedup:.2f}x" if speedup > 0 else "N/A" + throughput_str = f"{throughput_ratio:.2f}x" if throughput_ratio > 0 else "N/A" + + print( + f"{config:<15} {torch_ttft:<10.3f} {torch_tps:<10.1f} {ort_ttft_str:<11} {ort_tps_str:<11} {speedup_str:<10} {throughput_str:<11} {max_diff:<10.6f}" + ) + + print("\nNotes:") + print("- Time: Token generation time in ms (lower is better)") + print("- tok/s: Tokens per second throughput (higher is better)") + print("- Time Gain: ORT speedup for latency (higher is better)") + print("- Throughput: ORT throughput improvement (higher is better)") + + if __name__ == "__main__": unittest.main() From 919b894a340b58c80dd47ba70290cb578318ac0b Mon Sep 17 00:00:00 2001 From: Akshay Sonawane <111780983+apsonawane@users.noreply.github.com> Date: Mon, 15 Sep 2025 08:32:55 -0700 Subject: [PATCH 04/10] [CPU] Block-wise QMoE kernel for CPU (#26009) This PR adds block-wise quant kernel for QMoE CPU --- onnxruntime/contrib_ops/cpu/moe/moe_helper.h | 63 +- .../cpu/moe/moe_quantization_cpu.cc | 936 +++++++++++++++--- .../cuda/collective/sharded_moe.cc | 3 +- onnxruntime/contrib_ops/cuda/moe/moe.cc | 3 +- .../cuda/quantization/moe_quantization.cc | 3 +- .../test/python/transformers/test_qmoe_cpu.py | 487 +++++++-- 6 files changed, 1268 insertions(+), 227 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_helper.h b/onnxruntime/contrib_ops/cpu/moe/moe_helper.h index e494719464d20..39249f842e632 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_helper.h +++ b/onnxruntime/contrib_ops/cpu/moe/moe_helper.h @@ -49,7 +49,8 @@ Status CheckInputs(MoEParameters& parameters, const Tensor* fc3_experts_bias, // optional const Tensor* fc3_experts_scales, // required for qMoE; NULL for MOE const int64_t pack_size, // number of weights packed together (like 2 for uint4 packed to uint8) - const bool is_fused_swiglu) { + const bool is_fused_swiglu, + const int64_t block_size = 0) { // block size for block-wise quantization // Check dimensions of input to avoid input_dims index out of range. CHECK_TENSOR_SHAPE will verify each tensor later. ASSERT_TENSOR_2D_OR_3D(input); ASSERT_TENSOR_3D(fc1_experts_weights); @@ -90,9 +91,63 @@ Status CheckInputs(MoEParameters& parameters, CHECK_TENSOR_SHAPE(fc2_experts_bias, num_experts, hidden_size); CHECK_TENSOR_SHAPE(fc3_experts_bias, num_experts, inter_size); - CHECK_TENSOR_SHAPE(fc1_experts_scales, num_experts, fc1_inter_size); - CHECK_TENSOR_SHAPE(fc2_experts_scales, num_experts, hidden_size); - CHECK_TENSOR_SHAPE(fc3_experts_scales, num_experts, inter_size); + // Validate scale tensors: Handle both row-wise and block-wise quantization flexibly + // First, detect the actual quantization method from the tensor shapes + bool is_row_wise_quantization = true; + if (fc1_experts_scales != nullptr) { + const auto& fc1_scales_dims = fc1_experts_scales->Shape().GetDims(); + if (fc1_scales_dims.size() == 3 && fc1_scales_dims[2] > 1) { + is_row_wise_quantization = false; + } + } + + if (block_size > 0 && !is_row_wise_quantization) { + // Block-wise quantization: 3D scale tensors + // For block-wise quantization, we calculate the number of blocks using ceiling division + // to handle cases where the dimension is not perfectly divisible by block_size + const int64_t fc1_blocks_per_row = (hidden_size + block_size - 1) / block_size; + const int64_t fc2_blocks_per_row = (inter_size + block_size - 1) / block_size; + const int64_t fc3_blocks_per_row = (hidden_size + block_size - 1) / block_size; + + CHECK_TENSOR_SHAPE(fc1_experts_scales, num_experts, fc1_inter_size, fc1_blocks_per_row); + CHECK_TENSOR_SHAPE(fc2_experts_scales, num_experts, hidden_size, fc2_blocks_per_row); + CHECK_TENSOR_SHAPE(fc3_experts_scales, num_experts, inter_size, fc3_blocks_per_row); + } else { + // Row-wise quantization: 2D scale tensors or 3D with last dimension = 1 + // Handle both {num_experts, features} and {num_experts, features, 1} shapes + if (fc1_experts_scales != nullptr) { + const auto& fc1_scales_dims = fc1_experts_scales->Shape().GetDims(); + if (fc1_scales_dims.size() == 2) { + CHECK_TENSOR_SHAPE(fc1_experts_scales, num_experts, fc1_inter_size); + } else if (fc1_scales_dims.size() == 3) { + CHECK_TENSOR_SHAPE(fc1_experts_scales, num_experts, fc1_inter_size, 1); + } else { + ORT_THROW("fc1_experts_scales must be 2D or 3D tensor"); + } + } + + if (fc2_experts_scales != nullptr) { + const auto& fc2_scales_dims = fc2_experts_scales->Shape().GetDims(); + if (fc2_scales_dims.size() == 2) { + CHECK_TENSOR_SHAPE(fc2_experts_scales, num_experts, hidden_size); + } else if (fc2_scales_dims.size() == 3) { + CHECK_TENSOR_SHAPE(fc2_experts_scales, num_experts, hidden_size, 1); + } else { + ORT_THROW("fc2_experts_scales must be 2D or 3D tensor"); + } + } + + if (fc3_experts_scales != nullptr) { + const auto& fc3_scales_dims = fc3_experts_scales->Shape().GetDims(); + if (fc3_scales_dims.size() == 2) { + CHECK_TENSOR_SHAPE(fc3_experts_scales, num_experts, inter_size); + } else if (fc3_scales_dims.size() == 3) { + CHECK_TENSOR_SHAPE(fc3_experts_scales, num_experts, inter_size, 1); + } else { + ORT_THROW("fc3_experts_scales must be 2D or 3D tensor"); + } + } + } if (fc3_experts_weights == nullptr) { ORT_ENFORCE(fc3_experts_bias == nullptr && fc3_experts_scales == nullptr); diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc index 5c6c3b919b572..8195c9438d408 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -2,12 +2,16 @@ // Licensed under the MIT License. #include "contrib_ops/cpu/moe/moe_quantization_cpu.h" - #include "core/framework/allocator.h" #include "core/framework/float16.h" #include "core/mlas/inc/mlas.h" +#include "core/mlas/inc/mlas_q4.h" #include "core/platform/threadpool.h" #include "core/providers/cpu/math/gemm_helper.h" +#include "core/providers/cpu/activation/activations.h" +#include "core/common/safeint.h" +#include "core/framework/tensor_type_and_shape.h" +#include "core/util/math.h" #include "contrib_ops/cpu/moe/moe_utils.h" #include "contrib_ops/cpu/moe/moe_helper.h" @@ -17,44 +21,325 @@ #include #include +namespace { +inline int64_t GetOptimalBlockSize(int64_t total_elements, int num_threads) { + if (total_elements <= 0 || num_threads <= 0) return 64; + const int64_t l1_cache_elements = 8192; // ~32KB / 4 bytes per float + const int64_t divisor = std::max(1, num_threads > 1 ? 4 : 2); + const int64_t base_block_size = l1_cache_elements / divisor; + const int64_t max_block = std::max(int64_t{32}, total_elements / std::max(int64_t{1}, int64_t{4})); + return std::clamp(base_block_size, int64_t{32}, std::min(int64_t{512}, max_block)); +} + +inline int64_t GetUnrollFactor(int64_t vector_size) { + if (vector_size <= 0) return 2; + if (vector_size >= 512) return 16; + if (vector_size >= 128) return 8; + if (vector_size >= 32) return 4; + return 2; +} + +inline bool ShouldUseMemcpy(int64_t size) { + return size >= 64; +} + +inline int64_t GetDequantBlockSize(int64_t features, int64_t total_work) { + if (features <= 0 || total_work <= 0) return 16; + const int64_t target_block_size = std::max(int64_t{16}, features / std::max(int64_t{1}, int64_t{8})); + const int64_t work_based_size = std::max(int64_t{16}, total_work / std::max(int64_t{1}, int64_t{4})); + return std::min(target_block_size, work_based_size); +} + +bool CanUseMlasQ4Dequant(int64_t num_bits, int64_t block_size) { + if (num_bits != 4) { + return false; + } + + return true; +} + +bool CanUseMlasQ4Gemm(int64_t expert_weight_bits, int64_t block_size, + int64_t rows, int64_t cols, MLAS_BLK_QUANT_TYPE& out_qtype) { + if (expert_weight_bits != 4) { + return false; + } + + if (block_size == 64) { + out_qtype = BlkQ4Sym64; + } else if (block_size == 128) { + out_qtype = BlkQ4Sym128; + } else if (block_size == 0) { + out_qtype = BlkQ4Sym; + } else { + return false; + } + + size_t expected_size = MlasQ4GemmPackBSize(out_qtype, static_cast(cols), static_cast(rows)); + return expected_size > 0; +} + +} // namespace + namespace onnxruntime { namespace contrib { -// Helper function to dequantize weights. Supports 4-bit and 8-bit symmetric quantization. -// The source quantized weights are stored as a row-major representation of the transposed -// logical weight matrix (W^T). This function dequantizes it into a float row-major W^T matrix. template -void DequantizeBlock(const uint8_t* quantized_data, - const TScale* scales, - int64_t /*block_size*/, - int64_t num_bits, - int64_t rows, - int64_t cols, - float* dequantized_data) { +void DequantizeBlockWithMlas(const uint8_t* quantized_data, + const TScale* scales, + int64_t block_size, + int64_t num_bits, + int64_t rows, + int64_t cols, + float* dequantized_data, + MLAS_THREADPOOL* thread_pool); + +template +Status ConvertToMlasQ4Format(const uint8_t* quantized_data, + const TScale* scales, + int64_t block_size, + int64_t num_bits, + int64_t rows, + int64_t cols, + MLAS_BLK_QUANT_TYPE qtype, + AllocatorPtr allocator, + IAllocatorUniquePtr& mlas_packed_buffer) { + if (num_bits != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Only 4-bit quantization supported for MLAS Q4 format conversion"); + } + + auto temp_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(rows * cols)); + float* temp_float = temp_float_buffer.get(); + + DequantizeBlockWithMlas(quantized_data, scales, block_size, num_bits, rows, cols, temp_float, nullptr); + + size_t packed_size = MlasQ4GemmPackBSize(qtype, static_cast(cols), static_cast(rows)); + if (packed_size == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "MLAS Q4 packing not supported for this configuration"); + } + + mlas_packed_buffer = IAllocator::MakeUniquePtr(allocator, packed_size); + MlasQ4GemmPackB(qtype, mlas_packed_buffer.get(), temp_float, static_cast(cols), static_cast(rows), static_cast(cols)); + + return Status::OK(); +} + +Status DirectQ4Gemm(const float* A, + const uint8_t* mlas_packed_B, + const float* bias, + float* C, + int64_t M, + int64_t N, + int64_t K, + MLAS_BLK_QUANT_TYPE qtype, + MLAS_THREADPOOL* thread_pool) { + MLAS_Q4_GEMM_DATA_PARAMS params; + params.A = A; + params.lda = static_cast(K); + params.B = mlas_packed_B; + params.Bias = bias; + params.C = C; + params.ldc = static_cast(N); + params.OutputProcessor = nullptr; + + MlasQ4GemmBatch(qtype, static_cast(M), static_cast(N), static_cast(K), 1, ¶ms, thread_pool); + return Status::OK(); +} + +template +void DequantizeBlockWithMlas(const uint8_t* quantized_data, + const TScale* scales, + int64_t block_size, + int64_t num_bits, + int64_t rows, + int64_t cols, + float* dequantized_data, + MLAS_THREADPOOL* thread_pool) { const float zero_point = num_bits == 8 ? 128.0f : 8.0f; - if (num_bits == 8) { - for (int64_t r = 0; r < rows; ++r) { - const float scale = static_cast(scales[r]); - for (int64_t c = 0; c < cols; ++c) { - // Symmetric quantization: dequantized_value = scale * (quantized_value - zero_point) - dequantized_data[r * cols + c] = scale * (static_cast(quantized_data[r * cols + c]) - zero_point); + const int64_t blocks_per_row = (block_size > 0) ? ((cols + block_size - 1) / block_size) : 1; + + if (CanUseMlasQ4Dequant(num_bits, block_size)) { + const int64_t packed_cols = (cols + 1) / 2; + + if (block_size == 0) { + for (int64_t r = 0; r < rows; ++r) { + const uint8_t* row_data = quantized_data + r * packed_cols; + float* row_output = dequantized_data + r * cols; + const float scale = static_cast(scales[r]); + + int64_t c = 0; + for (; c + 8 <= cols; c += 8) { + const uint8_t packed_val0 = row_data[(c + 0) / 2]; + const uint8_t packed_val1 = row_data[(c + 2) / 2]; + const uint8_t packed_val2 = row_data[(c + 4) / 2]; + const uint8_t packed_val3 = row_data[(c + 6) / 2]; + + row_output[c + 0] = scale * (static_cast(packed_val0 & 0x0F) - zero_point); + row_output[c + 1] = scale * (static_cast(packed_val0 >> 4) - zero_point); + row_output[c + 2] = scale * (static_cast(packed_val1 & 0x0F) - zero_point); + row_output[c + 3] = scale * (static_cast(packed_val1 >> 4) - zero_point); + row_output[c + 4] = scale * (static_cast(packed_val2 & 0x0F) - zero_point); + row_output[c + 5] = scale * (static_cast(packed_val2 >> 4) - zero_point); + row_output[c + 6] = scale * (static_cast(packed_val3 & 0x0F) - zero_point); + row_output[c + 7] = scale * (static_cast(packed_val3 >> 4) - zero_point); + } + + for (; c < cols; c += 2) { + const uint8_t packed_val = row_data[c / 2]; + const uint8_t val0 = packed_val & 0x0F; + const uint8_t val1 = packed_val >> 4; + + row_output[c] = scale * (static_cast(val0) - zero_point); + if (c + 1 < cols) { + row_output[c + 1] = scale * (static_cast(val1) - zero_point); + } + } } + return; + } else { + for (int64_t r = 0; r < rows; ++r) { + const uint8_t* row_data = quantized_data + r * packed_cols; + float* row_output = dequantized_data + r * cols; + + for (int64_t block_start = 0; block_start < cols; block_start += block_size) { + const int64_t block_end = std::min(block_start + block_size, cols); + const int64_t block_idx = std::min(block_start / block_size, blocks_per_row - 1); + const int64_t scale_idx = r * blocks_per_row + block_idx; + const float scale = static_cast(scales[scale_idx]); + + int64_t c = block_start; + for (; c + 4 <= block_end; c += 4) { + const uint8_t packed_val0 = row_data[(c + 0) / 2]; + const uint8_t packed_val1 = row_data[(c + 2) / 2]; + + row_output[c + 0] = scale * (static_cast(packed_val0 & 0x0F) - zero_point); + row_output[c + 1] = scale * (static_cast(packed_val0 >> 4) - zero_point); + row_output[c + 2] = scale * (static_cast(packed_val1 & 0x0F) - zero_point); + row_output[c + 3] = scale * (static_cast(packed_val1 >> 4) - zero_point); + } + + for (; c < block_end; c += 2) { + const uint8_t packed_val = row_data[c / 2]; + const uint8_t val0 = packed_val & 0x0F; + const uint8_t val1 = packed_val >> 4; + + row_output[c] = scale * (static_cast(val0) - zero_point); + if (c + 1 < block_end) { + row_output[c + 1] = scale * (static_cast(val1) - zero_point); + } + } + } + } + return; } - } else if (num_bits == 4) { - const int64_t packed_cols = (cols + 1) / 2; + } + + if (num_bits == 8 && block_size == 0) { for (int64_t r = 0; r < rows; ++r) { const float scale = static_cast(scales[r]); - for (int64_t c = 0; c < cols; ++c) { - const uint8_t packed_val = quantized_data[r * packed_cols + c / 2]; - // Unpack the 4-bit value. Low nibble for even columns, high nibble for odd columns. - const uint8_t quantized_val = (c % 2 == 0) ? (packed_val & 0x0F) : (packed_val >> 4); - // Symmetric quantization: dequantized_value = scale * (quantized_value - zero_point) - dequantized_data[r * cols + c] = scale * (static_cast(quantized_val) - zero_point); + const uint8_t zero_pt = static_cast(zero_point); + + MlasDequantizeLinear( + quantized_data + r * cols, + dequantized_data + r * cols, + static_cast(cols), + scale, + zero_pt); + } + } else { + if (num_bits == 8) { + for (int64_t r = 0; r < rows; ++r) { + const uint8_t* row_data = quantized_data + r * cols; + float* row_output = dequantized_data + r * cols; + + int64_t c = 0; + if (block_size > 0) { + for (int64_t block_start = 0; block_start < cols; block_start += block_size) { + const int64_t block_end = std::min(block_start + block_size, cols); + const int64_t block_idx = std::min(block_start / block_size, blocks_per_row - 1); + const int64_t scale_idx = r * blocks_per_row + block_idx; + const float scale = static_cast(scales[scale_idx]); + + for (c = block_start; c + 4 <= block_end; c += 4) { + row_output[c] = scale * (static_cast(row_data[c]) - zero_point); + row_output[c + 1] = scale * (static_cast(row_data[c + 1]) - zero_point); + row_output[c + 2] = scale * (static_cast(row_data[c + 2]) - zero_point); + row_output[c + 3] = scale * (static_cast(row_data[c + 3]) - zero_point); + } + for (; c < block_end; ++c) { + row_output[c] = scale * (static_cast(row_data[c]) - zero_point); + } + } + } else { + const float scale = static_cast(scales[r]); + for (; c + 8 <= cols; c += 8) { + row_output[c] = scale * (static_cast(row_data[c]) - zero_point); + row_output[c + 1] = scale * (static_cast(row_data[c + 1]) - zero_point); + row_output[c + 2] = scale * (static_cast(row_data[c + 2]) - zero_point); + row_output[c + 3] = scale * (static_cast(row_data[c + 3]) - zero_point); + row_output[c + 4] = scale * (static_cast(row_data[c + 4]) - zero_point); + row_output[c + 5] = scale * (static_cast(row_data[c + 5]) - zero_point); + row_output[c + 6] = scale * (static_cast(row_data[c + 6]) - zero_point); + row_output[c + 7] = scale * (static_cast(row_data[c + 7]) - zero_point); + } + for (; c < cols; ++c) { + row_output[c] = scale * (static_cast(row_data[c]) - zero_point); + } + } + } + } else if (num_bits == 4) { + const int64_t packed_cols = (cols + 1) / 2; + for (int64_t r = 0; r < rows; ++r) { + const uint8_t* row_data = quantized_data + r * packed_cols; + float* row_output = dequantized_data + r * cols; + + if (block_size > 0) { + for (int64_t block_start = 0; block_start < cols; block_start += block_size) { + const int64_t block_end = std::min(block_start + block_size, cols); + const int64_t block_idx = std::min(block_start / block_size, blocks_per_row - 1); + const int64_t scale_idx = r * blocks_per_row + block_idx; + const float scale = static_cast(scales[scale_idx]); + + for (int64_t c = block_start; c < block_end; c += 2) { + const uint8_t packed_val = row_data[c / 2]; + const uint8_t val0 = packed_val & 0x0F; + const uint8_t val1 = packed_val >> 4; + + row_output[c] = scale * (static_cast(val0) - zero_point); + if (c + 1 < block_end) { + row_output[c + 1] = scale * (static_cast(val1) - zero_point); + } + } + } + } else { + const float scale = static_cast(scales[r]); + for (int64_t c = 0; c < cols; c += 2) { + const uint8_t packed_val = row_data[c / 2]; + const uint8_t val0 = packed_val & 0x0F; + const uint8_t val1 = packed_val >> 4; + + row_output[c] = scale * (static_cast(val0) - zero_point); + if (c + 1 < cols) { + row_output[c + 1] = scale * (static_cast(val1) - zero_point); + } + } + } } } } } +template +void DequantizeBlock(const uint8_t* quantized_data, + const TScale* scales, + int64_t block_size, + int64_t num_bits, + int64_t rows, + int64_t cols, + float* dequantized_data, + MLAS_THREADPOOL* thread_pool = nullptr) { + DequantizeBlockWithMlas(quantized_data, scales, block_size, num_bits, rows, cols, dequantized_data, thread_pool); +} + template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info), @@ -63,11 +348,15 @@ QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info) ORT_ENFORCE(expert_weight_bits_ == 4 || expert_weight_bits_ == 8, "Attribute 'expert_weight_bits' must be 4 or 8."); block_size_ = op_kernel_info.GetAttrOrDefault("block_size", 0); + + if (block_size_ > 0) { + ORT_ENFORCE(block_size_ >= 16, "block_size must be >= 16 when provided."); + ORT_ENFORCE((block_size_ & (block_size_ - 1)) == 0, "block_size must be a power of 2."); + } } template Status QMoECPU::Compute(OpKernelContext* context) const { - // --- 1. Get Inputs and Attributes --- const auto* input = context->Input(0); const auto* router_probs = context->Input(1); const auto* fc1_experts_weights = context->Input(2); @@ -87,7 +376,8 @@ Status QMoECPU::Compute(OpKernelContext* context) const { fc2_experts_weights, fc2_experts_bias, fc2_scales, fc3_experts_weights, fc3_experts_bias, fc3_scales, expert_weight_bits_ == 4 ? 2 : 1, - true)); + true, + block_size_)); if (fc3_experts_weights || fc3_experts_bias || fc3_scales) { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "FC3 gating is not yet implemented on CPU for QMoE"); @@ -109,19 +399,17 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const size_t output_buffer_size = static_cast(output->Shape().Size()); const T* input_data = input->Data(); - const T* router_probs_data = router_probs->Data(); - // --- 2. Routing Logic: Assign tokens to experts --- IAllocatorUniquePtr router_logits_float_buffer; const float* router_logits_float; if constexpr (std::is_same_v) { router_logits_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * num_experts)); router_logits_float = router_logits_float_buffer.get(); - MlasConvertHalfToFloatBuffer(reinterpret_cast(router_probs_data), + MlasConvertHalfToFloatBuffer(reinterpret_cast(router_probs->Data()), const_cast(router_logits_float), static_cast(num_tokens * num_experts)); } else { - router_logits_float = reinterpret_cast(router_probs_data); + router_logits_float = reinterpret_cast(router_probs->Data()); } auto route_expert_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * k_)); @@ -129,36 +417,37 @@ Status QMoECPU::Compute(OpKernelContext* context) const { auto route_scale_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_tokens * k_)); float* route_scale = route_scale_ptr.get(); - // Parallelize the routing logic to improve performance for large token batches. - // Minor performance regression for single-token decoding is an acceptable trade-off - int num_routing_threads = (tp == nullptr || num_tokens < 4096) ? 1 : std::min(static_cast(num_tokens), concurrency::ThreadPool::DegreeOfParallelism(tp)); + const int max_threads = tp ? concurrency::ThreadPool::DegreeOfParallelism(tp) : 1; + const int64_t thread_divisor = std::max(1, max_threads * 4); + const int64_t min_work_per_thread = std::max(int64_t{32}, static_cast(num_tokens / thread_divisor)); + const int optimal_routing_threads = (tp == nullptr || num_tokens < min_work_per_thread) ? 1 : std::min(static_cast(num_tokens / std::max(int64_t{1}, min_work_per_thread)), max_threads); + const int num_routing_threads = std::max(1, optimal_routing_threads); std::vector>> thread_local_expert_token_maps(num_routing_threads); for (auto& map : thread_local_expert_token_maps) { map.resize(static_cast(num_experts)); + for (auto& expert_tokens : map) { + expert_tokens.reserve(32); + } } concurrency::ThreadPool::TrySimpleParallelFor(tp, num_routing_threads, [&](std::ptrdiff_t thread_id) { auto work = concurrency::ThreadPool::PartitionWork(static_cast(thread_id), num_routing_threads, static_cast(num_tokens)); auto& local_expert_token_map = thread_local_expert_token_maps[thread_id]; - // Pre-allocate buffers for this thread to reuse, avoiding allocations inside the loop. std::vector> sorted_logits(static_cast(num_experts)); std::vector top_k_exp(static_cast(k_)); for (int64_t i = work.start; i < work.end; ++i) { const float* logits = router_logits_float + i * num_experts; + for (int64_t j = 0; j < num_experts; ++j) { sorted_logits[static_cast(j)] = {logits[j], j}; } - std::partial_sort(sorted_logits.begin(), sorted_logits.begin() + static_cast(k_), sorted_logits.end(), std::greater<>()); + std::partial_sort(sorted_logits.begin(), sorted_logits.begin() + static_cast(k_), + sorted_logits.end(), std::greater<>()); - float max_logit = -std::numeric_limits::infinity(); - for (int64_t j = 0; j < k_; ++j) { - if (sorted_logits[static_cast(j)].first > max_logit) { - max_logit = sorted_logits[static_cast(j)].first; - } - } + float max_logit = sorted_logits[0].first; float sum_exp = 0.0f; for (int64_t j = 0; j < k_; ++j) { @@ -166,20 +455,19 @@ Status QMoECPU::Compute(OpKernelContext* context) const { sum_exp += top_k_exp[static_cast(j)]; } - float scale = (sum_exp == 0.0f) ? 0.0f : (1.0f / sum_exp); + const float inv_sum = (sum_exp == 0.0f) ? 0.0f : (1.0f / sum_exp); for (int64_t j = 0; j < k_; ++j) { int64_t expert_idx = sorted_logits[static_cast(j)].second; int64_t route_idx = i * k_ + j; route_expert[route_idx] = static_cast(expert_idx); - route_scale[route_idx] = top_k_exp[static_cast(j)] * scale; - if (route_scale[route_idx] > 0.0f) { + route_scale[route_idx] = top_k_exp[static_cast(j)] * inv_sum; + if (route_scale[route_idx] > 1e-8f) { // Use small threshold to avoid zero weights local_expert_token_map[static_cast(expert_idx)].push_back(route_idx); } } } }); - // Merge the maps from each thread into a single global map. std::vector> expert_token_map(static_cast(num_experts)); for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { size_t total_tokens_for_expert = 0; @@ -187,18 +475,17 @@ Status QMoECPU::Compute(OpKernelContext* context) const { total_tokens_for_expert += thread_local_expert_token_maps[t][static_cast(expert_idx)].size(); } expert_token_map[static_cast(expert_idx)].reserve(total_tokens_for_expert); - } - for (int t = 0; t < num_routing_threads; ++t) { - for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + for (int t = 0; t < num_routing_threads; ++t) { auto& local_tokens = thread_local_expert_token_maps[t][static_cast(expert_idx)]; if (!local_tokens.empty()) { - expert_token_map[static_cast(expert_idx)].insert(expert_token_map[static_cast(expert_idx)].end(), local_tokens.begin(), local_tokens.end()); + expert_token_map[static_cast(expert_idx)].insert( + expert_token_map[static_cast(expert_idx)].end(), + local_tokens.begin(), local_tokens.end()); } } } - // --- 3. Parallel Expert Computation --- IAllocatorUniquePtr input_float_buffer; const float* input_float; if constexpr (std::is_same_v) { @@ -211,118 +498,434 @@ Status QMoECPU::Compute(OpKernelContext* context) const { input_float = reinterpret_cast(input_data); } - int num_expert_threads = (tp == nullptr) ? 1 : std::min(static_cast(num_experts), concurrency::ThreadPool::DegreeOfParallelism(tp)); + const int max_expert_threads = tp ? concurrency::ThreadPool::DegreeOfParallelism(tp) : 1; + const int64_t total_expert_work = std::accumulate(expert_token_map.begin(), expert_token_map.end(), 0LL, + [](int64_t sum, const std::vector& tokens) { return sum + static_cast(tokens.size()); }); + const int64_t expert_thread_divisor = std::max(1, max_expert_threads * 8); + const int64_t min_expert_work_per_thread = std::max(int64_t{16}, total_expert_work / expert_thread_divisor); + + int num_expert_threads = (tp == nullptr || total_expert_work < min_expert_work_per_thread) ? 1 : std::min(static_cast(total_expert_work / std::max(int64_t{1}, min_expert_work_per_thread)), std::min(static_cast(num_experts), max_expert_threads)); if (num_expert_threads == 0) num_expert_threads = 1; + auto thread_local_outputs_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_expert_threads) * output_buffer_size); float* thread_local_outputs = thread_local_outputs_ptr.get(); - memset(thread_local_outputs, 0, static_cast(num_expert_threads) * output_buffer_size * sizeof(float)); + std::memset(thread_local_outputs, 0, static_cast(num_expert_threads) * output_buffer_size * sizeof(float)); - // Pre-calculate workspace size per thread to avoid allocations inside the loop size_t max_tokens_per_expert = 0; for (const auto& tokens : expert_token_map) { - if (tokens.size() > max_tokens_per_expert) { - max_tokens_per_expert = tokens.size(); - } + max_tokens_per_expert = std::max(max_tokens_per_expert, tokens.size()); } - const size_t A1_size = static_cast(max_tokens_per_expert * hidden_size); - const size_t C1_size = static_cast(max_tokens_per_expert * fc1_out_features); - const size_t A2_size = static_cast(max_tokens_per_expert * inter_size); - const size_t C2_size = static_cast(max_tokens_per_expert * hidden_size); - const size_t B1_dequant_size = static_cast(fc1_out_features * hidden_size); - const size_t B2_dequant_size = static_cast(hidden_size * inter_size); - const size_t bias1_size = static_cast(fc1_out_features); - const size_t bias2_size = static_cast(hidden_size); + const auto align_size = [](size_t size) -> size_t { + return (size + 63) & ~63; + }; + + const size_t A1_size = align_size(static_cast(max_tokens_per_expert) * static_cast(hidden_size)); + const size_t C1_size = align_size(static_cast(max_tokens_per_expert) * static_cast(fc1_out_features)); + const size_t A2_size = align_size(static_cast(max_tokens_per_expert) * static_cast(inter_size)); + const size_t C2_size = align_size(static_cast(max_tokens_per_expert) * static_cast(hidden_size)); + const size_t B1_dequant_size = align_size(static_cast(fc1_out_features) * static_cast(hidden_size)); + const size_t B2_dequant_size = align_size(static_cast(hidden_size) * static_cast(inter_size)); + + const size_t workspace_elements_per_thread = A1_size + C1_size + A2_size + C2_size + + B1_dequant_size + B2_dequant_size; - const size_t workspace_elements_per_thread = A1_size + C1_size + A2_size + C2_size + B1_dequant_size + B2_dequant_size + bias1_size + bias2_size; auto workspace_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_expert_threads) * workspace_elements_per_thread); float* workspace = workspace_ptr.get(); + auto bias_conversion_buffers_ptr = IAllocator::MakeUniquePtr(allocator, + static_cast(num_expert_threads) * (static_cast(fc1_out_features) + static_cast(hidden_size))); + float* bias_conversion_buffers = bias_conversion_buffers_ptr.get(); + + const auto& fc1_scales_dims = fc1_scales->Shape().GetDims(); + const auto& fc2_scales_dims = fc2_scales->Shape().GetDims(); + const bool is_fc1_block_wise = (fc1_scales_dims.size() == 3 && fc1_scales_dims[2] > 1); + const bool is_fc2_block_wise = (fc2_scales_dims.size() == 3 && fc2_scales_dims[2] > 1); + + const uint8_t* fc1_weights_data = fc1_experts_weights->Data(); + const uint8_t* fc2_weights_data = fc2_experts_weights->Data(); + const T* fc1_scales_data = fc1_scales->Data(); + const T* fc2_scales_data = fc2_scales->Data(); + const T* fc1_bias_data = fc1_experts_bias ? fc1_experts_bias->Data() : nullptr; + const T* fc2_bias_data = fc2_experts_bias ? fc2_experts_bias->Data() : nullptr; + + const int64_t pack_unit = (8 / expert_weight_bits_); + const int64_t fc1_packed_cols = (hidden_size + pack_unit - 1) / pack_unit; + const int64_t fc2_packed_cols = (inter_size + pack_unit - 1) / pack_unit; + const bool has_fc1_bias = (fc1_bias_data != nullptr); + const bool has_fc2_bias = (fc2_bias_data != nullptr); + + std::vector> expert_workload; + size_t total_work = 0; + + for (int64_t i = 0; i < num_experts; ++i) { + const size_t token_count = expert_token_map[static_cast(i)].size(); + if (token_count > 0) { + expert_workload.emplace_back(i, token_count); + total_work += token_count; + } + } + + if (total_work < 48) { + num_expert_threads = 1; + } else if (total_work < 192) { + num_expert_threads = std::min(num_expert_threads, 2); + } else if (total_work < 512) { + num_expert_threads = std::min(num_expert_threads, 4); + } + + std::sort(expert_workload.begin(), expert_workload.end(), + [](const auto& a, const auto& b) { return a.second > b.second; }); + + std::vector> expert_batches(num_expert_threads); + size_t thread_idx = 0; + for (const auto& work : expert_workload) { + expert_batches[thread_idx].push_back(work.first); + thread_idx = (thread_idx + 1) % static_cast(num_expert_threads); + } + concurrency::ThreadPool::TrySimpleParallelFor(tp, num_expert_threads, [&](std::ptrdiff_t thread_id_pd) { - int thread_id = static_cast(thread_id_pd); - auto work = concurrency::ThreadPool::PartitionWork(thread_id, num_expert_threads, static_cast(num_experts)); + const int thread_id = static_cast(thread_id_pd); + const auto& expert_batch = expert_batches[static_cast(thread_id)]; float* thread_workspace = workspace + static_cast(thread_id) * workspace_elements_per_thread; - for (int64_t expert_idx = work.start; expert_idx < work.end; ++expert_idx) { + float* thread_bias1_buffer = bias_conversion_buffers + static_cast(thread_id) * (static_cast(fc1_out_features) + static_cast(hidden_size)); + float* thread_bias2_buffer = thread_bias1_buffer + static_cast(fc1_out_features); + + for (int64_t expert_idx : expert_batch) { const auto& routes = expert_token_map[static_cast(expert_idx)]; if (routes.empty()) { continue; } - const int64_t num_expert_tokens = routes.size(); + const int64_t num_expert_tokens = static_cast(routes.size()); - // Partition the workspace for the current expert float* A1 = thread_workspace; - float* C1 = A1 + num_expert_tokens * hidden_size; - float* A2 = C1 + num_expert_tokens * fc1_out_features; - float* C2 = A2 + num_expert_tokens * inter_size; - float* B1_dequant = C2 + num_expert_tokens * hidden_size; - float* B2_dequant = B1_dequant + fc1_out_features * hidden_size; - float* bias1_float = B2_dequant + hidden_size * inter_size; - float* bias2_float = bias1_float + fc1_out_features; - - // --- Gather input tokens for the current expert --- - for (int64_t i = 0; i < num_expert_tokens; ++i) { - const int64_t token_idx = routes[static_cast(i)] / k_; - memcpy(A1 + i * hidden_size, - input_float + token_idx * hidden_size, - static_cast(hidden_size) * sizeof(float)); + float* C1 = A1 + A1_size; + float* A2 = C1 + C1_size; + float* C2 = A2 + A2_size; + float* B1_dequant = C2 + C2_size; + float* B2_dequant = B1_dequant + B1_dequant_size; + + const int64_t dynamic_block_size = GetOptimalBlockSize(num_expert_tokens, tp ? concurrency::ThreadPool::DegreeOfParallelism(tp) : 1); + const int64_t num_blocks = (num_expert_tokens + dynamic_block_size - 1) / dynamic_block_size; + + if (num_expert_tokens >= 8 && num_blocks > 1 && tp != nullptr) { + concurrency::ThreadPool::TrySimpleParallelFor(tp, static_cast(num_blocks), [&](std::ptrdiff_t block_idx) { + const int64_t start_idx = block_idx * dynamic_block_size; + const int64_t end_idx = std::min(start_idx + dynamic_block_size, num_expert_tokens); + + for (int64_t i = start_idx; i < end_idx; ++i) { + const int64_t token_idx = routes[static_cast(i)] / k_; + const float* src = input_float + token_idx * hidden_size; + float* dst = A1 + i * hidden_size; + + std::memcpy(dst, src, static_cast(hidden_size) * sizeof(float)); + } + }); + } else { + for (int64_t i = 0; i < num_expert_tokens; ++i) { + const int64_t token_idx = routes[static_cast(i)] / k_; + const float* src = input_float + token_idx * hidden_size; + float* dst = A1 + i * hidden_size; + + if (ShouldUseMemcpy(hidden_size)) { + std::memcpy(dst, src, static_cast(hidden_size) * sizeof(float)); + } else { + const int64_t unroll_factor = GetUnrollFactor(hidden_size); + int64_t j = 0; + for (; j + unroll_factor <= hidden_size; j += unroll_factor) { + for (int64_t k = 0; k < unroll_factor; ++k) { + dst[j + k] = src[j + k]; + } + } + for (; j < hidden_size; ++j) { + dst[j] = src[j]; + } + } + } + } + + const T* fc1_scales_ptr; + + if (is_fc1_block_wise) { + const int64_t fc1_blocks_per_row = fc1_scales_dims[2]; + fc1_scales_ptr = fc1_scales_data + expert_idx * fc1_out_features * fc1_blocks_per_row; + } else { + fc1_scales_ptr = fc1_scales_data + expert_idx * fc1_out_features; } - // --- FC1 GEMM (X * W1^T) --- - DequantizeBlock(fc1_experts_weights->Data() + expert_idx * fc1_out_features * (hidden_size / (8 / expert_weight_bits_)), - fc1_scales->Data() + expert_idx * fc1_out_features * (block_size_ > 0 ? hidden_size / block_size_ : 1), - block_size_, expert_weight_bits_, - fc1_out_features, hidden_size, B1_dequant); + const int64_t dequant_block_size = GetDequantBlockSize(fc1_out_features, num_expert_tokens); + const int64_t num_dequant_blocks = (fc1_out_features + dequant_block_size - 1) / dequant_block_size; + + const size_t m = static_cast(num_expert_tokens); + const size_t n = static_cast(fc1_out_features); + const size_t k = static_cast(hidden_size); + + MLAS_BLK_QUANT_TYPE q_type; + bool use_direct_q4_gemm = CanUseMlasQ4Gemm(expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0, + fc1_out_features, hidden_size, q_type); + bool fc1_used_direct_q4 = false; + bool fc1_bias_handled_by_q4_gemm = false; + + if (use_direct_q4_gemm) { + IAllocatorUniquePtr mlas_packed_fc1; + Status convert_status = ConvertToMlasQ4Format( + fc1_weights_data + expert_idx * fc1_out_features * fc1_packed_cols, + fc1_scales_ptr, + is_fc1_block_wise ? block_size_ : 0, + expert_weight_bits_, + fc1_out_features, + hidden_size, + q_type, + allocator, + mlas_packed_fc1); + + if (convert_status.IsOK()) { + float* fc1_bias_float = nullptr; + IAllocatorUniquePtr fc1_bias_buffer; + + if (has_fc1_bias) { + const T* B1_bias = fc1_bias_data + expert_idx * fc1_out_features; + fc1_bias_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(fc1_out_features)); + fc1_bias_float = fc1_bias_buffer.get(); + + if constexpr (std::is_same_v) { + MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), fc1_bias_float, static_cast(fc1_out_features)); + } else { + for (int64_t i = 0; i < fc1_out_features; ++i) { + fc1_bias_float[i] = static_cast(B1_bias[i]); + } + } + } + + Status gemm_status = DirectQ4Gemm(A1, mlas_packed_fc1.get(), fc1_bias_float, C1, + num_expert_tokens, fc1_out_features, hidden_size, q_type, tp); + + if (gemm_status.IsOK()) { + fc1_used_direct_q4 = true; +#ifdef ONNXRUNTIME_ENABLE_VERBOSE_LOGGING + LOGS_DEFAULT(VERBOSE) << "QMoE: Using direct MLAS Q4 GEMM for FC1 expert " << expert_idx + << " (M=" << num_expert_tokens << ", N=" << fc1_out_features << ", K=" << hidden_size << ")"; +#endif + goto fc1_gemm_done; + } + } + // If direct Q4 GEMM failed, fall back to traditional approach + } + + // Traditional approach: dequantize + regular GEMM + if (num_dequant_blocks > 1 && fc1_out_features >= 32) { + concurrency::ThreadPool::TrySimpleParallelFor(tp, static_cast(num_dequant_blocks), [&](std::ptrdiff_t block_idx) { + const int64_t start_row = block_idx * dequant_block_size; + const int64_t end_row = std::min(start_row + dequant_block_size, fc1_out_features); + const auto offset = expert_idx * fc1_out_features * fc1_packed_cols + start_row * fc1_packed_cols; + DequantizeBlock(fc1_weights_data + offset, + fc1_scales_ptr + (is_fc1_block_wise ? start_row * fc1_scales_dims[2] : start_row), + is_fc1_block_wise ? block_size_ : 0, expert_weight_bits_, + end_row - start_row, hidden_size, B1_dequant + start_row * hidden_size, tp); + }); + } else { + DequantizeBlock(fc1_weights_data + expert_idx * fc1_out_features * fc1_packed_cols, + fc1_scales_ptr, + is_fc1_block_wise ? block_size_ : 0, expert_weight_bits_, + fc1_out_features, hidden_size, B1_dequant, tp); + } MlasGemm(CblasNoTrans, CblasTrans, - static_cast(num_expert_tokens), static_cast(fc1_out_features), static_cast(hidden_size), - 1.0f, A1, static_cast(hidden_size), - B1_dequant, static_cast(hidden_size), - 0.0f, C1, static_cast(fc1_out_features), - nullptr); - - const T* B1_bias = (fc1_experts_bias) ? fc1_experts_bias->Data() + expert_idx * fc1_out_features : nullptr; - if (B1_bias) { + m, n, k, + 1.0f, A1, k, + B1_dequant, k, + 0.0f, C1, n, + tp); + + fc1_bias_handled_by_q4_gemm = fc1_used_direct_q4 && has_fc1_bias; + if (has_fc1_bias && !fc1_bias_handled_by_q4_gemm) { + const T* B1_bias = fc1_bias_data + expert_idx * fc1_out_features; if constexpr (std::is_same_v) { - MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), bias1_float, static_cast(fc1_out_features)); + MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), thread_bias1_buffer, static_cast(fc1_out_features)); } else { - memcpy(bias1_float, B1_bias, static_cast(fc1_out_features) * sizeof(float)); + if (ShouldUseMemcpy(fc1_out_features)) { + std::memcpy(thread_bias1_buffer, B1_bias, static_cast(fc1_out_features) * sizeof(float)); + } else { + const int64_t unroll_factor = GetUnrollFactor(fc1_out_features); + int64_t j = 0; + for (; j + unroll_factor <= fc1_out_features; j += unroll_factor) { + for (int64_t k = 0; k < unroll_factor; ++k) { + thread_bias1_buffer[j + k] = static_cast(B1_bias[j + k]); + } + } + for (; j < fc1_out_features; ++j) { + thread_bias1_buffer[j] = static_cast(B1_bias[j]); + } + } } + for (int64_t i = 0; i < num_expert_tokens; ++i) { - for (int64_t j = 0; j < fc1_out_features; ++j) { - C1[i * fc1_out_features + j] += bias1_float[j]; + float* C1_row = C1 + i * fc1_out_features; + const int64_t unroll_factor = GetUnrollFactor(fc1_out_features); + + int64_t j = 0; + for (; j + unroll_factor <= fc1_out_features; j += unroll_factor) { + for (int64_t k = 0; k < unroll_factor; ++k) { + C1_row[j + k] += thread_bias1_buffer[j + k]; + } + } + for (; j < fc1_out_features; ++j) { + C1_row[j] += thread_bias1_buffer[j]; } } } - // --- Activation --- - for (int64_t i = 0; i < num_expert_tokens; ++i) { - const float* C1_token = C1 + i * fc1_out_features; - float* A2_token = A2 + i * inter_size; - ApplySwiGLUActivation(C1_token, A2_token, inter_size, true, activation_alpha_, activation_beta_, swiglu_limit_); + fc1_gemm_done: + + const int64_t activation_threshold = std::max(int64_t{4}, 256 / std::max(int64_t{1}, inter_size)); + if (num_expert_tokens >= activation_threshold && tp != nullptr) { + const int64_t activation_block_size = std::max(int64_t{1}, std::min(int64_t{64}, activation_threshold)); + const int64_t num_activation_blocks = (num_expert_tokens + activation_block_size - 1) / activation_block_size; + + if (num_activation_blocks > 1) { + concurrency::ThreadPool::TrySimpleParallelFor(tp, static_cast(num_activation_blocks), [&](std::ptrdiff_t block_idx) { + const int64_t start_token = block_idx * activation_block_size; + const int64_t end_token = std::min(start_token + activation_block_size, num_expert_tokens); + + for (int64_t i = start_token; i < end_token; ++i) { + const float* C1_token = C1 + i * fc1_out_features; + float* A2_token = A2 + i * inter_size; + ApplySwiGLUActivation(C1_token, A2_token, inter_size, true, activation_alpha_, activation_beta_, swiglu_limit_); + } + }); + } else { + for (int64_t i = 0; i < num_expert_tokens; ++i) { + const float* C1_token = C1 + i * fc1_out_features; + float* A2_token = A2 + i * inter_size; + ApplySwiGLUActivation(C1_token, A2_token, inter_size, true, activation_alpha_, activation_beta_, swiglu_limit_); + } + } + } else { + for (int64_t i = 0; i < num_expert_tokens; ++i) { + const float* C1_token = C1 + i * fc1_out_features; + float* A2_token = A2 + i * inter_size; + ApplySwiGLUActivation(C1_token, A2_token, inter_size, true, activation_alpha_, activation_beta_, swiglu_limit_); + } + } + + const T* fc2_scales_ptr; + + if (is_fc2_block_wise) { + const int64_t fc2_blocks_per_row = fc2_scales_dims[2]; + fc2_scales_ptr = fc2_scales_data + expert_idx * hidden_size * fc2_blocks_per_row; + } else { + fc2_scales_ptr = fc2_scales_data + expert_idx * hidden_size; } - // --- FC2 GEMM (A2 * W2^T) --- - DequantizeBlock(fc2_experts_weights->Data() + expert_idx * hidden_size * (inter_size / (8 / expert_weight_bits_)), - fc2_scales->Data() + expert_idx * hidden_size * (block_size_ > 0 ? inter_size / block_size_ : 1), - block_size_, expert_weight_bits_, - hidden_size, inter_size, B2_dequant); + const int64_t fc2_dequant_block_size = GetDequantBlockSize(hidden_size, num_expert_tokens); + const int64_t num_fc2_dequant_blocks = (hidden_size + fc2_dequant_block_size - 1) / fc2_dequant_block_size; + + const size_t m2 = static_cast(num_expert_tokens); + const size_t n2 = static_cast(hidden_size); + const size_t k2 = static_cast(inter_size); + + MLAS_BLK_QUANT_TYPE q_type2; + bool use_direct_q4_gemm_fc2 = CanUseMlasQ4Gemm(expert_weight_bits_, is_fc2_block_wise ? block_size_ : 0, + hidden_size, inter_size, q_type2); + bool fc2_used_direct_q4 = false; + + if (use_direct_q4_gemm_fc2) { + IAllocatorUniquePtr mlas_packed_fc2; + Status convert_status = ConvertToMlasQ4Format( + fc2_weights_data + expert_idx * hidden_size * fc2_packed_cols, + fc2_scales_ptr, + is_fc2_block_wise ? block_size_ : 0, + expert_weight_bits_, + hidden_size, + inter_size, + q_type2, + allocator, + mlas_packed_fc2); + + if (convert_status.IsOK()) { + float* fc2_bias_float = nullptr; + IAllocatorUniquePtr fc2_bias_buffer; + + if (has_fc2_bias) { + const T* B2_bias = fc2_bias_data + expert_idx * hidden_size; + fc2_bias_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(hidden_size)); + fc2_bias_float = fc2_bias_buffer.get(); + + if constexpr (std::is_same_v) { + MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), fc2_bias_float, static_cast(hidden_size)); + } else { + for (int64_t i = 0; i < hidden_size; ++i) { + fc2_bias_float[i] = static_cast(B2_bias[i]); + } + } + } + + Status gemm_status = DirectQ4Gemm(A2, mlas_packed_fc2.get(), fc2_bias_float, C2, + num_expert_tokens, hidden_size, inter_size, q_type2, tp); + + if (gemm_status.IsOK()) { + fc2_used_direct_q4 = true; +#ifdef ONNXRUNTIME_ENABLE_VERBOSE_LOGGING + LOGS_DEFAULT(VERBOSE) << "QMoE: Using direct MLAS Q4 GEMM for FC2 expert " << expert_idx + << " (M=" << num_expert_tokens << ", N=" << hidden_size << ", K=" << inter_size << ")"; +#endif + goto fc2_gemm_done; + } + } + + // If direct Q4 GEMM failed, fall back to traditional approach + } + + // Traditional approach: dequantize + regular GEMM + if (num_fc2_dequant_blocks > 1 && hidden_size >= 32) { + concurrency::ThreadPool::TrySimpleParallelFor(tp, static_cast(num_fc2_dequant_blocks), [&](std::ptrdiff_t block_idx) { + const int64_t start_row = block_idx * fc2_dequant_block_size; + const int64_t end_row = std::min(start_row + fc2_dequant_block_size, hidden_size); + const auto offset = expert_idx * hidden_size * fc2_packed_cols + start_row * fc2_packed_cols; + DequantizeBlock(fc2_weights_data + offset, + fc2_scales_ptr + (is_fc2_block_wise ? start_row * fc2_scales_dims[2] : start_row), + is_fc2_block_wise ? block_size_ : 0, expert_weight_bits_, + end_row - start_row, inter_size, B2_dequant + start_row * inter_size, tp); + }); + } else { + DequantizeBlock(fc2_weights_data + expert_idx * hidden_size * fc2_packed_cols, + fc2_scales_ptr, + is_fc2_block_wise ? block_size_ : 0, expert_weight_bits_, + hidden_size, inter_size, B2_dequant, tp); + } MlasGemm(CblasNoTrans, CblasTrans, - static_cast(num_expert_tokens), static_cast(hidden_size), static_cast(inter_size), - 1.0f, A2, static_cast(inter_size), - B2_dequant, static_cast(inter_size), - 0.0f, C2, static_cast(hidden_size), - nullptr); - - const T* B2_bias = (fc2_experts_bias) ? fc2_experts_bias->Data() + expert_idx * hidden_size : nullptr; - if (B2_bias) { + m2, n2, k2, + 1.0f, A2, k2, + B2_dequant, k2, + 0.0f, C2, n2, + tp); + + fc2_gemm_done: + + bool fc2_bias_handled_by_q4_gemm = fc2_used_direct_q4 && has_fc2_bias; + if (has_fc2_bias && !fc2_bias_handled_by_q4_gemm) { + const T* B2_bias = fc2_bias_data + expert_idx * hidden_size; if constexpr (std::is_same_v) { - MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), bias2_float, static_cast(hidden_size)); + MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), thread_bias2_buffer, static_cast(hidden_size)); } else { - memcpy(bias2_float, B2_bias, static_cast(hidden_size) * sizeof(float)); + if (ShouldUseMemcpy(hidden_size)) { + std::memcpy(thread_bias2_buffer, B2_bias, static_cast(hidden_size) * sizeof(float)); + } else { + const int64_t unroll_factor = GetUnrollFactor(hidden_size); + int64_t j = 0; + for (; j + unroll_factor <= hidden_size; j += unroll_factor) { + for (int64_t k = 0; k < unroll_factor; ++k) { + thread_bias2_buffer[j + k] = static_cast(B2_bias[j + k]); + } + } + for (; j < hidden_size; ++j) { + thread_bias2_buffer[j] = static_cast(B2_bias[j]); + } + } } } @@ -331,28 +934,89 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const int64_t token_idx = route_idx / k_; const float weight = route_scale[route_idx]; + if (token_idx < 0 || token_idx >= num_tokens) continue; + const size_t buffer_offset = static_cast(token_idx) * static_cast(hidden_size); - if (buffer_offset + static_cast(hidden_size) > output_buffer_size) { - // Skip this token to prevent buffer overflow - continue; - } + if (buffer_offset + static_cast(hidden_size) > output_buffer_size) continue; float* dest = thread_local_outputs + static_cast(thread_id) * output_buffer_size + buffer_offset; const float* src = C2 + i * hidden_size; - for (int64_t j = 0; j < hidden_size; ++j) { - dest[j] += weight * (src[j] + (B2_bias ? bias2_float[j] : 0.0f)); + + if (has_fc2_bias && !fc2_bias_handled_by_q4_gemm) { + const int64_t unroll_factor = GetUnrollFactor(hidden_size); + int64_t j = 0; + for (; j + unroll_factor <= hidden_size; j += unroll_factor) { + for (int64_t k = 0; k < unroll_factor; ++k) { + dest[j + k] += weight * (src[j + k] + thread_bias2_buffer[j + k]); + } + } + for (; j < hidden_size; ++j) { + dest[j] += weight * (src[j] + thread_bias2_buffer[j]); + } + } else { + const int64_t unroll_factor = GetUnrollFactor(hidden_size); + int64_t j = 0; + for (; j + unroll_factor <= hidden_size; j += unroll_factor) { + for (int64_t k = 0; k < unroll_factor; ++k) { + dest[j + k] += weight * src[j + k]; + } + } + for (; j < hidden_size; ++j) { + dest[j] += weight * src[j]; + } } } } }); - // --- 4. Final Reduction (accumulate expert outputs to a float buffer) --- auto accumulate = [&](float* buffer) { - memset(buffer, 0, output_buffer_size * sizeof(float)); - for (int i = 0; i < num_expert_threads; ++i) { - const size_t thread_offset = static_cast(i) * output_buffer_size; - for (size_t j = 0; j < output_buffer_size; ++j) { - buffer[j] += thread_local_outputs[thread_offset + j]; + std::memset(buffer, 0, output_buffer_size * sizeof(float)); + + const int max_acc_threads = tp ? concurrency::ThreadPool::DegreeOfParallelism(tp) : 1; + const size_t acc_thread_divisor = std::max(size_t{1}, static_cast(max_acc_threads) * 8); + const size_t min_elements_per_thread = std::max(size_t{32}, output_buffer_size / acc_thread_divisor); + const int optimal_acc_threads = (tp == nullptr || output_buffer_size < min_elements_per_thread) ? 1 : std::min(static_cast(output_buffer_size / std::max(size_t{1}, min_elements_per_thread)), max_acc_threads); + const int num_acc_threads = std::max(1, optimal_acc_threads); + + if (num_acc_threads > 1) { + concurrency::ThreadPool::TrySimpleParallelFor(tp, num_acc_threads, [&](std::ptrdiff_t acc_thread_id) { + const size_t elements_per_thread = output_buffer_size / static_cast(num_acc_threads); + const size_t start_idx = static_cast(acc_thread_id) * elements_per_thread; + const size_t end_idx = (acc_thread_id == num_acc_threads - 1) ? output_buffer_size : start_idx + elements_per_thread; + + for (int i = 0; i < num_expert_threads; ++i) { + const size_t thread_offset = static_cast(i) * output_buffer_size; + const float* src = thread_local_outputs + thread_offset + start_idx; + float* dst = buffer + start_idx; + + size_t j = 0; + const size_t chunk_size = end_idx - start_idx; + const int64_t unroll_factor = GetUnrollFactor(static_cast(chunk_size)); + for (; j + static_cast(unroll_factor) <= chunk_size; j += static_cast(unroll_factor)) { + for (int64_t k = 0; k < unroll_factor; ++k) { + dst[j + static_cast(k)] += src[j + static_cast(k)]; + } + } + for (; j < chunk_size; ++j) { + dst[j] += src[j]; + } + } + }); + } else { + for (int i = 0; i < num_expert_threads; ++i) { + const size_t thread_offset = static_cast(i) * output_buffer_size; + const float* src = thread_local_outputs + thread_offset; + + size_t j = 0; + const int64_t unroll_factor = GetUnrollFactor(static_cast(output_buffer_size)); + for (; j + static_cast(unroll_factor) <= output_buffer_size; j += static_cast(unroll_factor)) { + for (int64_t k = 0; k < unroll_factor; ++k) { + buffer[j + static_cast(k)] += src[j + static_cast(k)]; + } + } + for (; j < output_buffer_size; ++j) { + buffer[j] += src[j]; + } } } }; @@ -362,18 +1026,16 @@ Status QMoECPU::Compute(OpKernelContext* context) const { float* final_output_float = final_output_float_ptr.get(); accumulate(final_output_float); - // --- 5. Convert final float buffer to output type T --- MlasConvertFloatToHalfBuffer(final_output_float, reinterpret_cast(output->MutableData()), static_cast(output_buffer_size)); - } else { // T is float + } else { accumulate(output->MutableData()); } return Status::OK(); } -// Explicit template instantiation template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info); template Status QMoECPU::Compute(OpKernelContext* context) const; template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info); diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc index 93d802ca05b42..167b2af946183 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc @@ -77,7 +77,8 @@ Status ShardedMoE::ComputeInternal(OpKernelContext* context) const { fc2_experts_weights, fc2_experts_bias_optional, nullptr, fc3_experts_weights_optional, fc3_experts_bias_optional, nullptr, 1, // no quantization so pack size is 1 - activation_type_ == ort_fastertransformer::ActivationType::SwiGLU)); + activation_type_ == ort_fastertransformer::ActivationType::SwiGLU, + 0)); // no block-wise quantization for sharded MoE ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0, "num_experts should be divisible by world_size"); diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc index a5b9d483d5ad1..e5a064d59e360 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc @@ -45,7 +45,8 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { fc2_experts_weights, fc2_experts_bias_optional, nullptr, fc3_experts_weights_optional, fc3_experts_bias_optional, nullptr, 1, // no quantization so pack size is 1 - activation_type_ == ort_fastertransformer::ActivationType::SwiGLU)); + activation_type_ == ort_fastertransformer::ActivationType::SwiGLU, + 0)); // no block-wise quantization for regular MoE using CudaT = typename OrtToCudaType::type; auto stream = context->GetComputeStream(); diff --git a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc index dcf32bb3c5ae4..931b8ac09aa49 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc @@ -150,7 +150,8 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const { fc2_experts_weights, fc2_experts_bias_optional, fc2_scales, fc3_experts_weights_optional, fc3_experts_bias_optional, fc3_scales_optional, expert_weight_bits_ == 4 ? 2 : 1, - activation_type_ == ort_fastertransformer::ActivationType::SwiGLU)); + activation_type_ == ort_fastertransformer::ActivationType::SwiGLU, + 0)); // CUDA doesn't support block-wise quantization yet #if defined(__GNUC__) #pragma GCC diagnostic push diff --git a/onnxruntime/test/python/transformers/test_qmoe_cpu.py b/onnxruntime/test/python/transformers/test_qmoe_cpu.py index 403becbe0616a..0292111b16962 100644 --- a/onnxruntime/test/python/transformers/test_qmoe_cpu.py +++ b/onnxruntime/test/python/transformers/test_qmoe_cpu.py @@ -128,6 +128,148 @@ def quant_dequant(weights, is_4_bit_quantization: bool = True): # Calculate scale like C++ implementation abs_max = weights.abs().max(dim=-1, keepdim=True)[0] + + # Set minimum scale to avoid division by zero + scale = torch.clamp(abs_max, min=1e-6) + + # Quantization ranges for symmetric quantization + if is_4_bit_quantization: + qmin, qmax = -8, 7 + zero_point = 8 # Offset to make values unsigned + else: + qmin, qmax = -128, 127 + zero_point = 128 # Offset to make values unsigned + + # Quantize using double precision division and C-like rounding (half away from zero) + scaled = weights.double() / scale.double() + sign = torch.sign(scaled) + abs_scaled = torch.abs(scaled) + quant_rounded = torch.floor(abs_scaled + 0.5) + quantized = torch.clamp((sign * quant_rounded).to(torch.int32), qmin, qmax).to(weights.dtype) + + # Convert to unsigned and pack for storage + if is_4_bit_quantization: + # Convert to unsigned 4-bit and pack into uint8 + unsigned_quantized = (quantized + zero_point).to(torch.uint8) + + # Pack two 4-bit values into one uint8 + packed_size = (weights.shape[-1] + 1) // 2 + packed_quantized = torch.zeros((*weights.shape[:-1], packed_size), dtype=torch.uint8, device=weights.device) + + for i in range(0, weights.shape[-1], 2): + val1 = unsigned_quantized[..., i] + val2 = unsigned_quantized[..., i + 1] if i + 1 < weights.shape[-1] else torch.zeros_like(val1) + packed_quantized[..., i // 2] = (val1 & 0xF) | ((val2 & 0xF) << 4) + + quantized_storage = packed_quantized + else: + # 8-bit: convert to unsigned uint8 + quantized_storage = (quantized + zero_point).to(torch.uint8) + + # Dequantize for verification (use float32 scale for higher precision) + dequantized = quantized.to(torch.float32) * scale + + return scale.squeeze(-1).to(torch.float32), quantized_storage, dequantized + + +def quant_dequant_blockwise(weights, block_size, is_4_bit_quantization: bool = True): + """ + Block-wise quantization and dequantization for testing purposes. + This function uses symmetric quantization centered around 0 (no zero-point). + + Args: + weights: Input tensor of shape [rows, cols] + block_size: Size of each quantization block + is_4_bit_quantization: Whether to use 4-bit (True) or 8-bit (False) quantization + + Returns: + scales: Scale tensor of shape [rows, num_blocks] + quantized: Quantized tensor + dequantized: Dequantized tensor for verification + """ + rows, cols = weights.shape + num_blocks = (cols + block_size - 1) // block_size + + # Handle edge case of all-zero weights tensor + if torch.all(weights == 0): + scales = torch.zeros((rows, num_blocks), dtype=torch.float16, device=weights.device) + if is_4_bit_quantization: + packed_size = (cols + 1) // 2 + quantized = torch.zeros((rows, packed_size), dtype=torch.uint8, device=weights.device) + else: + quantized = torch.zeros((rows, cols), dtype=torch.uint8, device=weights.device) + dequantized = torch.zeros_like(weights) + return scales, quantized, dequantized + + # Initialize output tensors; use float32 for scales to reduce precision loss + scales = torch.zeros((rows, num_blocks), dtype=torch.float32, device=weights.device) + dequantized = torch.zeros_like(weights) + + # Quantization ranges and zero point + if is_4_bit_quantization: + qmin, qmax = -8, 7 + zero_point = 8 + packed_size = (cols + 1) // 2 + quantized = torch.zeros((rows, packed_size), dtype=torch.uint8, device=weights.device) + else: + qmin, qmax = -128, 127 + zero_point = 128 + quantized = torch.zeros((rows, cols), dtype=torch.uint8, device=weights.device) + + # Process each block with higher-precision math to match C++ behavior + for row in range(rows): + for block_idx in range(num_blocks): + start_col = block_idx * block_size + end_col = min(start_col + block_size, cols) + + # Get block data + block_data = weights[row, start_col:end_col] + + # Calculate absolute max and ensure small epsilon to avoid div-by-zero + abs_max = block_data.abs().max() + abs_max = torch.clamp(abs_max, min=1e-8) + + # Compute scale consistent with C++: use 7.0 for 4-bit positive max, 127.0 for 8-bit + if is_4_bit_quantization: + # Use higher precision then keep as float32 for scale + scale = (abs_max.double() / 7.0).float() + 1e-12 + else: + scale = (abs_max.double() / 127.0).float() + 1e-12 + + scales[row, block_idx] = scale.to(torch.float32) + + if scale == 0: + continue + + # Quantize using double precision for the division to reduce rounding error + scaled = block_data.double() / scale.double() + # Emulate C's round() behavior (round half away from zero) to match C++ implementation + sign = torch.sign(scaled) + abs_scaled = torch.abs(scaled) + quant_rounded = torch.floor(abs_scaled + 0.5) + quantized_block = (sign * quant_rounded).clamp(qmin, qmax).to(torch.int32) + + # Pack for 4-bit or store directly for 8-bit + if is_4_bit_quantization: + for i in range(0, end_col - start_col, 2): + col_idx = start_col + i + packed_idx = col_idx // 2 + + val1 = int(quantized_block[i]) + zero_point + val2 = int(quantized_block[i + 1]) + zero_point if i + 1 < len(quantized_block) else zero_point + + # Pack two 4-bit values into one uint8 + packed_val = (val1 & 0xF) | ((val2 & 0xF) << 4) + quantized[row, packed_idx] = packed_val + else: + quantized_vals = (quantized_block + zero_point).to(torch.uint8) + quantized[row, start_col:end_col] = quantized_vals + + # Dequantize for verification (signed quantized values multiplied by scale) + signed = quantized_block.to(torch.float32) + dequantized[row, start_col:end_col] = signed * scale + + return scales, quantized, dequantized abs_max = torch.clamp(abs_max, min=1e-8) # More conservative clamping for better precision if is_4_bit_quantization: @@ -247,6 +389,7 @@ def create_cpu_moe_onnx_graph( use_quant=False, quant_bits=4, swiglu_interleaved=False, + block_size=0, # New parameter for block-wise quantization ): if not has_onnx: return None @@ -264,14 +407,13 @@ def create_cpu_moe_onnx_graph( if not has_onnx: return None - if use_quant: - # Assertions only apply to quantized MoE - assert fc1_experts_weights.dtype == torch.uint8, "FC1 weights must be uint8 for QMoE" - assert fc2_experts_weights.dtype == torch.uint8, "FC2 weights must be uint8 for QMoE" - assert fc1_scales is not None, "FC1 scales must be provided for QMoE" - assert fc2_scales is not None, "FC2 scales must be provided for QMoE" - assert fc1_scales.dtype == torch.float16, "FC1 scales must be float16 for QMoE" - assert fc2_scales.dtype == torch.float16, "FC2 scales must be float16 for QMoE" + assert fc1_experts_weights.dtype == torch.uint8, "FC1 weights must be uint8 for QMoE" + assert fc2_experts_weights.dtype == torch.uint8, "FC2 weights must be uint8 for QMoE" + assert fc1_scales is not None, "FC1 scales must be provided for QMoE" + assert fc2_scales is not None, "FC2 scales must be provided for QMoE" + # Accept float16 or float32 scales; tests may produce float32 for better precision + assert fc1_scales.dtype in (torch.float16, torch.float32), "FC1 scales must be float16 or float32 for QMoE" + assert fc2_scales.dtype in (torch.float16, torch.float32), "FC2 scales must be float16 or float32 for QMoE" if not has_onnx: return None @@ -332,6 +474,10 @@ def create_cpu_moe_onnx_graph( if use_quant: nodes[0].attribute.extend([helper.make_attribute("expert_weight_bits", quant_bits)]) + # Add block_size attribute for block-wise quantization + if block_size > 0: + nodes[0].attribute.extend([helper.make_attribute("block_size", block_size)]) + # Weights are store in column major order. Need pack 2 int4 values into uint8. # Use the actual tensor shapes instead of calculating them to avoid size mismatches fc1_shape = list(fc1_experts_weights.shape) @@ -342,30 +488,59 @@ def create_cpu_moe_onnx_graph( weight_numpy_type = numpy.uint8 if use_quant else ort_to_numpy_type_map[onnx_dtype] weight_onnx_type = TensorProto.UINT8 if use_quant else onnx_dtype + # Use raw bytes from C-contiguous numpy arrays to ensure the exact memory layout + # of the packed uint8 weight tensors is preserved when writing the ONNX initializer. + fc1_np = fc1_experts_weights.detach().cpu().numpy().astype(weight_numpy_type) + fc2_np = fc2_experts_weights.detach().cpu().numpy().astype(weight_numpy_type) + fc1_np = numpy.ascontiguousarray(fc1_np) + fc2_np = numpy.ascontiguousarray(fc2_np) + initializers = [ helper.make_tensor( "fc1_experts_weights", weight_onnx_type, fc1_shape, - fc1_experts_weights.flatten().detach().cpu().numpy().astype(weight_numpy_type).tolist(), + fc1_np.tobytes(), + raw=True, ), helper.make_tensor( "fc2_experts_weights", weight_onnx_type, fc2_shape, - fc2_experts_weights.flatten().detach().cpu().numpy().astype(weight_numpy_type).tolist(), + fc2_np.tobytes(), + raw=True, ), ] - fc1_scale_shape = [num_experts, 2 * inter_size if use_swiglu else inter_size] - fc2_scale_shape = [num_experts, hidden_size] + # Calculate scale tensor shapes based on block_size + if block_size > 0: + # Block-wise quantization: 3D scale tensors + fc1_blocks_per_row = (hidden_size + block_size - 1) // block_size + fc2_blocks_per_row = (inter_size + block_size - 1) // block_size - fc1_scale_size = num_experts * (2 * inter_size if use_swiglu else inter_size) - fc2_scale_size = num_experts * hidden_size + fc1_scale_shape = [num_experts, 2 * inter_size if use_swiglu else inter_size, fc1_blocks_per_row] + fc2_scale_shape = [num_experts, hidden_size, fc2_blocks_per_row] - # Handle scale tensors based on quantization mode - if use_quant: - # Handle different possible scale tensor structures for fc1_scales + fc1_scale_size = num_experts * (2 * inter_size if use_swiglu else inter_size) * fc1_blocks_per_row + fc2_scale_size = num_experts * hidden_size * fc2_blocks_per_row + else: + # Row-wise quantization: 2D scale tensors + fc1_scale_shape = [num_experts, 2 * inter_size if use_swiglu else inter_size] + fc2_scale_shape = [num_experts, hidden_size] + + fc1_scale_size = num_experts * (2 * inter_size if use_swiglu else inter_size) + fc2_scale_size = num_experts * hidden_size + + # Handle scale tensors - fc1_scales and fc2_scales are guaranteed to be not None due to earlier assertions + # Process scale tensors based on whether block-wise quantization is used + if block_size > 0: + # For block-wise quantization, the scales are already in the correct 3D shape + # [num_experts, output_features, num_blocks] from quant_dequant_blockwise + # Convert scales to the selected ONNX dtype (prefer float32 for higher precision) + fc1_scale_tensor = fc1_scales.to(torch_dtype).flatten().detach().cpu().numpy() + fc2_scale_tensor = fc2_scales.to(torch_dtype).flatten().detach().cpu().numpy() + else: + # For row-wise quantization, handle different possible scale tensor structures for fc1_scales if len(fc1_scales.shape) == 4: # 4D case: [num_experts, inter_size, hidden_size, 1] - extract first scale per expert per output if use_swiglu: @@ -395,10 +570,6 @@ def create_cpu_moe_onnx_graph( [fc1_scale_tensor, numpy.ones(pad_size, dtype=fc1_scale_tensor.dtype)] ) - # Process scale tensor for proper shape - fc1_scale_data_list = fc1_scale_tensor.tolist() - fc1_scale_data = fc1_scale_data_list - # Handle different possible scale tensor structures for fc2_scales if len(fc2_scales.shape) == 4: # 4D case: [num_experts, hidden_size, inter_size, 1] - extract first scale per expert per output @@ -421,48 +592,30 @@ def create_cpu_moe_onnx_graph( [fc2_scale_tensor, numpy.ones(pad_size, dtype=fc2_scale_tensor.dtype)] ) - # Process scale tensor for proper shape - fc2_scale_data_list = fc2_scale_tensor.tolist() - fc2_scale_data = fc2_scale_data_list - - initializers.extend( - [ - helper.make_tensor( - "fc1_scales", - onnx_dtype, - fc1_scale_shape, - fc1_scale_data, - raw=False, - ), - helper.make_tensor( - "fc2_scales", - onnx_dtype, - fc2_scale_shape, - fc2_scale_data, - raw=False, - ), - ] - ) - else: - # For non-quantized mode, add bias tensors if provided - if fc1_bias is not None: - initializers.append( - helper.make_tensor( - "fc1_experts_bias", - onnx_dtype, - list(fc1_bias.shape), - fc1_bias.flatten().detach().cpu().numpy().astype(ort_to_numpy_type_map[onnx_dtype]).tolist(), - ) - ) - if fc2_bias is not None: - initializers.append( - helper.make_tensor( - "fc2_experts_bias", - onnx_dtype, - list(fc2_bias.shape), - fc2_bias.flatten().detach().cpu().numpy().astype(ort_to_numpy_type_map[onnx_dtype]).tolist(), - ) - ) + # Process scale tensors for proper data format + fc1_scale_data_list = fc1_scale_tensor.tolist() + fc1_scale_data = fc1_scale_data_list + fc2_scale_data_list = fc2_scale_tensor.tolist() + fc2_scale_data = fc2_scale_data_list + + initializers.extend( + [ + helper.make_tensor( + "fc1_scales", + onnx_dtype, + fc1_scale_shape, + fc1_scale_data, + raw=False, + ), + helper.make_tensor( + "fc2_scales", + onnx_dtype, + fc2_scale_shape, + fc2_scale_data, + raw=False, + ), + ] + ) graph_inputs = [ helper.make_tensor_value_info("input", onnx_dtype, [sequence_length, hidden_size]), @@ -645,10 +798,7 @@ class SparseMoeBlockORTHelper(nn.Module): def __init__(self, quant_bits=0, onnx_dtype=None): super().__init__() self.quant_bits = quant_bits - if onnx_dtype is None: - self.onnx_dtype = TensorProto.FLOAT16 if self.quant_bits > 0 else TensorProto.FLOAT - else: - self.onnx_dtype = onnx_dtype + self.onnx_dtype = onnx_dtype self.np_type = numpy.float16 if self.onnx_dtype == TensorProto.FLOAT16 else numpy.float32 def create_ort_session(self, moe_onnx_graph): @@ -717,8 +867,8 @@ def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False tensors = { "input": hidden_states_flat.clone().to(device=device, dtype=torch_dtype), - "router_probs": router_input.clone().to(device=device, dtype=torch_dtype), - "output": torch.zeros_like(hidden_states_flat, device=device, dtype=torch_dtype), + "router_probs": router_logits.clone().to(device=device, dtype=torch_dtype), + "output": torch.zeros((batch_size * sequence_length, hidden_dim), device=device, dtype=torch_dtype), } try: @@ -779,14 +929,47 @@ def recreate_onnx_model(self): is_4_bit = self.quant_bits == 4 for i in range(self.num_experts): - w1_scale, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight, is_4_bit) - w2_scale, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight, is_4_bit) + if self.block_size > 0: + # Use block-wise quantization + w1_scale, pre_qweight1, w1_qdq = quant_dequant_blockwise( + self.experts[i].w1.weight, self.block_size, is_4_bit + ) + w2_scale, pre_qweight2, w2_qdq = quant_dequant_blockwise( + self.experts[i].w2.weight, self.block_size, is_4_bit + ) + else: + # Use row-wise quantization + w1_scale, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight, is_4_bit) + w2_scale, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight, is_4_bit) if self.use_swiglu: - # For SwiGLU, CPU kernel now always expects interleaved format - # SwigluMlp weights are already in interleaved format [gate_0, linear_0, gate_1, linear_1, ...] - # No conversion needed - both CPU and CUDA use interleaved format - self.experts[i].w1.weight = nn.Parameter(w1_qdq.contiguous().clone()) + if self.swiglu_interleaved: + pass + else: + if self.block_size > 0: + w3_scale, pre_qweight3, w3_qdq = quant_dequant_blockwise( + self.experts[i].w3.weight, self.block_size, is_4_bit + ) + else: + w3_scale, pre_qweight3, w3_qdq = quant_dequant(self.experts[i].w3.weight, is_4_bit) + + gate_weights = pre_qweight1 + value_weights = pre_qweight3 + gate_scales = w1_scale + value_scales = w3_scale + + pre_qweight1 = torch.cat([gate_weights, value_weights], dim=0) + w1_scale = torch.cat([gate_scales, value_scales], dim=0) + + if self.swiglu_interleaved: + self.experts[i].w1.weight = nn.Parameter(w1_qdq.contiguous().clone()) + + else: + intermediate_size = self.experts[i].w1.weight.shape[0] + gate_dequant = w1_qdq[:intermediate_size].contiguous().clone() + value_dequant = w1_qdq[intermediate_size:].contiguous().clone() + self.experts[i].w1.weight.data = gate_dequant + self.experts[i].w3.weight.data = value_dequant else: self.experts[i].w1.weight.data = w1_qdq.contiguous().clone() @@ -828,7 +1011,8 @@ def recreate_onnx_model(self): use_swiglu=self.use_swiglu, use_quant=True, # Always use QMoE quant_bits=self.quant_bits, - swiglu_interleaved=True, # CPU kernel now always expects interleaved format + swiglu_interleaved=self.swiglu_interleaved if hasattr(self, "swiglu_interleaved") else False, + block_size=self.block_size, # Add block_size for block-wise quantization ) except Exception: self.moe_onnx_graph = None @@ -877,6 +1061,45 @@ def parity_check(self): print(f"Parity check - {act_type} {self.quant_bits}-bit: max_diff = {max_diff:.6f}") + # Diagnostic dump: when differences are large, show the index and nearby values + if max_diff > 1e-3: + diff = (torch_output.cpu() - ort_output.cpu()).abs() + idx = torch.argmax(diff) + flat_idx = int(idx) + # Derive coordinates (batch, seq, hidden) from flattened index + total_elems = torch_output.numel() + # Work in flattened [batch, seq, hidden] ordering + hidden_dim = self.hidden_dim + seq = self.sequence_length + # Clamp to safe bounds + flat_idx = min(flat_idx, total_elems - 1) + i = flat_idx // (hidden_dim) + j = i // seq + k = flat_idx % hidden_dim + print( + f"Diagnostic - max diff at flat_idx={flat_idx} -> sample (batch_idx={j}, seq_idx={i % seq}, hidden_idx={k})" + ) + print("Torch sample:", torch_output.cpu().reshape(-1, hidden_dim)[i, k].item()) + print("ORT sample:", ort_output.cpu().reshape(-1, hidden_dim)[i, k].item()) + # Print routing and per-expert contributions for this token from the PyTorch reference + try: + hidden_states_flat = hidden_state.view(-1, hidden_dim) + token_vec = hidden_states_flat[i : i + 1] + gate_logits = self.gate(token_vec) + topk_vals, topk_experts = torch.topk(gate_logits, self.top_k, dim=-1) + topk_soft = F.softmax(topk_vals, dim=1) + print("Gate logits:", gate_logits.detach().cpu().numpy()) + print("Selected experts:", topk_experts.detach().cpu().numpy()) + print("Routing weights:", topk_soft.detach().cpu().numpy()) + # Compute per-expert contributions for selected experts + for idx_e, e in enumerate(topk_experts[0].tolist()): + expert_layer = self.experts[e] + expert_out = expert_layer(token_vec) + contrib = expert_out[0, k].item() * topk_soft[0, idx_e].item() + print(f"Expert {e} contrib at hidden {k}: {contrib}") + except Exception as _: + pass + ort_dtype_quant_bits_tolerance_map = { "FP32:0": (5e-3, 1e-3), "FP16:0": (5e-2, 1e-3), @@ -917,7 +1140,13 @@ def small_test_cases(): class SwigluMoEBlock(SparseMoeBlockORTHelper): def __init__( - self, config: SwigluMoeConfig, batch_size: int, sequence_length: int, quant_bits: int = 0, onnx_dtype=None + self, + config: SwigluMoeConfig, + batch_size: int, + sequence_length: int, + quant_bits: int = 0, + onnx_dtype=None, + block_size: int = 0, ): super().__init__(quant_bits, onnx_dtype=onnx_dtype) self.hidden_dim = config.hidden_size @@ -926,6 +1155,7 @@ def __init__( self.top_k = config.num_experts_per_token self.use_swiglu = True self.swiglu_interleaved = True + self.block_size = block_size # Store block_size for QMoE use_quant = self.quant_bits > 0 self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=True) @@ -995,7 +1225,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class PhiMoESparseMoeBlock(SparseMoeBlockORTHelper): def __init__( - self, config: PhiMoEConfig, batch_size: int, sequence_length: int, quant_bits: int = 0, onnx_dtype=None + self, + config: PhiMoEConfig, + batch_size: int, + sequence_length: int, + quant_bits: int = 0, + onnx_dtype=None, + block_size: int = 0, ): super().__init__(quant_bits, onnx_dtype=onnx_dtype) self.hidden_dim = config.hidden_size @@ -1005,6 +1241,7 @@ def __init__( self.router_jitter_noise = config.router_jitter_noise self.use_swiglu = True self.swiglu_interleaved = True + self.block_size = block_size # Store block_size for QMoE use_quant = self.quant_bits > 0 self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=True) @@ -1024,8 +1261,14 @@ def __init__( else: is_4_bit = self.quant_bits == 4 - scale1, pre_qweight1, w1_qdq = quant_dequant(expert.w1.weight, is_4_bit) - scale2, pre_qweight2, w2_qdq = quant_dequant(expert.w2.weight, is_4_bit) + if self.block_size > 0: + # Use block-wise quantization + scale1, pre_qweight1, w1_qdq = quant_dequant_blockwise(expert.w1.weight, self.block_size, is_4_bit) + scale2, pre_qweight2, w2_qdq = quant_dequant_blockwise(expert.w2.weight, self.block_size, is_4_bit) + else: + # Use row-wise quantization + scale1, pre_qweight1, w1_qdq = quant_dequant(expert.w1.weight, is_4_bit) + scale2, pre_qweight2, w2_qdq = quant_dequant(expert.w2.weight, is_4_bit) expert.w1.weight.data = w1_qdq expert.w2.weight.data = w2_qdq @@ -1064,6 +1307,7 @@ def __init__( use_quant=use_quant, quant_bits=self.quant_bits, swiglu_interleaved=self.swiglu_interleaved, + block_size=self.block_size, # Add block_size for block-wise quantization ) self.ort_sess = self.create_ort_session(self.moe_onnx_graph) if self.moe_onnx_graph else None @@ -1075,9 +1319,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view(-1, hidden_dim) router_logits = self.gate(hidden_states) - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # Match CPU implementation: select top-k experts by logits, then softmax over those logits + routing_weights_vals, selected_experts = torch.topk(router_logits, self.top_k, dim=-1) + routing_weights = F.softmax(routing_weights_vals, dim=1, dtype=torch.float) routing_weights = routing_weights.to(hidden_states.dtype) final_hidden_states = torch.zeros( @@ -1112,6 +1356,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: (2, 16, 8), ] +# Define test cases for block-wise quantization +phi3_blockwise_test_cases = [ + (1, 32, 4, 32), # batch_size, sequence_length, quant_bits, block_size + (1, 32, 8, 64), + (2, 16, 4, 32), + (2, 16, 8, 64), +] + @unittest.skipIf(disable_cpu_qmoe_tests, "Skipping qMoE cpu tests") class TestPhiQMoECPU(unittest.TestCase): @@ -1152,6 +1404,37 @@ def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): phi3_moe.parity_check() + @parameterized.expand(phi3_blockwise_test_cases) + def test_phi3_qmoe_blockwise_parity_cpu(self, batch_size, sequence_length, quant_bits, block_size): + torch.manual_seed(42) + numpy.random.seed(42) + + test_config = f"batch_size={batch_size}, sequence_length={sequence_length}, quant_bits={quant_bits}, block_size={block_size}" + print(f"Running Phi3 QMoE block-wise test: {test_config}") + + config = PhiMoEConfig(hidden_size=128, intermediate_size=256, num_local_experts=4, num_experts_per_tok=2) + + phi3_moe = PhiMoESparseMoeBlock( + config, + batch_size=batch_size, + sequence_length=sequence_length, + quant_bits=quant_bits, + onnx_dtype=TensorProto.FLOAT, + block_size=block_size, # Enable block-wise quantization + ) + + hidden_states = torch.randn(batch_size, sequence_length, config.hidden_size).to(torch.float32) + + torch_result = phi3_moe.forward(hidden_states) + + # Verify output shape and basic properties + expected_shape = (batch_size, sequence_length, config.hidden_size) + self.assertEqual(torch_result.shape, expected_shape) + self.assertFalse(torch.isnan(torch_result).any()) + self.assertFalse(torch.isinf(torch_result).any()) + + phi3_moe.parity_check() + disable_cpu_qmoe_tests = False @@ -1162,6 +1445,14 @@ def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): (2, 16, 8), ] +# Define test cases for block-wise quantization +swiglu_blockwise_test_cases = [ + (1, 32, 4, 32), # batch_size, sequence_length, quant_bits, block_size + (1, 32, 8, 64), + (2, 16, 4, 32), + (2, 16, 8, 64), +] + @unittest.skipIf(disable_cpu_qmoe_tests, "Skipping qMoE cpu tests") class TestSwigluQMoECPU(unittest.TestCase): @@ -1201,6 +1492,36 @@ def test_swiglu_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): swiglu_moe.parity_check() + @parameterized.expand(swiglu_blockwise_test_cases) + def test_swiglu_qmoe_blockwise_parity_cpu(self, batch_size, sequence_length, quant_bits, block_size): + torch.manual_seed(42) + numpy.random.seed(42) + + test_config = f"batch_size={batch_size}, sequence_length={sequence_length}, quant_bits={quant_bits}, block_size={block_size}" + print(f"Running SwiGLU block-wise test: {test_config}") + + config = SwigluMoeConfig(hidden_size=128, intermediate_size=256, num_local_experts=4, num_experts_per_token=2) + + swiglu_moe = SwigluMoEBlock( + config, + batch_size=batch_size, + sequence_length=sequence_length, + quant_bits=quant_bits, + onnx_dtype=TensorProto.FLOAT, + block_size=block_size, # Enable block-wise quantization + ) + + hidden_states = torch.randn(batch_size, sequence_length, config.hidden_size).to(torch.float32) + + torch_result = swiglu_moe.forward(hidden_states) + + expected_shape = (batch_size, sequence_length, config.hidden_size) + self.assertEqual(torch_result.shape, expected_shape) + self.assertFalse(torch.isnan(torch_result).any()) + self.assertFalse(torch.isinf(torch_result).any()) + + swiglu_moe.parity_check() + @unittest.skipIf(True, "Skipping QMoE CPU benchmark tests") class TestQMoESwiGLUBenchmark(unittest.TestCase): From cdce7f121abbe1179241618b05f8aa1bce6cbb38 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Wed, 24 Sep 2025 10:50:15 -0700 Subject: [PATCH 05/10] [C#] Implement missing APIs (#26101) This pull request adds new APIs and updates existing ones to improve memory and device information handling in the ONNX Runtime C# bindings. The most significant changes introduce methods for fetching memory info and device info for session inputs/outputs, and add support for shared allocators and synchronization streams. There are also several updates and renamings for LoraAdapter delegates and related APIs. ### Memory and Device Info APIs * Added `GetMemoryInfosForInputs`, `GetMemoryInfosForOutputs`, and `GetEpDeviceForInputs` methods to `InferenceSession.shared.cs` to fetch memory info and device info for session inputs/outputs. These methods utilize new native delegates for retrieving memory and device information. * Introduced native delegates in `NativeMethods.shared.cs` for `OrtSessionGetMemoryInfoForInputs`, `OrtSessionGetMemoryInfoForOutputs`, and `OrtSessionGetEpDeviceForInputs`, and wired them up in the static constructor. [[1]](diffhunk://#diff-f9f2aaafc076365917de8ab96628da427d9dd0fd6a214fb9c266733f90d6fc73R530-R532) [[2]](diffhunk://#diff-f9f2aaafc076365917de8ab96628da427d9dd0fd6a214fb9c266733f90d6fc73R1312-R1335) ### Shared Allocator and Synchronization Stream Support * Added delegates and static fields for creating, getting, and releasing shared allocators, as well as for creating and managing synchronization streams (`OrtCreateSharedAllocator`, `OrtGetSharedAllocator`, `OrtReleaseSharedAllocator`, `OrtCreateSyncStreamForEpDevice`, `OrtSyncStream_GetHandle`, `OrtReleaseSyncStream`). * Added delegate for copying tensors (`OrtCopyTensors`). ### LoraAdapter API Updates * Renamed LoraAdapter-related delegates to use the `Ort` prefix (`OrtCreateLoraAdapter`, `OrtCreateLoraAdapterFromArray`, `OrtReleaseLoraAdapter`) and updated their usage throughout the codebase. [[1]](diffhunk://#diff-f9f2aaafc076365917de8ab96628da427d9dd0fd6a214fb9c266733f90d6fc73L699-R710) [[2]](diffhunk://#diff-f9f2aaafc076365917de8ab96628da427d9dd0fd6a214fb9c266733f90d6fc73L1561-R1672) [[3]](diffhunk://#diff-f9f2aaafc076365917de8ab96628da427d9dd0fd6a214fb9c266733f90d6fc73L1578-R1695) ### MemoryInfo Enhancements * Added new delegates for creating memory info with more parameters (`OrtCreateMemoryInfoV2`), and for querying device memory type and vendor ID (`OrtMemoryInfoGetDeviceMemType`, `OrtMemoryInfoGetVendorId`). [[1]](diffhunk://#diff-f9f2aaafc076365917de8ab96628da427d9dd0fd6a214fb9c266733f90d6fc73R594-R596) [[2]](diffhunk://#diff-f9f2aaafc076365917de8ab96628da427d9dd0fd6a214fb9c266733f90d6fc73R1804-R1817) [[3]](diffhunk://#diff-f9f2aaafc076365917de8ab96628da427d9dd0fd6a214fb9c266733f90d6fc73R1866-R1877) ### Minor API Documentation Update * Clarified the lifetime of allocators in the documentation, noting they can be explicitly unregistered.### Description ### Motivation and Context --- .../InferenceSession.shared.cs | 76 +++++ .../NativeMethods.shared.cs | 287 ++++++++++++++---- .../NativeOnnxValueHelper.shared.cs | 14 +- .../OrtAllocator.shared.cs | 78 ++++- .../Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs | 139 ++++++++- .../OrtEpDevice.shared.cs | 81 +++++ .../OrtKeyValuePairs.shared.cs | 5 + .../OrtLoraAdapter.shared.cs | 6 +- .../InferenceTest.cs | 151 +++------ .../OrtAutoEpTests.cs | 21 +- .../OrtEnvTests.cs | 284 ++++++++++++++++- .../InferenceTest.netcore.cs | 55 ++-- .../python/onnxruntime_pybind_state.cc | 3 + .../python/onnxruntime_test_python_autoep.py | 2 + 14 files changed, 981 insertions(+), 221 deletions(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs index 792f0ddd0f777..79e6dbbb11c89 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.shared.cs @@ -214,6 +214,82 @@ public IReadOnlyDictionary OverridableInitializerMetadata } } + /// + /// Fetches memory info for all inputs in the same order as their names. + /// (See InputNames property). + /// + /// A disposable readonly collection of OrtMemoryInfo + public IDisposableReadOnlyCollection GetMemoryInfosForInputs() + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetInputCount(_nativeHandle, out UIntPtr numInputs)); + + if(numInputs == UIntPtr.Zero) + { + return new DisposableList(); + } + + var memoryInfoArray = new IntPtr[(ulong)numInputs]; + + NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetMemoryInfoForInputs(_nativeHandle, + memoryInfoArray, numInputs)); + + return new DisposableList( + memoryInfoArray.Select(static ptr => new OrtMemoryInfo(ptr, /* owned= */ false))); + } + + /// + /// Fetches memory info for all outputs in the same order as their names. + /// (See OutputNames property). + /// + /// A disposable readonly collection of OrtMemoryInfo + public IDisposableReadOnlyCollection GetMemoryInfosForOutputs() + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOutputCount(_nativeHandle, + out UIntPtr numOutputs)); + + if(numOutputs == UIntPtr.Zero) + { + return new DisposableList(); + } + + var memoryInfoArray = new IntPtr[(ulong)numOutputs]; + + NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetMemoryInfoForOutputs(_nativeHandle, + memoryInfoArray, numOutputs)); + return new DisposableList( + memoryInfoArray.Select(static ptr => new OrtMemoryInfo(ptr, /* owned= */ false))); + } + + /// + /// Fetches OrtEpDevice instances for all inputs in the same order as their input names. + /// For inputs that do not have a device, the corresponding entry in the returned list is null. + /// See InputNames property. + /// + /// IReadOnlyList + public IReadOnlyList GetEpDeviceForInputs() + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetInputCount(_nativeHandle, + out UIntPtr numInputs)); + + if (numInputs == UIntPtr.Zero) + { + // OrtSessionGetEpDeviceForInputs expects numInputs > 0, otherwise it is an invalid arg. + return []; + } + + var epDevicesForInputs = new IntPtr[(ulong)numInputs]; + + NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetEpDeviceForInputs(_nativeHandle, + epDevicesForInputs, numInputs)); + + // Some entries in epDevicesForInputs can be IntPtr.Zero, indicating the input does not + // have a device; return null for those entries. + return epDevicesForInputs + .Select(static ptr => ptr == IntPtr.Zero ? null : new OrtEpDevice(ptr)) + .ToList() + .AsReadOnly(); + } + /// /// Runs the loaded model for the given inputs, and fetches all the outputs. /// diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index 53880308da261..b97adfbd564d5 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -500,7 +500,7 @@ static NativeMethods() OrtCreateEnvWithGlobalThreadPools = (DOrtCreateEnvWithGlobalThreadPools)Marshal.GetDelegateForFunctionPointer(api_.CreateEnvWithGlobalThreadPools, typeof(DOrtCreateEnvWithGlobalThreadPools)); OrtCreateEnvWithCustomLoggerAndGlobalThreadPools = (DOrtCreateEnvWithCustomLoggerAndGlobalThreadPools)Marshal.GetDelegateForFunctionPointer(api_.CreateEnvWithCustomLoggerAndGlobalThreadPools, typeof(DOrtCreateEnvWithCustomLoggerAndGlobalThreadPools)); OrtReleaseEnv = (DOrtReleaseEnv)Marshal.GetDelegateForFunctionPointer(api_.ReleaseEnv, typeof(DOrtReleaseEnv)); - + OrtEnableTelemetryEvents = (DOrtEnableTelemetryEvents)Marshal.GetDelegateForFunctionPointer(api_.EnableTelemetryEvents, typeof(DOrtEnableTelemetryEvents)); OrtDisableTelemetryEvents = (DOrtDisableTelemetryEvents)Marshal.GetDelegateForFunctionPointer(api_.DisableTelemetryEvents, typeof(DOrtDisableTelemetryEvents)); @@ -527,6 +527,9 @@ static NativeMethods() OrtSessionGetInputTypeInfo = (DOrtSessionGetInputTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.SessionGetInputTypeInfo, typeof(DOrtSessionGetInputTypeInfo)); OrtSessionGetOutputTypeInfo = (DOrtSessionGetOutputTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.SessionGetOutputTypeInfo, typeof(DOrtSessionGetOutputTypeInfo)); OrtSessionGetOverridableInitializerTypeInfo = (DOrtSessionGetOverridableInitializerTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.SessionGetOverridableInitializerTypeInfo, typeof(DOrtSessionGetOverridableInitializerTypeInfo)); + OrtSessionGetMemoryInfoForInputs = (DOrtSessionGetMemoryInfoForInputs)Marshal.GetDelegateForFunctionPointer(api_.SessionGetMemoryInfoForInputs, typeof(DOrtSessionGetMemoryInfoForInputs)); + OrtSessionGetMemoryInfoForOutputs = (DOrtSessionGetMemoryInfoForOutputs)Marshal.GetDelegateForFunctionPointer(api_.SessionGetMemoryInfoForOutputs, typeof(DOrtSessionGetMemoryInfoForOutputs)); + OrtSessionGetEpDeviceForInputs = (DOrtSessionGetEpDeviceForInputs)Marshal.GetDelegateForFunctionPointer(api_.SessionGetEpDeviceForInputs, typeof(DOrtSessionGetEpDeviceForInputs)); OrtReleaseTypeInfo = (DOrtReleaseTypeInfo)Marshal.GetDelegateForFunctionPointer(api_.ReleaseTypeInfo, typeof(DOrtReleaseTypeInfo)); OrtReleaseSession = (DOrtReleaseSession)Marshal.GetDelegateForFunctionPointer(api_.ReleaseSession, typeof(DOrtReleaseSession)); OrtSessionGetProfilingStartTimeNs = (DOrtSessionGetProfilingStartTimeNs)Marshal.GetDelegateForFunctionPointer(api_.SessionGetProfilingStartTimeNs, typeof(DOrtSessionGetProfilingStartTimeNs)); @@ -588,6 +591,9 @@ static NativeMethods() OrtMemoryInfoGetMemType = (DOrtMemoryInfoGetMemType)Marshal.GetDelegateForFunctionPointer(api_.MemoryInfoGetMemType, typeof(DOrtMemoryInfoGetMemType)); OrtMemoryInfoGetType = (DOrtMemoryInfoGetType)Marshal.GetDelegateForFunctionPointer(api_.MemoryInfoGetType, typeof(DOrtMemoryInfoGetType)); OrtGetAllocatorWithDefaultOptions = (DOrtGetAllocatorWithDefaultOptions)Marshal.GetDelegateForFunctionPointer(api_.GetAllocatorWithDefaultOptions, typeof(DOrtGetAllocatorWithDefaultOptions)); + OrtCreateMemoryInfoV2 = (DOrtCreateMemoryInfoV2)Marshal.GetDelegateForFunctionPointer(api_.CreateMemoryInfo_V2, typeof(DOrtCreateMemoryInfoV2)); + OrtMemoryInfoGetDeviceMemType = (DOrtMemoryInfoGetDeviceMemType)Marshal.GetDelegateForFunctionPointer(api_.MemoryInfoGetDeviceMemType, typeof(DOrtMemoryInfoGetDeviceMemType)); + OrtMemoryInfoGetVendorId = (DOrtMemoryInfoGetVendorId)Marshal.GetDelegateForFunctionPointer(api_.MemoryInfoGetVendorId, typeof(DOrtMemoryInfoGetVendorId)); OrtCreateAllocator = (DOrtCreateAllocator)Marshal.GetDelegateForFunctionPointer(api_.CreateAllocator, typeof(DOrtCreateAllocator)); OrtReleaseAllocator = (DOrtReleaseAllocator)Marshal.GetDelegateForFunctionPointer(api_.ReleaseAllocator, typeof(DOrtReleaseAllocator)); OrtAllocatorAlloc = (DOrtAllocatorAlloc)Marshal.GetDelegateForFunctionPointer(api_.AllocatorAlloc, typeof(DOrtAllocatorAlloc)); @@ -610,6 +616,7 @@ static NativeMethods() OrtTensorAt = (DOrtTensorAt)Marshal.GetDelegateForFunctionPointer(api_.TensorAt, typeof(DOrtTensorAt)); OrtCreateAndRegisterAllocator = (DOrtCreateAndRegisterAllocator)Marshal.GetDelegateForFunctionPointer(api_.CreateAndRegisterAllocator, typeof(DOrtCreateAndRegisterAllocator)); + OrtUnregisterAllocator = (DOrtUnregisterAllocator)Marshal.GetDelegateForFunctionPointer(api_.UnregisterAllocator, typeof(DOrtUnregisterAllocator)); OrtSetLanguageProjection = (DOrtSetLanguageProjection)Marshal.GetDelegateForFunctionPointer(api_.SetLanguageProjection, typeof(DOrtSetLanguageProjection)); OrtHasValue = (DOrtHasValue)Marshal.GetDelegateForFunctionPointer(api_.HasValue, typeof(DOrtHasValue)); @@ -696,11 +703,11 @@ static NativeMethods() OrtReleaseROCMProviderOptions = (DOrtReleaseROCMProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseROCMProviderOptions, typeof(DOrtReleaseROCMProviderOptions)); OrtCreateAndRegisterAllocatorV2 = (DCreateAndRegisterAllocatorV2)Marshal.GetDelegateForFunctionPointer(api_.CreateAndRegisterAllocatorV2, typeof(DCreateAndRegisterAllocatorV2)); OrtRunAsync = (DOrtRunAsync)Marshal.GetDelegateForFunctionPointer(api_.RunAsync, typeof(DOrtRunAsync)); - CreateLoraAdapter = (DCreateLoraAdapter)Marshal.GetDelegateForFunctionPointer(api_.CreateLoraAdapter, - typeof(DCreateLoraAdapter)); - CreateLoraAdapterFromArray = (DCreateLoraAdapterFromArray)Marshal.GetDelegateForFunctionPointer (api_.CreateLoraAdapterFromArray, typeof(DCreateLoraAdapterFromArray)); - ReleaseLoraAdapter = (DReleaseLoraAdapter)Marshal.GetDelegateForFunctionPointer(api_.ReleaseLoraAdapter, - typeof(DReleaseLoraAdapter)); + OrtCreateLoraAdapter = (DOrtCreateLoraAdapter)Marshal.GetDelegateForFunctionPointer(api_.CreateLoraAdapter, + typeof(DOrtCreateLoraAdapter)); + OrtCreateLoraAdapterFromArray = (DOrtCreateLoraAdapterFromArray)Marshal.GetDelegateForFunctionPointer(api_.CreateLoraAdapterFromArray, typeof(DOrtCreateLoraAdapterFromArray)); + OrtReleaseLoraAdapter = (DOrtReleaseLoraAdapter)Marshal.GetDelegateForFunctionPointer(api_.ReleaseLoraAdapter, + typeof(DOrtReleaseLoraAdapter)); OrtRunOptionsAddActiveLoraAdapter = (DOrtRunOptionsAddActiveLoraAdapter)Marshal.GetDelegateForFunctionPointer( api_.RunOptionsAddActiveLoraAdapter, typeof(DOrtRunOptionsAddActiveLoraAdapter)); @@ -759,12 +766,15 @@ static NativeMethods() OrtEpDevice_Device = (DOrtEpDevice_Device)Marshal.GetDelegateForFunctionPointer( api_.EpDevice_Device, typeof(DOrtEpDevice_Device)); - OrtRegisterExecutionProviderLibrary = + OrtEpDevice_MemoryInfo = (DOrtEpDevice_MemoryInfo)Marshal.GetDelegateForFunctionPointer( + api_.EpDevice_MemoryInfo, typeof(DOrtEpDevice_MemoryInfo)); + + OrtRegisterExecutionProviderLibrary = (DOrtRegisterExecutionProviderLibrary)Marshal.GetDelegateForFunctionPointer( api_.RegisterExecutionProviderLibrary, typeof(DOrtRegisterExecutionProviderLibrary)); - OrtUnregisterExecutionProviderLibrary = + OrtUnregisterExecutionProviderLibrary = (DOrtUnregisterExecutionProviderLibrary)Marshal.GetDelegateForFunctionPointer( api_.UnregisterExecutionProviderLibrary, typeof(DOrtUnregisterExecutionProviderLibrary)); @@ -773,12 +783,12 @@ static NativeMethods() api_.GetEpDevices, typeof(DOrtGetEpDevices)); - OrtSessionOptionsAppendExecutionProvider_V2 = + OrtSessionOptionsAppendExecutionProvider_V2 = (DOrtSessionOptionsAppendExecutionProvider_V2)Marshal.GetDelegateForFunctionPointer( api_.SessionOptionsAppendExecutionProvider_V2, typeof(DOrtSessionOptionsAppendExecutionProvider_V2)); - OrtSessionOptionsSetEpSelectionPolicy = + OrtSessionOptionsSetEpSelectionPolicy = (DSessionOptionsSetEpSelectionPolicy)Marshal.GetDelegateForFunctionPointer( api_.SessionOptionsSetEpSelectionPolicy, typeof(DSessionOptionsSetEpSelectionPolicy)); @@ -817,6 +827,40 @@ static NativeMethods() api_.CreateExternalInitializerInfo, typeof(DOrtCreateExternalInitializerInfo)); + OrtCreateSharedAllocator = + (DOrtCreateSharedAllocator)Marshal.GetDelegateForFunctionPointer( + api_.CreateSharedAllocator, + typeof(DOrtCreateSharedAllocator)); + + OrtGetSharedAllocator = + (DOrtGetSharedAllocator)Marshal.GetDelegateForFunctionPointer( + api_.GetSharedAllocator, + typeof(DOrtGetSharedAllocator)); + + OrtReleaseSharedAllocator = + (DOrtReleaseSharedAllocator)Marshal.GetDelegateForFunctionPointer( + api_.ReleaseSharedAllocator, + typeof(DOrtReleaseSharedAllocator)); + + OrtCreateSyncStreamForEpDevice = + (DOrtCreateSyncStreamForEpDevice)Marshal.GetDelegateForFunctionPointer( + api_.CreateSyncStreamForEpDevice, + typeof(DOrtCreateSyncStreamForEpDevice)); + + OrtSyncStream_GetHandle = + (DOrtSyncStream_GetHandle)Marshal.GetDelegateForFunctionPointer( + api_.SyncStream_GetHandle, + typeof(DOrtSyncStream_GetHandle)); + + OrtReleaseSyncStream = + (DOrtReleaseSyncStream)Marshal.GetDelegateForFunctionPointer( + api_.ReleaseSyncStream, + typeof(DOrtReleaseSyncStream)); + + OrtCopyTensors = + (DOrtCopyTensors)Marshal.GetDelegateForFunctionPointer( + api_.CopyTensors, + typeof(DOrtCopyTensors)); } internal class NativeLib @@ -839,7 +883,7 @@ internal class NativeLib public static extern ref OrtApiBase OrtGetApiBase(); #endif -#region Runtime / Environment API + #region Runtime / Environment API [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* OrtStatus* */ DOrtCreateEnv( @@ -896,9 +940,59 @@ internal class NativeLib public delegate IntPtr /* OrtStatus* */ DOrtUpdateEnvWithCustomLogLevel(IntPtr /*(OrtEnv*)*/ env, OrtLoggingLevel custom_log_level); public static DOrtUpdateEnvWithCustomLogLevel OrtUpdateEnvWithCustomLogLevel; -#endregion Runtime / Environment API + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DCreateAndRegisterAllocatorV2( + IntPtr /* OrtEnv* */ environment, + IntPtr /*const char* */ provderType, + IntPtr /* const OrtMemoryInfo* */ memInfo, + IntPtr /* const OrtArenaCfg* */ arenaCfg, + IntPtr[] /* const char* const* */ providerOptionsKeys, + IntPtr[] /* const char* const* */ providerOptionsValues, + UIntPtr /* size_t */ numKeys); + public static DCreateAndRegisterAllocatorV2 OrtCreateAndRegisterAllocatorV2; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtCreateSharedAllocator( + IntPtr /* OrtEnv* */ ortEnv, + IntPtr /* OrtEpDevice* */ epDevice, + OrtDeviceMemoryType deviceMemoryType, + OrtAllocatorType allocatorType, + IntPtr /* const OrtKeyValuePairs* */ allocatorOptions, + out IntPtr /* OrtAllocator** */ allocator); + + public static DOrtCreateSharedAllocator OrtCreateSharedAllocator; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtGetSharedAllocator( + IntPtr /*(OrtEnv*)*/ env, + IntPtr /*(const OrtMemoryInfo*)*/ memInfo, + out IntPtr /* OrtAllocator** */ allocator); + + public static DOrtGetSharedAllocator OrtGetSharedAllocator; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtReleaseSharedAllocator( + IntPtr /*(OrtEnv*)*/ env, + IntPtr /* const OrtEpDevice* */ epDevice, + OrtDeviceMemoryType deviceMemoryType); + + public static DOrtReleaseSharedAllocator OrtReleaseSharedAllocator; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtCopyTensors( + IntPtr /* const OrtEnv* */ env, + IntPtr[] /* const OrtValue* const* */ srcTensors, + IntPtr[] /* OrtValue* const* */ dstTensors, + IntPtr /* OrtSynStream* */ stream, + UIntPtr /* size_t */ numTensors + ); + + public static DOrtCopyTensors OrtCopyTensors; -#region Provider Options API + + #endregion Runtime / Environment API + + #region Provider Options API /// /// Creates native OrtTensorRTProviderOptions instance @@ -1032,9 +1126,9 @@ internal class NativeLib public delegate void DOrtReleaseROCMProviderOptions(IntPtr /*(OrtROCMProviderOptions*)*/ rocmProviderOptionsInstance); public static DOrtReleaseROCMProviderOptions OrtReleaseROCMProviderOptions; -#endregion + #endregion -#region Status API + #region Status API [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate ErrorCode DOrtGetErrorCode(IntPtr /*(OrtStatus*)*/ status); public static DOrtGetErrorCode OrtGetErrorCode; @@ -1049,12 +1143,12 @@ internal class NativeLib public delegate void DOrtReleaseStatus(IntPtr /*(OrtStatus*)*/ statusPtr); public static DOrtReleaseStatus OrtReleaseStatus; -#endregion Status API + #endregion Status API -#region InferenceSession API + #region InferenceSession API [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* OrtStatus* */ DOrtCreateStatus( - uint /* OrtErrorCode */ code, + uint /* OrtErrorCode */ code, byte[] /* const char* */ msg); public static DOrtCreateStatus OrtCreateStatus; @@ -1216,6 +1310,30 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca public static DOrtSessionGetOverridableInitializerTypeInfo OrtSessionGetOverridableInitializerTypeInfo; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionGetMemoryInfoForInputs( + IntPtr /*(const OrtSession*)*/ session, + IntPtr[] /* const OrtMemoryInfo** */ inputsMemoryInfos, + UIntPtr /* size_t */ numInputs); + + public static DOrtSessionGetMemoryInfoForInputs OrtSessionGetMemoryInfoForInputs; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionGetMemoryInfoForOutputs( + IntPtr /*(const OrtSession*)*/ session, + IntPtr[] /* OrtMemoryInfo** */ outputsMemoryInfos, + UIntPtr /* size_t */ numOutputs); + + public static DOrtSessionGetMemoryInfoForOutputs OrtSessionGetMemoryInfoForOutputs; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionGetEpDeviceForInputs( + IntPtr /*(const OrtSession*)*/ session, + IntPtr[] /* const OrtDevice** */ devices, + UIntPtr /* size_t */ numInputs); + + public static DOrtSessionGetEpDeviceForInputs OrtSessionGetEpDeviceForInputs; + // release the typeinfo using OrtReleaseTypeInfo [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate void DOrtReleaseTypeInfo(IntPtr /*(OrtTypeInfo*)*/ session); @@ -1231,17 +1349,6 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca out UIntPtr /*(ulong* out)*/ startTime); public static DOrtSessionGetProfilingStartTimeNs OrtSessionGetProfilingStartTimeNs; - [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(ONNStatus*)*/ DCreateAndRegisterAllocatorV2( - IntPtr /* (OrtEnv*) */ environment, - IntPtr /*(char*)*/ provider_type, - IntPtr /*(OrtMemoryInfo*)*/ mem_info, - IntPtr /*(OrtArenaCfg*)*/ arena_cfg, - IntPtr /*(char**)*/ provider_options_keys, - IntPtr /*(char**)*/ provider_options_values, - UIntPtr /*(size_t)*/ num_keys); - public static DCreateAndRegisterAllocatorV2 OrtCreateAndRegisterAllocatorV2; - [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(ONNStatus*)*/ DOrtRunAsync( IntPtr /*(OrtSession*)*/ session, @@ -1256,9 +1363,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca IntPtr /*(void*)*/ user_data); public static DOrtRunAsync OrtRunAsync; -#endregion InferenceSession API + #endregion InferenceSession API -#region SessionOptions API + #region SessionOptions API [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtCreateSessionOptions(out IntPtr /*(OrtSessionOptions**)*/ sessionOptions); @@ -1546,9 +1653,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca public static DSessionOptionsAppendExecutionProvider SessionOptionsAppendExecutionProvider; -#endregion + #endregion -#region LoraAdapter API + #region LoraAdapter API /// /// Memory maps the adapter file, wraps it into the adapter object /// and returns it. @@ -1558,12 +1665,12 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca /// New LoraAdapter object /// [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/ DCreateLoraAdapter( + public delegate IntPtr /*(OrtStatus*)*/ DOrtCreateLoraAdapter( byte[] adapter_path, // This takes const ORTCHAR_T* use GetPlatformSerializedString IntPtr /* OrtAllocator */ allocator, // optional out IntPtr lora_adapter ); - public static DCreateLoraAdapter CreateLoraAdapter; + public static DOrtCreateLoraAdapter OrtCreateLoraAdapter; /// /// Creates LoraAdapter instance from a byte array that must @@ -1575,22 +1682,22 @@ out IntPtr lora_adapter /// resulting LoraAdapter instance /// [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate IntPtr /*(OrtStatus*)*/ DCreateLoraAdapterFromArray( + public delegate IntPtr /*(OrtStatus*)*/ DOrtCreateLoraAdapterFromArray( byte[] bytes, UIntPtr size, IntPtr /* OrtAllocator */ allocator, // optional out IntPtr lora_adapter ); - public static DCreateLoraAdapterFromArray CreateLoraAdapterFromArray; + public static DOrtCreateLoraAdapterFromArray OrtCreateLoraAdapterFromArray; [UnmanagedFunctionPointer(CallingConvention.Winapi)] - public delegate void DReleaseLoraAdapter(IntPtr /* OrtLoraAdapter* */ lora_adapter); - public static DReleaseLoraAdapter ReleaseLoraAdapter; + public delegate void DOrtReleaseLoraAdapter(IntPtr /* OrtLoraAdapter* */ lora_adapter); + public static DOrtReleaseLoraAdapter OrtReleaseLoraAdapter; -#endregion + #endregion -#region RunOptions API + #region RunOptions API [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtCreateRunOptions(out IntPtr /* OrtRunOptions** */ runOptions); @@ -1653,9 +1760,9 @@ out IntPtr lora_adapter byte[] /* const char* */ configValue); public static DOrtAddRunConfigEntry OrtAddRunConfigEntry; -#endregion + #endregion -#region ThreadingOptions API + #region ThreadingOptions API [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtCreateThreadingOptions(out IntPtr /* OrtCreateThreadingOptions** */ threadingOptions); @@ -1680,9 +1787,9 @@ out IntPtr lora_adapter [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtThreadingOptionsSetGlobalSpinControl(IntPtr /* OrtThreadingOptions* */ threadingOptions, int allowSpinning); public static DOrtThreadingOptionsSetGlobalSpinControl OrtThreadingOptionsSetGlobalSpinControl; -#endregion + #endregion -#region Allocator / MemoryInfo API + #region Allocator / MemoryInfo API [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* (OrtStatus*)*/ DOrtCreateMemoryInfo( @@ -1695,6 +1802,20 @@ out IntPtr lora_adapter public static DOrtCreateMemoryInfo OrtCreateMemoryInfo; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* (OrtStatus*)*/ DOrtCreateMemoryInfoV2( + byte[] /*(const char*) */ name, + OrtMemoryInfoDeviceType memInfoDeviceType, + UInt32 /* uint32_t */ vendorId, + Int32 /* int32_t */ deviceId, + OrtDeviceMemoryType deviceMemoryType, + UIntPtr /* size_t */ alignment, + OrtAllocatorType allocatorType, + out IntPtr /*(OrtMemoryInfo*)*/ allocatorInfo // memory ownership transferred to caller + ); + + public static DOrtCreateMemoryInfoV2 OrtCreateMemoryInfoV2; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* (OrtStatus*)*/ DOrtCreateCpuMemoryInfo( OrtAllocatorType allocatorType, @@ -1743,6 +1864,18 @@ out IntPtr lora_adapter public static DOrtMemoryInfoGetType OrtMemoryInfoGetType; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate OrtDeviceMemoryType DOrtMemoryInfoGetDeviceMemType( + IntPtr /*(const OrtMemoryInfo* ptr)*/ memoryInfo); + + public static DOrtMemoryInfoGetDeviceMemType OrtMemoryInfoGetDeviceMemType; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate UInt32 DOrtMemoryInfoGetVendorId( + IntPtr /*(const OrtMemoryInfo* ptr)*/ memoryInfo); + + public static DOrtMemoryInfoGetVendorId OrtMemoryInfoGetVendorId; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtGetAllocatorWithDefaultOptions(out IntPtr /*(OrtAllocator**)*/ allocator); @@ -1819,9 +1952,9 @@ out IntPtr lora_adapter public static DOrtAllocatorFree OrtAllocatorFree; -#endregion Allocator / MemoryInfo API + #endregion Allocator / MemoryInfo API -#region IoBinding API + #region IoBinding API /// /// Create OrtIoBinding instance that is used to bind memory that is allocated @@ -1985,7 +2118,8 @@ out IntPtr lora_adapter /// /// Creates an allocator instance and registers it with the env to enable /// sharing between multiple sessions that use the same env instance. - /// Lifetime of the created allocator will be valid for the duration of the environment. + /// Lifetime of the created allocator will be valid for the duration of the environment + /// or until it is explicitly unregistered by UnregisterAllocator. /// Returns an error if an allocator with the same OrtMemoryInfo is already registered. /// /// Native OrtEnv instance @@ -1999,6 +2133,20 @@ out IntPtr lora_adapter public static DOrtCreateAndRegisterAllocator OrtCreateAndRegisterAllocator; + + /// + /// Unregisters an allocator that was previously registered with the env using + /// or . + /// + /// valid env + /// meminfo used for registering the allocator + /// + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtUnregisterAllocator(IntPtr /*(OrtEnv*)*/ env, + IntPtr /*(const OrtMemoryInfo*)*/ memInfo); + + public static DOrtUnregisterAllocator OrtUnregisterAllocator; + /// /// Set the language projection for collecting telemetry data when Env is created /// @@ -2009,9 +2157,9 @@ out IntPtr lora_adapter public static DOrtSetLanguageProjection OrtSetLanguageProjection; -#endregion IoBinding API + #endregion IoBinding API -#region ModelMetadata API + #region ModelMetadata API /// /// Gets the ModelMetadata associated with an InferenceSession @@ -2129,9 +2277,9 @@ out IntPtr lora_adapter public static DOrtReleaseModelMetadata OrtReleaseModelMetadata; -#endregion ModelMetadata API + #endregion ModelMetadata API -#region OrtValue API + #region OrtValue API [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtHasValue(IntPtr /*(OrtValue*)*/ value, out IntPtr /*(int*)*/ hasValue); @@ -2397,9 +2545,9 @@ out IntPtr lora_adapter public static DOrtReleaseValue OrtReleaseValue; -#endregion + #endregion -#region Compile API + #region Compile API #if NETSTANDARD2_0 [UnmanagedFunctionPointer(CallingConvention.Winapi)] @@ -2473,9 +2621,10 @@ out IntPtr /* OrtExternalInitializerInfo** */ newExternalInfo public static DOrtExternalInitializerInfo_GetFilePath OrtExternalInitializerInfo_GetFilePath; public static DOrtExternalInitializerInfo_GetFileOffset OrtExternalInitializerInfo_GetFileOffset; public static DOrtExternalInitializerInfo_GetByteSize OrtExternalInitializerInfo_GetByteSize; -#endregion -#region Auto EP API related + #endregion + + #region Auto EP API related // // OrtKeyValuePairs @@ -2582,12 +2731,36 @@ public delegate void DOrtRemoveKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps, public delegate IntPtr /* const OrtHardwareDevice* */ DOrtEpDevice_Device( IntPtr /* const OrtEpDevice* */ ep_device); + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* const OrtMemoryInfo* */ DOrtEpDevice_MemoryInfo( + IntPtr /* const OrtEpDevice* */ ep_device, OrtDeviceMemoryType deviceMemoryType); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtCreateSyncStreamForEpDevice( + IntPtr /* const OrtEpDevice* */ epDevice, + IntPtr /* const OrtKeyValuePairs* */ streamOptions, + out IntPtr /* OrtSyncStream** */ stream + ); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* void* */ DOrtSyncStream_GetHandle( + IntPtr /* OrtSyncStream* */ stream + ); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtReleaseSyncStream( + IntPtr /* OrtSyncStream* */ stream + ); public static DOrtEpDevice_EpName OrtEpDevice_EpName; public static DOrtEpDevice_EpVendor OrtEpDevice_EpVendor; public static DOrtEpDevice_EpMetadata OrtEpDevice_EpMetadata; public static DOrtEpDevice_EpOptions OrtEpDevice_EpOptions; public static DOrtEpDevice_Device OrtEpDevice_Device; + public static DOrtEpDevice_MemoryInfo OrtEpDevice_MemoryInfo; + public static DOrtCreateSyncStreamForEpDevice OrtCreateSyncStreamForEpDevice; + public static DOrtSyncStream_GetHandle OrtSyncStream_GetHandle; + public static DOrtReleaseSyncStream OrtReleaseSyncStream; // // Auto Selection EP registration and selection customization @@ -2763,7 +2936,7 @@ public delegate IntPtr DOrtEpSelectionDelegate( public static DOrtReleasePrepackedWeightsContainer OrtReleasePrepackedWeightsContainer; -#endregion + #endregion } // class NativeMethods // onnxruntime-extensions helpers to make usage simpler. diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs index 4611428ea12ef..f5dc253195ab1 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.shared.cs @@ -373,23 +373,21 @@ internal static void Update(Dictionary providerOptions, IntPtr handle, Func updateFunc) { - var keyStrings = providerOptions.Keys.ToArray(); - var valStrings = providerOptions.Values.ToArray(); - MarshaledStringArray keys = default; MarshaledStringArray values = default; try { - keys = new MarshaledStringArray(keyStrings); - values = new MarshaledStringArray(valStrings); + keys = new MarshaledStringArray(providerOptions.Keys); + values = new MarshaledStringArray(providerOptions.Values); - var nativeKeys = new IntPtr[keyStrings.Length]; + var nativeKeys = new IntPtr[providerOptions.Count]; keys.Fill(nativeKeys); - var nativeVals = new IntPtr[valStrings.Length]; + var nativeVals = new IntPtr[providerOptions.Count]; values.Fill(nativeVals); - NativeApiStatus.VerifySuccess(updateFunc(handle, nativeKeys, nativeVals, (UIntPtr)providerOptions.Count)); + NativeApiStatus.VerifySuccess(updateFunc(handle, nativeKeys, nativeVals, + (UIntPtr)providerOptions.Count)); } finally { diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtAllocator.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtAllocator.shared.cs index 3f918fc2ad6c8..c189cc1856252 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtAllocator.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtAllocator.shared.cs @@ -3,6 +3,7 @@ using Microsoft.ML.OnnxRuntime.Tensors; using System; +using System.Reflection; using System.Runtime.InteropServices; using System.Text; @@ -28,6 +29,28 @@ public enum OrtMemType Default = 0, // the default allocator for execution provider } + /// + /// See documentation for OrtDeviceMemoryType in C API + /// This matches OrtDevice::MemoryType values + /// + public enum OrtDeviceMemoryType + { + DEFAULT = 0, /// Device memory + HOST_ACCESSIBLE = 5, /// Shared/pinned memory for transferring between CPU and the device + } + + /// + /// See documentation for OrtMemoryInfoDeviceType in C API + /// This mimics OrtDevice type constants so they can be returned in the API + /// + public enum OrtMemoryInfoDeviceType + { + CPU = 0, + GPU = 1, + FPGA = 2, + NPU = 3, + } + /// /// This class encapsulates arena configuration information that will be used to define the behavior /// of an arena based allocator @@ -103,7 +126,8 @@ public class OrtMemoryInfo : SafeHandle private static OrtMemoryInfo CreateCpuMemoryInfo() { // Returns OrtMemoryInfo instance that needs to be disposed - NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateCpuMemoryInfo(OrtAllocatorType.DeviceAllocator, OrtMemType.Cpu, out IntPtr memoryInfo)); + NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateCpuMemoryInfo(OrtAllocatorType.DeviceAllocator, + OrtMemType.Cpu, out IntPtr memoryInfo)); return new OrtMemoryInfo(memoryInfo, true); } @@ -203,6 +227,26 @@ public OrtMemoryInfo(byte[] utf8AllocatorName, OrtAllocatorType allocatorType, i public OrtMemoryInfo(string allocatorName, OrtAllocatorType allocatorType, int deviceId, OrtMemType memoryType) : this(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(allocatorName), allocatorType, deviceId, memoryType) { + + } + + /// + /// Creates an instance of OrtMemoryInfo using OrtCreateMemoryInfoV2 + /// + /// In this overload this is an arbitrary name + /// Device Type + /// Vendor Id + /// Device Id + /// Device Memory Type + /// Alignment is required or 0 + /// Allocator Type + public OrtMemoryInfo(string allocatorName, OrtMemoryInfoDeviceType deviceType, uint vendorId, + int deviceId, OrtDeviceMemoryType deviceMemoryType, ulong alignment, OrtAllocatorType allocatorType) + : base(IntPtr.Zero, true) + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateMemoryInfoV2( + NativeOnnxValueHelper.StringToZeroTerminatedUtf8(allocatorName), + deviceType, vendorId, deviceId, deviceMemoryType, (UIntPtr)alignment, allocatorType, out handle)); } /// @@ -252,6 +296,24 @@ public OrtAllocatorType GetAllocatorType() return allocatorType; } + /// + /// Return the device memory type associated with this memory info + /// + /// OrtDeviceMemoryType for the device + public OrtDeviceMemoryType GetDeviceMemoryType() + { + return NativeMethods.OrtMemoryInfoGetDeviceMemType(handle); + } + + /// + /// Fetches vendor ID + /// + /// uint32_t + public uint GetVendorId() + { + return NativeMethods.OrtMemoryInfoGetVendorId(handle); + } + /// /// Overrides System.Object.Equals(object) /// @@ -493,12 +555,6 @@ internal IntPtr Pointer } } - /// - /// Overrides SafeHandle.IsInvalid - /// - /// returns true if handle is equal to Zero - public override bool IsInvalid { get { return handle == IntPtr.Zero; } } - /// /// Internal constructor wraps existing native allocators /// @@ -560,6 +616,14 @@ internal void FreeMemory(IntPtr allocation) } #region SafeHandle + + /// + /// Overrides SafeHandle.IsInvalid + /// + /// returns true if handle is equal to Zero + public override bool IsInvalid { get { return handle == IntPtr.Zero; } } + + /// /// Overrides SafeHandle.ReleaseHandle() to properly dispose of /// the native instance of OrtAllocator diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs index 052d5899b52c0..6fcff438c5cf3 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs @@ -329,14 +329,115 @@ public void DisableTelemetryEvents() } /// - /// Create and register an allocator to the OrtEnv instance - /// so as to enable sharing across all sessions using the OrtEnv instance + /// Create and register an allocator to the OrtEnv instance. + /// This API enhance CreateAndRegisterAllocator that it can create an allocator with specific type, not just CPU allocator + /// Enables sharing the allocator between multiple sessions that use the same env instance. + /// Lifetime of the created allocator will be valid for the duration of the environment. + /// so as to enable sharing across all sessions using the OrtEnv instance. /// OrtMemoryInfo instance to be used for allocator creation /// OrtArenaCfg instance that will be used to define the behavior of the arena based allocator /// public void CreateAndRegisterAllocator(OrtMemoryInfo memInfo, OrtArenaCfg arenaCfg) { - NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateAndRegisterAllocator(Handle, memInfo.Pointer, arenaCfg.Pointer)); + NativeApiStatus.VerifySuccess( + NativeMethods.OrtCreateAndRegisterAllocator(Handle, memInfo.Pointer, arenaCfg.Pointer)); + } + + /// + /// Create and register an allocator to the OrtEnv instance. + /// Use UnregisterAllocator to unregister it. + /// + /// + /// + /// + /// + public void CreateAndRegisterAllocator(string providerType, OrtMemoryInfo memInfo, OrtArenaCfg arenaCfg, + IReadOnlyDictionary provider_options) + { + MarshaledStringArray marshalledKeys = default; + MarshaledStringArray marshalledValues = default; + var keysPtrs = new IntPtr[provider_options.Count]; + var valuesPtrs = new IntPtr[provider_options.Count]; + + try + { + marshalledKeys = new MarshaledStringArray(provider_options.Keys); + marshalledValues = new MarshaledStringArray(provider_options.Values); + marshalledKeys.Fill(keysPtrs); + marshalledValues.Fill(valuesPtrs); + using var marshalledProviderType = new MarshaledString(providerType); + + NativeApiStatus.VerifySuccess( + NativeMethods.OrtCreateAndRegisterAllocatorV2(Handle, marshalledProviderType.Value, + memInfo.Pointer, arenaCfg.Pointer, + keysPtrs, valuesPtrs, + (UIntPtr)provider_options.Count)); + } + finally + { + marshalledValues.Dispose(); + marshalledKeys.Dispose(); + } + } + + /// + /// Unregister a custom allocator previously registered with the OrtEnv instance + /// using CreateAndRegisterAllocator + /// The memory info instance should correspond the one that is used for registration + /// + /// The memory info instance should correspond the one that is used for registration + public void UnregisterAllocator(OrtMemoryInfo memInfo) + { + NativeApiStatus.VerifySuccess( + NativeMethods.OrtUnregisterAllocator(Handle, memInfo.Pointer)); + } + + /// + /// Creates shared allocator owned by the OrtEnv instance. + /// + /// + /// + /// + /// allocator specific options + /// OrtAllocator instance + public OrtAllocator CreateSharedAllocator(OrtEpDevice epDevice, OrtDeviceMemoryType deviceMemoryType, + OrtAllocatorType ortAllocatorType, IReadOnlyDictionary allocatorOptions) + { + using var keyValueOptions = new OrtKeyValuePairs(allocatorOptions); + NativeApiStatus.VerifySuccess( + NativeMethods.OrtCreateSharedAllocator(Handle, epDevice.Handle, deviceMemoryType, + ortAllocatorType, keyValueOptions.Handle, out IntPtr allocatorHandle)); + return new OrtAllocator(allocatorHandle, /* owned= */ false); + } + + /// + /// Returns a shared allocator owned by the OrtEnv instance if such exists + /// (was previously created). If no such allocator exists, the API returns null. + /// + /// + /// OrtAllocator instance or null if the requested allocator does not exist + public OrtAllocator GetSharedAllocator(OrtMemoryInfo memoryInfo) + { + NativeApiStatus.VerifySuccess( + NativeMethods.OrtGetSharedAllocator(Handle, memoryInfo.Pointer, out IntPtr allocatorHandle)); + if (allocatorHandle == IntPtr.Zero) + { + return null; + } + return new OrtAllocator(allocatorHandle, /* owned= */ false); + } + + /// + /// Release a shared allocator from the OrtEnv for the OrtEpDevice and memory type. + /// This will release the shared allocator for the given OrtEpDevice and memory type. + /// If no shared allocator exists, this is a no-op. + /// + /// + /// + public void ReleaseSharedAllocator(OrtEpDevice epDevice, OrtDeviceMemoryType deviceMemoryType) + { + NativeApiStatus.VerifySuccess( + NativeMethods.OrtReleaseSharedAllocator(Handle, epDevice.Handle, deviceMemoryType)); } /// @@ -477,7 +578,37 @@ public IReadOnlyList GetEpDevices() } return epDevices.AsReadOnly(); - } + } + + /// + /// Copies data from source OrtValue tensors to destination OrtValue tensors. + /// The tensors may reside on difference devices if such are supported + /// by the registered execution providers. + /// + /// Source OrtValues + /// pre-allocated OrtValues + /// optional stream or null + /// + public void CopyTensors(IReadOnlyList srcValues, IReadOnlyList dstValues, + OrtSyncStream stream) + { + IntPtr streamHandle = stream != null ? stream.Handle : IntPtr.Zero; + IntPtr[] srcPtrs = new IntPtr[srcValues.Count]; + IntPtr[] dstPtrs = new IntPtr[dstValues.Count]; + + for (int i = 0; i < srcPtrs.Length; i++) + { + if (srcValues[i] == null) + throw new ArgumentNullException($"srcValues[{i}]"); + if (dstValues[i] == null) + throw new ArgumentNullException($"dstValues[{i}]"); + srcPtrs[i] = srcValues[i].Handle; + dstPtrs[i] = dstValues[i].Handle; + } + + NativeApiStatus.VerifySuccess( + NativeMethods.OrtCopyTensors(handle, srcPtrs, dstPtrs, streamHandle, (UIntPtr)srcPtrs.Length)); + } #endregion diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEpDevice.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEpDevice.shared.cs index 0318e08519128..9e59754374464 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEpDevice.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEpDevice.shared.cs @@ -2,10 +2,51 @@ // Licensed under the MIT License. using System; +using System.Collections.Generic; using System.Runtime.InteropServices; namespace Microsoft.ML.OnnxRuntime { + /// + /// Represents a synchronization primitive for stream operations. + /// + public class OrtSyncStream : SafeHandle + { + internal OrtSyncStream(IntPtr streamHandle) + : base(IntPtr.Zero, true) // Provide required arguments to SafeHandle constructor + { + handle = streamHandle; + } + + /// + /// Fetch sync stream handle for possible use + /// in session options. + /// + /// Opaque stream handle + public IntPtr GetHandle() + { + return NativeMethods.OrtSyncStream_GetHandle(handle); + } + + internal IntPtr Handle => handle; + + /// + /// Implements SafeHandle interface + /// + public override bool IsInvalid => handle == IntPtr.Zero; + + /// + /// Implements SafeHandle interface to release native handle + /// + /// always true + protected override bool ReleaseHandle() + { + NativeMethods.OrtReleaseSyncStream(handle); + handle = IntPtr.Zero; + return true; + } + } + /// /// Represents the combination of an execution provider and a hardware device /// that the execution provider can utilize. @@ -81,6 +122,46 @@ public OrtHardwareDevice HardwareDevice } } + /// + /// The OrtMemoryInfo instance describing the memory characteristics of the device. + /// + /// memory type requested + /// + public OrtMemoryInfo GetMemoryInfo(OrtDeviceMemoryType deviceMemoryType) + { + IntPtr memoryInfoPtr = NativeMethods.OrtEpDevice_MemoryInfo(_handle, deviceMemoryType); + return new OrtMemoryInfo(memoryInfoPtr, /* owned= */ false); + } + + /// + /// Creates a synchronization stream for operations on this device. + /// Can be used to implement async operations on the device such as + /// CopyTensors. + /// + /// stream options can be null + /// + public OrtSyncStream CreateSyncStream(IReadOnlyDictionary streamOptions) + { + OrtKeyValuePairs options = null; + IntPtr optionsHandle = IntPtr.Zero; + try + { + if (streamOptions != null) + { + options = new OrtKeyValuePairs(streamOptions); + optionsHandle = options.Handle; + } + + NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSyncStreamForEpDevice(_handle, + optionsHandle, out IntPtr syncStream)); + return new OrtSyncStream(syncStream); + } + finally + { + options?.Dispose(); + } + } + private readonly IntPtr _handle; } } \ No newline at end of file diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtKeyValuePairs.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtKeyValuePairs.shared.cs index 6a8d1037d9017..50fd1965231e1 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtKeyValuePairs.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtKeyValuePairs.shared.cs @@ -169,6 +169,11 @@ private Dictionary GetLatest() return dict; } + /// + /// Native handle to the OrtKeyValuePairs instance. + /// + internal IntPtr Handle => handle; + /// /// Indicates whether the native handle is invalid. /// diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtLoraAdapter.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtLoraAdapter.shared.cs index e2249b4c47fec..f1c03faccf16f 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtLoraAdapter.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtLoraAdapter.shared.cs @@ -23,7 +23,7 @@ public static OrtLoraAdapter Create(string adapterPath, OrtAllocator ortAllocato { var platformPath = NativeOnnxValueHelper.GetPlatformSerializedString(adapterPath); var allocatorHandle = (ortAllocator != null) ? ortAllocator.Pointer : IntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.CreateLoraAdapter(platformPath, allocatorHandle, + NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateLoraAdapter(platformPath, allocatorHandle, out IntPtr adapterHandle)); return new OrtLoraAdapter(adapterHandle); } @@ -38,7 +38,7 @@ public static OrtLoraAdapter Create(string adapterPath, OrtAllocator ortAllocato public static OrtLoraAdapter Create(byte[] bytes, OrtAllocator ortAllocator) { var allocatorHandle = (ortAllocator != null) ? ortAllocator.Pointer : IntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.CreateLoraAdapterFromArray(bytes, + NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateLoraAdapterFromArray(bytes, new UIntPtr((uint)bytes.Length), allocatorHandle, out IntPtr adapterHandle)); return new OrtLoraAdapter(adapterHandle); } @@ -71,7 +71,7 @@ internal IntPtr Handle /// always returns true protected override bool ReleaseHandle() { - NativeMethods.ReleaseLoraAdapter(handle); + NativeMethods.OrtReleaseLoraAdapter(handle); handle = IntPtr.Zero; return true; } diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs index 0a39d965979ca..73613541f8362 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs @@ -4,13 +4,10 @@ using Microsoft.ML.OnnxRuntime.Tensors; using System; using System.Collections.Generic; +using System.IO; using System.Linq; -using System.Linq.Expressions; -using System.Runtime.CompilerServices; using System.Runtime.InteropServices; -using System.Text.RegularExpressions; -using System.Threading; using System.Threading.Tasks; using Xunit; using Xunit.Abstractions; @@ -837,7 +834,7 @@ private async Task TestMultiThreads() Assert.Equal(res, expectedOut, (IEqualityComparer)new FloatComparer()); } })); - }; + } await Task.WhenAll(tasks); session.Dispose(); } @@ -1694,37 +1691,52 @@ private void TestInferenceSessionWithByteArray() void TestCPUAllocatorInternal(InferenceSession session) + { int device_id = 0; - using (var info_cpu = new OrtMemoryInfo(OrtMemoryInfo.allocatorCPU, OrtAllocatorType.ArenaAllocator, device_id, OrtMemType.Default)) - { - Assert.Equal("Cpu", info_cpu.Name); - Assert.Equal(device_id, info_cpu.Id); - Assert.Equal(OrtAllocatorType.ArenaAllocator, info_cpu.GetAllocatorType()); - Assert.Equal(OrtMemType.Default, info_cpu.GetMemoryType()); - - using (var allocator = new OrtAllocator(session, info_cpu)) - { - var alloc_info = allocator.Info; - // Allocator type returned may be different on x86 so we don't compare. - Assert.Equal(info_cpu.Name, alloc_info.Name); - Assert.Equal(info_cpu.GetMemoryType(), alloc_info.GetMemoryType()); - Assert.Equal(info_cpu.Id, alloc_info.Id); - - uint size = 1024; - OrtMemoryAllocation chunk = allocator.Allocate(size); - Assert.Equal(chunk.Size, size); - var chunk_info = chunk.Info; - // Allocator type returned may be different on x86 so we don't compare. - Assert.Equal(chunk_info.Name, alloc_info.Name); - Assert.Equal(chunk_info.GetMemoryType(), alloc_info.GetMemoryType()); - Assert.Equal(chunk_info.Id, alloc_info.Id); - chunk.Dispose(); - alloc_info.Dispose(); - } - } + using var info_cpu = new OrtMemoryInfo(OrtMemoryInfo.allocatorCPU, + OrtAllocatorType.ArenaAllocator, device_id, OrtMemType.Default); + Assert.Equal("Cpu", info_cpu.Name); + Assert.Equal(device_id, info_cpu.Id); + Assert.Equal(OrtAllocatorType.ArenaAllocator, info_cpu.GetAllocatorType()); + Assert.Equal(OrtMemType.Default, info_cpu.GetMemoryType()); + var deviceMemoryType = info_cpu.GetDeviceMemoryType(); + Assert.Equal(OrtDeviceMemoryType.DEFAULT, deviceMemoryType); + Assert.Equal(0U, info_cpu.GetVendorId()); + + using var allocator = new OrtAllocator(session, info_cpu); + using var alloc_info = allocator.Info; + // Allocator type returned may be different on x86 so we don't compare. + Assert.Equal(info_cpu.Name, alloc_info.Name); + Assert.Equal(info_cpu.GetMemoryType(), alloc_info.GetMemoryType()); + Assert.Equal(info_cpu.Id, alloc_info.Id); + + uint size = 1024; + using OrtMemoryAllocation chunk = allocator.Allocate(size); + Assert.Equal(chunk.Size, size); + var chunk_info = chunk.Info; + // Allocator type returned may be different on x86 so we don't compare. + Assert.Equal(chunk_info.Name, alloc_info.Name); + Assert.Equal(chunk_info.GetMemoryType(), alloc_info.GetMemoryType()); + Assert.Equal(chunk_info.Id, alloc_info.Id); + } + + [Fact(DisplayName = "TestMemoryInfoCreateV2")] + void TestMemoryInfoCreateV2() + { + const int device_id = 0; + const uint vendor_id = 1234U; + using var info_cpu = new OrtMemoryInfo("Test_CPU", OrtMemoryInfoDeviceType.CPU, vendor_id, device_id, + OrtDeviceMemoryType.DEFAULT, 0, OrtAllocatorType.DeviceAllocator); + Assert.Equal("Test_CPU", info_cpu.Name); + Assert.Equal(device_id, info_cpu.Id); + Assert.Equal(OrtAllocatorType.DeviceAllocator, info_cpu.GetAllocatorType()); + Assert.Equal(OrtMemType.Default, info_cpu.GetMemoryType()); + Assert.Equal(OrtDeviceMemoryType.DEFAULT, info_cpu.GetDeviceMemoryType()); + Assert.Equal(vendor_id, info_cpu.GetVendorId()); } + #if USE_CUDA void TestCUDAAllocatorInternal(InferenceSession session) { @@ -1896,81 +1908,6 @@ private void TestSharingOfInitializerAndItsPrepackedVersion() } } - [Fact(DisplayName = "TestSharedAllocatorUsingCreateAndRegisterAllocator")] - private void TestSharedAllocatorUsingCreateAndRegisterAllocator() - { - var model = TestDataLoader.LoadModelFromEmbeddedResource("mul_1.onnx"); - - using (var memInfo = new OrtMemoryInfo(OrtMemoryInfo.allocatorCPU, - OrtAllocatorType.ArenaAllocator, 0, OrtMemType.Default)) - using (var arenaCfg = new OrtArenaCfg(0, -1, -1, -1)) - { - var env = OrtEnv.Instance(); - // Create and register the arena based allocator - env.CreateAndRegisterAllocator(memInfo, arenaCfg); - - using (var sessionOptions = new SessionOptions()) - { - // Key must match kOrtSessionOptionsConfigUseEnvAllocators in onnxruntime_session_options_config_keys.h - sessionOptions.AddSessionConfigEntry("session.use_env_allocators", "1"); - - // Create two sessions to share the allocator - // Create a third session that DOES NOT use the allocator in the environment - using (var session1 = new InferenceSession(model, sessionOptions)) - using (var session2 = new InferenceSession(model, sessionOptions)) - using (var session3 = new InferenceSession(model)) // Use the default SessionOptions instance - { - // Input data - var inputDims = new long[] { 3, 2 }; - var input = new float[] { 1.0F, 2.0F, 3.0F, 4.0F, 5.0F, 6.0F }; - - // Output data - int[] outputDims = { 3, 2 }; - float[] output = { 1.0F, 4.0F, 9.0F, 16.0F, 25.0F, 36.0F }; - - // Run inference on all three models - var inputMeta = session1.InputMetadata; - var container = new List(); - - foreach (var name in inputMeta.Keys) - { - Assert.Equal(typeof(float), inputMeta[name].ElementType); - Assert.True(inputMeta[name].IsTensor); - var tensor = new DenseTensor(input, inputMeta[name].Dimensions); - container.Add(NamedOnnxValue.CreateFromTensor(name, tensor)); - } - - // Run inference with named inputs and outputs created with in Run() - using (var results = session1.Run(container)) // results is an IReadOnlyList container - { - foreach (var r in results) - { - ValidateRunResultData(r.AsTensor(), output, outputDims); - } - } - - // Run inference with named inputs and outputs created with in Run() - using (var results = session2.Run(container)) // results is an IReadOnlyList container - { - foreach (var r in results) - { - ValidateRunResultData(r.AsTensor(), output, outputDims); - } - } - - // Run inference with named inputs and outputs created with in Run() - using (var results = session3.Run(container)) // results is an IReadOnlyList container - { - foreach (var r in results) - { - ValidateRunResultData(r.AsTensor(), output, outputDims); - } - } - } - } - } - } - internal static Tuple, float[]> OpenSessionSqueezeNet(int? deviceId = null) { var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs index 9368f9d8bc298..1be0b6e9530ed 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs @@ -60,6 +60,8 @@ public void GetEpDevices() Assert.NotNull(metadata); var options = ep_device.EpOptions; Assert.NotNull(options); + var memInfo = ep_device.GetMemoryInfo(OrtDeviceMemoryType.DEFAULT); + Assert.NotNull(memInfo); ReadHardwareDeviceValues(ep_device.HardwareDevice); } } @@ -77,14 +79,17 @@ public void RegisterUnregisterLibrary() // register. shouldn't throw ortEnvInstance.RegisterExecutionProviderLibrary(epName, libFullPath); - - // check OrtEpDevice was found - var epDevices = ortEnvInstance.GetEpDevices(); - var found = epDevices.Any(d => string.Equals(epName, d.EpName, StringComparison.OrdinalIgnoreCase)); - Assert.True(found); - - // unregister - ortEnvInstance.UnregisterExecutionProviderLibrary(epName); + try + { + // check OrtEpDevice was found + var epDevices = ortEnvInstance.GetEpDevices(); + var found = epDevices.Any(d => string.Equals(epName, d.EpName, StringComparison.OrdinalIgnoreCase)); + Assert.True(found); + } + finally + { // unregister + ortEnvInstance.UnregisterExecutionProviderLibrary(epName); + } } } diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs index 229d683c162fd..ae4fb0cf164cd 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs @@ -1,7 +1,17 @@ -using System; +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using Microsoft.ML.OnnxRuntime.Tensors; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Linq.Expressions; +using System.Runtime.ConstrainedExecution; +using System.Runtime.InteropServices; using Xunit; - namespace Microsoft.ML.OnnxRuntime.Tests { /// @@ -212,5 +222,275 @@ public void TestEnvWithCustomLoggerAndThredingOptions() } } } + + [Collection("Ort Inference Tests")] + public class OrtEnvSharedAllocatorsTests + { + private void ValidateRunResultData(Tensor resultTensor, float[] expectedOutput, int[] expectedDimensions) + { + Assert.Equal(expectedDimensions.Length, resultTensor.Rank); + + var resultDimensions = resultTensor.Dimensions; + for (int i = 0; i < expectedDimensions.Length; i++) + { + Assert.Equal(expectedDimensions[i], resultDimensions[i]); + } + + var resultArray = new float[resultTensor.Length]; + for (int i = 0; i < resultTensor.Length; i++) + { + resultArray[i] = resultTensor.GetValue(i); + } + Assert.Equal(expectedOutput.Length, resultArray.Length); + Assert.Equal(expectedOutput, resultArray, new FloatComparer()); + } + + [Fact(DisplayName = "TestSharedAllocatorUsingCreateAndRegisterAllocator")] + private void TestSharedAllocatorUsingCreateAndRegisterAllocator() + { + var model = TestDataLoader.LoadModelFromEmbeddedResource("mul_1.onnx"); + + using var memInfo = new OrtMemoryInfo(OrtMemoryInfo.allocatorCPU, + OrtAllocatorType.ArenaAllocator, 0, OrtMemType.Default); + using var arenaCfg = new OrtArenaCfg(0, -1, -1, -1); + var env = OrtEnv.Instance(); + // Create and register the arena based allocator + env.CreateAndRegisterAllocator(memInfo, arenaCfg); + try + { + using var sessionOptions = new SessionOptions(); + // Key must match kOrtSessionOptionsConfigUseEnvAllocators in onnxruntime_session_options_config_keys.h + sessionOptions.AddSessionConfigEntry("session.use_env_allocators", "1"); + + // Create two sessions to share the allocator + // Create a third session that DOES NOT use the allocator in the environment + using var session1 = new InferenceSession(model, sessionOptions); + using var session2 = new InferenceSession(model, sessionOptions); + using var session3 = new InferenceSession(model); // Use the default SessionOptions instance + // Input data + var inputDims = new long[] { 3, 2 }; + var input = new float[] { 1.0F, 2.0F, 3.0F, 4.0F, 5.0F, 6.0F }; + + // Output data + int[] outputDims = { 3, 2 }; + float[] output = { 1.0F, 4.0F, 9.0F, 16.0F, 25.0F, 36.0F }; + + // Run inference on all three models + var inputMeta = session1.InputMetadata; + var container = new List(); + + foreach (var name in inputMeta.Keys) + { + Assert.Equal(typeof(float), inputMeta[name].ElementType); + Assert.True(inputMeta[name].IsTensor); + var tensor = new DenseTensor(input, inputMeta[name].Dimensions); + container.Add(NamedOnnxValue.CreateFromTensor(name, tensor)); + } + + // Run inference with named inputs and outputs created with in Run() + using var results = session1.Run(container); // results is an IReadOnlyList container + foreach (var r in results) + { + ValidateRunResultData(r.AsTensor(), output, outputDims); + } + + // Run inference with named inputs and outputs created with in Run() + using var results2 = session2.Run(container); // results is an IReadOnlyList container + foreach (var r in results2) + { + ValidateRunResultData(r.AsTensor(), output, outputDims); + } + + // Run inference with named inputs and outputs created with in Run() + using var results3 = session3.Run(container); // results is an IReadOnlyList container + foreach (var r in results3) + { + ValidateRunResultData(r.AsTensor(), output, outputDims); + } + } + finally + { + // Unregister the allocator + env.UnregisterAllocator(memInfo); + } + } + + [Fact(DisplayName = "TestSharedAllocatorUsingCreateAndRegisterAllocatorV2")] + private void TestSharedAllocatorUsingCreateAndRegisterAllocatorV2() + { + var model = TestDataLoader.LoadModelFromEmbeddedResource("mul_1.onnx"); + + using var memInfo = new OrtMemoryInfo(OrtMemoryInfo.allocatorCPU, + OrtAllocatorType.ArenaAllocator, 0, OrtMemType.Default); + using var arenaCfg = new OrtArenaCfg(0, -1, -1, -1); + var env = OrtEnv.Instance(); + + // Fill in with two arbitrary key-value pairs + var options = new Dictionary() { + { "key1", "value1" }, + { "key2", "value2" } + }; + + // Simply execute CreateAndRegisterAllocatorV2 to verify that C# API works as expected + env.CreateAndRegisterAllocator("CPUExecutionProvider", memInfo, arenaCfg, options); + try + { + using var sessionOptions = new SessionOptions(); + // Key must match kOrtSessionOptionsConfigUseEnvAllocators in onnxruntime_session_options_config_keys.h + sessionOptions.AddSessionConfigEntry("session.use_env_allocators", "1"); + using var session = new InferenceSession(model, sessionOptions); + } + finally + { + // Unregister the allocator + env.UnregisterAllocator(memInfo); + } + } + [Fact(DisplayName = "TestCreateGetReleaseSharedAllocator")] + private void TestCreateGetReleaseSharedAllocator() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + var env = OrtEnv.Instance(); + string libFullPath = Path.Combine(Directory.GetCurrentDirectory(), "example_plugin_ep.dll"); + Assert.True(File.Exists(libFullPath), $"Expected lib {libFullPath} does not exist."); + + // example plugin ep uses the registration name as the ep name + const string epName = "csharp_ep"; + + env.RegisterExecutionProviderLibrary(epName, libFullPath); + try + { + // Find OrtEpDevice for the example EP + OrtEpDevice epDevice = null; + var epDevices = env.GetEpDevices(); + foreach (var d in epDevices) + { + if (string.Equals(epName, d.EpName, StringComparison.OrdinalIgnoreCase)) + { + epDevice = d; + } + } + Assert.NotNull(epDevice); + + using var epMemoryInfo = epDevice.GetMemoryInfo(OrtDeviceMemoryType.DEFAULT); + + var options = new Dictionary() { + { "arena.initial_chunk_size_bytes", "25600" }, + }; + + // Strictly speaking the allocator is owned by the env + // but we want to dispose the C# object anyway + using var sharedAllocator = env.CreateSharedAllocator(epDevice, + OrtDeviceMemoryType.DEFAULT, + OrtAllocatorType.DeviceAllocator, + options); + + try + { + using var getAllocator = env.GetSharedAllocator(epMemoryInfo); + Assert.NotNull(getAllocator); + } + finally + { + // ReleaseSharedAllocator is a no-op if the allocator was created with CreateAndRegisterAllocator + env.ReleaseSharedAllocator(epDevice, OrtDeviceMemoryType.DEFAULT); + } + } + finally + { + env.UnregisterExecutionProviderLibrary(epName); + } + } + } + + [Fact(DisplayName = "TestCopyTensors")] + void TestCopyTensors() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + var env = OrtEnv.Instance(); + string libFullPath = Path.Combine(Directory.GetCurrentDirectory(), "example_plugin_ep.dll"); + Assert.True(File.Exists(libFullPath), $"Expected lib {libFullPath} does not exist."); + + // example plugin ep uses the registration name as the ep name + const string epName = "csharp_ep"; + + env.RegisterExecutionProviderLibrary(epName, libFullPath); + try + { + // Find the example device + OrtEpDevice epDevice = null; + var epDevices = env.GetEpDevices(); + foreach (var d in epDevices) + { + if (string.Equals(epName, d.EpName, StringComparison.OrdinalIgnoreCase)) + { + epDevice = d; + } + } + Assert.NotNull(epDevice); + + using var syncStream = epDevice.CreateSyncStream(null); + Assert.NotNull(syncStream); + // This returned Zero for example EP + // therefore do not assert for zero. + var streamHandle = syncStream.GetHandle(); + // Assert.NotEqual(IntPtr.Zero, streamHandle); + + var inputDims = new long[] { 3, 2 }; + float[] inputData1 = [1.0F, 2.0F, 3.0F, 4.0F, 5.0F, 6.0F]; + long[] inputData2 = [1, 2, 3, 4, 5, 6]; + + // Create source OrtValues on CPU on top of inputData + using var inputList = new DisposableListTest(2) + { + OrtValue.CreateTensorValueFromMemory(inputData1, inputDims), + OrtValue.CreateTensorValueFromMemory(inputData2, inputDims) + }; + + using var epMemoryInfo = epDevice.GetMemoryInfo(OrtDeviceMemoryType.DEFAULT); + var options = new Dictionary() { + { "arena.initial_chunk_size_bytes", "25600" }, + }; + + // Strictly speaking the allocator is owned by the env + // but we want to dispose the C# object anyway + using var sharedAllocator = env.CreateSharedAllocator(epDevice, + OrtDeviceMemoryType.DEFAULT, + OrtAllocatorType.DeviceAllocator, + options); + try + { + // Create destination empty OrtValues on the example EP device + using var outputList = new DisposableListTest(2) + { + OrtValue.CreateAllocatedTensorValue(sharedAllocator, + TensorElementType.Float, inputDims), + OrtValue.CreateAllocatedTensorValue(sharedAllocator, + TensorElementType.Int64, inputDims) + }; + + env.CopyTensors(inputList, outputList, syncStream); + + // Assert.Equal data on inputList and outputList + Assert.Equal(inputList[0].GetTensorDataAsSpan(), + outputList[0].GetTensorDataAsSpan()); + Assert.Equal(inputList[1].GetTensorDataAsSpan(), + outputList[1].GetTensorDataAsSpan()); + } + finally + { + // Unregister from the env + env.ReleaseSharedAllocator(epDevice, OrtDeviceMemoryType.DEFAULT); + } + } + finally + { + env.UnregisterExecutionProviderLibrary(epName); + } + } + } + } } diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs index eab4a3d412898..89dbce05326b5 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs @@ -42,33 +42,38 @@ public partial class InferenceTest public void CanCreateAndDisposeSessionWithModelPath() { string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "squeezenet.onnx"); - using (var session = new InferenceSession(modelPath)) + using var session = new InferenceSession(modelPath); + Assert.NotNull(session); + Assert.NotNull(session.InputMetadata); + Assert.Single(session.InputMetadata); // 1 input nodeMeta + Assert.True(session.InputMetadata.ContainsKey("data_0")); // input nodeMeta name + Assert.Equal(typeof(float), session.InputMetadata["data_0"].ElementType); + Assert.True(session.InputMetadata["data_0"].IsTensor); + var expectedInputDimensions = new int[] { 1, 3, 224, 224 }; + Assert.Equal(expectedInputDimensions.Length, session.InputMetadata["data_0"].Dimensions.Length); + for (int i = 0; i < expectedInputDimensions.Length; i++) { - Assert.NotNull(session); - Assert.NotNull(session.InputMetadata); - Assert.Single(session.InputMetadata); // 1 input nodeMeta - Assert.True(session.InputMetadata.ContainsKey("data_0")); // input nodeMeta name - Assert.Equal(typeof(float), session.InputMetadata["data_0"].ElementType); - Assert.True(session.InputMetadata["data_0"].IsTensor); - var expectedInputDimensions = new int[] { 1, 3, 224, 224 }; - Assert.Equal(expectedInputDimensions.Length, session.InputMetadata["data_0"].Dimensions.Length); - for (int i = 0; i < expectedInputDimensions.Length; i++) - { - Assert.Equal(expectedInputDimensions[i], session.InputMetadata["data_0"].Dimensions[i]); - } + Assert.Equal(expectedInputDimensions[i], session.InputMetadata["data_0"].Dimensions[i]); + } - Assert.NotNull(session.OutputMetadata); - Assert.Single(session.OutputMetadata); // 1 output nodeMeta - Assert.True(session.OutputMetadata.ContainsKey("softmaxout_1")); // output nodeMeta name - Assert.Equal(typeof(float), session.OutputMetadata["softmaxout_1"].ElementType); - Assert.True(session.OutputMetadata["softmaxout_1"].IsTensor); - var expectedOutputDimensions = new int[] { 1, 1000, 1, 1 }; - Assert.Equal(expectedOutputDimensions.Length, session.OutputMetadata["softmaxout_1"].Dimensions.Length); - for (int i = 0; i < expectedOutputDimensions.Length; i++) - { - Assert.Equal(expectedOutputDimensions[i], session.OutputMetadata["softmaxout_1"].Dimensions[i]); - } + Assert.NotNull(session.OutputMetadata); + Assert.Single(session.OutputMetadata); // 1 output nodeMeta + Assert.True(session.OutputMetadata.ContainsKey("softmaxout_1")); // output nodeMeta name + Assert.Equal(typeof(float), session.OutputMetadata["softmaxout_1"].ElementType); + Assert.True(session.OutputMetadata["softmaxout_1"].IsTensor); + var expectedOutputDimensions = new int[] { 1, 1000, 1, 1 }; + Assert.Equal(expectedOutputDimensions.Length, session.OutputMetadata["softmaxout_1"].Dimensions.Length); + for (int i = 0; i < expectedOutputDimensions.Length; i++) + { + Assert.Equal(expectedOutputDimensions[i], session.OutputMetadata["softmaxout_1"].Dimensions[i]); } + + using var inputsMemoryInfos = session.GetMemoryInfosForInputs(); + Assert.Equal(session.InputNames.Count, inputsMemoryInfos.Count); + using var outputsMemoryInfos = session.GetMemoryInfosForOutputs(); + Assert.Equal(session.OutputNames.Count, outputsMemoryInfos.Count); + var inputsEpDevices = session.GetEpDeviceForInputs(); + Assert.Equal(session.InputNames.Count, inputsEpDevices.Count); } #if NET8_0_OR_GREATER @@ -154,7 +159,7 @@ public void InferenceSessionDisposedDotnetTensors() { Assert.Equal(typeof(float), inputMeta[name].ElementType); Assert.True(inputMeta[name].IsTensor); - var tensor = SystemNumericsTensors.Tensor.Create(inputData, inputMeta[name].Dimensions.Select(x => (nint) x).ToArray()); + var tensor = SystemNumericsTensors.Tensor.Create(inputData, inputMeta[name].Dimensions.Select(x => (nint)x).ToArray()); inputOrtValues.Add(new DisposableTestPair(name, OrtValue.CreateTensorValueFromSystemNumericsTensorObject(tensor))); } diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index c17acc9ffff3a..479898beae83e 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1889,6 +1889,9 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra py::class_ py_sync_stream(m, "OrtSyncStream", R"pbdoc(Represents a synchronization stream for model inference.)pbdoc"); + py_sync_stream.def("get_handle", [](OrtSyncStream* stream) -> uintptr_t { + Ort::UnownedSyncStream ort_stream(stream); + return reinterpret_cast(ort_stream.GetHandle()); }, R"pbdoc(SyncStream handle that can be converted to a string and added to SessionOptions)pbdoc"); py::class_ py_ep_device(m, "OrtEpDevice", R"pbdoc(Represents a hardware device that an execution provider supports diff --git a/onnxruntime/test/python/onnxruntime_test_python_autoep.py b/onnxruntime/test/python/onnxruntime_test_python_autoep.py index d6281d165c053..d66951bd66f3d 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_autoep.py +++ b/onnxruntime/test/python/onnxruntime_test_python_autoep.py @@ -232,6 +232,8 @@ def test_example_plugin_ep_devices(self): test_sync_stream = test_ep_device.create_sync_stream() self.assertIsNotNone(test_sync_stream) + stream_handle = test_sync_stream.get_handle() + self.assertIsNotNone(stream_handle) del test_sync_stream # Add EP plugin's OrtEpDevice to the SessionOptions. From 25183d5a13e9f14a57365f4692e47a186aa45943 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Wed, 24 Sep 2025 14:50:12 -0700 Subject: [PATCH 06/10] Regenerate test model with ONNX IR < 12 (#26149) ### Description - Regenerates the `input_propagate_to_output.onnx` model used in [this unit test](https://github.com/microsoft/onnxruntime/blob/35dcab5088118117acc6086c9b6dd6dd92c7060f/onnxruntime/test/shared_lib/test_inference.cc#L497-L506) so that it uses an ONNX IR version compatible with ONNX 1.18.0 (i.e., IR version < 12). - Adds script `input_propagate_to_output.py` that can be used to regenerate the `input_propagate_to_output.onnx` model. - Embed missing weight values that are needed to run the existing `test_dangling_input_segment_ids.py` script. ### Motivation and Context The main branch is using ONNX 1.19. However, this unit test also needs to pass in the `rel-1.23.1` branch, which is still using ONNX 1.18.0. So, by downgrading the model's IR version, the unit test can run in both branches. See original PR that added the test models: https://github.com/microsoft/onnxruntime/pull/26021 --- .../testdata/input_propagated_to_output.onnx | Bin 854 -> 685 bytes .../testdata/input_propagated_to_output.py | 113 ++++++++++++++++++ .../test_dangling_input_segment_ids.py | 96 +++++++++++---- 3 files changed, 186 insertions(+), 23 deletions(-) create mode 100644 onnxruntime/test/testdata/input_propagated_to_output.py diff --git a/onnxruntime/test/testdata/input_propagated_to_output.onnx b/onnxruntime/test/testdata/input_propagated_to_output.onnx index feeab10556cb06cfb9fc59c9e03d84a6b41f55f7..28d805ce878681a032633b1061b1836c6d1ac5de 100644 GIT binary patch delta 70 zcmcb{ww9HZgWYNo>qOS&(o&gu1*IkN1x5JdX0^hgU4zW+eFsoygVhT#U=6SMTrF&@g@_Gn~RHaFbgm`F*zA}@tg`-w2v{M zYv10Ng8L@^HJ{wXC^6ZD(Qfl1Mpj0~tjQl4Wr3st6Nt Date: Thu, 25 Sep 2025 14:21:12 -0600 Subject: [PATCH 07/10] [CPU] Fix compilation errors because of unused variables (#26147) This PR fixes few unused variables --- onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc | 16 +- .../cpu/moe/moe_quantization_cpu.cc | 138 +++++++++--------- 2 files changed, 77 insertions(+), 77 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc index 73be099181aa9..a23ea07ac1cb8 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_cpu.cc @@ -12,6 +12,7 @@ #include "core/framework/float16.h" #include "core/framework/allocator.h" #include "core/platform/threadpool.h" +#include "core/common/narrow.h" #include #include @@ -120,7 +121,7 @@ Status MoE::ComputeMoE(const OpKernelContext* context, int num_routing_threads = 1; if (tp != nullptr && num_tokens >= 1024) { int max_threads = concurrency::ThreadPool::DegreeOfParallelism(tp); - num_routing_threads = std::min(static_cast(num_tokens / 512), max_threads); + num_routing_threads = std::min(narrow(num_tokens / 512), max_threads); num_routing_threads = std::max(1, num_routing_threads); } @@ -133,7 +134,7 @@ Status MoE::ComputeMoE(const OpKernelContext* context, } concurrency::ThreadPool::TrySimpleParallelFor(tp, num_routing_threads, [&](std::ptrdiff_t thread_id) { - auto work = concurrency::ThreadPool::PartitionWork(static_cast(thread_id), num_routing_threads, static_cast(num_tokens)); + auto work = concurrency::ThreadPool::PartitionWork(narrow(thread_id), num_routing_threads, static_cast(num_tokens)); auto& local_expert_token_map = thread_local_expert_token_maps[thread_id]; std::vector> sorted_logits(static_cast(num_experts)); @@ -173,7 +174,7 @@ Status MoE::ComputeMoE(const OpKernelContext* context, int64_t route_idx = i * k_ + j; float normalized_weight = sorted_logits[static_cast(j)].first * inv_top_k_sum; - route_expert[route_idx] = static_cast(expert_idx); + route_expert[route_idx] = narrow(expert_idx); route_scale[route_idx] = normalized_weight; if (normalized_weight > 0.0f) { local_expert_token_map[static_cast(expert_idx)].push_back(route_idx); @@ -185,7 +186,7 @@ Status MoE::ComputeMoE(const OpKernelContext* context, int64_t route_idx = i * k_ + j; float weight = sorted_logits[static_cast(j)].first; - route_expert[route_idx] = static_cast(expert_idx); + route_expert[route_idx] = narrow(expert_idx); route_scale[route_idx] = weight; if (weight > 0.0f) { local_expert_token_map[static_cast(expert_idx)].push_back(route_idx); @@ -319,7 +320,7 @@ Status MoE::ComputeMoE(const OpKernelContext* context, // Optimized expert processing with thread-local buffer reuse concurrency::ThreadPool::TrySimpleParallelFor(tp, num_expert_threads, [&](std::ptrdiff_t thread_id_pd) { - int thread_id = static_cast(thread_id_pd); + int thread_id = narrow(thread_id_pd); auto work = concurrency::ThreadPool::PartitionWork(thread_id, num_expert_threads, static_cast(num_experts)); float* local_output = thread_local_outputs + static_cast(thread_id) * output_buffer_size; @@ -440,6 +441,11 @@ Status MoE::ProcessExpertBatch(const T* input_tokens, int64_t inter_size, T* fc1_output_buffer, T* activation_output_buffer) const { + ORT_UNUSED_PARAMETER(token_expert_ids); + ORT_UNUSED_PARAMETER(token_weights); + ORT_UNUSED_PARAMETER(expert_id); + ORT_UNUSED_PARAMETER(fc1_output_buffer); + ORT_UNUSED_PARAMETER(activation_output_buffer); const bool is_swiglu = activation_type_ == ActivationType::SwiGLU; const int64_t fc1_output_size = is_swiglu ? (inter_size * 2) : inter_size; diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc index 8195c9438d408..8a3c3f6d9f37a 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -10,6 +10,7 @@ #include "core/providers/cpu/math/gemm_helper.h" #include "core/providers/cpu/activation/activations.h" #include "core/common/safeint.h" +#include "core/common/narrow.h" #include "core/framework/tensor_type_and_shape.h" #include "core/util/math.h" #include "contrib_ops/cpu/moe/moe_utils.h" @@ -50,7 +51,7 @@ inline int64_t GetDequantBlockSize(int64_t features, int64_t total_work) { return std::min(target_block_size, work_based_size); } -bool CanUseMlasQ4Dequant(int64_t num_bits, int64_t block_size) { +bool CanUseMlasQ4Dequant(int64_t num_bits) { if (num_bits != 4) { return false; } @@ -154,10 +155,11 @@ void DequantizeBlockWithMlas(const uint8_t* quantized_data, int64_t cols, float* dequantized_data, MLAS_THREADPOOL* thread_pool) { + ORT_UNUSED_PARAMETER(thread_pool); const float zero_point = num_bits == 8 ? 128.0f : 8.0f; const int64_t blocks_per_row = (block_size > 0) ? ((cols + block_size - 1) / block_size) : 1; - if (CanUseMlasQ4Dequant(num_bits, block_size)) { + if (CanUseMlasQ4Dequant(num_bits)) { const int64_t packed_cols = (cols + 1) / 2; if (block_size == 0) { @@ -420,7 +422,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const int max_threads = tp ? concurrency::ThreadPool::DegreeOfParallelism(tp) : 1; const int64_t thread_divisor = std::max(1, max_threads * 4); const int64_t min_work_per_thread = std::max(int64_t{32}, static_cast(num_tokens / thread_divisor)); - const int optimal_routing_threads = (tp == nullptr || num_tokens < min_work_per_thread) ? 1 : std::min(static_cast(num_tokens / std::max(int64_t{1}, min_work_per_thread)), max_threads); + const int optimal_routing_threads = (tp == nullptr || num_tokens < min_work_per_thread) ? 1 : std::min(narrow(num_tokens / std::max(int64_t{1}, min_work_per_thread)), max_threads); const int num_routing_threads = std::max(1, optimal_routing_threads); std::vector>> thread_local_expert_token_maps(num_routing_threads); @@ -432,7 +434,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { } concurrency::ThreadPool::TrySimpleParallelFor(tp, num_routing_threads, [&](std::ptrdiff_t thread_id) { - auto work = concurrency::ThreadPool::PartitionWork(static_cast(thread_id), num_routing_threads, static_cast(num_tokens)); + auto work = concurrency::ThreadPool::PartitionWork(narrow(thread_id), num_routing_threads, static_cast(num_tokens)); auto& local_expert_token_map = thread_local_expert_token_maps[thread_id]; std::vector> sorted_logits(static_cast(num_experts)); @@ -441,8 +443,8 @@ Status QMoECPU::Compute(OpKernelContext* context) const { for (int64_t i = work.start; i < work.end; ++i) { const float* logits = router_logits_float + i * num_experts; - for (int64_t j = 0; j < num_experts; ++j) { - sorted_logits[static_cast(j)] = {logits[j], j}; + for (size_t j = 0; j < narrow(num_experts); ++j) { + sorted_logits[j] = {logits[j], j}; } std::partial_sort(sorted_logits.begin(), sorted_logits.begin() + static_cast(k_), sorted_logits.end(), std::greater<>()); @@ -450,17 +452,17 @@ Status QMoECPU::Compute(OpKernelContext* context) const { float max_logit = sorted_logits[0].first; float sum_exp = 0.0f; - for (int64_t j = 0; j < k_; ++j) { - top_k_exp[static_cast(j)] = std::exp(sorted_logits[static_cast(j)].first - max_logit); - sum_exp += top_k_exp[static_cast(j)]; + for (size_t j = 0; j < narrow(k_); ++j) { + top_k_exp[j] = std::exp(sorted_logits[j].first - max_logit); + sum_exp += top_k_exp[j]; } const float inv_sum = (sum_exp == 0.0f) ? 0.0f : (1.0f / sum_exp); - for (int64_t j = 0; j < k_; ++j) { - int64_t expert_idx = sorted_logits[static_cast(j)].second; - int64_t route_idx = i * k_ + j; - route_expert[route_idx] = static_cast(expert_idx); - route_scale[route_idx] = top_k_exp[static_cast(j)] * inv_sum; + for (size_t j = 0; j < narrow(k_); ++j) { + int64_t expert_idx = sorted_logits[j].second; + int64_t route_idx = i * k_ + narrow(j); + route_expert[route_idx] = narrow(expert_idx); + route_scale[route_idx] = top_k_exp[j] * inv_sum; if (route_scale[route_idx] > 1e-8f) { // Use small threshold to avoid zero weights local_expert_token_map[static_cast(expert_idx)].push_back(route_idx); } @@ -504,7 +506,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const int64_t expert_thread_divisor = std::max(1, max_expert_threads * 8); const int64_t min_expert_work_per_thread = std::max(int64_t{16}, total_expert_work / expert_thread_divisor); - int num_expert_threads = (tp == nullptr || total_expert_work < min_expert_work_per_thread) ? 1 : std::min(static_cast(total_expert_work / std::max(int64_t{1}, min_expert_work_per_thread)), std::min(static_cast(num_experts), max_expert_threads)); + int num_expert_threads = (tp == nullptr || total_expert_work < min_expert_work_per_thread) ? 1 : std::min(narrow(total_expert_work / std::max(int64_t{1}, min_expert_work_per_thread)), std::min(narrow(num_experts), max_expert_threads)); if (num_expert_threads == 0) num_expert_threads = 1; auto thread_local_outputs_ptr = IAllocator::MakeUniquePtr(allocator, static_cast(num_expert_threads) * output_buffer_size); @@ -585,7 +587,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { } concurrency::ThreadPool::TrySimpleParallelFor(tp, num_expert_threads, [&](std::ptrdiff_t thread_id_pd) { - const int thread_id = static_cast(thread_id_pd); + const int thread_id = narrow(thread_id_pd); const auto& expert_batch = expert_batches[static_cast(thread_id)]; float* thread_workspace = workspace + static_cast(thread_id) * workspace_elements_per_thread; @@ -612,7 +614,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const int64_t num_blocks = (num_expert_tokens + dynamic_block_size - 1) / dynamic_block_size; if (num_expert_tokens >= 8 && num_blocks > 1 && tp != nullptr) { - concurrency::ThreadPool::TrySimpleParallelFor(tp, static_cast(num_blocks), [&](std::ptrdiff_t block_idx) { + concurrency::ThreadPool::TrySimpleParallelFor(tp, narrow(num_blocks), [&](std::ptrdiff_t block_idx) { const int64_t start_idx = block_idx * dynamic_block_size; const int64_t end_idx = std::min(start_idx + dynamic_block_size, num_expert_tokens); @@ -633,14 +635,14 @@ Status QMoECPU::Compute(OpKernelContext* context) const { if (ShouldUseMemcpy(hidden_size)) { std::memcpy(dst, src, static_cast(hidden_size) * sizeof(float)); } else { - const int64_t unroll_factor = GetUnrollFactor(hidden_size); - int64_t j = 0; - for (; j + unroll_factor <= hidden_size; j += unroll_factor) { - for (int64_t k = 0; k < unroll_factor; ++k) { + const size_t unroll_factor = narrow(GetUnrollFactor(hidden_size)); + size_t j = 0; + for (; j + unroll_factor <= narrow(hidden_size); j += unroll_factor) { + for (size_t k = 0; k < unroll_factor; ++k) { dst[j + k] = src[j + k]; } } - for (; j < hidden_size; ++j) { + for (; j < narrow(hidden_size); ++j) { dst[j] = src[j]; } } @@ -705,10 +707,6 @@ Status QMoECPU::Compute(OpKernelContext* context) const { if (gemm_status.IsOK()) { fc1_used_direct_q4 = true; -#ifdef ONNXRUNTIME_ENABLE_VERBOSE_LOGGING - LOGS_DEFAULT(VERBOSE) << "QMoE: Using direct MLAS Q4 GEMM for FC1 expert " << expert_idx - << " (M=" << num_expert_tokens << ", N=" << fc1_out_features << ", K=" << hidden_size << ")"; -#endif goto fc1_gemm_done; } } @@ -717,7 +715,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { // Traditional approach: dequantize + regular GEMM if (num_dequant_blocks > 1 && fc1_out_features >= 32) { - concurrency::ThreadPool::TrySimpleParallelFor(tp, static_cast(num_dequant_blocks), [&](std::ptrdiff_t block_idx) { + concurrency::ThreadPool::TrySimpleParallelFor(tp, narrow(num_dequant_blocks), [&](std::ptrdiff_t block_idx) { const int64_t start_row = block_idx * dequant_block_size; const int64_t end_row = std::min(start_row + dequant_block_size, fc1_out_features); const auto offset = expert_idx * fc1_out_features * fc1_packed_cols + start_row * fc1_packed_cols; @@ -749,14 +747,14 @@ Status QMoECPU::Compute(OpKernelContext* context) const { if (ShouldUseMemcpy(fc1_out_features)) { std::memcpy(thread_bias1_buffer, B1_bias, static_cast(fc1_out_features) * sizeof(float)); } else { - const int64_t unroll_factor = GetUnrollFactor(fc1_out_features); - int64_t j = 0; - for (; j + unroll_factor <= fc1_out_features; j += unroll_factor) { - for (int64_t k = 0; k < unroll_factor; ++k) { - thread_bias1_buffer[j + k] = static_cast(B1_bias[j + k]); + const size_t unroll_factor = static_cast(GetUnrollFactor(fc1_out_features)); + size_t j = 0; + for (; j + unroll_factor <= static_cast(fc1_out_features); j += unroll_factor) { + for (size_t loop_k = 0; loop_k < unroll_factor; ++loop_k) { + thread_bias1_buffer[j + loop_k] = static_cast(B1_bias[j + loop_k]); } } - for (; j < fc1_out_features; ++j) { + for (; j < static_cast(fc1_out_features); ++j) { thread_bias1_buffer[j] = static_cast(B1_bias[j]); } } @@ -764,15 +762,15 @@ Status QMoECPU::Compute(OpKernelContext* context) const { for (int64_t i = 0; i < num_expert_tokens; ++i) { float* C1_row = C1 + i * fc1_out_features; - const int64_t unroll_factor = GetUnrollFactor(fc1_out_features); + const size_t unroll_factor = static_cast(GetUnrollFactor(fc1_out_features)); - int64_t j = 0; - for (; j + unroll_factor <= fc1_out_features; j += unroll_factor) { - for (int64_t k = 0; k < unroll_factor; ++k) { - C1_row[j + k] += thread_bias1_buffer[j + k]; + size_t j = 0; + for (; j + unroll_factor <= static_cast(fc1_out_features); j += unroll_factor) { + for (size_t loop_k = 0; loop_k < unroll_factor; ++loop_k) { + C1_row[j + loop_k] += thread_bias1_buffer[j + loop_k]; } } - for (; j < fc1_out_features; ++j) { + for (; j < static_cast(fc1_out_features); ++j) { C1_row[j] += thread_bias1_buffer[j]; } } @@ -786,7 +784,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const int64_t num_activation_blocks = (num_expert_tokens + activation_block_size - 1) / activation_block_size; if (num_activation_blocks > 1) { - concurrency::ThreadPool::TrySimpleParallelFor(tp, static_cast(num_activation_blocks), [&](std::ptrdiff_t block_idx) { + concurrency::ThreadPool::TrySimpleParallelFor(tp, narrow(num_activation_blocks), [&](std::ptrdiff_t block_idx) { const int64_t start_token = block_idx * activation_block_size; const int64_t end_token = std::min(start_token + activation_block_size, num_expert_tokens); @@ -868,10 +866,6 @@ Status QMoECPU::Compute(OpKernelContext* context) const { if (gemm_status.IsOK()) { fc2_used_direct_q4 = true; -#ifdef ONNXRUNTIME_ENABLE_VERBOSE_LOGGING - LOGS_DEFAULT(VERBOSE) << "QMoE: Using direct MLAS Q4 GEMM for FC2 expert " << expert_idx - << " (M=" << num_expert_tokens << ", N=" << hidden_size << ", K=" << inter_size << ")"; -#endif goto fc2_gemm_done; } } @@ -881,7 +875,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { // Traditional approach: dequantize + regular GEMM if (num_fc2_dequant_blocks > 1 && hidden_size >= 32) { - concurrency::ThreadPool::TrySimpleParallelFor(tp, static_cast(num_fc2_dequant_blocks), [&](std::ptrdiff_t block_idx) { + concurrency::ThreadPool::TrySimpleParallelFor(tp, narrow(num_fc2_dequant_blocks), [&](std::ptrdiff_t block_idx) { const int64_t start_row = block_idx * fc2_dequant_block_size; const int64_t end_row = std::min(start_row + fc2_dequant_block_size, hidden_size); const auto offset = expert_idx * hidden_size * fc2_packed_cols + start_row * fc2_packed_cols; @@ -915,14 +909,14 @@ Status QMoECPU::Compute(OpKernelContext* context) const { if (ShouldUseMemcpy(hidden_size)) { std::memcpy(thread_bias2_buffer, B2_bias, static_cast(hidden_size) * sizeof(float)); } else { - const int64_t unroll_factor = GetUnrollFactor(hidden_size); - int64_t j = 0; - for (; j + unroll_factor <= hidden_size; j += unroll_factor) { - for (int64_t k = 0; k < unroll_factor; ++k) { - thread_bias2_buffer[j + k] = static_cast(B2_bias[j + k]); + const size_t unroll_factor = narrow(GetUnrollFactor(hidden_size)); + size_t j = 0; + for (; j + unroll_factor <= narrow(hidden_size); j += unroll_factor) { + for (size_t loop_k = 0; loop_k < unroll_factor; ++loop_k) { + thread_bias2_buffer[j + loop_k] = static_cast(B2_bias[j + loop_k]); } } - for (; j < hidden_size; ++j) { + for (; j < narrow(hidden_size); ++j) { thread_bias2_buffer[j] = static_cast(B2_bias[j]); } } @@ -943,25 +937,25 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const float* src = C2 + i * hidden_size; if (has_fc2_bias && !fc2_bias_handled_by_q4_gemm) { - const int64_t unroll_factor = GetUnrollFactor(hidden_size); - int64_t j = 0; - for (; j + unroll_factor <= hidden_size; j += unroll_factor) { - for (int64_t k = 0; k < unroll_factor; ++k) { - dest[j + k] += weight * (src[j + k] + thread_bias2_buffer[j + k]); + const size_t unroll_factor = narrow(GetUnrollFactor(hidden_size)); + size_t j = 0; + for (; j + unroll_factor <= narrow(hidden_size); j += unroll_factor) { + for (size_t loop_k = 0; loop_k < unroll_factor; ++loop_k) { + dest[j + loop_k] += weight * (src[j + loop_k] + thread_bias2_buffer[j + loop_k]); } } - for (; j < hidden_size; ++j) { + for (; j < narrow(hidden_size); ++j) { dest[j] += weight * (src[j] + thread_bias2_buffer[j]); } } else { - const int64_t unroll_factor = GetUnrollFactor(hidden_size); - int64_t j = 0; - for (; j + unroll_factor <= hidden_size; j += unroll_factor) { - for (int64_t k = 0; k < unroll_factor; ++k) { - dest[j + k] += weight * src[j + k]; + const size_t unroll_factor = narrow(GetUnrollFactor(hidden_size)); + size_t j = 0; + for (; j + unroll_factor <= narrow(hidden_size); j += unroll_factor) { + for (size_t loop_k = 0; loop_k < unroll_factor; ++loop_k) { + dest[j + loop_k] += weight * src[j + loop_k]; } } - for (; j < hidden_size; ++j) { + for (; j < narrow(hidden_size); ++j) { dest[j] += weight * src[j]; } } @@ -975,7 +969,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const int max_acc_threads = tp ? concurrency::ThreadPool::DegreeOfParallelism(tp) : 1; const size_t acc_thread_divisor = std::max(size_t{1}, static_cast(max_acc_threads) * 8); const size_t min_elements_per_thread = std::max(size_t{32}, output_buffer_size / acc_thread_divisor); - const int optimal_acc_threads = (tp == nullptr || output_buffer_size < min_elements_per_thread) ? 1 : std::min(static_cast(output_buffer_size / std::max(size_t{1}, min_elements_per_thread)), max_acc_threads); + const int optimal_acc_threads = (tp == nullptr || output_buffer_size < min_elements_per_thread) ? 1 : std::min(narrow(output_buffer_size / std::max(size_t{1}, min_elements_per_thread)), max_acc_threads); const int num_acc_threads = std::max(1, optimal_acc_threads); if (num_acc_threads > 1) { @@ -991,10 +985,10 @@ Status QMoECPU::Compute(OpKernelContext* context) const { size_t j = 0; const size_t chunk_size = end_idx - start_idx; - const int64_t unroll_factor = GetUnrollFactor(static_cast(chunk_size)); - for (; j + static_cast(unroll_factor) <= chunk_size; j += static_cast(unroll_factor)) { - for (int64_t k = 0; k < unroll_factor; ++k) { - dst[j + static_cast(k)] += src[j + static_cast(k)]; + const size_t unroll_factor = static_cast(GetUnrollFactor(static_cast(chunk_size))); + for (; j + unroll_factor <= chunk_size; j += unroll_factor) { + for (size_t loop_k = 0; loop_k < unroll_factor; ++loop_k) { + dst[j + loop_k] += src[j + loop_k]; } } for (; j < chunk_size; ++j) { @@ -1008,10 +1002,10 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const float* src = thread_local_outputs + thread_offset; size_t j = 0; - const int64_t unroll_factor = GetUnrollFactor(static_cast(output_buffer_size)); - for (; j + static_cast(unroll_factor) <= output_buffer_size; j += static_cast(unroll_factor)) { - for (int64_t k = 0; k < unroll_factor; ++k) { - buffer[j + static_cast(k)] += src[j + static_cast(k)]; + const size_t unroll_factor = narrow(GetUnrollFactor(narrow(output_buffer_size))); + for (; j + unroll_factor <= output_buffer_size; j += unroll_factor) { + for (size_t loop_k = 0; loop_k < unroll_factor; ++loop_k) { + buffer[j + loop_k] += src[j + loop_k]; } } for (; j < output_buffer_size; ++j) { From 5b2f588c731440aff504b89dcb653eee885215d3 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Fri, 26 Sep 2025 01:24:00 -0700 Subject: [PATCH 08/10] [EP ABI] Check if nodes specified in GetCapability() have already been assigned (#26156) ### Description Fixes segfault in `PluginExecutionProvider::GetCapability()` when the underlying `OrtEp` tries to claim nodes that have already been assigned to another EP. ### Motivation and Context Should log a warning (instead of crashing or throwing an exception) when a plugin EP tries to claim a node that is already assigned to another EP. --------- Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> --- .../ep_plugin_provider_interfaces.cc | 65 ++++- .../test/framework/ep_plugin_provider_test.cc | 231 ++++++++++++++++++ 2 files changed, 285 insertions(+), 11 deletions(-) diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index c8829423fbe26..55245420db37a 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -3,6 +3,7 @@ #include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" +#include #include #include #include @@ -117,6 +118,17 @@ static OrtDevice GetOrtDeviceForPluginEp(gsl::span ep_ return device_memory_info != nullptr ? device_memory_info->device : OrtDevice(); } +static const Node* FindFirstNodeAssignedToOtherEP(const std::string& ep_type, + gsl::span ep_nodes) { + auto node_iter = std::find_if(ep_nodes.begin(), ep_nodes.end(), + [&ep_type](const EpNode* node) -> bool { + const auto& node_ep_type = node->GetInternalNode().GetExecutionProviderType(); + return !node_ep_type.empty() && node_ep_type != ep_type; + }); + + return node_iter != ep_nodes.end() ? &(*node_iter)->GetInternalNode() : nullptr; +} + PluginExecutionProvider::PluginExecutionProvider(UniqueOrtEp ep, const OrtSessionOptions& session_options, OrtEpFactory& ep_factory, gsl::span ep_devices, @@ -158,9 +170,11 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie ORT_UNUSED_PARAMETER(resource_accountant); // TODO: Add support? Not used by prioritized EPs ORT_UNUSED_PARAMETER(kernel_lookup); // TODO: Add support? Not used by prioritized EPs, so probably not needed? + const logging::Logger& logger = GetLogger() != nullptr ? *GetLogger() : logging::LoggingManager::DefaultLogger(); + std::unique_ptr ep_graph = nullptr; if (Status status = EpGraph::Create(graph_viewer, ep_graph); !status.IsOK()) { - LOGS_DEFAULT(ERROR) << "Failed to create OrtGraph: " << status.ToString(); + LOGS(logger, ERROR) << "Failed to create OrtGraph for " << Type() << ": " << status.ToString(); return {}; } @@ -168,7 +182,7 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie Status status = ToStatusAndRelease(ort_ep_->GetCapability(ort_ep_.get(), ep_graph->ToExternal(), &api_graph_support_info)); if (!status.IsOK()) { - LOGS_DEFAULT(ERROR) << "OrtEp::GetCapability() failed with error: " << status.ToString(); + LOGS(logger, ERROR) << "OrtEp::GetCapability() for " << Type() << " failed with error: " << status.ToString(); return {}; } @@ -182,12 +196,39 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie // Create ComputeCapability instances from OrtEpGraphSupportInfo::NodeGrouping instances. for (const OrtEpGraphSupportInfo::NodeGrouping& node_grouping : api_graph_support_info.node_groupings) { + // Skip this node grouping if any node has already been assigned to another EP. + if (const Node* node_for_other_ep = FindFirstNodeAssignedToOtherEP(Type(), node_grouping.nodes); + node_for_other_ep != nullptr) { + LOGS(logger, WARNING) << "OrtEp::GetCapability() specified nodes that cannot be assigned to " << Type() << ". " + << "Found one or more nodes that were already assigned to a different EP named '" + << node_for_other_ep->GetExecutionProviderType() << "'. Ex: " + << node_for_other_ep->OpType() << " node with name '" + << node_for_other_ep->Name() << "'."; + continue; + } + if (node_grouping.kind == OrtEpGraphSupportInfo::NodeGroupingKind::kSingleAssignedNode) { + if (node_grouping.nodes.size() != 1) { + // The EpGraphSupportInfo_AddSingleNode() C API should already return an error if the EP tries to provide + // an invalid node. However, we check here too just in case this changes. + LOGS(logger, ERROR) << "OrtEp::GetCapability() for " << Type() << " did not specify exactly one valid node " + << "when calling EpGraphSupportInfo_AddSingleNode()."; + return {}; + } + auto indexed_sub_graph = std::make_unique(); indexed_sub_graph->nodes.push_back(node_grouping.nodes[0]->GetInternalNode().Index()); result.push_back(std::make_unique(std::move(indexed_sub_graph))); } else if (node_grouping.kind == OrtEpGraphSupportInfo::NodeGroupingKind::kFusedNode) { + if (node_grouping.nodes.empty()) { + // The EpGraphSupportInfo_AddNodesToFuse() C API should already return an error if the EP tries to provide + // an empty array of nodes from OrtEp::GetCapability(). However, we check here too just in case this changes. + LOGS(logger, ERROR) << "OrtEp::GetCapability() for " << Type() << " set an empty array of nodes " + << "when specifying supported nodes."; + return {}; + } + std::unordered_set node_set; node_set.reserve(node_grouping.nodes.size()); @@ -207,27 +248,29 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie this->Type(), this->Type(), /*node_unit_map*/ nullptr, node_grouping.fusion_options.drop_constant_initializers); - if (capabilities.size() > 1) { - LOGS_DEFAULT(ERROR) << "OrtEp::GetCapability() set nodes that cannot be fused together. " - << "Please ensure that the nodes provided to EpGraphSupportInfo_AddFusedNodes() do not " + if (capabilities.size() != 1) { + LOGS(logger, ERROR) << "OrtEp::GetCapability() for " << Type() << " set nodes that cannot be fused together. " + << "Please ensure that the nodes provided to EpGraphSupportInfo_AddNodesToFuse() do not " << "have an unsupported node in any path between two of the supported nodes."; return {}; } - // Enforce that the nodes in node_set match the nodes in capabilities[0] + // Log an error if the nodes in node_set do not match the nodes in capabilities[0]. We expect this to always + // be true because we've already checked that the EP did not try to claim nodes already assigned to another EP. // TODO(adrianlizarraga): This check can be removed when we stop using utils::CreateSupportedPartitions() above. std::vector& capability_node_indices = capabilities[0]->sub_graph->nodes; std::unordered_set capability_node_indices_set(capability_node_indices.begin(), capability_node_indices.end()); - ORT_ENFORCE(node_set.size() == capability_node_indices_set.size()); - ORT_ENFORCE(std::all_of(node_set.begin(), node_set.end(), [&capability_node_indices_set](const Node* node) { - return capability_node_indices_set.count(node->Index()) != 0; - })); + if (node_set.size() != capability_node_indices_set.size()) { + LOGS(logger, ERROR) << "OrtEp::GetCapability() for " << Type() + << " set nodes that cannot all be fused together."; + return {}; + } result.push_back(std::move(capabilities[0])); } else { - LOGS_DEFAULT(ERROR) << "PluginExecutionProvider::GetCapability() has invalid NodeGroupingKind: " + LOGS(logger, ERROR) << "PluginExecutionProvider::GetCapability() has invalid NodeGroupingKind: " << static_cast(node_grouping.kind); return {}; } diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index 35f7d06fb0912..30595d5ce97b2 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -3,9 +3,14 @@ #include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" +#include #include "gsl/gsl" #include "gtest/gtest.h" +#include "core/common/logging/sinks/file_sink.h" +#include "core/graph/graph_viewer.h" +#include "core/graph/model.h" +#include "core/optimizer/graph_optimizer_registry.h" #include "core/session/abi_devices.h" #include "core/session/onnxruntime_cxx_api.h" #include "test/util/include/asserts.h" @@ -23,6 +28,14 @@ struct ApiPtrs { const gsl::not_null ep_api; }; +static void CheckStringInFile(const PathString& filename, const std::string& look_for) { + std::ifstream ifs{filename}; + std::string content(std::istreambuf_iterator{ifs}, + std::istreambuf_iterator{}); + + EXPECT_NE(content.find(look_for), std::string::npos); +} + // Normally, a plugin EP would be implemented in a separate library. // The `test_plugin_ep` namespace contains a local implementation intended for unit testing. namespace test_plugin_ep { @@ -114,6 +127,10 @@ MakeTestOrtEpResult MakeTestOrtEp(std::vector ep_devices = { return result; } +class MockKernelLookup : public IExecutionProvider::IKernelLookup { + const KernelCreateInfo* LookUpKernel(const Node& /*node*/) const override { return nullptr; } +}; + } // namespace test_plugin_ep TEST(PluginExecutionProviderTest, GetPreferredLayout) { @@ -317,4 +334,218 @@ TEST(PluginExecutionProviderTest, InferOrtDeviceFromDeviceMemoryInfo) { #endif // !defined(ORT_NO_EXCEPTIONS) } +static void LoadModelAndAssignNodesToEp(const ORTCHAR_T* model_path, + const char* ep_name, + const std::unordered_set& ep_node_names, + /*out*/ std::shared_ptr& model) { + ASSERT_STATUS_OK(Model::Load(model_path, model, nullptr, + DefaultLoggingManager().DefaultLogger())); + + Graph& graph = model->MainGraph(); + + for (Node& node : graph.Nodes()) { + if (ep_node_names.count(node.Name()) > 0) { + node.SetExecutionProviderType(ep_name); + } + } +} + +static OrtStatus* ORT_API_CALL GetCapabilityTakeAllNodesOneGroup(OrtEp* this_ptr, const OrtGraph* graph, + OrtEpGraphSupportInfo* graph_support_info) noexcept { + auto* this_ep = static_cast(this_ptr); + + size_t num_nodes = 0; + if (OrtStatus* st = this_ep->ort_api->Graph_GetNumNodes(graph, &num_nodes); st != nullptr) { + return st; + } + + std::vector nodes(num_nodes); + if (OrtStatus* st = this_ep->ort_api->Graph_GetNodes(graph, nodes.data(), nodes.size()); st != nullptr) { + return st; + } + + if (OrtStatus* st = this_ep->ep_api->EpGraphSupportInfo_AddNodesToFuse(graph_support_info, + nodes.data(), nodes.size(), nullptr); + st != nullptr) { + return st; + } + + return nullptr; +} + +static OrtStatus* ORT_API_CALL GetCapabilityTakeAllNodesTwoGroups(OrtEp* this_ptr, const OrtGraph* graph, + OrtEpGraphSupportInfo* graph_support_info) noexcept { + auto* this_ep = static_cast(this_ptr); + + size_t num_nodes = 0; + if (OrtStatus* st = this_ep->ort_api->Graph_GetNumNodes(graph, &num_nodes); st != nullptr) { + return st; + } + + std::vector nodes(num_nodes); + if (OrtStatus* st = this_ep->ort_api->Graph_GetNodes(graph, nodes.data(), nodes.size()); st != nullptr) { + return st; + } + + // Expect at least 2 nodes. If not, this is really a testing/setup error. + if (num_nodes < 2) { + return this_ep->ort_api->CreateStatus(OrtErrorCode::ORT_FAIL, + "Expected at least two nodes in call to GetCapability"); + } + + std::vector node_group1; + std::vector node_group2; + + for (size_t i = 0; i < num_nodes; i++) { + if (i < num_nodes / 2) { + node_group1.push_back(nodes[i]); + } else { + node_group2.push_back(nodes[i]); + } + } + + if (OrtStatus* st = this_ep->ep_api->EpGraphSupportInfo_AddNodesToFuse(graph_support_info, + node_group1.data(), node_group1.size(), + nullptr); + st != nullptr) { + return st; + } + + if (OrtStatus* st = this_ep->ep_api->EpGraphSupportInfo_AddNodesToFuse(graph_support_info, + node_group2.data(), node_group2.size(), + nullptr); + st != nullptr) { + return st; + } + + return nullptr; +} + +static OrtStatus* ORT_API_CALL GetCapabilityTakeSingleNode(OrtEp* this_ptr, const OrtGraph* graph, + OrtEpGraphSupportInfo* graph_support_info) noexcept { + auto* this_ep = static_cast(this_ptr); + + size_t num_nodes = 0; + if (OrtStatus* st = this_ep->ort_api->Graph_GetNumNodes(graph, &num_nodes); st != nullptr) { + return st; + } + + std::vector nodes(num_nodes); + if (OrtStatus* st = this_ep->ort_api->Graph_GetNodes(graph, nodes.data(), nodes.size()); st != nullptr) { + return st; + } + + // Take only the first node using EpGraphSupportInfo_AddSingleNode(). + if (OrtStatus* st = this_ep->ep_api->EpGraphSupportInfo_AddSingleNode(graph_support_info, nodes[0]); + st != nullptr) { + return st; + } + + return nullptr; +} + +// Tests that GetCapability() doesn't crash if a plugin EP tries to claim a mix of unassigned nodes and +// nodes that are already assigned to another EP. +TEST(PluginExecutionProviderTest, GetCapability_ClaimNodesAssignedToOtherEP) { + std::filesystem::path log_file = ORT_TSTR("log_get_capability.txt"); + + // Helper function that loads a model (Add -> Mul -> Add) and assigns some or all of the nodes to another EP. + // Then, IExecutionProvider::GetCapability() is called to test the expected behavior. + auto run_test = [&log_file](IExecutionProvider& ep, + const std::unordered_set& nodes_for_other_ep, + const std::unordered_set& nodes_for_this_ep, + const char* expected_log_string) { + std::shared_ptr model; + ASSERT_NO_FATAL_FAILURE(LoadModelAndAssignNodesToEp(ORT_TSTR("testdata/add_mul_add.onnx"), + "OtherEp", nodes_for_other_ep, model)); + + std::filesystem::remove(log_file); + + // Call IExecutionProvider::GetCapability and check results + logs. + { + logging::LoggingManager log_manager{std::make_unique(log_file, false, false), + logging::Severity::kWARNING, false, + logging::LoggingManager::InstanceType::Temporal}; + auto file_logger = log_manager.CreateLogger("FileLogger"); + ep.SetLogger(file_logger.get()); // Make EP log to a file. + + GraphViewer graph_viewer(model->MainGraph()); + auto compute_capabilities = ep.GetCapability(graph_viewer, + test_plugin_ep::MockKernelLookup{}, + GraphOptimizerRegistry(nullptr, nullptr, file_logger.get()), + nullptr); + + ASSERT_EQ(compute_capabilities.size(), nodes_for_this_ep.empty() ? 0 : 1); + + if (compute_capabilities.size() == 1) { + ASSERT_EQ(compute_capabilities[0]->sub_graph->nodes.size(), nodes_for_this_ep.size()); + + for (NodeIndex node_index : compute_capabilities[0]->sub_graph->nodes) { + const Node* node = graph_viewer.GetNode(node_index); + ASSERT_NE(node, nullptr); + EXPECT_EQ(nodes_for_this_ep.count(node->Name()), 1); + } + } + } + + ASSERT_TRUE(std::filesystem::exists(log_file)); + EXPECT_NO_FATAL_FAILURE(CheckStringInFile(log_file, expected_log_string)); + }; + + constexpr std::array node_names = {"add_0", "mul_0", "add_1"}; + + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(); + + // Load a model and assign all of its nodes to another EP named 'OtherEp'. + // The plugin EP tries to claim all nodes in a single group via EpGraphSupportInfo_AddNodesToFuse. + // IExecutionProvider::GetCapability() should return an empty result and log a warning. + ort_ep->GetCapability = GetCapabilityTakeAllNodesOneGroup; + std::unordered_set nodes_for_other_ep = {"add_0", "mul_0", "add_1"}; + std::unordered_set nodes_for_this_ep; + run_test(*ep, nodes_for_other_ep, nodes_for_this_ep, + "Found one or more nodes that were already assigned to a different EP named 'OtherEp'"); + + // Load a model and assign only one node to another EP named 'OtherEp'. + // The plugin EP tries to claim all nodes in a single group. + // IExecutionProvider::GetCapability() should return an empty result and log a warning. + ort_ep->GetCapability = GetCapabilityTakeAllNodesOneGroup; + for (const char* node_name : node_names) { + nodes_for_other_ep = std::unordered_set{node_name}; + nodes_for_this_ep = std::unordered_set{}; + run_test(*ep, nodes_for_other_ep, nodes_for_this_ep, + "Found one or more nodes that were already assigned to a different EP named 'OtherEp'"); + } + + // Load a model and assign only the last Add node to another EP named 'OtherEp'. + // The plugin EP tries to claim all nodes in the following 2 groups: (add_0), (mul_0, add_1). + // IExecutionProvider::GetCapability() will only return (add_0) because the second group has a node + // that was assigned to 'OtherEp'. + ort_ep->GetCapability = GetCapabilityTakeAllNodesTwoGroups; + nodes_for_other_ep = std::unordered_set{"add_1"}; + nodes_for_this_ep = std::unordered_set{"add_0"}; + run_test(*ep, nodes_for_other_ep, nodes_for_this_ep, + "Found one or more nodes that were already assigned to a different EP named 'OtherEp'"); + + // Load a model and assign only the first Add node to another EP named 'OtherEp'. + // The plugin EP tries to claim all nodes in the following 2 groups: (add_0), (mul_0, add_1). + // IExecutionProvider::GetCapability() will only return (mul_0, add_1) because the first group has a node + // that was assigned to 'OtherEp'. + ort_ep->GetCapability = GetCapabilityTakeAllNodesTwoGroups; + nodes_for_other_ep = std::unordered_set{"add_0"}; + nodes_for_this_ep = std::unordered_set{"mul_0", "add_1"}; + run_test(*ep, nodes_for_other_ep, nodes_for_this_ep, + "Found one or more nodes that were already assigned to a different EP named 'OtherEp'"); + + // Load a model and assign the first Add node to another EP named 'OtherEp'. + // The plugin EP will try to take only the first Add node with a single call to EpGraphSupportInfo_AddSingleNode. + // IExecutionProvider::GetCapability() will return an empty result and log a warning. + ort_ep->GetCapability = GetCapabilityTakeSingleNode; + nodes_for_other_ep = std::unordered_set{"add_0"}; + nodes_for_this_ep = std::unordered_set{}; + run_test(*ep, nodes_for_other_ep, nodes_for_this_ep, + "Found one or more nodes that were already assigned to a different EP named 'OtherEp'"); + + std::filesystem::remove(log_file); +} + } // namespace onnxruntime::test From 906259961c78cbac6c10d12e504f46fc10466345 Mon Sep 17 00:00:00 2001 From: quic-tirupath Date: Fri, 26 Sep 2025 11:55:57 -0700 Subject: [PATCH 09/10] [QNN EP] Add dynamic option to set HTP performance mode (#26135) ### Description Add a new EP Dynamic option to set HTP performance mode after session creation. --------- Co-authored-by: quic-ashwshan --- .../onnxruntime_session_options_config_keys.h | 7 +++++++ .../core/providers/qnn/qnn_execution_provider.cc | 11 +++++++++++ .../test/providers/qnn/qnn_ep_context_test.cc | 15 +++++++++++++++ 3 files changed, 33 insertions(+) diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 7eb5f7659a365..64a434e2fe301 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -408,3 +408,10 @@ static const char* const kOrtSessionOptionsDisableModelCompile = "session.disabl // Note: UNSUPPORTED models always fail regardless of this setting. static const char* const kOrtSessionOptionsFailOnSuboptimalCompiledModel = "session.fail_on_suboptimal_compiled_model"; + +// THIS OPTION IS NOT A REGULAR SESSION OPTION SINCE IT CAN BE MODIFIED AT ANY TIME +// Meant to be used with SetEpDynamicOptions +// options for HTP performance mode: "burst", "balanced", "default", "high_performance", +// "high_power_saver", "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver", +// "sustained_high_performance". Default to "default". +static const char* const kOrtEpDynamicOptionsQnnHtpPerformanceMode = "ep.dynamic.qnn_htp_performance_mode"; diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 3acb3347acee1..4a6545a0e6f0a 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -1576,6 +1576,17 @@ Status QNNExecutionProvider::SetEpDynamicOptions(gsl::span ke LOGS_DEFAULT(ERROR) << "Invalid EP Workload Type: " << value; return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid EP Workload Type."); } + } else if (key == kOrtEpDynamicOptionsQnnHtpPerformanceMode) { + auto backend_type = qnn_backend_manager_->GetQnnBackendType(); + if (qnn::QnnBackendType::HTP != backend_type && qnn::QnnBackendType::DSP != backend_type) { + return Status::OK(); + } + qnn::HtpPerformanceMode htp_performance_mode = qnn::HtpPerformanceMode::kHtpDefault; + ParseHtpPerformanceMode(value, htp_performance_mode); + if (GetPerThreadContext().IsHtpPowerConfigIdValid()) { + ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetHtpPowerConfig(GetPerThreadContext().GetHtpPowerConfigId(), + htp_performance_mode)); + } } else { LOGS_DEFAULT(ERROR) << "EP Dynamic Option \"" << key << "\" is not currently supported."; return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported EP Dynamic Option"); diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 1c8cc6f78fe63..a2f1b9b56538b 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -2076,6 +2076,21 @@ TEST_F(QnnHTPBackendTests, QnnEpDynamicOptions) { } catch (const std::exception& e) { EXPECT_STREQ("Unsupported EP Dynamic Option", e.what()); } + + const char* const htp_perf_mode_type[] = {"ep.dynamic.qnn_htp_performance_mode"}; + const char* const eps_type[] = {"extreme_power_saver"}; + const char* const shp_type[] = {"sustained_high_performance"}; + session.SetEpDynamicOptions(htp_perf_mode_type, shp_type, 1); + ort_output = session.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(), + output_names_c.data(), 1); + + session.SetEpDynamicOptions(htp_perf_mode_type, eps_type, 1); + ort_output = session.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(), + output_names_c.data(), 1); + + session.SetEpDynamicOptions(htp_perf_mode_type, shp_type, 1); + ort_output = session.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(), + output_names_c.data(), 1); } // Implementation of OrtOutStreamWriteFunc that writes the compiled model to a file. From 41b4da6e3eccdff573b7aea619067a9d9a72ac28 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Fri, 26 Sep 2025 14:16:12 -0700 Subject: [PATCH 10/10] Re-enable inference tests that test the I/O memory info C APIs --- onnxruntime/test/shared_lib/test_inference.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 5199730ae323d..8c2928670934a 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -494,7 +494,7 @@ INSTANTIATE_TEST_SUITE_P(CApiTestWithProviders, CApiTestWithProvider, ::testing::Values(0, 1, 2, 3, 4)); -TEST(CApiTest, DISABLED_TestInputPassThroughToOutput) { +TEST(CApiTest, TestInputPassThroughToOutput) { const ORTCHAR_T* model_uri = TSTR("testdata/input_propagated_to_output.onnx"); Ort::Session session(*ort_env, model_uri, Ort::SessionOptions{}); auto inputs_meminfos = session.GetMemoryInfoForInputs(); @@ -505,7 +505,7 @@ TEST(CApiTest, DISABLED_TestInputPassThroughToOutput) { ASSERT_EQ(7U, outputs_meminfos.size()); } -TEST(CApiTest, DISABLED_TestDanglingInput) { +TEST(CApiTest, TestDanglingInput) { // Here we test an issue with segments_ids that is an input not consumed by anything // This kind of model is unlikely to be used in practice but we want to make sure it works const ORTCHAR_T* model_uri = TSTR("testdata/test_dangling_input_segment_ids.onnx");