From cece80d54c73efde16203dac7270f22dbac125bf Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Thu, 4 Sep 2025 13:35:29 -0700 Subject: [PATCH 1/9] Begin --- onnxruntime/__init__.py | 1 + .../onnxruntime_inference_collection.py | 3 ++ .../python/onnxruntime_pybind_state.cc | 31 ++++++++++++++----- .../test/python/onnxruntime_test_python.py | 21 +++++++++++++ .../python/onnxruntime_test_python_autoep.py | 3 ++ 5 files changed, 52 insertions(+), 7 deletions(-) diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index 38f034cf6d266..31267263e2384 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -31,6 +31,7 @@ OrtAllocatorType, # noqa: F401 OrtArenaCfg, # noqa: F401 OrtCompileApiFlags, # noqa: F401 + OrtDeviceMemoryType, # noqa: F401 OrtEpDevice, # noqa: F401 OrtExecutionProviderDevicePolicy, # noqa: F401 OrtHardwareDevice, # noqa: F401 diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index 8c8ba214eb714..8e380a40ceeeb 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -1116,6 +1116,9 @@ def device_type(self): def device_vendor_id(self): return self._ort_device.vendor_id() + def device_mem_type(self): + return self._ort_device.mem_type() + class SparseTensor: """ diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 27c76f7f5c482..63f5aa04bbcf3 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1787,6 +1787,10 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra .value("CPU", OrtMemTypeCPU) .value("DEFAULT", OrtMemTypeDefault); + py::enum_(m, "OrtDeviceMemoryType") + .value("DEFAULT", OrtDeviceMemoryType_DEFAULT) + .value("HOST_ACCESSIBLE", OrtDeviceMemoryType_HOST_ACCESSIBLE); + py::class_ device(m, "OrtDevice", R"pbdoc(ONNXRuntime device information.)pbdoc"); device.def(py::init()) .def(py::init([](OrtDevice::DeviceType type, @@ -1815,6 +1819,7 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra .def("device_id", &OrtDevice::Id, R"pbdoc(Device Id.)pbdoc") .def("device_type", &OrtDevice::Type, R"pbdoc(Device Type.)pbdoc") .def("vendor_id", &OrtDevice::Vendor, R"pbdoc(Vendor Id.)pbdoc") + .def("mem_type", &OrtDevice::MemType, R"pbdoc(Device Memory Type.)pbdoc") // generic device types that are typically used with a vendor id. .def_static("cpu", []() { return OrtDevice::CPU; }) .def_static("gpu", []() { return OrtDevice::GPU; }) @@ -1870,30 +1875,37 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra for model inference.)pbdoc"); py_ep_device.def_property_readonly( "ep_name", - [](OrtEpDevice* ep_device) -> std::string { return ep_device->ep_name; }, + [](const OrtEpDevice* ep_device) -> std::string { return ep_device->ep_name; }, R"pbdoc(The execution provider's name.)pbdoc") .def_property_readonly( "ep_vendor", - [](OrtEpDevice* ep_device) -> std::string { return ep_device->ep_vendor; }, + [](const OrtEpDevice* ep_device) -> std::string { return ep_device->ep_vendor; }, R"pbdoc(The execution provider's vendor name.)pbdoc") .def_property_readonly( "ep_metadata", - [](OrtEpDevice* ep_device) -> std::map { + [](const OrtEpDevice* ep_device) -> std::map { return ep_device->ep_metadata.Entries(); }, R"pbdoc(The execution provider's additional metadata for the OrtHardwareDevice.)pbdoc") .def_property_readonly( "ep_options", - [](OrtEpDevice* ep_device) -> std::map { + [](const OrtEpDevice* ep_device) -> std::map { return ep_device->ep_options.Entries(); }, R"pbdoc(The execution provider's options used to configure the provider to use the OrtHardwareDevice.)pbdoc") .def_property_readonly( "device", - [](OrtEpDevice* ep_device) -> const OrtHardwareDevice& { + [](const OrtEpDevice* ep_device) -> const OrtHardwareDevice& { return *ep_device->device; }, R"pbdoc(The OrtHardwareDevice instance for the OrtEpDevice.)pbdoc", + py::return_value_policy::reference_internal) + .def( + "memory_info", + [](const OrtEpDevice* ep_device, OrtDeviceMemoryType memory_type) -> const OrtMemoryInfo* { + return Ort::GetApi().EpDevice_MemoryInfo(ep_device, memory_type); + }, + R"pbdoc(The OrtMemoryInfo instance for the OrtEpDevice specific to the device memory type.)pbdoc", py::return_value_policy::reference_internal); py::class_ ort_arena_cfg_binding(m, "OrtArenaCfg"); @@ -1957,8 +1969,13 @@ for model inference.)pbdoc"); mem_type); } else { throw std::runtime_error("Specified device is not supported."); - } - })); + } })) + .def_property_readonly("name", [](const OrtMemoryInfo* mem_info) -> const char* { return mem_info->name; }) + .def_property_readonly("id", [](const OrtMemoryInfo* mem_info) -> int { return mem_info->device.Id(); }, R"pbdoc(Device Id.)pbdoc") + .def_property_readonly("mem_type", [](const OrtMemoryInfo* mem_info) -> OrtMemType { return mem_info->mem_type; }, R"pbdoc(OrtMemoryInfo memory type.)pbdoc") + .def_property_readonly("allocator_type", [](const OrtMemoryInfo* mem_info) -> OrtAllocatorType { return mem_info->alloc_type; }) + .def_property_readonly("device_mem_type", [](const OrtMemoryInfo* mem_info) -> int { return mem_info->device.MemType(); }, R"pbdoc(Device memory type.)pbdoc") + .def_property_readonly("vendor_id", [](const OrtMemoryInfo* mem_info) -> uint32_t { return mem_info->device.Vendor(); }); py::class_ sess(m, "SessionOptions", R"pbdoc(Configuration information for a session.)pbdoc"); diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 1820664e1d604..86e5bd2bae70f 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -1584,6 +1584,27 @@ def test_run_model_with_cuda_copy_stream(self): for _iteration in range(100000): session.run(output_names=["output"], input_feed={"shape": shape}) + def test_ort_device(self): + cpu_device = onnxrt.OrtDevice.make("cpu", 0) + self.assertEqual(cpu_device.device_id(), 0) + self.assertEqual(cpu_device.device_type(), 0) + self.assertEqual(cpu_device.device_vendor_id(), 0) + self.assertEqual(cpu_device.device_mem_type(), 0) + + def test_ort_memory_info(self): + cpu_memory_info = onnxrt.OrtMemoryInfo( + "Cpu", + onnxrt.OrtAllocatorType.ORT_ARENA_ALLOCATOR, + 0, + onnxrt.OrtMemType.DEFAULT, + ) + self.assertEqual(cpu_memory_info.name, "Cpu") + self.assertEqual(cpu_memory_info.id, 0) + self.assertEqual(cpu_memory_info.mem_type, onnxrt.OrtMemType.DEFAULT) + self.assertEqual(cpu_memory_info.allocator_type, onnxrt.OrtAllocatorType.ORT_ARENA_ALLOCATOR) + self.assertEqual(cpu_memory_info.device_mem_type, 0) + self.assertEqual(cpu_memory_info.vendor_id, 0) + def test_shared_allocator_using_create_and_register_allocator(self): # Create and register an arena based allocator diff --git a/onnxruntime/test/python/onnxruntime_test_python_autoep.py b/onnxruntime/test/python/onnxruntime_test_python_autoep.py index cb31627a87c48..49ff7b488bdbd 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_autoep.py +++ b/onnxruntime/test/python/onnxruntime_test_python_autoep.py @@ -226,6 +226,9 @@ def test_example_plugin_ep_devices(self): hw_metadata = hw_device.metadata self.assertGreater(len(hw_metadata), 0) # Should have at least SPDRP_HARDWAREID on Windows + test_mem_info = test_ep_device.memory_info(onnxrt.OrtDeviceMemoryType.DEFAULT) + self.assertIsNotNone(test_mem_info) + # Add EP plugin's OrtEpDevice to the SessionOptions. sess_options = onnxrt.SessionOptions() sess_options.add_provider_for_devices([test_ep_device], {"opt1": "val1"}) From c053030c05e712fc91b49d1f3e1db743df577c92 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Thu, 4 Sep 2025 16:30:30 -0700 Subject: [PATCH 2/9] Make name an std::string in OrtMemoryInfo --- .../core/framework/ortmemoryinfo.h | 10 ++--- onnxruntime/core/framework/allocator.cc | 43 ++++++++++++++----- onnxruntime/core/framework/bfc_arena.cc | 2 +- onnxruntime/core/session/environment.cc | 2 +- onnxruntime/core/session/lora_adapters.cc | 4 +- .../test/framework/TestAllocatorManager.cc | 2 +- onnxruntime/test/framework/allocator_test.cc | 2 +- onnxruntime/test/framework/tensor_test.cc | 8 ++-- 8 files changed, 48 insertions(+), 25 deletions(-) diff --git a/include/onnxruntime/core/framework/ortmemoryinfo.h b/include/onnxruntime/core/framework/ortmemoryinfo.h index d930b2289170d..5453b2526b160 100644 --- a/include/onnxruntime/core/framework/ortmemoryinfo.h +++ b/include/onnxruntime/core/framework/ortmemoryinfo.h @@ -13,13 +13,13 @@ struct OrtMemoryInfo { OrtMemoryInfo() = default; // to allow default construction of Tensor // use string for name, so we could have customized allocator in execution provider. - const char* name = nullptr; + std::string name; OrtMemType mem_type = OrtMemTypeDefault; OrtAllocatorType alloc_type = OrtInvalidAllocator; OrtDevice device; - constexpr OrtMemoryInfo(const char* name_, OrtAllocatorType type_, OrtDevice device_ = OrtDevice(), - OrtMemType mem_type_ = OrtMemTypeDefault) + OrtMemoryInfo(const char* name_, OrtAllocatorType type_, OrtDevice device_ = OrtDevice(), + OrtMemType mem_type_ = OrtMemTypeDefault) #if ((defined(__GNUC__) && __GNUC__ > 4) || defined(__clang__)) // this causes a spurious error in CentOS gcc 4.8 build so disable if GCC version < 5 __attribute__((nonnull)) @@ -39,7 +39,7 @@ struct OrtMemoryInfo { if (device != other.device) return device < other.device; - return strcmp(name, other.name) < 0; + return name < other.name; } // This is to make OrtMemoryInfo a valid key in hash tables @@ -68,7 +68,7 @@ inline bool operator==(const OrtMemoryInfo& left, const OrtMemoryInfo& other) { return left.mem_type == other.mem_type && left.alloc_type == other.alloc_type && left.device == other.device && - strcmp(left.name, other.name) == 0; + left.name == other.name; } inline bool operator!=(const OrtMemoryInfo& lhs, const OrtMemoryInfo& rhs) { return !(lhs == rhs); } diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index e1b9d1294fb9e..91b5b811a3529 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -6,6 +6,7 @@ #include "core/common/safeint.h" #include "core/common/status.h" #include "core/framework/allocator.h" +#include "core/framework/error_code_helper.h" #include "core/mlas/inc/mlas.h" #include "core/framework/utils.h" #include "core/session/ort_apis.h" @@ -185,22 +186,32 @@ std::ostream& operator<<(std::ostream& out, const OrtMemoryInfo& info) { return #endif ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtAllocatorType type, int id1, enum OrtMemType mem_type1, _Outptr_ OrtMemoryInfo** out) { + API_IMPL_BEGIN + + if (name1 == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "MemoryInfo name cannot be null."); + } + + if (out == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Output memory info cannot be null."); + } + auto device_id = static_cast(id1); if (strcmp(name1, onnxruntime::CPU) == 0) { *out = new OrtMemoryInfo(onnxruntime::CPU, type, OrtDevice(), mem_type1); } else if (strcmp(name1, onnxruntime::CUDA) == 0) { *out = new OrtMemoryInfo( - name1, type, + onnxruntime::CUDA, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, device_id), mem_type1); } else if (strcmp(name1, onnxruntime::OpenVINO_GPU) == 0) { *out = new OrtMemoryInfo( - name1, type, + onnxruntime::OpenVINO_GPU, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::INTEL, device_id), mem_type1); } else if (strcmp(name1, onnxruntime::HIP) == 0) { *out = new OrtMemoryInfo( - name1, type, + onnxruntime::HIP, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, device_id), mem_type1); } else if (strcmp(name1, onnxruntime::WEBGPU_BUFFER) == 0 || @@ -212,38 +223,39 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA } else if (strcmp(name1, onnxruntime::DML) == 0) { *out = new OrtMemoryInfo( - name1, type, + onnxruntime::DML, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::MICROSOFT, device_id), mem_type1); } else if (strcmp(name1, onnxruntime::OpenVINO_RT_NPU) == 0) { *out = new OrtMemoryInfo( - name1, type, + onnxruntime::OpenVINO_RT_NPU, type, OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::INTEL, device_id), mem_type1); } else if (strcmp(name1, onnxruntime::CUDA_PINNED) == 0) { *out = new OrtMemoryInfo( - name1, type, + onnxruntime::CUDA_PINNED, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, device_id), mem_type1); } else if (strcmp(name1, onnxruntime::HIP_PINNED) == 0) { *out = new OrtMemoryInfo( - name1, type, + onnxruntime::HIP_PINNED, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::AMD, device_id), mem_type1); } else if (strcmp(name1, onnxruntime::QNN_HTP_SHARED) == 0) { *out = new OrtMemoryInfo( - name1, type, + onnxruntime::QNN_HTP_SHARED, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::QUALCOMM, device_id), mem_type1); } else if (strcmp(name1, onnxruntime::CPU_ALIGNED_4K) == 0) { *out = new OrtMemoryInfo( - name1, type, + onnxruntime::CPU_ALIGNED_4K, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, device_id, onnxruntime::kAlloc4KAlignment), mem_type1); } else { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Specified device is not supported. Try CreateMemoryInfo_V2."); } + API_IMPL_END return nullptr; } @@ -251,6 +263,16 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo_V2, _In_ const char* name, _In_ en _In_ uint32_t vendor_id, _In_ int32_t device_id, _In_ enum OrtDeviceMemoryType mem_type, _In_ size_t alignment, enum OrtAllocatorType type, _Outptr_ OrtMemoryInfo** out) { + API_IMPL_BEGIN + + if (name == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "MemoryInfo name cannot be null."); + } + + if (out == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Output memory info cannot be null."); + } + // map the public enum values to internal OrtDevice values OrtDevice::MemoryType mt = mem_type == OrtDeviceMemoryType_DEFAULT ? OrtDevice::MemType::DEFAULT : OrtDevice::MemType::HOST_ACCESSIBLE; @@ -275,6 +297,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo_V2, _In_ const char* name, _In_ en *out = new OrtMemoryInfo(name, type, OrtDevice{dt, mt, vendor_id, narrow(device_id), alignment}, mem_type == OrtDeviceMemoryType_DEFAULT ? OrtMemTypeDefault : OrtMemTypeCPU); + API_IMPL_END return nullptr; } @@ -283,7 +306,7 @@ ORT_API(void, OrtApis::ReleaseMemoryInfo, _Frees_ptr_opt_ OrtMemoryInfo* p) { de #pragma warning(pop) #endif ORT_API_STATUS_IMPL(OrtApis::MemoryInfoGetName, _In_ const OrtMemoryInfo* ptr, _Out_ const char** out) { - *out = ptr->name; + *out = ptr->name.c_str(); return nullptr; } diff --git a/onnxruntime/core/framework/bfc_arena.cc b/onnxruntime/core/framework/bfc_arena.cc index e0b50cd04173e..3a5af42d03cdd 100644 --- a/onnxruntime/core/framework/bfc_arena.cc +++ b/onnxruntime/core/framework/bfc_arena.cc @@ -13,7 +13,7 @@ BFCArena::BFCArena(std::unique_ptr resource_allocator, int max_dead_bytes_per_chunk, int initial_growth_chunk_size_bytes, int64_t max_power_of_two_extend_bytes) - : IAllocator(OrtMemoryInfo(resource_allocator->Info().name, + : IAllocator(OrtMemoryInfo(resource_allocator->Info().name.c_str(), OrtAllocatorType::OrtArenaAllocator, resource_allocator->Info().device, resource_allocator->Info().mem_type)), diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 39b785c327d56..9c40eb75780ee 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -79,7 +79,7 @@ static bool AreOrtMemoryInfosEquivalent( bool ignore_alignment = false) { return left.mem_type == right.mem_type && (ignore_alignment ? left.device.EqualIgnoringAlignment(right.device) : left.device == right.device) && - (!match_name || strcmp(left.name, right.name) == 0); + (!match_name || left.name == right.name); } std::vector::const_iterator FindExistingAllocator(const std::vector& allocators, diff --git a/onnxruntime/core/session/lora_adapters.cc b/onnxruntime/core/session/lora_adapters.cc index 85ea958981e2c..d99ebaf1c7efb 100644 --- a/onnxruntime/core/session/lora_adapters.cc +++ b/onnxruntime/core/session/lora_adapters.cc @@ -53,11 +53,11 @@ void LoraAdapter::MemoryMap(const std::filesystem::path& file_path) { static std::unique_ptr GetDataTransfer(const OrtMemoryInfo& mem_info) { std::unique_ptr data_transfer; - if (strcmp(mem_info.name, onnxruntime::CPU) == 0) { + if (mem_info.name, onnxruntime::CPU == 0) { return data_transfer; } - if (strcmp(mem_info.name, onnxruntime::CUDA) == 0) { + if (mem_info.name, onnxruntime::CUDA == 0) { #if defined(USE_CUDA) || defined(USE_CUDA_PROVIDER_INTERFACE) auto* cuda_provider_info = TryGetProviderInfo_CUDA(); if (cuda_provider_info != nullptr) { diff --git a/onnxruntime/test/framework/TestAllocatorManager.cc b/onnxruntime/test/framework/TestAllocatorManager.cc index 30f2686cd62f5..6440a805cdc59 100644 --- a/onnxruntime/test/framework/TestAllocatorManager.cc +++ b/onnxruntime/test/framework/TestAllocatorManager.cc @@ -10,7 +10,7 @@ namespace test { class DummyArena : public IAllocator { public: explicit DummyArena(std::unique_ptr resource_allocator) - : IAllocator(OrtMemoryInfo(resource_allocator->Info().name, + : IAllocator(OrtMemoryInfo(resource_allocator->Info().name.c_str(), OrtAllocatorType::OrtDeviceAllocator, resource_allocator->Info().device, resource_allocator->Info().mem_type)), diff --git a/onnxruntime/test/framework/allocator_test.cc b/onnxruntime/test/framework/allocator_test.cc index 3efba6f1b6e52..445e023746aaa 100644 --- a/onnxruntime/test/framework/allocator_test.cc +++ b/onnxruntime/test/framework/allocator_test.cc @@ -13,7 +13,7 @@ namespace test { TEST(AllocatorTest, CPUAllocatorTest) { auto cpu_arena = TestCPUExecutionProvider()->CreatePreferredAllocators()[0]; - ASSERT_STREQ(cpu_arena->Info().name, CPU); + ASSERT_STREQ(cpu_arena->Info().name.c_str(), CPU); EXPECT_EQ(cpu_arena->Info().device.Id(), 0); const auto expected_allocator_type = DoesCpuAllocatorSupportArenaUsage() diff --git a/onnxruntime/test/framework/tensor_test.cc b/onnxruntime/test/framework/tensor_test.cc index 2ac1a93013932..f08675271de21 100644 --- a/onnxruntime/test/framework/tensor_test.cc +++ b/onnxruntime/test/framework/tensor_test.cc @@ -29,7 +29,7 @@ void CPUTensorTest(std::vector dims, const int offset_elements = 0) { EXPECT_EQ(shape.GetDims(), tensor_shape.GetDims()); EXPECT_EQ(t.DataType(), DataTypeImpl::GetType()); auto& location = t.Location(); - EXPECT_STREQ(location.name, CPU); + EXPECT_STREQ(location.name.c_str(), CPU); EXPECT_EQ(location.device.Id(), 0); const T* t_data = t.Data(); @@ -47,7 +47,7 @@ void CPUTensorTest(std::vector dims, const int offset_elements = 0) { EXPECT_EQ(shape.GetDims(), tensor_shape.GetDims()); EXPECT_EQ(new_t.DataType(), DataTypeImpl::GetType()); auto& new_location = new_t.Location(); - ASSERT_STREQ(new_location.name, CPU); + ASSERT_STREQ(new_location.name.c_str(), CPU); EXPECT_EQ(new_location.device.Id(), 0); } } @@ -135,7 +135,7 @@ TEST(TensorTest, EmptyTensorTest) { EXPECT_TRUE(!data); auto& location = t.Location(); - ASSERT_STREQ(location.name, CPU); + ASSERT_STREQ(location.name.c_str(), CPU); EXPECT_EQ(location.device.Id(), 0); const auto expected_allocator_type = DoesCpuAllocatorSupportArenaUsage() @@ -160,7 +160,7 @@ TEST(TensorTest, StringTensorTest) { EXPECT_EQ(shape, tensor_shape); EXPECT_EQ(t.DataType(), DataTypeImpl::GetType()); auto& location = t.Location(); - ASSERT_STREQ(location.name, CPU); + ASSERT_EQ(location.name, CPU); EXPECT_EQ(location.device.Id(), 0); std::string* new_data = t.MutableData(); From 9d8e5ce1cf114dbb145be10e1ee7f6cbb3574bb6 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Thu, 4 Sep 2025 16:31:51 -0700 Subject: [PATCH 3/9] Implement OrtMemoryInfo interfaces --- onnxruntime/__init__.py | 1 + .../python/onnxruntime_pybind_state.cc | 50 ++++++++++--------- .../test/python/onnxruntime_test_python.py | 23 +++++++-- 3 files changed, 48 insertions(+), 26 deletions(-) diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index 31267263e2384..91d54ecae8c28 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -37,6 +37,7 @@ OrtHardwareDevice, # noqa: F401 OrtHardwareDeviceType, # noqa: F401 OrtMemoryInfo, # noqa: F401 + OrtMemoryInfoDeviceType, # noqa: F401 OrtMemType, # noqa: F401 OrtSparseFormat, # noqa: F401 RunOptions, # noqa: F401 diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 63f5aa04bbcf3..06a3a738c00db 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1787,6 +1787,12 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra .value("CPU", OrtMemTypeCPU) .value("DEFAULT", OrtMemTypeDefault); + py::enum_(m, "OrtMemoryInfoDeviceType") + .value("CPU", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU) + .value("GPU", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU) + .value("NPU", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_NPU) + .value("FPGA", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_FPGA); + py::enum_(m, "OrtDeviceMemoryType") .value("DEFAULT", OrtDeviceMemoryType_DEFAULT) .value("HOST_ACCESSIBLE", OrtDeviceMemoryType_HOST_ACCESSIBLE); @@ -1952,30 +1958,28 @@ for model inference.)pbdoc"); .def_readwrite("max_power_of_two_extend_bytes", &OrtArenaCfg::max_power_of_two_extend_bytes); py::class_ ort_memory_info_binding(m, "OrtMemoryInfo"); - ort_memory_info_binding.def(py::init([](const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) { - if (strcmp(name, onnxruntime::CPU) == 0) { - return std::make_unique(onnxruntime::CPU, type, OrtDevice(), mem_type); - } else if (strcmp(name, onnxruntime::CUDA) == 0) { - return std::make_unique( - onnxruntime::CUDA, type, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, - static_cast(id)), - mem_type); - } else if (strcmp(name, onnxruntime::CUDA_PINNED) == 0) { - return std::make_unique( - onnxruntime::CUDA_PINNED, type, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, - static_cast(id)), - mem_type); - } else { - throw std::runtime_error("Specified device is not supported."); - } })) - .def_property_readonly("name", [](const OrtMemoryInfo* mem_info) -> const char* { return mem_info->name; }) - .def_property_readonly("id", [](const OrtMemoryInfo* mem_info) -> int { return mem_info->device.Id(); }, R"pbdoc(Device Id.)pbdoc") + ort_memory_info_binding.def( + py::init([](const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) { + Ort::MemoryInfo result(name, type, id, mem_type); + return std::unique_ptr(result.release()); + })) + .def_static( + "create_v2", + [](const char* name, OrtMemoryInfoDeviceType device_type, uint32_t vendor_id, + int32_t device_id, OrtDeviceMemoryType device_mem_type, size_t alignment, OrtAllocatorType type) { + Ort::MemoryInfo result(name, device_type, vendor_id, device_id, device_mem_type, alignment, type); + return std::unique_ptr(result.release()); + }, + R"pbdoc(Create an OrtMemoryInfo instance using CreateMemoryInfo_V2())pbdoc") + .def_property_readonly("name", [](const OrtMemoryInfo* mem_info) -> std::string { return mem_info->name; }, R"pbdoc(Arbitrary name supplied by the user)pbdoc") + .def_property_readonly("device_id", [](const OrtMemoryInfo* mem_info) -> int { return mem_info->device.Id(); }, R"pbdoc(Device Id.)pbdoc") .def_property_readonly("mem_type", [](const OrtMemoryInfo* mem_info) -> OrtMemType { return mem_info->mem_type; }, R"pbdoc(OrtMemoryInfo memory type.)pbdoc") - .def_property_readonly("allocator_type", [](const OrtMemoryInfo* mem_info) -> OrtAllocatorType { return mem_info->alloc_type; }) - .def_property_readonly("device_mem_type", [](const OrtMemoryInfo* mem_info) -> int { return mem_info->device.MemType(); }, R"pbdoc(Device memory type.)pbdoc") - .def_property_readonly("vendor_id", [](const OrtMemoryInfo* mem_info) -> uint32_t { return mem_info->device.Vendor(); }); + .def_property_readonly("allocator_type", [](const OrtMemoryInfo* mem_info) -> OrtAllocatorType { return mem_info->alloc_type; }, R"pbdoc(Allocator type)pbdoc") + .def_property_readonly("device_mem_type", [](const OrtMemoryInfo* mem_info) -> OrtDeviceMemoryType { + auto mem_type = mem_info->device.MemType(); + return (mem_type == OrtDevice::MemType::DEFAULT) ? + OrtDeviceMemoryType_DEFAULT: OrtDeviceMemoryType_HOST_ACCESSIBLE ; }, R"pbdoc(Device memory type (Device or Host accessible).)pbdoc") + .def_property_readonly("device_vendor_id", [](const OrtMemoryInfo* mem_info) -> uint32_t { return mem_info->device.Vendor(); }); py::class_ sess(m, "SessionOptions", R"pbdoc(Configuration information for a session.)pbdoc"); diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 86e5bd2bae70f..6d16430775925 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -1599,11 +1599,28 @@ def test_ort_memory_info(self): onnxrt.OrtMemType.DEFAULT, ) self.assertEqual(cpu_memory_info.name, "Cpu") - self.assertEqual(cpu_memory_info.id, 0) + self.assertEqual(cpu_memory_info.device_id, 0) self.assertEqual(cpu_memory_info.mem_type, onnxrt.OrtMemType.DEFAULT) self.assertEqual(cpu_memory_info.allocator_type, onnxrt.OrtAllocatorType.ORT_ARENA_ALLOCATOR) - self.assertEqual(cpu_memory_info.device_mem_type, 0) - self.assertEqual(cpu_memory_info.vendor_id, 0) + self.assertEqual(cpu_memory_info.device_mem_type, onnxrt.OrtDeviceMemoryType.DEFAULT) + self.assertEqual(cpu_memory_info.device_vendor_id, 0) + + def test_ort_memory_info_create_v2(self): + cpu_memory_info = onnxrt.OrtMemoryInfo.create_v2( + "Test", + onnxrt.OrtMemoryInfoDeviceType.CPU, + 0, # vendor_id + 0, # device_id + onnxrt.OrtDeviceMemoryType.DEFAULT, + 128, # alignment + onnxrt.OrtAllocatorType.ORT_ARENA_ALLOCATOR, + ) + self.assertEqual(cpu_memory_info.name, "Test") + self.assertEqual(cpu_memory_info.device_id, 0) + self.assertEqual(cpu_memory_info.mem_type, onnxrt.OrtMemType.DEFAULT) + self.assertEqual(cpu_memory_info.allocator_type, onnxrt.OrtAllocatorType.ORT_ARENA_ALLOCATOR) + self.assertEqual(cpu_memory_info.device_mem_type, onnxrt.OrtDeviceMemoryType.DEFAULT) + self.assertEqual(cpu_memory_info.device_vendor_id, 0) def test_shared_allocator_using_create_and_register_allocator(self): # Create and register an arena based allocator From d58cecf224f736504fe98884f7d5e9ec97a1b2df Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 5 Sep 2025 16:17:45 -0700 Subject: [PATCH 4/9] Passing None to copy_tensors does not work. Also: AttributeError: 'InferenceSession' object has no attribute 'inputs_meminfo' --- .../core/session/onnxruntime_cxx_api.h | 22 ++++++-- .../core/session/onnxruntime_cxx_inline.h | 5 +- onnxruntime/__init__.py | 2 + .../python/onnxruntime_pybind_ortvalue.cc | 4 +- .../python/onnxruntime_pybind_state.cc | 55 +++++++++++++++++- .../test/python/onnxruntime_test_python.py | 56 ++++++++++++++++--- .../python/onnxruntime_test_python_autoep.py | 3 + 7 files changed, 130 insertions(+), 17 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index af0f5046a3f9f..56036b53d0e58 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1053,11 +1053,25 @@ using UnownedAllocator = detail::AllocatorImpl>; /** \brief Wrapper around ::OrtSyncStream * */ -struct SyncStream : detail::Base { - explicit SyncStream(std::nullptr_t) {} ///< Create an empty SyncStream object, must be assigned a valid one to be used - explicit SyncStream(OrtSyncStream* p) : Base{p} {} ///< Take ownership of a pointer created by C API - void* GetHandle() const; ///< Wraps SyncStream_GetHandle + +namespace detail { +template +struct SyncStreamImpl : Base { + using B = Base; + using B::B; + // For some reason this is not a const method on the stream + void* GetHandle(); ///< Wraps SyncStream_GetHandle }; +} // namespace detail + +struct SyncStream : detail::SyncStreamImpl { + ///< Create an empty SyncStream object, must be assigned a valid one to be used + explicit SyncStream(std::nullptr_t) {} + ///< Take ownership of a pointer created by C API + explicit SyncStream(OrtSyncStream* p) : SyncStreamImpl{p} {} +}; + +using UnownedSyncStream = detail::SyncStreamImpl>; namespace detail { template diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 61a261b046693..d01d370ffc230 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -651,9 +651,12 @@ inline void KeyValuePairs::Remove(const char* key) { GetApi().RemoveKeyValuePair(this->p_, key); } -inline void* SyncStream::GetHandle() const { +namespace detail { +template +inline void* SyncStreamImpl::GetHandle() { return GetApi().SyncStream_GetHandle(this->p_); } +} // namespace detail namespace detail { template diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index 91d54ecae8c28..c7ac2d5cedf7b 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -40,9 +40,11 @@ OrtMemoryInfoDeviceType, # noqa: F401 OrtMemType, # noqa: F401 OrtSparseFormat, # noqa: F401 + OrtSyncStream, # noqa: F401 RunOptions, # noqa: F401 SessionIOBinding, # noqa: F401 SessionOptions, # noqa: F401 + copy_tensors, # noqa: F401 create_and_register_allocator, # noqa: F401 create_and_register_allocator_v2, # noqa: F401 disable_telemetry_events, # noqa: F401 diff --git a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc index 1fe7ab0884f9c..d74663ddb63d7 100644 --- a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc @@ -333,7 +333,7 @@ void addOrtValueMethods(pybind11::module& m) { }) #endif // Get a pointer to Tensor data - .def("data_ptr", [](OrtValue* ml_value) -> int64_t { + .def("data_ptr", [](OrtValue* ml_value) -> uintptr_t { // TODO: Assumes that the OrtValue is a Tensor, make this generic to handle non-Tensors ORT_ENFORCE(ml_value->IsTensor(), "Only OrtValues that are Tensors are currently supported"); @@ -344,7 +344,7 @@ void addOrtValueMethods(pybind11::module& m) { } // Should cover x86 and x64 platforms - return reinterpret_cast(tensor->MutableDataRaw()); + return reinterpret_cast(tensor->MutableDataRaw()); }) .def("device_name", [](const OrtValue* ort_value) -> std::string { if (ort_value->IsTensor()) { diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 06a3a738c00db..cf2e5b4d1edeb 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -25,6 +25,7 @@ #include "core/framework/data_transfer_utils.h" #include "core/framework/data_types_internal.h" #include "core/framework/error_code_helper.h" +#include "core/framework/plugin_ep_stream.h" #include "core/framework/provider_options_utils.h" #include "core/framework/random_seed.h" #include "core/framework/sparse_tensor.h" @@ -1586,6 +1587,17 @@ void addGlobalMethods(py::module& m) { }, R"pbdoc("Validate a compiled model's compatibility information for one or more EP devices.)pbdoc"); + m.def( + "copy_tensors", + [](const std::vector& src, const std::vector& dest, OrtSyncStream* stream) { + const OrtEnv* ort_env = GetOrtEnv(); + auto status = Ort::Status(Ort::GetApi().CopyTensors(ort_env, src.data(), dest.data(), stream, src.size())); + if (!status.IsOK()) { + throw Ort::Exception(status.GetErrorMessage(), status.GetErrorCode()); + } + }, + R"pbdoc("Copy tensors from sources to destinations using specified stream handle)pbdoc"); + #if defined(USE_OPENVINO) || defined(USE_OPENVINO_PROVIDER_INTERFACE) m.def( "get_available_openvino_device_ids", []() -> std::vector { @@ -1876,6 +1888,9 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra }, R"pbdoc(Hardware device's metadata as string key/value pairs.)pbdoc"); + py::class_ py_sync_stream(m, "OrtSyncStream", + R"pbdoc(Represents a synchronization stream for model inference.)pbdoc"); + py::class_ py_ep_device(m, "OrtEpDevice", R"pbdoc(Represents a hardware device that an execution provider supports for model inference.)pbdoc"); @@ -1909,9 +1924,19 @@ for model inference.)pbdoc"); .def( "memory_info", [](const OrtEpDevice* ep_device, OrtDeviceMemoryType memory_type) -> const OrtMemoryInfo* { - return Ort::GetApi().EpDevice_MemoryInfo(ep_device, memory_type); + Ort::ConstEpDevice ep_dev(ep_device); + return static_cast(ep_dev.GetMemoryInfo(memory_type)); }, R"pbdoc(The OrtMemoryInfo instance for the OrtEpDevice specific to the device memory type.)pbdoc", + py::return_value_policy::reference_internal) + .def( + "create_sync_stream", + [](const OrtEpDevice* ep_device) -> std::unique_ptr { + Ort::ConstEpDevice ep_dev(ep_device); + Ort::SyncStream stream = ep_dev.CreateSyncStream(); + return std::unique_ptr(stream.release()); + }, + R"pbdoc(The OrtSyncStream instance for the OrtEpDevice.)pbdoc", py::return_value_policy::reference_internal); py::class_ ort_arena_cfg_binding(m, "OrtArenaCfg"); @@ -2676,6 +2701,34 @@ including arg name, arg type (contains both type and shape).)pbdoc") auto res = sess->GetSessionHandle()->GetModelMetadata(); OrtPybindThrowIfError(res.first); return *(res.second); }, py::return_value_policy::reference_internal) + .def_property_readonly("inputs_meminfo", [](const PyInferenceSession* sess) -> std::vector { + Ort::ConstSession session(reinterpret_cast(sess->GetSessionHandle())); + auto inputs_mem_info = session.GetMemoryInfoForInputs(); + std::vector result; + result.reserve(inputs_mem_info.size()); + for (const auto& info : inputs_mem_info) { + result.push_back(info); + } + return result; }, py::return_value_policy::reference_internal) + .def_property_readonly("outputs_meminfo", [](const PyInferenceSession* sess) -> std::vector { + Ort::ConstSession session(reinterpret_cast(sess->GetSessionHandle())); + auto outputs_mem_info = session.GetMemoryInfoForOutputs(); + std::vector result; + result.reserve(outputs_mem_info.size()); + for (const auto& info : outputs_mem_info) { + result.push_back(info); + } + return result; }, py::return_value_policy::reference_internal) + .def_property_readonly("inputs_epdevices", [](const PyInferenceSession* sess) -> std::vector { + Ort::ConstSession session(reinterpret_cast(sess->GetSessionHandle())); + auto ep_devices = session.GetEpDeviceForInputs(); + std::vector result; + result.reserve(ep_devices.size()); + for (const auto& device : ep_devices) { + result.push_back(device); + } + return result; }, py::return_value_policy::reference_internal) + .def("run_with_iobinding", [](PyInferenceSession* sess, SessionIOBinding& io_binding, RunOptions* run_options = nullptr) -> void { Status status; diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 6d16430775925..6c1f927fc3d26 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -252,6 +252,29 @@ def test_get_providers(self): sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers()) self.assertTrue("CPUExecutionProvider" in sess.get_providers()) + def test_copy_tensors(self): + # Generate 2 numpy arrays + a = np.random.rand(3, 2).astype(np.float32) + b = np.random.rand(3, 2).astype(np.float32) + + # Create OrtValue from numpy arrays + a_ort = onnxrt.OrtValue.ortvalue_from_numpy(a) + b_ort = onnxrt.OrtValue.ortvalue_from_numpy(b) + + # Create destination ort values with the same shape + a_ort_copy = onnxrt.OrtValue.ortvalue_from_shape_and_type(a.shape, a.dtype) + b_ort_copy = onnxrt.OrtValue.ortvalue_from_shape_and_type(b.shape, b.dtype) + + # source list + src_list = [a_ort, b_ort] + dst_list = [a_ort_copy, b_ort_copy] + # Passing None for stream as we copy between CPU + onnxrt.copy_tensors(src_list, dst_list, None) + + # Verify the contents + np.testing.assert_array_equal(a, a_ort_copy.numpy()) + np.testing.assert_array_equal(b, b_ort_copy.numpy()) + def test_enabling_and_disabling_telemetry(self): onnxrt.disable_telemetry_events() @@ -689,15 +712,30 @@ def test_run_model_with_optional_sequence_input(self): def test_run_model(self): sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=available_providers) x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) - input_name = sess.get_inputs()[0].name - self.assertEqual(input_name, "X") - input_shape = sess.get_inputs()[0].shape - self.assertEqual(input_shape, [3, 2]) - output_name = sess.get_outputs()[0].name - self.assertEqual(output_name, "Y") - output_shape = sess.get_outputs()[0].shape - self.assertEqual(output_shape, [3, 2]) - res = sess.run([output_name], {input_name: x}) + + inputs = sess.get_inputs() + self.assertEqual(len(inputs), 1) + self.assertEqual(inputs[0].name, "X") + self.assertEqual(inputs[0].shape, [3, 2]) + + inputs_meminfo = sess.inputs_meminfo + self.assertEqual(len(inputs_meminfo), 1) + self.assertIsNotNone(inputs_meminfo[0]) + + inputs_epdevices = sess.inputs_epdevices + self.assertEqual(len(inputs_epdevices), 1) + self.assertIsNotNone(inputs_epdevices[0]) + + outputs_meminfo = sess.outputs_meminfo + self.assertEqual(len(outputs_meminfo), 1) + self.assertIsNotNone(outputs_meminfo[0]) + + outputs = sess.get_outputs() + self.assertEqual(len(outputs), 1) + self.assertEqual(outputs[0].name, "Y") + self.assertEqual(outputs[0].shape, [3, 2]) + + res = sess.run([outputs[0].name], {inputs[0].name: x}) output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) diff --git a/onnxruntime/test/python/onnxruntime_test_python_autoep.py b/onnxruntime/test/python/onnxruntime_test_python_autoep.py index 49ff7b488bdbd..3ff9a4c551119 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_autoep.py +++ b/onnxruntime/test/python/onnxruntime_test_python_autoep.py @@ -229,6 +229,9 @@ def test_example_plugin_ep_devices(self): test_mem_info = test_ep_device.memory_info(onnxrt.OrtDeviceMemoryType.DEFAULT) self.assertIsNotNone(test_mem_info) + test_sync_stream = test_ep_device.create_sync_stream() + self.assertIsNotNone(test_sync_stream) + # Add EP plugin's OrtEpDevice to the SessionOptions. sess_options = onnxrt.SessionOptions() sess_options.add_provider_for_devices([test_ep_device], {"opt1": "val1"}) From 37d3ee50cce785a0ab6bd8a6b9be4ad7c993e015 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 8 Sep 2025 15:59:44 -0700 Subject: [PATCH 5/9] Address a bug when a number of meminfos requested is always for inputs --- onnxruntime/core/session/onnxruntime_c_api.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index f3e2a8ce7ba7b..33e848fec8522 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3654,7 +3654,7 @@ OrtStatus* GetInputOutputMemoryInfo(const OrtSession* ort_session, InlinedVector mem_info; ORT_API_RETURN_IF_STATUS_NOT_OK( - session->GetInputOutputMemoryInfo(InferenceSession::SessionInputOutputType::kInput, mem_info)); + session->GetInputOutputMemoryInfo(type, mem_info)); auto num_found = mem_info.size(); if (num_found > num_values) { From 3d416f9ef54ed5c20e110b188f24e5eecdc95a60 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 8 Sep 2025 18:15:16 -0700 Subject: [PATCH 6/9] Two issues: copy_tensors fails no data transfer to copy from CPU to CPU. lintrunner complains OrtSyncStream is undefined. --- onnxruntime/__init__.py | 2 +- .../onnxruntime_inference_collection.py | 28 +++++++++++++++++++ .../python/onnxruntime_pybind_state.cc | 14 +++++----- .../test/python/onnxruntime_test_python.py | 13 +++++---- .../python/onnxruntime_test_python_autoep.py | 2 ++ 5 files changed, 45 insertions(+), 14 deletions(-) diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index c7ac2d5cedf7b..00831f8865df7 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -44,7 +44,6 @@ RunOptions, # noqa: F401 SessionIOBinding, # noqa: F401 SessionOptions, # noqa: F401 - copy_tensors, # noqa: F401 create_and_register_allocator, # noqa: F401 create_and_register_allocator_v2, # noqa: F401 disable_telemetry_events, # noqa: F401 @@ -81,6 +80,7 @@ OrtDevice, # noqa: F401 OrtValue, # noqa: F401 SparseTensor, # noqa: F401 + copy_tensors, # noqa: F401 ) # TODO: thiagofc: Temporary experimental namespace for new PyTorch front-end diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index 8e380a40ceeeb..ca883f59f071b 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -199,6 +199,18 @@ def get_modelmeta(self) -> onnxruntime.ModelMetadata: "Return the metadata. See :class:`onnxruntime.ModelMetadata`." return self._model_meta + def get_inputs_memory_info(self) -> Sequence[onnxruntime.MemoryInfo]: + "Return the memory info for the inputs." + return self._inputs_meminfo + + def get_outputs_memory_info(self) -> Sequence[onnxruntime.MemoryInfo]: + "Return the memory info for the outputs." + return self._outputs_meminfo + + def get_inputs_epdevices(self) -> Sequence[onnxruntime.OrtEpDevice]: + "Return the execution providers for the inputs." + return self._inputs_epdevices + def get_providers(self) -> Sequence[str]: "Return list of registered execution providers." return self._providers @@ -576,6 +588,9 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi self._inputs_meta = self._sess.inputs_meta self._outputs_meta = self._sess.outputs_meta self._overridable_initializers = self._sess.overridable_initializers + self._inputs_meminfo = self._sess.inputs_meminfo + self._outputs_meminfo = self._sess.outputs_meminfo + self._inputs_epdevices = self._sess.inputs_epdevices self._model_meta = self._sess.model_meta self._providers = self._sess.get_providers() self._provider_options = self._sess.get_provider_options() @@ -589,6 +604,9 @@ def _reset_session(self, providers, provider_options) -> None: self._inputs_meta = None self._outputs_meta = None self._overridable_initializers = None + self._inputs_meminfo = None + self._outputs_meminfo = None + self._inputs_epdevices = None self._model_meta = None self._providers = None self._provider_options = None @@ -1062,6 +1080,15 @@ def update_inplace(self, np_arr) -> None: self._ortvalue.update_inplace(np_arr) +def copy_tensors(src: Sequence[OrtValue], dst: Sequence[OrtValue], stream: OrtSyncStream = None) -> None: + """ + Copy tensor data from source OrtValue sequence to destination OrtValue sequence. + """ + c_sources = [s._get_c_value() for s in src] + c_dsts = [d._get_c_value() for d in dst] + C.copy_tensors(c_sources, c_dsts, stream) + + class OrtDevice: """ A data structure that exposes the underlying C++ OrtDevice @@ -1074,6 +1101,7 @@ def __init__(self, c_ort_device): if isinstance(c_ort_device, C.OrtDevice): self._ort_device = c_ort_device else: + # An end user won't hit this error raise ValueError( "`Provided object` needs to be of type `onnxruntime.capi.onnxruntime_pybind11_state.OrtDevice`" ) diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index cf2e5b4d1edeb..8da0f8fedd49f 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1589,14 +1589,15 @@ void addGlobalMethods(py::module& m) { m.def( "copy_tensors", - [](const std::vector& src, const std::vector& dest, OrtSyncStream* stream) { + [](const std::vector& src, const std::vector& dest, py::object& py_arg) { const OrtEnv* ort_env = GetOrtEnv(); - auto status = Ort::Status(Ort::GetApi().CopyTensors(ort_env, src.data(), dest.data(), stream, src.size())); - if (!status.IsOK()) { - throw Ort::Exception(status.GetErrorMessage(), status.GetErrorCode()); + OrtSyncStream* stream = nullptr; + if (!py_arg.is_none()) { + stream = py_arg.cast(); } + Ort::ThrowOnError(Ort::GetApi().CopyTensors(ort_env, src.data(), dest.data(), stream, src.size())); }, - R"pbdoc("Copy tensors from sources to destinations using specified stream handle)pbdoc"); + R"pbdoc("Copy tensors from sources to destinations using specified stream handle (or None))pbdoc"); #if defined(USE_OPENVINO) || defined(USE_OPENVINO_PROVIDER_INTERFACE) m.def( @@ -1936,8 +1937,7 @@ for model inference.)pbdoc"); Ort::SyncStream stream = ep_dev.CreateSyncStream(); return std::unique_ptr(stream.release()); }, - R"pbdoc(The OrtSyncStream instance for the OrtEpDevice.)pbdoc", - py::return_value_policy::reference_internal); + R"pbdoc(The OrtSyncStream instance for the OrtEpDevice.)pbdoc"); py::class_ ort_arena_cfg_binding(m, "OrtArenaCfg"); // Note: Doesn't expose initial_growth_chunk_sizes_bytes/max_power_of_two_extend_bytes option. diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 6c1f927fc3d26..0f7d86b736a17 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -269,6 +269,7 @@ def test_copy_tensors(self): src_list = [a_ort, b_ort] dst_list = [a_ort_copy, b_ort_copy] # Passing None for stream as we copy between CPU + # Test None as allowed. onnxrt.copy_tensors(src_list, dst_list, None) # Verify the contents @@ -718,23 +719,23 @@ def test_run_model(self): self.assertEqual(inputs[0].name, "X") self.assertEqual(inputs[0].shape, [3, 2]) - inputs_meminfo = sess.inputs_meminfo + inputs_meminfo = sess.get_inputs_memory_info() self.assertEqual(len(inputs_meminfo), 1) self.assertIsNotNone(inputs_meminfo[0]) - inputs_epdevices = sess.inputs_epdevices + inputs_epdevices = sess.get_inputs_epdevices() self.assertEqual(len(inputs_epdevices), 1) self.assertIsNotNone(inputs_epdevices[0]) - outputs_meminfo = sess.outputs_meminfo - self.assertEqual(len(outputs_meminfo), 1) - self.assertIsNotNone(outputs_meminfo[0]) - outputs = sess.get_outputs() self.assertEqual(len(outputs), 1) self.assertEqual(outputs[0].name, "Y") self.assertEqual(outputs[0].shape, [3, 2]) + outputs_meminfo = sess.get_outputs_memory_info() + self.assertEqual(len(outputs_meminfo), 1) + self.assertIsNotNone(outputs_meminfo[0]) + res = sess.run([outputs[0].name], {inputs[0].name: x}) output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) diff --git a/onnxruntime/test/python/onnxruntime_test_python_autoep.py b/onnxruntime/test/python/onnxruntime_test_python_autoep.py index 3ff9a4c551119..ad7ebd80ac279 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_autoep.py +++ b/onnxruntime/test/python/onnxruntime_test_python_autoep.py @@ -228,9 +228,11 @@ def test_example_plugin_ep_devices(self): test_mem_info = test_ep_device.memory_info(onnxrt.OrtDeviceMemoryType.DEFAULT) self.assertIsNotNone(test_mem_info) + del test_mem_info test_sync_stream = test_ep_device.create_sync_stream() self.assertIsNotNone(test_sync_stream) + del test_sync_stream # Add EP plugin's OrtEpDevice to the SessionOptions. sess_options = onnxrt.SessionOptions() From f061e7526146b8d5de70c56a8972a69e1f5245c3 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Tue, 9 Sep 2025 12:10:31 -0700 Subject: [PATCH 7/9] Test copy_tensors --- .../onnxruntime_inference_collection.py | 2 +- .../test/python/onnxruntime_test_python.py | 24 --------- .../python/onnxruntime_test_python_autoep.py | 49 +++++++++++++++++++ 3 files changed, 50 insertions(+), 25 deletions(-) diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index ca883f59f071b..df00d42cf3d1f 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -1080,7 +1080,7 @@ def update_inplace(self, np_arr) -> None: self._ortvalue.update_inplace(np_arr) -def copy_tensors(src: Sequence[OrtValue], dst: Sequence[OrtValue], stream: OrtSyncStream = None) -> None: +def copy_tensors(src: Sequence[OrtValue], dst: Sequence[OrtValue], stream=None) -> None: """ Copy tensor data from source OrtValue sequence to destination OrtValue sequence. """ diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 0f7d86b736a17..30a84b22253a5 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -252,30 +252,6 @@ def test_get_providers(self): sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers()) self.assertTrue("CPUExecutionProvider" in sess.get_providers()) - def test_copy_tensors(self): - # Generate 2 numpy arrays - a = np.random.rand(3, 2).astype(np.float32) - b = np.random.rand(3, 2).astype(np.float32) - - # Create OrtValue from numpy arrays - a_ort = onnxrt.OrtValue.ortvalue_from_numpy(a) - b_ort = onnxrt.OrtValue.ortvalue_from_numpy(b) - - # Create destination ort values with the same shape - a_ort_copy = onnxrt.OrtValue.ortvalue_from_shape_and_type(a.shape, a.dtype) - b_ort_copy = onnxrt.OrtValue.ortvalue_from_shape_and_type(b.shape, b.dtype) - - # source list - src_list = [a_ort, b_ort] - dst_list = [a_ort_copy, b_ort_copy] - # Passing None for stream as we copy between CPU - # Test None as allowed. - onnxrt.copy_tensors(src_list, dst_list, None) - - # Verify the contents - np.testing.assert_array_equal(a, a_ort_copy.numpy()) - np.testing.assert_array_equal(b, b_ort_copy.numpy()) - def test_enabling_and_disabling_telemetry(self): onnxrt.disable_telemetry_events() diff --git a/onnxruntime/test/python/onnxruntime_test_python_autoep.py b/onnxruntime/test/python/onnxruntime_test_python_autoep.py index ad7ebd80ac279..ae0e5aedafa48 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_autoep.py +++ b/onnxruntime/test/python/onnxruntime_test_python_autoep.py @@ -290,6 +290,55 @@ def test_example_plugin_ep_data_transfer(self): self.unregister_execution_provider_library(ep_name) + def test_copy_tensors(self): + """ + Test global api copy_tensors between OrtValue objects + using EP plug-in data transfoer + """ + if sys.platform != "win32": + self.skipTest("Skipping test because device discovery is only supported on Windows") + + ep_lib_path = "example_plugin_ep.dll" + try: + ep_lib_path = get_name("example_plugin_ep.dll") + except FileNotFoundError: + self.skipTest(f"Skipping test because EP library '{ep_lib_path}' cannot be found") + + ep_name = "example_ep" + self.register_execution_provider_library(ep_name, os.path.realpath(ep_lib_path)) + + # Generate 2 numpy arrays + a = np.random.rand(3, 2).astype(np.float32) + b = np.random.rand(3, 2).astype(np.float32) + + # Create OrtValue from numpy arrays on EP device + # the example EP pretends to use GPU memory, so we place it there + a_device = onnxrt.OrtValue.ortvalue_from_numpy(a, "gpu", 0, 0xBE57) + b_device = onnxrt.OrtValue.ortvalue_from_numpy(b, "gpu", 0, 0xBE57) + + # Create destination ort values with the same shape on CPU + a_cpu_copy = onnxrt.OrtValue.ortvalue_from_shape_and_type(a.shape, a.dtype) + b_cpu_copy = onnxrt.OrtValue.ortvalue_from_shape_and_type(b.shape, b.dtype) + + # source list + src_list = [a_device, b_device] + dst_list = [a_cpu_copy, b_cpu_copy] + # Passing None for stream as we copy between CPU + # Test None because it is allowed + onnxrt.copy_tensors(src_list, dst_list, None) + + # Release the OrtValue on the EP device + # before the EP library is unregistered + del src_list + del a_device + del b_device + + # Verify the contents + np.testing.assert_array_equal(a, a_cpu_copy.numpy()) + np.testing.assert_array_equal(b, b_cpu_copy.numpy()) + + self.unregister_execution_provider_library(ep_name) + if __name__ == "__main__": unittest.main(verbosity=1) From 0660951dc9942384d4f505376c7194b4b042919c Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Thu, 11 Sep 2025 13:58:36 -0700 Subject: [PATCH 8/9] Address co-pilot comments --- onnxruntime/python/onnxruntime_pybind_state.cc | 2 +- onnxruntime/test/python/onnxruntime_test_python_autoep.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 67a1fcd564e07..47eb676e76a0e 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1999,7 +1999,7 @@ for model inference.)pbdoc"); .def_property_readonly("mem_type", [](const OrtMemoryInfo* mem_info) -> OrtMemType { return mem_info->mem_type; }, R"pbdoc(OrtMemoryInfo memory type.)pbdoc") .def_property_readonly("allocator_type", [](const OrtMemoryInfo* mem_info) -> OrtAllocatorType { return mem_info->alloc_type; }, R"pbdoc(Allocator type)pbdoc") .def_property_readonly("device_mem_type", [](const OrtMemoryInfo* mem_info) -> OrtDeviceMemoryType { - auto mem_type = mem_info->device.MemType(); + auto mem_type = mem_info->device.MemType(); return (mem_type == OrtDevice::MemType::DEFAULT) ? OrtDeviceMemoryType_DEFAULT: OrtDeviceMemoryType_HOST_ACCESSIBLE ; }, R"pbdoc(Device memory type (Device or Host accessible).)pbdoc") .def_property_readonly("device_vendor_id", [](const OrtMemoryInfo* mem_info) -> uint32_t { return mem_info->device.Vendor(); }); diff --git a/onnxruntime/test/python/onnxruntime_test_python_autoep.py b/onnxruntime/test/python/onnxruntime_test_python_autoep.py index ae0e5aedafa48..d6281d165c053 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_autoep.py +++ b/onnxruntime/test/python/onnxruntime_test_python_autoep.py @@ -293,7 +293,7 @@ def test_example_plugin_ep_data_transfer(self): def test_copy_tensors(self): """ Test global api copy_tensors between OrtValue objects - using EP plug-in data transfoer + using EP plug-in data transfer """ if sys.platform != "win32": self.skipTest("Skipping test because device discovery is only supported on Windows") From 5c7e4626918c1aa1e86d09c2a6d3731cc4787a4b Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Tue, 16 Sep 2025 14:50:49 -0700 Subject: [PATCH 9/9] Address review comments --- .../onnxruntime_inference_collection.py | 24 +++++++------- .../python/onnxruntime_pybind_state.cc | 31 +++++++++---------- .../test/python/onnxruntime_test_python.py | 18 +++++------ 3 files changed, 36 insertions(+), 37 deletions(-) diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index 7846f8d2e78ef..4c3313046457c 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -199,17 +199,17 @@ def get_modelmeta(self) -> onnxruntime.ModelMetadata: "Return the metadata. See :class:`onnxruntime.ModelMetadata`." return self._model_meta - def get_inputs_memory_info(self) -> Sequence[onnxruntime.MemoryInfo]: + def get_input_memory_infos(self) -> Sequence[onnxruntime.MemoryInfo]: "Return the memory info for the inputs." - return self._inputs_meminfo + return self._input_meminfos - def get_outputs_memory_info(self) -> Sequence[onnxruntime.MemoryInfo]: + def get_output_memory_infos(self) -> Sequence[onnxruntime.MemoryInfo]: "Return the memory info for the outputs." - return self._outputs_meminfo + return self._output_meminfos - def get_inputs_epdevices(self) -> Sequence[onnxruntime.OrtEpDevice]: + def get_input_epdevices(self) -> Sequence[onnxruntime.OrtEpDevice]: "Return the execution providers for the inputs." - return self._inputs_epdevices + return self._input_epdevices def get_providers(self) -> Sequence[str]: "Return list of registered execution providers." @@ -588,9 +588,9 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi self._inputs_meta = self._sess.inputs_meta self._outputs_meta = self._sess.outputs_meta self._overridable_initializers = self._sess.overridable_initializers - self._inputs_meminfo = self._sess.inputs_meminfo - self._outputs_meminfo = self._sess.outputs_meminfo - self._inputs_epdevices = self._sess.inputs_epdevices + self._input_meminfos = self._sess.input_meminfos + self._output_meminfos = self._sess.output_meminfos + self._input_epdevices = self._sess.input_epdevices self._model_meta = self._sess.model_meta self._providers = self._sess.get_providers() self._provider_options = self._sess.get_provider_options() @@ -604,9 +604,9 @@ def _reset_session(self, providers, provider_options) -> None: self._inputs_meta = None self._outputs_meta = None self._overridable_initializers = None - self._inputs_meminfo = None - self._outputs_meminfo = None - self._inputs_epdevices = None + self._input_meminfos = None + self._output_meminfos = None + self._input_epdevices = None self._model_meta = None self._providers = None self._provider_options = None diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index b5e9b3fa4ee9e..b444ee5e7eb94 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -2702,34 +2702,33 @@ including arg name, arg type (contains both type and shape).)pbdoc") auto res = sess->GetSessionHandle()->GetModelMetadata(); OrtPybindThrowIfError(res.first); return *(res.second); }, py::return_value_policy::reference_internal) - .def_property_readonly("inputs_meminfo", [](const PyInferenceSession* sess) -> std::vector { + .def_property_readonly("input_meminfos", [](const PyInferenceSession* sess) -> py::list { Ort::ConstSession session(reinterpret_cast(sess->GetSessionHandle())); auto inputs_mem_info = session.GetMemoryInfoForInputs(); - std::vector result; - result.reserve(inputs_mem_info.size()); + py::list result; for (const auto& info : inputs_mem_info) { - result.push_back(info); + const auto* p_info = static_cast(info); + result.append(py::cast(p_info, py::return_value_policy::reference)); } - return result; }, py::return_value_policy::reference_internal) - .def_property_readonly("outputs_meminfo", [](const PyInferenceSession* sess) -> std::vector { + return result; }) + .def_property_readonly("output_meminfos", [](const PyInferenceSession* sess) -> py::list { Ort::ConstSession session(reinterpret_cast(sess->GetSessionHandle())); auto outputs_mem_info = session.GetMemoryInfoForOutputs(); - std::vector result; - result.reserve(outputs_mem_info.size()); + py::list result; for (const auto& info : outputs_mem_info) { - result.push_back(info); + const auto* p_info = static_cast(info); + result.append(py::cast(p_info, py::return_value_policy::reference)); } - return result; }, py::return_value_policy::reference_internal) - .def_property_readonly("inputs_epdevices", [](const PyInferenceSession* sess) -> std::vector { + return result; }) + .def_property_readonly("input_epdevices", [](const PyInferenceSession* sess) -> py::list { Ort::ConstSession session(reinterpret_cast(sess->GetSessionHandle())); auto ep_devices = session.GetEpDeviceForInputs(); - std::vector result; - result.reserve(ep_devices.size()); + py::list result; for (const auto& device : ep_devices) { - result.push_back(device); + const auto* p_device = static_cast(device); + result.append(py::cast(p_device, py::return_value_policy::reference)); } - return result; }, py::return_value_policy::reference_internal) - + return result; }) .def("run_with_iobinding", [](PyInferenceSession* sess, SessionIOBinding& io_binding, RunOptions* run_options = nullptr) -> void { Status status; diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 30a84b22253a5..b85030b46e94d 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -695,22 +695,22 @@ def test_run_model(self): self.assertEqual(inputs[0].name, "X") self.assertEqual(inputs[0].shape, [3, 2]) - inputs_meminfo = sess.get_inputs_memory_info() - self.assertEqual(len(inputs_meminfo), 1) - self.assertIsNotNone(inputs_meminfo[0]) + input_meminfos = sess.get_input_memory_infos() + self.assertEqual(len(input_meminfos), 1) + self.assertIsNotNone(input_meminfos[0]) - inputs_epdevices = sess.get_inputs_epdevices() - self.assertEqual(len(inputs_epdevices), 1) - self.assertIsNotNone(inputs_epdevices[0]) + input_epdevices = sess.get_input_epdevices() + # The entry my be None (null) but it should be present + self.assertEqual(len(input_epdevices), 1) outputs = sess.get_outputs() self.assertEqual(len(outputs), 1) self.assertEqual(outputs[0].name, "Y") self.assertEqual(outputs[0].shape, [3, 2]) - outputs_meminfo = sess.get_outputs_memory_info() - self.assertEqual(len(outputs_meminfo), 1) - self.assertIsNotNone(outputs_meminfo[0]) + output_meminfos = sess.get_output_memory_infos() + self.assertEqual(len(output_meminfos), 1) + self.assertIsNotNone(output_meminfos[0]) res = sess.run([outputs[0].name], {inputs[0].name: x}) output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32)