diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 82727cea542d6..c4bb12c6f3636 100755 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -335,7 +335,7 @@ if(WITH_DISTRIBUTE) index_sampler index_wrapper sampler index_dataset_proto lod_rank_table fs shell fleet_wrapper heter_wrapper box_wrapper metrics lodtensor_printer feed_fetch_method graph_to_program_pass variable_helper timer monitor - heter_service_proto fleet heter_server brpc fleet_executor) + heter_service_proto fleet heter_server brpc fleet_executor flags) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=parentheses") if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) set(DISTRIBUTE_COMPILE_FLAGS diff --git a/paddle/fluid/framework/barrier.h b/paddle/fluid/framework/barrier.h new file mode 100644 index 0000000000000..b83726c42469b --- /dev/null +++ b/paddle/fluid/framework/barrier.h @@ -0,0 +1,44 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace paddle { +namespace framework { + +class Barrier { +public: + explicit Barrier(int count = 1) { + CHECK(count >= 1); + PCHECK(0 == pthread_barrier_init(&_barrier, NULL, count)); + } + ~Barrier() { + PCHECK(0 == pthread_barrier_destroy(&_barrier)); + } + void reset(int count) { + CHECK(count >= 1); + PCHECK(0 == pthread_barrier_destroy(&_barrier)); + PCHECK(0 == pthread_barrier_init(&_barrier, NULL, count)); + } + void wait() { + int err = pthread_barrier_wait(&_barrier); + PCHECK((err = pthread_barrier_wait(&_barrier), err == 0 || err == PTHREAD_BARRIER_SERIAL_THREAD)); + } +private: + pthread_barrier_t _barrier; + DISABLE_COPY_AND_ASSIGN(Barrier); +}; + +} +} diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index 7962e1591f0fa..8b5ec8f1f1c4d 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -2066,6 +2066,11 @@ void SlotRecordInMemoryDataFeed::Init(const DataFeedDesc& data_feed_desc) { so_parser_name_.clear(); } gpu_graph_data_generator_.SetConfig(data_feed_desc); + if (gpu_graph_mode_) { + train_mode_ = true; + } else { + train_mode_ = data_feed_desc.graph_config().gpu_graph_training(); + } } void SlotRecordInMemoryDataFeed::LoadIntoMemory() { diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index 7833e9760c476..02e85b69e0ee8 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -1047,6 +1047,7 @@ class DataFeed { return ins_content_vec_; } virtual int GetCurBatchSize() { return batch_size_; } + virtual bool IsTrainMode() { return train_mode_; } virtual void LoadIntoMemory() { PADDLE_THROW(platform::errors::Unimplemented( "This function(LoadIntoMemory) is not implemented.")); @@ -1119,6 +1120,7 @@ class DataFeed { int input_type_; int gpu_graph_mode_ = 0; GraphDataGenerator gpu_graph_data_generator_; + bool train_mode_; }; // PrivateQueueDataFeed is the base virtual class for ohther DataFeeds. diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index 0946dd0536075..5d056cc2fb36a 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -32,6 +32,7 @@ limitations under the License. */ #endif #include +#include "paddle/fluid/framework/barrier.h" #include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/executor_gc_helper.h" #include "paddle/fluid/framework/heter_util.h" @@ -208,6 +209,9 @@ class DeviceWorker { virtual void SetDeviceContext(platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; } + virtual void SetThreadNum(int thread_num) { + thread_num_ = thread_num; + } virtual Scope* GetThreadScope() { return thread_scope_; } DataFeed* device_reader_ = nullptr; @@ -237,6 +241,7 @@ class DeviceWorker { ChannelWriter writer_; const size_t tensor_iterator_thread_num = 16; platform::DeviceContext* dev_ctx_ = nullptr; + int thread_num_; }; class CPUWorkerBase : public DeviceWorker { @@ -282,6 +287,7 @@ class HogwildWorker : public CPUWorkerBase { HogwildWorkerParameter param_; std::vector skip_ops_; std::map stat_var_name_map_; + static std::atomic worker_num_stat_; }; class DownpourWorker : public HogwildWorker { @@ -713,7 +719,6 @@ class HeterSectionWorker : public DeviceWorker { const platform::Place& place() const { return place_; } void SetDeviceIndex(int tid) override { thread_id_ = tid; } - void SetThreadNum(int thread_num) { thread_num_ = thread_num; } void SetMicrobatchNum(int num) { num_microbatches_ = num; } void SetPipelineStageNum(int num) { num_pipeline_stages_ = num; } void SetPipelineStage(int stage) { pipeline_stage_ = stage; } @@ -755,7 +760,6 @@ class HeterSectionWorker : public DeviceWorker { protected: int trainer_id_; int trainers_; - int thread_num_; int thread_id_; int num_microbatches_; int num_pipeline_stages_; diff --git a/paddle/fluid/framework/hogwild_worker.cc b/paddle/fluid/framework/hogwild_worker.cc index 84bf12ed31a66..8921a84ffb05e 100644 --- a/paddle/fluid/framework/hogwild_worker.cc +++ b/paddle/fluid/framework/hogwild_worker.cc @@ -25,9 +25,13 @@ limitations under the License. */ #include "paddle/fluid/distributed/ps/service/communicator/communicator.h" #endif +DECLARE_bool(enable_exit_when_partial_worker); namespace paddle { namespace framework { +std::atomic HogwildWorker::worker_num_stat_(0); +Barrier g_barrier; + void HogwildWorker::Initialize(const TrainerDesc &desc) { fetch_config_ = desc.fetch_config(); param_ = desc.hogwild_param(); @@ -138,9 +142,27 @@ void HogwildWorker::TrainFilesWithProfiler() { double read_time = 0.0; int cur_batch; int batch_cnt = 0; + if (thread_id_ == 0) { + worker_num_stat_.store(0); + } + g_barrier.wait(); + bool train_mode = device_reader_->IsTrainMode(); timeline.Start(); uint64_t total_inst = 0; - while ((cur_batch = device_reader_->Next()) > 0) { + while (1) { + cur_batch = device_reader_->Next(); + if (FLAGS_enable_exit_when_partial_worker && train_mode) { + if (cur_batch > 0) { + worker_num_stat_.fetch_add(1, std::memory_order_relaxed); + } + g_barrier.wait(); + if (worker_num_stat_.load(std::memory_order_relaxed) % thread_num_ != 0) { + break; + } + } + if (cur_batch <= 0) { + break; + } VLOG(3) << "read a batch in thread " << thread_id_; timeline.Pause(); read_time += timeline.ElapsedSec(); @@ -230,11 +252,30 @@ void HogwildWorker::TrainFiles() { device_reader_->Start(); int cur_batch; int batch_cnt = 0; + if (thread_id_ == 0) { + worker_num_stat_.store(0); + } + g_barrier.wait(); #if defined(PADDLE_WITH_HETERPS) && defined(PADDLE_WITH_CUDA) platform::SetDeviceId(thread_id_); #endif - while ((cur_batch = device_reader_->Next()) > 0) { + //while ((cur_batch = device_reader_->Next()) > 0) { + bool train_mode = device_reader_->IsTrainMode(); + while (1) { + cur_batch = device_reader_->Next(); + if (FLAGS_enable_exit_when_partial_worker && train_mode) { + if (cur_batch > 0) { + worker_num_stat_.fetch_add(1, std::memory_order_relaxed); + } + g_barrier.wait(); + if (worker_num_stat_.load(std::memory_order_relaxed) % thread_num_ != 0) { + break; + } + } + if (cur_batch <= 0) { + break; + } for (auto &op : ops_) { bool need_skip = false; for (auto t = 0u; t < skip_ops_.size(); ++t) { diff --git a/paddle/fluid/framework/multi_trainer.cc b/paddle/fluid/framework/multi_trainer.cc index 6479f7ae72654..ce967b7b5f245 100644 --- a/paddle/fluid/framework/multi_trainer.cc +++ b/paddle/fluid/framework/multi_trainer.cc @@ -25,6 +25,8 @@ limitations under the License. */ namespace paddle { namespace framework { +extern Barrier g_barrier; + void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, Dataset* dataset) { thread_num_ = trainer_desc.thread_num(); @@ -61,7 +63,7 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, thread_num_); } #endif - + g_barrier.reset(thread_num_); for (int i = 0; i < thread_num_; ++i) { workers_[i] = DeviceWorkerFactory::CreateDeviceWorker( trainer_desc.device_worker_name()); @@ -73,6 +75,7 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, workers_[i]->Initialize(trainer_desc); workers_[i]->SetDeviceIndex(i); workers_[i]->SetDataFeed(readers[i]); + workers_[i]->SetThreadNum(thread_num_); } // set debug here diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index f6bb96e2f62dd..ff3225723127a 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -752,6 +752,18 @@ PADDLE_DEFINE_EXPORTED_bool( graph_get_neighbor_id, false, "It controls get all neighbor id when running sub part graph."); +/** + * Distributed related FLAG + * Name: enable_exit_when_partial_worker + * Since Version: 2.2.0 + * Value Range: bool, default=false + * Example: + * Note: Control whether exit trainer when an worker has no ins. + * If it is not set, trainer will exit until all worker finish train. + */ +PADDLE_DEFINE_EXPORTED_bool( + enable_exit_when_partial_worker, false, + "It controls whether exit trainer when an worker has no ins."); /** * KP kernel related FLAG * Name: FLAGS_run_kp_kernel