Skip to content

Commit bec6f17

Browse files
committed
[GraphBolt][CUDA] Cooperative Minibatching hetero bug fixes.
1 parent 4a6bfa4 commit bec6f17

File tree

5 files changed

+22
-19
lines changed

5 files changed

+22
-19
lines changed

python/dgl/graphbolt/feature_fetcher.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -166,13 +166,13 @@ def _cooperative_exchange(self, data):
166166
self.node_feature_keys, Dict
167167
) or isinstance(self.edge_feature_keys, Dict)
168168
if is_heterogeneous:
169-
node_features = {key: {} for key, _ in data.node_features.keys()}
170-
for (key, ntype), feature in data.node_features.items():
169+
node_features = {key: {} for _, key in data.node_features.keys()}
170+
for (ntype, key), feature in data.node_features.items():
171171
node_features[key][ntype] = feature
172172
for key, feature in node_features.items():
173173
new_feature = CooperativeConvFunction.apply(subgraph, feature)
174174
for ntype, tensor in new_feature.items():
175-
data.node_features[(key, ntype)] = tensor
175+
data.node_features[(ntype, key)] = tensor
176176
else:
177177
for key in data.node_features:
178178
feature = data.node_features[key]

python/dgl/graphbolt/impl/cooperative_conv.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,18 @@ def forward(
4444
outs = {}
4545
for ntype, typed_tensor in convert_to_hetero(tensor).items():
4646
out = typed_tensor.new_empty(
47-
(sum(counts_sent[ntype]),) + typed_tensor.shape[1:]
47+
(sum(counts_sent.get(ntype, [0])),) + typed_tensor.shape[1:],
48+
requires_grad=typed_tensor.requires_grad,
4849
)
4950
all_to_all(
50-
torch.split(out, counts_sent[ntype]),
51+
torch.split(out, counts_sent.get(ntype, 0)),
5152
torch.split(
52-
typed_tensor[seed_inverse_ids[ntype]],
53-
counts_received[ntype],
53+
typed_tensor[seed_inverse_ids.get(ntype, slice(None))],
54+
counts_received.get(ntype, 0),
5455
),
5556
)
5657
outs[ntype] = out
57-
return revert_to_homo(out)
58+
return revert_to_homo(outs)
5859

5960
@staticmethod
6061
def backward(

python/dgl/graphbolt/impl/neighbor_sampler.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,7 @@ def _seeds_cooperative_exchange_1(minibatch):
561561
seeds_offsets = {"_N": seeds_offsets}
562562
num_ntypes = len(seeds_offsets)
563563
counts_sent = torch.empty(world_size * num_ntypes, dtype=torch.int64)
564-
for i, offsets in enumerate(seeds_offsets.values()):
564+
for i, (_, offsets) in enumerate(sorted(seeds_offsets.items())):
565565
counts_sent[
566566
torch.arange(i, world_size * num_ntypes, num_ntypes)
567567
] = offsets.diff()
@@ -589,7 +589,7 @@ def _seeds_cooperative_exchange_2(minibatch):
589589
seeds_received = {}
590590
counts_sent = {}
591591
counts_received = {}
592-
for i, (ntype, typed_seeds) in enumerate(seeds.items()):
592+
for i, (ntype, typed_seeds) in enumerate(sorted(seeds.items())):
593593
idx = torch.arange(i, world_size * num_ntypes, num_ntypes)
594594
typed_counts_sent = subgraph._counts_sent[idx].tolist()
595595
typed_counts_received = subgraph._counts_received[idx].tolist()

python/dgl/graphbolt/subgraph_sampler.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,9 @@ def _seeds_cooperative_exchange_1_wait_future(minibatch):
236236
else:
237237
minibatch._seeds_offsets = {"_N": minibatch._seeds_offsets}
238238
counts_sent = torch.empty(world_size * num_ntypes, dtype=torch.int64)
239-
for i, offsets in enumerate(minibatch._seeds_offsets.values()):
239+
for i, (_, offsets) in enumerate(
240+
sorted(minibatch._seeds_offsets.items())
241+
):
240242
counts_sent[
241243
torch.arange(i, world_size * num_ntypes, num_ntypes)
242244
] = offsets.diff()
@@ -261,7 +263,7 @@ def _seeds_cooperative_exchange_2(minibatch):
261263
seeds_received = {}
262264
counts_sent = {}
263265
counts_received = {}
264-
for i, (ntype, typed_seeds) in enumerate(seeds.items()):
266+
for i, (ntype, typed_seeds) in enumerate(sorted(seeds.items())):
265267
idx = torch.arange(i, world_size * num_ntypes, num_ntypes)
266268
typed_counts_sent = minibatch._counts_sent[idx].tolist()
267269
typed_counts_received = minibatch._counts_received[idx].tolist()

tests/python/pytorch/graphbolt/impl/test_cooperative_minibatching_utils.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,12 @@ def test_rank_sort_and_unique_and_compact(dtype, rank):
5757
nodes1[off1[j] : off1[j + 1]], nodes4[off4[i] : off4[i + 1]]
5858
)
5959

60-
unique, compacted, offsets = gb.unique_and_compact(
61-
nodes_list1[:1], rank, WORLD_SIZE
62-
)
60+
nodes = {str(i): [typed_seeds] for i, typed_seeds in enumerate(nodes_list1)}
6361

64-
nodes1, idx1, offsets1 = res1[0]
62+
unique, compacted, offsets = gb.unique_and_compact(nodes, rank, WORLD_SIZE)
6563

66-
assert_equal(unique, nodes1)
67-
assert_equal(compacted[0], idx1)
68-
assert_equal(offsets, offsets1)
64+
for i in nodes.keys():
65+
nodes1, idx1, offsets1 = res1[int(i)]
66+
assert_equal(unique[i], nodes1)
67+
assert_equal(compacted[i][0], idx1)
68+
assert_equal(offsets[i], offsets1)

0 commit comments

Comments
 (0)