diff --git a/include/ps/internal/customer.h b/include/ps/internal/customer.h index c0a6e7c3e..f25298b66 100644 --- a/include/ps/internal/customer.h +++ b/include/ps/internal/customer.h @@ -15,6 +15,17 @@ #include "ps/internal/threadsafe_queue.h" namespace ps { +/** + * \brief The structure for profiling + */ +struct Profile { + uint64_t key; + int sender; + bool is_push; + std::string ts; + bool is_begin; +}; + /** * \brief The object for communication. * @@ -91,9 +102,11 @@ class Customer { void ProcessPullRequest(int worker_id); void ProcessPushRequest(int thread_id); + void ProcessProfileData(); bool IsValidPushpull(const Message& msg); uint64_t GetKeyFromMsg(const Message& msg); void ProcessResponse(int thread_id); + std::string GetTimestampNow(); private: /** @@ -113,6 +126,9 @@ class Customer { std::condition_variable tracker_cond_; std::vector> tracker_; + // for storing profile data + ThreadsafeQueue pdata_queue_; + DISALLOW_COPY_AND_ASSIGN(Customer); }; diff --git a/src/customer.cc b/src/customer.cc index bb9673d15..606b0ed63 100644 --- a/src/customer.cc +++ b/src/customer.cc @@ -9,6 +9,7 @@ #include #include #include +#include namespace ps { const int Node::kEmpty = std::numeric_limits::max(); const int Meta::kEmpty = std::numeric_limits::max(); @@ -22,6 +23,7 @@ std::unordered_map > pull_collected_; std::vector > worker_buffer_; std::atomic thread_barrier_{0}; +bool enable_profile_ = false; Customer::Customer(int app_id, int customer_id, const Customer::RecvHandle& recv_handle) : app_id_(app_id), customer_id_(customer_id), recv_handle_(recv_handle) { @@ -123,6 +125,10 @@ void Customer::ProcessPullRequest(int worker_id) { } recv_handle_(msg); it = pull_consumer.erase(it); + if (enable_profile_) { + Profile pdata = {key, msg.meta.sender, false, GetTimestampNow(), false}; + pdata_queue_.Push(pdata); + } break; } else { ++it; @@ -162,6 +168,10 @@ void Customer::ProcessPushRequest(int thread_id) { CHECK(msg.meta.push); uint64_t key = GetKeyFromMsg(msg); recv_handle_(msg); + if (enable_profile_) { + Profile pdata = {key, msg.meta.sender, true, GetTimestampNow(), false}; + pdata_queue_.Push(pdata); + } it = push_consumer.erase(it); @@ -181,6 +191,48 @@ void Customer::ProcessPushRequest(int thread_id) { } } +void Customer::ProcessProfileData() { + LOG(INFO) << "profile thread is inited"; + bool profile_all = true; // default: profile all keys + uint64_t key_to_profile; + const char *val; + val = Environment::Get()->find("BYTEPS_SERVER_KEY_TO_PROFILE"); + if (val) { + profile_all = false; + key_to_profile = atoi(val); + } + + std::fstream fout_; + val = Environment::Get()->find("BYTEPS_SERVER_PROFILE_OUTPUT_PATH"); + fout_.open((val ? std::string(val) : "server_profile.json"), std::fstream::out); + fout_ << "{\n"; + fout_ << "\t\"traceEvents\": [\n"; + bool is_init = true; + while (true) { + Profile pdata; + pdata_queue_.WaitAndPop(&pdata); + if (profile_all || key_to_profile==pdata.key) { + if (!is_init) { + fout_ << ",\n"; + } else { + is_init = false; + } + fout_ << "\t\t" << "{\"name\": " << "\"" <<(pdata.is_push?"push":"pull") << "-" << pdata.sender << "\"" << ", " + << "\"ph\": " << "\"" << (pdata.is_begin?"B":"E") << "\"" << ", " + << "\"pid\": " << pdata.key << ", " + << "\"tid\": " << pdata.key << ", " + << "\"ts\": " << pdata.ts + << "}"; + } + } + fout_ << "]\n"; + fout_ << "}"; + fout_.clear(); + fout_.flush(); + fout_.close(); + LOG(INFO) << "profile thread ended"; +} + void Customer::ProcessResponse(int thread_id) { { std::lock_guard lock(mu_); @@ -207,6 +259,16 @@ void Customer::ProcessResponse(int thread_id) { } } +std::string Customer::GetTimestampNow() { + std::chrono::microseconds us = + std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()); + std::stringstream temp_stream; + std::string ts_string; + temp_stream << us.count(); + temp_stream >> ts_string; + return ts_string; +} + void Customer::Receiving() { const char *val; val = CHECK_NOTNULL(Environment::Get()->find("DMLC_ROLE")); @@ -224,11 +286,19 @@ void Customer::Receiving() { bool enable_async = val ? atoi(val) : false; if (is_server && enable_async) { is_server_multi_pull_enabled = false; - LOG(INFO) << "Multi-threading has been disabled for asynchronous training"; + } + + // profiling + val = Environment::Get()->find("BYTEPS_SERVER_ENABLE_PROFILE"); + enable_profile_ = val ? atoi(val) : false; + std::thread* profile_thread; + if (enable_profile_ && is_server) { + LOG(INFO) << "Enable server profiling"; + profile_thread = new std::thread(&Customer::ProcessProfileData, this); } if (is_server && is_server_multi_pull_enabled){ // server multi-thread - LOG(INFO) << "Use seperate thread to process pull requests from each worker."; + LOG(INFO) << "Use separate thread to process pull requests from each worker."; std::vector push_thread; for (int i = 0; i < server_push_nthread; ++i) { @@ -308,9 +378,17 @@ void Customer::Receiving() { if (recv.meta.push) { // push: same key goes to same thread std::lock_guard lock(mu_); - buffered_push_[(key/num_server)%server_push_nthread].push_back(recv); + if (enable_profile_) { + Profile pdata = {key, recv.meta.sender, true, GetTimestampNow(), true}; + pdata_queue_.Push(pdata); + } + buffered_push_[(key/num_server) % server_push_nthread].push_back(recv); } else { // pull std::lock_guard lock(mu_); + if (enable_profile_) { + Profile pdata = {key, recv.meta.sender, false, GetTimestampNow(), true}; + pdata_queue_.Push(pdata); + } int worker_id = (recv.meta.sender - 9) / 2; // worker id: 9, 11, 13 ... buffered_pull_[worker_id].push_back(recv); } @@ -320,6 +398,7 @@ void Customer::Receiving() { // wait until the threads finish for (auto t : push_thread) t->join(); for (auto t : pull_thread) t->join(); + if (profile_thread) profile_thread->join(); } // server multi-thread