diff --git a/3rdparty/tvm b/3rdparty/tvm index a64a5926a..1fc7578cd 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit a64a5926a6e59f5417ef2501f9d88b467337cf6a +Subproject commit 1fc7578cd1ff934455b07597508b5a67d7cb5a73 diff --git a/src/transform/inject_assumes.cc b/src/transform/inject_assumes.cc new file mode 100644 index 000000000..a2ddfc4a0 --- /dev/null +++ b/src/transform/inject_assumes.cc @@ -0,0 +1,164 @@ + +#include "tvm/arith/analyzer.h" +#include "tvm/ffi/optional.h" +#include "tvm/ir/expr.h" +#include "tvm/ir/transform.h" +#include "tvm/node/structural_hash.h" +#include "tvm/tir/builtin.h" +#include "tvm/tir/expr.h" +#include "tvm/tir/stmt.h" +#include "tvm/tir/stmt_functor.h" +#include "tvm/tir/transform.h" +#include + +namespace tvm::tl { +using namespace tir; + +class AssumeInjector : public tvm::tir::StmtExprMutator { + using Base = tvm::tir::StmtExprMutator; + +public: + AssumeInjector(PrimFunc f) : f(f) {} + static PrimFunc Substitute(PrimFunc f) { + auto injector = AssumeInjector(f); + f.CopyOnWrite()->body = injector(f->body); + return f; + } + +private: + struct AssertCreator { + struct Item { + PrimExpr expr; + std::vector buffers; + }; + tvm::StructuralHash sh; + tvm::StructuralEqual se; + // grouped by expr, since the amount of varidic shape symbols is usualy much + // smaller than buffer + std::vector items; + // hash => index in items + std::unordered_map> buckets; + void addExpr(PrimExpr e, Buffer buffer) { + size_t h = sh(e); + auto &bucket = buckets[h]; + auto it = std::find_if(bucket.begin(), bucket.end(), [&](size_t y) { + return se(e, items[y].expr, true); + }); + if (it == bucket.end()) { + auto index = items.size(); + items.push_back({e, {buffer}}); + bucket.push_back(index); + } else { + items[*it].buffers.push_back(buffer); + } + } + void addBuffer(Buffer buf) { + for (auto shape : buf->shape) { + if (shape->IsInstance()) + continue; + addExpr(shape, buf); + } + } + Stmt build(Stmt body) { + auto analyzer = arith::Analyzer{}; + for (const auto &e : items) { + auto simplified = analyzer.Simplify(GT(e.expr, 0)); + std::stringstream ss; + ss << "Buffer shape should be greater than 0: shape `" << e.expr + << "` from buffer "; + for (size_t i = 0; i < e.buffers.size(); i++) { + if (i) + ss << ", "; + ss << "`" << e.buffers[i]->name << "`"; + } + body = AttrStmt(simplified, tir::attr::tilelang_assume, + StringImm(ss.str()), body); + } + return body; + } + }; + Stmt VisitStmt_(const DeclBufferNode *op) final { + auto body = VisitStmt(op->body); + AssertCreator c; + c.addBuffer(op->buffer); + return DeclBuffer(op->buffer, c.build(body), op->span); + } + std::optional getAssumeExpr(Stmt stmt) { + auto eval = stmt.as(); + if (!eval) + return std::nullopt; + auto call = eval->value.as(); + if (!call) + return std::nullopt; + if (!call->op.same_as(builtin::assume())) + return std::nullopt; + return call->args[0]; + } + Stmt VisitStmt_(const SeqStmtNode *op) final { + struct AssumeGroup { + std::optional e; + std::vector stmts; + }; + std::vector groups = {AssumeGroup{std::nullopt, {}}}; + for (auto i = 0; i < op->seq.size(); i++) { + auto stmt = VisitStmt(op->seq[i]); + if (auto e = getAssumeExpr(stmt)) { + groups.push_back(AssumeGroup{*e, {}}); + } else { + groups.back().stmts.push_back(stmt); + } + } + for (size_t i = groups.size(); i--;) { + auto &g = groups[i]; + if (g.e) { + Stmt body = g.stmts.size() == 1 ? g.stmts[0] : SeqStmt(g.stmts); + std::stringstream ss; + ss << "Assume: " << *(g.e); + AttrStmt attr = AttrStmt(*g.e, tir::attr::tilelang_assume, + StringImm(ss.str()), body); + groups[i - 1].stmts.push_back(attr); + } else { + ICHECK(i == 0) << "only the first group can have no assume"; + } + } + return groups[0].stmts.size() == 1 ? groups[0].stmts[0] + : SeqStmt(groups[0].stmts); + // return SeqStmt(groups[0].stmts); + } + Stmt VisitStmt_(const BlockNode *op) final { + auto body = VisitStmt(op->body); + AssertCreator c; + if (root_node) { + for (auto item : f->buffer_map) { + c.addBuffer(item.second); + } + } + for (auto item : op->alloc_buffers) { + c.addBuffer(item); + } + for (auto item : op->match_buffers) { + c.addBuffer(item->buffer); + } + return Block(op->iter_vars, op->reads, op->writes, op->name_hint, + c.build(body), op->init, op->alloc_buffers, op->match_buffers, + op->annotations, op->span); + } + PrimFunc f; + bool root_node{true}; +}; + +using namespace tir::transform; + +tvm::transform::Pass InjectAssumes() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return AssumeInjector::Substitute(f); + }; + return CreatePrimFuncPass(pass_func, 0, "tl.InjectAssumes", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK({ + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.InjectAssumes", InjectAssumes); +}); + +} // namespace tvm::tl diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index f865b0085..646cb66c1 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -87,6 +87,8 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: # Legalize the frontend IR to make it compatible with TVM mod = tilelang.transform.FrontendLegalize()(mod) + # Inject assumes to speedup tvm prover + mod = tilelang.transform.InjectAssumes()(mod) # Simplify the IR expressions mod = tir.transform.Simplify()(mod) # Set layouts for reducers diff --git a/tilelang/transform/__init__.py b/tilelang/transform/__init__.py index d61e29189..da8cf51d9 100644 --- a/tilelang/transform/__init__.py +++ b/tilelang/transform/__init__.py @@ -79,6 +79,17 @@ def FrontendLegalize(): return _ffi_api.FrontendLegalize() # type: ignore +def InjectAssumes(): + """Inject Assumes + + Returns: + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.InjectAssumes() + + def LowerHopperIntrin(): """LowerHopperIntrin