Skip to content
Merged
28 changes: 28 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,34 @@ class ScheduleNode : public runtime::Object {
virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope,
const Array<BlockRV> consumer_blocks = {}) = 0;
/*!
* \brief Create a block that reads a buffer region into a read cache. It requires:
* 1) There is at most one block who writes the buffer in the scope.
* 2) The scope block have stage-pipeline property.
* Compared to cache read, the indices to access allocated cache buffer is customized by user.
* \param block_rv The consumer block of the target buffer.
* \param read_buffer_index The index of the buffer in block's read region.
* \param storage_scope The target storage scope.
* \param index_map User defined indices to access allocated cache buffer, maps from block iter
* vars.
* \return The cache stage block.
*/
virtual BlockRV ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index,
const String& storage_scope, const IndexMap& index_map) = 0;
/*!
* \brief Create a block that writes a buffer region into a write cache. It requires:
* 1) There is only one block who writes the target buffer.
* 2) The scope block have stage-pipeline property.
* Compared to cache write, the indices to access allocated cache buffer is customized by user.
* \param block_rv The producer of the buffer
* \param write_buffer_index The index of the buffer in block's write region
* \param storage_scope The target storage scope
* \param index_map User defined indices to access allocated cache buffer, maps from block iter
* vars.
* \return The cache stage block.
*/
virtual BlockRV ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope, const IndexMap& index_map) = 0;
/*!
* \brief Create 2 blocks that read&write a buffer region into a read/write cache.
* It requires the the target block both read & write the target buffer.
Expand Down
197 changes: 194 additions & 3 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,7 +1115,7 @@ def cache_write(
block: Union[BlockRV, str],
write_buffer_index: Union[int, str, Buffer],
storage_scope: str,
consumer_blocks=None,
consumer_blocks: Optional[List[Union[BlockRV, str]]] = None,
) -> BlockRV:
"""Create a block that reads a buffer region into a write cache. It requires:

Expand Down Expand Up @@ -1203,6 +1203,197 @@ def after_cache_write(a: T.handle, b: T.handle) -> None:
self, block, write_buffer_index, storage_scope, consumer_blocks
)

@type_checked
def reindex_cache_read(
self,
block: Union[BlockRV, str],
read_buffer_index: int,
storage_scope: str,
index_map: Union[IndexMap, Callable],
) -> BlockRV:
"""Create a block that reads a buffer region into a read cache using customized
indices specified by index map. The read region of the buffer must be a single point.

The cache stage block follows the original order of loops and block itervars in the block.
If a block itervar does not appear in the buffer access region, it and its corresponding
loop variables will be omitted. User can then use `transform_block_layout` primitive to
reorder the block itervars and surrounding loops of the cache read/write block.

Unlike `cache_read`, `reindex_cache_read` only supports single consumer, please use
`cache_read` when there are multiple consumers.

Parameters
----------
block : BlockRV
The consumer block of the target buffer.
read_buffer_index: int
The index of the buffer in block's read region.
storage_scope: str
The target storage scope.
index_map: Union[IndexMap, Callable]
User defined indices to access allocated cache buffer, maps from block iter vars.

Returns
-------
cached_block : BlockRV
The block of the cache stage

Examples
--------
Before reindex_cache_read, in TensorIR, the IR is:

.. code-block:: python

@T.prim_func
def before_reindex_cache_read(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
B = T.match_buffer(b, (128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and reindex_cache_read:

.. code-block:: python

sch = tir.Schedule(before_cache_read)
block_b = sch.get_block("B")
sch.reindex_cache_read(block_b, 0, "local", lambda vi, vj: (vj, vi))
print(sch.mod["main"].script())

After applying reindex_cache_read, the IR becomes:

.. code-block:: python

@T.prim_func
def after_reindex_cache_read(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
B = T.match_buffer(b, (128, 128))
A_local = T.alloc_buffer((128, 128), scope="local")
for i, j in T.grid(128, 128):
with T.block("A_local"):
vi, vj = T.axis.remap("SS", [i, j])
A_local[vj, vi] = A[vi, vj]
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A_local[vj, vi] * 2.0

See Also
--------
reindex_cache_write
transform_block_layout
transform_layout
cache_read
reindex
"""
# Convert any string block names into Block RVs.
block = self._normalize_block_arg(block)

if callable(index_map):
index_map = IndexMap.from_func(index_map)
return _ffi_api.ScheduleReindexCacheRead( # type: ignore # pylint: disable=no-member
self, block, read_buffer_index, storage_scope, index_map
)

@type_checked
def reindex_cache_write(
self,
block: Union[BlockRV, str],
write_buffer_index: int,
storage_scope: str,
index_map: Union[Callable, IndexMap],
) -> BlockRV:
r"""Create a block that reads a buffer region into a write cache using customized
indices specified by index map. The write region of the buffer must be a single point.

The cache stage block follows the original order of loops and block itervars in the block.
If a block itervar does not appear in the buffer access region, it and its corresponding
loop variables will be omitted. User can then use `transform_block_layout` primitive to
reorder the block itervars and surrounding loops of the cache read/write block.

Unlike `cache_write`, `reindex_cache_write` only supports single consumer, please use
`cache_write` when there are multiple consumers.

Parameters
----------
block : Union[BlockRV, str]
The consumer block of the target buffer.
write_buffer_index: int
The index of the buffer in block's write region.
storage_scope: str
The target storage scope.
index_map: Union[Callable, IndexMap]
User defined indices to access allocated cache buffer, maps from block iter vars.
consumer_blocks: Optional[List[Union[BlockRV, str]]]
An optional list of consumers that should read directly from the cache.
If not specified, all consumers will read from the original buffer.

Returns
-------
cached_block : BlockRV
The block of the cache stage

Examples
--------
Before reindex_cache_write, in TensorIR, the IR is:

.. code-block:: python

@T.prim_func
def before_reindex_cache_write(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
B = T.match_buffer(b, (128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and reindex_cache_write:

.. code-block:: python

sch = tir.Schedule(before_cache_write)
block_b = sch.get_block("B")
sch.reindex_cache_write(block_b, 0, "local", lambda vi, vj: (vi // 2, vi % 2, vj))
print(sch.mod["main"].script())

After applying reindex_cache_write, the IR becomes:

.. code-block:: python

@T.prim_func
def after_cache_write(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
B = T.match_buffer(b, (64, 2, 128))
B_local = T.alloc_buffer((128, 128), scope="local")
for i, j in T.grid(128, 128):
with T.block("A_local"):
vi, vj = T.axis.remap("SS", [i, j])
B_local[vi % 2, vi // 2, vj] = A[vi, vj] * 2.0
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = B_local[vi % 2, vi // 2, vj]

See Also
--------
reindex_cache_read
transform_block_layout
transform_layout
cache_write
reindex
"""
# Convert any string block names into Block RVs.
block = self._normalize_block_arg(block)

if callable(index_map):
index_map = IndexMap.from_func(index_map)
return _ffi_api.ScheduleReindexCacheWrite( # type: ignore # pylint: disable=no-member
self, block, write_buffer_index, storage_scope, index_map
)

@type_checked
def cache_inplace(
self,
Expand Down Expand Up @@ -1425,7 +1616,7 @@ def reindex(
Examples
--------

Before transform_layout, in TensorIR, the IR is:
Before reindex, in TensorIR, the IR is:

.. code-block:: python

Expand All @@ -1439,7 +1630,7 @@ def before_reindex(
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vj, vi] * 2.0

Create the schedule and do transform_layout:
Create the schedule and do reindex:

.. code-block:: python

Expand Down
24 changes: 24 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,30 @@ BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buff
return CreateRV<BlockRV>(result);
}

BlockRV ConcreteScheduleNode::ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index,
const String& storage_scope,
const IndexMap& index_map) {
StmtSRef result{nullptr};
TVM_TIR_SCHEDULE_BEGIN();
result = tir::ReindexCacheRead(state_, this->GetSRef(block_rv), read_buffer_index, storage_scope,
index_map);
TVM_TIR_SCHEDULE_END("reverse-cache-read", this->error_render_level_);
this->state_->DebugVerify();
return CreateRV<BlockRV>(result);
}

BlockRV ConcreteScheduleNode::ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope,
const IndexMap& index_map) {
StmtSRef result{nullptr};
TVM_TIR_SCHEDULE_BEGIN();
result = tir::ReindexCacheWrite(state_, this->GetSRef(block_rv), write_buffer_index,
storage_scope, index_map);
TVM_TIR_SCHEDULE_END("reverse-cache-write", this->error_render_level_);
this->state_->DebugVerify();
return CreateRV<BlockRV>(result);
}

Array<BlockRV> ConcreteScheduleNode::CacheInplace(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) {
Array<StmtSRef> results;
Expand Down
4 changes: 4 additions & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ class ConcreteScheduleNode : public ScheduleNode {
const Array<BlockRV> consumer_blocks = {}) override;
BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope,
const Array<BlockRV> consumer_blocks = {}) override;
BlockRV ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index,
const String& storage_scope, const IndexMap& index_map) override;
BlockRV ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope, const IndexMap& index_map) override;
Array<BlockRV> CacheInplace(const BlockRV& block_rv, int read_buffer_index,
const String& storage_scope) override;
Array<BlockRV> CacheIndex(const BlockRV& block_rv, const String& storage_scope,
Expand Down
33 changes: 33 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,39 @@ TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int r
TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index,
const String& storage_scope,
const Array<StmtSRef> consumer_blocks = {});
/*!
* \brief Create a block that reads a buffer region into a read cache. It requires:
* 1) There is at most one block who writes the buffer in the scope.
* 2) The scope block have stage-pipeline property.
* Compared to cache read, the indices to access allocated cache buffer is customized by user.
* \param self The state of the schedule
* \param block_sref The consumer block of the target buffer.
* \param read_buffer_index The index of the buffer in block's read region.
* \param storage_scope The target storage scope.
* \param index_map User defined indices to access allocated cache buffer, maps from block iter
* vars.
* \return The cache stage block.
*/
TVM_DLL StmtSRef ReindexCacheRead(ScheduleState self, const StmtSRef& block_sref,
int read_buffer_index, const String& storage_scope,
const IndexMap& index_map);
/*!
* \brief Create a block that writes a buffer region into a write cache. It requires:
* 1) There is only one block that writes the target buffer.
* 2) The scope block have stage-pipeline property.
* Compared to cache write, the indices to access allocated cache buffer is customized by user.
* \param self The state of the schedule
* \param block_sref The producer of the buffer
* \param write_buffer_index The index of the buffer in block's write region
* \param storage_scope The target storage scope
* \param index_map User defined indices to access allocated cache buffer, maps from block iter
* vars.
* \return The cache stage block.
*/
TVM_DLL StmtSRef ReindexCacheWrite(ScheduleState self, const StmtSRef& block_sref,
int write_buffer_index, const String& storage_scope,
const IndexMap& index_map);

/*!
*!
* \brief Create 2 blocks that read&write a buffer region into a read/write cache.
Expand Down
Loading