Skip to content

Commit 4101a29

Browse files
Rhett-Yinglijialin03
authored andcommitted
[DistDGL] enable exclude_edges for sample_etype_neighbors() (dmlc#7427)
1 parent 3a9b9c9 commit 4101a29

File tree

6 files changed

+137
-43
lines changed

6 files changed

+137
-43
lines changed

examples/distributed/graphsage/node_classification_unsupervised.py

+21
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,21 @@ def run(args, device, data):
257257
for step, (input_nodes, pos_graph, neg_graph, blocks) in enumerate(
258258
dataloader
259259
):
260+
if args.debug:
261+
# Verify exclude_edges functionality.
262+
for block in blocks:
263+
current_eids = block.edata[dgl.EID]
264+
seed_eids = pos_graph.edata[dgl.EID]
265+
if exclude is None:
266+
assert th.any(th.isin(current_eids, seed_eids))
267+
elif exclude == "self":
268+
assert not th.any(th.isin(current_eids, seed_eids))
269+
elif exclude == "reverse_id":
270+
assert not th.any(th.isin(current_eids, seed_eids))
271+
else:
272+
raise ValueError(
273+
f"Unsupported exclude type: {exclude}"
274+
)
260275
tic_step = time.time()
261276
sample_t.append(tic_step - start)
262277

@@ -449,6 +464,12 @@ def main(args):
449464
action="store_true",
450465
help="whether to remove edges during sampling",
451466
)
467+
parser.add_argument(
468+
"--debug",
469+
default=False,
470+
action="store_true",
471+
help="whether to verify functionality of remove edges",
472+
)
452473
args = parser.parse_args()
453474
print(args)
454475
main(args)

python/dgl/dataloading/dist_dataloader.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Distributed dataloaders.
22
"""
3+
34
import inspect
45
from abc import ABC, abstractmethod, abstractproperty
56
from collections.abc import Mapping
@@ -83,6 +84,8 @@ def _find_exclude_eids(g, exclude_mode, eids, **kwargs):
8384
if exclude_mode is None:
8485
return None
8586
elif exclude_mode == "self":
87+
if isinstance(eids, Mapping):
88+
eids = {g.to_canonical_etype(k): v for k, v in eids.items()}
8689
return eids
8790
elif exclude_mode == "reverse_id":
8891
return _find_exclude_eids_with_reverse_id(

python/dgl/distributed/dist_graph.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@
1010

1111
from .. import backend as F, graphbolt as gb, heterograph_index
1212
from .._ffi.ndarray import empty_shared_mem
13-
from ..base import ALL, dgl_warning, DGLError, EID, ETYPE, is_all, NID
13+
from ..base import ALL, DGLError, EID, ETYPE, is_all, NID
1414
from ..convert import graph as dgl_graph, heterograph as dgl_heterograph
1515
from ..frame import infer_scheme
1616

1717
from ..heterograph import DGLGraph
1818
from ..ndarray import exist_shared_mem_array
19+
from ..sampling.utils import EidExcluder
1920
from ..transforms import compact_graphs
2021
from . import graph_services, role, rpc
2122
from .dist_tensor import DistTensor
@@ -1422,17 +1423,14 @@ def sample_neighbors(
14221423
# pylint: disable=unused-argument
14231424
"""Sample neighbors from a distributed graph."""
14241425
if len(self.etypes) > 1:
1425-
if exclude_edges is not None:
1426-
dgl_warning(
1427-
"exclude_edges is not supported for a graph with multiple edge types."
1428-
)
14291426
frontier = graph_services.sample_etype_neighbors(
14301427
self,
14311428
seed_nodes,
14321429
fanout,
14331430
replace=replace,
14341431
etype_sorted=etype_sorted,
14351432
prob=prob,
1433+
exclude_edges=None,
14361434
use_graphbolt=self._use_graphbolt,
14371435
)
14381436
else:
@@ -1442,9 +1440,18 @@ def sample_neighbors(
14421440
fanout,
14431441
replace=replace,
14441442
prob=prob,
1445-
exclude_edges=exclude_edges,
1443+
exclude_edges=None,
14461444
use_graphbolt=self._use_graphbolt,
14471445
)
1446+
# [TODO][Rui]
1447+
# For now, exclude_edges is applied after sampling. Namely, we first sample
1448+
# the neighbors and then exclude the edges before returning frontier. This
1449+
# is probably not efficient. We could try to exclude the edges during
1450+
# sampling. Or we pass exclude_edges IDs to local and remote sampling
1451+
# functions and let them handle the exclusion.
1452+
if exclude_edges is not None:
1453+
eid_excluder = EidExcluder(exclude_edges)
1454+
frontier = eid_excluder(frontier)
14481455
return frontier
14491456

14501457
def _get_ndata_names(self, ntype=None):

python/dgl/distributed/graph_services.py

+28-17
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ def _sample_etype_neighbors_dgl(
257257
fan_out,
258258
edge_dir="in",
259259
prob=None,
260+
exclude_edges=None,
260261
replace=False,
261262
etype_offset=None,
262263
etype_sorted=False,
@@ -278,9 +279,10 @@ def _sample_etype_neighbors_dgl(
278279
local_ids,
279280
etype_offset,
280281
fan_out,
281-
edge_dir,
282-
prob,
283-
replace,
282+
edge_dir=edge_dir,
283+
prob=prob,
284+
exclude_edges=exclude_edges,
285+
replace=replace,
284286
etype_sorted=etype_sorted,
285287
_dist_training=True,
286288
)
@@ -481,13 +483,15 @@ def __init__(
481483
fan_out,
482484
edge_dir="in",
483485
prob=None,
486+
exclude_edges=None,
484487
replace=False,
485488
etype_sorted=True,
486489
use_graphbolt=False,
487490
):
488491
self.seed_nodes = nodes
489492
self.edge_dir = edge_dir
490493
self.prob = prob
494+
self.exclude_edges = exclude_edges
491495
self.replace = replace
492496
self.fan_out = fan_out
493497
self.etype_sorted = etype_sorted
@@ -498,6 +502,7 @@ def __setstate__(self, state):
498502
self.seed_nodes,
499503
self.edge_dir,
500504
self.prob,
505+
self.exclude_edges,
501506
self.replace,
502507
self.fan_out,
503508
self.etype_sorted,
@@ -509,6 +514,7 @@ def __getstate__(self):
509514
self.seed_nodes,
510515
self.edge_dir,
511516
self.prob,
517+
self.exclude_edges,
512518
self.replace,
513519
self.fan_out,
514520
self.etype_sorted,
@@ -536,6 +542,7 @@ def process_request(self, server_state):
536542
self.fan_out,
537543
edge_dir=self.edge_dir,
538544
prob=probs,
545+
exclude_edges=self.exclude_edges,
539546
replace=self.replace,
540547
etype_offset=etype_offset,
541548
etype_sorted=self.etype_sorted,
@@ -801,20 +808,19 @@ def _frontier_to_heterogeneous_graph(g, frontier, gpb):
801808
src_ntype_id = g.get_ntype_id(src_ntype)
802809
dst_ntype_id = g.get_ntype_id(dst_ntype)
803810
type_idx = etype_ids == etid
804-
if F.sum(type_idx, 0) > 0:
805-
data_dict[etype] = (
806-
F.boolean_mask(src, type_idx),
807-
F.boolean_mask(dst, type_idx),
808-
)
809-
if "DGL_DIST_DEBUG" in os.environ:
810-
assert torch.all(
811-
src_ntype_id == src_ntype_ids[type_idx]
812-
), "source ntype is is not expected."
813-
assert torch.all(
814-
dst_ntype_id == dst_ntype_ids[type_idx]
815-
), "destination ntype is is not expected."
816-
if type_wise_eids is not None:
817-
edge_ids[etype] = F.boolean_mask(type_wise_eids, type_idx)
811+
data_dict[etype] = (
812+
F.boolean_mask(src, type_idx),
813+
F.boolean_mask(dst, type_idx),
814+
)
815+
if "DGL_DIST_DEBUG" in os.environ:
816+
assert torch.all(
817+
src_ntype_id == src_ntype_ids[type_idx]
818+
), "source ntype is is not expected."
819+
assert torch.all(
820+
dst_ntype_id == dst_ntype_ids[type_idx]
821+
), "destination ntype is is not expected."
822+
if type_wise_eids is not None:
823+
edge_ids[etype] = F.boolean_mask(type_wise_eids, type_idx)
818824
hg = heterograph(
819825
data_dict,
820826
{ntype: g.num_nodes(ntype) for ntype in g.ntypes},
@@ -832,6 +838,7 @@ def sample_etype_neighbors(
832838
fanout,
833839
edge_dir="in",
834840
prob=None,
841+
exclude_edges=None,
835842
replace=False,
836843
etype_sorted=True,
837844
use_graphbolt=False,
@@ -879,6 +886,8 @@ def sample_etype_neighbors(
879886
The features must be non-negative floats, and the sum of the features of
880887
inbound/outbound edges for every node must be positive (though they don't have
881888
to sum up to one). Otherwise, the result will be undefined.
889+
exclude_edges : tensor, optional
890+
The edges to exclude when sampling.
882891
replace : bool, optional
883892
If True, sample with replacement.
884893
@@ -947,6 +956,7 @@ def issue_remote_req(node_ids):
947956
fanout,
948957
edge_dir=edge_dir,
949958
prob=_prob,
959+
exclude_edges=exclude_edges,
950960
replace=replace,
951961
etype_sorted=etype_sorted,
952962
use_graphbolt=use_graphbolt,
@@ -974,6 +984,7 @@ def local_access(local_g, partition_book, local_nids):
974984
fanout,
975985
edge_dir=edge_dir,
976986
prob=_prob,
987+
exclude_edges=exclude_edges,
977988
replace=replace,
978989
etype_offset=etype_offset,
979990
etype_sorted=etype_sorted,

python/dgl/sampling/neighbor.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,11 @@ def _prepare_edge_arrays(g, arg):
4949
result.append(None)
5050

5151
result = [
52-
F.to_dgl_nd(F.copy_to(F.tensor([], dtype=dtype), ctx))
53-
if x is None
54-
else x
52+
(
53+
F.to_dgl_nd(F.copy_to(F.tensor([], dtype=dtype), ctx))
54+
if x is None
55+
else x
56+
)
5557
for x in result
5658
]
5759
return result
@@ -74,6 +76,7 @@ def sample_etype_neighbors(
7476
fanout,
7577
edge_dir="in",
7678
prob=None,
79+
exclude_edges=None,
7780
replace=False,
7881
copy_ndata=True,
7982
copy_edata=True,
@@ -116,6 +119,11 @@ def sample_etype_neighbors(
116119
117120
The features must be non-negative floats or boolean. Otherwise, the
118121
result will be undefined.
122+
exclude_edges: tensor or dict
123+
Edge IDs to exclude during sampling neighbors for the seed nodes.
124+
125+
This argument can take a single ID tensor or a dictionary of edge types and ID tensors.
126+
If a single tensor is given, the graph must only have one type of nodes.
119127
replace : bool, optional
120128
If True, sample with replacement.
121129
copy_ndata: bool, optional
@@ -154,6 +162,10 @@ def sample_etype_neighbors(
154162
As a result, users should avoid performing in-place operations
155163
on the node features of the new graph to avoid feature corruption.
156164
"""
165+
if exclude_edges is not None:
166+
raise DGLError(
167+
"exclude_edges is not supported for sample_etype_neighbors"
168+
)
157169
if g.device != F.cpu():
158170
raise DGLError("The graph should be in cpu.")
159171
# (BarclayII) because the homogenized graph no longer contains the *name* of edge

0 commit comments

Comments
 (0)