From ca92086370cfa2283f840aeb479fd1ab89d74c2c Mon Sep 17 00:00:00 2001 From: Kavin-mcw Date: Tue, 22 Apr 2025 12:01:24 +0530 Subject: [PATCH 1/3] Fix bug of inplace binary_op --- .../tvm/relax/frontend/torch/fx_translator.py | 43 ++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 548320bd854e..513d71ca895e 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -132,6 +132,47 @@ def convert(node: fx.Node) -> relax.Var: return convert + ########## Binary Ops ############## + + def _binary_op_inplace(self, relax_op: Callable, intrinsic_op: Callable) -> Callable: + from torch import fx + + def convert(node: fx.Node) -> relax.Var: + def promote_binary_op_args(lhs, rhs): + if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr): + return lhs, rhs + elif isinstance(lhs, relax.Expr): + assert isinstance(lhs.struct_info, relax.TensorStructInfo) + return lhs, relax.const(rhs, lhs.struct_info.dtype) + elif isinstance(rhs, relax.Expr): + assert isinstance(rhs.struct_info, relax.TensorStructInfo) + return relax.const(lhs, rhs.struct_info.dtype), rhs + else: + assert False + + def call_binary_op(op, lhs, rhs): + lhs, rhs = promote_binary_op_args(lhs, rhs) + return self.block_builder.emit(op(lhs, rhs)) + + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + output = call_binary_op(relax_op, lhs, rhs) + self.env[node.args[0]] = output + return output + elif isinstance(lhs, relax.expr.Constant): + output = call_binary_op(relax_op, lhs, relax.const(rhs, dtype=lhs.struct_info.dtype)) + self.env[node.args[0]] = output + return output + elif isinstance(rhs, relax.expr.Constant): + output = call_binary_op(relax_op, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs) + self.env[node.args[0]] = output + return output + output = intrinsic_op(lhs, rhs) + self.env[node.args[0]] = output + return output + + return convert + ########## Neural Network ########## def _adaptive_avg_pool2d_module(self, node: fx.Node) -> relax.Var: @@ -679,7 +720,7 @@ def create_convert_map( # binary "add": self._binary_op(relax.op.add, operator.add), "and_": self._binary_op(relax.op.bitwise_and, operator.and_), - "bitwise_or_": self._binary_op(relax.op.bitwise_or, operator.or_), + "bitwise_or_": self._binary_op_inplace(relax.op.bitwise_or, operator.or_), "bitwise_or": self._binary_op(relax.op.bitwise_or, operator.or_), "eq": self._binary_op(relax.op.equal, operator.eq), "floordiv": self._binary_op(relax.op.floor_divide, operator.floordiv), From acedd69838d08e0fe77a74fa55a71e3dc0a1801a Mon Sep 17 00:00:00 2001 From: Kavin-mcw Date: Tue, 22 Apr 2025 12:11:07 +0530 Subject: [PATCH 2/3] format changes --- python/tvm/relax/frontend/torch/fx_translator.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 513d71ca895e..680f07c1337f 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -159,14 +159,17 @@ def call_binary_op(op, lhs, rhs): output = call_binary_op(relax_op, lhs, rhs) self.env[node.args[0]] = output return output + elif isinstance(lhs, relax.expr.Constant): output = call_binary_op(relax_op, lhs, relax.const(rhs, dtype=lhs.struct_info.dtype)) self.env[node.args[0]] = output return output + elif isinstance(rhs, relax.expr.Constant): output = call_binary_op(relax_op, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs) self.env[node.args[0]] = output return output + output = intrinsic_op(lhs, rhs) self.env[node.args[0]] = output return output From 70be6173cc49abc983481c0929f5ff4e7b5b4f27 Mon Sep 17 00:00:00 2001 From: Kavin-mcw Date: Tue, 22 Apr 2025 12:30:00 +0530 Subject: [PATCH 3/3] Fix lint issue --- python/tvm/relax/frontend/torch/fx_translator.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 680f07c1337f..c3bf8f045410 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -161,12 +161,16 @@ def call_binary_op(op, lhs, rhs): return output elif isinstance(lhs, relax.expr.Constant): - output = call_binary_op(relax_op, lhs, relax.const(rhs, dtype=lhs.struct_info.dtype)) + output = call_binary_op( + relax_op, lhs, relax.const(rhs, dtype=lhs.struct_info.dtype) + ) self.env[node.args[0]] = output return output elif isinstance(rhs, relax.expr.Constant): - output = call_binary_op(relax_op, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs) + output = call_binary_op( + relax_op, relax.const(lhs, dtype=rhs.struct_info.dtype), rhs + ) self.env[node.args[0]] = output return output