Skip to content

Commit 445d36e

Browse files
author
Siyuan Feng
committed
[TIR] StmtFunctor RenewDefs
In this PR, I introduce a StmtFunctor `RenewDefs` for deep copy all definition nodes in PrimFunc (including Var, Buffer, and IterVar). This functor can create a new PrimFunc with the same behavior as the old one but contains different Nodes. This Functor may help TIR fusion or inline multiple PrimFuncs
1 parent 62e0470 commit 445d36e

File tree

6 files changed

+480
-3
lines changed

6 files changed

+480
-3
lines changed

include/tvm/tir/stmt_functor.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#ifndef TVM_TIR_STMT_FUNCTOR_H_
2727
#define TVM_TIR_STMT_FUNCTOR_H_
2828

29+
#include <tvm/tir/function.h>
2930
#include <tvm/node/functor.h>
3031
#include <tvm/tir/expr.h>
3132
#include <tvm/tir/expr_functor.h>
@@ -413,6 +414,16 @@ inline T Substitute(T input, const std::unordered_map<const VarNode*, PrimExpr>&
413414
*/
414415
TVM_DLL void PreOrderVisit(const ObjectRef& stmt_or_expr,
415416
const std::function<bool(const ObjectRef&)>& fvisit);
417+
418+
class PrimFunc;
419+
/*!
420+
* \brief Renew the definition nodes for a TIR, including Var, Buffer and IterVar.
421+
* This pass works as a simple DeepCopy to duplicate a function with different Vars and
422+
* Buffers but the same behavior
423+
* \param func The input PrimFunc.
424+
* \return The renewed func.
425+
*/
426+
TVM_DLL PrimFunc RenewDefs(const PrimFunc& func);
416427
} // namespace tir
417428
} // namespace tvm
418429

python/tvm/tir/stmt_functor.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""Statement functor utilities for IR transformations"""
18+
from .function import PrimFunc
1819
from . import _ffi_api
1920

2021

@@ -75,3 +76,21 @@ def substitute(node, vmap):
7576
The result.
7677
"""
7778
return _ffi_api.Substitute(node, vmap) # type: ignore
79+
80+
81+
def renew_defs(func: PrimFunc):
82+
"""Re-generate the definition nodes for a TIR, including VarDef, BufferDef.
83+
This pass works as a simple DeepCopy to duplicate a function with different Vars and
84+
Buffers but the same behavior
85+
86+
Parameters
87+
----------
88+
func: PrimFunc
89+
The input function
90+
91+
Returns
92+
-------
93+
result : PrimFunc
94+
The new generated func.
95+
"""
96+
return _ffi_api.RenewDefs(func) # type: ignore

src/autotvm/feature_visitor.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,14 @@ void FeatureVisitor::VisitStmt_(const ForNode* op) {
6161

6262
// parallel axis, virtual thread
6363
void FeatureVisitor::VisitStmt_(const AttrStmtNode* op) {
64-
if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) {
64+
if (op->attr_key == tir::attr::thread_extent || op->attr_key == tir::attr::virtual_thread) {
6565
Var var = op->node.as<tir::IterVarNode>()->var;
6666
const auto* extent = op->value.as<IntImmNode>();
6767
ICHECK(extent);
6868

6969
std::string name = var.get()->name_hint;
7070
AnnotationType ann = kParallel;
71-
if (op->attr_key == attr::thread_extent) {
71+
if (op->attr_key == tir::attr::thread_extent) {
7272
if (name == "blockIdx.x")
7373
ann = kBlockX;
7474
else if (name == "blockIdx.y")

src/tir/transforms/remove_no_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
namespace tvm {
3434
namespace tir {
3535

36-
// Mark the statment of each stage.
36+
// Mark the statement of each stage.
3737
class NoOpRemover : public StmtMutator {
3838
public:
3939
Stmt VisitStmt_(const LetStmtNode* op) final {

src/tir/transforms/renew_defs.cc

Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,293 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file renew_defs.cc
22+
* \brief Renew the definition nodes for a TIR, including Var, Buffer and IterVar.
23+
*/
24+
25+
#include <tvm/tir/stmt_functor.h>
26+
#include <tvm/tir/transform.h>
27+
28+
#include "../ir/functor_common.h"
29+
30+
namespace tvm {
31+
namespace tir {
32+
33+
#define STMT_REGENERATE_VAR_DEF(NODE, FIELD) \
34+
Stmt VisitStmt_(const NODE* op) final { \
35+
Var new_var = this->ReDefineVar(op->FIELD); \
36+
Stmt stmt = StmtExprMutator::VisitStmt_(op); \
37+
op = stmt.as<NODE>(); \
38+
ICHECK(op != nullptr); \
39+
auto n = make_object<NODE>(*op); \
40+
n->FIELD = std::move(new_var); \
41+
return Stmt(n); \
42+
}
43+
44+
class RenewDefMutator : public StmtExprMutator {
45+
public:
46+
static PrimFunc Transform(const PrimFunc& func) {
47+
RenewDefMutator generator;
48+
// Redefine params
49+
Array<Var> params;
50+
for (const auto& param : func->params) {
51+
params.push_back(generator.ReDefineVar(param));
52+
}
53+
// Redefine buffers
54+
Map<tir::Var, Buffer> buffer_map;
55+
for (const auto& kv : func->buffer_map) {
56+
const Var& param = kv.first;
57+
const Buffer& buffer = kv.second;
58+
Var new_param = Downcast<Var>(generator.VisitExpr(param));
59+
Buffer new_buffer = generator.VisitBuffer(buffer, true);
60+
buffer_map.Set(new_param, new_buffer);
61+
}
62+
// Visit body
63+
Stmt body = generator(func->body);
64+
// Recreate function
65+
auto n = make_object<PrimFuncNode>(*func.get());
66+
n->params = std::move(params);
67+
n->buffer_map = std::move(buffer_map);
68+
n->body = std::move(body);
69+
return PrimFunc(n);
70+
}
71+
72+
private:
73+
Stmt operator()(Stmt stmt) {
74+
// override StmtMutator::operator() to disable copy_on_write
75+
// Since this pass tries to explict create a new function rather than update the existing one
76+
allow_copy_on_write_ = false;
77+
return VisitStmt(stmt);
78+
}
79+
80+
PrimExpr VisitExpr(const PrimExpr& expr) final {
81+
auto it = remap_.find(expr);
82+
if (it != remap_.end()) {
83+
return Downcast<PrimExpr>((*it).second);
84+
} else {
85+
return ExprMutator::VisitExpr(expr);
86+
}
87+
}
88+
89+
private:
90+
STMT_REGENERATE_VAR_DEF(LetStmtNode, var);
91+
STMT_REGENERATE_VAR_DEF(AllocateNode, buffer_var);
92+
STMT_REGENERATE_VAR_DEF(AllocateConstNode, buffer_var);
93+
STMT_REGENERATE_VAR_DEF(ForNode, loop_var);
94+
95+
Stmt VisitStmt_(const BlockNode* op) final {
96+
// Step 0. Re-define Itervars
97+
Array<IterVar> iter_vars = MutateArray(
98+
op->iter_vars, std::bind(&RenewDefMutator::VisitIterVar, this, std::placeholders::_1));
99+
100+
// Step 1. Re-define buffers allocate under the block
101+
Array<Buffer> alloc_buffers = MutateArray(
102+
op->alloc_buffers,
103+
std::bind(&RenewDefMutator::VisitBuffer, this, std::placeholders::_1, /*define=*/true));
104+
105+
// Step 2. Re-define match_buffers
106+
Array<MatchBufferRegion> match_buffers =
107+
MutateArray(op->match_buffers,
108+
std::bind(&RenewDefMutator::VisitMatchBuffer, this, std::placeholders::_1));
109+
110+
// Step 3. Visit body
111+
Stmt stmt = StmtExprMutator::VisitStmt_(op);
112+
op = stmt.as<BlockNode>();
113+
ICHECK(op);
114+
115+
// Step 4. Revisit access region
116+
Array<BufferRegion> reads = MutateArray(
117+
op->reads, std::bind(&RenewDefMutator::VisitBufferRegion, this, std::placeholders::_1));
118+
Array<BufferRegion> writes = MutateArray(
119+
op->writes, std::bind(&RenewDefMutator::VisitBufferRegion, this, std::placeholders::_1));
120+
121+
// Step 5. Regenerate block. Since the defs are changed, we need to create a new block
122+
auto n = make_object<BlockNode>(*op);
123+
n->iter_vars = std::move(iter_vars);
124+
n->alloc_buffers = std::move(alloc_buffers);
125+
n->match_buffers = std::move(match_buffers);
126+
n->reads = std::move(reads);
127+
n->writes = std::move(writes);
128+
129+
return Stmt(n);
130+
}
131+
132+
Stmt VisitStmt_(const BufferStoreNode* op) final {
133+
Stmt stmt = StmtExprMutator::VisitStmt_(op);
134+
op = stmt.as<BufferStoreNode>();
135+
ICHECK(op != nullptr);
136+
Buffer buffer = VisitDeclBuffer(op->buffer);
137+
if (buffer.same_as(op->buffer)) {
138+
return stmt;
139+
} else {
140+
auto n = make_object<BufferStoreNode>(*op);
141+
n->buffer = std::move(buffer);
142+
return BufferStore(n);
143+
}
144+
}
145+
146+
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
147+
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
148+
op = expr.as<BufferLoadNode>();
149+
ICHECK(op != nullptr);
150+
Buffer buffer = VisitDeclBuffer(op->buffer);
151+
if (buffer.same_as(op->buffer)) {
152+
return expr;
153+
} else {
154+
auto n = make_object<BufferLoadNode>(*op);
155+
n->buffer = std::move(buffer);
156+
return BufferLoad(n);
157+
}
158+
}
159+
160+
PrimExpr VisitExpr_(const LoadNode* op) final {
161+
LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
162+
return PrimExpr();
163+
}
164+
165+
Stmt VisitStmt_(const StoreNode* op) final {
166+
LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
167+
return Stmt();
168+
}
169+
170+
private:
171+
Var ReDefineVar(const Var& var) {
172+
Var new_var = Var(make_object<VarNode>(*var.get()));
173+
this->AddDefRemap(var, new_var);
174+
return new_var;
175+
}
176+
177+
template <typename T>
178+
void AddDefRemap(const T& source, const T& target) {
179+
ICHECK(remap_.count(source) == 0);
180+
remap_.Set(source, target);
181+
}
182+
183+
Buffer VisitBuffer(const Buffer& buffer, bool define = false) {
184+
auto it = remap_.find(buffer);
185+
if (it != remap_.end()) {
186+
return Downcast<Buffer>((*it).second);
187+
}
188+
ICHECK(define);
189+
190+
auto redefine_if_is_var = [this](const PrimExpr& expr) -> PrimExpr {
191+
auto it = remap_.find(expr);
192+
if (it != remap_.end()) {
193+
return Downcast<PrimExpr>((*it).second);
194+
} else if (const VarNode* var = expr.as<VarNode>()) {
195+
return this->ReDefineVar(GetRef<Var>(var));
196+
} else {
197+
return ExprMutator::VisitExpr(expr);
198+
}
199+
};
200+
201+
// update data
202+
Var data = Downcast<Var>(redefine_if_is_var(buffer->data));
203+
// update shape
204+
Array<PrimExpr> shape = MutateArray(buffer->shape, redefine_if_is_var);
205+
// update strides
206+
Array<PrimExpr> strides = MutateArray(buffer->strides, redefine_if_is_var);
207+
// update elem_offset
208+
PrimExpr elem_offset = redefine_if_is_var(buffer->elem_offset);
209+
210+
auto n = make_object<BufferNode>(*buffer.get());
211+
n->data = std::move(data);
212+
n->shape = std::move(shape);
213+
n->strides = std::move(strides);
214+
n->elem_offset = std::move(elem_offset);
215+
Buffer new_buffer(n);
216+
this->AddDefRemap(buffer, new_buffer);
217+
return new_buffer;
218+
}
219+
220+
IterVar VisitIterVar(const IterVar& iter_var) {
221+
auto it = remap_.find(iter_var);
222+
if (it != remap_.end()) {
223+
return Downcast<IterVar>((*it).second);
224+
}
225+
PrimExpr min = VisitExpr(iter_var->dom->min);
226+
PrimExpr extent = VisitExpr(iter_var->dom->extent);
227+
IterVar new_iter_var(Range(min, extent), ReDefineVar(iter_var->var), iter_var->iter_type,
228+
iter_var->thread_tag);
229+
this->AddDefRemap(iter_var, new_iter_var);
230+
return new_iter_var;
231+
}
232+
233+
Buffer VisitDeclBuffer(const Buffer& buffer) {
234+
// Due to a recent PR, we can allow undefined buffer appearing in BufferLoad/Store.
235+
// We need to remap them but will not create new var
236+
auto it = remap_.find(buffer);
237+
if (it != remap_.end()) {
238+
return Downcast<Buffer>((*it).second);
239+
}
240+
Var data = Downcast<Var>(VisitExpr(buffer->data));
241+
Array<PrimExpr> shape = MutateArray(
242+
buffer->shape, std::bind(&RenewDefMutator::VisitExpr, this, std::placeholders::_1));
243+
Array<PrimExpr> strides = MutateArray(
244+
buffer->strides, std::bind(&RenewDefMutator::VisitExpr, this, std::placeholders::_1));
245+
PrimExpr elem_offset = VisitExpr(buffer->elem_offset);
246+
247+
auto n = make_object<BufferNode>(*buffer.get());
248+
n->data = std::move(data);
249+
n->shape = std::move(shape);
250+
n->strides = std::move(strides);
251+
n->elem_offset = std::move(elem_offset);
252+
Buffer new_buffer(n);
253+
this->AddDefRemap(buffer, new_buffer);
254+
return new_buffer;
255+
}
256+
257+
MatchBufferRegion VisitMatchBuffer(const MatchBufferRegion& match_buffer) {
258+
Buffer buffer = VisitBuffer(match_buffer->buffer, /*define=*/true);
259+
BufferRegion region = VisitBufferRegion(match_buffer->source);
260+
return MatchBufferRegion(std::move(buffer), std::move(region));
261+
}
262+
263+
Range VisitRange(const Range& range) {
264+
PrimExpr min = VisitExpr(range->min);
265+
PrimExpr extent = VisitExpr(range->extent);
266+
if (min.same_as(range->min) && extent.same_as(range->extent)) {
267+
return range;
268+
} else {
269+
return Range::FromMinExtent(std::move(min), std::move(extent));
270+
}
271+
}
272+
273+
BufferRegion VisitBufferRegion(const BufferRegion& buffer_region) {
274+
Buffer buffer = VisitBuffer(buffer_region->buffer);
275+
Array<Range> region =
276+
MutateArray(buffer_region->region,
277+
std::bind(&RenewDefMutator::VisitRange, this, std::placeholders::_1));
278+
if (buffer.same_as(buffer_region->buffer) && region.same_as(buffer_region->region)) {
279+
return buffer_region;
280+
} else {
281+
return BufferRegion(std::move(buffer), std::move(region));
282+
}
283+
}
284+
285+
Map<ObjectRef, ObjectRef> remap_;
286+
};
287+
288+
PrimFunc RenewDefs(const PrimFunc& func) { return RenewDefMutator::Transform(func); }
289+
290+
TVM_REGISTER_GLOBAL("tir.RenewDefs").set_body_typed(RenewDefs);
291+
292+
} // namespace tir
293+
} // namespace tvm

0 commit comments

Comments
 (0)