Skip to content

Commit

Permalink
Merge pull request #10 from WeiyueSu/FeatureNode
Browse files Browse the repository at this point in the history
Feature node
  • Loading branch information
seemingwang authored Mar 23, 2021
2 parents 3e68780 + fe4afed commit e96c14b
Show file tree
Hide file tree
Showing 11 changed files with 173 additions and 82 deletions.
6 changes: 3 additions & 3 deletions paddle/fluid/distributed/service/graph_brpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ std::future<int32_t> GraphBrpcClient::random_sample_nodes(
}
std::future<int32_t> GraphBrpcClient::pull_graph_list(
uint32_t table_id, int server_index, int start, int size, int step,
std::vector<GraphNode> &res) {
std::vector<FeatureNode> &res) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
Expand All @@ -190,9 +190,9 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
int index = 0;
while (index < bytes_size) {
GraphNode node;
FeatureNode node;
node.recover_from_buffer(buffer + index);
index += node.get_size(true);
index += node.get_size(false);
res.push_back(node);
}
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/service/graph_brpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class GraphBrpcClient : public BrpcPsClient {
virtual std::future<int32_t> pull_graph_list(uint32_t table_id,
int server_index, int start,
int size, int step,
std::vector<GraphNode> &res);
std::vector<FeatureNode> &res);
virtual std::future<int32_t> random_sample_nodes(uint32_t table_id,
int server_index,
int sample_size,
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/service/graph_brpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ int32_t GraphBrpcService::pull_graph_list(Table *table,
int step = *(int *)(request.params(2).c_str());
std::unique_ptr<char[]> buffer;
int actual_size;
table->pull_graph_list(start, size, buffer, actual_size, true, step);
table->pull_graph_list(start, size, buffer, actual_size, false, step);
cntl->response_attachment().append(buffer.get(), actual_size);
return 0;
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/distributed/service/graph_py_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -310,11 +310,11 @@ std::vector<uint64_t> GraphPyClient::random_sample_nodes(std::string name,
}
return v;
}
std::vector<GraphNode> GraphPyClient::pull_graph_list(std::string name,
std::vector<FeatureNode> GraphPyClient::pull_graph_list(std::string name,
int server_index,
int start, int size,
int step) {
std::vector<GraphNode> res;
std::vector<FeatureNode> res;
if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name];
auto status = worker_ptr->pull_graph_list(table_id, server_index, start,
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/service/graph_py_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class GraphPyClient : public GraphPyService {
std::string name, std::vector<uint64_t> node_ids, int sample_size);
std::vector<uint64_t> random_sample_nodes(std::string name, int server_index,
int sample_size);
std::vector<GraphNode> pull_graph_list(std::string name, int server_index,
std::vector<FeatureNode> pull_graph_list(std::string name, int server_index,
int start, int size, int step = 1);
::paddle::distributed::PSParameter GetWorkerProto();

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/service/ps_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ class PSClient {
virtual std::future<int32_t> pull_graph_list(uint32_t table_id,
int server_index, int start,
int size, int step,
std::vector<GraphNode> &res) {
std::vector<FeatureNode> &res) {
LOG(FATAL) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
Expand Down
59 changes: 34 additions & 25 deletions paddle/fluid/distributed/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
namespace paddle {
namespace distributed {

std::vector<GraphNode *> GraphShard::get_batch(int start, int end, int step) {
std::vector<Node *> GraphShard::get_batch(int start, int end, int step) {
if (start < 0) start = 0;
std::vector<GraphNode *> res;
std::vector<Node *> res;
for (int pos = start; pos < std::min(end, (int)bucket.size()); pos += step) {
res.push_back(bucket[pos]);
}
Expand All @@ -34,21 +34,29 @@ std::vector<GraphNode *> GraphShard::get_batch(int start, int end, int step) {

size_t GraphShard::get_size() { return bucket.size(); }

GraphNode *GraphShard::add_node(uint64_t id, std::string feature) {
if (node_location.find(id) != node_location.end())
return bucket[node_location[id]];
node_location[id] = bucket.size();
bucket.push_back(new GraphNode(id, feature));
return bucket.back();
GraphNode *GraphShard::add_graph_node(uint64_t id) {
if (node_location.find(id) == node_location.end()){
node_location[id] = bucket.size();
bucket.push_back(new GraphNode(id));
}
return (GraphNode*)bucket[node_location[id]];
}

FeatureNode *GraphShard::add_feature_node(uint64_t id) {
if (node_location.find(id) == node_location.end()){
node_location[id] = bucket.size();
bucket.push_back(new FeatureNode(id));
}
return (FeatureNode*)bucket[node_location[id]];
}

void GraphShard::add_neighboor(uint64_t id, uint64_t dst_id, float weight) {
add_node(id, std::string(""))->add_edge(dst_id, weight);
find_node(id)->add_edge(dst_id, weight);
}

GraphNode *GraphShard::find_node(uint64_t id) {
Node *GraphShard::find_node(uint64_t id) {
auto iter = node_location.find(id);
return iter == node_location.end() ? NULL : bucket[iter->second];
return iter == node_location.end() ? nullptr : bucket[iter->second];
}

int32_t GraphTable::load(const std::string &path, const std::string &param) {
Expand Down Expand Up @@ -132,9 +140,10 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {
}
size_t index = shard_id - shard_start;
if (feature.size() > 0) {
shards[index].add_node(id, paddle::string::join_strings(feature, '\t'));
// TODO add feature
shards[index].add_feature_node(id);
} else {
shards[index].add_node(id, std::string(""));
shards[index].add_feature_node(id);
}
}
}
Expand Down Expand Up @@ -175,7 +184,7 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
}

size_t index = src_shard_id - shard_start;
shards[index].add_node(src_id, std::string(""))->build_edges(is_weighted);
shards[index].add_graph_node(src_id)->build_edges(is_weighted);
shards[index].add_neighboor(src_id, dst_id, weight);
}
}
Expand All @@ -192,13 +201,13 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
return 0;
}

GraphNode *GraphTable::find_node(uint64_t id) {
Node *GraphTable::find_node(uint64_t id) {
size_t shard_id = id % shard_num;
if (shard_id >= shard_end || shard_id < shard_start) {
return NULL;
}
size_t index = shard_id - shard_start;
GraphNode *node = shards[index].find_node(id);
Node *node = shards[index].find_node(id);
return node;
}
uint32_t GraphTable::get_thread_pool_index(uint64_t node_id) {
Expand Down Expand Up @@ -282,15 +291,15 @@ int GraphTable::random_sample_neighboors(
int &actual_size = actual_sizes[idx];
tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]->enqueue(
[&]() -> int {
GraphNode *node = find_node(node_id);
Node *node = find_node(node_id);

if (node == NULL) {
actual_size = 0;
return 0;
}
std::vector<int> res = node->sample_k(sample_size);
actual_size =
res.size() * (GraphNode::id_size + GraphNode::weight_size);
res.size() * (Node::id_size + Node::weight_size);
int offset = 0;
uint64_t id;
float weight;
Expand All @@ -299,10 +308,10 @@ int GraphTable::random_sample_neighboors(
for (int &x : res) {
id = node->get_neighbor_id(x);
weight = node->get_neighbor_weight(x);
memcpy(buffer_addr + offset, &id, GraphNode::id_size);
offset += GraphNode::id_size;
memcpy(buffer_addr + offset, &weight, GraphNode::weight_size);
offset += GraphNode::weight_size;
memcpy(buffer_addr + offset, &id, Node::id_size);
offset += Node::id_size;
memcpy(buffer_addr + offset, &weight, Node::weight_size);
offset += Node::weight_size;
}
return 0;
}));
Expand All @@ -318,7 +327,7 @@ int32_t GraphTable::pull_graph_list(int start, int total_size,
int step) {
if (start < 0) start = 0;
int size = 0, cur_size;
std::vector<std::future<std::vector<GraphNode *>>> tasks;
std::vector<std::future<std::vector<Node *>>> tasks;
for (size_t i = 0; i < shards.size() && total_size > 0; i++) {
cur_size = shards[i].get_size();
if (size + cur_size <= start) {
Expand All @@ -328,7 +337,7 @@ int32_t GraphTable::pull_graph_list(int start, int total_size,
int count = std::min(1 + (size + cur_size - start - 1) / step, total_size);
int end = start + (count - 1) * step + 1;
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[this, i, start, end, step, size]() -> std::vector<GraphNode *> {
[this, i, start, end, step, size]() -> std::vector<Node *> {

return this->shards[i].get_batch(start - size, end - size, step);
}));
Expand All @@ -340,7 +349,7 @@ int32_t GraphTable::pull_graph_list(int start, int total_size,
tasks[i].wait();
}
size = 0;
std::vector<std::vector<GraphNode *>> res;
std::vector<std::vector<Node *>> res;
for (size_t i = 0; i < tasks.size(); i++) {
res.push_back(tasks[i].get());
for (size_t j = 0; j < res.back().size(); j++) {
Expand Down
15 changes: 8 additions & 7 deletions paddle/fluid/distributed/table/common_graph_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ class GraphShard {
// bucket_size = init_bucket_size(shard_num);
// bucket.resize(bucket_size);
}
std::vector<GraphNode *> &get_bucket() { return bucket; }
std::vector<GraphNode *> get_batch(int start, int end, int step);
std::vector<Node *> &get_bucket() { return bucket; }
std::vector<Node *> get_batch(int start, int end, int step);
// int init_bucket_size(int shard_num) {
// for (int i = bucket_low_bound;; i++) {
// if (gcd(i, shard_num) == 1) return i;
Expand All @@ -59,8 +59,9 @@ class GraphShard {
}
return res;
}
GraphNode *add_node(uint64_t id, std::string feature);
GraphNode *find_node(uint64_t id);
GraphNode *add_graph_node(uint64_t id);
FeatureNode *add_feature_node(uint64_t id);
Node *find_node(uint64_t id);
void add_neighboor(uint64_t id, uint64_t dst_id, float weight);
// std::unordered_map<uint64_t, std::list<GraphNode *>::iterator>
std::unordered_map<uint64_t, int> get_node_location() {
Expand All @@ -70,7 +71,7 @@ class GraphShard {
private:
std::unordered_map<uint64_t, int> node_location;
int shard_num;
std::vector<GraphNode *> bucket;
std::vector<Node *> bucket;
};
class GraphTable : public SparseTable {
public:
Expand Down Expand Up @@ -98,8 +99,8 @@ class GraphTable : public SparseTable {
int32_t load_edges(const std::string &path, bool reverse);

int32_t load_nodes(const std::string &path, std::string node_type);

GraphNode *find_node(uint64_t id);
Node *find_node(uint64_t id);

virtual int32_t pull_sparse(float *values, const uint64_t *keys, size_t num) {
return 0;
Expand Down
91 changes: 68 additions & 23 deletions paddle/fluid/distributed/table/graph_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,37 @@ GraphNode::~GraphNode() {
}
}

int GraphNode::weight_size = sizeof(float);
int GraphNode::id_size = sizeof(uint64_t);
int GraphNode::int_size = sizeof(int);
int GraphNode::get_size(bool need_feature) {
return id_size + int_size + (need_feature ? feature.size() : 0);
int Node::weight_size = sizeof(float);
int Node::id_size = sizeof(uint64_t);
int Node::int_size = sizeof(int);

int Node::get_size(bool need_feature) {
return id_size + int_size;
}

void Node::to_buffer(char* buffer, bool need_feature) {
memcpy(buffer, &id, id_size);
buffer += id_size;

int feat_num = 0;
memcpy(buffer, &feat_num, sizeof(int));
}

void Node::recover_from_buffer(char* buffer) {
memcpy(&id, buffer, id_size);
}

int FeatureNode::get_size(bool need_feature) {
int size = id_size + int_size; // id, feat_num
if (need_feature){
size += feature.size() * int_size;
for (const std::string& fea: feature){
size += fea.size();
}
}
return size;
}

void GraphNode::build_edges(bool is_weighted) {
if (edges == nullptr){
if (is_weighted == true){
Expand All @@ -52,28 +77,48 @@ void GraphNode::build_sampler(std::string sample_type) {
}
sampler->build(edges);
}
void GraphNode::to_buffer(char* buffer, bool need_feature) {
int size = get_size(need_feature);
memcpy(buffer, &size, int_size);
void FeatureNode::to_buffer(char* buffer, bool need_feature) {
memcpy(buffer, &id, id_size);
buffer += id_size;

int feat_num = 0;
int feat_len;
if (need_feature) {
memcpy(buffer + int_size, feature.c_str(), feature.size());
memcpy(buffer + int_size + feature.size(), &id, id_size);
feat_num += feature.size();
memcpy(buffer, &feat_num, sizeof(int));
buffer += sizeof(int);
for (int i = 0; i < feat_num; ++i){
feat_len = feature[i].size();
memcpy(buffer, &feat_len, sizeof(int));
buffer += sizeof(int);
memcpy(buffer, feature[i].c_str(), feature[i].size());
buffer += feature[i].size();
}
} else {
memcpy(buffer + int_size, &id, id_size);
memcpy(buffer, &feat_num, sizeof(int));
}
}
void GraphNode::recover_from_buffer(char* buffer) {
int size;
memcpy(&size, buffer, int_size);
int feature_size = size - id_size - int_size;
char str[feature_size + 1];
memcpy(str, buffer + int_size, feature_size);
str[feature_size] = '\0';
feature = str;
memcpy(&id, buffer + int_size + feature_size, id_size);
// int int_state;
// memcpy(&int_state, buffer + int_size + feature_size + id_size, enum_size);
// type = GraphNodeType(int_state);
void FeatureNode::recover_from_buffer(char* buffer) {

int feat_num, feat_len;
memcpy(&id, buffer, id_size);
buffer += id_size;

memcpy(&feat_num, buffer, sizeof(int));
buffer += sizeof(int);

feature.clear();
for (int i = 0; i < feat_num; ++i) {
memcpy(&feat_len, buffer, sizeof(int));
buffer += sizeof(int);

char str[feat_len + 1];
memcpy(str, buffer, feat_len);
buffer += feat_len;
str[feat_len] = '\0';
feature.push_back(std::string(str));
}

}
}
}
Loading

0 comments on commit e96c14b

Please sign in to comment.