Skip to content

Commit 77b71fc

Browse files
authored
[CMSIS-NN] Support for Softmax Int16 operator (#15407)
* Support for int16 Softmax in CMSIS-NN * Supporting integration test
1 parent 8b37d4d commit 77b71fc

File tree

6 files changed

+364
-31
lines changed

6 files changed

+364
-31
lines changed

python/tvm/relay/op/contrib/cmsisnn.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,21 @@ def check_qnn_softmax(pattern):
8686
zero_point = pattern.args[2].data.numpy().item(0)
8787

8888
# check for dtypes of quantize and dequantize
89-
return (
89+
if (
9090
(scale == 1.0 / 256 and zero_point == -128)
9191
and pattern.attrs.out_dtype == "int8"
9292
and dequantize_call.args[0].checked_type.dtype == "int8"
93-
)
93+
):
94+
return True
95+
96+
if (
97+
(scale == 1.0 / 32768 and zero_point == 0)
98+
and pattern.attrs.out_dtype == "int16"
99+
and dequantize_call.args[0].checked_type.dtype == "int16"
100+
):
101+
return True
102+
103+
return False
94104

95105
def qnn_conv2d_pattern(with_pad):
96106
"""Create pattern for qnn.conv2D with optional pad and/or optional fused relu."""
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
/*!
20+
* \file src/relay/backend/contrib/cmsisnn/compute_luts.cc
21+
* \brief Creates LUTs for operators in different bit formats for accelerating computations.
22+
*/
23+
24+
#include "compute_luts.h"
25+
26+
#include <algorithm>
27+
#include <cmath>
28+
#include <limits>
29+
30+
namespace tvm {
31+
namespace relay {
32+
namespace contrib {
33+
namespace cmsisnn {
34+
35+
void CalculateLUTInt16(int key_zero_point, float key_scale, int value_zero_point, float value_scale,
36+
float (*func)(float), const int steps, int16_t* lut) {
37+
const float value_min = static_cast<float>(std::numeric_limits<int16_t>::min());
38+
const float value_max = static_cast<float>(std::numeric_limits<int16_t>::max());
39+
const float key_min_deq = key_scale * (std::numeric_limits<int16_t>::min() - key_zero_point);
40+
const float key_max_deq = key_scale * (std::numeric_limits<int16_t>::max() - key_zero_point);
41+
const float value_min_deq =
42+
value_scale * (std::numeric_limits<int16_t>::min() - value_zero_point);
43+
const float value_max_deq =
44+
value_scale * (std::numeric_limits<int16_t>::max() - value_zero_point);
45+
46+
const float step_size_deq = (key_max_deq - key_min_deq) / (steps - 1);
47+
const float half_step_size_deq = step_size_deq / 2;
48+
49+
const float value_inv_quantizing =
50+
(std::numeric_limits<int16_t>::max() - std::numeric_limits<int16_t>::min() + 1) /
51+
(value_max_deq - value_min_deq);
52+
53+
for (int i = 0; i < steps - 1; i++) {
54+
float value_deq = func(key_min_deq + i * step_size_deq);
55+
float mid_value_deq = func(key_min_deq + i * step_size_deq + half_step_size_deq);
56+
float next_value_deq = func(key_min_deq + (i + 1) * step_size_deq);
57+
58+
float value = std::round(value_deq * value_inv_quantizing);
59+
float mid_value = std::round(mid_value_deq * value_inv_quantizing);
60+
float next_value = std::round(next_value_deq * value_inv_quantizing);
61+
float mid_iterp_value = std::round((value + next_value) / 2);
62+
63+
float mid_err = mid_iterp_value - mid_value;
64+
float bias = std::round(mid_err / 2);
65+
66+
lut[i] = static_cast<int16_t>(std::max(std::min(value - bias, value_max), value_min));
67+
}
68+
69+
lut[steps - 1] = static_cast<int16_t>(
70+
std::max(std::min(func(value_max_deq) * value_inv_quantizing, value_max), value_min));
71+
}
72+
73+
} // namespace cmsisnn
74+
} // namespace contrib
75+
} // namespace relay
76+
} // namespace tvm
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file src/relay/backend/contrib/cmsisnn/compute_luts.h
22+
* \brief CMSIS-NN LUTs calculation functions
23+
*/
24+
25+
#ifndef TVM_RELAY_BACKEND_CONTRIB_CMSISNN_COMPUTE_LUTS_H_
26+
#define TVM_RELAY_BACKEND_CONTRIB_CMSISNN_COMPUTE_LUTS_H_
27+
28+
#include <cstdint>
29+
30+
namespace tvm {
31+
namespace relay {
32+
namespace contrib {
33+
namespace cmsisnn {
34+
35+
/*!
36+
* \brief Populates an int16 LUT based on the quantization parameters of its keys, values and
37+
* respective transformation function
38+
*
39+
* \param key_zero_point - zero point of table's keys
40+
* \param key_scale - scale of the table's keys
41+
* \param value_zero_point - zero point of table's values
42+
* \param value_scale - scale of the table's values
43+
* \param func - function pointer of the transformation performed by the LUT
44+
* \param steps - number of total values inside the table
45+
* \param lut - int16_t array storing the values of the LUT
46+
*/
47+
void CalculateLUTInt16(int key_zero_point, float key_scale, int value_zero_point, float value_scale,
48+
float (*func)(float), const int steps, int16_t* lut);
49+
50+
} // namespace cmsisnn
51+
} // namespace contrib
52+
} // namespace relay
53+
} // namespace tvm
54+
55+
#endif // TVM_RELAY_BACKEND_CONTRIB_CMSISNN_COMPUTE_LUTS_H_

src/relay/backend/contrib/cmsisnn/relay_to_tir.cc

Lines changed: 122 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "../../../transforms/pattern_utils.h"
3131
#include "buffer_size.h"
3232
#include "compiler_attrs.h"
33+
#include "compute_luts.h"
3334
#include "convolutions.h"
3435

3536
namespace tvm {
@@ -89,11 +90,17 @@ class RelayToTIRVisitor : public MixedModeMutator {
8990
private:
9091
inline IntImm ToArg(int32_t value) { return IntImm(DataType::Int(32), value); }
9192

92-
void CreatePrimFuncForExtern(const GlobalVar& global_var, Array<tir::Var> func_signature,
93-
const Map<tir::Var, tir::Buffer>& buffer_map,
94-
tvm::Array<PrimExpr> call_extern_args,
95-
PrimExpr context_buffer_var = PrimExpr(),
96-
int context_buffer_size = 0, int num_bits = 8) {
93+
// struct used to allocated const NDArray
94+
struct tir_input_constant_buffers {
95+
tir::Var buffer_var;
96+
tvm::runtime::NDArray ndarray;
97+
};
98+
99+
void CreatePrimFuncForExtern(
100+
const GlobalVar& global_var, Array<tir::Var> func_signature,
101+
const Map<tir::Var, tir::Buffer>& buffer_map, tvm::Array<PrimExpr> call_extern_args,
102+
PrimExpr context_buffer_var = PrimExpr(), int context_buffer_size = 0, int num_bits = 8,
103+
std::vector<tir_input_constant_buffers> context_const_buffer_vars = {}) {
97104
Map<String, ObjectRef> dict_attrs;
98105
dict_attrs.Set(tvm::attr::kGlobalSymbol, global_var->name_hint);
99106
dict_attrs.Set(tvm::attr::kTarget, target_);
@@ -107,8 +114,22 @@ class RelayToTIRVisitor : public MixedModeMutator {
107114
{context_buffer_size}, tir::const_true(), body);
108115
}
109116

117+
for (int i = 0; i < static_cast<int>(context_const_buffer_vars.size()); i++) {
118+
int bits = context_const_buffer_vars[i].ndarray.DataType().bits();
119+
120+
Array<PrimExpr> extents;
121+
for (int shape : context_const_buffer_vars[i].ndarray.Shape()) {
122+
extents.push_back(PrimExpr(shape));
123+
}
124+
125+
body = tir::AllocateConst(Downcast<tir::Var>(context_const_buffer_vars[i].buffer_var),
126+
DataType::Int(bits), extents, context_const_buffer_vars[i].ndarray,
127+
body);
128+
}
129+
110130
tir::PrimFunc replacement_func(func_signature, body, VoidType(), buffer_map,
111131
DictAttrs(dict_attrs));
132+
112133
ir_module_->Add(global_var, replacement_func);
113134
}
114135

@@ -505,6 +526,7 @@ class RelayToTIRVisitor : public MixedModeMutator {
505526
const CallNode* softmax_call = quantize_call->args[0].as<CallNode>();
506527
const CallNode* dequant_call = softmax_call->args[0].as<CallNode>();
507528
const float quant_scale = GetScalarFromConstant<float>(dequant_call->args[1]);
529+
const auto bit_width = quantize_call->type_as<TensorTypeNode>()->dtype.bits();
508530

509531
// assuming layout as NHWC
510532
auto shape = quantize_call->type_as<TensorTypeNode>()->shape;
@@ -517,36 +539,107 @@ class RelayToTIRVisitor : public MixedModeMutator {
517539

518540
// calculate multiplier and shift for CMSIS-NN softmax API
519541
// Note: TensorFlow Lite Micro assumptions
520-
// Output zero point and scale are fixed to -128 and 1 / 256
542+
// Output zero point and scale are fixed to -128 and 1 / 256 in the case of an int8 operator
543+
// or to 0 and 1 / 32768 in the case of an int16 operator
521544
// kScaledDiffIntegerBits, kInputBits, kBeta are described on the following github page
522545
// https://github.com/tensorflow/tflite-micro/blob/d97cd0908d8cf5021e9d86f05a49888bee28c2a4/tensorflow/lite/micro/kernels/softmax_common.cc#L47
523-
double beta_multiplier = (kBeta * quant_scale * (1 << (31 - kInputBits)));
524-
beta_multiplier = std::min<double>(beta_multiplier, (1ll << 31) - 1.0);
525-
auto mult_shift_pair = tvm::relay::qnn::GetFixedPointMultiplierShift(beta_multiplier);
526-
int32_t mult = std::get<0>(mult_shift_pair);
527-
int32_t shift = std::get<1>(mult_shift_pair);
528-
int32_t diff_min = (1 << kScaledDiffIntegerBits) - 1;
529-
diff_min <<= (31 - kScaledDiffIntegerBits);
530-
diff_min >>= shift;
531-
diff_min *= -1;
546+
547+
int32_t mult;
548+
int32_t shift;
549+
int32_t diff_min = 0;
550+
551+
std::vector<tir_input_constant_buffers> softmax_params(2);
552+
Device dev{DLDeviceType::kDLCPU, 0};
553+
554+
if (bit_width == 8) {
555+
double beta_multiplier = (kBeta * quant_scale * (1 << (31 - kInputBits)));
556+
beta_multiplier = std::min<double>(beta_multiplier, (1ll << 31) - 1.0);
557+
auto mult_shift_pair = tvm::relay::qnn::GetFixedPointMultiplierShift(beta_multiplier);
558+
mult = std::get<0>(mult_shift_pair);
559+
shift = std::get<1>(mult_shift_pair);
560+
diff_min = (1 << kScaledDiffIntegerBits) - 1;
561+
diff_min <<= (31 - kScaledDiffIntegerBits);
562+
diff_min >>= shift;
563+
diff_min *= -1;
564+
} else { // bit_width == 16
565+
double scale_beta_rescale = quant_scale * kBeta / (10.0 / 65535.0);
566+
auto mult_shift_pair = tvm::relay::qnn::GetFixedPointMultiplierShift(scale_beta_rescale);
567+
mult = std::get<0>(mult_shift_pair);
568+
shift = std::get<1>(mult_shift_pair);
569+
570+
const int kLUTEntries = 513;
571+
int16_t softmax_s16_exp_lut[kLUTEntries];
572+
int16_t softmax_s16_one_by_one_lut[kLUTEntries];
573+
574+
const int range_int16 =
575+
std::numeric_limits<int16_t>::max() - std::numeric_limits<int16_t>::min();
576+
int exp_zero_point = std::numeric_limits<int16_t>::max();
577+
float exp_scale = 10.0f / range_int16;
578+
579+
int one_by_one_zero_point = std::numeric_limits<int16_t>::min();
580+
float one_by_one_scale = 1.0f / range_int16;
581+
582+
int lut_value_zero_point = 0;
583+
float lut_value_scale = 2.0f / range_int16;
584+
585+
CalculateLUTInt16(
586+
exp_zero_point, exp_scale, lut_value_zero_point, lut_value_scale,
587+
[](float key) { return std::exp(key); }, kLUTEntries, softmax_s16_exp_lut);
588+
CalculateLUTInt16(
589+
one_by_one_zero_point, one_by_one_scale, lut_value_zero_point, lut_value_scale,
590+
[](float key) { return 1.0f / (1.0f + key); }, kLUTEntries, softmax_s16_one_by_one_lut);
591+
592+
// first LUT
593+
softmax_params[0].buffer_var =
594+
tir::Var("exp_lut", PointerType(PrimType(DataType::Int(bit_width)), "global.workspace"));
595+
softmax_params[0].ndarray =
596+
runtime::NDArray::Empty({kLUTEntries}, DataType::Int(bit_width), dev);
597+
softmax_params[0].ndarray.CopyFromBytes(softmax_s16_exp_lut, sizeof(int16_t) * kLUTEntries);
598+
599+
// second LUT
600+
softmax_params[1].buffer_var = tir::Var(
601+
"one_by_one_lut", PointerType(PrimType(DataType::Int(bit_width)), "global.workspace"));
602+
softmax_params[1].ndarray =
603+
runtime::NDArray::Empty({kLUTEntries}, DataType::Int(bit_width), dev);
604+
softmax_params[1].ndarray.CopyFromBytes(softmax_s16_one_by_one_lut,
605+
sizeof(int16_t) * kLUTEntries);
606+
}
532607

533608
BufferCreator buffer_creator;
534-
tir::Var in_var = buffer_creator.CreateBufferVar("input", DataType::Handle(8));
535-
tir::Var out_var = buffer_creator.CreateBufferVar("output", DataType::Handle(8));
609+
tir::Var in_var = buffer_creator.CreateBufferVar("input", DataType::Handle(bit_width));
610+
tir::Var out_var = buffer_creator.CreateBufferVar("output", DataType::Handle(bit_width));
611+
612+
if (bit_width == 8) {
613+
tvm::Array<PrimExpr> args = {
614+
tir::StringImm("arm_softmax_s" + std::to_string(bit_width)),
615+
in_var,
616+
ToArg(num_rows),
617+
ToArg(row_size),
618+
ToArg(mult),
619+
ToArg(shift),
620+
ToArg(diff_min),
621+
out_var,
622+
};
536623

537-
tvm::Array<PrimExpr> args = {
538-
tir::StringImm("arm_softmax_s8"),
539-
in_var,
540-
ToArg(num_rows),
541-
ToArg(row_size),
542-
ToArg(mult),
543-
ToArg(shift),
544-
ToArg(diff_min),
545-
out_var,
546-
};
624+
CreatePrimFuncForExtern(global_var, buffer_creator.GetPrimFuncParams(),
625+
buffer_creator.GetBufferMap(), args);
626+
} else { // bit_width == 16
627+
tvm::Array<PrimExpr> args = {
628+
tir::StringImm("arm_softmax_s" + std::to_string(bit_width)),
629+
in_var,
630+
ToArg(num_rows),
631+
ToArg(row_size),
632+
ToArg(mult),
633+
ToArg(shift),
634+
softmax_params[0].buffer_var,
635+
softmax_params[1].buffer_var,
636+
out_var,
637+
};
547638

548-
CreatePrimFuncForExtern(global_var, buffer_creator.GetPrimFuncParams(),
549-
buffer_creator.GetBufferMap(), args);
639+
CreatePrimFuncForExtern(global_var, buffer_creator.GetPrimFuncParams(),
640+
buffer_creator.GetBufferMap(), args, PrimExpr(), 0, 16,
641+
softmax_params);
642+
}
550643
}
551644

552645
struct BinaryElementwiseClipPattern {

0 commit comments

Comments
 (0)