diff --git a/python/tests/unit/distributed/distributed_neighborloader_test.py b/python/tests/unit/distributed/distributed_neighborloader_test.py index eed65570..89052caf 100644 --- a/python/tests/unit/distributed/distributed_neighborloader_test.py +++ b/python/tests/unit/distributed/distributed_neighborloader_test.py @@ -1,10 +1,11 @@ import unittest from collections.abc import Mapping -from typing import Optional, Union +from typing import Literal, Optional, Union import torch import torch.multiprocessing as mp from graphlearn_torch.distributed import shutdown_rpc +from graphlearn_torch.typing import reverse_edge_type from parameterized import param, parameterized from torch_geometric.data import Data, HeteroData @@ -549,6 +550,87 @@ def _run_cora_supervised_node_classification( shutdown_rpc() +def _run_subgraph_looks_as_expected_given_edge_direction( + _, + dataset: DistLinkPredictionDataset, + possible_edge_indices: dict[EdgeType, torch.Tensor], +): + torch.distributed.init_process_group( + rank=0, world_size=1, init_method=get_process_group_init_method() + ) + + assert isinstance(dataset.node_ids, Mapping) + + user_loader = DistNeighborLoader( + dataset=dataset, + input_nodes=(_USER, dataset.node_ids[_USER]), + num_neighbors=[2, 2], + pin_memory_device=torch.device("cpu"), + batch_size=1, + ) + + story_loader = DistNeighborLoader( + dataset=dataset, + input_nodes=(_STORY, dataset.node_ids[_STORY]), + num_neighbors=[2, 2], + pin_memory_device=torch.device("cpu"), + batch_size=1, + ) + + for user_datum, story_datum in zip(user_loader, story_loader): + for edge_type in user_datum.edge_types: + # First, we need to remap the edge index with local node ids in the HeteroData object to an edge index with the global node ids + global_src_nodes = user_datum[edge_type[0]].node + global_dst_nodes = user_datum[edge_type[2]].node + global_src_edge_index = global_src_nodes[ + user_datum[edge_type].edge_index[0] + ] + global_dst_edge_index = global_dst_nodes[ + user_datum[edge_type].edge_index[1] + ] + global_edge_index = torch.stack( + [global_src_edge_index, global_dst_edge_index], dim=0 + ) + + # Then, we can compare the global edge index with the possible edge indices that can exist for this edge type + assert ( + edge_type in possible_edge_indices + ), f"User HeteroData contains edge type {edge_type} that is not in the expected graph edge types: {list(possible_edge_indices.keys())}" + matches = global_edge_index == possible_edge_indices[edge_type] + column_matches = matches.all(dim=0) + contains_column = column_matches.any() + assert ( + contains_column + ), f"User HeteroData contains an edge for edge type {edge_type} in {user_datum[edge_type].edge_index} that is not in the expected graph: {possible_edge_indices[edge_type]}" + + for edge_type in story_datum.edge_types: + assert ( + edge_type in possible_edge_indices + ), f"Story HeteroData contains edge type {edge_type} that is not inthe expected graph edge types: {list(possible_edge_indices.keys())}" + # First, we need to remap the edge index with local node ids in the HeteroData object to an edge index with the global node ids + global_src_nodes = story_datum[edge_type[0]].node + global_dst_nodes = story_datum[edge_type[2]].node + global_src_edge_index = global_src_nodes[ + story_datum[edge_type].edge_index[0] + ] + global_dst_edge_index = global_dst_nodes[ + story_datum[edge_type].edge_index[1] + ] + global_edge_index = torch.stack( + [global_src_edge_index, global_dst_edge_index], dim=0 + ) + + # Then, we can compare the global edge index with the expected reversed edge index from the input graph + matches = global_edge_index == possible_edge_indices[edge_type] + column_matches = matches.all(dim=0) + contains_column = column_matches.any() + assert ( + contains_column + ), f"User HeteroData contains an edge for edge type {edge_type} in {user_datum[edge_type].edge_index} that is not in the expected graph: {possible_edge_indices[edge_type]}" + + shutdown_rpc() + + class DistributedNeighborLoaderTest(unittest.TestCase): def setUp(self): self._master_ip_address = "localhost" @@ -1049,6 +1131,98 @@ def test_cora_supervised_node_classification(self): ), ) + @parameterized.expand( + [ + param( + "Test subgraph looks as expected given outward edge direction", "out" + ), + param("Test subgraph looks as expected given inward edge direction", "in"), + ] + ) + def test_subgraph_looks_as_expected_given_edge_direction( + self, _, edge_direction: Literal["in", "out"] + ): + # We define the graph here so that we have edges + # User -> Story + # 0 -> 0 + # 1 -> 1 + # 2 -> 2 + # 3 -> 3 + # 4 -> 4 + + # Story -> User + # 0 -> 1 + # 1 -> 2 + # 2 -> 3 + # 3 -> 4 + # 4 -> 0 + + user_to_story_edge_index = torch.tensor([[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]) + story_to_user_edge_index = torch.tensor([[0, 1, 2, 3, 4], [1, 2, 3, 4, 0]]) + + partition_output = PartitionOutput( + node_partition_book={ + _USER: torch.zeros(5), + _STORY: torch.zeros(5), + }, + edge_partition_book={ + _USER_TO_STORY: torch.zeros(5), + _STORY_TO_USER: torch.zeros(5), + }, + partitioned_edge_index={ + _USER_TO_STORY: GraphPartitionData( + edge_index=user_to_story_edge_index, + edge_ids=None, + ), + _STORY_TO_USER: GraphPartitionData( + edge_index=story_to_user_edge_index, + edge_ids=None, + ), + }, + partitioned_node_features={ + _USER: FeaturePartitionData( + feats=torch.zeros(5, 2), ids=torch.arange(5) + ), + _STORY: FeaturePartitionData( + feats=torch.zeros(5, 2), ids=torch.arange(5) + ), + }, + partitioned_edge_features=None, + partitioned_positive_labels=None, + partitioned_negative_labels=None, + partitioned_node_labels=None, + ) + + dataset = DistLinkPredictionDataset( + rank=0, world_size=1, edge_dir=edge_direction + ) + dataset.build(partition_output=partition_output) + + if edge_direction == "out": + # If the edge direction is out, we expect the produced HeteroData object to have the edge type reversed and the + # edge index tensor also swapped. This is because GLT swaps the outward direction under-the-hood as a convenience for message passing: + # https://github.com/alibaba/graphlearn-for-pytorch/blob/main/graphlearn_torch/python/loader/transform.py#L116-L124 + possible_edge_indices = { + reverse_edge_type(_USER_TO_STORY): user_to_story_edge_index[[1, 0], :], + reverse_edge_type(_STORY_TO_USER): story_to_user_edge_index[[1, 0], :], + } + else: + # If the edge direction is in, we expect the produced HeteroData object to have the edge type and edge tensor be the same as the input + # graph. This is because GLT swaps the inward direction under-the-hood as a convenience for message passing: + # https://github.com/alibaba/graphlearn-for-pytorch/blob/main/graphlearn_torch/python/loader/transform.py#L116-L124 + possible_edge_indices = { + _USER_TO_STORY: user_to_story_edge_index, + _STORY_TO_USER: story_to_user_edge_index, + } + + mp.spawn( + fn=_run_subgraph_looks_as_expected_given_edge_direction, + args=( + dataset, + possible_edge_indices, + ), + ) + if __name__ == "__main__": unittest.main()