Skip to content

Commit

Permalink
Merge pull request #9 from nosovmik/mo_setPartialShape
Browse files Browse the repository at this point in the history
SetPartialShape in MO
  • Loading branch information
nosovmik authored Apr 16, 2021
2 parents aaedbdc + dfd675c commit 5aed606
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
10 changes: 5 additions & 5 deletions model-optimizer/mo/front/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,11 +591,11 @@ def input_user_data_repack(graph: Graph, input_user_shapes: [None, list, dict, n
if node is None:
raise Error('Cannot find location {} in the graph'.format(input_name))
shape = None if isinstance(input_user_shapes, list) else input_user_shapes[input_name]
if input_name in input_user_data_types and input_user_data_types[input_name] is not None:
data_type = input_user_data_types[input_name]
_input_shapes.append({'node': node, 'shape': shape, 'data_type': data_type})
else:
_input_shapes.append({'node': node, 'shape': shape})
if input_name in input_user_data_types and input_user_data_types[input_name] is not None:
data_type = input_user_data_types[input_name]
_input_shapes.append({'node': node, 'shape': shape, 'data_type': data_type})
else:
_input_shapes.append({'node': node, 'shape': shape})
elif isinstance(input_user_shapes, np.ndarray):
model_inputs = inputModel.getInputs()
assert len(model_inputs) == 1
Expand Down
5 changes: 2 additions & 3 deletions model-optimizer/mo/pipeline/unified.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,8 @@ def moc_pipeline(argv: argparse.Namespace):
apply_replacements_list(graph, transforms)
user_shapes = graph.graph['user_shapes']
if len(user_shapes) > 0:
assert len(inputModel.getInputs()) == 1
assert len(user_shapes) == 1
inputModel.setPartialShape(user_shapes[0]['node'], PartialShape(user_shapes[0]['shape']))
for user_shape in user_shapes:
inputModel.setPartialShape(user_shape['node'], PartialShape(user_shape['shape']))
nGraphModel = fe.convert(inputModel)
network = function_to_cnn(nGraphModel)
graph.graph['network'] = network
Expand Down

0 comments on commit 5aed606

Please sign in to comment.