Skip to content

Commit

Permalink
[cherry-pick]Refactor Heterogenous Pipeline Parameter Server (#37446)
Browse files Browse the repository at this point in the history
* bug fix for  DeserializeSelectedRows. test=develop (#36520)

* fix SerializeSelectedRows (#36543)

* bug fix for  DeserializeSelectedRows. test=develop

* fix bug for SerializeSelectedRows. test=develop

* update. test=develop

* [Heterps]Refactor Heter Pipeline Parameter Server (#36845)

* change username

* fix

* fix

* fix

* fix

* fix

* update

* update

* update unittests

* fix

* update

* fix

* update

* fix

* fix

* fix

* update

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update send_and_recv op. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix ut. test=develop

* fix unit. notest,test=coverage

* fix ut. notest, test=coverage

* update. notest,test=coverage

* fix ut. notest, test=coverage

* fix ut. notest, test=coverage

* fix. notest, test=coverage

* fix. notest, test=coverage

* fix ut. notest, test=coverage

* fix ut. notest, test=coverage

* fix ut. notest, test=coverage

* fix ut. notest, test=coverage

* add func. notest, test=coverage

* fix ut. notest, test=coverage

* fix. test=develop

* fix. test=develop

* Fix unit test for send_and_recv_cpu & send_and_recv_gpu (#37129)

* [heterps]fix ut for heter_pipeline_trainer.cc  (#37136)

* fix ut. test=develop

* fix ut. test=develop

* [heterps]bug fix for local training with --heter_worker_num (#37166)

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix ut. test=develop

* fix ut. test=develop

* fix ut. test=develop

* [heterps]Refactor heterogenous worker (#37244)

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix ut. test=develop

* fix ut. test=develop

* fix ut. test=develop

* refactor heter trainer. test=develop

* fix. test=develop

* fix ut. test=develop

* fix ut. test=develop

* fix ut. test=develop

* fix ut. test=develop

* fix ut. test=develop

* fix ut. test=develop

* fix ut. test=develop

* fix ut. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix ut. test=develop

* fix ut. test=develop

* fix ut. test=develop

* [heterps]add heterps mode judgement (#37298)

* [heterps]change default executor for heter trainer (#37314)

* fix pslib. test=develop

* add device to train_from_dataset. test=develop

* refine fleet.stop_worker. test=develop

* fix ut. test=develop

* fix ut. test=develop

* fix executor & ut. test=develop

* fix executor & ut. test=develop

* fix executor & ut. test=develop

* [heterps]remove api for heter pipeline ps (#37396)

* fix api. test=develop

* fix api. test=develop

* fix code style. test=release/2.2

* fix CMakeLists. test=develop (#37454)
  • Loading branch information
zmxdream authored Nov 23, 2021
1 parent 436808c commit 4dc426f
Show file tree
Hide file tree
Showing 51 changed files with 4,106 additions and 767 deletions.
18 changes: 3 additions & 15 deletions paddle/fluid/distributed/service/brpc_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,23 +138,11 @@ void SerializeSelectedRows(framework::Variable* var,
var_data->clear();
var_data->resize(rows->size() * sizeof(int64_t));
char* data_ptr = const_cast<char*>(var_data->data());

if (platform::is_cpu_place(tensor->place())) {
memcpy(data_ptr, &(*rows)[0], rows->size() * sizeof(int64_t));
} else {
#ifdef PADDLE_WITH_CUDA
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(platform::CPUPlace(), data_ptr,
BOOST_GET_CONST(platform::CUDAPlace, tensor->place()),
&(*rows)[0], rows->size() * sizeof(int64_t), stream);
#endif
}
memcpy(data_ptr, &((*rows)[0]), rows->size() * sizeof(int64_t));
var_msg->set_data_type(static_cast<VarMsg::Type>(tensor->type()));
for (auto& dim : framework::vectorize(tensor->dims())) {
var_msg->add_dims(dim);
}

// IO Buffer
if (platform::is_cpu_place(tensor->place())) {
auto data_len = tensor->numel() * framework::SizeOfType(tensor->type());
Expand Down Expand Up @@ -273,8 +261,8 @@ void DeserializeSelectedRows(framework::Variable* var, const VarMsg& msg,
auto* slr = var->GetMutable<framework::SelectedRows>();
framework::Tensor* tensor = slr->mutable_value();
slr->set_height(msg.slr_height());
std::vector<int64_t> tmp_rows(msg.slr_height());
memcpy(&tmp_rows[0], msg.data().data(), msg.slr_height() * sizeof(int64_t));
std::vector<int64_t> tmp_rows(msg.dims()[0]);
memcpy(tmp_rows.data(), msg.data().data(), msg.dims()[0] * sizeof(int64_t));
slr->set_rows(tmp_rows);
std::vector<int> vec_dim;
for (auto& x : msg.dims()) {
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/distributed/service/communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/distributed/service/communicator.h"

#include <google/protobuf/text_format.h>

#include "gflags/gflags.h"
Expand Down Expand Up @@ -361,6 +360,8 @@ void Communicator::InitParams(const RecvCtxMap &recv_varname_to_ctx) {
<< " from 0' trainer done";
}
}
std::this_thread::sleep_for(
std::chrono::milliseconds(100 + trainer_id_ * 10));
BarrierWithTable(1);
return;
}
Expand Down Expand Up @@ -518,7 +519,6 @@ void AsyncCommunicator::SendByCommunicator() {
MergeVars<float>(var_name, vars[i], send_scope_.get(), 1);
}
}

if (ctx.is_tensor_table) {
SendGlobalStep(ctx, merged_var_num, send_scope_.get());
} else if (ctx.is_sparse) {
Expand Down
101 changes: 78 additions & 23 deletions paddle/fluid/distributed/service/heter_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,36 @@ namespace distributed {
std::shared_ptr<HeterClient> HeterClient::s_instance_ = NULL;
bool HeterClient::is_initialized_ = false;

int GetMicroId(const platform::DeviceContext& ctx,
const framework::Scope* scope) {
framework::Variable* var = scope->FindVar("microbatch_id");
PADDLE_ENFORCE_EQ(var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
"the type of micro id shoulde be LoDTensor."));
auto micro_id = -1;
auto* tensor = var->GetMutable<framework::LoDTensor>();
if (platform::is_cpu_place(tensor->place())) {
auto data = reinterpret_cast<const float*>(tensor->data<void>());
micro_id = static_cast<int>(data[0]);
} else {
#ifdef PADDLE_WITH_CUDA
std::vector<char> temp;
temp.resize(tensor->numel() * framework::SizeOfType(tensor->type()));
char* temp_ptr = temp.data();
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
memory::Copy(platform::CPUPlace(), temp_ptr,
BOOST_GET_CONST(platform::CUDAPlace, tensor->place()),
tensor->data<void>(),
tensor->numel() * framework::SizeOfType(tensor->type()),
stream);
float* temp_ptr_float = reinterpret_cast<float*>(temp_ptr);
micro_id = static_cast<int>(temp_ptr_float[0]);
#endif
}
return micro_id;
}

void HeterClient::MainThread() {
while (running_) {
RpcProfilerControl();
Expand Down Expand Up @@ -99,43 +129,68 @@ void HeterClient::CreateClient2XpuConnection() {
}
}
}
previous_xpu_channels_.resize(previous_xpu_list_.size());
for (size_t i = 0; i < previous_xpu_list_.size(); ++i) {
previous_xpu_channels_[i].reset(new brpc::Channel());
if (previous_xpu_channels_[i]->Init(previous_xpu_list_[i].c_str(), "",
&options) != 0) {
VLOG(0) << "HeterClient channel init fail. Try Again";
auto ip_port = paddle::string::Split(previous_xpu_list_[i], ':');
std::string ip = ip_port[0];
int port = std::stoi(ip_port[1]);
std::string int_ip_port = GetIntTypeEndpoint(ip, port);
if (previous_xpu_channels_[i]->Init(int_ip_port.c_str(), "", &options) !=
0) {
LOG(ERROR) << "BrpcPsServer start failed, ip_port= " << int_ip_port;
}
}
}
}

void HeterClient::SendAndRecvAsync(
const std::vector<std::string>& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& message_name,
const platform::DeviceContext& ctx, const framework::Scope& scope,
const std::string& message_name,
const std::vector<std::string>& send_var_name,
const std::vector<std::string>& recv_var_name) {
const std::vector<std::string>& recv_var_name, const std::string& mode) {
platform::RecordEvent record_event("HeterClient->SendAndRecvAsync");
const platform::DeviceContext* p_ctx = &ctx;
const framework::Scope* p_scope = &scope;
const std::string message_name_val = message_name;
const std::vector<std::string> send_var_name_val = send_var_name;
const std::vector<std::string> recv_var_name_val = recv_var_name;

VLOG(3) << "GRPCClient::SendAndRecv Begin, message_name: "
VLOG(3) << "BRPCClient::SendAndRecv Begin, message_name: "
<< message_name_val;
// Todo: get correct channel
int num = trainer_id_ % xpu_channels_.size();

brpc::Controller cntl;
cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
distributed::MultiVarMsg request, response;
auto& request_io_buffer = cntl.request_attachment();
::paddle::distributed::PsService_Stub stub(xpu_channels_[num].get());
brpc::Channel* channel = nullptr;
distributed::MultiVarMsg request;
OnHeterRpcDone* closure = new OnHeterRpcDone([p_ctx, p_scope](void* done) {
auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
PADDLE_ENFORCE_NE(
closure->cntl.Failed(), true,
platform::errors::Unimplemented(
"HeterClient::SendAndRecv meets brpc error, error message is %s",
closure->cntl.ErrorText()));

VLOG(4) << "call heter_worker success";
});
closure->cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
auto& request_io_buffer = closure->cntl.request_attachment();
distributed::SerializeToMultiVarMsgAndIOBuf(
message_name_val, send_var_name_val, recv_var_name_val, *p_ctx, p_scope,
&request, &request_io_buffer);
stub.SendAndRecvVariable(&cntl, &request, &response, NULL);
PADDLE_ENFORCE_NE(
cntl.Failed(), true,
platform::errors::Unimplemented(
"HeterClient::SendAndRecv meets brpc error, error message is %s",
cntl.ErrorText()));
VLOG(4) << "call heter_worker success";
auto& response_io_buffer = cntl.response_attachment();
distributed::DeserializeFromMultiVarMsgAndIOBuf(response, &response_io_buffer,
ctx, p_scope);

int micro_id = GetMicroId(ctx, p_scope);
auto minibatch_id = micro_id / 10;
// select channel according to micro id
if (mode == "forward") {
int num = minibatch_id % xpu_channels_.size();
channel = xpu_channels_[num].get();
} else if (mode == "backward") {
int num = minibatch_id % previous_xpu_channels_.size();
channel = previous_xpu_channels_[num].get();
}
::paddle::distributed::PsService_Stub stub(channel);
stub.SendAndRecvVariable(&closure->cntl, &request, &closure->response,
closure);
}

std::future<int32_t> HeterClient::SendCmd(
Expand Down
17 changes: 13 additions & 4 deletions paddle/fluid/distributed/service/heter_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,20 +76,23 @@ class HeterClient {

void CreateClient2XpuConnection();

void SendAndRecvAsync(const std::vector<std::string>& ep,
const platform::DeviceContext& ctx,
void SendAndRecvAsync(const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& message_name,
const std::vector<std::string>& send_var_name,
const std::vector<std::string>& recv_var_name);
const std::vector<std::string>& recv_var_name,
const std::string& mode = "forward");

// HeterClient singleton
static std::shared_ptr<HeterClient> GetInstance(
const std::vector<std::string>& endpoint, const int& trainer_id) {
const std::vector<std::string>& endpoint,
const std::vector<std::string>& previous_endpoint,
const int& trainer_id) {
if (NULL == s_instance_) {
is_initialized_ = true;
s_instance_.reset(new paddle::distributed::HeterClient());
s_instance_->SetXpuList(endpoint);
s_instance_->SetPreviousXpuList(previous_endpoint);
s_instance_->SetTrainerID(trainer_id);
s_instance_->CreateClient2XpuConnection();
}
Expand Down Expand Up @@ -118,16 +121,22 @@ class HeterClient {
xpu_list_ = xpu_list;
}

void SetPreviousXpuList(const std::vector<std::string>& xpu_list) {
previous_xpu_list_ = xpu_list;
}

void SetTrainerID(const int& trainer_id) { trainer_id_ = trainer_id; }

private:
static std::shared_ptr<HeterClient> s_instance_;
static bool is_initialized_;
std::unique_ptr<std::thread> main_thread_{nullptr};
std::vector<std::shared_ptr<brpc::Channel>> xpu_channels_;
std::vector<std::shared_ptr<brpc::Channel>> previous_xpu_channels_;

DISABLE_COPY_AND_ASSIGN(HeterClient);
std::vector<std::string> xpu_list_;
std::vector<std::string> previous_xpu_list_;

bool running_ = false;
int trainer_id_;
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/distributed/service/heter_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,20 @@ void HeterServer::StartHeterService() {
ready_ = 1;
}
condition_ready_.notify_all();

std::unique_lock<std::mutex> running_lock(mutex_);
stoped_ = false;
cv_.wait(running_lock, [&] {
VLOG(1) << "Heter Server is Stop? " << stoped_;
return stoped_;
});
}

void HeterServer::SetEndPoint(std::string& endpoint) {
void HeterServer::SetEndPoint(const std::string& endpoint) {
endpoint_ = endpoint;
service_.SetEndpoint(endpoint);
}

void HeterServer::SetFanin(int& fan_in) { service_.SetFanin(fan_in); }
void HeterServer::SetFanin(const int& fan_in) { service_.SetFanin(fan_in); }

void HeterServer::WaitServerReady() {
std::unique_lock<std::mutex> lock(this->mutex_ready_);
Expand Down
Loading

0 comments on commit 4dc426f

Please sign in to comment.