3030#include " ../../../transforms/pattern_utils.h"
3131#include " buffer_size.h"
3232#include " compiler_attrs.h"
33+ #include " compute_luts.h"
3334#include " convolutions.h"
3435
3536namespace tvm {
@@ -89,11 +90,17 @@ class RelayToTIRVisitor : public MixedModeMutator {
8990 private:
9091 inline IntImm ToArg (int32_t value) { return IntImm (DataType::Int (32 ), value); }
9192
92- void CreatePrimFuncForExtern (const GlobalVar& global_var, Array<tir::Var> func_signature,
93- const Map<tir::Var, tir::Buffer>& buffer_map,
94- tvm::Array<PrimExpr> call_extern_args,
95- PrimExpr context_buffer_var = PrimExpr(),
96- int context_buffer_size = 0, int num_bits = 8) {
93+ // struct used to allocated const NDArray
94+ struct tir_input_constant_buffers {
95+ tir::Var buffer_var;
96+ tvm::runtime::NDArray ndarray;
97+ };
98+
99+ void CreatePrimFuncForExtern (
100+ const GlobalVar& global_var, Array<tir::Var> func_signature,
101+ const Map<tir::Var, tir::Buffer>& buffer_map, tvm::Array<PrimExpr> call_extern_args,
102+ PrimExpr context_buffer_var = PrimExpr(), int context_buffer_size = 0, int num_bits = 8,
103+ std::vector<tir_input_constant_buffers> context_const_buffer_vars = {}) {
97104 Map<String, ObjectRef> dict_attrs;
98105 dict_attrs.Set (tvm::attr::kGlobalSymbol , global_var->name_hint );
99106 dict_attrs.Set (tvm::attr::kTarget , target_);
@@ -107,8 +114,22 @@ class RelayToTIRVisitor : public MixedModeMutator {
107114 {context_buffer_size}, tir::const_true (), body);
108115 }
109116
117+ for (int i = 0 ; i < static_cast <int >(context_const_buffer_vars.size ()); i++) {
118+ int bits = context_const_buffer_vars[i].ndarray .DataType ().bits ();
119+
120+ Array<PrimExpr> extents;
121+ for (int shape : context_const_buffer_vars[i].ndarray .Shape ()) {
122+ extents.push_back (PrimExpr (shape));
123+ }
124+
125+ body = tir::AllocateConst (Downcast<tir::Var>(context_const_buffer_vars[i].buffer_var ),
126+ DataType::Int (bits), extents, context_const_buffer_vars[i].ndarray ,
127+ body);
128+ }
129+
110130 tir::PrimFunc replacement_func (func_signature, body, VoidType (), buffer_map,
111131 DictAttrs (dict_attrs));
132+
112133 ir_module_->Add (global_var, replacement_func);
113134 }
114135
@@ -505,6 +526,7 @@ class RelayToTIRVisitor : public MixedModeMutator {
505526 const CallNode* softmax_call = quantize_call->args [0 ].as <CallNode>();
506527 const CallNode* dequant_call = softmax_call->args [0 ].as <CallNode>();
507528 const float quant_scale = GetScalarFromConstant<float >(dequant_call->args [1 ]);
529+ const auto bit_width = quantize_call->type_as <TensorTypeNode>()->dtype .bits ();
508530
509531 // assuming layout as NHWC
510532 auto shape = quantize_call->type_as <TensorTypeNode>()->shape ;
@@ -517,36 +539,107 @@ class RelayToTIRVisitor : public MixedModeMutator {
517539
518540 // calculate multiplier and shift for CMSIS-NN softmax API
519541 // Note: TensorFlow Lite Micro assumptions
520- // Output zero point and scale are fixed to -128 and 1 / 256
542+ // Output zero point and scale are fixed to -128 and 1 / 256 in the case of an int8 operator
543+ // or to 0 and 1 / 32768 in the case of an int16 operator
521544 // kScaledDiffIntegerBits, kInputBits, kBeta are described on the following github page
522545 // https://github.com/tensorflow/tflite-micro/blob/d97cd0908d8cf5021e9d86f05a49888bee28c2a4/tensorflow/lite/micro/kernels/softmax_common.cc#L47
523- double beta_multiplier = (kBeta * quant_scale * (1 << (31 - kInputBits )));
524- beta_multiplier = std::min<double >(beta_multiplier, (1ll << 31 ) - 1.0 );
525- auto mult_shift_pair = tvm::relay::qnn::GetFixedPointMultiplierShift (beta_multiplier);
526- int32_t mult = std::get<0 >(mult_shift_pair);
527- int32_t shift = std::get<1 >(mult_shift_pair);
528- int32_t diff_min = (1 << kScaledDiffIntegerBits ) - 1 ;
529- diff_min <<= (31 - kScaledDiffIntegerBits );
530- diff_min >>= shift;
531- diff_min *= -1 ;
546+
547+ int32_t mult;
548+ int32_t shift;
549+ int32_t diff_min = 0 ;
550+
551+ std::vector<tir_input_constant_buffers> softmax_params (2 );
552+ Device dev{DLDeviceType::kDLCPU , 0 };
553+
554+ if (bit_width == 8 ) {
555+ double beta_multiplier = (kBeta * quant_scale * (1 << (31 - kInputBits )));
556+ beta_multiplier = std::min<double >(beta_multiplier, (1ll << 31 ) - 1.0 );
557+ auto mult_shift_pair = tvm::relay::qnn::GetFixedPointMultiplierShift (beta_multiplier);
558+ mult = std::get<0 >(mult_shift_pair);
559+ shift = std::get<1 >(mult_shift_pair);
560+ diff_min = (1 << kScaledDiffIntegerBits ) - 1 ;
561+ diff_min <<= (31 - kScaledDiffIntegerBits );
562+ diff_min >>= shift;
563+ diff_min *= -1 ;
564+ } else { // bit_width == 16
565+ double scale_beta_rescale = quant_scale * kBeta / (10.0 / 65535.0 );
566+ auto mult_shift_pair = tvm::relay::qnn::GetFixedPointMultiplierShift (scale_beta_rescale);
567+ mult = std::get<0 >(mult_shift_pair);
568+ shift = std::get<1 >(mult_shift_pair);
569+
570+ const int kLUTEntries = 513 ;
571+ int16_t softmax_s16_exp_lut[kLUTEntries ];
572+ int16_t softmax_s16_one_by_one_lut[kLUTEntries ];
573+
574+ const int range_int16 =
575+ std::numeric_limits<int16_t >::max () - std::numeric_limits<int16_t >::min ();
576+ int exp_zero_point = std::numeric_limits<int16_t >::max ();
577+ float exp_scale = 10 .0f / range_int16;
578+
579+ int one_by_one_zero_point = std::numeric_limits<int16_t >::min ();
580+ float one_by_one_scale = 1 .0f / range_int16;
581+
582+ int lut_value_zero_point = 0 ;
583+ float lut_value_scale = 2 .0f / range_int16;
584+
585+ CalculateLUTInt16 (
586+ exp_zero_point, exp_scale, lut_value_zero_point, lut_value_scale,
587+ [](float key) { return std::exp (key); }, kLUTEntries , softmax_s16_exp_lut);
588+ CalculateLUTInt16 (
589+ one_by_one_zero_point, one_by_one_scale, lut_value_zero_point, lut_value_scale,
590+ [](float key) { return 1 .0f / (1 .0f + key); }, kLUTEntries , softmax_s16_one_by_one_lut);
591+
592+ // first LUT
593+ softmax_params[0 ].buffer_var =
594+ tir::Var (" exp_lut" , PointerType (PrimType (DataType::Int (bit_width)), " global.workspace" ));
595+ softmax_params[0 ].ndarray =
596+ runtime::NDArray::Empty ({kLUTEntries }, DataType::Int (bit_width), dev);
597+ softmax_params[0 ].ndarray .CopyFromBytes (softmax_s16_exp_lut, sizeof (int16_t ) * kLUTEntries );
598+
599+ // second LUT
600+ softmax_params[1 ].buffer_var = tir::Var (
601+ " one_by_one_lut" , PointerType (PrimType (DataType::Int (bit_width)), " global.workspace" ));
602+ softmax_params[1 ].ndarray =
603+ runtime::NDArray::Empty ({kLUTEntries }, DataType::Int (bit_width), dev);
604+ softmax_params[1 ].ndarray .CopyFromBytes (softmax_s16_one_by_one_lut,
605+ sizeof (int16_t ) * kLUTEntries );
606+ }
532607
533608 BufferCreator buffer_creator;
534- tir::Var in_var = buffer_creator.CreateBufferVar (" input" , DataType::Handle (8 ));
535- tir::Var out_var = buffer_creator.CreateBufferVar (" output" , DataType::Handle (8 ));
609+ tir::Var in_var = buffer_creator.CreateBufferVar (" input" , DataType::Handle (bit_width));
610+ tir::Var out_var = buffer_creator.CreateBufferVar (" output" , DataType::Handle (bit_width));
611+
612+ if (bit_width == 8 ) {
613+ tvm::Array<PrimExpr> args = {
614+ tir::StringImm (" arm_softmax_s" + std::to_string (bit_width)),
615+ in_var,
616+ ToArg (num_rows),
617+ ToArg (row_size),
618+ ToArg (mult),
619+ ToArg (shift),
620+ ToArg (diff_min),
621+ out_var,
622+ };
536623
537- tvm::Array<PrimExpr> args = {
538- tir::StringImm (" arm_softmax_s8" ),
539- in_var,
540- ToArg (num_rows),
541- ToArg (row_size),
542- ToArg (mult),
543- ToArg (shift),
544- ToArg (diff_min),
545- out_var,
546- };
624+ CreatePrimFuncForExtern (global_var, buffer_creator.GetPrimFuncParams (),
625+ buffer_creator.GetBufferMap (), args);
626+ } else { // bit_width == 16
627+ tvm::Array<PrimExpr> args = {
628+ tir::StringImm (" arm_softmax_s" + std::to_string (bit_width)),
629+ in_var,
630+ ToArg (num_rows),
631+ ToArg (row_size),
632+ ToArg (mult),
633+ ToArg (shift),
634+ softmax_params[0 ].buffer_var ,
635+ softmax_params[1 ].buffer_var ,
636+ out_var,
637+ };
547638
548- CreatePrimFuncForExtern (global_var, buffer_creator.GetPrimFuncParams (),
549- buffer_creator.GetBufferMap (), args);
639+ CreatePrimFuncForExtern (global_var, buffer_creator.GetPrimFuncParams (),
640+ buffer_creator.GetBufferMap (), args, PrimExpr (), 0 , 16 ,
641+ softmax_params);
642+ }
550643 }
551644
552645 struct BinaryElementwiseClipPattern {
0 commit comments