diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc index 590e1ef3cdbdb..c01554197005d 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc @@ -196,9 +196,9 @@ void CudaToCpuMemCpy(void* dst, const void* src, size_t num_bytes) { GetProviderInfo_CUDA().cudaMemcpy_DeviceToHost(dst, src, num_bytes); } -const std::unordered_map* GetCudaToHostMemCpyFunction() { +const std::unordered_map* GetCudaToHostMemCpyFunction(const OrtDevice& device) { static std::unordered_map map{ - {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, 0}, CudaToCpuMemCpy}, + {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, device.Id()}, CudaToCpuMemCpy}, }; return ↦ @@ -256,7 +256,7 @@ void MIGraphXToCpuMemCpy(void* dst, const void* src, size_t num_bytes) { const std::unordered_map* GetMIGraphXToHostMemCpyFunction(const OrtDevice& device) { static std::unordered_map map{ - {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, 0}, MIGraphXToCpuMemCpy}, + {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, device.Id()}, MIGraphXToCpuMemCpy}, }; return ↦ @@ -374,9 +374,9 @@ void DmlToCpuMemCpy(void* dst, const void* src, size_t num_bytes) { D3D12_RESOURCE_STATE_UNORDERED_ACCESS); } -const std::unordered_map* GetDmlToHostMemCpyFunction() { +const std::unordered_map* GetDmlToHostMemCpyFunction(const OrtDevice& device) { static std::unordered_map map{ - {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::MICROSOFT, 0}, DmlToCpuMemCpy}, + {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::MICROSOFT, device.Id()}, DmlToCpuMemCpy}, }; return ↦ @@ -444,9 +444,9 @@ void RocmToCpuMemCpy(void* dst, const void* src, size_t num_bytes) { GetProviderInfo_ROCM().rocmMemcpy_DeviceToHost(dst, src, num_bytes); } -const std::unordered_map* GetRocmToHostMemCpyFunction() { +const std::unordered_map* GetRocmToHostMemCpyFunction(const OrtDevice& device) { static std::unordered_map map{ - {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, 0}, RocmToCpuMemCpy}, + {OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, device.Id()}, RocmToCpuMemCpy}, }; return ↦ diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.h b/onnxruntime/python/onnxruntime_pybind_mlvalue.h index 7b65c0aae45c1..eba783d826212 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.h +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.h @@ -74,7 +74,7 @@ void CpuToCudaMemCpy(void* dst, const void* src, size_t num_bytes); void CudaToCpuMemCpy(void* dst, const void* src, size_t num_bytes); -const std::unordered_map* GetCudaToHostMemCpyFunction(); +const std::unordered_map* GetCudaToHostMemCpyFunction(const OrtDevice&); bool IsCudaDeviceIdValid(const onnxruntime::logging::Logger& logger, int id); @@ -92,7 +92,7 @@ void CpuToDmlMemCpy(void* dst, const void* src, size_t num_bytes); void DmlToCpuMemCpy(void* dst, const void* src, size_t num_bytes); -const std::unordered_map* GetDmlToHostMemCpyFunction(); +const std::unordered_map* GetDmlToHostMemCpyFunction(const OrtDevice&); #endif @@ -102,7 +102,7 @@ void CpuToMIGraphXMemCpy(void* dst, const void* src, size_t num_bytes); void MIGraphXToCpuMemCpy(void* dst, const void* src, size_t num_bytes); -const std::unordered_map* GetMIGraphXToHostMemCpyFunction(); +const std::unordered_map* GetMIGraphXToHostMemCpyFunction(const OrtDevice&); AllocatorPtr GetMIGraphXAllocator(OrtDevice::DeviceId id); @@ -132,7 +132,7 @@ void CpuToRocmMemCpy(void* dst, const void* src, size_t num_bytes); void RocmToCpuMemCpy(void* dst, const void* src, size_t num_bytes); -const std::unordered_map* GetRocmToHostMemCpyFunction(); +const std::unordered_map* GetRocmToHostMemCpyFunction(const OrtDevice&); #endif diff --git a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc index 7234543eb14de..1d1ae1047d328 100644 --- a/onnxruntime/python/onnxruntime_pybind_ortvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_ortvalue.cc @@ -421,17 +421,17 @@ void addOrtValueMethods(pybind11::module& m) { // Converts Tensor into a numpy array .def("numpy", [](const OrtValue* ml_value) -> py::object { ORT_ENFORCE(ml_value->IsTensor(), "Only OrtValues that are Tensors are convertible to Numpy objects"); - + const auto& device = ml_value->Get().Location().device; #ifdef USE_CUDA - py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetCudaToHostMemCpyFunction()); + py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetCudaToHostMemCpyFunction(device)); #elif USE_ROCM - py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetRocmToHostMemCpyFunction()); + py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetRocmToHostMemCpyFunction(device)); #elif USE_CANN py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetCannToHostMemCpyFunction()); #elif USE_DML - py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetDmlToHostMemCpyFunction()); + py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetDmlToHostMemCpyFunction(device)); #elif USE_MIGRAPHX - py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetMIGraphXToHostMemCpyFunction()); + py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetMIGraphXToHostMemCpyFunction(device)); #else py::object obj = GetPyObjFromTensor(*ml_value, nullptr, nullptr); #endif