Skip to content

Commit 9d732d0

Browse files
authored
[TensorIR][Primitive] New schedule primitive reindex_cache_read/write (#14161)
# Motivation Currently, we have schedule primitives `cache_read`/`cache_write`, which allocate cache buffers and create cache stages copying data from the buffer being accessed to the cache buffer. However, `cache_read`/`cache_write` do only support customized indices. For the following block: ```python @T.prim_func def func(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (129, 129)) B = T.match_buffer(b, (128, 128)) for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi + 1, vj + 1] * 2.0 ``` after `cache_read("B", 0, "share")`, we get: ```python # from tvm.script import tir as T @T.prim_func def main(A: T.Buffer((129, 129), "float32"), B: T.Buffer((128, 128), "float32")): # with T.block("root"): A_shared = T.alloc_buffer((129, 129), scope="shared") for ax0, ax1 in T.grid(129, 129): with T.block("A_shared"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v0, v1]) T.writes(A_shared[v0, v1]) A_shared[v0, v1] = A[v0, v1] for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A_shared[vi + 1, vj + 1]) T.writes(B[vi, vj]) B[vi, vj] = A_shared[vi + 1, vj + 1] * T.float32(2) ``` where we access `A_shared` using the same indices(`vi + 1, vj + 1`) as original block, which is not flexible especially we want to do some layout transformation while copying data from original buffer to cache buffer (in MMA tensorization, and in flashattention) This PR propose a new interface that enables us to customize the indices to access the cache buffer, which is expressive enough to describe transposing and blocking. # Proposed API Below is the proposed interface of `reindex_cache_read` (`reindex_cache_write` has similar interface): ```python def reindex_cache_read( self, block: Union[BlockRV, str], read_buffer_index: int, storage_scope: str, index_map: Union[IndexMap, Callable], ) -> BlockRV: ... ``` Where `block`, `read_buffer_index` and `storage_scope` have the same meaning as in `cache_read`, there is another argument `index_map` specifies what indices to use to access the cache buffer, in the form of a index map that maps current block itervars to target indices. Suppose the block has itervars `vi, vj` and the user wants to access the cache buffer with customized indices `[vi // 16, vj // 16, vi % 16, vj % 16]`, user should set the argument `index_map` to `lambda vi, vj: (vi // 16, vj // 16, vi % 16, vj % 16)`. # Example By applying `reindex_cache_read("B", 0, lambda i, j: (j, i))` to `func`, we get: ```python @T.prim_func def main(A: T.Buffer((129, 129), "float32"), B: T.Buffer((128, 128), "float32")): # with T.block("root"): A_shared = T.alloc_buffer((128, 128), scope="shared") for i, j in T.grid(128, 128): with T.block("A_shared"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A[vi + 1, vj + 1]) T.writes(A_shared[vj, vi]) A_shared[vj, vi] = A[vi + 1, vj + 1] for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) T.reads(A_shared[vj, vi]) T.writes(B[vi, vj]) B[vi, vj] = A_shared[vj, vi] * T.float32(2) ``` # Notes Unlike `cache_read`/`cache_write` which allows `cache_read` a rectangle region, we only allows `reindex_cache_read` a single point, but it's enough to cover most use cases. The cache stage block follows the original order of loops and block itervars in the block. If a block itervar does not appear in the buffer access region, it and its corresponding loop variables will be omitted. User can then use `transform_block_layout` primitive to reorder the block itervars and surrounding loops of the cache read/write block. # Relations to Existing Schedule Primitives - Relation with `reindex` - `reindex` only supports the special case of `reindex_cache_read`/`reindex_cache_write`, where`index_map` is the identity map, `reindex` does not have `storage_scope` field. - Relation with `transform_layout` - `transform_layout` is not designed to transform the layout of input buffers instead of intermediate buffers, and does not have a `storage_scope` field. - Relation with `cache_read/wite` - `cache_read`/`cache_write` do not support customized indices.
1 parent e59d1ef commit 9d732d0

File tree

10 files changed

+1149
-31
lines changed

10 files changed

+1149
-31
lines changed

include/tvm/tir/schedule/schedule.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,34 @@ class ScheduleNode : public runtime::Object {
405405
virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
406406
const String& storage_scope,
407407
const Array<BlockRV> consumer_blocks = {}) = 0;
408+
/*!
409+
* \brief Create a block that reads a buffer region into a read cache. It requires:
410+
* 1) There is at most one block who writes the buffer in the scope.
411+
* 2) The scope block have stage-pipeline property.
412+
* Compared to cache read, the indices to access allocated cache buffer is customized by user.
413+
* \param block_rv The consumer block of the target buffer.
414+
* \param read_buffer_index The index of the buffer in block's read region.
415+
* \param storage_scope The target storage scope.
416+
* \param index_map User defined indices to access allocated cache buffer, maps from block iter
417+
* vars.
418+
* \return The cache stage block.
419+
*/
420+
virtual BlockRV ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index,
421+
const String& storage_scope, const IndexMap& index_map) = 0;
422+
/*!
423+
* \brief Create a block that writes a buffer region into a write cache. It requires:
424+
* 1) There is only one block who writes the target buffer.
425+
* 2) The scope block have stage-pipeline property.
426+
* Compared to cache write, the indices to access allocated cache buffer is customized by user.
427+
* \param block_rv The producer of the buffer
428+
* \param write_buffer_index The index of the buffer in block's write region
429+
* \param storage_scope The target storage scope
430+
* \param index_map User defined indices to access allocated cache buffer, maps from block iter
431+
* vars.
432+
* \return The cache stage block.
433+
*/
434+
virtual BlockRV ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index,
435+
const String& storage_scope, const IndexMap& index_map) = 0;
408436
/*!
409437
* \brief Create 2 blocks that read&write a buffer region into a read/write cache.
410438
* It requires the the target block both read & write the target buffer.

python/tvm/tir/schedule/schedule.py

Lines changed: 193 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,7 +1115,7 @@ def cache_write(
11151115
block: Union[BlockRV, str],
11161116
write_buffer_index: Union[int, str, Buffer],
11171117
storage_scope: str,
1118-
consumer_blocks=None,
1118+
consumer_blocks: Optional[List[Union[BlockRV, str]]] = None,
11191119
) -> BlockRV:
11201120
"""Create a block that reads a buffer region into a write cache. It requires:
11211121
@@ -1203,6 +1203,197 @@ def after_cache_write(a: T.handle, b: T.handle) -> None:
12031203
self, block, write_buffer_index, storage_scope, consumer_blocks
12041204
)
12051205

1206+
@type_checked
1207+
def reindex_cache_read(
1208+
self,
1209+
block: Union[BlockRV, str],
1210+
read_buffer_index: int,
1211+
storage_scope: str,
1212+
index_map: Union[IndexMap, Callable],
1213+
) -> BlockRV:
1214+
"""Create a block that reads a buffer region into a read cache using customized
1215+
indices specified by index map. The read region of the buffer must be a single point.
1216+
1217+
The cache stage block follows the original order of loops and block itervars in the block.
1218+
If a block itervar does not appear in the buffer access region, it and its corresponding
1219+
loop variables will be omitted. User can then use `transform_block_layout` primitive to
1220+
reorder the block itervars and surrounding loops of the cache read/write block.
1221+
1222+
Unlike `cache_read`, `reindex_cache_read` only supports single consumer, please use
1223+
`cache_read` when there are multiple consumers.
1224+
1225+
Parameters
1226+
----------
1227+
block : BlockRV
1228+
The consumer block of the target buffer.
1229+
read_buffer_index: int
1230+
The index of the buffer in block's read region.
1231+
storage_scope: str
1232+
The target storage scope.
1233+
index_map: Union[IndexMap, Callable]
1234+
User defined indices to access allocated cache buffer, maps from block iter vars.
1235+
1236+
Returns
1237+
-------
1238+
cached_block : BlockRV
1239+
The block of the cache stage
1240+
1241+
Examples
1242+
--------
1243+
Before reindex_cache_read, in TensorIR, the IR is:
1244+
1245+
.. code-block:: python
1246+
1247+
@T.prim_func
1248+
def before_reindex_cache_read(a: T.handle, b: T.handle) -> None:
1249+
A = T.match_buffer(a, (128, 128))
1250+
B = T.match_buffer(b, (128, 128))
1251+
for i, j in T.grid(128, 128):
1252+
with T.block("B"):
1253+
vi, vj = T.axis.remap("SS", [i, j])
1254+
B[vi, vj] = A[vi, vj] * 2.0
1255+
1256+
Create the schedule and reindex_cache_read:
1257+
1258+
.. code-block:: python
1259+
1260+
sch = tir.Schedule(before_cache_read)
1261+
block_b = sch.get_block("B")
1262+
sch.reindex_cache_read(block_b, 0, "local", lambda vi, vj: (vj, vi))
1263+
print(sch.mod["main"].script())
1264+
1265+
After applying reindex_cache_read, the IR becomes:
1266+
1267+
.. code-block:: python
1268+
1269+
@T.prim_func
1270+
def after_reindex_cache_read(a: T.handle, b: T.handle) -> None:
1271+
A = T.match_buffer(a, (128, 128))
1272+
B = T.match_buffer(b, (128, 128))
1273+
A_local = T.alloc_buffer((128, 128), scope="local")
1274+
for i, j in T.grid(128, 128):
1275+
with T.block("A_local"):
1276+
vi, vj = T.axis.remap("SS", [i, j])
1277+
A_local[vj, vi] = A[vi, vj]
1278+
for i, j in T.grid(128, 128):
1279+
with T.block("B"):
1280+
vi, vj = T.axis.remap("SS", [i, j])
1281+
B[vi, vj] = A_local[vj, vi] * 2.0
1282+
1283+
See Also
1284+
--------
1285+
reindex_cache_write
1286+
transform_block_layout
1287+
transform_layout
1288+
cache_read
1289+
reindex
1290+
"""
1291+
# Convert any string block names into Block RVs.
1292+
block = self._normalize_block_arg(block)
1293+
1294+
if callable(index_map):
1295+
index_map = IndexMap.from_func(index_map)
1296+
return _ffi_api.ScheduleReindexCacheRead( # type: ignore # pylint: disable=no-member
1297+
self, block, read_buffer_index, storage_scope, index_map
1298+
)
1299+
1300+
@type_checked
1301+
def reindex_cache_write(
1302+
self,
1303+
block: Union[BlockRV, str],
1304+
write_buffer_index: int,
1305+
storage_scope: str,
1306+
index_map: Union[Callable, IndexMap],
1307+
) -> BlockRV:
1308+
r"""Create a block that reads a buffer region into a write cache using customized
1309+
indices specified by index map. The write region of the buffer must be a single point.
1310+
1311+
The cache stage block follows the original order of loops and block itervars in the block.
1312+
If a block itervar does not appear in the buffer access region, it and its corresponding
1313+
loop variables will be omitted. User can then use `transform_block_layout` primitive to
1314+
reorder the block itervars and surrounding loops of the cache read/write block.
1315+
1316+
Unlike `cache_write`, `reindex_cache_write` only supports single consumer, please use
1317+
`cache_write` when there are multiple consumers.
1318+
1319+
Parameters
1320+
----------
1321+
block : Union[BlockRV, str]
1322+
The consumer block of the target buffer.
1323+
write_buffer_index: int
1324+
The index of the buffer in block's write region.
1325+
storage_scope: str
1326+
The target storage scope.
1327+
index_map: Union[Callable, IndexMap]
1328+
User defined indices to access allocated cache buffer, maps from block iter vars.
1329+
consumer_blocks: Optional[List[Union[BlockRV, str]]]
1330+
An optional list of consumers that should read directly from the cache.
1331+
If not specified, all consumers will read from the original buffer.
1332+
1333+
Returns
1334+
-------
1335+
cached_block : BlockRV
1336+
The block of the cache stage
1337+
1338+
Examples
1339+
--------
1340+
Before reindex_cache_write, in TensorIR, the IR is:
1341+
1342+
.. code-block:: python
1343+
1344+
@T.prim_func
1345+
def before_reindex_cache_write(a: T.handle, b: T.handle) -> None:
1346+
A = T.match_buffer(a, (128, 128))
1347+
B = T.match_buffer(b, (128, 128))
1348+
for i, j in T.grid(128, 128):
1349+
with T.block("B"):
1350+
vi, vj = T.axis.remap("SS", [i, j])
1351+
B[vi, vj] = A[vi, vj] * 2.0
1352+
1353+
Create the schedule and reindex_cache_write:
1354+
1355+
.. code-block:: python
1356+
1357+
sch = tir.Schedule(before_cache_write)
1358+
block_b = sch.get_block("B")
1359+
sch.reindex_cache_write(block_b, 0, "local", lambda vi, vj: (vi // 2, vi % 2, vj))
1360+
print(sch.mod["main"].script())
1361+
1362+
After applying reindex_cache_write, the IR becomes:
1363+
1364+
.. code-block:: python
1365+
1366+
@T.prim_func
1367+
def after_cache_write(a: T.handle, b: T.handle) -> None:
1368+
A = T.match_buffer(a, (128, 128))
1369+
B = T.match_buffer(b, (64, 2, 128))
1370+
B_local = T.alloc_buffer((128, 128), scope="local")
1371+
for i, j in T.grid(128, 128):
1372+
with T.block("A_local"):
1373+
vi, vj = T.axis.remap("SS", [i, j])
1374+
B_local[vi % 2, vi // 2, vj] = A[vi, vj] * 2.0
1375+
for i, j in T.grid(128, 128):
1376+
with T.block("B"):
1377+
vi, vj = T.axis.remap("SS", [i, j])
1378+
B[vi, vj] = B_local[vi % 2, vi // 2, vj]
1379+
1380+
See Also
1381+
--------
1382+
reindex_cache_read
1383+
transform_block_layout
1384+
transform_layout
1385+
cache_write
1386+
reindex
1387+
"""
1388+
# Convert any string block names into Block RVs.
1389+
block = self._normalize_block_arg(block)
1390+
1391+
if callable(index_map):
1392+
index_map = IndexMap.from_func(index_map)
1393+
return _ffi_api.ScheduleReindexCacheWrite( # type: ignore # pylint: disable=no-member
1394+
self, block, write_buffer_index, storage_scope, index_map
1395+
)
1396+
12061397
@type_checked
12071398
def cache_inplace(
12081399
self,
@@ -1439,7 +1630,7 @@ def before_reindex(
14391630
vi, vj = T.axis.remap("SS", [i, j])
14401631
B[vi, vj] = A[vj, vi] * 2.0
14411632
1442-
Create the schedule and do transform_layout:
1633+
Create the schedule and do reindex:
14431634
14441635
.. code-block:: python
14451636

src/tir/schedule/concrete_schedule.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,30 @@ BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buff
568568
return CreateRV<BlockRV>(result);
569569
}
570570

571+
BlockRV ConcreteScheduleNode::ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index,
572+
const String& storage_scope,
573+
const IndexMap& index_map) {
574+
StmtSRef result{nullptr};
575+
TVM_TIR_SCHEDULE_BEGIN();
576+
result = tir::ReindexCacheRead(state_, this->GetSRef(block_rv), read_buffer_index, storage_scope,
577+
index_map);
578+
TVM_TIR_SCHEDULE_END("reverse-cache-read", this->error_render_level_);
579+
this->state_->DebugVerify();
580+
return CreateRV<BlockRV>(result);
581+
}
582+
583+
BlockRV ConcreteScheduleNode::ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index,
584+
const String& storage_scope,
585+
const IndexMap& index_map) {
586+
StmtSRef result{nullptr};
587+
TVM_TIR_SCHEDULE_BEGIN();
588+
result = tir::ReindexCacheWrite(state_, this->GetSRef(block_rv), write_buffer_index,
589+
storage_scope, index_map);
590+
TVM_TIR_SCHEDULE_END("reverse-cache-write", this->error_render_level_);
591+
this->state_->DebugVerify();
592+
return CreateRV<BlockRV>(result);
593+
}
594+
571595
Array<BlockRV> ConcreteScheduleNode::CacheInplace(const BlockRV& block_rv, int write_buffer_index,
572596
const String& storage_scope) {
573597
Array<StmtSRef> results;

src/tir/schedule/concrete_schedule.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@ class ConcreteScheduleNode : public ScheduleNode {
116116
const Array<BlockRV> consumer_blocks = {}) override;
117117
BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope,
118118
const Array<BlockRV> consumer_blocks = {}) override;
119+
BlockRV ReindexCacheRead(const BlockRV& block_rv, int read_buffer_index,
120+
const String& storage_scope, const IndexMap& index_map) override;
121+
BlockRV ReindexCacheWrite(const BlockRV& block_rv, int write_buffer_index,
122+
const String& storage_scope, const IndexMap& index_map) override;
119123
Array<BlockRV> CacheInplace(const BlockRV& block_rv, int read_buffer_index,
120124
const String& storage_scope) override;
121125
Array<BlockRV> CacheIndex(const BlockRV& block_rv, const String& storage_scope,

src/tir/schedule/primitive.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,39 @@ TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int r
269269
TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index,
270270
const String& storage_scope,
271271
const Array<StmtSRef> consumer_blocks = {});
272+
/*!
273+
* \brief Create a block that reads a buffer region into a read cache. It requires:
274+
* 1) There is at most one block who writes the buffer in the scope.
275+
* 2) The scope block have stage-pipeline property.
276+
* Compared to cache read, the indices to access allocated cache buffer is customized by user.
277+
* \param self The state of the schedule
278+
* \param block_sref The consumer block of the target buffer.
279+
* \param read_buffer_index The index of the buffer in block's read region.
280+
* \param storage_scope The target storage scope.
281+
* \param index_map User defined indices to access allocated cache buffer, maps from block iter
282+
* vars.
283+
* \return The cache stage block.
284+
*/
285+
TVM_DLL StmtSRef ReindexCacheRead(ScheduleState self, const StmtSRef& block_sref,
286+
int read_buffer_index, const String& storage_scope,
287+
const IndexMap& index_map);
288+
/*!
289+
* \brief Create a block that writes a buffer region into a write cache. It requires:
290+
* 1) There is only one block that writes the target buffer.
291+
* 2) The scope block have stage-pipeline property.
292+
* Compared to cache write, the indices to access allocated cache buffer is customized by user.
293+
* \param self The state of the schedule
294+
* \param block_sref The producer of the buffer
295+
* \param write_buffer_index The index of the buffer in block's write region
296+
* \param storage_scope The target storage scope
297+
* \param index_map User defined indices to access allocated cache buffer, maps from block iter
298+
* vars.
299+
* \return The cache stage block.
300+
*/
301+
TVM_DLL StmtSRef ReindexCacheWrite(ScheduleState self, const StmtSRef& block_sref,
302+
int write_buffer_index, const String& storage_scope,
303+
const IndexMap& index_map);
304+
272305
/*!
273306
*!
274307
* \brief Create 2 blocks that read&write a buffer region into a read/write cache.

0 commit comments

Comments
 (0)