Skip to content

Commit 864fc25

Browse files
author
Ashutosh Parkhi
committed
code review: tests for new passes, clean up of relay_to_tir for cmsis-nn
Change-Id: I78dafb2be49afcd5b816c1d2828246c213edca85
1 parent 66eed5c commit 864fc25

File tree

9 files changed

+604
-166
lines changed

9 files changed

+604
-166
lines changed

python/tvm/relay/op/contrib/cmsisnn.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from ...dataflow_pattern import is_constant, is_op, wildcard
2424
from .register import register_pattern_table
2525

26+
tvm._ffi._init_api("relay.ext.cmsisnn.transform", __name__)
27+
2628

2729
def enabled():
2830
return bool(tvm.get_global_func("relay.ext.cmsisnn", True))
@@ -47,8 +49,6 @@ def partition_for_cmsisnn(mod, params=None, **opts):
4749
if params:
4850
mod["main"] = bind_params_by_name(mod["main"], params)
4951

50-
tvm._ffi._init_api("relay.ext.cmsisnn.transform", __name__)
51-
5252
seq = tvm.transform.Sequential(
5353
[
5454
transform.InferType(),
@@ -91,7 +91,7 @@ def qnn_conv2d_pattern():
9191
"""Create pattern for qnn.conv2D with optional fused relu."""
9292
qnn_conv2d = is_op("qnn.conv2d")(
9393
wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant()
94-
).has_attr({"kernel_layout": "HWIO"})
94+
)
9595
bias_add = is_op("nn.bias_add")(qnn_conv2d, is_constant())
9696
req = is_op("qnn.requantize")(
9797
qnn_conv2d | bias_add, is_constant(), is_constant(), is_constant(), is_constant()
@@ -123,8 +123,7 @@ def check_qnn_conv2d(pattern):
123123
kernel_zp = [kernel_zp] if kernel_zp.ndim == 0 else kernel_zp
124124

125125
return (
126-
conv2d.attrs.kernel_layout == "HWIO"
127-
and conv2d.attrs.out_dtype == "int32"
126+
conv2d.attrs.out_dtype == "int32"
128127
and conv2d.attrs.padding[2] == 0
129128
and conv2d.attrs.padding[3] == 0
130129
and conv2d_input.checked_type.dtype == "int8"

src/relay/backend/contrib/cmsisnn/extract_constants.cc

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@
1717
* specific language governing permissions and limitations
1818
* under the License.
1919
*/
20+
/*!
21+
* \file extract_constant.cc
22+
* \brief Pushes out constants within partitioned functions all the way upto main()
23+
*/
24+
2025
#include <tvm/relay/attrs/nn.h>
2126
#include <tvm/relay/expr_functor.h>
2227
#include <tvm/relay/transform.h>
@@ -30,44 +35,47 @@ namespace relay {
3035
namespace contrib {
3136
namespace cmsisnn {
3237

38+
/*!
39+
* \brief This Mutator finds all functions with constants. Constants are replaced with function
40+
* parameter variables. Constants are pushed all the way upto main().
41+
*/
3342
class ExtractConstantsMutator : public MixedModeMutator {
3443
public:
35-
explicit ExtractConstantsMutator(IRModule& mod) : mod_(mod) {}
44+
explicit ExtractConstantsMutator(const IRModule& mod) : mod_(mod) {}
3645

3746
private:
3847
String gen_var_name() { return "tvm_var_extract_const_" + std::to_string(var_count_++); }
3948

40-
Expr VisitExpr_(const FunctionNode* func) final {
41-
Function final_func = GetRef<Function>(func);
42-
++func_nesting_level_;
49+
Expr VisitExpr_(const FunctionNode* function) final {
50+
Function func = GetRef<Function>(function);
51+
function_to_constants_.Set(func, Array<Constant>{});
52+
functions_.push_back(func);
4353
auto new_body = VisitExpr(func->body);
44-
--func_nesting_level_;
45-
if (!new_body.same_as(func->body)) {
46-
final_func = Function(FreeVars(new_body), new_body, func->ret_type,
47-
FreeTypeVars(new_body, mod_), func->attrs);
48-
function_to_constants_.Set(GetRef<Function>(func), constants_within_function_);
49-
constants_within_function_.clear();
54+
functions_.pop_back();
55+
if (function_to_constants_[func].size()) {
56+
func = Function(FreeVars(new_body), new_body, func->ret_type, FreeTypeVars(new_body, mod_),
57+
func->attrs);
5058
}
51-
return final_func;
59+
return func;
5260
}
5361

5462
Expr Rewrite_(const CallNode* call, const Expr& post) final {
5563
Expr final_call = post;
5664
auto* post_call = post.as<CallNode>();
57-
if (post_call == nullptr) {
58-
return final_call;
59-
}
6065

6166
// Replace Constant arguments with Vars for ML Operators
6267
// Perform this for non-main Call Nodes only
63-
if (func_nesting_level_ && call->op.as<OpNode>()) {
68+
if (!functions_.empty() && call->op.as<OpNode>()) {
6469
Array<Expr> new_args;
6570
for (auto& arg : post_call->args) {
6671
auto* const_arg = arg.as<ConstantNode>();
6772
if (const_arg && !const_arg->is_scalar()) {
6873
Var var_arg = Var(gen_var_name(), const_arg->tensor_type());
6974
new_args.push_back(var_arg);
70-
constants_within_function_.push_back(GetRef<Constant>(const_arg));
75+
const Function& last_func = functions_.back();
76+
Array<Constant> fconstants(function_to_constants_[last_func]);
77+
fconstants.push_back(GetRef<Constant>(const_arg));
78+
function_to_constants_.Set(last_func, fconstants);
7179
} else {
7280
new_args.push_back(arg);
7381
}
@@ -94,17 +102,21 @@ class ExtractConstantsMutator : public MixedModeMutator {
94102

95103
// Since the constants are kicked out of the local partitioned functions
96104
// a new call to local function is needed
105+
// Also, pass on the constants to the callee of this function to support nested functions
97106
if (auto* func_node = call->op.as<FunctionNode>()) {
98107
Function func = GetRef<Function>(func_node);
99108
auto new_func = VisitExpr(func);
100109
if (!new_func.same_as(func)) {
101110
Array<Expr> new_args = post_call->args;
102111
ICHECK(function_to_constants_.find(func) != function_to_constants_.end());
112+
const Function& last_func = functions_.back();
113+
Array<Constant> fconstants(function_to_constants_[last_func]);
103114
for (auto constant : function_to_constants_.at(func)) {
104-
constants_within_function_.push_back(constant);
115+
fconstants.push_back(constant);
105116
Var var_arg = Var(gen_var_name(), constant->tensor_type());
106117
new_args.push_back(var_arg);
107118
}
119+
function_to_constants_.Set(last_func, fconstants);
108120
final_call = Call(new_func, new_args);
109121
}
110122
}
@@ -117,16 +129,14 @@ class ExtractConstantsMutator : public MixedModeMutator {
117129
IRModule mod_;
118130
/* \brief Maintains mapping of original function to the replaced constants */
119131
Map<Function, Array<Constant>> function_to_constants_;
120-
/* \brief Constants being kicked out of a function during the function visit */
121-
Array<Constant> constants_within_function_;
132+
/* \brief Stack of functions to determine scope while filling up function_to_constants_ */
133+
Array<Function> functions_;
122134
/* \brief Keeps track of variables being created */
123135
int var_count_ = 0;
124-
/* \brief Keeps track of function scope */
125-
int func_nesting_level_ = 0;
126136
};
127137

128138
/*! * \brief Kicks out all constants out of the partitioned function into main() */
129-
IRModule ExtractConstants(IRModule mod) {
139+
IRModule ExtractConstants(const IRModule& mod) {
130140
String func_name;
131141
Function func;
132142

@@ -150,7 +160,7 @@ transform::Pass ExtractConstantsFromPartitionedFunction() {
150160
}
151161

152162
TVM_REGISTER_GLOBAL("relay.ext.cmsisnn.transform.ExtractConstantsFromPartitionedFunction")
153-
.set_body_typed([]() { return ExtractConstantsFromPartitionedFunction(); });
163+
.set_body_typed(ExtractConstantsFromPartitionedFunction);
154164

155165
} // namespace cmsisnn
156166
} // namespace contrib

src/relay/backend/contrib/cmsisnn/generate_constants.cc

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,35 @@
1717
* specific language governing permissions and limitations
1818
* under the License.
1919
*/
20+
/*!
21+
* \file generate_constant.cc
22+
* \brief Generates quantization parameters needed by CMSIS-NN
23+
*/
24+
2025
#include <tvm/relay/attrs/nn.h>
2126
#include <tvm/relay/attrs/transform.h>
2227
#include <tvm/relay/expr_functor.h>
2328
#include <tvm/relay/transform.h>
2429
#include <tvm/runtime/ndarray.h>
2530

31+
#include "../../../op/make_op.h"
2632
#include "../../../qnn/utils.h"
2733
#include "../../../transforms/pattern_utils.h"
2834

2935
namespace tvm {
3036
namespace relay {
31-
Expr MakeTranspose(Expr data, Array<Integer> axes);
3237
namespace contrib {
3338
namespace cmsisnn {
3439

40+
/*!
41+
* \brief This Mutator will find all partitioned functions meant for CMSIS-NN Conv2D.
42+
* It will substitute original Conv2D's weight zero point and original Requantize's input zero point
43+
* with CMSIS-NN's quantization parameters.
44+
* https://github.com/tensorflow/tflite-micro/blob/0f40100fc60276e9f345c23282de3baf19a78059/tensorflow/lite/kernels/internal/quantization_util.cc#L53
45+
*/
3546
class GenerateConstantsMutator : public MixedModeMutator {
3647
public:
37-
explicit GenerateConstantsMutator(IRModule& mod) : mod_(mod) {}
48+
explicit GenerateConstantsMutator(const IRModule& mod) : mod_(mod) {}
3849

3950
private:
4051
/*! * \brief Converts Kernel layout from HWIO to OHWI to align to CMSIS-NN requirements */
@@ -52,8 +63,15 @@ class GenerateConstantsMutator : public MixedModeMutator {
5263
attrs->out_dtype = std::move(conv2d_attrs->out_dtype);
5364
*new_attrs = tvm::Attrs{attrs};
5465

66+
std::string kernel_layout = conv2d_attrs->kernel_layout.c_str();
67+
int pos_o = kernel_layout.find("O");
68+
int pos_h = kernel_layout.find("H");
69+
int pos_w = kernel_layout.find("W");
70+
int pos_i = kernel_layout.find("I");
71+
5572
IRModule kernel_module;
56-
auto func_body = MakeTranspose(kernel_expr, {Integer(3), Integer(0), Integer(1), Integer(2)});
73+
auto func_body = MakeTranspose(
74+
kernel_expr, {Integer(pos_o), Integer(pos_h), Integer(pos_w), Integer(pos_i)});
5775
auto kernel_func =
5876
Function(FreeVars(func_body), func_body, Type(), FreeTypeVars(func_body, kernel_module));
5977
GlobalVar kernel_var("main");
@@ -158,9 +176,6 @@ class GenerateConstantsMutator : public MixedModeMutator {
158176
Expr Rewrite_(const CallNode* call, const Expr& post) final {
159177
Expr final_call = post;
160178
auto* post_call = post.as<CallNode>();
161-
if (post_call == nullptr) {
162-
return final_call;
163-
}
164179

165180
auto* global_var = call->op.as<GlobalVarNode>();
166181
if (global_var) {
@@ -196,7 +211,7 @@ class GenerateConstantsMutator : public MixedModeMutator {
196211
IRModule mod_;
197212
};
198213

199-
IRModule GenerateConstants(IRModule mod) {
214+
IRModule GenerateConstants(const IRModule& mod) {
200215
String func_name;
201216
Function func;
202217

@@ -220,9 +235,8 @@ transform::Pass GenerateCMSISNNConstants() {
220235
return tvm::transform::CreateModulePass(pass_func, 0, "GenerateCMSISNNConstants", {});
221236
}
222237

223-
TVM_REGISTER_GLOBAL("relay.ext.cmsisnn.transform.GenerateCMSISNNConstants").set_body_typed([]() {
224-
return GenerateCMSISNNConstants();
225-
});
238+
TVM_REGISTER_GLOBAL("relay.ext.cmsisnn.transform.GenerateCMSISNNConstants")
239+
.set_body_typed(GenerateCMSISNNConstants);
226240

227241
} // namespace cmsisnn
228242
} // namespace contrib

0 commit comments

Comments
 (0)