diff --git a/backends/cadence/aot/functions.yaml b/backends/cadence/aot/functions.yaml index 173538cc2ee..698cdd7c0e4 100644 --- a/backends/cadence/aot/functions.yaml +++ b/backends/cadence/aot/functions.yaml @@ -238,3 +238,13 @@ kernels: - arg_meta: null kernel_name: impl::reference::quantized_conv_per_tensor_out + +- func: cadence::quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::reference::quantized_fully_connected_out + +- func: cadence::quantized_fully_connected.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::reference::quantized_fully_connected_per_tensor_out diff --git a/backends/cadence/reference/operators/CMakeLists.txt b/backends/cadence/reference/operators/CMakeLists.txt index ce926d86018..b2c16bf538c 100644 --- a/backends/cadence/reference/operators/CMakeLists.txt +++ b/backends/cadence/reference/operators/CMakeLists.txt @@ -87,6 +87,7 @@ add_library( "quantized_relu_out.cpp" "quantized_layer_norm.cpp" "quantize_per_tensor.cpp" + "quantized_fully_connected_out.cpp" "dequantize_per_tensor.cpp" "quantized_matmul_out.cpp" "im2row_out.cpp" diff --git a/backends/cadence/reference/operators/quantized_fully_connected_out.cpp b/backends/cadence/reference/operators/quantized_fully_connected_out.cpp new file mode 100644 index 00000000000..77b7dd94e9d --- /dev/null +++ b/backends/cadence/reference/operators/quantized_fully_connected_out.cpp @@ -0,0 +1,97 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include + +namespace impl { +namespace reference { +namespace native { + +using ::executorch::aten::optional; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +void quantized_fully_connected_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& in, + const Tensor& weight, + const Tensor& bias, + int64_t in_zero_point, + const Tensor& weight_zero_point_t, + const Tensor& out_multiplier, + const Tensor& out_shift, + int64_t out_zero_point, + __ET_UNUSED const optional& offset, + Tensor& out) { +#define typed_quantized_linear(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_linear_( \ + in, \ + weight, \ + bias, \ + in_zero_point, \ + weight_zero_point_t, \ + out_multiplier, \ + out_shift, \ + out_zero_point, \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_linear); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } +#undef typed_quantized_linear +} + +void quantized_fully_connected_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& in, + const Tensor& weight, + const Tensor& bias, + int64_t in_zero_point, + int64_t weight_zero_point, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + __ET_UNUSED const optional& offset, + Tensor& out) { +#define typed_quantized_linear(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_linear_per_tensor_( \ + in, \ + weight, \ + bias, \ + in_zero_point, \ + weight_zero_point, \ + out_multiplier, \ + out_shift, \ + out_zero_point, \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_linear); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } +#undef typed_quantized_linear +} + +}; // namespace native +}; // namespace reference +}; // namespace impl