23
23
24
24
25
25
class _SampleNeighborsWaiter :
26
- def __init__ (self , fn , future , seed_offsets ):
26
+ def __init__ (
27
+ self , fn , future , seed_offsets , fetching_original_edge_ids_is_optional
28
+ ):
27
29
self .fn = fn
28
30
self .future = future
29
31
self .seed_offsets = seed_offsets
32
+ self .fetching_original_edge_ids_is_optional = (
33
+ fetching_original_edge_ids_is_optional
34
+ )
30
35
31
36
def wait (self ):
32
37
"""Returns the stored value when invoked."""
33
38
fn = self .fn
34
39
C_sampled_subgraph = self .future .wait ()
35
40
seed_offsets = self .seed_offsets
41
+ fetching_original_edge_ids_is_optional = (
42
+ self .fetching_original_edge_ids_is_optional
43
+ )
36
44
# Ensure there is no memory leak.
37
45
self .fn = self .future = self .seed_offsets = None
38
- return fn (C_sampled_subgraph , seed_offsets )
46
+ self .fetching_original_edge_ids_is_optional = None
47
+ return fn (
48
+ C_sampled_subgraph ,
49
+ seed_offsets ,
50
+ fetching_original_edge_ids_is_optional ,
51
+ )
39
52
40
53
41
54
class FusedCSCSamplingGraph (SamplingGraph ):
@@ -592,6 +605,7 @@ def _convert_to_sampled_subgraph(
592
605
self ,
593
606
C_sampled_subgraph : torch .ScriptObject ,
594
607
seed_offsets : Optional [list ] = None ,
608
+ fetching_original_edge_ids_is_optional : bool = False ,
595
609
) -> SampledSubgraphImpl :
596
610
"""An internal function used to convert a fused homogeneous sampled
597
611
subgraph to general struct 'SampledSubgraphImpl'."""
@@ -611,18 +625,24 @@ def _convert_to_sampled_subgraph(
611
625
and ORIGINAL_EDGE_ID in self .edge_attributes
612
626
)
613
627
original_edge_ids = (
614
- torch .ops .graphbolt .index_select (
615
- self .edge_attributes [ORIGINAL_EDGE_ID ],
616
- edge_ids_in_fused_csc_sampling_graph ,
628
+ (
629
+ torch .ops .graphbolt .index_select (
630
+ self .edge_attributes [ORIGINAL_EDGE_ID ],
631
+ edge_ids_in_fused_csc_sampling_graph ,
632
+ )
633
+ if not fetching_original_edge_ids_is_optional
634
+ or not edge_ids_in_fused_csc_sampling_graph .is_cuda
635
+ or not self .edge_attributes [ORIGINAL_EDGE_ID ].is_pinned ()
636
+ else None
617
637
)
618
638
if has_original_eids
619
639
else edge_ids_in_fused_csc_sampling_graph
620
640
)
621
641
if type_per_edge is None and etype_offsets is None :
622
642
# The sampled graph is already a homogeneous graph.
623
643
sampled_csc = CSCFormatBase (indptr = indptr , indices = indices )
624
- if indices is not None :
625
- # Only needed to fetch indices.
644
+ if indices is not None and original_edge_ids is not None :
645
+ # Only needed to fetch indices or original_edge_ids .
626
646
edge_ids_in_fused_csc_sampling_graph = None
627
647
else :
628
648
offset = self ._node_type_offset_list
@@ -691,10 +711,16 @@ def _convert_to_sampled_subgraph(
691
711
]
692
712
]
693
713
)
694
- original_hetero_edge_ids [etype ] = original_edge_ids [
695
- etype_offsets [etype_id ] : etype_offsets [etype_id + 1 ]
696
- ]
697
- if indices is None :
714
+ original_hetero_edge_ids [etype ] = (
715
+ None
716
+ if original_edge_ids is None
717
+ else original_edge_ids [
718
+ etype_offsets [etype_id ] : etype_offsets [
719
+ etype_id + 1
720
+ ]
721
+ ]
722
+ )
723
+ if indices is None or original_edge_ids is None :
698
724
# Only needed to fetch indices.
699
725
sampled_hetero_edge_ids_in_fused_csc_sampling_graph [
700
726
etype
@@ -728,6 +754,7 @@ def sample_neighbors(
728
754
replace : bool = False ,
729
755
probs_name : Optional [str ] = None ,
730
756
returning_indices_is_optional : bool = False ,
757
+ fetching_original_edge_ids_is_optional : bool = False ,
731
758
async_op : bool = False ,
732
759
) -> SampledSubgraphImpl :
733
760
"""Sample neighboring edges of the given nodes and return the induced
@@ -772,6 +799,11 @@ def sample_neighbors(
772
799
Boolean indicating whether it is okay for the call to this function
773
800
to leave the indices tensor uninitialized. In this case, it is the
774
801
user's responsibility to gather it using the edge ids.
802
+ fetching_original_edge_ids_is_optional: bool
803
+ Boolean indicating whether it is okay for the call to this function
804
+ to leave the original edge ids tensor uninitialized. In this case,
805
+ it is the user's responsibility to gather it using
806
+ _edge_ids_in_fused_csc_sampling_graph.
775
807
async_op: bool
776
808
Boolean indicating whether the call is asynchronous. If so, the
777
809
result can be obtained by calling wait on the returned future.
@@ -826,10 +858,13 @@ def sample_neighbors(
826
858
self ._convert_to_sampled_subgraph ,
827
859
C_sampled_subgraph ,
828
860
seed_offsets ,
861
+ fetching_original_edge_ids_is_optional ,
829
862
)
830
863
else :
831
864
return self ._convert_to_sampled_subgraph (
832
- C_sampled_subgraph , seed_offsets
865
+ C_sampled_subgraph ,
866
+ seed_offsets ,
867
+ fetching_original_edge_ids_is_optional ,
833
868
)
834
869
835
870
def _check_sampler_arguments (self , nodes , fanouts , probs_or_mask ):
@@ -957,6 +992,7 @@ def sample_layer_neighbors(
957
992
replace : bool = False ,
958
993
probs_name : Optional [str ] = None ,
959
994
returning_indices_is_optional : bool = False ,
995
+ fetching_original_edge_ids_is_optional : bool = False ,
960
996
random_seed : torch .Tensor = None ,
961
997
seed2_contribution : float = 0.0 ,
962
998
async_op : bool = False ,
@@ -1005,6 +1041,11 @@ def sample_layer_neighbors(
1005
1041
Boolean indicating whether it is okay for the call to this function
1006
1042
to leave the indices tensor uninitialized. In this case, it is the
1007
1043
user's responsibility to gather it using the edge ids.
1044
+ fetching_original_edge_ids_is_optional: bool
1045
+ Boolean indicating whether it is okay for the call to this function
1046
+ to leave the original edge ids tensor uninitialized. In this case,
1047
+ it is the user's responsibility to gather it using
1048
+ _edge_ids_in_fused_csc_sampling_graph.
1008
1049
random_seed: torch.Tensor, optional
1009
1050
An int64 tensor with one or two elements.
1010
1051
@@ -1102,10 +1143,13 @@ def sample_layer_neighbors(
1102
1143
self ._convert_to_sampled_subgraph ,
1103
1144
C_sampled_subgraph ,
1104
1145
seed_offsets ,
1146
+ fetching_original_edge_ids_is_optional ,
1105
1147
)
1106
1148
else :
1107
1149
return self ._convert_to_sampled_subgraph (
1108
- C_sampled_subgraph , seed_offsets
1150
+ C_sampled_subgraph ,
1151
+ seed_offsets ,
1152
+ fetching_original_edge_ids_is_optional ,
1109
1153
)
1110
1154
1111
1155
def temporal_sample_neighbors (
0 commit comments