@@ -91,11 +91,6 @@ inline void dtype_specialized_elementwise_fn_impl(
9191 CppTypeToScalarType<CTYPE_COMMON>::value) &&
9292 ...));
9393
94- std::array<const CTYPE_COMMON*, kNumInputs > inputs_data_ptrs = {
95- inputs.first ->template const_data_ptr <CTYPE_COMMON>()...};
96-
97- CTYPE_OUT* const data_out = out.mutable_data_ptr <CTYPE_OUT>();
98-
9994#ifdef ET_USE_PYTORCH_HEADERS
10095 if constexpr (can_use_vectorized<CTYPE_COMMON, Op, Args...>()) {
10196 const bool any_is_broadcasted =
@@ -109,6 +104,11 @@ inline void dtype_specialized_elementwise_fn_impl(
109104 out.numel(),
110105 ::executorch::extension::internal::GRAIN_SIZE,
111106 [&](const auto begin, const auto end) {
107+ std::array<const CTYPE_COMMON*, kNumInputs > inputs_data_ptrs = {
108+ inputs.first ->template const_data_ptr <CTYPE_COMMON>()...};
109+
110+ CTYPE_OUT* const data_out = out.mutable_data_ptr <CTYPE_OUT>();
111+
112112 const auto vectorized_begin =
113113 begin + (Vec::size () - begin % Vec::size ()) % Vec::size ();
114114 const auto vectorized_end = end - (end % Vec::size ());
@@ -152,6 +152,11 @@ inline void dtype_specialized_elementwise_fn_impl(
152152 out.numel(),
153153 ::executorch::extension::internal::GRAIN_SIZE,
154154 [&](const auto begin, const auto end) {
155+ std::array<const CTYPE_COMMON*, kNumInputs > inputs_data_ptrs = {
156+ inputs.first ->template const_data_ptr <CTYPE_COMMON>()...};
157+
158+ CTYPE_OUT* const data_out = out.mutable_data_ptr <CTYPE_OUT>();
159+
155160 const auto range =
156161 BroadcastIndexesRange<kNumInputs >(out, (*inputs.first )...);
157162 auto begin_it = range.begin ();
0 commit comments