We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 702a4bc commit 3ee4ea6Copy full SHA for 3ee4ea6
python/dgl/graphbolt/impl/neighbor_sampler.py
@@ -398,6 +398,12 @@ def record_stream(tensor):
398
orig_edge_ids is not None
399
and subgraph.original_edge_ids[etype] is None
400
):
401
+ edge_ids = (
402
+ subgraph._edge_ids_in_fused_csc_sampling_graph[
403
+ etype
404
+ ]
405
+ )
406
+ edge_ids.record_stream(torch.cuda.current_stream())
407
subgraph.original_edge_ids[etype] = record_stream(
408
index_select(orig_edge_ids, edge_ids)
409
)
@@ -418,6 +424,9 @@ def record_stream(tensor):
418
424
419
425
and subgraph.original_edge_ids is None
420
426
427
+ subgraph._edge_ids_in_fused_csc_sampling_graph.record_stream(
428
+ torch.cuda.current_stream()
429
421
430
subgraph.original_edge_ids = record_stream(
422
431
index_select(
423
432
orig_edge_ids,
0 commit comments