Skip to content

Commit a64d1f1

Browse files
authored
[TIR] Make T.reinterpret nop when dtype is the same (#16879)
* [TIR] Make T.reinterpret nop when dtype is the same * fix scalable vec handling
1 parent 64911ab commit a64d1f1

File tree

4 files changed

+31
-5
lines changed

4 files changed

+31
-5
lines changed

python/tvm/tir/op.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1789,7 +1789,7 @@ def infinity(dtype: str, span: Optional[Span] = None) -> Any:
17891789
return _ffi_api.infinity(dtype, span) # type: ignore
17901790

17911791

1792-
def reinterpret(dtype, value) -> Any:
1792+
def reinterpret(dtype, value, span: Optional[Span] = None) -> Any:
17931793
"""infinity value of dtype
17941794
17951795
Parameters
@@ -1808,7 +1808,7 @@ def reinterpret(dtype, value) -> Any:
18081808
value : tvm.Expr
18091809
The reinterpret cast value of dtype.
18101810
"""
1811-
return call_intrin(dtype, "tir.reinterpret", value)
1811+
return _ffi_api.reinterpret(dtype, value, span) # type: ignore
18121812

18131813

18141814
def exp(x):

src/tir/op/op.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,8 +409,10 @@ PrimExpr cast(const DataType& t, PrimExpr value, Span span) {
409409
// reinterpret
410410
PrimExpr reinterpret(const DataType& t, PrimExpr value, Span span) {
411411
if (value.dtype() == t) return value;
412-
ICHECK(value.dtype().bits() * value.dtype().lanes() == t.bits() * t.lanes())
413-
<< "Bitcast requires size match " << t << " vs " << value.dtype();
412+
if (!t.is_scalable_vector() && !value.dtype().is_scalable_vector()) {
413+
ICHECK(value.dtype().bits() * value.dtype().lanes() == t.bits() * t.lanes())
414+
<< "Bitcast requires size match " << t << " vs " << value.dtype();
415+
}
414416
return tir::Call(t, tir::builtin::reinterpret(), {value}, span);
415417
}
416418

@@ -1083,6 +1085,8 @@ TVM_REGISTER_GLOBAL("tir.trunc").set_body_typed(tvm::trunc);
10831085

10841086
TVM_REGISTER_GLOBAL("tir._cast").set_body_typed(tvm::cast);
10851087

1088+
TVM_REGISTER_GLOBAL("tir.reinterpret").set_body_typed(tvm::reinterpret);
1089+
10861090
// operator overloading, smarter than make
10871091
#define REGISTER_MAKE_BINARY_OP(Node, Func) \
10881092
TVM_REGISTER_GLOBAL("tir." #Node).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { \

tests/python/codegen/test_target_codegen_cuda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1120,7 +1120,7 @@ def test_invalid_reinterpret():
11201120
@T.prim_func
11211121
def func(A: T.Buffer((4,), "uint32"), B: T.Buffer((4,), "uint8")) -> None:
11221122
for tx in T.thread_binding(4, "threadIdx.x"):
1123-
B[tx] = T.reinterpret("uint8", A[tx])
1123+
B[tx] = T.call_intrin("uint8", "tir.reinterpret", A[tx])
11241124

11251125
with pytest.raises(tvm.error.TVMError):
11261126
tvm.build(func, target="cuda")

tests/python/tvmscript/test_tvmscript_parser_tir.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,5 +449,27 @@ def func(a_handle: T.handle, b_handle: T.handle):
449449
tvm.ir.assert_structural_equal(func.struct_info, expected)
450450

451451

452+
def test_reinterpret_nop():
453+
"""Test builtin reinterpret op"""
454+
455+
@T.prim_func
456+
def func(A: T.Buffer((32,), "float32"), B: T.Buffer((32,), "float32")) -> None:
457+
T.func_attr({"global_symbol": "main"})
458+
for i in T.serial(0, 32):
459+
with T.block():
460+
vi = T.axis.remap("S", [i])
461+
B[vi] = T.reinterpret("float32", A[vi])
462+
463+
@T.prim_func
464+
def expected(A: T.Buffer((32,), "float32"), B: T.Buffer((32,), "float32")) -> None:
465+
T.func_attr({"global_symbol": "main"})
466+
for i in T.serial(0, 32):
467+
with T.block():
468+
vi = T.axis.remap("S", [i])
469+
B[vi] = A[vi]
470+
471+
tvm.ir.assert_structural_equal(func, expected)
472+
473+
452474
if __name__ == "__main__":
453475
tvm.testing.main()

0 commit comments

Comments
 (0)