1414# KIND, either express or implied. See the License for the
1515# specific language governing permissions and limitations
1616# under the License.
17- """Computes a "jumpy tensordot" operator, which can be used to tensorize many common operators
18- including regular conv2d, depthwise conv2d, and grouped conv2d for some data and kernel layouts.
19- When for regular convolution, use data laout HHWC and kernel layout OHWI. For depthwise convolution,
20- use data layout data layout is NCHW and kernel layout OIHW."""
17+ """Generates optimized code to compute a tensor dot product on ARMv7E-M.
18+
19+ This function can be used to tensorize many common operators including regular conv2d, depthwise
20+ conv2d, and grouped conv2d for some data and kernel layouts. When for regular convolution, use data
21+ layout HHWC and kernel layout OHWI. For depthwise convolution, use data layout data layout is NCHW
22+ and kernel layout OIHW.
23+ """
2124
2225from itertools import chain
2326import textwrap
2427from typing import Iterator , Tuple
2528
2629
2730def _get_c_function_name (split_size , dimensions , offsets , x_strides ):
28- """Gets the C function name of the tensordot function. We do not need a suffix, as the generated
29- function will have an #include guard. Unlike other microTVM operators, _get_c_function_name is
30- never called externally."""
31+ """Generates a C function name for tensordot.
32+
33+ We do not need a suffix, as the generated function will have an #include guard. Unlike other
34+ microTVM operators, _get_c_function_name is never called externally.
35+ """
3136 tensor_w , kernel_h , kernel_w = dimensions
3237 return (
3338 f"tensordot_opt_x{ split_size } _int16_w{ tensor_w } _"
@@ -38,12 +43,16 @@ def _get_c_function_name(split_size, dimensions, offsets, x_strides):
3843
3944
4045def _init_biased_accumulators (split_size ):
41- """Addition is commutative, so we could add the bias before, during, or after performing our
42- multiply-accumulate operations. It "costs" one cycle either way - if done at the beginning we
43- can't use a SMULXY trick to set sum_i to zero for "free", and if done at the end it doesn't
44- combine with anything. However, doing it at the beginning frees up a register/prevents needing
45- to do a stack push/pop, so we'll do it first."""
46- assignments = map (lambda x : f"sum_{ x :x} = bias" , range (split_size ))
46+ """Generates code to load the bias into the accumulators.
47+
48+ Addition is commutative, so we could add the bias before, during, or after performing our
49+ multiply-accumulate operations. Where we add the bias does not change the overflow behavior.
50+
51+ Doing the bias add takes one cycle either way (if done at the beginning we can't use a SMULXY
52+ trick to set sum_i to zero for "free"). However, doing it at the beginning frees up a register,
53+ so we'll do it first.
54+ """
55+ assignments = map (lambda x : f"sum_{ x :x} = *bias" , range (split_size ))
4756 joined_assignments = ", " .join (assignments )
4857 return f"int { joined_assignments } ;"
4958
@@ -98,12 +107,15 @@ def _load_kernel_vars(halfwords) -> Iterator[str]:
98107
99108
100109def _get_draft_macs (kernel_dims , tensor_halfwords , kernel_halfwords , offset ) -> Iterator [Tuple ]:
101- """Generates a functional but un-optimized list of multiply-accumulate instructions that we will
102- optimize later. The tuples in the returned iterator are organized as:
110+ """Generates an un-optimized list of multiply-accumulate instructions.
111+
112+ We will optimize these into SIMD instructions later. The tuples in the returned iterator are
113+ organized as:
103114
104115 (instruction, (arg1_y, arg1_x), (arg2_y, arg2_x))
105116
106- We return an iterator so that optimizations may be done by iterator chaining."""
117+ We return an iterator so that optimizations may be done by iterator chaining.
118+ """
107119
108120 def get_var (y , x , halfwords ):
109121 i = halfwords .index ((y , x ))
@@ -121,9 +133,12 @@ def get_var(y, x, halfwords):
121133
122134
123135def _apply_simd_optimizations (instruction_tuples ) -> Iterator [Tuple ]:
124- """Fuses single halfword MAC instructions into double halfword MAC instructions when possible.
136+ """When possible, fuses single MACs into SIMD MAC instructions.
137+
125138 The compiler cannot do this automatically, as calling __builtin_arm_smlaxy forces the SMLAxy
126- instruction to be used."""
139+ instruction to be used. This function takes as input an iterator of (instruction, var1, var2)
140+ tuples, and returns an iterator of (instruction, var1, var2) tuples.
141+ """
127142 curr_tuple = next (instruction_tuples , None )
128143 while curr_tuple :
129144 next_tuple = next (instruction_tuples , None )
@@ -147,9 +162,10 @@ def _apply_simd_optimizations(instruction_tuples) -> Iterator[Tuple]:
147162
148163
149164def _expand_instruction_tuples (instruction_tuples , index ) -> Iterator [str ]:
150- """Converts a series of (instruction, var1, var2) tuples into lines of C code. We want the
151- compiler to re-order these with the memory loads, so we generate them as a series of calls to
152- instruction aliases instead of as a single `asm` block.
165+ """Converts a series of (instruction, var1, var2) tuples into lines of C code.
166+
167+ We want the compiler to re-order these with the memory loads, so we generate them as a series of
168+ calls to instruction aliases instead of as a single `asm` block.
153169 """
154170
155171 for instruction , op1 , op2 in instruction_tuples :
@@ -165,10 +181,13 @@ def _expand_instruction_tuples(instruction_tuples, index) -> Iterator[str]:
165181 yield f"sum_{ index } = __builtin_arm_{ instruction } ({ op1 } , { op2 } , sum_{ index } );"
166182
167183
168- def _requantize_sums (num_sums ) -> Iterator [str ]:
169- """Simulates multiplying by the float32 requantization scale by doing a int64 multiply + shift,
170- which is much faster. The bias is added at the beginning, so we can skip doing it now. The shift
171- is hard-coded, as this saves a few cycles without hurting accuracy in "most" cases.
184+ def _requantize_sums (num_sums , requantize_shift , output_zero_point ) -> Iterator [str ]:
185+ """Generates code to requantize the accumulator values.
186+
187+ The generated code does not use floating point instructions, as it simulates floating point
188+ multiplication with an a int64 multiply + shift. The bias is added at the beginning, so we can
189+ skip doing it now. The shift is hard-coded, as this saves a few cycles without hurting accuracy
190+ in "most" cases.
172191
173192 It's *possible* we could save one more cycle here by pre-multiplying the bias with the
174193 requantize multiplier, and then doing the bias addition and shift in the same cycle (via <op2>).
@@ -180,22 +199,27 @@ def _requantize_sums(num_sums) -> Iterator[str]:
180199
181200 Calling __builtin_arm_ssat directly is a little bit gross, but GCC and Clang are unreliable
182201 about compiling other ways of writing this. Both the multiply + shift and shift + saturation
183- combine to one instruction each."""
202+ combine to one instruction each.
203+ """
184204
205+ yield "int scale_val = *scale;"
185206 for i in range (num_sums ):
186- yield f"int requant_{ i } = (sum_{ i } * (long long) requant_scale ) >> 32 ;"
207+ yield f"int requant_{ i } = (sum_{ i } * (long long) scale_val ) >> { requantize_shift - 1 } ;"
187208 yield f"requant_{ i } = (requant_{ i } + 1) >> 1;"
188- yield f"requant_{ i } = __builtin_arm_ssat(requant_{ i } - 128 , 8);"
209+ yield f"requant_{ i } = __builtin_arm_ssat(requant_{ i } + { output_zero_point } , 8);"
189210
190211
191212def _write_sums_to_memory (num_sums , offset , stride ) -> Iterator [str ]:
192- """Writes the requantized sums to memory. Note - halfword packing here *does* help. It seems
213+ """Generates code to write the requantized sums to memory.
214+
215+ Note - halfword packing here *does* help. It seems
193216 like it wouldn't, as doing two pipelined int16 stores takes two cycles - the same as halfword
194217 packing plus a pipelined int32 store. We still do the int16 stores when there is an output
195218 stride, though.
196219
197220 However, this lets the compiler re-order instructions to better preserve memory, as it doesn't
198- like breaking apart the store instructions (as this messes up pipelining)."""
221+ like breaking apart the store instructions (as this messes up pipelining).
222+ """
199223
200224 if stride > 1 :
201225 for i in range (num_sums ):
@@ -222,10 +246,14 @@ def tensordot_int16_impl(
222246 dimensions : Tuple [int , int , int ],
223247 offsets : Tuple [int , int , int ],
224248 x_strides : Tuple [int , int ],
249+ requantize_shift : int = 33 ,
250+ output_zero_point : int = - 128 ,
225251) -> Tuple [str , str ]:
226- """Code for a quantized version of tensordot, which computes `split_size` tensordot operations
227- at the same time. Only works with `int16`. The generated function takes as input pointers to the
228- output, tensor, and kernel, which must be word-aligned.
252+ """Generates code to compute a tensor dot product with requantization.
253+
254+ The generated function takes pointers to the output, tensor, and kernel as input. All pointers
255+ must be word aligned. Only works with `int16` data type. The generated code is optimized for the
256+ ARMv7E-M architecture.
229257
230258 Parameters
231259 ----------
@@ -249,6 +277,14 @@ def tensordot_int16_impl(
249277 The distance (in halfwords) between the start of each input tensor, and where to write each
250278 output result respectively. Only used when split_size > 1.
251279
280+ requantize_shift: int
281+ The distance to right shift after multiplying by the requantization scale. Defaults to 33,
282+ as this lets us skip a shift operation.
283+
284+ outout_zero_point: int
285+ The output zero point, which will be subtracted after scale multiplication but before
286+ clipping. Defaults to -128, as most models always use this.
287+
252288 Returns
253289 -------
254290 func_name, func_code: Tuple[str, str]
@@ -273,7 +309,9 @@ def gen_single_loop_macs(index):
273309 return _expand_instruction_tuples (draft_macs_iter , index )
274310
275311 multiply_acc_lines = chain .from_iterable (gen_single_loop_macs (i ) for i in range (split_size ))
276- requantize_lines = _requantize_sums (split_size )
312+ requantize_lines = _requantize_sums (
313+ split_size , requantize_shift = requantize_shift , output_zero_point = output_zero_point
314+ )
277315 write_out_lines = _write_sums_to_memory (split_size , output_offset , out_stride )
278316
279317 def insert_lines (lines ):
@@ -287,7 +325,7 @@ def insert_lines(lines):
287325 #ifndef { function_name .upper ()} _EXISTS
288326 #define { function_name .upper ()} _EXISTS
289327 __attribute__((always_inline)) static inline int { function_name } (
290- int *output, int *tensor, int *kernel, int bias, int requant_scale
328+ int *output, int *tensor, int *kernel, int * bias, int *scale
291329 ) {{
292330 { _init_biased_accumulators (split_size )}
293331
0 commit comments