Skip to content

Commit b961ad1

Browse files
committed
Breakpoint, expose the transformed axes for use in TE scheduling.
Final step, exposing the axes generated in .transform_layout for use in TE scheduling.
2 parents 9868fd5 + 5e18278 commit b961ad1

File tree

6 files changed

+455
-13
lines changed

6 files changed

+455
-13
lines changed

include/tvm/te/operation.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ class ComputeOp : public Operation {
265265
Array<IterVar> axis, Array<PrimExpr> body);
266266

267267
TVM_DEFINE_OBJECT_REF_METHODS(ComputeOp, Operation, ComputeOpNode);
268+
TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeOpNode);
268269
};
269270

270271
/*!

include/tvm/te/schedule.h

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -272,10 +272,14 @@ class Stage : public ObjectRef {
272272
* Expressions should be in terms of the variables given in
273273
* initial_indices.
274274
*
275+
* \param out_iter_vars An optional output location for the updated
276+
* loop iteration variables.
277+
*
275278
* \return reference to self
276279
*/
277280
TVM_DLL Stage& transform_layout(const Array<Var>& initial_indices,
278-
const Array<PrimExpr>& final_indices);
281+
const Array<PrimExpr>& final_indices,
282+
Array<IterVar>* out_iter_vars = nullptr);
279283
/*! \brief Defines separators between groups of axes.
280284
*
281285
* Used to define `BufferNode::axis_separators`, which has
@@ -494,9 +498,27 @@ class StageNode : public Object {
494498
* while origin_op remains fixed.
495499
*/
496500
Operation origin_op;
497-
/*! \brief All the nodes in the iter var */
501+
/*! \brief All the nodes in the iter var
502+
*
503+
* Each element of all_iter_vars represents an iteration variable
504+
* that may appear within this stage's computation. Any element
505+
* of `all_iter_vars` that is in `leaf_iter_vars` represents a
506+
* variable that is directly defined and usable within the stage's
507+
* computation. All other elements of `all_iter_vars` represent
508+
* variables whose value must be computed from the variables in
509+
* `leaf_iter_vars`. (e.g. Support index k has been split by
510+
* ``ko, ki = s.split(k, factor=4)``. ko and ki will appear in
511+
* `leaf_iter_vars`, while k will not, and must be computed as
512+
* `4*ko + ki`.
513+
*/
498514
Array<IterVar> all_iter_vars;
499-
/*! \brief The current active leaf iter vars in the stage. */
515+
/*! \brief The current active leaf iter vars in the stage.
516+
*
517+
* Each element of leaf_iter_vars will either be replaced with the
518+
* bound index (e.g. threadIdx.x), or will be expanded into a loop
519+
* over the variable's extent. `leaf_iter_vars` is a subset of
520+
* `all_iter_vars`.
521+
*/
500522
Array<IterVar> leaf_iter_vars;
501523
/*!
502524
* \brief Specify threads to be launched at the stage.
@@ -809,6 +831,36 @@ class Singleton : public IterVarRelation {
809831
TVM_DEFINE_OBJECT_REF_METHODS(Singleton, IterVarRelation, SingletonNode);
810832
};
811833

834+
/*!
835+
* \brief Transform iterator according to some arbitrary expression.
836+
*/
837+
class TransformNode : public IterVarRelationNode {
838+
public:
839+
Array<IterVar> original_variables;
840+
Array<IterVar> transformed_variables;
841+
IndexMap forward_transformation;
842+
IndexMap inverse_transformation;
843+
844+
void VisitAttrs(AttrVisitor* v) {
845+
v->Visit("original_variables", &original_variables);
846+
v->Visit("transformed_variables", &transformed_variables);
847+
v->Visit("forward_transformation", &forward_transformation);
848+
v->Visit("inverse_transformation", &inverse_transformation);
849+
}
850+
851+
static constexpr const char* _type_key = "Transform";
852+
TVM_DECLARE_FINAL_OBJECT_INFO(TransformNode, IterVarRelationNode);
853+
};
854+
855+
class Transform : public IterVarRelation {
856+
public:
857+
TVM_DLL explicit Transform(Array<IterVar> original_variables,
858+
Array<IterVar> transformed_variables, IndexMap forward_transformation,
859+
IndexMap inverse_transformation);
860+
861+
TVM_DEFINE_OBJECT_REF_METHODS(Transform, IterVarRelation, TransformNode);
862+
};
863+
812864
/*! \brief Container for specialization conditions. */
813865
class SpecializedConditionNode : public Object {
814866
public:

python/tvm/te/schedule.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -527,9 +527,15 @@ def transform_layout(self, mapping_function: Callable[..., List[tvm.tir.PrimExpr
527527
"""Defines the layout transformation for the current stage's tensor.
528528
529529
The map from initial_indices to final_indices must be an
530-
invertible affine transformation.
530+
invertible affine transformation. This method may be called
531+
more than once for a given tensor, in which case each
532+
transformation is applied sequentially.
531533
532-
This method may be called more than once for a given tensor, in which case each
534+
If the stage is a ComputeOp, then the iteration order of the
535+
compute stage is rewritten to be a row-major traversal of the
536+
tensor, and the new loop iteration variables are returned.
537+
For all other stages, the loop iteration order is unmodified,
538+
and the return value is None.
533539
534540
Parameters
535541
----------
@@ -543,6 +549,17 @@ def transform_layout(self, mapping_function: Callable[..., List[tvm.tir.PrimExpr
543549
the current stage's tensor, using the post-transformation
544550
layout.
545551
552+
Returns
553+
-------
554+
new_iter_vars : Optional[List[tvm.tir.IterVar]]
555+
556+
If the stage is a ComputeOp, then the return will be the
557+
updated loop iteration variables over the data array, in
558+
the same order as the output values from the
559+
`mapping_function`.
560+
561+
Otherwise, the return value is None.
562+
546563
Examples
547564
--------
548565
.. code-block:: python
@@ -557,15 +574,29 @@ def transform_layout(self, mapping_function: Callable[..., List[tvm.tir.PrimExpr
557574
558575
.. code-block:: python
559576
560-
# ``A`` is a tensor whose compute definition is in format,
561-
# and should be transformed such that the last index is
562-
# split, with the slower-chan index of the split placed at the
563-
# slowest changing dimension.
577+
# ``A`` is a tensor whose compute definition is in an
578+
# arbitrary format, and should be transformed such that
579+
# the last index is split, with the slower-changing index
580+
# of the split placed at the slowest changing dimension.
564581
565582
s[A].transform_layout(
566583
lambda *indices, i: [i//4, *indices, i%4]
567584
)
568585
586+
.. code-block:: python
587+
588+
# ``B`` is a tensor defined by te.compute to be a copy of
589+
# ``A`, and should be transformed such that ``B``'s layout
590+
# is a transpose of ``A``'s layout. The loop iteration
591+
# that computes ``B`` will correspond to ``B``'s memory
592+
# layout.
593+
594+
A = te.placeholder([n,m])
595+
B = te.compute(A.shape, lambda i,j: A[i,j])
596+
s = te.create_schedule(B.op)
597+
598+
s[B].transform_layout(lambda i,j: [j,i])
599+
569600
"""
570601

571602
args = []
@@ -626,9 +657,10 @@ def transform_layout(self, mapping_function: Callable[..., List[tvm.tir.PrimExpr
626657
"Instead received {val} of type {type(val)}."
627658
)
628659

629-
_ffi_api.StageTransformLayout(self, initial_indices, final_indices)
660+
new_iter_vars = _ffi_api.StageTransformLayout(self, initial_indices, final_indices)
630661
_ffi_api.StageSetAxisSeparators(self, axis_separators)
631662

663+
return new_iter_vars or None
632664

633665

634666
@tvm._ffi.register_object

src/te/schedule/message_passing.cc

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,22 @@ void PassUpThreadBinding(const Stage& stage, std::unordered_map<IterVar, bool>*
7979
} else if (const RebaseNode* s = rel.as<RebaseNode>()) {
8080
state[s->parent] = state[s->rebased];
8181
} else if (rel.as<SingletonNode>()) {
82+
} else if (const TransformNode* s = rel.as<TransformNode>()) {
83+
// Currently, this marks all original iter vars as deriving from
84+
// a thread bind if any of the transformed variables are bound,
85+
// even if the inverse expression for that iter var doesn't
86+
// depend on the bound variable.
87+
88+
// TODO(Lunderberg): For each of original variable, check
89+
// whether any variable in the inverse expression for it has a
90+
// thread binding.
91+
bool is_thread_binding = false;
92+
for (const auto& iter_var : s->transformed_variables) {
93+
is_thread_binding = is_thread_binding || state[iter_var];
94+
}
95+
for (const auto& iter_var : s->original_variables) {
96+
state[iter_var] = is_thread_binding;
97+
}
8298
} else {
8399
LOG(FATAL) << "unknown relation type";
84100
}
@@ -157,6 +173,29 @@ void PassDownDomain(const Stage& stage, std::unordered_map<IterVar, Range>* p_st
157173
Update(p_state, r->rebased, Range::FromMinExtent(0, state.at(r->parent)->extent), actx);
158174
} else if (const SingletonNode* s = rel.as<SingletonNode>()) {
159175
Update(p_state, s->iter, Range::FromMinExtent(0, 1), actx);
176+
} else if (const TransformNode* s = rel.as<TransformNode>()) {
177+
bool missing_originals = false;
178+
for (const auto& iter_var : s->original_variables) {
179+
if (!state.count(iter_var)) {
180+
ICHECK(allow_missing);
181+
missing_originals = true;
182+
}
183+
}
184+
if (missing_originals) {
185+
continue;
186+
}
187+
188+
Array<Range> original_ranges;
189+
for (const auto& iter_var : s->original_variables) {
190+
original_ranges.push_back(state[iter_var]);
191+
}
192+
Array<Range> updated_ranges = s->forward_transformation->MapRanges(original_ranges);
193+
194+
ICHECK_EQ(updated_ranges.size(), s->transformed_variables.size());
195+
for (size_t i = 0; i < updated_ranges.size(); i++) {
196+
Update(p_state, s->transformed_variables[i], updated_ranges[i], actx);
197+
}
198+
160199
} else {
161200
LOG(FATAL) << "unknown relation type";
162201
}
@@ -225,6 +264,39 @@ void PassUpIndex(const Stage& stage, const Map<IterVar, Range>& dom_map,
225264
state[s->parent] = value;
226265
}
227266
} else if (rel.as<SingletonNode>()) {
267+
} else if (const TransformNode* s = rel.as<TransformNode>()) {
268+
bool missing_transformed = false;
269+
for (const auto& iter_var : s->transformed_variables) {
270+
if (!state.count(iter_var)) {
271+
// for (const auto& kv : state) {
272+
// std::cout << "Looking for " << tvm::PrettyPrint(iter_var) << std::endl;
273+
// std::cout << "State contains " << tvm::PrettyPrint(kv.first) << " -> "
274+
// << tvm::PrettyPrint(kv.second) << std::endl;
275+
// }
276+
// TODO: Decide whether to have this check, for similarity
277+
// with other handlers. In this case, the indices may
278+
// already be in terms of the pre-transformed variables, so
279+
// no need to untransform them?
280+
281+
// ICHECK(allow_missing);
282+
missing_transformed = true;
283+
}
284+
}
285+
if (missing_transformed) {
286+
continue;
287+
}
288+
289+
Array<PrimExpr> transformed_indices;
290+
for (const auto& iter_var : s->transformed_variables) {
291+
transformed_indices.push_back(state[iter_var]);
292+
}
293+
Array<PrimExpr> original_indices = s->inverse_transformation->MapIndices(transformed_indices);
294+
295+
ICHECK_EQ(original_indices.size(), s->original_variables.size());
296+
for (size_t i = 0; i < original_indices.size(); i++) {
297+
state[s->original_variables[i]] = original_indices[i];
298+
}
299+
228300
} else {
229301
LOG(FATAL) << "unknown relation type";
230302
}
@@ -270,6 +342,28 @@ void PassDownIndex(const Stage& stage, const Map<IterVar, Range>& dom_map,
270342
state[s->rebased] = value;
271343
} else if (const SingletonNode* s = rel.as<SingletonNode>()) {
272344
state[s->iter] = make_zero(s->iter->var.dtype());
345+
} else if (const TransformNode* s = rel.as<TransformNode>()) {
346+
bool missing_originals = false;
347+
for (const auto& iter_var : s->original_variables) {
348+
if (!state.count(iter_var)) {
349+
ICHECK(allow_missing);
350+
missing_originals = true;
351+
}
352+
}
353+
if (missing_originals) {
354+
continue;
355+
}
356+
357+
Array<PrimExpr> original_indices;
358+
for (const auto& iter_var : s->original_variables) {
359+
original_indices.push_back(state[iter_var]);
360+
}
361+
Array<PrimExpr> transformed_indices = s->forward_transformation->MapIndices(original_indices);
362+
363+
ICHECK_EQ(transformed_indices.size(), s->transformed_variables.size());
364+
for (size_t i = 0; i < transformed_indices.size(); i++) {
365+
state[s->transformed_variables[i]] = transformed_indices[i];
366+
}
273367
} else {
274368
LOG(FATAL) << "unknown relation type";
275369
}
@@ -351,6 +445,26 @@ void PassUpDomain(const RebaseNode* s, const std::unordered_map<IterVar, Range>&
351445
*parent = arith::EvalSet(s->rebased->var + parent_min, {{s->rebased, rebased}});
352446
}
353447

448+
Array<IntSet> PassUpDomain(const TransformNode* s,
449+
const std::unordered_map<IterVar, Range>& dom_map,
450+
const Map<IterVar, IntSet>& transformed_domains) {
451+
Array<IntSet> output;
452+
453+
Array<PrimExpr> transformed_indices;
454+
for (const auto& iter_var : s->transformed_variables) {
455+
transformed_indices.push_back(iter_var->var);
456+
}
457+
458+
Array<PrimExpr> transformed_exprs = s->inverse_transformation->MapIndices(transformed_indices);
459+
460+
ICHECK_EQ(transformed_exprs.size(), s->original_variables.size());
461+
for (size_t i = 0; i < transformed_exprs.size(); i++) {
462+
output.push_back(arith::EvalSet(transformed_exprs[i], transformed_domains));
463+
}
464+
465+
return output;
466+
}
467+
354468
void PassUpDomain(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
355469
std::unordered_map<IterVar, IntSet>* p_state) {
356470
auto& state = *p_state;
@@ -370,6 +484,16 @@ void PassUpDomain(const Stage& stage, const std::unordered_map<IterVar, Range>&
370484
PassUpDomain(r, dom_map, state.at(r->rebased), &parent);
371485
state[r->parent] = parent;
372486
} else if (rel.as<SingletonNode>()) {
487+
} else if (const TransformNode* r = rel.as<TransformNode>()) {
488+
Map<IterVar, IntSet> transformed_domains;
489+
for (const auto& var : r->transformed_variables) {
490+
transformed_domains.Set(var, state.at(var));
491+
}
492+
auto original_ranges = PassUpDomain(r, dom_map, transformed_domains);
493+
ICHECK_EQ(original_ranges.size(), r->original_variables.size());
494+
for (size_t i = 0; i < original_ranges.size(); i++) {
495+
state[r->original_variables[i]] = original_ranges[i];
496+
}
373497
} else {
374498
LOG(FATAL) << "unknown relation type";
375499
}
@@ -509,6 +633,22 @@ void PassUpBoundCheck(const Stage& s, const Map<IterVar, Range>& dom_map,
509633
state[s->parent] = state.at(s->rebased);
510634
} else if (rel.as<SingletonNode>()) {
511635
// nop
636+
} else if (const TransformNode* s = rel.as<TransformNode>()) {
637+
// Currently, this marks all original iter vars as requiring
638+
// bounds checks if any of the transformed variables require
639+
// bounds checks, even if the inverse expression for that iter
640+
// var doesn't depend on the bound variable.
641+
642+
// TODO(Lunderberg): For each of original variable, check
643+
// whether any variable in the inverse expression for it
644+
// requires bounds checking.
645+
bool needs_bounds_check = false;
646+
for (const auto& iter_var : s->transformed_variables) {
647+
needs_bounds_check = needs_bounds_check || state[iter_var];
648+
}
649+
for (const auto& iter_var : s->original_variables) {
650+
state[iter_var] = needs_bounds_check;
651+
}
512652
} else {
513653
LOG(FATAL) << "unknown relation type";
514654
}

0 commit comments

Comments
 (0)