Skip to content

Commit 6bc72bb

Browse files
MousiusAshutosh Parkhi
andauthored
[microTVM] Replace arm_nnsupportfunctions.h with arm_acle.h (#13363)
* [microTVM] Replace arm_nnsupportfunctions.h with arm_acle.h This attempts to replace the CMSIS-NN header with a more portable alternative and avoid dependence on CMSIS * Remove CMSIS __STATIC_FORCEINLINE macro * Replace more intrinsics with ACLE variants * Use builtins for intrinsics missing in older GCC * Re-use common_includes to propagate shared functions The packing definitions aren't implemented as ACLE intrinsics nor is there a simple way to convince a C compiler to generate them. * Properly align memory access for Introduce `memcpy` to explain to the compiler that we're changing the alignment of `int16_t` to `int32_t`. What this appears to actually do is encourage the compiler to use three loads rather than one double load plus a regular load. The padded array is aligned as an `int16_t`, it isn't guaranteed to behave like an `int32_t` aligned array. One of the side effects of the type punning from `int16_t*` to `int32_t*` is that we're effectively lying to the compiler that this is correctly aligned and it can use instructions which load multiple `int32_t`s at the same time - this does not work 😿 Co-authored-by: Ashutosh Parkhi <[email protected]>
1 parent a435cbb commit 6bc72bb

File tree

5 files changed

+93
-42
lines changed

5 files changed

+93
-42
lines changed

python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def sum_impl(N, uniq_id):
101101
#ifdef __cplusplus
102102
extern "C"
103103
#endif // __cplusplus
104-
__STATIC_FORCEINLINE int32_t sum16_reset_{uniq_id}(
104+
__attribute__((always_inline)) static inline int32_t sum16_reset_{uniq_id}(
105105
int16_t *res) {{
106106
*res = (int16_t)0;
107107
return 0;
@@ -110,7 +110,7 @@ def sum_impl(N, uniq_id):
110110
#ifdef __cplusplus
111111
extern "C"
112112
#endif
113-
__STATIC_FORCEINLINE int32_t sum16_{N}_{uniq_id}(
113+
__attribute__((always_inline)) static inline int32_t sum16_{N}_{uniq_id}(
114114
int16_t *arr,
115115
int16_t *res16,
116116
long arr_offset,
@@ -129,7 +129,7 @@ def sum_impl(N, uniq_id):
129129
}}
130130
131131
for ( int i = 0; i < n / 2; ++ i ) {{
132-
res = __SMLAD(*p32, 0x00010001, res);
132+
res = __smlad(*p32, 0x00010001, res);
133133
++ p32;
134134
}}
135135

python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/common.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,42 @@
2424
#include <stdlib.h>
2525
#include <string.h>
2626
27-
#include <arm_nnsupportfunctions.h>
27+
#include <arm_acle.h>
2828
2929
#include <tvm/runtime/crt/error_codes.h>
3030
31+
32+
#ifndef ARM_CPU_INTRINSICS_EXIST
33+
#define ARM_CPU_INTRINSICS_EXIST
34+
__attribute__((always_inline)) uint32_t __ror(uint32_t op1, uint32_t op2)
35+
{
36+
op2 %= 32U;
37+
if (op2 == 0U)
38+
{
39+
return op1;
40+
}
41+
return (op1 >> op2) | (op1 << (32U - op2));
42+
}
43+
44+
#define __pkhbt(ARG1,ARG2,ARG3) \
45+
__extension__ \
46+
({ \
47+
uint32_t __RES, __ARG1 = (ARG1), __ARG2 = (ARG2); \
48+
__asm("pkhbt %0, %1, %2, lsl %3" : "=r" (__RES) : "r" (__ARG1), "r" (__ARG2), "I" (ARG3) ); \
49+
__RES; \
50+
})
51+
52+
#define __pkhtb(ARG1,ARG2,ARG3) \
53+
__extension__ \
54+
({ \
55+
uint32_t __RES, __ARG1 = (ARG1), __ARG2 = (ARG2); \
56+
if (ARG3 == 0) \
57+
__asm("pkhtb %0, %1, %2" : "=r" (__RES) : "r" (__ARG1), "r" (__ARG2) ); \
58+
else \
59+
__asm("pkhtb %0, %1, %2, asr %3" : "=r" (__RES) : "r" (__ARG1), "r" (__ARG2), "I" (ARG3) ); \
60+
__RES; \
61+
})
62+
#endif
3163
"""
3264

3365
MICRO_WORD_LENGTH_BITS = 32

python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -132,12 +132,30 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
132132
cc_code = (
133133
common.common_includes
134134
+ f"""
135+
#ifndef ARM_CPU_MPROFILE_READ_AND_PAD_EXISTS
136+
#define ARM_CPU_MPROFILE_READ_AND_PAD_EXISTS
137+
__attribute__((always_inline)) static inline const int8_t *read_and_pad(const int8_t *source, int32_t *out1, int32_t *out2)
138+
{{
139+
int32_t inA;
140+
memcpy(&inA, source, 4);
141+
source += 4;
142+
143+
int32_t inAbuf1 = __sxtb16(__ror((uint32_t)inA, 8));
144+
int32_t inAbuf2 = __sxtb16(inA);
145+
*out2 = (int32_t)(__pkhtb(inAbuf1, inAbuf2, 16));
146+
*out1 = (int32_t)(__pkhbt(inAbuf2, inAbuf1, 16));
147+
148+
return source;
149+
}}
150+
#endif
151+
"""
152+
+ f"""
135153
136154
137155
#ifdef __cplusplus
138156
extern "C"
139157
#endif
140-
__STATIC_FORCEINLINE int32_t gemm_{M}x{N}_body_rest_{uniq_id}(
158+
__attribute__((always_inline)) static inline int32_t gemm_{M}x{N}_body_rest_{uniq_id}(
141159
int K,
142160
int8_t *aa, int8_t *bb, int32_t *cc,
143161
int A_stride, int B_stride, int C_stride) {{
@@ -180,7 +198,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
180198
#ifdef __cplusplus
181199
extern "C"
182200
#endif
183-
__STATIC_FORCEINLINE int32_t gemm_{M}x{K}x{N}_body_loop_{uniq_id}(
201+
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_body_loop_{uniq_id}(
184202
int8_t *aa, int8_t *bb, int32_t *cc,
185203
int A_stride, int B_stride, int C_stride) {{
186204
for (int i = 0; i < {M}; i++) {{
@@ -201,7 +219,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
201219
#ifdef __cplusplus
202220
extern "C"
203221
#endif
204-
__STATIC_FORCEINLINE int32_t gemm_{M}x{K}x{N}_body_{uniq_id}(
222+
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_body_{uniq_id}(
205223
int8_t *aa, int8_t *bb, int32_t *cc,
206224
int A_stride, int B_stride, int C_stride) {{
207225
int16_t bb_pad[{bb_pad_size}];
@@ -226,7 +244,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
226244
int32_t *bb_ptr = (int32_t *) &bb_pad[j*{K}];
227245
int32_t sum = 0;
228246
for (int l = 0; l < 2 * ({K} / 4); l++) {{
229-
sum = __SMLAD(*aa_ptr, *bb_ptr, sum);
247+
sum = __smlad(*aa_ptr, *bb_ptr, sum);
230248
++ aa_ptr; ++ bb_ptr;
231249
}}
232250
// NOTE: this is the line where `*_body` differs from `*_update`. here
@@ -246,7 +264,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
246264
#ifdef __cplusplus
247265
extern "C"
248266
#endif
249-
__STATIC_FORCEINLINE int32_t gemm_{M}x{N}_update_rest_{uniq_id}(
267+
__attribute__((always_inline)) static inline int32_t gemm_{M}x{N}_update_rest_{uniq_id}(
250268
int K,
251269
int8_t *aa, int8_t *bb, int32_t *cc,
252270
int A_stride, int B_stride, int C_stride) {{
@@ -289,7 +307,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
289307
#ifdef __cplusplus
290308
extern "C"
291309
#endif
292-
__STATIC_FORCEINLINE int32_t gemm_{M}x{K}x{N}_update_loop_{uniq_id}(
310+
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_update_loop_{uniq_id}(
293311
int8_t *aa, int8_t *bb, int32_t *cc,
294312
int A_stride, int B_stride, int C_stride) {{
295313
for (int i = 0; i < {M}; i++) {{
@@ -307,7 +325,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
307325
#ifdef __cplusplus
308326
extern "C"
309327
#endif
310-
__STATIC_FORCEINLINE int32_t gemm_{M}x{K}x{N}_update_{uniq_id}(
328+
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_update_{uniq_id}(
311329
int8_t *aa, int8_t *bb, int32_t *cc,
312330
int A_stride, int B_stride, int C_stride) {{
313331
int16_t bb_pad[{bb_pad_size}];
@@ -332,7 +350,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
332350
int32_t *bb_ptr = (int32_t *) &bb_pad[j*{K}];
333351
int32_t sum = 0;
334352
for (int l = 0; l < 2 * ({K} / 4); l++) {{
335-
sum = __SMLAD(*aa_ptr, *bb_ptr, sum);
353+
sum = __smlad(*aa_ptr, *bb_ptr, sum);
336354
++ aa_ptr; ++ bb_ptr;
337355
}}
338356
cc[i*C_stride + j] += sum;
@@ -349,7 +367,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
349367
#ifdef __cplusplus
350368
extern "C"
351369
#endif
352-
__STATIC_FORCEINLINE int32_t gemm16_{M}x{N}_body_rest_{uniq_id}(
370+
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{N}_body_rest_{uniq_id}(
353371
int K,
354372
int16_t *aa, int16_t *bb, int32_t *cc,
355373
int A_stride, int B_stride, int C_stride) {{
@@ -367,7 +385,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
367385
#ifdef __cplusplus
368386
extern "C"
369387
#endif
370-
__STATIC_FORCEINLINE int32_t gemm16_{M}x{K}x{N}_body_loop_{uniq_id}(
388+
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_body_loop_{uniq_id}(
371389
int16_t *aa, int16_t *bb, int32_t *cc,
372390
int A_stride, int B_stride, int C_stride) {{
373391
for (int i = 0; i < {M}; i++) {{
@@ -388,7 +406,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
388406
#ifdef __cplusplus
389407
extern "C"
390408
#endif
391-
__STATIC_FORCEINLINE int32_t gemm16_{M}x{K}x{N}_body_{uniq_id}(
409+
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_body_{uniq_id}(
392410
int16_t *aa, int16_t *bb, int32_t *cc,
393411
int A_stride, int B_stride, int C_stride) {{
394412
int32_t retcode = 0;
@@ -405,13 +423,14 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
405423
406424
for (int i = 0; i < {M}; i++) {{
407425
for (int j = 0; j < {N}; j++) {{
408-
int32_t *aa_ptr = (int32_t *) &aa[i*A_stride];
409-
int32_t *bb_ptr = (int32_t *) &bb[j*B_stride];
426+
int32_t aa_vector[{K} / 2];
427+
int32_t bb_vector[{K} / 2];
428+
memcpy(&aa_vector, &aa[i * A_stride], sizeof(aa_vector));
429+
memcpy(&bb_vector, &bb[j * B_stride], sizeof(bb_vector));
410430
411431
int32_t sum = 0;
412432
for (int l = 0; l < {K} / 2; l++) {{
413-
sum = __SMLAD(*aa_ptr, *bb_ptr, sum);
414-
++ aa_ptr; ++ bb_ptr;
433+
sum = __smlad(aa_vector[l], bb_vector[l], sum);
415434
}}
416435
// NOTE: this is the line where `*_body` differs from `*_update`. here
417436
// we're *setting* the result, instead of accumulating, because we know
@@ -430,7 +449,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
430449
#ifdef __cplusplus
431450
extern "C"
432451
#endif
433-
__STATIC_FORCEINLINE int32_t gemm16_{M}x{N}_update_rest_{uniq_id}(
452+
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{N}_update_rest_{uniq_id}(
434453
int K,
435454
int16_t *aa, int16_t *bb, int32_t *cc,
436455
int A_stride, int B_stride, int C_stride) {{
@@ -448,7 +467,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
448467
#ifdef __cplusplus
449468
extern "C"
450469
#endif
451-
__STATIC_FORCEINLINE int32_t gemm16_{M}x{K}x{N}_update_loop_{uniq_id}(
470+
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_update_loop_{uniq_id}(
452471
int16_t *aa, int16_t *bb, int32_t *cc,
453472
int A_stride, int B_stride, int C_stride) {{
454473
for (int i = 0; i < {M}; i++) {{
@@ -466,7 +485,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
466485
#ifdef __cplusplus
467486
extern "C"
468487
#endif
469-
__STATIC_FORCEINLINE int32_t gemm16_{M}x{K}x{N}_update_{uniq_id}(
488+
__attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_update_{uniq_id}(
470489
int16_t *aa, int16_t *bb, int32_t *cc,
471490
int A_stride, int B_stride, int C_stride) {{
472491
int32_t retcode = 0;
@@ -478,13 +497,14 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
478497
479498
for (int i = 0; i < {M}; i++) {{
480499
for (int j = 0; j < {N}; j++) {{
481-
int32_t *aa_ptr = (int32_t *) &aa[i*A_stride];
482-
int32_t *bb_ptr = (int32_t *) &bb[j*B_stride];
500+
int32_t aa_vector[{K} / 2];
501+
int32_t bb_vector[{K} / 2];
502+
memcpy(&aa_vector, &aa[i * A_stride], sizeof(aa_vector));
503+
memcpy(&bb_vector, &bb[j * B_stride], sizeof(bb_vector));
483504
484505
int32_t sum = 0;
485506
for (int l = 0; l < {K} / 2; l++) {{
486-
sum = __SMLAD(*aa_ptr, *bb_ptr, sum);
487-
++ aa_ptr; ++ bb_ptr;
507+
sum = __smlad(aa_vector[l], bb_vector[l], sum);
488508
}}
489509
cc[i*C_stride + j] += sum;
490510
}}
@@ -500,7 +520,7 @@ def gemm_MxKxN_impl(M, K, N, uniq_id):
500520
#ifdef __cplusplus
501521
extern "C"
502522
#endif
503-
__STATIC_FORCEINLINE int32_t gemm_{M}x{K}x{N}_reset_{uniq_id}(int32_t *cc, int C_stride) {{
523+
__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_reset_{uniq_id}(int32_t *cc, int C_stride) {{
504524
for (int i = 0; i < {M}; i++) {{
505525
for (int j = 0; j < {N}; j++) {{
506526
cc[i*C_stride + j] = 0;

python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def max_impl(uniq_id):
9494
#ifdef __cplusplus
9595
extern "C"
9696
#endif
97-
__STATIC_FORCEINLINE int32_t max8_reset_{uniq_id}(
97+
__attribute__((always_inline)) static inline int32_t max8_reset_{uniq_id}(
9898
int8_t *res,
9999
int N) {{
100100
memset(res, (int8_t)-128, N * sizeof(*res));
@@ -104,7 +104,7 @@ def max_impl(uniq_id):
104104
#ifdef __cplusplus
105105
extern "C"
106106
#endif
107-
__STATIC_FORCEINLINE int32_t max8_loop_{uniq_id}(
107+
__attribute__((always_inline)) static inline int32_t max8_loop_{uniq_id}(
108108
int8_t *arg,
109109
int8_t *res,
110110
int N) {{
@@ -117,7 +117,7 @@ def max_impl(uniq_id):
117117
#ifdef __cplusplus
118118
extern "C"
119119
#endif
120-
__STATIC_FORCEINLINE int32_t max8_{uniq_id}(
120+
__attribute__((always_inline)) static inline int32_t max8_{uniq_id}(
121121
int8_t *arg,
122122
int8_t *res,
123123
int N) {{
@@ -146,8 +146,8 @@ def max_impl(uniq_id):
146146
for ( int i = 0; i < N / 4; ++ i ) {{
147147
int32_t arg32 = *parg32 ++;
148148
int32_t res32 = *pres32;
149-
__SSUB8(arg32, res32);
150-
res32 = __SEL(arg32, res32);
149+
__ssub8(arg32, res32);
150+
res32 = __sel(arg32, res32);
151151
*pres32 ++ = res32;
152152
}}
153153

python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/multi_channel_convolve.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import textwrap
2424

2525
from tvm import te, tir
26-
from .common import num_simd_lanes_per_word
26+
from .common import num_simd_lanes_per_word, common_includes
2727

2828

2929
def _get_func_name(in_dtype, tensor_w, channels, kernel_h, kernel_w, suffix):
@@ -107,10 +107,8 @@ def multi_channel_convolve_impl(in_dtype, *args) -> str:
107107
def _quad_int8_channel_convolve_impl(_tensor_h, tensor_w, channels, kernel_h, kernel_w, suffix):
108108
return textwrap.dedent(
109109
(
110-
f"""
111-
#include <stdint.h>
112-
#include <arm_nnsupportfunctions.h>
113-
110+
common_includes
111+
+ f"""
114112
// __SXTB16(_ROR(X, Y)) is combined into one assembly instruction
115113
116114
#define TVMGEN_QUAD_INT8_CHANNEL_REARRANGE_SUM_DSP( \
@@ -120,13 +118,13 @@ def _quad_int8_channel_convolve_impl(_tensor_h, tensor_w, channels, kernel_h, ke
120118
\
121119
uint32_t kernel_c3210 = *arranged_kernel++; \
122120
\
123-
uint32_t tensor_c20 = __SXTB16(tensor_c3210); \
124-
uint32_t kernel_c20 = __SXTB16(kernel_c3210); \
121+
uint32_t tensor_c20 = __sxtb16(tensor_c3210); \
122+
uint32_t kernel_c20 = __sxtb16(kernel_c3210); \
125123
sum_c0 = __builtin_arm_smlabb(tensor_c20, kernel_c20, sum_c0); \
126124
sum_c2 = __builtin_arm_smlatt(tensor_c20, kernel_c20, sum_c2); \
127125
\
128-
uint32_t tensor_c31 = __SXTB16(__ROR(tensor_c3210, 8)); \
129-
uint32_t kernel_c31 = __SXTB16(__ROR(kernel_c3210, 8)); \
126+
uint32_t tensor_c31 = __sxtb16(__ror(tensor_c3210, 8)); \
127+
uint32_t kernel_c31 = __sxtb16(__ror(kernel_c3210, 8)); \
130128
sum_c1 = __builtin_arm_smlabb(tensor_c31, kernel_c31, sum_c1); \
131129
sum_c3 = __builtin_arm_smlatt(tensor_c31, kernel_c31, sum_c3); \
132130
}}
@@ -172,7 +170,8 @@ def _quad_int8_channel_convolve_impl(_tensor_h, tensor_w, channels, kernel_h, ke
172170
def _dual_int16_channel_convolve_impl(_tensor_h, tensor_w, channels, kernel_h, kernel_w, suffix):
173171
return textwrap.dedent(
174172
(
175-
f"""
173+
common_includes
174+
+ f"""
176175
#include <stdint.h>
177176
178177
/* We do four channels at once to get this speed boost. */

0 commit comments

Comments
 (0)