Skip to content

File tree

6 files changed

+137
-2
lines changed

6 files changed

+137
-2
lines changed

include/tvm/tir/op.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,7 @@ TVM_DECLARE_INTRIN_UNARY(rsqrt);
701701
TVM_DECLARE_INTRIN_UNARY(log);
702702
TVM_DECLARE_INTRIN_UNARY(log2);
703703
TVM_DECLARE_INTRIN_UNARY(log10);
704+
TVM_DECLARE_INTRIN_UNARY(log1p);
704705
TVM_DECLARE_INTRIN_UNARY(popcount);
705706
TVM_DECLARE_INTRIN_UNARY(tan);
706707
TVM_DECLARE_INTRIN_UNARY(cos);

python/tvm/tir/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,15 @@
5050
from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace
5151
from .op import tvm_stack_alloca, tvm_stack_make_shape, tvm_stack_make_array
5252
from .op import tvm_tuple, tvm_struct_get, tvm_struct_set
53-
from .op import assume, undef
53+
from .op import address_of, lookup_param, assume, undef
54+
from .op import infinity, reinterpret
5455
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
5556
from .op import sin, sinh, asin, asinh
5657
from .op import cos, cosh, acos, acosh
5758
from .op import tan, tanh, atan, atan2, atanh
5859
from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot
5960
from .op import trunc, abs, round, nextafter, nearbyint, power, popcount, fmod, if_then_else
60-
from .op import likely, isnan, isfinite, isinf, copysign
61+
from .op import likely, isnan, isnullptr, isfinite, isinf, copysign
6162
from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, ceildiv
6263
from .op import comm_reducer, min, max, sum
6364
from .op import q_multiply_shift

python/tvm/tir/op.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,44 @@ def tvm_struct_set(arr, index, field, value):
466466
return call_intrin("handle", "tir.tvm_struct_set", arr, index, field, value)
467467

468468

469+
def address_of(buffer_load, span=None):
470+
"""Returns the address of an element in the buffer
471+
472+
Parameters
473+
----------
474+
buffer_load: BufferLoad
475+
The buffer load.
476+
477+
span : Optional[Span]
478+
The location of this operator in the source code.
479+
480+
Returns
481+
-------
482+
call : PrimExpr
483+
The call expression.
484+
"""
485+
return call_intrin("handle", "tir.address_of", buffer_load, span=span)
486+
487+
488+
def lookup_param(param_name, span=None):
489+
"""Returns the param by name
490+
491+
Parameters
492+
----------
493+
param_name : str
494+
The name of param.
495+
496+
span : Optional[Span]
497+
The location of this operator in the source code.
498+
499+
Returns
500+
-------
501+
call : PrimExpr
502+
The call expression.
503+
"""
504+
return call_intrin("handle", "tir.lookup_param", param_name, span=span)
505+
506+
469507
def ret(val):
470508
"""Create a tir return expression
471509
@@ -610,6 +648,47 @@ def max_value(dtype: str, span: Optional[Span] = None) -> Any:
610648
return _ffi_api.max_value(dtype, span) # type: ignore
611649

612650

651+
def infinity(dtype: str, span: Optional[Span] = None) -> Any:
652+
"""infinity value of dtype
653+
654+
Parameters
655+
----------
656+
dtype : str
657+
The data type.
658+
659+
span : Optional[Span]
660+
The location of this operator in the source code.
661+
662+
Returns
663+
-------
664+
value : tvm.Expr
665+
The infinity value of dtype.
666+
"""
667+
return _ffi_api.infinity(dtype, span) # type: ignore
668+
669+
670+
def reinterpret(dtype, value) -> Any:
671+
"""infinity value of dtype
672+
673+
Parameters
674+
----------
675+
dtype : str
676+
The data type.
677+
678+
value : PrimExpr
679+
The input value.
680+
681+
span : Optional[Span]
682+
The location of this operator in the source code.
683+
684+
Returns
685+
-------
686+
value : tvm.Expr
687+
The reinterpret cast value of dtype.
688+
"""
689+
return call_intrin(dtype, "tir.reinterpret", value)
690+
691+
613692
def exp(x):
614693
"""Take exponential of input x.
615694
@@ -1253,6 +1332,25 @@ def isnan(x, span=None):
12531332
return _ffi_api.isnan(x, span) # type: ignore
12541333

12551334

1335+
def isnullptr(x, span=None):
1336+
"""Check if input value is nullptr.
1337+
1338+
Parameters
1339+
----------
1340+
x : PrimExpr
1341+
Input argument.
1342+
1343+
span : Optional[Span]
1344+
The location of this operator in the source code.
1345+
1346+
Returns
1347+
-------
1348+
y : PrimExpr
1349+
The result.
1350+
"""
1351+
return call_intrin("bool", "tir.isnullptr", x, span=span) # type: ignore
1352+
1353+
12561354
def isfinite(x, span=None):
12571355
"""Check if input value is finite.
12581356

src/tir/op/op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -929,6 +929,8 @@ TVM_REGISTER_GLOBAL("tir.min_value").set_body_typed(min_value);
929929

930930
TVM_REGISTER_GLOBAL("tir.max_value").set_body_typed(max_value);
931931

932+
TVM_REGISTER_GLOBAL("tir.infinity").set_body_typed(infinity);
933+
932934
TVM_REGISTER_GLOBAL("tir.abs").set_body_typed(tvm::abs);
933935

934936
TVM_REGISTER_GLOBAL("tir.likely").set_body_typed(tvm::likely);

tests/python/unittest/test_tir_nodes.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,12 @@ def test_divide_by_zero():
301301
pass
302302

303303

304+
def test_infinity():
305+
assert str(tvm.tir.infinity("float16")) == "inff16"
306+
assert str(tvm.tir.infinity("float32")) == "inff32"
307+
assert str(tvm.tir.infinity("float64")) == "inff64"
308+
309+
304310
def test_isnan():
305311
x = te.var("x", "float32")
306312
assert str(tvm.tir.isnan(x)) == "@tir.isnan(x: float32, dtype=bool)"

tests/python/unittest/test_tir_op_types.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,29 @@ def test_tir_op_tvm_struct_set():
3939
assert expr.op.name == "tir.tvm_struct_set"
4040

4141

42+
def test_tir_op_address_of():
43+
buffer = tir.decl_buffer((128), "float32")
44+
expr = tir.address_of(buffer[0])
45+
assert expr.op.name == "tir.address_of"
46+
47+
48+
def test_tir_op_lookup_param():
49+
expr = tir.lookup_param("p0")
50+
assert expr.op.name == "tir.lookup_param"
51+
52+
53+
def test_tir_op_reinterpret():
54+
x = tir.Var("x", dtype="int32")
55+
expr = tir.reinterpret("float32", x)
56+
assert expr.op.name == "tir.reinterpret"
57+
58+
59+
def test_tir_op_isnullptr():
60+
x = tir.Var("x", dtype="int32")
61+
expr = tir.isnullptr(x)
62+
assert expr.op.name == "tir.isnullptr"
63+
64+
4265
def test_tir_op_call_assume():
4366
x = tir.Var("x", dtype="int32")
4467
expr = tir.assume(cond=x)
@@ -60,6 +83,10 @@ def test_tir_op_call_likely():
6083
test_tir_op_tvm_tuple()
6184
test_tir_op_tvm_struct_get()
6285
test_tir_op_tvm_struct_set()
86+
test_tir_op_address_of()
87+
test_tir_op_lookup_param()
88+
test_tir_op_reinterpret()
89+
test_tir_op_isnullptr()
6390
test_tir_op_call_assume()
6491
test_tir_op_call_undef()
6592
test_tir_op_call_likely()

0 commit comments

Comments
 (0)