Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 175 additions & 1 deletion python/tests/unit/distributed/distributed_neighborloader_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import unittest
from collections.abc import Mapping
from typing import Optional, Union
from typing import Literal, Optional, Union

import torch
import torch.multiprocessing as mp
from graphlearn_torch.distributed import shutdown_rpc
from graphlearn_torch.typing import reverse_edge_type
from parameterized import param, parameterized
from torch_geometric.data import Data, HeteroData

Expand Down Expand Up @@ -549,6 +550,87 @@ def _run_cora_supervised_node_classification(
shutdown_rpc()


def _run_subgraph_looks_as_expected_given_edge_direction(
_,
dataset: DistLinkPredictionDataset,
possible_edge_indices: dict[EdgeType, torch.Tensor],
):
torch.distributed.init_process_group(
rank=0, world_size=1, init_method=get_process_group_init_method()
)

assert isinstance(dataset.node_ids, Mapping)

user_loader = DistNeighborLoader(
dataset=dataset,
input_nodes=(_USER, dataset.node_ids[_USER]),
num_neighbors=[2, 2],
pin_memory_device=torch.device("cpu"),
batch_size=1,
)

story_loader = DistNeighborLoader(
dataset=dataset,
input_nodes=(_STORY, dataset.node_ids[_STORY]),
num_neighbors=[2, 2],
pin_memory_device=torch.device("cpu"),
batch_size=1,
)

for user_datum, story_datum in zip(user_loader, story_loader):
for edge_type in user_datum.edge_types:
# First, we need to remap the edge index with local node ids in the HeteroData object to an edge index with the global node ids
global_src_nodes = user_datum[edge_type[0]].node
global_dst_nodes = user_datum[edge_type[2]].node
global_src_edge_index = global_src_nodes[
user_datum[edge_type].edge_index[0]
]
global_dst_edge_index = global_dst_nodes[
user_datum[edge_type].edge_index[1]
]
global_edge_index = torch.stack(
[global_src_edge_index, global_dst_edge_index], dim=0
)

# Then, we can compare the global edge index with the possible edge indices that can exist for this edge type
assert (
edge_type in possible_edge_indices
), f"User HeteroData contains edge type {edge_type} that is not in the expected graph edge types: {list(possible_edge_indices.keys())}"
matches = global_edge_index == possible_edge_indices[edge_type]
column_matches = matches.all(dim=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we use this car ever? Also can we document precisely what we're checking for here?

contains_column = column_matches.any()
assert (
contains_column
), f"User HeteroData contains an edge for edge type {edge_type} in {user_datum[edge_type].edge_index} that is not in the expected graph: {possible_edge_indices[edge_type]}"

for edge_type in story_datum.edge_types:
assert (
edge_type in possible_edge_indices
), f"Story HeteroData contains edge type {edge_type} that is not inthe expected graph edge types: {list(possible_edge_indices.keys())}"
# First, we need to remap the edge index with local node ids in the HeteroData object to an edge index with the global node ids
global_src_nodes = story_datum[edge_type[0]].node
global_dst_nodes = story_datum[edge_type[2]].node
global_src_edge_index = global_src_nodes[
story_datum[edge_type].edge_index[0]
]
global_dst_edge_index = global_dst_nodes[
story_datum[edge_type].edge_index[1]
]
global_edge_index = torch.stack(
[global_src_edge_index, global_dst_edge_index], dim=0
)

# Then, we can compare the global edge index with the expected reversed edge index from the input graph
matches = global_edge_index == possible_edge_indices[edge_type]
column_matches = matches.all(dim=0)
contains_column = column_matches.any()
assert (
contains_column
), f"User HeteroData contains an edge for edge type {edge_type} in {user_datum[edge_type].edge_index} that is not in the expected graph: {possible_edge_indices[edge_type]}"

shutdown_rpc()


class DistributedNeighborLoaderTest(unittest.TestCase):
def setUp(self):
self._master_ip_address = "localhost"
Expand Down Expand Up @@ -1049,6 +1131,98 @@ def test_cora_supervised_node_classification(self):
),
)

@parameterized.expand(
[
param(
"Test subgraph looks as expected given outward edge direction", "out"
),
param("Test subgraph looks as expected given inward edge direction", "in"),
]
)
def test_subgraph_looks_as_expected_given_edge_direction(
self, _, edge_direction: Literal["in", "out"]
):
# We define the graph here so that we have edges
# User -> Story
# 0 -> 0
# 1 -> 1
# 2 -> 2
# 3 -> 3
# 4 -> 4

# Story -> User
# 0 -> 1
# 1 -> 2
# 2 -> 3
# 3 -> 4
# 4 -> 0

user_to_story_edge_index = torch.tensor([[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]])
story_to_user_edge_index = torch.tensor([[0, 1, 2, 3, 4], [1, 2, 3, 4, 0]])
Comment on lines +1145 to +1161
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible for us to create some graphviz/etc for this graph and what is expected/etc?

Can be hard to visualize just looking at COO :P

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


partition_output = PartitionOutput(
node_partition_book={
_USER: torch.zeros(5),
_STORY: torch.zeros(5),
},
edge_partition_book={
_USER_TO_STORY: torch.zeros(5),
_STORY_TO_USER: torch.zeros(5),
},
partitioned_edge_index={
_USER_TO_STORY: GraphPartitionData(
edge_index=user_to_story_edge_index,
edge_ids=None,
),
_STORY_TO_USER: GraphPartitionData(
edge_index=story_to_user_edge_index,
edge_ids=None,
),
},
partitioned_node_features={
_USER: FeaturePartitionData(
feats=torch.zeros(5, 2), ids=torch.arange(5)
),
_STORY: FeaturePartitionData(
feats=torch.zeros(5, 2), ids=torch.arange(5)
),
},
partitioned_edge_features=None,
partitioned_positive_labels=None,
partitioned_negative_labels=None,
partitioned_node_labels=None,
)

dataset = DistLinkPredictionDataset(
rank=0, world_size=1, edge_dir=edge_direction
)
dataset.build(partition_output=partition_output)

if edge_direction == "out":
# If the edge direction is out, we expect the produced HeteroData object to have the edge type reversed and the
# edge index tensor also swapped. This is because GLT swaps the outward direction under-the-hood as a convenience for message passing:
# https://github.com/alibaba/graphlearn-for-pytorch/blob/main/graphlearn_torch/python/loader/transform.py#L116-L124
possible_edge_indices = {
reverse_edge_type(_USER_TO_STORY): user_to_story_edge_index[[1, 0], :],
reverse_edge_type(_STORY_TO_USER): story_to_user_edge_index[[1, 0], :],
}
else:
# If the edge direction is in, we expect the produced HeteroData object to have the edge type and edge tensor be the same as the input
# graph. This is because GLT swaps the inward direction under-the-hood as a convenience for message passing:
# https://github.com/alibaba/graphlearn-for-pytorch/blob/main/graphlearn_torch/python/loader/transform.py#L116-L124
possible_edge_indices = {
_USER_TO_STORY: user_to_story_edge_index,
_STORY_TO_USER: story_to_user_edge_index,
}

mp.spawn(
fn=_run_subgraph_looks_as_expected_given_edge_direction,
args=(
dataset,
possible_edge_indices,
),
)


if __name__ == "__main__":
unittest.main()