Skip to content

Commit

Permalink
load parallel between nodes and edges (PaddlePaddle#37)
Browse files Browse the repository at this point in the history
* load parallel between nodes and edges
  • Loading branch information
miaoli06 authored Jun 23, 2022
1 parent d5011c7 commit c176c00
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 76 deletions.
262 changes: 194 additions & 68 deletions paddle/fluid/distributed/ps/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@
#include <chrono>
#include <set>
#include <sstream>
#include <boost/algorithm/string.hpp>

#include "gflags/gflags.h"
#include "paddle/fluid/platform/timer.h"
#include "paddle/fluid/distributed/common/utils.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/string/printf.h"
Expand Down Expand Up @@ -1022,6 +1026,75 @@ int32_t GraphTable::Load(const std::string &path, const std::string &param) {
return 0;
}

std::string GraphTable::get_inverse_etype(std::string &etype) {
auto etype_split = paddle::string::split_string<std::string>(etype, "2");
std::string res;
if ((int)etype_split.size() == 3) {
res = etype_split[2] + "2" + etype_split[1] + "2" + etype_split[0];
} else {
res = etype_split[1] + "2" + etype_split[0];
}
return res;
}

int32_t GraphTable::load_node_and_edge_file(std::string etype, std::string ntype, std::string epath,
std::string npath, int part_num, bool reverse) {
auto etypes = paddle::string::split_string<std::string>(etype, ",");
auto ntypes = paddle::string::split_string<std::string>(ntype, ",");
VLOG(0) << "etypes size: " << etypes.size();
VLOG(0) << "whether reverse: " << reverse;
std::string delim = ";";
size_t total_len = etypes.size() + 1; // 1 is for node

std::vector<std::future<int>> tasks;
for (size_t i = 0; i < total_len; i++) {
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[&, i, this]() ->int {
if (i < etypes.size()) {
std::string etype_path = epath + "/" + etypes[i];
auto etype_path_list = paddle::framework::localfs_list(etype_path);
std::string etype_path_str;
if (part_num > 0 && part_num < (int)etype_path_list.size()) {
std::vector<std::string> sub_etype_path_list(etype_path_list.begin(), etype_path_list.begin() + part_num);
etype_path_str = boost::algorithm::join(sub_etype_path_list, delim);
} else {
etype_path_str = boost::algorithm::join(etype_path_list, delim);
}
this->load_edges(etype_path_str, false, etypes[i]);
if (reverse) {
std::string r_etype = get_inverse_etype(etypes[i]);
this->load_edges(etype_path_str, true, r_etype);
}
} else {
auto npath_list = paddle::framework::localfs_list(npath);
std::string npath_str;
if (part_num > 0 && part_num < (int)npath_list.size()) {
std::vector<std::string> sub_npath_list(npath_list.begin(), npath_list.begin() + part_num);
npath_str = boost::algorithm::join(sub_npath_list, delim);
} else {
npath_str = boost::algorithm::join(npath_list, delim);
}

if (ntypes.size() == 0) {
VLOG(0) << "node_type not specified, nothing will be loaded ";
return 0;
} else {
for (size_t i = 0; i < ntypes.size(); i++) {
if (feature_to_id.find(ntypes[i]) == feature_to_id.end()) {
VLOG(0) << "node_type " << ntypes[i] << "is not defined, will not load";
return 0;
}
}
}
this->load_nodes(npath_str, "");
}
return 0;
}));
}
for (int i = 0; i < (int)tasks.size(); i++) tasks[i].get();
return 0;
}

int32_t GraphTable::get_nodes_ids_by_ranges(
int type_id, int idx, std::vector<std::pair<int, int>> ranges,
std::vector<uint64_t> &res) {
Expand Down Expand Up @@ -1061,19 +1134,20 @@ int32_t GraphTable::get_nodes_ids_by_ranges(
return 0;
}

int32_t GraphTable::parse_node_file(const std::string &path, const std::string &node_type, int idx, uint64_t &count, uint64_t &valid_count) {
std::pair<uint64_t, uint64_t> GraphTable::parse_node_file(const std::string &path, const std::string &node_type, int idx) {
std::ifstream file(path);
std::string line;
uint64_t local_count = 0;
uint64_t local_valid_count = 0;
while (std::getline(file, line)) {
auto values = paddle::string::split_string<std::string>(line, "\t");
if (values.size() < 2) continue;
if (values[0] != node_type) {
size_t start = line.find_first_of('\t');
if (start == std::string::npos) continue;
std::string parse_node_type = line.substr(0, start);
if (parse_node_type != node_type) {
continue;
}

auto id = std::stoull(values[1]);
size_t end = line.find_first_of('\t', start + 1);
uint64_t id = std::stoull(line.substr(start +1, end - start - 1));
size_t shard_id = id % shard_num;
if (shard_id >= shard_end || shard_id < shard_start) {
VLOG(4) << "will not load " << id << " from " << path
Expand All @@ -1086,18 +1160,64 @@ int32_t GraphTable::parse_node_file(const std::string &path, const std::string &
auto node = feature_shards[idx][index]->add_feature_node(id, false);
if (node != NULL) {
node->set_feature_size(feat_name[idx].size());
for (size_t slice = 2; slice < values.size(); slice++) {
parse_feature(idx, values[slice], node);
while (end != std::string::npos) {
start = end;
end = line.find_first_of('\t', start + 1);
std::string tmp_str = line.substr(start + 1, end - start - 1);
parse_feature(idx, tmp_str, node);
}
}
local_valid_count++;
}
mutex_.lock();
count += local_count;
valid_count += local_valid_count;
mutex_.unlock();
VLOG(0) << "node_type[" << node_type << "] loads " << local_count << " nodes from filepath->" << path;
return 0;
return {local_count, local_valid_count};
}

std::pair<uint64_t, uint64_t> GraphTable::parse_node_file(const std::string &path) {
std::ifstream file(path);
std::string line;
uint64_t local_count = 0;
uint64_t local_valid_count = 0;
int idx = 0;

auto path_split = paddle::string::split_string<std::string>(path, "/");
auto path_name = path_split[path_split.size() - 1];

while (std::getline(file, line)) {
size_t start = line.find_first_of('\t');
if (start == std::string::npos) continue;
std::string parse_node_type = line.substr(0, start);
auto it = feature_to_id.find(parse_node_type);
if (it == feature_to_id.end()) {
VLOG(0) << parse_node_type << "type error, please check";
continue;
}
idx = it->second;
size_t end = line.find_first_of('\t', start + 1);
uint64_t id = std::stoull(line.substr(start +1, end - start - 1));
size_t shard_id = id % shard_num;
if (shard_id >= shard_end || shard_id < shard_start) {
VLOG(4) << "will not load " << id << " from " << path
<< ", please check id distribution";
continue;
}
local_count++;

size_t index = shard_id - shard_start;
auto node = feature_shards[idx][index]->add_feature_node(id, false);
if (node != NULL) {
while (end != std::string::npos) {
start = end;
end = line.find_first_of('\t', start + 1);
std::string tmp_str = line.substr(start + 1, end - start - 1);
parse_feature(idx, tmp_str, node);
}
}

local_valid_count++;
}
VLOG(0) << local_valid_count << "/" << local_count << " nodes from filepath->" << path;
return {local_count, local_valid_count};
}

// TODO opt load all node_types in once reading
Expand All @@ -1106,33 +1226,40 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {
uint64_t count = 0;
uint64_t valid_count = 0;
int idx = 0;
if (node_type == "") {
VLOG(0) << "node_type not specified, loading edges to " << id_to_feature[0]
<< " part";
} else {
if (feature_to_id.find(node_type) == feature_to_id.end()) {
VLOG(0) << "node_type " << node_type
<< " is not defined, nothing will be loaded";
return 0;
}
idx = feature_to_id[node_type];
}

VLOG(0) << "Begin GraphTable::load_nodes() node_type[" << node_type << "]";
if (FLAGS_graph_load_in_parallel) {
std::vector<std::future<int>> tasks;
if (node_type == "") {
VLOG(0) << "Begin GraphTable::load_nodes(), will load all node_type once";
}
std::vector<std::future<std::pair<uint64_t, uint64_t>>> tasks;
for (size_t i = 0; i < paths.size(); i++) {
tasks.push_back(load_node_edge_task_pool->enqueue(
[&, i, idx, this]() -> int {
parse_node_file(paths[i], node_type, idx, count, valid_count);
return 0;
[&, i, this]() -> std::pair<uint64_t, uint64_t> {
return parse_node_file(paths[i]);
}));
}
for (int i = 0; i < (int)tasks.size(); i++) tasks[i].get();
for (int i = 0; i < (int)tasks.size(); i++) {
auto res = tasks[i].get();
count += res.first;
valid_count += res.second;
}
} else {
VLOG(0) << "Begin GraphTable::load_nodes() node_type[" << node_type << "]";
if (node_type == "") {
VLOG(0) << "node_type not specified, loading edges to " << id_to_feature[0]
<< " part";
} else {
if (feature_to_id.find(node_type) == feature_to_id.end()) {
VLOG(0) << "node_type " << node_type
<< " is not defined, nothing will be loaded";
return 0;
}
idx = feature_to_id[node_type];
}
for (auto path : paths) {
VLOG(2) << "Begin GraphTable::load_nodes(), path[" << path << "]";
parse_node_file(path, node_type, idx, count, valid_count);
auto res = parse_node_file(path, node_type, idx);
count += res.first;
valid_count += res.second;
}
}

Expand All @@ -1151,7 +1278,7 @@ int32_t GraphTable::build_sampler(int idx, std::string sample_type) {
return 0;
}

int32_t GraphTable::parse_edge_file(const std::string &path, int idx, bool reverse, uint64_t &count, uint64_t &valid_count) {
std::pair<uint64_t, uint64_t> GraphTable::parse_edge_file(const std::string &path, int idx, bool reverse) {
std::string sample_type = "random";
bool is_weighted = false;
std::ifstream file(path);
Expand All @@ -1164,13 +1291,13 @@ int32_t GraphTable::parse_edge_file(const std::string &path, int idx, bool rever
auto part_name_split = paddle::string::split_string<std::string>(path_split[path_split.size() - 1], "-");
part_num = std::stoull(part_name_split[part_name_split.size() - 1]);
}

while (std::getline(file, line)) {
auto values = paddle::string::split_string<std::string>(line, "\t");
size_t start = line.find_first_of('\t');
if (start == std::string::npos) continue;
local_count++;
if (values.size() < 2) continue;
auto src_id = std::stoull(values[0]);
auto dst_id = std::stoull(values[1]);
uint64_t src_id = std::stoull(line.substr(0, start));
uint64_t dst_id = std::stoull(line.substr(start + 1));
if (reverse) {
std::swap(src_id, dst_id);
}
Expand All @@ -1182,8 +1309,9 @@ int32_t GraphTable::parse_edge_file(const std::string &path, int idx, bool rever
}

float weight = 1;
if (values.size() == 3) {
weight = std::stof(values[2]);
size_t last = line.find_last_of('\t');
if (start != last) {
weight = std::stof(line.substr(last + 1));
sample_type = "weighted";
is_weighted = true;
}
Expand All @@ -1193,34 +1321,26 @@ int32_t GraphTable::parse_edge_file(const std::string &path, int idx, bool rever
<< ", please check id distribution";
continue;
}

size_t index = src_shard_id - shard_start;
edge_shards[idx][index]->add_graph_node(src_id)->build_edges(is_weighted);
edge_shards[idx][index]->add_neighbor(src_id, dst_id, weight);
auto node = edge_shards[idx][index]->add_graph_node(src_id);
if (node != NULL) {
node->build_edges(is_weighted);
node->add_edge(dst_id, weight);
}

local_valid_count++;

}
mutex_.lock();
count += local_count;
valid_count += local_valid_count;
#ifdef PADDLE_WITH_HETERPS
const uint64_t fixed_load_edges = 1000000;
if (count > fixed_load_edges && search_level == 2) {
dump_edges_to_ssd(idx);
VLOG(0) << "dumping edges to ssd, edge count is reset to 0";
clear_graph(idx);
count = 0;
}
#endif
mutex_.unlock();
VLOG(0) << local_count << " edges are loaded from filepath->" << path;
return 0;
return {local_count, local_valid_count};

}

int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge,
const std::string &edge_type) {
#ifdef PADDLE_WITH_HETERPS
if (search_level == 2) total_memory_cost = 0;
//const uint64_t fixed_load_edges = 1000000;
const uint64_t fixed_load_edges = 1000000;
#endif
int idx = 0;
if (edge_type == "") {
Expand All @@ -1241,18 +1361,23 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge,

VLOG(0) << "Begin GraphTable::load_edges() edge_type[" << edge_type << "]";
if (FLAGS_graph_load_in_parallel) {
std::vector<std::future<int>> tasks;
std::vector<std::future<std::pair<uint64_t, uint64_t>>> tasks;
for (int i = 0; i < paths.size(); i++) {
tasks.push_back(load_node_edge_task_pool->enqueue(
[&, i, idx, this]() -> int {
parse_edge_file(paths[i], idx, reverse_edge, count, valid_count);
return 0;
tasks.push_back(load_node_edge_task_pool->enqueue(
[&, i, idx, this]() -> std::pair<uint64_t, uint64_t> {
return parse_edge_file(paths[i], idx, reverse_edge);
}));
}
for (int j = 0; j < (int)tasks.size(); j++) tasks[j].get();
for (int j = 0; j < (int)tasks.size(); j++) {
auto res = tasks[j].get();
count += res.first;
valid_count += res.second;
}
} else {
for (auto path : paths) {
parse_edge_file(path, idx, reverse_edge, count, valid_count);
auto res = parse_edge_file(path, idx, reverse_edge);
count += res.first;
valid_count += res.second;
}
}
VLOG(0) << valid_count << "/" << count << " edge_type[" << edge_type << "] edges are loaded successfully";
Expand Down Expand Up @@ -1581,6 +1706,7 @@ int GraphTable::parse_feature(int idx, const std::string& feat_str,
// "")
std::vector<std::string> fields =
paddle::string::split_string<std::string>(feat_str, feature_separator_);

auto it = feat_id_map[idx].find(fields[0]);
if (it != feat_id_map[idx].end()) {
int32_t id = it->second;
Expand All @@ -1604,10 +1730,10 @@ int GraphTable::parse_feature(int idx, const std::string& feat_str,
} else if (dtype == "int64") {
FeatureNode::parse_value_to_bytes<uint64_t>(fields.begin() + 1, fields.end(), fea_ptr);
return 0;
}
}
} else {
VLOG(2) << "feature_name[" << fields[0]
<< "] is not in feat_id_map, ntype_id[" << idx
VLOG(2) << "feature_name[" << fields[0]
<< "] is not in feat_id_map, ntype_id[" << idx
<< "] feat_id_map_size[" << feat_id_map.size() << "]";
}

Expand Down
Loading

0 comments on commit c176c00

Please sign in to comment.