Skip to content

Commit 20a35a7

Browse files
committed
fix the last bug hopefully.
1 parent 01b77dc commit 20a35a7

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

python/dgl/graphbolt/impl/cooperative_conv.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,12 @@ def forward(
4747
(sum(counts_sent.get(ntype, [0])),) + typed_tensor.shape[1:],
4848
requires_grad=typed_tensor.requires_grad,
4949
)
50+
default_splits = [0] * torch.distributed.get_world_size()
5051
all_to_all(
51-
torch.split(out, counts_sent.get(ntype, 0)),
52+
torch.split(out, counts_sent.get(ntype, default_splits)),
5253
torch.split(
5354
typed_tensor[seed_inverse_ids.get(ntype, slice(None))],
54-
counts_received.get(ntype, 0),
55+
counts_received.get(ntype, default_splits),
5556
),
5657
)
5758
outs[ntype] = out

0 commit comments

Comments
 (0)