1515#include < executorch/runtime/kernel/kernel_runtime_context.h>
1616#include < executorch/runtime/kernel/thread_parallel_interface.h>
1717
18+ #ifdef ET_USE_PYTORCH_HEADERS
19+ #include < ATen/cpu/vec/vec.h>
20+ #endif // ET_USE_PYTORCH_HEADERS
21+
1822#include < array>
1923#include < utility>
2024
@@ -58,6 +62,38 @@ template <typename CTYPE_COMMON, typename Op, typename... Args>
5862using op_call_result =
5963 std::invoke_result_t <Op, ignore_first_yield_second<Args, CTYPE_COMMON>...>;
6064
65+ #ifdef ET_USE_PYTORCH_HEADERS
66+ template <typename T>
67+ struct is_vectorized : public std ::false_type {};
68+
69+ template <typename T>
70+ struct is_vectorized <at::vec::Vectorized<T>> : public std::true_type {};
71+
72+ // TODO: can_use_vectorized and can_use_vectorized_impl are a failed
73+ // attempt to use SFINAE to detect whether our generic lambda argument
74+ // with deduced return type would compile if it was passed
75+ // Vectorized<CTYPE_COMMON> instead of CTYPE_COMMON. SFINAE does not
76+ // work that way (see
77+ // e.g. https://stackoverflow.com/questions/53344484/hard-error-when-using-stdinvoke-result-t-with-a-generic-lambda,
78+ // https://stackoverflow.com/questions/31368601/how-to-detect-if-a-generic-lambda-is-uncompilable-in-c-14);
79+ // if we really want to do it then we need to at least require that
80+ // our lambdas actively participate in being SFINAE-friendly, as in
81+ // https://stackoverflow.com/questions/76525790/detecting-if-a-generic-lambda-with-certain-arguments-is-invocable.
82+ template <typename CTYPE_COMMON, typename Op, typename Enable=void , typename ... Args>
83+ struct can_use_vectorized_impl : std::false_type {};
84+ template <typename CTYPE_COMMON, typename Op, typename ... Args>
85+ struct can_use_vectorized_impl <CTYPE_COMMON, Op, typename std::void_t <decltype (std::declval<std::invoke_result_t <
86+ Op,
87+ ignore_first_yield_second<Args, at::vec::Vectorized<CTYPE_COMMON>>...>>().store(std::declval<CTYPE_COMMON*>()))>, Args...> : public std::true_type {};// std::bool_constant<is_vectorized<std::invoke_result_t<Op,ignore_first_yield_second<Args, at::vec::Vectorized<CTYPE_COMMON>>...>>::value> {};
88+
89+ // Can I call a function of type Op with sizeof...(Args) arguments of type
90+ // at::vec::Vectorized<CTYPE_COMMON>?
91+ // This is not possible in C++17 as the code is currently set up; see TODO above.
92+ template <typename CTYPE_COMMON, typename Op, typename ...Args>
93+ struct can_use_vectorized : public can_use_vectorized_impl <CTYPE_COMMON, Op, void , Args...> {};
94+
95+ #endif // ET_USE_PYTORCH_HEADERS
96+
6197template <
6298 typename CTYPE_COMMON,
6399 typename CTYPE_OUT,
@@ -68,14 +104,72 @@ inline void dtype_specialized_elementwise_fn_impl(
68104 KernelRuntimeContext& ctx,
69105 const Tensor& out,
70106 Args... inputs) {
107+ static_assert (
108+ (std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> &&
109+ ...));
71110 constexpr auto kNumInputs = sizeof ...(inputs);
72- ET_DCHECK (((inputs.first ->element_size () == sizeof (CTYPE_COMMON)) && ...));
111+ // All inputs must be of type CTYPE_COMMON.
112+ ET_DCHECK (
113+ ((inputs.first ->scalar_type () ==
114+ CppTypeToScalarType<CTYPE_COMMON>::value) &&
115+ ...));
73116
74117 std::array<const CTYPE_COMMON*, kNumInputs > inputs_data_ptrs = {
75118 inputs.first ->template const_data_ptr <CTYPE_COMMON>()...};
76119
77120 CTYPE_OUT* const data_out = out.mutable_data_ptr <CTYPE_OUT>();
78121
122+ #ifdef ET_USE_PYTORCH_HEADERS
123+ if constexpr (can_use_vectorized<CTYPE_COMMON, Op, Args...>::value) {
124+ const bool any_is_broadcasted =
125+ !(torch::executor::internal::sizes_match_ignoring_leading_1s (
126+ inputs.first ->sizes (), out.sizes ()) &&
127+ ...);
128+ if (!any_is_broadcasted) {
129+ using Vec = at::vec::Vectorized<CTYPE_COMMON>;
130+ ::executorch::extension::parallel_for (
131+ 0 ,
132+ out.numel(),
133+ ::executorch::extension::internal::GRAIN_SIZE,
134+ [&](const auto begin, const auto end) {
135+ const auto vectorized_begin =
136+ begin + (Vec::size () - begin % Vec::size ()) % Vec::size ();
137+ const auto vectorized_end = end - (end % Vec::size ());
138+ // Scalar prologue.
139+ for (const auto idx : c10::irange (begin, vectorized_begin)) {
140+ std::array<CTYPE_COMMON, kNumInputs > loaded_inputs;
141+ for (const auto input_idx : c10::irange (kNumInputs )) {
142+ loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx];
143+ }
144+ data_out[idx] = std::apply (compute_fun, loaded_inputs);
145+ }
146+
147+ // Main vectorized loop.
148+ for (auto idx = vectorized_begin; idx < vectorized_end;
149+ idx += Vec::size ()) {
150+ std::array<Vec, kNumInputs > loaded_vec_inputs;
151+ for (const auto input_idx : c10::irange (kNumInputs )) {
152+ loaded_vec_inputs[input_idx] =
153+ Vec::loadu (&inputs_data_ptrs[input_idx][idx]);
154+ }
155+ auto result_vec = std::apply (compute_fun, loaded_vec_inputs);
156+ result_vec.store (&data_out[idx]);
157+ }
158+
159+ // Scalar epilogue.
160+ for (const auto idx : c10::irange (vectorized_end, end)) {
161+ std::array<CTYPE_COMMON, kNumInputs > loaded_inputs;
162+ for (const auto input_idx : c10::irange (kNumInputs )) {
163+ loaded_inputs[input_idx] = inputs_data_ptrs[input_idx][idx];
164+ }
165+ data_out[idx] = std::apply (compute_fun, loaded_inputs);
166+ }
167+ });
168+ return ;
169+ }
170+ }
171+ #endif
172+
79173 ::executorch::extension::parallel_for (
80174 0 ,
81175 out.numel(),
0 commit comments