Skip to content

Commit

Permalink
simplified node insertion
Browse files Browse the repository at this point in the history
  • Loading branch information
epinzur committed Oct 14, 2024
1 parent a6add20 commit 034a485
Showing 1 changed file with 78 additions and 67 deletions.
145 changes: 78 additions & 67 deletions libs/community/langchain_community/graph_vectorstores/cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def default(self, obj: Any) -> Any: # noqa: ANN401
def _deserialize_links(json_blob: str | None) -> set[Link]:
return {
Link(kind=link["kind"], direction=link["direction"], tag=link["tag"])
for link in cast(list[dict[str, Any]], json.loads(json_blob or ""))
for link in cast(list[dict[str, Any]], json.loads(json_blob or "[]"))
}


Expand Down Expand Up @@ -155,16 +155,7 @@ def __init__(
exposed since CassandraGraphVectorStore only supports the
deny_list option.
"""
try:
from cassio.table.mixins.metadata import MetadataMixin
except (ImportError, ModuleNotFoundError):
raise ImportError(
"Could not import cassio python package. "
"Please install it with `pip install cassio`."
)
self.embedding = embedding
self._serialize_md_dict = MetadataMixin._serialize_md_dict
self._coerce_string = MetadataMixin._coerce_string

deny_list = set(metadata_deny_list)
deny_list.add(METADATA_LINKS_KEY)
Expand Down Expand Up @@ -193,7 +184,7 @@ def __init__(

@property
@override
def embeddings(self) -> Embeddings:
def embeddings(self) -> Embeddings | None:
return self.embedding

def _get_metadata_filter(
Expand All @@ -219,9 +210,44 @@ def _restore_links(self, doc: Document) -> Document:
"""
links = _deserialize_links(doc.metadata.get(METADATA_LINKS_KEY))
doc.metadata[METADATA_LINKS_KEY] = links
# TODO: Could this be skipped if we put these metadata entries
# only in the searchable `metadata_s` column?
for incoming_link_key in [
_metadata_link_key(link=link)
for link in links
if link.direction in ["in", "bidir"]
]:
if incoming_link_key in doc.metadata:
del doc.metadata[incoming_link_key]

return doc

# TODO: Async (aadd_nodes)
def _get_node_metadata_for_insertion(self, node: Node) -> dict[str, Any]:
metadata = node.metadata.copy()
metadata[METADATA_LINKS_KEY] = _serialize_links(node.links)
# TODO: Could we could put these metadata entries
# only in the searchable `metadata_s` column?
for incoming_link in _incoming_links(node=node):
metadata[_metadata_link_key(link=incoming_link)] = _metadata_link_value()
return metadata

def _get_docs_for_insertion(
self, nodes: Iterable[Node]
) -> tuple[list[Document], list[str]]:
docs = []
ids = []
for node in nodes:
node_id = secrets.token_hex(8) if not node.id else node.id

doc = Document(
page_content=node.text,
metadata=self._get_node_metadata_for_insertion(node=node),
id=node_id,
)
docs.append(doc)
ids.append(node_id)
return (docs, ids)

@override
def add_nodes(
self,
Expand All @@ -234,63 +260,24 @@ def add_nodes(
nodes: the nodes to add.
**kwargs: Additional keyword arguments.
"""
node_ids: list[str] = []
texts: list[str] = []
metadata_list: list[dict[str, Any]] = []
incoming_links_list: list[set[Link]] = []
for node in nodes:
if not node.id:
node_ids.append(secrets.token_hex(8))
else:
node_ids.append(node.id)
texts.append(node.text)
combined_metadata = node.metadata.copy()
combined_metadata[METADATA_LINKS_KEY] = _serialize_links(node.links)
metadata_list.append(combined_metadata)
incoming_links_list.append(_incoming_links(node=node))
(docs, ids) = self._get_docs_for_insertion(nodes=nodes)
return self.vector_store.add_documents(docs, ids=ids)

text_embeddings = self.embedding.embed_documents(texts)

futures = []
store_session: Session = self.vector_store.session
tuples = zip(
node_ids, texts, text_embeddings, metadata_list, incoming_links_list
)
for node_id, text, text_embedding, metadata, incoming_links in tuples:
metadata_s = {
k: self._coerce_string(v)
for k, v in metadata.items()
if k not in self._metadata_deny_list
}

for incoming_link in incoming_links:
metadata_s[_metadata_link_key(link=incoming_link)] = (
_metadata_link_value()
)

if isinstance(metadata.get(METADATA_LINKS_KEY), set):
metadata = metadata.copy()
metadata[METADATA_LINKS_KEY] = list(metadata[METADATA_LINKS_KEY])
attributes_blob = self._serialize_md_dict(metadata)

futures.append(
store_session.execute_async(
self._insert_node,
parameters=(
node_id,
text,
text_embedding,
attributes_blob,
metadata_s,
),
timeout=30.0,
)
)

for future in futures:
future.result()
@override
async def aadd_nodes(
self,
nodes: Iterable[Node],
**kwargs: Any,
) -> AsyncIterable[str]:
"""Add nodes to the graph store.
return node_ids
Args:
nodes: the nodes to add.
**kwargs: Additional keyword arguments.
"""
(docs, ids) = self._get_docs_for_insertion(nodes=nodes)
for inserted_id in await self.vector_store.aadd_documents(docs, ids=ids):
yield inserted_id

@classmethod
@override
Expand Down Expand Up @@ -475,6 +462,30 @@ async def ametadata_search(
)
]

def get_by_document_id(self, document_id: str) -> Document | None:
"""Retrieve a single document from the store, given its document ID.
Args:
document_id: The document ID
Returns:
The the document if it exists. Otherwise None.
"""
doc = self.vector_store.get_by_document_id(document_id=document_id)
return self._restore_links(doc) if doc is not None else None

async def aget_by_document_id(self, document_id: str) -> Document | None:
"""Retrieve a single document from the store, given its document ID.
Args:
document_id: The document ID
Returns:
The the document if it exists. Otherwise None.
"""
doc = await self.vector_store.aget_by_document_id(document_id=document_id)
return self._restore_links(doc) if doc is not None else None

def get_node(self, node_id: str) -> Node | None:
"""Retrieve a single node from the store, given its ID.
Expand Down

0 comments on commit 034a485

Please sign in to comment.