diff --git a/csrc/adam/cpu_adam.cpp b/csrc/adam/cpu_adam.cpp old mode 100755 new mode 100644 index e6afa7afe3e8..2baab6177c2e --- a/csrc/adam/cpu_adam.cpp +++ b/csrc/adam/cpu_adam.cpp @@ -1,5 +1,6 @@ #include "cpu_adam.h" #include +#include #include #include #include @@ -10,13 +11,11 @@ #include "cuda.h" #include "curand.h" #include "custom_cuda_layers.h" -#include static std::unordered_map> s_optimizers; #define ROUND_DOWN(size, step) ((size) & ~((step)-1)) - // C++ interface void Adam_Optimizer::Step(float* _params, @@ -43,8 +42,8 @@ void Adam_Optimizer::Step(float* _params, float bias_correction1 = 1 - _betta1_t; float bias_correction2 = 1 / sqrt(1 - _betta2_t); - //AVX_512 bias_correction1_4 = _mm512_set1_ps(bias_correction1); - AVX_512 bias2_sqrt ; + // AVX_512 bias_correction1_4 = _mm512_set1_ps(bias_correction1); + AVX_512 bias2_sqrt; bias2_sqrt.data = _mm512_set1_ps(bias_correction2); AVX_512 eps_4; @@ -91,48 +90,44 @@ void Adam_Optimizer::Step(float* _params, grad_4.data = _mm512_div_ps(momntum_4.data, grad_4.data); param_4.data = _mm512_fmadd_ps(grad_4.data, step_size_4.data, param_4.data); - if (dev_params) { - for (size_t j = 0; j < SIMD_WIDTH; j += 4) { - _doubled_buffer[_buf_index][(i - t) + (j << 2)] = (__half)param_4.data_f[(j << 2)]; - _doubled_buffer[_buf_index][(i - t) + (j << 2) + 1] = (__half)param_4.data_f[(j << 2) + 1]; - _doubled_buffer[_buf_index][(i - t) + (j << 2) + 2] = (__half)param_4.data_f[(j << 2) + 2]; - _doubled_buffer[_buf_index][(i - t) + (j << 2) + 3] = (__half)param_4.data_f[(j << 2) + 3]; - } - } + _mm512_storeu_ps(_params + i, param_4.data); + + if (dev_params) _mm512_storeu_ps(_doubled_buffer[_buf_index] + (i - t), param_4.data); + _mm512_storeu_ps(_exp_avg + i, momntum_4.data); _mm512_storeu_ps(_exp_avg_sq + i, varianc_4.data); } - if (dev_params) {/* -#pragma omp parallel for - for (size_t j = 0; j < copy_size; j += 4) { - _doubled_buffer[_buf_index][j] = (__half)_params[t + j]; - _doubled_buffer[_buf_index][j + 1] = (__half)_params[t + j + 1]; - _doubled_buffer[_buf_index][j + 2] = (__half)_params[t + j + 2]; - _doubled_buffer[_buf_index][j + 3] = (__half)_params[t + j + 3]; - }*/ - - CUDA_CHECK(cudaMemcpyAsync(dev_params + t, - _doubled_buffer[_buf_index], - copy_size * sizeof(__half), - cudaMemcpyHostToDevice, - Context::Instance().GetCurrentStream())); + if (dev_params) { /* + #pragma omp parallel for + for (size_t j = 0; j < copy_size; j += 4) { + _doubled_buffer[_buf_index][j] = (__half)_params[t + j]; + _doubled_buffer[_buf_index][j + 1] = (__half)_params[t + j + 1]; + _doubled_buffer[_buf_index][j + 2] = (__half)_params[t + j + 2]; + _doubled_buffer[_buf_index][j + 3] = (__half)_params[t + j + 3]; + } + + CUDA_CHECK(cudaMemcpyAsync(dev_params + t, + _doubled_buffer[_buf_index], + copy_size * sizeof(__half), + cudaMemcpyHostToDevice, + Context::Instance().GetCurrentStream()));*/ + launch_param_update(_doubled_buffer[_buf_index], + dev_params + t, + copy_size, + Context::Instance().GetCurrentStream()); _buf_index = !_buf_index; } } - if(_param_size > rounded_size) - { + if (_param_size > rounded_size) { #pragma omp parallel for - for (size_t k = rounded_size; k < _param_size; k++) - { + for (size_t k = rounded_size; k < _param_size; k++) { float grad = grads[k]; float param = _params[k]; float momntum = _exp_avg[k]; float varianc = _exp_avg_sq[k]; - if (_weight_decay > 0) { - grad = param * _weight_decay + grad; - } + if (_weight_decay > 0) { grad = param * _weight_decay + grad; } momntum *= momntum * _betta1; momntum = grad * betta1_minus1 + momntum; @@ -146,19 +141,17 @@ void Adam_Optimizer::Step(float* _params, grad = momntum / grad; param = grad * step_size + param; - if (dev_params) - _doubled_buffer[_buf_index][k - rounded_size] = (__half)param; + if (dev_params) _doubled_buffer[_buf_index][k - rounded_size] = (__half)param; _params[k] = param; _exp_avg[k] = momntum; _exp_avg_sq[k] = varianc; } if (dev_params) { - CUDA_CHECK(cudaMemcpyAsync(dev_params + rounded_size, - _doubled_buffer[_buf_index], - (_param_size - rounded_size) * sizeof(__half), - cudaMemcpyHostToDevice, - Context::Instance().GetCurrentStream())); + launch_param_update(_doubled_buffer[_buf_index], + dev_params + rounded_size, + (_param_size - rounded_size), + Context::Instance().GetCurrentStream()); } } } @@ -187,8 +180,8 @@ void Adam_Optimizer::Step_4(float* _params, float bias_correction1 = 1 - _betta1_t; float bias_correction2 = 1 / sqrt(1 - _betta2_t); - //AVX_512 bias_correction1_4 = _mm512_set1_ps(bias_correction1); - AVX_512 bias2_sqrt ; + // AVX_512 bias_correction1_4 = _mm512_set1_ps(bias_correction1); + AVX_512 bias2_sqrt; bias2_sqrt.data = _mm512_set1_ps(bias_correction2); AVX_512 eps_4; @@ -233,20 +226,28 @@ void Adam_Optimizer::Step_4(float* _params, if (_weight_decay > 0) { AVX_512 weight_decay4; weight_decay4.data = _mm512_set1_ps(_weight_decay); - grad_4[0].data = _mm512_fmadd_ps(param_4[0].data, weight_decay4.data, grad_4[0].data); - grad_4[1].data = _mm512_fmadd_ps(param_4[1].data, weight_decay4.data, grad_4[1].data); - grad_4[2].data = _mm512_fmadd_ps(param_4[2].data, weight_decay4.data, grad_4[2].data); - grad_4[3].data = _mm512_fmadd_ps(param_4[3].data, weight_decay4.data, grad_4[3].data); + grad_4[0].data = + _mm512_fmadd_ps(param_4[0].data, weight_decay4.data, grad_4[0].data); + grad_4[1].data = + _mm512_fmadd_ps(param_4[1].data, weight_decay4.data, grad_4[1].data); + grad_4[2].data = + _mm512_fmadd_ps(param_4[2].data, weight_decay4.data, grad_4[2].data); + grad_4[3].data = + _mm512_fmadd_ps(param_4[3].data, weight_decay4.data, grad_4[3].data); } momntum_4[0].data = _mm512_mul_ps(momntum_4[0].data, betta1_4.data); - momntum_4[0].data = _mm512_fmadd_ps(grad_4[0].data, betta1_minus1_4.data, momntum_4[0].data); + momntum_4[0].data = + _mm512_fmadd_ps(grad_4[0].data, betta1_minus1_4.data, momntum_4[0].data); momntum_4[1].data = _mm512_mul_ps(momntum_4[1].data, betta1_4.data); - momntum_4[1].data = _mm512_fmadd_ps(grad_4[1].data, betta1_minus1_4.data, momntum_4[1].data); + momntum_4[1].data = + _mm512_fmadd_ps(grad_4[1].data, betta1_minus1_4.data, momntum_4[1].data); momntum_4[2].data = _mm512_mul_ps(momntum_4[2].data, betta1_4.data); - momntum_4[2].data = _mm512_fmadd_ps(grad_4[2].data, betta1_minus1_4.data, momntum_4[2].data); + momntum_4[2].data = + _mm512_fmadd_ps(grad_4[2].data, betta1_minus1_4.data, momntum_4[2].data); momntum_4[3].data = _mm512_mul_ps(momntum_4[3].data, betta1_4.data); - momntum_4[3].data = _mm512_fmadd_ps(grad_4[3].data, betta1_minus1_4.data, momntum_4[3].data); + momntum_4[3].data = + _mm512_fmadd_ps(grad_4[3].data, betta1_minus1_4.data, momntum_4[3].data); varianc_4[0].data = _mm512_mul_ps(varianc_4[0].data, betta2_4.data); varianc_4[1].data = _mm512_mul_ps(varianc_4[1].data, betta2_4.data); @@ -256,10 +257,14 @@ void Adam_Optimizer::Step_4(float* _params, grad_4[1].data = _mm512_mul_ps(grad_4[1].data, grad_4[1].data); grad_4[2].data = _mm512_mul_ps(grad_4[2].data, grad_4[2].data); grad_4[3].data = _mm512_mul_ps(grad_4[3].data, grad_4[3].data); - varianc_4[0].data = _mm512_fmadd_ps(grad_4[0].data, betta2_minus1_4.data, varianc_4[0].data); - varianc_4[1].data = _mm512_fmadd_ps(grad_4[1].data, betta2_minus1_4.data, varianc_4[1].data); - varianc_4[2].data = _mm512_fmadd_ps(grad_4[2].data, betta2_minus1_4.data, varianc_4[2].data); - varianc_4[3].data = _mm512_fmadd_ps(grad_4[3].data, betta2_minus1_4.data, varianc_4[3].data); + varianc_4[0].data = + _mm512_fmadd_ps(grad_4[0].data, betta2_minus1_4.data, varianc_4[0].data); + varianc_4[1].data = + _mm512_fmadd_ps(grad_4[1].data, betta2_minus1_4.data, varianc_4[1].data); + varianc_4[2].data = + _mm512_fmadd_ps(grad_4[2].data, betta2_minus1_4.data, varianc_4[2].data); + varianc_4[3].data = + _mm512_fmadd_ps(grad_4[3].data, betta2_minus1_4.data, varianc_4[3].data); grad_4[0].data = _mm512_sqrt_ps(varianc_4[0].data); grad_4[1].data = _mm512_sqrt_ps(varianc_4[1].data); @@ -280,23 +285,21 @@ void Adam_Optimizer::Step_4(float* _params, param_4[2].data = _mm512_fmadd_ps(grad_4[2].data, step_size_4.data, param_4[2].data); param_4[3].data = _mm512_fmadd_ps(grad_4[3].data, step_size_4.data, param_4[3].data); - if (dev_params) { - for(int u = 0;u < 4;u++) - { - for (size_t j = 0; j < SIMD_WIDTH; j += 4) { - _doubled_buffer[_buf_index][(i - t) + (u << 4) + (j << 2)] = (__half)param_4[u].data_f[(j << 2)]; - _doubled_buffer[_buf_index][(i - t) + (u << 4) + (j << 2) + 1] = (__half)param_4[u].data_f[(j << 2) + 1]; - _doubled_buffer[_buf_index][(i - t) + (u << 4) + (j << 2) + 2] = (__half)param_4[u].data_f[(j << 2) + 2]; - _doubled_buffer[_buf_index][(i - t) + (u << 4) + (j << 2) + 3] = (__half)param_4[u].data_f[(j << 2) + 3]; - } - } - } - _mm512_storeu_ps(_params + i, param_4[0].data); _mm512_storeu_ps(_params + i + SIMD_WIDTH, param_4[1].data); _mm512_storeu_ps(_params + i + (SIMD_WIDTH << 1), param_4[2].data); _mm512_storeu_ps(_params + i + SIMD_WIDTH * 3, param_4[3].data); + if (dev_params) { + _mm512_storeu_ps(_doubled_buffer[_buf_index] + (i - t), param_4[0].data); + _mm512_storeu_ps(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH, + param_4[1].data); + _mm512_storeu_ps(_doubled_buffer[_buf_index] + (i - t) + (SIMD_WIDTH << 1), + param_4[2].data); + _mm512_storeu_ps(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 3, + param_4[3].data); + } + _mm512_storeu_ps(_exp_avg + i, momntum_4[0].data); _mm512_storeu_ps(_exp_avg + i + SIMD_WIDTH, momntum_4[1].data); _mm512_storeu_ps(_exp_avg + i + (SIMD_WIDTH << 1), momntum_4[2].data); @@ -308,24 +311,29 @@ void Adam_Optimizer::Step_4(float* _params, _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH * 3, varianc_4[3].data); } - if (dev_params) {/* -#pragma omp parallel for - for (size_t j = 0; j < copy_size; j += 4) { - _doubled_buffer[_buf_index][j] = (__half)_params[t + j]; - _doubled_buffer[_buf_index][j + 1] = (__half)_params[t + j + 1]; - _doubled_buffer[_buf_index][j + 2] = (__half)_params[t + j + 2]; - _doubled_buffer[_buf_index][j + 3] = (__half)_params[t + j + 3]; - }*/ - - CUDA_CHECK(cudaMemcpyAsync(dev_params + t, - _doubled_buffer[_buf_index], - copy_size * sizeof(__half), - cudaMemcpyHostToDevice, - Context::Instance().GetCurrentStream())); + if (dev_params) { /* + #pragma omp parallel for + for (size_t j = 0; j < copy_size; j += 4) { + _doubled_buffer[_buf_index][j] = (__half)_params[t + j]; + _doubled_buffer[_buf_index][j + 1] = (__half)_params[t + j + 1]; + _doubled_buffer[_buf_index][j + 2] = (__half)_params[t + j + 2]; + _doubled_buffer[_buf_index][j + 3] = (__half)_params[t + j + 3]; + } + + CUDA_CHECK(cudaMemcpyAsync(dev_params + t, + _doubled_buffer[_buf_index], + copy_size * sizeof(__half), + cudaMemcpyHostToDevice, + Context::Instance().GetCurrentStream())); + */ + launch_param_update(_doubled_buffer[_buf_index], + dev_params + t, + copy_size, + Context::Instance().GetCurrentStream()); _buf_index = !_buf_index; } } - if(_param_size > rounded_size) + if (_param_size > rounded_size) Step((_params + rounded_size), (grads + rounded_size), (_exp_avg + rounded_size), @@ -340,7 +348,6 @@ int create_adam_optimizer(int optimizer_id, float betta2 = 0.999, float eps = 1e-8, float weight_decay = 0) - { auto opt = std::make_shared(alpha, betta1, betta2, eps, weight_decay); @@ -375,8 +382,8 @@ void Adam_Optimizer::Step_8(float* _params, float bias_correction1 = 1 - _betta1_t; float bias_correction2 = 1 / sqrt(1 - _betta2_t); - //AVX_512 bias_correction1_4 = _mm512_set1_ps(bias_correction1); - AVX_512 bias2_sqrt ; + // AVX_512 bias_correction1_4 = _mm512_set1_ps(bias_correction1); + AVX_512 bias2_sqrt; bias2_sqrt.data = _mm512_set1_ps(bias_correction2); AVX_512 eps_4; @@ -437,32 +444,48 @@ void Adam_Optimizer::Step_8(float* _params, if (_weight_decay > 0) { AVX_512 weight_decay4; weight_decay4.data = _mm512_set1_ps(_weight_decay); - grad_4[0].data = _mm512_fmadd_ps(param_4[0].data, weight_decay4.data, grad_4[0].data); - grad_4[1].data = _mm512_fmadd_ps(param_4[1].data, weight_decay4.data, grad_4[1].data); - grad_4[2].data = _mm512_fmadd_ps(param_4[2].data, weight_decay4.data, grad_4[2].data); - grad_4[3].data = _mm512_fmadd_ps(param_4[3].data, weight_decay4.data, grad_4[3].data); - grad_4[4].data = _mm512_fmadd_ps(param_4[4].data, weight_decay4.data, grad_4[4].data); - grad_4[5].data = _mm512_fmadd_ps(param_4[5].data, weight_decay4.data, grad_4[5].data); - grad_4[6].data = _mm512_fmadd_ps(param_4[6].data, weight_decay4.data, grad_4[6].data); - grad_4[7].data = _mm512_fmadd_ps(param_4[7].data, weight_decay4.data, grad_4[7].data); + grad_4[0].data = + _mm512_fmadd_ps(param_4[0].data, weight_decay4.data, grad_4[0].data); + grad_4[1].data = + _mm512_fmadd_ps(param_4[1].data, weight_decay4.data, grad_4[1].data); + grad_4[2].data = + _mm512_fmadd_ps(param_4[2].data, weight_decay4.data, grad_4[2].data); + grad_4[3].data = + _mm512_fmadd_ps(param_4[3].data, weight_decay4.data, grad_4[3].data); + grad_4[4].data = + _mm512_fmadd_ps(param_4[4].data, weight_decay4.data, grad_4[4].data); + grad_4[5].data = + _mm512_fmadd_ps(param_4[5].data, weight_decay4.data, grad_4[5].data); + grad_4[6].data = + _mm512_fmadd_ps(param_4[6].data, weight_decay4.data, grad_4[6].data); + grad_4[7].data = + _mm512_fmadd_ps(param_4[7].data, weight_decay4.data, grad_4[7].data); } momntum_4[0].data = _mm512_mul_ps(momntum_4[0].data, betta1_4.data); - momntum_4[0].data = _mm512_fmadd_ps(grad_4[0].data, betta1_minus1_4.data, momntum_4[0].data); + momntum_4[0].data = + _mm512_fmadd_ps(grad_4[0].data, betta1_minus1_4.data, momntum_4[0].data); momntum_4[1].data = _mm512_mul_ps(momntum_4[1].data, betta1_4.data); - momntum_4[1].data = _mm512_fmadd_ps(grad_4[1].data, betta1_minus1_4.data, momntum_4[1].data); + momntum_4[1].data = + _mm512_fmadd_ps(grad_4[1].data, betta1_minus1_4.data, momntum_4[1].data); momntum_4[2].data = _mm512_mul_ps(momntum_4[2].data, betta1_4.data); - momntum_4[2].data = _mm512_fmadd_ps(grad_4[2].data, betta1_minus1_4.data, momntum_4[2].data); + momntum_4[2].data = + _mm512_fmadd_ps(grad_4[2].data, betta1_minus1_4.data, momntum_4[2].data); momntum_4[3].data = _mm512_mul_ps(momntum_4[3].data, betta1_4.data); - momntum_4[3].data = _mm512_fmadd_ps(grad_4[3].data, betta1_minus1_4.data, momntum_4[3].data); + momntum_4[3].data = + _mm512_fmadd_ps(grad_4[3].data, betta1_minus1_4.data, momntum_4[3].data); momntum_4[4].data = _mm512_mul_ps(momntum_4[4].data, betta1_4.data); - momntum_4[4].data = _mm512_fmadd_ps(grad_4[4].data, betta1_minus1_4.data, momntum_4[4].data); + momntum_4[4].data = + _mm512_fmadd_ps(grad_4[4].data, betta1_minus1_4.data, momntum_4[4].data); momntum_4[5].data = _mm512_mul_ps(momntum_4[5].data, betta1_4.data); - momntum_4[5].data = _mm512_fmadd_ps(grad_4[5].data, betta1_minus1_4.data, momntum_4[5].data); + momntum_4[5].data = + _mm512_fmadd_ps(grad_4[5].data, betta1_minus1_4.data, momntum_4[5].data); momntum_4[6].data = _mm512_mul_ps(momntum_4[6].data, betta1_4.data); - momntum_4[6].data = _mm512_fmadd_ps(grad_4[6].data, betta1_minus1_4.data, momntum_4[6].data); + momntum_4[6].data = + _mm512_fmadd_ps(grad_4[6].data, betta1_minus1_4.data, momntum_4[6].data); momntum_4[7].data = _mm512_mul_ps(momntum_4[7].data, betta1_4.data); - momntum_4[7].data = _mm512_fmadd_ps(grad_4[7].data, betta1_minus1_4.data, momntum_4[7].data); + momntum_4[7].data = + _mm512_fmadd_ps(grad_4[7].data, betta1_minus1_4.data, momntum_4[7].data); varianc_4[0].data = _mm512_mul_ps(varianc_4[0].data, betta2_4.data); varianc_4[1].data = _mm512_mul_ps(varianc_4[1].data, betta2_4.data); @@ -480,14 +503,22 @@ void Adam_Optimizer::Step_8(float* _params, grad_4[5].data = _mm512_mul_ps(grad_4[5].data, grad_4[5].data); grad_4[6].data = _mm512_mul_ps(grad_4[6].data, grad_4[6].data); grad_4[7].data = _mm512_mul_ps(grad_4[7].data, grad_4[7].data); - varianc_4[0].data = _mm512_fmadd_ps(grad_4[0].data, betta2_minus1_4.data, varianc_4[0].data); - varianc_4[1].data = _mm512_fmadd_ps(grad_4[1].data, betta2_minus1_4.data, varianc_4[1].data); - varianc_4[2].data = _mm512_fmadd_ps(grad_4[2].data, betta2_minus1_4.data, varianc_4[2].data); - varianc_4[3].data = _mm512_fmadd_ps(grad_4[3].data, betta2_minus1_4.data, varianc_4[3].data); - varianc_4[4].data = _mm512_fmadd_ps(grad_4[4].data, betta2_minus1_4.data, varianc_4[4].data); - varianc_4[5].data = _mm512_fmadd_ps(grad_4[5].data, betta2_minus1_4.data, varianc_4[5].data); - varianc_4[6].data = _mm512_fmadd_ps(grad_4[6].data, betta2_minus1_4.data, varianc_4[6].data); - varianc_4[7].data = _mm512_fmadd_ps(grad_4[7].data, betta2_minus1_4.data, varianc_4[7].data); + varianc_4[0].data = + _mm512_fmadd_ps(grad_4[0].data, betta2_minus1_4.data, varianc_4[0].data); + varianc_4[1].data = + _mm512_fmadd_ps(grad_4[1].data, betta2_minus1_4.data, varianc_4[1].data); + varianc_4[2].data = + _mm512_fmadd_ps(grad_4[2].data, betta2_minus1_4.data, varianc_4[2].data); + varianc_4[3].data = + _mm512_fmadd_ps(grad_4[3].data, betta2_minus1_4.data, varianc_4[3].data); + varianc_4[4].data = + _mm512_fmadd_ps(grad_4[4].data, betta2_minus1_4.data, varianc_4[4].data); + varianc_4[5].data = + _mm512_fmadd_ps(grad_4[5].data, betta2_minus1_4.data, varianc_4[5].data); + varianc_4[6].data = + _mm512_fmadd_ps(grad_4[6].data, betta2_minus1_4.data, varianc_4[6].data); + varianc_4[7].data = + _mm512_fmadd_ps(grad_4[7].data, betta2_minus1_4.data, varianc_4[7].data); grad_4[0].data = _mm512_sqrt_ps(varianc_4[0].data); grad_4[1].data = _mm512_sqrt_ps(varianc_4[1].data); @@ -533,24 +564,22 @@ void Adam_Optimizer::Step_8(float* _params, _mm512_storeu_ps(_params + i + SIMD_WIDTH * 6, param_4[6].data); _mm512_storeu_ps(_params + i + SIMD_WIDTH * 7, param_4[7].data); - //_mm512_storeu_ps(_doubled_buffer[_buf_index] + (i - t), param_4[0]); - //_mm512_storeu_ps(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH, param_4[1]); - //_mm512_storeu_ps(_doubled_buffer[_buf_index] + (i - t) + (SIMD_WIDTH << 1), param_4[2]); - //_mm512_storeu_ps(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 3, param_4[3]); - //_mm512_storeu_ps(_doubled_buffer[_buf_index] + (i - t) + (SIMD_WIDTH << 2), param_4[4]); - //_mm512_storeu_ps(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 5, param_4[5]); - //_mm512_storeu_ps(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 6, param_4[6]); - //_mm512_storeu_ps(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 7, param_4[7]); if (dev_params) { - for(int u = 0;u < 8;u++) - { - for (size_t j = 0; j < SIMD_WIDTH; j += 4) { - _doubled_buffer[_buf_index][(i - t) + (u << 4) + (j << 2)] = (__half)param_4[u].data_f[(j << 2)]; - _doubled_buffer[_buf_index][(i - t) + (u << 4) + (j << 2) + 1] = (__half)param_4[u].data_f[(j << 2) + 1]; - _doubled_buffer[_buf_index][(i - t) + (u << 4) + (j << 2) + 2] = (__half)param_4[u].data_f[(j << 2) + 2]; - _doubled_buffer[_buf_index][(i - t) + (u << 4) + (j << 2) + 3] = (__half)param_4[u].data_f[(j << 2) + 3]; - } - } + _mm512_storeu_ps(_doubled_buffer[_buf_index] + (i - t), param_4[0].data); + _mm512_storeu_ps(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH, + param_4[1].data); + _mm512_storeu_ps(_doubled_buffer[_buf_index] + (i - t) + (SIMD_WIDTH << 1), + param_4[2].data); + _mm512_storeu_ps(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 3, + param_4[3].data); + _mm512_storeu_ps(_doubled_buffer[_buf_index] + (i - t) + (SIMD_WIDTH << 2), + param_4[4].data); + _mm512_storeu_ps(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 5, + param_4[5].data); + _mm512_storeu_ps(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 6, + param_4[6].data); + _mm512_storeu_ps(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 7, + param_4[7].data); } _mm512_storeu_ps(_exp_avg + i, momntum_4[0].data); @@ -572,26 +601,20 @@ void Adam_Optimizer::Step_8(float* _params, _mm512_storeu_ps(_exp_avg_sq + i + SIMD_WIDTH * 7, varianc_4[7].data); } if (dev_params) { - /*launch_param_update(_doubled_buffer[_buf_index], + launch_param_update(_doubled_buffer[_buf_index], dev_params + t, copy_size, Context::Instance().GetCurrentStream()); - _buf_index = !_buf_index;*/ - CUDA_CHECK(cudaMemcpyAsync(dev_params + t, - _doubled_buffer[_buf_index], - copy_size * sizeof(__half), - cudaMemcpyHostToDevice, - Context::Instance().GetCurrentStream())); _buf_index = !_buf_index; } } - if(_param_size > rounded_size) - Step_4((_params + rounded_size), - (grads + rounded_size), - (_exp_avg + rounded_size), - (_exp_avg_sq + rounded_size), - (_param_size - rounded_size), - (dev_params != nullptr ? (dev_params + rounded_size) : dev_params)); + if (_param_size > rounded_size) + Step_4((_params + rounded_size), + (grads + rounded_size), + (_exp_avg + rounded_size), + (_exp_avg_sq + rounded_size), + (_param_size - rounded_size), + (dev_params != nullptr ? (dev_params + rounded_size) : dev_params)); } int ds_adam_step(int optimizer_id, diff --git a/csrc/includes/cpu_adam.h b/csrc/includes/cpu_adam.h old mode 100755 new mode 100644 index f56eb1501aa4..b373162612f2 --- a/csrc/includes/cpu_adam.h +++ b/csrc/includes/cpu_adam.h @@ -2,12 +2,12 @@ #include #include +#include #include #include "context.h" #include "cublas_v2.h" #include "cuda.h" #include "curand.h" -#include #define CUDA_CHECK(callstr) \ { \ @@ -38,8 +38,8 @@ class Adam_Optimizer { _betta2_t(1.0), _buf_index(false) { - cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(__half)); - cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(__half)); + cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float)); + cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float)); } ~Adam_Optimizer() { @@ -66,10 +66,9 @@ class Adam_Optimizer { __half* dev_params = nullptr); private: - - union AVX_512{ + union AVX_512 { __m512 data; - float data_f[16]; + // float data_f[16]; }; float _alpha; @@ -81,6 +80,6 @@ class Adam_Optimizer { float _betta1_t; float _betta2_t; - __half* _doubled_buffer[2]; + float* _doubled_buffer[2]; bool _buf_index; };