Skip to content

Commit 7d0e66a

Browse files
author
Siyuan Feng
committed
In this PR, I introduce a StmtFunctor ReGenerateDef for deep copy all definition nodes in PrimFunc (including Var, Buffer, and IterVar). This functor can create a new PrimFunc with the same behavior as the old one but contains different Nodes.
This Functor may help TIR fusion or inline multiple PrimFuncs
1 parent 62e0470 commit 7d0e66a

File tree

5 files changed

+398
-1
lines changed

5 files changed

+398
-1
lines changed

include/tvm/tir/stmt_functor.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,16 @@ inline T Substitute(T input, const std::unordered_map<const VarNode*, PrimExpr>&
413413
*/
414414
TVM_DLL void PreOrderVisit(const ObjectRef& stmt_or_expr,
415415
const std::function<bool(const ObjectRef&)>& fvisit);
416+
417+
class PrimFunc;
418+
/*!
419+
* \brief Re-generate the definition nodes for a TIR, including VarDef, BufferDef.
420+
* This pass works as a simple DeepCopy to duplicate a function with different Vars and
421+
* Buffers but the same behavior
422+
* \param func The input PrimFunc.
423+
* \return The new generated func.
424+
*/
425+
TVM_DLL PrimFunc ReGenerateDef(const PrimFunc& func);
416426
} // namespace tir
417427
} // namespace tvm
418428

python/tvm/tir/stmt_functor.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""Statement functor utilities for IR transformations"""
18+
from tvm.tir.function import PrimFunc
1819
from . import _ffi_api
1920

2021

@@ -75,3 +76,21 @@ def substitute(node, vmap):
7576
The result.
7677
"""
7778
return _ffi_api.Substitute(node, vmap) # type: ignore
79+
80+
81+
def regenerate_def(func: PrimFunc):
82+
"""Re-generate the definition nodes for a TIR, including VarDef, BufferDef.
83+
This pass works as a simple DeepCopy to duplicate a function with different Vars and
84+
Buffers but the same behavior
85+
86+
Parameters
87+
----------
88+
func: PrimFunc
89+
The input function
90+
91+
Returns
92+
-------
93+
result : PrimFunc
94+
The new generated func.
95+
"""
96+
return _ffi_api.ReGenerateDef(func) # type: ignore
Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file regenerate_def.cc
22+
* \brief This pass will regenerate (deep-copy) all defs, including VarDef, BufferDef
23+
*/
24+
25+
#include <tvm/tir/stmt_functor.h>
26+
#include <tvm/tir/transform.h>
27+
28+
#include "../ir/functor_common.h"
29+
30+
namespace tvm {
31+
namespace tir {
32+
33+
#define STMT_REGENERATE_VAR_DEF(NODE, FIELD) \
34+
Stmt VisitStmt_(const NODE* op) final { \
35+
Var new_var = this->ReDefineVar(op->FIELD); \
36+
Stmt stmt = StmtExprMutator::VisitStmt_(op); \
37+
op = stmt.as<NODE>(); \
38+
ICHECK(op != nullptr); \
39+
auto n = make_object<NODE>(*op); \
40+
n->FIELD = std::move(new_var); \
41+
return Stmt(n); \
42+
}
43+
44+
class DefRegenerator : public StmtExprMutator {
45+
public:
46+
static PrimFunc Transform(const PrimFunc& func) {
47+
DefRegenerator generator;
48+
// Redefine params
49+
Array<Var> params;
50+
for (const auto& param : func->params) {
51+
params.push_back(generator.ReDefineVar(param));
52+
}
53+
// Redefine buffers
54+
Map<tir::Var, Buffer> buffer_map;
55+
for (const auto& kv : func->buffer_map) {
56+
const Var& param = kv.first;
57+
const Buffer& buffer = kv.second;
58+
Var new_param = Downcast<Var>(generator.VisitExpr(param));
59+
Buffer new_buffer = generator.VisitBuffer(buffer, true);
60+
buffer_map.Set(new_param, new_buffer);
61+
}
62+
// Visit body
63+
Stmt body = generator(func->body);
64+
// Recreate function
65+
auto n = make_object<PrimFuncNode>(*func.get());
66+
n->params = std::move(params);
67+
n->buffer_map = std::move(buffer_map);
68+
n->body = std::move(body);
69+
return PrimFunc(n);
70+
}
71+
72+
private:
73+
Stmt operator()(Stmt stmt) {
74+
// overide StmtMutator::operator() to disable copy_on_write
75+
// Since this pass tries to explict create a new function rather than update the existing one
76+
allow_copy_on_write_ = false;
77+
return VisitStmt(stmt);
78+
}
79+
80+
Stmt VisitStmt(const Stmt& stmt) final {
81+
auto it = remap_.find(stmt);
82+
if (it != remap_.end()) {
83+
return Downcast<Stmt>((*it).second);
84+
} else {
85+
return StmtMutator::VisitStmt(stmt);
86+
}
87+
}
88+
89+
PrimExpr VisitExpr(const PrimExpr& expr) final {
90+
auto it = remap_.find(expr);
91+
if (it != remap_.end()) {
92+
return Downcast<PrimExpr>((*it).second);
93+
} else {
94+
return ExprMutator::VisitExpr(expr);
95+
}
96+
}
97+
98+
private:
99+
STMT_REGENERATE_VAR_DEF(LetStmtNode, var);
100+
STMT_REGENERATE_VAR_DEF(AllocateNode, buffer_var);
101+
STMT_REGENERATE_VAR_DEF(AllocateConstNode, buffer_var);
102+
STMT_REGENERATE_VAR_DEF(ForNode, loop_var);
103+
104+
Stmt VisitStmt_(const BlockNode* op) final {
105+
// Step 0. Re-define Itervars
106+
Array<IterVar> iter_vars = MutateArray(
107+
op->iter_vars, std::bind(&DefRegenerator::VisitIterVar, this, std::placeholders::_1));
108+
109+
// Step 1. Re-define buffers allocate under the block
110+
Array<Buffer> alloc_buffers = MutateArray(
111+
op->alloc_buffers,
112+
std::bind(&DefRegenerator::VisitBuffer, this, std::placeholders::_1, /*define=*/true));
113+
114+
// Step 2. Re-define match_buffers
115+
Array<MatchBufferRegion> match_buffers =
116+
MutateArray(op->match_buffers,
117+
std::bind(&DefRegenerator::VisitMatchBuffer, this, std::placeholders::_1));
118+
119+
// Step 3. ReDefine match_buffer
120+
Stmt stmt = StmtExprMutator::VisitStmt_(op);
121+
op = stmt.as<BlockNode>();
122+
ICHECK(op);
123+
124+
// Step 4. Revisit access region
125+
Array<BufferRegion> reads = MutateArray(
126+
op->reads, std::bind(&DefRegenerator::VisitBufferRegion, this, std::placeholders::_1));
127+
Array<BufferRegion> writes = MutateArray(
128+
op->writes, std::bind(&DefRegenerator::VisitBufferRegion, this, std::placeholders::_1));
129+
130+
// Step 5. Regenerate block. Since the defs are changed, we need to create a new block
131+
auto n = make_object<BlockNode>(*op);
132+
n->iter_vars = std::move(iter_vars);
133+
n->alloc_buffers = std::move(alloc_buffers);
134+
n->match_buffers = std::move(match_buffers);
135+
n->reads = std::move(reads);
136+
n->writes = std::move(writes);
137+
138+
return Stmt(n);
139+
}
140+
141+
Stmt VisitStmt_(const BufferStoreNode* op) final {
142+
Stmt stmt = StmtExprMutator::VisitStmt_(op);
143+
op = stmt.as<BufferStoreNode>();
144+
ICHECK(op != nullptr);
145+
auto it = remap_.find(op->buffer);
146+
if (it != remap_.end()) {
147+
auto n = make_object<BufferStoreNode>(*op);
148+
n->buffer = Downcast<Buffer>((*it).second);
149+
return BufferStore(n);
150+
} else {
151+
return stmt;
152+
}
153+
}
154+
155+
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
156+
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
157+
op = expr.as<BufferLoadNode>();
158+
ICHECK(op != nullptr);
159+
auto it = remap_.find(op->buffer);
160+
if (it != remap_.end()) {
161+
auto n = make_object<BufferLoadNode>(*op);
162+
n->buffer = Downcast<Buffer>((*it).second);
163+
return BufferLoad(n);
164+
} else {
165+
return expr;
166+
}
167+
}
168+
169+
PrimExpr VisitExpr_(const LoadNode* op) final {
170+
LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
171+
return PrimExpr();
172+
}
173+
174+
Stmt VisitStmt_(const StoreNode* op) final {
175+
LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
176+
return Stmt();
177+
}
178+
179+
private:
180+
Var ReDefineVar(const Var& var) {
181+
Var new_var = Var(make_object<VarNode>(*var.get()));
182+
this->AddDefRemap(var, new_var);
183+
return new_var;
184+
}
185+
186+
template <typename T>
187+
void AddDefRemap(const T& source, const T& target) {
188+
ICHECK(remap_.count(source) == 0);
189+
remap_.Set(source, target);
190+
}
191+
192+
Buffer VisitBuffer(const Buffer& buffer, bool define = false) {
193+
auto it = remap_.find(buffer);
194+
if (it != remap_.end()) {
195+
return Downcast<Buffer>((*it).second);
196+
}
197+
ICHECK(define);
198+
199+
auto redefine_if_is_var = [this](const PrimExpr& expr) -> PrimExpr {
200+
if (const VarNode* var = expr.as<VarNode>()) {
201+
return ReDefineVar(GetRef<Var>(var));
202+
} else {
203+
return VisitExpr(expr);
204+
}
205+
};
206+
207+
// update data
208+
Var data = Downcast<Var>(redefine_if_is_var(buffer->data));
209+
// update shape
210+
Array<PrimExpr> shape = MutateArray(buffer->shape, redefine_if_is_var);
211+
// update strides
212+
Array<PrimExpr> strides = MutateArray(buffer->strides, redefine_if_is_var);
213+
// update elem_offset
214+
PrimExpr elem_offset = redefine_if_is_var(buffer->elem_offset);
215+
216+
auto n = make_object<BufferNode>(*buffer.get());
217+
n->data = std::move(data);
218+
n->shape = std::move(shape);
219+
n->strides = std::move(strides);
220+
n->elem_offset = std::move(elem_offset);
221+
Buffer new_buffer(n);
222+
this->AddDefRemap(buffer, new_buffer);
223+
return new_buffer;
224+
}
225+
226+
IterVar VisitIterVar(const IterVar& iter_var) {
227+
auto it = remap_.find(iter_var);
228+
if (it != remap_.end()) {
229+
return Downcast<IterVar>((*it).second);
230+
}
231+
PrimExpr min = VisitExpr(iter_var->dom->min);
232+
PrimExpr extent = VisitExpr(iter_var->dom->extent);
233+
IterVar new_iter_var(Range(min, extent), ReDefineVar(iter_var->var), iter_var->iter_type,
234+
iter_var->thread_tag);
235+
this->AddDefRemap(iter_var, new_iter_var);
236+
return new_iter_var;
237+
}
238+
239+
MatchBufferRegion VisitMatchBuffer(const MatchBufferRegion& match_buffer) {
240+
Buffer buffer = VisitBuffer(match_buffer->buffer, /*define=*/true);
241+
BufferRegion region = VisitBufferRegion(match_buffer->source);
242+
return MatchBufferRegion(std::move(buffer), std::move(region));
243+
}
244+
245+
Range VisitRange(const Range& range) {
246+
PrimExpr min = VisitExpr(range->min);
247+
PrimExpr extent = VisitExpr(range->extent);
248+
if (min.same_as(range->min) && extent.same_as(range->extent)) {
249+
return range;
250+
} else {
251+
return Range::FromMinExtent(std::move(min), std::move(extent));
252+
}
253+
}
254+
255+
BufferRegion VisitBufferRegion(const BufferRegion& buffer_region) {
256+
Buffer buffer = VisitBuffer(buffer_region->buffer);
257+
Array<Range> region = MutateArray(
258+
buffer_region->region, std::bind(&DefRegenerator::VisitRange, this, std::placeholders::_1));
259+
if (buffer.same_as(buffer_region->buffer) && region.same_as(buffer_region->region)) {
260+
return buffer_region;
261+
} else {
262+
return BufferRegion(std::move(buffer), std::move(region));
263+
}
264+
}
265+
266+
Map<ObjectRef, ObjectRef> remap_;
267+
};
268+
269+
PrimFunc ReGenerateDef(const PrimFunc& func) { return DefRegenerator::Transform(func); }
270+
271+
TVM_REGISTER_GLOBAL("tir.ReGenerateDef").set_body_typed(ReGenerateDef);
272+
273+
} // namespace tir
274+
} // namespace tvm

src/tir/transforms/remove_no_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
namespace tvm {
3434
namespace tir {
3535

36-
// Mark the statment of each stage.
36+
// Mark the statement of each stage.
3737
class NoOpRemover : public StmtMutator {
3838
public:
3939
Stmt VisitStmt_(const LetStmtNode* op) final {

0 commit comments

Comments
 (0)