Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 8 additions & 21 deletions onnxruntime/core/providers/openvino/ov_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,29 +126,16 @@ class OVInferRequest {
OVTensorPtr GetTensor(const std::string& name);
std::string GetInputTensorName(uint32_t index);

// Set tensor described param_info and ort_ptr. Overrides shape in param_info with shape_override. Call infer req tensor if ort_ptr is last set.
// Set tensor call infer req tensor if ort_ptr differs from last set ptr.
void SetTensor(const std::string& name, const ov::element::Type& type, const ov::Shape& shape, void* ort_ptr) {
auto& cached_binding = bindings_cache_[name];
if (cached_binding.ort_ptr != ort_ptr) {
auto tensor_ptr = std::make_shared<ov::Tensor>(type, shape, const_cast<void*>(ort_ptr));
SetTensor(name, tensor_ptr);
cached_binding = {tensor_ptr, ort_ptr};
} else if (ort_ptr == nullptr) {
// a null ort_ptr is expected for a tensor that has 0 elements.
// for example, a tensor of shape=[1, 8, 0, 64], which is valid.
// So, we check to see if at least one shape entry is 0.
auto contains_zero = [](const ov::Shape& shape) {
for (auto& s : shape)
if (s == 0) return true;
return false;
};
if (contains_zero(shape)) {
// if there are zero elements (i.e. at least one shape entry is 0),
// then create and set the tensor anyway.
auto tensor_ptr = std::make_shared<ov::Tensor>(type, shape);
SetTensor(name, tensor_ptr);
cached_binding = {tensor_ptr, ort_ptr};
}
if (cached_binding.ort_ptr != ort_ptr ||
!cached_binding.tensor_ptr ||
cached_binding.tensor_ptr->get_shape() != shape) {
cached_binding.tensor_ptr.reset();
auto ov_tensor = std::make_shared<ov::Tensor>(type, shape, const_cast<void*>(ort_ptr));
ovInfReq.set_tensor(name, *ov_tensor);
cached_binding = {ov_tensor, ort_ptr};
}
}

Expand Down
Loading