Skip to content

Commit

Permalink
updated tests
Browse files Browse the repository at this point in the history
  • Loading branch information
epinzur committed Oct 17, 2024
1 parent a2edc8c commit 2184e0c
Show file tree
Hide file tree
Showing 5 changed files with 1,461 additions and 1,467 deletions.
22 changes: 12 additions & 10 deletions libs/community/langchain_community/graph_vectorstores/cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __init__(
*,
body_index_options: list[tuple[str, Any]] | None = None,
setup_mode: SetupMode = SetupMode.SYNC,
metadata_deny_list: Iterable[str] = [],
metadata_deny_list: list[str] | None = None,
) -> None:
"""Apache Cassandra(R) for graph-vector-store workloads.
Expand Down Expand Up @@ -164,9 +164,9 @@ def __init__(
"""
self.embedding = embedding

deny_list = set(metadata_deny_list)
deny_list.add(METADATA_LINKS_KEY)
self._metadata_deny_list = deny_list
if metadata_deny_list is None:
metadata_deny_list = []
metadata_deny_list.append(METADATA_LINKS_KEY)

self.vector_store = CassandraVectorStore(
embedding=embedding,
Expand All @@ -176,7 +176,7 @@ def __init__(
ttl_seconds=ttl_seconds,
body_index_options=body_index_options,
setup_mode=setup_mode,
metadata_indexing=("deny_list", deny_list),
metadata_indexing=("deny_list", metadata_deny_list),
)

store_session: Session = self.vector_store.session
Expand Down Expand Up @@ -1032,7 +1032,7 @@ def from_texts(
ids: Optional[List[str]] = None,
ttl_seconds: Optional[int] = None,
body_index_options: Optional[List[Tuple[str, Any]]] = None,
metadata_deny_list: Iterable[str] = [],
metadata_deny_list: list[str] | None = None,
**kwargs: Any,
) -> CGVST:
"""Create a CassandraGraphVectorStore from raw texts.
Expand Down Expand Up @@ -1094,7 +1094,7 @@ async def afrom_texts(
ids: Optional[List[str]] = None,
ttl_seconds: Optional[int] = None,
body_index_options: Optional[List[Tuple[str, Any]]] = None,
metadata_deny_list: Iterable[str] = [],
metadata_deny_list: list[str] | None = None,
**kwargs: Any,
) -> CGVST:
"""Create a CassandraGraphVectorStore from raw texts.
Expand Down Expand Up @@ -1165,7 +1165,7 @@ def from_documents(
ids: Optional[List[str]] = None,
ttl_seconds: Optional[int] = None,
body_index_options: Optional[List[Tuple[str, Any]]] = None,
metadata_deny_list: Iterable[str] = [],
metadata_deny_list: list[str] | None = None,
**kwargs: Any,
) -> CGVST:
"""Create a CassandraGraphVectorStore from a document list.
Expand Down Expand Up @@ -1220,7 +1220,7 @@ async def afrom_documents(
ids: Optional[List[str]] = None,
ttl_seconds: Optional[int] = None,
body_index_options: Optional[List[Tuple[str, Any]]] = None,
metadata_deny_list: Iterable[str] = [],
metadata_deny_list: list[str] | None = None,
**kwargs: Any,
) -> CGVST:
"""Create a CassandraGraphVectorStore from a document list.
Expand Down Expand Up @@ -1262,5 +1262,7 @@ async def afrom_documents(
metadata_deny_list=metadata_deny_list,
**kwargs,
)
await store.aadd_documents(documents=cls._add_ids_to_docs(docs=documents, ids=ids))
await store.aadd_documents(
documents=cls._add_ids_to_docs(docs=documents, ids=ids)
)
return store
37 changes: 32 additions & 5 deletions libs/community/langchain_community/vectorstores/cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import importlib.metadata
import typing
import uuid
import warnings
from typing import (
Any,
Awaitable,
Expand Down Expand Up @@ -1165,7 +1166,7 @@ def from_texts(
ids: Optional[List[str]] = None,
ttl_seconds: Optional[int] = None,
body_index_options: Optional[List[Tuple[str, Any]]] = None,
metadata_indexing: Iterable[str] = [],
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
**kwargs: Any,
) -> CVST:
"""Create a Cassandra vector store from raw texts.
Expand Down Expand Up @@ -1229,7 +1230,7 @@ async def afrom_texts(
ids: Optional[List[str]] = None,
ttl_seconds: Optional[int] = None,
body_index_options: Optional[List[Tuple[str, Any]]] = None,
metadata_indexing: Iterable[str] = [],
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
**kwargs: Any,
) -> CVST:
"""Create a Cassandra vector store from raw texts.
Expand Down Expand Up @@ -1302,7 +1303,7 @@ def from_documents(
ids: Optional[List[str]] = None,
ttl_seconds: Optional[int] = None,
body_index_options: Optional[List[Tuple[str, Any]]] = None,
metadata_indexing: Iterable[str] = [],
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
**kwargs: Any,
) -> CVST:
"""Create a Cassandra vector store from a document list.
Expand Down Expand Up @@ -1334,6 +1335,18 @@ def from_documents(
Returns:
a Cassandra vector store.
"""
if ids is not None:
warnings.warn(
(
"Parameter `ids` to Cassandra's `from_documents` "
"method is deprecated. Please set the supplied documents' "
"`.id` attribute instead. The id attribute of Document "
"is ignored as long as the `ids` parameter is passed."
),
DeprecationWarning,
stacklevel=2,
)

store = cls(
embedding=embedding,
session=session,
Expand All @@ -1359,7 +1372,7 @@ async def afrom_documents(
ids: Optional[List[str]] = None,
ttl_seconds: Optional[int] = None,
body_index_options: Optional[List[Tuple[str, Any]]] = None,
metadata_indexing: Iterable[str] = [],
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
**kwargs: Any,
) -> CVST:
"""Create a Cassandra vector store from a document list.
Expand Down Expand Up @@ -1391,6 +1404,18 @@ async def afrom_documents(
Returns:
a Cassandra vector store.
"""
if ids is not None:
warnings.warn(
(
"Parameter `ids` to Cassandra's `afrom_documents` "
"method is deprecated. Please set the supplied documents' "
"`.id` attribute instead. The id attribute of Document "
"is ignored as long as the `ids` parameter is passed."
),
DeprecationWarning,
stacklevel=2,
)

store = cls(
embedding=embedding,
session=session,
Expand All @@ -1402,7 +1427,9 @@ async def afrom_documents(
metadata_indexing=metadata_indexing,
**kwargs,
)
await store.aadd_documents(documents=cls._add_ids_to_docs(docs=documents, ids=ids))
await store.aadd_documents(
documents=cls._add_ids_to_docs(docs=documents, ids=ids)
)
return store

def as_retriever(
Expand Down
Loading

0 comments on commit 2184e0c

Please sign in to comment.