Skip to content

Commit 936b864

Browse files
committed
[TIR] Add schedule primitive TransformBlockLayout
1 parent 4a769c1 commit 936b864

File tree

15 files changed

+632
-26
lines changed

15 files changed

+632
-26
lines changed

include/tvm/tir/schedule/schedule.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,17 @@ class ScheduleNode : public runtime::Object {
545545
virtual void TransformLayout(const BlockRV& block_rv, int buffer_index,
546546
BufferIndexType buffer_index_type, const IndexMap& index_map) = 0;
547547

548+
/*!
549+
* \brief Apply a transformation represented by IndexMap to block
550+
* \details The block iters and the block body are transformed by the given index_map.
551+
* Outer loops corresponding to each new block iter are regenerated.
552+
* The index_map is required to be bijective affine since we need its inverse mapping.
553+
* \param self The state of the schedule
554+
* \param block_sref The block sref that refers to the block to be transformed
555+
* \param affine_index_map The transformation to apply.
556+
*/
557+
virtual void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) = 0;
558+
548559
/*!
549560
* \brief Set the axis separator of a buffer, where the buffer is specified by a block and a read
550561
* or write index

python/tvm/tir/schedule/schedule.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2286,6 +2286,67 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) ->
22862286
self, block, buffer_index, buffer_index_type_enum, axis_separators
22872287
)
22882288

2289+
@type_checked
2290+
def transform_block_layout(
2291+
self,
2292+
block: BlockRV,
2293+
index_map: Union[IndexMap, Callable],
2294+
) -> None:
2295+
"""Apply a transformation represented by IndexMap to block
2296+
2297+
Parameters
2298+
----------
2299+
block_rv : BlockRV
2300+
The block to be transformed
2301+
2302+
index_map : Union[IndexMap, Callable]
2303+
The transformation to apply.
2304+
2305+
Examples
2306+
--------
2307+
2308+
Before transform_block_layout, in TensorIR, the IR is:
2309+
2310+
.. code-block:: python
2311+
2312+
@T.prim_func
2313+
def before_transform_block_layout(
2314+
A: T.Buffer[(16, 16), "float32"],
2315+
B: T.Buffer[(16, 16), "float32"]
2316+
) -> None:
2317+
for i, j in T.grid(16, 16):
2318+
with T.block("B"):
2319+
vi, vj = T.axis.remap("SS", [i, j])
2320+
B[vi, vj] = A[vi, vj] * 2.0
2321+
2322+
Create the schedule and do transform_block_layout:
2323+
2324+
.. code-block:: python
2325+
2326+
sch = tir.Schedule(before_transform_block_layout)
2327+
sch.transform_block_layout(sch.get_block("B"), lambda i, j: (i * 16 + j,))
2328+
print(sch.mod["main"].script())
2329+
2330+
After applying transform_block_layout, the IR becomes:
2331+
2332+
.. code-block:: python
2333+
2334+
@T.prim_func
2335+
def after_transform_block_layout(
2336+
A: T.Buffer[(16, 16), "float32"],
2337+
B: T.Buffer[(16, 16), "float32"]
2338+
) -> None:
2339+
for i in range(256):
2340+
with T.block("B"):
2341+
vi, = T.axis.remap("S", [i])
2342+
B[vi // 16, vi % 16] = A[vi // 16, vi % 16] * 2.0
2343+
"""
2344+
if callable(index_map):
2345+
index_map = IndexMap.from_func(index_map)
2346+
_ffi_api.ScheduleTransformBlockLayout( # type: ignore # pylint: disable=no-member
2347+
self, block, index_map
2348+
)
2349+
22892350
@type_checked
22902351
def set_axis_separator(
22912352
self,

src/tir/schedule/analysis.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,16 @@ bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize,
277277
std::unordered_set<const VarNode*>* data_par_vars,
278278
std::unordered_set<const VarNode*>* reduce_vars);
279279

280+
/******** Loop properties ********/
281+
/*!
282+
* \brief Check the loop starts with zero.
283+
* \param self The schedule state
284+
* \param loop_sref The StmtSRef that points to the loop to be checked
285+
* \param analyzer The arithmetic analyzer
286+
* \throw ScheduleError If the loop doesn't starts with zero.
287+
*/
288+
void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef& loop_sref, arith::Analyzer* analyzer);
289+
280290
/******** Block-loop relation ********/
281291

282292
/*!

src/tir/schedule/analysis/analysis.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,33 @@ bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize,
686686
return has_block_vars_of_other_types;
687687
}
688688

689+
/******** Loop properties ********/
690+
691+
void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef& loop_sref, arith::Analyzer* analyzer) {
692+
class LoopNotStartWithZeroError : public ScheduleError {
693+
public:
694+
explicit LoopNotStartWithZeroError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {}
695+
696+
String FastErrorString() const final {
697+
return "ScheduleError: The primitive only supports loop starting with 0";
698+
}
699+
700+
String DetailRenderTemplate() const final {
701+
return "The loop {0} does not start with 0, which is not supported";
702+
}
703+
704+
IRModule mod() const final { return mod_; }
705+
Array<ObjectRef> LocationsOfInterest() const final { return {loop_}; }
706+
707+
IRModule mod_;
708+
For loop_;
709+
};
710+
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
711+
if (!analyzer->CanProve(loop->min == 0)) {
712+
throw LoopNotStartWithZeroError(self->mod, GetRef<For>(loop));
713+
}
714+
}
715+
689716
/******** Block-loop relation ********/
690717

691718
Array<StmtSRef> GetChildBlockSRefOnSRefTree(const ScheduleState& self,

src/tir/schedule/concrete_schedule.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,14 @@ void ConcreteScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_i
693693
TVM_TIR_SCHEDULE_END("transform_layout", this->error_render_level_);
694694
}
695695

696+
void ConcreteScheduleNode::TransformBlockLayout(const BlockRV& block_rv,
697+
const IndexMap& index_map) {
698+
TVM_TIR_SCHEDULE_BEGIN();
699+
tir::TransformBlockLayout(state_, this->GetSRef(block_rv), index_map);
700+
this->state_->DebugVerify();
701+
TVM_TIR_SCHEDULE_END("transform_block_layout", this->error_render_level_);
702+
}
703+
696704
void ConcreteScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
697705
BufferIndexType buffer_index_type,
698706
const Array<IntImm>& axis_separators) {

src/tir/schedule/concrete_schedule.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ class ConcreteScheduleNode : public ScheduleNode {
134134
/******** Schedule: Layout transformation ********/
135135
void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type,
136136
const IndexMap& index_map) override;
137+
void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) override;
137138
void SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
138139
BufferIndexType buffer_index_type,
139140
const Array<IntImm>& axis_separators) override;

src/tir/schedule/primitive.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,18 @@ TVM_DLL void Unannotate(ScheduleState self, const StmtSRef& sref, const String&
442442
TVM_DLL void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
443443
BufferIndexType buffer_index_type, const IndexMap& index_map);
444444

445+
/*!
446+
* \brief Apply a transformation represented by IndexMap to block
447+
* \details The block iters and the block body are transformed by the given index_map.
448+
* Outer loops corresponding to each new block iter are regenerated.
449+
* The index_map is required to be bijective affine since we need its inverse mapping.
450+
* \param self The state of the schedule
451+
* \param block_sref The block sref that refers to the block to be transformed
452+
* \param affine_index_map The transformation to apply.
453+
*/
454+
TVM_DLL void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref,
455+
const IndexMap& index_map);
456+
445457
/******** Schedule: Misc ********/
446458

447459
} // namespace tir

0 commit comments

Comments
 (0)