-
Notifications
You must be signed in to change notification settings - Fork 8
Add test to ensure subgraph edges are correct given an edge direction #317
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,11 @@ | ||
import unittest | ||
from collections.abc import Mapping | ||
from typing import Optional, Union | ||
from typing import Literal, Optional, Union | ||
|
||
import torch | ||
import torch.multiprocessing as mp | ||
from graphlearn_torch.distributed import shutdown_rpc | ||
from graphlearn_torch.typing import reverse_edge_type | ||
from parameterized import param, parameterized | ||
from torch_geometric.data import Data, HeteroData | ||
|
||
|
@@ -549,6 +550,87 @@ def _run_cora_supervised_node_classification( | |
shutdown_rpc() | ||
|
||
|
||
def _run_subgraph_looks_as_expected_given_edge_direction( | ||
_, | ||
dataset: DistLinkPredictionDataset, | ||
possible_edge_indices: dict[EdgeType, torch.Tensor], | ||
): | ||
torch.distributed.init_process_group( | ||
rank=0, world_size=1, init_method=get_process_group_init_method() | ||
) | ||
|
||
assert isinstance(dataset.node_ids, Mapping) | ||
|
||
user_loader = DistNeighborLoader( | ||
dataset=dataset, | ||
input_nodes=(_USER, dataset.node_ids[_USER]), | ||
num_neighbors=[2, 2], | ||
pin_memory_device=torch.device("cpu"), | ||
batch_size=1, | ||
) | ||
|
||
story_loader = DistNeighborLoader( | ||
dataset=dataset, | ||
input_nodes=(_STORY, dataset.node_ids[_STORY]), | ||
num_neighbors=[2, 2], | ||
pin_memory_device=torch.device("cpu"), | ||
batch_size=1, | ||
) | ||
|
||
for user_datum, story_datum in zip(user_loader, story_loader): | ||
for edge_type in user_datum.edge_types: | ||
# First, we need to remap the edge index with local node ids in the HeteroData object to an edge index with the global node ids | ||
global_src_nodes = user_datum[edge_type[0]].node | ||
global_dst_nodes = user_datum[edge_type[2]].node | ||
global_src_edge_index = global_src_nodes[ | ||
user_datum[edge_type].edge_index[0] | ||
] | ||
global_dst_edge_index = global_dst_nodes[ | ||
user_datum[edge_type].edge_index[1] | ||
] | ||
global_edge_index = torch.stack( | ||
[global_src_edge_index, global_dst_edge_index], dim=0 | ||
) | ||
|
||
# Then, we can compare the global edge index with the possible edge indices that can exist for this edge type | ||
assert ( | ||
edge_type in possible_edge_indices | ||
), f"User HeteroData contains edge type {edge_type} that is not in the expected graph edge types: {list(possible_edge_indices.keys())}" | ||
matches = global_edge_index == possible_edge_indices[edge_type] | ||
column_matches = matches.all(dim=0) | ||
contains_column = column_matches.any() | ||
assert ( | ||
contains_column | ||
), f"User HeteroData contains an edge for edge type {edge_type} in {user_datum[edge_type].edge_index} that is not in the expected graph: {possible_edge_indices[edge_type]}" | ||
|
||
for edge_type in story_datum.edge_types: | ||
assert ( | ||
edge_type in possible_edge_indices | ||
), f"Story HeteroData contains edge type {edge_type} that is not inthe expected graph edge types: {list(possible_edge_indices.keys())}" | ||
# First, we need to remap the edge index with local node ids in the HeteroData object to an edge index with the global node ids | ||
global_src_nodes = story_datum[edge_type[0]].node | ||
global_dst_nodes = story_datum[edge_type[2]].node | ||
global_src_edge_index = global_src_nodes[ | ||
story_datum[edge_type].edge_index[0] | ||
] | ||
global_dst_edge_index = global_dst_nodes[ | ||
story_datum[edge_type].edge_index[1] | ||
] | ||
global_edge_index = torch.stack( | ||
[global_src_edge_index, global_dst_edge_index], dim=0 | ||
) | ||
|
||
# Then, we can compare the global edge index with the expected reversed edge index from the input graph | ||
matches = global_edge_index == possible_edge_indices[edge_type] | ||
column_matches = matches.all(dim=0) | ||
contains_column = column_matches.any() | ||
assert ( | ||
contains_column | ||
), f"User HeteroData contains an edge for edge type {edge_type} in {user_datum[edge_type].edge_index} that is not in the expected graph: {possible_edge_indices[edge_type]}" | ||
|
||
shutdown_rpc() | ||
|
||
|
||
class DistributedNeighborLoaderTest(unittest.TestCase): | ||
def setUp(self): | ||
self._master_ip_address = "localhost" | ||
|
@@ -1049,6 +1131,98 @@ def test_cora_supervised_node_classification(self): | |
), | ||
) | ||
|
||
@parameterized.expand( | ||
[ | ||
param( | ||
"Test subgraph looks as expected given outward edge direction", "out" | ||
), | ||
param("Test subgraph looks as expected given inward edge direction", "in"), | ||
] | ||
) | ||
def test_subgraph_looks_as_expected_given_edge_direction( | ||
self, _, edge_direction: Literal["in", "out"] | ||
): | ||
# We define the graph here so that we have edges | ||
# User -> Story | ||
# 0 -> 0 | ||
# 1 -> 1 | ||
# 2 -> 2 | ||
# 3 -> 3 | ||
# 4 -> 4 | ||
|
||
# Story -> User | ||
# 0 -> 1 | ||
# 1 -> 2 | ||
# 2 -> 3 | ||
# 3 -> 4 | ||
# 4 -> 0 | ||
|
||
user_to_story_edge_index = torch.tensor([[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]) | ||
story_to_user_edge_index = torch.tensor([[0, 1, 2, 3, 4], [1, 2, 3, 4, 0]]) | ||
Comment on lines
+1145
to
+1161
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is it possible for us to create some graphviz/etc for this graph and what is expected/etc? Can be hard to visualize just looking at COO :P There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. https://dreampuf.github.io/GraphvizOnline is a handy tool here. |
||
|
||
partition_output = PartitionOutput( | ||
node_partition_book={ | ||
_USER: torch.zeros(5), | ||
_STORY: torch.zeros(5), | ||
}, | ||
edge_partition_book={ | ||
_USER_TO_STORY: torch.zeros(5), | ||
_STORY_TO_USER: torch.zeros(5), | ||
}, | ||
partitioned_edge_index={ | ||
_USER_TO_STORY: GraphPartitionData( | ||
edge_index=user_to_story_edge_index, | ||
edge_ids=None, | ||
), | ||
_STORY_TO_USER: GraphPartitionData( | ||
edge_index=story_to_user_edge_index, | ||
edge_ids=None, | ||
), | ||
}, | ||
partitioned_node_features={ | ||
_USER: FeaturePartitionData( | ||
feats=torch.zeros(5, 2), ids=torch.arange(5) | ||
), | ||
_STORY: FeaturePartitionData( | ||
feats=torch.zeros(5, 2), ids=torch.arange(5) | ||
), | ||
}, | ||
partitioned_edge_features=None, | ||
partitioned_positive_labels=None, | ||
partitioned_negative_labels=None, | ||
partitioned_node_labels=None, | ||
) | ||
|
||
dataset = DistLinkPredictionDataset( | ||
rank=0, world_size=1, edge_dir=edge_direction | ||
) | ||
dataset.build(partition_output=partition_output) | ||
|
||
if edge_direction == "out": | ||
# If the edge direction is out, we expect the produced HeteroData object to have the edge type reversed and the | ||
# edge index tensor also swapped. This is because GLT swaps the outward direction under-the-hood as a convenience for message passing: | ||
# https://github.com/alibaba/graphlearn-for-pytorch/blob/main/graphlearn_torch/python/loader/transform.py#L116-L124 | ||
possible_edge_indices = { | ||
reverse_edge_type(_USER_TO_STORY): user_to_story_edge_index[[1, 0], :], | ||
reverse_edge_type(_STORY_TO_USER): story_to_user_edge_index[[1, 0], :], | ||
} | ||
else: | ||
# If the edge direction is in, we expect the produced HeteroData object to have the edge type and edge tensor be the same as the input | ||
# graph. This is because GLT swaps the inward direction under-the-hood as a convenience for message passing: | ||
# https://github.com/alibaba/graphlearn-for-pytorch/blob/main/graphlearn_torch/python/loader/transform.py#L116-L124 | ||
possible_edge_indices = { | ||
_USER_TO_STORY: user_to_story_edge_index, | ||
_STORY_TO_USER: story_to_user_edge_index, | ||
} | ||
|
||
mp.spawn( | ||
fn=_run_subgraph_looks_as_expected_given_edge_direction, | ||
args=( | ||
dataset, | ||
possible_edge_indices, | ||
), | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we use this car ever? Also can we document precisely what we're checking for here?