Skip to content

Commit 2b8c430

Browse files
committed
wip MultiLevelTiling refactor
1 parent 7c21a9f commit 2b8c430

File tree

6 files changed

+568
-399
lines changed

6 files changed

+568
-399
lines changed

include/tvm/meta_schedule/schedule_rule.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,14 @@ class ScheduleRule : public runtime::ObjectRef {
150150
Optional<Array<Integer>> vector_load_lens, //
151151
Optional<Map<String, ObjectRef>> reuse_read, //
152152
Optional<Map<String, ObjectRef>> reuse_write);
153+
154+
TVM_DLL static ScheduleRule MultiLevelTilingVNNI(String structure, //
155+
Optional<Array<String>> tile_binds, //
156+
Optional<Integer> max_innermost_factor, //
157+
Optional<Array<Integer>> vector_load_lens, //
158+
Optional<Map<String, ObjectRef>> reuse_read, //
159+
Optional<Map<String, ObjectRef>> reuse_write);
160+
153161
/*!
154162
* \brief Create a rule: add-rfactor to some blocks if needed
155163
* \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the

python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,48 @@ def __init__(
8282
reuse_read.as_dict() if reuse_read is not None else None,
8383
reuse_write.as_dict() if reuse_write is not None else None,
8484
)
85+
86+
87+
@register_object("meta_schedule.MultiLevelTilingVNNI")
88+
class MultiLevelTilingVNNI(ScheduleRule):
89+
"""Multi-level tiling with reuse.
90+
91+
Parameters
92+
----------
93+
structure : str
94+
The tiling structure. Recommended:
95+
- 'SSRSRS' on CPU
96+
- 'SSSRRSRS' on GPU
97+
tile_bind : Optional[List[str]]
98+
For each level of tiles, which thread axis it is bound to. Recommended:
99+
- None on CPU
100+
- [blockIdx.x, vthread.x, threadIdx.x] on GPU
101+
max_innermost_factor : Optional[int]
102+
The maximum size of the innermost factor. None means no limit
103+
vector_load_lens : Optional[List[int]]
104+
The length of vector lane in vectorized cooperative fetching.
105+
None means disable vectorization
106+
reuse_read : Optional[ReuseType]
107+
Data reuse configuration for reading. None means no reuse.
108+
reuse_write : Optional[ReuseType]
109+
Data reuse configuration for writing. None means no reuse.
110+
"""
111+
112+
def __init__(
113+
self,
114+
structure: str,
115+
tile_binds: Optional[List[str]] = None,
116+
max_innermost_factor: Optional[int] = None,
117+
vector_load_lens: Optional[List[int]] = None,
118+
reuse_read: Optional[ReuseType] = None,
119+
reuse_write: Optional[ReuseType] = None,
120+
) -> None:
121+
self.__init_handle_by_constructor__(
122+
_ffi_api.ScheduleRuleMultiLevelTilingVNNI, # type: ignore # pylint: disable=no-member
123+
structure,
124+
tile_binds,
125+
max_innermost_factor,
126+
vector_load_lens,
127+
reuse_read.as_dict() if reuse_read is not None else None,
128+
reuse_write.as_dict() if reuse_write is not None else None,
129+
)

python/tvm/meta_schedule/tune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def _sch_rules() -> List[ScheduleRule]:
185185
disallow_op=["tir.exp"],
186186
),
187187
M.AddRFactor(max_jobs_per_core=16, max_innermost_factor=64),
188-
M.MultiLevelTiling(
188+
M.MultiLevelTilingVNNI(
189189
structure="SSRSRS",
190190
tile_binds=None,
191191
max_innermost_factor=64,

0 commit comments

Comments
 (0)