Skip to content

Commit dffc310

Browse files
authored
[CMSIS-NN] Fixed the case with repeating operands in the QNN binary ops (#11732)
1 parent 2ffd955 commit dffc310

File tree

7 files changed

+165
-5
lines changed

7 files changed

+165
-5
lines changed

python/tvm/relay/op/contrib/cmsisnn.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,6 @@ def qnn_max_pool2d_pattern():
223223
def check_qnn_max_pool2d(pattern):
224224
"""Check if max pool2d is supported by CMSIS-NN."""
225225
output = pattern
226-
input_op = None
227226

228227
if str(pattern.op.name) == "clip":
229228
pooling = pattern.args[0]

src/relay/backend/contrib/cmsisnn/extract_constants.cc

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,18 @@ class ExtractConstantsMutator : public MixedModeMutator {
164164
function_signature.push_back(arg);
165165
} else {
166166
if (arg.as<VarNode>()) {
167-
function_signature.push_back(arg);
167+
// Only push if its not already present as multiple consumers of any input var
168+
// will appear only once in the function signature.
169+
bool found_in_existing_signature = false;
170+
for (auto& sign : function_signature) {
171+
if (arg.same_as(sign)) {
172+
found_in_existing_signature = true;
173+
break;
174+
}
175+
}
176+
if (!found_in_existing_signature) {
177+
function_signature.push_back(arg);
178+
}
168179
}
169180
new_args.push_back(arg);
170181
}

src/relay/backend/contrib/cmsisnn/relay_to_tir.cc

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,12 @@ class RelayToTIRVisitor : public MixedModeMutator {
556556

557557
BufferCreator buffer_creator;
558558
tir::Var input_0 = buffer_creator.CreateBufferVar("input_0", DataType::Handle(8));
559-
tir::Var input_1 = buffer_creator.CreateBufferVar("input_1", DataType::Handle(8));
559+
tir::Var input_1;
560+
if (mul_call->args[0].same_as(mul_call->args[1])) {
561+
input_1 = input_0;
562+
} else {
563+
input_1 = buffer_creator.CreateBufferVar("input_1", DataType::Handle(8));
564+
}
560565
tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(8));
561566

562567
tvm::Array<PrimExpr> args = {
@@ -626,7 +631,12 @@ class RelayToTIRVisitor : public MixedModeMutator {
626631

627632
BufferCreator buffer_creator;
628633
tir::Var input_0 = buffer_creator.CreateBufferVar("input_0", DataType::Handle(8));
629-
tir::Var input_1 = buffer_creator.CreateBufferVar("input_1", DataType::Handle(8));
634+
tir::Var input_1;
635+
if (add_call->args[0].same_as(add_call->args[1])) {
636+
input_1 = input_0;
637+
} else {
638+
input_1 = buffer_creator.CreateBufferVar("input_1", DataType::Handle(8));
639+
}
630640
tir::Var output = buffer_creator.CreateBufferVar("output", DataType::Handle(8));
631641

632642
tvm::Array<PrimExpr> args = {

src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,12 @@ class ScalarToTensorConstantMutator : public MixedModeMutator {
179179
auto new_body = VisitExpr(func->body);
180180
Function new_func = WithFields(func, FreeVars(new_body), new_body, func->ret_type,
181181
FreeTypeVars(new_body, mod_), func->attrs);
182+
183+
// Updating new_func parameters could result into uniquification of function parameters.
184+
// Call arguments need to be aligned to the number of arguments expected by new_func.
185+
if (new_args[0].same_as(new_args[1])) {
186+
new_args.erase(new_args.begin());
187+
}
182188
return Call(new_func, new_args);
183189
}
184190

tests/python/contrib/test_cmsisnn/test_binary_ops.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def make_model(
101101
def test_op_int8(
102102
op, relu_type, input_0_scale, input_0_zero_point, input_1_scale, input_1_zero_point
103103
):
104-
"""Tests QNN Conv2D operator for CMSIS-NN"""
104+
"""Tests QNN binary operator for CMSIS-NN"""
105105
interface_api = "c"
106106
use_unpacked_api = True
107107
test_runner = AOT_USMP_CORSTONE300_RUNNER
@@ -145,6 +145,65 @@ def test_op_int8(
145145
)
146146

147147

148+
@skip_if_no_reference_system
149+
@tvm.testing.requires_cmsisnn
150+
@pytest.mark.parametrize("op", [relay.qnn.op.mul, relay.qnn.op.add])
151+
@pytest.mark.parametrize("relu_type", ["RELU", "NONE"])
152+
def test_same_input_to_binary_op(op, relu_type):
153+
"""Tests QNN binary operator for CMSIS-NN where both inputs are the same"""
154+
interface_api = "c"
155+
use_unpacked_api = True
156+
test_runner = AOT_USMP_CORSTONE300_RUNNER
157+
158+
dtype = "int8"
159+
shape = [1, 16, 16, 3]
160+
input_ = generate_variable("input")
161+
input_scale = 0.256
162+
input_zero_point = 33
163+
164+
model = make_model(
165+
op,
166+
input_,
167+
input_,
168+
input_scale,
169+
input_zero_point,
170+
input_scale,
171+
input_zero_point,
172+
relu_type,
173+
)
174+
orig_mod = make_module(model)
175+
176+
cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)
177+
178+
# validate pattern matching
179+
assert_partitioned_function(orig_mod, cmsisnn_mod)
180+
181+
# Check if the number of internal function parameter is 1
182+
cmsisnn_global_func = cmsisnn_mod["tvmgen_default_cmsis_nn_main_0"]
183+
assert (
184+
isinstance(cmsisnn_global_func.body, tvm.relay.expr.Call)
185+
and len(cmsisnn_global_func.body.args) == 1
186+
), "Composite function for the binary op should have only 1 parameter."
187+
188+
# validate the output
189+
in_min, in_max = get_range_for_dtype_str(dtype)
190+
inputs = {
191+
"input": np.random.randint(in_min, high=in_max, size=shape, dtype=dtype),
192+
}
193+
output_list = generate_ref_data(orig_mod["main"], inputs)
194+
compile_and_run(
195+
AOTTestModel(
196+
module=cmsisnn_mod,
197+
inputs=inputs,
198+
outputs=output_list,
199+
output_tolerance=1,
200+
),
201+
test_runner,
202+
interface_api,
203+
use_unpacked_api,
204+
)
205+
206+
148207
def parameterize_for_constant_inputs(test):
149208
"""Generates parameters in such a way so that at least one of the inputs is a constant,
150209
both can't be variables, both can't be scalars.

tests/python/contrib/test_cmsisnn/test_extract_constants.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,40 @@ def test_nested_function():
116116
relay.transform.InferType()(mod)
117117

118118

119+
@tvm.testing.requires_cmsisnn
120+
def test_internal_function_with_duplicate_arguments():
121+
"""Tests the pass ExternConstants when a composite function
122+
is present within global function with repeating arguments
123+
to one of the binary ops.
124+
"""
125+
input0 = relay.var("input0", shape=(8, 8))
126+
binary_op0 = input0 + input0
127+
binary_op1 = binary_op0 * relay.const(5.0, "float32")
128+
local_func = relay.Function([input0], binary_op1, relay.TensorType((8, 8), "float32"))
129+
local_func = set_composite_func_attr(local_func, "cmsis-nn")
130+
131+
arg = relay.var("arg", shape=(8, 8))
132+
call_local_func = relay.Call(local_func, [arg])
133+
extern_func = relay.Function([arg], call_local_func, relay.TensorType((8, 8), "float32"))
134+
135+
global_arg = relay.var("global_var", shape=(8, 8))
136+
global_var = relay.GlobalVar("external_function")
137+
extern_func = set_external_func_attr(extern_func, "cmsis-nn", global_var.name_hint)
138+
call_extern_func = relay.Call(global_var, [global_arg])
139+
main_func = relay.Function([global_arg], call_extern_func, relay.TensorType((8, 8), "float32"))
140+
main_var = relay.GlobalVar("main")
141+
142+
mod = tvm.IRModule()
143+
mod[global_var] = extern_func
144+
mod[main_var] = main_func
145+
146+
mod = ExtractConstantsFromPartitionedFunction()(mod)
147+
constant_verifier = CheckFunctionsForConstants()
148+
constant_verifier.visit_function(mod[global_var])
149+
constant_verifier.check_num_constants()
150+
relay.transform.InferType()(mod)
151+
152+
119153
@tvm.testing.requires_cmsisnn
120154
def test_multiple_functions():
121155
"""Tests the pass ExternConstants when global function

tests/python/contrib/test_cmsisnn/test_scalar_to_tensor_constant.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,47 @@ def test_all_primary_operands_tensor_constants():
256256
assert tvm.ir.structural_equal(mod[global_var].body, new_mod[global_var].body)
257257

258258

259+
@tvm.testing.requires_cmsisnn
260+
def test_duplicate_constant_arguments():
261+
"""Tests the pass when repeating operands are arguments to the binary op"""
262+
dtype = "int8"
263+
shape = (1, 3, 3, 32)
264+
operand0 = generate_variable("operand0", shape, dtype)
265+
operand1 = generate_variable("operand1", shape, dtype)
266+
binary_op = make_binary_op(
267+
relay.qnn.op.add,
268+
operand0,
269+
operand0,
270+
input_0_scale=0.0128,
271+
input_0_zero_point=32,
272+
input_1_scale=0.256,
273+
input_1_zero_point=-64,
274+
)
275+
276+
local_func = relay.Function([operand0, operand1], binary_op, relay.TensorType(shape, dtype))
277+
local_func = set_composite_func_attr(local_func, "cmsis-nn.qnn_add")
278+
279+
rng = np.random.default_rng(12345)
280+
arg0 = relay.const(rng.integers(-128, high=127, size=shape, dtype=dtype))
281+
call_local_func = relay.Call(local_func, [arg0, arg0])
282+
extern_func = relay.Function([], call_local_func, relay.TensorType(shape, dtype))
283+
284+
global_var = relay.GlobalVar("external_function")
285+
extern_func = set_external_func_attr(extern_func, "cmsis-nn", global_var.name_hint)
286+
call_extern_func = relay.Call(global_var, [])
287+
main_func = relay.Function([], call_extern_func, relay.TensorType(shape, dtype))
288+
main_var = relay.GlobalVar("main")
289+
290+
mod = tvm.IRModule()
291+
mod[global_var] = extern_func
292+
mod[main_var] = main_func
293+
294+
mod = relay.transform.InferType()(mod)
295+
mod = ScalarToTensorConstants()(mod)
296+
new_mod = relay.transform.InferType()(mod)
297+
assert tvm.ir.structural_equal(mod[global_var].body, new_mod[global_var].body)
298+
299+
259300
@tvm.testing.requires_cmsisnn
260301
def test_non_cmsisnn_ext_func():
261302
"""Non CMSISNN functions should not be altered."""

0 commit comments

Comments
 (0)