Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions onnxruntime/python/onnxruntime_pybind_mlvalue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,9 @@ void CudaToCpuMemCpy(void* dst, const void* src, size_t num_bytes) {
GetProviderInfo_CUDA().cudaMemcpy_DeviceToHost(dst, src, num_bytes);
}

const std::unordered_map<OrtDevice, MemCpyFunc>* GetCudaToHostMemCpyFunction() {
const std::unordered_map<OrtDevice, MemCpyFunc>* GetCudaToHostMemCpyFunction(const OrtDevice& device) {
static std::unordered_map<OrtDevice, MemCpyFunc> map{
{OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, 0}, CudaToCpuMemCpy},
{OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, device.Id()}, CudaToCpuMemCpy},
};

return &map;
Expand Down Expand Up @@ -256,7 +256,7 @@ void MIGraphXToCpuMemCpy(void* dst, const void* src, size_t num_bytes) {

const std::unordered_map<OrtDevice, MemCpyFunc>* GetMIGraphXToHostMemCpyFunction(const OrtDevice& device) {
static std::unordered_map<OrtDevice, MemCpyFunc> map{
{OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, 0}, MIGraphXToCpuMemCpy},
{OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, device.Id()}, MIGraphXToCpuMemCpy},
};

return &map;
Expand Down Expand Up @@ -374,9 +374,9 @@ void DmlToCpuMemCpy(void* dst, const void* src, size_t num_bytes) {
D3D12_RESOURCE_STATE_UNORDERED_ACCESS);
}

const std::unordered_map<OrtDevice, MemCpyFunc>* GetDmlToHostMemCpyFunction() {
const std::unordered_map<OrtDevice, MemCpyFunc>* GetDmlToHostMemCpyFunction(const OrtDevice& device) {
static std::unordered_map<OrtDevice, MemCpyFunc> map{
{OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::MICROSOFT, 0}, DmlToCpuMemCpy},
{OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::MICROSOFT, device.Id()}, DmlToCpuMemCpy},
};

return &map;
Expand Down Expand Up @@ -444,9 +444,9 @@ void RocmToCpuMemCpy(void* dst, const void* src, size_t num_bytes) {
GetProviderInfo_ROCM().rocmMemcpy_DeviceToHost(dst, src, num_bytes);
}

const std::unordered_map<OrtDevice, MemCpyFunc>* GetRocmToHostMemCpyFunction() {
const std::unordered_map<OrtDevice, MemCpyFunc>* GetRocmToHostMemCpyFunction(const OrtDevice& device) {
static std::unordered_map<OrtDevice, MemCpyFunc> map{
{OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, 0}, RocmToCpuMemCpy},
{OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, device.Id()}, RocmToCpuMemCpy},
};

return &map;
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/python/onnxruntime_pybind_mlvalue.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ void CpuToCudaMemCpy(void* dst, const void* src, size_t num_bytes);

void CudaToCpuMemCpy(void* dst, const void* src, size_t num_bytes);

const std::unordered_map<OrtDevice, MemCpyFunc>* GetCudaToHostMemCpyFunction();
const std::unordered_map<OrtDevice, MemCpyFunc>* GetCudaToHostMemCpyFunction(const OrtDevice&);

bool IsCudaDeviceIdValid(const onnxruntime::logging::Logger& logger, int id);

Expand All @@ -92,7 +92,7 @@ void CpuToDmlMemCpy(void* dst, const void* src, size_t num_bytes);

void DmlToCpuMemCpy(void* dst, const void* src, size_t num_bytes);

const std::unordered_map<OrtDevice, MemCpyFunc>* GetDmlToHostMemCpyFunction();
const std::unordered_map<OrtDevice, MemCpyFunc>* GetDmlToHostMemCpyFunction(const OrtDevice&);

#endif

Expand All @@ -102,7 +102,7 @@ void CpuToMIGraphXMemCpy(void* dst, const void* src, size_t num_bytes);

void MIGraphXToCpuMemCpy(void* dst, const void* src, size_t num_bytes);

const std::unordered_map<OrtDevice, MemCpyFunc>* GetMIGraphXToHostMemCpyFunction();
const std::unordered_map<OrtDevice, MemCpyFunc>* GetMIGraphXToHostMemCpyFunction(const OrtDevice&);

AllocatorPtr GetMIGraphXAllocator(OrtDevice::DeviceId id);

Expand Down Expand Up @@ -132,7 +132,7 @@ void CpuToRocmMemCpy(void* dst, const void* src, size_t num_bytes);

void RocmToCpuMemCpy(void* dst, const void* src, size_t num_bytes);

const std::unordered_map<OrtDevice, MemCpyFunc>* GetRocmToHostMemCpyFunction();
const std::unordered_map<OrtDevice, MemCpyFunc>* GetRocmToHostMemCpyFunction(const OrtDevice&);

#endif

Expand Down
10 changes: 5 additions & 5 deletions onnxruntime/python/onnxruntime_pybind_ortvalue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -421,17 +421,17 @@
// Converts Tensor into a numpy array
.def("numpy", [](const OrtValue* ml_value) -> py::object {
ORT_ENFORCE(ml_value->IsTensor(), "Only OrtValues that are Tensors are convertible to Numpy objects");

const auto& device = ml_value->Get<Tensor>().Location().device;

Check warning on line 424 in onnxruntime/python/onnxruntime_pybind_ortvalue.cc

View workflow job for this annotation

GitHub Actions / build_x64_release_vitisai

'device': local variable is initialized but not referenced

Check failure on line 424 in onnxruntime/python/onnxruntime_pybind_ortvalue.cc

View workflow job for this annotation

GitHub Actions / build_x64_release_vitisai

the following warning is treated as an error

Check warning on line 424 in onnxruntime/python/onnxruntime_pybind_ortvalue.cc

View workflow job for this annotation

GitHub Actions / build_x64_release

'device': local variable is initialized but not referenced

Check failure on line 424 in onnxruntime/python/onnxruntime_pybind_ortvalue.cc

View workflow job for this annotation

GitHub Actions / build_x64_release

the following warning is treated as an error

Check warning on line 424 in onnxruntime/python/onnxruntime_pybind_ortvalue.cc

View workflow job for this annotation

GitHub Actions / build_x64_debug

'device': local variable is initialized but not referenced

Check failure on line 424 in onnxruntime/python/onnxruntime_pybind_ortvalue.cc

View workflow job for this annotation

GitHub Actions / build_x64_debug

the following warning is treated as an error
#ifdef USE_CUDA
py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetCudaToHostMemCpyFunction());
py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetCudaToHostMemCpyFunction(device));
#elif USE_ROCM
py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetRocmToHostMemCpyFunction());
py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetRocmToHostMemCpyFunction(device));
#elif USE_CANN
py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetCannToHostMemCpyFunction());
#elif USE_DML
py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetDmlToHostMemCpyFunction());
py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetDmlToHostMemCpyFunction(device));
#elif USE_MIGRAPHX
py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetMIGraphXToHostMemCpyFunction());
py::object obj = GetPyObjFromTensor(*ml_value, nullptr, GetMIGraphXToHostMemCpyFunction(device));
#else
py::object obj = GetPyObjFromTensor(*ml_value, nullptr, nullptr);
#endif
Expand Down
Loading