@@ -79,6 +79,22 @@ void PassUpThreadBinding(const Stage& stage, std::unordered_map<IterVar, bool>*
7979 } else if (const RebaseNode* s = rel.as <RebaseNode>()) {
8080 state[s->parent ] = state[s->rebased ];
8181 } else if (rel.as <SingletonNode>()) {
82+ } else if (const TransformNode* s = rel.as <TransformNode>()) {
83+ // Currently, this marks all original iter vars as deriving from
84+ // a thread bind if any of the transformed variables are bound,
85+ // even if the inverse expression for that iter var doesn't
86+ // depend on the bound variable.
87+
88+ // TODO(Lunderberg): For each of original variable, check
89+ // whether any variable in the inverse expression for it has a
90+ // thread binding.
91+ bool is_thread_binding = false ;
92+ for (const auto & iter_var : s->transformed_variables ) {
93+ is_thread_binding = is_thread_binding || state[iter_var];
94+ }
95+ for (const auto & iter_var : s->original_variables ) {
96+ state[iter_var] = is_thread_binding;
97+ }
8298 } else {
8399 LOG (FATAL) << " unknown relation type" ;
84100 }
@@ -157,6 +173,29 @@ void PassDownDomain(const Stage& stage, std::unordered_map<IterVar, Range>* p_st
157173 Update (p_state, r->rebased , Range::FromMinExtent (0 , state.at (r->parent )->extent ), actx);
158174 } else if (const SingletonNode* s = rel.as <SingletonNode>()) {
159175 Update (p_state, s->iter , Range::FromMinExtent (0 , 1 ), actx);
176+ } else if (const TransformNode* s = rel.as <TransformNode>()) {
177+ bool missing_originals = false ;
178+ for (const auto & iter_var : s->original_variables ) {
179+ if (!state.count (iter_var)) {
180+ ICHECK (allow_missing);
181+ missing_originals = true ;
182+ }
183+ }
184+ if (missing_originals) {
185+ continue ;
186+ }
187+
188+ Array<Range> original_ranges;
189+ for (const auto & iter_var : s->original_variables ) {
190+ original_ranges.push_back (state[iter_var]);
191+ }
192+ Array<Range> updated_ranges = s->forward_transformation ->MapRanges (original_ranges);
193+
194+ ICHECK_EQ (updated_ranges.size (), s->transformed_variables .size ());
195+ for (size_t i = 0 ; i < updated_ranges.size (); i++) {
196+ Update (p_state, s->transformed_variables [i], updated_ranges[i], actx);
197+ }
198+
160199 } else {
161200 LOG (FATAL) << " unknown relation type" ;
162201 }
@@ -225,6 +264,39 @@ void PassUpIndex(const Stage& stage, const Map<IterVar, Range>& dom_map,
225264 state[s->parent ] = value;
226265 }
227266 } else if (rel.as <SingletonNode>()) {
267+ } else if (const TransformNode* s = rel.as <TransformNode>()) {
268+ bool missing_transformed = false ;
269+ for (const auto & iter_var : s->transformed_variables ) {
270+ if (!state.count (iter_var)) {
271+ // for (const auto& kv : state) {
272+ // std::cout << "Looking for " << tvm::PrettyPrint(iter_var) << std::endl;
273+ // std::cout << "State contains " << tvm::PrettyPrint(kv.first) << " -> "
274+ // << tvm::PrettyPrint(kv.second) << std::endl;
275+ // }
276+ // TODO: Decide whether to have this check, for similarity
277+ // with other handlers. In this case, the indices may
278+ // already be in terms of the pre-transformed variables, so
279+ // no need to untransform them?
280+
281+ // ICHECK(allow_missing);
282+ missing_transformed = true ;
283+ }
284+ }
285+ if (missing_transformed) {
286+ continue ;
287+ }
288+
289+ Array<PrimExpr> transformed_indices;
290+ for (const auto & iter_var : s->transformed_variables ) {
291+ transformed_indices.push_back (state[iter_var]);
292+ }
293+ Array<PrimExpr> original_indices = s->inverse_transformation ->MapIndices (transformed_indices);
294+
295+ ICHECK_EQ (original_indices.size (), s->original_variables .size ());
296+ for (size_t i = 0 ; i < original_indices.size (); i++) {
297+ state[s->original_variables [i]] = original_indices[i];
298+ }
299+
228300 } else {
229301 LOG (FATAL) << " unknown relation type" ;
230302 }
@@ -270,6 +342,28 @@ void PassDownIndex(const Stage& stage, const Map<IterVar, Range>& dom_map,
270342 state[s->rebased ] = value;
271343 } else if (const SingletonNode* s = rel.as <SingletonNode>()) {
272344 state[s->iter ] = make_zero (s->iter ->var .dtype ());
345+ } else if (const TransformNode* s = rel.as <TransformNode>()) {
346+ bool missing_originals = false ;
347+ for (const auto & iter_var : s->original_variables ) {
348+ if (!state.count (iter_var)) {
349+ ICHECK (allow_missing);
350+ missing_originals = true ;
351+ }
352+ }
353+ if (missing_originals) {
354+ continue ;
355+ }
356+
357+ Array<PrimExpr> original_indices;
358+ for (const auto & iter_var : s->original_variables ) {
359+ original_indices.push_back (state[iter_var]);
360+ }
361+ Array<PrimExpr> transformed_indices = s->forward_transformation ->MapIndices (original_indices);
362+
363+ ICHECK_EQ (transformed_indices.size (), s->transformed_variables .size ());
364+ for (size_t i = 0 ; i < transformed_indices.size (); i++) {
365+ state[s->transformed_variables [i]] = transformed_indices[i];
366+ }
273367 } else {
274368 LOG (FATAL) << " unknown relation type" ;
275369 }
@@ -351,6 +445,26 @@ void PassUpDomain(const RebaseNode* s, const std::unordered_map<IterVar, Range>&
351445 *parent = arith::EvalSet (s->rebased ->var + parent_min, {{s->rebased , rebased}});
352446}
353447
448+ Array<IntSet> PassUpDomain (const TransformNode* s,
449+ const std::unordered_map<IterVar, Range>& dom_map,
450+ const Map<IterVar, IntSet>& transformed_domains) {
451+ Array<IntSet> output;
452+
453+ Array<PrimExpr> transformed_indices;
454+ for (const auto & iter_var : s->transformed_variables ) {
455+ transformed_indices.push_back (iter_var->var );
456+ }
457+
458+ Array<PrimExpr> transformed_exprs = s->inverse_transformation ->MapIndices (transformed_indices);
459+
460+ ICHECK_EQ (transformed_exprs.size (), s->original_variables .size ());
461+ for (size_t i = 0 ; i < transformed_exprs.size (); i++) {
462+ output.push_back (arith::EvalSet (transformed_exprs[i], transformed_domains));
463+ }
464+
465+ return output;
466+ }
467+
354468void PassUpDomain (const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
355469 std::unordered_map<IterVar, IntSet>* p_state) {
356470 auto & state = *p_state;
@@ -370,6 +484,16 @@ void PassUpDomain(const Stage& stage, const std::unordered_map<IterVar, Range>&
370484 PassUpDomain (r, dom_map, state.at (r->rebased ), &parent);
371485 state[r->parent ] = parent;
372486 } else if (rel.as <SingletonNode>()) {
487+ } else if (const TransformNode* r = rel.as <TransformNode>()) {
488+ Map<IterVar, IntSet> transformed_domains;
489+ for (const auto & var : r->transformed_variables ) {
490+ transformed_domains.Set (var, state.at (var));
491+ }
492+ auto original_ranges = PassUpDomain (r, dom_map, transformed_domains);
493+ ICHECK_EQ (original_ranges.size (), r->original_variables .size ());
494+ for (size_t i = 0 ; i < original_ranges.size (); i++) {
495+ state[r->original_variables [i]] = original_ranges[i];
496+ }
373497 } else {
374498 LOG (FATAL) << " unknown relation type" ;
375499 }
@@ -509,6 +633,22 @@ void PassUpBoundCheck(const Stage& s, const Map<IterVar, Range>& dom_map,
509633 state[s->parent ] = state.at (s->rebased );
510634 } else if (rel.as <SingletonNode>()) {
511635 // nop
636+ } else if (const TransformNode* s = rel.as <TransformNode>()) {
637+ // Currently, this marks all original iter vars as requiring
638+ // bounds checks if any of the transformed variables require
639+ // bounds checks, even if the inverse expression for that iter
640+ // var doesn't depend on the bound variable.
641+
642+ // TODO(Lunderberg): For each of original variable, check
643+ // whether any variable in the inverse expression for it
644+ // requires bounds checking.
645+ bool needs_bounds_check = false ;
646+ for (const auto & iter_var : s->transformed_variables ) {
647+ needs_bounds_check = needs_bounds_check || state[iter_var];
648+ }
649+ for (const auto & iter_var : s->original_variables ) {
650+ state[iter_var] = needs_bounds_check;
651+ }
512652 } else {
513653 LOG (FATAL) << " unknown relation type" ;
514654 }
0 commit comments