From 1d4b41fe21b4278fa994d1ee72df8eff601b7e8b Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Sat, 5 Sep 2020 01:56:43 +0000 Subject: [PATCH 1/6] fixing adam copy fp16-param-add more compile flags for cpu_adam --- csrc/adam/cpu_adam.cpp | 526 ++++++++++++++++++++++++++++----------- csrc/includes/cpu_adam.h | 9 +- setup.py | 6 +- tests/unit/adam_test.py | 10 +- 4 files changed, 398 insertions(+), 153 deletions(-) mode change 100644 => 100755 csrc/adam/cpu_adam.cpp mode change 100644 => 100755 csrc/includes/cpu_adam.h diff --git a/csrc/adam/cpu_adam.cpp b/csrc/adam/cpu_adam.cpp old mode 100644 new mode 100755 index 2d3b521bea1c..ab5b6c0b7054 --- a/csrc/adam/cpu_adam.cpp +++ b/csrc/adam/cpu_adam.cpp @@ -49,54 +49,60 @@ void Adam_Optimizer::Step(float* _params, size_t tile = 0; -#pragma omp parallel for - for (size_t i = 0; i < _param_size; i += SIMD_WIDTH) { - __m512 grad_4 = _mm512_loadu_ps(grads + i); - - __m512 momntum_4 = _mm512_loadu_ps(_exp_avg + i); - __m512 varianc_4 = _mm512_loadu_ps(_exp_avg_sq + i); - - __m512 param_4 = _mm512_loadu_ps(_params + i); - - if (_weight_decay > 0) { - __m512 weight_decay4 = _mm512_set1_ps(_weight_decay); - grad_4 = _mm512_fmadd_ps(param_4, weight_decay4, grad_4); - } + for (size_t t = 0; t < _param_size; t += TILE) + { + size_t copy_size = TILE; + if((t + TILE) > _param_size)copy_size = _param_size - t; + size_t offset = copy_size + t; + #pragma omp parallel for + for(size_t i = t; i < offset;i += SIMD_WIDTH) + { + __m512 grad_4 = _mm512_loadu_ps(grads + i); + + __m512 momntum_4 = _mm512_loadu_ps(_exp_avg + i); + __m512 varianc_4 = _mm512_loadu_ps(_exp_avg_sq + i); + + __m512 param_4 = _mm512_loadu_ps(_params + i); + + if (_weight_decay > 0) { + __m512 weight_decay4 = _mm512_set1_ps(_weight_decay); + grad_4 = _mm512_fmadd_ps(param_4, weight_decay4, grad_4); + } - momntum_4 = _mm512_mul_ps(momntum_4, betta1_4); - momntum_4 = _mm512_fmadd_ps(grad_4, betta1_minus1_4, momntum_4); + momntum_4 = _mm512_mul_ps(momntum_4, betta1_4); + momntum_4 = _mm512_fmadd_ps(grad_4, betta1_minus1_4, momntum_4); - varianc_4 = _mm512_mul_ps(varianc_4, betta2_4); - grad_4 = _mm512_mul_ps(grad_4, grad_4); - varianc_4 = _mm512_fmadd_ps(grad_4, betta2_minus1_4, varianc_4); + varianc_4 = _mm512_mul_ps(varianc_4, betta2_4); + grad_4 = _mm512_mul_ps(grad_4, grad_4); + varianc_4 = _mm512_fmadd_ps(grad_4, betta2_minus1_4, varianc_4); - grad_4 = _mm512_sqrt_ps(varianc_4) / bias2_sqrt; - grad_4 = _mm512_add_ps(grad_4, eps_4); - grad_4 = _mm512_div_ps(momntum_4, grad_4); + grad_4 = _mm512_sqrt_ps(varianc_4) / bias2_sqrt; + grad_4 = _mm512_add_ps(grad_4, eps_4); + grad_4 = _mm512_div_ps(momntum_4, grad_4); - param_4 = _mm512_fmadd_ps(grad_4, step_size_4, param_4); + param_4 = _mm512_fmadd_ps(grad_4, step_size_4, param_4); - _mm512_storeu_ps(_params + i, param_4); - _mm512_storeu_ps(_exp_avg + i, momntum_4); - _mm512_storeu_ps(_exp_avg_sq + i, varianc_4); - if (dev_params) { - if ((i + SIMD_WIDTH) % TILE == 0) { - size_t offset = tile * TILE; -#pragma omp parallel for - for (size_t j = 0; j < TILE; j += 4) { - _doubled_buffer[buf_index][j] = (__half)_params[offset + j]; - _doubled_buffer[buf_index][j + 1] = (__half)_params[offset + j + 1]; - _doubled_buffer[buf_index][j + 2] = (__half)_params[offset + j + 2]; - _doubled_buffer[buf_index][j + 3] = (__half)_params[offset + j + 3]; - } - CUDA_CHECK(cudaMemcpyAsync(dev_params + tile * TILE, - _doubled_buffer[buf_index], - TILE * sizeof(__half), - cudaMemcpyHostToDevice, - Context::Instance().GetCurrentStream())); - buf_index = !buf_index; - tile++; + _mm512_storeu_ps(_params + i, param_4); + _mm512_storeu_ps(_exp_avg + i, momntum_4); + _mm512_storeu_ps(_exp_avg_sq + i, varianc_4); + } + if (dev_params) + { + #pragma omp parallel for + for (size_t j = 0; j < copy_size; j += 4) + { + _doubled_buffer[buf_index][j] = (__half)_params[t + j]; + _doubled_buffer[buf_index][j + 1] = (__half)_params[t + j + 1]; + _doubled_buffer[buf_index][j + 2] = (__half)_params[t + j + 2]; + _doubled_buffer[buf_index][j + 3] = (__half)_params[t + j + 3]; } + + CUDA_CHECK(cudaMemcpyAsync(dev_params + t, + _doubled_buffer[buf_index], + copy_size * sizeof(__half), + cudaMemcpyHostToDevice, + Context::Instance().GetCurrentStream())); + buf_index = !buf_index; } } } @@ -134,113 +140,120 @@ void Adam_Optimizer::Step_4(float* _params, __m512 bias2_sqrt = _mm512_sqrt_ps(bias_correction2_4); + for (size_t t = 0; t < _param_size; t += TILE) + { + size_t copy_size = TILE; + if((t + TILE) > _param_size)copy_size = _param_size - t; + size_t offset = copy_size + t; #pragma omp parallel for - for (size_t i = 0; i < _param_size; i += (SIMD_WIDTH << 2)) { - __m512 grad_4[4]; - grad_4[0] = _mm512_loadu_ps(grads + i); - grad_4[1] = _mm512_loadu_ps(grads + i + SIMD_WIDTH); - grad_4[2] = _mm512_loadu_ps(grads + i + (SIMD_WIDTH << 1)); - grad_4[3] = _mm512_loadu_ps(grads + i + SIMD_WIDTH * 3); - - __m512 momntum_4[2]; - momntum_4[0] = _mm512_loadu_ps(_exp_avg + i); - momntum_4[1] = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH); - momntum_4[2] = _mm512_loadu_ps(_exp_avg + i + (SIMD_WIDTH << 1)); - momntum_4[3] = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH * 3); - - __m512 varianc_4[2]; - varianc_4[0] = _mm512_loadu_ps(_exp_avg_sq + i); - varianc_4[1] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH); - varianc_4[2] = _mm512_loadu_ps(_exp_avg_sq + i + (SIMD_WIDTH << 1)); - varianc_4[3] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH * 3); - - __m512 param_4[2]; - param_4[0] = _mm512_loadu_ps(_params + i); - param_4[1] = _mm512_loadu_ps(_params + i + SIMD_WIDTH); - param_4[2] = _mm512_loadu_ps(_params + i + (SIMD_WIDTH << 1)); - param_4[3] = _mm512_loadu_ps(_params + i + SIMD_WIDTH * 3); - - if (_weight_decay > 0) { - __m512 weight_decay4 = _mm512_set1_ps(_weight_decay); - grad_4[0] = _mm512_fmadd_ps(param_4[0], weight_decay4, grad_4[0]); - grad_4[1] = _mm512_fmadd_ps(param_4[1], weight_decay4, grad_4[1]); - grad_4[2] = _mm512_fmadd_ps(param_4[2], weight_decay4, grad_4[2]); - grad_4[3] = _mm512_fmadd_ps(param_4[3], weight_decay4, grad_4[3]); + for(size_t i = t; i < offset;i += (SIMD_WIDTH << 2)) + { + __m512 grad_4[4]; + grad_4[0] = _mm512_loadu_ps(grads + i); + grad_4[1] = _mm512_loadu_ps(grads + i + SIMD_WIDTH); + grad_4[2] = _mm512_loadu_ps(grads + i + (SIMD_WIDTH << 1)); + grad_4[3] = _mm512_loadu_ps(grads + i + SIMD_WIDTH * 3); + + __m512 momntum_4[2]; + momntum_4[0] = _mm512_loadu_ps(_exp_avg + i); + momntum_4[1] = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH); + momntum_4[2] = _mm512_loadu_ps(_exp_avg + i + (SIMD_WIDTH << 1)); + momntum_4[3] = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH * 3); + + __m512 varianc_4[2]; + varianc_4[0] = _mm512_loadu_ps(_exp_avg_sq + i); + varianc_4[1] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH); + varianc_4[2] = _mm512_loadu_ps(_exp_avg_sq + i + (SIMD_WIDTH << 1)); + varianc_4[3] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH * 3); + + __m512 param_4[2]; + param_4[0] = _mm512_loadu_ps(_params + i); + param_4[1] = _mm512_loadu_ps(_params + i + SIMD_WIDTH); + param_4[2] = _mm512_loadu_ps(_params + i + (SIMD_WIDTH << 1)); + param_4[3] = _mm512_loadu_ps(_params + i + SIMD_WIDTH * 3); + + if (_weight_decay > 0) { + __m512 weight_decay4 = _mm512_set1_ps(_weight_decay); + grad_4[0] = _mm512_fmadd_ps(param_4[0], weight_decay4, grad_4[0]); + grad_4[1] = _mm512_fmadd_ps(param_4[1], weight_decay4, grad_4[1]); + grad_4[2] = _mm512_fmadd_ps(param_4[2], weight_decay4, grad_4[2]); + grad_4[3] = _mm512_fmadd_ps(param_4[3], weight_decay4, grad_4[3]); + } + + momntum_4[0] = _mm512_mul_ps(momntum_4[0], betta1_4); + momntum_4[0] = _mm512_fmadd_ps(grad_4[0], betta1_minus1_4, momntum_4[0]); + momntum_4[1] = _mm512_mul_ps(momntum_4[1], betta1_4); + momntum_4[1] = _mm512_fmadd_ps(grad_4[1], betta1_minus1_4, momntum_4[1]); + momntum_4[2] = _mm512_mul_ps(momntum_4[2], betta1_4); + momntum_4[2] = _mm512_fmadd_ps(grad_4[2], betta1_minus1_4, momntum_4[2]); + momntum_4[3] = _mm512_mul_ps(momntum_4[3], betta1_4); + momntum_4[3] = _mm512_fmadd_ps(grad_4[3], betta1_minus1_4, momntum_4[3]); + + varianc_4[0] = _mm512_mul_ps(varianc_4[0], betta2_4); + varianc_4[1] = _mm512_mul_ps(varianc_4[1], betta2_4); + varianc_4[2] = _mm512_mul_ps(varianc_4[2], betta2_4); + varianc_4[3] = _mm512_mul_ps(varianc_4[3], betta2_4); + grad_4[0] = _mm512_mul_ps(grad_4[0], grad_4[0]); + grad_4[1] = _mm512_mul_ps(grad_4[1], grad_4[1]); + grad_4[2] = _mm512_mul_ps(grad_4[2], grad_4[2]); + grad_4[3] = _mm512_mul_ps(grad_4[3], grad_4[3]); + varianc_4[0] = _mm512_fmadd_ps(grad_4[0], betta2_minus1_4, varianc_4[0]); + varianc_4[1] = _mm512_fmadd_ps(grad_4[1], betta2_minus1_4, varianc_4[1]); + varianc_4[2] = _mm512_fmadd_ps(grad_4[2], betta2_minus1_4, varianc_4[2]); + varianc_4[3] = _mm512_fmadd_ps(grad_4[3], betta2_minus1_4, varianc_4[3]); + + grad_4[0] = _mm512_sqrt_ps(varianc_4[0]) / bias2_sqrt; + grad_4[1] = _mm512_sqrt_ps(varianc_4[1]) / bias2_sqrt; + grad_4[2] = _mm512_sqrt_ps(varianc_4[2]) / bias2_sqrt; + grad_4[3] = _mm512_sqrt_ps(varianc_4[3]) / bias2_sqrt; + + grad_4[0] = _mm512_add_ps(grad_4[0], eps_4); + grad_4[1] = _mm512_add_ps(grad_4[1], eps_4); + grad_4[2] = _mm512_add_ps(grad_4[2], eps_4); + grad_4[3] = _mm512_add_ps(grad_4[3], eps_4); + grad_4[0] = _mm512_div_ps(momntum_4[0], grad_4[0]); + grad_4[1] = _mm512_div_ps(momntum_4[1], grad_4[1]); + grad_4[2] = _mm512_div_ps(momntum_4[2], grad_4[2]); + grad_4[3] = _mm512_div_ps(momntum_4[3], grad_4[3]); + + param_4[0] = _mm512_fmadd_ps(grad_4[0], step_size_4, param_4[0]); + param_4[1] = _mm512_fmadd_ps(grad_4[1], step_size_4, param_4[1]); + param_4[2] = _mm512_fmadd_ps(grad_4[2], step_size_4, param_4[2]); + param_4[3] = _mm512_fmadd_ps(grad_4[3], step_size_4, param_4[3]); + + _mm512_storeu_ps(_params + i, param_4[0]); + _mm512_storeu_ps(_params + i + SIMD_WIDTH, param_4[1]); + _mm512_storeu_ps(_params + i + (SIMD_WIDTH << 1), param_4[2]); + _mm512_storeu_ps(_params + i + SIMD_WIDTH * 3, param_4[3]); + + _mm512_storeu_ps(_exp_avg + i, momntum_4[0]); + _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH, momntum_4[1]); + _mm512_storeu_ps(_exp_avg + i + (SIMD_WIDTH << 1), momntum_4[2]); + _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH * 3, momntum_4[3]); + + _mm512_storeu_ps(_exp_avg_sq + i, varianc_4[0]); + _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH, varianc_4[1]); + _mm512_storeu_ps(_exp_avg_sq + i + (SIMD_WIDTH << 1), varianc_4[2]); + _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH * 3, varianc_4[3]); } - momntum_4[0] = _mm512_mul_ps(momntum_4[0], betta1_4); - momntum_4[0] = _mm512_fmadd_ps(grad_4[0], betta1_minus1_4, momntum_4[0]); - momntum_4[1] = _mm512_mul_ps(momntum_4[1], betta1_4); - momntum_4[1] = _mm512_fmadd_ps(grad_4[1], betta1_minus1_4, momntum_4[1]); - momntum_4[2] = _mm512_mul_ps(momntum_4[2], betta1_4); - momntum_4[2] = _mm512_fmadd_ps(grad_4[2], betta1_minus1_4, momntum_4[2]); - momntum_4[3] = _mm512_mul_ps(momntum_4[3], betta1_4); - momntum_4[3] = _mm512_fmadd_ps(grad_4[3], betta1_minus1_4, momntum_4[3]); - - varianc_4[0] = _mm512_mul_ps(varianc_4[0], betta2_4); - varianc_4[1] = _mm512_mul_ps(varianc_4[1], betta2_4); - varianc_4[2] = _mm512_mul_ps(varianc_4[2], betta2_4); - varianc_4[3] = _mm512_mul_ps(varianc_4[3], betta2_4); - grad_4[0] = _mm512_mul_ps(grad_4[0], grad_4[0]); - grad_4[1] = _mm512_mul_ps(grad_4[1], grad_4[1]); - grad_4[2] = _mm512_mul_ps(grad_4[2], grad_4[2]); - grad_4[3] = _mm512_mul_ps(grad_4[3], grad_4[3]); - varianc_4[0] = _mm512_fmadd_ps(grad_4[0], betta2_minus1_4, varianc_4[0]); - varianc_4[1] = _mm512_fmadd_ps(grad_4[1], betta2_minus1_4, varianc_4[1]); - varianc_4[2] = _mm512_fmadd_ps(grad_4[2], betta2_minus1_4, varianc_4[2]); - varianc_4[3] = _mm512_fmadd_ps(grad_4[3], betta2_minus1_4, varianc_4[3]); - - grad_4[0] = _mm512_sqrt_ps(varianc_4[0]) / bias2_sqrt; - grad_4[1] = _mm512_sqrt_ps(varianc_4[1]) / bias2_sqrt; - grad_4[2] = _mm512_sqrt_ps(varianc_4[2]) / bias2_sqrt; - grad_4[3] = _mm512_sqrt_ps(varianc_4[3]) / bias2_sqrt; - - grad_4[0] = _mm512_add_ps(grad_4[0], eps_4); - grad_4[1] = _mm512_add_ps(grad_4[1], eps_4); - grad_4[2] = _mm512_add_ps(grad_4[2], eps_4); - grad_4[3] = _mm512_add_ps(grad_4[3], eps_4); - grad_4[0] = _mm512_div_ps(momntum_4[0], grad_4[0]); - grad_4[1] = _mm512_div_ps(momntum_4[1], grad_4[1]); - grad_4[2] = _mm512_div_ps(momntum_4[2], grad_4[2]); - grad_4[3] = _mm512_div_ps(momntum_4[3], grad_4[3]); - - param_4[0] = _mm512_fmadd_ps(grad_4[0], step_size_4, param_4[0]); - param_4[1] = _mm512_fmadd_ps(grad_4[1], step_size_4, param_4[1]); - param_4[2] = _mm512_fmadd_ps(grad_4[2], step_size_4, param_4[2]); - param_4[3] = _mm512_fmadd_ps(grad_4[3], step_size_4, param_4[3]); - - _mm512_storeu_ps(_params + i, param_4[0]); - _mm512_storeu_ps(_params + i + SIMD_WIDTH, param_4[1]); - _mm512_storeu_ps(_params + i + (SIMD_WIDTH << 1), param_4[2]); - _mm512_storeu_ps(_params + i + SIMD_WIDTH * 3, param_4[3]); - - _mm512_storeu_ps(_exp_avg + i, momntum_4[0]); - _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH, momntum_4[1]); - _mm512_storeu_ps(_exp_avg + i + (SIMD_WIDTH << 1), momntum_4[2]); - _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH * 3, momntum_4[3]); - - _mm512_storeu_ps(_exp_avg_sq + i, varianc_4[0]); - _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH, varianc_4[1]); - _mm512_storeu_ps(_exp_avg_sq + i + (SIMD_WIDTH << 1), varianc_4[2]); - _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH * 3, varianc_4[3]); - if (dev_params) { - if ((i + (SIMD_WIDTH << 2)) % TILE == 0) { - size_t offset = tile * TILE; -#pragma omp parallel for - for (size_t j = 0; j < TILE; j += 4) { - _doubled_buffer[buf_index][j] = (__half)_params[offset + j]; - _doubled_buffer[buf_index][j + 1] = (__half)_params[offset + j + 1]; - _doubled_buffer[buf_index][j + 2] = (__half)_params[offset + j + 2]; - _doubled_buffer[buf_index][j + 3] = (__half)_params[offset + j + 3]; - } - CUDA_CHECK(cudaMemcpyAsync(dev_params + tile * TILE, - _doubled_buffer[buf_index], - TILE * sizeof(__half), - cudaMemcpyHostToDevice, - Context::Instance().GetCurrentStream())); - buf_index = !buf_index; - tile++; + if (dev_params) + { + #pragma omp parallel for + for (size_t j = 0; j < copy_size; j += 4) + { + _doubled_buffer[buf_index][j] = (__half)_params[t + j]; + _doubled_buffer[buf_index][j + 1] = (__half)_params[t + j + 1]; + _doubled_buffer[buf_index][j + 2] = (__half)_params[t + j + 2]; + _doubled_buffer[buf_index][j + 3] = (__half)_params[t + j + 3]; } + + CUDA_CHECK(cudaMemcpyAsync(dev_params + t, + _doubled_buffer[buf_index], + copy_size * sizeof(__half), + cudaMemcpyHostToDevice, + Context::Instance().GetCurrentStream())); + buf_index = !buf_index; } } } @@ -262,6 +275,225 @@ int create_adam_optimizer(int optimizer_id, return 0; } +void Adam_Optimizer::Step_8(float *_params, + float *grads, + float *_exp_avg, + float *_exp_avg_sq, + size_t _param_size, + __half* dev_params) +{ + + _betta1_t *= _betta1; + _betta2_t *= _betta2; + + __m512 betta1_4 = _mm512_set1_ps(_betta1); + __m512 betta2_4 = _mm512_set1_ps(_betta2); + + bool buf_index = 0; + + float betta1_minus1 = 1 - _betta1; + float betta2_minus1 = 1 - _betta2; + __m512 betta1_minus1_4 = _mm512_set1_ps(betta1_minus1); + __m512 betta2_minus1_4 = _mm512_set1_ps(betta2_minus1); + + float bias_correction1 = 1 - _betta1_t; + float bias_correction2 = 1 - _betta2_t; + //__m512 bias_correction1_4 = _mm512_set1_ps(bias_correction1); + __m512 bias_correction2_4 = _mm512_set1_ps(bias_correction2); + + __m512 eps_4 = _mm512_set1_ps(_eps); + + float step_size = -1 * _alpha / bias_correction1; + __m512 step_size_4 = _mm512_set1_ps(step_size); + + __m512 bias2_sqrt = _mm512_sqrt_ps(bias_correction2_4); + + for (size_t t = 0; t < _param_size; t += TILE) + { + size_t copy_size = TILE; + if((t + TILE) > _param_size)copy_size = _param_size - t; + size_t offset = copy_size + t; +#pragma omp parallel for + for(size_t i = t; i < offset;i += (SIMD_WIDTH << 3)) + { + __m512 grad_4[8]; + grad_4[0] = _mm512_loadu_ps(grads + i); + grad_4[1] = _mm512_loadu_ps(grads + i + SIMD_WIDTH); + grad_4[2] = _mm512_loadu_ps(grads + i + (SIMD_WIDTH<<1)); + grad_4[3] = _mm512_loadu_ps(grads + i + SIMD_WIDTH * 3); + grad_4[4] = _mm512_loadu_ps(grads + i + (SIMD_WIDTH<<2)); + grad_4[5] = _mm512_loadu_ps(grads + i + SIMD_WIDTH * 5); + grad_4[6] = _mm512_loadu_ps(grads + i + SIMD_WIDTH * 6); + grad_4[7] = _mm512_loadu_ps(grads + i + SIMD_WIDTH * 7); + + __m512 momntum_4[8]; + momntum_4[0] = _mm512_loadu_ps(_exp_avg + i); + momntum_4[1] = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH); + momntum_4[2] = _mm512_loadu_ps(_exp_avg + i + (SIMD_WIDTH<<1)); + momntum_4[3] = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH * 3); + momntum_4[4] = _mm512_loadu_ps(_exp_avg + i + (SIMD_WIDTH<<2)); + momntum_4[5] = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH * 5); + momntum_4[6] = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH * 6); + momntum_4[7] = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH * 7); + + __m512 varianc_4[8]; + varianc_4[0] = _mm512_loadu_ps(_exp_avg_sq + i); + varianc_4[1] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH); + varianc_4[2] = _mm512_loadu_ps(_exp_avg_sq + i + (SIMD_WIDTH<<1)); + varianc_4[3] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH * 3); + varianc_4[5] = _mm512_loadu_ps(_exp_avg_sq + i + (SIMD_WIDTH<<2)); + varianc_4[6] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH * 5); + varianc_4[7] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH * 6); + varianc_4[8] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH * 7); + + __m512 param_4[8]; + param_4[0] = _mm512_loadu_ps(_params + i); + param_4[1] = _mm512_loadu_ps(_params + i + SIMD_WIDTH); + param_4[2] = _mm512_loadu_ps(_params + i + (SIMD_WIDTH<<1)); + param_4[3] = _mm512_loadu_ps(_params + i + SIMD_WIDTH * 3); + param_4[4] = _mm512_loadu_ps(_params + i + (SIMD_WIDTH<<2)); + param_4[5] = _mm512_loadu_ps(_params + i + SIMD_WIDTH * 5); + param_4[6] = _mm512_loadu_ps(_params + i + SIMD_WIDTH * 6); + param_4[7] = _mm512_loadu_ps(_params + i + SIMD_WIDTH * 7); + + if(_weight_decay > 0) + { + __m512 weight_decay4 = _mm512_set1_ps(_weight_decay); + grad_4[0] = _mm512_fmadd_ps(param_4[0], weight_decay4, grad_4[0]); + grad_4[1] = _mm512_fmadd_ps(param_4[1], weight_decay4, grad_4[1]); + grad_4[2] = _mm512_fmadd_ps(param_4[2], weight_decay4, grad_4[2]); + grad_4[3] = _mm512_fmadd_ps(param_4[3], weight_decay4, grad_4[3]); + grad_4[4] = _mm512_fmadd_ps(param_4[4], weight_decay4, grad_4[4]); + grad_4[5] = _mm512_fmadd_ps(param_4[5], weight_decay4, grad_4[5]); + grad_4[6] = _mm512_fmadd_ps(param_4[6], weight_decay4, grad_4[6]); + grad_4[7] = _mm512_fmadd_ps(param_4[7], weight_decay4, grad_4[7]); + } + + momntum_4[0] = _mm512_mul_ps(momntum_4[0], betta1_4); + momntum_4[0] = _mm512_fmadd_ps(grad_4[0], betta1_minus1_4, momntum_4[0]); + momntum_4[1] = _mm512_mul_ps(momntum_4[1], betta1_4); + momntum_4[1] = _mm512_fmadd_ps(grad_4[1], betta1_minus1_4, momntum_4[1]); + momntum_4[2] = _mm512_mul_ps(momntum_4[2], betta1_4); + momntum_4[2] = _mm512_fmadd_ps(grad_4[2], betta1_minus1_4, momntum_4[2]); + momntum_4[3] = _mm512_mul_ps(momntum_4[3], betta1_4); + momntum_4[3] = _mm512_fmadd_ps(grad_4[3], betta1_minus1_4, momntum_4[3]); + momntum_4[4] = _mm512_mul_ps(momntum_4[4], betta1_4); + momntum_4[4] = _mm512_fmadd_ps(grad_4[4], betta1_minus1_4, momntum_4[4]); + momntum_4[5] = _mm512_mul_ps(momntum_4[5], betta1_4); + momntum_4[5] = _mm512_fmadd_ps(grad_4[5], betta1_minus1_4, momntum_4[5]); + momntum_4[6] = _mm512_mul_ps(momntum_4[6], betta1_4); + momntum_4[6] = _mm512_fmadd_ps(grad_4[6], betta1_minus1_4, momntum_4[6]); + momntum_4[7] = _mm512_mul_ps(momntum_4[7], betta1_4); + momntum_4[7] = _mm512_fmadd_ps(grad_4[7], betta1_minus1_4, momntum_4[7]); + + varianc_4[0] = _mm512_mul_ps(varianc_4[0], betta2_4); + varianc_4[1] = _mm512_mul_ps(varianc_4[1], betta2_4); + varianc_4[2] = _mm512_mul_ps(varianc_4[2], betta2_4); + varianc_4[3] = _mm512_mul_ps(varianc_4[3], betta2_4); + varianc_4[4] = _mm512_mul_ps(varianc_4[4], betta2_4); + varianc_4[5] = _mm512_mul_ps(varianc_4[5], betta2_4); + varianc_4[6] = _mm512_mul_ps(varianc_4[6], betta2_4); + varianc_4[7] = _mm512_mul_ps(varianc_4[7], betta2_4); + grad_4[0] = _mm512_mul_ps(grad_4[0], grad_4[0]); + grad_4[1] = _mm512_mul_ps(grad_4[1], grad_4[1]); + grad_4[2] = _mm512_mul_ps(grad_4[2], grad_4[2]); + grad_4[3] = _mm512_mul_ps(grad_4[3], grad_4[3]); + grad_4[4] = _mm512_mul_ps(grad_4[4], grad_4[4]); + grad_4[5] = _mm512_mul_ps(grad_4[5], grad_4[5]); + grad_4[6] = _mm512_mul_ps(grad_4[6], grad_4[6]); + grad_4[7] = _mm512_mul_ps(grad_4[7], grad_4[7]); + varianc_4[0] = _mm512_fmadd_ps(grad_4[0], betta2_minus1_4, varianc_4[0]); + varianc_4[1] = _mm512_fmadd_ps(grad_4[1], betta2_minus1_4, varianc_4[1]); + varianc_4[2] = _mm512_fmadd_ps(grad_4[2], betta2_minus1_4, varianc_4[2]); + varianc_4[3] = _mm512_fmadd_ps(grad_4[3], betta2_minus1_4, varianc_4[3]); + varianc_4[4] = _mm512_fmadd_ps(grad_4[4], betta2_minus1_4, varianc_4[4]); + varianc_4[5] = _mm512_fmadd_ps(grad_4[5], betta2_minus1_4, varianc_4[5]); + varianc_4[6] = _mm512_fmadd_ps(grad_4[6], betta2_minus1_4, varianc_4[6]); + varianc_4[7] = _mm512_fmadd_ps(grad_4[7], betta2_minus1_4, varianc_4[7]); + + grad_4[0] = _mm512_sqrt_ps(varianc_4[0]) / bias2_sqrt; + grad_4[1] = _mm512_sqrt_ps(varianc_4[1]) / bias2_sqrt; + grad_4[2] = _mm512_sqrt_ps(varianc_4[2]) / bias2_sqrt; + grad_4[3] = _mm512_sqrt_ps(varianc_4[3]) / bias2_sqrt; + grad_4[4] = _mm512_sqrt_ps(varianc_4[4]) / bias2_sqrt; + grad_4[5] = _mm512_sqrt_ps(varianc_4[5]) / bias2_sqrt; + grad_4[6] = _mm512_sqrt_ps(varianc_4[6]) / bias2_sqrt; + grad_4[7] = _mm512_sqrt_ps(varianc_4[7]) / bias2_sqrt; + + grad_4[0] = _mm512_add_ps(grad_4[0], eps_4); + grad_4[1] = _mm512_add_ps(grad_4[1], eps_4); + grad_4[2] = _mm512_add_ps(grad_4[2], eps_4); + grad_4[3] = _mm512_add_ps(grad_4[3], eps_4); + grad_4[4] = _mm512_add_ps(grad_4[4], eps_4); + grad_4[5] = _mm512_add_ps(grad_4[5], eps_4); + grad_4[6] = _mm512_add_ps(grad_4[6], eps_4); + grad_4[7] = _mm512_add_ps(grad_4[7], eps_4); + grad_4[0] = _mm512_div_ps(momntum_4[0], grad_4[0]); + grad_4[1] = _mm512_div_ps(momntum_4[1], grad_4[1]); + grad_4[2] = _mm512_div_ps(momntum_4[2], grad_4[2]); + grad_4[3] = _mm512_div_ps(momntum_4[3], grad_4[3]); + grad_4[4] = _mm512_div_ps(momntum_4[4], grad_4[4]); + grad_4[5] = _mm512_div_ps(momntum_4[5], grad_4[5]); + grad_4[6] = _mm512_div_ps(momntum_4[6], grad_4[6]); + grad_4[7] = _mm512_div_ps(momntum_4[7], grad_4[7]); + + param_4[0] = _mm512_fmadd_ps(grad_4[0], step_size_4, param_4[0]); + param_4[1] = _mm512_fmadd_ps(grad_4[1], step_size_4, param_4[1]); + param_4[2] = _mm512_fmadd_ps(grad_4[2], step_size_4, param_4[2]); + param_4[3] = _mm512_fmadd_ps(grad_4[3], step_size_4, param_4[3]); + param_4[4] = _mm512_fmadd_ps(grad_4[4], step_size_4, param_4[4]); + param_4[5] = _mm512_fmadd_ps(grad_4[5], step_size_4, param_4[5]); + param_4[6] = _mm512_fmadd_ps(grad_4[6], step_size_4, param_4[6]); + param_4[7] = _mm512_fmadd_ps(grad_4[7], step_size_4, param_4[7]); + + _mm512_storeu_ps(_params + i, param_4[0]); + _mm512_storeu_ps(_params + i + SIMD_WIDTH, param_4[1]); + _mm512_storeu_ps(_params + i + (SIMD_WIDTH<<1), param_4[2]); + _mm512_storeu_ps(_params + i + SIMD_WIDTH * 3, param_4[3]); + _mm512_storeu_ps(_params + i + (SIMD_WIDTH<<2), param_4[4]); + _mm512_storeu_ps(_params + i + SIMD_WIDTH * 5, param_4[5]); + _mm512_storeu_ps(_params + i + SIMD_WIDTH * 6, param_4[6]); + _mm512_storeu_ps(_params + i + SIMD_WIDTH * 7, param_4[7]); + + _mm512_storeu_ps(_exp_avg + i, momntum_4[0]); + _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH, momntum_4[1]); + _mm512_storeu_ps(_exp_avg + i + (SIMD_WIDTH<<1), momntum_4[2]); + _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH * 3, momntum_4[3]); + _mm512_storeu_ps(_exp_avg + i + (SIMD_WIDTH<<2), momntum_4[4]); + _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH * 5, momntum_4[5]); + _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH * 6, momntum_4[6]); + _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH * 7, momntum_4[7]); + + _mm512_storeu_ps(_exp_avg_sq + i, varianc_4[0]); + _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH, varianc_4[1]); + _mm512_storeu_ps(_exp_avg_sq + i + (SIMD_WIDTH<<1), varianc_4[2]); + _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH * 3, varianc_4[3]); + _mm512_storeu_ps(_exp_avg_sq + i + (SIMD_WIDTH<<2), varianc_4[4]); + _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH * 5, varianc_4[5]); + _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH * 6, varianc_4[6]); + _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH * 7, varianc_4[7]); + } + if (dev_params) + { + #pragma omp parallel for + for (size_t j = 0; j < copy_size; j += 4) + { + _doubled_buffer[buf_index][j] = (__half)_params[t + j]; + _doubled_buffer[buf_index][j + 1] = (__half)_params[t + j + 1]; + _doubled_buffer[buf_index][j + 2] = (__half)_params[t + j + 2]; + _doubled_buffer[buf_index][j + 3] = (__half)_params[t + j + 3]; + } + + CUDA_CHECK(cudaMemcpyAsync(dev_params + t, + _doubled_buffer[buf_index], + copy_size * sizeof(__half), + cudaMemcpyHostToDevice, + Context::Instance().GetCurrentStream())); + buf_index = !buf_index; + } + } +} + int ds_adam_step(int optimizer_id, torch::Tensor& params, torch::Tensor& grads, @@ -281,7 +513,7 @@ int ds_adam_step(int optimizer_id, std::shared_ptr opt = std::static_pointer_cast(s_optimizers[optimizer_id]); - opt->Step_4(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0)); + opt->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0)); return 0; } @@ -308,7 +540,7 @@ int ds_adam_step_plus_copy(int optimizer_id, std::shared_ptr opt = std::static_pointer_cast(s_optimizers[optimizer_id]); - opt->Step_4( + opt->Step_8( params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0), gpu_params_ptr); return 0; diff --git a/csrc/includes/cpu_adam.h b/csrc/includes/cpu_adam.h old mode 100644 new mode 100755 index d1a2aa088731..c74c31c72f16 --- a/csrc/includes/cpu_adam.h +++ b/csrc/includes/cpu_adam.h @@ -18,7 +18,8 @@ } #define SIMD_WIDTH 16 -#define TILE 1048576 + +#define TILE (1024 * 1024 * 1024) class Adam_Optimizer { public: @@ -55,6 +56,12 @@ class Adam_Optimizer { float* _exp_avg_sa, size_t param_size, __half* dev_param = nullptr); + void Step_8(float *_params, + float *grads, + float *_exp_avg, + float *_exp_avg_sq, + size_t _param_size, + __half* dev_params = nullptr); private: float _alpha; diff --git a/setup.py b/setup.py index ae5010a65b11..77b7241b582b 100755 --- a/setup.py +++ b/setup.py @@ -123,11 +123,15 @@ def fetch_requirements(path): sources=[ 'csrc/adam/cpu_adam.cpp', ], - include_dirs=['csrc/includes'], + include_dirs=['csrc/includes', + '/usr/local/cuda/include'], extra_compile_args={ 'cxx': ['-O3', '-std=c++14', + '-L/usr/local/cuda/lib64', + '-lcudart', + '-lcublas', '-g', '-Wno-reorder', '-march=native', diff --git a/tests/unit/adam_test.py b/tests/unit/adam_test.py index 0fedfeb18e42..51c588dda962 100755 --- a/tests/unit/adam_test.py +++ b/tests/unit/adam_test.py @@ -3,16 +3,18 @@ import time device = 'cpu' -model_size = 1 * 1024**3 +model_size = 10 * 1024**3 param = torch.nn.Parameter(torch.ones(model_size, device=device)) +param_fp16 = torch.nn.Parameter(torch.ones(model_size, dtype=torch.half, device='cuda:0')) + optimizer = DeepSpeedCPUAdam([param]) #torch.set_num_threads(128) param.grad = torch.ones(model_size, device=device) avg = 0 -for i in range(100): +for i in range(10): start = time.time() - optimizer.step() + optimizer.step(fp16_param_groups=[param_fp16]) stop = time.time() avg += (stop - start) param.grad = torch.ones(model_size, device=device) * 2 -print("Elapsed Time is ", avg / 100) +print("Elapsed Time is ", avg / 10) From 59ffc1a9092f4a47744676b727d522c2886260d0 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Sat, 5 Sep 2020 01:58:55 +0000 Subject: [PATCH 2/6] run precommit --- csrc/adam/cpu_adam.cpp | 102 +++++++++++++++++---------------------- csrc/includes/cpu_adam.h | 10 ++-- setup.py | 38 +++++++-------- tests/unit/adam_test.py | 4 +- 4 files changed, 71 insertions(+), 83 deletions(-) mode change 100755 => 100644 csrc/adam/cpu_adam.cpp mode change 100755 => 100644 csrc/includes/cpu_adam.h diff --git a/csrc/adam/cpu_adam.cpp b/csrc/adam/cpu_adam.cpp old mode 100755 new mode 100644 index ab5b6c0b7054..98ac32e04a2f --- a/csrc/adam/cpu_adam.cpp +++ b/csrc/adam/cpu_adam.cpp @@ -49,14 +49,12 @@ void Adam_Optimizer::Step(float* _params, size_t tile = 0; - for (size_t t = 0; t < _param_size; t += TILE) - { + for (size_t t = 0; t < _param_size; t += TILE) { size_t copy_size = TILE; - if((t + TILE) > _param_size)copy_size = _param_size - t; + if ((t + TILE) > _param_size) copy_size = _param_size - t; size_t offset = copy_size + t; - #pragma omp parallel for - for(size_t i = t; i < offset;i += SIMD_WIDTH) - { +#pragma omp parallel for + for (size_t i = t; i < offset; i += SIMD_WIDTH) { __m512 grad_4 = _mm512_loadu_ps(grads + i); __m512 momntum_4 = _mm512_loadu_ps(_exp_avg + i); @@ -86,17 +84,15 @@ void Adam_Optimizer::Step(float* _params, _mm512_storeu_ps(_exp_avg + i, momntum_4); _mm512_storeu_ps(_exp_avg_sq + i, varianc_4); } - if (dev_params) - { - #pragma omp parallel for - for (size_t j = 0; j < copy_size; j += 4) - { + if (dev_params) { +#pragma omp parallel for + for (size_t j = 0; j < copy_size; j += 4) { _doubled_buffer[buf_index][j] = (__half)_params[t + j]; _doubled_buffer[buf_index][j + 1] = (__half)_params[t + j + 1]; _doubled_buffer[buf_index][j + 2] = (__half)_params[t + j + 2]; _doubled_buffer[buf_index][j + 3] = (__half)_params[t + j + 3]; } - + CUDA_CHECK(cudaMemcpyAsync(dev_params + t, _doubled_buffer[buf_index], copy_size * sizeof(__half), @@ -140,14 +136,12 @@ void Adam_Optimizer::Step_4(float* _params, __m512 bias2_sqrt = _mm512_sqrt_ps(bias_correction2_4); - for (size_t t = 0; t < _param_size; t += TILE) - { + for (size_t t = 0; t < _param_size; t += TILE) { size_t copy_size = TILE; - if((t + TILE) > _param_size)copy_size = _param_size - t; + if ((t + TILE) > _param_size) copy_size = _param_size - t; size_t offset = copy_size + t; #pragma omp parallel for - for(size_t i = t; i < offset;i += (SIMD_WIDTH << 2)) - { + for (size_t i = t; i < offset; i += (SIMD_WIDTH << 2)) { __m512 grad_4[4]; grad_4[0] = _mm512_loadu_ps(grads + i); grad_4[1] = _mm512_loadu_ps(grads + i + SIMD_WIDTH); @@ -237,17 +231,15 @@ void Adam_Optimizer::Step_4(float* _params, _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH * 3, varianc_4[3]); } - if (dev_params) - { - #pragma omp parallel for - for (size_t j = 0; j < copy_size; j += 4) - { + if (dev_params) { +#pragma omp parallel for + for (size_t j = 0; j < copy_size; j += 4) { _doubled_buffer[buf_index][j] = (__half)_params[t + j]; _doubled_buffer[buf_index][j + 1] = (__half)_params[t + j + 1]; _doubled_buffer[buf_index][j + 2] = (__half)_params[t + j + 2]; _doubled_buffer[buf_index][j + 3] = (__half)_params[t + j + 3]; } - + CUDA_CHECK(cudaMemcpyAsync(dev_params + t, _doubled_buffer[buf_index], copy_size * sizeof(__half), @@ -275,14 +267,13 @@ int create_adam_optimizer(int optimizer_id, return 0; } -void Adam_Optimizer::Step_8(float *_params, - float *grads, - float *_exp_avg, - float *_exp_avg_sq, - size_t _param_size, +void Adam_Optimizer::Step_8(float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t _param_size, __half* dev_params) { - _betta1_t *= _betta1; _betta2_t *= _betta2; @@ -291,8 +282,8 @@ void Adam_Optimizer::Step_8(float *_params, bool buf_index = 0; - float betta1_minus1 = 1 - _betta1; - float betta2_minus1 = 1 - _betta2; + float betta1_minus1 = 1 - _betta1; + float betta2_minus1 = 1 - _betta2; __m512 betta1_minus1_4 = _mm512_set1_ps(betta1_minus1); __m512 betta2_minus1_4 = _mm512_set1_ps(betta2_minus1); @@ -308,20 +299,18 @@ void Adam_Optimizer::Step_8(float *_params, __m512 bias2_sqrt = _mm512_sqrt_ps(bias_correction2_4); - for (size_t t = 0; t < _param_size; t += TILE) - { + for (size_t t = 0; t < _param_size; t += TILE) { size_t copy_size = TILE; - if((t + TILE) > _param_size)copy_size = _param_size - t; + if ((t + TILE) > _param_size) copy_size = _param_size - t; size_t offset = copy_size + t; #pragma omp parallel for - for(size_t i = t; i < offset;i += (SIMD_WIDTH << 3)) - { + for (size_t i = t; i < offset; i += (SIMD_WIDTH << 3)) { __m512 grad_4[8]; grad_4[0] = _mm512_loadu_ps(grads + i); grad_4[1] = _mm512_loadu_ps(grads + i + SIMD_WIDTH); - grad_4[2] = _mm512_loadu_ps(grads + i + (SIMD_WIDTH<<1)); + grad_4[2] = _mm512_loadu_ps(grads + i + (SIMD_WIDTH << 1)); grad_4[3] = _mm512_loadu_ps(grads + i + SIMD_WIDTH * 3); - grad_4[4] = _mm512_loadu_ps(grads + i + (SIMD_WIDTH<<2)); + grad_4[4] = _mm512_loadu_ps(grads + i + (SIMD_WIDTH << 2)); grad_4[5] = _mm512_loadu_ps(grads + i + SIMD_WIDTH * 5); grad_4[6] = _mm512_loadu_ps(grads + i + SIMD_WIDTH * 6); grad_4[7] = _mm512_loadu_ps(grads + i + SIMD_WIDTH * 7); @@ -329,9 +318,9 @@ void Adam_Optimizer::Step_8(float *_params, __m512 momntum_4[8]; momntum_4[0] = _mm512_loadu_ps(_exp_avg + i); momntum_4[1] = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH); - momntum_4[2] = _mm512_loadu_ps(_exp_avg + i + (SIMD_WIDTH<<1)); + momntum_4[2] = _mm512_loadu_ps(_exp_avg + i + (SIMD_WIDTH << 1)); momntum_4[3] = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH * 3); - momntum_4[4] = _mm512_loadu_ps(_exp_avg + i + (SIMD_WIDTH<<2)); + momntum_4[4] = _mm512_loadu_ps(_exp_avg + i + (SIMD_WIDTH << 2)); momntum_4[5] = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH * 5); momntum_4[6] = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH * 6); momntum_4[7] = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH * 7); @@ -339,9 +328,9 @@ void Adam_Optimizer::Step_8(float *_params, __m512 varianc_4[8]; varianc_4[0] = _mm512_loadu_ps(_exp_avg_sq + i); varianc_4[1] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH); - varianc_4[2] = _mm512_loadu_ps(_exp_avg_sq + i + (SIMD_WIDTH<<1)); + varianc_4[2] = _mm512_loadu_ps(_exp_avg_sq + i + (SIMD_WIDTH << 1)); varianc_4[3] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH * 3); - varianc_4[5] = _mm512_loadu_ps(_exp_avg_sq + i + (SIMD_WIDTH<<2)); + varianc_4[5] = _mm512_loadu_ps(_exp_avg_sq + i + (SIMD_WIDTH << 2)); varianc_4[6] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH * 5); varianc_4[7] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH * 6); varianc_4[8] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH * 7); @@ -349,15 +338,14 @@ void Adam_Optimizer::Step_8(float *_params, __m512 param_4[8]; param_4[0] = _mm512_loadu_ps(_params + i); param_4[1] = _mm512_loadu_ps(_params + i + SIMD_WIDTH); - param_4[2] = _mm512_loadu_ps(_params + i + (SIMD_WIDTH<<1)); + param_4[2] = _mm512_loadu_ps(_params + i + (SIMD_WIDTH << 1)); param_4[3] = _mm512_loadu_ps(_params + i + SIMD_WIDTH * 3); - param_4[4] = _mm512_loadu_ps(_params + i + (SIMD_WIDTH<<2)); + param_4[4] = _mm512_loadu_ps(_params + i + (SIMD_WIDTH << 2)); param_4[5] = _mm512_loadu_ps(_params + i + SIMD_WIDTH * 5); param_4[6] = _mm512_loadu_ps(_params + i + SIMD_WIDTH * 6); param_4[7] = _mm512_loadu_ps(_params + i + SIMD_WIDTH * 7); - if(_weight_decay > 0) - { + if (_weight_decay > 0) { __m512 weight_decay4 = _mm512_set1_ps(_weight_decay); grad_4[0] = _mm512_fmadd_ps(param_4[0], weight_decay4, grad_4[0]); grad_4[1] = _mm512_fmadd_ps(param_4[1], weight_decay4, grad_4[1]); @@ -448,42 +436,40 @@ void Adam_Optimizer::Step_8(float *_params, _mm512_storeu_ps(_params + i, param_4[0]); _mm512_storeu_ps(_params + i + SIMD_WIDTH, param_4[1]); - _mm512_storeu_ps(_params + i + (SIMD_WIDTH<<1), param_4[2]); + _mm512_storeu_ps(_params + i + (SIMD_WIDTH << 1), param_4[2]); _mm512_storeu_ps(_params + i + SIMD_WIDTH * 3, param_4[3]); - _mm512_storeu_ps(_params + i + (SIMD_WIDTH<<2), param_4[4]); + _mm512_storeu_ps(_params + i + (SIMD_WIDTH << 2), param_4[4]); _mm512_storeu_ps(_params + i + SIMD_WIDTH * 5, param_4[5]); _mm512_storeu_ps(_params + i + SIMD_WIDTH * 6, param_4[6]); _mm512_storeu_ps(_params + i + SIMD_WIDTH * 7, param_4[7]); _mm512_storeu_ps(_exp_avg + i, momntum_4[0]); _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH, momntum_4[1]); - _mm512_storeu_ps(_exp_avg + i + (SIMD_WIDTH<<1), momntum_4[2]); + _mm512_storeu_ps(_exp_avg + i + (SIMD_WIDTH << 1), momntum_4[2]); _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH * 3, momntum_4[3]); - _mm512_storeu_ps(_exp_avg + i + (SIMD_WIDTH<<2), momntum_4[4]); + _mm512_storeu_ps(_exp_avg + i + (SIMD_WIDTH << 2), momntum_4[4]); _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH * 5, momntum_4[5]); _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH * 6, momntum_4[6]); _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH * 7, momntum_4[7]); _mm512_storeu_ps(_exp_avg_sq + i, varianc_4[0]); _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH, varianc_4[1]); - _mm512_storeu_ps(_exp_avg_sq + i + (SIMD_WIDTH<<1), varianc_4[2]); + _mm512_storeu_ps(_exp_avg_sq + i + (SIMD_WIDTH << 1), varianc_4[2]); _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH * 3, varianc_4[3]); - _mm512_storeu_ps(_exp_avg_sq + i + (SIMD_WIDTH<<2), varianc_4[4]); + _mm512_storeu_ps(_exp_avg_sq + i + (SIMD_WIDTH << 2), varianc_4[4]); _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH * 5, varianc_4[5]); _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH * 6, varianc_4[6]); _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH * 7, varianc_4[7]); } - if (dev_params) - { - #pragma omp parallel for - for (size_t j = 0; j < copy_size; j += 4) - { + if (dev_params) { +#pragma omp parallel for + for (size_t j = 0; j < copy_size; j += 4) { _doubled_buffer[buf_index][j] = (__half)_params[t + j]; _doubled_buffer[buf_index][j + 1] = (__half)_params[t + j + 1]; _doubled_buffer[buf_index][j + 2] = (__half)_params[t + j + 2]; _doubled_buffer[buf_index][j + 3] = (__half)_params[t + j + 3]; } - + CUDA_CHECK(cudaMemcpyAsync(dev_params + t, _doubled_buffer[buf_index], copy_size * sizeof(__half), diff --git a/csrc/includes/cpu_adam.h b/csrc/includes/cpu_adam.h old mode 100755 new mode 100644 index c74c31c72f16..40f4cba692ea --- a/csrc/includes/cpu_adam.h +++ b/csrc/includes/cpu_adam.h @@ -56,11 +56,11 @@ class Adam_Optimizer { float* _exp_avg_sa, size_t param_size, __half* dev_param = nullptr); - void Step_8(float *_params, - float *grads, - float *_exp_avg, - float *_exp_avg_sq, - size_t _param_size, + void Step_8(float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t _param_size, __half* dev_params = nullptr); private: diff --git a/setup.py b/setup.py index 77b7241b582b..10564e59a6e8 100755 --- a/setup.py +++ b/setup.py @@ -118,25 +118,25 @@ def fetch_requirements(path): ## Adam ## if BUILD_MASK & DS_BUILD_ADAM: ext_modules.append( - CUDAExtension( - name='deepspeed.ops.adam.cpu_adam_op', - sources=[ - 'csrc/adam/cpu_adam.cpp', - ], - include_dirs=['csrc/includes', - '/usr/local/cuda/include'], - extra_compile_args={ - 'cxx': - ['-O3', - '-std=c++14', - '-L/usr/local/cuda/lib64', - '-lcudart', - '-lcublas', - '-g', - '-Wno-reorder', - '-march=native', - '-fopenmp'] - })) + CUDAExtension(name='deepspeed.ops.adam.cpu_adam_op', + sources=[ + 'csrc/adam/cpu_adam.cpp', + ], + include_dirs=['csrc/includes', + '/usr/local/cuda/include'], + extra_compile_args={ + 'cxx': [ + '-O3', + '-std=c++14', + '-L/usr/local/cuda/lib64', + '-lcudart', + '-lcublas', + '-g', + '-Wno-reorder', + '-march=native', + '-fopenmp' + ] + })) ## Transformer ## if BUILD_MASK & DS_BUILD_TRANSFORMER: diff --git a/tests/unit/adam_test.py b/tests/unit/adam_test.py index 51c588dda962..be5caa5cbecc 100755 --- a/tests/unit/adam_test.py +++ b/tests/unit/adam_test.py @@ -5,7 +5,9 @@ device = 'cpu' model_size = 10 * 1024**3 param = torch.nn.Parameter(torch.ones(model_size, device=device)) -param_fp16 = torch.nn.Parameter(torch.ones(model_size, dtype=torch.half, device='cuda:0')) +param_fp16 = torch.nn.Parameter(torch.ones(model_size, + dtype=torch.half, + device='cuda:0')) optimizer = DeepSpeedCPUAdam([param]) #torch.set_num_threads(128) From ad8af38467a974dd54cf2e3da3ecd169f9270209 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Sat, 5 Sep 2020 02:04:45 +0000 Subject: [PATCH 3/6] fix variance indexes --- csrc/adam/cpu_adam.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) mode change 100644 => 100755 csrc/adam/cpu_adam.cpp diff --git a/csrc/adam/cpu_adam.cpp b/csrc/adam/cpu_adam.cpp old mode 100644 new mode 100755 index 98ac32e04a2f..5644ab1d4235 --- a/csrc/adam/cpu_adam.cpp +++ b/csrc/adam/cpu_adam.cpp @@ -330,10 +330,10 @@ void Adam_Optimizer::Step_8(float* _params, varianc_4[1] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH); varianc_4[2] = _mm512_loadu_ps(_exp_avg_sq + i + (SIMD_WIDTH << 1)); varianc_4[3] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH * 3); - varianc_4[5] = _mm512_loadu_ps(_exp_avg_sq + i + (SIMD_WIDTH << 2)); - varianc_4[6] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH * 5); - varianc_4[7] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH * 6); - varianc_4[8] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH * 7); + varianc_4[4] = _mm512_loadu_ps(_exp_avg_sq + i + (SIMD_WIDTH << 2)); + varianc_4[5] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH * 5); + varianc_4[6] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH * 6); + varianc_4[7] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH * 7); __m512 param_4[8]; param_4[0] = _mm512_loadu_ps(_params + i); From 2d17a6d3fe09a32f847bcdca5d615bb4a43653ac Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Sat, 5 Sep 2020 04:50:32 +0000 Subject: [PATCH 4/6] fix array-sizes --- csrc/adam/cpu_adam.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/adam/cpu_adam.cpp b/csrc/adam/cpu_adam.cpp index 5644ab1d4235..9399438b57f6 100755 --- a/csrc/adam/cpu_adam.cpp +++ b/csrc/adam/cpu_adam.cpp @@ -148,19 +148,19 @@ void Adam_Optimizer::Step_4(float* _params, grad_4[2] = _mm512_loadu_ps(grads + i + (SIMD_WIDTH << 1)); grad_4[3] = _mm512_loadu_ps(grads + i + SIMD_WIDTH * 3); - __m512 momntum_4[2]; + __m512 momntum_4[4]; momntum_4[0] = _mm512_loadu_ps(_exp_avg + i); momntum_4[1] = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH); momntum_4[2] = _mm512_loadu_ps(_exp_avg + i + (SIMD_WIDTH << 1)); momntum_4[3] = _mm512_loadu_ps(_exp_avg + i + SIMD_WIDTH * 3); - __m512 varianc_4[2]; + __m512 varianc_4[4]; varianc_4[0] = _mm512_loadu_ps(_exp_avg_sq + i); varianc_4[1] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH); varianc_4[2] = _mm512_loadu_ps(_exp_avg_sq + i + (SIMD_WIDTH << 1)); varianc_4[3] = _mm512_loadu_ps(_exp_avg_sq + i + SIMD_WIDTH * 3); - __m512 param_4[2]; + __m512 param_4[4]; param_4[0] = _mm512_loadu_ps(_params + i); param_4[1] = _mm512_loadu_ps(_params + i + SIMD_WIDTH); param_4[2] = _mm512_loadu_ps(_params + i + (SIMD_WIDTH << 1)); From 37887241c27ce1ae52c428592105fb123a08b418 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Sat, 5 Sep 2020 05:17:26 +0000 Subject: [PATCH 5/6] move adam_test --- tests/{unit => perf}/adam_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename tests/{unit => perf}/adam_test.py (93%) diff --git a/tests/unit/adam_test.py b/tests/perf/adam_test.py similarity index 93% rename from tests/unit/adam_test.py rename to tests/perf/adam_test.py index be5caa5cbecc..800cb4f42eaa 100755 --- a/tests/unit/adam_test.py +++ b/tests/perf/adam_test.py @@ -3,7 +3,7 @@ import time device = 'cpu' -model_size = 10 * 1024**3 +model_size = 1 * 1024**3 param = torch.nn.Parameter(torch.ones(model_size, device=device)) param_fp16 = torch.nn.Parameter(torch.ones(model_size, dtype=torch.half, From 36d5fde79bf1002fa3c0031377e6b391faa850f3 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Sat, 5 Sep 2020 05:20:10 +0000 Subject: [PATCH 6/6] rename perf test --- tests/perf/{adam_test.py => adam_test1.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/perf/{adam_test.py => adam_test1.py} (100%) diff --git a/tests/perf/adam_test.py b/tests/perf/adam_test1.py similarity index 100% rename from tests/perf/adam_test.py rename to tests/perf/adam_test1.py