Skip to content

Commit 7956eff

Browse files
Siyuan FengjunrushaozxybazhspectrometerHBHMasterJH5574
committed
[MetaSchedule] disallow_dynamic_loop
Co-authored-by: Junru Shao <[email protected]> Co-authored-by: Xiyou Zhou <[email protected]> Co-authored-by: Bohan Hou <[email protected]> Co-authored-by: Ruihang Lai <[email protected]> Co-authored-by: Hongyi Jin <[email protected]> Co-authored-by: Wuwei Lin <[email protected]>
1 parent 14d0187 commit 7956eff

File tree

4 files changed

+217
-0
lines changed

4 files changed

+217
-0
lines changed

python/tvm/meta_schedule/postproc/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@
1616
# under the License.
1717
"""The tvm.meta_schedule.postproc package."""
1818
from .postproc import Postproc, PyPostproc
19+
from .disallow_dynamic_loop import DisallowDynamicLoop
1920
from .verify_gpu_code import VerifyGPUCode
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""A postprocessor that checks if the IRModule has any loop with non-constant extent"""
18+
19+
from tvm._ffi.registry import register_object
20+
from .. import _ffi_api
21+
from .postproc import Postproc
22+
23+
24+
@register_object("meta_schedule.DisallowDynamicLoop")
25+
class DisallowDynamicLoop(Postproc):
26+
"""A postprocessor that checks if the IRModule has any loop with non-constant extent"""
27+
28+
def __init__(self) -> None:
29+
self.__init_handle_by_constructor__(
30+
_ffi_api.PostprocDisallowDynamicLoop, # type: ignore # pylint: disable=no-member
31+
)
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#include "../utils.h"
20+
21+
namespace tvm {
22+
namespace tir {
23+
24+
/*! \brief Check if the loop is dynamic. */
25+
struct DynamicExtentFinder : private StmtVisitor {
26+
public:
27+
static bool Find(const IRModule& mod) {
28+
DynamicExtentFinder finder;
29+
for (const auto& kv : mod->functions) {
30+
const BaseFunc& func = kv.second;
31+
if (const auto* prim_func = func.as<PrimFuncNode>()) {
32+
finder(prim_func->body);
33+
if (finder.found_) {
34+
return true;
35+
}
36+
}
37+
}
38+
return false;
39+
}
40+
41+
private:
42+
void VisitStmt_(const ForNode* loop) final {
43+
if (!loop->extent->IsInstance<IntImmNode>()) {
44+
found_ = true;
45+
} else {
46+
StmtVisitor::VisitStmt_(loop);
47+
}
48+
}
49+
50+
void VisitStmt(const Stmt& stmt) final {
51+
if (!found_) {
52+
StmtVisitor::VisitStmt(stmt);
53+
}
54+
}
55+
56+
bool found_ = false;
57+
};
58+
59+
} // namespace tir
60+
61+
namespace meta_schedule {
62+
63+
/*! \brief Check if the IRModule has any loop with non-constant extent. */
64+
class DisallowDynamicLoopNode : public PostprocNode {
65+
public:
66+
// Inherited from PostprocNode
67+
void InitializeWithTuneContext(const TuneContext& context) final {}
68+
// Inherited from PostprocNode
69+
bool Apply(const tir::Schedule& sch) final { return !tir::DynamicExtentFinder::Find(sch->mod()); }
70+
71+
static constexpr const char* _type_key = "meta_schedule.DisallowDynamicLoop";
72+
TVM_DECLARE_FINAL_OBJECT_INFO(DisallowDynamicLoopNode, PostprocNode);
73+
};
74+
75+
Postproc Postproc::DisallowDynamicLoop() {
76+
ObjectPtr<DisallowDynamicLoopNode> n = make_object<DisallowDynamicLoopNode>();
77+
return Postproc(n);
78+
}
79+
80+
TVM_REGISTER_NODE_TYPE(DisallowDynamicLoopNode);
81+
TVM_REGISTER_GLOBAL("meta_schedule.PostprocDisallowDynamicLoop")
82+
.set_body_typed(Postproc::DisallowDynamicLoop);
83+
84+
} // namespace meta_schedule
85+
} // namespace tvm
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
18+
19+
import tvm
20+
from tvm import tir
21+
from tvm.meta_schedule import TuneContext
22+
from tvm.meta_schedule.postproc import DisallowDynamicLoop
23+
from tvm.script import tir as T
24+
from tvm.target import Target
25+
26+
27+
def _target() -> Target:
28+
return Target("cuda", host="llvm")
29+
30+
31+
def _create_context(mod, target) -> TuneContext:
32+
ctx = TuneContext(
33+
mod=mod,
34+
target=target,
35+
postprocs=[
36+
DisallowDynamicLoop(),
37+
],
38+
task_name="test",
39+
)
40+
for rule in ctx.postprocs:
41+
rule.initialize_with_tune_context(ctx)
42+
return ctx
43+
44+
45+
# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
46+
# fmt: off
47+
48+
@tvm.script.ir_module
49+
class Matmul:
50+
@T.prim_func
51+
def main(a: T.handle, b: T.handle, c: T.handle) -> None:
52+
T.func_attr({"global_symbol": "main"})
53+
A = T.match_buffer(a, (1024, 1024), "float32")
54+
B = T.match_buffer(b, (1024, 1024), "float32")
55+
C = T.match_buffer(c, (1024, 1024), "float32")
56+
for i, j, k in T.grid(1024, 1024, 1024):
57+
with T.block("matmul"):
58+
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
59+
with T.init():
60+
C[vi, vj] = 0.0
61+
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
62+
63+
64+
@tvm.script.ir_module
65+
class DynamicLoop:
66+
@T.prim_func
67+
def main(a: T.handle, b: T.handle, c: T.handle) -> None:
68+
T.func_attr({"global_symbol": "main"})
69+
A = T.match_buffer(a, (1024, 1024), "float32")
70+
B = T.match_buffer(b, (1024, 1024), "float32")
71+
C = T.match_buffer(c, (1024, 1024), "float32")
72+
for i, j in T.grid(1024, 1024):
73+
for k in T.serial(0, i):
74+
with T.block("matmul"):
75+
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
76+
with T.init():
77+
C[vi, vj] = 0.0
78+
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
79+
80+
# fmt: on
81+
# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
82+
83+
84+
def test_postproc_disallow_dynamic_loops():
85+
mod = Matmul
86+
ctx = _create_context(mod, target=_target())
87+
sch = tir.Schedule(mod, debug_mask="all")
88+
assert ctx.postprocs[0].apply(sch)
89+
90+
91+
def test_postproc_disallow_dynamic_loops_fail():
92+
mod = DynamicLoop
93+
ctx = _create_context(mod, target=_target())
94+
sch = tir.Schedule(mod, debug_mask="all")
95+
assert not ctx.postprocs[0].apply(sch)
96+
97+
98+
if __name__ == "__main__":
99+
test_postproc_disallow_dynamic_loops()
100+
test_postproc_disallow_dynamic_loops_fail()

0 commit comments

Comments
 (0)