diff --git a/exir/lowered_backend_module.py b/exir/lowered_backend_module.py index b025cef233d..f6a995b5732 100644 --- a/exir/lowered_backend_module.py +++ b/exir/lowered_backend_module.py @@ -454,17 +454,17 @@ def _get_new_signature( # noqa: C901 new_state_dict = {} new_constants = {} - input_tensor_node_to_sig = { - input_spec.arg.name: input_spec - for input_spec in old_signature.input_specs - if isinstance(input_spec.arg, TensorArgument) - } + placeholder_nodes = [ + node.name for node in original_program.graph.nodes if node.op == "placeholder" + ] + assert len(placeholder_nodes) == len(old_signature.input_specs) + input_node_to_sig = dict(zip(placeholder_nodes, old_signature.input_specs)) for node in gm.graph.nodes: is_tagged = tag is None or node.meta.get("delegation_tag", None) == tag if node.op == "placeholder": - if node.name not in input_tensor_node_to_sig: + if node.name not in input_node_to_sig: assert tag is not None input_specs.append( InputSpec( @@ -475,7 +475,7 @@ def _get_new_signature( # noqa: C901 ) continue - orig_input_spec = input_tensor_node_to_sig[node.name] + orig_input_spec = input_node_to_sig[node.name] if not isinstance(orig_input_spec.arg, TensorArgument): input_specs.append(orig_input_spec)