Skip to content

Commit 5a1df3c

Browse files
committed
add assume_injective_transform option
1 parent 98484b9 commit 5a1df3c

File tree

10 files changed

+136
-38
lines changed

10 files changed

+136
-38
lines changed

include/tvm/tir/schedule/schedule.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -642,10 +642,17 @@ class ScheduleNode : public runtime::Object {
642642
* Algebraic symplifications, branch elimination, and other
643643
* optimizations may assume that this precondition is met, and
644644
* may result in incorrect results being returned.
645+
*
646+
* \param assume_injective_transform If set to true, the schedule primitive will assume the
647+
* index_map is injective and skip checking overlapping of the mapped indices. This can be useful
648+
* for complicated index_map that the analysis does not cover. It is the callers' responsibility
649+
* to ensure the index map is injective, otherwise, the correctness of the schedule is not
650+
* guaranteed.
645651
*/
646652
virtual void TransformLayout(const BlockRV& block_rv, int buffer_index,
647653
BufferIndexType buffer_index_type, const IndexMap& index_map,
648-
const Optional<IndexMap>& pad_value = NullOpt) = 0;
654+
const Optional<IndexMap>& pad_value = NullOpt,
655+
bool assume_injective_transform = false) = 0;
649656

650657
/*!
651658
* \brief Apply a transformation represented by IndexMap to block

python/tvm/tir/schedule/schedule.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2575,7 +2575,6 @@ def _normalize_buffer_arg(
25752575
buffer: Union[Tuple[str, int], int, str, Buffer],
25762576
required_buffer_type=None,
25772577
) -> Tuple[str, int, Buffer]:
2578-
25792578
block_obj: Block = self.get(block)
25802579
block_name = block_obj.name_hint
25812580

@@ -2645,6 +2644,8 @@ def transform_layout(
26452644
buffer: Union[Tuple[str, int], str, Buffer],
26462645
index_map: Union[IndexMap, Callable],
26472646
pad_value: Optional[Union[int, float, PrimExpr, IndexMap, Callable]] = None,
2647+
*,
2648+
assume_injective_transform=False,
26482649
) -> None:
26492650
"""Apply a transformation represented by IndexMap to buffer
26502651
@@ -2711,6 +2712,13 @@ def transform_layout(
27112712
value to be present in the padding in terms of the
27122713
transformed index.
27132714
2715+
assume_injective_transform : bool
2716+
2717+
If set to true, the schedule primitive will assume the index_map is injective and skip
2718+
checking overlapping of the mapped indices. This can be useful for complicated index_map
2719+
that the analysis does not cover. It is the callers' responsibility to ensure the
2720+
index map is injective, otherwise, the correctness of the schedule is not guaranteed.
2721+
27142722
Examples
27152723
--------
27162724
Before transform_layout, in TensorIR, the IR is:
@@ -2787,7 +2795,13 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) ->
27872795

27882796
buffer_index_type_enum = 0 if buffer_index_type == "read" else 1
27892797
_ffi_api.ScheduleTransformLayout( # type: ignore # pylint: disable=no-member
2790-
self, block, buffer_index, buffer_index_type_enum, index_map, pad_value
2798+
self,
2799+
block,
2800+
buffer_index,
2801+
buffer_index_type_enum,
2802+
index_map,
2803+
pad_value,
2804+
assume_injective_transform,
27912805
)
27922806
if axis_separators:
27932807
_ffi_api.ScheduleSetAxisSeparator( # type: ignore # pylint: disable=no-member

src/tir/schedule/concrete_schedule.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -800,14 +800,15 @@ void ConcreteScheduleNode::Unannotate(const BlockRV& block_rv, const String& ann
800800
void ConcreteScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_index,
801801
BufferIndexType buffer_index_type,
802802
const IndexMap& index_map,
803-
const Optional<IndexMap>& pad_value) {
803+
const Optional<IndexMap>& pad_value,
804+
bool assume_injective_transform) {
804805
TVM_TIR_SCHEDULE_BEGIN();
805806
auto f_subst = [&](const Var& var) -> Optional<PrimExpr> {
806807
return Downcast<Optional<PrimExpr>>(symbol_table_.Get(var));
807808
};
808809
auto new_index_map = Substitute(index_map, f_subst);
809810
tir::TransformLayout(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type,
810-
new_index_map, pad_value);
811+
new_index_map, pad_value, assume_injective_transform);
811812
this->state_->DebugVerify();
812813
TVM_TIR_SCHEDULE_END("transform_layout", this->error_render_level_);
813814
}

src/tir/schedule/concrete_schedule.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,8 @@ class ConcreteScheduleNode : public ScheduleNode {
148148
void Unannotate(const BlockRV& block_rv, const String& ann_key) override;
149149
/******** Schedule: Layout transformation ********/
150150
void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type,
151-
const IndexMap& index_map, const Optional<IndexMap>& pad_value) override;
151+
const IndexMap& index_map, const Optional<IndexMap>& pad_value,
152+
bool assume_injective_transform = false) override;
152153
void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) override;
153154
void SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
154155
BufferIndexType buffer_index_type,

src/tir/schedule/primitive.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,10 +501,15 @@ TVM_DLL void Unannotate(ScheduleState self, const StmtSRef& sref, const String&
501501
* \param buffer_index_type The type of the buffer index, kRead or kWrite.
502502
* \param index_map The transformation to apply.
503503
* \param pad_value The value to write into padding introduced by the transformation.
504+
* \param assume_injective_transform If set to true, the schedule primitive will assume the
505+
* index_map is injective and skip checking overlapping of the mapped indices. This can be useful
506+
* for complicated index_map that the analysis does not cover. It is the callers' responsibility
507+
* to ensure the index map is injective, otherwise, the correctness of the schedule is not
508+
* guaranteed.
504509
*/
505510
TVM_DLL void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
506511
BufferIndexType buffer_index_type, const IndexMap& index_map,
507-
const Optional<IndexMap>& pad_value);
512+
const Optional<IndexMap>& pad_value, bool assume_injective_transform);
508513

509514
/*!
510515
* \brief Apply a transformation represented by IndexMap to block

src/tir/schedule/primitive/layout_transformation.cc

Lines changed: 72 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -92,17 +92,11 @@ class TransformLayoutPlanner : private StmtExprVisitor {
9292
std::variant<ProloguePlan, ReplacementPlan, EpiloguePlan, NoPaddingRequired>;
9393

9494
static TransformPlan Plan(Block block, Buffer old_buffer, Buffer new_buffer, IndexMap index_map,
95+
IndexMap inverse, PrimExpr padding_predicate,
9596
Optional<IndexMap> pad_value) {
9697
ICHECK(!pad_value.defined() || pad_value.value()->final_indices.size() == 1)
9798
<< "Internal error: Should be caught by ScheduleError checks prior to this point";
9899
TransformLayoutPlanner visitor(old_buffer);
99-
auto [inverse, padding_predicate] = [&]() {
100-
Array<Range> region;
101-
for (const auto& dim : old_buffer->shape) {
102-
region.push_back(Range::FromMinExtent(make_zero(dim.dtype()), dim));
103-
}
104-
return index_map.NonSurjectiveInverse(region);
105-
}();
106100
visitor(block);
107101
return visitor.Finalize(new_buffer, index_map, inverse, padding_predicate, pad_value);
108102
}
@@ -754,18 +748,17 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer {
754748
* \param old_buffer The target buffer before transformation
755749
* \param new_buffer The new buffer after transformation
756750
* \param index_map The transformation applied to the buffer
757-
* \param pad_value The value to be used for padding
758751
* \return The new AST rooting at the original parent scope and the map from the old block to the
759752
* new block
760753
*/
761-
static std::pair<Stmt, Map<Block, Block>> Rewrite(const Block& scope_stmt,
762-
const Buffer& old_buffer,
763-
const Buffer& new_buffer,
764-
const IndexMap& index_map,
765-
const Optional<IndexMap>& pad_value) {
766-
auto plan = pad_value.defined() ? TransformLayoutPlanner::Plan(scope_stmt, old_buffer,
767-
new_buffer, index_map, pad_value)
768-
: TransformLayoutPlanner::NoPaddingRequired{};
754+
static std::pair<Stmt, Map<Block, Block>> Rewrite(
755+
const Block& scope_stmt, const Buffer& old_buffer, const Buffer& new_buffer,
756+
const IndexMap& index_map, const Optional<IndexMap>& opt_inverse,
757+
const PrimExpr& padding_predicate, const Optional<IndexMap>& pad_value) {
758+
auto plan = pad_value.defined() ? TransformLayoutPlanner::Plan(
759+
scope_stmt, old_buffer, new_buffer, index_map,
760+
opt_inverse.value(), padding_predicate, pad_value)
761+
: TransformLayoutPlanner::NoPaddingRequired();
769762

770763
arith::Analyzer analyzer;
771764
TransformLayoutRewriter rewriter(old_buffer, new_buffer, index_map, plan, &analyzer);
@@ -1058,6 +1051,40 @@ class TransformationPaddingExpressionError : public ScheduleError {
10581051
BufferLoad illegal_load_;
10591052
};
10601053

1054+
class TransformationIntroducesPaddingError : public ScheduleError {
1055+
public:
1056+
TransformationIntroducesPaddingError(IRModule mod, Buffer buffer, IndexMap index_map,
1057+
PrimExpr padding_predicate)
1058+
: mod_(std::move(mod)),
1059+
buffer_(std::move(buffer)),
1060+
index_map_(std::move(index_map)),
1061+
padding_predicate_(std::move(padding_predicate)) {}
1062+
1063+
String FastErrorString() const final {
1064+
std::ostringstream ss;
1065+
ss << "ScheduleError: Transformation would introduce padding at " << padding_predicate_ << ".";
1066+
return ss.str();
1067+
}
1068+
1069+
String DetailRenderTemplate() const final {
1070+
auto new_shape = index_map_->MapShape(buffer_->shape);
1071+
std::ostringstream os;
1072+
os << "The transformation " << index_map_ << " applied on buffer " << buffer_->name
1073+
<< " of shape " << buffer_->shape << " would result in shape " << new_shape
1074+
<< ". However, this would introduce padding wherever " << padding_predicate_ << " is true.";
1075+
return os.str();
1076+
}
1077+
1078+
IRModule mod() const final { return mod_; }
1079+
Array<ObjectRef> LocationsOfInterest() const final { return {}; }
1080+
1081+
private:
1082+
IRModule mod_;
1083+
Buffer buffer_;
1084+
IndexMap index_map_;
1085+
PrimExpr padding_predicate_;
1086+
};
1087+
10611088
// Make the dtypes of indices in IndexMap be the same as the dtype of the buffer shape, to avoid
10621089
// dtype-mismatch issues later.
10631090
IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Array<PrimExpr>& args) {
@@ -1094,7 +1121,7 @@ IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Array<PrimExpr>&
10941121

10951122
void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
10961123
BufferIndexType buffer_index_type, const IndexMap& index_map_orig,
1097-
const Optional<IndexMap>& pad_value) {
1124+
const Optional<IndexMap>& pad_value, bool assume_injective_transform) {
10981125
// Step 1: Input handling and error checking
10991126
const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref);
11001127
Buffer old_buffer =
@@ -1122,14 +1149,32 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_
11221149
: GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
11231150
const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref);
11241151

1152+
Optional<IndexMap> opt_inverse = NullOpt;
1153+
PrimExpr padding_predicate = Bool(true);
1154+
if (!assume_injective_transform) {
1155+
std::tie(opt_inverse, padding_predicate) = [&]() {
1156+
Array<Range> region;
1157+
for (const auto& dim : old_buffer->shape) {
1158+
region.push_back(Range::FromMinExtent(make_zero(dim.dtype()), dim));
1159+
}
1160+
return index_map.NonSurjectiveInverse(region);
1161+
}();
1162+
}
1163+
1164+
bool has_padding = !is_zero(padding_predicate);
1165+
if (has_padding && !pad_value.defined()) {
1166+
throw TransformationIntroducesPaddingError(self->mod, old_buffer, index_map, padding_predicate);
1167+
}
1168+
11251169
// Step 2: Infer the shape of the new buffer
11261170
Buffer new_buffer = old_buffer;
11271171
new_buffer.CopyOnWrite()->shape = index_map->MapShape(old_buffer->shape);
11281172

11291173
// Step 3: Rewrite BufferLoad/BufferStore access indices, block read/write regions, and block
11301174
// alloc_buffers.
1131-
auto [new_stmt, block_sref_reuse] = TransformLayoutRewriter::Rewrite(
1132-
GetRef<Block>(scope_block), old_buffer, new_buffer, index_map, pad_value);
1175+
auto [new_stmt, block_sref_reuse] =
1176+
TransformLayoutRewriter::Rewrite(GetRef<Block>(scope_block), old_buffer, new_buffer,
1177+
index_map, opt_inverse, padding_predicate, pad_value);
11331178
Block new_scope_block = Downcast<Block>(new_stmt);
11341179

11351180
// Step 4: Rewrite buffer_map of the PrimFunc if necessary.
@@ -1472,20 +1517,21 @@ struct TransformLayoutTraits : public UnpackedInstTraits<TransformLayoutTraits>
14721517

14731518
private:
14741519
static constexpr size_t kNumInputs = 2;
1475-
static constexpr size_t kNumAttrs = 3;
1520+
static constexpr size_t kNumAttrs = 4;
14761521
static constexpr size_t kNumDecisions = 0;
14771522

14781523
static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, IndexMap index_map,
14791524
Integer buffer_index, Integer buffer_index_type,
1480-
Optional<IndexMap> pad_value) {
1525+
Optional<IndexMap> pad_value,
1526+
Bool assume_injective_transform) {
14811527
return sch->TransformLayout(block_rv, buffer_index.IntValue(),
14821528
static_cast<BufferIndexType>(buffer_index_type->value), index_map,
1483-
pad_value);
1529+
pad_value, assume_injective_transform.operator bool());
14841530
}
14851531

14861532
static String UnpackedAsPython(Array<String> outputs, String block_rv, IndexMap index_map,
14871533
Integer buffer_index, Integer buffer_index_type,
1488-
Optional<IndexMap> pad_value) {
1534+
Optional<IndexMap> pad_value, Bool assume_injective_transform) {
14891535
PythonAPICall py("transform_layout");
14901536
py.Input("block", block_rv);
14911537

@@ -1495,6 +1541,7 @@ struct TransformLayoutTraits : public UnpackedInstTraits<TransformLayoutTraits>
14951541
py.Input("buffer", os.str());
14961542
py.Input("index_map", index_map->ToPythonString());
14971543
py.Input("pad_value", pad_value ? pad_value.value()->ToPythonString() : "None");
1544+
py.Input("assume_injective_transform", assume_injective_transform.operator bool());
14981545

14991546
return py.Str();
15001547
}
@@ -1510,6 +1557,7 @@ struct TransformLayoutTraits : public UnpackedInstTraits<TransformLayoutTraits>
15101557
} else {
15111558
attrs_record.push_back(attrs[2]);
15121559
}
1560+
attrs_record.push_back(attrs[3]);
15131561
return std::move(attrs_record);
15141562
}
15151563

@@ -1523,6 +1571,7 @@ struct TransformLayoutTraits : public UnpackedInstTraits<TransformLayoutTraits>
15231571
} else {
15241572
attrs.push_back(attrs_record[2]);
15251573
}
1574+
attrs.push_back(attrs_record[3]);
15261575
return attrs;
15271576
}
15281577

src/tir/schedule/schedule.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,10 +253,10 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnannotate")
253253
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformLayout")
254254
.set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index,
255255
int buffer_index_type, const IndexMap& index_map,
256-
const Optional<IndexMap>& pad_value) {
256+
const Optional<IndexMap>& pad_value, bool assume_injective_transform) {
257257
return self->TransformLayout(block_rv, buffer_index,
258258
static_cast<BufferIndexType>(buffer_index_type), index_map,
259-
pad_value);
259+
pad_value, assume_injective_transform);
260260
});
261261
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformBlockLayout")
262262
.set_body_method<Schedule>(&ScheduleNode::TransformBlockLayout);

src/tir/schedule/traced_schedule.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -523,15 +523,18 @@ void TracedScheduleNode::Unannotate(const BlockRV& block_rv, const String& ann_k
523523
void TracedScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_index,
524524
BufferIndexType buffer_index_type,
525525
const IndexMap& index_map,
526-
const Optional<IndexMap>& pad_value) {
526+
const Optional<IndexMap>& pad_value,
527+
bool assume_injective_transform) {
527528
ConcreteScheduleNode::TransformLayout(block_rv, buffer_index, buffer_index_type, index_map,
528529
pad_value);
529530
static const InstructionKind& kind = InstructionKind::Get("TransformLayout");
530531
trace_->Append(
531532
/*inst=*/Instruction(
532533
/*kind=*/kind,
533534
/*inputs=*/{block_rv, index_map},
534-
/*attrs=*/{Integer(buffer_index), Integer(buffer_index_type), pad_value},
535+
/*attrs=*/
536+
{Integer(buffer_index), Integer(buffer_index_type), pad_value,
537+
Bool(assume_injective_transform)},
535538
/*outputs=*/{}));
536539
}
537540

src/tir/schedule/traced_schedule.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ class TracedScheduleNode : public ConcreteScheduleNode {
107107
void Unannotate(const BlockRV& block_rv, const String& ann_key) override;
108108
/******** Schedule: Layout transformation ********/
109109
void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type,
110-
const IndexMap& index_map, const Optional<IndexMap>& pad_value) override;
110+
const IndexMap& index_map, const Optional<IndexMap>& pad_value,
111+
bool assume_injective_transform) override;
111112
void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) override;
112113
void SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
113114
BufferIndexType buffer_index_type,

tests/python/unittest/test_tir_schedule_transform_layout.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,8 @@ class BasePaddingCompare(tvm.testing.CompareBeforeAfter):
477477

478478
index_map = tvm.testing.parameter(lambda i: [i // 4, i % 4])
479479

480+
assume_injective_transform = tvm.testing.parameter(False)
481+
480482
@pytest.fixture
481483
def transform(self, pad_value, transformed_buffer, index_map):
482484
def transform(mod):
@@ -565,11 +567,26 @@ def expected():
565567
A[i // 4, i % 4] = 0
566568

567569

568-
class TestImplicitPadding(BasePaddingCompare):
569-
"""When pad_value is None, the buffer can be implicitly padded. The padded region is not
570-
accessed because the original loop extent is not changed.
570+
class TestErrorIfPaddingForbidden(BasePaddingCompare):
571+
"""Unless padding is explicitly enabled, should raise error"""
572+
573+
def before():
574+
A = T.alloc_buffer(14, "int32")
575+
for i in T.serial(14):
576+
with T.block("block"):
577+
vi = T.axis.remap("S", [i])
578+
A[vi] = 0
579+
580+
expected = tvm.tir.schedule.schedule.ScheduleError
581+
582+
583+
class TestImplicitPaddingAssumeInjective(BasePaddingCompare):
584+
"""When pad_value is None and assume_injective_transform is set, the buffer can be implicitly
585+
padded. The padded region is not accessed because the original loop extent is not changed.
571586
"""
572587

588+
assume_injective_transform = tvm.testing.parameter(True)
589+
573590
def before():
574591
A = T.alloc_buffer(14, "int32")
575592
for i in T.serial(14):

0 commit comments

Comments
 (0)