Skip to content

Commit 3b141b9

Browse files
[TIR] Expose bitwise ops to python (#13945)
Expose `bitwise_and`, `bitwise_or`, `bitwise_not`, `bitwise_xor` to python
1 parent 256bad7 commit 3b141b9

File tree

4 files changed

+107
-0
lines changed

4 files changed

+107
-0
lines changed

python/tvm/script/ir_builder/tir/ir.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1616,6 +1616,10 @@ def wrapped(*args, **kwargs):
16161616
atan = _op_wrapper(_tir_op.atan)
16171617
atan2 = _op_wrapper(_tir_op.atan2)
16181618
atanh = _op_wrapper(_tir_op.atanh)
1619+
bitwise_and = _op_wrapper(_tir_op.bitwise_and)
1620+
bitwise_not = _op_wrapper(_tir_op.bitwise_not)
1621+
bitwise_or = _op_wrapper(_tir_op.bitwise_or)
1622+
bitwise_xor = _op_wrapper(_tir_op.bitwise_xor)
16191623
ceil = _op_wrapper(_tir_op.ceil)
16201624
clz = _op_wrapper(_tir_op.clz)
16211625
copysign = _op_wrapper(_tir_op.copysign)
@@ -1866,6 +1870,10 @@ def wrapped(*args, **kwargs):
18661870
"atan",
18671871
"atan2",
18681872
"atanh",
1873+
"bitwise_and",
1874+
"bitwise_not",
1875+
"bitwise_or",
1876+
"bitwise_xor",
18691877
"ceil",
18701878
"clz",
18711879
"copysign",

python/tvm/tir/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
from .op import sin, sinh, asin, asinh
6969
from .op import cos, cosh, acos, acosh
7070
from .op import tan, tanh, atan, atan2, atanh
71+
from .op import bitwise_and, bitwise_not, bitwise_or, bitwise_xor
7172
from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot
7273
from .op import trunc, abs, round, nextafter, nearbyint, power, pow, popcount, fmod, if_then_else
7374
from .op import likely, isnan, isnullptr, isfinite, isinf, copysign

python/tvm/tir/op.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1997,6 +1997,91 @@ def abs(x, span=None):
19971997
return _ffi_api.abs(x, span) # type: ignore
19981998

19991999

2000+
def bitwise_and(x, y, span=None):
2001+
"""Take bitwise and of two values
2002+
2003+
Parameters
2004+
----------
2005+
x : PrimExpr
2006+
Left operand
2007+
2008+
y : PrimExpr
2009+
Right operand
2010+
2011+
span : Optional[Span]
2012+
The location of this operator in the source code.
2013+
2014+
Returns
2015+
-------
2016+
res : PrimExpr
2017+
The result.
2018+
"""
2019+
return _ffi_api.bitwise_and(x, y, span)
2020+
2021+
2022+
def bitwise_not(x, span=None):
2023+
"""Take bitwise not of input value
2024+
2025+
Parameters
2026+
----------
2027+
x : PrimExpr
2028+
Input operand
2029+
2030+
span : Optional[Span]
2031+
The location of this operator in the source code.
2032+
2033+
Returns
2034+
-------
2035+
res : PrimExpr
2036+
The result.
2037+
"""
2038+
return _ffi_api.bitwise_not(x, span)
2039+
2040+
2041+
def bitwise_or(x, y, span=None):
2042+
"""Take bitwise or of two values
2043+
2044+
Parameters
2045+
----------
2046+
x : PrimExpr
2047+
Left operand
2048+
2049+
y : PrimExpr
2050+
Right operand
2051+
2052+
span : Optional[Span]
2053+
The location of this operator in the source code.
2054+
2055+
Returns
2056+
-------
2057+
res : PrimExpr
2058+
The result.
2059+
"""
2060+
return _ffi_api.bitwise_or(x, y, span)
2061+
2062+
2063+
def bitwise_xor(x, y, span=None):
2064+
"""Take bitwise xor of two values
2065+
2066+
Parameters
2067+
----------
2068+
x : PrimExpr
2069+
Left operand
2070+
2071+
y : PrimExpr
2072+
Right operand
2073+
2074+
span : Optional[Span]
2075+
The location of this operator in the source code.
2076+
2077+
Returns
2078+
-------
2079+
res : PrimExpr
2080+
The result.
2081+
"""
2082+
return _ffi_api.bitwise_xor(x, y, span)
2083+
2084+
20002085
def round(x, span=None):
20012086
"""Round elements of the array to the nearest integer.
20022087

tests/python/unittest/test_tir_op_types.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,19 @@ def test_tir_op_shift_right():
279279
assert expr.op.name == "tir.shift_right"
280280

281281

282+
def test_tir_op_bitwise():
283+
x = tir.Var("x", dtype="int32")
284+
y = tir.Var("y", dtype="int32")
285+
expr = tir.bitwise_and(x, y)
286+
assert expr.op.name == "tir.bitwise_and"
287+
expr = tir.bitwise_or(x, y)
288+
assert expr.op.name == "tir.bitwise_or"
289+
expr = tir.bitwise_not(x)
290+
assert expr.op.name == "tir.bitwise_not"
291+
expr = tir.bitwise_xor(x, y)
292+
assert expr.op.name == "tir.bitwise_xor"
293+
294+
282295
def test_tir_op_TVMBackendAllocWorkspace():
283296
expr = tir.TVMBackendAllocWorkspace(0, 1, 2, 3, 4)
284297
assert expr.op.name == "tir.TVMBackendAllocWorkspace"

0 commit comments

Comments
 (0)