diff --git a/example/gsat.py b/example/gsat.py index 1a89172..6984ba6 100644 --- a/example/gsat.py +++ b/example/gsat.py @@ -41,7 +41,7 @@ def forward_pass(self, data, epoch, training): if self.learn_edge_att: if is_undirected(data.edge_index): nodesize = data.x.shape[0] - edge_att = (att + transpose(data.edge_index, att, nodesize, nodesize)[1]) / 2 + edge_att = (att + transpose(data.edge_index, att, nodesize, nodesize, coalesced=False)[1]) / 2 else: edge_att = att else: