Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,23 +660,29 @@ def create_convert_map(
"triu": self._tril_triu(relax.op.triu),
# binary
"add": self._binary_op(relax.op.add, operator.add),
"and_": self._binary_op(relax.op.bitwise_and, operator.and_),
"eq": self._binary_op(relax.op.equal, operator.eq),
"floordiv": self._binary_op(relax.op.floor_divide, operator.floordiv),
"ge": self._binary_op(relax.op.greater_equal, operator.ge),
"gt": self._binary_op(relax.op.greater, operator.gt),
"iadd": self._binary_op(relax.op.add, operator.add),
"le": self._binary_op(relax.op.less_equal, operator.le),
"lshift": self._binary_op(relax.op.left_shift, operator.lshift),
"lt": self._binary_op(relax.op.less, operator.lt),
"matmul": self._binary_op(
partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul
),
"max": self._binary_op(relax.op.maximum, max),
"min": self._binary_op(relax.op.minimum, min),
"mod": self._binary_op(relax.op.mod, operator.mod),
"mul": self._binary_op(relax.op.multiply, operator.mul),
"ne": self._binary_op(relax.op.not_equal, operator.ne),
"pow": self._binary_op(relax.op.power, operator.pow),
"or_": self._binary_op(relax.op.bitwise_or, operator.or_),
"rshift": self._binary_op(relax.op.right_shift, operator.rshift),
"sub": self._binary_op(relax.op.subtract, operator.sub),
"truediv": self._binary_op(relax.op.divide, operator.truediv),
"xor": self._binary_op(relax.op.bitwise_xor, operator.xor),
# neural network
"adaptive_avg_pool2d": self._adaptive_avg_pool2d,
"addmm": self._addmm,
Expand Down
228 changes: 228 additions & 0 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1485,6 +1485,8 @@ def main(
def test_binary():
input_info1 = [([1, 3, 10, 10], "float32"), ([1, 3, 10, 10], "float32")]
input_info2 = [([1, 3, 10, 10], "float32")]
input_info3 = [([1, 3, 10, 10], "int32"), ([1, 3, 10, 10], "int32")]
input_info4 = [([1, 3, 10, 10], "int32")]

# Add
class Add1(Module):
Expand Down Expand Up @@ -1962,6 +1964,211 @@ def main(
verify_model(Ne1(), input_info1, {}, expected23)
verify_model(Ne2(), input_info2, {}, expected24)

# Lshift
class LShift1(Module):
def forward(self, lhs, rhs):
return lhs << rhs

@tvm.script.ir_module
class expected25:
@R.function
def main(
lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
rhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="int32") = R.left_shift(lhs_1, rhs_1)
gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
R.output(gv)

return gv

class LShift2(Module):
def forward(self, lhs):
return lhs << 1

@tvm.script.ir_module
class expected26:
@R.function
def main(
lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="int32") = R.left_shift(lhs_1, R.const(1))
gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
R.output(gv)

return gv

verify_model(LShift1(), input_info3, {}, expected25)
verify_model(LShift2(), input_info4, {}, expected26)

# Rshift
class RShift1(Module):
def forward(self, lhs, rhs):
return lhs >> rhs

@tvm.script.ir_module
class expected27:
@R.function
def main(
lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
rhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="int32") = R.right_shift(lhs_1, rhs_1)
gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
R.output(gv)

return gv

class RShift2(Module):
def forward(self, lhs):
return lhs >> 1

@tvm.script.ir_module
class expected28:
@R.function
def main(
lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="int32") = R.right_shift(lhs_1, R.const(1))
gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
R.output(gv)

return gv

verify_model(RShift1(), input_info3, {}, expected27)
verify_model(RShift2(), input_info4, {}, expected28)

# Bitwise and
class BitwiseAnd1(Module):
def forward(self, lhs, rhs):
return lhs & rhs

@tvm.script.ir_module
class expected29:
@R.function
def main(
lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
rhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="int32") = R.bitwise_and(lhs_1, rhs_1)
gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
R.output(gv)

return gv

class BitwiseAnd2(Module):
def forward(self, lhs):
return lhs & 1

@tvm.script.ir_module
class expected30:
@R.function
def main(
lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="int32") = R.bitwise_and(lhs_1, R.const(1))
gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
R.output(gv)

return gv

verify_model(BitwiseAnd1(), input_info3, {}, expected29)
verify_model(BitwiseAnd2(), input_info4, {}, expected30)

# Bitwise or
class BitwiseOr1(Module):
def forward(self, lhs, rhs):
return lhs | rhs

@tvm.script.ir_module
class expected31:
@R.function
def main(
lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
rhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="int32") = R.bitwise_or(lhs_1, rhs_1)
gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
R.output(gv)

return gv

class BitwiseOr2(Module):
def forward(self, lhs):
return lhs | 1

@tvm.script.ir_module
class expected32:
@R.function
def main(
lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="int32") = R.bitwise_or(lhs_1, R.const(1))
gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
R.output(gv)

return gv

verify_model(BitwiseOr1(), input_info3, {}, expected31)
verify_model(BitwiseOr2(), input_info4, {}, expected32)

# Bitwise xor
class BitwiseXor1(Module):
def forward(self, lhs, rhs):
return lhs ^ rhs

@tvm.script.ir_module
class expected33:
@R.function
def main(
lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
rhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="int32") = R.bitwise_xor(lhs_1, rhs_1)
gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
R.output(gv)

return gv

class BitwiseXor2(Module):
def forward(self, lhs):
return lhs ^ 1

@tvm.script.ir_module
class expected34:
@R.function
def main(
lhs_1: R.Tensor((1, 3, 10, 10), dtype="int32"),
) -> R.Tensor((1, 3, 10, 10), dtype="int32"):
# block 0
with R.dataflow():
lv: R.Tensor((1, 3, 10, 10), dtype="int32") = R.bitwise_xor(lhs_1, R.const(1))
gv: R.Tensor((1, 3, 10, 10), dtype="int32") = lv
R.output(gv)

return gv

verify_model(BitwiseXor1(), input_info3, {}, expected33)
verify_model(BitwiseXor2(), input_info4, {}, expected34)


def test_size():
input_info = [([1, 3, 10, 10], "float32")]
Expand Down Expand Up @@ -3745,6 +3952,27 @@ def main(
verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")], {}, Expected1)


def test_min():
class Min(Module):
def forward(self, x, y):
return torch.min(x, y)

@I.ir_module
class Expected1:
@R.function
def main(
inp_0: R.Tensor((256, 256), dtype="float32"),
inp_1: R.Tensor((256, 256), dtype="float32"),
) -> R.Tensor((256, 256), dtype="float32"):
with R.dataflow():
lv: R.Tensor((256, 256), dtype="float32") = R.minimum(inp_0, inp_1)
gv: R.Tensor((256, 256), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Min(), [([256, 256], "float32"), ([256, 256], "float32")], {}, Expected1)


def test_attention():
@I.ir_module
class Expected1:
Expand Down
Loading