Skip to content

Commit

Permalink
Torch FX graph outputs fix
Browse files Browse the repository at this point in the history
  • Loading branch information
cavusmustafa committed Mar 31, 2023
1 parent 035dc66 commit e00427a
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def fx_openvino(subgraph, example_inputs):
compiled_model = partitioner.make_partitions(model)

def _call(*args):
res = execute(compiled_model, *example_inputs, executor="openvino")
res = execute(compiled_model, *args, executor="openvino")
return res
return _call
except Exception as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,14 @@ def openvino_execute(gm: GraphModule, *args, executor_parameters=None, partition
res = compiled(ov_inputs)

results1 = [res[out] for out in compiled.outputs]
results = torch.from_numpy(np.array(results1, dtype=np.float32))
return results

if len(results1) == 1:
results = torch.from_numpy(results1[0])
return results
else:
for i in range(len(results1)):
results1[i] = torch.from_numpy(results1[i])
return results1


class OpenVINOGraphModule(torch.nn.Module):
Expand Down

0 comments on commit e00427a

Please sign in to comment.