diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 098c13f04e9d..32c98efa6963 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -1460,6 +1460,8 @@ def ret(val): ret : PrimExpr The return expression """ + + val = convert(val) return call_intrin(val.dtype, "tir.ret", val) @@ -1645,6 +1647,7 @@ def exp(x): y : PrimExpr The result. """ + x = convert(x) return call_intrin(x.dtype, "tir.exp", x) @@ -1661,6 +1664,7 @@ def exp2(x): y : PrimExpr The result. """ + x = convert(x) return call_intrin(x.dtype, "tir.exp2", x) @@ -1677,6 +1681,7 @@ def exp10(x): y : PrimExpr The result. """ + x = convert(x) return call_intrin(x.dtype, "tir.exp10", x) @@ -1693,6 +1698,7 @@ def erf(x): y : PrimExpr The result. """ + x = convert(x) return call_intrin(x.dtype, "tir.erf", x) @@ -1709,6 +1715,7 @@ def tanh(x): y : PrimExpr The result. """ + x = convert(x) return call_intrin(x.dtype, "tir.tanh", x) @@ -1725,6 +1732,7 @@ def sigmoid(x): y : PrimExpr The result. """ + x = convert(x) return call_intrin(x.dtype, "tir.sigmoid", x) @@ -1741,6 +1749,7 @@ def log(x): y : PrimExpr The result. """ + x = convert(x) return call_intrin(x.dtype, "tir.log", x) @@ -1757,6 +1766,7 @@ def log2(x): y : PrimExpr The result. """ + x = convert(x) return call_intrin(x.dtype, "tir.log2", x) @@ -1773,6 +1783,7 @@ def log10(x): y : PrimExpr The result. """ + x = convert(x) return call_intrin(x.dtype, "tir.log10", x) @@ -1789,6 +1800,7 @@ def log1p(x): y : PrimExpr The result. """ + x = convert(x) return call_intrin(x.dtype, "tir.log1p", x) @@ -1805,6 +1817,7 @@ def tan(x): y : PrimExpr The result. """ + x = convert(x) return call_intrin(x.dtype, "tir.tan", x) @@ -1821,6 +1834,7 @@ def cos(x): y : PrimExpr The result. """ + x = convert(x) return call_intrin(x.dtype, "tir.cos", x) @@ -1837,6 +1851,7 @@ def cosh(x): y : PrimExpr The result. """ + x = convert(x) return call_intrin(x.dtype, "tir.cosh", x) @@ -1853,6 +1868,7 @@ def acos(x): y : PrimExpr The result. """ + x = convert(x) return call_intrin(x.dtype, "tir.acos", x) @@ -1869,6 +1885,7 @@ def acosh(x): y : PrimExpr The result. """ + x = convert(x) return call_intrin(x.dtype, "tir.acosh", x) @@ -1885,6 +1902,7 @@ def sin(x): y : PrimExpr The result. """ + x = convert(x) return call_intrin(x.dtype, "tir.sin", x) @@ -1901,6 +1919,7 @@ def sinh(x): y : PrimExpr The result. """ + x = convert(x) return call_intrin(x.dtype, "tir.sinh", x) @@ -1917,6 +1936,7 @@ def asin(x): y : PrimExpr The result. """ + x = convert(x) return call_intrin(x.dtype, "tir.asin", x) @@ -1933,6 +1953,7 @@ def asinh(x): y : PrimExpr The result. """ + x = convert(x) return call_intrin(x.dtype, "tir.asinh", x) @@ -1949,6 +1970,7 @@ def atan(x): y : PrimExpr The result. """ + x = convert(x) return call_intrin(x.dtype, "tir.atan", x) @@ -1965,6 +1987,7 @@ def atanh(x): y : PrimExpr The result. """ + x = convert(x) return call_intrin(x.dtype, "tir.atanh", x) @@ -1984,6 +2007,8 @@ def atan2(x1, x2): y : PrimExpr The result. """ + x1 = convert(x1) + x2 = convert(x2) return call_intrin(x1.dtype, "tir.atan2", x1, x2) @@ -2000,6 +2025,7 @@ def sqrt(x): y : PrimExpr The result. """ + x = convert(x) return call_intrin(x.dtype, "tir.sqrt", x) @@ -2016,6 +2042,7 @@ def rsqrt(x): y : PrimExpr The result. """ + x = convert(x) return call_intrin(x.dtype, "tir.rsqrt", x) @@ -2261,6 +2288,8 @@ def nextafter(x1, x2): y : PrimExpr The result. """ + x1 = convert(x1) + x2 = convert(x2) return call_intrin(x1.dtype, "tir.nextafter", x1, x2) # type: ignore @@ -2280,6 +2309,8 @@ def hypot(x1, x2): y : PrimExpr The result. """ + x1 = convert(x1) + x2 = convert(x2) return call_intrin(x1.dtype, "tir.hypot", x1, x2) # type: ignore @@ -2299,6 +2330,8 @@ def copysign(x1, x2): y : PrimExpr The result. """ + x1 = convert(x1) + x2 = convert(x2) return call_intrin(x1.dtype, "tir.copysign", x1, x2) # type: ignore @@ -2318,6 +2351,8 @@ def ldexp(x1, x2): y : PrimExpr The result. """ + x1 = convert(x1) + x2 = convert(x2) return call_intrin(x1.dtype, "tir.ldexp", x1, x2) # type: ignore @@ -2474,6 +2509,7 @@ def popcount(x): y : PrimExpr The result. """ + x = convert(x) return call_intrin(x.dtype, "tir.popcount", x) @@ -2605,6 +2641,8 @@ def fmod(x, y): z : PrimExpr The result. """ + x = convert(x) + y = convert(y) return call_intrin(x.dtype, "tir.fmod", x, y) diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 58be4e14d0a8..d36641dfc28f 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3882,6 +3882,62 @@ def subroutine(): return mod +def return_zero(): + @T.prim_func + def func() -> T.int32: + T.ret(0) + + return func + + +def op_of_literal(): + op_list = [ + (T.exp, 0), + (T.exp2, 0), + (T.exp10, 0), + (T.erf, 0.0), + (T.tanh, 0.0), + (T.sigmoid, 0.0), + (T.log, 0.0), + (T.log2, 0.0), + (T.log1p, 0.0), + (T.tan, 0.0), + (T.cos, 0.0), + (T.acos, 0.0), + (T.acosh, 0.0), + (T.sin, 0.0), + (T.sinh, 0.0), + (T.asin, 0.0), + (T.asinh, 0.0), + (T.atan, 0.0), + (T.atanh, 0.0), + (T.atan2, (1.0, 0.0)), + (T.sqrt, 0.0), + (T.rsqrt, 1.0), + (T.nextafter, (0.0, 1.0)), + (T.hypot, (1.0, 1.0)), + (T.copysign, (1.0, 1.0)), + (T.popcount, 0), + (T.fmod, (1.0, 1.0)), + ] + + def make_ir_generator(op, arg): + def inner(): + call_expr = op(*arg) if isinstance(arg, tuple) else op(arg) + + @T.prim_func + def func(): + T.evaluate(call_expr) + + return func + + inner.__name__ = f"{op.__name__}_of_literal" + return inner + + for op, arg in op_list: + yield make_ir_generator(op, arg) + + ir_generator = tvm.testing.parameter( launch_env_thread, opt_gemm_normalize, @@ -3958,6 +4014,8 @@ def subroutine(): undefined_stride_in_decl_buffer, undefined_elem_offset_in_decl_buffer, subroutine_call_without_arguments, + return_zero, + *op_of_literal(), )