Skip to content

Commit 4ac369a

Browse files
committed
simplify the logic. Reduce the number of args.
1 parent 0a52ad2 commit 4ac369a

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

python/dgl/graphbolt/impl/fused_csc_sampling_graph.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -797,7 +797,7 @@ def sample_neighbors(
797797
returning_indices_and_original_edge_ids_are_optional: bool
798798
Boolean indicating whether it is okay for the call to this function
799799
to leave the indices and the original edge ids tensors
800-
uninitialized. In this case, it is the user's responsibility to
800+
uninitialized. In this case, it is the user's responsibility to
801801
gather them using _edge_ids_in_fused_csc_sampling_graph if either is
802802
missing.
803803
async_op: bool
@@ -1035,7 +1035,7 @@ def sample_layer_neighbors(
10351035
returning_indices_and_original_edge_ids_are_optional: bool
10361036
Boolean indicating whether it is okay for the call to this function
10371037
to leave the indices and the original edge ids tensors
1038-
uninitialized. In this case, it is the user's responsibility to
1038+
uninitialized. In this case, it is the user's responsibility to
10391039
gather them using _edge_ids_in_fused_csc_sampling_graph if either is
10401040
missing.
10411041
random_seed: torch.Tensor, optional

python/dgl/graphbolt/impl/neighbor_sampler.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,13 @@ def __init__(
272272
if (
273273
overlap_fetch
274274
and sampler.__name__ == "sample_neighbors"
275-
and (graph.indices.is_pinned() or (original_edge_ids is not None and original_edge_ids.is_pinned()))
275+
and (
276+
graph.indices.is_pinned()
277+
or (
278+
original_edge_ids is not None
279+
and original_edge_ids.is_pinned()
280+
)
281+
)
276282
and graph._gpu_graph_cache is None
277283
):
278284
datapipe = datapipe.transform(self._sample_per_layer)

0 commit comments

Comments
 (0)