88
99#include < c10/util/irange.h>
1010#include < executorch/kernels/portable/cpu/util/activation_ops_util.h>
11+ #include < executorch/kernels/portable/cpu/util/elementwise_util.h>
1112#include < executorch/runtime/kernel/kernel_includes.h>
1213#include < executorch/runtime/platform/assert.h>
1314#include < cinttypes>
@@ -23,93 +24,6 @@ using ScalarType = executorch::aten::ScalarType;
2324
2425namespace {
2526
26- double exp_overload (double d) {
27- return exp (d);
28- }
29-
30- float exp_overload (float f) {
31- return expf (f);
32- }
33-
34- /* *
35- * In-place element-wise sigmoid function , i.e., f(x) = 1 / (1 + e^{-x})
36- */
37- // TODO: T146333648, refactor this as a common helper function
38- template <typename CTYPE_OUT>
39- void sigmoid_tensor (Tensor& out) {
40- CTYPE_OUT* out_data = out.mutable_data_ptr <CTYPE_OUT>();
41- for (const auto i : c10::irange (out.numel ())) {
42- out_data[i] = 1.0 / (1.0 + exp_overload (-out_data[i]));
43- }
44- }
45-
46- /* *
47- * Element-wise multiplication of the first half of `in` along the specified
48- * dimension and `out`, overwriting `out`.
49- */
50- template <typename CTYPE_IN, typename CTYPE_OUT>
51- void mul_tensors (const Tensor& in, int64_t dim, Tensor& out) {
52- size_t num_values = static_cast <size_t >(in.size (dim)) / 2 ;
53- size_t dim_length_in = static_cast <size_t >(in.size (dim));
54- size_t dim_length_out = static_cast <size_t >(out.size (dim));
55- size_t leading_dims = getLeadingDims (in, dim);
56- size_t trailing_dims = getTrailingDims (in, dim);
57-
58- const CTYPE_IN* input_data_base = in.const_data_ptr <CTYPE_IN>();
59- CTYPE_OUT* output_data_base = out.mutable_data_ptr <CTYPE_OUT>();
60-
61- for (const auto i : c10::irange (leading_dims)) {
62- const CTYPE_IN* input_data =
63- input_data_base + i * dim_length_in * trailing_dims;
64- CTYPE_OUT* output_data =
65- output_data_base + i * dim_length_out * trailing_dims;
66- for ([[maybe_unused]] const auto j : c10::irange (num_values)) {
67- for (const auto k : c10::irange (trailing_dims)) {
68- output_data[k] = static_cast <CTYPE_OUT>(input_data[k]) * output_data[k];
69- }
70- input_data += trailing_dims;
71- output_data += trailing_dims;
72- }
73- }
74- }
75-
76- /* *
77- * Slice the tensor in the given dim, from start to end, assume tensor in and
78- * out have same shape and dtype, the dim is a non-negative number and start,
79- * end are valid non-negative number
80- */
81- template <typename CTYPE_IN, typename CTYPE_OUT>
82- void slice_tensor (
83- const Tensor& in,
84- int64_t dim,
85- int64_t start,
86- int64_t end,
87- Tensor& out) {
88- size_t num_values = static_cast <size_t >(end - start);
89- size_t dim_length_in = static_cast <size_t >(in.size (dim));
90- size_t dim_length_out = static_cast <size_t >(out.size (dim));
91- size_t non_negative_start = static_cast <size_t >(start);
92- size_t leading_dims = getLeadingDims (in, dim);
93- size_t trailing_dims = getTrailingDims (in, dim);
94-
95- const CTYPE_IN* input_data_base = in.const_data_ptr <CTYPE_IN>();
96- CTYPE_OUT* output_data_base = out.mutable_data_ptr <CTYPE_OUT>();
97-
98- for (const auto i : c10::irange (leading_dims)) {
99- const CTYPE_IN* input_data = input_data_base +
100- (i * dim_length_in + non_negative_start) * trailing_dims;
101- CTYPE_OUT* output_data =
102- output_data_base + i * dim_length_out * trailing_dims;
103- for ([[maybe_unused]] const auto j : c10::irange (num_values)) {
104- for (const auto k : c10::irange (trailing_dims)) {
105- output_data[k] = static_cast <CTYPE_OUT>(input_data[k]);
106- }
107- input_data += trailing_dims;
108- output_data += trailing_dims;
109- }
110- }
111- }
112-
11327/* *
11428 * Applies the gated linear unit function
11529 *
@@ -120,11 +34,63 @@ void slice_tensor(
12034 * 2. The output shall be in float types (Float, Double)
12135 */
12236template <typename CTYPE_IN, typename CTYPE_OUT>
123- Tensor& glu_out_tensor (const Tensor& self, int64_t dim, Tensor& out) {
37+ Tensor& glu_out_tensor (
38+ KernelRuntimeContext& ctx,
39+ const Tensor& self,
40+ int64_t dim,
41+ Tensor& out) {
12442 const auto self_size = self.size (dim);
125- slice_tensor<CTYPE_IN, CTYPE_OUT>(self, dim, self_size / 2 , self_size, out);
126- sigmoid_tensor<CTYPE_OUT>(out);
127- mul_tensors<CTYPE_IN, CTYPE_OUT>(self, dim, out);
43+ ET_KERNEL_CHECK (
44+ ctx,
45+ self.dim () <= static_cast <ssize_t >(kTensorDimensionLimit ),
46+ InvalidArgument,
47+ out);
48+ std::array<executorch::aten::SizesType, kTensorDimensionLimit > half_sizes;
49+ std::copy (self.sizes ().begin (), self.sizes ().end (), half_sizes.begin ());
50+ half_sizes[dim] /= 2 ;
51+ TensorImpl first_half_impl (
52+ self.scalar_type (),
53+ self.dim (),
54+ half_sizes.data (),
55+ self.mutable_data_ptr (),
56+ const_cast <executorch::aten::DimOrderType*>(self.dim_order ().data ()),
57+ const_cast <executorch::aten::StridesType*>(self.strides ().data ()),
58+ self.shape_dynamism ());
59+ TensorImpl second_half_impl (
60+ self.scalar_type (),
61+ self.dim (),
62+ half_sizes.data (),
63+ reinterpret_cast <char *>(self.mutable_data_ptr ()) +
64+ self.strides ()[dim] * self_size / 2 * self.element_size (),
65+ const_cast <executorch::aten::DimOrderType*>(self.dim_order ().data ()),
66+ const_cast <executorch::aten::StridesType*>(self.strides ().data ()),
67+ self.shape_dynamism ());
68+ Tensor first_half (&first_half_impl);
69+ Tensor second_half (&second_half_impl);
70+ ScalarType compute_type =
71+ executorch::runtime::isFloatingType (self.scalar_type ())
72+ ? self.scalar_type ()
73+ : ScalarType::Float;
74+ // @lint-ignore CLANGTIDY facebook-hte-CArray
75+ static constexpr const char op_name[] = " glu.out" ;
76+ ET_SWITCH_FLOATHBF16_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
77+ utils::apply_bitensor_elementwise_fn<
78+ CTYPE_COMPUTE,
79+ op_name,
80+ utils::SupportedTensorDtypes::FLOATHBF16>(
81+ [](const auto val_a, const auto val_b) -> CTYPE_COMPUTE {
82+ // TODO: rewrite this to be vectorization-capable.
83+ const auto one = static_cast <decltype (val_a)>(1.0 );
84+ return val_a * (one / (one + std::exp (-val_b)));
85+ },
86+ ctx,
87+ first_half,
88+ utils::SupportedTensorDtypes::FLOATHBF16,
89+ second_half,
90+ utils::SupportedTensorDtypes::FLOATHBF16,
91+ out,
92+ utils::internal::SupportNoncontiguousTensors ());
93+ });
12894 return out;
12995}
13096} // namespace
@@ -158,7 +124,7 @@ Tensor& glu_out(
158124
159125 ET_SWITCH_FLOATHBF16_TYPES (in_dtype, ctx, " glu" , CTYPE_IN, [&]() {
160126 ET_SWITCH_FLOATHBF16_TYPES (out.scalar_type (), ctx, " glu" , CTYPE_OUT, [&]() {
161- glu_out_tensor<CTYPE_IN, CTYPE_OUT>(self, non_negative_dim, out);
127+ glu_out_tensor<CTYPE_IN, CTYPE_OUT>(ctx, self, non_negative_dim, out);
162128 });
163129 });
164130
0 commit comments