From fe6cb9ad7cafe18d5e0ba84ee7686379f638fbdc Mon Sep 17 00:00:00 2001 From: Mingtian Yin Date: Fri, 8 Aug 2025 12:05:49 -0700 Subject: [PATCH 1/4] Include json schema in the graph schema representation Allow graph schema to include the json fields. This is useful for QA chain to handle queries that refers to json properties. For example, for node: Company(details = Json('market_cap', ...)) graph schema will include 'market_cap' as a json_fields of `details` in the schema representation, so that when a user ask for `get market capitalization of company`, QA chain can understand which subfield to refer to. Note: - json property schema is done via by inspecting the first non-null property. - other changes: refactor the tests, improve the logging --- src/langchain_google_spanner/graph_store.py | 162 +++++- tests/integration/test_spanner_graph_store.py | 498 +++++++++--------- 2 files changed, 397 insertions(+), 263 deletions(-) diff --git a/src/langchain_google_spanner/graph_store.py b/src/langchain_google_spanner/graph_store.py index 93b42ec..d7f7571 100644 --- a/src/langchain_google_spanner/graph_store.py +++ b/src/langchain_google_spanner/graph_store.py @@ -15,6 +15,7 @@ from __future__ import annotations import json +import logging import re import string from abc import ABC, abstractmethod @@ -35,6 +36,8 @@ EDGE_KIND = "EDGE" USER_AGENT_GRAPH_STORE = "langchain-google-spanner-python:graphstore/" + __version__ +logger = logging.getLogger(__name__) + class NodeWrapper(object): """Wrapper around Node to support set operations using node id""" @@ -763,6 +766,69 @@ def __init__(self, node_name: str, node_keys: List[str], edge_keys: List[str]): self.edge_keys = edge_keys +class JsonSchema(object): + NODE_JSON_PROPERTY_QUERY_TEMPLATE = """ + GRAPH `{graph_id}` + MATCH (n:`{label_name}`) + WHERE n.`{property_name}` IS NOT NULL + LET j = n.`{property_name}` + LIMIT 1 + LET keys = JSON_KEYS(j, 1) + FOR key IN keys + LET v = j[key] + LET type = json_type(v) + RETURN key, type + """ + + EDGE_JSON_PROPERTY_QUERY_TEMPLATE = """ + GRAPH `{graph_id}` + MATCH -[n:`{label_name}`]-> + WHERE n.`{property_name}` IS NOT NULL + LET j = n.`{property_name}` + LIMIT 1 + LET keys = JSON_KEYS(j, 1) + FOR key IN keys + LET v = j[key] + let type = json_type(v) + RETURN key, type + """ + + def __init__(self, graph_name: str, impl: SpannerInterface): + self._graph_name = graph_name + self._impl = impl + + def get_node_json_property_schema(self, node_label: str, property_names: List[str]): + return self._get_label_json_property_schema( + node_label, property_names, self.NODE_JSON_PROPERTY_QUERY_TEMPLATE + ) + + def get_edge_json_property_schema(self, edge_label: str, property_names: List[str]): + return self._get_label_json_property_schema( + edge_label, property_names, self.EDGE_JSON_PROPERTY_QUERY_TEMPLATE + ) + + def _get_label_json_property_schema( + self, label: str, property_names: List[str], query_template: str + ): + if len(property_names) == 0: + return CaseInsensitiveDict({}) + return CaseInsensitiveDict( + { + pname: [ + row + for row in self._impl.query( + query_template.format( + graph_id=self._graph_name, + label_name=label, + property_name=pname, + ) + ) + ] + for pname in property_names + } + ) + + class SpannerGraphSchema(object): """Schema representation of a property graph.""" @@ -778,6 +844,7 @@ def __init__( use_flexible_schema: bool, static_node_properties: List[str] = [], static_edge_properties: List[str] = [], + json_schema: Optional[JsonSchema] = None, ): """Initializes the graph schema. @@ -805,9 +872,16 @@ def __init__( self.edge_tables: CaseInsensitiveDict[ElementSchema] = CaseInsensitiveDict({}) self.labels: CaseInsensitiveDict[Label] = CaseInsensitiveDict({}) self.properties: CaseInsensitiveDict[param_types.Type] = CaseInsensitiveDict({}) + self.node_json_property_schema: CaseInsensitiveDict[Dict] = CaseInsensitiveDict( + {} + ) + self.edge_json_property_schema: CaseInsensitiveDict[Dict] = CaseInsensitiveDict( + {} + ) self.use_flexible_schema = use_flexible_schema self.static_node_properties = set(static_node_properties) self.static_edge_properties = set(static_edge_properties) + self.json_schema = json_schema def evolve(self, graph_documents: List[GraphDocument]) -> List[str]: """Evolves current schema into a schema representing the input documents. @@ -861,11 +935,13 @@ def from_information_schema(self, info_schema: Dict[str, Any]) -> None: node_schema = ElementSchema.from_info_schema(node, decl_by_types) self._update_node_schema(node_schema) self._update_labels_and_properties(node_schema) + self._update_json_property_schema(node_schema) for edge in info_schema.get("edgeTables", []): edge_schema = ElementSchema.from_info_schema(edge, decl_by_types) self._update_edge_schema(edge_schema) self._update_labels_and_properties(edge_schema) + self._update_json_property_schema(edge_schema) def node_type_name(self, name: str) -> str: return NODE_KIND if self.use_flexible_schema else name @@ -952,26 +1028,36 @@ def __repr__(self) -> str: triplets_per_label.setdefault(label, []).append( (source_node, edge, target_node) ) + + def repr_property(lname, pname, ptype, json_fields): + if not json_fields: + return {"name": pname, "type": ptype} + return {"name": pname, "type": ptype, "json_fields": json_fields} + return json.dumps( { "Name of graph": self.graph_name, "Node properties per node label": { label: [ - { - "name": name, - "type": properties[name], - } - for name in sorted(self.labels[label].prop_names) + repr_property( + label, + pname, + properties[pname], + self.node_json_property_schema.get(label, {}).get(pname), + ) + for pname in sorted(self.labels[label].prop_names) ] for label in sorted(node_labels) }, "Edge properties per edge label": { label: [ - { - "name": name, - "type": properties[name], - } - for name in sorted(self.labels[label].prop_names) + repr_property( + label, + pname, + properties[pname], + self.edge_json_property_schema.get(label, {}).get(pname), + ) + for pname in sorted(self.labels[label].prop_names) ] for label in sorted(edge_labels) }, @@ -1124,6 +1210,37 @@ def _update_edge_schema(self, edge_schema: ElementSchema) -> List[str]: self.edges[edge_schema.name] = old_schema or edge_schema return ddls + def _update_json_property_schema(self, element_schema: ElementSchema) -> None: + if self.json_schema is None: + return + if len(element_schema.labels) == 0: + return + lname = element_schema.labels[0] + if element_schema.kind == NODE_KIND: + json_property_schema = self.json_schema.get_node_json_property_schema( + lname, + [ + pname + for pname, ptype in element_schema.types.items() + if ptype == param_types.JSON + ], + ) + self.node_json_property_schema.update( + {l: json_property_schema for l in element_schema.labels} + ) + else: + json_property_schema = self.json_schema.get_edge_json_property_schema( + lname, + [ + pname + for pname, ptype in element_schema.types.items() + if ptype == param_types.JSON + ], + ) + self.edge_json_property_schema.update( + {l: json_property_schema for l in element_schema.labels} + ) + def _update_labels_and_properties(self, element_schema: ElementSchema) -> None: """Updates labels and properties based on an element schema. @@ -1176,7 +1293,6 @@ def add_edges( """ edge_schema = self.get_edge_schema(self.edge_type_name(name)) if edge_schema is None: - print(list(self.edges.keys())) raise ValueError("Unknown edge schema `%s`" % name) for v in edge_schema.add_edges(name, edges): yield v @@ -1265,7 +1381,7 @@ def apply_ddls(self, ddls: List[str], options: Dict[str, Any] = {}) -> None: return op = self.database.update_ddl(ddl_statements=ddls) - print("Waiting for DDL operations to complete...") + logger.info("Waiting for DDL operations to complete...") return op.result(options.get("timeout", DEFAULT_DDL_TIMEOUT)) def insert_or_update( @@ -1291,6 +1407,7 @@ def __init__( static_edge_properties: List[str] = [], impl: Optional[SpannerInterface] = None, timeout: Optional[float] = None, + include_json_schema: bool = False, ): """Initializes SpannerGraphStore. @@ -1306,7 +1423,9 @@ def __init__( static_edge_properties: in flexible schema, treat these edge properties as static. timeout (Optional[float]): The timeout for queries in seconds. + include_json_schema (Optional[bool]): Whether to include json fields in the schema. """ + self.graph_name = graph_name self.impl = impl or SpannerImpl( instance_id, database_id, @@ -1318,6 +1437,9 @@ def __init__( use_flexible_schema, static_node_properties, static_edge_properties, + json_schema=( + JsonSchema(graph_name, self.impl) if include_json_schema else None + ), ) self.refresh_schema() @@ -1345,25 +1467,28 @@ def add_graph_documents( ddls = self.schema.evolve(graph_documents) if ddls: self.impl.apply_ddls(ddls) - self.refresh_schema() else: - print("No schema change required...") + logger.info("No schema change required...") nodes, edges = partition_graph_docs(graph_documents) for name, elements in nodes.items(): if len(elements) == 0: continue for table, columns, rows in self.schema.add_nodes(name, elements): - print("Insert nodes of type `{}`...".format(name)) + logger.info("Insert nodes of type `{}`...".format(name)) self.impl.insert_or_update(table, columns, rows) for name, elements in edges.items(): if len(elements) == 0: continue for table, columns, rows in self.schema.add_edges(name, elements): - print("Insert edges of type `{}`...".format(name)) + logger.info("Insert edges of type `{}`...".format(name)) self.impl.insert_or_update(table, columns, rows) + # Refresh schema after data insertion because json property is sampled + # over the actual data. + self.refresh_schema() + def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]: """Query Spanner database. @@ -1435,5 +1560,8 @@ def cleanup(self): ] ) self.schema = SpannerGraphSchema( - self.schema.graph_name, self.schema.use_flexible_schema + self.schema.graph_name, self.schema.use_flexible_schema, + self.schema.static_node_properties, + self.schema.static_edge_properties, + self.schema.json_schema ) diff --git a/tests/integration/test_spanner_graph_store.py b/tests/integration/test_spanner_graph_store.py index 271dba3..ce1328a 100644 --- a/tests/integration/test_spanner_graph_store.py +++ b/tests/integration/test_spanner_graph_store.py @@ -163,43 +163,50 @@ def random_graph_doc(suffix): ) +@pytest.fixture +def setup_graph(request): + use_flexible_schema = request.getfixturevalue("use_flexible_schema") + include_json_schema = request.getfixturevalue("include_json_schema") + suffix = random_string(num_char=5, exclude_whitespaces=True) + graph_name = "test_graph{}".format(suffix) + graph = SpannerGraphStore( + instance_id, + google_database, + graph_name, + client=Client(project=project_id), + use_flexible_schema=use_flexible_schema, + static_node_properties=["a", "b"], + static_edge_properties=["a", "b"], + include_json_schema=include_json_schema, + ) + graph.refresh_schema() + + yield suffix, graph + + print("Clean up graph with name `{}`".format(graph.graph_name)) + graph.cleanup() + + class TestSpannerGraphStore: @pytest.mark.parametrize("use_flexible_schema", [False, True]) - def test_spanner_graph_random_doc(self, use_flexible_schema): - suffix = random_string(num_char=5, exclude_whitespaces=True) - graph_name = "test_graph{}".format(suffix) - graph = SpannerGraphStore( - instance_id, - google_database, - graph_name, - client=Client(project=project_id), - use_flexible_schema=use_flexible_schema, - static_node_properties=random_property_names( - random_int(l=0, u=len(properties)) - ), - static_edge_properties=random_property_names( - random_int(l=0, u=len(properties)) - ), - ) - graph.refresh_schema() - - try: - node_ids = set() - edge_ids = set() - for _ in range(3): - graph_doc = random_graph_doc(suffix) - graph.add_graph_documents([graph_doc]) - node_ids.update({(n.type, n.id) for n in graph_doc.nodes}) - edge_ids.update( - { - (e.type, e.source.id, e.target.id) - for e in graph_doc.relationships - } - ) - graph.refresh_schema() + @pytest.mark.parametrize("include_json_schema", [False, True]) + def test_spanner_graph_random_doc( + self, setup_graph, use_flexible_schema, include_json_schema + ): + suffix, graph = setup_graph + node_ids = set() + edge_ids = set() + for _ in range(3): + graph_doc = random_graph_doc(suffix) + graph.add_graph_documents([graph_doc]) + node_ids.update({(n.type, n.id) for n in graph_doc.nodes}) + edge_ids.update( + {(e.type, e.source.id, e.target.id) for e in graph_doc.relationships} + ) + graph.refresh_schema() - results = graph.query( - """ + results = graph.query( + """ GRAPH {} MATCH -> @@ -215,66 +222,41 @@ def test_spanner_graph_random_doc(self, use_flexible_schema): RETURN type, num_elements, @param AS param ORDER BY type """.format( - graph_name - ), - params={"param": random_param()}, - ) - assert len(results) == 2 - assert results[0]["type"] == "edge", "Mismatch type" - assert results[0]["num_elements"] == len( - edge_ids - ), "Mismatch number of edges" - assert results[1]["type"] == "node", "Mismatch type" - assert results[1]["num_elements"] == len( - node_ids - ), "Mismatch number of nodes" - - finally: - print("Clean up graph with name `{}`".format(graph_name)) - print(graph.get_schema) - print(graph.get_structured_schema) - print(graph.get_ddl()) - graph.cleanup() + graph.graph_name + ), + params={"param": random_param()}, + ) + assert len(results) == 2 + assert results[0]["type"] == "edge", "Mismatch type" + assert results[0]["num_elements"] == len(edge_ids), "Mismatch number of edges" + assert results[1]["type"] == "node", "Mismatch type" + assert results[1]["num_elements"] == len(node_ids), "Mismatch number of nodes" @pytest.mark.parametrize("use_flexible_schema", [False, True]) - def test_spanner_graph_doc_with_duplicate_elements(self, use_flexible_schema): - suffix = random_string(num_char=5, exclude_whitespaces=True) - graph_name = "test_graph{}".format(suffix) - graph = SpannerGraphStore( - instance_id, - google_database, - graph_name, - client=Client(project=project_id), - use_flexible_schema=use_flexible_schema, - static_node_properties=random_property_names( - random_int(l=0, u=len(properties)) - ), - static_edge_properties=random_property_names( - random_int(l=0, u=len(properties)) + @pytest.mark.parametrize("include_json_schema", [False, True]) + def test_spanner_graph_doc_with_duplicate_elements( + self, setup_graph, use_flexible_schema, include_json_schema + ): + suffix, graph = setup_graph + node0 = random_node("Node0{}".format(suffix)) + node1 = random_node("Node1{}".format(suffix)) + edge0 = random_edge("Edge01", node0, node1) + edge1 = random_edge("Edge01", node0, node1) + + doc = GraphDocument( + nodes=[node0, node1, node0, node1], + relationships=[edge0, edge1], + source=Document( + page_content="Hello, world!", + metadata={"source": "https://example.com"}, ), ) - graph.refresh_schema() - - try: - node0 = random_node("Node0{}".format(suffix)) - node1 = random_node("Node1{}".format(suffix)) - edge0 = random_edge("Edge01", node0, node1) - edge1 = random_edge("Edge01", node0, node1) - - doc = GraphDocument( - nodes=[node0, node1, node0, node1], - relationships=[edge0, edge1], - source=Document( - page_content="Hello, world!", - metadata={"source": "https://example.com"}, - ), - ) - graph.add_graph_documents([doc]) + graph.add_graph_documents([doc]) - # In the case of flexible schema, `properties` is a nested json - # field. - results = graph.query( - """ + # In the case of flexible schema, `properties` is a nested json + # field. + results = graph.query( + """ GRAPH {} MATCH -[e]-> @@ -282,72 +264,55 @@ def test_spanner_graph_doc_with_duplicate_elements(self, use_flexible_schema): RETURN COALESCE(properties.properties, JSON "{{}}") AS dynamic_properties, properties AS static_properties """.format( - graph_name - ), - params={"param": random_param()}, - ) - assert len(results) == 1 + graph.graph_name + ), + params={"param": random_param()}, + ) + assert len(results) == 1 - edge_properties = edge0.properties - edge_properties.update(edge1.properties) - missing_properties = set(edge_properties.keys()).difference( - set(results[0]["dynamic_properties"].keys()).union( - set(results[0]["static_properties"].keys()) - ) + edge_properties = edge0.properties + edge_properties.update(edge1.properties) + missing_properties = set(edge_properties.keys()).difference( + set(results[0]["dynamic_properties"].keys()).union( + set(results[0]["static_properties"].keys()) ) - print(edge0.properties) - print(edge1.properties) - print(results) - assert ( - len(missing_properties) == 0 - ), "Missing properties of edge: {}".format(missing_properties) - - finally: - print("Clean up graph with name `{}`".format(graph_name)) - graph.cleanup() + ) + assert len(missing_properties) == 0, "Missing properties of edge: {}".format( + missing_properties + ) @pytest.mark.parametrize("use_flexible_schema", [False, True]) - def test_spanner_graph_avoid_unnecessary_overwrite(self, use_flexible_schema): - suffix = random_string(num_char=5, exclude_whitespaces=True) - graph_name = "test_graph{}".format(suffix) - graph = SpannerGraphStore( - instance_id, - google_database, - graph_name, - client=Client(project=project_id), - use_flexible_schema=use_flexible_schema, - static_node_properties=["a", "b"], - static_edge_properties=["a", "b"], + @pytest.mark.parametrize("include_json_schema", [False, True]) + def test_spanner_graph_avoid_unnecessary_overwrite( + self, setup_graph, use_flexible_schema, include_json_schema + ): + suffix, graph = setup_graph + node0 = Node( + id=random_string(), + type="Node{}".format(suffix), + properties={"a": 1, "b": 1}, + ) + node1 = Node( + id=random_string(), + type="Node{}".format(suffix), + properties={"a": 1, "b": 1}, + ) + edge0 = Relationship( + source=node0, + target=node1, + type="Edge{}".format(suffix), + properties={"a": 1, "b": 1}, + ) + doc = GraphDocument( + nodes=[node0, node1], + relationships=[edge0], + source=Document( + page_content="Hello, world!", + metadata={"source": "https://example.com"}, + ), ) - graph.refresh_schema() - - try: - node0 = Node( - id=random_string(), - type="Node{}".format(suffix), - properties={"a": 1, "b": 1}, - ) - node1 = Node( - id=random_string(), - type="Node{}".format(suffix), - properties={"a": 1, "b": 1}, - ) - edge0 = Relationship( - source=node0, - target=node1, - type="Edge{}".format(suffix), - properties={"a": 1, "b": 1}, - ) - doc = GraphDocument( - nodes=[node0, node1], - relationships=[edge0], - source=Document( - page_content="Hello, world!", - metadata={"source": "https://example.com"}, - ), - ) - query = """GRAPH {} + query = """GRAPH {} MATCH (n {{id: @nodeId}}) LET properties = TO_JSON(n)['properties'] RETURN int64(properties.a) AS a, int64(properties.b) AS b @@ -356,50 +321,35 @@ def test_spanner_graph_avoid_unnecessary_overwrite(self, use_flexible_schema): LET properties = TO_JSON(e)['properties'] RETURN int64(properties.a) AS a, int64(properties.b) AS b """.format( - graph_name - ) - graph.add_graph_documents([doc]) - - # Test initial value: a=1, b=1 - results = graph.query(query, {"nodeId": node0.id}) - assert len(results) == 2, "Actual results: {}".format(results) - assert all((r["a"] == 1 for r in results)), "Actual results: {}".format( - results - ) - assert all((r["b"] == 1 for r in results)), "Actual results: {}".format( - results - ) - - node0.properties["a"] = 2 - edge0.properties["a"] = 2 - graph.add_graph_documents([doc]) - - # Test value after first overwrite: a=2, b=1 - results = graph.query(query, {"nodeId": node0.id}) - assert len(results) == 2, "Actual results: {}".format(results) - assert all((r["a"] == 2 for r in results)), "Actual results: {}".format( - results - ) - assert all((r["b"] == 1 for r in results)), "Actual results: {}".format( - results - ) - - node0.properties = {} - edge0.properties = {} - graph.add_graph_documents([doc]) - - # Test value after second overwrite: a=2, b=1 - results = graph.query(query, {"nodeId": node0.id}) - assert len(results) == 2, "Actual results: {}".format(results) - assert all((r["a"] == 2 for r in results)), "Actual results: {}".format( - results - ) - assert all((r["b"] == 1 for r in results)), "Actual results: {}".format( - results - ) - finally: - print("Clean up graph with name `{}`".format(graph_name)) - graph.cleanup() + graph.graph_name + ) + graph.add_graph_documents([doc]) + + # Test initial value: a=1, b=1 + results = graph.query(query, {"nodeId": node0.id}) + assert len(results) == 2, "Actual results: {}".format(results) + assert all((r["a"] == 1 for r in results)), "Actual results: {}".format(results) + assert all((r["b"] == 1 for r in results)), "Actual results: {}".format(results) + + node0.properties["a"] = 2 + edge0.properties["a"] = 2 + graph.add_graph_documents([doc]) + + # Test value after first overwrite: a=2, b=1 + results = graph.query(query, {"nodeId": node0.id}) + assert len(results) == 2, "Actual results: {}".format(results) + assert all((r["a"] == 2 for r in results)), "Actual results: {}".format(results) + assert all((r["b"] == 1 for r in results)), "Actual results: {}".format(results) + + node0.properties = {} + edge0.properties = {} + graph.add_graph_documents([doc]) + + # Test value after second overwrite: a=2, b=1 + results = graph.query(query, {"nodeId": node0.id}) + assert len(results) == 2, "Actual results: {}".format(results) + assert all((r["a"] == 2 for r in results)), "Actual results: {}".format(results) + assert all((r["b"] == 1 for r in results)), "Actual results: {}".format(results) @pytest.mark.parametrize( "graph_name, raises_exception", @@ -435,36 +385,30 @@ def test_spanner_graph_invalid_graph_name(self, graph_name, raises_exception): ) @pytest.mark.parametrize("use_flexible_schema", [False, True]) - def test_spanner_graph_with_existing_graph(self, use_flexible_schema): - suffix = random_string(num_char=5, exclude_whitespaces=True) - graph_name = "test_graph{}".format(suffix) + @pytest.mark.parametrize("include_json_schema", [False, True]) + def test_spanner_graph_with_existing_graph( + self, setup_graph, use_flexible_schema, include_json_schema + ): + suffix, graph = setup_graph + graph_name = graph.graph_name node_table_name = "{}_node".format(graph_name) edge_table_name = "{}_edge".format(graph_name) - graph = SpannerGraphStore( - instance_id, - google_database, - graph_name, - client=Client(project=project_id), - use_flexible_schema=use_flexible_schema, - ) - graph.refresh_schema() - try: - graph.impl.apply_ddls( - [ - f""" + graph.impl.apply_ddls( + [ + f""" CREATE TABLE IF NOT EXISTS {node_table_name} ( id INT64 NOT NULL, str STRING(MAX), token TOKENLIST AS (TOKENIZE_FULLTEXT(str)) HIDDEN, ) PRIMARY KEY (id) """, - f""" + f""" CREATE TABLE IF NOT EXISTS {edge_table_name} ( id INT64 NOT NULL, target_id INT64 NOT NULL, ) PRIMARY KEY (id, target_id) """, - f""" + f""" CREATE PROPERTY GRAPH IF NOT EXISTS {graph_name} NODE TABLES ( {node_table_name} AS NodeA @@ -487,39 +431,101 @@ def test_spanner_graph_with_existing_graph(self, use_flexible_schema): LABEL EdgeBA PROPERTIES(target_id AS node_a_id, id AS node_b_id), ) """, - ] - ) - graph.refresh_schema() - schema = json.loads(graph.get_schema) - edgeab = graph.schema.get_edge_schema("EdgeAB") - edgeba = graph.schema.get_edge_schema("EdgeBA") - assert (edgeab.source.node_name, edgeab.target.node_name) == ( - "NodeA", - "NodeB", - ) - assert (edgeba.source.node_name, edgeba.target.node_name) == ( - "NodeB", - "NodeA", - ) - # TOKENLIST-typed properties are ignored. - assert len(schema["Node properties per node label"]["Node"]) == 4, schema[ - "Node properties per node label" - ]["Node"] - assert len(schema["Node properties per node label"]["NodeA"]) == 3, schema[ - "Node properties per node label" - ]["NodeA"] - assert len(schema["Node properties per node label"]["NodeB"]) == 3, schema[ - "Node properties per node label" - ]["NodB"] - assert len(schema["Possible edges per label"]["EdgeAB"]) == 4, schema[ - "Possible edges per label" - ]["EdgeAB"] - assert len(schema["Possible edges per label"]["EdgeBA"]) == 4, schema[ - "Possible edges per label" - ]["EdgeBA"] - assert len(schema["Possible edges per label"]["Edge"]) == 8, schema[ - "Possible edges per label" - ]["Edge"] - finally: - print("Clean up graph with name `{}`".format(graph_name)) - graph.cleanup() + ] + ) + graph.refresh_schema() + schema = json.loads(graph.get_schema) + edgeab = graph.schema.get_edge_schema("EdgeAB") + edgeba = graph.schema.get_edge_schema("EdgeBA") + assert (edgeab.source.node_name, edgeab.target.node_name) == ( + "NodeA", + "NodeB", + ) + assert (edgeba.source.node_name, edgeba.target.node_name) == ( + "NodeB", + "NodeA", + ) + # TOKENLIST-typed properties are ignored. + assert len(schema["Node properties per node label"]["Node"]) == 4, schema[ + "Node properties per node label" + ]["Node"] + assert len(schema["Node properties per node label"]["NodeA"]) == 3, schema[ + "Node properties per node label" + ]["NodeA"] + assert len(schema["Node properties per node label"]["NodeB"]) == 3, schema[ + "Node properties per node label" + ]["NodB"] + assert len(schema["Possible edges per label"]["EdgeAB"]) == 4, schema[ + "Possible edges per label" + ]["EdgeAB"] + assert len(schema["Possible edges per label"]["EdgeBA"]) == 4, schema[ + "Possible edges per label" + ]["EdgeBA"] + assert len(schema["Possible edges per label"]["Edge"]) == 8, schema[ + "Possible edges per label" + ]["Edge"] + + @pytest.mark.parametrize("use_flexible_schema", [False, True]) + @pytest.mark.parametrize("include_json_schema", [True]) + def test_spanner_graph_schema_with_json( + self, setup_graph, use_flexible_schema, include_json_schema + ): + suffix, graph = setup_graph + node0 = Node( + id=random_string(), + type="Node0{}".format(suffix), + properties={"j0": random_json()}, + ) + node1 = Node( + id=random_string(), + type="Node1{}".format(suffix), + properties={"j1": random_json()}, + ) + + edge = Relationship( + source=node0, target=node1, type="Edge", properties={"j": random_json()} + ) + + doc = GraphDocument( + nodes=[node0, node1], + relationships=[edge], + source=Document( + page_content="Hello, world!", + metadata={"source": "https://example.com"}, + ), + ) + graph.add_graph_documents([doc]) + schema = json.loads(graph.get_schema) + if use_flexible_schema: + node_json_fields = [ + [f["key"] for f in p["json_fields"]] + for p in schema["Node properties per node label"]["NODE"] + if "json_fields" in p + ] + edge_json_fields = [ + [f["key"] for f in p["json_fields"]] + for p in schema["Edge properties per edge label"]["EDGE"] + if "json_fields" in p + ] + assert node_json_fields in ([["j0"]], [["j1"]]), schema + assert edge_json_fields == [["j"]], schema + else: + node0_json_fields = [ + [f["key"] for f in p["json_fields"]] + for p in schema["Node properties per node label"][node0.type] + if "json_fields" in p + ] + node1_json_fields = [ + [f["key"] for f in p["json_fields"]] + for p in schema["Node properties per node label"][node1.type] + if "json_fields" in p + ] + edge_json_fields = [ + [f["key"] for f in p["json_fields"]] + for edge in schema["Edge properties per edge label"].values() + for p in edge + if "json_fields" in p + ] + assert node0_json_fields == [list(node0.properties["j0"].keys())] + assert node1_json_fields == [list(node1.properties["j1"].keys())] + assert edge_json_fields == [list(edge.properties["j"].keys())] From 25bcbdc45617687f4caf037cc1b3718bf127da9a Mon Sep 17 00:00:00 2001 From: Mingtian Yin Date: Fri, 8 Aug 2025 16:15:27 -0700 Subject: [PATCH 2/4] Fix linter --- src/langchain_google_spanner/graph_store.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/langchain_google_spanner/graph_store.py b/src/langchain_google_spanner/graph_store.py index d7f7571..2acc11e 100644 --- a/src/langchain_google_spanner/graph_store.py +++ b/src/langchain_google_spanner/graph_store.py @@ -1560,8 +1560,9 @@ def cleanup(self): ] ) self.schema = SpannerGraphSchema( - self.schema.graph_name, self.schema.use_flexible_schema, + self.schema.graph_name, + self.schema.use_flexible_schema, self.schema.static_node_properties, self.schema.static_edge_properties, - self.schema.json_schema + self.schema.json_schema, ) From d33f86736ff5a5c759fa3aa90de3fb9807a2c0d3 Mon Sep 17 00:00:00 2001 From: Mingtian Yin Date: Thu, 14 Aug 2025 14:23:13 -0700 Subject: [PATCH 3/4] feat(graph): Support dynamic schema This is to support dynamic schema introduced in https://cloud.google.com/spanner/docs/graph/manage-schemaless-data ``` DYNAMIC LABEL () DYNAMIC PROPERTIES () ``` allows we add labels / properties that are dynamically stored as values (as opposed to schema objects). This is useful to manage graph with schema that dynamically evolves (without making schema changes). --- src/langchain_google_spanner/graph_store.py | 433 +++++++++++------- tests/integration/test_spanner_graph_store.py | 183 +++++--- 2 files changed, 387 insertions(+), 229 deletions(-) diff --git a/src/langchain_google_spanner/graph_store.py b/src/langchain_google_spanner/graph_store.py index 2acc11e..b491ef3 100644 --- a/src/langchain_google_spanner/graph_store.py +++ b/src/langchain_google_spanner/graph_store.py @@ -19,7 +19,7 @@ import re import string from abc import ABC, abstractmethod -from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, Union +from typing import Any, Dict, Set, Generator, Iterable, List, Mapping, Optional, Tuple, Union from google.cloud import spanner from google.cloud.spanner_v1 import JsonObject, param_types @@ -208,18 +208,103 @@ class ElementSchema(object): source: NodeReference target: NodeReference + # DYNAMIC LABEL() + # DYNAMIC PROPERTIES() + dynamic_label_expr: Optional[str] = None + dynamic_property_expr: Optional[str] = None + + # Cache of dynamically fetched labels and properties. + dynamic_schema: Optional[CaseInsensitiveDict[DynamicLabel]] = None + # Cache of dynamically fetched edge patterns. + dynamic_edge_patterns: Optional[List[Tuple[str, str, str]]] = None + def is_dynamic_schema(self) -> bool: return ( - self.types.get(ElementSchema.DYNAMIC_PROPERTY_COLUMN_NAME, None) - == param_types.JSON + self.dynamic_label_expr is not None + or self.dynamic_property_expr is not None ) + def refresh_dynamic_schema(self, dynamic_schema_util: DynamicSchemaUtility): + if self.kind == NODE_KIND: + self.dynamic_schema = dynamic_schema_util.get_dynamic_node_schema( + self.labels + ) + else: + self.dynamic_schema = dynamic_schema_util.get_dynamic_edge_schema( + self.labels + ) + self.dynamic_edge_patterns = dynamic_schema_util.get_dynamic_edge_patterns( + self.labels + ) + + def get_label_and_properties(self, graph: SpannerGraphSchema): + + def get_readable_property(pname, ptype, json_type=None): + prop = { + "name": pname, + "type": TypeUtility.spanner_type_to_schema_str(ptype), + } + # Dynamic properties will have json_types: this represents the + # underlying data type of the json value. + if json_type: + prop["json_type"] = json_type + return prop + + if self.dynamic_schema: + return { + lname: [ + get_readable_property( + pname, self.types.get(pname, param_types.JSON), ptype + ) + for pname, ptype in label.properties + ] + for lname, label in self.dynamic_schema.items() + # Ignore static labels. + if lname not in self.labels + } + return { + label: [ + get_readable_property(pname, self.types[pname]) + for pname in sorted(graph.labels[label].prop_names) + if pname in self.types + ] + for label in sorted(self.labels) + } + + def get_edge_patterns(self, graph: SpannerGraphSchema): + assert self.kind == EDGE_KIND + source = graph.get_node_schema(self.source.node_name) + assert source is not None + target = graph.get_node_schema(self.target.node_name) + assert target is not None + if self.dynamic_edge_patterns: + return [ + (source_node_label, label, target_node_label) + for ( + source_node_label, + label, + target_node_label, + ) in self.dynamic_edge_patterns + # Ignore static labels. + if label not in self.labels + and source_node_label not in source.labels + and target_node_label not in target.labels + ] + return [ + (source_node_label, label, target_node_label) + for label in sorted(self.labels) + for source_node_label in source.labels + for target_node_label in target.labels + ] + @staticmethod def make_node_schema( node_type: str, node_label: str, graph_name: str, property_types: CaseInsensitiveDict, + dynamic_label_expr: Optional[str] = None, + dynamic_property_expr: Optional[str] = None, ) -> ElementSchema: node = ElementSchema() node.types = property_types @@ -229,6 +314,8 @@ def make_node_schema( node.name = node_type node.kind = NODE_KIND node.key_columns = [ElementSchema.NODE_KEY_COLUMN_NAME] + node.dynamic_label_expr = dynamic_label_expr + node.dynamic_property_expr = dynamic_property_expr return node @staticmethod @@ -240,6 +327,8 @@ def make_edge_schema( property_types: CaseInsensitiveDict, source_node_type: str, target_node_type: str, + dynamic_label_expr: Optional[str] = None, + dynamic_property_expr: Optional[str] = None, ) -> ElementSchema: edge = ElementSchema() edge.types = property_types @@ -273,6 +362,8 @@ def make_edge_schema( [ElementSchema.NODE_KEY_COLUMN_NAME], [ElementSchema.TARGET_NODE_KEY_COLUMN_NAME], ) + edge.dynamic_label_expr = dynamic_label_expr + edge.dynamic_property_expr = dynamic_property_expr return edge @staticmethod @@ -354,7 +445,12 @@ def from_dynamic_nodes( ) ) return ElementSchema.make_node_schema( - NODE_KIND, NODE_KIND, graph_schema.graph_name, types + NODE_KIND, + NODE_KIND, + graph_schema.graph_name, + types, + dynamic_label_expr=ElementSchema.DYNAMIC_LABEL_COLUMN_NAME, + dynamic_property_expr=ElementSchema.DYNAMIC_PROPERTY_COLUMN_NAME, ) @staticmethod @@ -471,6 +567,8 @@ def from_dynamic_edges( types, edges[0].source.type, edges[0].target.type, + dynamic_label_expr=ElementSchema.DYNAMIC_LABEL_COLUMN_NAME, + dynamic_property_expr=ElementSchema.DYNAMIC_PROPERTY_COLUMN_NAME, ) def add_nodes( @@ -502,6 +600,19 @@ def add_nodes( properties[ElementSchema.NODE_KEY_COLUMN_NAME] = node.id if self.is_dynamic_schema(): + assert ( + self.dynamic_label_expr == ElementSchema.DYNAMIC_LABEL_COLUMN_NAME + ), "Require dynamic label expression to be %s: got %s" % ( + ElementSchema.DYNAMIC_LABEL_COLUMN_NAME, + self.dynamic_label_expr, + ) + assert ( + self.dynamic_property_expr == + ElementSchema.DYNAMIC_PROPERTY_COLUMN_NAME + ), "Require dynamic property expression to be %s: got %s" % ( + ElementSchema.DYNAMIC_PROPERTY_COLUMN_NAME, + self.dynamic_property_expr, + ) dynamic_properties = { k: TypeUtility.value_for_json(v) for k, v in node.properties.items() @@ -552,6 +663,19 @@ def add_edges( properties[ElementSchema.TARGET_NODE_KEY_COLUMN_NAME] = edge.target.id if self.is_dynamic_schema(): + assert ( + self.dynamic_label_expr == ElementSchema.DYNAMIC_LABEL_COLUMN_NAME + ), "Require dynamic label expression to be %s: got %s" % ( + ElementSchema.DYNAMIC_LABEL_COLUMN_NAME, + self.dynamic_label_expr, + ) + assert ( + self.dynamic_property_expr == + ElementSchema.DYNAMIC_PROPERTY_COLUMN_NAME + ), "Require dynamic property expression to be %s: got %s" % ( + ElementSchema.DYNAMIC_PROPERTY_COLUMN_NAME, + self.dynamic_property_expr, + ) dynamic_properties = { k: TypeUtility.value_for_json(v) for k, v in edge.properties.items() @@ -621,6 +745,9 @@ def from_info_schema( element_schema["destinationNodeTable"]["nodeTableColumns"], element_schema["destinationNodeTable"]["edgeTableColumns"], ) + + element.dynamic_label_expr = element_schema.get("dynamicLabelExpr") + element.dynamic_property_expr = element_schema.get("dynamicPropertyExpr") return element def to_ddl(self, graph_schema: SpannerGraphSchema) -> str: @@ -746,6 +873,11 @@ def evolve(self, new_schema: ElementSchema) -> List[str]: ] self.properties.update(new_schema.properties) self.types.update(new_schema.types) + + self.dynamic_label_expr = new_schema.dynamic_label_expr + self.dynamic_property_expr = new_schema.dynamic_property_expr + self.dynamic_schema = new_schema.dynamic_schema + self.dynamic_edge_patterns = new_schema.dynamic_edge_patterns return ddls @@ -756,6 +888,9 @@ def __init__(self, name: str, prop_names: set[str]): self.name = name self.prop_names = prop_names + def __repr__(self): + return f"Label({self.name}, {self.prop_names})" + class NodeReference(object): """Schema representation of a source or destination node reference.""" @@ -766,65 +901,108 @@ def __init__(self, node_name: str, node_keys: List[str], edge_keys: List[str]): self.edge_keys = edge_keys -class JsonSchema(object): - NODE_JSON_PROPERTY_QUERY_TEMPLATE = """ +class DynamicLabel(object): + """Representation of a dynamic label.""" + + def __init__(self, name: str, properties: List[Tuple[str, str]]): + self.name = name + self.properties = properties + + +class DynamicSchemaUtility(object): + """Utility class that dynamically fetches graph schema.""" + + # Sample a list of (label, properties) for nodes of static label_expr. + NODE_DYNAMIC_SCHEMA_QUERY_TEMPLATE = """ + GRAPH `{graph_id}` + MATCH (n:{label_expr}) + LET json = SAFE_TO_JSON(n).properties + FOR label IN LABELS(n) + RETURN label, ANY_VALUE(json) AS json + NEXT + LET json_fields = JSON_KEYS(json) + RETURN label, ARRAY {{ + GRAPH `{graph_id}` + FOR field IN json_fields + FILTER json[field] IS NOT NULL + LET type = JSON_TYPE(json[field]) + FILTER type != 'null' + RETURN STRUCT(field, type) AS field + }} AS properties + """ + + # Sample a list of (label, properties) for edges of static label_expr. + EDGE_DYNAMIC_SCHEMA_QUERY_TEMPLATE = """ GRAPH `{graph_id}` - MATCH (n:`{label_name}`) - WHERE n.`{property_name}` IS NOT NULL - LET j = n.`{property_name}` - LIMIT 1 - LET keys = JSON_KEYS(j, 1) - FOR key IN keys - LET v = j[key] - LET type = json_type(v) - RETURN key, type + MATCH -[n:{label_expr}]-> + LET json = SAFE_TO_JSON(n).properties + FOR label IN LABELS(n) + RETURN label, ANY_VALUE(json) AS json + NEXT + LET json_fields = JSON_KEYS(json) + RETURN label, ARRAY {{ + GRAPH `{graph_id}` + FOR field IN json_fields + FILTER json[field] IS NOT NULL + LET type = JSON_TYPE(json[field]) + FILTER type != 'null' + RETURN STRUCT(field, type) AS property + ORDER BY field + }} AS properties + ORDER BY label """ - EDGE_JSON_PROPERTY_QUERY_TEMPLATE = """ + # Find all (source_node_label, edge_label, target_node_label) triplets. + EDGE_PATTERN_QUERY_TEMPLATE = """ GRAPH `{graph_id}` - MATCH -[n:`{label_name}`]-> - WHERE n.`{property_name}` IS NOT NULL - LET j = n.`{property_name}` - LIMIT 1 - LET keys = JSON_KEYS(j, 1) - FOR key IN keys - LET v = j[key] - let type = json_type(v) - RETURN key, type + MATCH (src) -[n:{label_expr}]-> (dst) + FOR edge_label IN LABELS(n) + FOR src_label IN LABELS(src) + FOR dst_label IN LABELS(dst) + RETURN DISTINCT src_label, edge_label, dst_label + ORDER BY src_label, edge_label, dst_label """ def __init__(self, graph_name: str, impl: SpannerInterface): self._graph_name = graph_name self._impl = impl - def get_node_json_property_schema(self, node_label: str, property_names: List[str]): - return self._get_label_json_property_schema( - node_label, property_names, self.NODE_JSON_PROPERTY_QUERY_TEMPLATE - ) + @staticmethod + def make_label_expr(labels: List[str]) -> str: + return " & ".join([f"`{label}`" for label in labels]) - def get_edge_json_property_schema(self, edge_label: str, property_names: List[str]): - return self._get_label_json_property_schema( - edge_label, property_names, self.EDGE_JSON_PROPERTY_QUERY_TEMPLATE - ) + def get_dynamic_node_schema(self, labels: List[str]): + return self._get_dynamic_schema(labels, self.NODE_DYNAMIC_SCHEMA_QUERY_TEMPLATE) - def _get_label_json_property_schema( - self, label: str, property_names: List[str], query_template: str - ): - if len(property_names) == 0: - return CaseInsensitiveDict({}) + def get_dynamic_edge_schema(self, labels: List[str]): + return self._get_dynamic_schema(labels, self.EDGE_DYNAMIC_SCHEMA_QUERY_TEMPLATE) + + def _get_dynamic_schema(self, labels: List[str], query_template: str): + label_expr = self.make_label_expr(labels) return CaseInsensitiveDict( { - pname: [ - row - for row in self._impl.query( - query_template.format( - graph_id=self._graph_name, - label_name=label, - property_name=pname, - ) + row["label"]: DynamicLabel( + name=row["label"], + properties=row["properties"], + ) + for row in self._impl.query( + query_template.format( + graph_id=self._graph_name, label_expr=label_expr ) - ] - for pname in property_names + ) + } + ) + + def get_dynamic_edge_patterns(self, labels: List[str]): + return set( + { + (row["src_label"], row["edge_label"], row["dst_label"]) + for row in self._impl.query( + self.EDGE_PATTERN_QUERY_TEMPLATE.format( + graph_id=self._graph_name, + label_expr=self.make_label_expr(labels), + ) + ) } ) @@ -844,7 +1022,7 @@ def __init__( use_flexible_schema: bool, static_node_properties: List[str] = [], static_edge_properties: List[str] = [], - json_schema: Optional[JsonSchema] = None, + dynamic_schema_util: Optional[DynamicSchemaUtility] = None, ): """Initializes the graph schema. @@ -872,16 +1050,10 @@ def __init__( self.edge_tables: CaseInsensitiveDict[ElementSchema] = CaseInsensitiveDict({}) self.labels: CaseInsensitiveDict[Label] = CaseInsensitiveDict({}) self.properties: CaseInsensitiveDict[param_types.Type] = CaseInsensitiveDict({}) - self.node_json_property_schema: CaseInsensitiveDict[Dict] = CaseInsensitiveDict( - {} - ) - self.edge_json_property_schema: CaseInsensitiveDict[Dict] = CaseInsensitiveDict( - {} - ) self.use_flexible_schema = use_flexible_schema self.static_node_properties = set(static_node_properties) self.static_edge_properties = set(static_edge_properties) - self.json_schema = json_schema + self.dynamic_schema_util = dynamic_schema_util def evolve(self, graph_documents: List[GraphDocument]) -> List[str]: """Evolves current schema into a schema representing the input documents. @@ -933,15 +1105,17 @@ def from_information_schema(self, info_schema: Dict[str, Any]) -> None: ) for node in info_schema["nodeTables"]: node_schema = ElementSchema.from_info_schema(node, decl_by_types) + if node_schema.is_dynamic_schema() and self.dynamic_schema_util: + node_schema.refresh_dynamic_schema(self.dynamic_schema_util) self._update_node_schema(node_schema) self._update_labels_and_properties(node_schema) - self._update_json_property_schema(node_schema) for edge in info_schema.get("edgeTables", []): edge_schema = ElementSchema.from_info_schema(edge, decl_by_types) + if edge_schema.is_dynamic_schema() and self.dynamic_schema_util: + edge_schema.refresh_dynamic_schema(self.dynamic_schema_util) self._update_edge_schema(edge_schema) self._update_labels_and_properties(edge_schema) - self._update_json_property_schema(edge_schema) def node_type_name(self, name: str) -> str: return NODE_KIND if self.use_flexible_schema else name @@ -1007,73 +1181,36 @@ def __repr__(self) -> str: Returns: str: a string representation of the graph schema. """ - properties = CaseInsensitiveDict( - { - k: TypeUtility.spanner_type_to_schema_str(v) - for k, v in self.properties.items() - } - ) - node_labels = {label for node in self.nodes.values() for label in node.labels} - edge_labels = {label for edge in self.edges.values() for label in edge.labels} - Triplet = Tuple[ElementSchema, ElementSchema, ElementSchema] - triplets_per_label: CaseInsensitiveDict[List[Triplet]] = CaseInsensitiveDict({}) - for edge in self.edges.values(): - for label in edge.labels: - source_node = self.get_node_schema(edge.source.node_name) - target_node = self.get_node_schema(edge.target.node_name) - if source_node is None: - raise ValueError(f"Source node {edge.source.node_name} not found") - if target_node is None: - raise ValueError(f"Tource node {edge.target.node_name} not found") - triplets_per_label.setdefault(label, []).append( - (source_node, edge, target_node) - ) - - def repr_property(lname, pname, ptype, json_fields): - if not json_fields: - return {"name": pname, "type": ptype} - return {"name": pname, "type": ptype, "json_fields": json_fields} + node_properties_per_label: Dict[str, Dict] = {} + edge_properties_per_label: Dict[str, Dict] = {} + edge_patterns_per_label: Dict[str, Set[str]] = {} + for node in self.nodes.values(): + node_properties_per_label.update( + node.get_label_and_properties(self)) + for edge in self.edges.values(): + edge_properties_per_label.update( + edge.get_label_and_properties(self)) + for src_node_label, label, tgt_node_label in edge.get_edge_patterns( + self): + edge_patterns_per_label.setdefault(label, set()).add( + "(:{}) -[:{}]-> (:{})".format(src_node_label, label, + tgt_node_label)) return json.dumps( { "Name of graph": self.graph_name, - "Node properties per node label": { - label: [ - repr_property( - label, - pname, - properties[pname], - self.node_json_property_schema.get(label, {}).get(pname), - ) - for pname in sorted(self.labels[label].prop_names) - ] - for label in sorted(node_labels) - }, - "Edge properties per edge label": { - label: [ - repr_property( - label, - pname, - properties[pname], - self.edge_json_property_schema.get(label, {}).get(pname), - ) - for pname in sorted(self.labels[label].prop_names) - ] - for label in sorted(edge_labels) - }, - "Possible edges per label": { - label: [ - "(:{}) -[:{}]-> (:{})".format( - source_node_label, label, target_node_label - ) - for (source, edge, target) in triplets - for source_node_label in source.labels - for target_node_label in target.labels - ] - for label, triplets in triplets_per_label.items() - }, + "Node properties per node label": dict( + sorted(node_properties_per_label.items()) + ), + "Edge properties per edge label": dict( + sorted(edge_properties_per_label.items()) + ), + "Possible edges per label": dict( + sorted(edge_patterns_per_label.items()) + ), }, indent=2, + default=lambda s: sorted(s), ) def to_ddl(self) -> str: @@ -1105,12 +1242,17 @@ def construct_label_and_properties_list( labels: CaseInsensitiveDict[Label], element: ElementSchema, ) -> str: - return "\n".join( - ( - construct_label_and_properties(target_label, labels, element) - for target_label in target_labels + clauses = [ + construct_label_and_properties(target_label, labels, element) + for target_label in target_labels + ] + if element.dynamic_label_expr: + clauses.append("DYNAMIC LABEL ({})".format(element.dynamic_label_expr)) + if element.dynamic_property_expr: + clauses.append( + "DYNAMIC PROPERTIES ({})".format(element.dynamic_property_expr) ) - ) + return "\n".join(clauses) def construct_columns(cols: List[str]) -> str: return ", ".join(to_identifiers(cols)) @@ -1210,37 +1352,6 @@ def _update_edge_schema(self, edge_schema: ElementSchema) -> List[str]: self.edges[edge_schema.name] = old_schema or edge_schema return ddls - def _update_json_property_schema(self, element_schema: ElementSchema) -> None: - if self.json_schema is None: - return - if len(element_schema.labels) == 0: - return - lname = element_schema.labels[0] - if element_schema.kind == NODE_KIND: - json_property_schema = self.json_schema.get_node_json_property_schema( - lname, - [ - pname - for pname, ptype in element_schema.types.items() - if ptype == param_types.JSON - ], - ) - self.node_json_property_schema.update( - {l: json_property_schema for l in element_schema.labels} - ) - else: - json_property_schema = self.json_schema.get_edge_json_property_schema( - lname, - [ - pname - for pname, ptype in element_schema.types.items() - if ptype == param_types.JSON - ], - ) - self.edge_json_property_schema.update( - {l: json_property_schema for l in element_schema.labels} - ) - def _update_labels_and_properties(self, element_schema: ElementSchema) -> None: """Updates labels and properties based on an element schema. @@ -1407,7 +1518,6 @@ def __init__( static_edge_properties: List[str] = [], impl: Optional[SpannerInterface] = None, timeout: Optional[float] = None, - include_json_schema: bool = False, ): """Initializes SpannerGraphStore. @@ -1423,7 +1533,6 @@ def __init__( static_edge_properties: in flexible schema, treat these edge properties as static. timeout (Optional[float]): The timeout for queries in seconds. - include_json_schema (Optional[bool]): Whether to include json fields in the schema. """ self.graph_name = graph_name self.impl = impl or SpannerImpl( @@ -1435,11 +1544,9 @@ def __init__( self.schema = SpannerGraphSchema( graph_name, use_flexible_schema, - static_node_properties, - static_edge_properties, - json_schema=( - JsonSchema(graph_name, self.impl) if include_json_schema else None - ), + static_node_properties=static_node_properties, + static_edge_properties=static_edge_properties, + dynamic_schema_util=DynamicSchemaUtility(graph_name, self.impl), ) self.refresh_schema() @@ -1564,5 +1671,5 @@ def cleanup(self): self.schema.use_flexible_schema, self.schema.static_node_properties, self.schema.static_edge_properties, - self.schema.json_schema, + self.schema.dynamic_schema_util, ) diff --git a/tests/integration/test_spanner_graph_store.py b/tests/integration/test_spanner_graph_store.py index ce1328a..0b9c9cd 100644 --- a/tests/integration/test_spanner_graph_store.py +++ b/tests/integration/test_spanner_graph_store.py @@ -102,7 +102,6 @@ def random_generators(): + [random_none, random_json] ) - properties = [ ("p{}".format(i), random_val_gen) for i, random_val_gen in enumerate(random_generators()) @@ -166,7 +165,6 @@ def random_graph_doc(suffix): @pytest.fixture def setup_graph(request): use_flexible_schema = request.getfixturevalue("use_flexible_schema") - include_json_schema = request.getfixturevalue("include_json_schema") suffix = random_string(num_char=5, exclude_whitespaces=True) graph_name = "test_graph{}".format(suffix) graph = SpannerGraphStore( @@ -177,7 +175,6 @@ def setup_graph(request): use_flexible_schema=use_flexible_schema, static_node_properties=["a", "b"], static_edge_properties=["a", "b"], - include_json_schema=include_json_schema, ) graph.refresh_schema() @@ -189,9 +186,10 @@ def setup_graph(request): class TestSpannerGraphStore: @pytest.mark.parametrize("use_flexible_schema", [False, True]) - @pytest.mark.parametrize("include_json_schema", [False, True]) def test_spanner_graph_random_doc( - self, setup_graph, use_flexible_schema, include_json_schema + self, + setup_graph, + use_flexible_schema, ): suffix, graph = setup_graph node_ids = set() @@ -233,9 +231,10 @@ def test_spanner_graph_random_doc( assert results[1]["num_elements"] == len(node_ids), "Mismatch number of nodes" @pytest.mark.parametrize("use_flexible_schema", [False, True]) - @pytest.mark.parametrize("include_json_schema", [False, True]) def test_spanner_graph_doc_with_duplicate_elements( - self, setup_graph, use_flexible_schema, include_json_schema + self, + setup_graph, + use_flexible_schema, ): suffix, graph = setup_graph node0 = random_node("Node0{}".format(suffix)) @@ -282,9 +281,10 @@ def test_spanner_graph_doc_with_duplicate_elements( ) @pytest.mark.parametrize("use_flexible_schema", [False, True]) - @pytest.mark.parametrize("include_json_schema", [False, True]) def test_spanner_graph_avoid_unnecessary_overwrite( - self, setup_graph, use_flexible_schema, include_json_schema + self, + setup_graph, + use_flexible_schema, ): suffix, graph = setup_graph node0 = Node( @@ -385,9 +385,10 @@ def test_spanner_graph_invalid_graph_name(self, graph_name, raises_exception): ) @pytest.mark.parametrize("use_flexible_schema", [False, True]) - @pytest.mark.parametrize("include_json_schema", [False, True]) def test_spanner_graph_with_existing_graph( - self, setup_graph, use_flexible_schema, include_json_schema + self, + setup_graph, + use_flexible_schema, ): suffix, graph = setup_graph graph_name = graph.graph_name @@ -446,45 +447,91 @@ def test_spanner_graph_with_existing_graph( "NodeA", ) # TOKENLIST-typed properties are ignored. - assert len(schema["Node properties per node label"]["Node"]) == 4, schema[ - "Node properties per node label" - ]["Node"] - assert len(schema["Node properties per node label"]["NodeA"]) == 3, schema[ - "Node properties per node label" - ]["NodeA"] - assert len(schema["Node properties per node label"]["NodeB"]) == 3, schema[ - "Node properties per node label" - ]["NodB"] - assert len(schema["Possible edges per label"]["EdgeAB"]) == 4, schema[ - "Possible edges per label" - ]["EdgeAB"] - assert len(schema["Possible edges per label"]["EdgeBA"]) == 4, schema[ - "Possible edges per label" - ]["EdgeBA"] - assert len(schema["Possible edges per label"]["Edge"]) == 8, schema[ - "Possible edges per label" - ]["Edge"] + assert schema["Node properties per node label"]["Node"] == [ + { + "name": "id", + "type": "INT64" + }, + { + "name": "node_b_id", + "type": "INT64" + }, + { + "name": "str", + "type": "STRING" + }, + ], 'Invalid Node properties' + assert schema["Node properties per node label"]["NodeA"] == [ + { + "name": "id", + "type": "INT64" + }, + { + "name": "node_a_id", + "type": "INT64" + }, + { + "name": "str", + "type": "STRING" + }, + ], 'Invalid NodeA properties' + assert schema["Node properties per node label"]["NodeB"] == [ + { + "name": "id", + "type": "INT64" + }, + { + "name": "node_b_id", + "type": "INT64" + }, + { + "name": "str", + "type": "STRING" + }, + ], 'Invalid NodeB properties' + assert schema["Possible edges per label"]["EdgeAB"] == [ + '(:Node) -[:EdgeAB]-> (:Node)', + '(:Node) -[:EdgeAB]-> (:NodeB)', + '(:NodeA) -[:EdgeAB]-> (:Node)', + '(:NodeA) -[:EdgeAB]-> (:NodeB)', + ], 'Invalid EdgeAB patterns' + assert schema["Possible edges per label"]["EdgeBA"] == [ + '(:Node) -[:EdgeBA]-> (:Node)', + '(:Node) -[:EdgeBA]-> (:NodeA)', + '(:NodeB) -[:EdgeBA]-> (:Node)', + '(:NodeB) -[:EdgeBA]-> (:NodeA)', + ], 'Invalid EdgeBA patterns' + assert schema["Possible edges per label"]["Edge"] == [ + '(:Node) -[:Edge]-> (:Node)', + '(:Node) -[:Edge]-> (:NodeA)', + '(:Node) -[:Edge]-> (:NodeB)', + '(:NodeA) -[:Edge]-> (:Node)', + '(:NodeA) -[:Edge]-> (:NodeB)', + '(:NodeB) -[:Edge]-> (:Node)', + '(:NodeB) -[:Edge]-> (:NodeA)', + ], 'Invalid Edge patterns' @pytest.mark.parametrize("use_flexible_schema", [False, True]) - @pytest.mark.parametrize("include_json_schema", [True]) - def test_spanner_graph_schema_with_json( - self, setup_graph, use_flexible_schema, include_json_schema + def test_spanner_graph_schema_representation( + self, + setup_graph, + use_flexible_schema, ): suffix, graph = setup_graph node0 = Node( id=random_string(), type="Node0{}".format(suffix), - properties={"j0": random_json()}, + properties={"j0": random_int()}, ) node1 = Node( id=random_string(), type="Node1{}".format(suffix), - properties={"j1": random_json()}, - ) - - edge = Relationship( - source=node0, target=node1, type="Edge", properties={"j": random_json()} + properties={"j1": random_string()}, ) + edge = Relationship(source=node0, + target=node1, + type="Links", + properties={"j": random_json()}) doc = GraphDocument( nodes=[node0, node1], @@ -496,36 +543,40 @@ def test_spanner_graph_schema_with_json( ) graph.add_graph_documents([doc]) schema = json.loads(graph.get_schema) + node0_json_fields = sorted([ + p['name'] + for p in schema["Node properties per node label"][node0.type] + ]) + node1_json_fields = sorted([ + p['name'] + for p in schema["Node properties per node label"][node1.type] + ]) + edge_json_fields = sorted([ + p['name'] + for edge in schema["Edge properties per edge label"].values() + for p in edge + ]) + edge_patterns = sorted([ + pattern + for edge in schema["Possible edges per label"].values() + for pattern in edge + ]) if use_flexible_schema: - node_json_fields = [ - [f["key"] for f in p["json_fields"]] - for p in schema["Node properties per node label"]["NODE"] - if "json_fields" in p + assert node0_json_fields == ['id', 'j0', 'label', 'properties'] + assert node1_json_fields == ['id', 'j1', 'label', 'properties'] + assert edge_json_fields == [ + 'id', 'j', 'label', 'properties', 'target_id' ] - edge_json_fields = [ - [f["key"] for f in p["json_fields"]] - for p in schema["Edge properties per edge label"]["EDGE"] - if "json_fields" in p + assert edge_patterns == [ + '(:{src}) -[:{edge}]-> (:{dst})'.format(src=node0.type, + edge=edge.type, + dst=node1.type) ] - assert node_json_fields in ([["j0"]], [["j1"]]), schema - assert edge_json_fields == [["j"]], schema else: - node0_json_fields = [ - [f["key"] for f in p["json_fields"]] - for p in schema["Node properties per node label"][node0.type] - if "json_fields" in p - ] - node1_json_fields = [ - [f["key"] for f in p["json_fields"]] - for p in schema["Node properties per node label"][node1.type] - if "json_fields" in p - ] - edge_json_fields = [ - [f["key"] for f in p["json_fields"]] - for edge in schema["Edge properties per edge label"].values() - for p in edge - if "json_fields" in p + assert node0_json_fields == ['id', 'j0'] + assert node1_json_fields == ['id', 'j1'] + assert edge_json_fields == ['id', 'j', 'target_id'] + assert edge_patterns == [ + '(:{src}) -[:{src}_{edge}_{dst}]-> (:{dst})'.format( + src=node0.type, edge=edge.type, dst=node1.type) ] - assert node0_json_fields == [list(node0.properties["j0"].keys())] - assert node1_json_fields == [list(node1.properties["j1"].keys())] - assert edge_json_fields == [list(edge.properties["j"].keys())] From 56db4af728e17671a723dcd78b66da140b548499 Mon Sep 17 00:00:00 2001 From: Mingtian Yin Date: Fri, 15 Aug 2025 17:12:36 -0700 Subject: [PATCH 4/4] Fix linter --- src/langchain_google_spanner/graph_store.py | 34 ++-- tests/integration/test_spanner_graph_store.py | 158 ++++++++---------- 2 files changed, 87 insertions(+), 105 deletions(-) diff --git a/src/langchain_google_spanner/graph_store.py b/src/langchain_google_spanner/graph_store.py index b491ef3..34e4032 100644 --- a/src/langchain_google_spanner/graph_store.py +++ b/src/langchain_google_spanner/graph_store.py @@ -19,7 +19,18 @@ import re import string from abc import ABC, abstractmethod -from typing import Any, Dict, Set, Generator, Iterable, List, Mapping, Optional, Tuple, Union +from typing import ( + Any, + Dict, + Generator, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, + Union, +) from google.cloud import spanner from google.cloud.spanner_v1 import JsonObject, param_types @@ -607,8 +618,8 @@ def add_nodes( self.dynamic_label_expr, ) assert ( - self.dynamic_property_expr == - ElementSchema.DYNAMIC_PROPERTY_COLUMN_NAME + self.dynamic_property_expr + == ElementSchema.DYNAMIC_PROPERTY_COLUMN_NAME ), "Require dynamic property expression to be %s: got %s" % ( ElementSchema.DYNAMIC_PROPERTY_COLUMN_NAME, self.dynamic_property_expr, @@ -670,8 +681,8 @@ def add_edges( self.dynamic_label_expr, ) assert ( - self.dynamic_property_expr == - ElementSchema.DYNAMIC_PROPERTY_COLUMN_NAME + self.dynamic_property_expr + == ElementSchema.DYNAMIC_PROPERTY_COLUMN_NAME ), "Require dynamic property expression to be %s: got %s" % ( ElementSchema.DYNAMIC_PROPERTY_COLUMN_NAME, self.dynamic_property_expr, @@ -1185,17 +1196,14 @@ def __repr__(self) -> str: edge_properties_per_label: Dict[str, Dict] = {} edge_patterns_per_label: Dict[str, Set[str]] = {} for node in self.nodes.values(): - node_properties_per_label.update( - node.get_label_and_properties(self)) + node_properties_per_label.update(node.get_label_and_properties(self)) for edge in self.edges.values(): - edge_properties_per_label.update( - edge.get_label_and_properties(self)) - for src_node_label, label, tgt_node_label in edge.get_edge_patterns( - self): + edge_properties_per_label.update(edge.get_label_and_properties(self)) + for src_node_label, label, tgt_node_label in edge.get_edge_patterns(self): edge_patterns_per_label.setdefault(label, set()).add( - "(:{}) -[:{}]-> (:{})".format(src_node_label, label, - tgt_node_label)) + "(:{}) -[:{}]-> (:{})".format(src_node_label, label, tgt_node_label) + ) return json.dumps( { "Name of graph": self.graph_name, diff --git a/tests/integration/test_spanner_graph_store.py b/tests/integration/test_spanner_graph_store.py index 0b9c9cd..36536fc 100644 --- a/tests/integration/test_spanner_graph_store.py +++ b/tests/integration/test_spanner_graph_store.py @@ -102,6 +102,7 @@ def random_generators(): + [random_none, random_json] ) + properties = [ ("p{}".format(i), random_val_gen) for i, random_val_gen in enumerate(random_generators()) @@ -448,68 +449,41 @@ def test_spanner_graph_with_existing_graph( ) # TOKENLIST-typed properties are ignored. assert schema["Node properties per node label"]["Node"] == [ - { - "name": "id", - "type": "INT64" - }, - { - "name": "node_b_id", - "type": "INT64" - }, - { - "name": "str", - "type": "STRING" - }, - ], 'Invalid Node properties' + {"name": "id", "type": "INT64"}, + {"name": "node_b_id", "type": "INT64"}, + {"name": "str", "type": "STRING"}, + ], "Invalid Node properties" assert schema["Node properties per node label"]["NodeA"] == [ - { - "name": "id", - "type": "INT64" - }, - { - "name": "node_a_id", - "type": "INT64" - }, - { - "name": "str", - "type": "STRING" - }, - ], 'Invalid NodeA properties' + {"name": "id", "type": "INT64"}, + {"name": "node_a_id", "type": "INT64"}, + {"name": "str", "type": "STRING"}, + ], "Invalid NodeA properties" assert schema["Node properties per node label"]["NodeB"] == [ - { - "name": "id", - "type": "INT64" - }, - { - "name": "node_b_id", - "type": "INT64" - }, - { - "name": "str", - "type": "STRING" - }, - ], 'Invalid NodeB properties' + {"name": "id", "type": "INT64"}, + {"name": "node_b_id", "type": "INT64"}, + {"name": "str", "type": "STRING"}, + ], "Invalid NodeB properties" assert schema["Possible edges per label"]["EdgeAB"] == [ - '(:Node) -[:EdgeAB]-> (:Node)', - '(:Node) -[:EdgeAB]-> (:NodeB)', - '(:NodeA) -[:EdgeAB]-> (:Node)', - '(:NodeA) -[:EdgeAB]-> (:NodeB)', - ], 'Invalid EdgeAB patterns' + "(:Node) -[:EdgeAB]-> (:Node)", + "(:Node) -[:EdgeAB]-> (:NodeB)", + "(:NodeA) -[:EdgeAB]-> (:Node)", + "(:NodeA) -[:EdgeAB]-> (:NodeB)", + ], "Invalid EdgeAB patterns" assert schema["Possible edges per label"]["EdgeBA"] == [ - '(:Node) -[:EdgeBA]-> (:Node)', - '(:Node) -[:EdgeBA]-> (:NodeA)', - '(:NodeB) -[:EdgeBA]-> (:Node)', - '(:NodeB) -[:EdgeBA]-> (:NodeA)', - ], 'Invalid EdgeBA patterns' + "(:Node) -[:EdgeBA]-> (:Node)", + "(:Node) -[:EdgeBA]-> (:NodeA)", + "(:NodeB) -[:EdgeBA]-> (:Node)", + "(:NodeB) -[:EdgeBA]-> (:NodeA)", + ], "Invalid EdgeBA patterns" assert schema["Possible edges per label"]["Edge"] == [ - '(:Node) -[:Edge]-> (:Node)', - '(:Node) -[:Edge]-> (:NodeA)', - '(:Node) -[:Edge]-> (:NodeB)', - '(:NodeA) -[:Edge]-> (:Node)', - '(:NodeA) -[:Edge]-> (:NodeB)', - '(:NodeB) -[:Edge]-> (:Node)', - '(:NodeB) -[:Edge]-> (:NodeA)', - ], 'Invalid Edge patterns' + "(:Node) -[:Edge]-> (:Node)", + "(:Node) -[:Edge]-> (:NodeA)", + "(:Node) -[:Edge]-> (:NodeB)", + "(:NodeA) -[:Edge]-> (:Node)", + "(:NodeA) -[:Edge]-> (:NodeB)", + "(:NodeB) -[:Edge]-> (:Node)", + "(:NodeB) -[:Edge]-> (:NodeA)", + ], "Invalid Edge patterns" @pytest.mark.parametrize("use_flexible_schema", [False, True]) def test_spanner_graph_schema_representation( @@ -528,10 +502,9 @@ def test_spanner_graph_schema_representation( type="Node1{}".format(suffix), properties={"j1": random_string()}, ) - edge = Relationship(source=node0, - target=node1, - type="Links", - properties={"j": random_json()}) + edge = Relationship( + source=node0, target=node1, type="Links", properties={"j": random_json()} + ) doc = GraphDocument( nodes=[node0, node1], @@ -543,40 +516,41 @@ def test_spanner_graph_schema_representation( ) graph.add_graph_documents([doc]) schema = json.loads(graph.get_schema) - node0_json_fields = sorted([ - p['name'] - for p in schema["Node properties per node label"][node0.type] - ]) - node1_json_fields = sorted([ - p['name'] - for p in schema["Node properties per node label"][node1.type] - ]) - edge_json_fields = sorted([ - p['name'] - for edge in schema["Edge properties per edge label"].values() - for p in edge - ]) - edge_patterns = sorted([ - pattern - for edge in schema["Possible edges per label"].values() - for pattern in edge - ]) - if use_flexible_schema: - assert node0_json_fields == ['id', 'j0', 'label', 'properties'] - assert node1_json_fields == ['id', 'j1', 'label', 'properties'] - assert edge_json_fields == [ - 'id', 'j', 'label', 'properties', 'target_id' + node0_json_fields = sorted( + [p["name"] for p in schema["Node properties per node label"][node0.type]] + ) + node1_json_fields = sorted( + [p["name"] for p in schema["Node properties per node label"][node1.type]] + ) + edge_json_fields = sorted( + [ + p["name"] + for edge in schema["Edge properties per edge label"].values() + for p in edge ] + ) + edge_patterns = sorted( + [ + pattern + for edge in schema["Possible edges per label"].values() + for pattern in edge + ] + ) + if use_flexible_schema: + assert node0_json_fields == ["id", "j0", "label", "properties"] + assert node1_json_fields == ["id", "j1", "label", "properties"] + assert edge_json_fields == ["id", "j", "label", "properties", "target_id"] assert edge_patterns == [ - '(:{src}) -[:{edge}]-> (:{dst})'.format(src=node0.type, - edge=edge.type, - dst=node1.type) + "(:{src}) -[:{edge}]-> (:{dst})".format( + src=node0.type, edge=edge.type, dst=node1.type + ) ] else: - assert node0_json_fields == ['id', 'j0'] - assert node1_json_fields == ['id', 'j1'] - assert edge_json_fields == ['id', 'j', 'target_id'] + assert node0_json_fields == ["id", "j0"] + assert node1_json_fields == ["id", "j1"] + assert edge_json_fields == ["id", "j", "target_id"] assert edge_patterns == [ - '(:{src}) -[:{src}_{edge}_{dst}]-> (:{dst})'.format( - src=node0.type, edge=edge.type, dst=node1.type) + "(:{src}) -[:{src}_{edge}_{dst}]-> (:{dst})".format( + src=node0.type, edge=edge.type, dst=node1.type + ) ]