-
Notifications
You must be signed in to change notification settings - Fork 439
[Enhancement] Improve unroll loop functionality for dynamic extent and corresponding test case #1654
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Enhancement] Improve unroll loop functionality for dynamic extent and corresponding test case #1654
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,325 @@ | ||||||||||||||||||||||||||||
| /* | ||||||||||||||||||||||||||||
| * Licensed to the Apache Software Foundation (ASF) under one | ||||||||||||||||||||||||||||
| * or more contributor license agreements. See the NOTICE file | ||||||||||||||||||||||||||||
| * distributed with this work for additional information | ||||||||||||||||||||||||||||
| * regarding copyright ownership. The ASF licenses this file | ||||||||||||||||||||||||||||
| * to you under the Apache License, Version 2.0 (the | ||||||||||||||||||||||||||||
| * "License"); you may not use this file except in compliance | ||||||||||||||||||||||||||||
| * with the License. You may obtain a copy of the License at | ||||||||||||||||||||||||||||
| * | ||||||||||||||||||||||||||||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||||||||||||||
| * | ||||||||||||||||||||||||||||
| * Unless required by applicable law or agreed to in writing, | ||||||||||||||||||||||||||||
| * software distributed under the License is distributed on an | ||||||||||||||||||||||||||||
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||||||||||||||||||||||||||||
| * KIND, either express or implied. See the License for the | ||||||||||||||||||||||||||||
| * specific language governing permissions and limitations | ||||||||||||||||||||||||||||
| * under the License. | ||||||||||||||||||||||||||||
| */ | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| /*! | ||||||||||||||||||||||||||||
| * Loop unrolling as in Halide pipeline. | ||||||||||||||||||||||||||||
| * \file unroll_loop.cc | ||||||||||||||||||||||||||||
| */ | ||||||||||||||||||||||||||||
| // Unrolls the loop as in Halide pipeline. | ||||||||||||||||||||||||||||
| #include <tvm/arith/analyzer.h> | ||||||||||||||||||||||||||||
| #include <tvm/ffi/function.h> | ||||||||||||||||||||||||||||
| #include <tvm/ffi/reflection/registry.h> | ||||||||||||||||||||||||||||
| #include <tvm/tir/expr.h> | ||||||||||||||||||||||||||||
| #include <tvm/tir/op.h> | ||||||||||||||||||||||||||||
| #include <tvm/tir/stmt_functor.h> | ||||||||||||||||||||||||||||
| #include <tvm/tir/transform.h> | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| #include <unordered_set> | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| #include "runtime/thread_storage_scope.h" | ||||||||||||||||||||||||||||
| #include "tir/transforms/ir_utils.h" | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| namespace tvm { | ||||||||||||||||||||||||||||
| namespace tl { | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| using namespace tir; | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| struct UnrollLoopConfigNode | ||||||||||||||||||||||||||||
| : public AttrsNodeReflAdapter<UnrollLoopConfigNode> { | ||||||||||||||||||||||||||||
| int auto_max_step; | ||||||||||||||||||||||||||||
| int auto_max_depth; | ||||||||||||||||||||||||||||
| int auto_max_extent; | ||||||||||||||||||||||||||||
| int explicit_unroll; | ||||||||||||||||||||||||||||
| int unroll_local_access; | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| static void RegisterReflection() { | ||||||||||||||||||||||||||||
| namespace refl = tvm::ffi::reflection; | ||||||||||||||||||||||||||||
| refl::ObjectDef<UnrollLoopConfigNode>() | ||||||||||||||||||||||||||||
| .def_ro("auto_max_step", &UnrollLoopConfigNode::auto_max_step, | ||||||||||||||||||||||||||||
| "Threshold of number of steps in the loop to be automatically " | ||||||||||||||||||||||||||||
| "unrolled", | ||||||||||||||||||||||||||||
| refl::DefaultValue(0)) | ||||||||||||||||||||||||||||
| .def_ro("auto_max_depth", &UnrollLoopConfigNode::auto_max_depth, | ||||||||||||||||||||||||||||
| "The maximum nested level of loops that can be automatically " | ||||||||||||||||||||||||||||
| "unrolled.", | ||||||||||||||||||||||||||||
| refl::DefaultValue(8)) | ||||||||||||||||||||||||||||
| .def_ro("auto_max_extent", &UnrollLoopConfigNode::auto_max_extent, | ||||||||||||||||||||||||||||
| "The maximum extent` of loop that will be unrolled.", | ||||||||||||||||||||||||||||
| refl::DefaultValue(0)) | ||||||||||||||||||||||||||||
| .def_ro( | ||||||||||||||||||||||||||||
| "explicit_unroll", &UnrollLoopConfigNode::explicit_unroll, | ||||||||||||||||||||||||||||
| "Whether to explicitly unroll the loop instead of setting a pragma", | ||||||||||||||||||||||||||||
| refl::DefaultValue(true)) | ||||||||||||||||||||||||||||
| .def_ro( | ||||||||||||||||||||||||||||
| "unroll_local_access", &UnrollLoopConfigNode::unroll_local_access, | ||||||||||||||||||||||||||||
| "Whether to always unroll local access", refl::DefaultValue(false)); | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.transform.UnrollLoopConfig", | ||||||||||||||||||||||||||||
| UnrollLoopConfigNode, BaseAttrsNode); | ||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| class UnrollLoopConfig : public Attrs { | ||||||||||||||||||||||||||||
| public: | ||||||||||||||||||||||||||||
| TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(UnrollLoopConfig, Attrs, | ||||||||||||||||||||||||||||
| UnrollLoopConfigNode); | ||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| TVM_FFI_STATIC_INIT_BLOCK() { UnrollLoopConfigNode::RegisterReflection(); } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| TVM_REGISTER_PASS_CONFIG_OPTION("tl.UnrollLoop", UnrollLoopConfig); | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| class VarLocalAccessMarker : public ExprVisitor { | ||||||||||||||||||||||||||||
| public: | ||||||||||||||||||||||||||||
| explicit VarLocalAccessMarker(std::unordered_set<Var> *var_touched_local) | ||||||||||||||||||||||||||||
| : var_touched_local_(var_touched_local) {} | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| void VisitExpr_(const VarNode *op) final { | ||||||||||||||||||||||||||||
| var_touched_local_->insert(ffi::GetRef<Var>(op)); | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| private: | ||||||||||||||||||||||||||||
| std::unordered_set<Var> *var_touched_local_; | ||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| // The Visitor is used to check whether var is used as write index in a local | ||||||||||||||||||||||||||||
| // memory If a loop var is used as indices to a local memory, it must be | ||||||||||||||||||||||||||||
| // unrolled so the local memory access can be turned into register access. | ||||||||||||||||||||||||||||
| class LoopUnroller : public StmtExprMutator { | ||||||||||||||||||||||||||||
| public: | ||||||||||||||||||||||||||||
| explicit LoopUnroller(int auto_max_step, int auto_max_depth, | ||||||||||||||||||||||||||||
| int auto_max_extent, bool explicit_unroll, | ||||||||||||||||||||||||||||
| bool unroll_local_access) | ||||||||||||||||||||||||||||
| : auto_max_step_(auto_max_step), auto_max_depth_(auto_max_depth), | ||||||||||||||||||||||||||||
| auto_max_extent_(auto_max_extent), explicit_unroll_(explicit_unroll), | ||||||||||||||||||||||||||||
| unroll_local_access_(unroll_local_access) {} | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| Stmt VisitStmt_(const AttrStmtNode *op) final { | ||||||||||||||||||||||||||||
| if (op->attr_key == "pragma_auto_unroll_max_step") { | ||||||||||||||||||||||||||||
| int value = static_cast<int>(Downcast<Integer>(op->value)->value); | ||||||||||||||||||||||||||||
| std::swap(value, auto_max_step_); | ||||||||||||||||||||||||||||
| Stmt ret = this->VisitStmt(op->body); | ||||||||||||||||||||||||||||
| std::swap(value, auto_max_step_); | ||||||||||||||||||||||||||||
| return ret; | ||||||||||||||||||||||||||||
| } else if (op->attr_key == "pragma_unroll_explicit") { | ||||||||||||||||||||||||||||
| bool explicit_unroll = Downcast<Integer>(op->value)->value; | ||||||||||||||||||||||||||||
| LOG(INFO) << "explicit_unroll: " << explicit_unroll; | ||||||||||||||||||||||||||||
| std::swap(explicit_unroll, explicit_unroll_); | ||||||||||||||||||||||||||||
| Stmt ret = this->VisitStmt(op->body); | ||||||||||||||||||||||||||||
| std::swap(explicit_unroll, explicit_unroll_); | ||||||||||||||||||||||||||||
| return ret; | ||||||||||||||||||||||||||||
|
Comment on lines
+119
to
+125
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove debug logging statement.
🧹 Suggested fix } else if (op->attr_key == "pragma_unroll_explicit") {
bool explicit_unroll = Downcast<Integer>(op->value)->value;
- LOG(INFO) << "explicit_unroll: " << explicit_unroll;
std::swap(explicit_unroll, explicit_unroll_);
Stmt ret = this->VisitStmt(op->body);
std::swap(explicit_unroll, explicit_unroll_);📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||
| return StmtExprMutator::VisitStmt_(op); | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| Stmt VisitStmt_(const ForNode *op) { | ||||||||||||||||||||||||||||
| // Post order so we can collect more information | ||||||||||||||||||||||||||||
| Stmt stmt = StmtExprMutator::VisitStmt_(op); | ||||||||||||||||||||||||||||
| op = stmt.as<ForNode>(); | ||||||||||||||||||||||||||||
| int value = GetExtent(op); | ||||||||||||||||||||||||||||
| // condition for auto unroll | ||||||||||||||||||||||||||||
| bool auto_unroll = | ||||||||||||||||||||||||||||
| (op->kind == ForKind::kSerial && value >= 0 && | ||||||||||||||||||||||||||||
| normal_loop_depth_ == 0 && unroll_depth_ <= auto_max_depth_); | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| auto_unroll = auto_unroll && (value * step_count_ <= auto_max_step_ || | ||||||||||||||||||||||||||||
| value <= auto_max_extent_); | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| if (op->kind == ForKind::kUnrolled) { | ||||||||||||||||||||||||||||
| if (explicit_unroll_) { | ||||||||||||||||||||||||||||
| ICHECK_GE(value, 0) | ||||||||||||||||||||||||||||
| << "Cannot unroll non-constant loop " << explicit_unroll_; | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
Comment on lines
+146
to
+148
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Confusing error message. The error message appends 📝 Suggested fix if (explicit_unroll_) {
ICHECK_GE(value, 0)
- << "Cannot unroll non-constant loop " << explicit_unroll_;
+ << "Cannot unroll non-constant loop with variable: " << op->loop_var;
}📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||
| auto_unroll = true; | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| // If a loop var is used as indices to a local memory, it must be unrolled | ||||||||||||||||||||||||||||
| // so the local memory access can be turned into register access. | ||||||||||||||||||||||||||||
| if (this->var_touched_local_.count(op->loop_var) && value > 0 && | ||||||||||||||||||||||||||||
| unroll_local_access_) { | ||||||||||||||||||||||||||||
| auto_unroll = true; | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| if (auto_unroll) { | ||||||||||||||||||||||||||||
| step_count_ *= value; | ||||||||||||||||||||||||||||
| unroll_depth_ += 1; | ||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||
| normal_loop_depth_ += 1; | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| if ((auto_unroll && explicit_unroll_) || | ||||||||||||||||||||||||||||
| // unroll loops with extent = 1, no matter how many steps in body | ||||||||||||||||||||||||||||
| (0 <= value && value <= auto_max_extent_ && auto_max_extent_ == 1)) { | ||||||||||||||||||||||||||||
| return Unroll(op); | ||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||
| if (auto_unroll) { | ||||||||||||||||||||||||||||
| if (op->kind != ForKind::kUnrolled) { | ||||||||||||||||||||||||||||
| auto n = CopyOnWrite(op); | ||||||||||||||||||||||||||||
| n->kind = ForKind::kUnrolled; | ||||||||||||||||||||||||||||
| return For(n); | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| return stmt; | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| PrimExpr VisitExpr_(const BufferLoadNode *op) final { | ||||||||||||||||||||||||||||
| if (unroll_local_access_) { | ||||||||||||||||||||||||||||
| auto storage_scope = | ||||||||||||||||||||||||||||
| runtime::StorageScope::Create(GetPtrStorageScope(op->buffer->data)); | ||||||||||||||||||||||||||||
| if (storage_scope.rank == runtime::StorageRank::kLocal || | ||||||||||||||||||||||||||||
| storage_scope.rank == runtime::StorageRank::kWarp) { | ||||||||||||||||||||||||||||
| VarLocalAccessMarker marker(&var_touched_local_); | ||||||||||||||||||||||||||||
| for (PrimExpr e : op->indices) { | ||||||||||||||||||||||||||||
| marker(e); | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| return ffi::GetRef<PrimExpr>(op); | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| Stmt VisitStmt_(const BufferStoreNode *op) final { | ||||||||||||||||||||||||||||
| ++step_count_; | ||||||||||||||||||||||||||||
| if (unroll_local_access_) { | ||||||||||||||||||||||||||||
| auto storage_scope = | ||||||||||||||||||||||||||||
| runtime::StorageScope::Create(GetPtrStorageScope(op->buffer->data)); | ||||||||||||||||||||||||||||
| if (storage_scope.rank == runtime::StorageRank::kLocal || | ||||||||||||||||||||||||||||
| storage_scope.rank == runtime::StorageRank::kWarp) { | ||||||||||||||||||||||||||||
| VarLocalAccessMarker marker(&var_touched_local_); | ||||||||||||||||||||||||||||
| for (PrimExpr e : op->indices) { | ||||||||||||||||||||||||||||
| marker(e); | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| return StmtExprMutator::VisitStmt_(op); | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| Stmt VisitStmt_(const EvaluateNode *op) final { | ||||||||||||||||||||||||||||
| ++step_count_; | ||||||||||||||||||||||||||||
| return StmtExprMutator::VisitStmt_(op); | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| Stmt VisitStmt_(const SeqStmtNode *op) final { | ||||||||||||||||||||||||||||
| auto fmutate = [this](const Stmt &s) { | ||||||||||||||||||||||||||||
| int step_count = step_count_; | ||||||||||||||||||||||||||||
| int unroll_depth = unroll_depth_; | ||||||||||||||||||||||||||||
| int normal_loop_depth = normal_loop_depth_; | ||||||||||||||||||||||||||||
| step_count_ = 0; | ||||||||||||||||||||||||||||
| unroll_depth_ = 0; | ||||||||||||||||||||||||||||
| normal_loop_depth_ = 0; | ||||||||||||||||||||||||||||
| Stmt ret = this->VisitStmt(s); | ||||||||||||||||||||||||||||
| step_count_ += step_count; | ||||||||||||||||||||||||||||
| normal_loop_depth_ = std::max(normal_loop_depth, normal_loop_depth_); | ||||||||||||||||||||||||||||
| unroll_depth_ = std::max(unroll_depth_, unroll_depth); | ||||||||||||||||||||||||||||
| return ret; | ||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||
| return StmtExprMutator::VisitSeqStmt_(op, false, fmutate); | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| Stmt Unroll(const ForNode *op) { | ||||||||||||||||||||||||||||
| int value = GetExtent(op); | ||||||||||||||||||||||||||||
| // For loop must have a constant integer extent | ||||||||||||||||||||||||||||
| ICHECK_NE(value, -1) << "loop doesn't have a constant integer extent"; | ||||||||||||||||||||||||||||
| if (value == 0) | ||||||||||||||||||||||||||||
| return Evaluate(0); | ||||||||||||||||||||||||||||
| Stmt body = op->body; | ||||||||||||||||||||||||||||
| ffi::Map<Var, PrimExpr> vmap; | ||||||||||||||||||||||||||||
| ffi::Array<Stmt> unrolled; | ||||||||||||||||||||||||||||
| for (int i = 0; i < value; ++i) { | ||||||||||||||||||||||||||||
| vmap.Set(op->loop_var, op->min + make_const(op->loop_var.dtype(), i)); | ||||||||||||||||||||||||||||
| Stmt step = Substitute(body, vmap); | ||||||||||||||||||||||||||||
| unrolled.push_back(step); | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| return SeqStmt::Flatten(unrolled); | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| private: | ||||||||||||||||||||||||||||
| // returns the extent of the loop if it's a constant integer, otherwise return | ||||||||||||||||||||||||||||
| // -1 | ||||||||||||||||||||||||||||
| int GetExtent(const ForNode *op) { | ||||||||||||||||||||||||||||
| // constant folding. | ||||||||||||||||||||||||||||
| PrimExpr extent = analyzer_.Simplify(op->extent); | ||||||||||||||||||||||||||||
| const IntImmNode *v1 = extent.as<IntImmNode>(); | ||||||||||||||||||||||||||||
| int value = -1; | ||||||||||||||||||||||||||||
| // integers that do not fit in int32_t are treated as symbolic, | ||||||||||||||||||||||||||||
| // as it's impossible to unroll such large loops | ||||||||||||||||||||||||||||
| if (v1 != nullptr && v1->value <= std::numeric_limits<int>::max()) { | ||||||||||||||||||||||||||||
| value = static_cast<int>(v1->value); | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| return value; | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| // maximum number of step to perform auto unroll. | ||||||||||||||||||||||||||||
| int auto_max_step_; | ||||||||||||||||||||||||||||
| int auto_max_depth_; | ||||||||||||||||||||||||||||
| // max extent of loop to auto unroll | ||||||||||||||||||||||||||||
| // this does not count the total steps, only count the number of loops | ||||||||||||||||||||||||||||
| int auto_max_extent_; | ||||||||||||||||||||||||||||
| bool explicit_unroll_; | ||||||||||||||||||||||||||||
| // Wether to unroll loops to local access. | ||||||||||||||||||||||||||||
| bool unroll_local_access_{false}; | ||||||||||||||||||||||||||||
|
Comment on lines
+275
to
+276
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor typo in comment. "Wether" should be "Whether". 📝 Suggested fix- // Wether to unroll loops to local access.
+ // Whether to unroll loops to local access.
bool unroll_local_access_{false};📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||
| // Number of normal loops in scope | ||||||||||||||||||||||||||||
| int normal_loop_depth_{0}; | ||||||||||||||||||||||||||||
| // number of unrolled cases in current scope. | ||||||||||||||||||||||||||||
| int unroll_depth_{0}; | ||||||||||||||||||||||||||||
| // Number of total steps unrolled | ||||||||||||||||||||||||||||
| int step_count_{0}; | ||||||||||||||||||||||||||||
| // set of indices touched during visit local memory | ||||||||||||||||||||||||||||
| std::unordered_set<Var> var_touched_local_; | ||||||||||||||||||||||||||||
| // analyzer | ||||||||||||||||||||||||||||
| arith::Analyzer analyzer_; | ||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| Stmt UnrollLoop(Stmt stmt, UnrollLoopConfig cfg) { | ||||||||||||||||||||||||||||
| Stmt ret = LoopUnroller(cfg->auto_max_step, cfg->auto_max_depth, | ||||||||||||||||||||||||||||
| cfg->auto_max_extent, cfg->explicit_unroll, | ||||||||||||||||||||||||||||
| cfg->unroll_local_access)(stmt); | ||||||||||||||||||||||||||||
| if (!ret.same_as(stmt)) { | ||||||||||||||||||||||||||||
| return ConvertSSA(ret); | ||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||
| return ret; | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| namespace transform { | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| using namespace tir::transform; | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| Pass UnrollLoop() { | ||||||||||||||||||||||||||||
| auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { | ||||||||||||||||||||||||||||
| auto *n = f.CopyOnWrite(); | ||||||||||||||||||||||||||||
| auto cfg = ctx->GetConfig<UnrollLoopConfig>("tl.UnrollLoop"); | ||||||||||||||||||||||||||||
| if (!cfg.defined()) { | ||||||||||||||||||||||||||||
| cfg = AttrsWithDefaultValues<UnrollLoopConfig>(); | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
| n->body = tl::UnrollLoop(f->body, cfg.value()); | ||||||||||||||||||||||||||||
| return f; | ||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||
| return CreatePrimFuncPass(pass_func, 0, "tl.UnrollLoop", {}); | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| TVM_FFI_STATIC_INIT_BLOCK() { | ||||||||||||||||||||||||||||
| namespace refl = tvm::ffi::reflection; | ||||||||||||||||||||||||||||
| refl::GlobalDef().def("tl.transform.UnrollLoop", UnrollLoop); | ||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| } // namespace transform | ||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||
| } // namespace tl | ||||||||||||||||||||||||||||
| } // namespace tvm | ||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor typo in documentation string.
There's an extra backtick character in the description string:
"The maximum extent\of loop..."should be"The maximum extent of loop..."`.📝 Suggested fix
.def_ro("auto_max_extent", &UnrollLoopConfigNode::auto_max_extent, - "The maximum extent` of loop that will be unrolled.", + "The maximum extent of loop that will be unrolled.", refl::DefaultValue(0))📝 Committable suggestion
🤖 Prompt for AI Agents