@@ -257,6 +257,7 @@ def _sample_etype_neighbors_dgl(
257
257
fan_out ,
258
258
edge_dir = "in" ,
259
259
prob = None ,
260
+ exclude_edges = None ,
260
261
replace = False ,
261
262
etype_offset = None ,
262
263
etype_sorted = False ,
@@ -278,9 +279,10 @@ def _sample_etype_neighbors_dgl(
278
279
local_ids ,
279
280
etype_offset ,
280
281
fan_out ,
281
- edge_dir ,
282
- prob ,
283
- replace ,
282
+ edge_dir = edge_dir ,
283
+ prob = prob ,
284
+ exclude_edges = exclude_edges ,
285
+ replace = replace ,
284
286
etype_sorted = etype_sorted ,
285
287
_dist_training = True ,
286
288
)
@@ -481,13 +483,15 @@ def __init__(
481
483
fan_out ,
482
484
edge_dir = "in" ,
483
485
prob = None ,
486
+ exclude_edges = None ,
484
487
replace = False ,
485
488
etype_sorted = True ,
486
489
use_graphbolt = False ,
487
490
):
488
491
self .seed_nodes = nodes
489
492
self .edge_dir = edge_dir
490
493
self .prob = prob
494
+ self .exclude_edges = exclude_edges
491
495
self .replace = replace
492
496
self .fan_out = fan_out
493
497
self .etype_sorted = etype_sorted
@@ -498,6 +502,7 @@ def __setstate__(self, state):
498
502
self .seed_nodes ,
499
503
self .edge_dir ,
500
504
self .prob ,
505
+ self .exclude_edges ,
501
506
self .replace ,
502
507
self .fan_out ,
503
508
self .etype_sorted ,
@@ -509,6 +514,7 @@ def __getstate__(self):
509
514
self .seed_nodes ,
510
515
self .edge_dir ,
511
516
self .prob ,
517
+ self .exclude_edges ,
512
518
self .replace ,
513
519
self .fan_out ,
514
520
self .etype_sorted ,
@@ -536,6 +542,7 @@ def process_request(self, server_state):
536
542
self .fan_out ,
537
543
edge_dir = self .edge_dir ,
538
544
prob = probs ,
545
+ exclude_edges = self .exclude_edges ,
539
546
replace = self .replace ,
540
547
etype_offset = etype_offset ,
541
548
etype_sorted = self .etype_sorted ,
@@ -801,20 +808,19 @@ def _frontier_to_heterogeneous_graph(g, frontier, gpb):
801
808
src_ntype_id = g .get_ntype_id (src_ntype )
802
809
dst_ntype_id = g .get_ntype_id (dst_ntype )
803
810
type_idx = etype_ids == etid
804
- if F .sum (type_idx , 0 ) > 0 :
805
- data_dict [etype ] = (
806
- F .boolean_mask (src , type_idx ),
807
- F .boolean_mask (dst , type_idx ),
808
- )
809
- if "DGL_DIST_DEBUG" in os .environ :
810
- assert torch .all (
811
- src_ntype_id == src_ntype_ids [type_idx ]
812
- ), "source ntype is is not expected."
813
- assert torch .all (
814
- dst_ntype_id == dst_ntype_ids [type_idx ]
815
- ), "destination ntype is is not expected."
816
- if type_wise_eids is not None :
817
- edge_ids [etype ] = F .boolean_mask (type_wise_eids , type_idx )
811
+ data_dict [etype ] = (
812
+ F .boolean_mask (src , type_idx ),
813
+ F .boolean_mask (dst , type_idx ),
814
+ )
815
+ if "DGL_DIST_DEBUG" in os .environ :
816
+ assert torch .all (
817
+ src_ntype_id == src_ntype_ids [type_idx ]
818
+ ), "source ntype is is not expected."
819
+ assert torch .all (
820
+ dst_ntype_id == dst_ntype_ids [type_idx ]
821
+ ), "destination ntype is is not expected."
822
+ if type_wise_eids is not None :
823
+ edge_ids [etype ] = F .boolean_mask (type_wise_eids , type_idx )
818
824
hg = heterograph (
819
825
data_dict ,
820
826
{ntype : g .num_nodes (ntype ) for ntype in g .ntypes },
@@ -832,6 +838,7 @@ def sample_etype_neighbors(
832
838
fanout ,
833
839
edge_dir = "in" ,
834
840
prob = None ,
841
+ exclude_edges = None ,
835
842
replace = False ,
836
843
etype_sorted = True ,
837
844
use_graphbolt = False ,
@@ -879,6 +886,8 @@ def sample_etype_neighbors(
879
886
The features must be non-negative floats, and the sum of the features of
880
887
inbound/outbound edges for every node must be positive (though they don't have
881
888
to sum up to one). Otherwise, the result will be undefined.
889
+ exclude_edges : tensor, optional
890
+ The edges to exclude when sampling.
882
891
replace : bool, optional
883
892
If True, sample with replacement.
884
893
@@ -947,6 +956,7 @@ def issue_remote_req(node_ids):
947
956
fanout ,
948
957
edge_dir = edge_dir ,
949
958
prob = _prob ,
959
+ exclude_edges = exclude_edges ,
950
960
replace = replace ,
951
961
etype_sorted = etype_sorted ,
952
962
use_graphbolt = use_graphbolt ,
@@ -974,6 +984,7 @@ def local_access(local_g, partition_book, local_nids):
974
984
fanout ,
975
985
edge_dir = edge_dir ,
976
986
prob = _prob ,
987
+ exclude_edges = exclude_edges ,
977
988
replace = replace ,
978
989
etype_offset = etype_offset ,
979
990
etype_sorted = etype_sorted ,
0 commit comments