Skip to content

Commit df429c5

Browse files
vinx13Siyuan Feng
andauthored
[TIR] Allow TransformLayout with non-inversible index map (#14095)
* [TIR] Allow TransformLayout with non-inversible index map TransformLayout requires the index map to have inverse map that can be calculated by the analyzer in order to check whether padding is added. However, such check doesn't always work for all cases because of limitation of the affine analysis that can only handle a set of supported patterns. In some cases, even if the index map doesn't introduce padding, the schedule primitive throws `TransformationIntroducesPaddingError` because it fails to calculate the inverse index map. It is safe to allow buffer being padded without providing pad_value because the original loop extent is not changed and the padded region is not accessed. This PR changes the behavior of `TransformLayout` to allow non-inversible index map. Previous discussion: https://discuss.tvm.apache.org/t/conflict-free-shared-memory-permutation-in-tensorir/13959/9 * add assume_injective_transform option * Apply suggestions from code review Co-authored-by: Siyuan Feng <[email protected]> --------- Co-authored-by: Siyuan Feng <[email protected]>
1 parent bc92a3f commit df429c5

File tree

10 files changed

+104
-32
lines changed

10 files changed

+104
-32
lines changed

include/tvm/tir/schedule/schedule.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -642,10 +642,17 @@ class ScheduleNode : public runtime::Object {
642642
* Algebraic symplifications, branch elimination, and other
643643
* optimizations may assume that this precondition is met, and
644644
* may result in incorrect results being returned.
645+
*
646+
* \param assume_injective_transform If set to true, the schedule primitive will assume the
647+
* index_map is injective and skip checking overlapping of the mapped indices. This can be useful
648+
* for complicated index_map that the analysis does not cover. It is the callers' responsibility
649+
* to ensure the index map is injective, otherwise, the correctness of the schedule is not
650+
* guaranteed.
645651
*/
646652
virtual void TransformLayout(const BlockRV& block_rv, int buffer_index,
647653
BufferIndexType buffer_index_type, const IndexMap& index_map,
648-
const Optional<IndexMap>& pad_value = NullOpt) = 0;
654+
const Optional<IndexMap>& pad_value = NullOpt,
655+
bool assume_injective_transform = false) = 0;
649656

650657
/*!
651658
* \brief Apply a transformation represented by IndexMap to block

python/tvm/tir/schedule/schedule.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2575,7 +2575,6 @@ def _normalize_buffer_arg(
25752575
buffer: Union[Tuple[str, int], int, str, Buffer],
25762576
required_buffer_type=None,
25772577
) -> Tuple[str, int, Buffer]:
2578-
25792578
block_obj: Block = self.get(block)
25802579
block_name = block_obj.name_hint
25812580

@@ -2645,6 +2644,8 @@ def transform_layout(
26452644
buffer: Union[Tuple[str, int], str, Buffer],
26462645
index_map: Union[IndexMap, Callable],
26472646
pad_value: Optional[Union[int, float, PrimExpr, IndexMap, Callable]] = None,
2647+
*,
2648+
assume_injective_transform: bool = False,
26482649
) -> None:
26492650
"""Apply a transformation represented by IndexMap to buffer
26502651
@@ -2711,6 +2712,13 @@ def transform_layout(
27112712
value to be present in the padding in terms of the
27122713
transformed index.
27132714
2715+
assume_injective_transform : bool
2716+
2717+
If set to true, the schedule primitive will assume the index_map is injective and skip
2718+
checking overlapping of the mapped indices. This can be useful for complicated index_map
2719+
that the analysis does not cover. It is the callers' responsibility to ensure the
2720+
index map is injective, otherwise, the correctness of the schedule is not guaranteed.
2721+
27142722
Examples
27152723
--------
27162724
Before transform_layout, in TensorIR, the IR is:
@@ -2787,7 +2795,13 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) ->
27872795

27882796
buffer_index_type_enum = 0 if buffer_index_type == "read" else 1
27892797
_ffi_api.ScheduleTransformLayout( # type: ignore # pylint: disable=no-member
2790-
self, block, buffer_index, buffer_index_type_enum, index_map, pad_value
2798+
self,
2799+
block,
2800+
buffer_index,
2801+
buffer_index_type_enum,
2802+
index_map,
2803+
pad_value,
2804+
assume_injective_transform,
27912805
)
27922806
if axis_separators:
27932807
_ffi_api.ScheduleSetAxisSeparator( # type: ignore # pylint: disable=no-member

src/tir/schedule/concrete_schedule.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -800,14 +800,15 @@ void ConcreteScheduleNode::Unannotate(const BlockRV& block_rv, const String& ann
800800
void ConcreteScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_index,
801801
BufferIndexType buffer_index_type,
802802
const IndexMap& index_map,
803-
const Optional<IndexMap>& pad_value) {
803+
const Optional<IndexMap>& pad_value,
804+
bool assume_injective_transform) {
804805
TVM_TIR_SCHEDULE_BEGIN();
805806
auto f_subst = [&](const Var& var) -> Optional<PrimExpr> {
806807
return Downcast<Optional<PrimExpr>>(symbol_table_.Get(var));
807808
};
808809
auto new_index_map = Substitute(index_map, f_subst);
809810
tir::TransformLayout(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type,
810-
new_index_map, pad_value);
811+
new_index_map, pad_value, assume_injective_transform);
811812
this->state_->DebugVerify();
812813
TVM_TIR_SCHEDULE_END("transform_layout", this->error_render_level_);
813814
}

src/tir/schedule/concrete_schedule.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,8 @@ class ConcreteScheduleNode : public ScheduleNode {
148148
void Unannotate(const BlockRV& block_rv, const String& ann_key) override;
149149
/******** Schedule: Layout transformation ********/
150150
void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type,
151-
const IndexMap& index_map, const Optional<IndexMap>& pad_value) override;
151+
const IndexMap& index_map, const Optional<IndexMap>& pad_value,
152+
bool assume_injective_transform = false) override;
152153
void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) override;
153154
void SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
154155
BufferIndexType buffer_index_type,

src/tir/schedule/primitive.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,10 +501,15 @@ TVM_DLL void Unannotate(ScheduleState self, const StmtSRef& sref, const String&
501501
* \param buffer_index_type The type of the buffer index, kRead or kWrite.
502502
* \param index_map The transformation to apply.
503503
* \param pad_value The value to write into padding introduced by the transformation.
504+
* \param assume_injective_transform If set to true, the schedule primitive will assume the
505+
* index_map is injective and skip checking overlapping of the mapped indices. This can be useful
506+
* for complicated index_map that the analysis does not cover. It is the callers' responsibility
507+
* to ensure the index map is injective, otherwise, the correctness of the schedule is not
508+
* guaranteed.
504509
*/
505510
TVM_DLL void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
506511
BufferIndexType buffer_index_type, const IndexMap& index_map,
507-
const Optional<IndexMap>& pad_value);
512+
const Optional<IndexMap>& pad_value, bool assume_injective_transform);
508513

509514
/*!
510515
* \brief Apply a transformation represented by IndexMap to block

src/tir/schedule/primitive/layout_transformation.cc

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -753,10 +753,12 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer {
753753
*/
754754
static std::pair<Stmt, Map<Block, Block>> Rewrite(
755755
const Block& scope_stmt, const Buffer& old_buffer, const Buffer& new_buffer,
756-
const IndexMap& index_map, const IndexMap& inverse, const PrimExpr& padding_predicate,
757-
const Optional<IndexMap>& pad_value) {
758-
auto plan = TransformLayoutPlanner::Plan(scope_stmt, old_buffer, new_buffer, index_map, inverse,
759-
padding_predicate, pad_value);
756+
const IndexMap& index_map, const Optional<IndexMap>& opt_inverse,
757+
const PrimExpr& padding_predicate, const Optional<IndexMap>& pad_value) {
758+
auto plan = pad_value.defined() ? TransformLayoutPlanner::Plan(
759+
scope_stmt, old_buffer, new_buffer, index_map,
760+
opt_inverse.value(), padding_predicate, pad_value)
761+
: TransformLayoutPlanner::NoPaddingRequired();
760762

761763
arith::Analyzer analyzer;
762764
TransformLayoutRewriter rewriter(old_buffer, new_buffer, index_map, plan, &analyzer);
@@ -1119,7 +1121,7 @@ IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Array<PrimExpr>&
11191121

11201122
void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
11211123
BufferIndexType buffer_index_type, const IndexMap& index_map_orig,
1122-
const Optional<IndexMap>& pad_value) {
1124+
const Optional<IndexMap>& pad_value, bool assume_injective_transform) {
11231125
// Step 1: Input handling and error checking
11241126
const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref);
11251127
Buffer old_buffer =
@@ -1147,13 +1149,17 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_
11471149
: GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
11481150
const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref);
11491151

1150-
auto [inverse, padding_predicate] = [&]() {
1151-
Array<Range> region;
1152-
for (const auto& dim : old_buffer->shape) {
1153-
region.push_back(Range::FromMinExtent(make_zero(dim.dtype()), dim));
1154-
}
1155-
return index_map.NonSurjectiveInverse(region);
1156-
}();
1152+
Optional<IndexMap> opt_inverse = NullOpt;
1153+
PrimExpr padding_predicate = Bool(false);
1154+
if (!assume_injective_transform) {
1155+
std::tie(opt_inverse, padding_predicate) = [&]() {
1156+
Array<Range> region;
1157+
for (const auto& dim : old_buffer->shape) {
1158+
region.push_back(Range::FromMinExtent(make_zero(dim.dtype()), dim));
1159+
}
1160+
return index_map.NonSurjectiveInverse(region);
1161+
}();
1162+
}
11571163

11581164
bool has_padding = !is_zero(padding_predicate);
11591165
if (has_padding && !pad_value.defined()) {
@@ -1168,7 +1174,7 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_
11681174
// alloc_buffers.
11691175
auto [new_stmt, block_sref_reuse] =
11701176
TransformLayoutRewriter::Rewrite(GetRef<Block>(scope_block), old_buffer, new_buffer,
1171-
index_map, inverse, padding_predicate, pad_value);
1177+
index_map, opt_inverse, padding_predicate, pad_value);
11721178
Block new_scope_block = Downcast<Block>(new_stmt);
11731179

11741180
// Step 4: Rewrite buffer_map of the PrimFunc if necessary.
@@ -1511,20 +1517,21 @@ struct TransformLayoutTraits : public UnpackedInstTraits<TransformLayoutTraits>
15111517

15121518
private:
15131519
static constexpr size_t kNumInputs = 2;
1514-
static constexpr size_t kNumAttrs = 3;
1520+
static constexpr size_t kNumAttrs = 4;
15151521
static constexpr size_t kNumDecisions = 0;
15161522

15171523
static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, IndexMap index_map,
15181524
Integer buffer_index, Integer buffer_index_type,
1519-
Optional<IndexMap> pad_value) {
1525+
Optional<IndexMap> pad_value,
1526+
Bool assume_injective_transform) {
15201527
return sch->TransformLayout(block_rv, buffer_index.IntValue(),
15211528
static_cast<BufferIndexType>(buffer_index_type->value), index_map,
1522-
pad_value);
1529+
pad_value, assume_injective_transform.operator bool());
15231530
}
15241531

15251532
static String UnpackedAsPython(Array<String> outputs, String block_rv, IndexMap index_map,
15261533
Integer buffer_index, Integer buffer_index_type,
1527-
Optional<IndexMap> pad_value) {
1534+
Optional<IndexMap> pad_value, Bool assume_injective_transform) {
15281535
PythonAPICall py("transform_layout");
15291536
py.Input("block", block_rv);
15301537

@@ -1534,6 +1541,7 @@ struct TransformLayoutTraits : public UnpackedInstTraits<TransformLayoutTraits>
15341541
py.Input("buffer", os.str());
15351542
py.Input("index_map", index_map->ToPythonString());
15361543
py.Input("pad_value", pad_value ? pad_value.value()->ToPythonString() : "None");
1544+
py.Input("assume_injective_transform", assume_injective_transform.operator bool());
15371545

15381546
return py.Str();
15391547
}
@@ -1549,6 +1557,7 @@ struct TransformLayoutTraits : public UnpackedInstTraits<TransformLayoutTraits>
15491557
} else {
15501558
attrs_record.push_back(attrs[2]);
15511559
}
1560+
attrs_record.push_back(attrs[3]);
15521561
return std::move(attrs_record);
15531562
}
15541563

@@ -1562,6 +1571,7 @@ struct TransformLayoutTraits : public UnpackedInstTraits<TransformLayoutTraits>
15621571
} else {
15631572
attrs.push_back(attrs_record[2]);
15641573
}
1574+
attrs.push_back(attrs_record[3]);
15651575
return attrs;
15661576
}
15671577

src/tir/schedule/schedule.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,10 +253,10 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnannotate")
253253
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformLayout")
254254
.set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index,
255255
int buffer_index_type, const IndexMap& index_map,
256-
const Optional<IndexMap>& pad_value) {
256+
const Optional<IndexMap>& pad_value, bool assume_injective_transform) {
257257
return self->TransformLayout(block_rv, buffer_index,
258258
static_cast<BufferIndexType>(buffer_index_type), index_map,
259-
pad_value);
259+
pad_value, assume_injective_transform);
260260
});
261261
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformBlockLayout")
262262
.set_body_method<Schedule>(&ScheduleNode::TransformBlockLayout);

src/tir/schedule/traced_schedule.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -523,15 +523,18 @@ void TracedScheduleNode::Unannotate(const BlockRV& block_rv, const String& ann_k
523523
void TracedScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_index,
524524
BufferIndexType buffer_index_type,
525525
const IndexMap& index_map,
526-
const Optional<IndexMap>& pad_value) {
526+
const Optional<IndexMap>& pad_value,
527+
bool assume_injective_transform) {
527528
ConcreteScheduleNode::TransformLayout(block_rv, buffer_index, buffer_index_type, index_map,
528-
pad_value);
529+
pad_value, assume_injective_transform);
529530
static const InstructionKind& kind = InstructionKind::Get("TransformLayout");
530531
trace_->Append(
531532
/*inst=*/Instruction(
532533
/*kind=*/kind,
533534
/*inputs=*/{block_rv, index_map},
534-
/*attrs=*/{Integer(buffer_index), Integer(buffer_index_type), pad_value},
535+
/*attrs=*/
536+
{Integer(buffer_index), Integer(buffer_index_type), pad_value,
537+
Bool(assume_injective_transform)},
535538
/*outputs=*/{}));
536539
}
537540

src/tir/schedule/traced_schedule.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ class TracedScheduleNode : public ConcreteScheduleNode {
107107
void Unannotate(const BlockRV& block_rv, const String& ann_key) override;
108108
/******** Schedule: Layout transformation ********/
109109
void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type,
110-
const IndexMap& index_map, const Optional<IndexMap>& pad_value) override;
110+
const IndexMap& index_map, const Optional<IndexMap>& pad_value,
111+
bool assume_injective_transform) override;
111112
void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) override;
112113
void SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
113114
BufferIndexType buffer_index_type,

tests/python/unittest/test_tir_schedule_transform_layout.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,11 +477,19 @@ class BasePaddingCompare(tvm.testing.CompareBeforeAfter):
477477

478478
index_map = tvm.testing.parameter(lambda i: [i // 4, i % 4])
479479

480+
assume_injective_transform = tvm.testing.parameter(False)
481+
480482
@pytest.fixture
481-
def transform(self, pad_value, transformed_buffer, index_map):
483+
def transform(self, pad_value, transformed_buffer, index_map, assume_injective_transform):
482484
def transform(mod):
483485
sch = tir.Schedule(mod)
484-
sch.transform_layout("block", transformed_buffer, index_map, pad_value=pad_value)
486+
sch.transform_layout(
487+
"block",
488+
transformed_buffer,
489+
index_map,
490+
pad_value=pad_value,
491+
assume_injective_transform=assume_injective_transform,
492+
)
485493
return sch.mod
486494

487495
return transform
@@ -578,6 +586,28 @@ def before():
578586
expected = tvm.tir.schedule.schedule.ScheduleError
579587

580588

589+
class TestImplicitPaddingAssumeInjective(BasePaddingCompare):
590+
"""When pad_value is None and assume_injective_transform is set, the buffer can be implicitly
591+
padded. The padded region is not accessed because the original loop extent is not changed.
592+
"""
593+
594+
assume_injective_transform = tvm.testing.parameter(True)
595+
596+
def before():
597+
A = T.alloc_buffer(14, "int32")
598+
for i in T.serial(14):
599+
with T.block("block"):
600+
vi = T.axis.remap("S", [i])
601+
A[vi] = 0
602+
603+
def expected():
604+
A = T.alloc_buffer([4, 4], "int32")
605+
for i in T.serial(14):
606+
with T.block("block"):
607+
vi = T.axis.remap("S", [i])
608+
A[vi // 4, vi % 4] = 0
609+
610+
581611
class TestErrorOnWrongPaddingType(BasePaddingCompare):
582612
"""The padding must have the same dtype as the buffer"""
583613

0 commit comments

Comments
 (0)