@@ -145,10 +145,11 @@ def add_submodule(self, module: torch.nn.Module) -> None:
145145 self .update_subpath (module , new_module_name )
146146 # self.written = True # not mark as written as graph break may happen
147147
148- def add_subparam (self , param : torch .nn .Parameter ) -> None :
148+ def add_subparam (self , param : torch .nn .Parameter ) -> str :
149149 new_param_name = "external_param__" + str (len (self .subparam_paths ))
150150 self .root .register_parameter (new_param_name , param )
151151 self .subparam_paths [param ] = new_param_name
152+ return new_param_name
152153
153154 def as_node_args_kwargs (
154155 self , args : list [Any ], kwargs : dict [str , Any ]
@@ -172,6 +173,11 @@ def as_fx_node(arg: Any) -> NodeArgs:
172173 if isinstance (arg , slice ):
173174 return slice (as_fx_node (arg .start ), as_fx_node (arg .stop ),
174175 as_fx_node (arg .step ))
176+ if isinstance (arg , np .ndarray ):
177+ param_name = self .add_subparam (
178+ torch .nn .Parameter (torch .tensor (arg ), requires_grad = False ))
179+ return self .fx_graph .create_node ("get_attr" , param_name , (), {})
180+
175181 var = self .objects .get (arg ,
176182 allow_unexist_const = True ,
177183 fx_graph = self .fx_graph )
@@ -192,6 +198,9 @@ def as_fx_node(arg: Any) -> NodeArgs:
192198 else :
193199 # TODO: record all operation in SymInt or SymFloat
194200 pass
201+
202+ if f"{ type (arg ).__module__ } .{ type (arg ).__qualname__ } " == "torch.tensortype" : # torch.LongTensor
203+ return f"torch.{ arg .__name__ } "
195204 return var .as_fx_node ()
196205
197206 if isinstance (args , torch .Tensor ):
@@ -225,6 +234,19 @@ def record_function(self,
225234 add_partial_var : bool = True ,
226235 inplace_ref : Any = None ,
227236 force_new_value : bool = False ) -> None :
237+ if hasattr (func , '__self__' ) and isinstance (
238+ func .__self__ , torch .autograd .grad_mode .no_grad ):
239+ if func .__name__ == '__enter__' :
240+ target_state = False
241+ elif func .__name__ == '__exit__' :
242+ target_state = func .__self__ .prev
243+ else :
244+ raise ValueError (func )
245+ args = [
246+ target_state ,
247+ ]
248+ func = torch ._C ._set_grad_enabled
249+ kwargs = {}
228250 pargs , pkwargs = self .as_node_args_kwargs (args , kwargs )
229251 if func in fx_graph_inplace_functions :
230252 scalar = None
@@ -268,6 +290,8 @@ def record_function(self,
268290 func = func_dict [func ]
269291 if func in math2torch :
270292 func = math2torch [func ]
293+ if func == torch .from_numpy :
294+ func = torch .tensor
271295
272296 self .written = True
273297 scalar2tensor : dict [Callable [..., Any ], Callable [..., Any ]] = {
@@ -1360,7 +1384,6 @@ def make_sub_var(value: Any, fx_node: torch.fx.Node) -> None:
13601384
13611385 self .state .inplace_update_objs .clear ()
13621386 self .state .partial_var .clear ()
1363- print ("clear partial var" )
13641387 self .state .written = False
13651388 self .state .unmark_calling_func ()
13661389 # print('process last instruction done')
@@ -1418,6 +1441,15 @@ def is_builtin_func(self, func: Callable[..., Any]) -> bool:
14181441 return func in (dict , tuple , set , list , hasattr , slice , range , len ,
14191442 type )
14201443
1444+ def is_numpy_constant_func (self , func : Callable [..., Any ]) -> bool :
1445+ print (dir (func ))
1446+ if (hasattr (func , '__module__' ) and 'numpy' in func .__module__ and
1447+ 'random' not in func .__module__ ):
1448+ return True
1449+ if type (func ) == np .ufunc :
1450+ return True
1451+ return False
1452+
14211453 def get_live_objs (self , pc : int = - 1 ) -> list [tuple [str , Any ]]:
14221454 if pc == - 1 :
14231455 pc = self .frame .f_lasti // 2
@@ -1603,6 +1635,8 @@ def set_if_inplace_return() -> None:
16031635 return
16041636 elif len (args ) > 0 and isinstance (args [0 ], torch .nn .ModuleList ):
16051637 return
1638+ elif self .is_numpy_constant_func (func ):
1639+ return
16061640 elif self .has_unknown_arg (args , kwargs ):
16071641 print (
16081642 f"func is { func } , { is_user_defined_func (func )} , args: { args } , kwargs:{ kwargs } "
@@ -1789,7 +1823,9 @@ def SETUP_FINALLY(self, _inst: Instruction) -> None:
17891823 pass
17901824
17911825 def SETUP_WITH (self , _inst : Instruction ) -> None :
1792- pass
1826+ mgr = get_value_stack_from_top (self .frame , 0 )
1827+ if type (mgr ) == torch .autograd .grad_mode .no_grad :
1828+ self .call_function (mgr .__enter__ , [], {})
17931829
17941830 # def WITH_EXCEPT_START(self, _inst: Instruction) -> None:
17951831 # pass
@@ -1873,9 +1909,9 @@ def LOAD_ATTR(self, inst: Instruction) -> None:
18731909 if inst .argval in obj_var .modified_attrs :
18741910 return
18751911 need_guard_check = obj_var .need_guard_check
1876- if obj == self .state .varargs and inst .argval in dir (tuple ):
1912+ if id ( obj ) == id ( self .state .varargs ) and inst .argval in dir (tuple ):
18771913 need_guard_check = False
1878- if obj == self .state .varkw and inst .argval in dir (dict ):
1914+ if id ( obj ) == id ( self .state .varkw ) and inst .argval in dir (dict ):
18791915 need_guard_check = False
18801916 if config .get_config ('dynshape' ) and isinstance (
18811917 obj , torch .Tensor ) and inst .argval == 'shape' :
@@ -1957,7 +1993,8 @@ def CALL_FUNCTION_KW(self, inst: Instruction) -> None:
19571993 '__self__' ) and func .__self__ is not None and not isinstance (
19581994 func .__self__ , ModuleType ):
19591995 args = [func .__self__ ] + list (args )
1960- # print(f"function kw: {func}, type: {type(func)},args:{args}, kwargs:{kwargs}")
1996+ for i , obj in enumerate (itertools .chain (args , kwargs .values ())):
1997+ self .state .fetch_function_parameters (obj )
19611998 self .call_function (func , args , kwargs )
19621999
19632000 def CALL_FUNCTION_EX (self , inst : Instruction ) -> None :
@@ -1973,6 +2010,9 @@ def CALL_FUNCTION_EX(self, inst: Instruction) -> None:
19732010 '__self__' ) and func .__self__ is not None and not isinstance (
19742011 func .__self__ , ModuleType ):
19752012 args = [func .__self__ ] + list (args )
2013+ if not isinstance (args , torch .Tensor ): # call(*x)
2014+ for i , obj in enumerate (itertools .chain (args , kwargs .values ())):
2015+ self .state .fetch_function_parameters (obj )
19762016 self .call_function (func , args , kwargs )
19772017
19782018 def STORE_FAST (self , inst : Instruction ) -> None :
@@ -2076,19 +2116,15 @@ def IMPORT_FROM(self, inst: Instruction) -> None:
20762116 pass
20772117
20782118 def UNPACK_SEQUENCE (self , inst : Instruction ) -> None :
2079- seq = get_value_stack_from_top (self .frame , 0 )
2080- if isinstance (seq , (tuple , list )):
2081- self .state .set_partial_var ({
2082- - 1 : [
2083- PartialVar (node = None ,
2084- need_guard_check = False ,
2085- extract_code_at_start = [],
2086- make_var_fn = vs .make_var_from_value )
2087- for _ in range (len (seq ))
2088- ]
2089- })
2090- else :
2091- raise NotImplementedError
2119+ self .state .set_partial_var ({
2120+ - 1 : [
2121+ PartialVar (node = None ,
2122+ need_guard_check = False ,
2123+ extract_code_at_start = [],
2124+ make_var_fn = vs .make_var_from_value )
2125+ for _ in range (inst .argval )
2126+ ]
2127+ })
20922128
20932129 def UNPACK_EX (self , inst : Instruction ) -> None :
20942130 seq = get_value_stack_from_top (self .frame , 0 )
0 commit comments