Skip to content

Commit 0f3604b

Browse files
committed
[GraphBolt][CUDA] Overlap original edge ids fetch.
1 parent 2521081 commit 0f3604b

File tree

2 files changed

+102
-16
lines changed

2 files changed

+102
-16
lines changed

python/dgl/graphbolt/impl/fused_csc_sampling_graph.py

+57-13
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,32 @@
2323

2424

2525
class _SampleNeighborsWaiter:
26-
def __init__(self, fn, future, seed_offsets):
26+
def __init__(
27+
self, fn, future, seed_offsets, fetching_original_edge_ids_is_optional
28+
):
2729
self.fn = fn
2830
self.future = future
2931
self.seed_offsets = seed_offsets
32+
self.fetching_original_edge_ids_is_optional = (
33+
fetching_original_edge_ids_is_optional
34+
)
3035

3136
def wait(self):
3237
"""Returns the stored value when invoked."""
3338
fn = self.fn
3439
C_sampled_subgraph = self.future.wait()
3540
seed_offsets = self.seed_offsets
41+
fetching_original_edge_ids_is_optional = (
42+
self.fetching_original_edge_ids_is_optional
43+
)
3644
# Ensure there is no memory leak.
3745
self.fn = self.future = self.seed_offsets = None
38-
return fn(C_sampled_subgraph, seed_offsets)
46+
self.fetching_original_edge_ids_is_optional = None
47+
return fn(
48+
C_sampled_subgraph,
49+
seed_offsets,
50+
fetching_original_edge_ids_is_optional,
51+
)
3952

4053

4154
class FusedCSCSamplingGraph(SamplingGraph):
@@ -592,6 +605,7 @@ def _convert_to_sampled_subgraph(
592605
self,
593606
C_sampled_subgraph: torch.ScriptObject,
594607
seed_offsets: Optional[list] = None,
608+
fetching_original_edge_ids_is_optional: bool = False,
595609
) -> SampledSubgraphImpl:
596610
"""An internal function used to convert a fused homogeneous sampled
597611
subgraph to general struct 'SampledSubgraphImpl'."""
@@ -611,18 +625,24 @@ def _convert_to_sampled_subgraph(
611625
and ORIGINAL_EDGE_ID in self.edge_attributes
612626
)
613627
original_edge_ids = (
614-
torch.ops.graphbolt.index_select(
615-
self.edge_attributes[ORIGINAL_EDGE_ID],
616-
edge_ids_in_fused_csc_sampling_graph,
628+
(
629+
torch.ops.graphbolt.index_select(
630+
self.edge_attributes[ORIGINAL_EDGE_ID],
631+
edge_ids_in_fused_csc_sampling_graph,
632+
)
633+
if not fetching_original_edge_ids_is_optional
634+
or not edge_ids_in_fused_csc_sampling_graph.is_cuda
635+
or not self.edge_attributes[ORIGINAL_EDGE_ID].is_pinned()
636+
else None
617637
)
618638
if has_original_eids
619639
else edge_ids_in_fused_csc_sampling_graph
620640
)
621641
if type_per_edge is None and etype_offsets is None:
622642
# The sampled graph is already a homogeneous graph.
623643
sampled_csc = CSCFormatBase(indptr=indptr, indices=indices)
624-
if indices is not None:
625-
# Only needed to fetch indices.
644+
if indices is not None and original_edge_ids is not None:
645+
# Only needed to fetch indices or original_edge_ids.
626646
edge_ids_in_fused_csc_sampling_graph = None
627647
else:
628648
offset = self._node_type_offset_list
@@ -691,10 +711,16 @@ def _convert_to_sampled_subgraph(
691711
]
692712
]
693713
)
694-
original_hetero_edge_ids[etype] = original_edge_ids[
695-
etype_offsets[etype_id] : etype_offsets[etype_id + 1]
696-
]
697-
if indices is None:
714+
original_hetero_edge_ids[etype] = (
715+
None
716+
if original_edge_ids is None
717+
else original_edge_ids[
718+
etype_offsets[etype_id] : etype_offsets[
719+
etype_id + 1
720+
]
721+
]
722+
)
723+
if indices is None or original_edge_ids is None:
698724
# Only needed to fetch indices.
699725
sampled_hetero_edge_ids_in_fused_csc_sampling_graph[
700726
etype
@@ -728,6 +754,7 @@ def sample_neighbors(
728754
replace: bool = False,
729755
probs_name: Optional[str] = None,
730756
returning_indices_is_optional: bool = False,
757+
fetching_original_edge_ids_is_optional: bool = False,
731758
async_op: bool = False,
732759
) -> SampledSubgraphImpl:
733760
"""Sample neighboring edges of the given nodes and return the induced
@@ -772,6 +799,11 @@ def sample_neighbors(
772799
Boolean indicating whether it is okay for the call to this function
773800
to leave the indices tensor uninitialized. In this case, it is the
774801
user's responsibility to gather it using the edge ids.
802+
fetching_original_edge_ids_is_optional: bool
803+
Boolean indicating whether it is okay for the call to this function
804+
to leave the original edge ids tensor uninitialized. In this case,
805+
it is the user's responsibility to gather it using
806+
_edge_ids_in_fused_csc_sampling_graph.
775807
async_op: bool
776808
Boolean indicating whether the call is asynchronous. If so, the
777809
result can be obtained by calling wait on the returned future.
@@ -826,10 +858,13 @@ def sample_neighbors(
826858
self._convert_to_sampled_subgraph,
827859
C_sampled_subgraph,
828860
seed_offsets,
861+
fetching_original_edge_ids_is_optional,
829862
)
830863
else:
831864
return self._convert_to_sampled_subgraph(
832-
C_sampled_subgraph, seed_offsets
865+
C_sampled_subgraph,
866+
seed_offsets,
867+
fetching_original_edge_ids_is_optional,
833868
)
834869

835870
def _check_sampler_arguments(self, nodes, fanouts, probs_or_mask):
@@ -957,6 +992,7 @@ def sample_layer_neighbors(
957992
replace: bool = False,
958993
probs_name: Optional[str] = None,
959994
returning_indices_is_optional: bool = False,
995+
fetching_original_edge_ids_is_optional: bool = False,
960996
random_seed: torch.Tensor = None,
961997
seed2_contribution: float = 0.0,
962998
async_op: bool = False,
@@ -1005,6 +1041,11 @@ def sample_layer_neighbors(
10051041
Boolean indicating whether it is okay for the call to this function
10061042
to leave the indices tensor uninitialized. In this case, it is the
10071043
user's responsibility to gather it using the edge ids.
1044+
fetching_original_edge_ids_is_optional: bool
1045+
Boolean indicating whether it is okay for the call to this function
1046+
to leave the original edge ids tensor uninitialized. In this case,
1047+
it is the user's responsibility to gather it using
1048+
_edge_ids_in_fused_csc_sampling_graph.
10081049
random_seed: torch.Tensor, optional
10091050
An int64 tensor with one or two elements.
10101051
@@ -1102,10 +1143,13 @@ def sample_layer_neighbors(
11021143
self._convert_to_sampled_subgraph,
11031144
C_sampled_subgraph,
11041145
seed_offsets,
1146+
fetching_original_edge_ids_is_optional,
11051147
)
11061148
else:
11071149
return self._convert_to_sampled_subgraph(
1108-
C_sampled_subgraph, seed_offsets
1150+
C_sampled_subgraph,
1151+
seed_offsets,
1152+
fetching_original_edge_ids_is_optional,
11091153
)
11101154

11111155
def temporal_sample_neighbors(

python/dgl/graphbolt/impl/neighbor_sampler.py

+45-3
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,21 @@ def __init__(
252252
):
253253
graph = sampler.__self__
254254
self.returning_indices_is_optional = False
255+
original_edge_ids = (
256+
None
257+
if graph.edge_attributes is None
258+
else graph.edge_attributes.get(ORIGINAL_EDGE_ID, None)
259+
)
260+
self.fetching_original_edge_ids_is_optional = (
261+
overlap_fetch
262+
and original_edge_ids is not None
263+
and original_edge_ids.is_pinned()
264+
)
265+
fetch_indices_and_original_edge_ids_fn = partial(
266+
self._fetch_indices_and_original_edge_ids,
267+
graph.indices,
268+
original_edge_ids,
269+
)
255270
if (
256271
overlap_fetch
257272
and sampler.__name__ == "sample_neighbors"
@@ -263,7 +278,7 @@ def __init__(
263278
datapipe = datapipe.buffer()
264279
datapipe = datapipe.transform(self._wait_subgraph_future)
265280
datapipe = (
266-
datapipe.transform(partial(self._fetch_indices, graph.indices))
281+
datapipe.transform(fetch_indices_and_original_edge_ids_fn)
267282
.buffer()
268283
.wait()
269284
)
@@ -285,6 +300,12 @@ def __init__(
285300
if asynchronous:
286301
datapipe = datapipe.buffer()
287302
datapipe = datapipe.transform(self._wait_subgraph_future)
303+
if self.fetching_original_edge_ids_is_optional:
304+
datapipe = (
305+
datapipe.transform(fetch_indices_and_original_edge_ids_fn)
306+
.buffer()
307+
.wait()
308+
)
288309
else:
289310
datapipe = datapipe.transform(self._sample_per_layer)
290311
if asynchronous:
@@ -310,6 +331,7 @@ def _sample_per_layer(self, minibatch):
310331
self.replace,
311332
self.prob_name,
312333
self.returning_indices_is_optional,
334+
self.fetching_original_edge_ids_is_optional,
313335
async_op=self.asynchronous,
314336
**kwargs,
315337
)
@@ -329,6 +351,8 @@ def _sample_per_layer_from_fetched_subgraph(self, minibatch):
329351
self.fanout,
330352
self.replace,
331353
self.prob_name,
354+
False,
355+
self.fetching_original_edge_ids_is_optional,
332356
async_op=self.asynchronous,
333357
**kwargs,
334358
)
@@ -341,7 +365,7 @@ def _wait_subgraph_future(minibatch):
341365
return minibatch
342366

343367
@staticmethod
344-
def _fetch_indices(indices, minibatch):
368+
def _fetch_indices_and_original_edge_ids(indices, orig_edge_ids, minibatch):
345369
stream = torch.cuda.current_stream()
346370
host_to_device_stream = get_host_to_device_uva_stream()
347371
host_to_device_stream.wait_stream(stream)
@@ -366,6 +390,13 @@ def record_stream(tensor):
366390
index_select(indices, edge_ids)
367391
)
368392
minibatch._indices_needs_offset_subtraction = True
393+
if (
394+
orig_edge_ids is not None
395+
and subgraph.original_edge_ids[etype] is None
396+
):
397+
subgraph.original_edge_ids[etype] = record_stream(
398+
index_select(orig_edge_ids, edge_ids)
399+
)
369400
elif subgraph.sampled_csc.indices is None:
370401
subgraph._edge_ids_in_fused_csc_sampling_graph.record_stream(
371402
torch.cuda.current_stream()
@@ -375,7 +406,18 @@ def record_stream(tensor):
375406
indices, subgraph._edge_ids_in_fused_csc_sampling_graph
376407
)
377408
)
378-
minibatch._indices_needs_offset_subtraction = True
409+
# homo case does not need subtraction of offsets from indices.
410+
minibatch._indices_needs_offset_subtraction = False
411+
if (
412+
orig_edge_ids is not None
413+
and subgraph.original_edge_ids is None
414+
):
415+
subgraph.original_edge_ids = record_stream(
416+
index_select(
417+
orig_edge_ids,
418+
subgraph._edge_ids_in_fused_csc_sampling_graph,
419+
)
420+
)
379421
subgraph._edge_ids_in_fused_csc_sampling_graph = None
380422
minibatch.wait = torch.cuda.current_stream().record_event().wait
381423

0 commit comments

Comments
 (0)