Skip to content

Commit 30883fd

Browse files
committed
chenge validation logic
1 parent 8468fa7 commit 30883fd

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

exir/backend/backend_api.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def _insert_lowered_submodule(
201201
is_submodule: bool,
202202
toplevel_input_specs_to_delete: Dict[str, InputSpec],
203203
toplevel_output_specs_to_delete: Dict[str, OutputSpec],
204+
validate_program: bool = True,
204205
):
205206
owning_graph_module = call_submodule_node.graph.owning_module
206207
# call delegate args should only use user_inputs
@@ -275,6 +276,7 @@ def generate_debug_handle(ep: ExportedProgram) -> int:
275276
call_delegate_node,
276277
toplevel_input_specs_to_delete,
277278
toplevel_output_specs_to_delete,
279+
validate_program,
278280
)
279281

280282

@@ -353,10 +355,6 @@ def _partition_and_lower_one_graph_module(
353355
toplevel_output_specs_to_delete,
354356
)
355357

356-
# perform validation here to make sure all the delegated submodules are gone
357-
# validate inside _insert_lowered_submodule will break multi-method scenario
358-
if not is_submodule:
359-
owning_program._validate()
360358
return tagged_graph_module
361359

362360

@@ -661,13 +659,11 @@ def lower_all_submodules_to_backend(
661659
is_submodule,
662660
toplevel_input_specs_to_delete,
663661
toplevel_output_specs_to_delete,
662+
# validate only when all submodules are processed
663+
validate_program=call_submodule_node
664+
== list_of_call_submodule_nodes[-1],
664665
)
665666

666-
# perform validation here to make sure all the delegated submodules are gone
667-
# validate inside _insert_lowered_submodule will break multi-method scenario
668-
if not is_submodule:
669-
owning_program._validate()
670-
671667

672668
@dataclass
673669
class MethodProgramsPartitionerSpec:

exir/lowered_backend_module.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,7 @@ def _unsafe_adjust_original_program( # noqa: C901
862862
call_delegate_node: torch.fx.Node,
863863
input_specs_to_delete: Dict[str, InputSpec],
864864
output_specs_to_delete: Dict[str, OutputSpec],
865+
validate_program: bool,
865866
) -> None:
866867
"""
867868
Directly modify the original exported program's signature and state dict
@@ -958,3 +959,6 @@ def _unsafe_adjust_original_program( # noqa: C901
958959
if user_idx > idx:
959960
user.args = (user.args[0], user_idx - (len(getitem_idxs) - i))
960961
break
962+
963+
if validate_program:
964+
original_program._validate()

0 commit comments

Comments
 (0)