Skip to content

Commit

Permalink
add FLAGS_gpugraph_load_node_list_into_hbm to decide wether load node…
Browse files Browse the repository at this point in the history
…_list into hbm (PaddlePaddle#72)

Co-authored-by: yangjunchao <[email protected]>
  • Loading branch information
chao9527 and yangjunchao authored Jul 27, 2022
1 parent e7ca266 commit 416c558
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 3 deletions.
5 changes: 4 additions & 1 deletion paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/phi/core/enforce.h"
DECLARE_bool(gpugraph_load_node_list_into_hbm);
namespace paddle {
namespace framework {
struct GpuPsNodeInfo {
Expand All @@ -31,7 +32,9 @@ struct GpuPsNodeInfo {
};

struct GpuPsCommGraph {
uint64_t *node_list; //locate on both side
uint64_t *node_list;
// when FLAGS_gpugraph_load_node_list_into_hbm is ture locate on both side
// else only locate on host side
int64_t node_size; // the size of node_list
GpuPsNodeInfo *node_info_list; // only locate on host side
uint64_t *neighbor_list; //locate on both side
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -435,10 +435,12 @@ void GpuPsGraphTable::build_graph_on_single_gpu(const GpuPsCommGraph& g, int i,
size_t capacity = std::max((uint64_t)1, (uint64_t)g.node_size) / load_factor_;
tables_[table_offset] = new Table(capacity);
if (g.node_size > 0) {
CUDA_CHECK(cudaMalloc((void**)&gpu_graph_list_[offset].node_list,
if (FLAGS_gpugraph_load_node_list_into_hbm) {
CUDA_CHECK(cudaMalloc((void**)&gpu_graph_list_[offset].node_list,
g.node_size * sizeof(uint64_t)));
CUDA_CHECK(cudaMemcpy(gpu_graph_list_[offset].node_list, g.node_list,
CUDA_CHECK(cudaMemcpy(gpu_graph_list_[offset].node_list, g.node_list,
g.node_size * sizeof(uint64_t), cudaMemcpyHostToDevice));
}

build_ps(i, g.node_list, (uint64_t*)(g.node_info_list),
g.node_size, 1024, 8, table_offset);
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,8 @@ std::vector<uint64_t> GraphGpuWrapper::graph_neighbor_sample(

NodeQueryResult GraphGpuWrapper::query_node_list(int gpu_id, int idx, int start,
int query_size) {
PADDLE_ENFORCE(FLAGS_gpugraph_load_node_list_into_hbm == true,
"when use query_node_list should set gpugraph_load_node_list_into_hbm true");
return ((GpuPsGraphTable *)graph_table)
->query_node_list(gpu_id, idx, start, query_size);
}
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/platform/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,9 @@ PADDLE_DEFINE_EXPORTED_uint64(
PADDLE_DEFINE_EXPORTED_int32(
gpugraph_dedup_pull_push_mode, 0,
"enable dedup keys while pull push sparse, default 0");
PADDLE_DEFINE_EXPORTED_bool(
gpugraph_load_node_list_into_hbm, true,
"enable load_node_list_into_hbm, default true");

/**
* ProcessGroupNCCL related FLAG
Expand Down

0 comments on commit 416c558

Please sign in to comment.