55from collections import namedtuple
66from dataclasses import dataclass , field
77from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Union
8+ from unittest .mock import patch
89
910import sympy
1011import torch
12+ import torch ._export
1113from executorch .exir .dynamic_shape import DynamicMemoryPlanningMode
1214from executorch .exir .emit import emit_program , EmitterOutput
1315from executorch .exir .error import ExportError , ExportErrorType , InternalError
3234from executorch .exir .schema import Program
3335from executorch .exir .serialize import serialize_to_flatbuffer
3436from executorch .exir .tracer import (
37+ _default_decomposition_table ,
3538 dispatch_trace ,
3639 dynamo_trace ,
3740 ExirDynamoConfig ,
4851from torch ._dynamo .eval_frame import Constraint
4952from torch ._export import CallSpec , export , ExportGraphSignature
5053from torch ._export .exported_program import ExportedProgram
54+ from torch ._export .passes import ReplaceViewOpsWithViewCopyOpsPass
5155from torch ._export .passes .add_runtime_assertions_for_constraints_pass import (
5256 InputDim ,
5357 RangeConstraint ,
5660from torch .fx ._compatibility import compatibility
5761from torch .fx .experimental .proxy_tensor import make_fx
5862from torch .fx .experimental .symbolic_shapes import ShapeEnv
63+ from torch .fx .graph import _PyTreeCodeGen , _PyTreeInfo
5964from torch .utils import _pytree as pytree
6065
6166
6267Val = Any
6368
6469
70+ def _unlift (gm , inp_pos_to_param_buffer_name , in_spec , out_spec , state_dict ):
71+ count = 0
72+ # Step 1: make lifted params as get_attr
73+ for node in gm .graph .nodes :
74+ if node .op == "placeholder" :
75+ if count in inp_pos_to_param_buffer_name :
76+ with gm .graph .inserting_after (node ):
77+ getattr_node = gm .graph .get_attr (
78+ inp_pos_to_param_buffer_name [count ]
79+ )
80+ node .replace_all_uses_with (getattr_node )
81+ metadata = node .meta
82+ gm .graph .erase_node (node )
83+ getattr_node .meta = metadata
84+ count += 1
85+
86+ # Step 2: Fix the input/output of the graph now that we deleted
87+ # some args.
88+ gm .graph .lint ()
89+ names = [f"arg_{ i } " for i in range (len (in_spec .children_specs ))]
90+ gm .graph ._codegen = _PyTreeCodeGen (
91+ _PyTreeInfo (
92+ names ,
93+ in_spec ,
94+ out_spec ,
95+ )
96+ )
97+ gm .recompile ()
98+
99+ # Step 3: Find state references in HigherOrderOps and recursively
100+ # fix them.
101+ for node in gm .graph .nodes :
102+ if node .op == "call_function" and node .target == torch .ops .cond :
103+ pred , true_graph , false_graph , operands = node .args
104+ true_gm = getattr (gm , true_graph .name )
105+ false_gm = getattr (gm , false_graph .name )
106+ inp_pos_to_param_buffer_name_for_submod = {}
107+ real_operands = []
108+ for ix , operand in enumerate (operands ):
109+ if operand .target in inp_pos_to_param_buffer_name .values ():
110+ inp_pos_to_param_buffer_name_for_submod [ix ] = operand .target
111+ true_gm .register_buffer (operand .target , state_dict [operand .target ])
112+ false_gm .register_buffer (operand .target , state_dict [operand .target ])
113+ else :
114+ real_operands .append (operand )
115+ node .args = (pred , true_graph , false_graph , real_operands )
116+
117+ _ , in_spec = pytree .tree_flatten (real_operands )
118+
119+ _unlift (
120+ true_gm ,
121+ inp_pos_to_param_buffer_name_for_submod ,
122+ in_spec ,
123+ None ,
124+ state_dict ,
125+ )
126+ _unlift (
127+ false_gm ,
128+ inp_pos_to_param_buffer_name_for_submod ,
129+ in_spec ,
130+ None ,
131+ state_dict ,
132+ )
133+ if node .op == "call_function" and node .target .__name__ == "map_impl" :
134+ body_graph , num_mapped , * operands = node .args
135+ body_gm = getattr (gm , body_graph .name )
136+ inp_pos_to_buffer_name_for_submod = {}
137+ real_operands = []
138+ for ix , operand in enumerate (operands ):
139+ if operand .target in inp_pos_to_param_buffer_name .values ():
140+ inp_pos_to_buffer_name_for_submod [ix ] = operand .target
141+ body_gm .register_buffer (operand .target , state_dict [operand .target ])
142+ else :
143+ real_operands .append (operand )
144+ node .args = (body_graph , num_mapped , * real_operands )
145+
146+ _ , in_spec = pytree .tree_flatten (real_operands )
147+
148+ _unlift (
149+ body_gm , inp_pos_to_buffer_name_for_submod , in_spec , None , state_dict
150+ )
151+ gm .graph .lint ()
152+ gm .graph .eliminate_dead_code ()
153+ gm .recompile ()
154+ return gm
155+
156+
157+ def unlift_exported_program_lifted_states (
158+ ep : torch ._export .exported_program .ExportedProgram ,
159+ ):
160+ new_gm = copy .deepcopy (ep .graph_module )
161+
162+ # TODO Fix the period in params/buffers names later
163+ # maybe a pass to replace graph signature with fixed names
164+ param_buffer_name_to_corrected_name = {}
165+
166+ for name , stuff in ep .state_dict .items ():
167+ if name in ep .graph_signature .buffers :
168+ if "." in name :
169+ new_gm .register_buffer (name .replace ("." , "_" ), stuff )
170+ param_buffer_name_to_corrected_name [name ] = name .replace ("." , "_" )
171+ else :
172+ new_gm .register_buffer (name , stuff )
173+ elif name in ep .graph_signature .parameters :
174+ if "." in name :
175+ new_gm .register_parameter (name .replace ("." , "_" ), stuff )
176+ param_buffer_name_to_corrected_name [name ] = name .replace ("." , "_" )
177+ else :
178+ new_gm .register_parameter (name , stuff )
179+ else :
180+ raise AssertionError ("encountered not registered param/buffer" )
181+
182+ count = 0
183+ inp_pos_to_param_buffer_name = {}
184+ for node in new_gm .graph .nodes :
185+ if node .op == "placeholder" :
186+ if node .name in ep .graph_signature .inputs_to_buffers :
187+ buffer_name = ep .graph_signature .inputs_to_buffers [node .name ]
188+ if buffer_name in param_buffer_name_to_corrected_name :
189+ inp_pos_to_param_buffer_name [
190+ count
191+ ] = param_buffer_name_to_corrected_name [buffer_name ]
192+ else :
193+ inp_pos_to_param_buffer_name [count ] = buffer_name
194+ if node .name in ep .graph_signature .inputs_to_parameters :
195+ param_name = ep .graph_signature .inputs_to_parameters [node .name ]
196+ if param_name in param_buffer_name_to_corrected_name :
197+ inp_pos_to_param_buffer_name [
198+ count
199+ ] = param_buffer_name_to_corrected_name [param_name ]
200+ else :
201+ inp_pos_to_param_buffer_name [count ] = param_name
202+ count += 1
203+ new_gm = _unlift (
204+ new_gm ,
205+ inp_pos_to_param_buffer_name ,
206+ ep .call_spec .in_spec ,
207+ ep .call_spec .out_spec ,
208+ ep .state_dict ,
209+ )
210+ return new_gm
211+
212+
65213@compatibility (is_backward_compatible = False )
66214@dataclass
67215class CaptureConfig :
@@ -70,6 +218,7 @@ class CaptureConfig:
70218 enable_dynamic_shape : bool = False
71219 enable_aot : bool = False
72220 _dynamo_config : "ExirDynamoConfig" = ExirDynamoConfig ()
221+ _unflatten : bool = False
73222
74223
75224@compatibility (is_backward_compatible = False )
@@ -469,8 +618,15 @@ def capture(
469618 "Functionalization is required for enable_aot." ,
470619 )
471620
472- ep = export (f , args , _add_runtime_assertions = False , constraints = constraints )
473- return ep # pyre-ignore
621+ # TODO remove this later
622+ with patch ("torch._export.DECOMP_TABLE" , _default_decomposition_table ()):
623+ ep = export (
624+ f , args , _add_runtime_assertions = False , constraints = constraints
625+ )
626+ ep = ep .transform (ReplaceViewOpsWithViewCopyOpsPass ())
627+ if not config ._unflatten :
628+ return ep # pyre-ignore
629+ graph_module = unlift_exported_program_lifted_states (ep )
474630
475631 elif config .enable_dynamic_shape :
476632 if not config ._dynamo_config .dynamic_shapes :
0 commit comments