Skip to content

Commit 47b95ca

Browse files
[Relax][FRONTEND][Pytorch] Add fmod support (#17893)
* Add fmod support * Fix lint issue
1 parent b1d1cdc commit 47b95ca

File tree

5 files changed

+19
-0
lines changed

5 files changed

+19
-0
lines changed

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,20 @@ def call_binary_op(op, lhs, rhs):
409409

410410
return convert
411411

412+
def _fmod(self, node: fx.Node):
413+
args = self.retrieve_args(node)
414+
lhs = args[0]
415+
rhs = args[1]
416+
if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr):
417+
return self.block_builder.emit(relax.op.mod(lhs, rhs))
418+
elif isinstance(lhs, relax.Expr):
419+
rhs = relax.const(rhs, lhs.struct_info.dtype)
420+
elif isinstance(rhs, relax.Expr):
421+
lhs = relax.const(lhs, rhs.struct_info.dtype)
422+
else:
423+
assert False
424+
return self.block_builder.emit(relax.op.mod(lhs, rhs))
425+
412426
def _rsub(self, node: fx.Node) -> relax.Var:
413427
args = self.retrieve_args(node)
414428
lhs = args[0]

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,8 @@ def create_convert_map(
345345
"eq.Scalar": self._binary_op(relax.op.equal, operator.eq),
346346
"eq.Tensor": self._binary_op(relax.op.equal, operator.eq),
347347
"floor_divide.default": self._binary_op(relax.op.floor_divide, operator.floordiv),
348+
"fmod.Scalar": self._fmod,
349+
"fmod.Tensor": self._fmod,
348350
"logaddexp.default": self._binary_op(relax.op.log_add_exp, torch.logaddexp),
349351
"ge.Scalar": self._binary_op(relax.op.greater_equal, operator.ge),
350352
"ge.Tensor": self._binary_op(relax.op.greater_equal, operator.ge),

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,7 @@ def create_convert_map(
732732
"bitwise_or": self._binary_op(relax.op.bitwise_or, operator.or_),
733733
"eq": self._binary_op(relax.op.equal, operator.eq),
734734
"floordiv": self._binary_op(relax.op.floor_divide, operator.floordiv),
735+
"fmod": self._fmod,
735736
"ge": self._binary_op(relax.op.greater_equal, operator.ge),
736737
"gt": self._binary_op(relax.op.greater, operator.gt),
737738
"iadd": self._binary_op(relax.op.add, operator.add),

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,7 @@ def main(
852852
(torch.ops.aten.mul_, R.multiply),
853853
(operator.truediv, R.divide),
854854
(operator.floordiv, R.floor_divide),
855+
(torch.ops.aten.fmod, R.mod),
855856
(operator.pow, R.power),
856857
(operator.mod, R.floor_mod),
857858
(operator.and_, R.bitwise_and),

tests/python/relax/test_frontend_from_fx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1653,6 +1653,7 @@ def main(
16531653
(operator.mul, R.multiply),
16541654
(operator.truediv, R.divide),
16551655
(operator.floordiv, R.floor_divide),
1656+
(torch.ops.aten.fmod, R.mod),
16561657
(operator.pow, R.power),
16571658
(operator.mod, R.floor_mod),
16581659
]

0 commit comments

Comments
 (0)