Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
… performance_opt_v2
  • Loading branch information
zyfncg committed Apr 21, 2022
2 parents b150b30 + c51f55f commit e6fb189
Show file tree
Hide file tree
Showing 48 changed files with 1,117 additions and 534 deletions.
6 changes: 5 additions & 1 deletion cmake/cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,11 @@ function(select_nvcc_arch_flags out_variable)
elseif(${CUDA_ARCH_NAME} STREQUAL "Turing")
set(cuda_arch_bin "75")
elseif(${CUDA_ARCH_NAME} STREQUAL "Ampere")
set(cuda_arch_bin "80")
if (${CMAKE_CUDA_COMPILER_VERSION} LESS 11.1) # CUDA 11.0
set(cuda_arch_bin "80")
elseif (${CMAKE_CUDA_COMPILER_VERSION} LESS 12.0) # CUDA 11.1+
set(cuda_arch_bin "80 86")
endif()
elseif(${CUDA_ARCH_NAME} STREQUAL "All")
set(cuda_arch_bin ${paddle_known_gpu_archs})
elseif(${CUDA_ARCH_NAME} STREQUAL "Auto")
Expand Down
37 changes: 33 additions & 4 deletions paddle/fluid/framework/fleet/heter_ps/hashtable.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ limitations under the License. */
#include "xpu/kernel/simd.h"
#endif

#if defined(PADDLE_WITH_XPU_KP)
#include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h"
#endif

namespace paddle {
namespace framework {

Expand All @@ -56,11 +60,10 @@ class TableContainer
capacity, ValType()) {}
};
#elif defined(PADDLE_WITH_XPU_KP)

template <typename KeyType, typename ValType>
class XPUCacheArray {
public:
explicit XPUCacheArray(size_t capacity) : capacity_(capacity), size_(0) {
explicit XPUCacheArray(long long capacity) : capacity_(capacity), size_(0) {
xpu_malloc(reinterpret_cast<void**>(&keys), capacity_ * sizeof(KeyType));
xpu_malloc(reinterpret_cast<void**>(&vals), capacity_ * sizeof(ValType));
}
Expand All @@ -71,8 +74,27 @@ class XPUCacheArray {
}

void print() {}
// ValType* find(const KeyType& key) { return NULL; }
// bool insert(const KeyType& key, const ValType& val) { return true; }

#if defined(__xpu__)
__device__ ValType* find(const KeyType& key) {
for (int i = 0; i < size_; i++) {
if (keys[i] == key) return &vals[i];
}
return NULL;
}
__device__ bool insert(const KeyType& key, const ValType& val) {
// # NOTE(zhangminxu): we set the capacity larger than the feasign number of
// one batch
if (size_ == capacity_) {
return false;
} else {
keys[size_] = key;
vals[size_] = val;
size_++;
return true;
}
}
#endif

int prefetch(const int dev_id, XPUStream stream = NULL) { return 0; }
size_t size() { return size_; }
Expand Down Expand Up @@ -110,6 +132,11 @@ class HashTable {

void show();

#if defined(PADDLE_WITH_XPU_KP)
void set_sparse_sgd(const OptimizerConfig& optimizer_config);
void set_embedx_sgd(const OptimizerConfig& optimizer_config);
#endif

template <typename StreamType>
void dump_to_cpu(int devid, StreamType stream);

Expand Down Expand Up @@ -151,6 +178,8 @@ class HashTable {
TableContainer<KeyType, ValType>* container_;
#elif defined(PADDLE_WITH_XPU_KP)
XPUCacheArray<KeyType, ValType>* container_;
OptimizerConfig* xpu_optimizer_config_;
OptimizerConfig cpu_optimizer_config_;
#endif
int BLOCK_SIZE_{256};
float LOAD_FACTOR{0.75f};
Expand Down
150 changes: 79 additions & 71 deletions paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.kps
Original file line number Diff line number Diff line change
Expand Up @@ -14,41 +14,21 @@ limitations under the License. */

#ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/framework/fleet/heter_ps/hashtable.h"

namespace optimizer_config {
extern _global_ptr_ float* nonclk_coeff;
extern _global_ptr_ float* clk_coeff;

extern _global_ptr_ float* min_bound;
extern _global_ptr_ float* max_bound;
extern _global_ptr_ float* learning_rate;
extern _global_ptr_ float* initial_g2sum;
extern _global_ptr_ float* initial_range;

extern _global_ptr_ float* mf_create_thresholds;
extern _global_ptr_ float* mf_learning_rate;
extern _global_ptr_ float* mf_initial_g2sum;
extern _global_ptr_ float* mf_initial_range;
extern _global_ptr_ float* mf_min_bound;
extern _global_ptr_ float* mf_max_bound;
}
#include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h"

namespace paddle {
namespace framework {

#if defined(PADDLE_WITH_XPU_KP)

__device__ void update_lr(float& w, float& g2sum, float g, // NOLINT
__device__ void update_lr(OptimizerConfig& optimizer_config, float& w,
float& g2sum,
float g, // NOLINT
float scale) {
__local__ float local_learning_rate;
__local__ float local_initial_g2sum;
__local__ float local_min_bound;
__local__ float local_max_bound;

GM2LM(optimizer_config::learning_rate, &local_learning_rate, sizeof(float));
GM2LM(optimizer_config::initial_g2sum, &local_initial_g2sum, sizeof(float));
GM2LM(optimizer_config::min_bound, &local_min_bound, sizeof(float));
GM2LM(optimizer_config::max_bound, &local_max_bound, sizeof(float));
float local_learning_rate = optimizer_config.learning_rate;
float local_initial_g2sum = optimizer_config.initial_g2sum;
float local_min_bound = optimizer_config.min_bound;
float local_max_bound = optimizer_config.max_bound;

double add_g2sum = 0;
double ratio = local_learning_rate *
Expand All @@ -65,19 +45,12 @@ __device__ void update_lr(float& w, float& g2sum, float g, // NOLINT
g2sum += add_g2sum;
}

__device__ void update_mf(int n, float* w, float& g2sum, const float* g,
float scale) {
__local__ float local_mf_learning_rate;
__local__ float local_mf_initial_g2sum;
__local__ float local_mf_min_bound;
__local__ float local_mf_max_bound;

GM2LM(optimizer_config::mf_learning_rate, &local_mf_learning_rate,
sizeof(float));
GM2LM(optimizer_config::mf_initial_g2sum, &local_mf_initial_g2sum,
sizeof(float));
GM2LM(optimizer_config::mf_min_bound, &local_mf_min_bound, sizeof(float));
GM2LM(optimizer_config::mf_max_bound, &local_mf_max_bound, sizeof(float));
__device__ void update_mf(OptimizerConfig& optimizer_config, int n, float* w,
float& g2sum, const float* g, float scale) {
float local_mf_learning_rate = optimizer_config.mf_learning_rate;
float local_mf_initial_g2sum = optimizer_config.mf_initial_g2sum;
float local_mf_min_bound = optimizer_config.mf_min_bound;
float local_mf_max_bound = optimizer_config.mf_max_bound;

double add_g2sum = 0;
double ratio =
Expand All @@ -98,26 +71,22 @@ __device__ void update_mf(int n, float* w, float& g2sum, const float* g,
__device__ float xpu_rand_uniform() { return 0.1; }

template <typename ValType, typename GradType>
__device__ void update_value(ValType& val, const GradType& grad) { // NOLINT
__device__ void update_value(OptimizerConfig& optimizer_config, ValType& val,
const GradType& grad) { // NOLINT
val.slot = grad.slot;
val.show += grad.show;
val.clk += grad.clk;

__local__ float local_nonclk_coeff;
__local__ float local_clk_coeff;
float local_nonclk_coeff = optimizer_config.nonclk_coeff;
float local_clk_coeff = optimizer_config.clk_coeff;

__local__ float local_mf_create_thresholds;
__local__ float local_mf_initial_range;

GM2LM(optimizer_config::nonclk_coeff, &local_nonclk_coeff, sizeof(float));
GM2LM(optimizer_config::clk_coeff, &local_clk_coeff, sizeof(float));
GM2LM(optimizer_config::mf_create_thresholds, &local_mf_create_thresholds,
sizeof(float));
float local_mf_create_thresholds = optimizer_config.mf_create_thresholds;
float local_mf_initial_range = optimizer_config.mf_initial_range;

val.delta_score +=
local_nonclk_coeff * (grad.show - grad.clk) + local_clk_coeff * grad.clk;

update_lr(val.lr, val.lr_g2sum, grad.lr_g, grad.show);
update_lr(optimizer_config, val.lr, val.lr_g2sum, grad.lr_g, grad.show);

if (val.mf_size == 0) {
if (local_mf_create_thresholds <=
Expand All @@ -130,12 +99,13 @@ __device__ void update_value(ValType& val, const GradType& grad) { // NOLINT
}
}
} else {
update_mf(MF_DIM, &val.mf[1], val.mf[0], grad.mf_g, grad.show);
update_mf(optimizer_config, MF_DIM, &val.mf[1], val.mf[0], grad.mf_g,
grad.show);
}
}

template <typename KeyType, typename ValType, typename Table>
__global__ void insert_kernel(Table* table, const KeyType* const keys,
__global__ void insert_kernel(Table& table, const KeyType* const keys,
const ValType* const vals, long long len) {
int cid = core_id();
int ncores = core_num();
Expand All @@ -156,14 +126,14 @@ __global__ void insert_kernel(Table* table, const KeyType* const keys,
GM2LM(keys, local_keys, read_len * sizeof(KeyType));
GM2LM(vals, local_vals, read_len * sizeof(ValType));
for (int k = 0; k < read_len; k++) {
// auto status = table->insert(local_keys[k], local_vals[k]);
// assert(status != false && "error: insert fails: table is full");
auto status = table.insert(local_keys[k], local_vals[k]);
assert(status != false && "error: insert fails: table is full");
}
}
}

template <typename KeyType, typename ValType, typename Table>
__global__ void search_kernel(Table* table, const KeyType* const keys,
__global__ void search_kernel(Table& table, const KeyType* const keys,
ValType* const vals, long long len) {
int cid = core_id();
int ncores = core_num();
Expand All @@ -183,17 +153,18 @@ __global__ void search_kernel(Table* table, const KeyType* const keys,
int read_len = min(len_per_loop, len - i);
GM2LM(keys, local_keys, read_len * sizeof(KeyType));
for (int k = 0; k < read_len; k++) {
// ValType* val = table->find(local_keys[k]);
// if (val != NULL) {
// local_vals[k] = *val;
// }
ValType* val = table.find(local_keys[k]);
if (val != NULL) {
local_vals[k] = *val;
}
}
LM2GM(local_vals, vals + i, read_len * sizeof(ValType));
}
}

template <typename KeyType, typename ValType, typename Table, typename GradType>
__global__ void update_kernel(Table* table, const KeyType* const keys,
__global__ void update_kernel(OptimizerConfig& optimizer_config, Table& table,
const KeyType* const keys,
const GradType* const grads, long long len) {
int cid = core_id();
int ncores = core_num();
Expand All @@ -216,10 +187,10 @@ __global__ void update_kernel(Table* table, const KeyType* const keys,
GM2LM(grads, local_grads, read_len * sizeof(GradType));

for (int k = 0; k < read_len; k++) {
// ValType* val = table->find(local_keys[k]);
// if (val != NULL) {
// update_value(*val, grads[i]);
//}
ValType* val = table.find(local_keys[k]);
if (val != NULL) {
update_value(optimizer_config, *val, local_grads[i]);
}
}
}
}
Expand All @@ -229,21 +200,58 @@ HashTable<KeyType, ValType>::HashTable(size_t capacity) {
auto tmp_container = XPUCacheArray<KeyType, ValType>(capacity);
xpu_malloc(reinterpret_cast<void**>(&container_),
sizeof(XPUCacheArray<KeyType, ValType>));
xpu_memcpy(container_, &tmp_container,
xpu_memcpy((void*)container_, &tmp_container,
sizeof(XPUCacheArray<KeyType, ValType>), XPU_HOST_TO_DEVICE);

OptimizerConfig tmp_opt_config;
xpu_malloc(reinterpret_cast<void**>(&xpu_optimizer_config_),
sizeof(OptimizerConfig));

xpu_memcpy((void*)xpu_optimizer_config_, &tmp_opt_config,
sizeof(OptimizerConfig), XPU_HOST_TO_DEVICE);

rwlock_.reset(new phi::RWLock);
}

template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::~HashTable() {
xpu_free((void*)container_);
xpu_free((void*)xpu_optimizer_config_);
}

template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::show() {
container_->print();
}

template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::set_sparse_sgd(
const OptimizerConfig& optimizer_config) {
cpu_optimizer_config_.nonclk_coeff = optimizer_config.nonclk_coeff;
cpu_optimizer_config_.clk_coeff = optimizer_config.clk_coeff;
cpu_optimizer_config_.min_bound = optimizer_config.min_bound;
cpu_optimizer_config_.max_bound = optimizer_config.max_bound;
cpu_optimizer_config_.learning_rate = optimizer_config.learning_rate;
cpu_optimizer_config_.initial_g2sum = optimizer_config.initial_g2sum;
cpu_optimizer_config_.initial_range = optimizer_config.initial_range;
xpu_memcpy((void*)xpu_optimizer_config_, &cpu_optimizer_config_,
sizeof(OptimizerConfig), XPU_HOST_TO_DEVICE);
}

template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::set_embedx_sgd(
const OptimizerConfig& optimizer_config) {
cpu_optimizer_config_.mf_create_thresholds =
optimizer_config.mf_create_thresholds;
cpu_optimizer_config_.mf_learning_rate = optimizer_config.mf_learning_rate;
cpu_optimizer_config_.mf_initial_g2sum = optimizer_config.mf_initial_g2sum;
cpu_optimizer_config_.mf_initial_range = optimizer_config.mf_initial_range;
cpu_optimizer_config_.mf_min_bound = optimizer_config.mf_min_bound;
cpu_optimizer_config_.mf_max_bound = optimizer_config.mf_max_bound;
xpu_memcpy((void*)xpu_optimizer_config_, &cpu_optimizer_config_,
sizeof(OptimizerConfig), XPU_HOST_TO_DEVICE);
}

template <typename KeyType, typename ValType>
template <typename StreamType>
void HashTable<KeyType, ValType>::get(const KeyType* d_keys, ValType* d_vals,
Expand All @@ -254,7 +262,7 @@ void HashTable<KeyType, ValType>::get(const KeyType* d_keys, ValType* d_vals,
long long c_len = (long long)len;
search_kernel<KeyType, ValType,
XPUCacheArray<KeyType, ValType>><<<4, 64, stream>>>(
container_, d_keys, d_vals, c_len);
*container_, d_keys, d_vals, c_len);
}

template <typename KeyType, typename ValType>
Expand All @@ -278,7 +286,7 @@ void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
long long c_len = (long long)len;
insert_kernel<KeyType, ValType,
XPUCacheArray<KeyType, ValType>><<<4, 64, stream>>>(
container_, d_keys, d_vals, c_len);
*container_, d_keys, d_vals, c_len);
}

template <typename KeyType, typename ValType>
Expand All @@ -297,8 +305,8 @@ void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
}
long long c_len = (long long)len;
update_kernel<KeyType, ValType, XPUCacheArray<KeyType, ValType>,
GradType><<<4, 64, stream>>>(container_, d_keys, d_grads,
c_len);
GradType><<<4, 64, stream>>>(
*xpu_optimizer_config_, *container_, d_keys, d_grads, c_len);
}

template <typename KeyType, typename ValType>
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/framework/fleet/heter_ps/heter_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/platform/dynload/nccl.h"
#include "thrust/pair.h"
#elif defined(PADDLE_WITH_XPU_KP)
// #include "paddle/fluid/framework/fleet/heter_ps/optimizer_conf.h"
#include <xpu/runtime.h>
#include "paddle/fluid/platform/device/xpu/enforce_xpu.h"
#endif
Expand Down Expand Up @@ -64,6 +65,11 @@ class HeterComm {
void push_sparse(int num, KeyType* d_keys, GradType* d_grads, size_t len);
#endif

#if defined(PADDLE_WITH_XPU_KP)
void set_sparse_sgd(const OptimizerConfig& optimizer_config);
void set_embedx_sgd(const OptimizerConfig& optimizer_config);
#endif

int log2i(int x);

template <typename DstPlace, typename SrcPlace, typename StreamType>
Expand Down
Loading

1 comment on commit e6fb189

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.