Skip to content

Commit 6d48a37

Browse files
committed
[TIR] Add schedule primitive ReIndex
1 parent 2252f95 commit 6d48a37

File tree

12 files changed

+833
-0
lines changed

12 files changed

+833
-0
lines changed

include/tvm/tir/schedule/schedule.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,19 @@ class ScheduleNode : public runtime::Object {
364364
*/
365365
virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
366366
const String& storage_scope) = 0;
367+
/*!
368+
* \brief Create a block that read/write a buffer region into a read/write cache with reindexing.
369+
* The layout of the cache will be the same as by the iterators of the block that reads/writes the
370+
* buffer. It requires:
371+
* 1) There is only one block who reads/writes the target buffer
372+
* 2) There is only one buffer load/store of this buffer in the block
373+
* \param block_rv The block operates on the target buffer.
374+
* \param buffer_index The index of the buffer in block's read or write region.
375+
* \param buffer_index_type The type of the buffer index, kRead or kWrite.
376+
* \return The reindex stage block.
377+
*/
378+
virtual BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
379+
BufferIndexType buffer_index_type) = 0;
367380
/******** Schedule: Compute location ********/
368381
/*!
369382
* \brief Move a producer block under the specific loop, and regenerate the

python/tvm/tir/schedule/schedule.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,6 +1056,79 @@ def after_cache_write(a: T.handle, b: T.handle) -> None:
10561056
self, block, write_buffer_index, storage_scope
10571057
)
10581058

1059+
@type_checked
1060+
def reindex(self, block: BlockRV, buffer_index: int, buffer_index_type: str) -> BlockRV:
1061+
"""Create a block that read/write a buffer region into a read/write cache with reindexing.
1062+
The layout of the cache will be the same as by the iterators of the block that reads/writes
1063+
the buffer. It requires:
1064+
1) There is only one block who reads/writes the target buffer
1065+
2) There is only one buffer load/store of this buffer in the block
1066+
1067+
Parameters
1068+
----------
1069+
block: BlockRV
1070+
The block that accesses the target buffer
1071+
buffer_index: int
1072+
The index of the buffer in block's read or write region
1073+
buffer_index_type : str
1074+
Type of the buffer index, "read" or "write"
1075+
1076+
Returns
1077+
-------
1078+
reindex_block : BlockRV
1079+
The block of the reindex stage
1080+
1081+
Examples
1082+
--------
1083+
1084+
Before transform_layout, in TensorIR, the IR is:
1085+
1086+
.. code-block:: python
1087+
1088+
@T.prim_func
1089+
def before_reindex(
1090+
A: T.Buffer[(128, 128), "float32"],
1091+
B: T.Buffer[(128, 128), "float32"]
1092+
) -> None:
1093+
for i, j in T.grid(128, 128):
1094+
with T.block("B"):
1095+
vi, vj = T.axis.remap("SS", [i, j])
1096+
B[vi, vj] = A[vj, vi] * 2.0
1097+
1098+
Create the schedule and do transform_layout:
1099+
1100+
.. code-block:: python
1101+
1102+
sch = tir.Schedule(before_reindex)
1103+
block = sch.get_block("B")
1104+
sch.reindex(block, 0, "read)
1105+
1106+
After applying reindex, the IR becomes:
1107+
1108+
.. code-block:: python
1109+
1110+
@T.prim_func
1111+
def after_reindex(
1112+
A: T.Buffer[(128, 128), "float32"],
1113+
B: T.Buffer[(128, 128), "float32"]
1114+
) -> None:
1115+
A_reindex = T.alloc_buffer((128, 128), "float32")
1116+
for i, j in T.grid(128, 128):
1117+
with T.block("A_reindex"):
1118+
vi, vj = T.axis.remap("SS", [i, j])
1119+
A_reindex[vi, vj] = A[vj, vi]
1120+
for i, j in T.grid(128, 128):
1121+
with T.block("B"):
1122+
vi, vj = T.axis.remap("SS", [i, j])
1123+
B[vi, vj] = A_reindex[vi, vj] * 2.0
1124+
1125+
"""
1126+
assert buffer_index_type in ["read", "write"], "Invalid buffer_index_type"
1127+
buffer_index_type_enum = 0 if buffer_index_type == "read" else 1
1128+
return _ffi_api.ScheduleReIndex( # type: ignore # pylint: disable=no-member
1129+
self, block, buffer_index, buffer_index_type_enum
1130+
)
1131+
10591132
########## Schedule: Compute location ##########
10601133

10611134
@type_checked

src/tir/schedule/concrete_schedule.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,16 @@ BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buff
511511
return CreateRV<BlockRV>(result);
512512
}
513513

514+
BlockRV ConcreteScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index,
515+
BufferIndexType buffer_index_type) {
516+
StmtSRef result{nullptr};
517+
TVM_TIR_SCHEDULE_BEGIN();
518+
result = tir::ReIndex(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type);
519+
TVM_TIR_SCHEDULE_END("reindex", this->error_render_level_);
520+
this->state_->DebugVerify();
521+
return CreateRV<BlockRV>(result);
522+
}
523+
514524
/******** Schedule: Compute location ********/
515525

516526
void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,

src/tir/schedule/concrete_schedule.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ class ConcreteScheduleNode : public ScheduleNode {
109109
const String& storage_scope) override;
110110
BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
111111
const String& storage_scope) override;
112+
BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
113+
BufferIndexType buffer_index_type) override;
112114
/******** Schedule: Compute location ********/
113115
void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) override;
114116
void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,

src/tir/schedule/primitive.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,21 @@ TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int r
253253
*/
254254
TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index,
255255
const String& storage_scope);
256+
/*!
257+
*!
258+
* \brief Create a block that read/write a buffer region into a read/write cache with reindexing.
259+
* The layout of the cache will be the same as by the iterators of the block that reads/writes the
260+
* buffer. It requires:
261+
* 1) There is only one block who reads/writes the target buffer
262+
* 2) There is only one buffer load/store of this buffer in the block
263+
* \param self The state of the schedule
264+
* \param block_rv The block operates on the target buffer.
265+
* \param buffer_index The index of the buffer in block's read or write region.
266+
* \param buffer_index_type The type of the buffer index, kRead or kWrite.
267+
* \return The reindex stage block.
268+
*/
269+
TVM_DLL StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
270+
BufferIndexType buffer_index_type);
256271
/******** Schedule: Compute location ********/
257272
/*!
258273
* \brief Move a producer block under the specific loop, and regenerate the

0 commit comments

Comments
 (0)