55from collections import namedtuple
66from dataclasses import dataclass , field
77from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Union
8+ from unittest .mock import patch
89
910import sympy
1011import torch
12+ import torch ._export
1113from executorch .exir .dynamic_shape import DynamicMemoryPlanningMode
1214from executorch .exir .emit import emit_program , EmitterOutput
1315from executorch .exir .error import ExportError , ExportErrorType , InternalError
2527from executorch .exir .schema import Program
2628from executorch .exir .serialize import serialize_to_flatbuffer
2729from executorch .exir .tracer import (
30+ _default_decomposition_table ,
2831 dispatch_trace ,
2932 dynamo_trace ,
3033 ExirDynamoConfig ,
4144from torch ._dynamo .eval_frame import Constraint
4245from torch ._export import CallSpec , export , ExportGraphSignature
4346from torch ._export .exported_program import ExportedProgram
47+ from torch ._export .passes import ReplaceViewOpsWithViewCopyOpsPass
4448from torch ._export .passes .add_runtime_assertions_for_constraints_pass import (
4549 InputDim ,
4650 RangeConstraint ,
4953from torch .fx ._compatibility import compatibility
5054from torch .fx .experimental .proxy_tensor import make_fx
5155from torch .fx .experimental .symbolic_shapes import ShapeEnv
56+ from torch .fx .graph import _PyTreeCodeGen , _PyTreeInfo
5257from torch .utils import _pytree as pytree
5358
5459
5560Val = Any
5661
5762
63+ def _unlift (gm , inp_pos_to_param_buffer_name , in_spec , out_spec , state_dict ):
64+ count = 0
65+ # Step 1: make lifted params as get_attr
66+ for node in gm .graph .nodes :
67+ if node .op == "placeholder" :
68+ if count in inp_pos_to_param_buffer_name :
69+ with gm .graph .inserting_after (node ):
70+ getattr_node = gm .graph .get_attr (
71+ inp_pos_to_param_buffer_name [count ]
72+ )
73+ node .replace_all_uses_with (getattr_node )
74+ metadata = node .meta
75+ gm .graph .erase_node (node )
76+ getattr_node .meta = metadata
77+ count += 1
78+
79+ # Step 2: Fix the input/output of the graph now that we deleted
80+ # some args.
81+ gm .graph .lint ()
82+ names = [f"arg_{ i } " for i in range (len (in_spec .children_specs ))]
83+ gm .graph ._codegen = _PyTreeCodeGen (
84+ _PyTreeInfo (
85+ names ,
86+ in_spec ,
87+ out_spec ,
88+ )
89+ )
90+ gm .recompile ()
91+
92+ # Step 3: Find state references in HigherOrderOps and recursively
93+ # fix them.
94+ for node in gm .graph .nodes :
95+ if node .op == "call_function" and node .target == torch .ops .cond :
96+ pred , true_graph , false_graph , operands = node .args
97+ true_gm = getattr (gm , true_graph .name )
98+ false_gm = getattr (gm , false_graph .name )
99+ inp_pos_to_param_buffer_name_for_submod = {}
100+ real_operands = []
101+ for ix , operand in enumerate (operands ):
102+ if operand .target in inp_pos_to_param_buffer_name .values ():
103+ inp_pos_to_param_buffer_name_for_submod [ix ] = operand .target
104+ true_gm .register_buffer (operand .target , state_dict [operand .target ])
105+ false_gm .register_buffer (operand .target , state_dict [operand .target ])
106+ else :
107+ real_operands .append (operand )
108+ node .args = (pred , true_graph , false_graph , real_operands )
109+
110+ _ , in_spec = pytree .tree_flatten (real_operands )
111+
112+ _unlift (
113+ true_gm ,
114+ inp_pos_to_param_buffer_name_for_submod ,
115+ in_spec ,
116+ None ,
117+ state_dict ,
118+ )
119+ _unlift (
120+ false_gm ,
121+ inp_pos_to_param_buffer_name_for_submod ,
122+ in_spec ,
123+ None ,
124+ state_dict ,
125+ )
126+ if node .op == "call_function" and node .target .__name__ == "map_impl" :
127+ body_graph , num_mapped , * operands = node .args
128+ body_gm = getattr (gm , body_graph .name )
129+ inp_pos_to_buffer_name_for_submod = {}
130+ real_operands = []
131+ for ix , operand in enumerate (operands ):
132+ if operand .target in inp_pos_to_param_buffer_name .values ():
133+ inp_pos_to_buffer_name_for_submod [ix ] = operand .target
134+ body_gm .register_buffer (operand .target , state_dict [operand .target ])
135+ else :
136+ real_operands .append (operand )
137+ node .args = (body_graph , num_mapped , * real_operands )
138+
139+ _ , in_spec = pytree .tree_flatten (real_operands )
140+
141+ _unlift (
142+ body_gm , inp_pos_to_buffer_name_for_submod , in_spec , None , state_dict
143+ )
144+ gm .graph .lint ()
145+ gm .graph .eliminate_dead_code ()
146+ gm .recompile ()
147+ return gm
148+
149+
150+ def unlift_exported_program_lifted_states (
151+ ep : torch ._export .exported_program .ExportedProgram ,
152+ ):
153+ new_gm = copy .deepcopy (ep .graph_module )
154+
155+ # TODO Fix the period in params/buffers names later
156+ # maybe a pass to replace graph signature with fixed names
157+ param_buffer_name_to_corrected_name = {}
158+
159+ for name , stuff in ep .state_dict .items ():
160+ if name in ep .graph_signature .buffers :
161+ if "." in name :
162+ new_gm .register_buffer (name .replace ("." , "_" ), stuff )
163+ param_buffer_name_to_corrected_name [name ] = name .replace ("." , "_" )
164+ else :
165+ new_gm .register_buffer (name , stuff )
166+ elif name in ep .graph_signature .parameters :
167+ if "." in name :
168+ new_gm .register_parameter (name .replace ("." , "_" ), stuff )
169+ param_buffer_name_to_corrected_name [name ] = name .replace ("." , "_" )
170+ else :
171+ new_gm .register_parameter (name , stuff )
172+ else :
173+ raise AssertionError ("encountered not registered param/buffer" )
174+
175+ count = 0
176+ inp_pos_to_param_buffer_name = {}
177+ for node in new_gm .graph .nodes :
178+ if node .op == "placeholder" :
179+ if node .name in ep .graph_signature .inputs_to_buffers :
180+ buffer_name = ep .graph_signature .inputs_to_buffers [node .name ]
181+ if buffer_name in param_buffer_name_to_corrected_name :
182+ inp_pos_to_param_buffer_name [
183+ count
184+ ] = param_buffer_name_to_corrected_name [buffer_name ]
185+ else :
186+ inp_pos_to_param_buffer_name [count ] = buffer_name
187+ if node .name in ep .graph_signature .inputs_to_parameters :
188+ param_name = ep .graph_signature .inputs_to_parameters [node .name ]
189+ if param_name in param_buffer_name_to_corrected_name :
190+ inp_pos_to_param_buffer_name [
191+ count
192+ ] = param_buffer_name_to_corrected_name [param_name ]
193+ else :
194+ inp_pos_to_param_buffer_name [count ] = param_name
195+ count += 1
196+ new_gm = _unlift (
197+ new_gm ,
198+ inp_pos_to_param_buffer_name ,
199+ ep .call_spec .in_spec ,
200+ ep .call_spec .out_spec ,
201+ ep .state_dict ,
202+ )
203+ return new_gm
204+
205+
58206@compatibility (is_backward_compatible = False )
59207@dataclass
60208class CaptureConfig :
@@ -63,6 +211,7 @@ class CaptureConfig:
63211 enable_dynamic_shape : bool = False
64212 enable_aot : bool = False
65213 _dynamo_config : "ExirDynamoConfig" = ExirDynamoConfig ()
214+ _unlift : bool = False
66215
67216
68217@compatibility (is_backward_compatible = False )
@@ -400,8 +549,15 @@ def capture(
400549 "Functionalization is required for enable_aot." ,
401550 )
402551
403- ep = export (f , args , _add_runtime_assertions = False , constraints = constraints )
404- return ep # pyre-ignore
552+ # TODO remove this later
553+ with patch ("torch._export.DECOMP_TABLE" , _default_decomposition_table ()):
554+ ep = export (
555+ f , args , _add_runtime_assertions = False , constraints = constraints
556+ )
557+ ep = ep .transform (ReplaceViewOpsWithViewCopyOpsPass ())
558+ if not config ._unlift :
559+ return ep # pyre-ignore
560+ graph_module = unlift_exported_program_lifted_states (ep )
405561
406562 elif config .enable_dynamic_shape :
407563 if not config ._dynamo_config .dynamic_shapes :
0 commit comments