Skip to content

Commit 3ee4ea6

Browse files
committed
record stream in both branches
1 parent 702a4bc commit 3ee4ea6

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

python/dgl/graphbolt/impl/neighbor_sampler.py

+9
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,12 @@ def record_stream(tensor):
398398
orig_edge_ids is not None
399399
and subgraph.original_edge_ids[etype] is None
400400
):
401+
edge_ids = (
402+
subgraph._edge_ids_in_fused_csc_sampling_graph[
403+
etype
404+
]
405+
)
406+
edge_ids.record_stream(torch.cuda.current_stream())
401407
subgraph.original_edge_ids[etype] = record_stream(
402408
index_select(orig_edge_ids, edge_ids)
403409
)
@@ -418,6 +424,9 @@ def record_stream(tensor):
418424
orig_edge_ids is not None
419425
and subgraph.original_edge_ids is None
420426
):
427+
subgraph._edge_ids_in_fused_csc_sampling_graph.record_stream(
428+
torch.cuda.current_stream()
429+
)
421430
subgraph.original_edge_ids = record_stream(
422431
index_select(
423432
orig_edge_ids,

0 commit comments

Comments
 (0)