Skip to content

Commit 6846484

Browse files
masahiSiyuan FengspectrometerHBHjinhongyiiMasterJH5574
authored
[Metaschedule] Auto tensorization for CPU / GPU dot product (#11088)
* [Metaschedule] Auto-tensorization for CPU / GPU dot product Co-authored-by: Siyuan Feng <[email protected]> Co-authored-by: Bohan Hou <[email protected]> Co-authored-by: Hongyi Jin <[email protected]> Co-authored-by: Ruihang Lai <[email protected]> Co-authored-by: Wuwei Lin <[email protected]> * doc update * add vnni conv2d test * add dp4a test * adding tests for rewrite_tensorize * add rewrite_tensorize test * add missing pydoc * black * more doc * adding auto tensorize integration test * add dp4a test * fix target name * fix dtype in test * skip bert test * replace hard-coded llvm intrinsic id in test with look up * remove unnecessary include, add doc for the rest of params * update postproc.h * update doc * fix shape in te matmul workload * fix newline in cppdoc Co-authored-by: Siyuan Feng <[email protected]> Co-authored-by: Bohan Hou <[email protected]> Co-authored-by: Hongyi Jin <[email protected]> Co-authored-by: Ruihang Lai <[email protected]> Co-authored-by: Wuwei Lin <[email protected]>
1 parent 4dc47df commit 6846484

File tree

15 files changed

+1457
-27
lines changed

15 files changed

+1457
-27
lines changed

include/tvm/meta_schedule/postproc.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,12 @@ class Postproc : public runtime::ObjectRef {
149149
*/
150150
TVM_DLL static Postproc RewriteUnboundBlock(int max_threadblock);
151151
/*!
152-
* \brief Create a postprocessor that tensorize Tensor Core related components
152+
* \brief Create a postprocessor that applies tensorization to annotated blocks
153+
* \param vectorize_init_loop Whether or not vectorize the initialization loop produced by
154+
* DecomposeReduction
153155
* \return The postprocessor created.
154156
*/
155-
TVM_DLL static Postproc RewriteTensorCore();
157+
TVM_DLL static Postproc RewriteTensorize(bool vectorize_init_loop = false);
156158

157159
/*!
158160
* \brief Creates a postprocessor that verifies if the GPU code is correct

include/tvm/meta_schedule/schedule_rule.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,29 @@ class ScheduleRule : public runtime::ObjectRef {
150150
Optional<Array<Integer>> vector_load_lens, //
151151
Optional<Map<String, ObjectRef>> reuse_read, //
152152
Optional<Map<String, ObjectRef>> reuse_write);
153+
154+
/*!
155+
* \brief Extension of MultiLevelTiling for auto-tensorizing with a single intrinsic.
156+
* \param intrin_name The name of a tensor intrinsic, must be registerd via
157+
* TensorIntrin.register(...) beforehand
158+
* \param structure The tiling structure. Recommended:
159+
* - 'SSRSRS' on CPU
160+
* - 'SSSRRSRS' on GPU
161+
* \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended:
162+
* - NullOpt on CPU
163+
* - [blockIdx.x, vthread.x, threadIdx.x] on GPU
164+
* \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit
165+
* \param vector_load_lens The length of vector lane in vectorized cooperative fetching.
166+
* NullOpt means disable vectorization
167+
* \param reuse_read Data reuse configuration for reading. NullOpt means no reuse.
168+
* \param reuse_write Data reuse configuration for writing. NullOpt means no reuse.
169+
* \return The schedule rule created
170+
*/
171+
TVM_DLL static ScheduleRule MultiLevelTilingWithIntrin(
172+
String intrin_name, String structure, Optional<Array<String>> tile_binds,
173+
Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
174+
Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write);
175+
153176
/*!
154177
* \brief Create a rule: add-rfactor to some blocks if needed
155178
* \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the

include/tvm/tir/stmt.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1509,6 +1509,11 @@ constexpr const char* meta_schedule_unroll_explicit = "meta_schedule.unroll_expl
15091509
/*! \brief Mark auto-unroll setting on the block. */
15101510
constexpr const char* meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit";
15111511

1512+
/*!
1513+
* \brief Mark that a block should be further rewritten using tensorization.
1514+
*/
1515+
constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensorize";
1516+
15121517
/*!
15131518
* \brief Check if attr_key is a pragma key extension
15141519
* \param attr_key The attr key to be compared

python/tvm/meta_schedule/postproc/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@
2222
from .rewrite_reduction_block import RewriteReductionBlock
2323
from .rewrite_unbound_block import RewriteUnboundBlock
2424
from .verify_gpu_code import VerifyGPUCode
25+
from .rewrite_tensorize import RewriteTensorize
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""A postprocessor that tensorize related components."""
18+
19+
from tvm._ffi.registry import register_object
20+
from .. import _ffi_api
21+
from .postproc import Postproc
22+
23+
24+
@register_object("meta_schedule.RewriteTensorize")
25+
class RewriteTensorize(Postproc):
26+
"""A postprocessor that applies tensorization to annotated blocks.
27+
28+
Parameters
29+
----------
30+
vectorize_init_loop : bool
31+
Whether or not vectorize the initialization loop produced by DecomposeReduction
32+
"""
33+
34+
def __init__(self, vectorize_init_loop=False) -> None:
35+
self.__init_handle_by_constructor__(
36+
_ffi_api.PostprocRewriteTensorize, # type: ignore # pylint: disable=no-member
37+
vectorize_init_loop,
38+
)

python/tvm/meta_schedule/schedule_rule/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from .add_rfactor import AddRFactor
2323
from .auto_inline import AutoInline
2424
from .cross_thread_reduction import CrossThreadReduction
25-
from .multi_level_tiling import MultiLevelTiling, ReuseType
25+
from .multi_level_tiling import MultiLevelTiling, MultiLevelTilingWithIntrin, ReuseType
2626
from .parallel_vectorize_unroll import ParallelizeVectorizeUnroll
2727
from .random_compute_location import RandomComputeLocation
2828
from .schedule_rule import PyScheduleRule, ScheduleRule

python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,52 @@ def __init__(
8282
reuse_read.as_dict() if reuse_read is not None else None,
8383
reuse_write.as_dict() if reuse_write is not None else None,
8484
)
85+
86+
87+
@register_object("meta_schedule.MultiLevelTilingWithIntrin")
88+
class MultiLevelTilingWithIntrin(ScheduleRule):
89+
"""Extension of MultiLevelTiling for auto-tensorizing with a single intrinsic.
90+
91+
Parameters
92+
----------
93+
intrin_name : str
94+
The name of a tensor intrinsic, must be registerd via TensorIntrin.register(...) beforehand
95+
structure : str
96+
The tiling structure. Recommended:
97+
- 'SSRSRS' on CPU
98+
- 'SSSRRSRS' on GPU
99+
tile_bind : Optional[List[str]]
100+
For each level of tiles, which thread axis it is bound to. Recommended:
101+
- None on CPU
102+
- [blockIdx.x, vthread.x, threadIdx.x] on GPU
103+
max_innermost_factor : Optional[int]
104+
The maximum size of the innermost factor. None means no limit
105+
vector_load_lens : Optional[List[int]]
106+
The length of vector lane in vectorized cooperative fetching.
107+
None means disable vectorization
108+
reuse_read : Optional[ReuseType]
109+
Data reuse configuration for reading. None means no reuse.
110+
reuse_write : Optional[ReuseType]
111+
Data reuse configuration for writing. None means no reuse.
112+
"""
113+
114+
def __init__(
115+
self,
116+
intrin_name: str,
117+
structure: str,
118+
tile_binds: Optional[List[str]] = None,
119+
max_innermost_factor: Optional[int] = None,
120+
vector_load_lens: Optional[List[int]] = None,
121+
reuse_read: Optional[ReuseType] = None,
122+
reuse_write: Optional[ReuseType] = None,
123+
) -> None:
124+
self.__init_handle_by_constructor__(
125+
_ffi_api.ScheduleRuleMultiLevelTilingWithIntrin, # type: ignore # pylint: disable=no-member
126+
intrin_name,
127+
structure,
128+
tile_binds,
129+
max_innermost_factor,
130+
vector_load_lens,
131+
reuse_read.as_dict() if reuse_read is not None else None,
132+
reuse_write.as_dict() if reuse_write is not None else None,
133+
)

python/tvm/meta_schedule/testing/te_workload.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,7 @@ def f_compute(i, j):
607607

608608
def matmul_relu(n: int, m: int, k: int) -> Tuple[te.Tensor, te.Tensor, te.Tensor]:
609609
a = te.placeholder((n, k), name="A")
610-
b = te.placeholder((m, k), name="B")
610+
b = te.placeholder((k, m), name="B")
611611
k = te.reduce_axis((0, k), name="k")
612612
c = te.compute(
613613
(n, m),
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#include <tvm/meta_schedule/postproc.h>
20+
21+
#include <algorithm>
22+
23+
#include "../utils.h"
24+
25+
namespace tvm {
26+
namespace meta_schedule {
27+
28+
using tir::BlockRV;
29+
using tir::LoopRV;
30+
31+
void ApplyTensorization(const tir::Schedule& sch, const String& func_name,
32+
const tir::PrimFuncNode* func, bool vectorize_init_loop) {
33+
std::vector<std::pair<std::string, std::function<void(tir::BlockRV)>>> jobs;
34+
35+
tir::PostOrderVisit(func->body, [=, &jobs](const ObjectRef& obj) {
36+
if (const auto* block = obj.as<tir::BlockNode>()) {
37+
tir::StmtSRef block_sref = sch->GetSRef(block);
38+
if (Optional<String> intrin_name =
39+
tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize)) {
40+
std::string block_name = block_sref->StmtAs<tir::BlockNode>()->name_hint;
41+
if (block_name.find("init") == std::string::npos) {
42+
jobs.emplace_back(block_name, [sch, intrin_name](tir::BlockRV block) {
43+
try {
44+
sch->Tensorize(block, intrin_name.value());
45+
} catch (const std::exception& e) {
46+
LOG(WARNING) << "Tensorize failed with error " << e.what();
47+
}
48+
});
49+
} else if (vectorize_init_loop) {
50+
jobs.emplace_back(block_name, [sch](tir::BlockRV block) {
51+
Array<BlockRV> child_blocks = sch->GetChildBlocks(block);
52+
ICHECK(child_blocks.size() == 1);
53+
Array<LoopRV> init_loops = sch->GetLoops(child_blocks[0]);
54+
ICHECK(init_loops.size() == 1);
55+
sch->Vectorize(init_loops[0]);
56+
});
57+
}
58+
}
59+
}
60+
});
61+
62+
for (auto kv : jobs) {
63+
tir::BlockRV block = sch->GetBlock(kv.first, func_name);
64+
sch->Unannotate(block, tir::attr::meta_schedule_auto_tensorize);
65+
kv.second(block);
66+
}
67+
}
68+
69+
class RewriteTensorizeNode : public PostprocNode {
70+
public:
71+
void InitializeWithTuneContext(const TuneContext& context) final {}
72+
73+
bool Apply(const tir::Schedule& sch) final;
74+
75+
void VisitAttrs(tvm::AttrVisitor* v) {}
76+
77+
bool vectorize_init_loop = false;
78+
79+
static constexpr const char* _type_key = "meta_schedule.RewriteTensorize";
80+
TVM_DECLARE_FINAL_OBJECT_INFO(RewriteTensorizeNode, PostprocNode);
81+
};
82+
83+
bool RewriteTensorizeNode::Apply(const tir::Schedule& sch) {
84+
for (const auto& kv : sch->mod()->functions) {
85+
GlobalVar g_var = kv.first;
86+
BaseFunc base_func = kv.second;
87+
if (const tir::PrimFuncNode* prim_func = base_func.as<tir::PrimFuncNode>()) {
88+
ApplyTensorization(sch, g_var->name_hint, prim_func, vectorize_init_loop);
89+
}
90+
}
91+
return true;
92+
}
93+
94+
Postproc Postproc::RewriteTensorize(bool vectorize_init_loop) {
95+
ObjectPtr<RewriteTensorizeNode> n = make_object<RewriteTensorizeNode>();
96+
n->vectorize_init_loop = vectorize_init_loop;
97+
return Postproc(n);
98+
}
99+
100+
TVM_REGISTER_NODE_TYPE(RewriteTensorizeNode);
101+
TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteTensorize")
102+
.set_body_typed(Postproc::RewriteTensorize);
103+
104+
} // namespace meta_schedule
105+
} // namespace tvm

src/meta_schedule/schedule_rule/multi_level_tiling.cc

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -260,28 +260,9 @@ ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional<Array<Str
260260
Optional<Array<Integer>> vector_load_lens,
261261
Optional<Map<String, ObjectRef>> reuse_read,
262262
Optional<Map<String, ObjectRef>> reuse_write) {
263-
ObjectPtr<MultiLevelTilingNode> n = make_object<MultiLevelTilingNode>();
264-
n->structure = structure;
265-
n->tile_binds = tile_binds.value_or({});
266-
n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value;
267-
n->vector_load_lens = vector_load_lens.defined()
268-
? support::AsVector<Integer, int>(vector_load_lens.value())
269-
: std::vector<int>();
270-
n->reuse_read_ = reuse_read.defined() ? ReuseConfig(reuse_read.value()) : ReuseConfig();
271-
n->reuse_write_ = reuse_write.defined() ? ReuseConfig(reuse_write.value()) : ReuseConfig();
272-
for (int i = 0, len = structure.size(); i < len; ++i) {
273-
char c = structure.data()[i];
274-
if (c == 'S') {
275-
n->s_indices_.push_back(i);
276-
} else if (c == 'R') {
277-
n->r_indices_.push_back(i);
278-
} else {
279-
LOG(FATAL) << "ValueError: Invalid tiling structure: " << structure;
280-
}
281-
}
282-
n->thread_warp_size_ = -1;
283-
n->max_threads_per_block_ = -1;
284-
return ScheduleRule(n);
263+
auto node = MultiLevelTilingInitCommon<MultiLevelTilingNode>(
264+
structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write);
265+
return ScheduleRule(node);
285266
}
286267

287268
TVM_REGISTER_NODE_TYPE(MultiLevelTilingNode);

0 commit comments

Comments
 (0)