@@ -26,11 +26,12 @@ void quantized_relu_per_tensor_out(
2626    const  int64_t  out_multiplier,
2727    const  int64_t  out_shift,
2828    Tensor& output) {
29-   const  uint8_t  _in_zero_point = static_cast <uint8_t >(in_zero_point);
30-   const  uint8_t  _out_zero_point = static_cast <uint8_t >(out_zero_point);
31-   const  int32_t  _out_multiplier = static_cast <int32_t >(out_multiplier);
32-   const  int32_t  _out_shift = static_cast <int32_t >(out_shift);
29+     const  int32_t  _out_multiplier = static_cast <int32_t >(out_multiplier);
30+     const  int32_t  _out_shift = static_cast <int32_t >(out_shift);
31+ 
3332  if  (input.scalar_type () == executorch::aten::ScalarType::Byte) {
33+     const  uint8_t  _in_zero_point = static_cast <uint8_t >(in_zero_point);
34+     const  uint8_t  _out_zero_point = static_cast <uint8_t >(out_zero_point);
3435    const  uint8_t * p_in = input.const_data_ptr <uint8_t >();
3536    uint8_t * p_out = output.mutable_data_ptr <uint8_t >();
3637
@@ -48,6 +49,8 @@ void quantized_relu_per_tensor_out(
4849    ET_CHECK_MSG (ret_val == 0 , " An internal error occured"  );
4950
5051  } else  if  (input.scalar_type () == executorch::aten::ScalarType::Char) {
52+     const  int8_t  _in_zero_point = static_cast <int8_t >(in_zero_point);
53+     const  int8_t  _out_zero_point = static_cast <int8_t >(out_zero_point);
5154    const  int8_t * p_in = input.const_data_ptr <int8_t >();
5255    int8_t * p_out = output.mutable_data_ptr <int8_t >();
5356
@@ -72,28 +75,6 @@ void quantized_relu_per_tensor_out(
7275  }
7376}
7477
75- void  quantized_relu_per_tensor_out (
76-     KernelRuntimeContext& ctx,
77-     const  Tensor& input,
78-     const  Tensor& in_zero_point,
79-     const  int64_t  out_zero_point,
80-     const  Tensor& out_multiplier,
81-     const  Tensor& out_shift,
82-     Tensor& output) {
83-   int8_t  _in_zero_point = in_zero_point.const_data_ptr <int8_t >()[0 ];
84-   int32_t  _out_multiplier = out_multiplier.const_data_ptr <int32_t >()[0 ];
85-   int32_t  _out_shift = out_shift.const_data_ptr <int32_t >()[0 ];
86- 
87-   quantized_relu_per_tensor_out (
88-       ctx,
89-       input,
90-       _in_zero_point,
91-       out_zero_point,
92-       _out_multiplier,
93-       _out_shift,
94-       output);
95- }
96- 
9778} //  namespace native
9879} //  namespace HiFi
9980} //  namespace impl
0 commit comments