1010import  torch 
1111from  executorch .exir .pass_base  import  ExportPass , PassResult 
1212
13- from  .utils  import  copy_nn_module_stack 
13+ from  .utils  import  merge_decomposed_graph 
1414
1515
1616class  DecomposeWrapWithAutocast (ExportPass ):
@@ -52,7 +52,7 @@ def _replace(self, gm: torch.fx.GraphModule) -> None:
5252        graph  =  gm .graph 
5353        for  node  in  graph .nodes :
5454            if  isinstance (node .target , torch ._higher_order_ops .wrap .WrapWithAutocast ):
55-                 submod , submod_name  =  self ._get_submod (gm , node )
55+                 submod , _  =  self ._get_submod (gm , node )
5656                n_args  =  node .args 
5757                input_submod  =  n_args [4 ]
5858                decomposed_module  =  submod 
@@ -61,22 +61,13 @@ def _replace(self, gm: torch.fx.GraphModule) -> None:
6161                    # which ensures that reference to nodes are correctly updated in the new graph 
6262                    # remap = {"expand_1": node.args[5], "to_4": node.args[6]} 
6363                    remap  =  {n_args [i ].name : n_args [i ] for  i  in  range (5 , len (n_args ))}
64- 
65-                     for  decomposed_node  in  decomposed_module .graph .nodes :
66-                         copy_nn_module_stack (node , decomposed_node )
67-                         # no need to copy existent 'output' 
68-                         if  decomposed_node .op  ==  "output" :
69-                             self ._replace_output (node , decomposed_node , remap )
70-                         # no need to copy existent placeholders 
71-                         elif  decomposed_node .op  ==  "placeholder" :
72-                             # replace node map from string to graph node 
73-                             remap [decomposed_node ] =  remap .pop (decomposed_node .name )
74-                         else :
75-                             remap [decomposed_node ] =  graph .node_copy (
76-                                 decomposed_node ,
77-                                 arg_transform = lambda  x , remap = remap : remap [x ],
78-                             )
79- 
64+                     merge_decomposed_graph (
65+                         remap = remap ,
66+                         target_node = node ,
67+                         target_graph = graph ,
68+                         decomposed_graph_module = decomposed_module ,
69+                         output_processor = self ._replace_output ,
70+                     )
8071                    graph .erase_node (node )
8172
8273                graph .erase_node (input_submod )
0 commit comments