Skip to content

Commit 6574e16

Browse files
authored
[MetaSchedule][Hexagon] Add postproc for verifying VTCM usage (#13538)
* add new postproc VerifyVTCMLimit * remove pass * add test * add doc, missing file * Add back VectorizeLoop in prereq lowering pass * fix lint
1 parent 965490e commit 6574e16

File tree

8 files changed

+288
-4
lines changed

8 files changed

+288
-4
lines changed

include/tvm/meta_schedule/postproc.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,11 @@ class Postproc : public runtime::ObjectRef {
144144
* \return The postprocessor created
145145
*/
146146
TVM_DLL static Postproc VerifyGPUCode();
147+
/*!
148+
* \brief Verifies that the VTCM usage of a given schedule is within the provided limit.
149+
* \return The postprocessor created
150+
*/
151+
TVM_DLL static Postproc VerifyVTCMLimit();
147152
/*!
148153
* \brief Creates a postprocessor that rewrites the layout of input tensor
149154
* \note Weight layout rewrite is supported so far, activation layout rewrite will be added.

include/tvm/tir/analysis.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,14 @@ TVM_DLL bool VerifyMemory(const PrimFunc& func);
169169
*/
170170
TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints);
171171

172+
/*!
173+
* \brief Verifies that the VTCM usage of the given prim_func is within the provided limit.
174+
* \param func The function to be checked.
175+
* \param limit The limit to check.
176+
* \return true if the VTCM usage is within the provided limit.
177+
*/
178+
TVM_DLL bool VerifyVTCMLimit(const PrimFunc& func, Integer limit);
179+
172180
/*!
173181
* \brief Auto detect the block access region according to its body stmt
174182
* It will detect the access region as an array in order of appearance in AST

python/tvm/meta_schedule/postproc/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@
2424
from .rewrite_tensorize import RewriteTensorize
2525
from .rewrite_unbound_block import RewriteUnboundBlock
2626
from .verify_gpu_code import VerifyGPUCode
27+
from .verify_vtcm_limit import VerifyVTCMLimit
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""A postprocessor that verifies the VTCM usage of a given schedule."""
18+
19+
from tvm._ffi.registry import register_object
20+
from .. import _ffi_api
21+
from .postproc import Postproc
22+
23+
24+
@register_object("meta_schedule.VerifyVTCMLimit")
25+
class VerifyVTCMLimit(Postproc):
26+
"""Verifies that the VTCM usage of a given schedule is within the provided limit."""
27+
28+
def __init__(self) -> None:
29+
self.__init_handle_by_constructor__(
30+
_ffi_api.PostprocVerifyVTCMLimit, # type: ignore # pylint: disable=no-member
31+
)

src/meta_schedule/postproc/postproc.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,9 @@ Array<Postproc> Postproc::DefaultCUDATensorCore() {
9494

9595
Array<Postproc> Postproc::DefaultHexagon() {
9696
return Array<Postproc>{
97-
Postproc::DisallowDynamicLoop(),
98-
Postproc::RewriteParallelVectorizeUnroll(),
99-
Postproc::RewriteReductionBlock(),
100-
Postproc::RewriteLayout(),
97+
Postproc::DisallowDynamicLoop(), Postproc::RewriteParallelVectorizeUnroll(),
98+
Postproc::RewriteReductionBlock(), Postproc::RewriteLayout(),
99+
Postproc::VerifyVTCMLimit(),
101100
};
102101
}
103102

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
#include <tvm/tir/transform.h>
20+
21+
#include "../utils.h"
22+
23+
namespace tvm {
24+
namespace meta_schedule {
25+
26+
class VerifyVTCMLimitNode : public PostprocNode {
27+
public:
28+
Integer vtcm_capacity;
29+
30+
void InitializeWithTuneContext(const TuneContext& context) final {
31+
ICHECK(context->target.defined());
32+
Target target = context->target.value();
33+
ICHECK(target->kind->name == "hexagon");
34+
// The value of 0 will disable VTCM verification.
35+
vtcm_capacity = target->GetAttr<Integer>("vtcm-capacity").value_or(0);
36+
}
37+
38+
bool Verify(const IRModule& mod) const {
39+
for (const auto& kv : mod->functions) {
40+
if (const auto* prim_func = kv.second.as<tir::PrimFuncNode>()) {
41+
if (!tir::VerifyVTCMLimit(GetRef<tir::PrimFunc>(prim_func), vtcm_capacity)) {
42+
return false;
43+
}
44+
}
45+
}
46+
return true;
47+
}
48+
49+
bool Apply(const tir::Schedule& sch) final {
50+
IRModule mod = sch->mod();
51+
for (const auto& kv : mod->functions) {
52+
const GlobalVar& g_var = kv.first;
53+
const BaseFunc& base_func = kv.second;
54+
if (const auto* prim_func = base_func.as<tir::PrimFuncNode>()) {
55+
IRModule lowered{nullptr};
56+
try {
57+
auto pass_list = Array<tvm::transform::Pass>();
58+
pass_list.push_back(tir::transform::LowerInitBlock());
59+
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
60+
pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
61+
pass_list.push_back(tir::transform::CompactBufferAllocation());
62+
pass_list.push_back(tir::transform::LowerMatchBuffer());
63+
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
64+
pass_list.push_back(tir::transform::LowerOpaqueBlock());
65+
pass_list.push_back(tir::transform::FlattenBuffer());
66+
pass_list.push_back(tir::transform::Simplify());
67+
pass_list.push_back(tir::transform::VectorizeLoop(true));
68+
pass_list.push_back(tir::transform::StorageRewrite());
69+
transform::PassContext pass_ctx = transform::PassContext::Current();
70+
tir::PrimFunc f = WithAttr(GetRef<tir::PrimFunc>(prim_func), "global_symbol",
71+
runtime::String(g_var->name_hint));
72+
IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(g_var->name_hint), f}}));
73+
lowered = tvm::transform::Sequential(pass_list)(std::move(mod));
74+
} catch (const dmlc::Error& e) {
75+
return false;
76+
}
77+
if (!Verify(lowered)) {
78+
return false;
79+
}
80+
}
81+
}
82+
return true;
83+
}
84+
85+
Postproc Clone() const {
86+
ObjectPtr<VerifyVTCMLimitNode> n = make_object<VerifyVTCMLimitNode>(*this);
87+
return Postproc(n);
88+
}
89+
90+
static constexpr const char* _type_key = "meta_schedule.VerifyVTCMLimit";
91+
TVM_DECLARE_FINAL_OBJECT_INFO(VerifyVTCMLimitNode, PostprocNode);
92+
};
93+
94+
Postproc Postproc::VerifyVTCMLimit() {
95+
ObjectPtr<VerifyVTCMLimitNode> n = make_object<VerifyVTCMLimitNode>();
96+
return Postproc(n);
97+
}
98+
99+
TVM_REGISTER_NODE_TYPE(VerifyVTCMLimitNode);
100+
TVM_REGISTER_GLOBAL("meta_schedule.PostprocVerifyVTCMLimit")
101+
.set_body_typed(Postproc::VerifyVTCMLimit);
102+
103+
} // namespace meta_schedule
104+
} // namespace tvm

src/tir/analysis/calculate_allocated_memory.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,15 @@ TVM_REGISTER_GLOBAL("tir.analysis.calculate_allocated_bytes").set_body_typed([](
8787
return CalculateAllocatedBytes(func);
8888
});
8989

90+
bool VerifyVTCMLimit(const PrimFunc& func, Integer limit) {
91+
auto sizes = CalculateAllocatedBytes(func);
92+
const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0);
93+
if (limit.IntValue() > 0 && vtcm_allocated.IntValue() > limit.IntValue()) {
94+
return false;
95+
}
96+
return true;
97+
}
98+
9099
namespace transform {
91100

92101
Pass VerifyVTCMLimit(const Integer& limit) {
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
18+
import tvm
19+
import tvm.testing
20+
from tvm import meta_schedule as ms
21+
from tvm import tir
22+
from tvm.script import tir as T
23+
24+
25+
def _create_context(mod, target) -> ms.TuneContext:
26+
return ms.TuneContext(
27+
mod=mod,
28+
target=target,
29+
space_generator=ms.space_generator.PostOrderApply(
30+
sch_rules=[],
31+
postprocs=[ms.postproc.VerifyVTCMLimit()],
32+
mutator_probs={},
33+
),
34+
task_name="test",
35+
)
36+
37+
38+
# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant
39+
# fmt: off
40+
41+
42+
@tvm.script.ir_module
43+
class Conv2dNCHWcVTCM:
44+
@T.prim_func
45+
def main(p0: T.Buffer[(T.int64(1), T.int64(2), T.int64(56), T.int64(56), T.int64(32)), "uint8"], p1: T.Buffer[(T.int64(2), T.int64(2), T.int64(3), T.int64(3), T.int64(8), T.int64(32), T.int64(4)), "uint8"], conv2d_NCHWc_int8: T.Buffer[(T.int64(1), T.int64(2), T.int64(54), T.int64(54), T.int64(32)), "int32"]):
46+
T.func_attr({"tir.noalias": True, "global_symbol": "main"})
47+
p0_global_vtcm = T.alloc_buffer([T.int64(1), T.int64(2), T.int64(56), T.int64(56), T.int64(32)], dtype="uint8", scope="global.vtcm")
48+
p1_global_vtcm = T.alloc_buffer([T.int64(2), T.int64(2), T.int64(3), T.int64(3), T.int64(8), T.int64(32), T.int64(4)], dtype="uint8", scope="global.vtcm")
49+
for n_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step":16, "pragma_unroll_explicit":1}):
50+
for oc_chunk_0, oh_0, ow_0, oc_block_0_0 in T.grid(T.int64(2), T.int64(2), T.int64(2), T.int64(1)):
51+
for oc_chunk_1_init, oh_1_init, ow_1_init, oc_chunk_2_init, oh_2_init, ow_2_init in T.grid(T.int64(1), T.int64(27), T.int64(3), T.int64(1), T.int64(1), T.int64(9)):
52+
with T.block("conv2d_NCHWc_int8_o_init"):
53+
v_n = T.axis.spatial(T.int64(1), T.int64(0))
54+
v_oc_chunk = T.axis.spatial(T.int64(2), oc_chunk_1_init + oc_chunk_2_init + oc_chunk_0)
55+
v_oh = T.axis.spatial(T.int64(54), oh_2_init + oh_0 * T.int64(27) + oh_1_init)
56+
v_ow = T.axis.spatial(T.int64(54), ow_0 * T.int64(27) + ow_1_init * T.int64(9) + ow_2_init)
57+
v_oc_block_o = T.axis.spatial(T.int64(1), T.int64(0))
58+
T.reads()
59+
T.writes(conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32)])
60+
for oc_block_1 in T.vectorized(T.int64(32)):
61+
with T.block("conv2d_NCHWc_int8_init"):
62+
v_oc_block_i_init = T.axis.spatial(T.int64(32), oc_block_1)
63+
T.reads()
64+
T.writes(conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i_init])
65+
conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i_init] = 0
66+
for kh_0_kw_0_ic_outer_0_ic_f_inner_0_ic_s_inner_0_0_fused in T.serial(T.int64(2), annotations={"software_pipeline_async_stages":[0], "software_pipeline_order":[0, 1, 2], "software_pipeline_stage":[0, 0, 1]}):
67+
for ax0_ax1_ax2_ax3_ax4_fused in T.serial(T.int64(26912)):
68+
with T.block("p0_global.vtcm"):
69+
v0 = T.axis.spatial(T.int64(1), T.int64(0))
70+
v1 = T.axis.spatial(T.int64(2), ax0_ax1_ax2_ax3_ax4_fused // T.int64(13456))
71+
v2 = T.axis.spatial(T.int64(56), oh_0 * T.int64(27) + ax0_ax1_ax2_ax3_ax4_fused % T.int64(13456) // T.int64(464))
72+
v3 = T.axis.spatial(T.int64(56), ow_0 * T.int64(27) + ax0_ax1_ax2_ax3_ax4_fused % T.int64(464) // T.int64(16))
73+
v4 = T.axis.spatial(T.int64(32), kh_0_kw_0_ic_outer_0_ic_f_inner_0_ic_s_inner_0_0_fused * T.int64(16) + ax0_ax1_ax2_ax3_ax4_fused % T.int64(16))
74+
T.reads(p0[v0, v1, v2, v3, v4])
75+
T.writes(p0_global_vtcm[v0, v1, v2, v3, v4])
76+
p0_global_vtcm[v0, v1, v2, v3, v4] = p0[v0, v1, v2, v3, v4]
77+
for ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused in T.serial(T.int64(9216)):
78+
with T.block("p1_global.vtcm"):
79+
v0 = T.axis.spatial(T.int64(2), oc_chunk_0)
80+
v1 = T.axis.spatial(T.int64(2), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused // T.int64(4608))
81+
v2 = T.axis.spatial(T.int64(3), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(4608) // T.int64(1536))
82+
v3 = T.axis.spatial(T.int64(3), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(1536) // T.int64(512))
83+
v4 = T.axis.spatial(T.int64(8), kh_0_kw_0_ic_outer_0_ic_f_inner_0_ic_s_inner_0_0_fused * T.int64(4) + ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(512) // T.int64(128))
84+
v5 = T.axis.spatial(T.int64(32), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(128) // T.int64(4))
85+
v6 = T.axis.spatial(T.int64(4), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(4))
86+
T.reads(p1[v0, v1, v2, v3, v4, v5, v6])
87+
T.writes(p1_global_vtcm[v0, v1, v2, v3, v4, v5, v6])
88+
p1_global_vtcm[v0, v1, v2, v3, v4, v5, v6] = p1[v0, v1, v2, v3, v4, v5, v6]
89+
for n_1, oc_chunk_1, oh_1, ow_1, oc_block_0_1, kh_1, kw_1, ic_outer_1, ic_f_inner_1, ic_s_inner_0_1, n_2, oc_chunk_2, oh_2, ow_2, oc_block_0_2 in T.grid(T.int64(1), T.int64(1), T.int64(27), T.int64(3), T.int64(1), T.int64(3), T.int64(3), T.int64(2), T.int64(4), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(9), T.int64(1)):
90+
with T.block("conv2d_NCHWc_int8_o_update"):
91+
v_n = T.axis.spatial(T.int64(1), T.int64(0))
92+
v_oc_chunk = T.axis.spatial(T.int64(2), oc_chunk_1 + oc_chunk_2 + oc_chunk_0)
93+
v_oh = T.axis.spatial(T.int64(54), oh_2 + oh_0 * T.int64(27) + oh_1)
94+
v_ow = T.axis.spatial(T.int64(54), ow_0 * T.int64(27) + ow_1 * T.int64(9) + ow_2)
95+
v_oc_block_o = T.axis.spatial(T.int64(1), T.int64(0))
96+
v_kh, v_kw, v_ic_outer = T.axis.remap("RRR", [kh_1, kw_1, ic_outer_1])
97+
v_ic_f_inner = T.axis.reduce(T.int64(8), kh_0_kw_0_ic_outer_0_ic_f_inner_0_ic_s_inner_0_0_fused * T.int64(4) + ic_f_inner_1)
98+
v_ic_s_inner_o = T.axis.reduce(T.int64(1), T.int64(0))
99+
T.reads(conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32)], p0_global_vtcm[v_n, v_ic_outer, v_oh + v_kh, v_ow + v_kw, v_ic_f_inner * T.int64(4) : v_ic_f_inner * T.int64(4) + T.int64(4)], p1_global_vtcm[v_oc_chunk, v_ic_outer, v_kh, v_kw, v_ic_f_inner, T.int64(0) : T.int64(32), T.int64(0) : T.int64(4)])
100+
T.writes(conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32)])
101+
for oc_block_1, ic_s_inner_1 in T.grid(T.int64(32), T.int64(4)):
102+
with T.block("conv2d_NCHWc_int8"):
103+
v_oc_block_i, v_ic_s_inner_i = T.axis.remap("SR", [oc_block_1, ic_s_inner_1])
104+
T.reads(conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i], p0_global_vtcm[v_n, v_ic_outer, v_oh + v_kh, v_ow + v_kw, v_ic_f_inner * T.int64(4) + v_ic_s_inner_i], p1_global_vtcm[v_oc_chunk, v_ic_outer, v_kh, v_kw, v_ic_f_inner, v_oc_block_i, v_ic_s_inner_i])
105+
T.writes(conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i])
106+
T.block_attr({"meta_schedule.tiling_structure":"SRSRS"})
107+
conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i] = conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i] + T.Cast("int32", p0_global_vtcm[v_n, v_ic_outer, v_oh + v_kh, v_ow + v_kw, v_ic_f_inner * T.int64(4) + v_ic_s_inner_i]) * T.Cast("int32", p1_global_vtcm[v_oc_chunk, v_ic_outer, v_kh, v_kw, v_ic_f_inner, v_oc_block_i, v_ic_s_inner_i])
108+
109+
#fmt on
110+
111+
112+
def test_conv2d_vtcm():
113+
def get_target(vtcm_cap):
114+
target = tvm.target.hexagon("v68", vtcm_capacity=vtcm_cap)
115+
return tvm.target.Target(target, host=target)
116+
117+
sch = tir.Schedule(Conv2dNCHWcVTCM, debug_mask="all")
118+
119+
ctx = _create_context(Conv2dNCHWcVTCM, target=get_target(70000))
120+
assert not ctx.space_generator.postprocs[0].apply(sch)
121+
122+
ctx = _create_context(Conv2dNCHWcVTCM, target=get_target(75000))
123+
assert ctx.space_generator.postprocs[0].apply(sch)
124+
125+
126+
if __name__ == "__main__":
127+
tvm.testing.main()

0 commit comments

Comments
 (0)