Skip to content

Commit 39cb5a4

Browse files
committed
Clean up code to prepare for review
1 parent e445bda commit 39cb5a4

File tree

9 files changed

+233
-158
lines changed

9 files changed

+233
-158
lines changed

python/tvm/relay/op/nn/_nn.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -878,26 +878,16 @@ def convert_deformable_conv2d(attrs, inputs, tinfos, desired_layouts):
878878

879879

880880
# QNN ops
881-
@reg.register_alter_op_layout("qnn.conv2d")
882-
def alter_op_layout_requantize(attrs, inputs, tinfos, out_type):
883-
"""Alter the layout of a requantization op."""
884-
return topi.nn.qnn.qnn_conv2d_alter_layout(attrs, inputs, tinfos, out_type)
885-
886-
887-
@reg.register_alter_op_layout("qnn.dense")
888-
def alter_op_layout_requantize(attrs, inputs, tinfos, out_type):
889-
"""Alter the layout of a requantization op."""
890-
return topi.nn.qnn.qnn_dense_alter_layout(attrs, inputs, tinfos, out_type)
891-
892-
893881
@reg.register_alter_op_layout("add")
894-
def alter_op_layout_requantize(attrs, inputs, tinfos, out_type):
895-
"""Alter the layout of a requantization op."""
882+
def alter_op_layout_add(attrs, inputs, tinfos, out_type):
883+
"""Alter the layout of a add op. Useful for fusing the bias constant with an input zero point
884+
constant in a previous quantized op. Only used when previous op is a quantized op, which is why
885+
it lives in topi.nn.qnn."""
896886
return topi.nn.qnn.qnn_add_alter_layout(attrs, inputs, tinfos, out_type)
897887

898888

899889
@reg.register_alter_op_layout("qnn.requantize")
900-
def alter_op_layout_requantize(attrs, inputs, tinfos, out_type):
890+
def alter_op_layout_qnn_requantize(attrs, inputs, tinfos, out_type):
901891
"""Alter the layout of a requantization op."""
902892
return topi.nn.qnn.qnn_requantize_alter_layout(attrs, inputs, tinfos, out_type)
903893

python/tvm/relay/qnn/strategy/arm_cpu.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,34 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
"""Definition of Hexagon operator strategy."""
18-
# pylint: disable=unused-argument,wildcard-import,unused-wildcard-import
17+
"""Quantized operator strategy for Arm CPU. These schedules are only used if the qnn.Legalize pass
18+
is disabled. These schedules only work on fused operators with a bias, as this is a very common use
19+
case. Currently only regular/depthwise conv2d is supported, but qnn_dense should be added."""
1920

2021
from tvm import topi
21-
from .generic import *
22+
from .generic import qnn_conv2d_strategy
2223
from ... import op as _op
2324
from ...op.strategy.generic import is_depthwise_conv2d
2425

26+
2527
@qnn_conv2d_strategy.register("arm_cpu")
26-
def qnn_conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
27-
"""qnn.conv2d strategy for Arm CPU"""
28+
def qnn_conv2d_strategy_arm_cpu(attrs, inputs, _out_type, target):
29+
"""qnn.conv2d strategy for Arm CPU. Currently, the schedules only support Cortex-M processors
30+
with DSP - the qnn.Legalize pass should be run on all others."""
31+
32+
if not (target.features.has_dsp and "cortex-m" in target.mcpu):
33+
raise RuntimeError(
34+
"Quantized Arm schedules only exist for Cortex-M with DSP! "
35+
"The qnn.Legalize pass should be run for other Arm processors."
36+
)
37+
2838
data = inputs[0]
2939
kernel = inputs[1]
3040
data_layout = attrs.data_layout
3141
kernel_layout = attrs.kernel_layout
3242
groups = attrs.groups
3343
strategy = _op.OpStrategy()
44+
3445
if groups == 1:
3546
if data_layout == "NHWC" and kernel_layout == "OHWI":
3647
strategy.add_implementation(
@@ -46,6 +57,6 @@ def qnn_conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
4657
name="qnn_depthwise_conv2d.arm_cpu",
4758
)
4859
else:
49-
raise RuntimeError("Unsupported strategy for group qnn.conv2d")
60+
raise RuntimeError("No Arm Cortex-M DSP strategy exists for generic group qnn.conv2d")
5061

5162
return strategy

python/tvm/topi/arm_cpu/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@
3030
from .group_conv2d import *
3131
from .pooling import *
3232
from .dense import *
33-
from .qnn import *
33+
from .qnn import *

python/tvm/topi/arm_cpu/conv2d.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
conv2d_nhwc_dsp_schedule,
3939
)
4040

41+
4142
@autotvm.register_topi_compute("conv2d_nchw_spatial_pack.arm_cpu")
4243
def conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype):
4344
"""Compute conv2d with NCHW layout"""

python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py

Lines changed: 77 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,18 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""Computes a "jumpy tensordot" operator, which can be used to tensorize many common operators
18-
including regular conv2d, depthwise conv2d, and grouped conv2d provided the data and kernel layouts
19-
are the optimal ones. When groups=1, the optimal data layout is NHWC and kernel layout is OHWI. When
20-
this is a depthwise convolution, the optimal data layout is NCHW and kernel layout is OIHW."""
18+
including regular conv2d, depthwise conv2d, and grouped conv2d for some data and kernel layouts.
19+
When for regular convolution, use data laout HHWC and kernel layout OHWI. For depthwise convolution,
20+
use data layout data layout is NCHW and kernel layout OIHW."""
2121

2222
from itertools import chain
2323
import textwrap
2424
from typing import Iterator, Tuple
2525

26-
import numpy as np
27-
28-
from tvm import te, tir
29-
30-
31-
def get_c_function_name(split_size, dimensions, offsets, x_strides):
32-
"""Gets the C function name of the tensordot function."""
26+
def _get_c_function_name(split_size, dimensions, offsets, x_strides):
27+
"""Gets the C function name of the tensordot function. We do not need a suffix, as the generated
28+
function will have an #include guard. Unlike other microTVM operators, _get_c_function_name is
29+
never called externally."""
3330
tensor_w, kernel_h, kernel_w = dimensions
3431
return (
3532
f"tensordot_opt_x{split_size}_int16_w{tensor_w}_"
@@ -42,7 +39,7 @@ def get_c_function_name(split_size, dimensions, offsets, x_strides):
4239
def _init_biased_accumulators(split_size):
4340
"""Addition is commutative, so we could add the bias before, during, or after performing our
4441
multiply-accumulate operations. It "costs" one cycle either way - if done at the beginning we
45-
can't use our SMULXY trick to set sum_i to zero for "free", and if done at the end it doesn't
42+
can't use a SMULXY trick to set sum_i to zero for "free", and if done at the end it doesn't
4643
combine with anything. However, doing it at the beginning frees up a register/prevents needing
4744
to do a stack push/pop, so we'll do it first."""
4845
assignments = map(lambda x: f"sum_{x:x} = bias", range(split_size))
@@ -64,7 +61,7 @@ def _get_tensor_halfwords(dimensions, offset, split_size, in_stride) -> Iterator
6461

6562

6663
def _get_kernel_halfwords(dimensions, offset) -> Iterator:
67-
_tensor_w, kernel_h, kernel_w = dimensions
64+
_, kernel_h, kernel_w = dimensions
6865
if offset == 1:
6966
yield None
7067
for y in range(kernel_h):
@@ -100,22 +97,32 @@ def _load_kernel_vars(halfwords) -> Iterator[str]:
10097

10198

10299
def _get_draft_macs(kernel_dims, tensor_halfwords, kernel_halfwords, offset) -> Iterator[Tuple]:
100+
"""Generates a functional but un-optimized list of multiply-accumulate instructions that we will
101+
optimize later. The tuples in the returned iterator are organized as:
102+
103+
(instruction, (arg1_y, arg1_x), (arg2_y, arg2_x))
104+
105+
We return an iterator so that optimizations may be done by iterator chaining."""
106+
103107
def get_var(y, x, halfwords):
104108
i = halfwords.index((y, x))
105109
if i % 2 == 0:
106110
return f"{_get_int16_alias((y, x))}__{_get_int16_alias(halfwords[i + 1])}", "b"
107-
else:
108-
return f"{_get_int16_alias(halfwords[i - 1])}__{_get_int16_alias((y, x))}", "t"
111+
return f"{_get_int16_alias(halfwords[i - 1])}__{_get_int16_alias((y, x))}", "t"
109112

110113
kernel_h, kernel_w = kernel_dims
111114
for y in range(kernel_h):
112115
for x in range(kernel_w):
113116
tensor_var, tensor_half = get_var(y, x + offset, tensor_halfwords)
114117
kernel_var, kernel_half = get_var(y, x, kernel_halfwords)
115-
yield f"smla{tensor_half}{kernel_half}", f"tensor__{tensor_var}", f"kernel__{kernel_var}"
118+
instruction = f"smla{tensor_half}{kernel_half}"
119+
yield instruction, f"tensor__{tensor_var}", f"kernel__{kernel_var}"
116120

117121

118122
def _apply_simd_optimizations(instruction_tuples) -> Iterator[Tuple]:
123+
"""Fuses single halfword MAC instructions into double halfword MAC instructions when possible.
124+
The compiler cannot do this automatically, as calling __builtin_arm_smlaxy forces the SMLAxy
125+
instruction to be used."""
119126
curr_tuple = next(instruction_tuples, None)
120127
while curr_tuple:
121128
next_tuple = next(instruction_tuples, None)
@@ -138,60 +145,23 @@ def _apply_simd_optimizations(instruction_tuples) -> Iterator[Tuple]:
138145
curr_tuple = next_tuple
139146

140147

141-
NO_ACC_PREFIX_CONVERSIONS = {
142-
"smlad": "smuad",
143-
"smladx": "smuadx",
144-
"smlatt": "smultt",
145-
"smlatb": "smultb",
146-
"smlabt": "smulbt",
147-
"smlabb": "smulbb",
148-
}
149-
150-
151-
# def _no_first_accumulate(instruction_tuples) -> Iterator[Tuple]:
152-
# ins, op1, op2 = next(instruction_tuples)
153-
# yield NO_ACC_PREFIX_CONVERSIONS[ins], op1, op2
154-
# for instruction_tuple in instruction_tuples:
155-
# yield instruction_tuple
156-
157-
158148
def _expand_instruction_tuples(instruction_tuples, index) -> Iterator[str]:
159-
"""Converts a series of (instruction, var1, var2) tuples into lines of C code. Should be simple,
160-
but we need to work around a series of cryptic bugs while ensuring the compiler makes certain
161-
optimizations.
162-
163-
1. Ideally, we would call __builtin_arm functions instead of including inline assembly, as this
164-
is easier to read and more future proof. However:
165-
a. Arm GCC apparently *forgot* to include `__builtin_arm_smlabt`, even though
166-
`__builtin_arm_smlatt`, `__builtin_arm_smlatb`, `__builtin_arm_smlad` and so on all
167-
exist. These work as expected on Clang - the issue is GCC only.
168-
169-
b. Calling `__builtin_arm_smlatt` (and `smlatb` and `smlabb`) works fine on real devices.
170-
However, calling these builtins causes the Corstone300 simulator to freeze and stall. I
171-
have no clue on why this is - wouldn't these builtins be compiled to assembly? - yet it
172-
occurs consistently.
173-
174-
175-
2. Ideally, the compiler would know that the first multiply instruction should *not* accumulate,
176-
and would automatically replace it with an otherwise identical but non-accumulating
177-
instruction. Doing this saves us one cycle, as we don't need to load a zero into sum_i.
178-
However, the compiler (understandably) does not like overwriting instructions we explicitly
179-
as for, so we must do this ourselves.
180-
181-
3. Ideally, since we're going to emit several lines of assembly code, we would do it in a single
182-
`asm` block. However, we *want* the compiler to reorder the instructions and interleave them
183-
with memory loads, and it can only do this if we specify the instructions as individual non-
184-
volatile memory loads.
149+
"""Converts a series of (instruction, var1, var2) tuples into lines of C code. We want the
150+
compiler to re-order these with the memory loads, so we generate them as a series of calls to
151+
instruction aliases instead of as a single `asm` block.
185152
"""
186-
for instruction, op1, op2 in instruction_tuples:
187-
if "smla" in instruction:
188-
if instruction == "smlabt":
189-
yield f"sum_{index} = __builtin_arm_smlatb({op2}, {op1}, sum_{index});"
190-
else:
191-
yield f"sum_{index} = __builtin_arm_{instruction}({op1}, {op2}, sum_{index});"
192153

154+
for instruction, op1, op2 in instruction_tuples:
155+
assert "smla" in instruction
156+
157+
# Arm GCC does not have `__builtin_arm_smlabt`, even though `__builtin_arm_smlatt`,
158+
# `__builtin_arm_smlatb`, `__builtin_arm_smlad` and so on all exist. Perhaps this is a
159+
# choice, since we can just use `smlabt` with the argument order swapped instead? Note that
160+
# `__builtin_arm_smlabt` exists on most compilers (e.g. Clang) - this is just a GCC thing.
161+
if instruction == "smlabt":
162+
yield f"sum_{index} = __builtin_arm_smlatb({op2}, {op1}, sum_{index});"
193163
else:
194-
yield f'asm ("{instruction} %0, %1, %2" : "=r" (sum_{index}) : "r" ({op1}), "r" ({op2}));'
164+
yield f"sum_{index} = __builtin_arm_{instruction}({op1}, {op2}, sum_{index});"
195165

196166

197167
def _requantize_sums(num_sums) -> Iterator[str]:
@@ -207,9 +177,9 @@ def _requantize_sums(num_sums) -> Iterator[str]:
207177
halfwords in a word, and rearrainging it would take at least one cycle. Two SSAT operations is
208178
just as good.
209179
210-
Lastly, calling __builtin_arm_ssat is a little bit gross, but GCC and Clang are unreliable about
211-
compiling other ways of writing this. Both the multiply + shift and shift + saturation combine
212-
to one instruction each."""
180+
Calling __builtin_arm_ssat directly is a little bit gross, but GCC and Clang are unreliable
181+
about compiling other ways of writing this. Both the multiply + shift and shift + saturation
182+
combine to one instruction each."""
213183

214184
for i in range(num_sums):
215185
yield f"int requant_{i} = (sum_{i} * (long long) requant_scale) >> 32;"
@@ -237,7 +207,7 @@ def _write_sums_to_memory(num_sums, offset, stride) -> Iterator[str]:
237207
yield f"int packed_res_{i} = requant_{index} + (requant_{index + 1} << 16);"
238208

239209
if offset == 1:
240-
yield f"((short*) output)[1] = (short) requant_0;"
210+
yield "((short*) output)[1] = (short) requant_0;"
241211

242212
for i in range(num_halfwords):
243213
yield f"output[{offset + i}] = packed_res_{i};"
@@ -251,12 +221,39 @@ def tensordot_int16_impl(
251221
dimensions: Tuple[int, int, int],
252222
offsets: Tuple[int, int, int],
253223
x_strides: Tuple[int, int],
254-
) -> str:
255-
"""Code for a specialized version of tensordot, which computes `split_size` tensordot operations
224+
) -> Tuple[str, str]:
225+
"""Code for a quantized version of tensordot, which computes `split_size` tensordot operations
256226
at the same time. Only works with `int16`. The generated function takes as input pointers to the
257-
output, tensor, and kernel, which must be word-aligned. However, the stride can be half a word.
227+
output, tensor, and kernel, which must be word-aligned.
228+
229+
Parameters
230+
----------
231+
split_size: int
232+
The number of tensordot values to compute in this function. Computing more than one at once
233+
makes us much faster by reducing how often overlapping data is loaded. However, setting this
234+
too high causes us to run out of registers and need to store data on the stack. We should
235+
autotune this, but split_size=2 is usually OK.
236+
237+
dimensions: Tuple[int, int, int]
238+
The dimensions of each tensordot operation. dimensions[1] and dimensions[2] are the height
239+
and width of the kernel, respectively. dimensions[0] is the width of the data tensor, which
240+
is usually larger than the kernel.
241+
242+
offsets: Tuple[int, int, int]
243+
Each value is 0 or 1, and represents how far after the given data, kernel, and output
244+
pointers (respectively) we should start reading/writing. This prevents us from having to
245+
check if each pointer is aligned or unaligned at runtime, making us faster.
246+
247+
x_strides: Tuple[int, int]
248+
The distance (in halfwords) between the start of each input tensor, and where to write each
249+
output result respectively. Only used when split_size > 1.
250+
251+
Returns
252+
-------
253+
func_name, func_code: Tuple[str, str]
254+
The name and source code of the generated function.
258255
"""
259-
function_name = get_c_function_name(split_size, dimensions, offsets, x_strides)
256+
function_name = _get_c_function_name(split_size, dimensions, offsets, x_strides)
260257
tensor_w, kernel_h, kernel_w = dimensions
261258
tensor_offset, kernel_offset, output_offset = offsets
262259
assert tensor_offset < 2 and kernel_offset < 2 and output_offset < 2
@@ -281,11 +278,14 @@ def gen_single_loop_macs(index):
281278
def insert_lines(lines):
282279
return ("\n" + " " * 10).join(lines)
283280

284-
# __WEAK allows multiple copies of the function to overwrite themselves, saving flash
281+
# It's very common for one model to have different layers that use identical tensordot
282+
# functions. To prevent function re-definition errors, we need an #include guard. This is better
283+
# than adding a random suffix, as it saves flash memory.
285284
code = textwrap.dedent(
286285
f"""
287-
#include <arm_nnsupportfunctions.h>
288-
__STATIC_FORCEINLINE __WEAK int {function_name}(
286+
#ifndef {function_name.upper()}_EXISTS
287+
#define {function_name.upper()}_EXISTS
288+
__attribute__((always_inline)) static inline int {function_name}(
289289
int *output, int *tensor, int *kernel, int bias, int requant_scale
290290
) {{
291291
{_init_biased_accumulators(split_size)}
@@ -301,6 +301,7 @@ def insert_lines(lines):
301301
{insert_lines(write_out_lines)}
302302
return 0;
303303
}}
304+
#endif
304305
"""
305306
)
306307
return (function_name, code)

0 commit comments

Comments
 (0)