Skip to content

Commit 7b7677f

Browse files
authored
[TIR] Enhance and fix tensorize schedule for some case (#16560)
* support tensorize with simplified and call expr * replace stmt simplifier with primfunc simplifier * lint fix * lint:remove white space * lint: remove white space * cpp lint fix * lint: resolve include * clang format lint fix
1 parent 657880c commit 7b7677f

File tree

6 files changed

+159
-6
lines changed

6 files changed

+159
-6
lines changed

src/tir/schedule/ir_comparator.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,30 @@ bool TensorizeComparator::VisitExpr(const PrimExpr& n, const PrimExpr& other) {
8383
return equal;
8484
}
8585

86+
bool TensorizeComparator::VisitExpr_(const CallNode* op, const PrimExpr& other) {
87+
const auto* rhs = other.as<CallNode>();
88+
if (!rhs->op.same_as(op->op)) return false;
89+
if (op->dtype.code() != rhs->dtype.code()) {
90+
if (assert_mode_) {
91+
std::ostringstream os;
92+
os << "CallNode data type codes do not match: op->dtype.code()=" << op->dtype.code()
93+
<< " vs rhs->dtype.code()=" << rhs->dtype.code();
94+
EmitError(os.str());
95+
}
96+
return false;
97+
}
98+
if (!CompareArray(op->args, rhs->args, &TensorizeComparator::VisitExpr)) {
99+
if (assert_mode_) {
100+
std::ostringstream os;
101+
os << "CallNode iter_values do not match: op->iter_values=" << op->args
102+
<< " vs rhs->iter_values=" << rhs->args;
103+
EmitError(os.str());
104+
}
105+
return false;
106+
}
107+
return true;
108+
}
109+
86110
bool TensorizeComparator::VisitStmt_(const ForNode* op, const Stmt& other) {
87111
const auto* rhs = other.as<ForNode>();
88112
if (!DefEqual(op->loop_var, rhs->loop_var)) {

src/tir/schedule/ir_comparator.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class TensorizeComparator : public ExprComparator, public StmtComparator {
4646
bool VisitExpr(const PrimExpr& n, const PrimExpr& other) override;
4747
bool VisitStmt(const Stmt& n, const Stmt& other) override;
4848

49+
bool VisitExpr_(const CallNode* op, const PrimExpr& other) override;
4950
bool VisitStmt_(const ForNode* op, const Stmt& other) override;
5051
bool VisitStmt_(const SeqStmtNode* op, const Stmt& other) override;
5152
bool VisitStmt_(const BufferStoreNode* op, const Stmt& other) override;

src/tir/schedule/primitive/blockize_tensorize.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
#include <functional>
2222

23+
#include "../../transforms/simplify.h"
2324
#include "../ir_comparator.h"
2425
#include "../utils.h"
2526

@@ -755,7 +756,9 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int
755756
<< GetRef<Stmt>(sref->stmt);
756757
throw;
757758
}
758-
PrimFunc intrin_desc = intrin->desc;
759+
760+
arith::Analyzer analyzer;
761+
PrimFunc intrin_desc = Simplify(intrin->desc, &analyzer);
759762
PrimFunc intrin_impl = DeepCopy(intrin->impl);
760763

761764
int index_dtype_bits = -1;

src/tir/transforms/simplify.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
* \file simplify.cc
2222
* \brief Statement simplifier based on analyzer
2323
*/
24+
25+
#include "../../tir/transforms/simplify.h"
26+
2427
#include <tvm/arith/analyzer.h>
2528
#include <tvm/runtime/registry.h>
2629
#include <tvm/tir/analysis.h>
@@ -339,6 +342,11 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
339342
} // namespace arith
340343

341344
namespace tir {
345+
346+
PrimFunc Simplify(PrimFunc func, arith::Analyzer* analyzer) {
347+
return arith::StmtSimplifier::Apply(std::move(func), analyzer);
348+
}
349+
342350
namespace transform {
343351

344352
Pass Simplify() {

src/tir/transforms/simplify.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,16 @@
2525
#define TVM_TIR_TRANSFORMS_SIMPLIFY_H_
2626

2727
#include <tvm/arith/analyzer.h>
28-
#include <tvm/tir/stmt.h>
28+
#include <tvm/tir/function.h>
2929

3030
namespace tvm {
3131
namespace tir {
3232

33-
/* \brief Simplifies the statement
33+
/* \brief Simplifies the prim func
3434
*
35-
* Applies the same behavior as the tir.transform.Simplify pass, but
36-
* on a single statement, usable as a subroutine in other passes.
35+
* Applies the same behavior as the tir.transform.Simplify pass.
3736
*/
38-
Stmt Simplify(Stmt stmt, arith::Analyzer* analyzer);
37+
PrimFunc Simplify(PrimFunc stmt, arith::Analyzer* analyzer);
3938

4039
} // namespace tir
4140
} // namespace tvm

tests/python/tir-schedule/test_tir_schedule_tensorize.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -836,6 +836,124 @@ def tensorized_matmul_int64_shape(
836836
assert_structural_equal_ignore_global_symbol(s.mod["main"], tensorized_matmul_int64_shape)
837837
verify_trace_roundtrip(sch=s, mod=matmul_int64_shape)
838838

839+
def _tir_packed_int_to_int_to_float(storage_nbit: int):
840+
storage_dtype = "int" + str(storage_nbit)
841+
842+
def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
843+
assert val.dtype == storage_dtype
844+
mask = tir.const((1 << nbit) - 1, "int32")
845+
unextended = (val >> (pos.astype("int32") * tir.const(nbit, "int32"))) & mask
846+
return tir.Cast(dtype, (unextended << tir.const(32 - nbit, "int32")) >> tir.const(32 - nbit, "int32"))
847+
848+
return f_convert
849+
850+
@T.prim_func
851+
def decode_i4s_to_f16_desc(compressed: T.handle, decompressed: T.handle) -> None:
852+
Compressed = T.match_buffer(
853+
compressed,
854+
[
855+
1,
856+
],
857+
dtype="int32",
858+
scope="local",
859+
)
860+
Decompressed = T.match_buffer(
861+
decompressed,
862+
[
863+
8,
864+
],
865+
dtype="float16",
866+
scope="local",
867+
)
868+
869+
with T.block("root"):
870+
T.reads(Compressed[0:1])
871+
T.writes(Decompressed[0:8])
872+
for i in T.grid(8):
873+
with T.block("decode"):
874+
vi = T.axis.remap("S", [i])
875+
Decompressed[vi] = _tir_packed_int_to_int_to_float(32)(
876+
4,
877+
Compressed[vi // 8],
878+
vi % 8,
879+
dtype="float16",
880+
)
881+
882+
@T.prim_func
883+
def decode_i4s_to_f16_impl(compressed: T.handle, decompressed: T.handle) -> None:
884+
Compressed = T.match_buffer(
885+
compressed,
886+
[
887+
1,
888+
],
889+
dtype="int32",
890+
scope="local",
891+
)
892+
Decompressed = T.match_buffer(
893+
decompressed,
894+
[
895+
8,
896+
],
897+
dtype="float16",
898+
scope="local",
899+
)
900+
901+
with T.block("root"):
902+
T.reads(Compressed[0:1])
903+
T.writes(Decompressed[0:8])
904+
T.call_extern(
905+
"handle",
906+
"test_decode_i4s_to_f16",
907+
Compressed.data,
908+
Decompressed.data,
909+
8,
910+
)
911+
912+
tir.TensorIntrin.register("test_decode_i4s_to_f16_intrin", decode_i4s_to_f16_desc, decode_i4s_to_f16_impl)
913+
914+
def test_tensorize_arith_simplification():
915+
# fmt: off
916+
@T.prim_func
917+
def decode_i4s_to_int32_to_f16():
918+
B_decode_local = T.alloc_buffer((16384, 16384), "float16", scope="local")
919+
B_local = T.alloc_buffer((16384, 2048), "int32", scope="local")
920+
for ax0_0 in T.thread_binding(8192, thread="blockIdx.x"):
921+
for ax0_1 in T.thread_binding(2, thread="threadIdx.y"):
922+
for ax1_0 in range(32):
923+
for ax1_1 in T.thread_binding(64, thread="threadIdx.x"):
924+
for ax0, ax1 in T.grid(1, 8):
925+
with T.block("B_decode_local"):
926+
v0 = T.axis.spatial(16384, ax0_0 * 2 + ax0_1 + ax0)
927+
v1 = T.axis.spatial(16384, ax1_0 * 512 + ax1_1 * 8 + ax1)
928+
T.reads(B_local[v0, v1 // 8])
929+
T.writes(B_decode_local[v0, v1])
930+
B_decode_local[v0, v1] = T.Cast("float16", T.shift_right(T.shift_left(T.bitwise_and(T.shift_right(B_local[v0, v1 // 8], v1 % 8 * 4), 15), 28), 28))
931+
932+
@T.prim_func
933+
def tensorized_decode_i4s_to_int32_to_f16():
934+
B_decode_local = T.alloc_buffer((16384, 16384), "float16", scope="local")
935+
B_local = T.alloc_buffer((16384, 2048), "int32", scope="local")
936+
for ax0_0 in T.thread_binding(8192, thread="blockIdx.x"):
937+
for ax0_1 in T.thread_binding(2, thread="threadIdx.y"):
938+
for ax1_0 in range(32):
939+
for ax1_1 in T.thread_binding(64, thread="threadIdx.x"):
940+
for ax0 in range(1):
941+
with T.block("B_decode_local_o"):
942+
v0_o = T.axis.spatial(16384, ax0_0 * 2 + ax0_1 + ax0)
943+
v1_o = T.axis.spatial(2048, ax1_0 * 64 + ax1_1)
944+
T.reads(B_local[v0_o, v1_o])
945+
T.writes(B_decode_local[v0_o, v1_o * 8:v1_o * 8 + 8])
946+
Compressed = T.match_buffer(B_local[v0_o, v1_o], (1,), "int32", scope="local")
947+
Decompressed = T.match_buffer(B_decode_local[v0_o, v1_o * 8:v1_o * 8 + 8], (8,), "float16", scope="local")
948+
T.call_extern("handle", "test_decode_i4s_to_f16", Compressed.data, Decompressed.data, 8)
949+
950+
s = tir.Schedule(decode_i4s_to_int32_to_f16, debug_mask="all")
951+
update = s.get_block("B_decode_local")
952+
ii = s.get_loops(update)[-1]
953+
s.tensorize(ii, "test_decode_i4s_to_f16_intrin")
954+
assert_structural_equal_ignore_global_symbol(s.mod["main"], tensorized_decode_i4s_to_int32_to_f16)
955+
verify_trace_roundtrip(sch=s, mod=decode_i4s_to_int32_to_f16)
956+
839957

840958
if __name__ == "__main__":
841959
tvm.testing.main()

0 commit comments

Comments
 (0)