@@ -34,37 +34,6 @@ class ExportedProgramImporter(BaseFXGraphImporter):
3434
3535 from torch import fx
3636
37- def create_input_vars (
38- self , exported_program : torch .export .ExportedProgram
39- ) -> Tuple [List [relax .Var ], List [relax .Var ]]:
40- """Create relax input vars."""
41- parameters_buffers_constants = []
42- user_inputs = []
43- for spec in exported_program .graph_signature .input_specs :
44- name_hint = spec .arg .name
45- if spec .kind is torch .export .graph_signature .InputKind .CONSTANT_TENSOR :
46- shape = exported_program .tensor_constants [spec .target ].shape
47- torch_dtype = exported_program .tensor_constants [spec .target ].dtype
48- elif spec .kind is torch .export .graph_signature .InputKind .USER_INPUT :
49- for node in exported_program .graph .find_nodes (op = "placeholder" , target = spec .target ):
50- if node .name == name_hint :
51- shape = node .meta ["tensor_meta" ].shape
52- torch_dtype = node .meta ["tensor_meta" ].dtype
53- break
54- else :
55- # PARAMETER or BUFFER
56- shape = exported_program .state_dict [spec .target ].shape
57- torch_dtype = exported_program .state_dict [spec .target ].dtype
58-
59- dtype = self ._convert_data_type (torch_dtype )
60- relax_var = relax .Var (name_hint , relax .TensorStructInfo (shape , dtype ))
61- if spec .kind is torch .export .graph_signature .InputKind .USER_INPUT :
62- user_inputs .append (relax_var )
63- else :
64- parameters_buffers_constants .append (relax_var )
65-
66- return parameters_buffers_constants , user_inputs
67-
6837 ########## Unary Ops ##########
6938
7039 def _hardtanh (self , node : fx .Node ) -> relax .Expr :
@@ -178,6 +147,8 @@ def _slice(self, node: fx.Node) -> relax.Var:
178147 stride = [node .args [4 ] if len (node .args ) > 4 else 1 ]
179148 return self .block_builder .emit (relax .op .strided_slice (x , axes , begin , end , stride ))
180149
150+ ########## Others ##########
151+
181152 def create_convert_map (
182153 self ,
183154 ) -> Dict [str , Callable [[fx .Node ], relax .Var ]]:
@@ -293,6 +264,37 @@ def create_convert_map(
293264 "getitem" : self ._getitem ,
294265 }
295266
267+ def create_input_vars (
268+ self , exported_program : torch .export .ExportedProgram
269+ ) -> Tuple [Dict [str , relax .Var ], Dict [str , relax .Var ]]:
270+ """Create relax input vars."""
271+ parameters_buffers_constants = OrderedDict ()
272+ user_inputs = OrderedDict ()
273+ for spec in exported_program .graph_signature .input_specs :
274+ name_hint = spec .arg .name
275+ if spec .kind is torch .export .graph_signature .InputKind .CONSTANT_TENSOR :
276+ shape = exported_program .tensor_constants [spec .target ].shape
277+ torch_dtype = exported_program .tensor_constants [spec .target ].dtype
278+ elif spec .kind is torch .export .graph_signature .InputKind .USER_INPUT :
279+ for node in exported_program .graph .find_nodes (op = "placeholder" , target = spec .target ):
280+ if node .name == name_hint :
281+ shape = node .meta ["tensor_meta" ].shape
282+ torch_dtype = node .meta ["tensor_meta" ].dtype
283+ break
284+ else :
285+ # PARAMETER or BUFFER
286+ shape = exported_program .state_dict [spec .target ].shape
287+ torch_dtype = exported_program .state_dict [spec .target ].dtype
288+
289+ dtype = self ._convert_data_type (torch_dtype )
290+ relax_var = relax .Var (name_hint , relax .TensorStructInfo (shape , dtype ))
291+ if spec .kind is torch .export .graph_signature .InputKind .USER_INPUT :
292+ user_inputs [name_hint ] = relax_var
293+ else :
294+ parameters_buffers_constants [name_hint ] = relax_var
295+
296+ return parameters_buffers_constants , user_inputs
297+
296298 def from_exported_program (
297299 self ,
298300 exported_program : torch .export .ExportedProgram ,
@@ -305,7 +307,8 @@ def from_exported_program(
305307
306308 # Create input variables.
307309 parameter_buffer_constant_vars , user_input_vars = self .create_input_vars (exported_program )
308- inputs_vars = parameter_buffer_constant_vars + user_input_vars
310+ inputs_vars = user_input_vars .copy ()
311+ inputs_vars .update (parameter_buffer_constant_vars )
309312
310313 # Initialize the block builder with a function and a dataflow block.
311314 self .block_builder = relax .BlockBuilder ()
@@ -314,7 +317,7 @@ def from_exported_program(
314317
315318 nodes : List [fx .Node ] = exported_program .graph .nodes
316319 with self .block_builder .function (
317- name = func_name , params = inputs_vars .copy (), attrs = func_attrs
320+ name = func_name , params = list ( inputs_vars . values ()) .copy (), attrs = func_attrs
318321 ):
319322 output = None
320323 with self .block_builder .dataflow ():
@@ -325,7 +328,7 @@ def from_exported_program(
325328 # Ignore sym input
326329 continue
327330
328- self .env [node ] = inputs_vars . pop ( 0 )
331+ self .env [node ] = inputs_vars [ node . name ]
329332 elif node .op == "output" :
330333 args = self .retrieve_args (node )
331334 assert len (args ) == 1
0 commit comments