diff --git a/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h b/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h index 1b996a9b9359b..92e9b33b7f5c6 100644 --- a/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h +++ b/paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h @@ -21,6 +21,7 @@ #include "paddle/fluid/memory/memory.h" #include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/phi/core/enforce.h" +DECLARE_bool(gpugraph_load_node_list_into_hbm); namespace paddle { namespace framework { struct GpuPsNodeInfo { @@ -31,7 +32,9 @@ struct GpuPsNodeInfo { }; struct GpuPsCommGraph { - uint64_t *node_list; //locate on both side + uint64_t *node_list; + // when FLAGS_gpugraph_load_node_list_into_hbm is ture locate on both side + // else only locate on host side int64_t node_size; // the size of node_list GpuPsNodeInfo *node_info_list; // only locate on host side uint64_t *neighbor_list; //locate on both side diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu index d1e854fb16898..61ab2037673f7 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu @@ -435,10 +435,12 @@ void GpuPsGraphTable::build_graph_on_single_gpu(const GpuPsCommGraph& g, int i, size_t capacity = std::max((uint64_t)1, (uint64_t)g.node_size) / load_factor_; tables_[table_offset] = new Table(capacity); if (g.node_size > 0) { - CUDA_CHECK(cudaMalloc((void**)&gpu_graph_list_[offset].node_list, + if (FLAGS_gpugraph_load_node_list_into_hbm) { + CUDA_CHECK(cudaMalloc((void**)&gpu_graph_list_[offset].node_list, g.node_size * sizeof(uint64_t))); - CUDA_CHECK(cudaMemcpy(gpu_graph_list_[offset].node_list, g.node_list, + CUDA_CHECK(cudaMemcpy(gpu_graph_list_[offset].node_list, g.node_list, g.node_size * sizeof(uint64_t), cudaMemcpyHostToDevice)); + } build_ps(i, g.node_list, (uint64_t*)(g.node_info_list), g.node_size, 1024, 8, table_offset); diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu index 841bcddb9e1fd..01dc83c323040 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu @@ -337,6 +337,8 @@ std::vector GraphGpuWrapper::graph_neighbor_sample( NodeQueryResult GraphGpuWrapper::query_node_list(int gpu_id, int idx, int start, int query_size) { + PADDLE_ENFORCE(FLAGS_gpugraph_load_node_list_into_hbm == true, + "when use query_node_list should set gpugraph_load_node_list_into_hbm true"); return ((GpuPsGraphTable *)graph_table) ->query_node_list(gpu_id, idx, start, query_size); } diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index c544a426d11f3..f6bb96e2f62dd 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -869,6 +869,9 @@ PADDLE_DEFINE_EXPORTED_uint64( PADDLE_DEFINE_EXPORTED_int32( gpugraph_dedup_pull_push_mode, 0, "enable dedup keys while pull push sparse, default 0"); +PADDLE_DEFINE_EXPORTED_bool( + gpugraph_load_node_list_into_hbm, true, + "enable load_node_list_into_hbm, default true"); /** * ProcessGroupNCCL related FLAG