Skip to content

Commit 98e830b

Browse files
committed
[SCAN] Enable fix point analysis for scan
1 parent d114dfc commit 98e830b

File tree

12 files changed

+523
-91
lines changed

12 files changed

+523
-91
lines changed

include/tvm/schedule.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ enum AttachType : int {
2626
kNone = 0,
2727
kRoot = 1,
2828
kInline = 2,
29-
kScope = 3
29+
kScope = 3,
30+
kScanUpdate = 4
3031
};
3132

3233
/*! \brief IterVar type */

include/tvm/tensor.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@ class OperationNode : public FunctionBaseNode {
175175
virtual Type output_dtype(size_t i) const = 0;
176176
/*! \return shape of i-th output */
177177
virtual Array<Expr> output_shape(size_t i) const = 0;
178+
179+
static constexpr const char* _type_key = "Operation";
178180
};
179181

180182
// Implementations of inline functions

python/tvm/build.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ def build(sch,
6363
arg_list.append(x)
6464
else:
6565
raise ValueError("args must be Tensor, Buffer or Var")
66-
# lowering
66+
# normalize schedule first
67+
sch.normalize()
6768
bounds = schedule.InferBound(sch)
6869
stmt = schedule.ScheduleOps(sch, bounds)
6970
stmt = ir_pass.StorageFlatten(stmt, binds)

src/api/api_schedule.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ TVM_REGISTER_API(_schedule_AutoInlineElemWise)
3434
REGISTER_SCHEDULE_PASS1(InferBound);
3535
REGISTER_SCHEDULE_PASS1(CreateReadGraph);
3636
REGISTER_SCHEDULE_PASS2(PostDFSOrder);
37+
REGISTER_SCHEDULE_PASS1(ScanGetBody);
38+
REGISTER_SCHEDULE_PASS1(CreateAttachPath);
39+
REGISTER_SCHEDULE_PASS2(ScanFixPointAnalysis);
3740
REGISTER_SCHEDULE_PASS2(ScheduleOps);
3841

3942
} // namespace schedule

src/schedule/bound.cc

Lines changed: 36 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
/*!
23
* Copyright (c) 2016 by Contributors
34
* \file bound.cc
@@ -277,10 +278,12 @@ void BoundProp(const Operation& op,
277278
}
278279
}
279280

281+
280282
// Given the bound of output of op
281283
// Pass the bound to the related axis in op.
282284
void GatherOpBound(const ScanOpNode* scan,
283285
const Operation& op,
286+
const FeedGraph& fg,
284287
const std::unordered_map<Tensor, TensorDom>& tmap,
285288
std::unordered_map<IterVar, Range>* rmap) {
286289
CHECK(!rmap->count(scan->scan_axis));
@@ -299,21 +302,29 @@ void GatherOpBound(const ScanOpNode* scan,
299302
Range r = arith::Union(time_dom).cover_range(sdom);
300303
(*rmap)[scan->scan_axis] = Range::make_with_min_extent(
301304
sdom->min, ir::Simplify(r->extent + r->min - sdom->min));
305+
Array<Operation> body = ScanGetBody_(scan, fg);
306+
Map<IterVar, Expr> fix_pt = ScanFixPointAnalysis(op, body);
302307
// Update for spatial axis.
303308
size_t sp_idx = 0;
304309
for (size_t i = 0; i < output.size(); ++i) {
310+
const TensorDom& d = tmap.at(output[i]);
305311
for (size_t k = 0; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
306312
IterVar sp_ax = scan->spatial_axis_[sp_idx];
307313
CHECK(!rmap->count(sp_ax));
308-
// In default, we always need all spatial axis
309-
// Unless that axis only refers back to itself as a fixed point.
310-
// TODO(tqchen): Add fix point detection.
311-
(*rmap)[sp_ax] = sp_ax->dom;
314+
CHECK(fix_pt.count(sp_ax));
315+
if (fix_pt[sp_ax].as<ir::IntImm>()->value) {
316+
// fix point, we can slice it.
317+
(*rmap)[sp_ax] = arith::Union(d.data[k + 1]).cover_range(sp_ax->dom);
318+
} else {
319+
// not a fix point, need to include everything.
320+
(*rmap)[sp_ax] = sp_ax->dom;
321+
}
312322
}
313323
}
314324
}
315325

316326
void GatherOpBound(const Operation& op,
327+
const FeedGraph& fg,
317328
const std::unordered_map<Tensor, TensorDom>& tmap,
318329
std::unordered_map<IterVar, Range>* rmap) {
319330
if (op.as<ComputeOpNode>()) {
@@ -329,7 +340,7 @@ void GatherOpBound(const Operation& op,
329340
(*rmap)[compute->reduce_axis[i]] = compute->reduce_axis[i]->dom;
330341
}
331342
} else if (op.as<ScanOpNode>()) {
332-
GatherOpBound(op.as<ScanOpNode>(), op, tmap, rmap);
343+
GatherOpBound(op.as<ScanOpNode>(), op, fg, tmap, rmap);
333344
} else if (op.as<PlaceholderOpNode>()) {
334345
// dp nothing
335346
} else {
@@ -347,31 +358,26 @@ inline bool ScopeRelax(const IterVar& iv, const std::string& scope) {
347358
return StorageScope::make(scope).rank <= ThreadScope::make(iv->thread_tag).rank;
348359
}
349360

350-
// The map beteen tensor and operation it feeds ti
351-
using FeedGraph = std::unordered_map<Tensor, std::vector<Operation> >;
352-
353-
// AttachPath maps op-> a list of IterVar
354-
// That represents the loop nest op sits in from inner most to outermost
355-
using AttachPath = Map<Operation, Array<IterVar> >;
356-
357-
358361
void InferRootBound(const Stage& stage,
359362
const FeedGraph& feed_graph,
360363
const AttachPath& attach_path,
361364
std::unordered_map<IterVar, Range>* rmap) {
362365
if (stage->attach_type == kInline) return;
363-
if (stage->attach_type == kRoot || stage->attach_type == kNone) {
366+
if (stage->is_output ||
367+
stage->attach_type == kRoot ||
368+
stage->attach_type == kNone) {
364369
for (auto iv : OutputRelatedIterVars(stage->op)) {
365370
CHECK(iv->dom.defined());
366371
CHECK(!rmap->count(iv));
367372
(*rmap)[iv] = iv->dom;
368373
}
369374
return;
370375
}
371-
// Infer root bounds for the attached node.
372-
CHECK_EQ(stage->attach_type, kScope);
373-
Stage parent = stage->attach_stage;
374-
CHECK(parent.defined());
376+
// parent stage, if any
377+
Stage parent;
378+
if (stage->attach_type == kScope || stage->attach_type == kScanUpdate) {
379+
parent = stage->attach_stage;
380+
}
375381

376382
// The tensor domain.
377383
std::unordered_map<Tensor, TensorDom> tmap;
@@ -385,7 +391,7 @@ void InferRootBound(const Stage& stage,
385391
auto it = feed_graph.find(t);
386392
if (it != feed_graph.end()) {
387393
for (const Operation& op : it->second) {
388-
if (op != parent->op) {
394+
if (!parent.defined() || op != parent->op) {
389395
consumers.insert(op);
390396
} else {
391397
direct_consume_by_parent = true;
@@ -406,14 +412,17 @@ void InferRootBound(const Stage& stage,
406412
}
407413

408414
if (direct_consume_by_parent) {
415+
// parent stage if exist
416+
Stage parent = stage->attach_stage;
409417
// Bound inference logics in parent.
410418
std::unordered_map<IterVar, IntSet> up_state;
411419
bool fix_value = true;
412420
for (auto iv : parent->leaf_iter_vars) {
413421
Range vrange = rmap->at(iv);
414422
CHECK(is_zero(vrange->min))
415423
<< "InferBound requires every leaf iter var's min equals 0, "
416-
<< "call schedule.normalize to achieve this.";
424+
<< " call schedule.normalize to achieve this. "
425+
<< " stage=" << parent;
417426
// special optimization to remove trivial loop
418427
if (is_one(vrange->extent)) {
419428
up_state[iv] = IntSet::single_point(vrange->min);
@@ -464,8 +473,10 @@ void InferRootBound(const Stage& stage,
464473
for (const Operation& op : consumers) {
465474
std::unordered_map<const Variable*, IntSet> dom_map;
466475
bool found = false;
476+
Array<IterVar> attach = attach_path.at(stage->op);
477+
467478
for (IterVar iv : attach_path.at(op)) {
468-
if (iv == stage->attach_ivar) {
479+
if (attach.size() != 0 && iv == attach[0]) {
469480
found = true; break;
470481
}
471482
Range vrange = rmap->at(iv);
@@ -474,7 +485,7 @@ void InferRootBound(const Stage& stage,
474485
<< "call schedule.normalize to achieve this.";
475486
relax_set[iv->var.get()] = IntSet::range(vrange);
476487
}
477-
CHECK(found)
488+
CHECK(found || attach.size() == 0)
478489
<< "Invalid Schedule, cannot find the producer " << stage->op
479490
<< " along the loop nest specified by compute_at of consumer " << op;
480491
for (auto iv : OutputRelatedIterVars(op)) {
@@ -483,50 +494,15 @@ void InferRootBound(const Stage& stage,
483494
}
484495
BoundProp(op, dom_map, &tmap);
485496
}
486-
GatherOpBound(stage->op, tmap, rmap);
497+
GatherOpBound(stage->op, feed_graph, tmap, rmap);
487498
}
488499

489-
FeedGraph CreateFeedGraph(const Schedule& sch) {
500+
Map<IterVar, Range> InferBound(const Schedule& sch) {
490501
Array<Operation> roots;
491502
for (Operation op : sch->outputs) {
492503
roots.push_back(sch->stage_map[op]->op);
493504
}
494-
auto g = CreateReadGraph(roots);
495-
FeedGraph fg;
496-
for (auto kv : g) {
497-
for (Tensor t : kv.second) {
498-
fg[t].push_back(kv.first);
499-
}
500-
}
501-
return fg;
502-
}
503-
504-
// Create AttachPath that maps op-> a list of IterVar
505-
// That represents the loop nest op sits in from inner most to outermost
506-
AttachPath CreateAttachPath(const Schedule& sch) {
507-
AttachPath ret;
508-
for (Stage stage : sch->stages) {
509-
Array<IterVar> path;
510-
for (Stage s = stage; s->attach_type == kScope;) {
511-
IterVar attach_ivar = s->attach_ivar;
512-
s = s->attach_stage;
513-
bool start_attach = false;
514-
for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) {
515-
IterVar iv = s->leaf_iter_vars[i - 1];
516-
if (iv == attach_ivar) start_attach = true;
517-
if (start_attach) path.push_back(iv);
518-
}
519-
CHECK(start_attach)
520-
<< "Invalid Schedule: cannot find attach point " << attach_ivar
521-
<< " in the schedule of " << s->op;
522-
}
523-
ret.Set(stage->op, path);
524-
}
525-
return ret;
526-
}
527-
528-
Map<IterVar, Range> InferBound(const Schedule& sch) {
529-
FeedGraph feed_graph = CreateFeedGraph(sch);
505+
FeedGraph feed_graph = CreateFeedGraph(CreateReadGraph(roots));
530506
AttachPath attach_path = CreateAttachPath(sch);
531507

532508
std::unordered_map<IterVar, Range> ret;

0 commit comments

Comments
 (0)