Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pull nodes with step #10

Merged
merged 1 commit into from
Mar 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion paddle/fluid/distributed/service/graph_brpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ std::future<int32_t> GraphBrpcClient::random_sample_nodes(
return fut;
}
std::future<int32_t> GraphBrpcClient::pull_graph_list(
uint32_t table_id, int server_index, int start, int size,
uint32_t table_id, int server_index, int start, int size, int step,
std::vector<GraphNode> &res) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
int ret = 0;
Expand Down Expand Up @@ -207,6 +207,7 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
closure->request(0)->set_client_id(_client_id);
closure->request(0)->add_params((char *)&start, sizeof(int));
closure->request(0)->add_params((char *)&size, sizeof(int));
closure->request(0)->add_params((char *)&step, sizeof(int));
PsService_Stub rpc_stub(get_cmd_channel(server_index));
closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0),
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/service/graph_brpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class GraphBrpcClient : public BrpcPsClient {
std::vector<std::vector<std::pair<uint64_t, float>>> &res);
virtual std::future<int32_t> pull_graph_list(uint32_t table_id,
int server_index, int start,
int size,
int size, int step,
std::vector<GraphNode> &res);
virtual std::future<int32_t> random_sample_nodes(uint32_t table_id,
int server_index,
Expand Down
7 changes: 4 additions & 3 deletions paddle/fluid/distributed/service/graph_brpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -260,16 +260,17 @@ int32_t GraphBrpcService::pull_graph_list(Table *table,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
if (request.params_size() < 3) {
set_response_code(response, -1,
"pull_graph_list request requires at least 2 arguments");
"pull_graph_list request requires at least 3 arguments");
return 0;
}
int start = *(int *)(request.params(0).c_str());
int size = *(int *)(request.params(1).c_str());
int step = *(int *)(request.params(2).c_str());
std::unique_ptr<char[]> buffer;
int actual_size;
table->pull_graph_list(start, size, buffer, actual_size, true);
table->pull_graph_list(start, size, buffer, actual_size, step, true);
cntl->response_attachment().append(buffer.get(), actual_size);
return 0;
}
Expand Down
7 changes: 4 additions & 3 deletions paddle/fluid/distributed/service/graph_py_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,13 @@ std::vector<uint64_t> GraphPyClient::random_sample_nodes(std::string name,
}
std::vector<GraphNode> GraphPyClient::pull_graph_list(std::string name,
int server_index,
int start, int size) {
int start, int size,
int step) {
std::vector<GraphNode> res;
if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name];
auto status =
worker_ptr->pull_graph_list(table_id, server_index, start, size, res);
auto status = worker_ptr->pull_graph_list(table_id, server_index, start,
size, step, res);
status.wait();
}
return res;
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/service/graph_py_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class GraphPyClient : public GraphPyService {
std::vector<uint64_t> random_sample_nodes(std::string name, int server_index,
int sample_size);
std::vector<GraphNode> pull_graph_list(std::string name, int server_index,
int start, int size);
int start, int size, int step = 1);
::paddle::distributed::PSParameter GetWorkerProto();

protected:
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/service/ps_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ class PSClient {
}
virtual std::future<int32_t> pull_graph_list(uint32_t table_id,
int server_index, int start,
int size,
int size, int step,
std::vector<GraphNode> &res) {
LOG(FATAL) << "Did not implement";
std::promise<int32_t> promise;
Expand Down
59 changes: 24 additions & 35 deletions paddle/fluid/distributed/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
namespace paddle {
namespace distributed {

std::vector<GraphNode *> GraphShard::get_batch(int start, int total_size) {
std::vector<GraphNode *> GraphShard::get_batch(int start, int end, int step) {
if (start < 0) start = 0;
std::vector<GraphNode *> res;
for (int pos = start; pos < start + total_size; pos++) {
for (int pos = start; pos < std::min(end, (int)bucket.size()); pos += step) {
res.push_back(bucket[pos]);
}
return res;
Expand All @@ -52,15 +52,14 @@ GraphNode *GraphShard::find_node(uint64_t id) {
}

int32_t GraphTable::load(const std::string &path, const std::string &param) {

bool load_edge = (param[0] == 'e');
bool load_node = (param[0] == 'n');
if (load_edge) {
bool reverse_edge = (param[1] == '<');
return this->load_edges(path, reverse_edge);
}
if (load_node){
std::string node_type = param.substr(1);
if (load_node) {
std::string node_type = param.substr(1);
return this->load_nodes(path, node_type);
}
}
Expand Down Expand Up @@ -125,18 +124,17 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {

std::string nt = values[0];
if (nt != node_type) {
continue;
continue;
}
std::vector<std::string> feature;
for (size_t slice = 2; slice < values.size(); slice++) {
feature.push_back(values[slice]);
}
size_t index = shard_id - shard_start;
if(feature.size() > 0) {
shards[index].add_node(id, paddle::string::join_strings(feature, '\t'));
}
else {
shards[index].add_node(id, std::string(""));
if (feature.size() > 0) {
shards[index].add_node(id, paddle::string::join_strings(feature, '\t'));
} else {
shards[index].add_node(id, std::string(""));
}
}
}
Expand Down Expand Up @@ -188,7 +186,8 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
for (auto &shard : shards) {
auto bucket = shard.get_bucket();
for (int i = 0; i < bucket.size(); i++) {
bucket[i]->build_sampler(sample_type); }
bucket[i]->build_sampler(sample_type);
}
}
return 0;
}
Expand Down Expand Up @@ -315,37 +314,27 @@ int GraphTable::random_sample_neighboors(
}
int32_t GraphTable::pull_graph_list(int start, int total_size,
std::unique_ptr<char[]> &buffer,
int &actual_size, bool need_feature) {
int &actual_size, bool need_feature,
int step) {
if (start < 0) start = 0;
int size = 0, cur_size;
if (total_size <= 0) {
actual_size = 0;
return 0;
}
std::vector<std::future<std::vector<GraphNode *>>> tasks;
for (size_t i = 0; i < shards.size(); i++) {
for (size_t i = 0; i < shards.size() && total_size > 0; i++) {
cur_size = shards[i].get_size();
if (size + cur_size <= start) {
size += cur_size;
continue;
}
if (size + cur_size - start >= total_size) {
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[this, i, start, size, total_size]() -> std::vector<GraphNode *> {
return this->shards[i].get_batch(start - size, total_size);
}));
break;
} else {
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[this, i, start, size, total_size,
cur_size]() -> std::vector<GraphNode *> {
return this->shards[i].get_batch(start - size,
size + cur_size - start);
}));
total_size -= size + cur_size - start;
size += cur_size;
start = size;
}
int count = std::min(1 + (size + cur_size - start - 1) / step, total_size);
int end = start + (count - 1) * step + 1;
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[this, i, start, end, step, size]() -> std::vector<GraphNode *> {

return this->shards[i].get_batch(start - size, end - size, step);
}));
start += count * step;
total_size -= count;
size += cur_size;
}
for (size_t i = 0; i < tasks.size(); ++i) {
tasks[i].wait();
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/distributed/table/common_graph_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class GraphShard {
// bucket.resize(bucket_size);
}
std::vector<GraphNode *> &get_bucket() { return bucket; }
std::vector<GraphNode *> get_batch(int start, int total_size);
std::vector<GraphNode *> get_batch(int start, int end, int step);
// int init_bucket_size(int shard_num) {
// for (int i = bucket_low_bound;; i++) {
// if (gcd(i, shard_num) == 1) return i;
Expand Down Expand Up @@ -78,7 +78,8 @@ class GraphTable : public SparseTable {
virtual ~GraphTable() {}
virtual int32_t pull_graph_list(int start, int size,
std::unique_ptr<char[]> &buffer,
int &actual_size, bool need_feature);
int &actual_size, bool need_feature,
int step);

virtual int32_t random_sample_neighboors(
uint64_t *node_ids, int sample_size,
Expand Down
25 changes: 13 additions & 12 deletions paddle/fluid/distributed/table/graph_edge.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,34 @@
// limitations under the License.

#pragma once
#include <vector>
#include <cstddef>
#include <cstdint>
#include <vector>
namespace paddle {
namespace distributed {


class GraphEdgeBlob {
public:
public:
GraphEdgeBlob() {}
virtual ~GraphEdgeBlob() {}
const size_t size() {return id_arr.size();}
size_t size() { return id_arr.size(); }
virtual void add_edge(uint64_t id, float weight);
const uint64_t get_id(int idx) { return id_arr[idx]; }
virtual const float get_weight(int idx) { return 1; }
protected:
uint64_t get_id(int idx) { return id_arr[idx]; }
virtual float get_weight(int idx) { return 1; }

protected:
std::vector<uint64_t> id_arr;
};

class WeightedGraphEdgeBlob: public GraphEdgeBlob{
public:
class WeightedGraphEdgeBlob : public GraphEdgeBlob {
public:
WeightedGraphEdgeBlob() {}
virtual ~WeightedGraphEdgeBlob() {}
virtual void add_edge(uint64_t id, float weight);
virtual const float get_weight(int idx) { return weight_arr[idx]; }
protected:
virtual float get_weight(int idx) { return weight_arr[idx]; }

protected:
std::vector<float> weight_arr;
};

}
}
3 changes: 2 additions & 1 deletion paddle/fluid/distributed/table/table.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ class Table {
// only for graph table
virtual int32_t pull_graph_list(int start, int total_size,
std::unique_ptr<char[]> &buffer,
int &actual_size, bool need_feature) {
int &actual_size, bool need_feature,
int step = 1) {
return 0;
}
// only for graph table
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/distributed/test/graph_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -335,12 +335,12 @@ void RunBrpcPushSparse() {
ASSERT_EQ(0, vs[0].size());

std::vector<distributed::GraphNode> nodes;
pull_status = worker_ptr_->pull_graph_list(0, 0, 0, 1, nodes);
pull_status = worker_ptr_->pull_graph_list(0, 0, 0, 1, 1, nodes);
pull_status.wait();
ASSERT_EQ(nodes.size(), 1);
ASSERT_EQ(nodes[0].get_id(), 37);
nodes.clear();
pull_status = worker_ptr_->pull_graph_list(0, 0, 1, 4, nodes);
pull_status = worker_ptr_->pull_graph_list(0, 0, 1, 4, 1, nodes);
pull_status.wait();
ASSERT_EQ(nodes.size(), 1);
ASSERT_EQ(nodes[0].get_id(), 59);
Expand Down Expand Up @@ -373,7 +373,7 @@ void RunBrpcPushSparse() {
// client2.load_edge_file(std::string("user2item"), std::string(file_name),
// 0);
nodes.clear();
nodes = client1.pull_graph_list(std::string("user"), 0, 1, 4);
nodes = client1.pull_graph_list(std::string("user"), 0, 1, 4, 1);

for (auto g : nodes) {
std::cout << "node_ids: " << g.get_id() << std::endl;
Expand Down