Skip to content

Commit ffbe491

Browse files
[MetaSchedule][M4a] Schedule Rule: Multi-Level-Tiling (#10043)
* multi level tiling * remove tensor core related code * pylint * fix Co-authored-by: Junru Shao <[email protected]>
1 parent 92cd754 commit ffbe491

File tree

10 files changed

+898
-4
lines changed

10 files changed

+898
-4
lines changed

include/tvm/meta_schedule/schedule_rule.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,19 +137,17 @@ class ScheduleRule : public runtime::ObjectRef {
137137
* \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended:
138138
* - NullOpt on CPU
139139
* - [blockIdx.x, vthread.x, threadIdx.x] on GPU
140-
* \param use_tensor_core Whether to apply tensor core wmma intrinsic for the computation
141140
* \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit
142-
* \param vector_load_max_len The length of vector lane in vectorized cooperative fetching.
141+
* \param vector_load_lens The length of vector lane in vectorized cooperative fetching.
143142
* NullOpt means disable vectorization
144143
* \param reuse_read Data reuse configuration for reading. NullOpt means no reuse.
145144
* \param reuse_write Data reuse configuration for writing. NullOpt means no reuse.
146145
* \return The schedule rule created
147146
*/
148147
TVM_DLL static ScheduleRule MultiLevelTiling(String structure, //
149148
Optional<Array<String>> tile_binds, //
150-
bool use_tensor_core, //
151149
Optional<Integer> max_innermost_factor, //
152-
Optional<Integer> vector_load_max_len, //
150+
Optional<Array<Integer>> vector_load_lens, //
153151
Optional<Map<String, ObjectRef>> reuse_read, //
154152
Optional<Map<String, ObjectRef>> reuse_write);
155153
/*!

include/tvm/tir/stmt.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,6 +1364,20 @@ constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint";
13641364
/*! \brief Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling */
13651365
constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure";
13661366

1367+
/*!
1368+
* \brief Mark that the loop should be further skip and bound to environment threads to enable
1369+
* cooperative fetching.
1370+
*/
1371+
constexpr const char* meta_schedule_cooperative_fetch = "meta_schedule.cooperative_fetch";
1372+
1373+
/*! \brief The allowed range of thread extent in thread bindings */
1374+
constexpr const char* meta_schedule_thread_extent_low_inclusive =
1375+
"meta_schedule.thread_extent_low_inclusive";
1376+
1377+
/*! \brief The allowed range of thread extent in thread bindings */
1378+
constexpr const char* meta_schedule_thread_extent_high_inclusive =
1379+
"meta_schedule.thread_extent_high_inclusive";
1380+
13671381
/*! \brief Mark the block whose producer needs to be applied by rule Random-Compute-Location */
13681382
constexpr const char* meta_schedule_random_compute_producer =
13691383
"meta_schedule.random_compute_producer";

python/tvm/meta_schedule/schedule_rule/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .add_rfactor import AddRFactor
2020
from .auto_inline import AutoInline
2121
from .cross_thread_reduction import CrossThreadReduction
22+
from .multi_level_tiling import MultiLevelTiling, ReuseType
2223
from .parallel_vectorize_unroll import ParallelizeVectorizeUnroll
2324
from .random_compute_location import RandomComputeLocation
2425
from .schedule_rule import PyScheduleRule, ScheduleRule
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Multi-level tiling with reuse."""
18+
from typing import Any, Dict, List, NamedTuple, Optional
19+
20+
from tvm._ffi import register_object
21+
22+
from .. import _ffi_api
23+
from .schedule_rule import ScheduleRule
24+
25+
26+
class ReuseType(NamedTuple):
27+
"""Reuse type."""
28+
29+
req: str
30+
levels: List[int]
31+
scope: str
32+
33+
def as_dict(self) -> Dict[str, Any]:
34+
"""Return the dict representation of the reuse type."""
35+
return {
36+
"req": self.req,
37+
"levels": self.levels,
38+
"scope": self.scope,
39+
}
40+
41+
42+
@register_object("meta_schedule.MultiLevelTiling")
43+
class MultiLevelTiling(ScheduleRule):
44+
"""Multi-level tiling with reuse.
45+
46+
Parameters
47+
----------
48+
structure : str
49+
The tiling structure. Recommended:
50+
- 'SSRSRS' on CPU
51+
- 'SSSRRSRS' on GPU
52+
tile_bind : Optional[List[str]]
53+
For each level of tiles, which thread axis it is bound to. Recommended:
54+
- None on CPU
55+
- [blockIdx.x, vthread.x, threadIdx.x] on GPU
56+
max_innermost_factor : Optional[int]
57+
The maximum size of the innermost factor. None means no limit
58+
vector_load_lens : Optional[List[int]]
59+
The length of vector lane in vectorized cooperative fetching.
60+
None means disable vectorization
61+
reuse_read : Optional[ReuseType]
62+
Data reuse configuration for reading. None means no reuse.
63+
reuse_write : Optional[ReuseType]
64+
Data reuse configuration for writing. None means no reuse.
65+
"""
66+
67+
def __init__(
68+
self,
69+
structure: str,
70+
tile_binds: Optional[List[str]] = None,
71+
max_innermost_factor: Optional[int] = None,
72+
vector_load_lens: Optional[List[int]] = None,
73+
reuse_read: Optional[ReuseType] = None,
74+
reuse_write: Optional[ReuseType] = None,
75+
) -> None:
76+
self.__init_handle_by_constructor__(
77+
_ffi_api.ScheduleRuleMultiLevelTiling, # type: ignore # pylint: disable=no-member
78+
structure,
79+
tile_binds,
80+
max_innermost_factor,
81+
vector_load_lens,
82+
reuse_read.as_dict() if reuse_read is not None else None,
83+
reuse_write.as_dict() if reuse_write is not None else None,
84+
)

python/tvm/meta_schedule/testing/schedule_rule.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
AddRFactor,
2020
AutoInline,
2121
CrossThreadReduction,
22+
MultiLevelTiling,
2223
ParallelizeVectorizeUnroll,
2324
RandomComputeLocation,
25+
ReuseType,
2426
ScheduleRule,
2527
)
2628
from tvm.target import Target
@@ -65,6 +67,41 @@ def cross_thread_reduction(target: Target) -> ScheduleRule:
6567
raise NotImplementedError(f"{target.kind.name} is not supported")
6668

6769

70+
def multi_level_tiling(target: Target) -> ScheduleRule:
71+
"""Default schedule rules for with multi-level tiling and reuse"""
72+
if target.kind.name == "llvm":
73+
return MultiLevelTiling(
74+
structure="SSRSRS",
75+
tile_binds=None,
76+
max_innermost_factor=64,
77+
vector_load_lens=None,
78+
reuse_read=None,
79+
reuse_write=ReuseType(
80+
req="may",
81+
levels=[1, 2],
82+
scope="global",
83+
),
84+
)
85+
if target.kind.name == "cuda":
86+
return MultiLevelTiling(
87+
structure="SSSRRSRS",
88+
tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"],
89+
max_innermost_factor=64,
90+
vector_load_lens=[1, 2, 3, 4],
91+
reuse_read=ReuseType(
92+
req="must",
93+
levels=[4],
94+
scope="shared",
95+
),
96+
reuse_write=ReuseType(
97+
req="must",
98+
levels=[3],
99+
scope="local",
100+
),
101+
)
102+
raise NotImplementedError(f"{target.kind.name} is not supported")
103+
104+
68105
def random_compute_location(target: Target) -> ScheduleRule:
69106
"""Default schedule rules for with random-compute-location"""
70107
if target.kind.name == "llvm":

0 commit comments

Comments
 (0)