@@ -119,23 +119,6 @@ def _retrieve_args(self, node):
119119 else :
120120 return node
121121
122- @staticmethod
123- def _promote_binary_op_args (lhs , rhs ):
124- if isinstance (lhs , relax .Expr ) and isinstance (rhs , relax .Expr ):
125- return lhs , rhs
126- elif isinstance (lhs , relax .Expr ):
127- assert isinstance (lhs .struct_info , relax .TensorStructInfo )
128- return lhs , relax .const (rhs , lhs .struct_info .dtype )
129- elif isinstance (rhs , relax .Expr ):
130- assert isinstance (rhs .struct_info , relax .TensorStructInfo )
131- return relax .const (lhs , rhs .struct_info .dtype ), rhs
132- else :
133- assert False
134-
135- def _call_binary_op (self , op , lhs , rhs ):
136- lhs , rhs = TorchFXImporter ._promote_binary_op_args (lhs , rhs )
137- return self .block_builder .emit (op (lhs , rhs ))
138-
139122 ########## Unary Ops ##########
140123
141124 def _unary_op (self , op : Callable ) -> Callable :
@@ -246,17 +229,29 @@ def _binary_op(self, relax_op: Callable, intrinsic_op: Callable) -> Callable:
246229 from torch import fx
247230
248231 def convert (node : fx .Node ) -> relax .Var :
232+ def promote_binary_op_args (lhs , rhs ):
233+ if isinstance (lhs , relax .Expr ) and isinstance (rhs , relax .Expr ):
234+ return lhs , rhs
235+ elif isinstance (lhs , relax .Expr ):
236+ assert isinstance (lhs .struct_info , relax .TensorStructInfo )
237+ return lhs , relax .const (rhs , lhs .struct_info .dtype )
238+ elif isinstance (rhs , relax .Expr ):
239+ assert isinstance (rhs .struct_info , relax .TensorStructInfo )
240+ return relax .const (lhs , rhs .struct_info .dtype ), rhs
241+ else :
242+ assert False
243+
244+ def call_binary_op (op , lhs , rhs ):
245+ lhs , rhs = promote_binary_op_args (lhs , rhs )
246+ return self .block_builder .emit (op (lhs , rhs ))
247+
249248 lhs , rhs = self .retrieve_args (node )
250249 if isinstance (lhs , relax .Var ) or isinstance (rhs , relax .Var ):
251- return self . _call_binary_op (relax_op , lhs , rhs )
250+ return call_binary_op (relax_op , lhs , rhs )
252251 elif isinstance (lhs , relax .expr .Constant ):
253- return self ._call_binary_op (
254- relax_op , lhs , relax .const (rhs , dtype = lhs .struct_info .dtype )
255- )
252+ return call_binary_op (relax_op , lhs , relax .const (rhs , dtype = lhs .struct_info .dtype ))
256253 elif isinstance (rhs , relax .expr .Constant ):
257- return self ._call_binary_op (
258- relax_op , relax .const (lhs , dtype = rhs .struct_info .dtype ), rhs
259- )
254+ return call_binary_op (relax_op , relax .const (lhs , dtype = rhs .struct_info .dtype ), rhs )
260255 return intrinsic_op (lhs , rhs )
261256
262257 return convert
0 commit comments