diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 548320bd854e..c3bf8f045410 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -132,6 +132,54 @@ def convert(node: fx.Node) -> relax.Var: return convert + ########## Binary Ops ############## + + def _binary_op_inplace(self, relax_op: Callable, intrinsic_op: Callable) -> Callable: + from torch import fx + + def convert(node: fx.Node) -> relax.Var: + def promote_binary_op_args(lhs, rhs): + if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr): + return lhs, rhs + elif isinstance(lhs, relax.Expr): + assert isinstance(lhs.struct_info, relax.TensorStructInfo) + return lhs, relax.const(rhs, lhs.struct_info.dtype) + elif isinstance(rhs, relax.Expr): + assert isinstance(rhs.struct_info, relax.TensorStructInfo) + return relax.const(lhs, rhs.struct_info.dtype), rhs + else: + assert False + + def call_binary_op(op, lhs, rhs): + lhs, rhs = promote_binary_op_args(lhs, rhs) + return self.block_builder.emit(op(lhs, rhs)) + + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + output = call_binary_op(relax_op, lhs, rhs) + self.env[node.args[0]] = output + return output + + elif isinstance(lhs, relax.expr.Constant): + output = call_binary_op( + relax_op, lhs, relax.const(rhs, dtype=lhs.struct_info.dtype) + ) + self.env[node.args[0]] = output + return output + + elif isinstance(rhs, relax.expr.Constant): + output = call_binary_op( + relax_op, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs + ) + self.env[node.args[0]] = output + return output + + output = intrinsic_op(lhs, rhs) + self.env[node.args[0]] = output + return output + + return convert + ########## Neural Network ########## def _adaptive_avg_pool2d_module(self, node: fx.Node) -> relax.Var: @@ -679,7 +727,7 @@ def create_convert_map( # binary "add": self._binary_op(relax.op.add, operator.add), "and_": self._binary_op(relax.op.bitwise_and, operator.and_), - "bitwise_or_": self._binary_op(relax.op.bitwise_or, operator.or_), + "bitwise_or_": self._binary_op_inplace(relax.op.bitwise_or, operator.or_), "bitwise_or": self._binary_op(relax.op.bitwise_or, operator.or_), "eq": self._binary_op(relax.op.equal, operator.eq), "floordiv": self._binary_op(relax.op.floor_divide, operator.floordiv),