diff --git a/cmake/onnxruntime_session.cmake b/cmake/onnxruntime_session.cmake index 3ec3c6ee1d5ae..f81a7a9726b76 100644 --- a/cmake/onnxruntime_session.cmake +++ b/cmake/onnxruntime_session.cmake @@ -5,6 +5,8 @@ file(GLOB onnxruntime_session_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_INCLUDE_DIR}/core/session/*.h" "${ONNXRUNTIME_ROOT}/core/session/*.h" "${ONNXRUNTIME_ROOT}/core/session/*.cc" + "${ONNXRUNTIME_ROOT}/core/session/plugin_ep/*.h" + "${ONNXRUNTIME_ROOT}/core/session/plugin_ep/*.cc" ) if (onnxruntime_ENABLE_TRAINING_APIS) @@ -22,7 +24,7 @@ endif() # which is not enabled for any minimal builds. if (onnxruntime_MINIMAL_BUILD) file(GLOB autoep_srcs - "${ONNXRUNTIME_ROOT}/core/session/ep_*.*" + "${ONNXRUNTIME_ROOT}/core/session/plugin_ep/*.*" ) set(onnxruntime_session_src_exclude diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Exceptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Exceptions.shared.cs index 098a18b7444cf..2467475b6b189 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Exceptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Exceptions.shared.cs @@ -23,8 +23,8 @@ internal enum ErrorCode ModelLoaded = 8, NotImplemented = 9, InvalidGraph = 10, - ShapeInferenceNotRegistered = 11, - RequirementNotRegistered = 12, + ShapeInferenceNotRegistered = 11, // TODO: should be ORT_EP_FAIL + RequirementNotRegistered = 12, // TODO: should be ORT_MODEL_LOAD_CANCELED } /// diff --git a/include/onnxruntime/core/common/status.h b/include/onnxruntime/core/common/status.h index da9735aa4e418..8cf6420f2d0f7 100644 --- a/include/onnxruntime/core/common/status.h +++ b/include/onnxruntime/core/common/status.h @@ -46,6 +46,7 @@ enum StatusCode { EP_FAIL = 11, MODEL_LOAD_CANCELED = 12, MODEL_REQUIRES_COMPILATION = 13, + NOT_FOUND = 14, }; constexpr const char* StatusCodeToString(StatusCode status) noexcept { @@ -78,6 +79,8 @@ constexpr const char* StatusCodeToString(StatusCode status) noexcept { return "MODEL_LOAD_CANCELED"; case StatusCode::MODEL_REQUIRES_COMPILATION: return "MODEL_REQUIRES_COMPILATION"; + case StatusCode::NOT_FOUND: + return "NOT_FOUND"; default: return "GENERAL ERROR"; } @@ -114,6 +117,8 @@ constexpr HRESULT StatusCodeToHRESULT(StatusCode status) noexcept { return HRESULT_FROM_WIN32(ERROR_CANCELLED); case StatusCode::MODEL_REQUIRES_COMPILATION: return HRESULT_FROM_WIN32(ERROR_NOT_SUPPORTED); + case StatusCode::NOT_FOUND: + return HRESULT_FROM_WIN32(ERROR_NOT_FOUND); default: return E_FAIL; } diff --git a/include/onnxruntime/core/session/environment.h b/include/onnxruntime/core/session/environment.h index 7e49275e59b8b..306f81df38e48 100644 --- a/include/onnxruntime/core/session/environment.h +++ b/include/onnxruntime/core/session/environment.h @@ -20,7 +20,7 @@ #include "core/platform/threadpool.h" #include "core/session/abi_devices.h" -#include "core/session/ep_library.h" +#include "core/session/plugin_ep/ep_library.h" #include "core/session/onnxruntime_c_api.h" struct OrtThreadingOptions; diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 2f0e4aa7ce108..cf5ad29b03801 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -264,6 +264,7 @@ typedef enum OrtErrorCode { ORT_EP_FAIL, ORT_MODEL_LOAD_CANCELED, ORT_MODEL_REQUIRES_COMPILATION, + ORT_NOT_FOUND, } OrtErrorCode; typedef enum OrtOpAttrType { @@ -5846,14 +5847,13 @@ struct OrtApi { /** \brief Returns an OrtGraph that contains a subset of nodes in the source OrtGraph. * - * Note: - * The lifetime of "dst_graph" is tied to that of "src_graph", as they both internally reference + * \note The lifetime of "dst_graph" is tied to that of "src_graph", as they both internally reference * the same underlying graph. * * \param[in] src_graph The source OrtGraph instance. * \param[in] nodes A subset of the nodes/OrtNodes in 'graph'. * \param[in] num_nodes Number of nodes. - * \param[out] dst_sub_graph An OrtGraph created from a given set of nodes. Must be released by calling ReleaseGraph. + * \param[out] dst_graph An OrtGraph created from a given set of nodes. Must be released by calling ReleaseGraph. * * \snippet{doc} snippets.dox OrtStatus Return Value * @@ -6032,6 +6032,11 @@ struct OrtApi { * Typical usage sets this to the result of Node_GetNumAttributes(). An error status is * returned if `num_attributes` is less than the number of node attributes. * + * \note ONNX Runtime automatically sets optional (unset) attributes to their default values if the default value + * is a constant expression that does not depend on other tensor/model characteristics. Conv's 'kernel_shape' + * attribute is an example of an optional attribute that does not have a constant default value. This function + * does not provide any unset optional attributes without a constant default value. + * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. @@ -6043,14 +6048,22 @@ struct OrtApi { * * \param[in] node The OrtNode instance. * \param[in] attribute_name The name of the attribute - * \param[out] attribute Output the attribute if its name matches 'attribute_name', otherwise output nullptr. + * \param[out] attribute Output parameter set to the OrtOpAttr instance if an attribute by the given name exists. + * For an unset optional attribute, `attribute` is set to NULL and a non-error status is + * returned. For an invalid attribute name, `attribute` is set to NULL and an error status with + * code ORT_NOT_FOUND is returned. + * + * \note ONNX Runtime automatically sets optional (unset) attributes to their default values if the default value + * is a constant expression that does not depend on other tensor/model characteristics. Conv's 'kernel_shape' + * attribute is an example of an optional attribute that does not have a constant default value. This function + * does not provide any unset optional attributes without a constant default value. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. */ ORT_API2_STATUS(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, - _Outptr_ const OrtOpAttr** attribute); + _Outptr_result_maybenull_ const OrtOpAttr** attribute); /** \brief Get the attribute type as OrtOpAttrType from an OrtOpAttr. * diff --git a/java/src/main/java/ai/onnxruntime/OrtException.java b/java/src/main/java/ai/onnxruntime/OrtException.java index 5ec58ea137124..06c3d3cbc770c 100644 --- a/java/src/main/java/ai/onnxruntime/OrtException.java +++ b/java/src/main/java/ai/onnxruntime/OrtException.java @@ -81,11 +81,17 @@ public enum OrtErrorCode { /** The ONNX graph is invalid. */ ORT_INVALID_GRAPH(10), /** The ORT execution provider failed. */ - ORT_EP_FAIL(11); + ORT_EP_FAIL(11), + /** Model load was canceled. */ + ORT_MODEL_LOAD_CANCELED(12), + /** Model requires compilation. */ + ORT_MODEL_REQUIRES_COMPILATION(13), + /** Item was not found. */ + ORT_NOT_FOUND(14); private final int value; - private static final OrtErrorCode[] values = new OrtErrorCode[12]; + private static final OrtErrorCode[] values = new OrtErrorCode[15]; static { for (OrtErrorCode ot : OrtErrorCode.values()) { diff --git a/java/src/main/native/OrtJniUtil.c b/java/src/main/native/OrtJniUtil.c index fe19015d642f0..5d8efd7b476cb 100644 --- a/java/src/main/native/OrtJniUtil.c +++ b/java/src/main/native/OrtJniUtil.c @@ -1051,6 +1051,12 @@ jint convertErrorCode(OrtErrorCode code) { return 10; case ORT_EP_FAIL: return 11; + case ORT_MODEL_LOAD_CANCELED: + return 12; + case ORT_MODEL_REQUIRES_COMPILATION: + return 13; + case ORT_NOT_FOUND: + return 14; default: return -1; // Unknown error code } diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc index e2bb3b508ca7c..85a2cbaea0e44 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/common/cpuid_info.h" // for CPUIDInfo::GetCPUIDInfo().HasArm_SME() #include "core/common/narrow.h" #include "core/common/safeint.h" #include "core/mlas/inc/mlas.h" @@ -10,6 +11,7 @@ #include "core/util/math_cpuonly.h" #include "core/util/qmath.h" +#include #include #include @@ -169,43 +171,40 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase { // only pack Matrix B if (input_idx == GetBIdx()) { const Tensor* b_zp_constant_tensor{nullptr}; - bool b_quantization_is_asymmetric = false; + bool b_quantization_might_be_asymmetric = false; - // zero point tensor could be provided as a direct input to the kernel and not as a constant so this - // test is not sufficient const OrtValue* b_zp; if (Info().TryGetConstantInput(IN_B_ZERO_POINT, &b_zp)) { b_zp_constant_tensor = &b_zp->Get(); } - // MlasDynamicQgemm requires symmetric quantization for B, so no zero point should exist or it should - // have a zero value - if (b_zp_constant_tensor != nullptr) { // Covers the case where tensor is not a constant - const auto& shape = b_zp_constant_tensor->Shape(); - const auto* zp_data = static_cast(b_zp_constant_tensor->DataRaw()); - size_t zp_size = static_cast(shape.Size()); - // MlasDynamicQgemm requires symmetric quantization: zp must be scalar 0 or 1D all-zero - if ((shape.NumDimensions() == 0) && (zp_data[0] == 0)) { - b_quantization_is_asymmetric = false; - } else if (shape.NumDimensions() == 1) { - b_quantization_is_asymmetric = false; - for (size_t i = 0; i < zp_size; ++i) { - if (zp_data[i] != 0) { - b_quantization_is_asymmetric = true; - break; - } - } - } else { - // Unsupported higher-rank zp tensor - b_quantization_is_asymmetric = true; - } + // MlasDynamicQgemm requires symmetric quantization for B, so the B zero point value should either be all zeros + // or not provided. + if (b_zp_constant_tensor != nullptr) { + // B zero point is constant. Check if it is all zeros. + assert(b_zp_constant_tensor->IsDataType() || b_zp_constant_tensor->IsDataType()); + const auto* zp_bytes = static_cast(b_zp_constant_tensor->DataRaw()); + const size_t zp_size_in_bytes = b_zp_constant_tensor->SizeInBytes(); + b_quantization_might_be_asymmetric = std::any_of(zp_bytes, zp_bytes + zp_size_in_bytes, + [](std::byte v) { return v != std::byte{0}; }); + } else { + // B zero point input is not constant. If it exists, we can't assume symmetric quantization. + const auto input_defs = Info().node().InputDefs(); + const bool b_zp_input_exists = input_defs.size() > IN_B_ZERO_POINT && input_defs[IN_B_ZERO_POINT]->Exists(); + b_quantization_might_be_asymmetric = b_zp_input_exists; } // MlasDynamicQgemm requires scale data to be available at packing stage const Tensor* b_scale_tensor = nullptr; const bool b_scale_available = Info().TryGetConstantInput(IN_B_SCALE, &b_scale_tensor); - can_use_dynamic_quant_mlas_ = (!b_quantization_is_asymmetric && b_scale_available); + can_use_dynamic_quant_mlas_ = (!b_quantization_might_be_asymmetric && b_scale_available); + + // Currently, MlasDynamicQGemmBatch() and associated functions require SME or else they are no-ops. + // We check that here too before attempting to use them. + if (!CPUIDInfo::GetCPUIDInfo().HasArm_SME()) { + can_use_dynamic_quant_mlas_ = false; + } // Only handle the common case of a 2D weight matrix. Additional matrices // could be handled by stacking the packed buffers. diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index 4ceadb6191a9b..493fbff897af8 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -87,6 +87,24 @@ static void ConvertNodeArgsToValueInfos(const EpGraph* ep_graph, } } +#if !defined(ORT_MINIMAL_BUILD) +static bool IsOptionalAttribute(const Node& node, const std::string& attr_name) { + const ONNX_NAMESPACE::OpSchema* op_schema = node.Op(); + if (op_schema == nullptr) { + return false; + } + + auto attr_schema_iter = op_schema->attributes().find(attr_name); + if (attr_schema_iter == op_schema->attributes().end()) { + return false; // Not an attribute for this operator type. + } + + const ONNX_NAMESPACE::OpSchema::Attribute& attr_schema = attr_schema_iter->second; + + return !attr_schema.required; +} +#endif // !defined(ORT_MINIMAL_BUILD) + // // EpNode // @@ -268,13 +286,20 @@ gsl::span EpNode::GetOutputsSpan() const { return outputs_; } -const OrtOpAttr* EpNode::GetAttribute(const std::string& name) const { +const OrtOpAttr* EpNode::GetAttribute(const std::string& name, bool& is_unset_optional_attr) const { auto iter = attributes_map_.find(name); - if (iter == attributes_map_.end()) { - return nullptr; - } else { + if (iter != attributes_map_.end()) { + is_unset_optional_attr = false; return reinterpret_cast(iter->second.get()); } + +#if !defined(ORT_MINIMAL_BUILD) + is_unset_optional_attr = IsOptionalAttribute(node_, name); +#else + // This is not properly set in a minimal build because it does not have access to the operator schema. + is_unset_optional_attr = false; +#endif // !defined(ORT_MINIMAL_BUILD) + return nullptr; } const std::string& EpNode::GetEpName() const { diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index 243bdc2944ffb..e61bb4d62dba6 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -209,8 +209,9 @@ struct EpNode : public OrtNode { // Helper that returns this node's outputs as a span of EpValueInfo pointers. gsl::span GetOutputsSpan() const; - // Helper that gets the node's attributes by name. - const OrtOpAttr* GetAttribute(const std::string& name) const; + // Helper that gets the node's attributes by name. If the attribute is not set, returns NULL and sets the + // output parameter `is_unset_optional_attr` to true if this is an unset optional attribute. + const OrtOpAttr* GetAttribute(const std::string& name, bool& is_unset_optional_attr) const; // Helper that gets the execution provider name that this node is assigned to run on. const std::string& GetEpName() const; diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 450a8bad09392..2b553aecbca6c 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -16,10 +16,10 @@ #include "core/session/abi_session_options_impl.h" #include "core/session/allocator_adapters.h" #include "core/session/inference_session.h" -#include "core/session/ep_factory_internal.h" -#include "core/session/ep_library_internal.h" -#include "core/session/ep_library_plugin.h" -#include "core/session/ep_library_provider_bridge.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_library_internal.h" +#include "core/session/plugin_ep/ep_library_plugin.h" +#include "core/session/plugin_ep/ep_library_provider_bridge.h" #include "core/session/ort_apis.h" #include "core/session/utils.h" diff --git a/onnxruntime/core/session/ep_library_internal.cc b/onnxruntime/core/session/ep_library_internal.cc deleted file mode 100644 index 986ccb1fa17fc..0000000000000 --- a/onnxruntime/core/session/ep_library_internal.cc +++ /dev/null @@ -1,281 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/session/ep_library_internal.h" - -#include "core/framework/error_code_helper.h" -#include "core/framework/ortmemoryinfo.h" -#include "core/framework/session_options.h" -#include "core/providers/cpu/cpu_execution_provider.h" -#include "core/session/abi_devices.h" -#include "core/session/abi_logger.h" -#include "core/session/abi_session_options_impl.h" -#include "core/session/ep_api.h" -#include "core/session/ort_apis.h" - -#if defined(USE_DML) -#include "core/providers/dml/dml_provider_factory_creator.h" -#endif - -#if defined(USE_WEBGPU) -#include "core/providers/webgpu/webgpu_provider_factory_creator.h" -#endif - -namespace onnxruntime { - -class CpuEpFactory : public EpFactoryInternalImpl { - public: - CpuEpFactory() : EpFactoryInternalImpl(kCpuExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { - } - - private: - OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, - const OrtHardwareDevice* const* devices, - size_t num_devices, - OrtEpDevice** ep_devices, - size_t max_ep_devices, - size_t* p_num_ep_devices) noexcept override { - size_t& num_ep_devices = *p_num_ep_devices; - for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { - const OrtHardwareDevice& device = *devices[i]; - if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { - ORT_API_RETURN_IF_ERROR( - OrtExecutionProviderApi::CreateEpDevice(&ep_factory, &device, nullptr, nullptr, - &ep_devices[num_ep_devices++])); - } - } - - return nullptr; - } - - OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, - const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, - size_t num_devices, - const OrtSessionOptions* session_options, - const OrtLogger* session_logger, - std::unique_ptr* ep) noexcept override { - if (num_devices != 1) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, - "CPU EP factory currently only supports one device at a time."); - } - - CPUExecutionProviderInfo epi{session_options->value.enable_cpu_mem_arena}; - *ep = std::make_unique(epi); - (*ep)->SetLogger(session_logger->ToInternal()); - - return nullptr; - } -}; - -std::unique_ptr EpLibraryInternal::CreateCpuEp() { - auto cpu_factory_impl = std::make_unique(); - auto internal_factory = std::make_unique(std::move(cpu_factory_impl)); - return std::make_unique(std::move(internal_factory)); -} - -#if defined(USE_DML) -class DmlEpFactory : public EpFactoryInternalImpl { - public: - DmlEpFactory() : EpFactoryInternalImpl(kDmlExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { - } - - private: - OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, - const OrtHardwareDevice* const* devices, - size_t num_devices, - OrtEpDevice** ep_devices, - size_t max_ep_devices, - size_t* p_num_ep_devices) noexcept override { - size_t& num_ep_devices = *p_num_ep_devices; - num_ep_devices = 0; - - for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { - const OrtHardwareDevice& device = *devices[i]; - if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { - std::unique_ptr ep_options; - - // TODO: Should we ignore a user provided 'device_id' when they select an OrtEpDevice as that is - // associated with a specific device. - // How would we know what options should not allow user overrides if set in OrtEpDevice? - int32_t device_id = 0; // If no device_id was found default to 0 - if (auto it = device.metadata.Entries().find("DxgiAdapterNumber"); it != device.metadata.Entries().end()) { - ep_options = std::make_unique(); - device_id = std::stoi(it->second); - } - - ep_options->Add("device_id", std::to_string(device_id)); - - auto* api_status = OrtExecutionProviderApi::CreateEpDevice(&ep_factory, - &device, nullptr, ep_options.get(), - &ep_devices[num_ep_devices]); - - if (device_memory_infos.size() < device_id + 1) { - device_memory_infos.resize(device_id + 1); - device_allocators.resize(device_id + 1); - } - - if (device_memory_infos[device_id] == nullptr) { - // Create memory info for the device if it doesn't already exist - device_memory_infos[device_id] = std::make_unique( - "DML", OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::MICROSOFT, - narrow(device_id))); - } - - // This is what we need to add once CreateAllocator is implemented to create a shared allocator for the device. - // OrtExecutionProviderApi::EpDevice_AddAllocatorInfo(ep_devices[num_ep_devices], - // device_memory_infos[device_id].get()); - - if (api_status != nullptr) { - return api_status; - } - - ++num_ep_devices; - } - } - - return nullptr; - } - - OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, - const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, - size_t num_devices, - const OrtSessionOptions* session_options, - const OrtLogger* session_logger, - std::unique_ptr* ep) noexcept override { - *ep = nullptr; - - if (num_devices != 1) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, - "DML EP factory currently only supports one device at a time."); - } - - auto ep_options = GetOptionsFromSessionOptions(session_options->value); - auto dml_ep_factory = DMLProviderFactoryCreator::CreateFromProviderOptions(session_options->value.config_options, - ep_options); - - *ep = dml_ep_factory->CreateProvider(); - (*ep)->SetLogger(session_logger->ToInternal()); - - return nullptr; - } - - OrtStatus* CreateAllocator(const OrtMemoryInfo* /*memory_info*/, - const OrtKeyValuePairs* /*allocator_options*/, - OrtAllocator** allocator) noexcept override { - // TODO: This needs to create an allocator for the specific device so it's available as a shared allocator. That - // requires pulling lots of things out of the DML EP to get the D3D12 device and create a - // BucketizedBufferAllocator. See providers\dml\DmlExecutionProvider\src\ExecutionProvider.cpp - //*allocator = device_allocators[memory_info->device.Id()].get(); - *allocator = nullptr; - return nullptr; - } - - OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) override { - // TODO: Wrap the IDataTransfer implementation so we can copy to device using OrtApi CopyTensors. - *data_transfer = nullptr; - return nullptr; - } - - std::vector> device_memory_infos; // memory info for each device - std::vector> device_allocators; // allocators for each device -}; - -std::unique_ptr EpLibraryInternal::CreateDmlEp() { - auto dml_factory_impl = std::make_unique(); - auto internal_factory = std::make_unique(std::move(dml_factory_impl)); - return std::make_unique(std::move(internal_factory)); -} -#endif - -#if defined(USE_WEBGPU) -class WebGpuEpFactory : public EpFactoryInternalImpl { - public: - WebGpuEpFactory() : EpFactoryInternalImpl(kWebGpuExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { - } - - private: - OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, - const OrtHardwareDevice* const* devices, - size_t num_devices, - OrtEpDevice** ep_devices, - size_t max_ep_devices, - size_t* p_num_ep_devices) noexcept override { - size_t& num_ep_devices = *p_num_ep_devices; - num_ep_devices = 0; - - for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { - const OrtHardwareDevice& device = *devices[i]; - if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { - // TODO: any metadata or options to add? - ORT_API_RETURN_IF_ERROR(OrtExecutionProviderApi::CreateEpDevice(&ep_factory, - &device, nullptr, nullptr, - &ep_devices[num_ep_devices++])); - } - } - - return nullptr; - } - - OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, - const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, - size_t num_devices, - const OrtSessionOptions* session_options, - const OrtLogger* session_logger, - std::unique_ptr* ep) noexcept override { - *ep = nullptr; - - if (num_devices != 1) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, - "WebGPU EP factory currently only supports one device at a time."); - } - - auto webgpu_ep_factory = WebGpuProviderFactoryCreator::Create(session_options->value.config_options); - *ep = webgpu_ep_factory->CreateProvider(); - (*ep)->SetLogger(session_logger->ToInternal()); - - return nullptr; - } - - /* TODO: Implement CreateAllocator and CreateDataTransfer to support shared allocators and data transfer outside of - an InferenceSession. - OrtStatus* CreateAllocator(const OrtMemoryInfo* memory_info, - const OrtKeyValuePairs* allocator_options, - OrtAllocator** allocator) noexcept override { - *allocator = device_allocators[memory_info->device.Id()].get(); - } - - OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) override { - // TODO: Wrap the IDataTransfer implementation so we can copy to device using OrtApi CopyTensors. - *data_transfer = nullptr; - return nullptr; - } - */ -}; - -std::unique_ptr EpLibraryInternal::CreateWebGpuEp() { - auto webgpu_factory_impl = std::make_unique(); - auto internal_factory = std::make_unique(std::move(webgpu_factory_impl)); - return std::make_unique(std::move(internal_factory)); -} -#endif - -std::vector> EpLibraryInternal::CreateInternalEps() { - std::vector> internal_eps; - internal_eps.reserve(4); - - // CPU EP - internal_eps.push_back(CreateCpuEp()); - -#if defined(USE_WEBGPU) - internal_eps.push_back(CreateWebGpuEp()); -#endif - -#if defined(USE_DML) - internal_eps.push_back(CreateDmlEp()); -#endif - - return internal_eps; -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/session/ep_library_provider_bridge.cc b/onnxruntime/core/session/ep_library_provider_bridge.cc deleted file mode 100644 index ae553891beaa7..0000000000000 --- a/onnxruntime/core/session/ep_library_provider_bridge.cc +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/session/ep_library_provider_bridge.h" - -#include "core/common/status.h" -#include "core/framework/error_code_helper.h" -#include "core/framework/session_options.h" -#include "core/providers/cuda/cuda_provider_options.h" -#include "core/providers/shared_library/provider_host_api.h" -#include "core/session/abi_devices.h" -#include "core/session/abi_session_options_impl.h" -#include "core/session/ep_factory_internal.h" - -namespace onnxruntime { -class ProviderBridgeEpFactory : public EpFactoryInternalImpl { - public: - ProviderBridgeEpFactory(OrtEpFactory& ep_factory, ProviderLibrary& provider_library) - : EpFactoryInternalImpl(ep_factory.GetName(&ep_factory), - ep_factory.GetVendor(&ep_factory), - ep_factory.GetVendorId(&ep_factory)), - ep_factory_{ep_factory}, - provider_library_{provider_library} { - } - - private: - OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, - const OrtHardwareDevice* const* devices, - size_t num_devices, - OrtEpDevice** ep_devices, - size_t max_ep_devices, - size_t* num_ep_devices) noexcept override { - ORT_API_RETURN_IF_ERROR(ep_factory_.GetSupportedDevices(&ep_factory_, devices, num_devices, ep_devices, - max_ep_devices, num_ep_devices)); - - // add the EpFactoryInternal layer back in so that we can redirect to CreateIExecutionProvider. - for (size_t i = 0; i < *num_ep_devices; ++i) { - auto* ep_device = ep_devices[i]; - if (ep_device) { - ep_device->ep_factory = &ep_factory; - } - } - - return nullptr; - } - - OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* devices, - const OrtKeyValuePairs* const* ep_metadata_pairs, - size_t num_devices, - const OrtSessionOptions* session_options, - const OrtLogger* session_logger, - std::unique_ptr* ep) noexcept override { - // get the provider specific options - auto ep_options = GetOptionsFromSessionOptions(session_options->value); - auto& provider = provider_library_.Get(); - - auto status = provider.CreateIExecutionProvider(devices, ep_metadata_pairs, num_devices, - ep_options, *session_options, *session_logger, *ep); - - return ToOrtStatus(status); - } - - OrtStatus* CreateAllocator(const OrtMemoryInfo* memory_info, - const OrtKeyValuePairs* allocator_options, - OrtAllocator** allocator) noexcept override { - return ep_factory_.CreateAllocator(&ep_factory_, memory_info, allocator_options, allocator); - } - - void ReleaseAllocator(OrtAllocator* allocator) noexcept override { - ep_factory_.ReleaseAllocator(&ep_factory_, allocator); - } - - OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) override { - return ep_factory_.CreateDataTransfer(&ep_factory_, data_transfer); - } - - bool IsStreamAware() const noexcept override { - return ep_factory_.IsStreamAware(&ep_factory_); - } - - OrtStatus* CreateSyncStreamForDevice(const OrtMemoryDevice* device, - const OrtKeyValuePairs* stream_options, - OrtSyncStreamImpl** stream) noexcept override { - return ep_factory_.CreateSyncStreamForDevice(&ep_factory_, device, stream_options, stream); - } - - OrtEpFactory& ep_factory_; // OrtEpFactory from the provider bridge EP - ProviderLibrary& provider_library_; // ProviderLibrary from the provider bridge EP -}; - -Status EpLibraryProviderBridge::Load() { - std::lock_guard lock{mutex_}; - - if (!factories_.empty()) { - // already loaded - return Status::OK(); - } - - // if we have been unloaded we can't just be reloaded. - if (!ep_library_plugin_ || !provider_library_) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "EpLibraryProviderBridge has been unloaded. " - "Please create a new instance using LoadPluginOrProviderBridge."); - } - - // wrap the EpLibraryPlugin factories that were created via calling CreateEpFactories in the library. - // use GetSupportedDevices from the library's factory. - // to do this we need to capture `factory` and plug it in to is_supported_fn and create_fn. - // we also need to update any returned OrtEpDevice instances to swap the wrapper EpFactoryInternal in so that we can - // call Provider::CreateIExecutionProvider in EpFactoryInternal::CreateIExecutionProvider. - for (const auto& factory : ep_library_plugin_->GetFactories()) { - auto factory_impl = std::make_unique(*factory, *provider_library_); - auto internal_factory = std::make_unique(std::move(factory_impl)); - - factory_ptrs_.push_back(internal_factory.get()); - internal_factory_ptrs_.push_back(internal_factory.get()); - factories_.push_back(std::move(internal_factory)); - } - - return Status::OK(); -} - -Status EpLibraryProviderBridge::Unload() { - std::lock_guard lock{mutex_}; - - internal_factory_ptrs_.clear(); - factory_ptrs_.clear(); - factories_.clear(); - - // we loaded ep_library_plugin_ after provider_library_ in LoadPluginOrProviderBridge so do the reverse order here. - ORT_RETURN_IF_ERROR(ep_library_plugin_->Unload()); - ep_library_plugin_ = nullptr; - - provider_library_->Unload(); - provider_library_ = nullptr; - - return Status::OK(); -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 27f81b18be0c9..ae9a86aa923fc 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -38,8 +38,8 @@ #include "core/session/allocator_adapters.h" #include "core/session/compile_api.h" #include "core/session/environment.h" -#include "core/session/ep_api.h" -#include "core/session/ep_library_internal.h" +#include "core/session/plugin_ep/ep_api.h" +#include "core/session/plugin_ep/ep_library_internal.h" #include "core/session/inference_session.h" #include "core/session/inference_session_utils.h" #include "core/session/IOBinding.h" @@ -2993,7 +2993,8 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributes, _In_ const OrtNode* node, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, _Outptr_ const OrtOpAttr** attribute) { +ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, + _Outptr_result_maybenull_ const OrtOpAttr** attribute) { API_IMPL_BEGIN if (attribute == nullptr) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'attribute' argument is NULL"); @@ -3004,14 +3005,16 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributeByName, _In_ const OrtNode* node, return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "node is a ModelEditorNode which doesn't support Node_GetAttributeByName."); } - *attribute = ep_node->GetAttribute(attribute_name); + bool is_unset_optional_attr = false; + *attribute = ep_node->GetAttribute(attribute_name, is_unset_optional_attr); - if (*attribute) { + if (*attribute || is_unset_optional_attr) { return nullptr; } else { - return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Attribute does not exist."); + std::ostringstream oss; + oss << "Node attribute does not exist: " << attribute_name; + return OrtApis::CreateStatus(OrtErrorCode::ORT_NOT_FOUND, oss.str().c_str()); } - API_IMPL_END } diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index d2f22397bf82c..9636c41938a2b 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -678,7 +678,7 @@ ORT_API_STATUS_IMPL(Node_GetNumAttributes, _In_ const OrtNode* node, _Out_ size_ ORT_API_STATUS_IMPL(Node_GetAttributes, _In_ const OrtNode* node, _Out_writes_(num_attributes) const OrtOpAttr** attributes, _In_ size_t num_attributes); ORT_API_STATUS_IMPL(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, - _Outptr_ const OrtOpAttr** attribute); + _Outptr_result_maybenull_ const OrtOpAttr** attribute); ORT_API_STATUS_IMPL(OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOpAttrType* type); ORT_API_STATUS_IMPL(OpAttr_GetName, _In_ const OrtOpAttr* attribute, _Outptr_ const char** name); ORT_API_STATUS_IMPL(Node_GetNumSubgraphs, _In_ const OrtNode* node, _Out_ size_t* num_subgraphs); diff --git a/onnxruntime/core/session/ep_api.cc b/onnxruntime/core/session/plugin_ep/ep_api.cc similarity index 99% rename from onnxruntime/core/session/ep_api.cc rename to onnxruntime/core/session/plugin_ep/ep_api.cc index 8fd1fc198374f..cae0b086af66c 100644 --- a/onnxruntime/core/session/ep_api.cc +++ b/onnxruntime/core/session/plugin_ep/ep_api.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/session/ep_api.h" +#include "core/session/plugin_ep/ep_api.h" #include #include diff --git a/onnxruntime/core/session/ep_api.h b/onnxruntime/core/session/plugin_ep/ep_api.h similarity index 100% rename from onnxruntime/core/session/ep_api.h rename to onnxruntime/core/session/plugin_ep/ep_api.h diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_cpu.cc b/onnxruntime/core/session/plugin_ep/ep_factory_cpu.cc new file mode 100644 index 0000000000000..7e6d0dd2ae5df --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_cpu.cc @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/plugin_ep/ep_factory_cpu.h" + +#include "core/framework/error_code_helper.h" +#include "core/graph/constants.h" +#include "core/providers/cpu/cpu_execution_provider.h" +#include "core/session/abi_devices.h" +#include "core/session/abi_logger.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/plugin_ep/ep_api.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/ort_apis.h" + +namespace onnxruntime { + +OrtStatus* CpuEpFactory::GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept { + size_t& num_ep_devices = *p_num_ep_devices; + num_ep_devices = 0; + + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *devices[i]; + if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { + ORT_API_RETURN_IF_ERROR( + OrtExecutionProviderApi::CreateEpDevice(&ep_factory, &device, nullptr, nullptr, + &ep_devices[num_ep_devices++])); + } + } + + return nullptr; +} + +OrtStatus* CpuEpFactory::CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept { + if (num_devices != 1) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "CPU EP factory currently only supports one device at a time."); + } + + CPUExecutionProviderInfo epi{session_options->value.enable_cpu_mem_arena}; + *ep = std::make_unique(epi); + (*ep)->SetLogger(session_logger->ToInternal()); + + return nullptr; +} +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_cpu.h b/onnxruntime/core/session/plugin_ep/ep_factory_cpu.h new file mode 100644 index 0000000000000..fba9bac976bb2 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_cpu.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include "core/session/plugin_ep/ep_factory_internal_impl.h" + +#include "core/graph/constants.h" + +namespace onnxruntime { + +class CpuEpFactory : public EpFactoryInternalImpl { + public: + CpuEpFactory() : EpFactoryInternalImpl(kCpuExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { + } + + private: + OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept override; + + OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* devices, + const OrtKeyValuePairs* const* ep_metadata_pairs, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept override; +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_dml.cc b/onnxruntime/core/session/plugin_ep/ep_factory_dml.cc new file mode 100644 index 0000000000000..2f12ffa394537 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_dml.cc @@ -0,0 +1,113 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if defined(USE_DML) + +#include "core/session/plugin_ep/ep_factory_dml.h" + +#include "core/framework/error_code_helper.h" +#include "core/providers/dml/dml_provider_factory_creator.h" +#include "core/session/abi_devices.h" +#include "core/session/abi_logger.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/plugin_ep/ep_api.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/ort_apis.h" + +namespace onnxruntime { + +OrtStatus* DmlEpFactory::GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept { + size_t& num_ep_devices = *p_num_ep_devices; + num_ep_devices = 0; + + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *devices[i]; + if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + auto ep_options = std::make_unique(); + + // TODO: Should we ignore a user provided 'device_id' when they select an OrtEpDevice as that is + // associated with a specific device. + // How would we know what options should not allow user overrides if set in OrtEpDevice? + int32_t device_id = 0; // If no device_id was found default to 0 + if (auto it = device.metadata.Entries().find("DxgiAdapterNumber"); it != device.metadata.Entries().end()) { + device_id = std::stoi(it->second); + } + + ep_options->Add("device_id", std::to_string(device_id)); + + auto* api_status = OrtExecutionProviderApi::CreateEpDevice(&ep_factory, + &device, nullptr, ep_options.get(), + &ep_devices[num_ep_devices]); + + if (device_memory_infos.size() < device_id + 1) { + device_memory_infos.resize(device_id + 1); + device_allocators.resize(device_id + 1); + } + + if (device_memory_infos[device_id] == nullptr) { + // Create memory info for the device if it doesn't already exist + device_memory_infos[device_id] = std::make_unique( + "DML", OrtAllocatorType::OrtDeviceAllocator, + OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::MICROSOFT, + narrow(device_id))); + } + + // This is what we need to add once CreateAllocator is implemented to create a shared allocator for the device. + // OrtExecutionProviderApi::EpDevice_AddAllocatorInfo(ep_devices[num_ep_devices], + // device_memory_infos[device_id].get()); + + if (api_status != nullptr) { + return api_status; + } + + ++num_ep_devices; + } + } + + return nullptr; +} + +OrtStatus* DmlEpFactory::CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept { + *ep = nullptr; + + if (num_devices != 1) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "DML EP factory currently only supports one device at a time."); + } + + auto ep_options = GetOptionsFromSessionOptions(session_options->value); + auto dml_ep_factory = DMLProviderFactoryCreator::CreateFromProviderOptions(session_options->value.config_options, + ep_options); + + *ep = dml_ep_factory->CreateProvider(); + (*ep)->SetLogger(session_logger->ToInternal()); + + return nullptr; +} + +/* +// TODO: This needs to create an allocator for the specific device so it's available as a shared allocator. That +// requires pulling lots of things out of the DML EP to get the D3D12 device and create a +// BucketizedBufferAllocator. See providers\dml\DmlExecutionProvider\src\ExecutionProvider.cpp +OrtStatus* DmlEpFactory::CreateAllocator(const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* allocator_options, + OrtAllocator** allocator) noexcept { +} + +// TODO: Wrap the IDataTransfer implementation so we can copy to device using OrtApi CopyTensors. +OrtStatus* DmlEpFactory::CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) noexcept { +} +*/ +} // namespace onnxruntime + +#endif // USE_DML diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_dml.h b/onnxruntime/core/session/plugin_ep/ep_factory_dml.h new file mode 100644 index 0000000000000..1cdd172901942 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_dml.h @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if defined(USE_DML) + +#include "core/session/plugin_ep/ep_factory_internal_impl.h" + +#include "core/graph/constants.h" + +namespace onnxruntime { + +class DmlEpFactory : public EpFactoryInternalImpl { + public: + DmlEpFactory() : EpFactoryInternalImpl(kDmlExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { + } + + private: + OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept override; + + OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* devices, + const OrtKeyValuePairs* const* ep_metadata_pairs, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept override; + + std::vector> device_memory_infos; // memory info for each device + std::vector> device_allocators; // allocators for each device +}; + +} // namespace onnxruntime + +#endif // USE_DML diff --git a/onnxruntime/core/session/ep_factory_internal.cc b/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc similarity index 58% rename from onnxruntime/core/session/ep_factory_internal.cc rename to onnxruntime/core/session/plugin_ep/ep_factory_internal.cc index 9804aa6a5c42d..3610b0f797a46 100644 --- a/onnxruntime/core/session/ep_factory_internal.cc +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc @@ -1,18 +1,16 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/session/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_factory_internal.h" #include "core/framework/error_code_helper.h" #include "core/session/abi_devices.h" #include "core/session/abi_session_options_impl.h" -#include "core/session/ep_api_utils.h" +#include "core/session/plugin_ep/forward_to_factory_impl.h" #include "core/session/ort_apis.h" -#include "onnxruntime_config.h" // for ORT_VERSION namespace onnxruntime { - -using Forward = ForwardToFactory; +using Forward = ForwardToFactoryImpl; EpFactoryInternal::EpFactoryInternal(std::unique_ptr impl) : impl_{std::move(impl)} { @@ -32,38 +30,6 @@ EpFactoryInternal::EpFactoryInternal(std::unique_ptr impl OrtEpFactory::CreateSyncStreamForDevice = Forward::CreateSyncStreamForDevice; } -const char* EpFactoryInternal::GetVersion() const noexcept { - return ORT_VERSION; -} - -OrtStatus* EpFactoryInternal::CreateEp(const OrtHardwareDevice* const* /*devices*/, - const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, - size_t /*num_devices*/, - const OrtSessionOptions* /*api_session_options*/, - const OrtLogger* /*api_logger*/, - OrtEp** /*ep*/) { - ORT_THROW("Internal error. CreateIExecutionProvider should be used for EpFactoryInternal."); -} - -// Prior to addition to SessionOptions the EP options do not have a prefix. -// They are prefixed with 'ep..' when added to SessionOptions. -// -// Use this function to get the options without the prefix from SessionOptions. -// Required by the option parsing for multiple existing EPs. -ProviderOptions EpFactoryInternalImpl::GetOptionsFromSessionOptions(const SessionOptions& session_options) const { - const std::string option_prefix = OrtSessionOptions::GetProviderOptionPrefix(GetName()); - ProviderOptions ep_options; - - for (const auto& [key, value] : session_options.config_options.configurations) { - if (key.find(option_prefix) == 0) { - // remove the prefix and add - ep_options[key.substr(option_prefix.length())] = value; - } - } - - return ep_options; -} - InternalExecutionProviderFactory::InternalExecutionProviderFactory(EpFactoryInternal& ep_factory, gsl::span ep_devices) : ep_factory_{ep_factory} { diff --git a/onnxruntime/core/session/ep_factory_internal.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h similarity index 50% rename from onnxruntime/core/session/ep_factory_internal.h rename to onnxruntime/core/session/plugin_ep/ep_factory_internal.h index ae450efa394e8..0e34fef0ff74c 100644 --- a/onnxruntime/core/session/ep_factory_internal.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h @@ -7,85 +7,16 @@ #include #include "core/common/common.h" -#include "core/framework/execution_provider.h" #include "core/providers/providers.h" #include "core/session/onnxruntime_c_api.h" #include "core/session/ort_apis.h" +#include "core/session/plugin_ep/ep_factory_internal_impl.h" + +#include "onnxruntime_config.h" // for ORT_VERSION namespace onnxruntime { -class EpFactoryInternal; -class EpLibraryInternal; struct SessionOptions; - -// class with virtual methods that are implemented for each internal EP -class EpFactoryInternalImpl { - public: - EpFactoryInternalImpl(const std::string& ep_name, const std::string& vendor, uint32_t vendor_id) - : ep_name_(ep_name), vendor_(vendor), vendor_id_(vendor_id) { - } - - const char* GetName() const noexcept { return ep_name_.c_str(); } - const char* GetVendor() const noexcept { return vendor_.c_str(); } - uint32_t GetVendorId() const noexcept { return vendor_id_; } - const char* GetVersion() const noexcept; - - virtual OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, - _In_reads_(num_devices) const OrtHardwareDevice* const* devices, - _In_ size_t num_devices, - _Inout_ OrtEpDevice** ep_devices, - _In_ size_t max_ep_devices, - _Out_ size_t* num_ep_devices) noexcept = 0; - - virtual OrtStatus* CreateIExecutionProvider(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, - _In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs, - _In_ size_t num_devices, - _In_ const OrtSessionOptions* session_options, - _In_ const OrtLogger* logger, - _Out_ std::unique_ptr* ep) = 0; - - virtual OrtStatus* CreateAllocator(_In_ const OrtMemoryInfo* /*memory_info*/, - _In_opt_ const OrtKeyValuePairs* /*allocator_options*/, - _Outptr_ OrtAllocator** allocator) noexcept { - // default implementation does not add OrtMemoryInfo to OrtEpDevice instances returned - // so this should never be called - *allocator = nullptr; - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "CreateAllocator is not implemented for this EP factory."); - } - - virtual void ReleaseAllocator(_In_ OrtAllocator* /*allocator*/) noexcept { - // we don't create any allocators so we don't need to release any - } - - virtual OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) { - *data_transfer = nullptr; - return nullptr; // Default implementation does nothing - } - - virtual bool IsStreamAware() const { - return false; - } - - virtual OrtStatus* CreateSyncStreamForDevice(_In_ const OrtMemoryDevice* /*memory_device*/, - _In_opt_ const OrtKeyValuePairs* /*stream_options*/, - _Outptr_result_maybenull_ OrtSyncStreamImpl** stream) { - *stream = nullptr; - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, - "CreateSyncStreamForDevice is not implemented for this EP factory."); - } - - // Function ORT calls to release an EP instance. - void ReleaseEp(OrtEp* ep); - - virtual ~EpFactoryInternalImpl() = default; - - protected: - ProviderOptions GetOptionsFromSessionOptions(const SessionOptions& session_options) const; - - private: - const std::string ep_name_; // EP name library was registered with - const std::string vendor_; // EP vendor name - const uint32_t vendor_id_; // EP vendor ID -}; +class EpFactoryInternalImpl; // this class can't have any virtual methods as they break using it as an OrtEpFactory* in OrtEpDevice. class EpFactoryInternal : public OrtEpFactory { @@ -95,7 +26,7 @@ class EpFactoryInternal : public OrtEpFactory { const char* GetName() const noexcept { return impl_->GetName(); } const char* GetVendor() const noexcept { return impl_->GetVendor(); } uint32_t GetVendorId() const noexcept { return impl_->GetVendorId(); } - const char* GetVersion() const noexcept; + const char* GetVersion() const noexcept { return ORT_VERSION; } OrtStatus* GetSupportedDevices(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, _In_ size_t num_devices, @@ -106,11 +37,14 @@ class EpFactoryInternal : public OrtEpFactory { } // we don't implement this. CreateIExecutionProvider should be used. - OrtStatus* CreateEp(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, - _In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs, - _In_ size_t num_devices, - _In_ const OrtSessionOptions* session_options, - _In_ const OrtLogger* logger, _Out_ OrtEp** ep); + OrtStatus* CreateEp(_In_reads_(num_devices) const OrtHardwareDevice* const* /*devices*/, + _In_reads_(num_devices) const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, + _In_ size_t /*num_devices*/, + _In_ const OrtSessionOptions* /*session_options*/, + _In_ const OrtLogger* /*logger*/, + _Out_ OrtEp** /*ep*/) { + ORT_THROW("Internal error. CreateIExecutionProvider should be used for EpFactoryInternal."); + } // same input args as CreateEp in case we need something from device or ep_metadata_pairs in the future. OrtStatus* CreateIExecutionProvider(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, @@ -132,24 +66,23 @@ class EpFactoryInternal : public OrtEpFactory { return impl_->ReleaseAllocator(allocator); } - OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) { + OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) noexcept { return impl_->CreateDataTransfer(data_transfer); } - bool IsStreamAware() const { + bool IsStreamAware() const noexcept { return impl_->IsStreamAware(); } OrtStatus* CreateSyncStreamForDevice(_In_ const OrtMemoryDevice* memory_device, _In_opt_ const OrtKeyValuePairs* stream_options, - _Outptr_result_maybenull_ OrtSyncStreamImpl** stream) { + _Outptr_result_maybenull_ OrtSyncStreamImpl** stream) noexcept { return impl_->CreateSyncStreamForDevice(memory_device, stream_options, stream); } // Function ORT calls to release an EP instance. - void ReleaseEp(OrtEp* /*ep*/) { + void ReleaseEp(OrtEp* /*ep*/) noexcept { // we never create an OrtEp so we should never be trying to release one - ORT_THROW("Internal error. No ReleaseEp call is required for EpFactoryInternal."); } private: diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.cc b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.cc new file mode 100644 index 0000000000000..e61804d842859 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.cc @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/plugin_ep/ep_factory_internal_impl.h" + +#include "core/framework/error_code_helper.h" +#include "core/session/abi_devices.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/plugin_ep/ep_factory_internal.h" + +namespace onnxruntime { + +// Prior to addition to SessionOptions the EP options do not have a prefix. +// They are prefixed with 'ep..' when added to SessionOptions. +// +// Use this function to get the options without the prefix from SessionOptions. +// Required by the option parsing for multiple existing EPs. +ProviderOptions EpFactoryInternalImpl::GetOptionsFromSessionOptions(const SessionOptions& session_options) const { + const std::string option_prefix = OrtSessionOptions::GetProviderOptionPrefix(GetName()); + ProviderOptions ep_options; + + for (const auto& [key, value] : session_options.config_options.configurations) { + if (key.find(option_prefix) == 0) { + // remove the prefix and add + ep_options[key.substr(option_prefix.length())] = value; + } + } + + return ep_options; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h new file mode 100644 index 0000000000000..bd0b76b21511f --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/framework/execution_provider.h" +#include "core/framework/provider_options.h" +#include "core/session/onnxruntime_c_api.h" +#include "core/session/ort_apis.h" + +namespace onnxruntime { +class EpFactoryInternal; +struct SessionOptions; + +// class with virtual methods that are implemented for each internal EP +class EpFactoryInternalImpl { + public: + EpFactoryInternalImpl(const std::string& ep_name, const std::string& vendor, uint32_t vendor_id) + : ep_name_(ep_name), vendor_(vendor), vendor_id_(vendor_id) { + } + + const char* GetName() const noexcept { return ep_name_.c_str(); } + const char* GetVendor() const noexcept { return vendor_.c_str(); } + uint32_t GetVendorId() const noexcept { return vendor_id_; } + const char* GetVersion() const noexcept; + + virtual OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_ size_t num_devices, + _Inout_ OrtEpDevice** ep_devices, + _In_ size_t max_ep_devices, + _Out_ size_t* num_ep_devices) noexcept = 0; + + virtual OrtStatus* CreateIExecutionProvider(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs, + _In_ size_t num_devices, + _In_ const OrtSessionOptions* session_options, + _In_ const OrtLogger* logger, + _Out_ std::unique_ptr* ep) = 0; + + virtual OrtStatus* CreateAllocator(_In_ const OrtMemoryInfo* /*memory_info*/, + _In_opt_ const OrtKeyValuePairs* /*allocator_options*/, + _Outptr_ OrtAllocator** allocator) noexcept { + // default implementation does not add OrtMemoryInfo to OrtEpDevice instances returned + // so this should never be called + *allocator = nullptr; + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "CreateAllocator is not implemented for this EP factory."); + } + + virtual void ReleaseAllocator(_In_ OrtAllocator* /*allocator*/) noexcept { + // we don't create any allocators so we don't need to release any + } + + virtual OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) noexcept { + *data_transfer = nullptr; + return nullptr; // Default implementation does nothing + } + + virtual bool IsStreamAware() const noexcept { + return false; + } + + virtual OrtStatus* CreateSyncStreamForDevice(_In_ const OrtMemoryDevice* /*memory_device*/, + _In_opt_ const OrtKeyValuePairs* /*stream_options*/, + _Outptr_result_maybenull_ OrtSyncStreamImpl** stream) noexcept { + *stream = nullptr; + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, + "CreateSyncStreamForDevice is not implemented for this EP factory."); + } + + // Function ORT calls to release an EP instance. + void ReleaseEp(OrtEp* ep); + + virtual ~EpFactoryInternalImpl() = default; + + protected: + ProviderOptions GetOptionsFromSessionOptions(const SessionOptions& session_options) const; + + private: + const std::string ep_name_; // EP name library was registered with + const std::string vendor_; // EP vendor name + const uint32_t vendor_id_; // EP vendor ID +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc new file mode 100644 index 0000000000000..d6e51a44c1c69 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/plugin_ep/ep_factory_provider_bridge.h" + +#include "core/providers/shared_library/provider_host_api.h" + +namespace onnxruntime { +OrtStatus* ProviderBridgeEpFactory::GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* num_ep_devices) noexcept { + ORT_API_RETURN_IF_ERROR(ep_factory_.GetSupportedDevices(&ep_factory_, devices, num_devices, ep_devices, + max_ep_devices, num_ep_devices)); + + // add the EpFactoryInternal layer back in so that we can redirect to CreateIExecutionProvider. + for (size_t i = 0; i < *num_ep_devices; ++i) { + auto* ep_device = ep_devices[i]; + if (ep_device) { + ep_device->ep_factory = &ep_factory; + } + } + + return nullptr; +} + +OrtStatus* ProviderBridgeEpFactory::CreateIExecutionProvider(const OrtHardwareDevice* const* devices, + const OrtKeyValuePairs* const* ep_metadata_pairs, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept { + // get the provider specific options + auto ep_options = GetOptionsFromSessionOptions(session_options->value); + auto& provider = provider_library_.Get(); + + auto status = provider.CreateIExecutionProvider(devices, ep_metadata_pairs, num_devices, + ep_options, *session_options, *session_logger, *ep); + + return ToOrtStatus(status); +} +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h new file mode 100644 index 0000000000000..437af62dc2c0c --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/error_code_helper.h" +#include "core/session/abi_devices.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/provider_bridge_library.h" + +namespace onnxruntime { +class ProviderBridgeEpFactory : public EpFactoryInternalImpl { + public: + ProviderBridgeEpFactory(OrtEpFactory& ep_factory, ProviderLibrary& provider_library) + : EpFactoryInternalImpl(ep_factory.GetName(&ep_factory), + ep_factory.GetVendor(&ep_factory), + ep_factory.GetVendorId(&ep_factory)), + ep_factory_{ep_factory}, + provider_library_{provider_library} { + } + + private: + OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* num_ep_devices) noexcept override; + + OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* devices, + const OrtKeyValuePairs* const* ep_metadata_pairs, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept override; + + OrtStatus* CreateAllocator(const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* allocator_options, + OrtAllocator** allocator) noexcept override { + return ep_factory_.CreateAllocator(&ep_factory_, memory_info, allocator_options, allocator); + } + + void ReleaseAllocator(OrtAllocator* allocator) noexcept override { + ep_factory_.ReleaseAllocator(&ep_factory_, allocator); + } + + OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) noexcept override { + return ep_factory_.CreateDataTransfer(&ep_factory_, data_transfer); + } + + bool IsStreamAware() const noexcept override { + return ep_factory_.IsStreamAware(&ep_factory_); + } + + OrtStatus* CreateSyncStreamForDevice(const OrtMemoryDevice* device, + const OrtKeyValuePairs* stream_options, + OrtSyncStreamImpl** stream) noexcept override { + return ep_factory_.CreateSyncStreamForDevice(&ep_factory_, device, stream_options, stream); + } + + OrtEpFactory& ep_factory_; // OrtEpFactory from the provider bridge EP + ProviderLibrary& provider_library_; // ProviderLibrary from the provider bridge EP +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_webgpu.cc b/onnxruntime/core/session/plugin_ep/ep_factory_webgpu.cc new file mode 100644 index 0000000000000..0f955e0bab248 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_webgpu.cc @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if defined(USE_WEBGPU) +#include "core/session/plugin_ep/ep_factory_webgpu.h" + +#include "core/framework/error_code_helper.h" +#include "core/providers/webgpu/webgpu_provider_factory_creator.h" +#include "core/session/abi_devices.h" +#include "core/session/abi_logger.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/plugin_ep/ep_api.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/ort_apis.h" + +namespace onnxruntime { + +OrtStatus* WebGpuEpFactory::GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept { + size_t& num_ep_devices = *p_num_ep_devices; + num_ep_devices = 0; + + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *devices[i]; + if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + // TODO: any metadata or options to add? + ORT_API_RETURN_IF_ERROR(OrtExecutionProviderApi::CreateEpDevice(&ep_factory, + &device, nullptr, nullptr, + &ep_devices[num_ep_devices++])); + } + } + + return nullptr; +} + +OrtStatus* WebGpuEpFactory::CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept { + *ep = nullptr; + + if (num_devices != 1) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "WebGPU EP factory currently only supports one device at a time."); + } + + auto webgpu_ep_factory = WebGpuProviderFactoryCreator::Create(session_options->value.config_options); + *ep = webgpu_ep_factory->CreateProvider(); + (*ep)->SetLogger(session_logger->ToInternal()); + + return nullptr; +} + +/* TODO: Implement CreateAllocator and CreateDataTransfer to support shared allocators and data transfer outside of + an InferenceSession. +OrtStatus* WebGpuEpFactory::CreateAllocator(const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* allocator_options, + OrtAllocator** allocator) noexcept override { + *allocator = device_allocators[memory_info->device.Id()].get(); +} + +OrtStatus* WebGpuEpFactory::CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) override { + // TODO: Wrap the IDataTransfer implementation so we can copy to device using OrtApi CopyTensors. + *data_transfer = nullptr; + return nullptr; +} +*/ +} // namespace onnxruntime + +#endif // USE_WEBGPU diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_webgpu.h b/onnxruntime/core/session/plugin_ep/ep_factory_webgpu.h new file mode 100644 index 0000000000000..06ecfa744bbda --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_webgpu.h @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if defined(USE_WEBGPU) +#include "core/session/plugin_ep/ep_factory_internal_impl.h" + +#include "core/graph/constants.h" + +namespace onnxruntime { + +class WebGpuEpFactory : public EpFactoryInternalImpl { + public: + WebGpuEpFactory() : EpFactoryInternalImpl(kWebGpuExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { + } + + private: + OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept override; + + OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* devices, + const OrtKeyValuePairs* const* ep_metadata_pairs, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept override; +}; +} // namespace onnxruntime + +#endif // USE_WEBGPU diff --git a/onnxruntime/core/session/ep_library.h b/onnxruntime/core/session/plugin_ep/ep_library.h similarity index 100% rename from onnxruntime/core/session/ep_library.h rename to onnxruntime/core/session/plugin_ep/ep_library.h diff --git a/onnxruntime/core/session/plugin_ep/ep_library_internal.cc b/onnxruntime/core/session/plugin_ep/ep_library_internal.cc new file mode 100644 index 0000000000000..d4015e0bbd366 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_library_internal.cc @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/plugin_ep/ep_library_internal.h" +#include "core/session/plugin_ep/ep_factory_cpu.h" +#include "core/session/plugin_ep/ep_factory_dml.h" +#include "core/session/plugin_ep/ep_factory_webgpu.h" + +namespace onnxruntime { + +std::unique_ptr EpLibraryInternal::CreateCpuEp() { + auto cpu_factory_impl = std::make_unique(); + auto internal_factory = std::make_unique(std::move(cpu_factory_impl)); + return std::make_unique(std::move(internal_factory)); +} + +#if defined(USE_DML) + +std::unique_ptr EpLibraryInternal::CreateDmlEp() { + auto dml_factory_impl = std::make_unique(); + auto internal_factory = std::make_unique(std::move(dml_factory_impl)); + return std::make_unique(std::move(internal_factory)); +} +#endif + +#if defined(USE_WEBGPU) +std::unique_ptr EpLibraryInternal::CreateWebGpuEp() { + auto webgpu_factory_impl = std::make_unique(); + auto internal_factory = std::make_unique(std::move(webgpu_factory_impl)); + return std::make_unique(std::move(internal_factory)); +} +#endif + +std::vector> EpLibraryInternal::CreateInternalEps() { + std::vector> internal_eps; + internal_eps.reserve(4); + + // CPU EP + internal_eps.push_back(CreateCpuEp()); + +#if defined(USE_WEBGPU) + internal_eps.push_back(CreateWebGpuEp()); +#endif + +#if defined(USE_DML) + internal_eps.push_back(CreateDmlEp()); +#endif + + return internal_eps; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/session/ep_library_internal.h b/onnxruntime/core/session/plugin_ep/ep_library_internal.h similarity index 94% rename from onnxruntime/core/session/ep_library_internal.h rename to onnxruntime/core/session/plugin_ep/ep_library_internal.h index ab529edc2507f..1587f01360e26 100644 --- a/onnxruntime/core/session/ep_library_internal.h +++ b/onnxruntime/core/session/plugin_ep/ep_library_internal.h @@ -4,8 +4,8 @@ #pragma once #include "core/common/common.h" -#include "core/session/ep_library.h" -#include "core/session/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_library.h" +#include "core/session/plugin_ep/ep_factory_internal.h" #include "core/session/onnxruntime_c_api.h" #include "core/session/provider_bridge_library.h" diff --git a/onnxruntime/core/session/ep_library_plugin.cc b/onnxruntime/core/session/plugin_ep/ep_library_plugin.cc similarity index 98% rename from onnxruntime/core/session/ep_library_plugin.cc rename to onnxruntime/core/session/plugin_ep/ep_library_plugin.cc index 32ddd8a765b4c..ebfa364f4f1df 100644 --- a/onnxruntime/core/session/ep_library_plugin.cc +++ b/onnxruntime/core/session/plugin_ep/ep_library_plugin.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/session/ep_library_plugin.h" +#include "core/session/plugin_ep/ep_library_plugin.h" #include "core/common/logging/logging.h" #include "core/framework/error_code_helper.h" diff --git a/onnxruntime/core/session/ep_library_plugin.h b/onnxruntime/core/session/plugin_ep/ep_library_plugin.h similarity index 96% rename from onnxruntime/core/session/ep_library_plugin.h rename to onnxruntime/core/session/plugin_ep/ep_library_plugin.h index e2b02ccc654da..e044e91b61e37 100644 --- a/onnxruntime/core/session/ep_library_plugin.h +++ b/onnxruntime/core/session/plugin_ep/ep_library_plugin.h @@ -6,7 +6,7 @@ #include #include -#include "core/session/ep_library.h" +#include "core/session/plugin_ep/ep_library.h" namespace onnxruntime { /// diff --git a/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.cc b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.cc new file mode 100644 index 0000000000000..06cf54aea4071 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.cc @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/plugin_ep/ep_library_provider_bridge.h" + +#include "core/session/plugin_ep/ep_factory_provider_bridge.h" + +namespace onnxruntime { +Status EpLibraryProviderBridge::Load() { + std::lock_guard lock{mutex_}; + + if (!factories_.empty()) { + // already loaded + return Status::OK(); + } + + // if we have been unloaded we can't just be reloaded. + if (!ep_library_plugin_ || !provider_library_) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "EpLibraryProviderBridge has been unloaded. " + "Please create a new instance using LoadPluginOrProviderBridge."); + } + + // wrap the EpLibraryPlugin factories that were created via calling CreateEpFactories in the library. + // use GetSupportedDevices from the library's factory. + // to do this we need to capture `factory` and plug it in to is_supported_fn and create_fn. + // we also need to update any returned OrtEpDevice instances to swap the wrapper EpFactoryInternal in so that we can + // call Provider::CreateIExecutionProvider in EpFactoryInternal::CreateIExecutionProvider. + for (const auto& factory : ep_library_plugin_->GetFactories()) { + auto factory_impl = std::make_unique(*factory, *provider_library_); + auto internal_factory = std::make_unique(std::move(factory_impl)); + + factory_ptrs_.push_back(internal_factory.get()); + internal_factory_ptrs_.push_back(internal_factory.get()); + factories_.push_back(std::move(internal_factory)); + } + + return Status::OK(); +} + +Status EpLibraryProviderBridge::Unload() { + std::lock_guard lock{mutex_}; + + internal_factory_ptrs_.clear(); + factory_ptrs_.clear(); + factories_.clear(); + + // we loaded ep_library_plugin_ after provider_library_ in LoadPluginOrProviderBridge so do the reverse order here. + ORT_RETURN_IF_ERROR(ep_library_plugin_->Unload()); + ep_library_plugin_ = nullptr; + + provider_library_->Unload(); + provider_library_ = nullptr; + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/session/ep_library_provider_bridge.h b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h similarity index 95% rename from onnxruntime/core/session/ep_library_provider_bridge.h rename to onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h index 0717ccd957de7..c7e8ebefc3785 100644 --- a/onnxruntime/core/session/ep_library_provider_bridge.h +++ b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h @@ -5,8 +5,8 @@ #include #include -#include "core/session/ep_library.h" -#include "core/session/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_library.h" +#include "core/session/plugin_ep/ep_factory_internal.h" #include "core/session/provider_bridge_library.h" namespace onnxruntime { diff --git a/onnxruntime/core/session/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc similarity index 99% rename from onnxruntime/core/session/ep_plugin_provider_interfaces.cc rename to onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index c7d7ea2e8a4ec..2aac1e1c21cc7 100644 --- a/onnxruntime/core/session/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/session/ep_plugin_provider_interfaces.h" +#include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" #include #include diff --git a/onnxruntime/core/session/ep_plugin_provider_interfaces.h b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h similarity index 100% rename from onnxruntime/core/session/ep_plugin_provider_interfaces.h rename to onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h diff --git a/onnxruntime/core/session/ep_api_utils.h b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h similarity index 99% rename from onnxruntime/core/session/ep_api_utils.h rename to onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h index 77528565eced7..67b22779395ec 100644 --- a/onnxruntime/core/session/ep_api_utils.h +++ b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h @@ -7,7 +7,7 @@ namespace onnxruntime { // helper to forward a call from the C API to an instance of the factory implementation. // used by EpFactoryInternal and EpFactoryProviderBridge. template -struct ForwardToFactory { +struct ForwardToFactoryImpl { static const char* ORT_API_CALL GetFactoryName(const OrtEpFactory* this_ptr) noexcept { return static_cast(this_ptr)->GetName(); } diff --git a/onnxruntime/core/session/provider_policy_context.cc b/onnxruntime/core/session/provider_policy_context.cc index 211bf8b2d15a4..6bcbda0f13b92 100644 --- a/onnxruntime/core/session/provider_policy_context.cc +++ b/onnxruntime/core/session/provider_policy_context.cc @@ -11,8 +11,8 @@ #include "core/framework/error_code_helper.h" #include "core/session/abi_devices.h" #include "core/session/abi_logger.h" -#include "core/session/ep_factory_internal.h" -#include "core/session/ep_plugin_provider_interfaces.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" #include "core/session/inference_session.h" #include "core/session/inference_session_utils.h" #include "core/session/onnxruntime_c_api.h" diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 69039beb49363..f90ace95d6e58 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -19,10 +19,10 @@ #include "core/session/ort_env.h" #if !defined(ORT_MINIMAL_BUILD) -#include "core/session/ep_factory_internal.h" -#include "core/session/ep_plugin_provider_interfaces.h" -#include "core/session/ep_library_plugin.h" -#include "core/session/ep_library_provider_bridge.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" +#include "core/session/plugin_ep/ep_library_plugin.h" +#include "core/session/plugin_ep/ep_library_provider_bridge.h" #include "core/session/model_compilation_options.h" #include "core/session/provider_policy_context.h" #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/python/onnxruntime_pybind_exceptions.cc b/onnxruntime/python/onnxruntime_pybind_exceptions.cc index 8f3b97c8c7786..6b3062205b52e 100644 --- a/onnxruntime/python/onnxruntime_pybind_exceptions.cc +++ b/onnxruntime/python/onnxruntime_pybind_exceptions.cc @@ -37,6 +37,7 @@ void RegisterExceptions(pybind11::module& m) { pybind11::register_exception(m, "EPFail"); pybind11::register_exception(m, "ModelLoadCanceled"); pybind11::register_exception(m, "ModelRequiresCompilation"); + pybind11::register_exception(m, "NotFound"); } void OrtPybindThrowIfError(onnxruntime::common::Status status) { @@ -67,6 +68,8 @@ void OrtPybindThrowIfError(onnxruntime::common::Status status) { throw ModelLoadCanceled(std::move(msg)); case onnxruntime::common::StatusCode::MODEL_REQUIRES_COMPILATION: throw ModelRequiresCompilation(std::move(msg)); + case onnxruntime::common::StatusCode::NOT_FOUND: + throw NotFound(std::move(msg)); default: throw std::runtime_error(std::move(msg)); } diff --git a/onnxruntime/python/onnxruntime_pybind_exceptions.h b/onnxruntime/python/onnxruntime_pybind_exceptions.h index 86bc4a5da8d46..7680c06c59d79 100644 --- a/onnxruntime/python/onnxruntime_pybind_exceptions.h +++ b/onnxruntime/python/onnxruntime_pybind_exceptions.h @@ -50,6 +50,9 @@ struct ModelLoadCanceled : std::runtime_error { struct ModelRequiresCompilation : std::runtime_error { explicit ModelRequiresCompilation(const std::string& what) : std::runtime_error(what) {} }; +struct NotFound : std::runtime_error { + explicit NotFound(const std::string& what) : std::runtime_error(what) {} +}; void RegisterExceptions(pybind11::module& m); diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index ec4d8c6330c8d..acf0681cf8752 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -46,7 +46,7 @@ #if !defined(ORT_MINIMAL_BUILD) #include "core/session/abi_devices.h" -#include "core/session/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_factory_internal.h" #include "core/session/provider_policy_context.h" #include "core/session/utils.h" #endif diff --git a/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc b/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc index c9a7116bf8052..2918e4baf86a4 100644 --- a/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc +++ b/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc @@ -82,21 +82,48 @@ static void CalculateDynamicQuantizeMatMul(const int64_t M, const int64_t N, con } } +struct TestDynamicQuantizeMatMulOptions { + bool is_matrix_b_constant = true; + + bool per_column = false; + + bool is_scale_constant = false; + + bool has_zp = true; + bool is_zp_constant = false; + bool is_zp_zero = false; + + bool has_bias = false; + bool is_bias_constant = false; + + bool empty_input = false; +}; + template -void TestDynamicQuantizeMatMul(bool is_matrix_b_constant, - bool per_column = false, - bool has_zp = true, - bool has_bias = false, - bool empty_input = false) { +void TestDynamicQuantizeMatMul(const TestDynamicQuantizeMatMulOptions& opts) { + static_assert(std::is_same_v || std::is_same_v); + + SCOPED_TRACE(MakeString( + "b data type:", (std::is_same_v ? "uint8" : "int8"), + ", is_matrix_b_constant:", opts.is_matrix_b_constant, + ", per_column:", opts.per_column, + ", is_scale_constant:", opts.is_scale_constant, + ", has_zp:", opts.has_zp, + ", is_zp_constant:", opts.is_zp_constant, + ", is_zp_zero:", opts.is_zp_zero, + ", has_bias:", opts.has_bias, + ", is_bias_constant:", opts.is_bias_constant, + ", empty_input:", opts.empty_input)); + // create rand inputs RandomValueGenerator random{1668426375}; - int64_t M = empty_input ? 1 : 4; + int64_t M = opts.empty_input ? 1 : 4; int64_t N = 128; int64_t K = 128; - std::vector A_dims{empty_input ? 0 : M, K}; + std::vector A_dims{opts.empty_input ? 0 : M, K}; std::vector B_dims{K, N}; - std::vector Y_dims{empty_input ? 0 : M, K}; + std::vector Y_dims{opts.empty_input ? 0 : M, K}; std::vector A_data = random.Uniform(A_dims, -1.0f, 1.0f); std::vector B_data; std::vector tmp_B_data = random.Uniform(B_dims, @@ -106,101 +133,120 @@ void TestDynamicQuantizeMatMul(bool is_matrix_b_constant, return static_cast(v); }); - int64_t b_scale_zp_size = per_column ? B_dims.back() : 1; + int64_t b_scale_zp_size = opts.per_column ? B_dims.back() : 1; std::vector B_scale = random.Uniform(AsSpan({b_scale_zp_size}), -0.1f, 0.1f); std::vector B_zero_point(b_scale_zp_size); - std::for_each(B_zero_point.begin(), - B_zero_point.end(), - [&random](T& zp) { - zp = static_cast(random.Uniform(std::array{1}, - std::numeric_limits::min(), - std::numeric_limits::max())[0]); - }); + if (!opts.is_zp_zero) { + std::for_each(B_zero_point.begin(), + B_zero_point.end(), + [&random](T& zp) { + zp = static_cast(random.Uniform(std::array{1}, + std::numeric_limits::min(), + std::numeric_limits::max())[0]); + }); + } std::vector Bias = random.Uniform(AsSpan({B_dims.back()}), -0.1f, 0.1f); OpTester test("DynamicQuantizeMatMul", 1, onnxruntime::kMSDomain); test.AddInput("A", A_dims, A_data); - test.AddInput("B", B_dims, B_data, is_matrix_b_constant); - test.AddInput("b_scale", {b_scale_zp_size}, B_scale); + test.AddInput("B", B_dims, B_data, opts.is_matrix_b_constant); + test.AddInput("b_scale", {b_scale_zp_size}, B_scale, opts.is_scale_constant); - if (has_zp) { - test.AddInput("b_zero_point", {b_scale_zp_size}, B_zero_point); + if (opts.has_zp) { + test.AddInput("b_zero_point", {b_scale_zp_size}, B_zero_point, opts.is_zp_constant); } else { test.AddOptionalInputEdge(); } - if (has_bias) { - test.AddInput("bias", {B_dims.back()}, Bias); + if (opts.has_bias) { + test.AddInput("bias", {B_dims.back()}, Bias, opts.is_bias_constant); } else { test.AddOptionalInputEdge(); } std::vector Y_data(M * N); CalculateDynamicQuantizeMatMul(M, N, K, A_data, B_data, B_scale, B_zero_point, Bias, Y_data, - per_column, has_zp, has_bias); + opts.per_column, opts.has_zp, opts.has_bias); test.AddOutput("Y", Y_dims, Y_data); test.SetOutputRelErr("Y", 0.02f); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); } -template -void RunDynamicQuantizeMatMulTest() { - TestDynamicQuantizeMatMul(false, /*is_matrix_b_constant*/ - false, /*per_column*/ - HasZeroPoint, /*has_zp*/ - HasBias /*has_bias*/ - ); - - TestDynamicQuantizeMatMul(true, /*is_matrix_b_constant*/ - false, /*per_column*/ - HasZeroPoint, /*has_zp*/ - HasBias /*has_bias*/ - ); - - TestDynamicQuantizeMatMul(false, /*is_matrix_b_constant*/ - true, /*per_column*/ - HasZeroPoint, /*has_zp*/ - HasBias /*has_bias*/ - ); - - TestDynamicQuantizeMatMul(true, /*is_matrix_b_constant*/ - true, /*per_column*/ - HasZeroPoint, /*has_zp*/ - HasBias /*has_bias*/ - ); +template +void TestDynamicQuantizeMatMul(bool is_matrix_b_constant, + bool per_column = false, + bool has_zp = true, + bool has_bias = false, + bool empty_input = false) { + TestDynamicQuantizeMatMulOptions opts{}; + opts.is_matrix_b_constant = is_matrix_b_constant; + opts.per_column = per_column; + opts.has_zp = has_zp; + opts.has_bias = has_bias; + opts.empty_input = empty_input; + + TestDynamicQuantizeMatMul(opts); } -TEST(DynamicQuantizeMatMul, HasZeroPoint_NoBias_test_S8) { - RunDynamicQuantizeMatMulTest(); +template +void RunDynamicQuantizeMatMulTest() { + for (bool is_matrix_b_constant : {false, true}) { + for (bool per_column : {false, true}) { + for (bool has_zp : {false, true}) { + for (bool has_bias : {false, true}) { + TestDynamicQuantizeMatMul(is_matrix_b_constant, + per_column, + has_zp, + has_bias); + } + } + } + } } -TEST(DynamicQuantizeMatMul, HasZeroPoint_NoBias_test_U8) { - RunDynamicQuantizeMatMulTest(); +TEST(DynamicQuantizeMatMul, Int8) { + RunDynamicQuantizeMatMulTest(); } -TEST(DynamicQuantizeMatMul, NoZeroPoint_HasBias_test_S8) { - RunDynamicQuantizeMatMulTest(); +TEST(DynamicQuantizeMatMul, UInt8) { + RunDynamicQuantizeMatMulTest(); } -TEST(DynamicQuantizeMatMul, NoZeroPoint_HasBias_test_U8) { - RunDynamicQuantizeMatMulTest(); -} +TEST(DynamicQuantizeMatMul, WithConstantBInputs) { + TestDynamicQuantizeMatMulOptions base_opts{}; + base_opts.is_matrix_b_constant = true; + base_opts.is_scale_constant = true; + base_opts.is_zp_constant = true; -TEST(DynamicQuantizeMatMul, NoZeroPoint_NoBias_test_S8) { - RunDynamicQuantizeMatMulTest(); -} + { + // no zp + auto opts = base_opts; + opts.has_zp = false; -TEST(DynamicQuantizeMatMul, NoZeroPoint_NoBias_test_U8) { - RunDynamicQuantizeMatMulTest(); -} + TestDynamicQuantizeMatMul(opts); + TestDynamicQuantizeMatMul(opts); + } -TEST(DynamicQuantizeMatMul, HasZeroPoint_HasBias_test_S8) { - RunDynamicQuantizeMatMulTest(); -} + { + // zp that is zero (symmetric quantization) + auto opts = base_opts; + opts.has_zp = true; + opts.is_zp_zero = true; -TEST(DynamicQuantizeMatMul, HasZeroPoint_HasBias_test_U8) { - RunDynamicQuantizeMatMulTest(); + TestDynamicQuantizeMatMul(opts); + TestDynamicQuantizeMatMul(opts); + } + + { + // zp that is non-zero + auto opts = base_opts; + opts.has_zp = true; + opts.is_zp_zero = false; + + TestDynamicQuantizeMatMul(opts); + TestDynamicQuantizeMatMul(opts); + } } TEST(DynamicQuantizeMatMul, UInt8_test_with_empty_input) { diff --git a/onnxruntime/test/ep_graph/test_ep_graph.cc b/onnxruntime/test/ep_graph/test_ep_graph.cc index 45314f8f39eea..bdbc60c1a0c48 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph.cc @@ -87,6 +87,92 @@ TEST(EpGraphTest, Check3LayerNestedSubgraphV2) { CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); } +TEST(EpGraphTest, GetAttributeByName) { + // Load model with a single Conv that has no explicit attributes set. + auto test_graph = TestGraph::Load(ORT_TSTR("testdata/conv_default_attrs.onnx")); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + // + // Pre-check + // + + // Original Conv has no explicit attributes but Graph::Resolve() fills in default values for + // 'auto_pad' and 'group'. The other optional attributes (i.e. dilations, kernel_shape, pads, strides) do not + // have statically computable default values, so will not be filled in by Graph::Resolve(). + const OrtGraph& ort_graph = test_graph->GetOrtGraph(); + const OrtApi& ort_api = Ort::GetApi(); + + size_t num_nodes = 0; + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNumNodes(&ort_graph, &num_nodes)); + ASSERT_EQ(num_nodes, 1); + + std::vector nodes(num_nodes); + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNodes(&ort_graph, nodes.data(), nodes.size())); + + const OrtNode* conv_node = nodes[0]; + const char* op_type = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.Node_GetOperatorType(conv_node, &op_type)); + ASSERT_STREQ(op_type, "Conv"); + + size_t num_attrs = 0; + ASSERT_ORTSTATUS_OK(ort_api.Node_GetNumAttributes(conv_node, &num_attrs)); + ASSERT_EQ(num_attrs, 2); + + std::vector attrs(num_attrs); + ASSERT_ORTSTATUS_OK(ort_api.Node_GetAttributes(conv_node, attrs.data(), attrs.size())); + for (const OrtOpAttr* attr : attrs) { + const char* attr_name_cstr = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.OpAttr_GetName(attr, &attr_name_cstr)); + std::string_view attr_name = attr_name_cstr; + ASSERT_TRUE(attr_name == "auto_pad" || attr_name == "group"); // Only 'auto_pad' and 'group' have been set + } + + // + // Test 1: Get optional attribute that is not set (e.g., dilations). Should not get an error. + // + { + const OrtOpAttr* attr = nullptr; + Ort::Status status{ort_api.Node_GetAttributeByName(conv_node, "dilations", &attr)}; + ASSERT_TRUE(status.IsOK()); + ASSERT_EQ(attr, nullptr); + } + + // + // Test 2: Get attribute that does not exist in operator schema. Should get a ORT_NOT_FOUND error. + // + { + const OrtOpAttr* attr = nullptr; + Ort::Status status{ort_api.Node_GetAttributeByName(conv_node, "_does_not_exist_", &attr)}; + ASSERT_FALSE(status.IsOK()); + ASSERT_EQ(status.GetErrorCode(), ORT_NOT_FOUND); + ASSERT_EQ(attr, nullptr); + } + + // + // Test 3: Get attribute that is known to be set. + // + { + const OrtOpAttr* attr = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.Node_GetAttributeByName(conv_node, "auto_pad", &attr)); + ASSERT_NE(attr, nullptr); + + OrtOpAttrType attr_type = ORT_OP_ATTR_UNDEFINED; + ASSERT_ORTSTATUS_OK(ort_api.OpAttr_GetType(attr, &attr_type)); + ASSERT_EQ(attr_type, ORT_OP_ATTR_STRING); + + std::string auto_pad_val; + + // First call to ReadOpAttr gets the total byte size. Second call reads the data. + size_t total_attr_bytes = 0; + Ort::Status status2{ort_api.ReadOpAttr(attr, attr_type, nullptr, 0, &total_attr_bytes)}; + auto_pad_val.resize(total_attr_bytes); + + ASSERT_ORTSTATUS_OK(ort_api.ReadOpAttr(attr, attr_type, auto_pad_val.data(), total_attr_bytes, + &total_attr_bytes)); + ASSERT_EQ(auto_pad_val, "NOTSET"); + } +} + // Check correctness of an OrtGraph that has external initializers. TEST(EpGraphTest, CheckModelExternalInitializers) { auto test_graph = TestGraph::Load(ORT_TSTR("testdata/conv_qdq_external_ini.onnx")); diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index 4c5dcd2bd7580..35f7d06fb0912 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/session/ep_plugin_provider_interfaces.h" +#include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" #include "gsl/gsl" #include "gtest/gtest.h" diff --git a/onnxruntime/test/testdata/conv_default_attrs.onnx b/onnxruntime/test/testdata/conv_default_attrs.onnx new file mode 100644 index 0000000000000..fc7ee58dee15e Binary files /dev/null and b/onnxruntime/test/testdata/conv_default_attrs.onnx differ diff --git a/onnxruntime/test/testdata/make_conv_default_attrs.py b/onnxruntime/test/testdata/make_conv_default_attrs.py new file mode 100644 index 0000000000000..fc092bf8b25fb --- /dev/null +++ b/onnxruntime/test/testdata/make_conv_default_attrs.py @@ -0,0 +1,36 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import numpy as np +import onnx + + +def main(): + inp_shape = (1, 2, 8, 8) + input_0 = onnx.helper.make_tensor_value_info("input_0", onnx.TensorProto.FLOAT, inp_shape) + output_0 = onnx.helper.make_tensor_value_info("output_0", onnx.TensorProto.FLOAT, None) + + weight_data = [ + [[[-1.5, 0.0], [0.2, 1.5]], [[-1.5, 0.0], [0.2, 1.5]]], + [[[-1.0, 0.0], [0.1333, 1.0]], [[-1.0, 0.0], [0.1333, 1.0]]], + ] + weight = onnx.numpy_helper.from_array(np.array(weight_data, dtype=np.float32), "weight") + bias = onnx.numpy_helper.from_array(np.array([0.0, 0.0], dtype=np.float32), "bias") + conv_node = onnx.helper.make_node("Conv", ["input_0", "weight", "bias"], ["output_0"], name="Conv0") + graph = onnx.helper.make_graph( + [conv_node], + "Convf32", + [input_0], + [output_0], + initializer=[weight, bias], + ) + opset_imports = [onnx.helper.make_opsetid("", 21)] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + model = onnx.shape_inference.infer_shapes(model) + + onnx.checker.check_model(model, True) + onnx.save_model(model, "conv_default_attrs.onnx") + + +if __name__ == "__main__": + main() diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml index c1b83c5e579dc..f4a62208059c8 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml @@ -316,7 +316,21 @@ stages: MACHINE_POOL: 'onnxruntime-qnn-windows-vs-2022-arm64' QNN_SDK: ${{ parameters.qnn_sdk_version }} BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} - is1ES: true + PYTHON_VERSION: '3.11' + + - template: ../templates/py-win-arm64-qnn.yml + parameters: + MACHINE_POOL: 'onnxruntime-qnn-windows-vs-2022-arm64' + QNN_SDK: ${{ parameters.qnn_sdk_version }} + BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} + PYTHON_VERSION: '3.12' + + - template: ../templates/py-win-arm64-qnn.yml + parameters: + MACHINE_POOL: 'onnxruntime-qnn-windows-vs-2022-arm64' + QNN_SDK: ${{ parameters.qnn_sdk_version }} + BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} + PYTHON_VERSION: '3.13' - ${{ if eq(parameters.enable_windows_arm64ec_qnn, true) }}: - stage: Python_Packaging_Windows_arm64ec_QNN @@ -327,7 +341,6 @@ stages: MACHINE_POOL: 'Onnxruntime-QNNEP-Windows-2022-CPU' QNN_SDK: ${{ parameters.qnn_sdk_version }} BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} - is1ES: true - ${{ if eq(parameters.enable_windows_x64_qnn, true) }}: - stage: Python_Packaging_Windows_x64_QNN diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml index 761c551e9f4d9..3c2ef4741f049 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -4,6 +4,10 @@ parameters: type: string default: 'onnxruntime-qnn-windows-vs-2022-arm64' +- name: PYTHON_VERSION + type: string + default: '3.11' + - name: QNN_SDK displayName: QNN SDK Version type: string @@ -19,13 +23,8 @@ parameters: type: string default: '' -- name: is1ES - displayName: 'Whether the pipeline is running in 1ES' - type: boolean - default: false - jobs: -- job: Win_py_arm64_qnn_Wheels +- job: Win_py_arm64_qnn_Wheels_${{ replace(parameters.PYTHON_VERSION,'.','_') }} timeoutInMinutes: 210 workspace: clean: all @@ -48,41 +47,21 @@ jobs: outputs: - output: pipelineArtifact targetPath: $(Build.ArtifactStagingDirectory) - artifactName: onnxruntime_qnn_arm64_$(PythonVersion) - - strategy: - matrix: - Python311_arm64: - PythonVersion: '3.11.0' - LocalPythonDir: 'C:\Python\Python311' - Python312_arm64: - PythonVersion: '3.12.6' - LocalPythonDir: 'C:\Python\Python312' - Python313_arm64: - PythonVersion: '3.13.2' - LocalPythonDir: 'C:\Python\Python313' + artifactName: onnxruntime_qnn_arm64_${{ parameters.PYTHON_VERSION }} + variables: GRADLE_OPTS: '-Dorg.gradle.daemon=false' VSGenerator: 'Visual Studio 17 2022' steps: - checkout: self clean: true - submodules: recursive + submodules: none - template: telemetry-steps.yml - - script: | - MKDIR $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64 - XCOPY /s /y /h /e /c /q "$(LocalPythonDir)\*.*" $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64\ - COPY NUL $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64.complete - DIR $(Agent.ToolsDirectory)\Python - DIR $(Agent.ToolsDirectory)\Python\$(PythonVersion) - DIR $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64 - displayName: Copy python $(PythonVersion) version to agent tools directory - - task: UsePythonVersion@0 inputs: - versionSpec: $(PythonVersion) + versionSpec: ${{ parameters.PYTHON_VERSION }} addToPath: true architecture: 'arm64' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml index 74cae38393ea6..c8d37457a1034 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml @@ -19,11 +19,6 @@ parameters: type: string default: '' -- name: is1ES - displayName: 'Whether the pipeline is running in 1ES' - type: boolean - default: false - jobs: - job: Win_py_x64_qnn_Wheels timeoutInMinutes: 210 diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index 7ebf5394e4530..66d1cd1687d99 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -61,16 +61,6 @@ jobs: # because the python bindings also use the USE__PROVIDER_INTERFACE preprocessor macros. ExtraQnnBuildArgs: '--enable_generic_interface --build_wheel' steps: - - - script: | - MKDIR $(Agent.ToolsDirectory)\Python\3.11.0\arm64 - XCOPY /s /y /h /e /c /q "C:\Python\Python311\*.*" $(Agent.ToolsDirectory)\Python\3.11.0\arm64\ - COPY NUL $(Agent.ToolsDirectory)\Python\3.11.0\arm64.complete - DIR $(Agent.ToolsDirectory)\Python - DIR $(Agent.ToolsDirectory)\Python\3.11.0 - DIR $(Agent.ToolsDirectory)\Python\3.11.0\arm64 - displayName: Copy python 3.11.0 version to agent tools directory - - task: UsePythonVersion@0 inputs: versionSpec: '3.x'