From 66b68dd34494b1a666bd49efed6c9855959ee101 Mon Sep 17 00:00:00 2001 From: Eric Pinzur Date: Wed, 16 Oct 2024 12:53:49 +0200 Subject: [PATCH] added more tests from AstraDBVectorStore --- .../vectorstores/test_cassandra.py | 905 ++++++------------ 1 file changed, 267 insertions(+), 638 deletions(-) diff --git a/libs/community/tests/integration_tests/vectorstores/test_cassandra.py b/libs/community/tests/integration_tests/vectorstores/test_cassandra.py index dd1f52c9a21fd8..b546bb887ab21e 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_cassandra.py +++ b/libs/community/tests/integration_tests/vectorstores/test_cassandra.py @@ -1,18 +1,17 @@ """Test Cassandra functionality.""" import asyncio +import json +import math import os import time -import json -from typing import Iterable, List, Optional, Tuple, Type, Union, TYPE_CHECKING +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union import pytest from langchain_core.documents import Document from langchain_community.vectorstores import Cassandra -from langchain_community.vectorstores.cassandra import SetupMode from tests.integration_tests.vectorstores.fake_embeddings import ( - AngularTwoDimensionalEmbeddings, ConsistentFakeEmbeddings, Embeddings, ) @@ -22,6 +21,11 @@ TEST_KEYSPACE = "vector_test_keyspace" +# similarity threshold definitions +EUCLIDEAN_MIN_SIM_UNIT_VECTORS = 0.2 +MATCH_EPSILON = 0.0001 + + class ParserEmbeddings(Embeddings): """Parse input texts: if they are json for a List[float], fine. Otherwise, return all zeros and call it a day. @@ -48,9 +52,11 @@ def embed_query(self, text: str) -> list[float]: async def aembed_query(self, text: str) -> list[float]: return self.embed_query(text) + def _embedding_d2() -> Embeddings: return ParserEmbeddings(dimension=2) + def _metadata_documents() -> list[Document]: """Documents for metadata and id tests""" return [ @@ -97,8 +103,10 @@ def _strip_doc(document: Document) -> Document: metadata=document.metadata, ) + def _get_cassandra_session(table_name: str, drop: bool) -> Session: from cassandra.cluster import Cluster + # get db connection if "CASSANDRA_CONTACT_POINTS" in os.environ: contact_points = [ @@ -145,6 +153,7 @@ def _vectorstore_from_texts( metadata_indexing=metadata_indexing, ) + async def _vectorstore_from_texts_async( texts: List[str], embedding: Embeddings, @@ -166,6 +175,7 @@ async def _vectorstore_from_texts_async( metadata_indexing=metadata_indexing, ) + def _vectorstore_from_documents( docs: List[Document], embedding: Embeddings, @@ -210,7 +220,9 @@ def test_cassandra_add_texts() -> None: """Test end to end construction with further insertions.""" texts = ["foo", "bar", "baz"] metadatas = [{"page": i} for i in range(len(texts))] - vstore = _vectorstore_from_texts(texts, metadatas=metadatas, embedding=ConsistentFakeEmbeddings()) + vstore = _vectorstore_from_texts( + texts, metadatas=metadatas, embedding=ConsistentFakeEmbeddings() + ) texts2 = ["foo2", "bar2", "baz2"] metadatas2 = [{"page": i + 3} for i in range(len(texts))] @@ -220,11 +232,13 @@ def test_cassandra_add_texts() -> None: assert len(output) == 6 -async def test_cassandra_aadd_texts() -> None: +async def test_cassandra_add_texts_async() -> None: """Test end to end construction with further insertions.""" texts = ["foo", "bar", "baz"] metadatas = [{"page": i} for i in range(len(texts))] - vstore = _vectorstore_from_texts(texts, metadatas=metadatas, embedding=ConsistentFakeEmbeddings()) + vstore = await _vectorstore_from_texts_async( + texts, metadatas=metadatas, embedding=ConsistentFakeEmbeddings() + ) texts2 = ["foo2", "bar2", "baz2"] metadatas2 = [{"page": i + 3} for i in range(len(texts))] @@ -238,10 +252,14 @@ def test_cassandra_no_drop() -> None: """Test end to end construction and re-opening the same index.""" texts = ["foo", "bar", "baz"] metadatas = [{"page": i} for i in range(len(texts))] - _vectorstore_from_texts(texts, metadatas=metadatas, embedding=ConsistentFakeEmbeddings()) + _vectorstore_from_texts( + texts, metadatas=metadatas, embedding=ConsistentFakeEmbeddings() + ) texts2 = ["foo2", "bar2", "baz2"] - vstore = _vectorstore_from_texts(texts2, metadatas=metadatas, drop=False, embedding=ConsistentFakeEmbeddings()) + vstore = _vectorstore_from_texts( + texts2, metadatas=metadatas, drop=False, embedding=ConsistentFakeEmbeddings() + ) output = vstore.similarity_search("foo", k=10) assert len(output) == 6 @@ -251,7 +269,9 @@ async def test_cassandra_no_drop_async() -> None: """Test end to end construction and re-opening the same index.""" texts = ["foo", "bar", "baz"] metadatas = [{"page": i} for i in range(len(texts))] - await _vectorstore_from_texts_async(texts, metadatas=metadatas, embedding=ConsistentFakeEmbeddings()) + await _vectorstore_from_texts_async( + texts, metadatas=metadatas, embedding=ConsistentFakeEmbeddings() + ) texts2 = ["foo2", "bar2", "baz2"] vstore = await _vectorstore_from_texts_async( @@ -266,7 +286,9 @@ def test_cassandra_delete() -> None: """Test delete methods from vector store.""" texts = ["foo", "bar", "baz", "gni"] metadatas = [{"page": i, "mod2": i % 2} for i in range(len(texts))] - vstore = _vectorstore_from_texts([], metadatas=metadatas, embedding=ConsistentFakeEmbeddings()) + vstore = _vectorstore_from_texts( + [], metadatas=metadatas, embedding=ConsistentFakeEmbeddings() + ) ids = vstore.add_texts(texts, metadatas) output = vstore.similarity_search("foo", k=10) @@ -300,11 +322,13 @@ def test_cassandra_delete() -> None: vstore.delete_by_metadata_filter({}) -async def test_cassandra_adelete() -> None: +async def test_cassandra_delete_async() -> None: """Test delete methods from vector store.""" texts = ["foo", "bar", "baz", "gni"] metadatas = [{"page": i, "mod2": i % 2} for i in range(len(texts))] - vstore = await _vectorstore_from_texts_async([], metadatas=metadatas, embedding=ConsistentFakeEmbeddings()) + vstore = await _vectorstore_from_texts_async( + [], metadatas=metadatas, embedding=ConsistentFakeEmbeddings() + ) ids = await vstore.aadd_texts(texts, metadatas) output = await vstore.asimilarity_search("foo", k=10) @@ -342,7 +366,9 @@ def test_cassandra_metadata_indexing() -> None: """Test comparing metadata indexing policies.""" texts = ["foo"] metadatas = [{"field1": "a", "field2": "b"}] - vstore_all = _vectorstore_from_texts(texts, metadatas=metadatas, embedding=ConsistentFakeEmbeddings()) + vstore_all = _vectorstore_from_texts( + texts, metadatas=metadatas, embedding=ConsistentFakeEmbeddings() + ) vstore_f1 = _vectorstore_from_texts( texts, metadatas=metadatas, @@ -365,122 +391,16 @@ def test_cassandra_metadata_indexing() -> None: vstore_f1.similarity_search("bar", filter={"field2": "b"}, k=2) -def test_cassandra_replace_metadata() -> None: - """Test of replacing metadata.""" - N_DOCS = 100 - REPLACE_RATIO = 2 # one in ... will have replaced metadata - BATCH_SIZE = 3 - - vstore_f1 = _vectorstore_from_texts( - texts=[], - metadata_indexing=("allowlist", ["field1", "field2"]), - table_name="vector_test_table_indexing", - embedding=ConsistentFakeEmbeddings(), - ) - orig_documents = [ - Document( - page_content=f"doc_{doc_i}", - id=f"doc_id_{doc_i}", - metadata={"field1": f"f1_{doc_i}", "otherf": "pre"}, - ) - for doc_i in range(N_DOCS) - ] - vstore_f1.add_documents(orig_documents) - - ids_to_replace = [ - f"doc_id_{doc_i}" for doc_i in range(N_DOCS) if doc_i % REPLACE_RATIO == 0 - ] - - # various kinds of replacement at play here: - def _make_new_md(mode: int, doc_id: str) -> dict[str, str]: - if mode == 0: - return {} - elif mode == 1: - return {"field2": f"NEW_{doc_id}"} - elif mode == 2: - return {"field2": f"NEW_{doc_id}", "ofherf2": "post"} - else: - return {"ofherf2": "post"} - - ids_to_new_md = { - doc_id: _make_new_md(rep_i % 4, doc_id) - for rep_i, doc_id in enumerate(ids_to_replace) - } - - vstore_f1.replace_metadata(ids_to_new_md, batch_size=BATCH_SIZE) - # thorough check - expected_id_to_metadata: dict[str, dict] = { - **{(document.id or ""): document.metadata for document in orig_documents}, - **ids_to_new_md, - } - for hit in vstore_f1.similarity_search("doc", k=N_DOCS + 1): - assert hit.id is not None - assert hit.metadata == expected_id_to_metadata[hit.id] - - -async def test_cassandra_areplace_metadata() -> None: - """Test of replacing metadata.""" - N_DOCS = 100 - REPLACE_RATIO = 2 # one in ... will have replaced metadata - BATCH_SIZE = 3 - - vstore_f1 = _vectorstore_from_texts( - texts=[], - metadata_indexing=("allowlist", ["field1", "field2"]), - table_name="vector_test_table_indexing", - embedding=ConsistentFakeEmbeddings(), - ) - orig_documents = [ - Document( - page_content=f"doc_{doc_i}", - id=f"doc_id_{doc_i}", - metadata={"field1": f"f1_{doc_i}", "otherf": "pre"}, - ) - for doc_i in range(N_DOCS) - ] - await vstore_f1.aadd_documents(orig_documents) - - ids_to_replace = [ - f"doc_id_{doc_i}" for doc_i in range(N_DOCS) if doc_i % REPLACE_RATIO == 0 - ] - - # various kinds of replacement at play here: - def _make_new_md(mode: int, doc_id: str) -> dict[str, str]: - if mode == 0: - return {} - elif mode == 1: - return {"field2": f"NEW_{doc_id}"} - elif mode == 2: - return {"field2": f"NEW_{doc_id}", "ofherf2": "post"} - else: - return {"ofherf2": "post"} - - ids_to_new_md = { - doc_id: _make_new_md(rep_i % 4, doc_id) - for rep_i, doc_id in enumerate(ids_to_replace) - } - - await vstore_f1.areplace_metadata(ids_to_new_md, concurrency=BATCH_SIZE) - # thorough check - expected_id_to_metadata: dict[str, dict] = { - **{(document.id or ""): document.metadata for document in orig_documents}, - **ids_to_new_md, - } - for hit in await vstore_f1.asimilarity_search("doc", k=N_DOCS + 1): - assert hit.id is not None - assert hit.metadata == expected_id_to_metadata[hit.id] - - def test_cassandra_vectorstore_from_texts_sync() -> None: """from_texts methods and the associated warnings.""" page_contents = [ - "[1,2]", - "[3,4]", - "[5,6]", - "[7,8]", - "[9,10]", - "[11,12]", - ] + "[1,2]", + "[3,4]", + "[5,6]", + "[7,8]", + "[9,10]", + "[11,12]", + ] table_name = "empty_collection_d2" v_store = _vectorstore_from_texts( @@ -488,7 +408,7 @@ def test_cassandra_vectorstore_from_texts_sync() -> None: metadatas=[{"m": 1}, {"m": 3}], embedding=_embedding_d2(), ids=["ft1", "ft3"], - table_name=table_name + table_name=table_name, ) search_results_triples_0 = v_store.similarity_search_with_score_id( page_contents[1], @@ -571,7 +491,7 @@ def test_cassandra_vectorstore_from_documents_sync() -> None: embedding=_embedding_d2(), table_name=table_name, drop=False, - ids=["idx1", "idx3"] + ids=["idx1", "idx3"], ) hits = v_store_2.similarity_search(pc2, k=1) assert len(hits) == 1 @@ -619,16 +539,17 @@ def test_cassandra_vectorstore_from_documents_sync() -> None: assert hits[0].id == "idx3" v_store_4.clear() + async def test_cassandra_vectorstore_from_texts_async() -> None: """from_texts methods and the associated warnings, async version.""" page_contents = [ - "[1,2]", - "[3,4]", - "[5,6]", - "[7,8]", - "[9,10]", - "[11,12]", - ] + "[1,2]", + "[3,4]", + "[5,6]", + "[7,8]", + "[9,10]", + "[11,12]", + ] table_name = "empty_collection_d2" v_store = await _vectorstore_from_texts_async( @@ -684,6 +605,7 @@ async def test_cassandra_vectorstore_from_texts_async() -> None: assert res_doc_2.metadata == {"m": 11} assert res_id_2 == "ft11" + async def test_cassandra_vectorstore_from_documents_async() -> None: """ from_documents, esp. the various handling of ID-in-doc vs external. @@ -722,7 +644,7 @@ async def test_cassandra_vectorstore_from_documents_async() -> None: embedding=_embedding_d2(), table_name=table_name, drop=False, - ids=["idx1", "idx3"] + ids=["idx1", "idx3"], ) hits = await v_store_2.asimilarity_search(pc2, k=1) assert len(hits) == 1 @@ -770,23 +692,13 @@ async def test_cassandra_vectorstore_from_documents_async() -> None: assert hits[0].id == "idx3" await v_store_4.aclear() -### UPDATED TO HERE - -@pytest.mark.parametrize( - "vector_store", - [ - "vector_store_d2", - "vector_store_d2_stringtoken", - ], -) -def test_cassandra_vectorstore_crud_sync( - self, - vector_store: str, - request: pytest.FixtureRequest, -) -> None: +def test_cassandra_vectorstore_crud_sync() -> None: """Add/delete/update behaviour.""" - vstore: Cassandra = request.getfixturevalue(vector_store) + vstore = _vectorstore_from_documents( + docs=[], + embedding=_embedding_d2(), + ) res0 = vstore.similarity_search("[-1,-1]", k=2) assert res0 == [] @@ -865,20 +777,13 @@ def test_cassandra_vectorstore_crud_sync( vstore.delete_by_document_id("s") assert len(vstore.similarity_search("[-1,-1]", k=10)) == 3 -@pytest.mark.parametrize( - "vector_store", - [ - "vector_store_d2", - "vector_store_d2_stringtoken", - ], -) -async def test_cassandra_vectorstore_crud_async( - self, - vector_store: str, - request: pytest.FixtureRequest, -) -> None: + +async def test_cassandra_vectorstore_crud_async() -> None: """Add/delete/update behaviour, async version.""" - vstore: Cassandra = request.getfixturevalue(vector_store) + vstore = await _vectorstore_from_documents_async( + docs=[], + embedding=_embedding_d2(), + ) res0 = await vstore.asimilarity_search("[-1,-1]", k=2) assert res0 == [] @@ -957,11 +862,14 @@ async def test_cassandra_vectorstore_crud_async( await vstore.adelete_by_document_id("s") assert len(await vstore.asimilarity_search("[-1,-1]", k=10)) == 3 -def test_cassandra_vectorstore_massive_insert_replace_sync( - self, - vector_store_d2: Cassandra, -) -> None: + +def test_cassandra_vectorstore_massive_insert_replace_sync() -> None: """Testing the insert-many-and-replace-some patterns thoroughly.""" + vector_store_d2 = _vectorstore_from_documents( + docs=[], + embedding=_embedding_d2(), + ) + full_size = 300 first_group_size = 150 second_group_slicer = [30, 100, 2] @@ -1002,14 +910,17 @@ def test_cassandra_vectorstore_massive_insert_replace_sync( for doc, _, doc_id in full_results: assert doc.page_content == expected_text_by_id[doc_id] -async def test_cassandra_vectorstore_massive_insert_replace_async( - self, - vector_store_d2: Cassandra, -) -> None: + +async def test_cassandra_vectorstore_massive_insert_replace_async() -> None: """ Testing the insert-many-and-replace-some patterns thoroughly. Async version. """ + vector_store_d2 = await _vectorstore_from_documents_async( + docs=[], + embedding=_embedding_d2(), + ) + full_size = 300 first_group_size = 150 second_group_slicer = [30, 100, 2] @@ -1059,19 +970,20 @@ async def test_cassandra_vectorstore_massive_insert_replace_async( assert doc.page_content == expected_text_by_id[doc_id] assert embedding == expected_embedding_by_id[doc_id] -def test_cassandra_vectorstore_delete_by_metadata_sync( - self, - vector_store_d2: Cassandra, -) -> None: + +def test_cassandra_vectorstore_delete_by_metadata_sync() -> None: """Testing delete_by_metadata_filter.""" + vector_store_d2 = _vectorstore_from_documents( + docs=[], + embedding=_embedding_d2(), + ) + full_size = 400 # one in ... will be deleted deletee_ratio = 3 documents = [ - Document( - page_content="[1,1]", metadata={"deletee": doc_i % deletee_ratio == 0} - ) + Document(page_content="[1,1]", metadata={"deletee": doc_i % deletee_ratio == 0}) for doc_i in range(full_size) ] num_deletees = len([doc for doc in documents if doc.metadata["deletee"]]) @@ -1081,31 +993,27 @@ def test_cassandra_vectorstore_delete_by_metadata_sync( d_result0 = vector_store_d2.delete_by_metadata_filter({"deletee": True}) assert d_result0 == num_deletees - count_on_store0 = len( - vector_store_d2.similarity_search("[1,1]", k=full_size + 1) - ) + count_on_store0 = len(vector_store_d2.similarity_search("[1,1]", k=full_size + 1)) assert count_on_store0 == full_size - num_deletees with pytest.raises(ValueError, match="does not accept an empty"): vector_store_d2.delete_by_metadata_filter({}) - count_on_store1 = len( - vector_store_d2.similarity_search("[1,1]", k=full_size + 1) - ) + count_on_store1 = len(vector_store_d2.similarity_search("[1,1]", k=full_size + 1)) assert count_on_store1 == full_size - num_deletees -async def test_cassandra_vectorstore_delete_by_metadata_async( - self, - vector_store_d2: Cassandra, -) -> None: + +async def test_cassandra_vectorstore_delete_by_metadata_async() -> None: """Testing delete_by_metadata_filter, async version.""" + vector_store_d2 = await _vectorstore_from_documents_async( + docs=[], + embedding=_embedding_d2(), + ) full_size = 400 # one in ... will be deleted deletee_ratio = 3 documents = [ - Document( - page_content="[1,1]", metadata={"deletee": doc_i % deletee_ratio == 0} - ) + Document(page_content="[1,1]", metadata={"deletee": doc_i % deletee_ratio == 0}) for doc_i in range(full_size) ] num_deletees = len([doc for doc in documents if doc.metadata["deletee"]]) @@ -1127,120 +1035,122 @@ async def test_cassandra_vectorstore_delete_by_metadata_async( ) assert count_on_store1 == full_size - num_deletees -def test_cassandra_vectorstore_update_metadata_sync( - self, - vector_store_d2: Cassandra, -) -> None: - """Testing update_metadata.""" - # this should not exceed the max number of hits from ANN search - full_size = 20 - # one in ... will be updated - updatee_ratio = 2 - # set this to lower than full_size // updatee_ratio to test everything. - update_concurrency = 7 - - def doc_sorter(doc: Document) -> str: - return doc.id or "" - - orig_documents0 = [ + +def test_cassandra_replace_metadata() -> None: + """Test of replacing metadata.""" + N_DOCS = 100 + REPLACE_RATIO = 2 # one in ... will have replaced metadata + BATCH_SIZE = 3 + + vstore_f1 = _vectorstore_from_texts( + texts=[], + metadata_indexing=("allowlist", ["field1", "field2"]), + table_name="vector_test_table_indexing", + embedding=ConsistentFakeEmbeddings(), + ) + orig_documents = [ Document( - page_content="[1,1]", - metadata={ - "to_update": doc_i % updatee_ratio == 0, - "inert_field": "I", - "updatee_field": "0", - }, - id=f"um_doc_{doc_i}", + page_content=f"doc_{doc_i}", + id=f"doc_id_{doc_i}", + metadata={"field1": f"f1_{doc_i}", "otherf": "pre"}, ) - for doc_i in range(full_size) + for doc_i in range(N_DOCS) ] - orig_documents = sorted(orig_documents0, key=doc_sorter) + vstore_f1.add_documents(orig_documents) - inserted_ids0 = vector_store_d2.add_documents(orig_documents) - assert len(inserted_ids0) == len(orig_documents) + ids_to_replace = [ + f"doc_id_{doc_i}" for doc_i in range(N_DOCS) if doc_i % REPLACE_RATIO == 0 + ] - update_map = { - f"um_doc_{doc_i}": {"updatee_field": "1", "to_update": False} - for doc_i in range(full_size) - if doc_i % updatee_ratio == 0 + # various kinds of replacement at play here: + def _make_new_md(mode: int, doc_id: str) -> dict[str, str]: + if mode == 0: + return {} + elif mode == 1: + return {"field2": f"NEW_{doc_id}"} + elif mode == 2: + return {"field2": f"NEW_{doc_id}", "ofherf2": "post"} + else: + return {"ofherf2": "post"} + + ids_to_new_md = { + doc_id: _make_new_md(rep_i % 4, doc_id) + for rep_i, doc_id in enumerate(ids_to_replace) } - u_result0 = vector_store_d2.update_metadata( - update_map, - overwrite_concurrency=update_concurrency, - ) - assert u_result0 == len(update_map) - - all_documents = sorted( - vector_store_d2.similarity_search("[1,1]", k=full_size), - key=doc_sorter, - ) - assert len(all_documents) == len(orig_documents) - for doc, orig_doc in zip(all_documents, orig_documents): - assert doc.id == orig_doc.id - if doc.id in update_map: - assert doc.metadata == orig_doc.metadata | update_map[doc.id] - -async def test_cassandra_vectorstore_update_metadata_async( - self, - vector_store_d2: Cassandra, -) -> None: - """Testing update_metadata, async version.""" - # this should not exceed the max number of hits from ANN search - full_size = 20 - # one in ... will be updated - updatee_ratio = 2 - # set this to lower than full_size // updatee_ratio to test everything. - update_concurrency = 7 - - def doc_sorter(doc: Document) -> str: - return doc.id or "" - - orig_documents0 = [ + + vstore_f1.replace_metadata(ids_to_new_md, batch_size=BATCH_SIZE) + # thorough check + expected_id_to_metadata: dict[str, dict] = { + **{(document.id or ""): document.metadata for document in orig_documents}, + **ids_to_new_md, + } + for hit in vstore_f1.similarity_search("doc", k=N_DOCS + 1): + assert hit.id is not None + assert hit.metadata == expected_id_to_metadata[hit.id] + + +async def test_cassandra_replace_metadata_async() -> None: + """Test of replacing metadata.""" + N_DOCS = 100 + REPLACE_RATIO = 2 # one in ... will have replaced metadata + BATCH_SIZE = 3 + + vstore_f1 = _vectorstore_from_texts( + texts=[], + metadata_indexing=("allowlist", ["field1", "field2"]), + table_name="vector_test_table_indexing", + embedding=ConsistentFakeEmbeddings(), + ) + orig_documents = [ Document( - page_content="[1,1]", - metadata={ - "to_update": doc_i % updatee_ratio == 0, - "inert_field": "I", - "updatee_field": "0", - }, - id=f"um_doc_{doc_i}", + page_content=f"doc_{doc_i}", + id=f"doc_id_{doc_i}", + metadata={"field1": f"f1_{doc_i}", "otherf": "pre"}, ) - for doc_i in range(full_size) + for doc_i in range(N_DOCS) ] - orig_documents = sorted(orig_documents0, key=doc_sorter) + await vstore_f1.aadd_documents(orig_documents) - inserted_ids0 = await vector_store_d2.aadd_documents(orig_documents) - assert len(inserted_ids0) == len(orig_documents) + ids_to_replace = [ + f"doc_id_{doc_i}" for doc_i in range(N_DOCS) if doc_i % REPLACE_RATIO == 0 + ] - update_map = { - f"um_doc_{doc_i}": {"updatee_field": "1", "to_update": False} - for doc_i in range(full_size) - if doc_i % updatee_ratio == 0 + # various kinds of replacement at play here: + def _make_new_md(mode: int, doc_id: str) -> dict[str, str]: + if mode == 0: + return {} + elif mode == 1: + return {"field2": f"NEW_{doc_id}"} + elif mode == 2: + return {"field2": f"NEW_{doc_id}", "ofherf2": "post"} + else: + return {"ofherf2": "post"} + + ids_to_new_md = { + doc_id: _make_new_md(rep_i % 4, doc_id) + for rep_i, doc_id in enumerate(ids_to_replace) + } + + await vstore_f1.areplace_metadata(ids_to_new_md, concurrency=BATCH_SIZE) + # thorough check + expected_id_to_metadata: dict[str, dict] = { + **{(document.id or ""): document.metadata for document in orig_documents}, + **ids_to_new_md, } - u_result0 = await vector_store_d2.aupdate_metadata( - update_map, - overwrite_concurrency=update_concurrency, - ) - assert u_result0 == len(update_map) - - all_documents = sorted( - await vector_store_d2.asimilarity_search("[1,1]", k=full_size), - key=doc_sorter, - ) - assert len(all_documents) == len(orig_documents) - for doc, orig_doc in zip(all_documents, orig_documents): - assert doc.id == orig_doc.id - if doc.id in update_map: - assert doc.metadata == orig_doc.metadata | update_map[doc.id] - -def test_cassandra_vectorstore_mmr_sync( - self, - vector_store_d2: Cassandra, -) -> None: + for hit in await vstore_f1.asimilarity_search("doc", k=N_DOCS + 1): + assert hit.id is not None + assert hit.metadata == expected_id_to_metadata[hit.id] + + +def test_cassandra_vectorstore_mmr_sync() -> None: """MMR testing. We work on the unit circle with angle multiples of 2*pi/20 and prepare a store with known vectors for a controlled MMR outcome. """ + vector_store_d2 = _vectorstore_from_documents( + docs=[], + embedding=_embedding_d2(), + ) def _v_from_i(i: int, n: int) -> str: angle = 2 * math.pi * i / n @@ -1260,15 +1170,17 @@ def _v_from_i(i: int, n: int) -> str: res_i_vals = {doc.metadata["i"] for doc in res1} assert res_i_vals == {0, 4} -async def test_cassandra_vectorstore_mmr_async( - self, - vector_store_d2: Cassandra, -) -> None: + +async def test_cassandra_vectorstore_mmr_async() -> None: """MMR testing. We work on the unit circle with angle multiples of 2*pi/20 and prepare a store with known vectors for a controlled MMR outcome. Async version. """ + vector_store_d2 = await _vectorstore_from_documents_async( + docs=[], + embedding=_embedding_d2(), + ) def _v_from_i(i: int, n: int) -> str: angle = 2 * math.pi * i / n @@ -1289,21 +1201,13 @@ def _v_from_i(i: int, n: int) -> str: res_i_vals = {doc.metadata["i"] for doc in res1} assert res_i_vals == {0, 4} -@pytest.mark.parametrize( - "vector_store", - [ - "vector_store_d2", - ], -) -def test_cassandra_vectorstore_metadata_filter( - self, - vector_store: str, - request: pytest.FixtureRequest, - metadata_documents: list[Document], -) -> None: + +def test_cassandra_vectorstore_metadata_filter() -> None: """Metadata filtering.""" - vstore: Cassandra = request.getfixturevalue(vector_store) - vstore.add_documents(metadata_documents) + vstore = _vectorstore_from_documents( + docs=_metadata_documents(), + embedding=_embedding_d2(), + ) # no filters res0 = vstore.similarity_search("[-1,-1]", k=10) assert {doc.metadata["letter"] for doc in res0} == set("qwreio") @@ -1318,39 +1222,24 @@ def test_cassandra_vectorstore_metadata_filter( res2 = vstore.similarity_search( "[-1,-1]", k=10, - filter={"group": "consonant", "ord": ord("q")}, + filter={"group": "consonant", "ord": str(ord("q"))}, ) assert {doc.metadata["letter"] for doc in res2} == set("q") # excessive filters res3 = vstore.similarity_search( "[-1,-1]", k=10, - filter={"group": "consonant", "ord": ord("q"), "case": "upper"}, + filter={"group": "consonant", "ord": str(ord("q")), "case": "upper"}, ) assert res3 == [] - # filter with logical operator - res4 = vstore.similarity_search( - "[-1,-1]", - k=10, - filter={"$or": [{"ord": ord("q")}, {"ord": ord("r")}]}, - ) - assert {doc.metadata["letter"] for doc in res4} == {"q", "r"} -@pytest.mark.parametrize( - "vector_store", - [ - "vector_store_d2", - ], -) -def test_cassandra_vectorstore_metadata_search_sync( - self, - vector_store: str, - request: pytest.FixtureRequest, - metadata_documents: list[Document], -) -> None: + +def test_cassandra_vectorstore_metadata_search_sync() -> None: """Metadata Search""" - vstore: Cassandra = request.getfixturevalue(vector_store) - vstore.add_documents(metadata_documents) + vstore = _vectorstore_from_documents( + docs=_metadata_documents(), + embedding=_embedding_d2(), + ) # no filters res0 = vstore.metadata_search(filter={}, n=10) assert {doc.metadata["letter"] for doc in res0} == set("qwreio") @@ -1363,37 +1252,24 @@ def test_cassandra_vectorstore_metadata_search_sync( # multiple filters res2 = vstore.metadata_search( n=10, - filter={"group": "consonant", "ord": ord("q")}, + filter={"group": "consonant", "ord": str(ord("q"))}, ) assert {doc.metadata["letter"] for doc in res2} == set("q") # excessive filters res3 = vstore.metadata_search( n=10, - filter={"group": "consonant", "ord": ord("q"), "case": "upper"}, + filter={"group": "consonant", "ord": str(ord("q")), "case": "upper"}, ) assert res3 == [] - # filter with logical operator - res4 = vstore.metadata_search( - n=10, - filter={"$or": [{"ord": ord("q")}, {"ord": ord("r")}]}, - ) - assert {doc.metadata["letter"] for doc in res4} == {"q", "r"} -@pytest.mark.parametrize( - "vector_store", - [ - "vector_store_d2", - ], -) -async def test_cassandra_vectorstore_metadata_search_async( - self, - vector_store: str, - request: pytest.FixtureRequest, - metadata_documents: list[Document], -) -> None: + +async def test_cassandra_vectorstore_metadata_search_async() -> None: """Metadata Search""" - vstore: Cassandra = request.getfixturevalue(vector_store) - await vstore.aadd_documents(metadata_documents) + vstore = await _vectorstore_from_documents_async( + docs=_metadata_documents(), + embedding=_embedding_d2(), + ) + # no filters res0 = await vstore.ametadata_search(filter={}, n=10) assert {doc.metadata["letter"] for doc in res0} == set("qwreio") @@ -1422,21 +1298,13 @@ async def test_cassandra_vectorstore_metadata_search_async( ) assert {doc.metadata["letter"] for doc in res4} == {"q", "r"} -@pytest.mark.parametrize( - "vector_store", - [ - "vector_store_d2", - ], -) -def test_cassandra_vectorstore_get_by_document_id_sync( - self, - vector_store: str, - request: pytest.FixtureRequest, - metadata_documents: list[Document], -) -> None: + +def test_cassandra_vectorstore_get_by_document_id_sync() -> None: """Get by document_id""" - vstore: Cassandra = request.getfixturevalue(vector_store) - vstore.add_documents(metadata_documents) + vstore = _vectorstore_from_documents( + docs=_metadata_documents(), + embedding=_embedding_d2(), + ) # invalid id invalid = vstore.get_by_document_id(document_id="z") assert invalid is None @@ -1448,21 +1316,13 @@ def test_cassandra_vectorstore_get_by_document_id_sync( assert valid.metadata["group"] == "consonant" assert valid.metadata["letter"] == "q" -@pytest.mark.parametrize( - "vector_store", - [ - "vector_store_d2", - ], -) -async def test_cassandra_vectorstore_get_by_document_id_async( - self, - vector_store: str, - request: pytest.FixtureRequest, - metadata_documents: list[Document], -) -> None: + +async def test_cassandra_vectorstore_get_by_document_id_async() -> None: """Get by document_id""" - vstore: Cassandra = request.getfixturevalue(vector_store) - await vstore.aadd_documents(metadata_documents) + vstore = await _vectorstore_from_documents_async( + docs=_metadata_documents(), + embedding=_embedding_d2(), + ) # invalid id invalid = await vstore.aget_by_document_id(document_id="z") assert invalid is None @@ -1474,33 +1334,16 @@ async def test_cassandra_vectorstore_get_by_document_id_async( assert valid.metadata["group"] == "consonant" assert valid.metadata["letter"] == "q" -@pytest.mark.parametrize( - ("is_vectorize", "vector_store", "texts", "query"), - [ - ( - False, - "vector_store_d2", - ["[1,1]", "[-1,-1]"], - "[0.99999,1.00001]", - ), - ], - ids=["nonvectorize_store"], -) -def test_cassandra_vectorstore_similarity_scale_sync( - self, - *, - is_vectorize: bool, - vector_store: str, - texts: list[str], - query: str, - request: pytest.FixtureRequest, -) -> None: + +def test_cassandra_vectorstore_similarity_scale_sync() -> None: """Scale of the similarity scores.""" - vstore: Cassandra = request.getfixturevalue(vector_store) - vstore.add_texts( - texts=texts, + vstore = _vectorstore_from_texts( + texts=["[1,1]", "[-1,-1]"], ids=["near", "far"], + embedding=_embedding_d2(), ) + query = "[0.99999,1.00001]" + res1 = vstore.similarity_search_with_score( query, k=2, @@ -1508,37 +1351,19 @@ def test_cassandra_vectorstore_similarity_scale_sync( scores = [sco for _, sco in res1] sco_near, sco_far = scores assert sco_far >= 0 - if not is_vectorize: - assert abs(1 - sco_near) < MATCH_EPSILON - assert sco_far < EUCLIDEAN_MIN_SIM_UNIT_VECTORS + MATCH_EPSILON + assert abs(1 - sco_near) < MATCH_EPSILON + assert sco_far < EUCLIDEAN_MIN_SIM_UNIT_VECTORS + MATCH_EPSILON -@pytest.mark.parametrize( - ("is_vectorize", "vector_store", "texts", "query"), - [ - ( - False, - "vector_store_d2", - ["[1,1]", "[-1,-1]"], - "[0.99999,1.00001]", - ), - ], - ids=["nonvectorize_store"], -) -async def test_cassandra_vectorstore_similarity_scale_async( - self, - *, - is_vectorize: bool, - vector_store: str, - texts: list[str], - query: str, - request: pytest.FixtureRequest, -) -> None: + +async def test_cassandra_vectorstore_similarity_scale_async() -> None: """Scale of the similarity scores, async version.""" - vstore: Cassandra = request.getfixturevalue(vector_store) - await vstore.aadd_texts( - texts=texts, + vstore = await _vectorstore_from_texts_async( + texts=["[1,1]", "[-1,-1]"], ids=["near", "far"], + embedding=_embedding_d2(), ) + + query = "[0.99999,1.00001]" res1 = await vstore.asimilarity_search_with_score( query, k=2, @@ -1546,23 +1371,16 @@ async def test_cassandra_vectorstore_similarity_scale_async( scores = [sco for _, sco in res1] sco_near, sco_far = scores assert sco_far >= 0 - if not is_vectorize: - assert abs(1 - sco_near) < MATCH_EPSILON - assert sco_far < EUCLIDEAN_MIN_SIM_UNIT_VECTORS + MATCH_EPSILON - -@pytest.mark.parametrize( - "vector_store", - [ - "vector_store_d2", - ], -) -def test_cassandra_vectorstore_massive_delete( - self, - vector_store: str, - request: pytest.FixtureRequest, -) -> None: + assert abs(1 - sco_near) < MATCH_EPSILON + assert sco_far < EUCLIDEAN_MIN_SIM_UNIT_VECTORS + MATCH_EPSILON + + +def test_cassandra_vectorstore_massive_delete() -> None: """Larger-scale bulk deletes.""" - vstore: Cassandra = request.getfixturevalue(vector_store) + vstore = _vectorstore_from_documents( + docs=[], + embedding=_embedding_d2(), + ) m = 150 texts = [f"[0,{i + 1 / 7.0}]" for i in range(2 * m)] ids0 = [f"doc_{i}" for i in range(m)] @@ -1577,192 +1395,3 @@ def test_cassandra_vectorstore_massive_delete( assert del_res1 is True # ensure no error # nothing left assert vstore.similarity_search("[-1,-1]", k=2 * m) == [] - -def test_cassandra_vectorstore_custom_params_sync( - self, - astra_db_credentials: AstraDBCredentials, - empty_collection_d2: Collection, - embedding_d2: Embeddings, -) -> None: - """Custom batch size and concurrency params.""" - v_store = Cassandra( - embedding=embedding_d2, - collection_name=empty_collection_d2.name, - token=StaticTokenProvider(astra_db_credentials["token"]), - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - setup_mode=SetupMode.OFF, - batch_size=17, - bulk_insert_batch_concurrency=13, - bulk_insert_overwrite_concurrency=7, - bulk_delete_concurrency=19, - ) - # add_texts and delete some - n = 120 - texts = [f"[0,{i + 1 / 7.0}]" for i in range(n)] - ids = [f"doc_{i}" for i in range(n)] - v_store.add_texts(texts=texts, ids=ids) - v_store.add_texts( - texts=texts, - ids=ids, - batch_size=19, - batch_concurrency=7, - overwrite_concurrency=13, - ) - v_store.delete(ids[: n // 2]) - v_store.delete(ids[n // 2 :], concurrency=23) - -async def test_cassandra_vectorstore_custom_params_async( - self, - astra_db_credentials: AstraDBCredentials, - empty_collection_d2: Collection, - embedding_d2: Embeddings, -) -> None: - """Custom batch size and concurrency params, async version""" - v_store = Cassandra( - embedding=embedding_d2, - collection_name=empty_collection_d2.name, - token=StaticTokenProvider(astra_db_credentials["token"]), - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - setup_mode=SetupMode.OFF, - batch_size=17, - bulk_insert_batch_concurrency=13, - bulk_insert_overwrite_concurrency=7, - bulk_delete_concurrency=19, - ) - # add_texts and delete some - n = 120 - texts = [f"[0,{i + 1 / 7.0}]" for i in range(n)] - ids = [f"doc_{i}" for i in range(n)] - await v_store.aadd_texts(texts=texts, ids=ids) - await v_store.aadd_texts( - texts=texts, - ids=ids, - batch_size=19, - batch_concurrency=7, - overwrite_concurrency=13, - ) - await v_store.adelete(ids[: n // 2]) - await v_store.adelete(ids[n // 2 :], concurrency=23) - -def test_cassandra_vectorstore_metrics( - self, - astra_db_credentials: AstraDBCredentials, - embedding_d2: Embeddings, - vector_store_d2: Cassandra, - ephemeral_collection_cleaner_d2: str, -) -> None: - """Different choices of similarity metric. - Both stores (with "cosine" and "euclidea" metrics) contain these two: - - a vector slightly rotated w.r.t query vector - - a vector which is a long multiple of query vector - so, which one is "the closest one" depends on the metric. - """ - euclidean_store = vector_store_d2 - - isq2 = 0.5**0.5 - isa = 0.7 - isb = (1.0 - isa * isa) ** 0.5 - texts = [ - json.dumps([isa, isb]), - json.dumps([10 * isq2, 10 * isq2]), - ] - ids = ["rotated", "scaled"] - query_text = json.dumps([isq2, isq2]) - - # prepare empty collections - cosine_store = Cassandra( - embedding=embedding_d2, - collection_name=ephemeral_collection_cleaner_d2, - metric="cosine", - token=StaticTokenProvider(astra_db_credentials["token"]), - api_endpoint=astra_db_credentials["api_endpoint"], - namespace=astra_db_credentials["namespace"], - environment=astra_db_credentials["environment"], - ) - - cosine_store.add_texts(texts=texts, ids=ids) - euclidean_store.add_texts(texts=texts, ids=ids) - - cosine_triples = cosine_store.similarity_search_with_score_id( - query_text, - k=1, - ) - euclidean_triples = euclidean_store.similarity_search_with_score_id( - query_text, - k=1, - ) - assert len(cosine_triples) == 1 - assert len(euclidean_triples) == 1 - assert cosine_triples[0][2] == "scaled" - assert euclidean_triples[0][2] == "rotated" - -@pytest.mark.skipif( - os.environ.get("ASTRA_DB_ENVIRONMENT", "prod").upper() != "PROD", - reason="Can run on Astra DB production environment only", -) -def test_cassandra_vectorstore_coreclients_init_sync( - self, - core_astra_db: AstraDB, - embedding_d2: Embeddings, - vector_store_d2: Cassandra, -) -> None: - """ - Expect a deprecation warning from passing a (core) AstraDB class, - but it must work. - """ - vector_store_d2.add_texts(["[1,2]"]) - - with pytest.warns(DeprecationWarning) as rec_warnings: - v_store_init_core = Cassandra( - embedding=embedding_d2, - collection_name=vector_store_d2.collection_name, - astra_db_client=core_astra_db, - metric="euclidean", - ) - - results = v_store_init_core.similarity_search("[-1,-1]", k=1) - # cleaning out 'spurious' "unclosed socket/transport..." warnings - f_rec_warnings = [ - wrn for wrn in rec_warnings if issubclass(wrn.category, DeprecationWarning) - ] - assert len(f_rec_warnings) == 1 - assert len(results) == 1 - assert results[0].page_content == "[1,2]" - -@pytest.mark.skipif( - os.environ.get("ASTRA_DB_ENVIRONMENT", "prod").upper() != "PROD", - reason="Can run on Astra DB production environment only", -) -async def test_cassandra_vectorstore_coreclients_init_async( - self, - core_astra_db: AstraDB, - embedding_d2: Embeddings, - vector_store_d2: Cassandra, -) -> None: - """ - Expect a deprecation warning from passing a (core) AstraDB class, - but it must work. Async version. - """ - vector_store_d2.add_texts(["[1,2]"]) - - with pytest.warns(DeprecationWarning) as rec_warnings: - v_store_init_core = Cassandra( - embedding=embedding_d2, - collection_name=vector_store_d2.collection_name, - astra_db_client=core_astra_db, - metric="euclidean", - setup_mode=SetupMode.ASYNC, - ) - - results = await v_store_init_core.asimilarity_search("[-1,-1]", k=1) - # cleaning out 'spurious' "unclosed socket/transport..." warnings - f_rec_warnings = [ - wrn for wrn in rec_warnings if issubclass(wrn.category, DeprecationWarning) - ] - assert len(f_rec_warnings) == 1 - assert len(results) == 1 - assert results[0].page_content == "[1,2]"