@@ -132,6 +132,54 @@ def convert(node: fx.Node) -> relax.Var:
132132
133133 return convert
134134
135+ ########## Binary Ops ##############
136+
137+ def _binary_op_inplace (self , relax_op : Callable , intrinsic_op : Callable ) -> Callable :
138+ from torch import fx
139+
140+ def convert (node : fx .Node ) -> relax .Var :
141+ def promote_binary_op_args (lhs , rhs ):
142+ if isinstance (lhs , relax .Expr ) and isinstance (rhs , relax .Expr ):
143+ return lhs , rhs
144+ elif isinstance (lhs , relax .Expr ):
145+ assert isinstance (lhs .struct_info , relax .TensorStructInfo )
146+ return lhs , relax .const (rhs , lhs .struct_info .dtype )
147+ elif isinstance (rhs , relax .Expr ):
148+ assert isinstance (rhs .struct_info , relax .TensorStructInfo )
149+ return relax .const (lhs , rhs .struct_info .dtype ), rhs
150+ else :
151+ assert False
152+
153+ def call_binary_op (op , lhs , rhs ):
154+ lhs , rhs = promote_binary_op_args (lhs , rhs )
155+ return self .block_builder .emit (op (lhs , rhs ))
156+
157+ lhs , rhs = self .retrieve_args (node )
158+ if isinstance (lhs , relax .Var ) or isinstance (rhs , relax .Var ):
159+ output = call_binary_op (relax_op , lhs , rhs )
160+ self .env [node .args [0 ]] = output
161+ return output
162+
163+ elif isinstance (lhs , relax .expr .Constant ):
164+ output = call_binary_op (
165+ relax_op , lhs , relax .const (rhs , dtype = lhs .struct_info .dtype )
166+ )
167+ self .env [node .args [0 ]] = output
168+ return output
169+
170+ elif isinstance (rhs , relax .expr .Constant ):
171+ output = call_binary_op (
172+ relax_op , relax .const (lhs , dtype = rhs .struct_info .dtype ), rhs
173+ )
174+ self .env [node .args [0 ]] = output
175+ return output
176+
177+ output = intrinsic_op (lhs , rhs )
178+ self .env [node .args [0 ]] = output
179+ return output
180+
181+ return convert
182+
135183 ########## Neural Network ##########
136184
137185 def _adaptive_avg_pool2d_module (self , node : fx .Node ) -> relax .Var :
@@ -679,7 +727,7 @@ def create_convert_map(
679727 # binary
680728 "add" : self ._binary_op (relax .op .add , operator .add ),
681729 "and_" : self ._binary_op (relax .op .bitwise_and , operator .and_ ),
682- "bitwise_or_" : self ._binary_op (relax .op .bitwise_or , operator .or_ ),
730+ "bitwise_or_" : self ._binary_op_inplace (relax .op .bitwise_or , operator .or_ ),
683731 "bitwise_or" : self ._binary_op (relax .op .bitwise_or , operator .or_ ),
684732 "eq" : self ._binary_op (relax .op .equal , operator .eq ),
685733 "floordiv" : self ._binary_op (relax .op .floor_divide , operator .floordiv ),
0 commit comments