@@ -201,6 +201,7 @@ def _insert_lowered_submodule(
201201 is_submodule : bool ,
202202 toplevel_input_specs_to_delete : Dict [str , InputSpec ],
203203 toplevel_output_specs_to_delete : Dict [str , OutputSpec ],
204+ validate_program : bool = True ,
204205):
205206 owning_graph_module = call_submodule_node .graph .owning_module
206207 # call delegate args should only use user_inputs
@@ -275,6 +276,7 @@ def generate_debug_handle(ep: ExportedProgram) -> int:
275276 call_delegate_node ,
276277 toplevel_input_specs_to_delete ,
277278 toplevel_output_specs_to_delete ,
279+ validate_program ,
278280 )
279281
280282
@@ -353,10 +355,6 @@ def _partition_and_lower_one_graph_module(
353355 toplevel_output_specs_to_delete ,
354356 )
355357
356- # perform validation here to make sure all the delegated submodules are gone
357- # validate inside _insert_lowered_submodule will break multi-method scenario
358- if not is_submodule :
359- owning_program ._validate ()
360358 return tagged_graph_module
361359
362360
@@ -661,13 +659,11 @@ def lower_all_submodules_to_backend(
661659 is_submodule ,
662660 toplevel_input_specs_to_delete ,
663661 toplevel_output_specs_to_delete ,
662+ # validate only when all submodules are processed
663+ validate_program = call_submodule_node
664+ == list_of_call_submodule_nodes [- 1 ],
664665 )
665666
666- # perform validation here to make sure all the delegated submodules are gone
667- # validate inside _insert_lowered_submodule will break multi-method scenario
668- if not is_submodule :
669- owning_program ._validate ()
670-
671667
672668@dataclass
673669class MethodProgramsPartitionerSpec :
0 commit comments