@@ -1056,6 +1056,79 @@ def after_cache_write(a: T.handle, b: T.handle) -> None:
10561056 self , block , write_buffer_index , storage_scope
10571057 )
10581058
1059+ @type_checked
1060+ def reindex (self , block : BlockRV , buffer_index : int , buffer_index_type : str ) -> BlockRV :
1061+ """Create a block that read/write a buffer region into a read/write cache with reindexing.
1062+ The layout of the cache will be the same as by the iterators of the block that reads/writes
1063+ the buffer. It requires:
1064+ 1) There is only one block who reads/writes the target buffer
1065+ 2) There is only one buffer load/store of this buffer in the block
1066+
1067+ Parameters
1068+ ----------
1069+ block: BlockRV
1070+ The block that accesses the target buffer
1071+ buffer_index: int
1072+ The index of the buffer in block's read or write region
1073+ buffer_index_type : str
1074+ Type of the buffer index, "read" or "write"
1075+
1076+ Returns
1077+ -------
1078+ reindex_block : BlockRV
1079+ The block of the reindex stage
1080+
1081+ Examples
1082+ --------
1083+
1084+ Before transform_layout, in TensorIR, the IR is:
1085+
1086+ .. code-block:: python
1087+
1088+ @T.prim_func
1089+ def before_reindex(
1090+ A: T.Buffer[(128, 128), "float32"],
1091+ B: T.Buffer[(128, 128), "float32"]
1092+ ) -> None:
1093+ for i, j in T.grid(128, 128):
1094+ with T.block("B"):
1095+ vi, vj = T.axis.remap("SS", [i, j])
1096+ B[vi, vj] = A[vj, vi] * 2.0
1097+
1098+ Create the schedule and do transform_layout:
1099+
1100+ .. code-block:: python
1101+
1102+ sch = tir.Schedule(before_reindex)
1103+ block = sch.get_block("B")
1104+ sch.reindex(block, 0, "read)
1105+
1106+ After applying reindex, the IR becomes:
1107+
1108+ .. code-block:: python
1109+
1110+ @T.prim_func
1111+ def after_reindex(
1112+ A: T.Buffer[(128, 128), "float32"],
1113+ B: T.Buffer[(128, 128), "float32"]
1114+ ) -> None:
1115+ A_reindex = T.alloc_buffer((128, 128), "float32")
1116+ for i, j in T.grid(128, 128):
1117+ with T.block("A_reindex"):
1118+ vi, vj = T.axis.remap("SS", [i, j])
1119+ A_reindex[vi, vj] = A[vj, vi]
1120+ for i, j in T.grid(128, 128):
1121+ with T.block("B"):
1122+ vi, vj = T.axis.remap("SS", [i, j])
1123+ B[vi, vj] = A_reindex[vi, vj] * 2.0
1124+
1125+ """
1126+ assert buffer_index_type in ["read" , "write" ], "Invalid buffer_index_type"
1127+ buffer_index_type_enum = 0 if buffer_index_type == "read" else 1
1128+ return _ffi_api .ScheduleReIndex ( # type: ignore # pylint: disable=no-member
1129+ self , block , buffer_index , buffer_index_type_enum
1130+ )
1131+
10591132 ########## Schedule: Compute location ##########
10601133
10611134 @type_checked
0 commit comments