From 842fd7c3f9b9aac0a234b3bd13481a019195c529 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Tue, 9 Sep 2025 00:47:20 +0000 Subject: [PATCH 1/3] Initial commit --- .../libs/template_config_checks.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/python/gigl/src/validation_check/libs/template_config_checks.py b/python/gigl/src/validation_check/libs/template_config_checks.py index 588f92782..78ec2e3df 100644 --- a/python/gigl/src/validation_check/libs/template_config_checks.py +++ b/python/gigl/src/validation_check/libs/template_config_checks.py @@ -109,17 +109,21 @@ def check_if_task_metadata_valid( assert ( len(task_metadata_pb.supervision_edge_types) > 0 ), "Must provide at least one supervision edge type." - graph_metadata_pb_edge_types = [ - GbmlProtosTranslator.edge_type_from_EdgeTypePb(edge_type_pb=edge_type_pb) - for edge_type_pb in graph_metadata_pb.edge_types + graph_metadata_node_types = [ + GbmlProtosTranslator.node_type_from_NodeTypePb( + node_type_pb=graph_metadata_pb.node_types + ) ] for edge_type_pb in task_metadata_pb.supervision_edge_types: edge_type = GbmlProtosTranslator.edge_type_from_EdgeTypePb( edge_type_pb=edge_type_pb ) assert ( - edge_type in graph_metadata_pb_edge_types - ), f"Invalid supervision edge type: {edge_type}; not found in graphMetadata edge types {graph_metadata_pb_edge_types}." + edge_type.src_node_type in graph_metadata_node_types + ), f"Invalid supervision edge type: {edge_type}; which contains a source node type not found in graphMetadata node types: {graph_metadata_node_types}." + assert ( + edge_type.dst_node_type in graph_metadata_node_types + ), f"Invalid supervision edge type: {edge_type}; which contains a destination node type not found in graphMetadata node types: {graph_metadata_node_types}." else: raise ValueError( f"Invalid 'taskMetadata'; must be one of {[TaskMetadataType.NODE_BASED_TASK, TaskMetadataType.NODE_ANCHOR_BASED_LINK_PREDICTION_TASK]}.", From 9d0c4b5542729bcd359674d97b1b26097ae86b41 Mon Sep 17 00:00:00 2001 From: mkolodner Date: Tue, 9 Sep 2025 00:51:58 +0000 Subject: [PATCH 2/3] Update --- .../src/validation_check/libs/template_config_checks.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/python/gigl/src/validation_check/libs/template_config_checks.py b/python/gigl/src/validation_check/libs/template_config_checks.py index 78ec2e3df..7776c3ad3 100644 --- a/python/gigl/src/validation_check/libs/template_config_checks.py +++ b/python/gigl/src/validation_check/libs/template_config_checks.py @@ -109,11 +109,7 @@ def check_if_task_metadata_valid( assert ( len(task_metadata_pb.supervision_edge_types) > 0 ), "Must provide at least one supervision edge type." - graph_metadata_node_types = [ - GbmlProtosTranslator.node_type_from_NodeTypePb( - node_type_pb=graph_metadata_pb.node_types - ) - ] + graph_metadata_node_types = graph_metadata_pb.node_types for edge_type_pb in task_metadata_pb.supervision_edge_types: edge_type = GbmlProtosTranslator.edge_type_from_EdgeTypePb( edge_type_pb=edge_type_pb From 8a01adb893eae51109e7bd13a65b779c87e941ee Mon Sep 17 00:00:00 2001 From: mkolodner Date: Tue, 9 Sep 2025 01:07:03 +0000 Subject: [PATCH 3/3] update --- .../validation/task_metadata_is_valid_test.py | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 python/tests/unit/src/validation/task_metadata_is_valid_test.py diff --git a/python/tests/unit/src/validation/task_metadata_is_valid_test.py b/python/tests/unit/src/validation/task_metadata_is_valid_test.py new file mode 100644 index 000000000..ece586beb --- /dev/null +++ b/python/tests/unit/src/validation/task_metadata_is_valid_test.py @@ -0,0 +1,71 @@ +import unittest + +from gigl.src.validation_check.libs.template_config_checks import ( + check_if_task_metadata_valid, +) +from snapchat.research.gbml import gbml_config_pb2, graph_schema_pb2 +from tests.test_assets.graph_metadata_constants import ( + DEFAULT_HOMOGENEOUS_GRAPH_METADATA_PB, + DEFAULT_HOMOGENEOUS_NODE_TYPE_STR, +) + + +class TaskMetadataIsValidTest(unittest.TestCase): + """ + Tests for the check_if_task_metadata_valid function. + Tests edge validation behavior for link prediction tasks. + """ + + def _create_link_prediction_task_config( + self, + supervision_edge_types: list[graph_schema_pb2.EdgeType], + graph_metadata: graph_schema_pb2.GraphMetadata, + ) -> gbml_config_pb2.GbmlConfig: + """Helper method to create a node-anchor-based link prediction task configuration.""" + + return gbml_config_pb2.GbmlConfig( + task_metadata=gbml_config_pb2.GbmlConfig.TaskMetadata( + node_anchor_based_link_prediction_task_metadata=gbml_config_pb2.GbmlConfig.TaskMetadata.NodeAnchorBasedLinkPredictionTaskMetadata( + supervision_edge_types=supervision_edge_types + ) + ), + graph_metadata=graph_metadata, + ) + + def test_link_prediction_task_edge_with_invalid_node_types_raises_error(self): + """Test that error is raised when supervision edge has node types not in graph metadata.""" + # Create an edge type with node types that don't exist in graph metadata + edge_with_invalid_nodes = graph_schema_pb2.EdgeType( + src_node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE_STR, # valid node type + relation="to", + dst_node_type="nonexistent_dst_node_type", # invalid destination node type + ) + config = self._create_link_prediction_task_config( + supervision_edge_types=[edge_with_invalid_nodes], + graph_metadata=DEFAULT_HOMOGENEOUS_GRAPH_METADATA_PB, + ) + + with self.assertRaises(AssertionError): + check_if_task_metadata_valid(config) + + def test_link_prediction_task_edge_not_in_graph_metadata_but_nodes_valid_passes( + self, + ): + """Test that no error is raised when edge type is not in graph metadata but node types are valid.""" + # Create an edge type with valid node types but a relation that doesn't exist in graph metadata + edge_with_new_relation = graph_schema_pb2.EdgeType( + src_node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE_STR, # Valid node type + relation="completely_new_relation", # This relation doesn't exist in graph metadata + dst_node_type=DEFAULT_HOMOGENEOUS_NODE_TYPE_STR, # Valid node type + ) + config = self._create_link_prediction_task_config( + supervision_edge_types=[edge_with_new_relation], + graph_metadata=DEFAULT_HOMOGENEOUS_GRAPH_METADATA_PB, + ) + + # This should not raise any errors + check_if_task_metadata_valid(config) + + +if __name__ == "__main__": + unittest.main()