1+
12/* !
23 * Copyright (c) 2016 by Contributors
34 * \file bound.cc
@@ -277,10 +278,12 @@ void BoundProp(const Operation& op,
277278 }
278279}
279280
281+
280282// Given the bound of output of op
281283// Pass the bound to the related axis in op.
282284void GatherOpBound (const ScanOpNode* scan,
283285 const Operation& op,
286+ const FeedGraph& fg,
284287 const std::unordered_map<Tensor, TensorDom>& tmap,
285288 std::unordered_map<IterVar, Range>* rmap) {
286289 CHECK (!rmap->count (scan->scan_axis ));
@@ -299,21 +302,29 @@ void GatherOpBound(const ScanOpNode* scan,
299302 Range r = arith::Union (time_dom).cover_range (sdom);
300303 (*rmap)[scan->scan_axis ] = Range::make_with_min_extent (
301304 sdom->min , ir::Simplify (r->extent + r->min - sdom->min ));
305+ Array<Operation> body = ScanGetBody_ (scan, fg);
306+ Map<IterVar, Expr> fix_pt = ScanFixPointAnalysis (op, body);
302307 // Update for spatial axis.
303308 size_t sp_idx = 0 ;
304309 for (size_t i = 0 ; i < output.size (); ++i) {
310+ const TensorDom& d = tmap.at (output[i]);
305311 for (size_t k = 0 ; k < scan->update [i]->shape .size (); ++k, ++sp_idx) {
306312 IterVar sp_ax = scan->spatial_axis_ [sp_idx];
307313 CHECK (!rmap->count (sp_ax));
308- // In default, we always need all spatial axis
309- // Unless that axis only refers back to itself as a fixed point.
310- // TODO(tqchen): Add fix point detection.
311- (*rmap)[sp_ax] = sp_ax->dom ;
314+ CHECK (fix_pt.count (sp_ax));
315+ if (fix_pt[sp_ax].as <ir::IntImm>()->value ) {
316+ // fix point, we can slice it.
317+ (*rmap)[sp_ax] = arith::Union (d.data [k + 1 ]).cover_range (sp_ax->dom );
318+ } else {
319+ // not a fix point, need to include everything.
320+ (*rmap)[sp_ax] = sp_ax->dom ;
321+ }
312322 }
313323 }
314324}
315325
316326void GatherOpBound (const Operation& op,
327+ const FeedGraph& fg,
317328 const std::unordered_map<Tensor, TensorDom>& tmap,
318329 std::unordered_map<IterVar, Range>* rmap) {
319330 if (op.as <ComputeOpNode>()) {
@@ -329,7 +340,7 @@ void GatherOpBound(const Operation& op,
329340 (*rmap)[compute->reduce_axis [i]] = compute->reduce_axis [i]->dom ;
330341 }
331342 } else if (op.as <ScanOpNode>()) {
332- GatherOpBound (op.as <ScanOpNode>(), op, tmap, rmap);
343+ GatherOpBound (op.as <ScanOpNode>(), op, fg, tmap, rmap);
333344 } else if (op.as <PlaceholderOpNode>()) {
334345 // dp nothing
335346 } else {
@@ -347,31 +358,26 @@ inline bool ScopeRelax(const IterVar& iv, const std::string& scope) {
347358 return StorageScope::make (scope).rank <= ThreadScope::make (iv->thread_tag ).rank ;
348359}
349360
350- // The map beteen tensor and operation it feeds ti
351- using FeedGraph = std::unordered_map<Tensor, std::vector<Operation> >;
352-
353- // AttachPath maps op-> a list of IterVar
354- // That represents the loop nest op sits in from inner most to outermost
355- using AttachPath = Map<Operation, Array<IterVar> >;
356-
357-
358361void InferRootBound (const Stage& stage,
359362 const FeedGraph& feed_graph,
360363 const AttachPath& attach_path,
361364 std::unordered_map<IterVar, Range>* rmap) {
362365 if (stage->attach_type == kInline ) return ;
363- if (stage->attach_type == kRoot || stage->attach_type == kNone ) {
366+ if (stage->is_output ||
367+ stage->attach_type == kRoot ||
368+ stage->attach_type == kNone ) {
364369 for (auto iv : OutputRelatedIterVars (stage->op )) {
365370 CHECK (iv->dom .defined ());
366371 CHECK (!rmap->count (iv));
367372 (*rmap)[iv] = iv->dom ;
368373 }
369374 return ;
370375 }
371- // Infer root bounds for the attached node.
372- CHECK_EQ (stage->attach_type , kScope );
373- Stage parent = stage->attach_stage ;
374- CHECK (parent.defined ());
376+ // parent stage, if any
377+ Stage parent;
378+ if (stage->attach_type == kScope || stage->attach_type == kScanUpdate ) {
379+ parent = stage->attach_stage ;
380+ }
375381
376382 // The tensor domain.
377383 std::unordered_map<Tensor, TensorDom> tmap;
@@ -385,7 +391,7 @@ void InferRootBound(const Stage& stage,
385391 auto it = feed_graph.find (t);
386392 if (it != feed_graph.end ()) {
387393 for (const Operation& op : it->second ) {
388- if (op != parent->op ) {
394+ if (!parent. defined () || op != parent->op ) {
389395 consumers.insert (op);
390396 } else {
391397 direct_consume_by_parent = true ;
@@ -406,14 +412,17 @@ void InferRootBound(const Stage& stage,
406412 }
407413
408414 if (direct_consume_by_parent) {
415+ // parent stage if exist
416+ Stage parent = stage->attach_stage ;
409417 // Bound inference logics in parent.
410418 std::unordered_map<IterVar, IntSet> up_state;
411419 bool fix_value = true ;
412420 for (auto iv : parent->leaf_iter_vars ) {
413421 Range vrange = rmap->at (iv);
414422 CHECK (is_zero (vrange->min ))
415423 << " InferBound requires every leaf iter var's min equals 0, "
416- << " call schedule.normalize to achieve this." ;
424+ << " call schedule.normalize to achieve this. "
425+ << " stage=" << parent;
417426 // special optimization to remove trivial loop
418427 if (is_one (vrange->extent )) {
419428 up_state[iv] = IntSet::single_point (vrange->min );
@@ -464,8 +473,10 @@ void InferRootBound(const Stage& stage,
464473 for (const Operation& op : consumers) {
465474 std::unordered_map<const Variable*, IntSet> dom_map;
466475 bool found = false ;
476+ Array<IterVar> attach = attach_path.at (stage->op );
477+
467478 for (IterVar iv : attach_path.at (op)) {
468- if (iv == stage-> attach_ivar ) {
479+ if (attach. size () != 0 && iv == attach[ 0 ] ) {
469480 found = true ; break ;
470481 }
471482 Range vrange = rmap->at (iv);
@@ -474,7 +485,7 @@ void InferRootBound(const Stage& stage,
474485 << " call schedule.normalize to achieve this." ;
475486 relax_set[iv->var .get ()] = IntSet::range (vrange);
476487 }
477- CHECK (found)
488+ CHECK (found || attach. size () == 0 )
478489 << " Invalid Schedule, cannot find the producer " << stage->op
479490 << " along the loop nest specified by compute_at of consumer " << op;
480491 for (auto iv : OutputRelatedIterVars (op)) {
@@ -483,50 +494,15 @@ void InferRootBound(const Stage& stage,
483494 }
484495 BoundProp (op, dom_map, &tmap);
485496 }
486- GatherOpBound (stage->op , tmap, rmap);
497+ GatherOpBound (stage->op , feed_graph, tmap, rmap);
487498}
488499
489- FeedGraph CreateFeedGraph (const Schedule& sch) {
500+ Map<IterVar, Range> InferBound (const Schedule& sch) {
490501 Array<Operation> roots;
491502 for (Operation op : sch->outputs ) {
492503 roots.push_back (sch->stage_map [op]->op );
493504 }
494- auto g = CreateReadGraph (roots);
495- FeedGraph fg;
496- for (auto kv : g) {
497- for (Tensor t : kv.second ) {
498- fg[t].push_back (kv.first );
499- }
500- }
501- return fg;
502- }
503-
504- // Create AttachPath that maps op-> a list of IterVar
505- // That represents the loop nest op sits in from inner most to outermost
506- AttachPath CreateAttachPath (const Schedule& sch) {
507- AttachPath ret;
508- for (Stage stage : sch->stages ) {
509- Array<IterVar> path;
510- for (Stage s = stage; s->attach_type == kScope ;) {
511- IterVar attach_ivar = s->attach_ivar ;
512- s = s->attach_stage ;
513- bool start_attach = false ;
514- for (size_t i = s->leaf_iter_vars .size (); i != 0 ; --i) {
515- IterVar iv = s->leaf_iter_vars [i - 1 ];
516- if (iv == attach_ivar) start_attach = true ;
517- if (start_attach) path.push_back (iv);
518- }
519- CHECK (start_attach)
520- << " Invalid Schedule: cannot find attach point " << attach_ivar
521- << " in the schedule of " << s->op ;
522- }
523- ret.Set (stage->op , path);
524- }
525- return ret;
526- }
527-
528- Map<IterVar, Range> InferBound (const Schedule& sch) {
529- FeedGraph feed_graph = CreateFeedGraph (sch);
505+ FeedGraph feed_graph = CreateFeedGraph (CreateReadGraph (roots));
530506 AttachPath attach_path = CreateAttachPath (sch);
531507
532508 std::unordered_map<IterVar, Range> ret;
0 commit comments