Skip to content

Commit

Permalink
common: add support for profiling server (#4)
Browse files Browse the repository at this point in the history
* initial support for server profiling

* simulate the customer

* add optional profile granularity

* output to json file

* improve format
  • Loading branch information
ymjiang authored Sep 25, 2019
1 parent 21301d5 commit 92e012f
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 3 deletions.
16 changes: 16 additions & 0 deletions include/ps/internal/customer.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,17 @@
#include "ps/internal/threadsafe_queue.h"
namespace ps {

/**
* \brief The structure for profiling
*/
struct Profile {
uint64_t key;
int sender;
bool is_push;
std::string ts;
bool is_begin;
};

/**
* \brief The object for communication.
*
Expand Down Expand Up @@ -91,9 +102,11 @@ class Customer {

void ProcessPullRequest(int worker_id);
void ProcessPushRequest(int thread_id);
void ProcessProfileData();
bool IsValidPushpull(const Message& msg);
uint64_t GetKeyFromMsg(const Message& msg);
void ProcessResponse(int thread_id);
std::string GetTimestampNow();

private:
/**
Expand All @@ -113,6 +126,9 @@ class Customer {
std::condition_variable tracker_cond_;
std::vector<std::pair<int, int>> tracker_;

// for storing profile data
ThreadsafeQueue<Profile> pdata_queue_;

DISALLOW_COPY_AND_ASSIGN(Customer);
};

Expand Down
85 changes: 82 additions & 3 deletions src/customer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <set>
#include <list>
#include <fstream>
#include <chrono>
namespace ps {
const int Node::kEmpty = std::numeric_limits<int>::max();
const int Meta::kEmpty = std::numeric_limits<int>::max();
Expand All @@ -22,6 +23,7 @@ std::unordered_map<uint64_t, std::set<int> > pull_collected_;
std::vector<std::list<Message> > worker_buffer_;

std::atomic<int> thread_barrier_{0};
bool enable_profile_ = false;

Customer::Customer(int app_id, int customer_id, const Customer::RecvHandle& recv_handle)
: app_id_(app_id), customer_id_(customer_id), recv_handle_(recv_handle) {
Expand Down Expand Up @@ -123,6 +125,10 @@ void Customer::ProcessPullRequest(int worker_id) {
}
recv_handle_(msg);
it = pull_consumer.erase(it);
if (enable_profile_) {
Profile pdata = {key, msg.meta.sender, false, GetTimestampNow(), false};
pdata_queue_.Push(pdata);
}
break;
} else {
++it;
Expand Down Expand Up @@ -162,6 +168,10 @@ void Customer::ProcessPushRequest(int thread_id) {
CHECK(msg.meta.push);
uint64_t key = GetKeyFromMsg(msg);
recv_handle_(msg);
if (enable_profile_) {
Profile pdata = {key, msg.meta.sender, true, GetTimestampNow(), false};
pdata_queue_.Push(pdata);
}

it = push_consumer.erase(it);

Expand All @@ -181,6 +191,48 @@ void Customer::ProcessPushRequest(int thread_id) {
}
}

void Customer::ProcessProfileData() {
LOG(INFO) << "profile thread is inited";
bool profile_all = true; // default: profile all keys
uint64_t key_to_profile;
const char *val;
val = Environment::Get()->find("BYTEPS_SERVER_KEY_TO_PROFILE");
if (val) {
profile_all = false;
key_to_profile = atoi(val);
}

std::fstream fout_;
val = Environment::Get()->find("BYTEPS_SERVER_PROFILE_OUTPUT_PATH");
fout_.open((val ? std::string(val) : "server_profile.json"), std::fstream::out);
fout_ << "{\n";
fout_ << "\t\"traceEvents\": [\n";
bool is_init = true;
while (true) {
Profile pdata;
pdata_queue_.WaitAndPop(&pdata);
if (profile_all || key_to_profile==pdata.key) {
if (!is_init) {
fout_ << ",\n";
} else {
is_init = false;
}
fout_ << "\t\t" << "{\"name\": " << "\"" <<(pdata.is_push?"push":"pull") << "-" << pdata.sender << "\"" << ", "
<< "\"ph\": " << "\"" << (pdata.is_begin?"B":"E") << "\"" << ", "
<< "\"pid\": " << pdata.key << ", "
<< "\"tid\": " << pdata.key << ", "
<< "\"ts\": " << pdata.ts
<< "}";
}
}
fout_ << "]\n";
fout_ << "}";
fout_.clear();
fout_.flush();
fout_.close();
LOG(INFO) << "profile thread ended";
}

void Customer::ProcessResponse(int thread_id) {
{
std::lock_guard<std::mutex> lock(mu_);
Expand All @@ -207,6 +259,16 @@ void Customer::ProcessResponse(int thread_id) {
}
}

std::string Customer::GetTimestampNow() {
std::chrono::microseconds us =
std::chrono::duration_cast<std::chrono::microseconds >(std::chrono::system_clock::now().time_since_epoch());
std::stringstream temp_stream;
std::string ts_string;
temp_stream << us.count();
temp_stream >> ts_string;
return ts_string;
}

void Customer::Receiving() {
const char *val;
val = CHECK_NOTNULL(Environment::Get()->find("DMLC_ROLE"));
Expand All @@ -224,11 +286,19 @@ void Customer::Receiving() {
bool enable_async = val ? atoi(val) : false;
if (is_server && enable_async) {
is_server_multi_pull_enabled = false;
LOG(INFO) << "Multi-threading has been disabled for asynchronous training";
}

// profiling
val = Environment::Get()->find("BYTEPS_SERVER_ENABLE_PROFILE");
enable_profile_ = val ? atoi(val) : false;
std::thread* profile_thread;
if (enable_profile_ && is_server) {
LOG(INFO) << "Enable server profiling";
profile_thread = new std::thread(&Customer::ProcessProfileData, this);
}

if (is_server && is_server_multi_pull_enabled){ // server multi-thread
LOG(INFO) << "Use seperate thread to process pull requests from each worker.";
LOG(INFO) << "Use separate thread to process pull requests from each worker.";

std::vector<std::thread *> push_thread;
for (int i = 0; i < server_push_nthread; ++i) {
Expand Down Expand Up @@ -308,9 +378,17 @@ void Customer::Receiving() {

if (recv.meta.push) { // push: same key goes to same thread
std::lock_guard<std::mutex> lock(mu_);
buffered_push_[(key/num_server)%server_push_nthread].push_back(recv);
if (enable_profile_) {
Profile pdata = {key, recv.meta.sender, true, GetTimestampNow(), true};
pdata_queue_.Push(pdata);
}
buffered_push_[(key/num_server) % server_push_nthread].push_back(recv);
} else { // pull
std::lock_guard<std::mutex> lock(mu_);
if (enable_profile_) {
Profile pdata = {key, recv.meta.sender, false, GetTimestampNow(), true};
pdata_queue_.Push(pdata);
}
int worker_id = (recv.meta.sender - 9) / 2; // worker id: 9, 11, 13 ...
buffered_pull_[worker_id].push_back(recv);
}
Expand All @@ -320,6 +398,7 @@ void Customer::Receiving() {
// wait until the threads finish
for (auto t : push_thread) t->join();
for (auto t : pull_thread) t->join();
if (profile_thread) profile_thread->join();

} // server multi-thread

Expand Down

0 comments on commit 92e012f

Please sign in to comment.