Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
epinzur committed Oct 10, 2024
1 parent 5ea2c44 commit 5bd1c47
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions libs/community/langchain_community/graph_vectorstores/cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
cast,
)

from cassio.table.mixins.metadata import MetadataMixin
from langchain_core._api import beta
from langchain_core.documents import Document
from typing_extensions import override
Expand Down Expand Up @@ -47,13 +46,6 @@ def __init__(self, node: Node, embedding: list[float]) -> None:
self.embedding = embedding


def _serialize_metadata(md: dict[str, Any]) -> str:
if isinstance(md.get(METADATA_LINKS_KEY), set):
md = md.copy()
md[METADATA_LINKS_KEY] = list(md[METADATA_LINKS_KEY])
return MetadataMixin._serialize_md_dict(md)


def _serialize_links(links: list[Link]) -> str:
class SetAndLinkEncoder(json.JSONEncoder):
def default(self, obj: Any) -> Any: # noqa: ANN401
Expand Down Expand Up @@ -163,7 +155,16 @@ def __init__(
exposed since CassandraGraphVectorStore only supports the
deny_list option.
"""
try:
from cassio.table.mixins.metadata import MetadataMixin
except (ImportError, ModuleNotFoundError):
raise ImportError(
"Could not import cassio python package. "
"Please install it with `pip install cassio`."
)
self.embedding = embedding
self._serialize_md_dict = MetadataMixin._serialize_md_dict
self._coerce_string = MetadataMixin._coerce_string

deny_list = set(metadata_deny_list)
deny_list.add(METADATA_LINKS_KEY)
Expand Down Expand Up @@ -257,7 +258,7 @@ def add_nodes(
)
for node_id, text, text_embedding, metadata, incoming_links in tuples:
metadata_s = {
k: MetadataMixin._coerce_string(v)
k: self._coerce_string(v)
for k, v in metadata.items()
if k not in self._metadata_deny_list
}
Expand All @@ -267,7 +268,10 @@ def add_nodes(
_metadata_link_value()
)

attributes_blob = _serialize_metadata(metadata)
if isinstance(metadata.get(METADATA_LINKS_KEY), set):
metadata = metadata.copy()
metadata[METADATA_LINKS_KEY] = list(metadata[METADATA_LINKS_KEY])
attributes_blob = self._serialize_md_dict(metadata)

futures.append(
store_session.execute_async(
Expand Down

0 comments on commit 5bd1c47

Please sign in to comment.