Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions backends/cadence/aot/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,11 @@
- arg_meta: null
kernel_name: impl::reference::quantized_relu_out

- func: cadence::quantized_relu.per_tensor_out(Tensor X, int X_zero_point, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::reference::quantized_relu_per_tensor_out

- func: cadence::quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
Expand Down
54 changes: 54 additions & 0 deletions backends/cadence/reference/operators/quantized_relu_out.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
*/

#include <executorch/backends/cadence/reference/kernels/kernels.h>
#include <executorch/backends/cadence/reference/operators/operators.h>
#include <executorch/runtime/kernel/kernel_includes.h>

namespace impl {
Expand Down Expand Up @@ -75,6 +76,59 @@ void quantized_relu_out(
}
}

template <typename T>
void quantized_relu_per_tensor_out_(
__ET_UNUSED KernelRuntimeContext& ctx,
const Tensor& input,
const int64_t in_zero_point,
const int64_t out_zero_point,
const int64_t out_multiplier,
const int64_t out_shift,
Tensor& output) {
const T* __restrict__ in = input.const_data_ptr<T>();
T* __restrict__ out = output.mutable_data_ptr<T>();

// Compute the out_scale from out_multiplier and out_shift
const float out_scale = -out_multiplier * 1.0 / (1 << 31) * pow(2, out_shift);

for (size_t i = 0, e = input.numel(); i < e; ++i) {
const float temp = in[i] > in_zero_point ? (in[i] - in_zero_point) : 0;
out[i] = kernels::quantize<T>(temp, out_scale, out_zero_point);
}
}

void quantized_relu_per_tensor_out(
KernelRuntimeContext& ctx,
const Tensor& input,
const int64_t in_zero_point,
const int64_t out_zero_point,
const int64_t out_multiplier,
const int64_t out_shift,
Tensor& output) {
#define typed_quantized_relu(ctype, dtype) \
case executorch::aten::ScalarType::dtype: { \
quantized_relu_per_tensor_out_<ctype>( \
ctx, \
input, \
in_zero_point, \
out_zero_point, \
out_multiplier, \
out_shift, \
output); \
break; \
}

executorch::aten::ScalarType dtype = input.scalar_type();
switch (dtype) {
ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_relu)
default:
ET_DCHECK_MSG(
false, "Unhandled dtype %s", torch::executor::toString(dtype));
}

#undef typed_quantized_relu
}

}; // namespace native
}; // namespace reference
}; // namespace impl
Loading