Skip to content

Commit f98eb78

Browse files
committed
Re-use common_includes to propagate shared function
1 parent 205a477 commit f98eb78

File tree

2 files changed

+8
-24
lines changed

2 files changed

+8
-24
lines changed

python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/multi_channel_convolve.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import textwrap
2424

2525
from tvm import te, tir
26-
from .common import num_simd_lanes_per_word
26+
from .common import num_simd_lanes_per_word, common_includes
2727

2828

2929
def _get_func_name(in_dtype, tensor_w, channels, kernel_h, kernel_w, suffix):
@@ -107,10 +107,8 @@ def multi_channel_convolve_impl(in_dtype, *args) -> str:
107107
def _quad_int8_channel_convolve_impl(_tensor_h, tensor_w, channels, kernel_h, kernel_w, suffix):
108108
return textwrap.dedent(
109109
(
110-
f"""
111-
#include <stdint.h>
112-
#include <arm_acle.h>
113-
110+
common_includes
111+
+ f"""
114112
// __SXTB16(_ROR(X, Y)) is combined into one assembly instruction
115113
116114
#define TVMGEN_QUAD_INT8_CHANNEL_REARRANGE_SUM_DSP( \
@@ -172,7 +170,8 @@ def _quad_int8_channel_convolve_impl(_tensor_h, tensor_w, channels, kernel_h, ke
172170
def _dual_int16_channel_convolve_impl(_tensor_h, tensor_w, channels, kernel_h, kernel_w, suffix):
173171
return textwrap.dedent(
174172
(
175-
f"""
173+
common_includes
174+
+ f"""
176175
#include <stdint.h>
177176
178177
/* We do four channels at once to get this speed boost. */

python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from tvm import te, tir
2525

26-
from .common import num_simd_lanes_per_word
26+
from .common import num_simd_lanes_per_word, common_includes
2727

2828

2929
def _get_func_name(in_dtype, tensor_h, jump, tensor_w, suffix):
@@ -123,23 +123,8 @@ def tensordot_impl(in_dtype: str, tensor_h: int, jump: int, tensor_w: int, suffi
123123
function_name = _get_func_name(in_dtype, tensor_h, jump, tensor_w, suffix)
124124
return textwrap.dedent(
125125
(
126-
f"""
127-
#include <stdint.h>
128-
#include <arm_acle.h>
129-
130-
#ifndef ARM_CPU_ROR_EXISTS
131-
#define ARM_CPU_ROR_EXISTS
132-
__attribute__((always_inline)) uint32_t __ror(uint32_t op1, uint32_t op2)
133-
{{
134-
op2 %= 32U;
135-
if (op2 == 0U)
136-
{{
137-
return op1;
138-
}}
139-
return (op1 >> op2) | (op1 << (32U - op2));
140-
}}
141-
#endif
142-
126+
common_includes
127+
+ f"""
143128
#ifdef __cplusplus
144129
extern "C"
145130
#endif

0 commit comments

Comments
 (0)