Commit 9d732d0
authored
[TensorIR][Primitive] New schedule primitive
# Motivation
Currently, we have schedule primitives `cache_read`/`cache_write`, which allocate cache buffers and create cache stages copying data from the buffer being accessed to the cache buffer. However, `cache_read`/`cache_write` do only support customized indices. For the following block:
```python
@T.prim_func
def func(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, (129, 129))
B = T.match_buffer(b, (128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi + 1, vj + 1] * 2.0
```
after `cache_read("B", 0, "share")`, we get:
```python
# from tvm.script import tir as T
@T.prim_func
def main(A: T.Buffer((129, 129), "float32"), B: T.Buffer((128, 128), "float32")):
# with T.block("root"):
A_shared = T.alloc_buffer((129, 129), scope="shared")
for ax0, ax1 in T.grid(129, 129):
with T.block("A_shared"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v0, v1])
T.writes(A_shared[v0, v1])
A_shared[v0, v1] = A[v0, v1]
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A_shared[vi + 1, vj + 1])
T.writes(B[vi, vj])
B[vi, vj] = A_shared[vi + 1, vj + 1] * T.float32(2)
```
where we access `A_shared` using the same indices(`vi + 1, vj + 1`) as original block, which is not flexible especially we want to do some layout transformation while copying data from original buffer to cache buffer (in MMA tensorization, and in flashattention)
This PR propose a new interface that enables us to customize the indices to access the cache buffer, which is expressive enough to describe transposing and blocking.
# Proposed API
Below is the proposed interface of `reindex_cache_read` (`reindex_cache_write` has similar interface):
```python
def reindex_cache_read(
self,
block: Union[BlockRV, str],
read_buffer_index: int,
storage_scope: str,
index_map: Union[IndexMap, Callable],
) -> BlockRV:
...
```
Where `block`, `read_buffer_index` and `storage_scope` have the same meaning as in `cache_read`, there is another argument `index_map` specifies what indices to use to access the cache buffer, in the form of a index map that maps current block itervars to target indices. Suppose the block has itervars `vi, vj` and the user wants to access the cache buffer with customized indices `[vi // 16, vj // 16, vi % 16, vj % 16]`, user should set the argument `index_map` to `lambda vi, vj: (vi // 16, vj // 16, vi % 16, vj % 16)`.
# Example
By applying `reindex_cache_read("B", 0, lambda i, j: (j, i))` to `func`, we get:
```python
@T.prim_func
def main(A: T.Buffer((129, 129), "float32"), B: T.Buffer((128, 128), "float32")):
# with T.block("root"):
A_shared = T.alloc_buffer((128, 128), scope="shared")
for i, j in T.grid(128, 128):
with T.block("A_shared"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A[vi + 1, vj + 1])
T.writes(A_shared[vj, vi])
A_shared[vj, vi] = A[vi + 1, vj + 1]
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A_shared[vj, vi])
T.writes(B[vi, vj])
B[vi, vj] = A_shared[vj, vi] * T.float32(2)
```
# Notes
Unlike `cache_read`/`cache_write` which allows `cache_read` a rectangle region, we only allows `reindex_cache_read` a single point, but it's enough to cover most use cases.
The cache stage block follows the original order of loops and block itervars in the block. If a block itervar does not appear in the buffer access region, it and its corresponding loop variables will be omitted. User can then use `transform_block_layout` primitive to reorder the block itervars and surrounding loops of the cache read/write block.
# Relations to Existing Schedule Primitives
- Relation with `reindex`
- `reindex` only supports the special case of `reindex_cache_read`/`reindex_cache_write`, where`index_map` is the identity map, `reindex` does not have `storage_scope` field.
- Relation with `transform_layout`
- `transform_layout` is not designed to transform the layout of input buffers instead of intermediate buffers, and does not have a `storage_scope` field.
- Relation with `cache_read/wite`
- `cache_read`/`cache_write` do not support customized indices.reindex_cache_read/write (#14161)1 parent e59d1ef commit 9d732d0
File tree
10 files changed
+1149
-31
lines changed- include/tvm/tir/schedule
- python/tvm/tir/schedule
- src/tir/schedule
- primitive
- tests/python/unittest
10 files changed
+1149
-31
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
405 | 405 | | |
406 | 406 | | |
407 | 407 | | |
| 408 | + | |
| 409 | + | |
| 410 | + | |
| 411 | + | |
| 412 | + | |
| 413 | + | |
| 414 | + | |
| 415 | + | |
| 416 | + | |
| 417 | + | |
| 418 | + | |
| 419 | + | |
| 420 | + | |
| 421 | + | |
| 422 | + | |
| 423 | + | |
| 424 | + | |
| 425 | + | |
| 426 | + | |
| 427 | + | |
| 428 | + | |
| 429 | + | |
| 430 | + | |
| 431 | + | |
| 432 | + | |
| 433 | + | |
| 434 | + | |
| 435 | + | |
408 | 436 | | |
409 | 437 | | |
410 | 438 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1115 | 1115 | | |
1116 | 1116 | | |
1117 | 1117 | | |
1118 | | - | |
| 1118 | + | |
1119 | 1119 | | |
1120 | 1120 | | |
1121 | 1121 | | |
| |||
1203 | 1203 | | |
1204 | 1204 | | |
1205 | 1205 | | |
| 1206 | + | |
| 1207 | + | |
| 1208 | + | |
| 1209 | + | |
| 1210 | + | |
| 1211 | + | |
| 1212 | + | |
| 1213 | + | |
| 1214 | + | |
| 1215 | + | |
| 1216 | + | |
| 1217 | + | |
| 1218 | + | |
| 1219 | + | |
| 1220 | + | |
| 1221 | + | |
| 1222 | + | |
| 1223 | + | |
| 1224 | + | |
| 1225 | + | |
| 1226 | + | |
| 1227 | + | |
| 1228 | + | |
| 1229 | + | |
| 1230 | + | |
| 1231 | + | |
| 1232 | + | |
| 1233 | + | |
| 1234 | + | |
| 1235 | + | |
| 1236 | + | |
| 1237 | + | |
| 1238 | + | |
| 1239 | + | |
| 1240 | + | |
| 1241 | + | |
| 1242 | + | |
| 1243 | + | |
| 1244 | + | |
| 1245 | + | |
| 1246 | + | |
| 1247 | + | |
| 1248 | + | |
| 1249 | + | |
| 1250 | + | |
| 1251 | + | |
| 1252 | + | |
| 1253 | + | |
| 1254 | + | |
| 1255 | + | |
| 1256 | + | |
| 1257 | + | |
| 1258 | + | |
| 1259 | + | |
| 1260 | + | |
| 1261 | + | |
| 1262 | + | |
| 1263 | + | |
| 1264 | + | |
| 1265 | + | |
| 1266 | + | |
| 1267 | + | |
| 1268 | + | |
| 1269 | + | |
| 1270 | + | |
| 1271 | + | |
| 1272 | + | |
| 1273 | + | |
| 1274 | + | |
| 1275 | + | |
| 1276 | + | |
| 1277 | + | |
| 1278 | + | |
| 1279 | + | |
| 1280 | + | |
| 1281 | + | |
| 1282 | + | |
| 1283 | + | |
| 1284 | + | |
| 1285 | + | |
| 1286 | + | |
| 1287 | + | |
| 1288 | + | |
| 1289 | + | |
| 1290 | + | |
| 1291 | + | |
| 1292 | + | |
| 1293 | + | |
| 1294 | + | |
| 1295 | + | |
| 1296 | + | |
| 1297 | + | |
| 1298 | + | |
| 1299 | + | |
| 1300 | + | |
| 1301 | + | |
| 1302 | + | |
| 1303 | + | |
| 1304 | + | |
| 1305 | + | |
| 1306 | + | |
| 1307 | + | |
| 1308 | + | |
| 1309 | + | |
| 1310 | + | |
| 1311 | + | |
| 1312 | + | |
| 1313 | + | |
| 1314 | + | |
| 1315 | + | |
| 1316 | + | |
| 1317 | + | |
| 1318 | + | |
| 1319 | + | |
| 1320 | + | |
| 1321 | + | |
| 1322 | + | |
| 1323 | + | |
| 1324 | + | |
| 1325 | + | |
| 1326 | + | |
| 1327 | + | |
| 1328 | + | |
| 1329 | + | |
| 1330 | + | |
| 1331 | + | |
| 1332 | + | |
| 1333 | + | |
| 1334 | + | |
| 1335 | + | |
| 1336 | + | |
| 1337 | + | |
| 1338 | + | |
| 1339 | + | |
| 1340 | + | |
| 1341 | + | |
| 1342 | + | |
| 1343 | + | |
| 1344 | + | |
| 1345 | + | |
| 1346 | + | |
| 1347 | + | |
| 1348 | + | |
| 1349 | + | |
| 1350 | + | |
| 1351 | + | |
| 1352 | + | |
| 1353 | + | |
| 1354 | + | |
| 1355 | + | |
| 1356 | + | |
| 1357 | + | |
| 1358 | + | |
| 1359 | + | |
| 1360 | + | |
| 1361 | + | |
| 1362 | + | |
| 1363 | + | |
| 1364 | + | |
| 1365 | + | |
| 1366 | + | |
| 1367 | + | |
| 1368 | + | |
| 1369 | + | |
| 1370 | + | |
| 1371 | + | |
| 1372 | + | |
| 1373 | + | |
| 1374 | + | |
| 1375 | + | |
| 1376 | + | |
| 1377 | + | |
| 1378 | + | |
| 1379 | + | |
| 1380 | + | |
| 1381 | + | |
| 1382 | + | |
| 1383 | + | |
| 1384 | + | |
| 1385 | + | |
| 1386 | + | |
| 1387 | + | |
| 1388 | + | |
| 1389 | + | |
| 1390 | + | |
| 1391 | + | |
| 1392 | + | |
| 1393 | + | |
| 1394 | + | |
| 1395 | + | |
| 1396 | + | |
1206 | 1397 | | |
1207 | 1398 | | |
1208 | 1399 | | |
| |||
1439 | 1630 | | |
1440 | 1631 | | |
1441 | 1632 | | |
1442 | | - | |
| 1633 | + | |
1443 | 1634 | | |
1444 | 1635 | | |
1445 | 1636 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
568 | 568 | | |
569 | 569 | | |
570 | 570 | | |
| 571 | + | |
| 572 | + | |
| 573 | + | |
| 574 | + | |
| 575 | + | |
| 576 | + | |
| 577 | + | |
| 578 | + | |
| 579 | + | |
| 580 | + | |
| 581 | + | |
| 582 | + | |
| 583 | + | |
| 584 | + | |
| 585 | + | |
| 586 | + | |
| 587 | + | |
| 588 | + | |
| 589 | + | |
| 590 | + | |
| 591 | + | |
| 592 | + | |
| 593 | + | |
| 594 | + | |
571 | 595 | | |
572 | 596 | | |
573 | 597 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
116 | 116 | | |
117 | 117 | | |
118 | 118 | | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
119 | 123 | | |
120 | 124 | | |
121 | 125 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
269 | 269 | | |
270 | 270 | | |
271 | 271 | | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
| 284 | + | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
| 298 | + | |
| 299 | + | |
| 300 | + | |
| 301 | + | |
| 302 | + | |
| 303 | + | |
| 304 | + | |
272 | 305 | | |
273 | 306 | | |
274 | 307 | | |
| |||
0 commit comments