Skip to content

Commit

Permalink
langchain_chroma: fixed integration tests (#27968)
Browse files Browse the repository at this point in the history
Description:
* I'm planning to add `Document.id` support to the Chroma VectorStore,
but first I wanted to make sure all the integration tests were passing
first. They weren't. This PR fixes the broken tests.
* I found 2 issues:
* This change (from a year ago, exactly :) ) for supporting multi-modal
embeddings:
https://docs.trychroma.com/deployment/migration#migration-to-0.4.16---november-7,-2023
* This change #27827 due
to an update in the chroma client.
  
Also ran `format` and `lint` on the changes.

Note: I am not a member of the Chroma team.
  • Loading branch information
epinzur authored Nov 20, 2024
1 parent 218b4e0 commit 923ef85
Showing 1 changed file with 25 additions and 12 deletions.
37 changes: 25 additions & 12 deletions libs/partners/chroma/tests/integration_tests/test_vectorstores.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
"""Test Chroma functionality."""

import uuid
from typing import Generator
from typing import (
Generator,
cast,
)

import chromadb
import pytest # type: ignore[import-not-found]
import requests
from chromadb.api.client import SharedSystemClient
from chromadb.api.types import Embeddable
from langchain_core.documents import Document
from langchain_core.embeddings.fake import FakeEmbeddings as Fak

Expand All @@ -17,6 +21,15 @@
)


class MyEmbeddingFunction:
def __init__(self, fak: Fak):
self.fak = fak

def __call__(self, input: Embeddable) -> list[list[float]]:
texts = cast(list[str], input)
return self.fak.embed_documents(texts=texts)


@pytest.fixture()
def client() -> Generator[chromadb.ClientAPI, None, None]:
SharedSystemClient.clear_system_cache()
Expand Down Expand Up @@ -254,8 +267,8 @@ def test_chroma_update_document() -> None:
# Assert that the updated document is returned by the search
assert output == [Document(page_content=updated_content, metadata={"page": "0"})]

assert new_embedding == embedding.embed_documents([updated_content])[0]
assert new_embedding != old_embedding
assert list(new_embedding) == list(embedding.embed_documents([updated_content])[0])
assert list(new_embedding) != list(old_embedding)


# TODO: RELEVANCE SCORE IS BROKEN. FIX TEST
Expand Down Expand Up @@ -341,17 +354,17 @@ def batch_support_chroma_version() -> bool:
)
def test_chroma_large_batch() -> None:
client = chromadb.HttpClient()
embedding_function = Fak(size=255)
embedding_function = MyEmbeddingFunction(fak=Fak(size=255))
col = client.get_or_create_collection(
"my_collection",
embedding_function=embedding_function.embed_documents, # type: ignore
embedding_function=embedding_function, # type: ignore
)
docs = ["This is a test document"] * (client.max_batch_size + 100) # type: ignore
docs = ["This is a test document"] * (client.get_max_batch_size() + 100) # type: ignore
db = Chroma.from_texts(
client=client,
collection_name=col.name,
texts=docs,
embedding=embedding_function,
embedding=embedding_function.fak,
ids=[str(uuid.uuid4()) for _ in range(len(docs))],
)

Expand All @@ -369,18 +382,18 @@ def test_chroma_large_batch() -> None:
)
def test_chroma_large_batch_update() -> None:
client = chromadb.HttpClient()
embedding_function = Fak(size=255)
embedding_function = MyEmbeddingFunction(fak=Fak(size=255))
col = client.get_or_create_collection(
"my_collection",
embedding_function=embedding_function.embed_documents, # type: ignore
embedding_function=embedding_function, # type: ignore
)
docs = ["This is a test document"] * (client.max_batch_size + 100) # type: ignore
docs = ["This is a test document"] * (client.get_max_batch_size() + 100) # type: ignore
ids = [str(uuid.uuid4()) for _ in range(len(docs))]
db = Chroma.from_texts(
client=client,
collection_name=col.name,
texts=docs,
embedding=embedding_function,
embedding=embedding_function.fak,
ids=ids,
)
new_docs = [
Expand Down Expand Up @@ -408,7 +421,7 @@ def test_chroma_legacy_batching() -> None:
embedding_function = Fak(size=255)
col = client.get_or_create_collection(
"my_collection",
embedding_function=embedding_function.embed_documents, # type: ignore
embedding_function=MyEmbeddingFunction, # type: ignore
)
docs = ["This is a test document"] * 100
db = Chroma.from_texts(
Expand Down

0 comments on commit 923ef85

Please sign in to comment.