Skip to content

Commit 236eea0

Browse files
authored
[CMSIS-NN] Removed redudant arguments to CMSIS-NN wrapper function (#11431)
Removed input_scale and filter_scale from CMSIS-NN wrapper function. These are not needed by CMSIS-NN API which gets called from the generated C wrapper function for Conv2D.
1 parent 2f9d9b4 commit 236eea0

File tree

2 files changed

+121
-4
lines changed

2 files changed

+121
-4
lines changed

src/relay/backend/contrib/cmsisnn/relay_to_tir.cc

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,18 +141,24 @@ class RelayToTIRVisitor : public MixedModeMutator {
141141
// %3 = qnn.requantize(%2, %input_scale_const_4, %cmsisnn_shift_const_5,
142142
// %output_scale_scalar, %output_zero_point_scalar)
143143
// clip(%3, a_min=%min_scalar, a_max=%max_scalar)
144+
// Position of scales in the global function for Conv2D
145+
const int filter_scale_pos = 3;
146+
const int input_scale_pos = bias_add_call ? 5 : 4;
144147
BufferCreator buffer_creator;
145148
tir::Var input = buffer_creator.CreateBufferVar("input", DataType::Handle(8));
146149
tir::Var filter = buffer_creator.CreateBufferVar("filter", DataType::Handle(8));
147150
tir::Var multiplier = buffer_creator.CreateBufferVar("multiplier", DataType::Handle(32));
148-
tir::Var filter_scale = buffer_creator.CreateBufferVar("filter_scale", DataType::Handle(32));
149151
if (bias_add_call) {
150152
buffer_creator.CreateBufferVar("bias", DataType::Handle(32));
151153
}
152-
tir::Var input_scale = buffer_creator.CreateBufferVar("input_scale", DataType::Handle(32));
153154
tir::Var shift = buffer_creator.CreateBufferVar("shift", DataType::Handle(32));
154155
tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(8));
155156

157+
// Relay function contains input_scale and filter_scale as function parameters at the following
158+
// locations in the global partitioned function for Conv2D
159+
skip_call_args_.insert(filter_scale_pos);
160+
skip_call_args_.insert(input_scale_pos);
161+
156162
// Individual arguments to the structs arguments of the CMSIS-NN API are filled into call_extern
157163
// https://github.com/ARM-software/CMSIS_5/blob/def6f800f95661eb3451d317f7d0dde504f6020d/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_wrapper_s8.c#L50
158164

@@ -742,11 +748,25 @@ class RelayToTIRVisitor : public MixedModeMutator {
742748
GetRef<Function>(func));
743749
}
744750

751+
// Drop out the redundant arguments, and the arg_types from the global function call
745752
Array<Expr> args;
753+
Array<Type> arg_types;
754+
auto* func_type = new_global_var->checked_type_.as<FuncTypeNode>();
755+
int arg_id = -1;
746756
for (const auto& arg : call->args) {
757+
++arg_id;
758+
if (std::find(skip_call_args_.begin(), skip_call_args_.end(), arg_id) !=
759+
skip_call_args_.end()) {
760+
continue;
761+
}
747762
args.push_back(VisitExpr(arg));
763+
arg_types.push_back(func_type->arg_types[arg_id]);
748764
}
749-
765+
if (arg_types.size() != func_type->arg_types.size()) {
766+
new_global_var->checked_type_ =
767+
FuncType(arg_types, func_type->ret_type, {}, func_type->type_constraints);
768+
}
769+
skip_call_args_.clear();
750770
return Call(new_global_var, args, call->attrs, call->type_args, call->span);
751771
}
752772
}
@@ -757,7 +777,10 @@ class RelayToTIRVisitor : public MixedModeMutator {
757777
static constexpr int32_t kScaledDiffIntegerBits = 5;
758778
static constexpr int32_t kInputBits = 5;
759779
static constexpr double kBeta = 1.0;
780+
/*! \brief Unique id for context buffer needed by CMSIS-NN layers. */
760781
int32_t context_buffer_id_;
782+
/*! \brief Skip arguments in the call to global partitioned function. */
783+
std::unordered_set<int32_t> skip_call_args_;
761784
IRModule ir_module_;
762785
Target target_;
763786
};

tests/python/contrib/test_cmsisnn/test_conv2d.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from tvm import relay
2424
from tvm.relay.op.contrib import cmsisnn
2525

26-
from tvm.testing.aot import generate_ref_data, AOTTestModel, compile_and_run
26+
from tvm.testing.aot import generate_ref_data, AOTTestModel, compile_models, compile_and_run
2727

2828
from tvm.micro.testing.aot_test_utils import AOT_USMP_CORSTONE300_RUNNER
2929
from utils import (
@@ -119,6 +119,100 @@ def make_model(
119119
return last_op, params
120120

121121

122+
@tvm.testing.requires_cmsisnn
123+
@pytest.mark.parametrize("padding", ["SAME", "VALID"])
124+
@pytest.mark.parametrize("enable_bias", [True, False])
125+
@pytest.mark.parametrize(
126+
"input_zero_point, input_scale, kernel_scale, out_channels",
127+
[(10, 0.0128, [0.11, 0.22], 2)],
128+
)
129+
def test_conv2d_number_primfunc_args(
130+
padding,
131+
enable_bias,
132+
input_zero_point,
133+
input_scale,
134+
kernel_scale,
135+
out_channels,
136+
):
137+
interface_api = "c"
138+
use_unpacked_api = True
139+
test_runner = AOT_USMP_CORSTONE300_RUNNER
140+
141+
ifm_shape = (1, 64, 100, 4)
142+
kernel_size = (3, 3)
143+
strides = (1, 1)
144+
dilation = (1, 1)
145+
dtype = "int8"
146+
groups = 1
147+
weight_format = "HWIO"
148+
kernel_h = kernel_size[0]
149+
kernel_w = kernel_size[1]
150+
kernel_shape = (kernel_h, kernel_w, ifm_shape[3] // groups, out_channels)
151+
kernel_zero_point = 0
152+
in_min, in_max = get_range_for_dtype_str(dtype)
153+
relu_type = "RELU"
154+
155+
output_scale, output_zero_point = get_conv2d_qnn_params(
156+
kernel_shape,
157+
input_scale,
158+
input_zero_point,
159+
kernel_scale,
160+
kernel_zero_point,
161+
dtype,
162+
dtype,
163+
dtype,
164+
)
165+
166+
model, params = make_model(
167+
ifm_shape,
168+
kernel_shape,
169+
input_zero_point,
170+
input_scale,
171+
kernel_zero_point,
172+
kernel_scale,
173+
output_zero_point,
174+
output_scale,
175+
padding,
176+
strides,
177+
dilation,
178+
groups,
179+
dtype,
180+
dtype,
181+
out_channels,
182+
weight_format,
183+
enable_bias,
184+
relu_type,
185+
)
186+
orig_mod = make_module(model)
187+
cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params)
188+
189+
# validate pattern matching
190+
assert_partitioned_function(orig_mod, cmsisnn_mod)
191+
192+
# compile the model
193+
rng = np.random.default_rng(12345)
194+
inputs = {"input": rng.integers(in_min, high=in_max, size=ifm_shape, dtype=dtype)}
195+
output_list = generate_ref_data(orig_mod["main"], inputs, params)
196+
197+
compiled_models = compile_models(
198+
AOTTestModel(module=cmsisnn_mod, inputs=inputs, outputs=output_list, params=params),
199+
interface_api,
200+
use_unpacked_api,
201+
)
202+
203+
# validate number of TIR primfunc args
204+
expected_num_params = 6 if enable_bias else 5
205+
cmsisnn_tir_mod = None
206+
for target, mod in compiled_models[0].executor_factory.lowered_ir_mods.items():
207+
if "cmsis-nn" == target.kind.name:
208+
cmsisnn_tir_mod = mod
209+
210+
cmsisnn_func = cmsisnn_tir_mod["tvmgen_default_cmsis_nn_main_0"]
211+
assert (
212+
len(cmsisnn_func.params) == expected_num_params
213+
), "Generated unexpected number of function arguments"
214+
215+
122216
@tvm.testing.requires_cmsisnn
123217
@pytest.mark.parametrize("padding", ["SAME", "VALID"])
124218
@pytest.mark.parametrize("relu_type", ["RELU"])

0 commit comments

Comments
 (0)