Skip to content

Commit 6e4e8fa

Browse files
Cortex_m backend: Add SharedQspecQuantizer + maximum/minimum (#15872)
This patch reworks how ops w/o rescaling are handled. Instead of not annotating them and counting on the retracing to make them get correct dtype, they must now be annotated by the SharedQspecQuantizer. The reason for this is that it simplifies ensuring that non-rescaling with multiple inputs has the same scale on both inputs. It also becomes easier to tell which dtype each op will receive before folding. Signed-off-by: Adrian Lundell <[email protected]>
1 parent 9952aef commit 6e4e8fa

File tree

11 files changed

+893
-101
lines changed

11 files changed

+893
-101
lines changed

backends/cortex_m/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ set(_cortex_m_kernels__srcs
5858
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_add.cpp
5959
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_linear.cpp
6060
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_mul.cpp
61+
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_minimum.cpp
62+
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_maximum.cpp
6163
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_transpose.cpp
6264
)
6365

backends/cortex_m/ops/cortex_m_ops_common.h

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,19 @@
1818
#include <executorch/runtime/kernel/kernel_includes.h>
1919
#include <executorch/runtime/platform/assert.h>
2020

21+
#include <limits>
22+
#include <optional>
23+
24+
extern "C" {
25+
#include "arm_nn_types.h"
26+
}
27+
2128
using Tensor = torch::executor::Tensor;
2229
using ScalarType = executorch::aten::ScalarType;
2330
using Scalar = torch::executor::Scalar;
2431
using Error = executorch::runtime::Error;
2532
using IntArrayRef = executorch::aten::ArrayRef<int64_t>;
33+
using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
2634

2735
// From arm_nn_math_types.h
2836
#define ARM_NN_Q31_MAX ((int32_t)(0x7FFFFFFFL))
@@ -34,7 +42,8 @@ inline void validate_cmsis_nn_tensor_requirements(
3442
const Tensor& input2,
3543
Tensor& output,
3644
ScalarType expected_dtype = ScalarType::Char,
37-
bool require_channels_last = false) {
45+
bool require_channels_last = false,
46+
bool require_same_sizes = true) {
3847
// Basic dtype validation
3948
ET_CHECK_MSG(
4049
input1.scalar_type() == expected_dtype,
@@ -51,12 +60,14 @@ inline void validate_cmsis_nn_tensor_requirements(
5160
"Output dtype must be %hhd, got %hhd",
5261
expected_dtype,
5362
output.scalar_type());
54-
ET_CHECK_MSG(
55-
input1.sizes() == input2.sizes(),
56-
"Input1 and Input2 must have the same sizes");
57-
ET_CHECK_MSG(
58-
output.sizes() == input1.sizes(),
59-
"Output must have the same sizes as inputs");
63+
if (require_same_sizes) {
64+
ET_CHECK_MSG(
65+
input1.sizes() == input2.sizes(),
66+
"Input1 and Input2 must have the same sizes");
67+
ET_CHECK_MSG(
68+
output.sizes() == input1.sizes(),
69+
"Output must have the same sizes as inputs");
70+
}
6071

6172
// Dim order consistency
6273
ET_CHECK_MSG(
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/*
2+
* Copyright 2025 Arm Limited and/or its affiliates.
3+
*
4+
* This source code is licensed under the BSD-style license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#include "cortex_m_ops_common.h"
9+
10+
// Include CMSIS-NN headers with C linkage
11+
extern "C" {
12+
#include "arm_nnfunctions.h"
13+
}
14+
15+
namespace cortex_m {
16+
namespace native {
17+
18+
using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
19+
20+
Tensor& maximum_out(
21+
KernelRuntimeContext& context,
22+
const Tensor& input1,
23+
const Tensor& input2,
24+
Tensor& out) {
25+
validate_cmsis_nn_tensor_requirements(
26+
input1,
27+
input2,
28+
out,
29+
ScalarType::Char,
30+
/*require_channels_last=*/false,
31+
/*require_same_sizes=*/false);
32+
33+
auto resize_error = resize_to_broadcast_target_size(input1, input2, out);
34+
if (resize_error != Error::Ok) {
35+
ET_LOG(Error, "maximum_out: broadcast shape mismatch between inputs");
36+
context.fail(resize_error);
37+
return out;
38+
}
39+
40+
const int8_t* input1_data = input1.const_data_ptr<int8_t>();
41+
const int8_t* input2_data = input2.const_data_ptr<int8_t>();
42+
int8_t* output_data = out.mutable_data_ptr<int8_t>();
43+
44+
// Create CMSIS-NN dims directly from tensor sizes
45+
const auto input1_rank = input1.dim();
46+
const auto input1_sizes = input1.sizes();
47+
const cmsis_nn_dims input1_dims{
48+
static_cast<int32_t>(
49+
input1_rank >= 4 ? input1_sizes[input1_rank - 4] : 1),
50+
static_cast<int32_t>(
51+
input1_rank >= 3 ? input1_sizes[input1_rank - 3] : 1),
52+
static_cast<int32_t>(
53+
input1_rank >= 2 ? input1_sizes[input1_rank - 2] : 1),
54+
static_cast<int32_t>(
55+
input1_rank >= 1 ? input1_sizes[input1_rank - 1] : 1)};
56+
57+
const auto input2_rank = input2.dim();
58+
const auto input2_sizes = input2.sizes();
59+
const cmsis_nn_dims input2_dims{
60+
static_cast<int32_t>(
61+
input2_rank >= 4 ? input2_sizes[input2_rank - 4] : 1),
62+
static_cast<int32_t>(
63+
input2_rank >= 3 ? input2_sizes[input2_rank - 3] : 1),
64+
static_cast<int32_t>(
65+
input2_rank >= 2 ? input2_sizes[input2_rank - 2] : 1),
66+
static_cast<int32_t>(
67+
input2_rank >= 1 ? input2_sizes[input2_rank - 1] : 1)};
68+
69+
const auto output_rank = out.dim();
70+
const auto output_sizes = out.sizes();
71+
const cmsis_nn_dims output_dims{
72+
static_cast<int32_t>(
73+
output_rank >= 4 ? output_sizes[output_rank - 4] : 1),
74+
static_cast<int32_t>(
75+
output_rank >= 3 ? output_sizes[output_rank - 3] : 1),
76+
static_cast<int32_t>(
77+
output_rank >= 2 ? output_sizes[output_rank - 2] : 1),
78+
static_cast<int32_t>(
79+
output_rank >= 1 ? output_sizes[output_rank - 1] : 1)};
80+
81+
const arm_cmsis_nn_status status = arm_maximum_s8(
82+
/* ctx */ nullptr,
83+
input1_data,
84+
&input1_dims,
85+
input2_data,
86+
&input2_dims,
87+
output_data,
88+
&output_dims);
89+
90+
if (status != ARM_CMSIS_NN_SUCCESS) {
91+
ET_LOG(
92+
Error,
93+
"maximum_out: arm_maximum_s8 failed with status [%d]",
94+
static_cast<int>(status));
95+
context.fail(Error::Internal);
96+
}
97+
98+
return out;
99+
}
100+
101+
} // namespace native
102+
} // namespace cortex_m
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
* Copyright 2025 Arm Limited and/or its affiliates.
5+
*
6+
* This source code is licensed under the BSD-style license found in the
7+
* LICENSE file in the root directory of this source tree.
8+
*/
9+
10+
#include "cortex_m_ops_common.h"
11+
12+
// Include CMSIS-NN headers with C linkage
13+
extern "C" {
14+
#include "arm_nnfunctions.h"
15+
}
16+
17+
namespace cortex_m {
18+
namespace native {
19+
20+
using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
21+
22+
Tensor& minimum_out(
23+
KernelRuntimeContext& context,
24+
const Tensor& input1,
25+
const Tensor& input2,
26+
Tensor& out) {
27+
validate_cmsis_nn_tensor_requirements(
28+
input1,
29+
input2,
30+
out,
31+
ScalarType::Char,
32+
/*require_channels_last=*/false,
33+
/*require_same_sizes=*/false);
34+
35+
auto resize_error = resize_to_broadcast_target_size(input1, input2, out);
36+
if (resize_error != Error::Ok) {
37+
ET_LOG(Error, "minimum_out: broadcast shape mismatch between inputs");
38+
context.fail(resize_error);
39+
return out;
40+
}
41+
42+
const int8_t* input1_data = input1.const_data_ptr<int8_t>();
43+
const int8_t* input2_data = input2.const_data_ptr<int8_t>();
44+
int8_t* output_data = out.mutable_data_ptr<int8_t>();
45+
46+
// Create CMSIS-NN dims directly from tensor sizes
47+
const auto input1_rank = input1.dim();
48+
const auto input1_sizes = input1.sizes();
49+
const cmsis_nn_dims input1_dims{
50+
static_cast<int32_t>(
51+
input1_rank >= 4 ? input1_sizes[input1_rank - 4] : 1),
52+
static_cast<int32_t>(
53+
input1_rank >= 3 ? input1_sizes[input1_rank - 3] : 1),
54+
static_cast<int32_t>(
55+
input1_rank >= 2 ? input1_sizes[input1_rank - 2] : 1),
56+
static_cast<int32_t>(
57+
input1_rank >= 1 ? input1_sizes[input1_rank - 1] : 1)};
58+
59+
const auto input2_rank = input2.dim();
60+
const auto input2_sizes = input2.sizes();
61+
const cmsis_nn_dims input2_dims{
62+
static_cast<int32_t>(
63+
input2_rank >= 4 ? input2_sizes[input2_rank - 4] : 1),
64+
static_cast<int32_t>(
65+
input2_rank >= 3 ? input2_sizes[input2_rank - 3] : 1),
66+
static_cast<int32_t>(
67+
input2_rank >= 2 ? input2_sizes[input2_rank - 2] : 1),
68+
static_cast<int32_t>(
69+
input2_rank >= 1 ? input2_sizes[input2_rank - 1] : 1)};
70+
71+
const auto output_rank = out.dim();
72+
const auto output_sizes = out.sizes();
73+
const cmsis_nn_dims output_dims{
74+
static_cast<int32_t>(
75+
output_rank >= 4 ? output_sizes[output_rank - 4] : 1),
76+
static_cast<int32_t>(
77+
output_rank >= 3 ? output_sizes[output_rank - 3] : 1),
78+
static_cast<int32_t>(
79+
output_rank >= 2 ? output_sizes[output_rank - 2] : 1),
80+
static_cast<int32_t>(
81+
output_rank >= 1 ? output_sizes[output_rank - 1] : 1)};
82+
83+
const arm_cmsis_nn_status status = arm_minimum_s8(
84+
/* ctx */ nullptr,
85+
input1_data,
86+
&input1_dims,
87+
input2_data,
88+
&input2_dims,
89+
output_data,
90+
&output_dims);
91+
92+
if (status != ARM_CMSIS_NN_SUCCESS) {
93+
ET_LOG(
94+
Error,
95+
"minimum_out: arm_minimum_s8 failed with status [%d]",
96+
static_cast<int>(status));
97+
context.fail(Error::Internal);
98+
}
99+
100+
return out;
101+
}
102+
103+
} // namespace native
104+
} // namespace cortex_m

backends/cortex_m/ops/operators.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,47 @@ def quantized_mul_impl(
238238
return result
239239

240240

241+
# ===================================================================
242+
# MINIMUM/MAXIMUM OPERATION DEFINITIONS
243+
# ===================================================================
244+
lib.define("minimum(Tensor self, Tensor other) -> Tensor")
245+
lib.define("minimum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)")
246+
247+
248+
@register_fake("cortex_m::minimum")
249+
def minimum_meta(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
250+
assert self.dtype == other.dtype, (
251+
"Cortex-M minimum: dtype mismatch — "
252+
f"got self.dtype={self.dtype}, other.dtype={other.dtype}"
253+
)
254+
broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape)
255+
return torch.empty(broadcasted_shape, dtype=self.dtype, device=self.device)
256+
257+
258+
@impl(lib, "minimum", "CompositeExplicitAutograd")
259+
def minimum_impl(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
260+
return torch.minimum(self, other)
261+
262+
263+
lib.define("maximum(Tensor self, Tensor other) -> Tensor")
264+
lib.define("maximum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)")
265+
266+
267+
@register_fake("cortex_m::maximum")
268+
def maximum_meta(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
269+
assert self.dtype == other.dtype, (
270+
"Cortex-M maximum: dtype mismatch — "
271+
f"got self.dtype={self.dtype}, other.dtype={other.dtype}"
272+
)
273+
broadcasted_shape = torch.broadcast_shapes(self.shape, other.shape)
274+
return torch.empty(broadcasted_shape, dtype=self.dtype, device=self.device)
275+
276+
277+
@impl(lib, "maximum", "CompositeExplicitAutograd")
278+
def maximum_impl(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
279+
return torch.maximum(self, other)
280+
281+
241282
# ===================================================================
242283
# QUANTIZED LINEAR OPERATION DEFINITION
243284
# ===================================================================

backends/cortex_m/ops/operators.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,18 @@
2929
- arg_meta: null
3030
kernel_name: cortex_m::quantized_mul_out
3131

32+
- func: cortex_m::minimum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
33+
variants: function
34+
kernels:
35+
- arg_meta: null
36+
kernel_name: cortex_m::minimum_out
37+
38+
- func: cortex_m::maximum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
39+
variants: function
40+
kernels:
41+
- arg_meta: null
42+
kernel_name: cortex_m::maximum_out
43+
3244
- func: cortex_m::quantized_linear.out(Tensor input, Tensor weights, Tensor? bias, Tensor? kernel_sum, Scalar input_offset, Scalar filter_offset, Scalar output_offset, int[] requantize_multipliers, int[] requantize_shifts, Scalar activation_max, Scalar activation_min, *, Tensor(a!) out) -> Tensor(a!)
3345
variants: function
3446
kernels:

backends/cortex_m/passes/quantized_op_fusion_pass.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,18 @@ def _get_mul_replacement(self, args, meta):
101101

102102
return exir_ops.edge.cortex_m.quantized_mul.default, args
103103

104+
def _get_minimum_replacement(self, args, meta):
105+
if args[0].data.dtype != torch.int8:
106+
return exir_ops.edge.aten.minimum.default, args
107+
108+
return exir_ops.edge.cortex_m.minimum.default, args
109+
110+
def _get_maximum_replacement(self, args, meta):
111+
if args[0].data.dtype != torch.int8:
112+
return exir_ops.edge.aten.maximum.default, args
113+
114+
return exir_ops.edge.cortex_m.maximum.default, args
115+
104116
def _get_permute_replacement(self, args, meta):
105117
if args[0].data.dtype != torch.int8:
106118
return exir_ops.edge.aten.permute_copy.default, args
@@ -123,6 +135,10 @@ def call_operator(
123135
op, args = self._get_add_replacement(args, meta)
124136
case exir_ops.edge.aten.mul.Tensor:
125137
op, args = self._get_mul_replacement(args, meta)
138+
case exir_ops.edge.aten.minimum.default:
139+
op, args = self._get_minimum_replacement(args, meta)
140+
case exir_ops.edge.aten.maximum.default:
141+
op, args = self._get_maximum_replacement(args, meta)
126142
case exir_ops.edge.aten.permute_copy.default:
127143
op, args = self._get_permute_replacement(args, meta)
128144
case _:

0 commit comments

Comments
 (0)