Skip to content

Commit 220140f

Browse files
mfbalinlijialin03
authored andcommitted
[GraphBolt][CUDA] Refactor overlap_graph_fetch, simplify gb.DataLoader. (dmlc#7681)
1 parent 58e7f0d commit 220140f

File tree

11 files changed

+187
-237
lines changed

11 files changed

+187
-237
lines changed

examples/graphbolt/disk_based_feature/node_classification.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,10 @@ def create_dataloader(
115115
else {}
116116
)
117117
datapipe = getattr(datapipe, args.sample_mode)(
118-
graph, fanout if job != "infer" else [-1], **kwargs
118+
graph,
119+
fanout if job != "infer" else [-1],
120+
overlap_fetch=args.overlap_graph_fetch,
121+
**kwargs,
119122
)
120123
# Copy the data to the specified device.
121124
if args.feature_device != "cpu":
@@ -130,11 +133,7 @@ def create_dataloader(
130133
if args.feature_device == "cpu":
131134
datapipe = datapipe.copy_to(device=device)
132135
# Create and return a DataLoader to handle data loading.
133-
return gb.DataLoader(
134-
datapipe,
135-
num_workers=args.num_workers,
136-
overlap_graph_fetch=args.overlap_graph_fetch,
137-
)
136+
return gb.DataLoader(datapipe, num_workers=args.num_workers)
138137

139138

140139
def train_step(minibatch, optimizer, model, loss_fn):

examples/graphbolt/node_classification.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,9 @@ def create_dataloader(
117117
# Initialize a neighbor sampler for sampling the neighborhoods of nodes.
118118
############################################################################
119119
datapipe = getattr(datapipe, args.sample_mode)(
120-
graph, fanout if job != "infer" else [-1]
120+
graph,
121+
fanout if job != "infer" else [-1],
122+
overlap_fetch=args.storage_device == "pinned",
121123
)
122124

123125
############################################################################
@@ -156,11 +158,7 @@ def create_dataloader(
156158
# [Role]:
157159
# Initialize a multi-process dataloader to load the data in parallel.
158160
############################################################################
159-
dataloader = gb.DataLoader(
160-
datapipe,
161-
num_workers=num_workers,
162-
overlap_graph_fetch=args.storage_device == "pinned",
163-
)
161+
dataloader = gb.DataLoader(datapipe, num_workers=num_workers)
164162

165163
# Return the fully-initialized DataLoader object.
166164
return dataloader

examples/graphbolt/pyg/labor/node_classification.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,10 @@ def create_dataloader(
147147
else {}
148148
)
149149
datapipe = getattr(datapipe, args.sample_mode)(
150-
graph, fanout if job != "infer" else [-1], **kwargs
150+
graph,
151+
fanout if job != "infer" else [-1],
152+
overlap_fetch=args.overlap_graph_fetch,
153+
**kwargs,
151154
)
152155
# Copy the data to the specified device.
153156
if args.feature_device != "cpu" and need_copy:
@@ -163,11 +166,7 @@ def create_dataloader(
163166
if need_copy:
164167
datapipe = datapipe.copy_to(device=device)
165168
# Create and return a DataLoader to handle data loading.
166-
return gb.DataLoader(
167-
datapipe,
168-
num_workers=args.num_workers,
169-
overlap_graph_fetch=args.overlap_graph_fetch,
170-
)
169+
return gb.DataLoader(datapipe, num_workers=args.num_workers)
171170

172171

173172
@torch.compile

examples/graphbolt/pyg/node_classification_advanced.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,11 @@ def create_dataloader(
195195
need_copy = False
196196
# Sample neighbors for each node in the mini-batch.
197197
datapipe = getattr(datapipe, args.sample_mode)(
198-
graph, fanout if job != "infer" else [-1]
198+
graph,
199+
fanout if job != "infer" else [-1],
200+
overlap_fetch=args.overlap_graph_fetch,
201+
num_gpu_cached_edges=args.num_gpu_cached_edges,
202+
gpu_cache_threshold=args.gpu_graph_caching_threshold,
199203
)
200204
# Copy the data to the specified device.
201205
if args.feature_device != "cpu" and need_copy:
@@ -211,13 +215,7 @@ def create_dataloader(
211215
if need_copy:
212216
datapipe = datapipe.copy_to(device=device)
213217
# Create and return a DataLoader to handle data loading.
214-
return gb.DataLoader(
215-
datapipe,
216-
num_workers=args.num_workers,
217-
overlap_graph_fetch=args.overlap_graph_fetch,
218-
num_gpu_cached_edges=args.num_gpu_cached_edges,
219-
gpu_cache_threshold=args.gpu_graph_caching_threshold,
220-
)
218+
return gb.DataLoader(datapipe, num_workers=args.num_workers)
221219

222220

223221
@torch.compile

examples/graphbolt/rgcn/hetero_rgcn.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,9 @@ def create_dataloader(
124124
# The graph(FusedCSCSamplingGraph) from which to sample neighbors.
125125
# `fanouts`:
126126
# The number of neighbors to sample for each node in each layer.
127-
datapipe = datapipe.sample_neighbor(graph, fanouts=fanouts)
127+
datapipe = datapipe.sample_neighbor(
128+
graph, fanouts=fanouts, overlap_fetch=args.overlap_graph_fetch
129+
)
128130

129131
# Fetch the features for each node in the mini-batch.
130132
# `features`:
@@ -141,11 +143,7 @@ def create_dataloader(
141143
# Create a DataLoader from the datapipe.
142144
# `num_workers`:
143145
# The number of worker processes to use for data loading.
144-
return gb.DataLoader(
145-
datapipe,
146-
num_workers=num_workers,
147-
overlap_graph_fetch=args.overlap_graph_fetch,
148-
)
146+
return gb.DataLoader(datapipe, num_workers=num_workers)
149147

150148

151149
def extract_embed(node_embed, input_nodes):

examples/multigpu/graphbolt/node_classification.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -134,16 +134,14 @@ def create_dataloader(
134134
############################################################################
135135
if args.storage_device != "cpu":
136136
datapipe = datapipe.copy_to(device)
137-
datapipe = datapipe.sample_neighbor(graph, args.fanout)
137+
datapipe = datapipe.sample_neighbor(
138+
graph, args.fanout, overlap_fetch=args.storage_device == "pinned"
139+
)
138140
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
139141
if args.storage_device == "cpu":
140142
datapipe = datapipe.copy_to(device)
141143

142-
dataloader = gb.DataLoader(
143-
datapipe,
144-
args.num_workers,
145-
overlap_graph_fetch=args.storage_device == "pinned",
146-
)
144+
dataloader = gb.DataLoader(datapipe, args.num_workers)
147145

148146
# Return the fully-initialized DataLoader object.
149147
return dataloader

python/dgl/graphbolt/dataloader.py

+4-87
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
"""Graph Bolt DataLoaders"""
22

3-
from collections import OrderedDict
4-
53
import torch
64
import torch.utils.data as torch_data
75

8-
from .base import CopyTo, get_host_to_device_uva_stream
6+
from .base import CopyTo
97
from .feature_fetcher import FeatureFetcher, FeatureFetcherStartMarker
10-
from .impl.gpu_graph_cache import GPUGraphCache
118
from .impl.neighbor_sampler import SamplePerLayer
129

1310
from .internal import (
@@ -22,34 +19,9 @@
2219

2320
__all__ = [
2421
"DataLoader",
25-
"construct_gpu_graph_cache",
2622
]
2723

2824

29-
def construct_gpu_graph_cache(
30-
sample_per_layer_obj, num_gpu_cached_edges, gpu_cache_threshold
31-
):
32-
"Construct a GPUGraphCache given a sample_per_layer_obj and cache parameters."
33-
graph = sample_per_layer_obj.sampler.__self__
34-
num_gpu_cached_edges = min(num_gpu_cached_edges, graph.total_num_edges)
35-
dtypes = OrderedDict()
36-
dtypes["indices"] = graph.indices.dtype
37-
if graph.type_per_edge is not None:
38-
dtypes["type_per_edge"] = graph.type_per_edge.dtype
39-
if graph.edge_attributes is not None:
40-
probs_or_mask = graph.edge_attributes.get(
41-
sample_per_layer_obj.prob_name, None
42-
)
43-
if probs_or_mask is not None:
44-
dtypes["probs_or_mask"] = probs_or_mask.dtype
45-
return GPUGraphCache(
46-
num_gpu_cached_edges,
47-
gpu_cache_threshold,
48-
graph.csc_indptr.dtype,
49-
list(dtypes.values()),
50-
)
51-
52-
5325
def _find_and_wrap_parent(datapipe_graph, target_datapipe, wrapper, **kwargs):
5426
"""Find parent of target_datapipe and wrap it with ."""
5527
datapipes = find_dps(
@@ -125,18 +97,6 @@ class DataLoader(torch_data.DataLoader):
12597
If True, the data loader will not shut down the worker processes after a
12698
dataset has been consumed once. This allows to maintain the workers
12799
instances alive.
128-
overlap_graph_fetch : bool, optional
129-
If True, the data loader will overlap the UVA graph fetching operations
130-
with the rest of operations by using an alternative CUDA stream. This
131-
option should be enabled if you have moved your graph to the pinned
132-
memory for optimal performance. Default is False.
133-
num_gpu_cached_edges : int, optional
134-
If positive and overlap_graph_fetch is True, then the GPU will cache
135-
frequently accessed vertex neighborhoods to reduce the PCI-e bandwidth
136-
demand due to pinned graph accesses.
137-
gpu_cache_threshold : int, optional
138-
Determines how many times a vertex needs to be accessed before its
139-
neighborhood ends up being cached on the GPU.
140100
max_uva_threads : int, optional
141101
Limits the number of CUDA threads used for UVA copies so that the rest
142102
of the computations can run simultaneously with it. Setting it to a too
@@ -150,9 +110,6 @@ def __init__(
150110
datapipe,
151111
num_workers=0,
152112
persistent_workers=True,
153-
overlap_graph_fetch=False,
154-
num_gpu_cached_edges=0,
155-
gpu_cache_threshold=1,
156113
max_uva_threads=10240,
157114
):
158115
# Multiprocessing requires two modifications to the datapipe:
@@ -200,54 +157,14 @@ def __init__(
200157
if feature_fetcher.max_num_stages > 0: # Overlap enabled.
201158
torch.ops.graphbolt.set_max_uva_threads(max_uva_threads)
202159

203-
if (
204-
overlap_graph_fetch
205-
and num_workers == 0
206-
and torch.cuda.is_available()
207-
):
208-
torch.ops.graphbolt.set_max_uva_threads(max_uva_threads)
160+
if num_workers == 0 and torch.cuda.is_available():
209161
samplers = find_dps(
210162
datapipe_graph,
211163
SamplePerLayer,
212164
)
213-
gpu_graph_cache = None
214165
for sampler in samplers:
215-
if num_gpu_cached_edges > 0 and gpu_graph_cache is None:
216-
gpu_graph_cache = construct_gpu_graph_cache(
217-
sampler, num_gpu_cached_edges, gpu_cache_threshold
218-
)
219-
if (
220-
sampler.sampler.__name__ == "sample_layer_neighbors"
221-
or gpu_graph_cache is not None
222-
):
223-
# This code path is not faster for sample_neighbors.
224-
datapipe_graph = replace_dp(
225-
datapipe_graph,
226-
sampler,
227-
sampler.fetch_and_sample(
228-
gpu_graph_cache,
229-
get_host_to_device_uva_stream(),
230-
1,
231-
),
232-
)
233-
elif sampler.sampler.__name__ == "sample_neighbors":
234-
# This code path is faster for sample_neighbors.
235-
datapipe_graph = replace_dp(
236-
datapipe_graph,
237-
sampler,
238-
sampler.datapipe.sample_per_layer(
239-
sampler=sampler.sampler,
240-
fanout=sampler.fanout,
241-
replace=sampler.replace,
242-
prob_name=sampler.prob_name,
243-
returning_indices_is_optional=True,
244-
),
245-
)
246-
else:
247-
raise AssertionError(
248-
"overlap_graph_fetch is supported only for "
249-
"sample_neighbor and sample_layer_neighbor."
250-
)
166+
if sampler.overlap_fetch:
167+
torch.ops.graphbolt.set_max_uva_threads(max_uva_threads)
251168

252169
# (4) Cut datapipe at CopyTo and wrap with pinning and prefetching
253170
# before it. This enables enables non_blocking copies to the device.

python/dgl/graphbolt/impl/fused_csc_sampling_graph.py

+31
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from ..base import etype_str_to_tuple, etype_tuple_to_str, ORIGINAL_EDGE_ID
1111
from ..internal_utils import gb_warning, is_wsl, recursive_apply
1212
from ..sampling_graph import SamplingGraph
13+
from .gpu_graph_cache import GPUGraphCache
1314
from .sampled_subgraph_impl import CSCFormatBase, SampledSubgraphImpl
1415

1516

@@ -315,6 +316,14 @@ def _indptr_node_type_offset_list(
315316
"""Sets the indptr node type offset list if present."""
316317
self._indptr_node_type_offset_list_ = indptr_node_type_offset_list
317318

319+
@property
320+
def _gpu_graph_cache(self) -> Optional[GPUGraphCache]:
321+
return (
322+
self._gpu_graph_cache_
323+
if hasattr(self, "_gpu_graph_cache_")
324+
else None
325+
)
326+
318327
@property
319328
def type_per_edge(self) -> Optional[torch.Tensor]:
320329
"""Returns the edge type tensor if present.
@@ -1432,6 +1441,28 @@ def _pin(x):
14321441

14331442
return self._apply_to_members(_pin)
14341443

1444+
def _initialize_gpu_graph_cache(
1445+
self,
1446+
num_gpu_cached_edges: int,
1447+
gpu_cache_threshold: int,
1448+
prob_name: Optional[str] = None,
1449+
):
1450+
"Construct a GPUGraphCache given the cache parameters."
1451+
num_gpu_cached_edges = min(num_gpu_cached_edges, self.total_num_edges)
1452+
dtypes = [self.indices.dtype]
1453+
if self.type_per_edge is not None:
1454+
dtypes.append(self.type_per_edge.dtype)
1455+
if self.edge_attributes is not None:
1456+
probs_or_mask = self.edge_attributes.get(prob_name, None)
1457+
if probs_or_mask is not None:
1458+
dtypes.append(probs_or_mask.dtype)
1459+
self._gpu_graph_cache_ = GPUGraphCache(
1460+
num_gpu_cached_edges,
1461+
gpu_cache_threshold,
1462+
self.csc_indptr.dtype,
1463+
dtypes,
1464+
)
1465+
14351466

14361467
def fused_csc_sampling_graph(
14371468
csc_indptr: torch.Tensor,

0 commit comments

Comments
 (0)