Skip to content

Commit

Permalink
[GPUGraph] enable exit trainer when an worker has no instance (Paddle…
Browse files Browse the repository at this point in the history
…Paddle#75)

* add barrier

* train mode
  • Loading branch information
Thunderbrook authored Jul 29, 2022
1 parent eed5f6b commit a0a6e81
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 6 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ if(WITH_DISTRIBUTE)
index_sampler index_wrapper sampler index_dataset_proto
lod_rank_table fs shell fleet_wrapper heter_wrapper box_wrapper metrics lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer monitor
heter_service_proto fleet heter_server brpc fleet_executor)
heter_service_proto fleet heter_server brpc fleet_executor flags)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=parentheses")
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
set(DISTRIBUTE_COMPILE_FLAGS
Expand Down
44 changes: 44 additions & 0 deletions paddle/fluid/framework/barrier.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

namespace paddle {
namespace framework {

class Barrier {
public:
explicit Barrier(int count = 1) {
CHECK(count >= 1);
PCHECK(0 == pthread_barrier_init(&_barrier, NULL, count));
}
~Barrier() {
PCHECK(0 == pthread_barrier_destroy(&_barrier));
}
void reset(int count) {
CHECK(count >= 1);
PCHECK(0 == pthread_barrier_destroy(&_barrier));
PCHECK(0 == pthread_barrier_init(&_barrier, NULL, count));
}
void wait() {
int err = pthread_barrier_wait(&_barrier);
PCHECK((err = pthread_barrier_wait(&_barrier), err == 0 || err == PTHREAD_BARRIER_SERIAL_THREAD));
}
private:
pthread_barrier_t _barrier;
DISABLE_COPY_AND_ASSIGN(Barrier);
};

}
}
5 changes: 5 additions & 0 deletions paddle/fluid/framework/data_feed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2066,6 +2066,11 @@ void SlotRecordInMemoryDataFeed::Init(const DataFeedDesc& data_feed_desc) {
so_parser_name_.clear();
}
gpu_graph_data_generator_.SetConfig(data_feed_desc);
if (gpu_graph_mode_) {
train_mode_ = true;
} else {
train_mode_ = data_feed_desc.graph_config().gpu_graph_training();
}
}

void SlotRecordInMemoryDataFeed::LoadIntoMemory() {
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/data_feed.h
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,7 @@ class DataFeed {
return ins_content_vec_;
}
virtual int GetCurBatchSize() { return batch_size_; }
virtual bool IsTrainMode() { return train_mode_; }
virtual void LoadIntoMemory() {
PADDLE_THROW(platform::errors::Unimplemented(
"This function(LoadIntoMemory) is not implemented."));
Expand Down Expand Up @@ -1119,6 +1120,7 @@ class DataFeed {
int input_type_;
int gpu_graph_mode_ = 0;
GraphDataGenerator gpu_graph_data_generator_;
bool train_mode_;
};

// PrivateQueueDataFeed is the base virtual class for ohther DataFeeds.
Expand Down
8 changes: 6 additions & 2 deletions paddle/fluid/framework/device_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ limitations under the License. */
#endif

#include <map>
#include "paddle/fluid/framework/barrier.h"
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/heter_util.h"
Expand Down Expand Up @@ -208,6 +209,9 @@ class DeviceWorker {
virtual void SetDeviceContext(platform::DeviceContext* dev_ctx) {
dev_ctx_ = dev_ctx;
}
virtual void SetThreadNum(int thread_num) {
thread_num_ = thread_num;
}
virtual Scope* GetThreadScope() { return thread_scope_; }
DataFeed* device_reader_ = nullptr;

Expand Down Expand Up @@ -237,6 +241,7 @@ class DeviceWorker {
ChannelWriter<std::string> writer_;
const size_t tensor_iterator_thread_num = 16;
platform::DeviceContext* dev_ctx_ = nullptr;
int thread_num_;
};

class CPUWorkerBase : public DeviceWorker {
Expand Down Expand Up @@ -282,6 +287,7 @@ class HogwildWorker : public CPUWorkerBase {
HogwildWorkerParameter param_;
std::vector<std::string> skip_ops_;
std::map<std::string, int> stat_var_name_map_;
static std::atomic<uint64_t> worker_num_stat_;
};

class DownpourWorker : public HogwildWorker {
Expand Down Expand Up @@ -713,7 +719,6 @@ class HeterSectionWorker : public DeviceWorker {
const platform::Place& place() const { return place_; }

void SetDeviceIndex(int tid) override { thread_id_ = tid; }
void SetThreadNum(int thread_num) { thread_num_ = thread_num; }
void SetMicrobatchNum(int num) { num_microbatches_ = num; }
void SetPipelineStageNum(int num) { num_pipeline_stages_ = num; }
void SetPipelineStage(int stage) { pipeline_stage_ = stage; }
Expand Down Expand Up @@ -755,7 +760,6 @@ class HeterSectionWorker : public DeviceWorker {
protected:
int trainer_id_;
int trainers_;
int thread_num_;
int thread_id_;
int num_microbatches_;
int num_pipeline_stages_;
Expand Down
45 changes: 43 additions & 2 deletions paddle/fluid/framework/hogwild_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,13 @@ limitations under the License. */
#include "paddle/fluid/distributed/ps/service/communicator/communicator.h"
#endif

DECLARE_bool(enable_exit_when_partial_worker);
namespace paddle {
namespace framework {

std::atomic<uint64_t> HogwildWorker::worker_num_stat_(0);
Barrier g_barrier;

void HogwildWorker::Initialize(const TrainerDesc &desc) {
fetch_config_ = desc.fetch_config();
param_ = desc.hogwild_param();
Expand Down Expand Up @@ -138,9 +142,27 @@ void HogwildWorker::TrainFilesWithProfiler() {
double read_time = 0.0;
int cur_batch;
int batch_cnt = 0;
if (thread_id_ == 0) {
worker_num_stat_.store(0);
}
g_barrier.wait();
bool train_mode = device_reader_->IsTrainMode();
timeline.Start();
uint64_t total_inst = 0;
while ((cur_batch = device_reader_->Next()) > 0) {
while (1) {
cur_batch = device_reader_->Next();
if (FLAGS_enable_exit_when_partial_worker && train_mode) {
if (cur_batch > 0) {
worker_num_stat_.fetch_add(1, std::memory_order_relaxed);
}
g_barrier.wait();
if (worker_num_stat_.load(std::memory_order_relaxed) % thread_num_ != 0) {
break;
}
}
if (cur_batch <= 0) {
break;
}
VLOG(3) << "read a batch in thread " << thread_id_;
timeline.Pause();
read_time += timeline.ElapsedSec();
Expand Down Expand Up @@ -230,11 +252,30 @@ void HogwildWorker::TrainFiles() {
device_reader_->Start();
int cur_batch;
int batch_cnt = 0;
if (thread_id_ == 0) {
worker_num_stat_.store(0);
}
g_barrier.wait();

#if defined(PADDLE_WITH_HETERPS) && defined(PADDLE_WITH_CUDA)
platform::SetDeviceId(thread_id_);
#endif
while ((cur_batch = device_reader_->Next()) > 0) {
//while ((cur_batch = device_reader_->Next()) > 0) {
bool train_mode = device_reader_->IsTrainMode();
while (1) {
cur_batch = device_reader_->Next();
if (FLAGS_enable_exit_when_partial_worker && train_mode) {
if (cur_batch > 0) {
worker_num_stat_.fetch_add(1, std::memory_order_relaxed);
}
g_barrier.wait();
if (worker_num_stat_.load(std::memory_order_relaxed) % thread_num_ != 0) {
break;
}
}
if (cur_batch <= 0) {
break;
}
for (auto &op : ops_) {
bool need_skip = false;
for (auto t = 0u; t < skip_ops_.size(); ++t) {
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/framework/multi_trainer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ limitations under the License. */
namespace paddle {
namespace framework {

extern Barrier g_barrier;

void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* dataset) {
thread_num_ = trainer_desc.thread_num();
Expand Down Expand Up @@ -61,7 +63,7 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
thread_num_);
}
#endif

g_barrier.reset(thread_num_);
for (int i = 0; i < thread_num_; ++i) {
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name());
Expand All @@ -73,6 +75,7 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
workers_[i]->Initialize(trainer_desc);
workers_[i]->SetDeviceIndex(i);
workers_[i]->SetDataFeed(readers[i]);
workers_[i]->SetThreadNum(thread_num_);
}

// set debug here
Expand Down
12 changes: 12 additions & 0 deletions paddle/fluid/platform/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,18 @@ PADDLE_DEFINE_EXPORTED_bool(
graph_get_neighbor_id, false,
"It controls get all neighbor id when running sub part graph.");

/**
* Distributed related FLAG
* Name: enable_exit_when_partial_worker
* Since Version: 2.2.0
* Value Range: bool, default=false
* Example:
* Note: Control whether exit trainer when an worker has no ins.
* If it is not set, trainer will exit until all worker finish train.
*/
PADDLE_DEFINE_EXPORTED_bool(
enable_exit_when_partial_worker, false,
"It controls whether exit trainer when an worker has no ins.");
/**
* KP kernel related FLAG
* Name: FLAGS_run_kp_kernel
Expand Down

0 comments on commit a0a6e81

Please sign in to comment.