Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HeterPs]ps gpu dump #36157

Merged
merged 2 commits into from
Sep 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions paddle/fluid/framework/device_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,6 @@ class PSGPUWorker : public HogwildWorker {
virtual void Initialize(const TrainerDesc& desc);
virtual void TrainFiles();
virtual void TrainFilesWithProfiler();
virtual void SetNeedDump(bool need_dump_field);
virtual void SetChannelWriter(ChannelObject<std::string>* queue);
virtual void SetWorkerNum(int num) { worker_num_ = num; }
virtual void CacheProgram(const ProgramDesc& main_program) {
Expand All @@ -467,26 +466,19 @@ class PSGPUWorker : public HogwildWorker {

protected:
void PushGradients();
void DumpParam();
void CopySparseTable();
void CopyDenseTable();
void CopyDenseVars();

private:
int mpi_rank_;
std::mutex mutex_;
std::vector<std::string> send_var_list_;
int worker_num_;
ProgramDesc program_;
HeterObjectPool<HeterTask> object_pool_;
bool need_dump_param_;
std::vector<std::string> dump_param_;
bool need_to_push_dense_;
bool need_dump_field_;
bool dump_slot_;
bool need_to_push_sparse_;
std::vector<std::string> dump_fields_;
ChannelWriter<std::string> writer_;
DownpourWorkerParameter param_;
float scale_datanorm_;
// just save the value in param_ for easy access
Expand Down
45 changes: 43 additions & 2 deletions paddle/fluid/framework/ps_gpu_trainer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@ namespace framework {

void PSGPUTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* dataset) {
dataset_ = dataset;
SetDataset(dataset);
thread_num_ = trainer_desc.thread_num();
param_ = trainer_desc.downpour_param();
ParseDumpConfig(trainer_desc);
mpi_rank_ = trainer_desc.mpi_rank();
mpi_size_ = trainer_desc.mpi_size();
for (int i = 0; i < param_.dense_table_size(); ++i) {
uint64_t table_id = static_cast<uint64_t>(param_.dense_table(i).table_id());
auto table = param_.dense_table(i);
Expand All @@ -44,6 +47,8 @@ void PSGPUTrainer::Initialize(const TrainerDesc& trainer_desc,
int place_num = trainer_desc.worker_places_size();
const std::vector<paddle::framework::DataFeed*> readers =
dataset->GetReaders();
dump_file_num_ = trainer_desc.dump_file_num();
user_define_dump_filename_ = trainer_desc.user_define_dump_filename();
std::vector<int> dev_ids;
for (int i = 0; i < place_num; ++i) {
int num = trainer_desc.worker_places(i);
Expand All @@ -64,14 +69,26 @@ void PSGPUTrainer::Initialize(const TrainerDesc& trainer_desc,
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name());
workers_[i]->SetDeviceIndex(i);
workers_[i]->SetNeedDumpField(need_dump_field_);
workers_[i]->SetNeedDumpParam(need_dump_param_);
workers_[i]->SetDumpFieldVector(dump_fields_);
workers_[i]->SetDumpParamVector(dump_param_);
workers_[i]->InitRandomDumpConfig(trainer_desc);
workers_[i]->SetDataFeed(readers[i]);
workers_[i]->Initialize(trainer_desc);
workers_[i]->SetWorkerNum(place_num);
}
return;
}

void PSGPUTrainer::DumpWork(int tid) {}
std::string PSGPUTrainer::GetDumpPath(int tid) {
if (user_define_dump_filename_ != "") {
return string::format_string("%s/part-%s-%05d", dump_fields_path_.c_str(),
user_define_dump_filename_.c_str(), tid);
}
return string::format_string("%s/part-%03d-%05d", dump_fields_path_.c_str(),
mpi_rank_, tid);
}

void PSGPUTrainer::RegisterHeterCallback() {
/*
Expand Down Expand Up @@ -124,7 +141,28 @@ void PSGPUTrainer::InitTrainerEnv(const ProgramDesc& main_program,
return;
}

void PSGPUTrainer::InitDumpEnv() {
queue_ = paddle::framework::MakeChannel<std::string>();
for (size_t i = 0; i < places_.size(); ++i) {
workers_[i]->SetChannelWriter(queue_.get());
}
dump_thread_num_ = 1;
if (dump_file_num_ > mpi_size_) {
dump_thread_num_ = dump_file_num_ / mpi_size_;
if (dump_file_num_ % mpi_size_ > mpi_rank_) {
dump_thread_num_ += 1;
}
}
for (int i = 0; i < dump_thread_num_; i++) {
dump_thread_.push_back(
std::thread(std::bind(&TrainerBase::DumpWork, this, i)));
}
}

void PSGPUTrainer::InitOtherEnv(const ProgramDesc& main_program) {
if (need_dump_field_ || need_dump_param_) {
InitDumpEnv();
}
VLOG(3) << "init other env done.";
}

Expand Down Expand Up @@ -204,6 +242,9 @@ void PSGPUTrainer::Finalize() {
}
}
MergeDenseParam();
if (need_dump_field_ || need_dump_param_) {
FinalizeDumpEnv();
}
root_scope_->DropKids();
}
} // namespace framework
Expand Down
34 changes: 11 additions & 23 deletions paddle/fluid/framework/ps_gpu_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,6 @@ void PSGPUWorker::Initialize(const TrainerDesc& desc) {
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
mpi_rank_ = desc.mpi_rank();
trainer_desc_ = desc;
/*
for (int i = 0; i < trainer_desc_.xpu_recv_list_size(); ++i) {
send_var_list_.push_back(trainer_desc_.xpu_recv_list(i));
}
*/
for (int i = 0; i < param_.sparse_table_size(); ++i) {
uint64_t table_id =
static_cast<uint64_t>(param_.sparse_table(i).table_id());
Expand Down Expand Up @@ -89,19 +84,7 @@ void PSGPUWorker::Initialize(const TrainerDesc& desc) {
no_cvm_ = desc.no_cvm();
scale_datanorm_ = desc.scale_datanorm();
dump_slot_ = desc.dump_slot();
dump_fields_.resize(desc.dump_fields_size());
for (int i = 0; i < desc.dump_fields_size(); ++i) {
dump_fields_[i] = desc.dump_fields(i);
}
adjust_ins_weight_config_ = desc.adjust_ins_weight_config();
need_dump_param_ = false;
dump_param_.resize(desc.dump_param_size());
for (int i = 0; i < desc.dump_param_size(); ++i) {
dump_param_[i] = desc.dump_param(i);
}
if (desc.dump_param_size() != 0) {
need_dump_param_ = true;
}
for (int i = 0; i < desc.check_nan_var_names_size(); ++i) {
check_nan_var_names_.push_back(desc.check_nan_var_names(i));
}
Expand Down Expand Up @@ -134,12 +117,6 @@ void PSGPUWorker::SetChannelWriter(ChannelObject<std::string>* queue) {
writer_.Reset(queue);
}

void PSGPUWorker::SetNeedDump(bool need_dump_field) {
need_dump_field_ = need_dump_field;
}

void PSGPUWorker::DumpParam() {}

void PSGPUWorker::TrainFiles() {
platform::SetNumThreads(1);
platform::Timer timeline;
Expand All @@ -150,6 +127,7 @@ void PSGPUWorker::TrainFiles() {
// how to accumulate fetched values here
device_reader_->Start();
int cur_batch;
int batch_cnt = 0;
while ((cur_batch = device_reader_->Next()) > 0) {
total_ins_num += cur_batch;
for (auto& op : ops_) {
Expand All @@ -164,9 +142,19 @@ void PSGPUWorker::TrainFiles() {
op->Run(*thread_scope_, place_);
}
}
if (need_dump_field_) {
DumpField(*thread_scope_, dump_mode_, dump_interval_);
}
if (need_dump_param_ && thread_id_ == 0) {
DumpParam(*thread_scope_, batch_cnt);
}

PrintFetchVars();
thread_scope_->DropKids();
++batch_cnt;
}
if (need_dump_field_ || need_dump_param_) {
writer_.Flush();
}
timeline.Pause();
VLOG(1) << "GpuPs worker " << thread_id_ << " train cost "
Expand Down
8 changes: 5 additions & 3 deletions paddle/fluid/framework/trainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,12 @@ class PSGPUTrainer : public TrainerBase {
virtual void Run();
virtual void Finalize();
virtual void RegisterHeterCallback();
virtual void DumpWork(int tid);
virtual Scope* GetWorkerScope(int thread_id);
virtual void CacheProgram(const ProgramDesc& main_program) {
new (&program_) ProgramDesc(main_program);
}
virtual std::string GetDumpPath(int tid) { return ""; }
virtual void InitDumpEnv() {}
virtual std::string GetDumpPath(int tid);
virtual void InitDumpEnv() override;
virtual void MergeDenseParam();

template <typename T>
Expand All @@ -286,6 +285,9 @@ class PSGPUTrainer : public TrainerBase {
std::vector<std::thread> threads_;
int use_ps_gpu_;
int thread_num_;
int mpi_rank_;
int mpi_size_;
int dump_file_num_;
};
#endif

Expand Down