Skip to content

Commit 66eed5c

Browse files
author
Ashutosh Parkhi
committed
Code generation for Conv2D via CMSIS-NN
Change-Id: I0a2279965a0b505f809ffcf8b955f64db8f4aff0
1 parent 151696f commit 66eed5c

File tree

11 files changed

+1198
-45
lines changed

11 files changed

+1198
-45
lines changed

python/tvm/relay/op/contrib/cmsisnn.py

Lines changed: 64 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -47,42 +47,93 @@ def partition_for_cmsisnn(mod, params=None, **opts):
4747
if params:
4848
mod["main"] = bind_params_by_name(mod["main"], params)
4949

50+
tvm._ffi._init_api("relay.ext.cmsisnn.transform", __name__)
51+
5052
seq = tvm.transform.Sequential(
5153
[
5254
transform.InferType(),
5355
transform.MergeComposite(pattern_table()),
5456
transform.AnnotateTarget("cmsisnn"),
55-
transform.MergeCompilerRegions(),
5657
transform.PartitionGraph(),
58+
GenerateCMSISNNConstants(),
59+
ExtractConstantsFromPartitionedFunction(),
60+
transform.InferType(),
5761
]
5862
)
59-
6063
return seq(mod)
6164

6265

6366
@register_pattern_table("cmsisnn")
6467
def pattern_table():
6568
"""Get the cmsisnn compiler pattern table."""
6669

67-
def softmax_pattern():
70+
def qnn_softmax_pattern():
71+
"""Create pattern for quantized softmax"""
6872
pattern = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant())
6973
pattern = is_op("nn.softmax")(pattern)
7074
pattern = is_op("qnn.quantize")(pattern, is_constant(), is_constant())
7175
return pattern
7276

73-
def check_quantized_softmax(extract):
77+
def check_qnn_softmax(pattern):
7478
"""Check if softmax is supported by CMSIS-NN."""
75-
dequantize_call = extract.args[0].args[0]
76-
scale = extract.args[1].data.numpy().item(0)
77-
zero_point = extract.args[2].data.numpy().item(0)
79+
dequantize_call = pattern.args[0].args[0]
80+
scale = pattern.args[1].data.numpy().item(0)
81+
zero_point = pattern.args[2].data.numpy().item(0)
7882

7983
# check for dtypes of quantize and dequantize
8084
return (
8185
(scale == 1.0 / 256 and zero_point == -128)
82-
and extract.attrs.out_dtype == "int8"
86+
and pattern.attrs.out_dtype == "int8"
8387
and dequantize_call.args[0].checked_type.dtype == "int8"
8488
)
8589

90+
def qnn_conv2d_pattern():
91+
"""Create pattern for qnn.conv2D with optional fused relu."""
92+
qnn_conv2d = is_op("qnn.conv2d")(
93+
wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant()
94+
).has_attr({"kernel_layout": "HWIO"})
95+
bias_add = is_op("nn.bias_add")(qnn_conv2d, is_constant())
96+
req = is_op("qnn.requantize")(
97+
qnn_conv2d | bias_add, is_constant(), is_constant(), is_constant(), is_constant()
98+
)
99+
clip_or_req = req.optional(is_op("clip"))
100+
return clip_or_req
101+
102+
def check_qnn_conv2d(pattern):
103+
"""Check if the Conv2D is supported by CMSIS-NN."""
104+
if str(pattern.op.name) == "clip":
105+
relu = pattern
106+
requantize = relu.args[0]
107+
else:
108+
requantize = pattern
109+
requantize_input = requantize.args[0]
110+
bias_add = None
111+
bias_dtype = "int32"
112+
if str(requantize_input.op.name) == "nn.bias_add":
113+
bias_add = requantize_input
114+
conv2d = bias_add.args[0]
115+
bias_dtype = bias_add.args[1].checked_type.dtype
116+
else:
117+
conv2d = requantize_input
118+
conv2d_input = conv2d.args[0]
119+
conv2d_weight = conv2d.args[1]
120+
121+
# kernel zero_point should be 0
122+
kernel_zp = conv2d.args[3].data.numpy()
123+
kernel_zp = [kernel_zp] if kernel_zp.ndim == 0 else kernel_zp
124+
125+
return (
126+
conv2d.attrs.kernel_layout == "HWIO"
127+
and conv2d.attrs.out_dtype == "int32"
128+
and conv2d.attrs.padding[2] == 0
129+
and conv2d.attrs.padding[3] == 0
130+
and conv2d_input.checked_type.dtype == "int8"
131+
and conv2d_weight.checked_type.dtype == "int8"
132+
and pattern.checked_type.dtype == "int8"
133+
and bias_dtype == "int32"
134+
and all([zp == 0 for zp in kernel_zp])
135+
)
136+
86137
def binary_op_pattern(op):
87138
"""Matches QNN binary operation"""
88139
return is_op(f"qnn.{op}")(
@@ -96,23 +147,16 @@ def binary_op_pattern(op):
96147
is_constant(),
97148
)
98149

99-
def check_quantized_binary_op(extract):
150+
def check_qnn_binary_op(extract):
100151
"""Check if multiply is supported by CMSIS-NN."""
101152
return (
102153
extract.args[0].checked_type.dtype == "int8"
103154
and extract.args[1].checked_type.dtype == "int8"
104155
)
105156

106157
return [
107-
("cmsisnn.quantized_softmax", softmax_pattern(), check_quantized_softmax),
108-
(
109-
"cmsisnn.quantized_mul",
110-
binary_op_pattern("mul"),
111-
check_quantized_binary_op,
112-
),
113-
(
114-
"cmsisnn.quantized_add",
115-
binary_op_pattern("add"),
116-
check_quantized_binary_op,
117-
),
158+
("cmsisnn.qnn_softmax", qnn_softmax_pattern(), check_qnn_softmax),
159+
("cmsisnn.qnn_conv2d", qnn_conv2d_pattern(), check_qnn_conv2d),
160+
("cmsisnn.qnn_mul", binary_op_pattern("mul"), check_qnn_binary_op),
161+
("cmsisnn.qnn_add", binary_op_pattern("add"), check_qnn_binary_op),
118162
]

src/relay/backend/contrib/cmsisnn/codegen_cmsisnn.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
#include <tvm/runtime/registry.h>
2222

2323
namespace tvm {
24+
namespace codegen {
25+
runtime::Module CMSISNNModuleNodeCreate(IRModule mod);
26+
} // namespace codegen
2427
namespace relay {
2528
namespace contrib {
2629
namespace cmsisnn {
@@ -33,14 +36,12 @@ runtime::Module CompileCMSISNN(const ObjectRef& ref) {
3336
auto func_name = relay_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
3437
GlobalVar var = GlobalVar(func_name.value());
3538
relay_mod->Add(var, relay_func);
36-
relay_mod = transform::InferType()(relay_mod);
3739

38-
Array<transform::Pass> pass_seqs{transform::InferType(), RelayToTIR()};
40+
Array<transform::Pass> pass_seqs{RelayToTIR()};
3941
transform::Sequential seq(pass_seqs);
4042
IRModule tir_mod = seq(relay_mod);
4143

42-
const auto* pf = runtime::Registry::Get("runtime.CMSISNNModuleNodeCreate");
43-
return (*pf)(tir_mod);
44+
return tvm::codegen::CMSISNNModuleNodeCreate(tir_mod);
4445
}
4546

4647
TVM_REGISTER_GLOBAL("relay.ext.cmsisnn").set_body_typed(CompileCMSISNN);
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
2+
/*
3+
* Licensed to the Apache Software Foundation (ASF) under one
4+
* or more contributor license agreements. See the NOTICE file
5+
* distributed with this work for additional information
6+
* regarding copyright ownership. The ASF licenses this file
7+
* to you under the Apache License, Version 2.0 (the
8+
* "License"); you may not use this file except in compliance
9+
* with the License. You may obtain a copy of the License at
10+
*
11+
* http://www.apache.org/licenses/LICENSE-2.0
12+
*
13+
* Unless required by applicable law or agreed to in writing,
14+
* software distributed under the License is distributed on an
15+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
* KIND, either express or implied. See the License for the
17+
* specific language governing permissions and limitations
18+
* under the License.
19+
*/
20+
#include <tvm/relay/attrs/nn.h>
21+
#include <tvm/relay/expr_functor.h>
22+
#include <tvm/relay/transform.h>
23+
#include <tvm/runtime/ndarray.h>
24+
25+
#include "../../../qnn/utils.h"
26+
#include "../../../transforms/pattern_utils.h"
27+
28+
namespace tvm {
29+
namespace relay {
30+
namespace contrib {
31+
namespace cmsisnn {
32+
33+
class ExtractConstantsMutator : public MixedModeMutator {
34+
public:
35+
explicit ExtractConstantsMutator(IRModule& mod) : mod_(mod) {}
36+
37+
private:
38+
String gen_var_name() { return "tvm_var_extract_const_" + std::to_string(var_count_++); }
39+
40+
Expr VisitExpr_(const FunctionNode* func) final {
41+
Function final_func = GetRef<Function>(func);
42+
++func_nesting_level_;
43+
auto new_body = VisitExpr(func->body);
44+
--func_nesting_level_;
45+
if (!new_body.same_as(func->body)) {
46+
final_func = Function(FreeVars(new_body), new_body, func->ret_type,
47+
FreeTypeVars(new_body, mod_), func->attrs);
48+
function_to_constants_.Set(GetRef<Function>(func), constants_within_function_);
49+
constants_within_function_.clear();
50+
}
51+
return final_func;
52+
}
53+
54+
Expr Rewrite_(const CallNode* call, const Expr& post) final {
55+
Expr final_call = post;
56+
auto* post_call = post.as<CallNode>();
57+
if (post_call == nullptr) {
58+
return final_call;
59+
}
60+
61+
// Replace Constant arguments with Vars for ML Operators
62+
// Perform this for non-main Call Nodes only
63+
if (func_nesting_level_ && call->op.as<OpNode>()) {
64+
Array<Expr> new_args;
65+
for (auto& arg : post_call->args) {
66+
auto* const_arg = arg.as<ConstantNode>();
67+
if (const_arg && !const_arg->is_scalar()) {
68+
Var var_arg = Var(gen_var_name(), const_arg->tensor_type());
69+
new_args.push_back(var_arg);
70+
constants_within_function_.push_back(GetRef<Constant>(const_arg));
71+
} else {
72+
new_args.push_back(arg);
73+
}
74+
}
75+
final_call = Call(call->op, new_args, call->attrs, {});
76+
}
77+
78+
// Since the constants are kicked out of partitioned functions
79+
// a new call to global function is needed
80+
if (auto* glob_var_node = post_call->op.as<GlobalVarNode>()) {
81+
auto glob_var = GetRef<GlobalVar>(glob_var_node);
82+
auto glob_func = Downcast<Function>(mod_->Lookup(glob_var));
83+
auto new_glob_func = VisitExpr(glob_func);
84+
if (!new_glob_func.same_as(glob_func)) {
85+
mod_->Update(glob_var, Downcast<Function>(new_glob_func));
86+
Array<Expr> new_args = post_call->args;
87+
ICHECK(function_to_constants_.find(glob_func) != function_to_constants_.end());
88+
for (auto constant : function_to_constants_.at(glob_func)) {
89+
new_args.push_back(constant);
90+
}
91+
final_call = Call(glob_var, new_args);
92+
}
93+
}
94+
95+
// Since the constants are kicked out of the local partitioned functions
96+
// a new call to local function is needed
97+
if (auto* func_node = call->op.as<FunctionNode>()) {
98+
Function func = GetRef<Function>(func_node);
99+
auto new_func = VisitExpr(func);
100+
if (!new_func.same_as(func)) {
101+
Array<Expr> new_args = post_call->args;
102+
ICHECK(function_to_constants_.find(func) != function_to_constants_.end());
103+
for (auto constant : function_to_constants_.at(func)) {
104+
constants_within_function_.push_back(constant);
105+
Var var_arg = Var(gen_var_name(), constant->tensor_type());
106+
new_args.push_back(var_arg);
107+
}
108+
final_call = Call(new_func, new_args);
109+
}
110+
}
111+
112+
return final_call;
113+
}
114+
115+
private:
116+
/* \brief Updated module where all calls have replaced constants with new variables */
117+
IRModule mod_;
118+
/* \brief Maintains mapping of original function to the replaced constants */
119+
Map<Function, Array<Constant>> function_to_constants_;
120+
/* \brief Constants being kicked out of a function during the function visit */
121+
Array<Constant> constants_within_function_;
122+
/* \brief Keeps track of variables being created */
123+
int var_count_ = 0;
124+
/* \brief Keeps track of function scope */
125+
int func_nesting_level_ = 0;
126+
};
127+
128+
/*! * \brief Kicks out all constants out of the partitioned function into main() */
129+
IRModule ExtractConstants(IRModule mod) {
130+
String func_name;
131+
Function func;
132+
133+
auto extract_constants = ExtractConstantsMutator(mod);
134+
Function main_func = Downcast<Function>(mod->Lookup("main"));
135+
auto new_main_body = extract_constants.VisitExpr(main_func->body);
136+
if (!new_main_body.same_as(main_func->body)) {
137+
auto main_var = mod->GetGlobalVar("main");
138+
auto new_main_func = Function(main_func->params, new_main_body, main_func->ret_type,
139+
main_func->type_params, main_func->attrs);
140+
mod->Update(main_var, new_main_func);
141+
}
142+
return mod;
143+
}
144+
145+
transform::Pass ExtractConstantsFromPartitionedFunction() {
146+
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
147+
[=](IRModule m, transform::PassContext pc) { return ExtractConstants(m); };
148+
return tvm::transform::CreateModulePass(pass_func, 0, "ExtractConstantsFromPartitionedFunction",
149+
{});
150+
}
151+
152+
TVM_REGISTER_GLOBAL("relay.ext.cmsisnn.transform.ExtractConstantsFromPartitionedFunction")
153+
.set_body_typed([]() { return ExtractConstantsFromPartitionedFunction(); });
154+
155+
} // namespace cmsisnn
156+
} // namespace contrib
157+
} // namespace relay
158+
} // namespace tvm

0 commit comments

Comments
 (0)