diff --git a/src/transform/unroll_loop.cc b/src/transform/unroll_loop.cc new file mode 100644 index 000000000..229749ab0 --- /dev/null +++ b/src/transform/unroll_loop.cc @@ -0,0 +1,325 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Loop unrolling as in Halide pipeline. + * \file unroll_loop.cc + */ +// Unrolls the loop as in Halide pipeline. +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "runtime/thread_storage_scope.h" +#include "tir/transforms/ir_utils.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +struct UnrollLoopConfigNode + : public AttrsNodeReflAdapter { + int auto_max_step; + int auto_max_depth; + int auto_max_extent; + int explicit_unroll; + int unroll_local_access; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("auto_max_step", &UnrollLoopConfigNode::auto_max_step, + "Threshold of number of steps in the loop to be automatically " + "unrolled", + refl::DefaultValue(0)) + .def_ro("auto_max_depth", &UnrollLoopConfigNode::auto_max_depth, + "The maximum nested level of loops that can be automatically " + "unrolled.", + refl::DefaultValue(8)) + .def_ro("auto_max_extent", &UnrollLoopConfigNode::auto_max_extent, + "The maximum extent` of loop that will be unrolled.", + refl::DefaultValue(0)) + .def_ro( + "explicit_unroll", &UnrollLoopConfigNode::explicit_unroll, + "Whether to explicitly unroll the loop instead of setting a pragma", + refl::DefaultValue(true)) + .def_ro( + "unroll_local_access", &UnrollLoopConfigNode::unroll_local_access, + "Whether to always unroll local access", refl::DefaultValue(false)); + } + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.transform.UnrollLoopConfig", + UnrollLoopConfigNode, BaseAttrsNode); +}; + +class UnrollLoopConfig : public Attrs { +public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(UnrollLoopConfig, Attrs, + UnrollLoopConfigNode); +}; + +TVM_FFI_STATIC_INIT_BLOCK() { UnrollLoopConfigNode::RegisterReflection(); } + +TVM_REGISTER_PASS_CONFIG_OPTION("tl.UnrollLoop", UnrollLoopConfig); + +class VarLocalAccessMarker : public ExprVisitor { +public: + explicit VarLocalAccessMarker(std::unordered_set *var_touched_local) + : var_touched_local_(var_touched_local) {} + + void VisitExpr_(const VarNode *op) final { + var_touched_local_->insert(ffi::GetRef(op)); + } + +private: + std::unordered_set *var_touched_local_; +}; + +// The Visitor is used to check whether var is used as write index in a local +// memory If a loop var is used as indices to a local memory, it must be +// unrolled so the local memory access can be turned into register access. +class LoopUnroller : public StmtExprMutator { +public: + explicit LoopUnroller(int auto_max_step, int auto_max_depth, + int auto_max_extent, bool explicit_unroll, + bool unroll_local_access) + : auto_max_step_(auto_max_step), auto_max_depth_(auto_max_depth), + auto_max_extent_(auto_max_extent), explicit_unroll_(explicit_unroll), + unroll_local_access_(unroll_local_access) {} + + Stmt VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == "pragma_auto_unroll_max_step") { + int value = static_cast(Downcast(op->value)->value); + std::swap(value, auto_max_step_); + Stmt ret = this->VisitStmt(op->body); + std::swap(value, auto_max_step_); + return ret; + } else if (op->attr_key == "pragma_unroll_explicit") { + bool explicit_unroll = Downcast(op->value)->value; + LOG(INFO) << "explicit_unroll: " << explicit_unroll; + std::swap(explicit_unroll, explicit_unroll_); + Stmt ret = this->VisitStmt(op->body); + std::swap(explicit_unroll, explicit_unroll_); + return ret; + } else { + return StmtExprMutator::VisitStmt_(op); + } + } + + Stmt VisitStmt_(const ForNode *op) { + // Post order so we can collect more information + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + int value = GetExtent(op); + // condition for auto unroll + bool auto_unroll = + (op->kind == ForKind::kSerial && value >= 0 && + normal_loop_depth_ == 0 && unroll_depth_ <= auto_max_depth_); + + auto_unroll = auto_unroll && (value * step_count_ <= auto_max_step_ || + value <= auto_max_extent_); + + if (op->kind == ForKind::kUnrolled) { + if (explicit_unroll_) { + ICHECK_GE(value, 0) + << "Cannot unroll non-constant loop " << explicit_unroll_; + } + auto_unroll = true; + } + + // If a loop var is used as indices to a local memory, it must be unrolled + // so the local memory access can be turned into register access. + if (this->var_touched_local_.count(op->loop_var) && value > 0 && + unroll_local_access_) { + auto_unroll = true; + } + + if (auto_unroll) { + step_count_ *= value; + unroll_depth_ += 1; + } else { + normal_loop_depth_ += 1; + } + + if ((auto_unroll && explicit_unroll_) || + // unroll loops with extent = 1, no matter how many steps in body + (0 <= value && value <= auto_max_extent_ && auto_max_extent_ == 1)) { + return Unroll(op); + } else { + if (auto_unroll) { + if (op->kind != ForKind::kUnrolled) { + auto n = CopyOnWrite(op); + n->kind = ForKind::kUnrolled; + return For(n); + } + } + return stmt; + } + } + + PrimExpr VisitExpr_(const BufferLoadNode *op) final { + if (unroll_local_access_) { + auto storage_scope = + runtime::StorageScope::Create(GetPtrStorageScope(op->buffer->data)); + if (storage_scope.rank == runtime::StorageRank::kLocal || + storage_scope.rank == runtime::StorageRank::kWarp) { + VarLocalAccessMarker marker(&var_touched_local_); + for (PrimExpr e : op->indices) { + marker(e); + } + } + } + return ffi::GetRef(op); + } + + Stmt VisitStmt_(const BufferStoreNode *op) final { + ++step_count_; + if (unroll_local_access_) { + auto storage_scope = + runtime::StorageScope::Create(GetPtrStorageScope(op->buffer->data)); + if (storage_scope.rank == runtime::StorageRank::kLocal || + storage_scope.rank == runtime::StorageRank::kWarp) { + VarLocalAccessMarker marker(&var_touched_local_); + for (PrimExpr e : op->indices) { + marker(e); + } + } + } + return StmtExprMutator::VisitStmt_(op); + } + + Stmt VisitStmt_(const EvaluateNode *op) final { + ++step_count_; + return StmtExprMutator::VisitStmt_(op); + } + + Stmt VisitStmt_(const SeqStmtNode *op) final { + auto fmutate = [this](const Stmt &s) { + int step_count = step_count_; + int unroll_depth = unroll_depth_; + int normal_loop_depth = normal_loop_depth_; + step_count_ = 0; + unroll_depth_ = 0; + normal_loop_depth_ = 0; + Stmt ret = this->VisitStmt(s); + step_count_ += step_count; + normal_loop_depth_ = std::max(normal_loop_depth, normal_loop_depth_); + unroll_depth_ = std::max(unroll_depth_, unroll_depth); + return ret; + }; + return StmtExprMutator::VisitSeqStmt_(op, false, fmutate); + } + + Stmt Unroll(const ForNode *op) { + int value = GetExtent(op); + // For loop must have a constant integer extent + ICHECK_NE(value, -1) << "loop doesn't have a constant integer extent"; + if (value == 0) + return Evaluate(0); + Stmt body = op->body; + ffi::Map vmap; + ffi::Array unrolled; + for (int i = 0; i < value; ++i) { + vmap.Set(op->loop_var, op->min + make_const(op->loop_var.dtype(), i)); + Stmt step = Substitute(body, vmap); + unrolled.push_back(step); + } + return SeqStmt::Flatten(unrolled); + } + +private: + // returns the extent of the loop if it's a constant integer, otherwise return + // -1 + int GetExtent(const ForNode *op) { + // constant folding. + PrimExpr extent = analyzer_.Simplify(op->extent); + const IntImmNode *v1 = extent.as(); + int value = -1; + // integers that do not fit in int32_t are treated as symbolic, + // as it's impossible to unroll such large loops + if (v1 != nullptr && v1->value <= std::numeric_limits::max()) { + value = static_cast(v1->value); + } + return value; + } + + // maximum number of step to perform auto unroll. + int auto_max_step_; + int auto_max_depth_; + // max extent of loop to auto unroll + // this does not count the total steps, only count the number of loops + int auto_max_extent_; + bool explicit_unroll_; + // Wether to unroll loops to local access. + bool unroll_local_access_{false}; + // Number of normal loops in scope + int normal_loop_depth_{0}; + // number of unrolled cases in current scope. + int unroll_depth_{0}; + // Number of total steps unrolled + int step_count_{0}; + // set of indices touched during visit local memory + std::unordered_set var_touched_local_; + // analyzer + arith::Analyzer analyzer_; +}; + +Stmt UnrollLoop(Stmt stmt, UnrollLoopConfig cfg) { + Stmt ret = LoopUnroller(cfg->auto_max_step, cfg->auto_max_depth, + cfg->auto_max_extent, cfg->explicit_unroll, + cfg->unroll_local_access)(stmt); + if (!ret.same_as(stmt)) { + return ConvertSSA(ret); + } else { + return ret; + } +} + +namespace transform { + +using namespace tir::transform; + +Pass UnrollLoop() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto *n = f.CopyOnWrite(); + auto cfg = ctx->GetConfig("tl.UnrollLoop"); + if (!cfg.defined()) { + cfg = AttrsWithDefaultValues(); + } + n->body = tl::UnrollLoop(f->body, cfg.value()); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tl.UnrollLoop", {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tl.transform.UnrollLoop", UnrollLoop); +} + +} // namespace transform + +} // namespace tl +} // namespace tvm diff --git a/testing/python/language/test_tilelang_language_unroll.py b/testing/python/language/test_tilelang_language_unroll.py index 665e57737..2adb63855 100644 --- a/testing/python/language/test_tilelang_language_unroll.py +++ b/testing/python/language/test_tilelang_language_unroll.py @@ -1,3 +1,4 @@ +import tilelang import tilelang.testing from tilelang import tvm as tvm from tilelang import language as T @@ -33,5 +34,22 @@ def main(A_ptr: T.handle): assert "#pragma unroll 4" in kernel.get_kernel_source() +def test_unroll_with_extent_only(): + """Test T.unroll with only extent parameter.""" + + @tilelang.jit + def unroll_kernel(): + out = T.empty((512,), dtype=T.float32) + with T.Kernel(1, threads=512): + tid = T.get_thread_binding() + for i in T.unroll(tid % 32): + out[i] = i + return out + + kernel = unroll_kernel.compile() + source = kernel.get_kernel_source() + assert "#pragma unroll" in source + + if __name__ == "__main__": tilelang.testing.main() diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 43eb0eca4..e518350dc 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -241,7 +241,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: mod = tir.transform.Simplify()(mod) mod = tilelang.transform.VectorizeLoop(enable_vectorize=allow_vectorize(pass_ctx=pass_ctx))(mod) mod = tilelang.transform.StorageRewrite()(mod) - mod = tir.transform.UnrollLoop()(mod) + mod = tilelang.transform.UnrollLoop()(mod) mod = tir.transform.RenormalizeSplitPattern()(mod) mod = tir.transform.Simplify()(mod) mod = tir.transform.RemoveNoOp()(mod) diff --git a/tilelang/transform/__init__.py b/tilelang/transform/__init__.py index ddea1a96e..98a5d2bc9 100644 --- a/tilelang/transform/__init__.py +++ b/tilelang/transform/__init__.py @@ -510,3 +510,21 @@ def LayoutReducer(): The transform pass object produced by the FFI backend. """ return _ffi_api.LayoutReducer() # type: ignore + + +def UnrollLoop(): + """Unroll loops as in Halide pipeline. + + This pass unrolls loops based on configuration options including: + - auto_max_step: Threshold of number of steps to be automatically unrolled + - auto_max_depth: Maximum nested level of loops that can be automatically unrolled + - auto_max_extent: Maximum extent of loop that will be unrolled + - explicit_unroll: Whether to explicitly unroll instead of setting a pragma + - unroll_local_access: Whether to always unroll local access + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.UnrollLoop() # type: ignore