diff --git a/camel/storages/__init__.py b/camel/storages/__init__.py index 7863d9ec6b..37f5e52835 100644 --- a/camel/storages/__init__.py +++ b/camel/storages/__init__.py @@ -13,6 +13,7 @@ # =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== from .graph_storages.base import BaseGraphStorage +from .graph_storages.nebula_graph import NebulaGraph from .graph_storages.neo4j_graph import Neo4jGraph from .key_value_storages.base import BaseKeyValueStorage from .key_value_storages.in_memory import InMemoryKeyValueStorage @@ -40,4 +41,5 @@ 'MilvusStorage', 'BaseGraphStorage', 'Neo4jGraph', + 'NebulaGraph', ] diff --git a/camel/storages/graph_storages/__init__.py b/camel/storages/graph_storages/__init__.py index ae829acccc..f5898411c3 100644 --- a/camel/storages/graph_storages/__init__.py +++ b/camel/storages/graph_storages/__init__.py @@ -14,10 +14,12 @@ from .base import BaseGraphStorage from .graph_element import GraphElement +from .nebula_graph import NebulaGraph from .neo4j_graph import Neo4jGraph __all__ = [ 'BaseGraphStorage', 'GraphElement', 'Neo4jGraph', + 'NebulaGraph', ] diff --git a/camel/storages/graph_storages/nebula_graph.py b/camel/storages/graph_storages/nebula_graph.py new file mode 100644 index 0000000000..7220305119 --- /dev/null +++ b/camel/storages/graph_storages/nebula_graph.py @@ -0,0 +1,547 @@ +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== + +import time +from typing import TYPE_CHECKING, Any, Dict, List, Tuple + +if TYPE_CHECKING: + from nebula3.data.ResultSet import ( # type: ignore[import-untyped] + ResultSet, + ) + from nebula3.gclient.net import ( # type: ignore[import-untyped] + ConnectionPool, + Session, + ) + +from camel.storages.graph_storages.base import BaseGraphStorage +from camel.storages.graph_storages.graph_element import ( + GraphElement, +) +from camel.utils.commons import dependencies_required + +MAX_RETRIES = 5 +RETRY_DELAY = 3 + + +class NebulaGraph(BaseGraphStorage): + @dependencies_required('nebula3') + def __init__( + self, host, username, password, space, port=9669, timeout=10000 + ): + r"""Initializes the NebulaGraph client. + + Args: + host (str): The host address of the NebulaGraph service. + username (str): The username for authentication. + password (str): The password for authentication. + space (str): The graph space to use. If it doesn't exist, a new + one will be created. + port (int, optional): The port number for the connection. + (default: :obj:`9669`) + timeout (int, optional): The connection timeout in milliseconds. + (default: :obj:`10000`) + """ + self.host = host + self.username = username + self.password = password + self.space = space + self.timeout = timeout + self.port = port + self.schema: str = "" + self.structured_schema: Dict[str, Any] = {} + self.connection_pool = self._init_connection_pool() + self.session = self._get_session() + + def _init_connection_pool(self) -> "ConnectionPool": + r"""Initialize the connection pool. + + Returns: + ConnectionPool: A connection pool instance. + + Raises: + Exception: If the connection pool initialization fails. + """ + from nebula3.Config import Config # type: ignore[import-untyped] + from nebula3.gclient.net import ConnectionPool + + config = Config() + config.max_connection_pool_size = 10 + config.timeout = self.timeout + + # Create the connection pool + connection_pool = ConnectionPool() + + # Initialize the connection pool with Nebula Graph's address and port + if not connection_pool.init([(self.host, self.port)], config): + raise Exception("Failed to initialize the connection pool") + + return connection_pool + + def _get_session(self) -> "Session": + r"""Get a session from the connection pool. + + Returns: + Session: A session object connected to NebulaGraph. + + Raises: + Exception: If session creation or space usage fails. + """ + session = self.connection_pool.get_session( + self.username, self.password + ) + if not session: + raise Exception("Failed to create a session") + + # Use the specified space + session.execute( + f"CREATE SPACE IF NOT EXISTS {self.space} " + "(vid_type=FIXED_STRING(30));" + ) + + for attempt in range(MAX_RETRIES): + res = session.execute(f"USE {self.space};") + + if res.is_succeeded(): + return session + + if attempt < MAX_RETRIES - 1: + time.sleep(RETRY_DELAY) + else: + # Final attempt failed, raise an exception + raise Exception( + f"Failed to execute `{self.space}` after " + f"{MAX_RETRIES} attempts: {res.error_msg()}" + ) + + @property + def get_client(self) -> Any: + r"""Get the underlying graph storage client.""" + return self.session + + def query(self, query: str) -> "ResultSet": # type:ignore[override] + r"""Execute a query on the graph store. + + Args: + query (str): The Cypher-like query to be executed. + + Returns: + ResultSet: The result set of the query execution. + + Raises: + ValueError: If the query execution fails. + """ + try: + # Get the session + result_set = self.session.execute(query) + return result_set + + except Exception as e: + raise ValueError(f"Query execution error: {e!s}") + + def get_relationship_types(self) -> List[str]: + r"""Retrieve relationship types from the graph. + + Returns: + List[str]: A list of relationship (edge) type names. + """ + # Query all edge types + result = self.query('SHOW EDGES') + rel_types = [] + + # Extract relationship type names + for row in result.rows(): + edge_name = row.values[0].get_sVal().decode('utf-8') + rel_types.append(edge_name) + + return rel_types + + def add_graph_elements( + self, + graph_elements: List[GraphElement], + ) -> None: + r"""Add graph elements (nodes and relationships) to the graph. + + Args: + graph_elements (List[GraphElement]): A list of graph elements + containing nodes and relationships. + """ + nodes = self._extract_nodes(graph_elements) + for node in nodes: + self.add_node(node['id'], node['type']) + + relationships = self._extract_relationships(graph_elements) + for rel in relationships: + self.add_triplet(rel['subj']['id'], rel['obj']['id'], rel['type']) + + def ensure_edge_type_exists( + self, + edge_type: str, + ) -> None: + r"""Ensures that a specified edge type exists in the NebulaGraph + database. If the edge type already exists, this method does nothing. + + Args: + edge_type (str): The name of the edge type to be created. + + Raises: + Exception: If the edge type creation fails after multiple retry + attempts, an exception is raised with the error message. + """ + create_edge_stmt = f'CREATE EDGE IF NOT EXISTS {edge_type}()' + + for attempt in range(MAX_RETRIES): + res = self.query(create_edge_stmt) + if res.is_succeeded(): + return # Tag creation succeeded, exit the method + + if attempt < MAX_RETRIES - 1: + time.sleep(RETRY_DELAY) + else: + # Final attempt failed, raise an exception + raise Exception( + f"Failed to create tag `{edge_type}` after " + f"{MAX_RETRIES} attempts: {res.error_msg()}" + ) + + def ensure_tag_exists(self, tag_name: str) -> None: + r"""Ensures a tag is created in the NebulaGraph database. If the tag + already exists, it does nothing. + + Args: + tag_name (str): The name of the tag to be created. + + Raises: + Exception: If the tag creation fails after retries, an exception + is raised with the error message. + """ + + create_tag_stmt = f'CREATE TAG IF NOT EXISTS {tag_name}()' + + for attempt in range(MAX_RETRIES): + res = self.query(create_tag_stmt) + if res.is_succeeded(): + return # Tag creation succeeded, exit the method + + if attempt < MAX_RETRIES - 1: + time.sleep(RETRY_DELAY) + else: + # Final attempt failed, raise an exception + raise Exception( + f"Failed to create tag `{tag_name}` after " + f"{MAX_RETRIES} attempts: {res.error_msg()}" + ) + + def add_node( + self, + node_id: str, + tag_name: str, + ) -> None: + r"""Add a node with the specified tag and properties. + + Args: + node_id (str): The ID of the node. + tag_name (str): The tag name of the node. + """ + self.ensure_tag_exists(tag_name) + + # Insert node without properties + insert_stmt = ( + f'INSERT VERTEX IF NOT EXISTS {tag_name}() VALUES "{node_id}":()' + ) + + for attempt in range(MAX_RETRIES): + res = self.query(insert_stmt) + if res.is_succeeded(): + return # Tag creation succeeded, exit the method + + if attempt < MAX_RETRIES - 1: + time.sleep(RETRY_DELAY) + else: + # Final attempt failed, raise an exception + raise Exception( + f"Failed to add node `{node_id}` after" + f" {MAX_RETRIES} attempts: {res.error_msg()}" + ) + + def _extract_nodes(self, graph_elements: List[Any]) -> List[Dict]: + r"""Extracts unique nodes from graph elements. + + Args: + graph_elements (List[Any]): A list of graph elements containing + nodes. + + Returns: + List[Dict]: A list of dictionaries representing nodes. + """ + nodes = [] + seen_nodes = set() + for graph_element in graph_elements: + for node in graph_element.nodes: + node_key = (node.id, node.type) + if node_key not in seen_nodes: + nodes.append( + { + 'id': node.id, + 'type': node.type, + 'properties': node.properties, + } + ) + seen_nodes.add(node_key) + return nodes + + def _extract_relationships(self, graph_elements: List[Any]) -> List[Dict]: + r"""Extracts relationships from graph elements. + + Args: + graph_elements (List[Any]): A list of graph elements containing + relationships. + + Returns: + List[Dict]: A list of dictionaries representing relationships. + """ + relationships = [] + for graph_element in graph_elements: + for rel in graph_element.relationships: + relationship_dict = { + 'subj': {'id': rel.subj.id, 'type': rel.subj.type}, + 'obj': {'id': rel.obj.id, 'type': rel.obj.type}, + 'type': rel.type, + } + relationships.append(relationship_dict) + return relationships + + def refresh_schema(self) -> None: + r"""Refreshes the schema by fetching the latest schema details.""" + self.schema = self.get_schema() + self.structured_schema = self.get_structured_schema + + @property + def get_structured_schema(self) -> Dict[str, Any]: + r"""Generates a structured schema consisting of node and relationship + properties, relationships, and metadata. + + Returns: + Dict[str, Any]: A dictionary representing the structured schema. + """ + _, node_properties = self.get_node_properties() + _, rel_properties = self.get_relationship_properties() + relationships = self.get_relationship_types() + index = self.get_indexes() + + # Build structured_schema + structured_schema = { + "node_props": { + el["labels"]: el["properties"] for el in node_properties + }, + "rel_props": { + el["type"]: el["properties"] for el in rel_properties + }, + "relationships": relationships, + "metadata": {"index": index}, + } + + return structured_schema + + def get_schema(self): + r"""Generates a schema string describing node and relationship + properties and relationships. + + Returns: + str: A string describing the schema. + """ + # Get all node and relationship properties + formatted_node_props, _ = self.get_node_properties() + formatted_rel_props, _ = self.get_relationship_properties() + formatted_rels = self.get_relationship_types() + + # Generate schema string + schema = "\n".join( + [ + "Node properties are the following:", + ", ".join(formatted_node_props), + "Relationship properties are the following:", + ", ".join(formatted_rel_props), + "The relationships are the following:", + ", ".join(formatted_rels), + ] + ) + + return schema + + def get_indexes(self): + r"""Fetches the tag indexes from the database. + + Returns: + List[str]: A list of tag index names. + """ + result = self.query('SHOW TAG INDEXES') + indexes = [] + + # Get tag indexes + for row in result.rows(): + index_name = row.values[0].get_sVal().decode('utf-8') + indexes.append(index_name) + + return indexes + + def add_triplet( + self, + subj: str, + obj: str, + rel: str, + ) -> None: + r"""Adds a relationship (triplet) between two entities in the Nebula + Graph database. + + Args: + subj (str): The identifier for the subject entity. + obj (str): The identifier for the object entity. + rel (str): The relationship between the subject and object. + """ + self.ensure_tag_exists(subj) + self.ensure_tag_exists(obj) + self.ensure_edge_type_exists(rel) + self.add_node(node_id=subj, tag_name=subj) + self.add_node(node_id=obj, tag_name=obj) + + # Avoid latenicy + time.sleep(1) + + insert_stmt = ( + f'INSERT EDGE IF NOT EXISTS {rel}() VALUES "{subj}"->"{obj}":();' + ) + + res = self.query(insert_stmt) + if not res.is_succeeded(): + raise Exception( + f'create relationship `]{subj}` -> `{obj}`' + + f'failed: {res.error_msg()}' + ) + + def delete_triplet(self, subj: str, obj: str, rel: str) -> None: + r"""Deletes a specific triplet (relationship between two entities) + from the Nebula Graph database. + + Args: + subj (str): The identifier for the subject entity. + obj (str): The identifier for the object entity. + rel (str): The relationship between the subject and object. + """ + delete_edge_query = f'DELETE EDGE {rel} "{subj}"->"{obj}";' + self.query(delete_edge_query) + + if not self._check_edges(subj): + self.delete_entity(subj) + if not self._check_edges(obj): + self.delete_entity(obj) + + def delete_entity(self, entity_id: str) -> None: + r"""Deletes an entity (vertex) from the graph. + + Args: + entity_id (str): The identifier of the entity to be deleted. + """ + delete_vertex_query = f'DELETE VERTEX "{entity_id}";' + self.query(delete_vertex_query) + + def _check_edges(self, entity_id: str) -> bool: + r"""Checks if an entity has any remaining edges in the graph. + + Args: + entity_id (str): The identifier of the entity. + + Returns: + bool: :obj:`True` if the entity has edges, :obj:`False` otherwise. + """ + # Combine the outgoing and incoming edge count query + check_query = f""" + (GO FROM {entity_id} OVER * YIELD count(*) as out_count) + UNION + (GO FROM {entity_id} REVERSELY OVER * YIELD count(*) as in_count) + """ + + # Execute the query + result = self.query(check_query) + + # Check if the result contains non-zero edges + if result.is_succeeded(): + rows = result.rows() + total_count = sum(int(row.values[0].get_iVal()) for row in rows) + return total_count > 0 + else: + return False + + def get_node_properties(self) -> Tuple[List[str], List[Dict[str, Any]]]: + r"""Retrieve node properties from the graph. + + Returns: + Tuple[List[str], List[Dict[str, Any]]]: A tuple where the first + element is a list of node schema properties, and the second + element is a list of dictionaries representing node structures. + """ + # Query all tags + result = self.query('SHOW TAGS') + node_schema_props = [] + node_structure_props = [] + + # Iterate through each tag to get its properties + for row in result.rows(): + tag_name = row.values[0].get_sVal().decode('utf-8') + describe_result = self.query(f'DESCRIBE TAG {tag_name}') + properties = [] + + for prop_row in describe_result.rows(): + prop_name = prop_row.values[0].get_sVal().decode('utf-8') + node_schema_props.append(f"{tag_name}.{prop_name}") + properties.append(prop_name) + + node_structure_props.append( + {"labels": tag_name, "properties": properties} + ) + + return node_schema_props, node_structure_props + + def get_relationship_properties( + self, + ) -> Tuple[List[str], List[Dict[str, Any]]]: + r"""Retrieve relationship (edge) properties from the graph. + + Returns: + Tuple[List[str], List[Dict[str, Any]]]: A tuple where the first + element is a list of relationship schema properties, and the + second element is a list of dictionaries representing + relationship structures. + """ + + # Query all edge types + result = self.query('SHOW EDGES') + rel_schema_props = [] + rel_structure_props = [] + + # Iterate through each edge type to get its properties + for row in result.rows(): + edge_name = row.values[0].get_sVal().decode('utf-8') + describe_result = self.query(f'DESCRIBE EDGE {edge_name}') + properties = [] + + for prop_row in describe_result.rows(): + prop_name = prop_row.values[0].get_sVal().decode('utf-8') + rel_schema_props.append(f"{edge_name}.{prop_name}") + properties.append(prop_name) + + rel_structure_props.append( + {"type": edge_name, "properties": properties} + ) + + return rel_schema_props, rel_structure_props diff --git a/examples/storages/nebular_graph.py b/examples/storages/nebular_graph.py new file mode 100644 index 0000000000..1c00f105d0 --- /dev/null +++ b/examples/storages/nebular_graph.py @@ -0,0 +1,104 @@ +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== + +from unstructured.documents.elements import Element + +from camel.storages.graph_storages import NebulaGraph +from camel.storages.graph_storages.graph_element import ( + GraphElement, + Node, + Relationship, +) + +# Step 2: Initialize the NebulaGraph client +host = '127.0.0.1' +username = 'root' +password = 'nebula' +space = 'space_name' + +nebula_graph = NebulaGraph(host, username, password, space) + +# Ensure necessary tags (node types) exist +nebula_graph.ensure_tag_exists("CAMEL_AI") +nebula_graph.ensure_tag_exists("Agent_Framework") + +# Show existing tags +query = 'SHOW TAGS;' +print(nebula_graph.query(query)) + +""" +============================================================================== +ResultSet(keys: ['Name'], values: ["CAMEL_AI"],["Agent_Framework"]) +============================================================================== +""" + +# Add triplet +nebula_graph.add_triplet( + subj="CAMEL_AI", obj="Agent_Framework", rel="contribute_to" +) + +# Check structured schema +print(nebula_graph.get_structured_schema) + +""" +============================================================================== +{'node_props': {'CAMEL_AI': [], 'Agent_Framework': []}, 'rel_props': +{'contribute_to': []}, 'relationships': ['contribute_to'], 'metadata': +{'index': []}} +============================================================================== +""" + +# Delete triplet +nebula_graph.delete_triplet( + subj="CAMEL_AI", obj="Agent_Framework", rel="contribute_to" +) + +# Create and add graph element +node_camel = Node( + id="CAMEL_AI", + type="Agent_Framework", +) +node_nebula = Node( + id="Nebula", + type="Graph_Database", +) + +graph_elements = [ + GraphElement( + nodes=[node_camel, node_nebula], + relationships=[ + Relationship( + subj=node_camel, + obj=node_nebula, + type="Supporting", + ) + ], + source=Element(element_id="a05b820b51c760a41415c57c1eef8f08"), + ) +] + +# Add this graph element to graph db +nebula_graph.add_graph_elements(graph_elements) + +# Get structured schema +print(nebula_graph.get_structured_schema) + +""" +============================================================================== +{'node_props': {'Agent_Framework': [], 'CAMEL_AI': [], 'Graph_Database': [], +'Nebula': [], 'agent_framework': []}, 'rel_props': {'Supporting': [], +'contribute_to': []}, 'relationships': ['Supporting', 'contribute_to'], +'metadata': {'index': []}} +============================================================================== +""" diff --git a/poetry.lock b/poetry.lock index c51e00312e..ecc531c8b9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "accelerate" @@ -564,13 +564,13 @@ css = ["tinycss2 (>=1.1.0,<1.3)"] [[package]] name = "botocore" -version = "1.35.19" +version = "1.35.20" description = "Low-level, data-driven core of boto 3." optional = true python-versions = ">=3.8" files = [ - {file = "botocore-1.35.19-py3-none-any.whl", hash = "sha256:c83f7f0cacfe7c19b109b363ebfa8736e570d24922f16ed371681f58ebab44a9"}, - {file = "botocore-1.35.19.tar.gz", hash = "sha256:42d6d8db7250cbd7899f786f9861e02cab17dc238f64d6acb976098ed9809625"}, + {file = "botocore-1.35.20-py3-none-any.whl", hash = "sha256:62412038f960691a299e60492f9ee7e8e75af563f2eca7f3640b3b54b8f5d236"}, + {file = "botocore-1.35.20.tar.gz", hash = "sha256:82ad8a73fcd5852d127461c8dadbe40bf679f760a4efb0dde8d4d269ad3f126f"}, ] [package.dependencies] @@ -1281,13 +1281,13 @@ dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"] [[package]] name = "diffusers" -version = "0.30.2" +version = "0.30.3" description = "State-of-the-art diffusion in PyTorch and JAX." optional = true python-versions = ">=3.8.0" files = [ - {file = "diffusers-0.30.2-py3-none-any.whl", hash = "sha256:739826043147c2b59560944591dfdea5d24cd4fb15e751abbe20679a289bece8"}, - {file = "diffusers-0.30.2.tar.gz", hash = "sha256:641875f78f36bdfa4b9af752b124d1fd6d431eadd5547fe0a3f354ae0af2636c"}, + {file = "diffusers-0.30.3-py3-none-any.whl", hash = "sha256:1b70209e4d2c61223b96a7e13bc4d70869c8b0b68f54a35ce3a67fcf813edeee"}, + {file = "diffusers-0.30.3.tar.gz", hash = "sha256:67c5eb25d5b50bf0742624ef43fe0f6d1e1604f64aad3e8558469cbe89ecf72f"}, ] [package.dependencies] @@ -1916,6 +1916,17 @@ test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask-expr", "dask[dataframe, test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr", "zstandard"] tqdm = ["tqdm"] +[[package]] +name = "future" +version = "1.0.0" +description = "Clean single-source support for Python 3 and 2" +optional = true +python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" +files = [ + {file = "future-1.0.0-py3-none-any.whl", hash = "sha256:929292d34f5872e70396626ef385ec22355a1fae8ad29e1a734c3e43f9fbc216"}, + {file = "future-1.0.0.tar.gz", hash = "sha256:bd2968309307861edae1458a4f8a4f3598c03be43b97521076aebf5d94c07b05"}, +] + [[package]] name = "geojson" version = "2.5.0" @@ -2582,13 +2593,13 @@ license = ["ukkonen"] [[package]] name = "idna" -version = "3.9" +version = "3.10" description = "Internationalized Domain Names in Applications (IDNA)" optional = false python-versions = ">=3.6" files = [ - {file = "idna-3.9-py3-none-any.whl", hash = "sha256:69297d5da0cc9281c77efffb4e730254dd45943f45bbfb461de5991713989b1e"}, - {file = "idna-3.9.tar.gz", hash = "sha256:e5c5dafde284f26e9e0f28f6ea2d6400abd5ca099864a67f576f3981c6476124"}, + {file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"}, + {file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"}, ] [package.extras] @@ -3012,13 +3023,13 @@ referencing = ">=0.31.0" [[package]] name = "jupyter-client" -version = "8.6.2" +version = "8.6.3" description = "Jupyter protocol implementation and client libraries" optional = false python-versions = ">=3.8" files = [ - {file = "jupyter_client-8.6.2-py3-none-any.whl", hash = "sha256:50cbc5c66fd1b8f65ecb66bc490ab73217993632809b6e505687de18e9dea39f"}, - {file = "jupyter_client-8.6.2.tar.gz", hash = "sha256:2bda14d55ee5ba58552a8c53ae43d215ad9868853489213f37da060ced54d8df"}, + {file = "jupyter_client-8.6.3-py3-none-any.whl", hash = "sha256:e8a19cc986cc45905ac3362915f410f3af85424b4c0905e94fa5f2cb08e8f23f"}, + {file = "jupyter_client-8.6.3.tar.gz", hash = "sha256:35b3a0947c4a6e9d589eb97d7d4cd5e90f910ee73101611f01283732bd6d9419"}, ] [package.dependencies] @@ -3282,13 +3293,13 @@ files = [ [[package]] name = "litellm" -version = "1.46.0" +version = "1.46.1" description = "Library to easily interface with LLM API providers" optional = true python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8" files = [ - {file = "litellm-1.46.0-py3-none-any.whl", hash = "sha256:40209dc6368677d03b21b2c9d9cb91937c9648f741d42bb5a8f992a1cd31fb42"}, - {file = "litellm-1.46.0.tar.gz", hash = "sha256:6707eb4b17a2eca714f81261c3b6f33297cd25470c4843b8297e345ebdff0560"}, + {file = "litellm-1.46.1-py3-none-any.whl", hash = "sha256:f6b78278cf21a38da0d10a8b3e7b1084b6410012552c0a413774d1c43706e5ba"}, + {file = "litellm-1.46.1.tar.gz", hash = "sha256:993c23d6f5e1d0f070b250d858a6ee87750a032e38f460f8c82385be854bc45f"}, ] [package.dependencies] @@ -4022,13 +4033,13 @@ testing-docutils = ["pygments", "pytest (>=8,<9)", "pytest-param-files (>=0.6.0, [[package]] name = "narwhals" -version = "1.8.0" +version = "1.8.1" description = "Extremely lightweight compatibility layer between dataframe libraries" optional = false python-versions = ">=3.8" files = [ - {file = "narwhals-1.8.0-py3-none-any.whl", hash = "sha256:73bde7b1721e1d95f749f6aacec12bc616fcccae5a926064d641c93e133fd548"}, - {file = "narwhals-1.8.0.tar.gz", hash = "sha256:b1572b8781273e5712ee76144b9b6f412f2b71f39d63053853322cc98201eeaa"}, + {file = "narwhals-1.8.1-py3-none-any.whl", hash = "sha256:91a3af813733df39a74f590fdd1bb0d2d6d8a33e32aa409f56d941c0a29f8cdd"}, + {file = "narwhals-1.8.1.tar.gz", hash = "sha256:97527778e11f39a1e5e2113b8fbb9ead788be41c0337f21852e684e378f583e8"}, ] [package.extras] @@ -4138,6 +4149,24 @@ nbformat = "*" sphinx = ">=1.8" traitlets = ">=5" +[[package]] +name = "nebula3-python" +version = "3.8.2" +description = "Python client for NebulaGraph v3" +optional = true +python-versions = ">=3.6.2" +files = [ + {file = "nebula3_python-3.8.2-py3-none-any.whl", hash = "sha256:8942ef87619f05115f643896408f8cbe602670405a3aeab01fdcc454eeabf0d7"}, + {file = "nebula3_python-3.8.2.tar.gz", hash = "sha256:889df21bac0f7ccad1d3a1807d9b736b2136770b24ed03d4fd49b76b3e2612ea"}, +] + +[package.dependencies] +future = ">=0.18.0" +httplib2 = ">=0.20.0" +httpx = {version = ">=0.22.0", extras = ["http2"]} +pytz = ">=2021.1" +six = ">=1.16.0" + [[package]] name = "neo4j" version = "5.24.0" @@ -4567,13 +4596,13 @@ sympy = "*" [[package]] name = "openai" -version = "1.45.0" +version = "1.45.1" description = "The official Python library for the openai API" optional = false python-versions = ">=3.7.1" files = [ - {file = "openai-1.45.0-py3-none-any.whl", hash = "sha256:2f1f7b7cf90f038a9f1c24f0d26c0f1790c102ec5acd07ffd70a9b7feac1ff4e"}, - {file = "openai-1.45.0.tar.gz", hash = "sha256:731207d10637335413aa3c0955f8f8df30d7636a4a0f9c381f2209d32cf8de97"}, + {file = "openai-1.45.1-py3-none-any.whl", hash = "sha256:4a6cce402aec803ae57ae7eff4b5b94bf6c0e1703a8d85541c27243c2adeadf8"}, + {file = "openai-1.45.1.tar.gz", hash = "sha256:f79e384916b219ab2f028bbf9c778e81291c61eb0645ccfa1828a4b18b55d534"}, ] [package.dependencies] @@ -5200,19 +5229,19 @@ virtualenv = ">=20.10.0" [[package]] name = "primp" -version = "0.6.1" +version = "0.6.2" description = "HTTP client that can impersonate web browsers, mimicking their headers and `TLS/JA3/JA4/HTTP2` fingerprints" optional = true python-versions = ">=3.8" files = [ - {file = "primp-0.6.1-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:60cfe95e0bdf154b0f9036d38acaddc9aef02d6723ed125839b01449672d3946"}, - {file = "primp-0.6.1-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:e1e92433ecf32639f9e800bc3a5d58b03792bdec99421b7fb06500e2fae63c85"}, - {file = "primp-0.6.1-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e02353f13f07fb5a6f91df9e2f4d8ec9f41312de95088744dce1c9729a3865d"}, - {file = "primp-0.6.1-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:c5a2ccfdf488b17be225a529a31e2b22724b2e22fba8e1ae168a222f857c2dc0"}, - {file = "primp-0.6.1-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:f335c2ace907800a23bbb7bc6e15acc7fff659b86a2d5858817f6ed79cea07cf"}, - {file = "primp-0.6.1-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5dc15bd9d47ded7bc356fcb5d8321972dcbeba18e7d3b7250e12bb7365447b2b"}, - {file = "primp-0.6.1-cp38-abi3-win_amd64.whl", hash = "sha256:eebf0412ebba4089547b16b97b765d83f69f1433d811bb02b02cdcdbca20f672"}, - {file = "primp-0.6.1.tar.gz", hash = "sha256:64b3c12e3d463a887518811c46f3ec37cca02e6af1ddf1287e548342de436301"}, + {file = "primp-0.6.2-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:4a35d441462a55d9a9525bf170e2ffd2fcb3db6039b23e802859fa22c18cdd51"}, + {file = "primp-0.6.2-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:f67ccade95bdbca3cf9b96b93aa53f9617d85ddbf988da4e9c523aa785fd2d54"}, + {file = "primp-0.6.2-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8074b93befaf36567e4cf3d4a1a8cd6ab9cc6e4dd4ff710650678daa405aee71"}, + {file = "primp-0.6.2-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:7d3e2a3f8c6262e9b883651b79c4ff2b7677a76f47293a139f541c9ea333ce3b"}, + {file = "primp-0.6.2-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:a460ea389371c6d04839b4b50b5805d99da8ebe281a2e8b534d27377c6d44f0e"}, + {file = "primp-0.6.2-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5b6b27e89d3c05c811aff0e4fde7a36d6957b15b3112f4ce28b6b99e8ca1e725"}, + {file = "primp-0.6.2-cp38-abi3-win_amd64.whl", hash = "sha256:1006a40a85f88a4c5222094813a1ebc01f85a63e9a33d2c443288c0720bed321"}, + {file = "primp-0.6.2.tar.gz", hash = "sha256:5a96a6b65195a8a989157e67d23bd171c49be238654e02bdf1b1fda36cbcc068"}, ] [package.extras] @@ -5380,7 +5409,6 @@ description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs optional = true python-versions = ">=3.8" files = [ - {file = "pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629"}, {file = "pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034"}, ] @@ -5391,7 +5419,6 @@ description = "A collection of ASN.1-based protocols modules" optional = true python-versions = ">=3.8" files = [ - {file = "pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd"}, {file = "pyasn1_modules-0.4.1.tar.gz", hash = "sha256:c28e2dbf9c06ad61c71a075c7e0f9fd0f1b0bb2d2ad4377f240d33ac2ab60a7c"}, ] @@ -6285,13 +6312,13 @@ cffi = {version = "*", markers = "implementation_name == \"pypy\""} [[package]] name = "qdrant-client" -version = "1.11.1" +version = "1.11.2" description = "Client library for the Qdrant vector search engine" optional = true python-versions = ">=3.8" files = [ - {file = "qdrant_client-1.11.1-py3-none-any.whl", hash = "sha256:1375fad77c825c957181ff53775fb900c4383e817f864ea30b2605314da92f07"}, - {file = "qdrant_client-1.11.1.tar.gz", hash = "sha256:bfc23239b027073352ad92152209ec50281519686b7da3041612faece0fcdfbd"}, + {file = "qdrant_client-1.11.2-py3-none-any.whl", hash = "sha256:3151e3da61588ad138dfcd6760c2f13e57251c8b0c62001bfd0e03bb7bcd6c8e"}, + {file = "qdrant_client-1.11.2.tar.gz", hash = "sha256:0d5aa3f778077762963a754459c9c7144ba48e13dea62e559323924126a1b4a4"}, ] [package.dependencies] @@ -7228,18 +7255,18 @@ files = [ [[package]] name = "setuptools" -version = "74.1.2" +version = "75.1.0" description = "Easily download, build, install, upgrade, and uninstall Python packages" optional = true python-versions = ">=3.8" files = [ - {file = "setuptools-74.1.2-py3-none-any.whl", hash = "sha256:5f4c08aa4d3ebcb57a50c33b1b07e94315d7fc7230f7115e47fc99776c8ce308"}, - {file = "setuptools-74.1.2.tar.gz", hash = "sha256:95b40ed940a1c67eb70fc099094bd6e99c6ee7c23aa2306f4d2697ba7916f9c6"}, + {file = "setuptools-75.1.0-py3-none-any.whl", hash = "sha256:35ab7fd3bcd95e6b7fd704e4a1539513edad446c097797f2985e0e4b960772f2"}, + {file = "setuptools-75.1.0.tar.gz", hash = "sha256:d59a21b17a275fb872a9c3dae73963160ae079f1049ed956880cd7c09b120538"}, ] [package.extras] check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.5.2)"] -core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.text (>=3.7)", "more-itertools (>=8.8)", "packaging (>=24)", "platformdirs (>=2.6.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] +core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.collections", "jaraco.functools", "jaraco.text (>=3.7)", "more-itertools", "more-itertools (>=8.8)", "packaging", "packaging (>=24)", "platformdirs (>=2.6.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] cover = ["pytest-cov"] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"] enabler = ["pytest-enabler (>=2.2)"] @@ -7269,13 +7296,13 @@ files = [ [[package]] name = "slack-sdk" -version = "3.32.0" +version = "3.33.0" description = "The Slack API Platform SDK for Python" optional = true python-versions = ">=3.6" files = [ - {file = "slack_sdk-3.32.0-py2.py3-none-any.whl", hash = "sha256:f35e85f2847e6c25cf7c2d1df206ca0ad75556263fb592457bf03cca68ef64bb"}, - {file = "slack_sdk-3.32.0.tar.gz", hash = "sha256:af8fc4ef1d1cbcecd28d01acf6955a3bb5b13d56f0a43a1b1c7e3b212cc5ec5b"}, + {file = "slack_sdk-3.33.0-py2.py3-none-any.whl", hash = "sha256:853bb55154115d080cae342c4099f2ccb559a78ae8d0f5109b49842401a920fa"}, + {file = "slack_sdk-3.33.0.tar.gz", hash = "sha256:070eb1fb355c149a5f80fa0be6eeb5f5588e4ddff4dd76acf060454435cb037e"}, ] [package.extras] @@ -8183,13 +8210,13 @@ urllib3 = ">=2" [[package]] name = "types-setuptools" -version = "74.1.0.20240907" +version = "75.1.0.20240917" description = "Typing stubs for setuptools" optional = false python-versions = ">=3.8" files = [ - {file = "types-setuptools-74.1.0.20240907.tar.gz", hash = "sha256:0abdb082552ca966c1e5fc244e4853adc62971f6cd724fb1d8a3713b580e5a65"}, - {file = "types_setuptools-74.1.0.20240907-py3-none-any.whl", hash = "sha256:15b38c8e63ca34f42f6063ff4b1dd662ea20086166d5ad6a102e670a52574120"}, + {file = "types-setuptools-75.1.0.20240917.tar.gz", hash = "sha256:12f12a165e7ed383f31def705e5c0fa1c26215dd466b0af34bd042f7d5331f55"}, + {file = "types_setuptools-75.1.0.20240917-py3-none-any.whl", hash = "sha256:06f78307e68d1bbde6938072c57b81cf8a99bc84bd6dc7e4c5014730b097dc0c"}, ] [[package]] @@ -9071,9 +9098,9 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", type = ["pytest-mypy"] [extras] -all = ["PyMuPDF", "accelerate", "agentops", "azure-storage-blob", "beautifulsoup4", "botocore", "cohere", "datasets", "diffusers", "discord.py", "docker", "docx2txt", "duckduckgo-search", "firecrawl-py", "google-cloud-storage", "google-generativeai", "googlemaps", "imageio", "jupyter_client", "litellm", "mistralai", "neo4j", "newspaper3k", "nltk", "openapi-spec-validator", "opencv-python", "pillow", "prance", "praw", "pyTelegramBotAPI", "pydub", "pygithub", "pymilvus", "pyowm", "qdrant-client", "rank-bm25", "redis", "reka-api", "requests_oauthlib", "sentence-transformers", "sentencepiece", "slack-sdk", "soundfile", "textblob", "torch", "transformers", "unstructured", "wikipedia", "wolframalpha"] +all = ["PyMuPDF", "accelerate", "agentops", "azure-storage-blob", "beautifulsoup4", "botocore", "cohere", "datasets", "diffusers", "discord.py", "docker", "docx2txt", "duckduckgo-search", "firecrawl-py", "google-cloud-storage", "google-generativeai", "googlemaps", "imageio", "jupyter_client", "litellm", "mistralai", "nebula3-python", "neo4j", "newspaper3k", "nltk", "openapi-spec-validator", "opencv-python", "pillow", "prance", "praw", "pyTelegramBotAPI", "pydub", "pygithub", "pymilvus", "pyowm", "qdrant-client", "rank-bm25", "redis", "reka-api", "requests_oauthlib", "sentence-transformers", "sentencepiece", "slack-sdk", "soundfile", "textblob", "torch", "transformers", "unstructured", "wikipedia", "wolframalpha"] encoders = ["sentence-transformers"] -graph-storages = ["neo4j"] +graph-storages = ["nebula3-python", "neo4j"] huggingface-agent = ["accelerate", "datasets", "diffusers", "opencv-python", "sentencepiece", "soundfile", "torch", "transformers"] kv-stroages = ["redis"] model-platforms = ["google-generativeai", "litellm", "mistralai", "reka-api"] @@ -9086,4 +9113,4 @@ vector-databases = ["pymilvus", "qdrant-client"] [metadata] lock-version = "2.0" python-versions = ">=3.10.0,<3.12" -content-hash = "9b38b864079b2bf438a4a7b73b48cdc09ece0f9ca573edaa49cf3a5af762d34e" +content-hash = "03ee11047eab0683a44a057d74f7eb4f66ab2b2480439f523d2b4d2c9673bbe0" diff --git a/pyproject.toml b/pyproject.toml index de119e65b4..e398ede120 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,6 +101,7 @@ pymilvus = { version = "^2.4.0", optional = true } # graph-storages neo4j = { version = "^5.18.0", optional = true } +nebula3-python = { version = "3.8.2", optional = true } # key-value-storages redis = { version = "^5.0.6", optional = true } @@ -176,6 +177,7 @@ vector-databases = [ graph-storages = [ "neo4j", + "nebula3-python", ] kv-stroages = [ @@ -240,6 +242,7 @@ all = [ "sentence-transformers", # graph-storages "neo4j", + "nebula3-python", # retrievers "rank-bm25", # model platforms diff --git a/test/storages/graph_storages/test_nebula_graph.py b/test/storages/graph_storages/test_nebula_graph.py new file mode 100644 index 0000000000..d712cc8762 --- /dev/null +++ b/test/storages/graph_storages/test_nebula_graph.py @@ -0,0 +1,470 @@ +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +import unittest +from unittest.mock import Mock, patch + +from unstructured.documents.elements import Element + +from camel.storages import NebulaGraph +from camel.storages.graph_storages.graph_element import ( + GraphElement, + Node, + Relationship, +) + +MAX_RETRIES = 5 + + +class TestNebulaGraph(unittest.TestCase): + def setUp(self): + # Mock the dependencies and external interactions + self.host = 'localhost' + self.username = 'user' + self.password = 'pass' + self.space = 'test_space' + self.port = 9669 + self.timeout = 10000 + + # Patch the methods that interact with the database + patcher1 = patch.object(NebulaGraph, '_init_connection_pool') + patcher2 = patch.object(NebulaGraph, '_get_session') + self.addCleanup(patcher1.stop) + self.addCleanup(patcher2.stop) + self.mock_init_connection_pool = patcher1.start() + self.mock_get_session = patcher2.start() + + # Mock the return values + self.mock_connection_pool = Mock() + self.mock_session = Mock() + self.mock_init_connection_pool.return_value = self.mock_connection_pool + self.mock_get_session.return_value = self.mock_session + + # Initialize the NebulaGraph instance with the mocks + self.graph = NebulaGraph( + host=self.host, + username=self.username, + password=self.password, + space=self.space, + port=self.port, + timeout=self.timeout, + ) + + def test_query_success(self): + # Mock session.execute to return a successful result + mock_result_set = Mock() + self.mock_session.execute.return_value = mock_result_set + + query_str = 'SHOW SPACES;' + + result = self.graph.query(query_str) + self.mock_session.execute.assert_called_with(query_str) + self.assertEqual(result, mock_result_set) + + def test_query_exception(self): + # Mock session.execute to raise an exception + self.mock_session.execute.side_effect = Exception('Database error') + + query_str = 'INVALID QUERY;' + + with self.assertRaises(ValueError) as context: + self.graph.query(query_str) + self.assertIn('Query execution error', str(context.exception)) + + def test_get_relationship_types(self): + # Mock the query method + mock_result = Mock() + # Mock the rows returned + row1 = Mock() + row1.values = [Mock()] + row1.values[0].get_sVal.return_value = b'Relationship1' + row2 = Mock() + row2.values = [Mock()] + row2.values[0].get_sVal.return_value = b'Relationship2' + mock_result.rows.return_value = [row1, row2] + self.graph.query = Mock(return_value=mock_result) + + rel_types = self.graph.get_relationship_types() + self.graph.query.assert_called_with('SHOW EDGES') + self.assertEqual(rel_types, ['Relationship1', 'Relationship2']) + + def test_add_node(self): + node_id = 'node1' + tag_name = 'Tag1' + self.graph.ensure_tag_exists = Mock() + self.graph.query = Mock() + + self.graph.add_node(node_id, tag_name) + + self.graph.ensure_tag_exists.assert_called_with(tag_name) + insert_stmt = ( + f'INSERT VERTEX IF NOT EXISTS {tag_name}() VALUES "{node_id}":()' + ) + self.graph.query.assert_called_with(insert_stmt) + + def test_ensure_tag_exists_success(self): + tag_name = 'Tag1' + # Mock query to return a successful result + mock_result = Mock() + mock_result.is_succeeded.return_value = True + self.graph.query = Mock(return_value=mock_result) + + self.graph.ensure_tag_exists(tag_name) + + create_tag_stmt = f'CREATE TAG IF NOT EXISTS {tag_name}()' + self.graph.query.assert_called_with(create_tag_stmt) + + @patch('time.sleep', return_value=None) + def test_ensure_tag_exists_failure(self, mock_sleep): + tag_name = 'Tag1' + # Mock query to return a failed result every time + mock_result = Mock() + mock_result.is_succeeded.return_value = False + mock_result.error_msg.return_value = 'Error message' + self.graph.query = Mock(return_value=mock_result) + + with self.assertRaises(Exception) as context: + self.graph.ensure_tag_exists(tag_name) + + self.assertIn( + f"Failed to create tag `{tag_name}` after {MAX_RETRIES} attempts", + str(context.exception), + ) + self.assertEqual(self.graph.query.call_count, MAX_RETRIES) + + def test_add_triplet(self): + subj = 'node1' + obj = 'node2' + rel = 'RELATES_TO' + self.graph.ensure_tag_exists = Mock() + self.graph.ensure_edge_type_exists = Mock() + self.graph.add_node = Mock() + self.graph.query = Mock() + + self.graph.add_triplet(subj, obj, rel) + + self.graph.ensure_tag_exists.assert_any_call(subj) + self.graph.ensure_tag_exists.assert_any_call(obj) + self.graph.ensure_edge_type_exists.assert_called_with(rel) + self.graph.add_node.assert_any_call(node_id=subj, tag_name=subj) + self.graph.add_node.assert_any_call(node_id=obj, tag_name=obj) + insert_stmt = ( + f'INSERT EDGE IF NOT EXISTS {rel}() VALUES "{subj}"->"{obj}":();' + ) + self.graph.query.assert_called_with(insert_stmt) + + def test_delete_triplet(self): + subj = 'node1' + obj = 'node2' + rel = 'RELATES_TO' + self.graph.query = Mock() + self.graph._check_edges = Mock(side_effect=[False, False]) + self.graph.delete_entity = Mock() + + self.graph.delete_triplet(subj, obj, rel) + + delete_edge_query = f'DELETE EDGE {rel} "{subj}"->"{obj}";' + self.graph.query.assert_called_with(delete_edge_query) + self.graph._check_edges.assert_any_call(subj) + self.graph._check_edges.assert_any_call(obj) + self.graph.delete_entity.assert_any_call(subj) + self.graph.delete_entity.assert_any_call(obj) + + def test_check_edges_with_edges(self): + entity_id = 'node1' + mock_result = Mock() + mock_result.is_succeeded.return_value = True + # Mock rows with counts indicating edges exist + row_out = Mock() + row_out.values = [Mock()] + row_out.values[0].get_iVal.return_value = 2 + row_in = Mock() + row_in.values = [Mock()] + row_in.values[0].get_iVal.return_value = 3 + mock_result.rows.return_value = [row_out, row_in] + + self.graph.query = Mock(return_value=mock_result) + + has_edges = self.graph._check_edges(entity_id) + self.assertTrue(has_edges) + self.graph.query.assert_called() + + def test_check_edges_no_edges(self): + entity_id = 'node1' + mock_result = Mock() + mock_result.is_succeeded.return_value = True + # Mock rows with counts indicating no edges + row_out = Mock() + row_out.values = [Mock()] + row_out.values[0].get_iVal.return_value = 0 + row_in = Mock() + row_in.values = [Mock()] + row_in.values[0].get_iVal.return_value = 0 + mock_result.rows.return_value = [row_out, row_in] + + self.graph.query = Mock(return_value=mock_result) + + has_edges = self.graph._check_edges(entity_id) + self.assertFalse(has_edges) + self.graph.query.assert_called() + + def test_get_node_properties(self): + # Mock query for 'SHOW TAGS' + mock_show_tags_result = Mock() + mock_show_tags_result.is_succeeded.return_value = True + row1 = Mock() + row1.values = [Mock()] + row1.values[0].get_sVal.return_value = b'Tag1' + row2 = Mock() + row2.values = [Mock()] + row2.values[0].get_sVal.return_value = b'Tag2' + mock_show_tags_result.rows.return_value = [row1, row2] + + # Mock query for 'DESCRIBE TAG ' + mock_describe_tag_result = Mock() + mock_describe_tag_result.is_succeeded.return_value = True + prop_row = Mock() + prop_row.values = [Mock()] + prop_row.values[0].get_sVal.return_value = b'prop1' + mock_describe_tag_result.rows.return_value = [prop_row] + + self.graph.query = Mock( + side_effect=[ + mock_show_tags_result, + mock_describe_tag_result, + mock_describe_tag_result, + ] + ) + + node_schema_props, node_structure_props = ( + self.graph.get_node_properties() + ) + + expected_node_schema_props = ['Tag1.prop1', 'Tag2.prop1'] + expected_node_structure_props = [ + {'labels': 'Tag1', 'properties': ['prop1']}, + {'labels': 'Tag2', 'properties': ['prop1']}, + ] + + self.assertEqual(node_schema_props, expected_node_schema_props) + self.assertEqual(node_structure_props, expected_node_structure_props) + + def test_get_relationship_properties(self): + # Mock query for 'SHOW EDGES' + mock_show_edges_result = Mock() + mock_show_edges_result.is_succeeded.return_value = True + edge_row1 = Mock() + edge_row1.values = [Mock()] + edge_row1.values[0].get_sVal.return_value = b'Edge1' + edge_row2 = Mock() + edge_row2.values = [Mock()] + edge_row2.values[0].get_sVal.return_value = b'Edge2' + mock_show_edges_result.rows.return_value = [edge_row1, edge_row2] + + # Mock query for 'DESCRIBE EDGE ' + mock_describe_edge_result = Mock() + mock_describe_edge_result.is_succeeded.return_value = True + prop_row = Mock() + prop_row.values = [Mock()] + prop_row.values[0].get_sVal.return_value = b'prop1' + mock_describe_edge_result.rows.return_value = [prop_row] + + self.graph.query = Mock( + side_effect=[ + mock_show_edges_result, + mock_describe_edge_result, + mock_describe_edge_result, + ] + ) + + rel_schema_props, rel_structure_props = ( + self.graph.get_relationship_properties() + ) + + expected_rel_schema_props = ['Edge1.prop1', 'Edge2.prop1'] + expected_rel_structure_props = [ + {'type': 'Edge1', 'properties': ['prop1']}, + {'type': 'Edge2', 'properties': ['prop1']}, + ] + + self.assertEqual(rel_schema_props, expected_rel_schema_props) + self.assertEqual(rel_structure_props, expected_rel_structure_props) + + def test_extract_nodes(self): + # Use actual Node instances + node1 = Node(id='node1', type='Tag1', properties={}) + node2 = Node(id='node2', type='Tag2', properties={}) + # Create a GraphElement with nodes and optional source + graph_elements = [ + GraphElement( + nodes=[node1, node2], + relationships=[], + source=Element(element_id="a05b820b51c760a41415c57c1eef8f08"), + ) + ] + # Call the method + nodes = self.graph._extract_nodes(graph_elements) + # Expected result + expected_nodes = [ + {'id': 'node1', 'type': 'Tag1', 'properties': {}}, + {'id': 'node2', 'type': 'Tag2', 'properties': {}}, + ] + # Assert + self.assertEqual(nodes, expected_nodes) + + def test_extract_relationships(self): + # Use actual Node and Relationship instances + node1 = Node(id='node1', type='Tag1', properties={}) + node2 = Node(id='node2', type='Tag2', properties={}) + rel = Relationship( + subj=node1, obj=node2, type='RELATES_TO', properties={} + ) + # Create a GraphElement with relationships and optional source + graph_elements = [ + GraphElement( + nodes=[], + relationships=[rel], + source=Element(element_id="a05b820b51c760a41415c57c1eef8f08"), + ) + ] + # Call the method + relationships = self.graph._extract_relationships(graph_elements) + # Expected result + expected_relationships = [ + { + 'subj': {'id': 'node1', 'type': 'Tag1'}, + 'obj': {'id': 'node2', 'type': 'Tag2'}, + 'type': 'RELATES_TO', + } + ] + # Assert + self.assertEqual(relationships, expected_relationships) + + def test_get_indexes(self): + # Mock query for 'SHOW TAG INDEXES' + mock_show_indexes_result = Mock() + mock_show_indexes_result.is_succeeded.return_value = True + index_row = Mock() + index_row.values = [Mock()] + index_row.values[0].get_sVal.return_value = b'index1' + mock_show_indexes_result.rows.return_value = [index_row] + self.graph.query = Mock(return_value=mock_show_indexes_result) + + indexes = self.graph.get_indexes() + expected_indexes = ['index1'] + self.assertEqual(indexes, expected_indexes) + + def test_delete_entity(self): + entity_id = 'node1' + self.graph.query = Mock() + delete_vertex_query = f'DELETE VERTEX "{entity_id}";' + self.graph.delete_entity(entity_id) + self.graph.query.assert_called_with(delete_vertex_query) + + def test_get_schema(self): + self.graph.get_node_properties = Mock( + return_value=( + ['Node.prop'], + [{'labels': 'Node', 'properties': ['prop']}], + ) + ) + self.graph.get_relationship_properties = Mock( + return_value=( + ['Rel.prop'], + [{'type': 'Rel', 'properties': ['prop']}], + ) + ) + self.graph.get_relationship_types = Mock(return_value=['RELATES_TO']) + schema = self.graph.get_schema() + expected_schema = "\n".join( + [ + "Node properties are the following:", + "Node.prop", + "Relationship properties are the following:", + "Rel.prop", + "The relationships are the following:", + "RELATES_TO", + ] + ) + self.assertEqual(schema, expected_schema) + + def test_get_structured_schema(self): + self.graph.get_node_properties = Mock( + return_value=( + ['Node.prop'], + [{'labels': 'Node', 'properties': ['prop']}], + ) + ) + self.graph.get_relationship_properties = Mock( + return_value=( + ['Rel.prop'], + [{'type': 'Rel', 'properties': ['prop']}], + ) + ) + self.graph.get_relationship_types = Mock(return_value=['RELATES_TO']) + self.graph.get_indexes = Mock(return_value=['index1']) + structured_schema = self.graph.get_structured_schema + expected_schema = { + "node_props": {'Node': ['prop']}, + "rel_props": {'Rel': ['prop']}, + "relationships": ['RELATES_TO'], + "metadata": {"index": ['index1']}, + } + self.assertEqual(structured_schema, expected_schema) + + def test_add_graph_elements(self): + # Create actual Node and Relationship instances + node1 = Node(id='node1', type='Tag1', properties={}) + node2 = Node(id='node2', type='Tag2', properties={}) + rel = Relationship( + subj=node1, obj=node2, type='RELATES_TO', properties={} + ) + # Create a GraphElement instance with nodes, relationships, and source + graph_elements = [ + GraphElement( + nodes=[node1, node2], + relationships=[rel], + source=Element(element_id="a05b820b51c760a41415c57c1eef8f08"), + ) + ] + + # Mock the methods called within add_graph_elements + self.graph._extract_nodes = Mock( + return_value=[ + {'id': 'node1', 'type': 'Tag1'}, + {'id': 'node2', 'type': 'Tag2'}, + ] + ) + self.graph._extract_relationships = Mock( + return_value=[ + { + 'subj': {'id': 'node1'}, + 'obj': {'id': 'node2'}, + 'type': 'RELATES_TO', + } + ] + ) + self.graph.add_node = Mock() + self.graph.add_triplet = Mock() + + # Call the method under test + self.graph.add_graph_elements(graph_elements) + self.graph.add_node.assert_any_call('node1', 'Tag1') + self.graph.add_node.assert_any_call('node2', 'Tag2') + self.graph.add_triplet.assert_called_with( + 'node1', 'node2', 'RELATES_TO' + ) + + +if __name__ == '__main__': + unittest.main()