Skip to content

Commit

Permalink
support async update (PaddlePaddle#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
hutuxian authored Aug 6, 2020
1 parent 321c0a3 commit 17ec1a3
Show file tree
Hide file tree
Showing 7 changed files with 282 additions and 2 deletions.
31 changes: 30 additions & 1 deletion paddle/fluid/framework/device_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -378,11 +378,29 @@ class SectionWorker : public DeviceWorker {
void SetNextSectionPlace(const paddle::platform::Place& place) {
next_section_place_ = place;
}

SyncFunctor* sync_func_ = nullptr;
void SetSyncFunctor(SyncFunctor* sync_func) { sync_func_ = sync_func; }

static std::atomic<int> cpu_id_;

// Async
void SetAsyncParamName(std::vector<std::string>* async_param_list) {
async_param_list_ = async_param_list;
}
void SetAsyncParamSize(std::vector<size_t>* async_param_size) {
async_param_size_ = async_param_size;
}
void SetPsBuffer(
operators::reader::BlockingQueue<std::vector<LoDTensor>*>* ps_buffer) {
ps_buffer_ = ps_buffer;
}
void SetPs(std::vector<LoDTensor>* ps, RWLock* ps_lock) {
ps_ = ps;
ps_lock_ = ps_lock;
}
void SetAsyncMode(bool async_mode) { async_mode_ = async_mode; }

protected:
void AutoSetCPUAffinity(bool reuse);
int section_id_;
Expand All @@ -401,8 +419,19 @@ class SectionWorker : public DeviceWorker {
paddle::platform::Place next_section_place_;

std::vector<std::unique_ptr<OperatorBase>> ops_;

platform::DeviceContext* dev_ctx_ = nullptr;

// async
void PullDense(const Scope& scope);
void PushDense(const Scope& scope);
std::vector<std::string>* async_param_list_ = nullptr;
std::vector<size_t>* async_param_size_ = nullptr;
std::vector<LoDTensor> grad_;
operators::reader::BlockingQueue<std::vector<LoDTensor>*>* ps_buffer_ =
nullptr;
RWLock* ps_lock_;
std::vector<LoDTensor>* ps_;
bool async_mode_ = false;
};
#endif
} // namespace framework
Expand Down
13 changes: 12 additions & 1 deletion paddle/fluid/framework/fleet/box_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ class BoxWrapper {
void InitializeGPUAndLoadModel(
const char* conf_file, const std::vector<int>& slot_vector,
const std::vector<std::string>& slot_omit_in_feedpass,
const std::string& model_path) {
const std::string& model_path, const std::map<std::string, float> &lr_map) {
if (nullptr != s_instance_) {
VLOG(3) << "Begin InitializeGPU";
std::vector<cudaStream_t*> stream_list;
Expand All @@ -260,6 +260,15 @@ class BoxWrapper {
}
slot_vector_ = slot_vector;
device_caches_ = new DeviceBoxData[gpu_num];

VLOG(0) << "lr_map.size(): " << lr_map.size();
for (const auto e: lr_map) {
VLOG(0) << e.first << "'s lr is " << e.second;
if (e.first.find("param") != std::string::npos) {
lr_map_[e.first + ".w_0"] = e.second;
lr_map_[e.first + ".b_0"] = e.second;
}
}
}
}

Expand Down Expand Up @@ -680,6 +689,7 @@ class BoxWrapper {
}
int Phase() const { return phase_; }
void FlipPhase() { phase_ = (phase_ + 1) % phase_num_; }
const std::map<std::string, float> GetLRMap() const { return lr_map_; }
std::map<std::string, MetricMsg*>& GetMetricList() { return metric_lists_; }

void InitMetric(const std::string& method, const std::string& name,
Expand Down Expand Up @@ -761,6 +771,7 @@ class BoxWrapper {
std::shared_ptr<boxps::PaddleFileMgr> file_manager_ = nullptr;
// box device cache
DeviceBoxData* device_caches_ = nullptr;
std::map<std::string, float> lr_map_;

public:
static std::shared_ptr<boxps::PaddleShuffler> data_shuffle_;
Expand Down
166 changes: 166 additions & 0 deletions paddle/fluid/framework/pipeline_trainer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,82 @@
#if defined(PADDLE_WITH_NCCL)
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/fleet/box_wrapper.h"
#include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/framework/trainer_desc.pb.h"

namespace paddle {
namespace framework {

void PipelineTrainer::AsyncUpdate() {
#ifdef PADDLE_WITH_BOX_PS
VLOG(0) << "Begin AsyncUpdate";
std::vector<std::vector<LoDTensor>*> grad(4, nullptr); // max package

auto box_ptr = BoxWrapper::GetInstance();
std::map<std::string, float> lr_map = box_ptr->GetLRMap();

while (ps_buffer_->Receive(&grad[0])) {
size_t merge_num = ps_buffer_->Size() + 1;
if (merge_num > 4) {
merge_num = 4;
}
for (size_t i = 1; i < merge_num; ++i) {
ps_buffer_->Receive(&grad[i]);
}
AutoWRLock ps_lock(&ps_lock_);
// VLOG(0) << "AsyncUpdate recevie grads, and begin to update param, merge "
// << merge_num;
for (size_t i = 0; i < async_param_list_.size() / 3; ++i) {
LoDTensor* param_tensor = &ps_[i * 3];
LoDTensor* mom1_tensor = &ps_[i * 3 + 1];
LoDTensor* mom2_tensor = &ps_[i * 3 + 2];
LoDTensor* grad_tensor = &(*grad[0])[i];
auto len = async_param_size_[i * 3];
float* grad_data = grad_tensor->mutable_data<float>(platform::CPUPlace());
float* param_data =
param_tensor->mutable_data<float>(platform::CPUPlace());
float* mom1_data = mom1_tensor->mutable_data<float>(platform::CPUPlace());
float* mom2_data = mom2_tensor->mutable_data<float>(platform::CPUPlace());

// merge grad
for (size_t k = 1; k < merge_num; ++k) {
LoDTensor* other_grad_tensor = &(*grad[k])[i];
float* other_grad_data =
other_grad_tensor->mutable_data<float>(platform::CPUPlace());
for (size_t j = 0; j < len; ++j) {
grad_data[j] += other_grad_data[j];
}
}
if (merge_num > 1) {
for (size_t j = 0; j < len; ++j) {
grad_data[j] /= merge_num;
}
}
// float tmp = param_data[0];
float learning_rate = base_lr_;
if (lr_map.find(async_param_list_[i * 3]) != lr_map.end()) {
learning_rate = lr_map[async_param_list_[i * 3]];
}
// VLOG(0) << "learning rate for " << async_param_list_[i * 3] << " is "
// << learning_rate;
for (size_t j = 0; j < len; ++j) {
mom1_data[j] = 0.99 * mom1_data[j] +
0.01 * grad_data[j]; // magic beta and episilon
mom2_data[j] =
0.9999 * mom2_data[j] + 0.0001 * grad_data[j] * grad_data[j];
param_data[j] -=
learning_rate * (mom1_data[j] / (sqrt(mom2_data[j]) + 1e-8));
}
// VLOG(0) << "update dense for " << async_param_list_[i*3] << ", param["
// << tmp << "] - 0.000005 * [" << mom1_data[0] << "] / [" << mom1_data[1]
// << "] = [" << param_data[0] << "]";
}
}
VLOG(0) << "Quit AsyncUpdate";
#endif
}

void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* dataset) {
pipeline_num_ = trainer_desc.thread_num();
Expand All @@ -37,6 +107,7 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
scope_queue_size_ = pipeline_config_.queue_size();
sync_steps_ = pipeline_config_.sync_steps();
section_num_ = pipeline_config_.section_config_size();
async_mode_ = pipeline_config_.async_mode();

VLOG(3) << "scope_queue_size: " << scope_queue_size_;
VLOG(3) << "section num: " << section_num_;
Expand Down Expand Up @@ -222,6 +293,76 @@ void PipelineTrainer::InitTrainerEnv(const ProgramDesc& main_program,
}
}

if (async_mode_) {
VLOG(0) << "Begin Init For Aysnc Optimize";
// Variable *v = root_scope_->Var("whole_param");
// VLOG(0) << "create whole_param tensor done";
// std::vector<std::string> input_names{"h4_param.w_0", "h5_param.w_0"};
// std::string fused_var_name = "whole_param";
// std::vector<std::string> output_names{"g_h4_param.w_0",
// "g_h5_param.w_0"};
// paddle::framework::AttributeMap attrs;
// auto coa_op = framework::OpRegistry::CreateOp("coa_op", {{"Input",
// {input_names}}},
// {{"Output",
// {output_names}},
// {"FusedOutput",
// {fused_var_name}}},
// attrs);

// VLOG(0) << "create op done";
// coa_op->Run(*(root_scope_), platform::CPUPlace());

for (const auto& e : *param_need_sync_) {
if (e.find("param") != std::string::npos &&
e.find("pow_acc") == std::string::npos) {
VLOG(0) << "async mode choose " << e << " to update";
async_param_list_.push_back(e);
async_param_list_.push_back(e + "_moment1_0");
async_param_list_.push_back(e + "_moment2_0");
}
}
ps_.resize(async_param_list_.size());
VLOG(0) << "async_param_list_.size(): " << async_param_list_.size();
std::sort(
async_param_list_.begin(),
async_param_list_
.end()); // xx_param.b_0, xx_param_moment1_0, xx_param_moment2_0
for (size_t i = 0; i < async_param_list_.size(); ++i) {
VLOG(0) << "begin to copy " << async_param_list_[i];
const LoDTensor& root_tensor =
root_scope_->FindVar(async_param_list_[i])->Get<LoDTensor>();
VLOG(0) << "its size is " << root_tensor.numel();
async_param_size_.push_back(root_tensor.numel());
ps_[i].mutable_data<float>({root_tensor.numel(), 1},
platform::CPUPlace());
TensorCopy(*static_cast<const Tensor*>(&root_tensor),
platform::CPUPlace(), static_cast<Tensor*>(&(ps_[i])));
}

// Copy global lr for async mode
for (const auto& e : persistable_vars_) {
if (e.find("learning_rate_") != std::string::npos) {
PADDLE_ENFORCE_LE(
base_lr_, 0,
platform::errors::PreconditionNotMet(
"lr have been set, previous value: %f, current var is %s",
base_lr_, e.c_str()));
VLOG(0) << "begin to copy global learning rate: " << e;
const LoDTensor& root_tensor =
root_scope_->FindVar(e)->Get<LoDTensor>();
const float* gpu_lr = root_tensor.data<float>();
if (platform::is_gpu_place(root_tensor.place())) {
cudaMemcpy(&base_lr_, gpu_lr, sizeof(float), cudaMemcpyDeviceToHost);
} else {
base_lr_ = *gpu_lr;
}
}
}
VLOG(0) << "base lr is " << base_lr_;
ps_buffer_ = new operators::reader::BlockingQueue<std::vector<LoDTensor>*>(
8 * 3); // magic number
}
for (int i = 0; i < section_num_; ++i) {
for (int j = 0; j < pipeline_num_; ++j) {
for (size_t k = 0; k < workers_[i][j].size(); ++k) {
Expand All @@ -243,6 +384,14 @@ void PipelineTrainer::InitTrainerEnv(const ProgramDesc& main_program,
workers_[i + 1][j][0])
->place());
}
if (async_mode_) {
this_worker->SetAsyncParamName(&async_param_list_);
this_worker->SetAsyncParamSize(&async_param_size_);
this_worker->SetPsBuffer(ps_buffer_);
this_worker->SetAsyncMode(async_mode_);
VLOG(0) << "set ps buffer for card " << j;
this_worker->SetPs(&ps_, &ps_lock_);
}
}
}
}
Expand Down Expand Up @@ -300,12 +449,20 @@ void PipelineTrainer::Run() {
}
}
}
if (async_mode_) {
update_thread_ = new std::thread(&PipelineTrainer::AsyncUpdate, this);
}
}

void PipelineTrainer::Finalize() {
for (auto& th : section_threads_) {
th.join();
}
if (async_mode_) {
// must be after train thread, otherwise the ps_buffer_ will be closed first
ps_buffer_->Close();
update_thread_->join();
}
if (need_dump_field_) {
FinalizeDumpEnv();
}
Expand All @@ -316,6 +473,15 @@ void PipelineTrainer::Finalize() {
pipeline_scopes_[0]->FindVar(var)->Get<LoDTensor>();
TensorCopySync(thread_tensor, platform::CPUPlace(), root_tensor);
}
if (async_mode_) {
for (size_t i = 0; i < async_param_list_.size(); ++i) {
VLOG(0) << "begin to copy back" << async_param_list_[i];
auto* root_tensor =
root_scope_->Var(async_param_list_[i])->GetMutable<LoDTensor>();
TensorCopySync(*static_cast<const Tensor*>(&ps_[i]), platform::CPUPlace(),
root_tensor);
}
}
root_scope_->DropKids();
}

Expand Down
60 changes: 60 additions & 0 deletions paddle/fluid/framework/section_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,51 @@ void SectionWorker::AutoSetCPUAffinity(bool reuse) {
#endif
}

void SectionWorker::PullDense(const Scope& scope) {
// while(ps_buffer_->Size() != 0) {//Size have lock, may have perf problem.
// And will hang when the lock was removed
// ;
// }
AutoRDLock ps_lock(ps_lock_);
for (size_t i = 0; i < async_param_list_->size(); ++i) {
if (i % 3 != 0) {
continue;
}
const std::string& param_name = (*async_param_list_)[i];
Variable* var = scope.FindVar(param_name);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
TensorCopy(*static_cast<const Tensor*>(&(*ps_)[i]), place_,
static_cast<Tensor*>(tensor));

// float *p = (*ps_)[i].mutable_data<float>(platform::CPUPlace());
// VLOG(0) << "pull dense for " << (*async_param_list_)[i] << ", and the
// first ele is " << p[0];
}
VLOG(0) << "card[" << pipeline_id_ << "] pull dense done";
}

void SectionWorker::PushDense(const Scope& scope) {
for (size_t i = 0; i < async_param_list_->size(); ++i) {
if (i % 3 != 0) {
continue;
}
// VLOG(0) << "push dense for " << (*async_param_list_)[i] << "@GRAD";
std::string grad_name = (*async_param_list_)[i] + "@GRAD";
Variable* var = scope.FindVar(grad_name);
CHECK(var != nullptr) << "var[" << grad_name << "] not found";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
// For Debug
float* g = tensor->mutable_data<float>(place_);
float tmp;
cudaMemcpy(&tmp, g, 1 * sizeof(float), cudaMemcpyDeviceToHost);
// VLOG(0) << "the first element of grad_name is: " << tmp;
TensorCopy(*static_cast<const Tensor*>(tensor), platform::CPUPlace(),
static_cast<Tensor*>(&grad_[i / 3]));
}
ps_buffer_->Send(&grad_);
VLOG(0) << "card[" << pipeline_id_ << "] push dense done";
}

void SectionWorker::TrainFiles() {
SEC_LOG << "begin section_worker TrainFiles";
AutoSetCPUAffinity(true);
Expand All @@ -166,6 +211,15 @@ void SectionWorker::TrainFiles() {
if (device_reader_ != nullptr) {
device_reader_->Start();
}
if (async_mode_) {
grad_.resize(async_param_size_->size() / 3);
for (size_t i = 0; i < async_param_size_->size(); ++i) {
if (i % 3 != 0) continue;
grad_[i / 3].mutable_data<float>(
{static_cast<int64_t>((*async_param_size_)[i]), 1}, place_);
}
}

while (in_scope_queue_->Receive(&scope)) {
if (device_reader_ != nullptr) {
device_reader_->AssignFeedVar(*scope);
Expand Down Expand Up @@ -213,9 +267,15 @@ void SectionWorker::TrainFiles() {

SEC_LOG << "begin running ops";

if (async_mode_) {
PullDense(*exe_scope);
}
for (auto& op : ops_) {
op->Run(*exe_scope, place_);
}
if (async_mode_) {
PushDense(*exe_scope);
}
exe_scope->DropKids();
// Wait for GPU calc finising, as the cudaMemcpy and GPU calc may be in
// different streams
Expand Down
Loading

0 comments on commit 17ec1a3

Please sign in to comment.