Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions python/tvm/relay/op/contrib/cmsisnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,21 @@ def check_qnn_softmax(pattern):
zero_point = pattern.args[2].data.numpy().item(0)

# check for dtypes of quantize and dequantize
return (
if (
(scale == 1.0 / 256 and zero_point == -128)
and pattern.attrs.out_dtype == "int8"
and dequantize_call.args[0].checked_type.dtype == "int8"
)
):
return True

if (
(scale == 1.0 / 32768 and zero_point == 0)
and pattern.attrs.out_dtype == "int16"
and dequantize_call.args[0].checked_type.dtype == "int16"
):
return True

return False

def qnn_conv2d_pattern(with_pad):
"""Create pattern for qnn.conv2D with optional pad and/or optional fused relu."""
Expand Down
76 changes: 76 additions & 0 deletions src/relay/backend/contrib/cmsisnn/compute_luts.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/relay/backend/contrib/cmsisnn/compute_luts.cc
* \brief Creates LUTs for operators in different bit formats for accelerating computations.
*/

#include "compute_luts.h"

#include <algorithm>
#include <cmath>
#include <limits>

namespace tvm {
namespace relay {
namespace contrib {
namespace cmsisnn {

void CalculateLUTInt16(int key_zero_point, float key_scale, int value_zero_point, float value_scale,
float (*func)(float), const int steps, int16_t* lut) {
const float value_min = static_cast<float>(std::numeric_limits<int16_t>::min());
const float value_max = static_cast<float>(std::numeric_limits<int16_t>::max());
const float key_min_deq = key_scale * (std::numeric_limits<int16_t>::min() - key_zero_point);
const float key_max_deq = key_scale * (std::numeric_limits<int16_t>::max() - key_zero_point);
const float value_min_deq =
value_scale * (std::numeric_limits<int16_t>::min() - value_zero_point);
const float value_max_deq =
value_scale * (std::numeric_limits<int16_t>::max() - value_zero_point);

const float step_size_deq = (key_max_deq - key_min_deq) / (steps - 1);
const float half_step_size_deq = step_size_deq / 2;

const float value_inv_quantizing =
(std::numeric_limits<int16_t>::max() - std::numeric_limits<int16_t>::min() + 1) /
(value_max_deq - value_min_deq);

for (int i = 0; i < steps - 1; i++) {
float value_deq = func(key_min_deq + i * step_size_deq);
float mid_value_deq = func(key_min_deq + i * step_size_deq + half_step_size_deq);
float next_value_deq = func(key_min_deq + (i + 1) * step_size_deq);

float value = std::round(value_deq * value_inv_quantizing);
float mid_value = std::round(mid_value_deq * value_inv_quantizing);
float next_value = std::round(next_value_deq * value_inv_quantizing);
float mid_iterp_value = std::round((value + next_value) / 2);

float mid_err = mid_iterp_value - mid_value;
float bias = std::round(mid_err / 2);

lut[i] = static_cast<int16_t>(std::max(std::min(value - bias, value_max), value_min));
}

lut[steps - 1] = static_cast<int16_t>(
std::max(std::min(func(value_max_deq) * value_inv_quantizing, value_max), value_min));
}

} // namespace cmsisnn
} // namespace contrib
} // namespace relay
} // namespace tvm
55 changes: 55 additions & 0 deletions src/relay/backend/contrib/cmsisnn/compute_luts.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/relay/backend/contrib/cmsisnn/compute_luts.h
* \brief CMSIS-NN LUTs calculation functions
*/

#ifndef TVM_RELAY_BACKEND_CONTRIB_CMSISNN_COMPUTE_LUTS_H_
#define TVM_RELAY_BACKEND_CONTRIB_CMSISNN_COMPUTE_LUTS_H_

#include <cstdint>

namespace tvm {
namespace relay {
namespace contrib {
namespace cmsisnn {

/*!
* \brief Populates an int16 LUT based on the quantization parameters of its keys, values and
* respective transformation function
*
* \param key_zero_point - zero point of table's keys
* \param key_scale - scale of the table's keys
* \param value_zero_point - zero point of table's values
* \param value_scale - scale of the table's values
* \param func - function pointer of the transformation performed by the LUT
* \param steps - number of total values inside the table
* \param lut - int16_t array storing the values of the LUT
*/
void CalculateLUTInt16(int key_zero_point, float key_scale, int value_zero_point, float value_scale,
float (*func)(float), const int steps, int16_t* lut);

} // namespace cmsisnn
} // namespace contrib
} // namespace relay
} // namespace tvm

#endif // TVM_RELAY_BACKEND_CONTRIB_CMSISNN_COMPUTE_LUTS_H_
151 changes: 122 additions & 29 deletions src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "../../../transforms/pattern_utils.h"
#include "buffer_size.h"
#include "compiler_attrs.h"
#include "compute_luts.h"
#include "convolutions.h"

namespace tvm {
Expand Down Expand Up @@ -89,11 +90,17 @@ class RelayToTIRVisitor : public MixedModeMutator {
private:
inline IntImm ToArg(int32_t value) { return IntImm(DataType::Int(32), value); }

void CreatePrimFuncForExtern(const GlobalVar& global_var, Array<tir::Var> func_signature,
const Map<tir::Var, tir::Buffer>& buffer_map,
tvm::Array<PrimExpr> call_extern_args,
PrimExpr context_buffer_var = PrimExpr(),
int context_buffer_size = 0, int num_bits = 8) {
// struct used to allocated const NDArray
struct tir_input_constant_buffers {
tir::Var buffer_var;
tvm::runtime::NDArray ndarray;
};

void CreatePrimFuncForExtern(
const GlobalVar& global_var, Array<tir::Var> func_signature,
const Map<tir::Var, tir::Buffer>& buffer_map, tvm::Array<PrimExpr> call_extern_args,
PrimExpr context_buffer_var = PrimExpr(), int context_buffer_size = 0, int num_bits = 8,
std::vector<tir_input_constant_buffers> context_const_buffer_vars = {}) {
Map<String, ObjectRef> dict_attrs;
dict_attrs.Set(tvm::attr::kGlobalSymbol, global_var->name_hint);
dict_attrs.Set(tvm::attr::kTarget, target_);
Expand All @@ -107,8 +114,22 @@ class RelayToTIRVisitor : public MixedModeMutator {
{context_buffer_size}, tir::const_true(), body);
}

for (int i = 0; i < static_cast<int>(context_const_buffer_vars.size()); i++) {
int bits = context_const_buffer_vars[i].ndarray.DataType().bits();

Array<PrimExpr> extents;
for (int shape : context_const_buffer_vars[i].ndarray.Shape()) {
extents.push_back(PrimExpr(shape));
}

body = tir::AllocateConst(Downcast<tir::Var>(context_const_buffer_vars[i].buffer_var),
DataType::Int(bits), extents, context_const_buffer_vars[i].ndarray,
body);
}

tir::PrimFunc replacement_func(func_signature, body, VoidType(), buffer_map,
DictAttrs(dict_attrs));

ir_module_->Add(global_var, replacement_func);
}

Expand Down Expand Up @@ -505,6 +526,7 @@ class RelayToTIRVisitor : public MixedModeMutator {
const CallNode* softmax_call = quantize_call->args[0].as<CallNode>();
const CallNode* dequant_call = softmax_call->args[0].as<CallNode>();
const float quant_scale = GetScalarFromConstant<float>(dequant_call->args[1]);
const auto bit_width = quantize_call->type_as<TensorTypeNode>()->dtype.bits();

// assuming layout as NHWC
auto shape = quantize_call->type_as<TensorTypeNode>()->shape;
Expand All @@ -517,36 +539,107 @@ class RelayToTIRVisitor : public MixedModeMutator {

// calculate multiplier and shift for CMSIS-NN softmax API
// Note: TensorFlow Lite Micro assumptions
// Output zero point and scale are fixed to -128 and 1 / 256
// Output zero point and scale are fixed to -128 and 1 / 256 in the case of an int8 operator
// or to 0 and 1 / 32768 in the case of an int16 operator
// kScaledDiffIntegerBits, kInputBits, kBeta are described on the following github page
// https://github.com/tensorflow/tflite-micro/blob/d97cd0908d8cf5021e9d86f05a49888bee28c2a4/tensorflow/lite/micro/kernels/softmax_common.cc#L47
double beta_multiplier = (kBeta * quant_scale * (1 << (31 - kInputBits)));
beta_multiplier = std::min<double>(beta_multiplier, (1ll << 31) - 1.0);
auto mult_shift_pair = tvm::relay::qnn::GetFixedPointMultiplierShift(beta_multiplier);
int32_t mult = std::get<0>(mult_shift_pair);
int32_t shift = std::get<1>(mult_shift_pair);
int32_t diff_min = (1 << kScaledDiffIntegerBits) - 1;
diff_min <<= (31 - kScaledDiffIntegerBits);
diff_min >>= shift;
diff_min *= -1;

int32_t mult;
int32_t shift;
int32_t diff_min = 0;

std::vector<tir_input_constant_buffers> softmax_params(2);
Device dev{DLDeviceType::kDLCPU, 0};

if (bit_width == 8) {
double beta_multiplier = (kBeta * quant_scale * (1 << (31 - kInputBits)));
beta_multiplier = std::min<double>(beta_multiplier, (1ll << 31) - 1.0);
auto mult_shift_pair = tvm::relay::qnn::GetFixedPointMultiplierShift(beta_multiplier);
mult = std::get<0>(mult_shift_pair);
shift = std::get<1>(mult_shift_pair);
diff_min = (1 << kScaledDiffIntegerBits) - 1;
diff_min <<= (31 - kScaledDiffIntegerBits);
diff_min >>= shift;
diff_min *= -1;
} else { // bit_width == 16
double scale_beta_rescale = quant_scale * kBeta / (10.0 / 65535.0);
auto mult_shift_pair = tvm::relay::qnn::GetFixedPointMultiplierShift(scale_beta_rescale);
mult = std::get<0>(mult_shift_pair);
shift = std::get<1>(mult_shift_pair);

const int kLUTEntries = 513;
int16_t softmax_s16_exp_lut[kLUTEntries];
int16_t softmax_s16_one_by_one_lut[kLUTEntries];

const int range_int16 =
std::numeric_limits<int16_t>::max() - std::numeric_limits<int16_t>::min();
int exp_zero_point = std::numeric_limits<int16_t>::max();
float exp_scale = 10.0f / range_int16;

int one_by_one_zero_point = std::numeric_limits<int16_t>::min();
float one_by_one_scale = 1.0f / range_int16;

int lut_value_zero_point = 0;
float lut_value_scale = 2.0f / range_int16;

CalculateLUTInt16(
exp_zero_point, exp_scale, lut_value_zero_point, lut_value_scale,
[](float key) { return std::exp(key); }, kLUTEntries, softmax_s16_exp_lut);
CalculateLUTInt16(
one_by_one_zero_point, one_by_one_scale, lut_value_zero_point, lut_value_scale,
[](float key) { return 1.0f / (1.0f + key); }, kLUTEntries, softmax_s16_one_by_one_lut);

// first LUT
softmax_params[0].buffer_var =
tir::Var("exp_lut", PointerType(PrimType(DataType::Int(bit_width)), "global.workspace"));
softmax_params[0].ndarray =
runtime::NDArray::Empty({kLUTEntries}, DataType::Int(bit_width), dev);
softmax_params[0].ndarray.CopyFromBytes(softmax_s16_exp_lut, sizeof(int16_t) * kLUTEntries);

// second LUT
softmax_params[1].buffer_var = tir::Var(
"one_by_one_lut", PointerType(PrimType(DataType::Int(bit_width)), "global.workspace"));
softmax_params[1].ndarray =
runtime::NDArray::Empty({kLUTEntries}, DataType::Int(bit_width), dev);
softmax_params[1].ndarray.CopyFromBytes(softmax_s16_one_by_one_lut,
sizeof(int16_t) * kLUTEntries);
}

BufferCreator buffer_creator;
tir::Var in_var = buffer_creator.CreateBufferVar("input", DataType::Handle(8));
tir::Var out_var = buffer_creator.CreateBufferVar("output", DataType::Handle(8));
tir::Var in_var = buffer_creator.CreateBufferVar("input", DataType::Handle(bit_width));
tir::Var out_var = buffer_creator.CreateBufferVar("output", DataType::Handle(bit_width));

if (bit_width == 8) {
tvm::Array<PrimExpr> args = {
tir::StringImm("arm_softmax_s" + std::to_string(bit_width)),
in_var,
ToArg(num_rows),
ToArg(row_size),
ToArg(mult),
ToArg(shift),
ToArg(diff_min),
out_var,
};

tvm::Array<PrimExpr> args = {
tir::StringImm("arm_softmax_s8"),
in_var,
ToArg(num_rows),
ToArg(row_size),
ToArg(mult),
ToArg(shift),
ToArg(diff_min),
out_var,
};
CreatePrimFuncForExtern(global_var, buffer_creator.GetPrimFuncParams(),
buffer_creator.GetBufferMap(), args);
} else { // bit_width == 16
tvm::Array<PrimExpr> args = {
tir::StringImm("arm_softmax_s" + std::to_string(bit_width)),
in_var,
ToArg(num_rows),
ToArg(row_size),
ToArg(mult),
ToArg(shift),
softmax_params[0].buffer_var,
softmax_params[1].buffer_var,
out_var,
};

CreatePrimFuncForExtern(global_var, buffer_creator.GetPrimFuncParams(),
buffer_creator.GetBufferMap(), args);
CreatePrimFuncForExtern(global_var, buffer_creator.GetPrimFuncParams(),
buffer_creator.GetBufferMap(), args, PrimExpr(), 0, 16,
softmax_params);
}
}

struct BinaryElementwiseClipPattern {
Expand Down
Loading