Skip to content

Commit

Permalink
progress on tests
Browse files Browse the repository at this point in the history
  • Loading branch information
epinzur committed Oct 17, 2024
1 parent 86fdb71 commit a2edc8c
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 167 deletions.
18 changes: 16 additions & 2 deletions libs/community/langchain_community/graph_vectorstores/cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,16 @@ async def afrom_texts(
**kwargs,
)

@staticmethod
def _add_ids_to_docs(
docs: List[Document],
ids: Optional[List[str]] = None,
) -> List[Document]:
if ids is not None:
for doc, doc_id in zip(docs, ids):
doc.id = doc_id
return docs

@classmethod
def from_documents(
cls: Type[CGVST],
Expand All @@ -1152,6 +1162,7 @@ def from_documents(
session: Optional[Session] = None,
keyspace: Optional[str] = None,
table_name: str = "",
ids: Optional[List[str]] = None,
ttl_seconds: Optional[int] = None,
body_index_options: Optional[List[Tuple[str, Any]]] = None,
metadata_deny_list: Iterable[str] = [],
Expand All @@ -1167,6 +1178,7 @@ def from_documents(
keyspace: Cassandra key space.
If not provided, it is resolved from cassio.
table_name: Cassandra table (required).
ids: Optional list of IDs associated with the documents.
ttl_seconds: Optional time-to-live for the added documents.
body_index_options: Optional options used to create the body index.
Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER]
Expand All @@ -1193,7 +1205,7 @@ def from_documents(
metadata_deny_list=metadata_deny_list,
**kwargs,
)
store.add_documents(documents=documents)
store.add_documents(documents=cls._add_ids_to_docs(docs=documents, ids=ids))
return store

@classmethod
Expand All @@ -1205,6 +1217,7 @@ async def afrom_documents(
session: Optional[Session] = None,
keyspace: Optional[str] = None,
table_name: str = "",
ids: Optional[List[str]] = None,
ttl_seconds: Optional[int] = None,
body_index_options: Optional[List[Tuple[str, Any]]] = None,
metadata_deny_list: Iterable[str] = [],
Expand All @@ -1220,6 +1233,7 @@ async def afrom_documents(
keyspace: Cassandra key space.
If not provided, it is resolved from cassio.
table_name: Cassandra table (required).
ids: Optional list of IDs associated with the documents.
ttl_seconds: Optional time-to-live for the added documents.
body_index_options: Optional options used to create the body index.
Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER]
Expand Down Expand Up @@ -1248,5 +1262,5 @@ async def afrom_documents(
metadata_deny_list=metadata_deny_list,
**kwargs,
)
await store.aadd_documents(documents=documents)
await store.aadd_documents(documents=cls._add_ids_to_docs(docs=documents, ids=ids))
return store
112 changes: 63 additions & 49 deletions libs/community/langchain_community/vectorstores/cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,6 +1134,24 @@ async def amax_marginal_relevance_search(
body_search=body_search,
)

@staticmethod
def _build_docs_from_texts(
texts: List[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
) -> List[Document]:
docs: List[Document] = []
for i, text in enumerate(texts):
doc = Document(
page_content=text,
)
if metadatas is not None:
doc.metadata = metadatas[i]
if ids is not None:
doc.id = ids[i]
docs.append(doc)
return docs

@classmethod
def from_texts(
cls: Type[CVST],
Expand All @@ -1145,13 +1163,12 @@ def from_texts(
keyspace: Optional[str] = None,
table_name: str = "",
ids: Optional[List[str]] = None,
batch_size: int = 16,
ttl_seconds: Optional[int] = None,
body_index_options: Optional[List[Tuple[str, Any]]] = None,
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
metadata_indexing: Iterable[str] = [],
**kwargs: Any,
) -> CVST:
"""Create a Cassandra vectorstore from raw texts.
"""Create a Cassandra vector store from raw texts.
Args:
texts: Texts to add to the vectorstore.
Expand All @@ -1163,8 +1180,6 @@ def from_texts(
If not provided, it is resolved from cassio.
table_name: Cassandra table (required).
ids: Optional list of IDs associated with the texts.
batch_size: Number of concurrent requests to send to the server.
Defaults to 16.
ttl_seconds: Optional time-to-live for the added texts.
body_index_options: Optional options used to create the body index.
Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER]
Expand All @@ -1181,9 +1196,16 @@ def from_texts(
(and to overcome max-length limitations).
Returns:
a Cassandra vectorstore.
a Cassandra vector store.
"""
store = cls(
docs = cls._build_docs_from_texts(
texts=texts,
metadatas=metadatas,
ids=ids,
)

return cls.from_documents(
documents=docs,
embedding=embedding,
session=session,
keyspace=keyspace,
Expand All @@ -1193,10 +1215,6 @@ def from_texts(
metadata_indexing=metadata_indexing,
**kwargs,
)
store.add_texts(
texts=texts, metadatas=metadatas, ids=ids, batch_size=batch_size
)
return store

@classmethod
async def afrom_texts(
Expand All @@ -1209,13 +1227,12 @@ async def afrom_texts(
keyspace: Optional[str] = None,
table_name: str = "",
ids: Optional[List[str]] = None,
concurrency: int = 16,
ttl_seconds: Optional[int] = None,
body_index_options: Optional[List[Tuple[str, Any]]] = None,
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
metadata_indexing: Iterable[str] = [],
**kwargs: Any,
) -> CVST:
"""Create a Cassandra vectorstore from raw texts.
"""Create a Cassandra vector store from raw texts.
Args:
texts: Texts to add to the vectorstore.
Expand All @@ -1227,8 +1244,6 @@ async def afrom_texts(
If not provided, it is resolved from cassio.
table_name: Cassandra table (required).
ids: Optional list of IDs associated with the texts.
concurrency: Number of concurrent queries to send to the database.
Defaults to 16.
ttl_seconds: Optional time-to-live for the added texts.
body_index_options: Optional options used to create the body index.
Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER]
Expand All @@ -1245,23 +1260,35 @@ async def afrom_texts(
(and to overcome max-length limitations).
Returns:
a Cassandra vectorstore.
a Cassandra vector store.
"""
store = cls(
docs = cls._build_docs_from_texts(
texts=texts,
metadatas=metadatas,
ids=ids,
)

return await cls.afrom_documents(
documents=docs,
embedding=embedding,
session=session,
keyspace=keyspace,
table_name=table_name,
ttl_seconds=ttl_seconds,
setup_mode=SetupMode.ASYNC,
body_index_options=body_index_options,
metadata_indexing=metadata_indexing,
**kwargs,
)
await store.aadd_texts(
texts=texts, metadatas=metadatas, ids=ids, concurrency=concurrency
)
return store

@staticmethod
def _add_ids_to_docs(
docs: List[Document],
ids: Optional[List[str]] = None,
) -> List[Document]:
if ids is not None:
for doc, doc_id in zip(docs, ids):
doc.id = doc_id
return docs

@classmethod
def from_documents(
Expand All @@ -1273,13 +1300,12 @@ def from_documents(
keyspace: Optional[str] = None,
table_name: str = "",
ids: Optional[List[str]] = None,
batch_size: int = 16,
ttl_seconds: Optional[int] = None,
body_index_options: Optional[List[Tuple[str, Any]]] = None,
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
metadata_indexing: Iterable[str] = [],
**kwargs: Any,
) -> CVST:
"""Create a Cassandra vectorstore from a document list.
"""Create a Cassandra vector store from a document list.
Args:
documents: Documents to add to the vectorstore.
Expand All @@ -1290,8 +1316,6 @@ def from_documents(
If not provided, it is resolved from cassio.
table_name: Cassandra table (required).
ids: Optional list of IDs associated with the documents.
batch_size: Number of concurrent requests to send to the server.
Defaults to 16.
ttl_seconds: Optional time-to-live for the added documents.
body_index_options: Optional options used to create the body index.
Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER]
Expand All @@ -1308,24 +1332,20 @@ def from_documents(
(and to overcome max-length limitations).
Returns:
a Cassandra vectorstore.
a Cassandra vector store.
"""
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
return cls.from_texts(
texts=texts,
store = cls(
embedding=embedding,
metadatas=metadatas,
session=session,
keyspace=keyspace,
table_name=table_name,
ids=ids,
batch_size=batch_size,
ttl_seconds=ttl_seconds,
body_index_options=body_index_options,
metadata_indexing=metadata_indexing,
**kwargs,
)
store.add_documents(documents=cls._add_ids_to_docs(docs=documents, ids=ids))
return store

@classmethod
async def afrom_documents(
Expand All @@ -1337,13 +1357,12 @@ async def afrom_documents(
keyspace: Optional[str] = None,
table_name: str = "",
ids: Optional[List[str]] = None,
concurrency: int = 16,
ttl_seconds: Optional[int] = None,
body_index_options: Optional[List[Tuple[str, Any]]] = None,
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
metadata_indexing: Iterable[str] = [],
**kwargs: Any,
) -> CVST:
"""Create a Cassandra vectorstore from a document list.
"""Create a Cassandra vector store from a document list.
Args:
documents: Documents to add to the vectorstore.
Expand All @@ -1354,8 +1373,6 @@ async def afrom_documents(
If not provided, it is resolved from cassio.
table_name: Cassandra table (required).
ids: Optional list of IDs associated with the documents.
concurrency: Number of concurrent queries to send to the database.
Defaults to 16.
ttl_seconds: Optional time-to-live for the added documents.
body_index_options: Optional options used to create the body index.
Eg. body_index_options = [cassio.table.cql.STANDARD_ANALYZER]
Expand All @@ -1372,24 +1389,21 @@ async def afrom_documents(
(and to overcome max-length limitations).
Returns:
a Cassandra vectorstore.
a Cassandra vector store.
"""
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
return await cls.afrom_texts(
texts=texts,
store = cls(
embedding=embedding,
metadatas=metadatas,
session=session,
keyspace=keyspace,
table_name=table_name,
ids=ids,
concurrency=concurrency,
ttl_seconds=ttl_seconds,
setup_mode=SetupMode.ASYNC,
body_index_options=body_index_options,
metadata_indexing=metadata_indexing,
**kwargs,
)
await store.aadd_documents(documents=cls._add_ids_to_docs(docs=documents, ids=ids))
return store

def as_retriever(
self,
Expand Down
Loading

0 comments on commit a2edc8c

Please sign in to comment.