diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 4a8c67e2215ec..3045648d17cd2 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1063,11 +1063,25 @@ using UnownedAllocator = detail::AllocatorImpl>; /** \brief Wrapper around ::OrtSyncStream * */ -struct SyncStream : detail::Base { - explicit SyncStream(std::nullptr_t) {} ///< Create an empty SyncStream object, must be assigned a valid one to be used - explicit SyncStream(OrtSyncStream* p) : Base{p} {} ///< Take ownership of a pointer created by C API - void* GetHandle() const; ///< Wraps SyncStream_GetHandle + +namespace detail { +template +struct SyncStreamImpl : Base { + using B = Base; + using B::B; + // For some reason this is not a const method on the stream + void* GetHandle(); ///< Wraps SyncStream_GetHandle }; +} // namespace detail + +struct SyncStream : detail::SyncStreamImpl { + ///< Create an empty SyncStream object, must be assigned a valid one to be used + explicit SyncStream(std::nullptr_t) {} + ///< Take ownership of a pointer created by C API + explicit SyncStream(OrtSyncStream* p) : SyncStreamImpl{p} {} +}; + +using UnownedSyncStream = detail::SyncStreamImpl>; namespace detail { template diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 9c42bf34b5b0f..cb6448ad12a81 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -669,9 +669,12 @@ inline void KeyValuePairs::Remove(const char* key) { GetApi().RemoveKeyValuePair(this->p_, key); } -inline void* SyncStream::GetHandle() const { +namespace detail { +template +inline void* SyncStreamImpl::GetHandle() { return GetApi().SyncStream_GetHandle(this->p_); } +} // namespace detail namespace detail { template diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index c0fe171f76037..8b019f60d3e99 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -31,14 +31,17 @@ OrtAllocatorType, # noqa: F401 OrtArenaCfg, # noqa: F401 OrtCompileApiFlags, # noqa: F401 + OrtDeviceMemoryType, # noqa: F401 OrtEpDevice, # noqa: F401 OrtExecutionProviderDevicePolicy, # noqa: F401 OrtExternalInitializerInfo, # noqa: F401 OrtHardwareDevice, # noqa: F401 OrtHardwareDeviceType, # noqa: F401 OrtMemoryInfo, # noqa: F401 + OrtMemoryInfoDeviceType, # noqa: F401 OrtMemType, # noqa: F401 OrtSparseFormat, # noqa: F401 + OrtSyncStream, # noqa: F401 RunOptions, # noqa: F401 SessionIOBinding, # noqa: F401 SessionOptions, # noqa: F401 @@ -78,6 +81,7 @@ OrtDevice, # noqa: F401 OrtValue, # noqa: F401 SparseTensor, # noqa: F401 + copy_tensors, # noqa: F401 ) # TODO: thiagofc: Temporary experimental namespace for new PyTorch front-end diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index 35abad5760c32..4c3313046457c 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -199,6 +199,18 @@ def get_modelmeta(self) -> onnxruntime.ModelMetadata: "Return the metadata. See :class:`onnxruntime.ModelMetadata`." return self._model_meta + def get_input_memory_infos(self) -> Sequence[onnxruntime.MemoryInfo]: + "Return the memory info for the inputs." + return self._input_meminfos + + def get_output_memory_infos(self) -> Sequence[onnxruntime.MemoryInfo]: + "Return the memory info for the outputs." + return self._output_meminfos + + def get_input_epdevices(self) -> Sequence[onnxruntime.OrtEpDevice]: + "Return the execution providers for the inputs." + return self._input_epdevices + def get_providers(self) -> Sequence[str]: "Return list of registered execution providers." return self._providers @@ -576,6 +588,9 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi self._inputs_meta = self._sess.inputs_meta self._outputs_meta = self._sess.outputs_meta self._overridable_initializers = self._sess.overridable_initializers + self._input_meminfos = self._sess.input_meminfos + self._output_meminfos = self._sess.output_meminfos + self._input_epdevices = self._sess.input_epdevices self._model_meta = self._sess.model_meta self._providers = self._sess.get_providers() self._provider_options = self._sess.get_provider_options() @@ -589,6 +604,9 @@ def _reset_session(self, providers, provider_options) -> None: self._inputs_meta = None self._outputs_meta = None self._overridable_initializers = None + self._input_meminfos = None + self._output_meminfos = None + self._input_epdevices = None self._model_meta = None self._providers = None self._provider_options = None @@ -1134,6 +1152,15 @@ def update_inplace(self, np_arr) -> None: self._ortvalue.update_inplace(np_arr) +def copy_tensors(src: Sequence[OrtValue], dst: Sequence[OrtValue], stream=None) -> None: + """ + Copy tensor data from source OrtValue sequence to destination OrtValue sequence. + """ + c_sources = [s._get_c_value() for s in src] + c_dsts = [d._get_c_value() for d in dst] + C.copy_tensors(c_sources, c_dsts, stream) + + class OrtDevice: """ A data structure that exposes the underlying C++ OrtDevice @@ -1146,6 +1173,7 @@ def __init__(self, c_ort_device): if isinstance(c_ort_device, C.OrtDevice): self._ort_device = c_ort_device else: + # An end user won't hit this error raise ValueError( "`Provided object` needs to be of type `onnxruntime.capi.onnxruntime_pybind11_state.OrtDevice`" ) @@ -1188,6 +1216,9 @@ def device_type(self): def device_vendor_id(self): return self._ort_device.vendor_id() + def device_mem_type(self): + return self._ort_device.mem_type() + class SparseTensor: """ diff --git a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc index 1fe7ab0884f9c..d74663ddb63d7 100644 --- a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc @@ -333,7 +333,7 @@ void addOrtValueMethods(pybind11::module& m) { }) #endif // Get a pointer to Tensor data - .def("data_ptr", [](OrtValue* ml_value) -> int64_t { + .def("data_ptr", [](OrtValue* ml_value) -> uintptr_t { // TODO: Assumes that the OrtValue is a Tensor, make this generic to handle non-Tensors ORT_ENFORCE(ml_value->IsTensor(), "Only OrtValues that are Tensors are currently supported"); @@ -344,7 +344,7 @@ void addOrtValueMethods(pybind11::module& m) { } // Should cover x86 and x64 platforms - return reinterpret_cast(tensor->MutableDataRaw()); + return reinterpret_cast(tensor->MutableDataRaw()); }) .def("device_name", [](const OrtValue* ort_value) -> std::string { if (ort_value->IsTensor()) { diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 2392278344d68..b444ee5e7eb94 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -22,6 +22,7 @@ #include "core/framework/data_transfer_utils.h" #include "core/framework/data_types_internal.h" #include "core/framework/error_code_helper.h" +#include "core/framework/plugin_ep_stream.h" #include "core/framework/provider_options_utils.h" #include "core/framework/random_seed.h" #include "core/framework/sparse_tensor.h" @@ -1587,6 +1588,18 @@ void addGlobalMethods(py::module& m) { }, R"pbdoc("Validate a compiled model's compatibility information for one or more EP devices.)pbdoc"); + m.def( + "copy_tensors", + [](const std::vector& src, const std::vector& dest, py::object& py_arg) { + const OrtEnv* ort_env = GetOrtEnv(); + OrtSyncStream* stream = nullptr; + if (!py_arg.is_none()) { + stream = py_arg.cast(); + } + Ort::ThrowOnError(Ort::GetApi().CopyTensors(ort_env, src.data(), dest.data(), stream, src.size())); + }, + R"pbdoc("Copy tensors from sources to destinations using specified stream handle (or None))pbdoc"); + #if defined(USE_OPENVINO) || defined(USE_OPENVINO_PROVIDER_INTERFACE) m.def( "get_available_openvino_device_ids", []() -> std::vector { @@ -1788,6 +1801,16 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra .value("CPU", OrtMemTypeCPU) .value("DEFAULT", OrtMemTypeDefault); + py::enum_(m, "OrtMemoryInfoDeviceType") + .value("CPU", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU) + .value("GPU", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU) + .value("NPU", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_NPU) + .value("FPGA", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_FPGA); + + py::enum_(m, "OrtDeviceMemoryType") + .value("DEFAULT", OrtDeviceMemoryType_DEFAULT) + .value("HOST_ACCESSIBLE", OrtDeviceMemoryType_HOST_ACCESSIBLE); + py::class_ device(m, "OrtDevice", R"pbdoc(ONNXRuntime device information.)pbdoc"); device.def(py::init()) .def(py::init([](OrtDevice::DeviceType type, @@ -1816,6 +1839,7 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra .def("device_id", &OrtDevice::Id, R"pbdoc(Device Id.)pbdoc") .def("device_type", &OrtDevice::Type, R"pbdoc(Device Type.)pbdoc") .def("vendor_id", &OrtDevice::Vendor, R"pbdoc(Vendor Id.)pbdoc") + .def("mem_type", &OrtDevice::MemType, R"pbdoc(Device Memory Type.)pbdoc") // generic device types that are typically used with a vendor id. .def_static("cpu", []() { return OrtDevice::CPU; }) .def_static("gpu", []() { return OrtDevice::GPU; }) @@ -1866,36 +1890,55 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra }, R"pbdoc(Hardware device's metadata as string key/value pairs.)pbdoc"); + py::class_ py_sync_stream(m, "OrtSyncStream", + R"pbdoc(Represents a synchronization stream for model inference.)pbdoc"); + py::class_ py_ep_device(m, "OrtEpDevice", R"pbdoc(Represents a hardware device that an execution provider supports for model inference.)pbdoc"); py_ep_device.def_property_readonly( "ep_name", - [](OrtEpDevice* ep_device) -> std::string { return ep_device->ep_name; }, + [](const OrtEpDevice* ep_device) -> std::string { return ep_device->ep_name; }, R"pbdoc(The execution provider's name.)pbdoc") .def_property_readonly( "ep_vendor", - [](OrtEpDevice* ep_device) -> std::string { return ep_device->ep_vendor; }, + [](const OrtEpDevice* ep_device) -> std::string { return ep_device->ep_vendor; }, R"pbdoc(The execution provider's vendor name.)pbdoc") .def_property_readonly( "ep_metadata", - [](OrtEpDevice* ep_device) -> std::map { + [](const OrtEpDevice* ep_device) -> std::map { return ep_device->ep_metadata.Entries(); }, R"pbdoc(The execution provider's additional metadata for the OrtHardwareDevice.)pbdoc") .def_property_readonly( "ep_options", - [](OrtEpDevice* ep_device) -> std::map { + [](const OrtEpDevice* ep_device) -> std::map { return ep_device->ep_options.Entries(); }, R"pbdoc(The execution provider's options used to configure the provider to use the OrtHardwareDevice.)pbdoc") .def_property_readonly( "device", - [](OrtEpDevice* ep_device) -> const OrtHardwareDevice& { + [](const OrtEpDevice* ep_device) -> const OrtHardwareDevice& { return *ep_device->device; }, R"pbdoc(The OrtHardwareDevice instance for the OrtEpDevice.)pbdoc", - py::return_value_policy::reference_internal); + py::return_value_policy::reference_internal) + .def( + "memory_info", + [](const OrtEpDevice* ep_device, OrtDeviceMemoryType memory_type) -> const OrtMemoryInfo* { + Ort::ConstEpDevice ep_dev(ep_device); + return static_cast(ep_dev.GetMemoryInfo(memory_type)); + }, + R"pbdoc(The OrtMemoryInfo instance for the OrtEpDevice specific to the device memory type.)pbdoc", + py::return_value_policy::reference_internal) + .def( + "create_sync_stream", + [](const OrtEpDevice* ep_device) -> std::unique_ptr { + Ort::ConstEpDevice ep_dev(ep_device); + Ort::SyncStream stream = ep_dev.CreateSyncStream(); + return std::unique_ptr(stream.release()); + }, + R"pbdoc(The OrtSyncStream instance for the OrtEpDevice.)pbdoc"); py::class_ ort_arena_cfg_binding(m, "OrtArenaCfg"); // Note: Doesn't expose initial_growth_chunk_sizes_bytes/max_power_of_two_extend_bytes option. @@ -1941,25 +1984,28 @@ for model inference.)pbdoc"); .def_readwrite("max_power_of_two_extend_bytes", &OrtArenaCfg::max_power_of_two_extend_bytes); py::class_ ort_memory_info_binding(m, "OrtMemoryInfo"); - ort_memory_info_binding.def(py::init([](const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) { - if (strcmp(name, onnxruntime::CPU) == 0) { - return std::make_unique(onnxruntime::CPU, type, OrtDevice(), mem_type); - } else if (strcmp(name, onnxruntime::CUDA) == 0) { - return std::make_unique( - onnxruntime::CUDA, type, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, - static_cast(id)), - mem_type); - } else if (strcmp(name, onnxruntime::CUDA_PINNED) == 0) { - return std::make_unique( - onnxruntime::CUDA_PINNED, type, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, - static_cast(id)), - mem_type); - } else { - throw std::runtime_error("Specified device is not supported."); - } - })); + ort_memory_info_binding.def( + py::init([](const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) { + Ort::MemoryInfo result(name, type, id, mem_type); + return std::unique_ptr(result.release()); + })) + .def_static( + "create_v2", + [](const char* name, OrtMemoryInfoDeviceType device_type, uint32_t vendor_id, + int32_t device_id, OrtDeviceMemoryType device_mem_type, size_t alignment, OrtAllocatorType type) { + Ort::MemoryInfo result(name, device_type, vendor_id, device_id, device_mem_type, alignment, type); + return std::unique_ptr(result.release()); + }, + R"pbdoc(Create an OrtMemoryInfo instance using CreateMemoryInfo_V2())pbdoc") + .def_property_readonly("name", [](const OrtMemoryInfo* mem_info) -> std::string { return mem_info->name; }, R"pbdoc(Arbitrary name supplied by the user)pbdoc") + .def_property_readonly("device_id", [](const OrtMemoryInfo* mem_info) -> int { return mem_info->device.Id(); }, R"pbdoc(Device Id.)pbdoc") + .def_property_readonly("mem_type", [](const OrtMemoryInfo* mem_info) -> OrtMemType { return mem_info->mem_type; }, R"pbdoc(OrtMemoryInfo memory type.)pbdoc") + .def_property_readonly("allocator_type", [](const OrtMemoryInfo* mem_info) -> OrtAllocatorType { return mem_info->alloc_type; }, R"pbdoc(Allocator type)pbdoc") + .def_property_readonly("device_mem_type", [](const OrtMemoryInfo* mem_info) -> OrtDeviceMemoryType { + auto mem_type = mem_info->device.MemType(); + return (mem_type == OrtDevice::MemType::DEFAULT) ? + OrtDeviceMemoryType_DEFAULT: OrtDeviceMemoryType_HOST_ACCESSIBLE ; }, R"pbdoc(Device memory type (Device or Host accessible).)pbdoc") + .def_property_readonly("device_vendor_id", [](const OrtMemoryInfo* mem_info) -> uint32_t { return mem_info->device.Vendor(); }); py::class_ sess(m, "SessionOptions", R"pbdoc(Configuration information for a session.)pbdoc"); @@ -2656,6 +2702,33 @@ including arg name, arg type (contains both type and shape).)pbdoc") auto res = sess->GetSessionHandle()->GetModelMetadata(); OrtPybindThrowIfError(res.first); return *(res.second); }, py::return_value_policy::reference_internal) + .def_property_readonly("input_meminfos", [](const PyInferenceSession* sess) -> py::list { + Ort::ConstSession session(reinterpret_cast(sess->GetSessionHandle())); + auto inputs_mem_info = session.GetMemoryInfoForInputs(); + py::list result; + for (const auto& info : inputs_mem_info) { + const auto* p_info = static_cast(info); + result.append(py::cast(p_info, py::return_value_policy::reference)); + } + return result; }) + .def_property_readonly("output_meminfos", [](const PyInferenceSession* sess) -> py::list { + Ort::ConstSession session(reinterpret_cast(sess->GetSessionHandle())); + auto outputs_mem_info = session.GetMemoryInfoForOutputs(); + py::list result; + for (const auto& info : outputs_mem_info) { + const auto* p_info = static_cast(info); + result.append(py::cast(p_info, py::return_value_policy::reference)); + } + return result; }) + .def_property_readonly("input_epdevices", [](const PyInferenceSession* sess) -> py::list { + Ort::ConstSession session(reinterpret_cast(sess->GetSessionHandle())); + auto ep_devices = session.GetEpDeviceForInputs(); + py::list result; + for (const auto& device : ep_devices) { + const auto* p_device = static_cast(device); + result.append(py::cast(p_device, py::return_value_policy::reference)); + } + return result; }) .def("run_with_iobinding", [](PyInferenceSession* sess, SessionIOBinding& io_binding, RunOptions* run_options = nullptr) -> void { Status status; diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 1820664e1d604..b85030b46e94d 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -689,15 +689,30 @@ def test_run_model_with_optional_sequence_input(self): def test_run_model(self): sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=available_providers) x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) - input_name = sess.get_inputs()[0].name - self.assertEqual(input_name, "X") - input_shape = sess.get_inputs()[0].shape - self.assertEqual(input_shape, [3, 2]) - output_name = sess.get_outputs()[0].name - self.assertEqual(output_name, "Y") - output_shape = sess.get_outputs()[0].shape - self.assertEqual(output_shape, [3, 2]) - res = sess.run([output_name], {input_name: x}) + + inputs = sess.get_inputs() + self.assertEqual(len(inputs), 1) + self.assertEqual(inputs[0].name, "X") + self.assertEqual(inputs[0].shape, [3, 2]) + + input_meminfos = sess.get_input_memory_infos() + self.assertEqual(len(input_meminfos), 1) + self.assertIsNotNone(input_meminfos[0]) + + input_epdevices = sess.get_input_epdevices() + # The entry my be None (null) but it should be present + self.assertEqual(len(input_epdevices), 1) + + outputs = sess.get_outputs() + self.assertEqual(len(outputs), 1) + self.assertEqual(outputs[0].name, "Y") + self.assertEqual(outputs[0].shape, [3, 2]) + + output_meminfos = sess.get_output_memory_infos() + self.assertEqual(len(output_meminfos), 1) + self.assertIsNotNone(output_meminfos[0]) + + res = sess.run([outputs[0].name], {inputs[0].name: x}) output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) @@ -1584,6 +1599,44 @@ def test_run_model_with_cuda_copy_stream(self): for _iteration in range(100000): session.run(output_names=["output"], input_feed={"shape": shape}) + def test_ort_device(self): + cpu_device = onnxrt.OrtDevice.make("cpu", 0) + self.assertEqual(cpu_device.device_id(), 0) + self.assertEqual(cpu_device.device_type(), 0) + self.assertEqual(cpu_device.device_vendor_id(), 0) + self.assertEqual(cpu_device.device_mem_type(), 0) + + def test_ort_memory_info(self): + cpu_memory_info = onnxrt.OrtMemoryInfo( + "Cpu", + onnxrt.OrtAllocatorType.ORT_ARENA_ALLOCATOR, + 0, + onnxrt.OrtMemType.DEFAULT, + ) + self.assertEqual(cpu_memory_info.name, "Cpu") + self.assertEqual(cpu_memory_info.device_id, 0) + self.assertEqual(cpu_memory_info.mem_type, onnxrt.OrtMemType.DEFAULT) + self.assertEqual(cpu_memory_info.allocator_type, onnxrt.OrtAllocatorType.ORT_ARENA_ALLOCATOR) + self.assertEqual(cpu_memory_info.device_mem_type, onnxrt.OrtDeviceMemoryType.DEFAULT) + self.assertEqual(cpu_memory_info.device_vendor_id, 0) + + def test_ort_memory_info_create_v2(self): + cpu_memory_info = onnxrt.OrtMemoryInfo.create_v2( + "Test", + onnxrt.OrtMemoryInfoDeviceType.CPU, + 0, # vendor_id + 0, # device_id + onnxrt.OrtDeviceMemoryType.DEFAULT, + 128, # alignment + onnxrt.OrtAllocatorType.ORT_ARENA_ALLOCATOR, + ) + self.assertEqual(cpu_memory_info.name, "Test") + self.assertEqual(cpu_memory_info.device_id, 0) + self.assertEqual(cpu_memory_info.mem_type, onnxrt.OrtMemType.DEFAULT) + self.assertEqual(cpu_memory_info.allocator_type, onnxrt.OrtAllocatorType.ORT_ARENA_ALLOCATOR) + self.assertEqual(cpu_memory_info.device_mem_type, onnxrt.OrtDeviceMemoryType.DEFAULT) + self.assertEqual(cpu_memory_info.device_vendor_id, 0) + def test_shared_allocator_using_create_and_register_allocator(self): # Create and register an arena based allocator diff --git a/onnxruntime/test/python/onnxruntime_test_python_autoep.py b/onnxruntime/test/python/onnxruntime_test_python_autoep.py index cb31627a87c48..d6281d165c053 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_autoep.py +++ b/onnxruntime/test/python/onnxruntime_test_python_autoep.py @@ -226,6 +226,14 @@ def test_example_plugin_ep_devices(self): hw_metadata = hw_device.metadata self.assertGreater(len(hw_metadata), 0) # Should have at least SPDRP_HARDWAREID on Windows + test_mem_info = test_ep_device.memory_info(onnxrt.OrtDeviceMemoryType.DEFAULT) + self.assertIsNotNone(test_mem_info) + del test_mem_info + + test_sync_stream = test_ep_device.create_sync_stream() + self.assertIsNotNone(test_sync_stream) + del test_sync_stream + # Add EP plugin's OrtEpDevice to the SessionOptions. sess_options = onnxrt.SessionOptions() sess_options.add_provider_for_devices([test_ep_device], {"opt1": "val1"}) @@ -282,6 +290,55 @@ def test_example_plugin_ep_data_transfer(self): self.unregister_execution_provider_library(ep_name) + def test_copy_tensors(self): + """ + Test global api copy_tensors between OrtValue objects + using EP plug-in data transfer + """ + if sys.platform != "win32": + self.skipTest("Skipping test because device discovery is only supported on Windows") + + ep_lib_path = "example_plugin_ep.dll" + try: + ep_lib_path = get_name("example_plugin_ep.dll") + except FileNotFoundError: + self.skipTest(f"Skipping test because EP library '{ep_lib_path}' cannot be found") + + ep_name = "example_ep" + self.register_execution_provider_library(ep_name, os.path.realpath(ep_lib_path)) + + # Generate 2 numpy arrays + a = np.random.rand(3, 2).astype(np.float32) + b = np.random.rand(3, 2).astype(np.float32) + + # Create OrtValue from numpy arrays on EP device + # the example EP pretends to use GPU memory, so we place it there + a_device = onnxrt.OrtValue.ortvalue_from_numpy(a, "gpu", 0, 0xBE57) + b_device = onnxrt.OrtValue.ortvalue_from_numpy(b, "gpu", 0, 0xBE57) + + # Create destination ort values with the same shape on CPU + a_cpu_copy = onnxrt.OrtValue.ortvalue_from_shape_and_type(a.shape, a.dtype) + b_cpu_copy = onnxrt.OrtValue.ortvalue_from_shape_and_type(b.shape, b.dtype) + + # source list + src_list = [a_device, b_device] + dst_list = [a_cpu_copy, b_cpu_copy] + # Passing None for stream as we copy between CPU + # Test None because it is allowed + onnxrt.copy_tensors(src_list, dst_list, None) + + # Release the OrtValue on the EP device + # before the EP library is unregistered + del src_list + del a_device + del b_device + + # Verify the contents + np.testing.assert_array_equal(a, a_cpu_copy.numpy()) + np.testing.assert_array_equal(b, b_cpu_copy.numpy()) + + self.unregister_execution_provider_library(ep_name) + if __name__ == "__main__": unittest.main(verbosity=1)