Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
325 changes: 325 additions & 0 deletions src/transform/unroll_loop.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,325 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Loop unrolling as in Halide pipeline.
* \file unroll_loop.cc
*/
// Unrolls the loop as in Halide pipeline.
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include <unordered_set>

#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"

namespace tvm {
namespace tl {

using namespace tir;

struct UnrollLoopConfigNode
: public AttrsNodeReflAdapter<UnrollLoopConfigNode> {
int auto_max_step;
int auto_max_depth;
int auto_max_extent;
int explicit_unroll;
int unroll_local_access;

static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<UnrollLoopConfigNode>()
.def_ro("auto_max_step", &UnrollLoopConfigNode::auto_max_step,
"Threshold of number of steps in the loop to be automatically "
"unrolled",
refl::DefaultValue(0))
.def_ro("auto_max_depth", &UnrollLoopConfigNode::auto_max_depth,
"The maximum nested level of loops that can be automatically "
"unrolled.",
refl::DefaultValue(8))
.def_ro("auto_max_extent", &UnrollLoopConfigNode::auto_max_extent,
"The maximum extent` of loop that will be unrolled.",
refl::DefaultValue(0))
Comment on lines +62 to +64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Minor typo in documentation string.

There's an extra backtick character in the description string: "The maximum extent\ of loop..."should be"The maximum extent of loop..."`.

📝 Suggested fix
         .def_ro("auto_max_extent", &UnrollLoopConfigNode::auto_max_extent,
-                "The maximum extent` of loop that will be unrolled.",
+                "The maximum extent of loop that will be unrolled.",
                 refl::DefaultValue(0))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
.def_ro("auto_max_extent", &UnrollLoopConfigNode::auto_max_extent,
"The maximum extent` of loop that will be unrolled.",
refl::DefaultValue(0))
.def_ro("auto_max_extent", &UnrollLoopConfigNode::auto_max_extent,
"The maximum extent of loop that will be unrolled.",
refl::DefaultValue(0))
🤖 Prompt for AI Agents
In @src/transform/unroll_loop.cc around lines 62 - 64, The docstring for the
UnrollLoopConfigNode::auto_max_extent binding contains an extraneous backtick;
update the .def_ro call that binds UnrollLoopConfigNode::auto_max_extent to
remove the backtick so the description reads "The maximum extent of loop that
will be unrolled." (i.e., edit the string argument in the .def_ro for
auto_max_extent to delete the stray ` character).

.def_ro(
"explicit_unroll", &UnrollLoopConfigNode::explicit_unroll,
"Whether to explicitly unroll the loop instead of setting a pragma",
refl::DefaultValue(true))
.def_ro(
"unroll_local_access", &UnrollLoopConfigNode::unroll_local_access,
"Whether to always unroll local access", refl::DefaultValue(false));
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.transform.UnrollLoopConfig",
UnrollLoopConfigNode, BaseAttrsNode);
};

class UnrollLoopConfig : public Attrs {
public:
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(UnrollLoopConfig, Attrs,
UnrollLoopConfigNode);
};

TVM_FFI_STATIC_INIT_BLOCK() { UnrollLoopConfigNode::RegisterReflection(); }

TVM_REGISTER_PASS_CONFIG_OPTION("tl.UnrollLoop", UnrollLoopConfig);

class VarLocalAccessMarker : public ExprVisitor {
public:
explicit VarLocalAccessMarker(std::unordered_set<Var> *var_touched_local)
: var_touched_local_(var_touched_local) {}

void VisitExpr_(const VarNode *op) final {
var_touched_local_->insert(ffi::GetRef<Var>(op));
}

private:
std::unordered_set<Var> *var_touched_local_;
};

// The Visitor is used to check whether var is used as write index in a local
// memory If a loop var is used as indices to a local memory, it must be
// unrolled so the local memory access can be turned into register access.
class LoopUnroller : public StmtExprMutator {
public:
explicit LoopUnroller(int auto_max_step, int auto_max_depth,
int auto_max_extent, bool explicit_unroll,
bool unroll_local_access)
: auto_max_step_(auto_max_step), auto_max_depth_(auto_max_depth),
auto_max_extent_(auto_max_extent), explicit_unroll_(explicit_unroll),
unroll_local_access_(unroll_local_access) {}

Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == "pragma_auto_unroll_max_step") {
int value = static_cast<int>(Downcast<Integer>(op->value)->value);
std::swap(value, auto_max_step_);
Stmt ret = this->VisitStmt(op->body);
std::swap(value, auto_max_step_);
return ret;
} else if (op->attr_key == "pragma_unroll_explicit") {
bool explicit_unroll = Downcast<Integer>(op->value)->value;
LOG(INFO) << "explicit_unroll: " << explicit_unroll;
std::swap(explicit_unroll, explicit_unroll_);
Stmt ret = this->VisitStmt(op->body);
std::swap(explicit_unroll, explicit_unroll_);
return ret;
Comment on lines +119 to +125
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove debug logging statement.

LOG(INFO) << "explicit_unroll: " << explicit_unroll; appears to be leftover debug logging that will generate noise in production builds. Consider removing it or using DLOG if debugging output is needed only in debug builds.

🧹 Suggested fix
     } else if (op->attr_key == "pragma_unroll_explicit") {
       bool explicit_unroll = Downcast<Integer>(op->value)->value;
-      LOG(INFO) << "explicit_unroll: " << explicit_unroll;
       std::swap(explicit_unroll, explicit_unroll_);
       Stmt ret = this->VisitStmt(op->body);
       std::swap(explicit_unroll, explicit_unroll_);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
} else if (op->attr_key == "pragma_unroll_explicit") {
bool explicit_unroll = Downcast<Integer>(op->value)->value;
LOG(INFO) << "explicit_unroll: " << explicit_unroll;
std::swap(explicit_unroll, explicit_unroll_);
Stmt ret = this->VisitStmt(op->body);
std::swap(explicit_unroll, explicit_unroll_);
return ret;
} else if (op->attr_key == "pragma_unroll_explicit") {
bool explicit_unroll = Downcast<Integer>(op->value)->value;
std::swap(explicit_unroll, explicit_unroll_);
Stmt ret = this->VisitStmt(op->body);
std::swap(explicit_unroll, explicit_unroll_);
return ret;
🤖 Prompt for AI Agents
In @src/transform/unroll_loop.cc around lines 119 - 125, Remove the leftover
debug log inside the pragma_unroll_explicit handling: delete or replace the
LOG(INFO) << "explicit_unroll: " << explicit_unroll; statement found in the
branch where op->attr_key == "pragma_unroll_explicit" (around the code that
reads bool explicit_unroll = Downcast<Integer>(op->value)->value and uses
explicit_unroll_ and VisitStmt(op->body)); if you still want conditional debug
output, change it to DLOG(INFO) instead of LOG(INFO).

} else {
return StmtExprMutator::VisitStmt_(op);
}
}

Stmt VisitStmt_(const ForNode *op) {
// Post order so we can collect more information
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<ForNode>();
int value = GetExtent(op);
// condition for auto unroll
bool auto_unroll =
(op->kind == ForKind::kSerial && value >= 0 &&
normal_loop_depth_ == 0 && unroll_depth_ <= auto_max_depth_);

auto_unroll = auto_unroll && (value * step_count_ <= auto_max_step_ ||
value <= auto_max_extent_);

if (op->kind == ForKind::kUnrolled) {
if (explicit_unroll_) {
ICHECK_GE(value, 0)
<< "Cannot unroll non-constant loop " << explicit_unroll_;
}
Comment on lines +146 to +148
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Confusing error message.

The error message appends explicit_unroll_ (a boolean flag) which doesn't provide useful context. Consider including the loop variable or extent instead.

📝 Suggested fix
       if (explicit_unroll_) {
         ICHECK_GE(value, 0)
-            << "Cannot unroll non-constant loop " << explicit_unroll_;
+            << "Cannot unroll non-constant loop with variable: " << op->loop_var;
       }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
ICHECK_GE(value, 0)
<< "Cannot unroll non-constant loop " << explicit_unroll_;
}
ICHECK_GE(value, 0)
<< "Cannot unroll non-constant loop with variable: " << op->loop_var;
}

auto_unroll = true;
}

// If a loop var is used as indices to a local memory, it must be unrolled
// so the local memory access can be turned into register access.
if (this->var_touched_local_.count(op->loop_var) && value > 0 &&
unroll_local_access_) {
auto_unroll = true;
}

if (auto_unroll) {
step_count_ *= value;
unroll_depth_ += 1;
} else {
normal_loop_depth_ += 1;
}

if ((auto_unroll && explicit_unroll_) ||
// unroll loops with extent = 1, no matter how many steps in body
(0 <= value && value <= auto_max_extent_ && auto_max_extent_ == 1)) {
return Unroll(op);
} else {
if (auto_unroll) {
if (op->kind != ForKind::kUnrolled) {
auto n = CopyOnWrite(op);
n->kind = ForKind::kUnrolled;
return For(n);
}
}
return stmt;
}
}

PrimExpr VisitExpr_(const BufferLoadNode *op) final {
if (unroll_local_access_) {
auto storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(op->buffer->data));
if (storage_scope.rank == runtime::StorageRank::kLocal ||
storage_scope.rank == runtime::StorageRank::kWarp) {
VarLocalAccessMarker marker(&var_touched_local_);
for (PrimExpr e : op->indices) {
marker(e);
}
}
}
return ffi::GetRef<PrimExpr>(op);
}

Stmt VisitStmt_(const BufferStoreNode *op) final {
++step_count_;
if (unroll_local_access_) {
auto storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(op->buffer->data));
if (storage_scope.rank == runtime::StorageRank::kLocal ||
storage_scope.rank == runtime::StorageRank::kWarp) {
VarLocalAccessMarker marker(&var_touched_local_);
for (PrimExpr e : op->indices) {
marker(e);
}
}
}
return StmtExprMutator::VisitStmt_(op);
}

Stmt VisitStmt_(const EvaluateNode *op) final {
++step_count_;
return StmtExprMutator::VisitStmt_(op);
}

Stmt VisitStmt_(const SeqStmtNode *op) final {
auto fmutate = [this](const Stmt &s) {
int step_count = step_count_;
int unroll_depth = unroll_depth_;
int normal_loop_depth = normal_loop_depth_;
step_count_ = 0;
unroll_depth_ = 0;
normal_loop_depth_ = 0;
Stmt ret = this->VisitStmt(s);
step_count_ += step_count;
normal_loop_depth_ = std::max(normal_loop_depth, normal_loop_depth_);
unroll_depth_ = std::max(unroll_depth_, unroll_depth);
return ret;
};
return StmtExprMutator::VisitSeqStmt_(op, false, fmutate);
}

Stmt Unroll(const ForNode *op) {
int value = GetExtent(op);
// For loop must have a constant integer extent
ICHECK_NE(value, -1) << "loop doesn't have a constant integer extent";
if (value == 0)
return Evaluate(0);
Stmt body = op->body;
ffi::Map<Var, PrimExpr> vmap;
ffi::Array<Stmt> unrolled;
for (int i = 0; i < value; ++i) {
vmap.Set(op->loop_var, op->min + make_const(op->loop_var.dtype(), i));
Stmt step = Substitute(body, vmap);
unrolled.push_back(step);
}
return SeqStmt::Flatten(unrolled);
}

private:
// returns the extent of the loop if it's a constant integer, otherwise return
// -1
int GetExtent(const ForNode *op) {
// constant folding.
PrimExpr extent = analyzer_.Simplify(op->extent);
const IntImmNode *v1 = extent.as<IntImmNode>();
int value = -1;
// integers that do not fit in int32_t are treated as symbolic,
// as it's impossible to unroll such large loops
if (v1 != nullptr && v1->value <= std::numeric_limits<int>::max()) {
value = static_cast<int>(v1->value);
}
return value;
}

// maximum number of step to perform auto unroll.
int auto_max_step_;
int auto_max_depth_;
// max extent of loop to auto unroll
// this does not count the total steps, only count the number of loops
int auto_max_extent_;
bool explicit_unroll_;
// Wether to unroll loops to local access.
bool unroll_local_access_{false};
Comment on lines +275 to +276
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Minor typo in comment.

"Wether" should be "Whether".

📝 Suggested fix
-  // Wether to unroll loops to local access.
+  // Whether to unroll loops to local access.
   bool unroll_local_access_{false};
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// Wether to unroll loops to local access.
bool unroll_local_access_{false};
// Whether to unroll loops to local access.
bool unroll_local_access_{false};
🤖 Prompt for AI Agents
In @src/transform/unroll_loop.cc around lines 275 - 276, Fix the typo in the
comment above the unroll_local_access_ field: change "Wether to unroll loops to
local access." to "Whether to unroll loops to local access." so the comment
correctly reads "Whether to unroll loops to local access." and remains
immediately above the bool unroll_local_access_{false}; declaration.

// Number of normal loops in scope
int normal_loop_depth_{0};
// number of unrolled cases in current scope.
int unroll_depth_{0};
// Number of total steps unrolled
int step_count_{0};
// set of indices touched during visit local memory
std::unordered_set<Var> var_touched_local_;
// analyzer
arith::Analyzer analyzer_;
};

Stmt UnrollLoop(Stmt stmt, UnrollLoopConfig cfg) {
Stmt ret = LoopUnroller(cfg->auto_max_step, cfg->auto_max_depth,
cfg->auto_max_extent, cfg->explicit_unroll,
cfg->unroll_local_access)(stmt);
if (!ret.same_as(stmt)) {
return ConvertSSA(ret);
} else {
return ret;
}
}

namespace transform {

using namespace tir::transform;

Pass UnrollLoop() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto *n = f.CopyOnWrite();
auto cfg = ctx->GetConfig<UnrollLoopConfig>("tl.UnrollLoop");
if (!cfg.defined()) {
cfg = AttrsWithDefaultValues<UnrollLoopConfig>();
}
n->body = tl::UnrollLoop(f->body, cfg.value());
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.UnrollLoop", {});
}

TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.UnrollLoop", UnrollLoop);
}

} // namespace transform

} // namespace tl
} // namespace tvm
18 changes: 18 additions & 0 deletions testing/python/language/test_tilelang_language_unroll.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import tilelang
import tilelang.testing
from tilelang import tvm as tvm
from tilelang import language as T
Expand Down Expand Up @@ -33,5 +34,22 @@ def main(A_ptr: T.handle):
assert "#pragma unroll 4" in kernel.get_kernel_source()


def test_unroll_with_extent_only():
"""Test T.unroll with only extent parameter."""

@tilelang.jit
def unroll_kernel():
out = T.empty((512,), dtype=T.float32)
with T.Kernel(1, threads=512):
tid = T.get_thread_binding()
for i in T.unroll(tid % 32):
out[i] = i
return out

kernel = unroll_kernel.compile()
source = kernel.get_kernel_source()
assert "#pragma unroll" in source


if __name__ == "__main__":
tilelang.testing.main()
2 changes: 1 addition & 1 deletion tilelang/engine/phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tir.transform.Simplify()(mod)
mod = tilelang.transform.VectorizeLoop(enable_vectorize=allow_vectorize(pass_ctx=pass_ctx))(mod)
mod = tilelang.transform.StorageRewrite()(mod)
mod = tir.transform.UnrollLoop()(mod)
mod = tilelang.transform.UnrollLoop()(mod)
mod = tir.transform.RenormalizeSplitPattern()(mod)
mod = tir.transform.Simplify()(mod)
mod = tir.transform.RemoveNoOp()(mod)
Expand Down
18 changes: 18 additions & 0 deletions tilelang/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,3 +510,21 @@ def LayoutReducer():
The transform pass object produced by the FFI backend.
"""
return _ffi_api.LayoutReducer() # type: ignore


def UnrollLoop():
"""Unroll loops as in Halide pipeline.

This pass unrolls loops based on configuration options including:
- auto_max_step: Threshold of number of steps to be automatically unrolled
- auto_max_depth: Maximum nested level of loops that can be automatically unrolled
- auto_max_extent: Maximum extent of loop that will be unrolled
- explicit_unroll: Whether to explicitly unroll instead of setting a pragma
- unroll_local_access: Whether to always unroll local access

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.UnrollLoop() # type: ignore
Loading