@@ -391,81 +391,163 @@ void RelaxBufferRegions(const Map<Var, PrimExpr>& binding,
391391 * domain
392392 * \param provided The provided integer set to cover the required domain
393393 * \param required The required domain to be covered
394+ * \param dim_max The maximum index bound by the buffer shape
394395 * \param analyzer The arithmetic analyzer
395396 */
396- std::pair<Var, arith::IntSet> SolveBlockVarDomain (const arith::IntSet& provided,
397- const arith::IntSet& required,
398- arith::Analyzer* analyzer) {
397+ std::pair<Var, BlockVarDomainInfo> SolveBlockVarDomain (const arith::IntSet& provided,
398+ const arith::IntSet& required,
399+ PrimExpr dim_max,
400+ arith::Analyzer* analyzer) {
399401 PrimExpr provided_min = analyzer->Simplify (provided.min ());
400402 PrimExpr provided_max = analyzer->Simplify (provided.max ());
401403 PrimExpr required_min = analyzer->Simplify (required.min ());
402404 PrimExpr required_max = analyzer->Simplify (required.max ());
403- PrimExpr dom_min{ nullptr }, dom_max{ nullptr } ;
404- Var dom_var{ObjectPtr<VarNode>{ nullptr }} ;
405+ arith::IntSet var_dom, var_bound ;
406+ Optional<Var> var ;
405407 arith::PVar<Var> p_v;
406408 arith::PVar<PrimExpr> p_e;
407409 if ((p_v * p_e).Match (provided_min) || (p_e * p_v).Match (provided_min)) {
408410 PrimExpr e = p_e.Eval ();
409- dom_var = p_v.Eval ();
410- dom_min = floordiv (required_min, e);
411- dom_max = floordiv (required_max , e);
411+ var = p_v.Eval ();
412+ var_dom = arith::IntSet::Interval ( floordiv (required_min, e), floordiv (required_max, e) );
413+ var_bound = arith::IntSet::Interval ( 0 , floordiv (dim_max , e) );
412414 } else if (analyzer->CanProveEqual (provided_min, provided_max)) {
413415 if (p_v.Match (provided_min)) {
414- dom_var = p_v.Eval ();
415- dom_min = required_min;
416- dom_max = required_max ;
416+ var = p_v.Eval ();
417+ var_dom = arith::IntSet::Interval ( required_min, required_max) ;
418+ var_bound = arith::IntSet::Interval ( 0 , dim_max) ;
417419 } else {
418420 arith::PVar<PrimExpr> p_f;
419421 if ((floordiv (p_v, p_f)).Match (provided_min)) {
420422 // a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + fac - 1)
421423 PrimExpr fac = p_f.Eval ();
422424 if (analyzer->CanProveGreaterEqual (fac, 1 )) {
423- dom_var = p_v.Eval ();
424- dom_min = required_min * fac;
425- dom_max = analyzer->Simplify (required_max * fac + fac - 1 );
425+ var = p_v.Eval ();
426+ var_dom = arith::IntSet::Interval (required_min * fac,
427+ analyzer->Simplify (required_max * fac + fac - 1 ));
428+ var_bound = arith::IntSet::Interval (0 , analyzer->Simplify (dim_max * fac + fac - 1 ));
426429 }
427430 } else if ((floormod (p_v, p_f).Match (provided_min))) {
428431 // generally domain of (x % fac) enforce no constraints to domain of x
429- dom_var = p_v.Eval ();
430- return std::make_pair (dom_var, arith::IntSet::Nothing ());
432+ return {p_v.Eval (), BlockVarDomainInfo ()};
431433 }
432434 }
433435 }
434- ICHECK (dom_var .defined ()) << " ValueError: BufferRegion pattern match failed: " << provided_min;
435- return std::make_pair (dom_var, arith::IntSet::Interval (dom_min, dom_max)) ;
436+ ICHECK (var .defined ()) << " ValueError: BufferRegion pattern match failed: " << provided_min;
437+ return {var. value (), BlockVarDomainInfo{var_dom, var_bound}} ;
436438}
437439
438440/* !
439- * \brief Calculate and update the iteration domain info to fully cover the required domain
440- * \param provided The provided integer set to cover the required domain
441- * \param required The required domain to be covered
442- * \param required_bound The additional region bound of the required domain to be covered
441+ * \brief Calculate and update the iteration domain info to fully cover the required domain in
442+ * dimension-wise fashion. The region relation on each buffer dimension is independently estimated.
443+ * \param buffer The accessed buffer
444+ * \param provided_region The provided NDIntSet to cover the required domain
445+ * \param required_region The required NDIntSet domain to be covered
446+ * \param analyzer The arithmetic analyzer
443447 * \param iter_doms The result iteration domains to be updated
448+ */
449+ void UpdateBlockVarDomainDimwise (
450+ const BufferNode* buffer, const NDIntSet& provided_region, const NDIntSet& required_region,
451+ arith::Analyzer* analyzer, std::unordered_map<const VarNode*, BlockVarDomainInfo>* iter_doms) {
452+ size_t ndim = buffer->shape .size ();
453+ for (size_t i = 0 ; i < ndim; ++i) {
454+ arith::IntSet provided = provided_region[i];
455+ arith::IntSet required = required_region[i];
456+ PrimExpr dim_max = max (buffer->shape [i] - 1 , 0 );
457+
458+ if (provided.IsSinglePoint () && is_const_int (provided.min ())) {
459+ ICHECK (required.IsSinglePoint () && analyzer->CanProveEqual (provided.min (), required.min ()));
460+ continue ;
461+ }
462+
463+ auto [var, dom_info] = SolveBlockVarDomain (provided, required, dim_max, analyzer);
464+ auto it = iter_doms->find (var.get ());
465+ if (it != iter_doms->end ()) {
466+ it->second .Union (dom_info);
467+ } else {
468+ ICHECK (analyzer->CanProveEqual (provided.min (), required.min ()));
469+ ICHECK (analyzer->CanProveEqual (provided.max (), required.max ()));
470+ }
471+ }
472+ }
473+
474+ /* ! \brief Helper function to implement intset version of `InverseAffineIterMap`. */
475+ Map<Var, arith::IntSet> InverseAffineIterMap (const Array<arith::IterSumExpr>& iter_map,
476+ const NDIntSet& outputs, arith::Analyzer* analyzer) {
477+ Array<PrimExpr> min_point, max_point;
478+ min_point.reserve (outputs.size ());
479+ max_point.reserve (outputs.size ());
480+ for (const auto & intset : outputs) {
481+ ICHECK (intset.HasLowerBound () && intset.HasUpperBound ());
482+ min_point.push_back (intset.min ());
483+ max_point.push_back (intset.max ());
484+ }
485+ auto rev_min = InverseAffineIterMap (iter_map, min_point);
486+ auto rev_max = InverseAffineIterMap (iter_map, max_point);
487+ Map<Var, arith::IntSet> dom_map;
488+ for (const auto & kv : rev_min) {
489+ const Var& var = kv.first ;
490+ auto it = rev_max.find (var);
491+ ICHECK (it != rev_max.end ()); // InverseAffineIterMap's result vars are assumed stable
492+ const PrimExpr& rev_min_point = kv.second ;
493+ const PrimExpr& rev_max_point = (*it).second ;
494+ dom_map.Set (var,
495+ arith::IntSet::Interval (analyzer->Simplify (min (rev_min_point, rev_max_point)),
496+ analyzer->Simplify (max (rev_min_point, rev_max_point))));
497+ }
498+ return dom_map;
499+ }
500+
501+ /* !
502+ * \brief Calculate and update the iteration domain info to fully cover the required domain
503+ * with affine analysis. It requires bijective mapping of block var to provided region points.
504+ * \param buffer The accessed buffer
505+ * \param iter_vars The list of block vars to cover the required region
506+ * \param provided_region The provided NDIntSet to cover the required domain
507+ * \param required_region The required NDIntSet domain to be covered
444508 * \param analyzer The arithmetic analyzer
509+ * \param iter_doms The result iteration domains to be updated
510+ * \returns bool. Denotes whether update success
445511 */
446- void UpdateBlockVarDomain (const arith::IntSet& provided, const arith::IntSet& required,
447- const arith::IntSet& required_bound,
448- std::unordered_map<const VarNode*, BlockVarDomainInfo>* iter_doms,
449- arith::Analyzer* analyzer) {
450- if (provided.IsSinglePoint () && is_const_int (provided.min ())) {
451- ICHECK (required.IsSinglePoint () && analyzer->CanProveEqual (provided.min (), required.min ()));
452- ICHECK (required_bound.IsSinglePoint () &&
453- analyzer->CanProveEqual (provided.min (), required_bound.min ()));
454- return ;
512+ bool UpdateBlockVarDomainAffine (const BufferNode* buffer, const Array<IterVar>& iter_vars,
513+ const NDIntSet& provided_region, const NDIntSet& required_region,
514+ arith::Analyzer* analyzer,
515+ std::unordered_map<const VarNode*, BlockVarDomainInfo>* iter_doms) {
516+ // we only support single point provided region now, which could cover most cases
517+ for (const auto & intset : provided_region) {
518+ if (!intset.IsSinglePoint ()) return false ;
519+ }
520+ // calculate forward mapping (block vars -> provided region point)
521+ Map<Var, Range> dom_map;
522+ for (const IterVar& iter_var : iter_vars) {
523+ dom_map.Set (iter_var->var , iter_var->dom );
455524 }
456- auto var_with_dom = SolveBlockVarDomain (provided, required, analyzer);
457- auto var_with_bound = SolveBlockVarDomain (provided, required_bound, analyzer);
458- const Var& var = var_with_dom.first ;
459- const auto & var_dom = var_with_dom.second ;
460- const auto & var_bound = var_with_bound.second ;
461- ICHECK (var.same_as (var_with_bound.first ));
462- auto it = iter_doms->find (var.get ());
463- if (it != iter_doms->end ()) {
464- it->second .Union ({var_dom, var_bound});
465- } else {
466- ICHECK (analyzer->CanProveEqual (provided.min (), required.min ()));
467- ICHECK (analyzer->CanProveEqual (provided.max (), required.max ()));
525+ size_t ndim = buffer->shape .size ();
526+ Array<PrimExpr> provide_indices;
527+ provide_indices.reserve (ndim);
528+ for (size_t i = 0 ; i < ndim; ++i) {
529+ provide_indices.push_back (provided_region[i].min ());
530+ }
531+ auto res = arith::DetectIterMap (provide_indices, dom_map, const_true (),
532+ arith::IterMapLevel::Bijective, analyzer, false );
533+ if (res->indices .empty ()) {
534+ return false ;
468535 }
536+ // calculate backward mapping (required region point -> block vars)
537+ NDIntSet required_bound;
538+ for (size_t i = 0 ; i < ndim; ++i) {
539+ required_bound.push_back (
540+ arith::IntSet::Interval (make_zero (buffer->shape [i]->dtype ), max (buffer->shape [i] - 1 , 0 )));
541+ }
542+ Map<Var, arith::IntSet> var_dom = InverseAffineIterMap (res->indices , required_region, analyzer);
543+ Map<Var, arith::IntSet> var_bound = InverseAffineIterMap (res->indices , required_bound, analyzer);
544+ for (const auto & kv : var_dom) {
545+ const Var& var = kv.first ;
546+ auto it = var_bound.find (var);
547+ ICHECK (it != var_bound.end ()); // InverseAffineIterMap's result vars are assumed stable
548+ (*iter_doms)[var.get ()].Union (BlockVarDomainInfo{kv.second , (*it).second });
549+ }
550+ return true ;
469551}
470552
471553/* !
@@ -501,13 +583,10 @@ std::vector<BlockVarDomainInfo> CalculateBlockVarDomain(
501583 NDIntSet provided_region = support::NDIntSetUnion (many_provided_regions);
502584 ICHECK_EQ (provided_region.size (), buffer->shape .size ());
503585 ICHECK_EQ (required_region.size (), buffer->shape .size ());
504- // For each dimension, update the iteration domain
505- int ndim = buffer->shape .size ();
506- for (int i = 0 ; i < ndim; ++i) {
507- arith::IntSet provided = provided_region[i];
508- arith::IntSet required = required_region[i];
509- arith::IntSet required_bound = arith::IntSet::FromMinExtent (Integer (0 ), buffer->shape [i]);
510- UpdateBlockVarDomain (provided, required, required_bound, &iter_doms, analyzer);
586+ // Try update iter var domains with current required and provided region pair.
587+ if (!UpdateBlockVarDomainAffine (buffer, iter_vars, provided_region, required_region, analyzer,
588+ &iter_doms)) {
589+ UpdateBlockVarDomainDimwise (buffer, provided_region, required_region, analyzer, &iter_doms);
511590 }
512591 }
513592 // Union the iter var domains, put them in the same order of block vars, and return
0 commit comments