@@ -33,7 +33,8 @@ namespace executorch::cpublas::internal {
3333constexpr auto kF32RegisterPairsPerIteration = 4 ;
3434constexpr auto kF32RegistersPerIteration = kF32RegisterPairsPerIteration * 2 ;
3535constexpr auto kF32ElementsPerRegister = vec::Vectorized<float >::size();
36- constexpr auto kF32ElementsPerIteration = kF32RegistersPerIteration * kF32ElementsPerRegister ;
36+ constexpr auto kF32ElementsPerIteration =
37+ kF32RegistersPerIteration * kF32ElementsPerRegister ;
3738
3839namespace {
3940template <typename T>
@@ -58,8 +59,8 @@ constexpr int IntegerLog2(T n, int p = 0) {
5859 * copies of the Software, and to permit persons to whom the Software is
5960 * furnished to do so, subject to the following conditions:
6061 *
61- * The above copyright notice and this permission notice shall be included in all
62- * copies or substantial portions of the Software.
62+ * The above copyright notice and this permission notice shall be included in
63+ * all copies or substantial portions of the Software.
6364 *
6465 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
6566 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
@@ -74,9 +75,7 @@ float reduce(vec::Vectorized<float> x) {
7475#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE)
7576 return vaddvq_f32 (x);
7677#else
77- return vec::vec_reduce_all<float >(
78- std::plus<vec::Vectorized<float >>(),
79- x);
78+ return vec::vec_reduce_all<float >(std::plus<vec::Vectorized<float >>(), x);
8079#endif
8180}
8281
@@ -86,12 +85,13 @@ float reduce(vec::Vectorized<float> x) {
8685// required notice.
8786float reduce (vec::VectorizedN<float , kF32RegistersPerIteration >& x) {
8887 int offset = kF32RegistersPerIteration ;
89- c10::ForcedUnroll<IntegerLog2 (kF32RegistersPerIteration )>{}([&offset, &x](auto idx) {
90- offset /= 2 ;
91- for (const auto i : c10::irange (offset)) {
92- x[i] = x[i] + x[offset + i];
93- }
94- });
88+ c10::ForcedUnroll<IntegerLog2 (kF32RegistersPerIteration )>{}(
89+ [&offset, &x](auto idx) {
90+ offset /= 2 ;
91+ for (const auto i : c10::irange (offset)) {
92+ x[i] = x[i] + x[offset + i];
93+ }
94+ });
9595 return reduce (x[0 ]);
9696}
9797
@@ -102,16 +102,20 @@ float reduce(vec::VectorizedN<float, kF32RegistersPerIteration>& x) {
102102// We would have to write a separate SVE-specific path to use SVE
103103// BFDOT. Deferring that for now to get the NEON/ASIMD BFDOT path
104104// working.
105- #if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && defined(__clang__) && __clang_major__ > 15
105+ #if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && \
106+ defined (__clang__) && __clang_major__ > 15
106107// https://godbolt.org/z/z8P4Yncra
107108#define COMPILER_SUPPORTS_BF16_TARGET 1
108- #elif defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && !defined(__clang__) && defined(__GNUC__) && __GNUC__ >= 10
109+ #elif defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && \
110+ !defined (__clang__) && defined (__GNUC__) && __GNUC__ >= 10
109111// https://gcc.gnu.org/gcc-10/changes.html
110112// https://godbolt.org/z/cdGG7vn8o
111113#define COMPILER_SUPPORTS_BF16_TARGET 1
112- #else // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && defined(__clang__) && __clang_major__ > 15
114+ #else // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) &&
115+ // defined(__clang__) && __clang_major__ > 15
113116#define COMPILER_SUPPORTS_BF16_TARGET 0
114- #endif // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) && defined(__clang__) && __clang_major__ > 15
117+ #endif // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) &&
118+ // defined(__clang__) && __clang_major__ > 15
115119
116120#if COMPILER_SUPPORTS_BF16_TARGET
117121#define TARGET_ARM_BF16_ATTRIBUTE __attribute__ ((target(" arch=armv8.2-a+bf16" )))
@@ -128,25 +132,25 @@ dot_with_fp32_arith_main_inner_loop_bfdot(
128132 // bfloat16x8_t. I suspect a bug or incomplete
129133 // __attribute__((target)) implementation. Intrinsics should be fine
130134 // because we're using vbfdotq_f32 below anyway.
131- const auto temp_vec1 = vld1q_bf16 (
132- reinterpret_cast <const bfloat16_t *>(
133- &vec1[registerPairIndex * vec::Vectorized<BFloat16>::size ()]));
134- const auto temp_vec2 = vld1q_bf16 (
135- reinterpret_cast <const bfloat16_t *>(
136- &vec2[registerPairIndex * vec::Vectorized<BFloat16>::size ()]));
135+ const auto temp_vec1 = vld1q_bf16 (reinterpret_cast <const bfloat16_t *>(
136+ &vec1[registerPairIndex * vec::Vectorized<BFloat16>::size ()]));
137+ const auto temp_vec2 = vld1q_bf16 (reinterpret_cast <const bfloat16_t *>(
138+ &vec2[registerPairIndex * vec::Vectorized<BFloat16>::size ()]));
137139 sum[registerPairIndex] =
138- vbfdotq_f32 (sum[registerPairIndex], temp_vec1, temp_vec2);
140+ vbfdotq_f32 (sum[registerPairIndex], temp_vec1, temp_vec2);
139141}
140142
141- TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE
142- void dot_with_fp32_arith_vectorized_tail_inner_loop_bfdot (
143+ TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void
144+ dot_with_fp32_arith_vectorized_tail_inner_loop_bfdot (
143145 const at::BFloat16* vec1,
144146 const at::BFloat16* vec2,
145147 vec::Vectorized<float >* tail_sum,
146148 int idx) {
147149 // See NOTE[Intrinsics in bfdot variant] above.
148- const auto temp_vec1 = vld1q_bf16 (reinterpret_cast <const bfloat16_t *>(&vec1[idx]));
149- const auto temp_vec2 = vld1q_bf16 (reinterpret_cast <const bfloat16_t *>(&vec2[idx]));
150+ const auto temp_vec1 =
151+ vld1q_bf16 (reinterpret_cast <const bfloat16_t *>(&vec1[idx]));
152+ const auto temp_vec2 =
153+ vld1q_bf16 (reinterpret_cast <const bfloat16_t *>(&vec2[idx]));
150154 *tail_sum = vbfdotq_f32 (*tail_sum, temp_vec1, temp_vec2);
151155}
152156
@@ -156,14 +160,17 @@ void dot_with_fp32_arith_vectorized_tail_inner_loop_bfdot(
156160
157161namespace {
158162
159- [[maybe_unused]] std::pair<vec::Vectorized<float >, vec::Vectorized<float >> fmadd (
163+ [[maybe_unused]] std::pair<vec::Vectorized<float >, vec::Vectorized<float >>
164+ fmadd (
160165 const vec::Vectorized<c10::BFloat16>& a,
161166 const vec::Vectorized<c10::BFloat16>& b,
162167 const vec::Vectorized<float >& acc_low,
163168 const vec::Vectorized<float >& acc_high) {
164169 const auto [a_float_low, a_float_high] = convert_bfloat16_float (a);
165170 const auto [b_float_low, b_float_high] = convert_bfloat16_float (b);
166- return std::make_pair (fmadd (a_float_low, b_float_low, acc_low), fmadd (a_float_high, b_float_high, acc_high));
171+ return std::make_pair (
172+ fmadd (a_float_low, b_float_low, acc_low),
173+ fmadd (a_float_high, b_float_high, acc_high));
167174}
168175
169176[[maybe_unused]] vec::Vectorized<float > fmadd (
@@ -172,21 +179,28 @@ namespace {
172179 const vec::Vectorized<c10::BFloat16>& b) {
173180 const auto [a_float_low, a_float_high] = convert_bfloat16_float (a);
174181 const auto [b_float_low, b_float_high] = convert_bfloat16_float (b);
175- return fmadd (a_float_high, b_float_high, fmadd (a_float_low, b_float_low, acc));
182+ return fmadd (
183+ a_float_high, b_float_high, fmadd (a_float_low, b_float_low, acc));
176184}
177185} // namespace
178186
179187template <typename T>
180188C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop_no_bfdot (
181- const T* vec1,
182- const T* vec2,
183- vec::VectorizedN<float , kF32RegistersPerIteration >& sum,
184- int registerPairIndex) {
189+ const T* vec1,
190+ const T* vec2,
191+ vec::VectorizedN<float , kF32RegistersPerIteration >& sum,
192+ int registerPairIndex) {
185193 static_assert (std::is_same_v<T, BFloat16>);
186- const auto temp_vec1 = vec::Vectorized<T>::loadu (&vec1[registerPairIndex * vec::Vectorized<T>::size ()]);
187- const auto temp_vec2 = vec::Vectorized<T>::loadu (&vec2[registerPairIndex * vec::Vectorized<T>::size ()]);
188-
189- const auto [result_low, result_high] = fmadd (temp_vec1, temp_vec2, sum[2 * registerPairIndex], sum[2 * registerPairIndex + 1 ]);
194+ const auto temp_vec1 = vec::Vectorized<T>::loadu (
195+ &vec1[registerPairIndex * vec::Vectorized<T>::size ()]);
196+ const auto temp_vec2 = vec::Vectorized<T>::loadu (
197+ &vec2[registerPairIndex * vec::Vectorized<T>::size ()]);
198+
199+ const auto [result_low, result_high] = fmadd (
200+ temp_vec1,
201+ temp_vec2,
202+ sum[2 * registerPairIndex],
203+ sum[2 * registerPairIndex + 1 ]);
190204 sum[2 * registerPairIndex] = result_low;
191205 sum[2 * registerPairIndex + 1 ] = result_high;
192206}
@@ -203,19 +217,19 @@ C10_ALWAYS_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop_no_bfdot(
203217}
204218
205219template <typename T>
206- C10_ALWAYS_INLINE auto
207- dot_with_fp32_arith_main_loop_no_bfdot (
220+ C10_ALWAYS_INLINE auto dot_with_fp32_arith_main_loop_no_bfdot (
208221 const T* vec1,
209222 const T* vec2,
210223 int64_t len) {
211224 vec::VectorizedN<float , kF32RegistersPerIteration > sum (0 );
212225 const auto len_aligned = len & ~(kF32ElementsPerIteration - 1 );
213- for (int j = 0 ; j < len_aligned ; j += kF32ElementsPerIteration ) {
226+ for (int j = 0 ; j < len_aligned; j += kF32ElementsPerIteration ) {
214227 const auto * vec1_ = vec1 + j;
215228 const auto * vec2_ = vec2 + j;
216- c10::ForcedUnroll<kF32RegisterPairsPerIteration >{}([vec1_, vec2_, &sum](auto k) C10_ALWAYS_INLINE_ATTRIBUTE {
217- dot_with_fp32_arith_main_inner_loop_no_bfdot (vec1_, vec2_, sum, k);
218- });
229+ c10::ForcedUnroll<kF32RegisterPairsPerIteration >{}(
230+ [vec1_, vec2_, &sum](auto k) C10_ALWAYS_INLINE_ATTRIBUTE {
231+ dot_with_fp32_arith_main_inner_loop_no_bfdot (vec1_, vec2_, sum, k);
232+ });
219233 }
220234 return reduce (sum);
221235}
@@ -224,7 +238,8 @@ dot_with_fp32_arith_main_loop_no_bfdot(
224238template <int n>
225239struct ForcedUnrollTargetBFloat16 {
226240 template <typename Func>
227- TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void operator ()(const Func& f) const {
241+ TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void operator ()(
242+ const Func& f) const {
228243 ForcedUnrollTargetBFloat16<n - 1 >{}(f);
229244 f (n - 1 );
230245 }
@@ -233,7 +248,8 @@ struct ForcedUnrollTargetBFloat16 {
233248template <>
234249struct ForcedUnrollTargetBFloat16 <1 > {
235250 template <typename Func>
236- TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void operator ()(const Func& f) const {
251+ TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void operator ()(
252+ const Func& f) const {
237253 f (0 );
238254 }
239255};
@@ -245,20 +261,22 @@ dot_with_fp32_arith_main_loop_bfdot(
245261 int64_t len) {
246262 vec::VectorizedN<float , kF32RegistersPerIteration > sum (0 );
247263 const auto len_aligned = len & ~(kF32ElementsPerIteration - 1 );
248- for (int j = 0 ; j < len_aligned ; j += kF32ElementsPerIteration ) {
264+ for (int j = 0 ; j < len_aligned; j += kF32ElementsPerIteration ) {
249265 const auto * vec1_ = vec1 + j;
250266 const auto * vec2_ = vec2 + j;
251- ForcedUnrollTargetBFloat16<kF32RegisterPairsPerIteration >{}([vec1_, vec2_, &sum](auto k)
252- C10_ALWAYS_INLINE_ATTRIBUTE TARGET_ARM_BF16_ATTRIBUTE {
253- dot_with_fp32_arith_main_inner_loop_bfdot (vec1_, vec2_, sum, k);
254- });
267+ ForcedUnrollTargetBFloat16<kF32RegisterPairsPerIteration >{}(
268+ [vec1_, vec2_, &sum](auto k)
269+ C10_ALWAYS_INLINE_ATTRIBUTE TARGET_ARM_BF16_ATTRIBUTE {
270+ dot_with_fp32_arith_main_inner_loop_bfdot (vec1_, vec2_, sum, k);
271+ });
255272 }
256273 return reduce (sum);
257274}
258275#endif // COMPILER_SUPPORTS_BF16_TARGET
259276
260277static_assert (
261- (vec::Vectorized<BFloat16>::size() & (vec::Vectorized<BFloat16>::size() - 1 )) == 0 ,
278+ (vec::Vectorized<BFloat16>::size() &
279+ (vec::Vectorized<BFloat16>::size() - 1 )) == 0 ,
262280 " Below code expects power-of-2 vector register size!" );
263281
264282// NOTE [GCC code duplication]: The first attempt at landing BFDOT support with
@@ -267,31 +285,35 @@ static_assert(
267285// function. We can work around this by duplicating the code into the
268286// bfdot and non-bfdot callsites. The code is in this macro to avoid
269287// actual copy/paste.
270- #define DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY (bfdot_suffix ) \
271- /* First-tier tail fixup: make sure we handle workloads that can */ \
272- /* benefit from vectorization, but don't fit into our fully unrolled */ \
273- /* loop above. */ \
274- vec::Vectorized<float > tail_sum (0 ); \
275- const auto len_aligned = len & ~(kF32ElementsPerIteration - 1 ); \
288+ #define DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY (bfdot_suffix ) \
289+ /* First-tier tail fixup: make sure we handle workloads that can */ \
290+ /* benefit from vectorization, but don't fit into our fully unrolled */ \
291+ /* loop above. */ \
292+ vec::Vectorized<float > tail_sum (0 ); \
293+ const auto len_aligned = len & ~(kF32ElementsPerIteration - 1 ); \
276294 const auto len_aligned_vec = len & ~(vec::Vectorized<BFloat16>::size() - 1 ); \
277- for (int j = len_aligned; j < len_aligned_vec; j += vec::Vectorized<BFloat16>::size()) { \
278- dot_with_fp32_arith_vectorized_tail_inner_loop##bfdot_suffix (vec1, vec2, &tail_sum, j); \
279- } \
280- reduced_sum += reduce(tail_sum); \
281- \
282- /* Second-tier tail fixup: handle all workloads. */ \
283- for (const auto j : c10::irange(len_aligned_vec, len)) { \
284- /* Attempting to use Half here caused multiple test failures; */ \
285- /* using float to unbreak. (Suspect we need a scalar FMA.) */ \
286- float x1 = vec1[j]; \
287- float x2 = vec2[j]; \
288- reduced_sum += x1 * x2; \
289- } \
295+ for (int j = len_aligned; j < len_aligned_vec; \
296+ j += vec::Vectorized<BFloat16>::size()) { \
297+ dot_with_fp32_arith_vectorized_tail_inner_loop##bfdot_suffix ( \
298+ vec1, vec2, &tail_sum, j); \
299+ } \
300+ reduced_sum += reduce(tail_sum); \
301+ \
302+ /* Second-tier tail fixup: handle all workloads. */ \
303+ for (const auto j : c10::irange(len_aligned_vec, len)) { \
304+ /* Attempting to use Half here caused multiple test failures; */ \
305+ /* using float to unbreak. (Suspect we need a scalar FMA.) */ \
306+ float x1 = vec1[j]; \
307+ float x2 = vec2[j]; \
308+ reduced_sum += x1 * x2; \
309+ } \
290310 return reduced_sum
291311
292312#if COMPILER_SUPPORTS_BF16_TARGET
293- TARGET_ARM_BF16_ATTRIBUTE float
294- dot_with_fp32_arith_bfdot (const BFloat16* vec1, const BFloat16* vec2, int64_t len) {
313+ TARGET_ARM_BF16_ATTRIBUTE float dot_with_fp32_arith_bfdot (
314+ const BFloat16* vec1,
315+ const BFloat16* vec2,
316+ int64_t len) {
295317 auto reduced_sum = dot_with_fp32_arith_main_loop_bfdot (vec1, vec2, len);
296318 DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY (_bfdot);
297319}
@@ -307,7 +329,10 @@ dot_with_fp32_arith_no_bfdot(const T* vec1, const T* vec2, int64_t len) {
307329
308330} // namespace
309331
310- float bf16_dot_with_fp32_arith (const at::BFloat16* vec1, const at::BFloat16* vec2, int64_t len) {
332+ float bf16_dot_with_fp32_arith (
333+ const at::BFloat16* vec1,
334+ const at::BFloat16* vec2,
335+ int64_t len) {
311336#if COMPILER_SUPPORTS_BF16_TARGET
312337 if (cpuinfo_has_arm_bf16 ()) {
313338 return dot_with_fp32_arith_bfdot (vec1, vec2, len);
0 commit comments