Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/tir/schedule/ir_comparator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,30 @@ bool TensorizeComparator::VisitExpr(const PrimExpr& n, const PrimExpr& other) {
return equal;
}

bool TensorizeComparator::VisitExpr_(const CallNode* op, const PrimExpr& other) {
const auto* rhs = other.as<CallNode>();
if (!rhs->op.same_as(op->op)) return false;
if (op->dtype.code() != rhs->dtype.code()) {
if (assert_mode_) {
std::ostringstream os;
os << "CallNode data type codes do not match: op->dtype.code()=" << op->dtype.code()
<< " vs rhs->dtype.code()=" << rhs->dtype.code();
EmitError(os.str());
}
return false;
}
if (!CompareArray(op->args, rhs->args, &TensorizeComparator::VisitExpr)) {
if (assert_mode_) {
std::ostringstream os;
os << "CallNode iter_values do not match: op->iter_values=" << op->args
<< " vs rhs->iter_values=" << rhs->args;
EmitError(os.str());
}
return false;
}
return true;
}

bool TensorizeComparator::VisitStmt_(const ForNode* op, const Stmt& other) {
const auto* rhs = other.as<ForNode>();
if (!DefEqual(op->loop_var, rhs->loop_var)) {
Expand Down
1 change: 1 addition & 0 deletions src/tir/schedule/ir_comparator.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class TensorizeComparator : public ExprComparator, public StmtComparator {
bool VisitExpr(const PrimExpr& n, const PrimExpr& other) override;
bool VisitStmt(const Stmt& n, const Stmt& other) override;

bool VisitExpr_(const CallNode* op, const PrimExpr& other) override;
bool VisitStmt_(const ForNode* op, const Stmt& other) override;
bool VisitStmt_(const SeqStmtNode* op, const Stmt& other) override;
bool VisitStmt_(const BufferStoreNode* op, const Stmt& other) override;
Expand Down
27 changes: 26 additions & 1 deletion src/tir/schedule/primitive/blockize_tensorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <functional>

#include "../../transforms/simplify.h"
#include "../ir_comparator.h"
#include "../utils.h"

Expand Down Expand Up @@ -738,6 +739,28 @@ StmtSRef Blockize(ScheduleState self, const Array<StmtSRef>& blocks, bool preser
return result;
}

class TensorIntrinSimplifier : public arith::IRMutatorWithAnalyzer {
public:
static PrimFunc Apply(PrimFunc func, arith::Analyzer* analyzer) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of simplifying the body of the PrimFunc, can we instead simplify the entire PrimFunc? That way, dynamic expressions that are used in shapes are exposed to the analyzer as non-negative. (e.g. Using buffer of shape [n,m] implies that n >= 0 && m >= 0.)

Copy link
Contributor Author

@LeiWang1999 LeiWang1999 Feb 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as u mentioned in #13299 (comment) , perhaps its better to simplify in prim_func level, I chose to implement a stmt simplifier because it may be more useful. The rationale is that stmt is more fine-grained.

Moreover, in the context of tensor desc in tensorize schedule, prim_func typically encompasses a single block without dynamic symbolic. I think for this issue a stmt simplifier is enough.
But we can implement a prim_func one as well, should we keep both stmt and primfunc simplifier or just maintain only one of them?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as u mentioned in #13299 (comment) , perhaps its better to simplify in prim_func level, I chose to implement a stmt simplifier because it may be more useful. The rationale is that stmt is more fine-grained.

Good point. Thinking on it again in the morning, I think we should avoid having the simplify function for tir::Stmt altogether, because it is more fine-grained. That is, its existence would encourage simplifications to be performed for specific statements, even though those statements might not be the outer-most.

But we can implement a prim_func one as well, should we keep both stmt and primfunc simplifier or just maintain only one of them?

I think having a simplifier for a PrimFunc would be better, because it encourages developers to simplify with the full context of a statement. The functionality already exists here, and would just need a wrapper function to expose StmtSimplifier::Apply.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Lunderberg hi, I think this pr is ready for review.

TensorIntrinSimplifier simplifier(analyzer);
func.CopyOnWrite()->body = simplifier(func->body);
return func;
}

private:
explicit TensorIntrinSimplifier(arith::Analyzer* analyzer) : IRMutatorWithAnalyzer(analyzer) {}

using Parent = IRMutatorWithAnalyzer;
using Parent::VisitExpr_;
using Parent::VisitStmt;
using Parent::VisitStmt_;

Stmt VisitStmt_(const BlockNode* block) final {
Block sref = GetRef<Block>(block);
return tvm::tir::Simplify(sref, analyzer_);
}
};

void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& intrin,
bool preserve_unit_iters) {
// Step 1: Blockize the subtree rooted at the given loop if needed
Expand All @@ -755,7 +778,9 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int
<< GetRef<Stmt>(sref->stmt);
throw;
}
PrimFunc intrin_desc = intrin->desc;

arith::Analyzer analyzer;
PrimFunc intrin_desc = TensorIntrinSimplifier::Apply(intrin->desc, &analyzer);
PrimFunc intrin_impl = DeepCopy(intrin->impl);

int index_dtype_bits = -1;
Expand Down
23 changes: 23 additions & 0 deletions src/tir/transforms/simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

#include <optional>

#include "simplify.h"
#include "../../arith/ir_mutator_with_analyzer.h"
#include "../../tir/analysis/control_flow_graph.h"
#include "../../tir/analysis/var_use_def_analysis.h"
Expand Down Expand Up @@ -162,6 +163,23 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
return func;
}

static Stmt Apply(Stmt stmt, Analyzer* analyzer, Optional<SimplifyConfig> config_opt = NullOpt) {
auto config = config_opt.value_or(AttrsWithDefaultValues<arith::SimplifyConfig>());
analyzer->rewrite_simplify.SetEnabledExtensions(config->GetEnabledExtensions());

std::optional<ControlFlowGraph> touch_pattern = std::nullopt;
if (config->propagate_knowns_to_prove_conditional ||
config->propagate_knowns_to_simplify_expressions) {
touch_pattern = ControlFlowGraph(stmt);
}

std::unordered_set<const VarNode*> used_in_buffer_def = CollectVarsUsedInBufferDefinition(stmt);
StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern),
std::move(used_in_buffer_def));
stmt = simplifier.Simplify(std::move(stmt));
return stmt;
}

private:
explicit StmtSimplifier(Analyzer* analyzer, SimplifyConfig config,
std::optional<ControlFlowGraph> touch_pattern,
Expand Down Expand Up @@ -339,6 +357,11 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
} // namespace arith

namespace tir {

Stmt Simplify(Stmt stmt, arith::Analyzer* analyzer) {
return arith::StmtSimplifier::Apply(std::move(stmt), analyzer);
}

namespace transform {

Pass Simplify() {
Expand Down
118 changes: 118 additions & 0 deletions tests/python/tir-schedule/test_tir_schedule_tensorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,124 @@ def tensorized_matmul_int64_shape(
assert_structural_equal_ignore_global_symbol(s.mod["main"], tensorized_matmul_int64_shape)
verify_trace_roundtrip(sch=s, mod=matmul_int64_shape)

def _tir_packed_int_to_int_to_float(storage_nbit: int):
storage_dtype = "int" + str(storage_nbit)

def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert val.dtype == storage_dtype
mask = tir.const((1 << nbit) - 1, "int32")
unextended = (val >> (pos.astype("int32") * tir.const(nbit, "int32"))) & mask
return tir.Cast(dtype, (unextended << tir.const(32 - nbit, "int32")) >> tir.const(32 - nbit, "int32"))

return f_convert

@T.prim_func
def decode_i4s_to_f16_desc(compressed: T.handle, decompressed: T.handle) -> None:
Compressed = T.match_buffer(
compressed,
[
1,
],
dtype="int32",
scope="local",
)
Decompressed = T.match_buffer(
decompressed,
[
8,
],
dtype="float16",
scope="local",
)

with T.block("root"):
T.reads(Compressed[0:1])
T.writes(Decompressed[0:8])
for i in T.grid(8):
with T.block("decode"):
vi = T.axis.remap("S", [i])
Decompressed[vi] = _tir_packed_int_to_int_to_float(32)(
4,
Compressed[vi // 8],
vi % 8,
dtype="float16",
)

@T.prim_func
def decode_i4s_to_f16_impl(compressed: T.handle, decompressed: T.handle) -> None:
Compressed = T.match_buffer(
compressed,
[
1,
],
dtype="int32",
scope="local",
)
Decompressed = T.match_buffer(
decompressed,
[
8,
],
dtype="float16",
scope="local",
)

with T.block("root"):
T.reads(Compressed[0:1])
T.writes(Decompressed[0:8])
T.call_extern(
"handle",
"test_decode_i4s_to_f16",
Compressed.data,
Decompressed.data,
8,
)

tir.TensorIntrin.register("test_decode_i4s_to_f16_intrin", decode_i4s_to_f16_desc, decode_i4s_to_f16_impl)

def test_tensorize_arith_simplification():
# fmt: off
@T.prim_func
def decode_i4s_to_int32_to_f16():
B_decode_local = T.alloc_buffer((16384, 16384), "float16", scope="local")
B_local = T.alloc_buffer((16384, 2048), "int32", scope="local")
for ax0_0 in T.thread_binding(8192, thread="blockIdx.x"):
for ax0_1 in T.thread_binding(2, thread="threadIdx.y"):
for ax1_0 in range(32):
for ax1_1 in T.thread_binding(64, thread="threadIdx.x"):
for ax0, ax1 in T.grid(1, 8):
with T.block("B_decode_local"):
v0 = T.axis.spatial(16384, ax0_0 * 2 + ax0_1 + ax0)
v1 = T.axis.spatial(16384, ax1_0 * 512 + ax1_1 * 8 + ax1)
T.reads(B_local[v0, v1 // 8])
T.writes(B_decode_local[v0, v1])
B_decode_local[v0, v1] = T.Cast("float16", T.shift_right(T.shift_left(T.bitwise_and(T.shift_right(B_local[v0, v1 // 8], v1 % 8 * 4), 15), 28), 28))

@T.prim_func
def tensorized_decode_i4s_to_int32_to_f16():
B_decode_local = T.alloc_buffer((16384, 16384), "float16", scope="local")
B_local = T.alloc_buffer((16384, 2048), "int32", scope="local")
for ax0_0 in T.thread_binding(8192, thread="blockIdx.x"):
for ax0_1 in T.thread_binding(2, thread="threadIdx.y"):
for ax1_0 in range(32):
for ax1_1 in T.thread_binding(64, thread="threadIdx.x"):
for ax0 in range(1):
with T.block("B_decode_local_o"):
v0_o = T.axis.spatial(16384, ax0_0 * 2 + ax0_1 + ax0)
v1_o = T.axis.spatial(2048, ax1_0 * 64 + ax1_1)
T.reads(B_local[v0_o, v1_o])
T.writes(B_decode_local[v0_o, v1_o * 8:v1_o * 8 + 8])
Compressed = T.match_buffer(B_local[v0_o, v1_o], (1,), "int32", scope="local")
Decompressed = T.match_buffer(B_decode_local[v0_o, v1_o * 8:v1_o * 8 + 8], (8,), "float16", scope="local")
T.call_extern("handle", "test_decode_i4s_to_f16", Compressed.data, Decompressed.data, 8)

s = tir.Schedule(decode_i4s_to_int32_to_f16, debug_mask="all")
update = s.get_block("B_decode_local")
ii = s.get_loops(update)[-1]
s.tensorize(ii, "test_decode_i4s_to_f16_intrin")
assert_structural_equal_ignore_global_symbol(s.mod["main"], tensorized_decode_i4s_to_int32_to_f16)
verify_trace_roundtrip(sch=s, mod=decode_i4s_to_int32_to_f16)


if __name__ == "__main__":
tvm.testing.main()