@@ -83,25 +83,35 @@ class OpenvinoBackend final : public ::executorch::runtime::BackendInterface {
8383
8484 auto infer_request = execution_handle->infer_request ;
8585
86- // Assume first argument is the input tensor
87- auto input_tensor = args[0 ]->toTensor ();
88- ov::Shape input_shape (input_tensor.sizes ().begin (), input_tensor.sizes ().end ());
86+ size_t num_inputs = infer_request->get_compiled_model ().inputs ().size ();
87+ size_t num_outputs = infer_request->get_compiled_model ().outputs ().size ();
8988
90- // Convert input tensor to OpenVINO tensor
91- ov::element::Type ov_type = convert_to_openvino_type (input_tensor.scalar_type ());
92- ov::Tensor ov_input_tensor (ov_type, input_shape, input_tensor.mutable_data_ptr ());
89+ // Set inputs
90+ for (size_t i = 0 ; i < num_inputs; i++) {
91+ auto input_tensor = args[i]->toTensor ();
92+ ov::Shape input_shape (input_tensor.sizes ().begin (), input_tensor.sizes ().end ());
9393
94- // infer_request->set_tensor("input", ov_input_tensor);
95- infer_request->set_input_tensor (0 , ov_input_tensor);
94+ // Convert input tensor to OpenVINO tensor
95+ ov::element::Type ov_type = convert_to_openvino_type (input_tensor.scalar_type ());
96+ ov::Tensor ov_input_tensor (ov_type, input_shape, input_tensor.mutable_data_ptr ());
9697
97- // Execute the inference
98- infer_request->infer ();
98+ infer_request->set_input_tensor (i, ov_input_tensor);
99+ }
100+
101+ // Set outputs
102+ for (size_t i = 0 ; i < num_outputs; i++) {
103+ auto output_tensor = args[num_inputs+i]->toTensor ();
104+ ov::Shape output_shape (output_tensor.sizes ().begin (), output_tensor.sizes ().end ());
99105
100- // Retrieve and copy output
101- auto output_tensor = args[ 1 ]-> toTensor (); // Assume second argument is the output
102- ov::Tensor ov_output_tensor = infer_request-> get_output_tensor ( 0 ); // get_tensor("output" );
106+ // Convert input tensor to OpenVINO tensor
107+ ov::element::Type ov_type = convert_to_openvino_type (output_tensor. scalar_type ());
108+ ov::Tensor ov_output_tensor (ov_type, output_shape, output_tensor. mutable_data_ptr () );
103109
104- std::memcpy (output_tensor.mutable_data_ptr (), ov_output_tensor.data (), ov_output_tensor.get_byte_size ());
110+ infer_request->set_output_tensor (i, ov_output_tensor);
111+ }
112+
113+ // Execute the inference
114+ infer_request->infer ();
105115
106116 return Error::Ok;
107117 }
0 commit comments