Skip to content

Commit

Permalink
combine graph_table and feature_table in graph_engine (PaddlePaddle#4…
Browse files Browse the repository at this point in the history
…2134)

* extract sub-graph

* graph-engine merging

* fix

* fix

* fix heter-ps config

* test performance

* test performance

* test performance

* test

* test

* update bfs

* change cmake

* test

* test gpu speed

* gpu_graph_engine optimization

* add dsm sample method

* add graph_neighbor_sample_v2

* Add graph_neighbor_sample_v2

* fix for loop

* add cpu sample interface

* fix kernel judgement

* add ssd layer to graph_engine

* fix allocation

* fix syntax error

* fix syntax error

* fix pscore class

* fix

* change index settings

* recover test

* recover test

* fix spelling

* recover

* fix

* move cudamemcpy after cuda stream sync

* fix linking problem

* remove comment

* add cpu test

* test

* add cpu test

* change comment

* combine feature table and graph table

* test

* test

* pybind

* test

* test

* test

* test

* pybind

* pybind

* fix cmake

* pybind

* fix

* fix

* add pybind

* add pybind

Co-authored-by: DesmonDay <[email protected]>
  • Loading branch information
seemingwang and DesmonDay committed Apr 26, 2022
1 parent 0c68ae0 commit c158e4c
Show file tree
Hide file tree
Showing 24 changed files with 1,618 additions and 936 deletions.
107 changes: 23 additions & 84 deletions paddle/fluid/distributed/ps/service/graph_brpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ int GraphBrpcClient::get_server_index_by_id(int64_t id) {
}

std::future<int32_t> GraphBrpcClient::get_node_feat(
const uint32_t &table_id, const std::vector<int64_t> &node_ids,
const uint32_t &table_id, int idx_, const std::vector<int64_t> &node_ids,
const std::vector<std::string> &feature_names,
std::vector<std::vector<std::string>> &res) {
std::vector<int> request2server;
Expand Down Expand Up @@ -124,9 +124,11 @@ std::future<int32_t> GraphBrpcClient::get_node_feat(
int server_index = request2server[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_GET_NODE_FEAT);
closure->request(request_idx)->set_table_id(table_id);

closure->request(request_idx)->set_client_id(_client_id);
size_t node_num = node_id_buckets[request_idx].size();

closure->request(request_idx)->add_params((char *)&idx_, sizeof(int));
closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(int64_t) * node_num);
Expand All @@ -144,7 +146,8 @@ std::future<int32_t> GraphBrpcClient::get_node_feat(
return fut;
}

std::future<int32_t> GraphBrpcClient::clear_nodes(uint32_t table_id) {
std::future<int32_t> GraphBrpcClient::clear_nodes(uint32_t table_id,
int type_id, int idx_) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
server_size, [&, server_size = this->server_size ](void *done) {
int ret = 0;
Expand All @@ -167,7 +170,8 @@ std::future<int32_t> GraphBrpcClient::clear_nodes(uint32_t table_id) {
closure->request(server_index)->set_cmd_id(PS_GRAPH_CLEAR);
closure->request(server_index)->set_table_id(table_id);
closure->request(server_index)->set_client_id(_client_id);

closure->request(server_index)->add_params((char *)&type_id, sizeof(int));
closure->request(server_index)->add_params((char *)&idx_, sizeof(int));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(server_index),
Expand All @@ -177,7 +181,7 @@ std::future<int32_t> GraphBrpcClient::clear_nodes(uint32_t table_id) {
return fut;
}
std::future<int32_t> GraphBrpcClient::add_graph_node(
uint32_t table_id, std::vector<int64_t> &node_id_list,
uint32_t table_id, int idx_, std::vector<int64_t> &node_id_list,
std::vector<bool> &is_weighted_list) {
std::vector<std::vector<int64_t>> request_bucket;
std::vector<std::vector<bool>> is_weighted_bucket;
Expand Down Expand Up @@ -225,6 +229,7 @@ std::future<int32_t> GraphBrpcClient::add_graph_node(
closure->request(request_idx)->set_table_id(table_id);
closure->request(request_idx)->set_client_id(_client_id);
size_t node_num = request_bucket[request_idx].size();
closure->request(request_idx)->add_params((char *)&idx_, sizeof(int));
closure->request(request_idx)
->add_params((char *)request_bucket[request_idx].data(),
sizeof(int64_t) * node_num);
Expand All @@ -245,7 +250,7 @@ std::future<int32_t> GraphBrpcClient::add_graph_node(
return fut;
}
std::future<int32_t> GraphBrpcClient::remove_graph_node(
uint32_t table_id, std::vector<int64_t> &node_id_list) {
uint32_t table_id, int idx_, std::vector<int64_t> &node_id_list) {
std::vector<std::vector<int64_t>> request_bucket;
std::vector<int> server_index_arr;
std::vector<int> index_mapping(server_size, -1);
Expand Down Expand Up @@ -286,6 +291,7 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
closure->request(request_idx)->set_client_id(_client_id);
size_t node_num = request_bucket[request_idx].size();

closure->request(request_idx)->add_params((char *)&idx_, sizeof(int));
closure->request(request_idx)
->add_params((char *)request_bucket[request_idx].data(),
sizeof(int64_t) * node_num);
Expand All @@ -299,7 +305,7 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
}
// char* &buffer,int &actual_size
std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
uint32_t table_id, std::vector<int64_t> node_ids, int sample_size,
uint32_t table_id, int idx_, std::vector<int64_t> node_ids, int sample_size,
// std::vector<std::vector<std::pair<int64_t, float>>> &res,
std::vector<std::vector<int64_t>> &res,
std::vector<std::vector<float>> &res_weight, bool need_weight,
Expand Down Expand Up @@ -353,6 +359,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
closure->request(0)->set_cmd_id(PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER);
closure->request(0)->set_table_id(table_id);
closure->request(0)->set_client_id(_client_id);
closure->request(0)->add_params((char *)&idx_, sizeof(int));
closure->request(0)->add_params((char *)node_ids.data(),
sizeof(int64_t) * node_ids.size());
closure->request(0)->add_params((char *)&sample_size, sizeof(int));
Expand Down Expand Up @@ -452,6 +459,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
closure->request(request_idx)->set_client_id(_client_id);
size_t node_num = node_id_buckets[request_idx].size();

closure->request(request_idx)->add_params((char *)&idx_, sizeof(int));
closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(int64_t) * node_num);
Expand All @@ -469,7 +477,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
return fut;
}
std::future<int32_t> GraphBrpcClient::random_sample_nodes(
uint32_t table_id, int server_index, int sample_size,
uint32_t table_id, int type_id, int idx_, int server_index, int sample_size,
std::vector<int64_t> &ids) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
int ret = 0;
Expand Down Expand Up @@ -498,6 +506,8 @@ std::future<int32_t> GraphBrpcClient::random_sample_nodes(
closure->request(0)->set_cmd_id(PS_GRAPH_SAMPLE_NODES);
closure->request(0)->set_table_id(table_id);
closure->request(0)->set_client_id(_client_id);
closure->request(0)->add_params((char *)&type_id, sizeof(int));
closure->request(0)->add_params((char *)&idx_, sizeof(int));
closure->request(0)->add_params((char *)&sample_size, sizeof(int));
;
// PsService_Stub rpc_stub(GetCmdChannel(server_index));
Expand All @@ -508,83 +518,9 @@ std::future<int32_t> GraphBrpcClient::random_sample_nodes(
return fut;
}

std::future<int32_t> GraphBrpcClient::load_graph_split_config(
uint32_t table_id, std::string path) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
server_size, [&, server_size = this->server_size ](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < server_size; ++request_idx) {
if (closure->check_response(request_idx,
PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG) != 0) {
++fail_num;
break;
}
}
ret = fail_num == 0 ? 0 : -1;
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < server_size; i++) {
int server_index = i;
closure->request(server_index)
->set_cmd_id(PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG);
closure->request(server_index)->set_table_id(table_id);
closure->request(server_index)->set_client_id(_client_id);
closure->request(server_index)->add_params(path);
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(server_index),
closure->request(server_index),
closure->response(server_index), closure);
}
return fut;
}
std::future<int32_t> GraphBrpcClient::use_neighbors_sample_cache(
uint32_t table_id, size_t total_size_limit, size_t ttl) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
server_size, [&, server_size = this->server_size ](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < server_size; ++request_idx) {
if (closure->check_response(
request_idx, PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE) != 0) {
++fail_num;
break;
}
}
ret = fail_num == 0 ? 0 : -1;
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
size_t size_limit = total_size_limit / server_size +
(total_size_limit % server_size != 0 ? 1 : 0);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < server_size; i++) {
int server_index = i;
closure->request(server_index)
->set_cmd_id(PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE);
closure->request(server_index)->set_table_id(table_id);
closure->request(server_index)->set_client_id(_client_id);
closure->request(server_index)
->add_params((char *)&size_limit, sizeof(size_t));
closure->request(server_index)->add_params((char *)&ttl, sizeof(size_t));
GraphPsService_Stub rpc_stub = getServiceStub(GetCmdChannel(server_index));
closure->cntl(server_index)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(server_index),
closure->request(server_index),
closure->response(server_index), closure);
}
return fut;
}
std::future<int32_t> GraphBrpcClient::pull_graph_list(
uint32_t table_id, int server_index, int start, int size, int step,
std::vector<FeatureNode> &res) {
uint32_t table_id, int type_id, int idx_, int server_index, int start,
int size, int step, std::vector<FeatureNode> &res) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
Expand Down Expand Up @@ -613,6 +549,8 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
closure->request(0)->set_cmd_id(PS_PULL_GRAPH_LIST);
closure->request(0)->set_table_id(table_id);
closure->request(0)->set_client_id(_client_id);
closure->request(0)->add_params((char *)&type_id, sizeof(int));
closure->request(0)->add_params((char *)&idx_, sizeof(int));
closure->request(0)->add_params((char *)&start, sizeof(int));
closure->request(0)->add_params((char *)&size, sizeof(int));
closure->request(0)->add_params((char *)&step, sizeof(int));
Expand All @@ -625,7 +563,7 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
}

std::future<int32_t> GraphBrpcClient::set_node_feat(
const uint32_t &table_id, const std::vector<int64_t> &node_ids,
const uint32_t &table_id, int idx_, const std::vector<int64_t> &node_ids,
const std::vector<std::string> &feature_names,
const std::vector<std::vector<std::string>> &features) {
std::vector<int> request2server;
Expand Down Expand Up @@ -686,6 +624,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat(
closure->request(request_idx)->set_client_id(_client_id);
size_t node_num = node_id_buckets[request_idx].size();

closure->request(request_idx)->add_params((char *)&idx_, sizeof(int));
closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(int64_t) * node_num);
Expand Down
27 changes: 12 additions & 15 deletions paddle/fluid/distributed/ps/service/graph_brpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,40 +63,37 @@ class GraphBrpcClient : public BrpcPsClient {
virtual ~GraphBrpcClient() {}
// given a batch of nodes, sample graph_neighbors for each of them
virtual std::future<int32_t> batch_sample_neighbors(
uint32_t table_id, std::vector<int64_t> node_ids, int sample_size,
std::vector<std::vector<int64_t>>& res,
uint32_t table_id, int idx, std::vector<int64_t> node_ids,
int sample_size, std::vector<std::vector<int64_t>>& res,
std::vector<std::vector<float>>& res_weight, bool need_weight,
int server_index = -1);

virtual std::future<int32_t> pull_graph_list(uint32_t table_id,
int server_index, int start,
int size, int step,
virtual std::future<int32_t> pull_graph_list(uint32_t table_id, int type_id,
int idx, int server_index,
int start, int size, int step,
std::vector<FeatureNode>& res);
virtual std::future<int32_t> random_sample_nodes(uint32_t table_id,
int type_id, int idx,
int server_index,
int sample_size,
std::vector<int64_t>& ids);
virtual std::future<int32_t> get_node_feat(
const uint32_t& table_id, const std::vector<int64_t>& node_ids,
const uint32_t& table_id, int idx, const std::vector<int64_t>& node_ids,
const std::vector<std::string>& feature_names,
std::vector<std::vector<std::string>>& res);

virtual std::future<int32_t> set_node_feat(
const uint32_t& table_id, const std::vector<int64_t>& node_ids,
const uint32_t& table_id, int idx, const std::vector<int64_t>& node_ids,
const std::vector<std::string>& feature_names,
const std::vector<std::vector<std::string>>& features);

virtual std::future<int32_t> clear_nodes(uint32_t table_id);
virtual std::future<int32_t> clear_nodes(uint32_t table_id, int type_id,
int idx);
virtual std::future<int32_t> add_graph_node(
uint32_t table_id, std::vector<int64_t>& node_id_list,
uint32_t table_id, int idx, std::vector<int64_t>& node_id_list,
std::vector<bool>& is_weighted_list);
virtual std::future<int32_t> use_neighbors_sample_cache(uint32_t table_id,
size_t size_limit,
size_t ttl);
virtual std::future<int32_t> load_graph_split_config(uint32_t table_id,
std::string path);
virtual std::future<int32_t> remove_graph_node(
uint32_t table_id, std::vector<int64_t>& node_id_list);
uint32_t table_id, int idx_, std::vector<int64_t>& node_id_list);
virtual int32_t Initialize();
int get_shard_num() { return shard_num; }
void set_shard_num(int shard_num) { this->shard_num = shard_num; }
Expand Down
Loading

0 comments on commit c158e4c

Please sign in to comment.