diff --git a/libs/community/langchain_community/graph_vectorstores/cassandra.py b/libs/community/langchain_community/graph_vectorstores/cassandra.py index f6c99d6cf0547f..e4a5fa030b4799 100644 --- a/libs/community/langchain_community/graph_vectorstores/cassandra.py +++ b/libs/community/langchain_community/graph_vectorstores/cassandra.py @@ -16,7 +16,6 @@ cast, ) -from cassio.table.mixins.metadata import MetadataMixin from langchain_core._api import beta from langchain_core.documents import Document from typing_extensions import override @@ -47,13 +46,6 @@ def __init__(self, node: Node, embedding: list[float]) -> None: self.embedding = embedding -def _serialize_metadata(md: dict[str, Any]) -> str: - if isinstance(md.get(METADATA_LINKS_KEY), set): - md = md.copy() - md[METADATA_LINKS_KEY] = list(md[METADATA_LINKS_KEY]) - return MetadataMixin._serialize_md_dict(md) - - def _serialize_links(links: list[Link]) -> str: class SetAndLinkEncoder(json.JSONEncoder): def default(self, obj: Any) -> Any: # noqa: ANN401 @@ -163,7 +155,16 @@ def __init__( exposed since CassandraGraphVectorStore only supports the deny_list option. """ + try: + from cassio.table.mixins.metadata import MetadataMixin + except (ImportError, ModuleNotFoundError): + raise ImportError( + "Could not import cassio python package. " + "Please install it with `pip install cassio`." + ) self.embedding = embedding + self._serialize_md_dict = MetadataMixin._serialize_md_dict + self._coerce_string = MetadataMixin._coerce_string deny_list = set(metadata_deny_list) deny_list.add(METADATA_LINKS_KEY) @@ -257,7 +258,7 @@ def add_nodes( ) for node_id, text, text_embedding, metadata, incoming_links in tuples: metadata_s = { - k: MetadataMixin._coerce_string(v) + k: self._coerce_string(v) for k, v in metadata.items() if k not in self._metadata_deny_list } @@ -267,7 +268,10 @@ def add_nodes( _metadata_link_value() ) - attributes_blob = _serialize_metadata(metadata) + if isinstance(metadata.get(METADATA_LINKS_KEY), set): + metadata = metadata.copy() + metadata[METADATA_LINKS_KEY] = list(metadata[METADATA_LINKS_KEY]) + attributes_blob = self._serialize_md_dict(metadata) futures.append( store_session.execute_async(