diff --git a/libs/partners/chroma/tests/integration_tests/test_vectorstores.py b/libs/partners/chroma/tests/integration_tests/test_vectorstores.py index 382b24cb54b47..4393d5f339b09 100644 --- a/libs/partners/chroma/tests/integration_tests/test_vectorstores.py +++ b/libs/partners/chroma/tests/integration_tests/test_vectorstores.py @@ -1,12 +1,16 @@ """Test Chroma functionality.""" import uuid -from typing import Generator +from typing import ( + Generator, + cast, +) import chromadb import pytest # type: ignore[import-not-found] import requests from chromadb.api.client import SharedSystemClient +from chromadb.api.types import Embeddable from langchain_core.documents import Document from langchain_core.embeddings.fake import FakeEmbeddings as Fak @@ -17,6 +21,15 @@ ) +class MyEmbeddingFunction: + def __init__(self, fak: Fak): + self.fak = fak + + def __call__(self, input: Embeddable) -> list[list[float]]: + texts = cast(list[str], input) + return self.fak.embed_documents(texts=texts) + + @pytest.fixture() def client() -> Generator[chromadb.ClientAPI, None, None]: SharedSystemClient.clear_system_cache() @@ -254,8 +267,8 @@ def test_chroma_update_document() -> None: # Assert that the updated document is returned by the search assert output == [Document(page_content=updated_content, metadata={"page": "0"})] - assert new_embedding == embedding.embed_documents([updated_content])[0] - assert new_embedding != old_embedding + assert list(new_embedding) == list(embedding.embed_documents([updated_content])[0]) + assert list(new_embedding) != list(old_embedding) # TODO: RELEVANCE SCORE IS BROKEN. FIX TEST @@ -341,17 +354,17 @@ def batch_support_chroma_version() -> bool: ) def test_chroma_large_batch() -> None: client = chromadb.HttpClient() - embedding_function = Fak(size=255) + embedding_function = MyEmbeddingFunction(fak=Fak(size=255)) col = client.get_or_create_collection( "my_collection", - embedding_function=embedding_function.embed_documents, # type: ignore + embedding_function=embedding_function, # type: ignore ) - docs = ["This is a test document"] * (client.max_batch_size + 100) # type: ignore + docs = ["This is a test document"] * (client.get_max_batch_size() + 100) # type: ignore db = Chroma.from_texts( client=client, collection_name=col.name, texts=docs, - embedding=embedding_function, + embedding=embedding_function.fak, ids=[str(uuid.uuid4()) for _ in range(len(docs))], ) @@ -369,18 +382,18 @@ def test_chroma_large_batch() -> None: ) def test_chroma_large_batch_update() -> None: client = chromadb.HttpClient() - embedding_function = Fak(size=255) + embedding_function = MyEmbeddingFunction(fak=Fak(size=255)) col = client.get_or_create_collection( "my_collection", - embedding_function=embedding_function.embed_documents, # type: ignore + embedding_function=embedding_function, # type: ignore ) - docs = ["This is a test document"] * (client.max_batch_size + 100) # type: ignore + docs = ["This is a test document"] * (client.get_max_batch_size() + 100) # type: ignore ids = [str(uuid.uuid4()) for _ in range(len(docs))] db = Chroma.from_texts( client=client, collection_name=col.name, texts=docs, - embedding=embedding_function, + embedding=embedding_function.fak, ids=ids, ) new_docs = [ @@ -408,7 +421,7 @@ def test_chroma_legacy_batching() -> None: embedding_function = Fak(size=255) col = client.get_or_create_collection( "my_collection", - embedding_function=embedding_function.embed_documents, # type: ignore + embedding_function=MyEmbeddingFunction, # type: ignore ) docs = ["This is a test document"] * 100 db = Chroma.from_texts(