Skip to content

Commit d2ee4ec

Browse files
authored
Add DisallowAsyncStridedMemCopy post processor to rem (#13720)
* [MetaScheduler] Add DisallowAsyncStridedMemCopy post processor to remove schedules that use async strided mem copies. * [MetaScheduler] Add test for DisallowAsyncStridedMemCopy
1 parent b2da945 commit d2ee4ec

File tree

6 files changed

+351
-4
lines changed

6 files changed

+351
-4
lines changed

include/tvm/meta_schedule/postproc.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,12 @@ class Postproc : public runtime::ObjectRef {
109109
* \return The postprocessor created
110110
*/
111111
TVM_DLL static Postproc DisallowDynamicLoop();
112+
/*!
113+
* \brief Create a postprocessor that checks if all async mem copies are not strided.
114+
* \param merge_async_commit_queue_scope Whether or not to merge async commit queue scope.
115+
* \return The postprocessor created
116+
*/
117+
TVM_DLL static Postproc DisallowAsyncStridedMemCopy(bool merge_async_commit_queue_scope = true);
112118
/*!
113119
* \brief Create a postprocessor that rewrites the cooperative fetch annotation to
114120
* actual vectorized cooperative fetching in loop bindings.

python/tvm/meta_schedule/postproc/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717
"""The tvm.meta_schedule.postproc package."""
1818
from .disallow_dynamic_loop import DisallowDynamicLoop
19+
from .disallow_async_strided_mem_copy import DisallowAsyncStridedMemCopy
1920
from .postproc import Postproc, PyPostproc
2021
from .rewrite_cooperative_fetch import RewriteCooperativeFetch
2122
from .rewrite_layout import RewriteLayout
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""A postprocessor that checks if the IRModule has any strided memory copies"""
18+
19+
from tvm._ffi.registry import register_object
20+
from .. import _ffi_api
21+
from .postproc import Postproc
22+
23+
24+
@register_object("meta_schedule.DisallowAsyncStridedMemCopy")
25+
class DisallowAsyncStridedMemCopy(Postproc):
26+
"""A postprocessor that disallows schedules that use async strided mem copies.
27+
28+
Parameters
29+
----------
30+
merge_async_commit_queue_scope : bool
31+
Whether or not to merge the async commit queue scope.
32+
"""
33+
34+
def __init__(self, merge_async_commit_queue_scope=True) -> None:
35+
self.__init_handle_by_constructor__(
36+
_ffi_api.PostprocDisallowAsyncStridedMemCopy, # type: ignore # pylint: disable=no-member
37+
merge_async_commit_queue_scope,
38+
)
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#include "../utils.h"
20+
21+
namespace tvm {
22+
namespace tir {
23+
24+
/*! \brief Check if an IRModule has any async strided mem copies. */
25+
struct AsyncStridedMemCopyFinder : private StmtExprVisitor {
26+
public:
27+
static bool Find(const IRModule& mod) {
28+
AsyncStridedMemCopyFinder finder;
29+
for (const auto& kv : mod->functions) {
30+
if (const auto* prim_func = kv.second.as<PrimFuncNode>()) {
31+
finder(prim_func->body);
32+
if (finder.found_) {
33+
return true;
34+
}
35+
}
36+
}
37+
return false;
38+
}
39+
40+
private:
41+
void VisitStmt_(const ForNode* loop) final {
42+
if (!found_) {
43+
input_iters.Set(loop->loop_var, Range(loop->min, loop->extent));
44+
StmtExprVisitor::VisitStmt_(loop);
45+
}
46+
}
47+
48+
void VisitStmt_(const AttrStmtNode* attrStmt) final {
49+
if (!found_) {
50+
if (attrStmt->attr_key == tir::attr::async_commit_queue_scope) {
51+
auto async_scope = attrStmt->body.as<AttrStmtNode>();
52+
if (!async_scope) {
53+
StmtExprVisitor::VisitStmt_(attrStmt);
54+
}
55+
56+
auto for_loop = async_scope->body.as<ForNode>();
57+
if (!for_loop) {
58+
StmtExprVisitor::VisitStmt_(attrStmt);
59+
}
60+
61+
input_iters.Set(for_loop->loop_var, Range(for_loop->min, for_loop->extent));
62+
63+
auto bufferstorenode = for_loop->body.as<BufferStoreNode>();
64+
if (!bufferstorenode) {
65+
StmtExprVisitor::VisitStmt_(attrStmt);
66+
}
67+
68+
auto bufferloadnode = bufferstorenode->value.as<BufferLoadNode>();
69+
if (!bufferloadnode) {
70+
StmtExprVisitor::VisitStmt_(attrStmt);
71+
}
72+
73+
// get store buffer; assert it exists and is contiguous given it uses a single index
74+
auto bufferstore = bufferstorenode->buffer.as<BufferNode>();
75+
76+
// get load buffer; assert it exists and is contiguous given it uses a single index
77+
auto bufferload = bufferloadnode->buffer.as<BufferNode>();
78+
79+
if (!bufferstore || !bufferload) {
80+
StmtExprVisitor::VisitStmt_(attrStmt);
81+
}
82+
83+
// map loop variable to zero for the store index & simplify
84+
Array<PrimExpr> store_index = bufferstorenode->indices;
85+
86+
// Use DetectIterMap to detect whether store index is non-contiguous.
87+
arith::Analyzer analyzer;
88+
auto store_iter_map = DetectIterMap(store_index, input_iters, 1,
89+
arith::IterMapLevel::Surjective, &analyzer, false);
90+
if (!store_iter_map->errors.empty()) {
91+
found_ = true;
92+
}
93+
94+
// map loop variable to zero for the load index & simplify
95+
Array<PrimExpr> load_index = bufferloadnode->indices;
96+
97+
// Use DetectIterMap to detect whether load index is non-contiguous.
98+
auto load_iter_map = DetectIterMap(load_index, input_iters, 1,
99+
arith::IterMapLevel::Surjective, &analyzer, false);
100+
if (!load_iter_map->errors.empty()) {
101+
found_ = true;
102+
}
103+
}
104+
if (!found_) {
105+
StmtExprVisitor::VisitStmt_(attrStmt);
106+
}
107+
}
108+
}
109+
110+
bool found_ = false;
111+
Map<Var, Range> input_iters = Map<Var, Range>();
112+
};
113+
114+
} // namespace tir
115+
116+
namespace meta_schedule {
117+
118+
/*! \brief Check if the IRModule has any loop with non-constant extent. */
119+
class DisallowAsyncStridedMemCopyNode : public PostprocNode {
120+
public:
121+
// Inherited from PostprocNode
122+
void InitializeWithTuneContext(const TuneContext& context) final {}
123+
// Inherited from PostprocNode
124+
bool Apply(const tir::Schedule& sch) final {
125+
IRModule mod = sch->mod();
126+
for (const auto& kv : mod->functions) {
127+
const GlobalVar& g_var = kv.first;
128+
const BaseFunc& base_func = kv.second;
129+
if (const auto* prim_func = base_func.as<tir::PrimFuncNode>()) {
130+
IRModule lowered{nullptr};
131+
try {
132+
auto pass_list = Array<tvm::transform::Pass>();
133+
pass_list.push_back(tir::transform::LowerInitBlock());
134+
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
135+
pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
136+
pass_list.push_back(tir::transform::CompactBufferAllocation());
137+
pass_list.push_back(tir::transform::LowerMatchBuffer());
138+
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
139+
pass_list.push_back(tir::transform::LowerOpaqueBlock());
140+
pass_list.push_back(tir::transform::FlattenBuffer());
141+
pass_list.push_back(tir::transform::BF16Legalize());
142+
pass_list.push_back(tir::transform::NarrowDataType(32));
143+
pass_list.push_back(tir::transform::Simplify());
144+
pass_list.push_back(tir::transform::InjectVirtualThread());
145+
pass_list.push_back(tir::transform::InjectDoubleBuffer());
146+
pass_list.push_back(tir::transform::VectorizeLoop(true));
147+
pass_list.push_back(tir::transform::StorageRewrite());
148+
transform::PassContext pass_ctx = transform::PassContext::Current();
149+
pass_ctx->config.Set("tir.merge_async_commit_queue_scope",
150+
Bool(merge_async_commit_queue_scope));
151+
tir::PrimFunc f = WithAttr(GetRef<tir::PrimFunc>(prim_func), "global_symbol",
152+
runtime::String(g_var->name_hint));
153+
IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(g_var->name_hint), f}}));
154+
lowered = tvm::transform::Sequential(pass_list)(std::move(mod));
155+
} catch (const dmlc::Error& e) {
156+
return false;
157+
}
158+
if (tir::AsyncStridedMemCopyFinder::Find(lowered)) {
159+
return false;
160+
}
161+
}
162+
}
163+
return true;
164+
}
165+
// Inherited from PostprocNode
166+
Postproc Clone() const {
167+
ObjectPtr<DisallowAsyncStridedMemCopyNode> n =
168+
make_object<DisallowAsyncStridedMemCopyNode>(*this);
169+
return Postproc(n);
170+
}
171+
172+
bool merge_async_commit_queue_scope = true;
173+
174+
static constexpr const char* _type_key = "meta_schedule.DisallowAsyncStridedMemCopy";
175+
TVM_DECLARE_FINAL_OBJECT_INFO(DisallowAsyncStridedMemCopyNode, PostprocNode);
176+
};
177+
178+
Postproc Postproc::DisallowAsyncStridedMemCopy(bool merge_async_commit_queue_scope) {
179+
ObjectPtr<DisallowAsyncStridedMemCopyNode> n = make_object<DisallowAsyncStridedMemCopyNode>();
180+
n->merge_async_commit_queue_scope = merge_async_commit_queue_scope;
181+
return Postproc(n);
182+
}
183+
184+
TVM_REGISTER_NODE_TYPE(DisallowAsyncStridedMemCopyNode);
185+
TVM_REGISTER_GLOBAL("meta_schedule.PostprocDisallowAsyncStridedMemCopy")
186+
.set_body_typed(Postproc::DisallowAsyncStridedMemCopy);
187+
188+
} // namespace meta_schedule
189+
} // namespace tvm

src/tir/transforms/lower_async_dma.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ class AsyncDMALowerer : public StmtExprMutator {
119119
return StmtExprMutator::VisitStmt_(op);
120120
}
121121

122+
// Add the current loop to the input iters mapping.
123+
input_iters.Set(for_loop->loop_var, Range(for_loop->min, for_loop->extent));
124+
122125
// 3) for loop contains buffer store with single index
123126
auto bufferstorenode = for_loop->body.as<BufferStoreNode>();
124127
if (!bufferstorenode || bufferstorenode->indices.size() != 1) {
@@ -156,8 +159,8 @@ class AsyncDMALowerer : public StmtExprMutator {
156159

157160
// Use DetectIterMap to detect whether store index is non-contiguous.
158161
arith::Analyzer analyzer;
159-
auto store_iter_map = DetectIterMap(store_index, input_iters, 1, arith::IterMapLevel::NoCheck,
160-
&analyzer, false);
162+
auto store_iter_map = DetectIterMap(store_index, input_iters, 1,
163+
arith::IterMapLevel::Surjective, &analyzer, false);
161164
if (!store_iter_map->errors.empty()) {
162165
LOG(FATAL)
163166
<< "Unable to lower async dma for non contiguous memory access with store index: "
@@ -173,8 +176,8 @@ class AsyncDMALowerer : public StmtExprMutator {
173176
Array<PrimExpr> load_index = bufferloadnode->indices;
174177

175178
// Use DetectIterMap to detect whether load index is non-contiguous.
176-
auto load_iter_map =
177-
DetectIterMap(load_index, input_iters, 1, arith::IterMapLevel::NoCheck, &analyzer, false);
179+
auto load_iter_map = DetectIterMap(load_index, input_iters, 1,
180+
arith::IterMapLevel::Surjective, &analyzer, false);
178181
if (!load_iter_map->errors.empty()) {
179182
LOG(FATAL) << "Unable to lower async dma for non contiguous memory access with load index: "
180183
<< load_index;
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
18+
19+
import tvm
20+
from tvm import meta_schedule as ms
21+
from tvm import tir
22+
from tvm.script import tir as T
23+
from tvm.target import Target
24+
25+
26+
def _target() -> Target:
27+
return Target("hexagon", host="llvm")
28+
29+
30+
def _create_context(mod, target) -> ms.TuneContext:
31+
ctx = ms.TuneContext(
32+
mod=mod,
33+
target=target,
34+
space_generator=ms.space_generator.PostOrderApply(
35+
sch_rules=[],
36+
postprocs=[
37+
ms.postproc.DisallowAsyncStridedMemCopy(),
38+
],
39+
mutator_probs={},
40+
),
41+
task_name="test",
42+
)
43+
return ctx
44+
45+
46+
# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
47+
# fmt: off
48+
49+
@tvm.script.ir_module
50+
class Matmul:
51+
@T.prim_func
52+
def main(a: T.handle, b: T.handle, c: T.handle) -> None:
53+
T.func_attr({"global_symbol": "main"})
54+
A = T.match_buffer(a, (1024, 1024), "float32")
55+
B = T.match_buffer(b, (1024, 1024), "float32")
56+
C = T.match_buffer(c, (1024, 1024), "float32")
57+
for i, j, k in T.grid(1024, 1024, 1024):
58+
with T.block("matmul"):
59+
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
60+
with T.init():
61+
C[vi, vj] = 0.0
62+
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
63+
64+
# fmt: on
65+
# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
66+
67+
68+
def test_postproc_disallow_async_strided_mem_copy_allows():
69+
mod = Matmul
70+
sch = tir.Schedule(mod, debug_mask="all")
71+
72+
matmul_block = sch.get_block("matmul")
73+
74+
loops = sch.get_loops(matmul_block)
75+
cache_read = sch.cache_read(matmul_block, 0, "global.vtcm")
76+
77+
sch.compute_at(cache_read, loops[1])
78+
79+
sch.annotate(loops[1], "software_pipeline_stage", [0, 1])
80+
sch.annotate(loops[1], "software_pipeline_order", [0, 1])
81+
sch.annotate(loops[1], "software_pipeline_async_stages", [0])
82+
83+
ctx = _create_context(sch.mod, target=_target())
84+
sch.mod.show()
85+
assert ctx.space_generator.postprocs[0].apply(sch)
86+
87+
88+
def test_postproc_disallow_async_strided_mem_copy_disallows():
89+
mod = Matmul
90+
sch = tir.Schedule(mod, debug_mask="all")
91+
92+
matmul_block = sch.get_block("matmul")
93+
94+
loops = sch.get_loops(matmul_block)
95+
# Make it a strided mem copy.
96+
cache_read = sch.cache_read(matmul_block, 1, "global.vtcm")
97+
98+
sch.compute_at(cache_read, loops[1])
99+
sch.annotate(loops[1], "software_pipeline_stage", [0, 1])
100+
sch.annotate(loops[1], "software_pipeline_order", [0, 1])
101+
sch.annotate(loops[1], "software_pipeline_async_stages", [0])
102+
103+
sch.mod.show()
104+
ctx = _create_context(sch.mod, target=_target())
105+
assert not ctx.space_generator.postprocs[0].apply(sch)
106+
107+
108+
if __name__ == "__main__":
109+
test_postproc_disallow_async_strided_mem_copy_allows()
110+
test_postproc_disallow_async_strided_mem_copy_disallows()

0 commit comments

Comments
 (0)