@@ -143,13 +143,19 @@ inline __m256i cvt_fp32_to_bf16(const __m512 src)
143
143
return _mm512_cvtusepi32_epi16 (t_value);
144
144
}
145
145
146
- void reduce_2_bf16_buffers_iio (int num_elements, void * in0, void * in1, void * out)
147
- __attribute__((target(" avx512bw" )));
146
+ __m512 cvt_fp16_to_fp32 (const __m256i src) __attribute__((target(" avx512bw" )));
147
+ inline __m512 cvt_fp16_to_fp32 (const __m256i src) { return _mm512_cvtph_ps (src); }
148
+
149
+ inline __m256i cvt_fp32_to_fp16 (const __m512 src) __attribute__((target(" avx512bw" )));
150
+ inline __m256i cvt_fp32_to_fp16 (const __m512 src)
151
+ {
152
+ return _mm512_cvtps_ph (src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
153
+ }
148
154
149
155
void reduce_bf16_buffers (int start_elements, int num_elements, char * to_buffer, char ** buffers)
150
156
__attribute__((target(" avx512bw" )));
151
157
152
- void reduce_2_fp32_buffers_iio (int num_elements, void * in0, void * in1, void * out )
158
+ void reduce_fp16_buffers (int start_elements, int num_elements, char * to_buffer, char ** buffers )
153
159
__attribute__((target(" avx512bw" )));
154
160
155
161
void reduce_fp32_buffers (int start_elements, int num_elements, char * to_buffer, char ** buffers)
@@ -164,26 +170,13 @@ void reduce_all_buffers(int start_elements,
164
170
{
165
171
switch (scalar_type) {
166
172
case c10::ScalarType::BFloat16:
167
- if (world_size == 2 ) {
168
- // add the other buffer to to_buffer
169
- reduce_2_bf16_buffers_iio (num_elements,
170
- buffers[1 - to_buffer_idx] + start_elements * 2 ,
171
- to_buffer + start_elements * 2 ,
172
- to_buffer + start_elements * 2 );
173
- } else {
174
- reduce_bf16_buffers (start_elements, num_elements, to_buffer, buffers);
175
- }
173
+ reduce_bf16_buffers (start_elements, num_elements, to_buffer, buffers);
174
+ break ;
175
+ case c10::ScalarType::Half:
176
+ reduce_fp16_buffers (start_elements, num_elements, to_buffer, buffers);
176
177
break ;
177
178
case c10::ScalarType::Float:
178
- if (world_size == 2 ) {
179
- reduce_2_fp32_buffers_iio (num_elements,
180
- buffers[1 - to_buffer_idx] + start_elements * 4 ,
181
- to_buffer + start_elements * 4 ,
182
- to_buffer + start_elements * 4 );
183
- } else {
184
- assert (world_size > 2 );
185
- reduce_fp32_buffers (start_elements, num_elements, to_buffer, buffers);
186
- }
179
+ reduce_fp32_buffers (start_elements, num_elements, to_buffer, buffers);
187
180
break ;
188
181
default : assert (!" Should not get here" );
189
182
}
@@ -197,8 +190,8 @@ void reduce_all_buffers(int start_elements,
197
190
198
191
// Reduce functions down below use vectorized algorithm, the number of bytes processed each
199
192
// iteration depends on vector length. 256bit vector ==> 32 bytes, 512bit vector ==> 64 bytes
200
- // If you change implementation of reduce_2_bf16_buffers_iio or reduce_2_fp32_buffers_iio , check
201
- // whether this number needs to be changed
193
+ // If you change implementation of reduce_bf16_buffers, etc. , check whether this number needs
194
+ // to be changed
202
195
#define VECTOR_LENGTH_IN_BYTES 32
203
196
204
197
void reduce_bf16_buffers (int start_elements, int num_elements, char * to_buffer, char ** buffers)
@@ -227,10 +220,9 @@ void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer,
227
220
case 6 : CVT_ADD_BF16 (5 );
228
221
case 5 : CVT_ADD_BF16 (4 );
229
222
case 4 : CVT_ADD_BF16 (3 );
230
- case 3 :
231
- CVT_ADD_BF16 (2 );
232
- CVT_ADD_BF16 (1 );
233
- break ;
223
+ case 3 : CVT_ADD_BF16 (2 );
224
+ case 2 : CVT_ADD_BF16 (1 );
225
+ case 1 : break ;
234
226
default :
235
227
for (int j = 1 ; j < world_size; j++) {
236
228
auto in_val = cvt_bf16_to_fp32 (_mm256_loadu_si256 ((__m256i*)(buffers[j] + i)));
@@ -251,7 +243,13 @@ void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer,
251
243
}
252
244
}
253
245
254
- void reduce_2_bf16_buffers_iio (int num_elements, void * in0, void * in1, void * out)
246
+ #define CVT_ADD_FP16 (x ) \
247
+ do { \
248
+ auto in##x##_val = cvt_fp16_to_fp32 (_mm256_loadu_si256 ((__m256i*)(buffers[x] + i))); \
249
+ inout_val = _mm512_add_ps (inout_val, in##x##_val); \
250
+ } while (0 )
251
+
252
+ void reduce_fp16_buffers (int start_elements, int num_elements, char * to_buffer, char ** buffers)
255
253
{
256
254
const int element_size = 2 ;
257
255
const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size;
@@ -260,19 +258,41 @@ void reduce_2_bf16_buffers_iio(int num_elements, void* in0, void* in1, void* out
260
258
261
259
// process aligned part
262
260
#pragma omp parallel for
263
- for (int i = 0 ; i < main_elements * element_size; i += VECTOR_LENGTH_IN_BYTES) {
264
- auto in0_val = cvt_bf16_to_fp32 (_mm256_loadu_si256 ((__m256i*)((char *)in0 + i)));
265
- auto in1_val = cvt_bf16_to_fp32 (_mm256_loadu_si256 ((__m256i*)((char *)in1 + i)));
266
- auto out_val = _mm512_add_ps (in0_val, in1_val);
267
- _mm256_storeu_si256 ((__m256i*)((char *)out + i), cvt_fp32_to_bf16 (out_val));
261
+ for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size;
262
+ i += VECTOR_LENGTH_IN_BYTES) {
263
+ auto inout_val = cvt_fp16_to_fp32 (_mm256_loadu_si256 ((__m256i*)(buffers[0 ] + i)));
264
+ switch (world_size) {
265
+ case 16 : CVT_ADD_FP16 (15 );
266
+ case 15 : CVT_ADD_FP16 (14 );
267
+ case 14 : CVT_ADD_FP16 (13 );
268
+ case 13 : CVT_ADD_FP16 (12 );
269
+ case 12 : CVT_ADD_FP16 (11 );
270
+ case 11 : CVT_ADD_FP16 (10 );
271
+ case 10 : CVT_ADD_FP16 (9 );
272
+ case 9 : CVT_ADD_FP16 (8 );
273
+ case 8 : CVT_ADD_FP16 (7 );
274
+ case 7 : CVT_ADD_FP16 (6 );
275
+ case 6 : CVT_ADD_FP16 (5 );
276
+ case 5 : CVT_ADD_FP16 (4 );
277
+ case 4 : CVT_ADD_FP16 (3 );
278
+ case 3 : CVT_ADD_FP16 (2 );
279
+ case 2 : CVT_ADD_FP16 (1 );
280
+ case 1 : break ;
281
+ default :
282
+ for (int j = 1 ; j < world_size; j++) {
283
+ auto in_val = cvt_fp16_to_fp32 (_mm256_loadu_si256 ((__m256i*)(buffers[j] + i)));
284
+ inout_val = _mm512_add_ps (inout_val, in_val);
285
+ }
286
+ }
287
+ _mm256_storeu_si256 ((__m256i*)(to_buffer + i), cvt_fp32_to_fp16 (inout_val));
268
288
}
269
289
270
290
// process remaining part
271
- int i = main_elements * element_size;
291
+ int i = (start_elements + main_elements) * element_size;
272
292
while (remain_elements > 0 ) {
273
- float in0_val = *((at::BFloat16*)(( char *)in0 + i)) ;
274
- float in1_val = *(( at::BFloat16 *)(( char *)in1 + i));
275
- *(( at::BFloat16 *)(( char *)out + i)) = in0_val + in1_val ;
293
+ float val = 0 . 0f ;
294
+ for ( int j = 0 ; j < world_size; j++) { val += *( at::Half *)(buffers[j] + i); }
295
+ *(at::Half *)(to_buffer + i) = val ;
276
296
remain_elements--;
277
297
i += element_size;
278
298
}
@@ -310,10 +330,9 @@ void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer,
310
330
case 6 : CVT_ADD_F32 (5 );
311
331
case 5 : CVT_ADD_F32 (4 );
312
332
case 4 : CVT_ADD_F32 (3 );
313
- case 3 :
314
- CVT_ADD_F32 (2 );
315
- CVT_ADD_F32 (1 );
316
- break ;
333
+ case 3 : CVT_ADD_F32 (2 );
334
+ case 2 : CVT_ADD_F32 (1 );
335
+ case 1 : break ;
317
336
default :
318
337
for (int j = 1 ; j < world_size; j++) {
319
338
auto in_val = _mm256_loadu_ps ((float *)(buffers[j] + i));
@@ -334,33 +353,6 @@ void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer,
334
353
}
335
354
}
336
355
337
- void reduce_2_fp32_buffers_iio (int num_elements, void * in0, void * in1, void * out)
338
- {
339
- const int element_size = 4 ;
340
- const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size;
341
- int main_elements = num_elements - (num_elements % vector_length);
342
- int remain_elements = num_elements % vector_length;
343
-
344
- // process aligned part
345
- #pragma omp parallel for
346
- for (int i = 0 ; i < main_elements * element_size; i += VECTOR_LENGTH_IN_BYTES) {
347
- auto in0_val = _mm256_loadu_ps ((float *)((char *)in0 + i));
348
- auto in1_val = _mm256_loadu_ps ((float *)((char *)in1 + i));
349
- auto out_val = _mm256_add_ps (in0_val, in1_val);
350
- _mm256_storeu_ps ((float *)((char *)out + i), out_val);
351
- }
352
-
353
- // process remaining part
354
- int i = main_elements * element_size;
355
- while (remain_elements > 0 ) {
356
- float in0_val = *((float *)((char *)in0 + i));
357
- float in1_val = *((float *)((char *)in1 + i));
358
- *((float *)((char *)out + i)) = in0_val + in1_val;
359
- remain_elements--;
360
- i += element_size;
361
- }
362
- }
363
-
364
356
static bool is_initialized = 0 ;
365
357
static int world_rank;
366
358
0 commit comments