diff --git a/libs/community/langchain_community/graph_vectorstores/__init__.py b/libs/community/langchain_community/graph_vectorstores/__init__.py index 5773b224dc565a..ad044624b1b4f8 100644 --- a/libs/community/langchain_community/graph_vectorstores/__init__.py +++ b/libs/community/langchain_community/graph_vectorstores/__init__.py @@ -144,6 +144,7 @@ from langchain_community.graph_vectorstores.links import ( Link, ) +from langchain_community.graph_vectorstores.mmr_helper import MmrHelper __all__ = [ "GraphVectorStore", @@ -151,4 +152,5 @@ "Node", "Link", "CassandraGraphVectorStore", + "MmrHelper", ] diff --git a/libs/community/langchain_community/graph_vectorstores/base.py b/libs/community/langchain_community/graph_vectorstores/base.py index 97a47471311827..0a320d98f9eca9 100644 --- a/libs/community/langchain_community/graph_vectorstores/base.py +++ b/libs/community/langchain_community/graph_vectorstores/base.py @@ -1,11 +1,13 @@ from __future__ import annotations +import logging from abc import abstractmethod from collections.abc import AsyncIterable, Collection, Iterable, Iterator from typing import ( Any, ClassVar, Optional, + Sequence, ) from langchain_core._api import beta @@ -21,6 +23,8 @@ from langchain_community.graph_vectorstores.links import METADATA_LINKS_KEY, Link +logger = logging.getLogger(__name__) + def _has_next(iterator: Iterator) -> bool: """Checks if the iterator has more elements. @@ -158,6 +162,7 @@ def add_nodes( Args: nodes: the nodes to add. + **kwargs: Additional keyword arguments. """ async def aadd_nodes( @@ -169,6 +174,7 @@ async def aadd_nodes( Args: nodes: the nodes to add. + **kwargs: Additional keyword arguments. """ iterator = iter(await run_in_executor(None, self.add_nodes, nodes, **kwargs)) done = object() @@ -186,7 +192,7 @@ def add_texts( ids: Optional[Iterable[str]] = None, **kwargs: Any, ) -> list[str]: - """Run more texts through the embeddings and add to the vectorstore. + """Run more texts through the embeddings and add to the vector store. The Links present in the metadata field `links` will be extracted to create the `Node` links. @@ -214,15 +220,15 @@ def add_texts( ) Args: - texts: Iterable of strings to add to the vectorstore. + texts: Iterable of strings to add to the vector store. metadatas: Optional list of metadatas associated with the texts. The metadata key `links` shall be an iterable of :py:class:`~langchain_community.graph_vectorstores.links.Link`. ids: Optional list of IDs associated with the texts. - **kwargs: vectorstore specific parameters. + **kwargs: vector store specific parameters. Returns: - List of ids from adding the texts into the vectorstore. + List of ids from adding the texts into the vector store. """ nodes = _texts_to_nodes(texts, metadatas, ids) return list(self.add_nodes(nodes, **kwargs)) @@ -235,7 +241,7 @@ async def aadd_texts( ids: Optional[Iterable[str]] = None, **kwargs: Any, ) -> list[str]: - """Run more texts through the embeddings and add to the vectorstore. + """Run more texts through the embeddings and add to the vector store. The Links present in the metadata field `links` will be extracted to create the `Node` links. @@ -263,15 +269,15 @@ async def aadd_texts( ) Args: - texts: Iterable of strings to add to the vectorstore. + texts: Iterable of strings to add to the vector store. metadatas: Optional list of metadatas associated with the texts. The metadata key `links` shall be an iterable of :py:class:`~langchain_community.graph_vectorstores.links.Link`. ids: Optional list of IDs associated with the texts. - **kwargs: vectorstore specific parameters. + **kwargs: vector store specific parameters. Returns: - List of ids from adding the texts into the vectorstore. + List of ids from adding the texts into the vector store. """ nodes = _texts_to_nodes(texts, metadatas, ids) return [_id async for _id in self.aadd_nodes(nodes, **kwargs)] @@ -281,7 +287,7 @@ def add_documents( documents: Iterable[Document], **kwargs: Any, ) -> list[str]: - """Run more documents through the embeddings and add to the vectorstore. + """Run more documents through the embeddings and add to the vector store. The Links present in the document metadata field `links` will be extracted to create the `Node` links. @@ -316,7 +322,7 @@ def add_documents( ) Args: - documents: Documents to add to the vectorstore. + documents: Documents to add to the vector store. The document's metadata key `links` shall be an iterable of :py:class:`~langchain_community.graph_vectorstores.links.Link`. @@ -331,7 +337,7 @@ async def aadd_documents( documents: Iterable[Document], **kwargs: Any, ) -> list[str]: - """Run more documents through the embeddings and add to the vectorstore. + """Run more documents through the embeddings and add to the vector store. The Links present in the document metadata field `links` will be extracted to create the `Node` links. @@ -366,7 +372,7 @@ async def aadd_documents( ) Args: - documents: Documents to add to the vectorstore. + documents: Documents to add to the vector store. The document's metadata key `links` shall be an iterable of :py:class:`~langchain_community.graph_vectorstores.links.Link`. @@ -383,6 +389,7 @@ def traversal_search( *, k: int = 4, depth: int = 1, + filter: dict[str, Any] | None = None, # noqa: A002 **kwargs: Any, ) -> Iterable[Document]: """Retrieve documents from traversing this graph store. @@ -396,8 +403,10 @@ def traversal_search( k: The number of Documents to return from the initial search. Defaults to 4. Applies to each of the query strings. depth: The maximum depth of edges to traverse. Defaults to 1. + filter: Optional metadata to filter the results. + **kwargs: Additional keyword arguments. Returns: - Retrieved documents. + Collection of retrieved documents. """ async def atraversal_search( @@ -406,6 +415,7 @@ async def atraversal_search( *, k: int = 4, depth: int = 1, + filter: dict[str, Any] | None = None, # noqa: A002 **kwargs: Any, ) -> AsyncIterable[Document]: """Retrieve documents from traversing this graph store. @@ -419,12 +429,20 @@ async def atraversal_search( k: The number of Documents to return from the initial search. Defaults to 4. Applies to each of the query strings. depth: The maximum depth of edges to traverse. Defaults to 1. + filter: Optional metadata to filter the results. + **kwargs: Additional keyword arguments. Returns: - Retrieved documents. + Collection of retrieved documents. """ iterator = iter( await run_in_executor( - None, self.traversal_search, query, k=k, depth=depth, **kwargs + None, + self.traversal_search, + query, + k=k, + depth=depth, + filter=filter, + **kwargs, ) ) done = object() @@ -439,12 +457,14 @@ def mmr_traversal_search( self, query: str, *, + initial_roots: Sequence[str] = (), k: int = 4, depth: int = 2, fetch_k: int = 100, adjacent_k: int = 10, lambda_mult: float = 0.5, score_threshold: float = float("-inf"), + filter: dict[str, Any] | None = None, # noqa: A002 **kwargs: Any, ) -> Iterable[Document]: """Retrieve documents from this graph store using MMR-traversal. @@ -459,6 +479,10 @@ def mmr_traversal_search( Args: query: The query string to search for. + initial_roots: Optional list of document IDs to use for initializing search. + The top `adjacent_k` nodes adjacent to each initial root will be + included in the set of initial candidates. To fetch only in the + neighborhood of these nodes, set `fetch_k = 0`. k: Number of Documents to return. Defaults to 4. fetch_k: Number of Documents to fetch via similarity. Defaults to 100. @@ -471,18 +495,22 @@ def mmr_traversal_search( diversity and 1 to minimum diversity. Defaults to 0.5. score_threshold: Only documents with a score greater than or equal this threshold will be chosen. Defaults to negative infinity. + filter: Optional metadata to filter the results. + **kwargs: Additional keyword arguments. """ async def ammr_traversal_search( self, query: str, *, + initial_roots: Sequence[str] = (), k: int = 4, depth: int = 2, fetch_k: int = 100, adjacent_k: int = 10, lambda_mult: float = 0.5, score_threshold: float = float("-inf"), + filter: dict[str, Any] | None = None, # noqa: A002 **kwargs: Any, ) -> AsyncIterable[Document]: """Retrieve documents from this graph store using MMR-traversal. @@ -497,6 +525,10 @@ async def ammr_traversal_search( Args: query: The query string to search for. + initial_roots: Optional list of document IDs to use for initializing search. + The top `adjacent_k` nodes adjacent to each initial root will be + included in the set of initial candidates. To fetch only in the + neighborhood of these nodes, set `fetch_k = 0`. k: Number of Documents to return. Defaults to 4. fetch_k: Number of Documents to fetch via similarity. Defaults to 100. @@ -509,18 +541,22 @@ async def ammr_traversal_search( diversity and 1 to minimum diversity. Defaults to 0.5. score_threshold: Only documents with a score greater than or equal this threshold will be chosen. Defaults to negative infinity. + filter: Optional metadata to filter the results. + **kwargs: Additional keyword arguments. """ iterator = iter( await run_in_executor( None, self.mmr_traversal_search, query, + initial_roots=initial_roots, k=k, fetch_k=fetch_k, adjacent_k=adjacent_k, depth=depth, lambda_mult=lambda_mult, score_threshold=score_threshold, + filter=filter, **kwargs, ) ) @@ -544,6 +580,11 @@ def max_marginal_relevance_search( lambda_mult: float = 0.5, **kwargs: Any, ) -> list[Document]: + if kwargs.get("depth", 0) > 0: + logger.warning( + "'mmr' search started with depth > 0. " + "Maybe you meant to do a 'mmr_traversal' search?" + ) return list( self.mmr_traversal_search( query, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, depth=0 @@ -573,7 +614,7 @@ def search(self, query: str, search_type: str, **kwargs: Any) -> list[Document]: raise ValueError( f"search_type of {search_type} not allowed. Expected " "search_type to be 'similarity', 'similarity_score_threshold', " - "'mmr' or 'traversal'." + "'mmr', 'traversal', or 'mmr_traversal'." ) async def asearch( @@ -590,11 +631,13 @@ async def asearch( return await self.amax_marginal_relevance_search(query, **kwargs) elif search_type == "traversal": return [doc async for doc in self.atraversal_search(query, **kwargs)] + elif search_type == "mmr_traversal": + return [doc async for doc in self.ammr_traversal_search(query, **kwargs)] else: raise ValueError( f"search_type of {search_type} not allowed. Expected " "search_type to be 'similarity', 'similarity_score_threshold', " - "'mmr' or 'traversal'." + "'mmr', 'traversal', or 'mmr_traversal'." ) def as_retriever(self, **kwargs: Any) -> GraphVectorStoreRetriever: @@ -606,13 +649,14 @@ def as_retriever(self, **kwargs: Any) -> GraphVectorStoreRetriever: - search_type (Optional[str]): Defines the type of search that the Retriever should perform. - Can be ``traversal`` (default), ``similarity``, ``mmr``, or - ``similarity_score_threshold``. + Can be ``traversal`` (default), ``similarity``, ``mmr``, + ``mmr_traversal``, or ``similarity_score_threshold``. - search_kwargs (Optional[Dict]): Keyword arguments to pass to the search function. Can include things like: - k(int): Amount of documents to return (Default: 4). - depth(int): The maximum depth of edges to traverse (Default: 1). + Only applies to search_type: ``traversal`` and ``mmr_traversal``. - score_threshold(float): Minimum relevance threshold for similarity_score_threshold. - fetch_k(int): Amount of documents to pass to MMR algorithm @@ -629,21 +673,21 @@ def as_retriever(self, **kwargs: Any) -> GraphVectorStoreRetriever: # Retrieve documents traversing edges docsearch.as_retriever( search_type="traversal", - search_kwargs={'k': 6, 'depth': 3} + search_kwargs={'k': 6, 'depth': 2} ) - # Retrieve more documents with higher diversity + # Retrieve documents with higher diversity # Useful if your dataset has many similar documents docsearch.as_retriever( - search_type="mmr", - search_kwargs={'k': 6, 'lambda_mult': 0.25} + search_type="mmr_traversal", + search_kwargs={'k': 6, 'lambda_mult': 0.25, 'depth': 2} ) # Fetch more documents for the MMR algorithm to consider # But only return the top 5 docsearch.as_retriever( - search_type="mmr", - search_kwargs={'k': 5, 'fetch_k': 50} + search_type="mmr_traversal", + search_kwargs={'k': 5, 'fetch_k': 50, 'depth': 2} ) # Only retrieve documents that have a relevance score @@ -657,7 +701,7 @@ def as_retriever(self, **kwargs: Any) -> GraphVectorStoreRetriever: docsearch.as_retriever(search_kwargs={'k': 1}) """ - return GraphVectorStoreRetriever(vectorstore=self, **kwargs) + return GraphVectorStoreRetriever(vector_store=self, **kwargs) @beta(message="Added in version 0.3.1 of langchain_community. API subject to change.") @@ -744,7 +788,7 @@ class GraphVectorStoreRetriever(VectorStoreRetriever): Passing search parameters ------------------------- - We can pass parameters to the underlying graph vectorstore's search methods using + We can pass parameters to the underlying graph vector store's search methods using ``search_kwargs``. Specifying graph traversal depth @@ -793,7 +837,7 @@ class GraphVectorStoreRetriever(VectorStoreRetriever): retriever = graph_vectorstore.as_retriever(search_kwargs={"score_threshold": 0.5}) """ # noqa: E501 - vectorstore: GraphVectorStore + vector_store: GraphVectorStore """GraphVectorStore to use for retrieval.""" search_type: str = "traversal" """Type of search to perform. Defaults to "traversal".""" @@ -809,10 +853,10 @@ def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> list[Document]: if self.search_type == "traversal": - return list(self.vectorstore.traversal_search(query, **self.search_kwargs)) + return list(self.vector_store.traversal_search(query, **self.search_kwargs)) elif self.search_type == "mmr_traversal": return list( - self.vectorstore.mmr_traversal_search(query, **self.search_kwargs) + self.vector_store.mmr_traversal_search(query, **self.search_kwargs) ) else: return super()._get_relevant_documents(query, run_manager=run_manager) @@ -823,14 +867,14 @@ async def _aget_relevant_documents( if self.search_type == "traversal": return [ doc - async for doc in self.vectorstore.atraversal_search( + async for doc in self.vector_store.atraversal_search( query, **self.search_kwargs ) ] elif self.search_type == "mmr_traversal": return [ doc - async for doc in self.vectorstore.ammr_traversal_search( + async for doc in self.vector_store.ammr_traversal_search( query, **self.search_kwargs ) ] diff --git a/libs/community/langchain_community/graph_vectorstores/cassandra.py b/libs/community/langchain_community/graph_vectorstores/cassandra.py index 42525cf03dded8..5b377d61f51120 100644 --- a/libs/community/langchain_community/graph_vectorstores/cassandra.py +++ b/libs/community/langchain_community/graph_vectorstores/cassandra.py @@ -1,27 +1,107 @@ +"""Apache Cassandra DB graph vector store integration.""" + from __future__ import annotations +import asyncio +import json +import logging +import secrets +from dataclasses import asdict, is_dataclass from typing import ( TYPE_CHECKING, Any, + AsyncIterable, Iterable, List, Optional, + Sequence, + Tuple, Type, + TypeVar, + cast, ) from langchain_core._api import beta from langchain_core.documents import Document -from langchain_core.embeddings import Embeddings +from typing_extensions import override -from langchain_community.graph_vectorstores.base import ( - GraphVectorStore, - Node, - nodes_to_documents, -) +from langchain_community.graph_vectorstores.base import GraphVectorStore, Node +from langchain_community.graph_vectorstores.links import METADATA_LINKS_KEY, Link +from langchain_community.graph_vectorstores.mmr_helper import MmrHelper from langchain_community.utilities.cassandra import SetupMode +from langchain_community.vectorstores.cassandra import Cassandra as CassandraVectorStore + +CGVST = TypeVar("CGVST", bound="CassandraGraphVectorStore") if TYPE_CHECKING: from cassandra.cluster import Session + from langchain_core.embeddings import Embeddings + + +logger = logging.getLogger(__name__) + + +class AdjacentNode: + id: str + links: list[Link] + embedding: list[float] + + def __init__(self, node: Node, embedding: list[float]) -> None: + """Create an Adjacent Node.""" + self.id = node.id or "" + self.links = node.links + self.embedding = embedding + + +def _serialize_links(links: list[Link]) -> str: + class SetAndLinkEncoder(json.JSONEncoder): + def default(self, obj: Any) -> Any: # noqa: ANN401 + if not isinstance(obj, type) and is_dataclass(obj): + return asdict(obj) + + if isinstance(obj, Iterable): + return list(obj) + + # Let the base class default method raise the TypeError + return super().default(obj) + + return json.dumps(links, cls=SetAndLinkEncoder) + + +def _deserialize_links(json_blob: str | None) -> set[Link]: + return { + Link(kind=link["kind"], direction=link["direction"], tag=link["tag"]) + for link in cast(list[dict[str, Any]], json.loads(json_blob or "[]")) + } + + +def _metadata_link_key(link: Link) -> str: + return f"link:{link.kind}:{link.tag}" + + +def _metadata_link_value() -> str: + return "link" + + +def _doc_to_node(doc: Document) -> Node: + metadata = doc.metadata.copy() + links = _deserialize_links(metadata.get(METADATA_LINKS_KEY)) + metadata[METADATA_LINKS_KEY] = links + + return Node( + id=doc.id, + text=doc.page_content, + metadata=metadata, + links=list(links), + ) + + +def _incoming_links(node: Node | AdjacentNode) -> set[Link]: + return {link for link in node.links if link.direction in ["in", "bidir"]} + + +def _outgoing_links(node: Node | AdjacentNode) -> set[Link]: + return {link for link in node.links if link.direction in ["out", "bidir"]} @beta() @@ -29,162 +109,1160 @@ class CassandraGraphVectorStore(GraphVectorStore): def __init__( self, embedding: Embeddings, + session: Session | None = None, + keyspace: str | None = None, + table_name: str = "", + ttl_seconds: int | None = None, *, - node_table: str = "graph_nodes", - session: Optional[Session] = None, - keyspace: Optional[str] = None, + body_index_options: list[tuple[str, Any]] | None = None, setup_mode: SetupMode = SetupMode.SYNC, - **kwargs: Any, - ): - """ - Create the hybrid graph store. + metadata_deny_list: Optional[list[str]] = None, + ) -> None: + """Apache Cassandra(R) for graph-vector-store workloads. - Args: - embedding: The embeddings to use for the document content. - setup_mode: Mode used to create the Cassandra table (SYNC, - ASYNC or OFF). - """ - try: - from ragstack_knowledge_store import EmbeddingModel, graph_store - except (ImportError, ModuleNotFoundError): - raise ImportError( - "Could not import ragstack_knowledge_store python package. " - "Please install it with `pip install ragstack-ai-knowledge-store`." - ) - - self._embedding = embedding - _setup_mode = getattr(graph_store.SetupMode, setup_mode.name) + To use it, you need a recent installation of the `cassio` library + and a Cassandra cluster / Astra DB instance supporting vector capabilities. - class _EmbeddingModelAdapter(EmbeddingModel): - def __init__(self, embeddings: Embeddings): - self.embeddings = embeddings + Example: + .. code-block:: python - def embed_texts(self, texts: List[str]) -> List[List[float]]: - return self.embeddings.embed_documents(texts) + from langchain_community.graph_vectorstores import + CassandraGraphVectorStore + from langchain_openai import OpenAIEmbeddings - def embed_query(self, text: str) -> List[float]: - return self.embeddings.embed_query(text) + embeddings = OpenAIEmbeddings() + session = ... # create your Cassandra session object + keyspace = 'my_keyspace' # the keyspace should exist already + table_name = 'my_graph_vector_store' + vectorstore = CassandraGraphVectorStore( + embeddings, + session, + keyspace, + table_name, + ) - async def aembed_texts(self, texts: List[str]) -> List[List[float]]: - return await self.embeddings.aembed_documents(texts) + Args: + embedding: Embedding function to use. + session: Cassandra driver session. If not provided, it is resolved from + cassio. + keyspace: Cassandra keyspace. If not provided, it is resolved from cassio. + table_name: Cassandra table (required). + ttl_seconds: Optional time-to-live for the added texts. + body_index_options: Optional options used to create the body index. + Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER] + setup_mode: mode used to create the Cassandra table (SYNC, + ASYNC or OFF). + metadata_deny_list: Optional list of metadata keys to not index. + i.e. to fine-tune which of the metadata fields are indexed. + Note: if you plan to have massive unique text metadata entries, + consider not indexing them for performance + (and to overcome max-length limitations). + Note: the `metadata_indexing` parameter from + langchain_community.utilities.cassandra.Cassandra is not + exposed since CassandraGraphVectorStore only supports the + deny_list option. + """ + self.embedding = embedding - async def aembed_query(self, text: str) -> List[float]: - return await self.embeddings.aembed_query(text) + if metadata_deny_list is None: + metadata_deny_list = [] + metadata_deny_list.append(METADATA_LINKS_KEY) - self.store = graph_store.GraphStore( - embedding=_EmbeddingModelAdapter(embedding), - node_table=node_table, + self.vector_store = CassandraVectorStore( + embedding=embedding, session=session, keyspace=keyspace, - setup_mode=_setup_mode, - **kwargs, + table_name=table_name, + ttl_seconds=ttl_seconds, + body_index_options=body_index_options, + setup_mode=setup_mode, + metadata_indexing=("deny_list", metadata_deny_list), + ) + + store_session: Session = self.vector_store.session + + self._insert_node = store_session.prepare( + f""" + INSERT INTO {keyspace}.{table_name} ( + row_id, body_blob, vector, attributes_blob, metadata_s + ) VALUES (?, ?, ?, ?, ?) + """ # noqa: S608 ) @property - def embeddings(self) -> Optional[Embeddings]: - return self._embedding + @override + def embeddings(self) -> Embeddings | None: + return self.embedding + def _get_metadata_filter( + self, + metadata: dict[str, Any] | None = None, + outgoing_link: Link | None = None, + ) -> dict[str, Any]: + if outgoing_link is None: + return metadata or {} + + metadata_filter = {} if metadata is None else metadata.copy() + metadata_filter[_metadata_link_key(link=outgoing_link)] = _metadata_link_value() + return metadata_filter + + def _restore_links(self, doc: Document) -> Document: + """Restores the links in the document by deserializing them from metadata. + + Args: + doc: A single Document + + Returns: + The same Document with restored links. + """ + links = _deserialize_links(doc.metadata.get(METADATA_LINKS_KEY)) + doc.metadata[METADATA_LINKS_KEY] = links + # TODO: Could this be skipped if we put these metadata entries + # only in the searchable `metadata_s` column? + for incoming_link_key in [ + _metadata_link_key(link=link) + for link in links + if link.direction in ["in", "bidir"] + ]: + if incoming_link_key in doc.metadata: + del doc.metadata[incoming_link_key] + + return doc + + def _get_node_metadata_for_insertion(self, node: Node) -> dict[str, Any]: + metadata = node.metadata.copy() + metadata[METADATA_LINKS_KEY] = _serialize_links(node.links) + # TODO: Could we could put these metadata entries + # only in the searchable `metadata_s` column? + for incoming_link in _incoming_links(node=node): + metadata[_metadata_link_key(link=incoming_link)] = _metadata_link_value() + return metadata + + def _get_docs_for_insertion( + self, nodes: Iterable[Node] + ) -> tuple[list[Document], list[str]]: + docs = [] + ids = [] + for node in nodes: + node_id = secrets.token_hex(8) if not node.id else node.id + + doc = Document( + page_content=node.text, + metadata=self._get_node_metadata_for_insertion(node=node), + id=node_id, + ) + docs.append(doc) + ids.append(node_id) + return (docs, ids) + + @override def add_nodes( self, nodes: Iterable[Node], **kwargs: Any, ) -> Iterable[str]: - return self.store.add_nodes(nodes) + """Add nodes to the graph store. - @classmethod - def from_texts( - cls: Type["CassandraGraphVectorStore"], - texts: Iterable[str], - embedding: Embeddings, - metadatas: Optional[List[dict]] = None, - ids: Optional[Iterable[str]] = None, - **kwargs: Any, - ) -> "CassandraGraphVectorStore": - """Return CassandraGraphVectorStore initialized from texts and embeddings.""" - store = cls(embedding, **kwargs) - store.add_texts(texts, metadatas, ids=ids) - return store + Args: + nodes: the nodes to add. + **kwargs: Additional keyword arguments. + """ + (docs, ids) = self._get_docs_for_insertion(nodes=nodes) + return self.vector_store.add_documents(docs, ids=ids) - @classmethod - def from_documents( - cls: Type["CassandraGraphVectorStore"], - documents: Iterable[Document], - embedding: Embeddings, - ids: Optional[Iterable[str]] = None, + @override + async def aadd_nodes( + self, + nodes: Iterable[Node], **kwargs: Any, - ) -> "CassandraGraphVectorStore": - """Return CassandraGraphVectorStore initialized from documents and - embeddings.""" - store = cls(embedding, **kwargs) - store.add_documents(documents, ids=ids) - return store + ) -> AsyncIterable[str]: + """Add nodes to the graph store. + Args: + nodes: the nodes to add. + **kwargs: Additional keyword arguments. + """ + (docs, ids) = self._get_docs_for_insertion(nodes=nodes) + for inserted_id in await self.vector_store.aadd_documents(docs, ids=ids): + yield inserted_id + + @override def similarity_search( self, query: str, k: int = 4, - metadata_filter: dict[str, Any] = {}, + filter: dict[str, Any] | None = None, **kwargs: Any, - ) -> List[Document]: - embedding_vector = self._embedding.embed_query(query) - return self.similarity_search_by_vector( - embedding_vector, - k=k, - metadata_filter=metadata_filter, - ) + ) -> list[Document]: + """Retrieve documents from this graph store. + + Args: + query: The query string. + k: The number of Documents to return. Defaults to 4. + filter: Optional metadata to filter the results. + **kwargs: Additional keyword arguments. + + Returns: + Collection of retrieved documents. + """ + return [ + self._restore_links(doc) + for doc in self.vector_store.similarity_search( + query=query, + k=k, + filter=filter, + **kwargs, + ) + ] + @override + async def asimilarity_search( + self, + query: str, + k: int = 4, + filter: dict[str, Any] | None = None, + **kwargs: Any, + ) -> list[Document]: + """Retrieve documents from this graph store. + + Args: + query: The query string. + k: The number of Documents to return. Defaults to 4. + filter: Optional metadata to filter the results. + **kwargs: Additional keyword arguments. + + Returns: + Collection of retrieved documents. + """ + return [ + self._restore_links(doc) + for doc in await self.vector_store.asimilarity_search( + query=query, + k=k, + filter=filter, + **kwargs, + ) + ] + + @override def similarity_search_by_vector( self, - embedding: List[float], + embedding: list[float], k: int = 4, - metadata_filter: dict[str, Any] = {}, + filter: dict[str, Any] | None = None, **kwargs: Any, - ) -> List[Document]: - nodes = self.store.similarity_search( - embedding, - k=k, - metadata_filter=metadata_filter, - ) - return list(nodes_to_documents(nodes)) + ) -> list[Document]: + """Return docs most similar to embedding vector. - def traversal_search( + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter: Filter on the metadata to apply. + **kwargs: Additional arguments are ignored. + + Returns: + The list of Documents most similar to the query vector. + """ + return [ + self._restore_links(doc) + for doc in self.vector_store.similarity_search_by_vector( + embedding, + k=k, + filter=filter, + **kwargs, + ) + ] + + @override + async def asimilarity_search_by_vector( + self, + embedding: list[float], + k: int = 4, + filter: dict[str, Any] | None = None, + **kwargs: Any, + ) -> list[Document]: + """Return docs most similar to embedding vector. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter: Filter on the metadata to apply. + **kwargs: Additional arguments are ignored. + + Returns: + The list of Documents most similar to the query vector. + """ + return [ + self._restore_links(doc) + for doc in await self.vector_store.asimilarity_search_by_vector( + embedding, + k=k, + filter=filter, + **kwargs, + ) + ] + + def metadata_search( + self, + filter: dict[str, Any] | None = None, # noqa: A002 + n: int = 5, + ) -> Iterable[Document]: + """Get documents via a metadata search. + + Args: + filter: the metadata to query for. + n: the maximum number of documents to return. + """ + return [ + self._restore_links(doc) + for doc in self.vector_store.metadata_search( + filter=filter or {}, + n=n, + ) + ] + + async def ametadata_search( + self, + filter: dict[str, Any] | None = None, # noqa: A002 + n: int = 5, + ) -> Iterable[Document]: + """Get documents via a metadata search. + + Args: + filter: the metadata to query for. + n: the maximum number of documents to return. + """ + return [ + self._restore_links(doc) + for doc in await self.vector_store.ametadata_search( + filter=filter or {}, + n=n, + ) + ] + + def get_by_document_id(self, document_id: str) -> Document | None: + """Retrieve a single document from the store, given its document ID. + + Args: + document_id: The document ID + + Returns: + The the document if it exists. Otherwise None. + """ + doc = self.vector_store.get_by_document_id(document_id=document_id) + return self._restore_links(doc) if doc is not None else None + + async def aget_by_document_id(self, document_id: str) -> Document | None: + """Retrieve a single document from the store, given its document ID. + + Args: + document_id: The document ID + + Returns: + The the document if it exists. Otherwise None. + """ + doc = await self.vector_store.aget_by_document_id(document_id=document_id) + return self._restore_links(doc) if doc is not None else None + + def get_node(self, node_id: str) -> Node | None: + """Retrieve a single node from the store, given its ID. + + Args: + node_id: The node ID + + Returns: + The the node if it exists. Otherwise None. + """ + doc = self.vector_store.get_by_document_id(document_id=node_id) + if doc is None: + return None + return _doc_to_node(doc=doc) + + @override + async def ammr_traversal_search( # noqa: C901 self, query: str, *, + initial_roots: Sequence[str] = (), k: int = 4, - depth: int = 1, - metadata_filter: dict[str, Any] = {}, + depth: int = 2, + fetch_k: int = 100, + adjacent_k: int = 10, + lambda_mult: float = 0.5, + score_threshold: float = float("-inf"), + filter: dict[str, Any] | None = None, **kwargs: Any, - ) -> Iterable[Document]: - nodes = self.store.traversal_search( - query, + ) -> AsyncIterable[Document]: + """Retrieve documents from this graph store using MMR-traversal. + + This strategy first retrieves the top `fetch_k` results by similarity to + the question. It then selects the top `k` results based on + maximum-marginal relevance using the given `lambda_mult`. + + At each step, it considers the (remaining) documents from `fetch_k` as + well as any documents connected by edges to a selected document + retrieved based on similarity (a "root"). + + Args: + query: The query string to search for. + initial_roots: Optional list of document IDs to use for initializing search. + The top `adjacent_k` nodes adjacent to each initial root will be + included in the set of initial candidates. To fetch only in the + neighborhood of these nodes, set `fetch_k = 0`. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of initial Documents to fetch via similarity. + Will be added to the nodes adjacent to `initial_roots`. + Defaults to 100. + adjacent_k: Number of adjacent Documents to fetch. + Defaults to 10. + depth: Maximum depth of a node (number of edges) from a node + retrieved via similarity. Defaults to 2. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding to maximum + diversity and 1 to minimum diversity. Defaults to 0.5. + score_threshold: Only documents with a score greater than or equal + this threshold will be chosen. Defaults to -infinity. + filter: Optional metadata to filter the results. + **kwargs: Additional keyword arguments. + """ + query_embedding = self.embedding.embed_query(query) + helper = MmrHelper( k=k, - depth=depth, - metadata_filter=metadata_filter, + query_embedding=query_embedding, + lambda_mult=lambda_mult, + score_threshold=score_threshold, ) - return nodes_to_documents(nodes) + # For each unselected node, stores the outgoing links. + outgoing_links_map: dict[str, set[Link]] = {} + visited_links: set[Link] = set() + # Map from id to Document + retrieved_docs: dict[str, Document] = {} + + async def fetch_neighborhood(neighborhood: Sequence[str]) -> None: + nonlocal outgoing_links_map, visited_links, retrieved_docs + + # Put the neighborhood into the outgoing links, to avoid adding it + # to the candidate set in the future. + outgoing_links_map.update( + {content_id: set() for content_id in neighborhood} + ) + + # Initialize the visited_links with the set of outgoing links from the + # neighborhood. This prevents re-visiting them. + visited_links = await self._get_outgoing_links(neighborhood) + + # Call `self._get_adjacent` to fetch the candidates. + adjacent_nodes = await self._get_adjacent( + links=visited_links, + query_embedding=query_embedding, + k_per_link=adjacent_k, + filter=filter, + retrieved_docs=retrieved_docs, + ) + + new_candidates: dict[str, list[float]] = {} + for adjacent_node in adjacent_nodes: + if adjacent_node.id not in outgoing_links_map: + outgoing_links_map[adjacent_node.id] = _outgoing_links( + node=adjacent_node + ) + new_candidates[adjacent_node.id] = adjacent_node.embedding + helper.add_candidates(new_candidates) + + async def fetch_initial_candidates() -> None: + nonlocal outgoing_links_map, visited_links, retrieved_docs + + results = ( + await self.vector_store.asimilarity_search_with_embedding_id_by_vector( + embedding=query_embedding, + k=fetch_k, + filter=filter, + ) + ) + + candidates: dict[str, list[float]] = {} + for doc, embedding, doc_id in results: + if doc_id not in retrieved_docs: + retrieved_docs[doc_id] = doc + + if doc_id not in outgoing_links_map: + node = _doc_to_node(doc) + outgoing_links_map[doc_id] = _outgoing_links(node=node) + candidates[doc_id] = embedding + helper.add_candidates(candidates) + + if initial_roots: + await fetch_neighborhood(initial_roots) + if fetch_k > 0: + await fetch_initial_candidates() + + # Tracks the depth of each candidate. + depths = {candidate_id: 0 for candidate_id in helper.candidate_ids()} + + # Select the best item, K times. + for _ in range(k): + selected_id = helper.pop_best() + + if selected_id is None: + break + + next_depth = depths[selected_id] + 1 + if next_depth < depth: + # If the next nodes would not exceed the depth limit, find the + # adjacent nodes. + + # Find the links linked to from the selected ID. + selected_outgoing_links = outgoing_links_map.pop(selected_id) + + # Don't re-visit already visited links. + selected_outgoing_links.difference_update(visited_links) + + # Find the nodes with incoming links from those links. + adjacent_nodes = await self._get_adjacent( + links=selected_outgoing_links, + query_embedding=query_embedding, + k_per_link=adjacent_k, + filter=filter, + retrieved_docs=retrieved_docs, + ) + + # Record the selected_outgoing_links as visited. + visited_links.update(selected_outgoing_links) + + new_candidates = {} + for adjacent_node in adjacent_nodes: + if adjacent_node.id not in outgoing_links_map: + outgoing_links_map[adjacent_node.id] = _outgoing_links( + node=adjacent_node + ) + new_candidates[adjacent_node.id] = adjacent_node.embedding + if next_depth < depths.get(adjacent_node.id, depth + 1): + # If this is a new shortest depth, or there was no + # previous depth, update the depths. This ensures that + # when we discover a node we will have the shortest + # depth available. + # + # NOTE: No effort is made to traverse from nodes that + # were previously selected if they become reachable via + # a shorter path via nodes selected later. This is + # currently "intended", but may be worth experimenting + # with. + depths[adjacent_node.id] = next_depth + helper.add_candidates(new_candidates) + + for doc_id, similarity_score, mmr_score in zip( + helper.selected_ids, + helper.selected_similarity_scores, + helper.selected_mmr_scores, + ): + if doc_id in retrieved_docs: + doc = self._restore_links(retrieved_docs[doc_id]) + doc.metadata["similarity_score"] = similarity_score + doc.metadata["mmr_score"] = mmr_score + yield doc + else: + msg = f"retrieved_docs should contain id: {doc_id}" + raise RuntimeError(msg) + + @override def mmr_traversal_search( self, query: str, *, + initial_roots: Sequence[str] = (), k: int = 4, depth: int = 2, fetch_k: int = 100, adjacent_k: int = 10, lambda_mult: float = 0.5, score_threshold: float = float("-inf"), - metadata_filter: dict[str, Any] = {}, + filter: dict[str, Any] | None = None, **kwargs: Any, ) -> Iterable[Document]: - nodes = self.store.mmr_traversal_search( - query, + """Retrieve documents from this graph store using MMR-traversal. + + This strategy first retrieves the top `fetch_k` results by similarity to + the question. It then selects the top `k` results based on + maximum-marginal relevance using the given `lambda_mult`. + + At each step, it considers the (remaining) documents from `fetch_k` as + well as any documents connected by edges to a selected document + retrieved based on similarity (a "root"). + + Args: + query: The query string to search for. + initial_roots: Optional list of document IDs to use for initializing search. + The top `adjacent_k` nodes adjacent to each initial root will be + included in the set of initial candidates. To fetch only in the + neighborhood of these nodes, set `fetch_k = 0`. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of initial Documents to fetch via similarity. + Will be added to the nodes adjacent to `initial_roots`. + Defaults to 100. + adjacent_k: Number of adjacent Documents to fetch. + Defaults to 10. + depth: Maximum depth of a node (number of edges) from a node + retrieved via similarity. Defaults to 2. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding to maximum + diversity and 1 to minimum diversity. Defaults to 0.5. + score_threshold: Only documents with a score greater than or equal + this threshold will be chosen. Defaults to -infinity. + filter: Optional metadata to filter the results. + **kwargs: Additional keyword arguments. + """ + + async def collect_docs() -> Iterable[Document]: + async_iter = self.ammr_traversal_search( + query=query, + initial_roots=initial_roots, + k=k, + depth=depth, + fetch_k=fetch_k, + adjacent_k=adjacent_k, + lambda_mult=lambda_mult, + score_threshold=score_threshold, + filter=filter, + **kwargs, + ) + return [doc async for doc in async_iter] + + return asyncio.run(collect_docs()) + + @override + async def atraversal_search( # noqa: C901 + self, + query: str, + *, + k: int = 4, + depth: int = 1, + filter: dict[str, Any] | None = None, + **kwargs: Any, + ) -> AsyncIterable[Document]: + """Retrieve documents from this knowledge store. + + First, `k` nodes are retrieved using a vector search for the `query` string. + Then, additional nodes are discovered up to the given `depth` from those + starting nodes. + + Args: + query: The query string. + k: The number of Documents to return from the initial vector search. + Defaults to 4. + depth: The maximum depth of edges to traverse. Defaults to 1. + filter: Optional metadata to filter the results. + **kwargs: Additional keyword arguments. + + Returns: + Collection of retrieved documents. + """ + # Depth 0: + # Query for `k` nodes similar to the question. + # Retrieve `content_id` and `outgoing_links()`. + # + # Depth 1: + # Query for nodes that have an incoming link in the `outgoing_links()` set. + # Combine node IDs. + # Query for `outgoing_links()` of those "new" node IDs. + # + # ... + + # Map from visited ID to depth + visited_ids: dict[str, int] = {} + + # Map from visited link to depth + visited_links: dict[Link, int] = {} + + # Map from id to Document + retrieved_docs: dict[str, Document] = {} + + async def visit_nodes(d: int, docs: Iterable[Document]) -> None: + """Recursively visit nodes and their outgoing links.""" + nonlocal visited_ids, visited_links, retrieved_docs + + # Iterate over nodes, tracking the *new* outgoing links for this + # depth. These are links that are either new, or newly discovered at a + # lower depth. + outgoing_links: set[Link] = set() + for doc in docs: + if doc.id is not None: + if doc.id not in retrieved_docs: + retrieved_docs[doc.id] = doc + + # If this node is at a closer depth, update visited_ids + if d <= visited_ids.get(doc.id, depth): + visited_ids[doc.id] = d + + # If we can continue traversing from this node, + if d < depth: + node = _doc_to_node(doc=doc) + # Record any new (or newly discovered at a lower depth) + # links to the set to traverse. + for link in _outgoing_links(node=node): + if d <= visited_links.get(link, depth): + # Record that we'll query this link at the + # given depth, so we don't fetch it again + # (unless we find it an earlier depth) + visited_links[link] = d + outgoing_links.add(link) + + if outgoing_links: + metadata_search_tasks = [] + for outgoing_link in outgoing_links: + metadata_filter = self._get_metadata_filter( + metadata=filter, + outgoing_link=outgoing_link, + ) + metadata_search_tasks.append( + asyncio.create_task( + self.vector_store.ametadata_search( + filter=metadata_filter, n=1000 + ) + ) + ) + results = await asyncio.gather(*metadata_search_tasks) + + # Visit targets concurrently + visit_target_tasks = [ + visit_targets(d=d + 1, docs=docs) for docs in results + ] + await asyncio.gather(*visit_target_tasks) + + async def visit_targets(d: int, docs: Iterable[Document]) -> None: + """Visit target nodes retrieved from outgoing links.""" + nonlocal visited_ids, retrieved_docs + + new_ids_at_next_depth = set() + for doc in docs: + if doc.id is not None: + if doc.id not in retrieved_docs: + retrieved_docs[doc.id] = doc + + if d <= visited_ids.get(doc.id, depth): + new_ids_at_next_depth.add(doc.id) + + if new_ids_at_next_depth: + visit_node_tasks = [ + visit_nodes(d=d, docs=[retrieved_docs[doc_id]]) + for doc_id in new_ids_at_next_depth + if doc_id in retrieved_docs + ] + + fetch_tasks = [ + asyncio.create_task( + self.vector_store.aget_by_document_id(document_id=doc_id) + ) + for doc_id in new_ids_at_next_depth + if doc_id not in retrieved_docs + ] + + new_docs: list[Document | None] = await asyncio.gather(*fetch_tasks) + + visit_node_tasks.extend( + visit_nodes(d=d, docs=[new_doc]) + for new_doc in new_docs + if new_doc is not None + ) + + await asyncio.gather(*visit_node_tasks) + + # Start the traversal + initial_docs = self.vector_store.similarity_search( + query=query, k=k, - depth=depth, - fetch_k=fetch_k, - adjacent_k=adjacent_k, - lambda_mult=lambda_mult, - score_threshold=score_threshold, - metadata_filter=metadata_filter, + filter=filter, + ) + await visit_nodes(d=0, docs=initial_docs) + + for doc_id in visited_ids: + if doc_id in retrieved_docs: + yield self._restore_links(retrieved_docs[doc_id]) + else: + msg = f"retrieved_docs should contain id: {doc_id}" + raise RuntimeError(msg) + + @override + def traversal_search( + self, + query: str, + *, + k: int = 4, + depth: int = 1, + filter: dict[str, Any] | None = None, + **kwargs: Any, + ) -> Iterable[Document]: + """Retrieve documents from this knowledge store. + + First, `k` nodes are retrieved using a vector search for the `query` string. + Then, additional nodes are discovered up to the given `depth` from those + starting nodes. + + Args: + query: The query string. + k: The number of Documents to return from the initial vector search. + Defaults to 4. + depth: The maximum depth of edges to traverse. Defaults to 1. + filter: Optional metadata to filter the results. + **kwargs: Additional keyword arguments. + + Returns: + Collection of retrieved documents. + """ + + async def collect_docs() -> Iterable[Document]: + async_iter = self.atraversal_search( + query=query, + k=k, + depth=depth, + filter=filter, + **kwargs, + ) + return [doc async for doc in async_iter] + + return asyncio.run(collect_docs()) + + async def _get_outgoing_links(self, source_ids: Iterable[str]) -> set[Link]: + """Return the set of outgoing links for the given source IDs asynchronously. + + Args: + source_ids: The IDs of the source nodes to retrieve outgoing links for. + + Returns: + A set of `Link` objects representing the outgoing links from the source + nodes. + """ + links = set() + + # Create coroutine objects without scheduling them yet + coroutines = [ + self.vector_store.aget_by_document_id(document_id=source_id) + for source_id in source_ids + ] + + # Schedule and await all coroutines + docs = await asyncio.gather(*coroutines) + + for doc in docs: + if doc is not None: + node = _doc_to_node(doc=doc) + links.update(_outgoing_links(node=node)) + + return links + + async def _get_adjacent( + self, + links: set[Link], + query_embedding: list[float], + retrieved_docs: dict[str, Document], + k_per_link: int | None = None, + filter: dict[str, Any] | None = None, # noqa: A002 + ) -> Iterable[AdjacentNode]: + """Return the target nodes with incoming links from any of the given links. + + Args: + links: The links to look for. + query_embedding: The query embedding. Used to rank target nodes. + retrieved_docs: A cache of retrieved docs. This will be added to. + k_per_link: The number of target nodes to fetch for each link. + filter: Optional metadata to filter the results. + + Returns: + Iterable of adjacent edges. + """ + targets: dict[str, AdjacentNode] = {} + + tasks = [] + for link in links: + metadata_filter = self._get_metadata_filter( + metadata=filter, + outgoing_link=link, + ) + + tasks.append( + self.vector_store.asimilarity_search_with_embedding_id_by_vector( + embedding=query_embedding, + k=k_per_link or 10, + filter=metadata_filter, + ) + ) + + results = await asyncio.gather(*tasks) + + for result in results: + for doc, embedding, doc_id in result: + if doc_id not in retrieved_docs: + retrieved_docs[doc_id] = doc + if doc_id not in targets: + node = _doc_to_node(doc=doc) + targets[doc_id] = AdjacentNode(node=node, embedding=embedding) + + # TODO: Consider a combined limit based on the similarity and/or + # predicated MMR score? + return targets.values() + + @staticmethod + def _build_docs_from_texts( + texts: List[str], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + ) -> List[Document]: + docs: List[Document] = [] + for i, text in enumerate(texts): + doc = Document( + page_content=text, + ) + if metadatas is not None: + doc.metadata = metadatas[i] + if ids is not None: + doc.id = ids[i] + docs.append(doc) + return docs + + @classmethod + def from_texts( + cls: Type[CGVST], + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + *, + session: Optional[Session] = None, + keyspace: Optional[str] = None, + table_name: str = "", + ids: Optional[List[str]] = None, + ttl_seconds: Optional[int] = None, + body_index_options: Optional[List[Tuple[str, Any]]] = None, + metadata_deny_list: Optional[list[str]] = None, + **kwargs: Any, + ) -> CGVST: + """Create a CassandraGraphVectorStore from raw texts. + + Args: + texts: Texts to add to the vectorstore. + embedding: Embedding function to use. + metadatas: Optional list of metadatas associated with the texts. + session: Cassandra driver session. + If not provided, it is resolved from cassio. + keyspace: Cassandra key space. + If not provided, it is resolved from cassio. + table_name: Cassandra table (required). + ids: Optional list of IDs associated with the texts. + ttl_seconds: Optional time-to-live for the added texts. + body_index_options: Optional options used to create the body index. + Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER] + metadata_deny_list: Optional list of metadata keys to not index. + i.e. to fine-tune which of the metadata fields are indexed. + Note: if you plan to have massive unique text metadata entries, + consider not indexing them for performance + (and to overcome max-length limitations). + Note: the `metadata_indexing` parameter from + langchain_community.utilities.cassandra.Cassandra is not + exposed since CassandraGraphVectorStore only supports the + deny_list option. + + Returns: + a CassandraGraphVectorStore. + """ + docs = cls._build_docs_from_texts( + texts=texts, + metadatas=metadatas, + ids=ids, ) - return nodes_to_documents(nodes) + + return cls.from_documents( + documents=docs, + embedding=embedding, + session=session, + keyspace=keyspace, + table_name=table_name, + ttl_seconds=ttl_seconds, + body_index_options=body_index_options, + metadata_deny_list=metadata_deny_list, + **kwargs, + ) + + @classmethod + async def afrom_texts( + cls: Type[CGVST], + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + *, + session: Optional[Session] = None, + keyspace: Optional[str] = None, + table_name: str = "", + ids: Optional[List[str]] = None, + ttl_seconds: Optional[int] = None, + body_index_options: Optional[List[Tuple[str, Any]]] = None, + metadata_deny_list: Optional[list[str]] = None, + **kwargs: Any, + ) -> CGVST: + """Create a CassandraGraphVectorStore from raw texts. + + Args: + texts: Texts to add to the vectorstore. + embedding: Embedding function to use. + metadatas: Optional list of metadatas associated with the texts. + session: Cassandra driver session. + If not provided, it is resolved from cassio. + keyspace: Cassandra key space. + If not provided, it is resolved from cassio. + table_name: Cassandra table (required). + ids: Optional list of IDs associated with the texts. + ttl_seconds: Optional time-to-live for the added texts. + body_index_options: Optional options used to create the body index. + Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER] + metadata_deny_list: Optional list of metadata keys to not index. + i.e. to fine-tune which of the metadata fields are indexed. + Note: if you plan to have massive unique text metadata entries, + consider not indexing them for performance + (and to overcome max-length limitations). + Note: the `metadata_indexing` parameter from + langchain_community.utilities.cassandra.Cassandra is not + exposed since CassandraGraphVectorStore only supports the + deny_list option. + + Returns: + a CassandraGraphVectorStore. + """ + docs = cls._build_docs_from_texts( + texts=texts, + metadatas=metadatas, + ids=ids, + ) + + return await cls.afrom_documents( + documents=docs, + embedding=embedding, + session=session, + keyspace=keyspace, + table_name=table_name, + ttl_seconds=ttl_seconds, + body_index_options=body_index_options, + metadata_deny_list=metadata_deny_list, + **kwargs, + ) + + @staticmethod + def _add_ids_to_docs( + docs: List[Document], + ids: Optional[List[str]] = None, + ) -> List[Document]: + if ids is not None: + for doc, doc_id in zip(docs, ids): + doc.id = doc_id + return docs + + @classmethod + def from_documents( + cls: Type[CGVST], + documents: List[Document], + embedding: Embeddings, + *, + session: Optional[Session] = None, + keyspace: Optional[str] = None, + table_name: str = "", + ids: Optional[List[str]] = None, + ttl_seconds: Optional[int] = None, + body_index_options: Optional[List[Tuple[str, Any]]] = None, + metadata_deny_list: Optional[list[str]] = None, + **kwargs: Any, + ) -> CGVST: + """Create a CassandraGraphVectorStore from a document list. + + Args: + documents: Documents to add to the vectorstore. + embedding: Embedding function to use. + session: Cassandra driver session. + If not provided, it is resolved from cassio. + keyspace: Cassandra key space. + If not provided, it is resolved from cassio. + table_name: Cassandra table (required). + ids: Optional list of IDs associated with the documents. + ttl_seconds: Optional time-to-live for the added documents. + body_index_options: Optional options used to create the body index. + Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER] + metadata_deny_list: Optional list of metadata keys to not index. + i.e. to fine-tune which of the metadata fields are indexed. + Note: if you plan to have massive unique text metadata entries, + consider not indexing them for performance + (and to overcome max-length limitations). + Note: the `metadata_indexing` parameter from + langchain_community.utilities.cassandra.Cassandra is not + exposed since CassandraGraphVectorStore only supports the + deny_list option. + + Returns: + a CassandraGraphVectorStore. + """ + store = cls( + embedding=embedding, + session=session, + keyspace=keyspace, + table_name=table_name, + ttl_seconds=ttl_seconds, + body_index_options=body_index_options, + metadata_deny_list=metadata_deny_list, + **kwargs, + ) + store.add_documents(documents=cls._add_ids_to_docs(docs=documents, ids=ids)) + return store + + @classmethod + async def afrom_documents( + cls: Type[CGVST], + documents: List[Document], + embedding: Embeddings, + *, + session: Optional[Session] = None, + keyspace: Optional[str] = None, + table_name: str = "", + ids: Optional[List[str]] = None, + ttl_seconds: Optional[int] = None, + body_index_options: Optional[List[Tuple[str, Any]]] = None, + metadata_deny_list: Optional[list[str]] = None, + **kwargs: Any, + ) -> CGVST: + """Create a CassandraGraphVectorStore from a document list. + + Args: + documents: Documents to add to the vectorstore. + embedding: Embedding function to use. + session: Cassandra driver session. + If not provided, it is resolved from cassio. + keyspace: Cassandra key space. + If not provided, it is resolved from cassio. + table_name: Cassandra table (required). + ids: Optional list of IDs associated with the documents. + ttl_seconds: Optional time-to-live for the added documents. + body_index_options: Optional options used to create the body index. + Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER] + metadata_deny_list: Optional list of metadata keys to not index. + i.e. to fine-tune which of the metadata fields are indexed. + Note: if you plan to have massive unique text metadata entries, + consider not indexing them for performance + (and to overcome max-length limitations). + Note: the `metadata_indexing` parameter from + langchain_community.utilities.cassandra.Cassandra is not + exposed since CassandraGraphVectorStore only supports the + deny_list option. + + + Returns: + a CassandraGraphVectorStore. + """ + store = cls( + embedding=embedding, + session=session, + keyspace=keyspace, + table_name=table_name, + ttl_seconds=ttl_seconds, + setup_mode=SetupMode.ASYNC, + body_index_options=body_index_options, + metadata_deny_list=metadata_deny_list, + **kwargs, + ) + await store.aadd_documents( + documents=cls._add_ids_to_docs(docs=documents, ids=ids) + ) + return store diff --git a/libs/community/langchain_community/graph_vectorstores/mmr_helper.py b/libs/community/langchain_community/graph_vectorstores/mmr_helper.py new file mode 100644 index 00000000000000..43aa8c0949fc43 --- /dev/null +++ b/libs/community/langchain_community/graph_vectorstores/mmr_helper.py @@ -0,0 +1,272 @@ +"""Tools for the Graph Traversal Maximal Marginal Relevance (MMR) reranking.""" + +from __future__ import annotations + +import dataclasses +from typing import TYPE_CHECKING, Iterable + +import numpy as np + +from langchain_community.utils.math import cosine_similarity + +if TYPE_CHECKING: + from numpy.typing import NDArray + + +def _emb_to_ndarray(embedding: list[float]) -> NDArray[np.float32]: + emb_array = np.array(embedding, dtype=np.float32) + if emb_array.ndim == 1: + emb_array = np.expand_dims(emb_array, axis=0) + return emb_array + + +NEG_INF = float("-inf") + + +@dataclasses.dataclass +class _Candidate: + id: str + similarity: float + weighted_similarity: float + weighted_redundancy: float + score: float = dataclasses.field(init=False) + + def __post_init__(self) -> None: + self.score = self.weighted_similarity - self.weighted_redundancy + + def update_redundancy(self, new_weighted_redundancy: float) -> None: + if new_weighted_redundancy > self.weighted_redundancy: + self.weighted_redundancy = new_weighted_redundancy + self.score = self.weighted_similarity - self.weighted_redundancy + + +class MmrHelper: + """Helper for executing an MMR traversal query. + + Args: + query_embedding: The embedding of the query to use for scoring. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding to maximum + diversity and 1 to minimum diversity. Defaults to 0.5. + score_threshold: Only documents with a score greater than or equal + this threshold will be chosen. Defaults to -infinity. + """ + + dimensions: int + """Dimensions of the embedding.""" + + query_embedding: NDArray[np.float32] + """Embedding of the query as a (1,dim) ndarray.""" + + lambda_mult: float + """Number between 0 and 1. + + Determines the degree of diversity among the results with 0 corresponding to + maximum diversity and 1 to minimum diversity.""" + + lambda_mult_complement: float + """1 - lambda_mult.""" + + score_threshold: float + """Only documents with a score greater than or equal to this will be chosen.""" + + selected_ids: list[str] + """List of selected IDs (in selection order).""" + + selected_mmr_scores: list[float] + """List of MMR score at the time each document is selected.""" + + selected_similarity_scores: list[float] + """List of similarity score for each selected document.""" + + selected_embeddings: NDArray[np.float32] + """(N, dim) ndarray with a row for each selected node.""" + + candidate_id_to_index: dict[str, int] + """Dictionary of candidate IDs to indices in candidates and candidate_embeddings.""" + candidates: list[_Candidate] + """List containing information about candidates. + + Same order as rows in `candidate_embeddings`. + """ + candidate_embeddings: NDArray[np.float32] + """(N, dim) ndarray with a row for each candidate.""" + + best_score: float + best_id: str | None + + def __init__( + self, + k: int, + query_embedding: list[float], + lambda_mult: float = 0.5, + score_threshold: float = NEG_INF, + ) -> None: + """Create a new Traversal MMR helper.""" + self.query_embedding = _emb_to_ndarray(query_embedding) + self.dimensions = self.query_embedding.shape[1] + + self.lambda_mult = lambda_mult + self.lambda_mult_complement = 1 - lambda_mult + self.score_threshold = score_threshold + + self.selected_ids = [] + self.selected_similarity_scores = [] + self.selected_mmr_scores = [] + + # List of selected embeddings (in selection order). + self.selected_embeddings = np.ndarray((k, self.dimensions), dtype=np.float32) + + self.candidate_id_to_index = {} + + # List of the candidates. + self.candidates = [] + # numpy n-dimensional array of the candidate embeddings. + self.candidate_embeddings = np.ndarray((0, self.dimensions), dtype=np.float32) + + self.best_score = NEG_INF + self.best_id = None + + def candidate_ids(self) -> Iterable[str]: + """Return the IDs of the candidates.""" + return self.candidate_id_to_index.keys() + + def _already_selected_embeddings(self) -> NDArray[np.float32]: + """Return the selected embeddings sliced to the already assigned values.""" + selected = len(self.selected_ids) + return np.vsplit(self.selected_embeddings, [selected])[0] + + def _pop_candidate(self, candidate_id: str) -> tuple[float, NDArray[np.float32]]: + """Pop the candidate with the given ID. + + Returns: + The similarity score and embedding of the candidate. + """ + # Get the embedding for the id. + index = self.candidate_id_to_index.pop(candidate_id) + if self.candidates[index].id != candidate_id: + msg = ( + "ID in self.candidate_id_to_index doesn't match the ID of the " + "corresponding index in self.candidates" + ) + raise ValueError(msg) + embedding: NDArray[np.float32] = self.candidate_embeddings[index].copy() + + # Swap that index with the last index in the candidates and + # candidate_embeddings. + last_index = self.candidate_embeddings.shape[0] - 1 + + similarity = 0.0 + if index == last_index: + # Already the last item. We don't need to swap. + similarity = self.candidates.pop().similarity + else: + self.candidate_embeddings[index] = self.candidate_embeddings[last_index] + + similarity = self.candidates[index].similarity + + old_last = self.candidates.pop() + self.candidates[index] = old_last + self.candidate_id_to_index[old_last.id] = index + + self.candidate_embeddings = np.vsplit(self.candidate_embeddings, [last_index])[ + 0 + ] + + return similarity, embedding + + def pop_best(self) -> str | None: + """Select and pop the best item being considered. + + Updates the consideration set based on it. + + Returns: + A tuple containing the ID of the best item. + """ + if self.best_id is None or self.best_score < self.score_threshold: + return None + + # Get the selection and remove from candidates. + selected_id = self.best_id + selected_similarity, selected_embedding = self._pop_candidate(selected_id) + + # Add the ID and embedding to the selected information. + selection_index = len(self.selected_ids) + self.selected_ids.append(selected_id) + self.selected_mmr_scores.append(self.best_score) + self.selected_similarity_scores.append(selected_similarity) + self.selected_embeddings[selection_index] = selected_embedding + + # Reset the best score / best ID. + self.best_score = NEG_INF + self.best_id = None + + # Update the candidates redundancy, tracking the best node. + if self.candidate_embeddings.shape[0] > 0: + similarity = cosine_similarity( + self.candidate_embeddings, np.expand_dims(selected_embedding, axis=0) + ) + for index, candidate in enumerate(self.candidates): + candidate.update_redundancy(similarity[index][0]) + if candidate.score > self.best_score: + self.best_score = candidate.score + self.best_id = candidate.id + + return selected_id + + def add_candidates(self, candidates: dict[str, list[float]]) -> None: + """Add candidates to the consideration set.""" + # Determine the keys to actually include. + # These are the candidates that aren't already selected + # or under consideration. + include_ids_set = set(candidates.keys()) + include_ids_set.difference_update(self.selected_ids) + include_ids_set.difference_update(self.candidate_id_to_index.keys()) + include_ids = list(include_ids_set) + + # Now, build up a matrix of the remaining candidate embeddings. + # And add them to the + new_embeddings: NDArray[np.float32] = np.ndarray( + ( + len(include_ids), + self.dimensions, + ) + ) + offset = self.candidate_embeddings.shape[0] + for index, candidate_id in enumerate(include_ids): + if candidate_id in include_ids: + self.candidate_id_to_index[candidate_id] = offset + index + embedding = candidates[candidate_id] + new_embeddings[index] = embedding + + # Compute the similarity to the query. + similarity = cosine_similarity(new_embeddings, self.query_embedding) + + # Compute the distance metrics of all of pairs in the selected set with + # the new candidates. + redundancy = cosine_similarity( + new_embeddings, self._already_selected_embeddings() + ) + for index, candidate_id in enumerate(include_ids): + max_redundancy = 0.0 + if redundancy.shape[0] > 0: + max_redundancy = redundancy[index].max() + candidate = _Candidate( + id=candidate_id, + similarity=similarity[index][0], + weighted_similarity=self.lambda_mult * similarity[index][0], + weighted_redundancy=self.lambda_mult_complement * max_redundancy, + ) + self.candidates.append(candidate) + + if candidate.score >= self.best_score: + self.best_score = candidate.score + self.best_id = candidate.id + + # Add the new embeddings to the candidate set. + self.candidate_embeddings = np.vstack( + ( + self.candidate_embeddings, + new_embeddings, + ) + ) diff --git a/libs/community/langchain_community/vectorstores/cassandra.py b/libs/community/langchain_community/vectorstores/cassandra.py index 3e9ea17cbb0ae4..b5e63d3e3f2f79 100644 --- a/libs/community/langchain_community/vectorstores/cassandra.py +++ b/libs/community/langchain_community/vectorstores/cassandra.py @@ -4,6 +4,7 @@ import importlib.metadata import typing import uuid +import warnings from typing import ( Any, Awaitable, @@ -501,10 +502,13 @@ def _row_to_document(row: Dict[str, Any]) -> Document: ) def get_by_document_id(self, document_id: str) -> Document | None: - """Get by document ID. + """Retrieve a single document from the store, given its document ID. Args: - document_id: the document ID to get. + document_id: The document ID + + Returns: + The the document if it exists. Otherwise None. """ row = self.table.get(row_id=document_id) if row is None: @@ -512,10 +516,13 @@ def get_by_document_id(self, document_id: str) -> Document | None: return self._row_to_document(row=row) async def aget_by_document_id(self, document_id: str) -> Document | None: - """Get by document ID. + """Retrieve a single document from the store, given its document ID. Args: - document_id: the document ID to get. + document_id: The document ID + + Returns: + The the document if it exists. Otherwise None. """ row = await self.table.aget(row_id=document_id) if row is None: @@ -524,28 +531,30 @@ async def aget_by_document_id(self, document_id: str) -> Document | None: def metadata_search( self, - metadata: dict[str, Any] = {}, # noqa: B006 + filter: dict[str, Any] = {}, # noqa: B006 n: int = 5, ) -> Iterable[Document]: """Get documents via a metadata search. Args: - metadata: the metadata to query for. + filter: the metadata to query for. + n: the maximum number of documents to return. """ - rows = self.table.find_entries(metadata=metadata, n=n) + rows = self.table.find_entries(metadata=filter, n=n) return [self._row_to_document(row=row) for row in rows if row] async def ametadata_search( self, - metadata: dict[str, Any] = {}, # noqa: B006 + filter: dict[str, Any] = {}, # noqa: B006 n: int = 5, ) -> Iterable[Document]: """Get documents via a metadata search. Args: - metadata: the metadata to query for. + filter: the metadata to query for. + n: the maximum number of documents to return. """ - rows = await self.table.afind_entries(metadata=metadata, n=n) + rows = await self.table.afind_entries(metadata=filter, n=n) return [self._row_to_document(row=row) for row in rows] async def asimilarity_search_with_embedding_id_by_vector( @@ -1126,6 +1135,24 @@ async def amax_marginal_relevance_search( body_search=body_search, ) + @staticmethod + def _build_docs_from_texts( + texts: List[str], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + ) -> List[Document]: + docs: List[Document] = [] + for i, text in enumerate(texts): + doc = Document( + page_content=text, + ) + if metadatas is not None: + doc.metadata = metadatas[i] + if ids is not None: + doc.id = ids[i] + docs.append(doc) + return docs + @classmethod def from_texts( cls: Type[CVST], @@ -1137,13 +1164,12 @@ def from_texts( keyspace: Optional[str] = None, table_name: str = "", ids: Optional[List[str]] = None, - batch_size: int = 16, ttl_seconds: Optional[int] = None, body_index_options: Optional[List[Tuple[str, Any]]] = None, metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all", **kwargs: Any, ) -> CVST: - """Create a Cassandra vectorstore from raw texts. + """Create a Cassandra vector store from raw texts. Args: texts: Texts to add to the vectorstore. @@ -1155,16 +1181,32 @@ def from_texts( If not provided, it is resolved from cassio. table_name: Cassandra table (required). ids: Optional list of IDs associated with the texts. - batch_size: Number of concurrent requests to send to the server. - Defaults to 16. ttl_seconds: Optional time-to-live for the added texts. body_index_options: Optional options used to create the body index. Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER] + metadata_indexing: Optional specification of a metadata indexing policy, + i.e. to fine-tune which of the metadata fields are indexed. + It can be a string ("all" or "none"), or a 2-tuple. The following + means that all fields except 'f1', 'f2' ... are NOT indexed: + metadata_indexing=("allowlist", ["f1", "f2", ...]) + The following means all fields EXCEPT 'g1', 'g2', ... are indexed: + metadata_indexing("denylist", ["g1", "g2", ...]) + The default is to index every metadata field. + Note: if you plan to have massive unique text metadata entries, + consider not indexing them for performance + (and to overcome max-length limitations). Returns: - a Cassandra vectorstore. + a Cassandra vector store. """ - store = cls( + docs = cls._build_docs_from_texts( + texts=texts, + metadatas=metadatas, + ids=ids, + ) + + return cls.from_documents( + documents=docs, embedding=embedding, session=session, keyspace=keyspace, @@ -1172,11 +1214,8 @@ def from_texts( ttl_seconds=ttl_seconds, body_index_options=body_index_options, metadata_indexing=metadata_indexing, + **kwargs, ) - store.add_texts( - texts=texts, metadatas=metadatas, ids=ids, batch_size=batch_size - ) - return store @classmethod async def afrom_texts( @@ -1189,13 +1228,12 @@ async def afrom_texts( keyspace: Optional[str] = None, table_name: str = "", ids: Optional[List[str]] = None, - concurrency: int = 16, ttl_seconds: Optional[int] = None, body_index_options: Optional[List[Tuple[str, Any]]] = None, metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all", **kwargs: Any, ) -> CVST: - """Create a Cassandra vectorstore from raw texts. + """Create a Cassandra vector store from raw texts. Args: texts: Texts to add to the vectorstore. @@ -1207,29 +1245,51 @@ async def afrom_texts( If not provided, it is resolved from cassio. table_name: Cassandra table (required). ids: Optional list of IDs associated with the texts. - concurrency: Number of concurrent queries to send to the database. - Defaults to 16. ttl_seconds: Optional time-to-live for the added texts. body_index_options: Optional options used to create the body index. Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER] + metadata_indexing: Optional specification of a metadata indexing policy, + i.e. to fine-tune which of the metadata fields are indexed. + It can be a string ("all" or "none"), or a 2-tuple. The following + means that all fields except 'f1', 'f2' ... are NOT indexed: + metadata_indexing=("allowlist", ["f1", "f2", ...]) + The following means all fields EXCEPT 'g1', 'g2', ... are indexed: + metadata_indexing("denylist", ["g1", "g2", ...]) + The default is to index every metadata field. + Note: if you plan to have massive unique text metadata entries, + consider not indexing them for performance + (and to overcome max-length limitations). Returns: - a Cassandra vectorstore. + a Cassandra vector store. """ - store = cls( + docs = cls._build_docs_from_texts( + texts=texts, + metadatas=metadatas, + ids=ids, + ) + + return await cls.afrom_documents( + documents=docs, embedding=embedding, session=session, keyspace=keyspace, table_name=table_name, ttl_seconds=ttl_seconds, - setup_mode=SetupMode.ASYNC, body_index_options=body_index_options, metadata_indexing=metadata_indexing, + **kwargs, ) - await store.aadd_texts( - texts=texts, metadatas=metadatas, ids=ids, concurrency=concurrency - ) - return store + + @staticmethod + def _add_ids_to_docs( + docs: List[Document], + ids: Optional[List[str]] = None, + ) -> List[Document]: + if ids is not None: + for doc, doc_id in zip(docs, ids): + doc.id = doc_id + return docs @classmethod def from_documents( @@ -1241,13 +1301,12 @@ def from_documents( keyspace: Optional[str] = None, table_name: str = "", ids: Optional[List[str]] = None, - batch_size: int = 16, ttl_seconds: Optional[int] = None, body_index_options: Optional[List[Tuple[str, Any]]] = None, metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all", **kwargs: Any, ) -> CVST: - """Create a Cassandra vectorstore from a document list. + """Create a Cassandra vector store from a document list. Args: documents: Documents to add to the vectorstore. @@ -1258,31 +1317,48 @@ def from_documents( If not provided, it is resolved from cassio. table_name: Cassandra table (required). ids: Optional list of IDs associated with the documents. - batch_size: Number of concurrent requests to send to the server. - Defaults to 16. ttl_seconds: Optional time-to-live for the added documents. body_index_options: Optional options used to create the body index. Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER] + metadata_indexing: Optional specification of a metadata indexing policy, + i.e. to fine-tune which of the metadata fields are indexed. + It can be a string ("all" or "none"), or a 2-tuple. The following + means that all fields except 'f1', 'f2' ... are NOT indexed: + metadata_indexing=("allowlist", ["f1", "f2", ...]) + The following means all fields EXCEPT 'g1', 'g2', ... are indexed: + metadata_indexing("denylist", ["g1", "g2", ...]) + The default is to index every metadata field. + Note: if you plan to have massive unique text metadata entries, + consider not indexing them for performance + (and to overcome max-length limitations). Returns: - a Cassandra vectorstore. + a Cassandra vector store. """ - texts = [doc.page_content for doc in documents] - metadatas = [doc.metadata for doc in documents] - return cls.from_texts( - texts=texts, + if ids is not None: + warnings.warn( + ( + "Parameter `ids` to Cassandra's `from_documents` " + "method is deprecated. Please set the supplied documents' " + "`.id` attribute instead. The id attribute of Document " + "is ignored as long as the `ids` parameter is passed." + ), + DeprecationWarning, + stacklevel=2, + ) + + store = cls( embedding=embedding, - metadatas=metadatas, session=session, keyspace=keyspace, table_name=table_name, - ids=ids, - batch_size=batch_size, ttl_seconds=ttl_seconds, body_index_options=body_index_options, metadata_indexing=metadata_indexing, **kwargs, ) + store.add_documents(documents=cls._add_ids_to_docs(docs=documents, ids=ids)) + return store @classmethod async def afrom_documents( @@ -1294,13 +1370,12 @@ async def afrom_documents( keyspace: Optional[str] = None, table_name: str = "", ids: Optional[List[str]] = None, - concurrency: int = 16, ttl_seconds: Optional[int] = None, body_index_options: Optional[List[Tuple[str, Any]]] = None, metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all", **kwargs: Any, ) -> CVST: - """Create a Cassandra vectorstore from a document list. + """Create a Cassandra vector store from a document list. Args: documents: Documents to add to the vectorstore. @@ -1311,31 +1386,51 @@ async def afrom_documents( If not provided, it is resolved from cassio. table_name: Cassandra table (required). ids: Optional list of IDs associated with the documents. - concurrency: Number of concurrent queries to send to the database. - Defaults to 16. ttl_seconds: Optional time-to-live for the added documents. body_index_options: Optional options used to create the body index. Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER] + metadata_indexing: Optional specification of a metadata indexing policy, + i.e. to fine-tune which of the metadata fields are indexed. + It can be a string ("all" or "none"), or a 2-tuple. The following + means that all fields except 'f1', 'f2' ... are NOT indexed: + metadata_indexing=("allowlist", ["f1", "f2", ...]) + The following means all fields EXCEPT 'g1', 'g2', ... are indexed: + metadata_indexing("denylist", ["g1", "g2", ...]) + The default is to index every metadata field. + Note: if you plan to have massive unique text metadata entries, + consider not indexing them for performance + (and to overcome max-length limitations). Returns: - a Cassandra vectorstore. + a Cassandra vector store. """ - texts = [doc.page_content for doc in documents] - metadatas = [doc.metadata for doc in documents] - return await cls.afrom_texts( - texts=texts, + if ids is not None: + warnings.warn( + ( + "Parameter `ids` to Cassandra's `afrom_documents` " + "method is deprecated. Please set the supplied documents' " + "`.id` attribute instead. The id attribute of Document " + "is ignored as long as the `ids` parameter is passed." + ), + DeprecationWarning, + stacklevel=2, + ) + + store = cls( embedding=embedding, - metadatas=metadatas, session=session, keyspace=keyspace, table_name=table_name, - ids=ids, - concurrency=concurrency, ttl_seconds=ttl_seconds, + setup_mode=SetupMode.ASYNC, body_index_options=body_index_options, metadata_indexing=metadata_indexing, **kwargs, ) + await store.aadd_documents( + documents=cls._add_ids_to_docs(docs=documents, ids=ids) + ) + return store def as_retriever( self, diff --git a/libs/community/tests/integration_tests/graph_vectorstores/test_cassandra.py b/libs/community/tests/integration_tests/graph_vectorstores/test_cassandra.py index 870c5141f8486e..d55f5469e546ad 100644 --- a/libs/community/tests/integration_tests/graph_vectorstores/test_cassandra.py +++ b/libs/community/tests/integration_tests/graph_vectorstores/test_cassandra.py @@ -1,116 +1,255 @@ -import math +"""Test of Apache Cassandra graph vector g_store class `CassandraGraphVectorStore`""" + +import json import os -from typing import Iterable, List, Optional, Type +import random +from contextlib import contextmanager +from typing import Any, Generator, Iterable, List, Optional +import pytest from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_community.graph_vectorstores import CassandraGraphVectorStore -from langchain_community.graph_vectorstores.links import METADATA_LINKS_KEY, Link +from langchain_community.graph_vectorstores.base import Node +from langchain_community.graph_vectorstores.links import ( + METADATA_LINKS_KEY, + Link, + add_links, +) +from tests.integration_tests.cache.fake_embeddings import ( + AngularTwoDimensionalEmbeddings, + FakeEmbeddings, +) + +TEST_KEYSPACE = "graph_test_keyspace" + + +class ParserEmbeddings(Embeddings): + """Parse input texts: if they are json for a List[float], fine. + Otherwise, return all zeros and call it a day. + """ -CASSANDRA_DEFAULT_KEYSPACE = "graph_test_keyspace" + def __init__(self, dimension: int) -> None: + self.dimension = dimension + def embed_documents(self, texts: list[str]) -> list[list[float]]: + return [self.embed_query(txt) for txt in texts] -def _get_graph_store( - embedding_class: Type[Embeddings], documents: Iterable[Document] = () -) -> CassandraGraphVectorStore: - import cassio - from cassandra.cluster import Cluster - from cassio.config import check_resolve_session, resolve_keyspace + def embed_query(self, text: str) -> list[float]: + try: + vals = json.loads(text) + except json.JSONDecodeError: + return [0.0] * self.dimension + else: + assert len(vals) == self.dimension + return vals - node_table = "graph_test_node_table" - edge_table = "graph_test_edge_table" - if any( - env_var in os.environ - for env_var in [ - "CASSANDRA_CONTACT_POINTS", - "ASTRA_DB_APPLICATION_TOKEN", - "ASTRA_DB_INIT_STRING", - ] - ): - cassio.init(auto=True) - session = check_resolve_session() - else: - cluster = Cluster() - session = cluster.connect() - keyspace = resolve_keyspace() or CASSANDRA_DEFAULT_KEYSPACE - cassio.init(session=session, keyspace=keyspace) - # ensure keyspace exists - session.execute( - ( - f"CREATE KEYSPACE IF NOT EXISTS {keyspace} " - f"WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}" - ) - ) - session.execute(f"DROP TABLE IF EXISTS {keyspace}.{node_table}") - session.execute(f"DROP TABLE IF EXISTS {keyspace}.{edge_table}") - store = CassandraGraphVectorStore.from_documents( - documents, - embedding=embedding_class(), - session=session, - keyspace=keyspace, - node_table=node_table, - targets_table=edge_table, - ) - return store +@pytest.fixture +def embedding_d2() -> Embeddings: + return ParserEmbeddings(dimension=2) -class FakeEmbeddings(Embeddings): - """Fake embeddings functionality for testing.""" +class EarthEmbeddings(Embeddings): + def get_vector_near(self, value: float) -> List[float]: + base_point = [value, (1 - value**2) ** 0.5] + fluctuation = random.random() / 100.0 + return [base_point[0] + fluctuation, base_point[1] - fluctuation] - def embed_documents(self, texts: List[str]) -> List[List[float]]: - """Return simple embeddings. - Embeddings encode each text as its index.""" - return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))] + def embed_documents(self, texts: list[str]) -> list[list[float]]: + return [self.embed_query(txt) for txt in texts] - async def aembed_documents(self, texts: List[str]) -> List[List[float]]: - return self.embed_documents(texts) + def embed_query(self, text: str) -> list[float]: + words = set(text.lower().split()) + if "earth" in words: + vector = self.get_vector_near(0.9) + elif {"planet", "world", "globe", "sphere"}.intersection(words): + vector = self.get_vector_near(0.8) + else: + vector = self.get_vector_near(0.1) + return vector - def embed_query(self, text: str) -> List[float]: - """Return constant query embeddings. - Embeddings are identical to embed_documents(texts)[0]. - Distance to each text will be that text's index, - as it was passed to embed_documents.""" - return [float(1.0)] * 9 + [float(0.0)] - async def aembed_query(self, text: str) -> List[float]: - return self.embed_query(text) +def _result_ids(docs: Iterable[Document]) -> List[Optional[str]]: + return [doc.id for doc in docs] -class AngularTwoDimensionalEmbeddings(Embeddings): +@pytest.fixture +def graph_vector_store_docs() -> list[Document]: """ - From angles (as strings in units of pi) to unit embedding vectors on a circle. + This is a set of Documents to pre-populate a graph vector store, + with entries placed in a certain way. + + Space of the entries (under Euclidean similarity): + + A0 (*) + .... AL AR <.... + : | : + : | ^ : + v | . v + | : + TR | : BL + T0 --------------x-------------- B0 + TL | : BR + | : + | . + | . + | + FL FR + F0 + + the query point is meant to be at (*). + the A are bidirectionally with B + the A are outgoing to T + the A are incoming from F + The links are like: L with L, 0 with 0 and R with R. """ - def embed_documents(self, texts: List[str]) -> List[List[float]]: - """ - Make a list of texts into a list of embedding vectors. - """ - return [self.embed_query(text) for text in texts] - - def embed_query(self, text: str) -> List[float]: - """ - Convert input text to a 'vector' (list of floats). - If the text is a number, use it as the angle for the - unit vector in units of pi. - Any other input text becomes the singular result [0, 0] ! - """ - try: - angle = float(text) - return [math.cos(angle * math.pi), math.sin(angle * math.pi)] - except ValueError: - # Assume: just test string, no attention is paid to values. - return [0.0, 0.0] + docs_a = [ + Document(id="AL", page_content="[-1, 9]", metadata={"label": "AL"}), + Document(id="A0", page_content="[0, 10]", metadata={"label": "A0"}), + Document(id="AR", page_content="[1, 9]", metadata={"label": "AR"}), + ] + docs_b = [ + Document(id="BL", page_content="[9, 1]", metadata={"label": "BL"}), + Document(id="B0", page_content="[10, 0]", metadata={"label": "B0"}), + Document(id="BL", page_content="[9, -1]", metadata={"label": "BR"}), + ] + docs_f = [ + Document(id="FL", page_content="[1, -9]", metadata={"label": "FL"}), + Document(id="F0", page_content="[0, -10]", metadata={"label": "F0"}), + Document(id="FR", page_content="[-1, -9]", metadata={"label": "FR"}), + ] + docs_t = [ + Document(id="TL", page_content="[-9, -1]", metadata={"label": "TL"}), + Document(id="T0", page_content="[-10, 0]", metadata={"label": "T0"}), + Document(id="TR", page_content="[-9, 1]", metadata={"label": "TR"}), + ] + for doc_a, suffix in zip(docs_a, ["l", "0", "r"]): + add_links(doc_a, Link.bidir(kind="ab_example", tag=f"tag_{suffix}")) + add_links(doc_a, Link.outgoing(kind="at_example", tag=f"tag_{suffix}")) + add_links(doc_a, Link.incoming(kind="af_example", tag=f"tag_{suffix}")) + for doc_b, suffix in zip(docs_b, ["l", "0", "r"]): + add_links(doc_b, Link.bidir(kind="ab_example", tag=f"tag_{suffix}")) + for doc_t, suffix in zip(docs_t, ["l", "0", "r"]): + add_links(doc_t, Link.incoming(kind="at_example", tag=f"tag_{suffix}")) + for doc_f, suffix in zip(docs_f, ["l", "0", "r"]): + add_links(doc_f, Link.outgoing(kind="af_example", tag=f"tag_{suffix}")) + return docs_a + docs_b + docs_f + docs_t + + +class CassandraSession: + table_name: str + session: Any + + def __init__(self, table_name: str, session: Any): + self.table_name = table_name + self.session = session + + +@contextmanager +def get_cassandra_session( + table_name: str, drop: bool = True +) -> Generator[CassandraSession, None, None]: + """Initialize the Cassandra cluster and session""" + from cassandra.cluster import Cluster + if "CASSANDRA_CONTACT_POINTS" in os.environ: + contact_points = [ + cp.strip() + for cp in os.environ["CASSANDRA_CONTACT_POINTS"].split(",") + if cp.strip() + ] + else: + contact_points = None -def _result_ids(docs: Iterable[Document]) -> List[Optional[str]]: - return [doc.id for doc in docs] + cluster = Cluster(contact_points) + session = cluster.connect() + try: + session.execute( + ( + f"CREATE KEYSPACE IF NOT EXISTS {TEST_KEYSPACE}" + " WITH replication = " + "{'class': 'SimpleStrategy', 'replication_factor': 1}" + ) + ) + if drop: + session.execute(f"DROP TABLE IF EXISTS {TEST_KEYSPACE}.{table_name}") + + # Yield the session for usage + yield CassandraSession(table_name=table_name, session=session) + finally: + # Ensure proper shutdown/cleanup of resources + session.shutdown() + cluster.shutdown() + + +@pytest.fixture(scope="function") +def graph_vector_store_angular( + table_name: str = "graph_test_table", +) -> Generator[CassandraGraphVectorStore, None, None]: + with get_cassandra_session(table_name=table_name) as session: + yield CassandraGraphVectorStore( + embedding=AngularTwoDimensionalEmbeddings(), + session=session.session, + keyspace=TEST_KEYSPACE, + table_name=session.table_name, + ) -def test_mmr_traversal() -> None: - """ - Test end to end construction and MMR search. + +@pytest.fixture(scope="function") +def graph_vector_store_earth( + table_name: str = "graph_test_table", +) -> Generator[CassandraGraphVectorStore, None, None]: + with get_cassandra_session(table_name=table_name) as session: + yield CassandraGraphVectorStore( + embedding=EarthEmbeddings(), + session=session.session, + keyspace=TEST_KEYSPACE, + table_name=session.table_name, + ) + + +@pytest.fixture(scope="function") +def graph_vector_store_fake( + table_name: str = "graph_test_table", +) -> Generator[CassandraGraphVectorStore, None, None]: + with get_cassandra_session(table_name=table_name) as session: + yield CassandraGraphVectorStore( + embedding=FakeEmbeddings(), + session=session.session, + keyspace=TEST_KEYSPACE, + table_name=session.table_name, + ) + + +@pytest.fixture(scope="function") +def graph_vector_store_d2( + embedding_d2: Embeddings, + table_name: str = "graph_test_table", +) -> Generator[CassandraGraphVectorStore, None, None]: + with get_cassandra_session(table_name=table_name) as session: + yield CassandraGraphVectorStore( + embedding=embedding_d2, + session=session.session, + keyspace=TEST_KEYSPACE, + table_name=session.table_name, + ) + + +@pytest.fixture(scope="function") +def populated_graph_vector_store_d2( + graph_vector_store_d2: CassandraGraphVectorStore, + graph_vector_store_docs: list[Document], +) -> Generator[CassandraGraphVectorStore, None, None]: + graph_vector_store_d2.add_documents(graph_vector_store_docs) + yield graph_vector_store_d2 + + +def test_mmr_traversal(graph_vector_store_angular: CassandraGraphVectorStore) -> None: + """ Test end to end construction and MMR search. The embedding function used here ensures `texts` become the following vectors on a circle (numbered v0 through v3): @@ -128,140 +267,128 @@ def test_mmr_traversal() -> None: Both v2 and v3 are reachable via edges from v0, so once it is selected, those are both considered. """ - store = _get_graph_store(AngularTwoDimensionalEmbeddings) - - v0 = Document( + v0 = Node( id="v0", - page_content="-0.124", - metadata={ - METADATA_LINKS_KEY: [ - Link.outgoing(kind="explicit", tag="link"), - ], - }, + text="-0.124", + links=[ + Link.outgoing(kind="explicit", tag="link"), + ], ) - v1 = Document( + v1 = Node( id="v1", - page_content="+0.127", + text="+0.127", ) - v2 = Document( + v2 = Node( id="v2", - page_content="+0.25", - metadata={ - METADATA_LINKS_KEY: [ - Link.incoming(kind="explicit", tag="link"), - ], - }, + text="+0.25", + links=[ + Link.incoming(kind="explicit", tag="link"), + ], ) - v3 = Document( + v3 = Node( id="v3", - page_content="+1.0", - metadata={ - METADATA_LINKS_KEY: [ - Link.incoming(kind="explicit", tag="link"), - ], - }, + text="+1.0", + links=[ + Link.incoming(kind="explicit", tag="link"), + ], ) - store.add_documents([v0, v1, v2, v3]) - results = store.mmr_traversal_search("0.0", k=2, fetch_k=2) + g_store = graph_vector_store_angular + g_store.add_nodes([v0, v1, v2, v3]) + + results = g_store.mmr_traversal_search("0.0", k=2, fetch_k=2) assert _result_ids(results) == ["v0", "v2"] # With max depth 0, no edges are traversed, so this doesn't reach v2 or v3. # So it ends up picking "v1" even though it's similar to "v0". - results = store.mmr_traversal_search("0.0", k=2, fetch_k=2, depth=0) + results = g_store.mmr_traversal_search("0.0", k=2, fetch_k=2, depth=0) assert _result_ids(results) == ["v0", "v1"] # With max depth 0 but higher `fetch_k`, we encounter v2 - results = store.mmr_traversal_search("0.0", k=2, fetch_k=3, depth=0) + results = g_store.mmr_traversal_search("0.0", k=2, fetch_k=3, depth=0) assert _result_ids(results) == ["v0", "v2"] # v0 score is .46, v2 score is 0.16 so it won't be chosen. - results = store.mmr_traversal_search("0.0", k=2, score_threshold=0.2) + results = g_store.mmr_traversal_search("0.0", k=2, score_threshold=0.2) assert _result_ids(results) == ["v0"] # with k=4 we should get all of the documents. - results = store.mmr_traversal_search("0.0", k=4) + results = g_store.mmr_traversal_search("0.0", k=4) assert _result_ids(results) == ["v0", "v2", "v1", "v3"] -def test_write_retrieve_keywords() -> None: - from langchain_openai import OpenAIEmbeddings - - greetings = Document( +def test_write_retrieve_keywords( + graph_vector_store_earth: CassandraGraphVectorStore, +) -> None: + greetings = Node( id="greetings", - page_content="Typical Greetings", - metadata={ - METADATA_LINKS_KEY: [ - Link.incoming(kind="parent", tag="parent"), - ], - }, + text="Typical Greetings", + links=[ + Link.incoming(kind="parent", tag="parent"), + ], ) - doc1 = Document( + + node1 = Node( id="doc1", - page_content="Hello World", - metadata={ - METADATA_LINKS_KEY: [ - Link.outgoing(kind="parent", tag="parent"), - Link.bidir(kind="kw", tag="greeting"), - Link.bidir(kind="kw", tag="world"), - ], - }, + text="Hello World", + links=[ + Link.outgoing(kind="parent", tag="parent"), + Link.bidir(kind="kw", tag="greeting"), + Link.bidir(kind="kw", tag="world"), + ], ) - doc2 = Document( + + node2 = Node( id="doc2", - page_content="Hello Earth", - metadata={ - METADATA_LINKS_KEY: [ - Link.outgoing(kind="parent", tag="parent"), - Link.bidir(kind="kw", tag="greeting"), - Link.bidir(kind="kw", tag="earth"), - ], - }, + text="Hello Earth", + links=[ + Link.outgoing(kind="parent", tag="parent"), + Link.bidir(kind="kw", tag="greeting"), + Link.bidir(kind="kw", tag="earth"), + ], ) - store = _get_graph_store(OpenAIEmbeddings, [greetings, doc1, doc2]) + + g_store = graph_vector_store_earth + g_store.add_nodes(nodes=[greetings, node1, node2]) # Doc2 is more similar, but World and Earth are similar enough that doc1 also # shows up. - results: Iterable[Document] = store.similarity_search("Earth", k=2) + results: Iterable[Document] = g_store.similarity_search("Earth", k=2) assert _result_ids(results) == ["doc2", "doc1"] - results = store.similarity_search("Earth", k=1) + results = g_store.similarity_search("Earth", k=1) assert _result_ids(results) == ["doc2"] - results = store.traversal_search("Earth", k=2, depth=0) + results = g_store.traversal_search("Earth", k=2, depth=0) assert _result_ids(results) == ["doc2", "doc1"] - results = store.traversal_search("Earth", k=2, depth=1) + results = g_store.traversal_search("Earth", k=2, depth=1) assert _result_ids(results) == ["doc2", "doc1", "greetings"] # K=1 only pulls in doc2 (Hello Earth) - results = store.traversal_search("Earth", k=1, depth=0) + results = g_store.traversal_search("Earth", k=1, depth=0) assert _result_ids(results) == ["doc2"] # K=1 only pulls in doc2 (Hello Earth). Depth=1 traverses to parent and via # keyword edge. - results = store.traversal_search("Earth", k=1, depth=1) + results = g_store.traversal_search("Earth", k=1, depth=1) assert set(_result_ids(results)) == {"doc2", "doc1", "greetings"} -def test_metadata() -> None: - store = _get_graph_store(FakeEmbeddings) - store.add_documents( - [ - Document( - id="a", - page_content="A", - metadata={ - METADATA_LINKS_KEY: [ - Link.incoming(kind="hyperlink", tag="http://a"), - Link.bidir(kind="other", tag="foo"), - ], - "other": "some other field", - }, - ) - ] +def test_metadata(graph_vector_store_fake: CassandraGraphVectorStore) -> None: + doc_a = Node( + id="a", + text="A", + metadata={"other": "some other field"}, + links=[ + Link.incoming(kind="hyperlink", tag="http://a"), + Link.bidir(kind="other", tag="foo"), + ], ) - results = store.similarity_search("A") + + g_store = graph_vector_store_fake + g_store.add_nodes([doc_a]) + results = g_store.similarity_search("A") assert len(results) == 1 assert results[0].id == "a" metadata = results[0].metadata @@ -270,3 +397,274 @@ def test_metadata() -> None: Link.incoming(kind="hyperlink", tag="http://a"), Link.bidir(kind="other", tag="foo"), } + + +class TestCassandraGraphVectorStore: + def test_gvs_similarity_search_sync( + self, + populated_graph_vector_store_d2: CassandraGraphVectorStore, + ) -> None: + """Simple (non-graph) similarity search on a graph vector g_store.""" + g_store = populated_graph_vector_store_d2 + ss_response = g_store.similarity_search(query="[2, 10]", k=2) + ss_labels = [doc.metadata["label"] for doc in ss_response] + assert ss_labels == ["AR", "A0"] + ss_by_v_response = g_store.similarity_search_by_vector(embedding=[2, 10], k=2) + ss_by_v_labels = [doc.metadata["label"] for doc in ss_by_v_response] + assert ss_by_v_labels == ["AR", "A0"] + + async def test_gvs_similarity_search_async( + self, + populated_graph_vector_store_d2: CassandraGraphVectorStore, + ) -> None: + """Simple (non-graph) similarity search on a graph vector store.""" + g_store = populated_graph_vector_store_d2 + ss_response = await g_store.asimilarity_search(query="[2, 10]", k=2) + ss_labels = [doc.metadata["label"] for doc in ss_response] + assert ss_labels == ["AR", "A0"] + ss_by_v_response = await g_store.asimilarity_search_by_vector( + embedding=[2, 10], k=2 + ) + ss_by_v_labels = [doc.metadata["label"] for doc in ss_by_v_response] + assert ss_by_v_labels == ["AR", "A0"] + + def test_gvs_traversal_search_sync( + self, + populated_graph_vector_store_d2: CassandraGraphVectorStore, + ) -> None: + """Graph traversal search on a graph vector store.""" + g_store = populated_graph_vector_store_d2 + ts_response = g_store.traversal_search(query="[2, 10]", k=2, depth=2) + # this is a set, as some of the internals of trav.search are set-driven + # so ordering is not deterministic: + ts_labels = {doc.metadata["label"] for doc in ts_response} + assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"} + + async def test_gvs_traversal_search_async( + self, + populated_graph_vector_store_d2: CassandraGraphVectorStore, + ) -> None: + """Graph traversal search on a graph vector store.""" + g_store = populated_graph_vector_store_d2 + ts_labels = set() + async for doc in g_store.atraversal_search(query="[2, 10]", k=2, depth=2): + ts_labels.add(doc.metadata["label"]) + # this is a set, as some of the internals of trav.search are set-driven + # so ordering is not deterministic: + assert ts_labels == {"AR", "A0", "BR", "B0", "TR", "T0"} + + def test_gvs_mmr_traversal_search_sync( + self, + populated_graph_vector_store_d2: CassandraGraphVectorStore, + ) -> None: + """MMR Graph traversal search on a graph vector store.""" + g_store = populated_graph_vector_store_d2 + mt_response = g_store.mmr_traversal_search( + query="[2, 10]", + k=2, + depth=2, + fetch_k=1, + adjacent_k=2, + lambda_mult=0.1, + ) + # TODO: can this rightfully be a list (or must it be a set)? + mt_labels = {doc.metadata["label"] for doc in mt_response} + assert mt_labels == {"AR", "BR"} + + async def test_gvs_mmr_traversal_search_async( + self, + populated_graph_vector_store_d2: CassandraGraphVectorStore, + ) -> None: + """MMR Graph traversal search on a graph vector store.""" + g_store = populated_graph_vector_store_d2 + mt_labels = set() + async for doc in g_store.ammr_traversal_search( + query="[2, 10]", + k=2, + depth=2, + fetch_k=1, + adjacent_k=2, + lambda_mult=0.1, + ): + mt_labels.add(doc.metadata["label"]) + # TODO: can this rightfully be a list (or must it be a set)? + assert mt_labels == {"AR", "BR"} + + def test_gvs_metadata_search_sync( + self, + populated_graph_vector_store_d2: CassandraGraphVectorStore, + ) -> None: + """Metadata search on a graph vector store.""" + g_store = populated_graph_vector_store_d2 + mt_response = g_store.metadata_search( + filter={"label": "T0"}, + n=2, + ) + doc: Document = next(iter(mt_response)) + assert doc.page_content == "[-10, 0]" + links = doc.metadata["links"] + assert len(links) == 1 + link: Link = links.pop() + assert isinstance(link, Link) + assert link.direction == "in" + assert link.kind == "at_example" + assert link.tag == "tag_0" + + async def test_gvs_metadata_search_async( + self, + populated_graph_vector_store_d2: CassandraGraphVectorStore, + ) -> None: + """Metadata search on a graph vector store.""" + g_store = populated_graph_vector_store_d2 + mt_response = await g_store.ametadata_search( + filter={"label": "T0"}, + n=2, + ) + doc: Document = next(iter(mt_response)) + assert doc.page_content == "[-10, 0]" + links: set[Link] = doc.metadata["links"] + assert len(links) == 1 + link: Link = links.pop() + assert isinstance(link, Link) + assert link.direction == "in" + assert link.kind == "at_example" + assert link.tag == "tag_0" + + def test_gvs_get_by_document_id_sync( + self, + populated_graph_vector_store_d2: CassandraGraphVectorStore, + ) -> None: + """Get by document_id on a graph vector store.""" + g_store = populated_graph_vector_store_d2 + doc = g_store.get_by_document_id(document_id="FL") + assert doc is not None + assert doc.page_content == "[1, -9]" + links = doc.metadata["links"] + assert len(links) == 1 + link: Link = links.pop() + assert isinstance(link, Link) + assert link.direction == "out" + assert link.kind == "af_example" + assert link.tag == "tag_l" + + invalid_doc = g_store.get_by_document_id(document_id="invalid") + assert invalid_doc is None + + async def test_gvs_get_by_document_id_async( + self, + populated_graph_vector_store_d2: CassandraGraphVectorStore, + ) -> None: + """Get by document_id on a graph vector store.""" + g_store = populated_graph_vector_store_d2 + doc = await g_store.aget_by_document_id(document_id="FL") + assert doc is not None + assert doc.page_content == "[1, -9]" + links = doc.metadata["links"] + assert len(links) == 1 + link: Link = links.pop() + assert isinstance(link, Link) + assert link.direction == "out" + assert link.kind == "af_example" + assert link.tag == "tag_l" + + invalid_doc = await g_store.aget_by_document_id(document_id="invalid") + assert invalid_doc is None + + def test_gvs_from_texts( + self, + graph_vector_store_d2: CassandraGraphVectorStore, + ) -> None: + g_store = graph_vector_store_d2 + g_store.add_texts( + texts=["[1, 2]"], + metadatas=[{"md": 1}], + ids=["x_id"], + ) + + hits = g_store.similarity_search("[2, 1]", k=2) + assert len(hits) == 1 + assert hits[0].page_content == "[1, 2]" + assert hits[0].id == "x_id" + # there may be more re:graph structure. + assert hits[0].metadata["md"] == "1.0" + + def test_gvs_from_documents_containing_ids( + self, + graph_vector_store_d2: CassandraGraphVectorStore, + ) -> None: + the_document = Document( + page_content="[1, 2]", + metadata={"md": 1}, + id="x_id", + ) + g_store = graph_vector_store_d2 + g_store.add_documents([the_document]) + hits = g_store.similarity_search("[2, 1]", k=2) + assert len(hits) == 1 + assert hits[0].page_content == "[1, 2]" + assert hits[0].id == "x_id" + # there may be more re:graph structure. + assert hits[0].metadata["md"] == "1.0" + + def test_gvs_add_nodes_sync( + self, + *, + graph_vector_store_d2: CassandraGraphVectorStore, + ) -> None: + links0 = [ + Link(kind="kA", direction="out", tag="tA"), + Link(kind="kB", direction="bidir", tag="tB"), + ] + links1 = [ + Link(kind="kC", direction="in", tag="tC"), + ] + nodes = [ + Node(id="id0", text="[1, 0]", metadata={"m": 0}, links=links0), + Node(text="[-1, 0]", metadata={"m": 1}, links=links1), + ] + graph_vector_store_d2.add_nodes(nodes) + hits = graph_vector_store_d2.similarity_search_by_vector([0.9, 0.1]) + assert len(hits) == 2 + assert hits[0].id == "id0" + assert hits[0].page_content == "[1, 0]" + md0 = hits[0].metadata + assert md0["m"] == "0.0" + assert any(isinstance(v, set) for k, v in md0.items() if k != "m") + + assert hits[1].id != "id0" + assert hits[1].page_content == "[-1, 0]" + md1 = hits[1].metadata + assert md1["m"] == "1.0" + assert any(isinstance(v, set) for k, v in md1.items() if k != "m") + + async def test_gvs_add_nodes_async( + self, + *, + graph_vector_store_d2: CassandraGraphVectorStore, + ) -> None: + links0 = [ + Link(kind="kA", direction="out", tag="tA"), + Link(kind="kB", direction="bidir", tag="tB"), + ] + links1 = [ + Link(kind="kC", direction="in", tag="tC"), + ] + nodes = [ + Node(id="id0", text="[1, 0]", metadata={"m": 0}, links=links0), + Node(text="[-1, 0]", metadata={"m": 1}, links=links1), + ] + async for _ in graph_vector_store_d2.aadd_nodes(nodes): + pass + + hits = await graph_vector_store_d2.asimilarity_search_by_vector([0.9, 0.1]) + assert len(hits) == 2 + assert hits[0].id == "id0" + assert hits[0].page_content == "[1, 0]" + md0 = hits[0].metadata + assert md0["m"] == "0.0" + assert any(isinstance(v, set) for k, v in md0.items() if k != "m") + assert hits[1].id != "id0" + assert hits[1].page_content == "[-1, 0]" + md1 = hits[1].metadata + assert md1["m"] == "1.0" + assert any(isinstance(v, set) for k, v in md1.items() if k != "m") diff --git a/libs/community/tests/integration_tests/graph_vectorstores/test_upgrade_to_cassandra.py b/libs/community/tests/integration_tests/graph_vectorstores/test_upgrade_to_cassandra.py new file mode 100644 index 00000000000000..6a09a2c5cb04ba --- /dev/null +++ b/libs/community/tests/integration_tests/graph_vectorstores/test_upgrade_to_cassandra.py @@ -0,0 +1,269 @@ +"""Test of Upgrading to Apache Cassandra graph vector store class: +`CassandraGraphVectorStore` from an existing table used +by the Cassandra vector store class: `Cassandra` +""" + +from __future__ import annotations + +import json +import os +from contextlib import contextmanager +from typing import Any, Generator, Iterable, Optional, Tuple, Union + +import pytest +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings + +from langchain_community.graph_vectorstores import CassandraGraphVectorStore +from langchain_community.utilities.cassandra import SetupMode +from langchain_community.vectorstores import Cassandra + +TEST_KEYSPACE = "graph_test_keyspace" + +TABLE_NAME_ALLOW_INDEXING = "allow_graph_table" +TABLE_NAME_DEFAULT = "default_graph_table" +TABLE_NAME_DENY_INDEXING = "deny_graph_table" + + +class ParserEmbeddings(Embeddings): + """Parse input texts: if they are json for a List[float], fine. + Otherwise, return all zeros and call it a day. + """ + + def __init__(self, dimension: int) -> None: + self.dimension = dimension + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + return [self.embed_query(txt) for txt in texts] + + async def aembed_documents(self, texts: list[str]) -> list[list[float]]: + return self.embed_documents(texts) + + def embed_query(self, text: str) -> list[float]: + try: + vals = json.loads(text) + except json.JSONDecodeError: + return [0.0] * self.dimension + else: + assert len(vals) == self.dimension + return vals + + async def aembed_query(self, text: str) -> list[float]: + return self.embed_query(text) + + +@pytest.fixture +def embedding_d2() -> Embeddings: + return ParserEmbeddings(dimension=2) + + +class CassandraSession: + table_name: str + session: Any + + def __init__(self, table_name: str, session: Any): + self.table_name = table_name + self.session = session + + +@contextmanager +def get_cassandra_session( + table_name: str, drop: bool = True +) -> Generator[CassandraSession, None, None]: + """Initialize the Cassandra cluster and session""" + from cassandra.cluster import Cluster + + if "CASSANDRA_CONTACT_POINTS" in os.environ: + contact_points = [ + cp.strip() + for cp in os.environ["CASSANDRA_CONTACT_POINTS"].split(",") + if cp.strip() + ] + else: + contact_points = None + + cluster = Cluster(contact_points) + session = cluster.connect() + + try: + session.execute( + ( + f"CREATE KEYSPACE IF NOT EXISTS {TEST_KEYSPACE}" + " WITH replication = " + "{'class': 'SimpleStrategy', 'replication_factor': 1}" + ) + ) + if drop: + session.execute(f"DROP TABLE IF EXISTS {TEST_KEYSPACE}.{table_name}") + + # Yield the session for usage + yield CassandraSession(table_name=table_name, session=session) + finally: + # Ensure proper shutdown/cleanup of resources + session.shutdown() + cluster.shutdown() + + +@contextmanager +def vector_store( + embedding: Embeddings, + table_name: str, + setup_mode: SetupMode, + metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all", + drop: bool = True, +) -> Generator[Cassandra, None, None]: + with get_cassandra_session(table_name=table_name, drop=drop) as session: + yield Cassandra( + table_name=session.table_name, + keyspace=TEST_KEYSPACE, + session=session.session, + embedding=embedding, + setup_mode=setup_mode, + metadata_indexing=metadata_indexing, + ) + + +@contextmanager +def graph_vector_store( + embedding: Embeddings, + table_name: str, + setup_mode: SetupMode, + metadata_deny_list: Optional[list[str]] = None, + drop: bool = True, +) -> Generator[CassandraGraphVectorStore, None, None]: + with get_cassandra_session(table_name=table_name, drop=drop) as session: + yield CassandraGraphVectorStore( + table_name=session.table_name, + keyspace=TEST_KEYSPACE, + session=session.session, + embedding=embedding, + setup_mode=setup_mode, + metadata_deny_list=metadata_deny_list, + ) + + +def _vs_indexing_policy(table_name: str) -> Union[Tuple[str, Iterable[str]], str]: + if table_name == TABLE_NAME_ALLOW_INDEXING: + return ("allowlist", ["test"]) + if table_name == TABLE_NAME_DEFAULT: + return "all" + if table_name == TABLE_NAME_DENY_INDEXING: + return ("denylist", ["test"]) + msg = f"Unknown table_name: {table_name} in _vs_indexing_policy()" + raise ValueError(msg) + + +class TestUpgradeToGraphVectorStore: + @pytest.mark.parametrize( + ("table_name", "gvs_setup_mode", "gvs_metadata_deny_list"), + [ + (TABLE_NAME_DEFAULT, SetupMode.SYNC, None), + (TABLE_NAME_DENY_INDEXING, SetupMode.SYNC, ["test"]), + (TABLE_NAME_DEFAULT, SetupMode.OFF, None), + (TABLE_NAME_DENY_INDEXING, SetupMode.OFF, ["test"]), + # for this one, even though the passed policy doesn't + # match the policy used to create the collection, + # there is no error since the SetupMode is OFF and + # and no attempt is made to re-create the collection. + (TABLE_NAME_DENY_INDEXING, SetupMode.OFF, None), + ], + ids=[ + "default_upgrade_no_policy_sync", + "deny_list_upgrade_same_policy_sync", + "default_upgrade_no_policy_off", + "deny_list_upgrade_same_policy_off", + "deny_list_upgrade_change_policy_off", + ], + ) + def test_upgrade_to_gvs_success_sync( + self, + *, + embedding_d2: Embeddings, + gvs_setup_mode: SetupMode, + table_name: str, + gvs_metadata_deny_list: list[str], + ) -> None: + doc_id = "AL" + doc_al = Document(id=doc_id, page_content="[-1, 9]", metadata={"label": "AL"}) + + # Create vector store using SetupMode.SYNC + with vector_store( + embedding=embedding_d2, + table_name=table_name, + setup_mode=SetupMode.SYNC, + metadata_indexing=_vs_indexing_policy(table_name=table_name), + drop=True, + ) as v_store: + # load a document to the vector store + v_store.add_documents([doc_al]) + + # get the document from the vector store + v_doc = v_store.get_by_document_id(document_id=doc_id) + assert v_doc is not None + assert v_doc.page_content == doc_al.page_content + + # Create a GRAPH Vector Store using the existing collection from above + # with setup_mode=gvs_setup_mode and indexing_policy=gvs_indexing_policy + with graph_vector_store( + embedding=embedding_d2, + table_name=table_name, + setup_mode=gvs_setup_mode, + metadata_deny_list=gvs_metadata_deny_list, + drop=False, + ) as gv_store: + # get the document from the GRAPH vector store + gv_doc = gv_store.get_by_document_id(document_id=doc_id) + assert gv_doc is not None + assert gv_doc.page_content == doc_al.page_content + + @pytest.mark.parametrize( + ("table_name", "gvs_setup_mode", "gvs_metadata_deny_list"), + [ + (TABLE_NAME_DEFAULT, SetupMode.ASYNC, None), + (TABLE_NAME_DENY_INDEXING, SetupMode.ASYNC, ["test"]), + ], + ids=[ + "default_upgrade_no_policy_async", + "deny_list_upgrade_same_policy_async", + ], + ) + async def test_upgrade_to_gvs_success_async( + self, + *, + embedding_d2: Embeddings, + gvs_setup_mode: SetupMode, + table_name: str, + gvs_metadata_deny_list: list[str], + ) -> None: + doc_id = "AL" + doc_al = Document(id=doc_id, page_content="[-1, 9]", metadata={"label": "AL"}) + + # Create vector store using SetupMode.ASYNC + with vector_store( + embedding=embedding_d2, + table_name=table_name, + setup_mode=SetupMode.ASYNC, + metadata_indexing=_vs_indexing_policy(table_name=table_name), + drop=True, + ) as v_store: + # load a document to the vector store + await v_store.aadd_documents([doc_al]) + + # get the document from the vector store + v_doc = await v_store.aget_by_document_id(document_id=doc_id) + assert v_doc is not None + assert v_doc.page_content == doc_al.page_content + + # Create a GRAPH Vector Store using the existing collection from above + # with setup_mode=gvs_setup_mode and indexing_policy=gvs_indexing_policy + with graph_vector_store( + embedding=embedding_d2, + table_name=table_name, + setup_mode=gvs_setup_mode, + metadata_deny_list=gvs_metadata_deny_list, + drop=False, + ) as gv_store: + # get the document from the GRAPH vector store + gv_doc = await gv_store.aget_by_document_id(document_id=doc_id) + assert gv_doc is not None + assert gv_doc.page_content == doc_al.page_content diff --git a/libs/community/tests/integration_tests/vectorstores/test_cassandra.py b/libs/community/tests/integration_tests/vectorstores/test_cassandra.py index fd55bab2d3163b..70158f1fc5e8f7 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_cassandra.py +++ b/libs/community/tests/integration_tests/vectorstores/test_cassandra.py @@ -1,21 +1,38 @@ """Test Cassandra functionality.""" import asyncio +import json +import math import os import time -from typing import Iterable, List, Optional, Tuple, Type, Union +from contextlib import asynccontextmanager, contextmanager +from typing import ( + Any, + AsyncGenerator, + Generator, + Iterable, + List, + Optional, + Tuple, + Union, +) import pytest from langchain_core.documents import Document from langchain_community.vectorstores import Cassandra -from langchain_community.vectorstores.cassandra import SetupMode from tests.integration_tests.vectorstores.fake_embeddings import ( AngularTwoDimensionalEmbeddings, ConsistentFakeEmbeddings, Embeddings, ) +TEST_KEYSPACE = "vector_test_keyspace" + +# similarity threshold definitions +EUCLIDEAN_MIN_SIM_UNIT_VECTORS = 0.2 +MATCH_EPSILON = 0.0001 + def _strip_docs(documents: List[Document]) -> List[Document]: return [_strip_doc(doc) for doc in documents] @@ -28,18 +45,85 @@ def _strip_doc(document: Document) -> Document: ) -def _vectorstore_from_texts( - texts: List[str], - metadatas: Optional[List[dict]] = None, - embedding_class: Type[Embeddings] = ConsistentFakeEmbeddings, - drop: bool = True, - metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all", - table_name: str = "vector_test_table", -) -> Cassandra: +class ParserEmbeddings(Embeddings): + """Parse input texts: if they are json for a List[float], fine. + Otherwise, return all zeros and call it a day. + """ + + def __init__(self, dimension: int) -> None: + self.dimension = dimension + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + return [self.embed_query(txt) for txt in texts] + + def embed_query(self, text: str) -> list[float]: + try: + vals = json.loads(text) + except json.JSONDecodeError: + return [0.0] * self.dimension + else: + assert len(vals) == self.dimension + return vals + + +@pytest.fixture +def embedding_d2() -> Embeddings: + return ParserEmbeddings(dimension=2) + + +@pytest.fixture +def metadata_documents() -> list[Document]: + """Documents for metadata and id tests""" + return [ + Document( + id="q", + page_content="[1,2]", + metadata={"ord": str(ord("q")), "group": "consonant", "letter": "q"}, + ), + Document( + id="w", + page_content="[3,4]", + metadata={"ord": str(ord("w")), "group": "consonant", "letter": "w"}, + ), + Document( + id="r", + page_content="[5,6]", + metadata={"ord": str(ord("r")), "group": "consonant", "letter": "r"}, + ), + Document( + id="e", + page_content="[-1,2]", + metadata={"ord": str(ord("e")), "group": "vowel", "letter": "e"}, + ), + Document( + id="i", + page_content="[-3,4]", + metadata={"ord": str(ord("i")), "group": "vowel", "letter": "i"}, + ), + Document( + id="o", + page_content="[-5,6]", + metadata={"ord": str(ord("o")), "group": "vowel", "letter": "o"}, + ), + ] + + +class CassandraSession: + table_name: str + session: Any + + def __init__(self, table_name: str, session: Any): + self.table_name = table_name + self.session = session + + +@contextmanager +def get_cassandra_session( + table_name: str, drop: bool = True +) -> Generator[CassandraSession, None, None]: + """Initialize the Cassandra cluster and session""" from cassandra.cluster import Cluster - keyspace = "vector_test_keyspace" - # get db connection if "CASSANDRA_CONTACT_POINTS" in os.environ: contact_points = [ cp.strip() @@ -48,107 +132,133 @@ def _vectorstore_from_texts( ] else: contact_points = None + cluster = Cluster(contact_points) session = cluster.connect() - # ensure keyspace exists - session.execute( - ( - f"CREATE KEYSPACE IF NOT EXISTS {keyspace} " - f"WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}" + + try: + session.execute( + ( + f"CREATE KEYSPACE IF NOT EXISTS {TEST_KEYSPACE}" + " WITH replication = " + "{'class': 'SimpleStrategy', 'replication_factor': 1}" + ) ) - ) - # drop table if required - if drop: - session.execute(f"DROP TABLE IF EXISTS {keyspace}.{table_name}") - # - return Cassandra.from_texts( - texts, - embedding_class(), - metadatas=metadatas, - session=session, - keyspace=keyspace, - table_name=table_name, - metadata_indexing=metadata_indexing, - ) + if drop: + session.execute(f"DROP TABLE IF EXISTS {TEST_KEYSPACE}.{table_name}") + # Yield the session for usage + yield CassandraSession(table_name=table_name, session=session) + finally: + # Ensure proper shutdown/cleanup of resources + session.shutdown() + cluster.shutdown() -async def _vectorstore_from_texts_async( + +@pytest.fixture +def cassandra_session( + request: pytest.FixtureRequest, +) -> Generator[CassandraSession, None, None]: + request_param = getattr(request, "param", {}) + table_name = request_param.get("table_name", "vector_test_table") + drop = request_param.get("drop", True) + + with get_cassandra_session(table_name, drop) as session: + yield session + + +@contextmanager +def vector_store_from_texts( texts: List[str], metadatas: Optional[List[dict]] = None, - embedding_class: Type[Embeddings] = ConsistentFakeEmbeddings, + embedding: Optional[Embeddings] = None, drop: bool = True, metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all", table_name: str = "vector_test_table", -) -> Cassandra: - from cassandra.cluster import Cluster +) -> Generator[Cassandra, None, None]: + if embedding is None: + embedding = ConsistentFakeEmbeddings() + with get_cassandra_session(table_name=table_name, drop=drop) as session: + yield Cassandra.from_texts( + texts, + embedding=embedding, + metadatas=metadatas, + session=session.session, + keyspace=TEST_KEYSPACE, + table_name=session.table_name, + metadata_indexing=metadata_indexing, + ) - keyspace = "vector_test_keyspace" - # get db connection - if "CASSANDRA_CONTACT_POINTS" in os.environ: - contact_points = [ - cp.strip() - for cp in os.environ["CASSANDRA_CONTACT_POINTS"].split(",") - if cp.strip() - ] - else: - contact_points = None - cluster = Cluster(contact_points) - session = cluster.connect() - # ensure keyspace exists - session.execute( - ( - f"CREATE KEYSPACE IF NOT EXISTS {keyspace} " - f"WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}" + +@asynccontextmanager +async def vector_store_from_texts_async( + texts: List[str], + metadatas: Optional[List[dict]] = None, + embedding: Optional[Embeddings] = None, + drop: bool = True, + metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all", + table_name: str = "vector_test_table", +) -> AsyncGenerator[Cassandra, None]: + if embedding is None: + embedding = ConsistentFakeEmbeddings() + with get_cassandra_session(table_name=table_name, drop=drop) as session: + yield await Cassandra.afrom_texts( + texts, + embedding=embedding, + metadatas=metadatas, + session=session.session, + keyspace=TEST_KEYSPACE, + table_name=session.table_name, + metadata_indexing=metadata_indexing, + ) + + +@pytest.fixture(scope="function") +def vector_store_d2( + embedding_d2: Embeddings, + table_name: str = "vector_test_table_d2", +) -> Generator[Cassandra, None, None]: + with get_cassandra_session(table_name=table_name) as session: + yield Cassandra( + embedding=embedding_d2, + session=session.session, + keyspace=TEST_KEYSPACE, + table_name=session.table_name, ) - ) - # drop table if required - if drop: - session.execute(f"DROP TABLE IF EXISTS {keyspace}.{table_name}") - # - return await Cassandra.afrom_texts( - texts, - embedding_class(), - metadatas=metadatas, - session=session, - keyspace=keyspace, - table_name=table_name, - setup_mode=SetupMode.ASYNC, - ) async def test_cassandra() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] - docsearch = _vectorstore_from_texts(texts) - output = docsearch.similarity_search("foo", k=1) - assert _strip_docs(output) == _strip_docs([Document(page_content="foo")]) - output = await docsearch.asimilarity_search("foo", k=1) - assert _strip_docs(output) == _strip_docs([Document(page_content="foo")]) + with vector_store_from_texts(texts) as vstore: + output = vstore.similarity_search("foo", k=1) + assert _strip_docs(output) == _strip_docs([Document(page_content="foo")]) + output = await vstore.asimilarity_search("foo", k=1) + assert _strip_docs(output) == _strip_docs([Document(page_content="foo")]) async def test_cassandra_with_score() -> None: """Test end to end construction and search with scores and IDs.""" texts = ["foo", "bar", "baz"] metadatas = [{"page": i} for i in range(len(texts))] - docsearch = _vectorstore_from_texts(texts, metadatas=metadatas) - - expected_docs = [ - Document(page_content="foo", metadata={"page": "0.0"}), - Document(page_content="bar", metadata={"page": "1.0"}), - Document(page_content="baz", metadata={"page": "2.0"}), - ] + with vector_store_from_texts(texts, metadatas=metadatas) as vstore: + expected_docs = [ + Document(page_content="foo", metadata={"page": "0.0"}), + Document(page_content="bar", metadata={"page": "1.0"}), + Document(page_content="baz", metadata={"page": "2.0"}), + ] - output = docsearch.similarity_search_with_score("foo", k=3) - docs = [o[0] for o in output] - scores = [o[1] for o in output] - assert _strip_docs(docs) == _strip_docs(expected_docs) - assert scores[0] > scores[1] > scores[2] + output = vstore.similarity_search_with_score("foo", k=3) + docs = [o[0] for o in output] + scores = [o[1] for o in output] + assert _strip_docs(docs) == _strip_docs(expected_docs) + assert scores[0] > scores[1] > scores[2] - output = await docsearch.asimilarity_search_with_score("foo", k=3) - docs = [o[0] for o in output] - scores = [o[1] for o in output] - assert _strip_docs(docs) == _strip_docs(expected_docs) - assert scores[0] > scores[1] > scores[2] + output = await vstore.asimilarity_search_with_score("foo", k=3) + docs = [o[0] for o in output] + scores = [o[1] for o in output] + assert _strip_docs(docs) == _strip_docs(expected_docs) + assert scores[0] > scores[1] > scores[2] async def test_cassandra_max_marginal_relevance_search() -> None: @@ -169,285 +279,1265 @@ async def test_cassandra_max_marginal_relevance_search() -> None: """ texts = ["-0.124", "+0.127", "+0.25", "+1.0"] metadatas = [{"page": i} for i in range(len(texts))] - docsearch = _vectorstore_from_texts( - texts, metadatas=metadatas, embedding_class=AngularTwoDimensionalEmbeddings - ) - - expected_set = { - ("+0.25", "2.0"), - ("-0.124", "0.0"), - } - - output = docsearch.max_marginal_relevance_search("0.0", k=2, fetch_k=3) - output_set = { - (mmr_doc.page_content, mmr_doc.metadata["page"]) for mmr_doc in output - } - assert output_set == expected_set - - output = await docsearch.amax_marginal_relevance_search("0.0", k=2, fetch_k=3) - output_set = { - (mmr_doc.page_content, mmr_doc.metadata["page"]) for mmr_doc in output - } - assert output_set == expected_set + with vector_store_from_texts( + texts, + metadatas=metadatas, + embedding=AngularTwoDimensionalEmbeddings(), + ) as vstore: + expected_set = { + ("+0.25", "2.0"), + ("-0.124", "0.0"), + } + + output = vstore.max_marginal_relevance_search("0.0", k=2, fetch_k=3) + output_set = { + (mmr_doc.page_content, mmr_doc.metadata["page"]) for mmr_doc in output + } + assert output_set == expected_set + + output = await vstore.amax_marginal_relevance_search("0.0", k=2, fetch_k=3) + output_set = { + (mmr_doc.page_content, mmr_doc.metadata["page"]) for mmr_doc in output + } + assert output_set == expected_set def test_cassandra_add_texts() -> None: """Test end to end construction with further insertions.""" texts = ["foo", "bar", "baz"] metadatas = [{"page": i} for i in range(len(texts))] - docsearch = _vectorstore_from_texts(texts, metadatas=metadatas) - - texts2 = ["foo2", "bar2", "baz2"] - metadatas2 = [{"page": i + 3} for i in range(len(texts))] - docsearch.add_texts(texts2, metadatas2) + with vector_store_from_texts(texts, metadatas=metadatas) as vstore: + texts2 = ["foo2", "bar2", "baz2"] + metadatas2 = [{"page": i + 3} for i in range(len(texts))] + vstore.add_texts(texts2, metadatas2) - output = docsearch.similarity_search("foo", k=10) - assert len(output) == 6 + output = vstore.similarity_search("foo", k=10) + assert len(output) == 6 -async def test_cassandra_aadd_texts() -> None: +async def test_cassandra_add_texts_async() -> None: """Test end to end construction with further insertions.""" texts = ["foo", "bar", "baz"] metadatas = [{"page": i} for i in range(len(texts))] - docsearch = _vectorstore_from_texts(texts, metadatas=metadatas) + async with vector_store_from_texts_async(texts, metadatas=metadatas) as vstore: + texts2 = ["foo2", "bar2", "baz2"] + metadatas2 = [{"page": i + 3} for i in range(len(texts))] + await vstore.aadd_texts(texts2, metadatas2) - texts2 = ["foo2", "bar2", "baz2"] - metadatas2 = [{"page": i + 3} for i in range(len(texts))] - await docsearch.aadd_texts(texts2, metadatas2) - - output = await docsearch.asimilarity_search("foo", k=10) - assert len(output) == 6 + output = await vstore.asimilarity_search("foo", k=10) + assert len(output) == 6 def test_cassandra_no_drop() -> None: """Test end to end construction and re-opening the same index.""" texts = ["foo", "bar", "baz"] metadatas = [{"page": i} for i in range(len(texts))] - _vectorstore_from_texts(texts, metadatas=metadatas) + with vector_store_from_texts(texts, metadatas=metadatas) as vstore: + output = vstore.similarity_search("foo", k=10) + assert len(output) == 3 texts2 = ["foo2", "bar2", "baz2"] - docsearch = _vectorstore_from_texts(texts2, metadatas=metadatas, drop=False) - - output = docsearch.similarity_search("foo", k=10) - assert len(output) == 6 + with vector_store_from_texts(texts2, metadatas=metadatas, drop=False) as vstore: + output = vstore.similarity_search("foo", k=10) + assert len(output) == 6 async def test_cassandra_no_drop_async() -> None: """Test end to end construction and re-opening the same index.""" texts = ["foo", "bar", "baz"] metadatas = [{"page": i} for i in range(len(texts))] - await _vectorstore_from_texts_async(texts, metadatas=metadatas) + async with vector_store_from_texts_async(texts, metadatas=metadatas) as vstore: + output = await vstore.asimilarity_search("foo", k=10) + assert len(output) == 3 texts2 = ["foo2", "bar2", "baz2"] - docsearch = await _vectorstore_from_texts_async( + async with vector_store_from_texts_async( texts2, metadatas=metadatas, drop=False - ) - - output = await docsearch.asimilarity_search("foo", k=10) - assert len(output) == 6 + ) as vstore: + output = await vstore.asimilarity_search("foo", k=10) + assert len(output) == 6 def test_cassandra_delete() -> None: """Test delete methods from vector store.""" texts = ["foo", "bar", "baz", "gni"] metadatas = [{"page": i, "mod2": i % 2} for i in range(len(texts))] - docsearch = _vectorstore_from_texts([], metadatas=metadatas) + with vector_store_from_texts([], metadatas=metadatas) as vstore: + ids = vstore.add_texts(texts, metadatas) + output = vstore.similarity_search("foo", k=10) + assert len(output) == 4 - ids = docsearch.add_texts(texts, metadatas) - output = docsearch.similarity_search("foo", k=10) - assert len(output) == 4 + vstore.delete_by_document_id(ids[0]) + output = vstore.similarity_search("foo", k=10) + assert len(output) == 3 - docsearch.delete_by_document_id(ids[0]) - output = docsearch.similarity_search("foo", k=10) - assert len(output) == 3 + vstore.delete(ids[1:3]) + output = vstore.similarity_search("foo", k=10) + assert len(output) == 1 - docsearch.delete(ids[1:3]) - output = docsearch.similarity_search("foo", k=10) - assert len(output) == 1 + vstore.delete(["not-existing"]) + output = vstore.similarity_search("foo", k=10) + assert len(output) == 1 - docsearch.delete(["not-existing"]) - output = docsearch.similarity_search("foo", k=10) - assert len(output) == 1 + vstore.clear() + time.sleep(0.3) + output = vstore.similarity_search("foo", k=10) + assert len(output) == 0 - docsearch.clear() - time.sleep(0.3) - output = docsearch.similarity_search("foo", k=10) - assert len(output) == 0 + vstore.add_texts(texts, metadatas) + num_deleted = vstore.delete_by_metadata_filter({"mod2": 0}, batch_size=1) + assert num_deleted == 2 + output = vstore.similarity_search("foo", k=10) + assert len(output) == 2 + vstore.clear() - docsearch.add_texts(texts, metadatas) - num_deleted = docsearch.delete_by_metadata_filter({"mod2": 0}, batch_size=1) - assert num_deleted == 2 - output = docsearch.similarity_search("foo", k=10) - assert len(output) == 2 - docsearch.clear() + with pytest.raises(ValueError): + vstore.delete_by_metadata_filter({}) - with pytest.raises(ValueError): - docsearch.delete_by_metadata_filter({}) - -async def test_cassandra_adelete() -> None: +async def test_cassandra_delete_async() -> None: """Test delete methods from vector store.""" texts = ["foo", "bar", "baz", "gni"] metadatas = [{"page": i, "mod2": i % 2} for i in range(len(texts))] - docsearch = await _vectorstore_from_texts_async([], metadatas=metadatas) - - ids = await docsearch.aadd_texts(texts, metadatas) - output = await docsearch.asimilarity_search("foo", k=10) - assert len(output) == 4 + async with vector_store_from_texts_async([], metadatas=metadatas) as vstore: + ids = await vstore.aadd_texts(texts, metadatas) + output = await vstore.asimilarity_search("foo", k=10) + assert len(output) == 4 - await docsearch.adelete_by_document_id(ids[0]) - output = await docsearch.asimilarity_search("foo", k=10) - assert len(output) == 3 + await vstore.adelete_by_document_id(ids[0]) + output = await vstore.asimilarity_search("foo", k=10) + assert len(output) == 3 - await docsearch.adelete(ids[1:3]) - output = await docsearch.asimilarity_search("foo", k=10) - assert len(output) == 1 + await vstore.adelete(ids[1:3]) + output = await vstore.asimilarity_search("foo", k=10) + assert len(output) == 1 - await docsearch.adelete(["not-existing"]) - output = await docsearch.asimilarity_search("foo", k=10) - assert len(output) == 1 + await vstore.adelete(["not-existing"]) + output = await vstore.asimilarity_search("foo", k=10) + assert len(output) == 1 - await docsearch.aclear() - await asyncio.sleep(0.3) - output = docsearch.similarity_search("foo", k=10) - assert len(output) == 0 + await vstore.aclear() + await asyncio.sleep(0.3) + output = vstore.similarity_search("foo", k=10) + assert len(output) == 0 - await docsearch.aadd_texts(texts, metadatas) - num_deleted = await docsearch.adelete_by_metadata_filter({"mod2": 0}, batch_size=1) - assert num_deleted == 2 - output = await docsearch.asimilarity_search("foo", k=10) - assert len(output) == 2 - await docsearch.aclear() + await vstore.aadd_texts(texts, metadatas) + num_deleted = await vstore.adelete_by_metadata_filter({"mod2": 0}, batch_size=1) + assert num_deleted == 2 + output = await vstore.asimilarity_search("foo", k=10) + assert len(output) == 2 + await vstore.aclear() - with pytest.raises(ValueError): - await docsearch.adelete_by_metadata_filter({}) + with pytest.raises(ValueError): + await vstore.adelete_by_metadata_filter({}) def test_cassandra_metadata_indexing() -> None: """Test comparing metadata indexing policies.""" texts = ["foo"] metadatas = [{"field1": "a", "field2": "b"}] - vstore_all = _vectorstore_from_texts(texts, metadatas=metadatas) - vstore_f1 = _vectorstore_from_texts( - texts, - metadatas=metadatas, - metadata_indexing=("allowlist", ["field1"]), - table_name="vector_test_table_indexing", + with vector_store_from_texts(texts, metadatas=metadatas) as vstore_all: + with vector_store_from_texts( + texts, + metadatas=metadatas, + metadata_indexing=("allowlist", ["field1"]), + table_name="vector_test_table_indexing", + embedding=ConsistentFakeEmbeddings(), + ) as vstore_f1: + output_all = vstore_all.similarity_search("bar", k=2) + output_f1 = vstore_f1.similarity_search("bar", filter={"field1": "a"}, k=2) + output_f1_no = vstore_f1.similarity_search( + "bar", filter={"field1": "Z"}, k=2 + ) + assert len(output_all) == 1 + assert output_all[0].metadata == metadatas[0] + assert len(output_f1) == 1 + assert output_f1[0].metadata == metadatas[0] + assert len(output_f1_no) == 0 + + with pytest.raises(ValueError): + # "Non-indexed metadata fields cannot be used in queries." + vstore_f1.similarity_search("bar", filter={"field2": "b"}, k=2) + + +class TestCassandraVectorStore: + @pytest.mark.parametrize( + "page_contents", + [ + [ + "[1,2]", + "[3,4]", + "[5,6]", + "[7,8]", + "[9,10]", + "[11,12]", + ], + ], ) + def test_cassandra_vectorstore_from_texts_sync( + self, + *, + cassandra_session: CassandraSession, + embedding_d2: Embeddings, + page_contents: list[str], + ) -> None: + """from_texts methods and the associated warnings.""" + v_store = Cassandra.from_texts( + texts=page_contents[0:2], + metadatas=[{"m": 1}, {"m": 3}], + ids=["ft1", "ft3"], + table_name=cassandra_session.table_name, + session=cassandra_session.session, + keyspace=TEST_KEYSPACE, + embedding=embedding_d2, + ) + search_results_triples_0 = v_store.similarity_search_with_score_id( + page_contents[1], + k=1, + ) + assert len(search_results_triples_0) == 1 + res_doc_0, _, res_id_0 = search_results_triples_0[0] + assert res_doc_0.page_content == page_contents[1] + assert res_doc_0.metadata == {"m": "3.0"} + assert res_id_0 == "ft3" + + Cassandra.from_texts( + texts=page_contents[2:4], + metadatas=[{"m": 5}, {"m": 7}], + ids=["ft5", "ft7"], + table_name=cassandra_session.table_name, + session=cassandra_session.session, + keyspace=TEST_KEYSPACE, + embedding=embedding_d2, + ) - output_all = vstore_all.similarity_search("bar", k=2) - output_f1 = vstore_f1.similarity_search("bar", filter={"field1": "a"}, k=2) - output_f1_no = vstore_f1.similarity_search("bar", filter={"field1": "Z"}, k=2) - assert len(output_all) == 1 - assert output_all[0].metadata == metadatas[0] - assert len(output_f1) == 1 - assert output_f1[0].metadata == metadatas[0] - assert len(output_f1_no) == 0 - - with pytest.raises(ValueError): - # "Non-indexed metadata fields cannot be used in queries." - vstore_f1.similarity_search("bar", filter={"field2": "b"}, k=2) - - -def test_cassandra_replace_metadata() -> None: - """Test of replacing metadata.""" - N_DOCS = 100 - REPLACE_RATIO = 2 # one in ... will have replaced metadata - BATCH_SIZE = 3 - - vstore_f1 = _vectorstore_from_texts( - texts=[], - metadata_indexing=("allowlist", ["field1", "field2"]), - table_name="vector_test_table_indexing", + search_results_triples_1 = v_store.similarity_search_with_score_id( + page_contents[3], + k=1, + ) + assert len(search_results_triples_1) == 1 + res_doc_1, _, res_id_1 = search_results_triples_1[0] + assert res_doc_1.page_content == page_contents[3] + assert res_doc_1.metadata == {"m": "7.0"} + assert res_id_1 == "ft7" + v_store_2 = Cassandra.from_texts( + texts=page_contents[4:6], + metadatas=[{"m": 9}, {"m": 11}], + ids=["ft9", "ft11"], + table_name=cassandra_session.table_name, + session=cassandra_session.session, + keyspace=TEST_KEYSPACE, + embedding=embedding_d2, + ) + search_results_triples_2 = v_store_2.similarity_search_with_score_id( + page_contents[5], + k=1, + ) + assert len(search_results_triples_2) == 1 + res_doc_2, _, res_id_2 = search_results_triples_2[0] + assert res_doc_2.page_content == page_contents[5] + assert res_doc_2.metadata == {"m": "11.0"} + assert res_id_2 == "ft11" + v_store_2.clear() + + @pytest.mark.parametrize( + "page_contents", + [ + ["[1,2]", "[3,4]"], + ], ) - orig_documents = [ - Document( - page_content=f"doc_{doc_i}", - id=f"doc_id_{doc_i}", - metadata={"field1": f"f1_{doc_i}", "otherf": "pre"}, + def test_cassandra_vectorstore_from_documents_sync( + self, + *, + cassandra_session: CassandraSession, + embedding_d2: Embeddings, + page_contents: list[str], + ) -> None: + """from_documents, esp. the various handling of ID-in-doc vs external.""" + pc1, pc2 = page_contents + # no IDs. + v_store = Cassandra.from_documents( + [ + Document(page_content=pc1, metadata={"m": 1}), + Document(page_content=pc2, metadata={"m": 3}), + ], + table_name=cassandra_session.table_name, + session=cassandra_session.session, + keyspace=TEST_KEYSPACE, + embedding=embedding_d2, ) - for doc_i in range(N_DOCS) - ] - vstore_f1.add_documents(orig_documents) + hits = v_store.similarity_search(pc2, k=1) + assert len(hits) == 1 + assert hits[0].page_content == pc2 + assert hits[0].metadata == {"m": "3.0"} + v_store.clear() + + # IDs passed separately. + with pytest.warns(DeprecationWarning) as rec_warnings: + v_store_2 = Cassandra.from_documents( + [ + Document(page_content=pc1, metadata={"m": 1}), + Document(page_content=pc2, metadata={"m": 3}), + ], + ids=["idx1", "idx3"], + table_name=cassandra_session.table_name, + session=cassandra_session.session, + keyspace=TEST_KEYSPACE, + embedding=embedding_d2, + ) + f_rec_warnings = [ + wrn for wrn in rec_warnings if issubclass(wrn.category, DeprecationWarning) + ] + assert len(f_rec_warnings) == 1 + hits = v_store_2.similarity_search(pc2, k=1) + assert len(hits) == 1 + assert hits[0].page_content == pc2 + assert hits[0].metadata == {"m": "3.0"} + assert hits[0].id == "idx3" + v_store_2.clear() + + # IDs in documents. + v_store_3 = Cassandra.from_documents( + [ + Document(page_content=pc1, metadata={"m": 1}, id="idx1"), + Document(page_content=pc2, metadata={"m": 3}, id="idx3"), + ], + table_name=cassandra_session.table_name, + session=cassandra_session.session, + keyspace=TEST_KEYSPACE, + embedding=embedding_d2, + ) + hits = v_store_3.similarity_search(pc2, k=1) + assert len(hits) == 1 + assert hits[0].page_content == pc2 + assert hits[0].metadata == {"m": "3.0"} + assert hits[0].id == "idx3" + v_store_3.clear() + + # IDs both in documents and aside. + with pytest.warns(DeprecationWarning) as rec_warnings: + v_store_4 = Cassandra.from_documents( + [ + Document(page_content=pc1, metadata={"m": 1}), + Document(page_content=pc2, metadata={"m": 3}, id="idy3"), + ], + ids=["idx1", "idx3"], + table_name=cassandra_session.table_name, + session=cassandra_session.session, + keyspace=TEST_KEYSPACE, + embedding=embedding_d2, + ) + f_rec_warnings = [ + wrn for wrn in rec_warnings if issubclass(wrn.category, DeprecationWarning) + ] + hits = v_store_4.similarity_search(pc2, k=1) + assert len(hits) == 1 + assert hits[0].page_content == pc2 + assert hits[0].metadata == {"m": "3.0"} + assert hits[0].id == "idx3" + v_store_4.clear() + + @pytest.mark.parametrize( + "page_contents", + [ + [ + "[1,2]", + "[3,4]", + "[5,6]", + "[7,8]", + "[9,10]", + "[11,12]", + ], + ], + ) + async def test_cassandra_vectorstore_from_texts_async( + self, + *, + cassandra_session: CassandraSession, + embedding_d2: Embeddings, + page_contents: list[str], + ) -> None: + """from_texts methods and the associated warnings, async version.""" + v_store = await Cassandra.afrom_texts( + texts=page_contents[0:2], + metadatas=[{"m": 1}, {"m": 3}], + ids=["ft1", "ft3"], + table_name=cassandra_session.table_name, + session=cassandra_session.session, + keyspace=TEST_KEYSPACE, + embedding=embedding_d2, + ) + search_results_triples_0 = await v_store.asimilarity_search_with_score_id( + page_contents[1], + k=1, + ) + assert len(search_results_triples_0) == 1 + res_doc_0, _, res_id_0 = search_results_triples_0[0] + assert res_doc_0.page_content == page_contents[1] + assert res_doc_0.metadata == {"m": "3.0"} + assert res_id_0 == "ft3" + + await Cassandra.afrom_texts( + texts=page_contents[2:4], + metadatas=[{"m": 5}, {"m": 7}], + ids=["ft5", "ft7"], + table_name=cassandra_session.table_name, + session=cassandra_session.session, + keyspace=TEST_KEYSPACE, + embedding=embedding_d2, + ) + search_results_triples_1 = await v_store.asimilarity_search_with_score_id( + page_contents[3], + k=1, + ) + assert len(search_results_triples_1) == 1 + res_doc_1, _, res_id_1 = search_results_triples_1[0] + assert res_doc_1.page_content == page_contents[3] + assert res_doc_1.metadata == {"m": "7.0"} + assert res_id_1 == "ft7" + + v_store_2 = await Cassandra.afrom_texts( + texts=page_contents[4:6], + metadatas=[{"m": 9}, {"m": 11}], + ids=["ft9", "ft11"], + table_name=cassandra_session.table_name, + session=cassandra_session.session, + keyspace=TEST_KEYSPACE, + embedding=embedding_d2, + ) + search_results_triples_2 = await v_store_2.asimilarity_search_with_score_id( + page_contents[5], + k=1, + ) + assert len(search_results_triples_2) == 1 + res_doc_2, _, res_id_2 = search_results_triples_2[0] + assert res_doc_2.page_content == page_contents[5] + assert res_doc_2.metadata == {"m": "11.0"} + assert res_id_2 == "ft11" + await v_store_2.aclear() + + @pytest.mark.parametrize( + "page_contents", + [ + ["[1,2]", "[3,4]"], + ], + ) + async def test_cassandra_vectorstore_from_documents_async( + self, + *, + cassandra_session: CassandraSession, + embedding_d2: Embeddings, + page_contents: list[str], + ) -> None: + """ + from_documents, esp. the various handling of ID-in-doc vs external. + Async version. + """ + pc1, pc2 = page_contents + + # no IDs. + v_store = await Cassandra.afrom_documents( + [ + Document(page_content=pc1, metadata={"m": 1}), + Document(page_content=pc2, metadata={"m": 3}), + ], + table_name=cassandra_session.table_name, + session=cassandra_session.session, + keyspace=TEST_KEYSPACE, + embedding=embedding_d2, + ) + hits = await v_store.asimilarity_search(pc2, k=1) + assert len(hits) == 1 + assert hits[0].page_content == pc2 + assert hits[0].metadata == {"m": "3.0"} + await v_store.aclear() + + # IDs passed separately. + with pytest.warns(DeprecationWarning) as rec_warnings: + v_store_2 = await Cassandra.afrom_documents( + [ + Document(page_content=pc1, metadata={"m": 1}), + Document(page_content=pc2, metadata={"m": 3}), + ], + ids=["idx1", "idx3"], + table_name=cassandra_session.table_name, + session=cassandra_session.session, + keyspace=TEST_KEYSPACE, + embedding=embedding_d2, + ) + f_rec_warnings = [ + wrn for wrn in rec_warnings if issubclass(wrn.category, DeprecationWarning) + ] + assert len(f_rec_warnings) == 1 + hits = await v_store_2.asimilarity_search(pc2, k=1) + assert len(hits) == 1 + assert hits[0].page_content == pc2 + assert hits[0].metadata == {"m": "3.0"} + assert hits[0].id == "idx3" + await v_store_2.aclear() + + # IDs in documents. + + v_store_3 = await Cassandra.afrom_documents( + [ + Document(page_content=pc1, metadata={"m": 1}, id="idx1"), + Document(page_content=pc2, metadata={"m": 3}, id="idx3"), + ], + table_name=cassandra_session.table_name, + session=cassandra_session.session, + keyspace=TEST_KEYSPACE, + embedding=embedding_d2, + ) + hits = await v_store_3.asimilarity_search(pc2, k=1) + assert len(hits) == 1 + assert hits[0].page_content == pc2 + assert hits[0].metadata == {"m": "3.0"} + assert hits[0].id == "idx3" + await v_store_3.aclear() + + # IDs both in documents and aside. + with pytest.warns(DeprecationWarning) as rec_warnings: + v_store_4 = await Cassandra.afrom_documents( + [ + Document(page_content=pc1, metadata={"m": 1}), + Document(page_content=pc2, metadata={"m": 3}, id="idy3"), + ], + ids=["idx1", "idx3"], + table_name=cassandra_session.table_name, + session=cassandra_session.session, + keyspace=TEST_KEYSPACE, + embedding=embedding_d2, + ) + f_rec_warnings = [ + wrn for wrn in rec_warnings if issubclass(wrn.category, DeprecationWarning) + ] + assert len(f_rec_warnings) == 1 + hits = await v_store_4.asimilarity_search(pc2, k=1) + assert len(hits) == 1 + assert hits[0].page_content == pc2 + assert hits[0].metadata == {"m": "3.0"} + assert hits[0].id == "idx3" + await v_store_4.aclear() + + def test_cassandra_vectorstore_crud_sync( + self, + vector_store_d2: Cassandra, + ) -> None: + """Add/delete/update behaviour.""" + vstore = vector_store_d2 + + res0 = vstore.similarity_search("[-1,-1]", k=2) + assert res0 == [] + # write and check again + added_ids = vstore.add_texts( + texts=["[1,2]", "[3,4]", "[5,6]"], + metadatas=[ + {"k": "a", "ord": 0}, + {"k": "b", "ord": 1}, + {"k": "c", "ord": 2}, + ], + ids=["a", "b", "c"], + ) + # not requiring ordered match (elsewhere it may be overwriting some) + assert set(added_ids) == {"a", "b", "c"} + res1 = vstore.similarity_search("[-1,-1]", k=5) + assert {doc.page_content for doc in res1} == {"[1,2]", "[3,4]", "[5,6]"} + res2 = vstore.similarity_search("[3,4]", k=1) + assert len(res2) == 1 + assert res2[0].page_content == "[3,4]" + assert res2[0].metadata == {"k": "b", "ord": "1.0"} + assert res2[0].id == "b" + # partial overwrite and count total entries + added_ids_1 = vstore.add_texts( + texts=["[5,6]", "[7,8]"], + metadatas=[ + {"k": "c_new", "ord": 102}, + {"k": "d_new", "ord": 103}, + ], + ids=["c", "d"], + ) + # not requiring ordered match (elsewhere it may be overwriting some) + assert set(added_ids_1) == {"c", "d"} + res2 = vstore.similarity_search("[-1,-1]", k=10) + assert len(res2) == 4 + # pick one that was just updated and check its metadata + res3 = vstore.similarity_search_with_score_id( + query="[5,6]", k=1, filter={"k": "c_new"} + ) + doc3, _, id3 = res3[0] + assert doc3.page_content == "[5,6]" + assert doc3.metadata == {"k": "c_new", "ord": "102.0"} + assert id3 == "c" + # delete and count again + del1_res = vstore.delete(["b"]) + assert del1_res is True + del2_res = vstore.delete(["a", "c", "Z!"]) + assert del2_res is True # a non-existing ID was supplied + assert len(vstore.similarity_search("[-1,-1]", k=10)) == 1 + # clear store + vstore.clear() + assert vstore.similarity_search("[-1,-1]", k=2) == [] + # add_documents with "ids" arg passthrough + vstore.add_documents( + [ + Document(page_content="[9,10]", metadata={"k": "v", "ord": 204}), + Document(page_content="[11,12]", metadata={"k": "w", "ord": 205}), + ], + ids=["v", "w"], + ) + assert len(vstore.similarity_search("[-1,-1]", k=10)) == 2 + res4 = vstore.similarity_search("[11,12]", k=1, filter={"k": "w"}) + assert res4[0].metadata["ord"] == "205.0" + assert res4[0].id == "w" + # add_texts with "ids" arg passthrough + vstore.add_texts( + texts=["[13,14]", "[15,16]"], + metadatas=[{"k": "r", "ord": 306}, {"k": "s", "ord": 307}], + ids=["r", "s"], + ) + assert len(vstore.similarity_search("[-1,-1]", k=10)) == 4 + res4 = vstore.similarity_search("[-1,-1]", k=1, filter={"k": "s"}) + assert res4[0].metadata["ord"] == "307.0" + assert res4[0].id == "s" + # delete_by_document_id + vstore.delete_by_document_id("s") + assert len(vstore.similarity_search("[-1,-1]", k=10)) == 3 + + async def test_cassandra_vectorstore_crud_async( + self, + vector_store_d2: Cassandra, + ) -> None: + """Add/delete/update behaviour, async version.""" + vstore = vector_store_d2 + + res0 = await vstore.asimilarity_search("[-1,-1]", k=2) + assert res0 == [] + # write and check again + added_ids = await vstore.aadd_texts( + texts=["[1,2]", "[3,4]", "[5,6]"], + metadatas=[ + {"k": "a", "ord": 0}, + {"k": "b", "ord": 1}, + {"k": "c", "ord": 2}, + ], + ids=["a", "b", "c"], + ) + # not requiring ordered match (elsewhere it may be overwriting some) + assert set(added_ids) == {"a", "b", "c"} + res1 = await vstore.asimilarity_search("[-1,-1]", k=5) + assert {doc.page_content for doc in res1} == {"[1,2]", "[3,4]", "[5,6]"} + res2 = await vstore.asimilarity_search("[3,4]", k=1) + assert len(res2) == 1 + assert res2[0].page_content == "[3,4]" + assert res2[0].metadata == {"k": "b", "ord": "1.0"} + assert res2[0].id == "b" + # partial overwrite and count total entries + added_ids_1 = await vstore.aadd_texts( + texts=["[5,6]", "[7,8]"], + metadatas=[ + {"k": "c_new", "ord": 102}, + {"k": "d_new", "ord": 103}, + ], + ids=["c", "d"], + ) + # not requiring ordered match (elsewhere it may be overwriting some) + assert set(added_ids_1) == {"c", "d"} + res2 = await vstore.asimilarity_search("[-1,-1]", k=10) + assert len(res2) == 4 + # pick one that was just updated and check its metadata + res3 = await vstore.asimilarity_search_with_score_id( + query="[5,6]", k=1, filter={"k": "c_new"} + ) + doc3, _, id3 = res3[0] + assert doc3.page_content == "[5,6]" + assert doc3.metadata == {"k": "c_new", "ord": "102.0"} + assert id3 == "c" + # delete and count again + del1_res = await vstore.adelete(["b"]) + assert del1_res is True + del2_res = await vstore.adelete(["a", "c", "Z!"]) + assert del2_res is True # a non-existing ID was supplied + assert len(await vstore.asimilarity_search("[-1,-1]", k=10)) == 1 + # clear store + await vstore.aclear() + assert await vstore.asimilarity_search("[-1,-1]", k=2) == [] + # add_documents with "ids" arg passthrough + await vstore.aadd_documents( + [ + Document(page_content="[9,10]", metadata={"k": "v", "ord": 204}), + Document(page_content="[11,12]", metadata={"k": "w", "ord": 205}), + ], + ids=["v", "w"], + ) + assert len(await vstore.asimilarity_search("[-1,-1]", k=10)) == 2 + res4 = await vstore.asimilarity_search("[11,12]", k=1, filter={"k": "w"}) + assert res4[0].metadata["ord"] == "205.0" + assert res4[0].id == "w" + # add_texts with "ids" arg passthrough + await vstore.aadd_texts( + texts=["[13,14]", "[15,16]"], + metadatas=[{"k": "r", "ord": 306}, {"k": "s", "ord": 307}], + ids=["r", "s"], + ) + assert len(await vstore.asimilarity_search("[-1,-1]", k=10)) == 4 + res4 = await vstore.asimilarity_search("[-1,-1]", k=1, filter={"k": "s"}) + assert res4[0].metadata["ord"] == "307.0" + assert res4[0].id == "s" + # delete_by_document_id + await vstore.adelete_by_document_id("s") + assert len(await vstore.asimilarity_search("[-1,-1]", k=10)) == 3 + + def test_cassandra_vectorstore_massive_insert_replace_sync( + self, + vector_store_d2: Cassandra, + ) -> None: + """Testing the insert-many-and-replace-some patterns thoroughly.""" + full_size = 300 + first_group_size = 150 + second_group_slicer = [30, 100, 2] + + all_ids = [f"doc_{idx}" for idx in range(full_size)] + all_texts = [f"[0,{idx + 1}]" for idx in range(full_size)] + + # massive insertion on empty + group0_ids = all_ids[0:first_group_size] + group0_texts = all_texts[0:first_group_size] + inserted_ids0 = vector_store_d2.add_texts( + texts=group0_texts, + ids=group0_ids, + ) + assert set(inserted_ids0) == set(group0_ids) + # massive insertion with many overwrites scattered through + # (we change the text to later check on DB for successful update) + _s, _e, _st = second_group_slicer + group1_ids = all_ids[_s:_e:_st] + all_ids[first_group_size:full_size] + group1_texts = [ + txt.upper() + for txt in (all_texts[_s:_e:_st] + all_texts[first_group_size:full_size]) + ] + inserted_ids1 = vector_store_d2.add_texts( + texts=group1_texts, + ids=group1_ids, + ) + assert set(inserted_ids1) == set(group1_ids) + # final read (we want the IDs to do a full check) + expected_text_by_id = { + **dict(zip(group0_ids, group0_texts)), + **dict(zip(group1_ids, group1_texts)), + } + full_results = vector_store_d2.similarity_search_with_score_id_by_vector( + embedding=[1.0, 1.0], + k=full_size, + ) + for doc, _, doc_id in full_results: + assert doc.page_content == expected_text_by_id[doc_id] + + async def test_cassandra_vectorstore_massive_insert_replace_async( + self, + vector_store_d2: Cassandra, + ) -> None: + """ + Testing the insert-many-and-replace-some patterns thoroughly. + Async version. + """ + full_size = 300 + first_group_size = 150 + second_group_slicer = [30, 100, 2] + + all_ids = [f"doc_{idx}" for idx in range(full_size)] + all_texts = [f"[0,{idx + 1}]" for idx in range(full_size)] + all_embeddings = [[0, idx + 1] for idx in range(full_size)] + + # massive insertion on empty + group0_ids = all_ids[0:first_group_size] + group0_texts = all_texts[0:first_group_size] + + inserted_ids0 = await vector_store_d2.aadd_texts( + texts=group0_texts, + ids=group0_ids, + ) + assert set(inserted_ids0) == set(group0_ids) + # massive insertion with many overwrites scattered through + # (we change the text to later check on DB for successful update) + _s, _e, _st = second_group_slicer + group1_ids = all_ids[_s:_e:_st] + all_ids[first_group_size:full_size] + group1_texts = [ + txt.upper() + for txt in (all_texts[_s:_e:_st] + all_texts[first_group_size:full_size]) + ] + inserted_ids1 = await vector_store_d2.aadd_texts( + texts=group1_texts, + ids=group1_ids, + ) + assert set(inserted_ids1) == set(group1_ids) + # final read (we want the IDs to do a full check) + expected_text_by_id = dict(zip(all_ids, all_texts)) + full_results = await vector_store_d2.asimilarity_search_with_score_id_by_vector( + embedding=[1.0, 1.0], + k=full_size, + ) + for doc, _, doc_id in full_results: + assert doc.page_content == expected_text_by_id[doc_id] + expected_embedding_by_id = dict(zip(all_ids, all_embeddings)) + full_results_with_embeddings = ( + await vector_store_d2.asimilarity_search_with_embedding_id_by_vector( + embedding=[1.0, 1.0], + k=full_size, + ) + ) + for doc, embedding, doc_id in full_results_with_embeddings: + assert doc.page_content == expected_text_by_id[doc_id] + assert embedding == expected_embedding_by_id[doc_id] + + def test_cassandra_vectorstore_delete_by_metadata_sync( + self, + vector_store_d2: Cassandra, + ) -> None: + """Testing delete_by_metadata_filter.""" + full_size = 400 + # one in ... will be deleted + deletee_ratio = 3 + + documents = [ + Document( + page_content="[1,1]", metadata={"deletee": doc_i % deletee_ratio == 0} + ) + for doc_i in range(full_size) + ] + num_deletees = len([doc for doc in documents if doc.metadata["deletee"]]) - ids_to_replace = [ - f"doc_id_{doc_i}" for doc_i in range(N_DOCS) if doc_i % REPLACE_RATIO == 0 - ] + inserted_ids0 = vector_store_d2.add_documents(documents) + assert len(inserted_ids0) == len(documents) - # various kinds of replacement at play here: - def _make_new_md(mode: int, doc_id: str) -> dict[str, str]: - if mode == 0: - return {} - elif mode == 1: - return {"field2": f"NEW_{doc_id}"} - elif mode == 2: - return {"field2": f"NEW_{doc_id}", "ofherf2": "post"} - else: - return {"ofherf2": "post"} - - ids_to_new_md = { - doc_id: _make_new_md(rep_i % 4, doc_id) - for rep_i, doc_id in enumerate(ids_to_replace) - } - - vstore_f1.replace_metadata(ids_to_new_md, batch_size=BATCH_SIZE) - # thorough check - expected_id_to_metadata: dict[str, dict] = { - **{(document.id or ""): document.metadata for document in orig_documents}, - **ids_to_new_md, - } - for hit in vstore_f1.similarity_search("doc", k=N_DOCS + 1): - assert hit.id is not None - assert hit.metadata == expected_id_to_metadata[hit.id] - - -async def test_cassandra_areplace_metadata() -> None: - """Test of replacing metadata.""" - N_DOCS = 100 - REPLACE_RATIO = 2 # one in ... will have replaced metadata - BATCH_SIZE = 3 - - vstore_f1 = _vectorstore_from_texts( - texts=[], - metadata_indexing=("allowlist", ["field1", "field2"]), - table_name="vector_test_table_indexing", - ) - orig_documents = [ - Document( - page_content=f"doc_{doc_i}", - id=f"doc_id_{doc_i}", - metadata={"field1": f"f1_{doc_i}", "otherf": "pre"}, + d_result0 = vector_store_d2.delete_by_metadata_filter({"deletee": True}) + assert d_result0 == num_deletees + count_on_store0 = len( + vector_store_d2.similarity_search("[1,1]", k=full_size + 1) ) - for doc_i in range(N_DOCS) - ] - await vstore_f1.aadd_documents(orig_documents) + assert count_on_store0 == full_size - num_deletees - ids_to_replace = [ - f"doc_id_{doc_i}" for doc_i in range(N_DOCS) if doc_i % REPLACE_RATIO == 0 - ] + with pytest.raises(ValueError, match="does not accept an empty"): + vector_store_d2.delete_by_metadata_filter({}) + count_on_store1 = len( + vector_store_d2.similarity_search("[1,1]", k=full_size + 1) + ) + assert count_on_store1 == full_size - num_deletees + + async def test_cassandra_vectorstore_delete_by_metadata_async( + self, + vector_store_d2: Cassandra, + ) -> None: + """Testing delete_by_metadata_filter, async version.""" + full_size = 400 + # one in ... will be deleted + deletee_ratio = 3 + + documents = [ + Document( + page_content="[1,1]", metadata={"deletee": doc_i % deletee_ratio == 0} + ) + for doc_i in range(full_size) + ] + num_deletees = len([doc for doc in documents if doc.metadata["deletee"]]) - # various kinds of replacement at play here: - def _make_new_md(mode: int, doc_id: str) -> dict[str, str]: - if mode == 0: - return {} - elif mode == 1: - return {"field2": f"NEW_{doc_id}"} - elif mode == 2: - return {"field2": f"NEW_{doc_id}", "ofherf2": "post"} - else: - return {"ofherf2": "post"} - - ids_to_new_md = { - doc_id: _make_new_md(rep_i % 4, doc_id) - for rep_i, doc_id in enumerate(ids_to_replace) - } - - await vstore_f1.areplace_metadata(ids_to_new_md, concurrency=BATCH_SIZE) - # thorough check - expected_id_to_metadata: dict[str, dict] = { - **{(document.id or ""): document.metadata for document in orig_documents}, - **ids_to_new_md, - } - for hit in await vstore_f1.asimilarity_search("doc", k=N_DOCS + 1): - assert hit.id is not None - assert hit.metadata == expected_id_to_metadata[hit.id] + inserted_ids0 = await vector_store_d2.aadd_documents(documents) + assert len(inserted_ids0) == len(documents) + + d_result0 = await vector_store_d2.adelete_by_metadata_filter({"deletee": True}) + assert d_result0 == num_deletees + count_on_store0 = len( + await vector_store_d2.asimilarity_search("[1,1]", k=full_size + 1) + ) + assert count_on_store0 == full_size - num_deletees + + with pytest.raises(ValueError, match="does not accept an empty"): + await vector_store_d2.adelete_by_metadata_filter({}) + count_on_store1 = len( + await vector_store_d2.asimilarity_search("[1,1]", k=full_size + 1) + ) + assert count_on_store1 == full_size - num_deletees + + def test_cassandra_replace_metadata(self) -> None: + """Test of replacing metadata.""" + N_DOCS = 100 + REPLACE_RATIO = 2 # one in ... will have replaced metadata + BATCH_SIZE = 3 + + with vector_store_from_texts( + texts=[], + metadata_indexing=("allowlist", ["field1", "field2"]), + table_name="vector_test_table_indexing", + ) as vstore_f1: + orig_documents = [ + Document( + page_content=f"doc_{doc_i}", + id=f"doc_id_{doc_i}", + metadata={"field1": f"f1_{doc_i}", "otherf": "pre"}, + ) + for doc_i in range(N_DOCS) + ] + vstore_f1.add_documents(orig_documents) + + ids_to_replace = [ + f"doc_id_{doc_i}" + for doc_i in range(N_DOCS) + if doc_i % REPLACE_RATIO == 0 + ] + + # various kinds of replacement at play here: + def _make_new_md(mode: int, doc_id: str) -> dict[str, str]: + if mode == 0: + return {} + elif mode == 1: + return {"field2": f"NEW_{doc_id}"} + elif mode == 2: + return {"field2": f"NEW_{doc_id}", "ofherf2": "post"} + else: + return {"ofherf2": "post"} + + ids_to_new_md = { + doc_id: _make_new_md(rep_i % 4, doc_id) + for rep_i, doc_id in enumerate(ids_to_replace) + } + + vstore_f1.replace_metadata(ids_to_new_md, batch_size=BATCH_SIZE) + # thorough check + expected_id_to_metadata: dict[str, dict] = { + **{ + (document.id or ""): document.metadata + for document in orig_documents + }, + **ids_to_new_md, + } + for hit in vstore_f1.similarity_search("doc", k=N_DOCS + 1): + assert hit.id is not None + assert hit.metadata == expected_id_to_metadata[hit.id] + + async def test_cassandra_replace_metadata_async(self) -> None: + """Test of replacing metadata.""" + N_DOCS = 100 + REPLACE_RATIO = 2 # one in ... will have replaced metadata + BATCH_SIZE = 3 + + async with vector_store_from_texts_async( + texts=[], + metadata_indexing=("allowlist", ["field1", "field2"]), + table_name="vector_test_table_indexing", + embedding=ConsistentFakeEmbeddings(), + ) as vstore_f1: + orig_documents = [ + Document( + page_content=f"doc_{doc_i}", + id=f"doc_id_{doc_i}", + metadata={"field1": f"f1_{doc_i}", "otherf": "pre"}, + ) + for doc_i in range(N_DOCS) + ] + await vstore_f1.aadd_documents(orig_documents) + + ids_to_replace = [ + f"doc_id_{doc_i}" + for doc_i in range(N_DOCS) + if doc_i % REPLACE_RATIO == 0 + ] + + # various kinds of replacement at play here: + def _make_new_md(mode: int, doc_id: str) -> dict[str, str]: + if mode == 0: + return {} + elif mode == 1: + return {"field2": f"NEW_{doc_id}"} + elif mode == 2: + return {"field2": f"NEW_{doc_id}", "ofherf2": "post"} + else: + return {"ofherf2": "post"} + + ids_to_new_md = { + doc_id: _make_new_md(rep_i % 4, doc_id) + for rep_i, doc_id in enumerate(ids_to_replace) + } + + await vstore_f1.areplace_metadata(ids_to_new_md, concurrency=BATCH_SIZE) + # thorough check + expected_id_to_metadata: dict[str, dict] = { + **{ + (document.id or ""): document.metadata + for document in orig_documents + }, + **ids_to_new_md, + } + for hit in await vstore_f1.asimilarity_search("doc", k=N_DOCS + 1): + assert hit.id is not None + assert hit.metadata == expected_id_to_metadata[hit.id] + + def test_cassandra_vectorstore_mmr_sync( + self, + vector_store_d2: Cassandra, + ) -> None: + """MMR testing. We work on the unit circle with angle multiples + of 2*pi/20 and prepare a store with known vectors for a controlled + MMR outcome. + """ + + def _v_from_i(i: int, n: int) -> str: + angle = 2 * math.pi * i / n + vector = [math.cos(angle), math.sin(angle)] + return json.dumps(vector) + + i_vals = [0, 4, 5, 13] + n_val = 20 + vector_store_d2.add_texts( + [_v_from_i(i, n_val) for i in i_vals], metadatas=[{"i": i} for i in i_vals] + ) + res1 = vector_store_d2.max_marginal_relevance_search( + _v_from_i(3, n_val), + k=2, + fetch_k=3, + ) + res_i_vals = {doc.metadata["i"] for doc in res1} + assert res_i_vals == {"0.0", "4.0"} + + async def test_cassandra_vectorstore_mmr_async( + self, + vector_store_d2: Cassandra, + ) -> None: + """MMR testing. We work on the unit circle with angle multiples + of 2*pi/20 and prepare a store with known vectors for a controlled + MMR outcome. + Async version. + """ + + def _v_from_i(i: int, n: int) -> str: + angle = 2 * math.pi * i / n + vector = [math.cos(angle), math.sin(angle)] + return json.dumps(vector) + + i_vals = [0, 4, 5, 13] + n_val = 20 + await vector_store_d2.aadd_texts( + [_v_from_i(i, n_val) for i in i_vals], + metadatas=[{"i": i} for i in i_vals], + ) + res1 = await vector_store_d2.amax_marginal_relevance_search( + _v_from_i(3, n_val), + k=2, + fetch_k=3, + ) + res_i_vals = {doc.metadata["i"] for doc in res1} + assert res_i_vals == {"0.0", "4.0"} + + def test_cassandra_vectorstore_metadata_filter( + self, + vector_store_d2: Cassandra, + metadata_documents: list[Document], + ) -> None: + """Metadata filtering.""" + vstore = vector_store_d2 + vstore.add_documents(metadata_documents) + # no filters + res0 = vstore.similarity_search("[-1,-1]", k=10) + assert {doc.metadata["letter"] for doc in res0} == set("qwreio") + # single filter + res1 = vstore.similarity_search( + "[-1,-1]", + k=10, + filter={"group": "vowel"}, + ) + assert {doc.metadata["letter"] for doc in res1} == set("eio") + # multiple filters + res2 = vstore.similarity_search( + "[-1,-1]", + k=10, + filter={"group": "consonant", "ord": str(ord("q"))}, + ) + assert {doc.metadata["letter"] for doc in res2} == set("q") + # excessive filters + res3 = vstore.similarity_search( + "[-1,-1]", + k=10, + filter={"group": "consonant", "ord": str(ord("q")), "case": "upper"}, + ) + assert res3 == [] + + def test_cassandra_vectorstore_metadata_search_sync( + self, + vector_store_d2: Cassandra, + metadata_documents: list[Document], + ) -> None: + """Metadata Search""" + vstore = vector_store_d2 + vstore.add_documents(metadata_documents) + # no filters + res0 = vstore.metadata_search(filter={}, n=10) + assert {doc.metadata["letter"] for doc in res0} == set("qwreio") + # single filter + res1 = vstore.metadata_search( + n=10, + filter={"group": "vowel"}, + ) + assert {doc.metadata["letter"] for doc in res1} == set("eio") + # multiple filters + res2 = vstore.metadata_search( + n=10, + filter={"group": "consonant", "ord": str(ord("q"))}, + ) + assert {doc.metadata["letter"] for doc in res2} == set("q") + # excessive filters + res3 = vstore.metadata_search( + n=10, + filter={"group": "consonant", "ord": str(ord("q")), "case": "upper"}, + ) + assert res3 == [] + + async def test_cassandra_vectorstore_metadata_search_async( + self, + vector_store_d2: Cassandra, + metadata_documents: list[Document], + ) -> None: + """Metadata Search""" + vstore = vector_store_d2 + await vstore.aadd_documents(metadata_documents) + # no filters + res0 = await vstore.ametadata_search(filter={}, n=10) + assert {doc.metadata["letter"] for doc in res0} == set("qwreio") + # single filter + res1 = vstore.metadata_search( + n=10, + filter={"group": "vowel"}, + ) + assert {doc.metadata["letter"] for doc in res1} == set("eio") + # multiple filters + res2 = await vstore.ametadata_search( + n=10, + filter={"group": "consonant", "ord": str(ord("q"))}, + ) + assert {doc.metadata["letter"] for doc in res2} == set("q") + # excessive filters + res3 = await vstore.ametadata_search( + n=10, + filter={"group": "consonant", "ord": str(ord("q")), "case": "upper"}, + ) + assert res3 == [] + + def test_cassandra_vectorstore_get_by_document_id_sync( + self, + vector_store_d2: Cassandra, + metadata_documents: list[Document], + ) -> None: + """Get by document_id""" + vstore = vector_store_d2 + vstore.add_documents(metadata_documents) + # invalid id + invalid = vstore.get_by_document_id(document_id="z") + assert invalid is None + # valid id + valid = vstore.get_by_document_id(document_id="q") + assert isinstance(valid, Document) + assert valid.id == "q" + assert valid.page_content == "[1,2]" + assert valid.metadata["group"] == "consonant" + assert valid.metadata["letter"] == "q" + + async def test_cassandra_vectorstore_get_by_document_id_async( + self, + vector_store_d2: Cassandra, + metadata_documents: list[Document], + ) -> None: + """Get by document_id""" + vstore = vector_store_d2 + await vstore.aadd_documents(metadata_documents) + # invalid id + invalid = await vstore.aget_by_document_id(document_id="z") + assert invalid is None + # valid id + valid = await vstore.aget_by_document_id(document_id="q") + assert isinstance(valid, Document) + assert valid.id == "q" + assert valid.page_content == "[1,2]" + assert valid.metadata["group"] == "consonant" + assert valid.metadata["letter"] == "q" + + @pytest.mark.parametrize( + ("texts", "query"), + [ + ( + ["[1,1]", "[-1,-1]"], + "[0.99999,1.00001]", + ), + ], + ) + def test_cassandra_vectorstore_similarity_scale_sync( + self, + *, + vector_store_d2: Cassandra, + texts: list[str], + query: str, + ) -> None: + """Scale of the similarity scores.""" + vstore = vector_store_d2 + vstore.add_texts( + texts=texts, + ids=["near", "far"], + ) + res1 = vstore.similarity_search_with_score( + query, + k=2, + ) + scores = [sco for _, sco in res1] + sco_near, sco_far = scores + assert sco_far >= 0 + assert abs(1 - sco_near) < MATCH_EPSILON + assert sco_far < EUCLIDEAN_MIN_SIM_UNIT_VECTORS + MATCH_EPSILON + + @pytest.mark.parametrize( + ("texts", "query"), + [ + ( + ["[1,1]", "[-1,-1]"], + "[0.99999,1.00001]", + ), + ], + ) + async def test_cassandra_vectorstore_similarity_scale_async( + self, + *, + vector_store_d2: Cassandra, + texts: list[str], + query: str, + ) -> None: + """Scale of the similarity scores, async version.""" + vstore = vector_store_d2 + await vstore.aadd_texts( + texts=texts, + ids=["near", "far"], + ) + res1 = await vstore.asimilarity_search_with_score( + query, + k=2, + ) + scores = [sco for _, sco in res1] + sco_near, sco_far = scores + assert sco_far >= 0 + assert abs(1 - sco_near) < MATCH_EPSILON + assert sco_far < EUCLIDEAN_MIN_SIM_UNIT_VECTORS + MATCH_EPSILON + + def test_cassandra_vectorstore_massive_delete( + self, + vector_store_d2: Cassandra, + ) -> None: + """Larger-scale bulk deletes.""" + vstore = vector_store_d2 + m = 150 + texts = [f"[0,{i + 1 / 7.0}]" for i in range(2 * m)] + ids0 = [f"doc_{i}" for i in range(m)] + ids1 = [f"doc_{i + m}" for i in range(m)] + ids = ids0 + ids1 + vstore.add_texts(texts=texts, ids=ids) + # deleting a bunch of these + del_res0 = vstore.delete(ids0) + assert del_res0 is True + # deleting the rest plus a fake one + del_res1 = vstore.delete([*ids1, "ghost!"]) + assert del_res1 is True # ensure no error + # nothing left + assert vstore.similarity_search("[-1,-1]", k=2 * m) == [] diff --git a/libs/community/tests/unit_tests/graph_vectorstores/test_mmr_helper.py b/libs/community/tests/unit_tests/graph_vectorstores/test_mmr_helper.py new file mode 100644 index 00000000000000..0f5fef97ce95be --- /dev/null +++ b/libs/community/tests/unit_tests/graph_vectorstores/test_mmr_helper.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +import math + +from langchain_community.graph_vectorstores.mmr_helper import MmrHelper + +IDS = { + "-1", + "-2", + "-3", + "-4", + "-5", + "+1", + "+2", + "+3", + "+4", + "+5", +} + + +class TestMmrHelper: + def test_mmr_helper_functional(self) -> None: + helper = MmrHelper(k=3, query_embedding=[6, 5], lambda_mult=0.5) + + assert len(list(helper.candidate_ids())) == 0 + + helper.add_candidates({"-1": [3, 5]}) + helper.add_candidates({"-2": [3, 5]}) + helper.add_candidates({"-3": [2, 6]}) + helper.add_candidates({"-4": [1, 6]}) + helper.add_candidates({"-5": [0, 6]}) + + assert len(list(helper.candidate_ids())) == 5 + + helper.add_candidates({"+1": [5, 3]}) + helper.add_candidates({"+2": [5, 3]}) + helper.add_candidates({"+3": [6, 2]}) + helper.add_candidates({"+4": [6, 1]}) + helper.add_candidates({"+5": [6, 0]}) + + assert len(list(helper.candidate_ids())) == 10 + + for idx in range(3): + best_id = helper.pop_best() + assert best_id in IDS + assert len(list(helper.candidate_ids())) == 9 - idx + assert best_id not in helper.candidate_ids() + + def test_mmr_helper_max_diversity(self) -> None: + helper = MmrHelper(k=2, query_embedding=[6, 5], lambda_mult=0) + helper.add_candidates({"-1": [3, 5]}) + helper.add_candidates({"-2": [3, 5]}) + helper.add_candidates({"-3": [2, 6]}) + helper.add_candidates({"-4": [1, 6]}) + helper.add_candidates({"-5": [0, 6]}) + + best = {helper.pop_best(), helper.pop_best()} + assert best == {"-1", "-5"} + + def test_mmr_helper_max_similarity(self) -> None: + helper = MmrHelper(k=2, query_embedding=[6, 5], lambda_mult=1) + helper.add_candidates({"-1": [3, 5]}) + helper.add_candidates({"-2": [3, 5]}) + helper.add_candidates({"-3": [2, 6]}) + helper.add_candidates({"-4": [1, 6]}) + helper.add_candidates({"-5": [0, 6]}) + + best = {helper.pop_best(), helper.pop_best()} + assert best == {"-1", "-2"} + + def test_mmr_helper_add_candidate(self) -> None: + helper = MmrHelper(5, [0.0, 1.0]) + helper.add_candidates( + { + "a": [0.0, 1.0], + "b": [1.0, 0.0], + } + ) + assert helper.best_id == "a" + + def test_mmr_helper_pop_best(self) -> None: + helper = MmrHelper(5, [0.0, 1.0]) + helper.add_candidates( + { + "a": [0.0, 1.0], + "b": [1.0, 0.0], + } + ) + assert helper.pop_best() == "a" + assert helper.pop_best() == "b" + assert helper.pop_best() is None + + def angular_embedding(self, angle: float) -> list[float]: + return [math.cos(angle * math.pi), math.sin(angle * math.pi)] + + def test_mmr_helper_added_documents(self) -> None: + """Test end to end construction and MMR search. + The embedding function used here ensures `texts` become + the following vectors on a circle (numbered v0 through v3): + + ______ v2 + / \ + / | v1 + v3 | . | query + | / v0 + |______/ (N.B. very crude drawing) + + + With fetch_k==2 and k==2, when query is at 0.0, (1, ), + one expects that v2 and v0 are returned (in some order) + because v1 is "too close" to v0 (and v0 is closer than v1)). + + Both v2 and v3 are discovered after v0. + """ + helper = MmrHelper(5, self.angular_embedding(0.0)) + + # Fetching the 2 nearest neighbors to 0.0 + helper.add_candidates( + { + "v0": self.angular_embedding(-0.124), + "v1": self.angular_embedding(+0.127), + } + ) + assert helper.pop_best() == "v0" + + # After v0 is selected, new nodes are discovered. + # v2 is closer than v3. v1 is "too similar" to "v0" so it's not included. + helper.add_candidates( + { + "v2": self.angular_embedding(+0.25), + "v3": self.angular_embedding(+1.0), + } + ) + assert helper.pop_best() == "v2" + + assert math.isclose( + helper.selected_similarity_scores[0], 0.9251, abs_tol=0.0001 + ) + assert math.isclose( + helper.selected_similarity_scores[1], 0.7071, abs_tol=0.0001 + ) + assert math.isclose(helper.selected_mmr_scores[0], 0.4625, abs_tol=0.0001) + assert math.isclose(helper.selected_mmr_scores[1], 0.1608, abs_tol=0.0001)