diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 932607287571..65390d803229 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -328,6 +328,10 @@ def create_convert_map( # binary "add.Tensor": self._binary_op(relax.op.add, operator.add), "add_.Tensor": self._binary_op(relax.op.add, operator.add), + "bitwise_or_.Scalar": self._binary_op(relax.op.bitwise_or, operator.or_), + "bitwise_or.Scalar": self._binary_op(relax.op.bitwise_or, operator.or_), + "bitwise_or_.Tensor": self._binary_op(relax.op.bitwise_or, operator.or_), + "bitwise_or.Tensor": self._binary_op(relax.op.bitwise_or, operator.or_), "div.Tensor": self._binary_op(relax.op.divide, operator.truediv), "eq.Scalar": self._binary_op(relax.op.equal, operator.eq), "eq.Tensor": self._binary_op(relax.op.equal, operator.eq), diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 5a34befb9296..8bf6ea82cf88 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -674,6 +674,8 @@ def create_convert_map( # binary "add": self._binary_op(relax.op.add, operator.add), "and_": self._binary_op(relax.op.bitwise_and, operator.and_), + "bitwise_or_": self._binary_op(relax.op.bitwise_or, operator.or_), + "bitwise_or": self._binary_op(relax.op.bitwise_or, operator.or_), "eq": self._binary_op(relax.op.equal, operator.eq), "floordiv": self._binary_op(relax.op.floor_divide, operator.floordiv), "ge": self._binary_op(relax.op.greater_equal, operator.ge), diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 80c0bd5fb4f5..c8ed33eea793 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -845,6 +845,8 @@ def main( operator_binary_1 = [ (operator.add, R.add), (torch.ops.aten.add_, R.add), + (torch.ops.aten.bitwise_or, R.bitwise_or), + (torch.ops.aten.bitwise_or_, R.bitwise_or), (operator.sub, R.subtract), (operator.mul, R.multiply), (torch.ops.aten.mul_, R.multiply), diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index c52255638072..4c6e5475b22b 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -1769,6 +1769,8 @@ def main( operator_binary_3 = [ + (torch.ops.aten.bitwise_or_, R.bitwise_or), + (torch.ops.aten.bitwise_or, R.bitwise_or), (operator.lshift, R.left_shift), (operator.rshift, R.right_shift), (operator.and_, R.bitwise_and),