Skip to content

Commit 9bd3598

Browse files
committed
Address @areusch code review
1 parent 8158823 commit 9bd3598

File tree

3 files changed

+271
-189
lines changed

3 files changed

+271
-189
lines changed

python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py

Lines changed: 148 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -22,30 +22,46 @@
2222
and kernel layout OIHW.
2323
"""
2424

25-
from collections import namedtuple
25+
from dataclasses import dataclass
2626
from itertools import chain
2727
import textwrap
28-
from typing import Iterator, Tuple
28+
from typing import Iterator, Optional, Tuple
2929

30-
SMLAInstruction = namedtuple("SMLAInstruction", ["instruction", "tensor_var", "kernel_var"])
3130

31+
@dataclass
32+
class SMLAInstruction:
33+
"""Class for keeping track of an item in inventory."""
3234

33-
def _get_c_function_name(split_size, dimensions, offsets, x_strides):
35+
instruction: str
36+
tensor_var: str
37+
kernel_var: str
38+
39+
def call_with_acle(self, accumulator_var: str) -> str:
40+
return (
41+
f"{accumulator_var} = __{self.instruction}"
42+
f"({self.tensor_var}, {self.kernel_var}, {accumulator_var});"
43+
)
44+
45+
def has_same_operands(self, other: "SMLAInstruction") -> bool:
46+
return self.tensor_var == other.tensor_var and self.kernel_var == other.kernel_var
47+
48+
49+
def _get_c_function_name(num_outputs, dimensions, offsets, x_strides):
3450
"""Generates a C function name for tensordot.
3551
3652
We do not need a suffix, as the generated function will have an #include guard. Unlike other
3753
microTVM operators, _get_c_function_name is never called externally.
3854
"""
3955
tensor_w, kernel_h, kernel_w = dimensions
4056
return (
41-
f"tensordot_opt_x{split_size}_int16_w{tensor_w}_"
57+
f"tensordot_opt_x{num_outputs}_int16_w{tensor_w}_"
4258
+ f"{kernel_h}x{kernel_w}_"
4359
+ "".join(map(str, offsets))
44-
+ (f"_{x_strides[0]}_{x_strides[1]}" if split_size > 1 else "")
60+
+ (f"_{x_strides[0]}_{x_strides[1]}" if num_outputs > 1 else "")
4561
)
4662

4763

48-
def _init_biased_accumulators(split_size):
64+
def _init_biased_accumulators(num_outputs):
4965
"""Generates code to load the bias into the accumulators.
5066
5167
Addition is commutative, so we could add the bias before, during, or after performing our
@@ -55,37 +71,102 @@ def _init_biased_accumulators(split_size):
5571
trick to set sum_i to zero for "free"). However, doing it at the beginning frees up a register,
5672
so we'll do it first.
5773
"""
58-
assignments = map(lambda x: f"sum_{x:x} = *bias", range(split_size))
74+
assignments = map(lambda x: f"sum_{x:x} = *bias", range(num_outputs))
5975
joined_assignments = ", ".join(assignments)
60-
return f"int {joined_assignments};"
61-
76+
return f"int32_t {joined_assignments};"
77+
78+
79+
def _get_tensor_halfwords(dimensions, offset, num_outputs, in_stride) -> Iterator[Optional[Tuple]]:
80+
"""Gets the data that will be stored in memory at the tensor pointer.
81+
82+
Returns an Iterator of Optional[Tuple], while skipping over word-aligned pairs of unrelated
83+
halfwords. The returned iterator is as short as possible while having even length and containing
84+
all relevant tensor data. Tuples in the returned Iterator represent an (y, x) offset from the
85+
top-left tensor position being used in this convolution. We need to be aware of the None values
86+
so our code is correctly word-aligned.
87+
88+
One consequence of these requirements - each row in the tensor is broken into word-aligned pairs
89+
of halfwords (which are later combined into full words). See the examples below:
90+
91+
A simple 3x3 depthwise convolution computing one output and with in_stride = 1. Note that each
92+
row is padded with None at the end to make the rows word-aligned.
93+
>>> _get_tensor_halfwords((48, 3, 3), 0, 1, 1) # doctest: +NORMALIZE_WHITESPACE
94+
[(0, 0), (0, 1), (0, 2), None,
95+
(1, 0), (1, 1), (1, 2), None,
96+
(2, 0), (2, 1), (2, 2), None]
97+
98+
If the tensor width is odd, padding alternates before/after every row.
99+
>>> _get_tensor_halfwords((49, 3, 3), 0, 1, 1) # doctest: +NORMALIZE_WHITESPACE
100+
[(0, 0), (0, 1), (0, 2), None,
101+
None, (1, 0), (1, 1), (1, 2),
102+
(2, 0), (2, 1), (2, 2), None]
103+
104+
If we are computing multiple outputs, more tensor data becomes relevant.
105+
>>> _get_tensor_halfwords((48, 3, 3), 0, 2, 1) # doctest: +NORMALIZE_WHITESPACE
106+
[(0, 0), (0, 1), (0, 2), (0, 3),
107+
(1, 0), (1, 1), (1, 2), (1, 3),
108+
(2, 0), (2, 1), (2, 2), (2, 3)]
109+
110+
Setting in_stride > 1 also makes more tensor data relevant, and setting offset=1 means tensor
111+
data starts one position after the tensor pointer.
112+
113+
>>> _get_tensor_halfwords((49, 3, 3), 1, 2, 2) # doctest: +NORMALIZE_WHITESPACE
114+
[None, (0, 0), (0, 1), (0, 2), (0, 3), (0, 4),
115+
(1, 0), (1, 1), (1, 2), (1, 3), (1, 4), None,
116+
None, (2, 0), (2, 1), (2, 2), (2, 3), (2, 4)]
117+
"""
62118

63-
def _get_tensor_halfwords(dimensions, offset, split_size, in_stride) -> Iterator:
64119
tensor_w, kernel_h, kernel_w = dimensions
120+
max_x_val = (num_outputs - 1) * in_stride + kernel_w
121+
halfwords = []
65122

66-
split_max = (split_size - 1) * in_stride
67123
for y in range(kernel_h):
68-
if y * tensor_w % 2 + offset == 1:
69-
yield None
70-
for x in range(kernel_w + split_max):
71-
yield (y, x)
72-
if (y * tensor_w + kernel_w + split_max + offset) % 2 == 1:
73-
yield None
124+
# If needed, pad so the beginning of the row is word-aligned
125+
if (y * tensor_w + offset) % 2 == 1:
126+
halfwords.append(None)
127+
128+
for x in range(max_x_val):
129+
halfwords.append((y, x))
74130

131+
# If needed, pad so the row length is word aligned
132+
if (y * tensor_w + offset + max_x_val) % 2 == 1:
133+
halfwords.append(None)
134+
return halfwords
75135

76-
def _get_kernel_halfwords(dimensions, offset) -> Iterator:
136+
137+
def _get_kernel_halfwords(dimensions, offset) -> Iterator[Optional[Tuple]]:
138+
"""Gets the data that will be stored in memory at the kernel pointer.
139+
140+
Returns an Iterator of Optional[Tuple]. The returned iterator is as short as possible while
141+
having even length and containing all kernel data. Tuples in the returned Iterator represent
142+
an (y, x) position in the kernel, while None values represent other, irrelevant data. We need
143+
to be aware of the None values so our code is correctly word-aligned.
144+
145+
>>> _get_kernel_halfwords((96, 3, 3), 0)
146+
[(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2), None]
147+
>>> _get_kernel_halfwords((48, 1, 4), 1)
148+
[None, (0, 0), (0, 1), (0, 2), (0, 3), None]
149+
"""
77150
_, kernel_h, kernel_w = dimensions
151+
halfwords = []
152+
153+
# Kernel data starts `offset` places after the pointer value
78154
if offset == 1:
79-
yield None
155+
halfwords.append(None)
156+
80157
for y in range(kernel_h):
81158
for x in range(kernel_w):
82-
yield (y, x)
159+
halfwords.append((y, x))
160+
161+
# Make sure the returned iterator has even length by padding with an "unknown" value. We want
162+
# even length as this corresponds to an integer number of int32 words.
83163
if (kernel_h * kernel_w + offset) % 2 == 1:
84-
yield None
164+
halfwords.append(None)
165+
return halfwords
85166

86167

87168
def _get_int16_alias(position) -> str:
88-
if not position:
169+
if position is None:
89170
return "unknown"
90171
y, x = position
91172
return f"y{y:0>2x}_x{x:0>2x}"
@@ -96,17 +177,17 @@ def _load_tensor_vars(halfwords, tensor_w) -> Iterator[str]:
96177
offset = int(not bool(halfwords[0]))
97178

98179
for i in range(0, len(halfwords), 2):
99-
var_name = "__".join(map(_get_int16_alias, halfwords[i : i + 2]))
180+
var_name = f"{_get_int16_alias(halfwords[i])}__{_get_int16_alias(halfwords[i+1])}"
100181
y, x = halfwords[i + 1] or halfwords[i]
101182
tensor_index = (y * tensor_w + x + offset) // 2
102-
yield f"int tensor__{var_name} = tensor[{tensor_index}];"
183+
yield f"int32_t tensor__{var_name} = tensor[{tensor_index}];"
103184

104185

105186
def _load_kernel_vars(halfwords) -> Iterator[str]:
106187
assert len(halfwords) % 2 == 0
107188
for i in range(0, len(halfwords), 2):
108-
var_name = "__".join(map(_get_int16_alias, halfwords[i : i + 2]))
109-
yield f"int kernel__{var_name} = kernel[{i // 2}];"
189+
var_name = f"{_get_int16_alias(halfwords[i])}__{_get_int16_alias(halfwords[i+1])}"
190+
yield f"int32_t kernel__{var_name} = kernel[{i // 2}];"
110191

111192

112193
def _get_draft_macs(
@@ -147,16 +228,17 @@ def _apply_simd_optimizations(instruction_tuples) -> Iterator[SMLAInstruction]:
147228
curr_tuple = next(instruction_tuples, None)
148229
while curr_tuple:
149230
next_tuple = next(instruction_tuples, None)
150-
if not next_tuple:
231+
if next_tuple is None:
151232
yield curr_tuple
152233
break
153234

154-
if curr_tuple[1:] == next_tuple[1:]:
155-
if set([curr_tuple[0], next_tuple[0]]) == set(["smlatt", "smlabb"]):
156-
yield SMLAInstruction("smlad", *curr_tuple[1:])
235+
if curr_tuple.has_same_operands(next_tuple):
236+
instructions = sorted([curr_tuple.instruction, next_tuple.instruction])
237+
if instructions == ["smlabb", "smlatt"]:
238+
yield SMLAInstruction("smlad", curr_tuple.tensor_var, curr_tuple.kernel_var)
157239
next_tuple = next(instruction_tuples, None)
158-
elif set([curr_tuple[0], next_tuple[0]]) == set(["smlatb", "smlabt"]):
159-
yield SMLAInstruction("smladx", *curr_tuple[1:])
240+
elif instructions == ["smlabt", "smlatb"]:
241+
yield SMLAInstruction("smladx", curr_tuple.tensor_var, curr_tuple.kernel_var)
160242
next_tuple = next(instruction_tuples, None)
161243
else:
162244
yield curr_tuple
@@ -173,15 +255,15 @@ def _expand_instruction_tuples(instruction_tuples, index) -> Iterator[str]:
173255
calls to instruction aliases instead of as a single `asm` block.
174256
"""
175257

176-
for instruction, op1, op2 in instruction_tuples:
177-
assert "smla" in instruction
258+
for smla_instruction in instruction_tuples:
259+
assert "smla" in smla_instruction.instruction
178260

179261
# We call the instruction using the Arm C Language Extensions. Using ACLE gives better
180262
# cross-compiler compatibility than using __builtin functions.
181-
yield f"sum_{index} = __{instruction}({op1}, {op2}, sum_{index});"
263+
yield smla_instruction.call_with_acle(f"sum_{index}")
182264

183265

184-
def _requantize_sums(num_sums, requantize_shift, output_zero_point) -> Iterator[str]:
266+
def _requantize_sums(num_outputs, requantize_shift, output_zero_point) -> Iterator[str]:
185267
"""Generates code to requantize the accumulator values.
186268
187269
The generated code does not use floating point instructions, as it simulates floating point
@@ -202,14 +284,14 @@ def _requantize_sums(num_sums, requantize_shift, output_zero_point) -> Iterator[
202284
instruction each.
203285
"""
204286

205-
yield "int scale_val = *scale;"
206-
for i in range(num_sums):
207-
yield f"int requant_{i} = (sum_{i} * (long long) scale_val) >> {requantize_shift - 1};"
287+
yield "int32_t scale_val = *scale;"
288+
for i in range(num_outputs):
289+
yield f"int32_t requant_{i} = (sum_{i} * (int64_t) scale_val) >> {requantize_shift - 1};"
208290
yield f"requant_{i} = (requant_{i} + 1) >> 1;"
209291
yield f"requant_{i} = __ssat(requant_{i} + {output_zero_point}, 8);"
210292

211293

212-
def _write_sums_to_memory(num_sums, offset, stride) -> Iterator[str]:
294+
def _write_sums_to_memory(num_outputs, offset, stride) -> Iterator[str]:
213295
"""Generates code to write the requantized sums to memory.
214296
215297
Note - halfword packing here *does* help. It seems
@@ -222,27 +304,27 @@ def _write_sums_to_memory(num_sums, offset, stride) -> Iterator[str]:
222304
"""
223305

224306
if stride > 1:
225-
for i in range(num_sums):
226-
yield f"((short*) output)[{i * stride + offset}] = (short) requant_{i};"
307+
for i in range(num_outputs):
308+
yield f"((int16_t*) output)[{i * stride + offset}] = (int16_t) requant_{i};"
227309

228310
else:
229-
num_halfwords = (num_sums - offset) // 2
230-
for i in range(num_halfwords):
311+
num_packed = (num_outputs - offset) // 2
312+
for i in range(num_packed):
231313
index = 2 * i + offset
232-
yield f"int packed_res_{i} = requant_{index} + (requant_{index + 1} << 16);"
314+
yield f"int32_t packed_res_{i} = requant_{index} + (requant_{index + 1} << 16);"
233315

234316
if offset == 1:
235-
yield "((short*) output)[1] = (short) requant_0;"
317+
yield "((int16_t*) output)[1] = (int16_t) requant_0;"
236318

237-
for i in range(num_halfwords):
319+
for i in range(num_packed):
238320
yield f"output[{offset + i}] = packed_res_{i};"
239321

240-
if (offset + num_sums) % 2 == 1:
241-
yield f"((short*) output)[{num_halfwords * 2}] = (short) requant_{num_halfwords * 2};"
322+
if (offset + num_outputs) % 2 == 1:
323+
yield f"((int16_t*) output)[{num_packed * 2}] = (int16_t) requant_{num_packed * 2};"
242324

243325

244326
def tensordot_int16_impl(
245-
split_size: int,
327+
num_outputs: int,
246328
dimensions: Tuple[int, int, int],
247329
offsets: Tuple[int, int, int],
248330
x_strides: Tuple[int, int],
@@ -257,11 +339,11 @@ def tensordot_int16_impl(
257339
258340
Parameters
259341
----------
260-
split_size: int
261-
The number of tensordot values to compute in this function. Computing more than one at once
262-
makes us much faster by reducing how often overlapping data is loaded. However, setting this
263-
too high causes us to run out of registers and need to store data on the stack. We should
264-
autotune this, but split_size=2 is usually OK.
342+
num_outputs: int
343+
The number of tensordot outputs to compute per function call. Computing more than one at
344+
once makes us much faster by reducing how often overlapping data is loaded. However, setting
345+
this too high causes us to run out of registers and need to store data on the stack. We
346+
should autotune this, but num_outputs=2 is usually OK.
265347
266348
dimensions: Tuple[int, int, int]
267349
The dimensions of each tensordot operation. dimensions[1] and dimensions[2] are the height
@@ -275,7 +357,7 @@ def tensordot_int16_impl(
275357
276358
x_strides: Tuple[int, int]
277359
The distance (in halfwords) between the start of each input tensor, and where to write each
278-
output result respectively. Only used when split_size > 1.
360+
output result respectively. Only used when num_outputs > 1.
279361
280362
requantize_shift: int
281363
The distance to right shift after multiplying by the requantization scale. Defaults to 33,
@@ -290,14 +372,14 @@ def tensordot_int16_impl(
290372
func_name, func_code: Tuple[str, str]
291373
The name and source code of the generated function.
292374
"""
293-
function_name = _get_c_function_name(split_size, dimensions, offsets, x_strides)
375+
function_name = _get_c_function_name(num_outputs, dimensions, offsets, x_strides)
294376
tensor_w, kernel_h, kernel_w = dimensions
295377
tensor_offset, kernel_offset, output_offset = offsets
296378
assert tensor_offset < 2 and kernel_offset < 2 and output_offset < 2
297379
in_stride, out_stride = x_strides
298380

299-
tensor_halfwords = list(_get_tensor_halfwords(dimensions, tensor_offset, split_size, in_stride))
300-
kernel_halfwords = list(_get_kernel_halfwords(dimensions, kernel_offset))
381+
tensor_halfwords = _get_tensor_halfwords(dimensions, tensor_offset, num_outputs, in_stride)
382+
kernel_halfwords = _get_kernel_halfwords(dimensions, kernel_offset)
301383
load_tensor_lines = _load_tensor_vars(tensor_halfwords, tensor_w)
302384
load_kernel_lines = _load_kernel_vars(kernel_halfwords)
303385

@@ -308,11 +390,11 @@ def gen_single_loop_macs(index):
308390
draft_macs_iter = _apply_simd_optimizations(draft_macs_iter)
309391
return _expand_instruction_tuples(draft_macs_iter, index)
310392

311-
multiply_acc_lines = chain.from_iterable(gen_single_loop_macs(i) for i in range(split_size))
393+
multiply_acc_lines = chain.from_iterable(gen_single_loop_macs(i) for i in range(num_outputs))
312394
requantize_lines = _requantize_sums(
313-
split_size, requantize_shift=requantize_shift, output_zero_point=output_zero_point
395+
num_outputs, requantize_shift=requantize_shift, output_zero_point=output_zero_point
314396
)
315-
write_out_lines = _write_sums_to_memory(split_size, output_offset, out_stride)
397+
write_out_lines = _write_sums_to_memory(num_outputs, output_offset, out_stride)
316398

317399
def insert_lines(lines):
318400
return ("\n" + " " * 10).join(lines)
@@ -325,10 +407,10 @@ def insert_lines(lines):
325407
#ifndef {function_name.upper()}_EXISTS
326408
#define {function_name.upper()}_EXISTS
327409
#include <arm_acle.h>
328-
__attribute__((always_inline)) static inline int {function_name}(
329-
int *output, int *tensor, int *kernel, int *bias, int *scale
410+
__attribute__((always_inline)) static inline int32_t {function_name}(
411+
int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale
330412
) {{
331-
{_init_biased_accumulators(split_size)}
413+
{_init_biased_accumulators(num_outputs)}
332414
333415
{insert_lines(load_tensor_lines)}
334416

0 commit comments

Comments
 (0)