Skip to content

Commit e16de6d

Browse files
delockloadamstjruwase
authored
[CPU] add fp16 support to shm inference_all_reduce (#5669)
This PR adds FP16 support to DeepSpeed SHM inference_all_reduce. Previously only FP32 and BF16 is supported. This is to align with PyTorch CPU support on FP16 datatype. --------- Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]>
1 parent d89e8cd commit e16de6d

File tree

3 files changed

+72
-70
lines changed

3 files changed

+72
-70
lines changed

csrc/cpu/comm/shm.cpp

+60-68
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,19 @@ inline __m256i cvt_fp32_to_bf16(const __m512 src)
143143
return _mm512_cvtusepi32_epi16(t_value);
144144
}
145145

146-
void reduce_2_bf16_buffers_iio(int num_elements, void* in0, void* in1, void* out)
147-
__attribute__((target("avx512bw")));
146+
__m512 cvt_fp16_to_fp32(const __m256i src) __attribute__((target("avx512bw")));
147+
inline __m512 cvt_fp16_to_fp32(const __m256i src) { return _mm512_cvtph_ps(src); }
148+
149+
inline __m256i cvt_fp32_to_fp16(const __m512 src) __attribute__((target("avx512bw")));
150+
inline __m256i cvt_fp32_to_fp16(const __m512 src)
151+
{
152+
return _mm512_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
153+
}
148154

149155
void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
150156
__attribute__((target("avx512bw")));
151157

152-
void reduce_2_fp32_buffers_iio(int num_elements, void* in0, void* in1, void* out)
158+
void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
153159
__attribute__((target("avx512bw")));
154160

155161
void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
@@ -164,26 +170,13 @@ void reduce_all_buffers(int start_elements,
164170
{
165171
switch (scalar_type) {
166172
case c10::ScalarType::BFloat16:
167-
if (world_size == 2) {
168-
// add the other buffer to to_buffer
169-
reduce_2_bf16_buffers_iio(num_elements,
170-
buffers[1 - to_buffer_idx] + start_elements * 2,
171-
to_buffer + start_elements * 2,
172-
to_buffer + start_elements * 2);
173-
} else {
174-
reduce_bf16_buffers(start_elements, num_elements, to_buffer, buffers);
175-
}
173+
reduce_bf16_buffers(start_elements, num_elements, to_buffer, buffers);
174+
break;
175+
case c10::ScalarType::Half:
176+
reduce_fp16_buffers(start_elements, num_elements, to_buffer, buffers);
176177
break;
177178
case c10::ScalarType::Float:
178-
if (world_size == 2) {
179-
reduce_2_fp32_buffers_iio(num_elements,
180-
buffers[1 - to_buffer_idx] + start_elements * 4,
181-
to_buffer + start_elements * 4,
182-
to_buffer + start_elements * 4);
183-
} else {
184-
assert(world_size > 2);
185-
reduce_fp32_buffers(start_elements, num_elements, to_buffer, buffers);
186-
}
179+
reduce_fp32_buffers(start_elements, num_elements, to_buffer, buffers);
187180
break;
188181
default: assert(!"Should not get here");
189182
}
@@ -197,8 +190,8 @@ void reduce_all_buffers(int start_elements,
197190

198191
// Reduce functions down below use vectorized algorithm, the number of bytes processed each
199192
// iteration depends on vector length. 256bit vector ==> 32 bytes, 512bit vector ==> 64 bytes
200-
// If you change implementation of reduce_2_bf16_buffers_iio or reduce_2_fp32_buffers_iio, check
201-
// whether this number needs to be changed
193+
// If you change implementation of reduce_bf16_buffers, etc. , check whether this number needs
194+
// to be changed
202195
#define VECTOR_LENGTH_IN_BYTES 32
203196

204197
void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
@@ -227,10 +220,9 @@ void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer,
227220
case 6: CVT_ADD_BF16(5);
228221
case 5: CVT_ADD_BF16(4);
229222
case 4: CVT_ADD_BF16(3);
230-
case 3:
231-
CVT_ADD_BF16(2);
232-
CVT_ADD_BF16(1);
233-
break;
223+
case 3: CVT_ADD_BF16(2);
224+
case 2: CVT_ADD_BF16(1);
225+
case 1: break;
234226
default:
235227
for (int j = 1; j < world_size; j++) {
236228
auto in_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[j] + i)));
@@ -251,7 +243,13 @@ void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer,
251243
}
252244
}
253245

254-
void reduce_2_bf16_buffers_iio(int num_elements, void* in0, void* in1, void* out)
246+
#define CVT_ADD_FP16(x) \
247+
do { \
248+
auto in##x##_val = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[x] + i))); \
249+
inout_val = _mm512_add_ps(inout_val, in##x##_val); \
250+
} while (0)
251+
252+
void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
255253
{
256254
const int element_size = 2;
257255
const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size;
@@ -260,19 +258,41 @@ void reduce_2_bf16_buffers_iio(int num_elements, void* in0, void* in1, void* out
260258

261259
// process aligned part
262260
#pragma omp parallel for
263-
for (int i = 0; i < main_elements * element_size; i += VECTOR_LENGTH_IN_BYTES) {
264-
auto in0_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)((char*)in0 + i)));
265-
auto in1_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)((char*)in1 + i)));
266-
auto out_val = _mm512_add_ps(in0_val, in1_val);
267-
_mm256_storeu_si256((__m256i*)((char*)out + i), cvt_fp32_to_bf16(out_val));
261+
for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size;
262+
i += VECTOR_LENGTH_IN_BYTES) {
263+
auto inout_val = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[0] + i)));
264+
switch (world_size) {
265+
case 16: CVT_ADD_FP16(15);
266+
case 15: CVT_ADD_FP16(14);
267+
case 14: CVT_ADD_FP16(13);
268+
case 13: CVT_ADD_FP16(12);
269+
case 12: CVT_ADD_FP16(11);
270+
case 11: CVT_ADD_FP16(10);
271+
case 10: CVT_ADD_FP16(9);
272+
case 9: CVT_ADD_FP16(8);
273+
case 8: CVT_ADD_FP16(7);
274+
case 7: CVT_ADD_FP16(6);
275+
case 6: CVT_ADD_FP16(5);
276+
case 5: CVT_ADD_FP16(4);
277+
case 4: CVT_ADD_FP16(3);
278+
case 3: CVT_ADD_FP16(2);
279+
case 2: CVT_ADD_FP16(1);
280+
case 1: break;
281+
default:
282+
for (int j = 1; j < world_size; j++) {
283+
auto in_val = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[j] + i)));
284+
inout_val = _mm512_add_ps(inout_val, in_val);
285+
}
286+
}
287+
_mm256_storeu_si256((__m256i*)(to_buffer + i), cvt_fp32_to_fp16(inout_val));
268288
}
269289

270290
// process remaining part
271-
int i = main_elements * element_size;
291+
int i = (start_elements + main_elements) * element_size;
272292
while (remain_elements > 0) {
273-
float in0_val = *((at::BFloat16*)((char*)in0 + i));
274-
float in1_val = *((at::BFloat16*)((char*)in1 + i));
275-
*((at::BFloat16*)((char*)out + i)) = in0_val + in1_val;
293+
float val = 0.0f;
294+
for (int j = 0; j < world_size; j++) { val += *(at::Half*)(buffers[j] + i); }
295+
*(at::Half*)(to_buffer + i) = val;
276296
remain_elements--;
277297
i += element_size;
278298
}
@@ -310,10 +330,9 @@ void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer,
310330
case 6: CVT_ADD_F32(5);
311331
case 5: CVT_ADD_F32(4);
312332
case 4: CVT_ADD_F32(3);
313-
case 3:
314-
CVT_ADD_F32(2);
315-
CVT_ADD_F32(1);
316-
break;
333+
case 3: CVT_ADD_F32(2);
334+
case 2: CVT_ADD_F32(1);
335+
case 1: break;
317336
default:
318337
for (int j = 1; j < world_size; j++) {
319338
auto in_val = _mm256_loadu_ps((float*)(buffers[j] + i));
@@ -334,33 +353,6 @@ void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer,
334353
}
335354
}
336355

337-
void reduce_2_fp32_buffers_iio(int num_elements, void* in0, void* in1, void* out)
338-
{
339-
const int element_size = 4;
340-
const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size;
341-
int main_elements = num_elements - (num_elements % vector_length);
342-
int remain_elements = num_elements % vector_length;
343-
344-
// process aligned part
345-
#pragma omp parallel for
346-
for (int i = 0; i < main_elements * element_size; i += VECTOR_LENGTH_IN_BYTES) {
347-
auto in0_val = _mm256_loadu_ps((float*)((char*)in0 + i));
348-
auto in1_val = _mm256_loadu_ps((float*)((char*)in1 + i));
349-
auto out_val = _mm256_add_ps(in0_val, in1_val);
350-
_mm256_storeu_ps((float*)((char*)out + i), out_val);
351-
}
352-
353-
// process remaining part
354-
int i = main_elements * element_size;
355-
while (remain_elements > 0) {
356-
float in0_val = *((float*)((char*)in0 + i));
357-
float in1_val = *((float*)((char*)in1 + i));
358-
*((float*)((char*)out + i)) = in0_val + in1_val;
359-
remain_elements--;
360-
i += element_size;
361-
}
362-
}
363-
364356
static bool is_initialized = 0;
365357
static int world_rank;
366358

csrc/cpu/comm/shm_interface.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ int inference_all_reduce(torch::Tensor& data, py::object op)
7979

8080
switch (data.scalar_type()) {
8181
case c10::ScalarType::BFloat16: data_size = numel * 2; break;
82+
case c10::ScalarType::Half: data_size = numel * 2; break;
8283
case c10::ScalarType::Float: data_size = numel * 4; break;
8384
default: data_type_fallback = true;
8485
}

tests/unit/comm/test_dist.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,22 @@ def test(self):
127127
assert torch.all(x == result)
128128

129129

130+
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
130131
class TestDistInferenceAllReduce(DistributedTest):
131-
world_size = 4
132+
device_count = get_accelerator().device_count()
133+
if device_count >= 4:
134+
world_size = [1, 2, 4]
135+
elif device_count >= 2:
136+
world_size = [1, 2]
137+
else:
138+
world_size = [1]
132139

133-
def test(self):
140+
def test(self, dtype):
134141
x = torch.ones(1, 3).to(get_accelerator().device_name()) * (dist.get_rank() + 1)
135142
sum_of_ranks = (dist.get_world_size() * (dist.get_world_size() + 1)) // 2
136143
result = torch.ones(1, 3).to(get_accelerator().device_name()) * sum_of_ranks
144+
result = result.to(dtype)
145+
x = x.to(dtype)
137146
dist.inference_all_reduce(x)
138147
assert torch.all(x == result)
139148

0 commit comments

Comments
 (0)