Skip to content

Commit

Permalink
[BUG] Test that query result shapes are correct in invariants (#2807)
Browse files Browse the repository at this point in the history
## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
- Discovered that there is an issue with `PersistentLocalHnswSegment`
where `query_vectors` can return results with lengths different than
`n_results`. This PR implements a test both displays and catches this
issue.
- Fixes off by one issue in hnsw/BF merge logic in both single node and
distributed

## Test plan
*How are these changes tested?*
- Added a test for the breaking case that fails on main, fix shows test
passes
- [x] Tests pass locally with `pytest` for python, `yarn test` for js,
`cargo test` for rust

---------

Co-authored-by: hammadb <[email protected]>
  • Loading branch information
drewkim and HammadB authored Sep 17, 2024
1 parent 9ab0196 commit d17472e
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 5 deletions.
4 changes: 3 additions & 1 deletion chromadb/segment/impl/vector/local_persistent_hnsw.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,8 @@ def query_vectors(
# Overquery by updated and deleted elements layered on the index because they may
# hide the real nearest neighbors in the hnsw index
hnsw_k = k + self._curr_batch.update_count + self._curr_batch.delete_count
# self._id_to_label contains the ids of the elements in the hnsw index
# so its length is the number of elements in the hnsw index
if hnsw_k > len(self._id_to_label):
hnsw_k = len(self._id_to_label)
hnsw_query = VectorQuery(
Expand Down Expand Up @@ -472,7 +474,7 @@ def query_vectors(
if remaining > 0 and hnsw_pointer < len(curr_hnsw_result):
for i in range(
hnsw_pointer,
min(len(curr_hnsw_result), hnsw_pointer + remaining + 1),
min(len(curr_hnsw_result), hnsw_pointer + remaining),
):
id = curr_hnsw_result[i]["id"]
if not self._brute_force_index.has_id(id):
Expand Down
18 changes: 15 additions & 3 deletions chromadb/test/property/invariants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
from chromadb.db.base import get_sql
from chromadb.db.impl.sqlite import SqliteDB
from time import sleep

import psutil

from chromadb.test.property.strategies import NormalizedRecordSet, RecordSet
from typing import Callable, Optional, Tuple, Union, List, TypeVar, cast
from typing_extensions import Literal
Expand Down Expand Up @@ -261,10 +259,14 @@ def ann_accuracy(
include=["embeddings", "documents", "metadatas", "distances"], # type: ignore[list-item]
)

_query_results_are_correct_shape(query_results, n_results)

# Assert fields are not None for type checking
assert query_results["ids"] is not None
assert query_results["distances"] is not None
assert query_results["embeddings"] is not None
assert query_results["documents"] is not None
assert query_results["metadatas"] is not None
assert query_results["embeddings"] is not None

# Dict of ids to indices
id_to_index = {id: i for i, id in enumerate(normalized_record_set["ids"])}
Expand Down Expand Up @@ -324,6 +326,16 @@ def ann_accuracy(
assert np.allclose(np.sort(distance_result), distance_result)


def _query_results_are_correct_shape(
query_results: types.QueryResult, n_results: int
) -> None:
for result_type in ["distances", "embeddings", "documents", "metadatas"]:
assert query_results[result_type] is not None # type: ignore[literal-required]
assert all(
len(result) == n_results for result in query_results[result_type] # type: ignore[literal-required]
)


def _total_embedding_queue_log_size(sqlite: SqliteDB) -> int:
t = Table("embeddings_queue")
q = sqlite.querybuilder().from_(t)
Expand Down
50 changes: 50 additions & 0 deletions chromadb/test/property/test_persist.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import tempfile
from chromadb.api.client import Client as ClientCreator
from chromadb.utils.embedding_functions import DefaultEmbeddingFunction
import numpy as np

CreatePersistAPI = Callable[[], ServerAPI]

Expand Down Expand Up @@ -308,6 +309,55 @@ def test_persist_embeddings_state(
) # type: ignore


def test_delete_less_than_k(
caplog: pytest.LogCaptureFixture, settings: Settings
) -> None:
client = chromadb.Client(settings)
state = PersistEmbeddingsStateMachine(settings=settings, client=client)
state.initialize(
collection=strategies.Collection(
name="A00",
metadata={
"hnsw:construction_ef": 128,
"hnsw:search_ef": 128,
"hnsw:M": 128,
"hnsw:sync_threshold": 3,
"hnsw:batch_size": 3,
},
embedding_function=None,
id=UUID("2d3eddc7-2314-45f4-a951-47a9a8e099d2"),
dimension=2,
dtype=np.float16,
known_metadata_keys={},
known_document_keywords=[],
has_documents=False,
has_embeddings=True,
)
)
state.ann_accuracy()
state.count()
state.fields_match()
state.log_size_below_max()
state.no_duplicates()
(embedding_ids_0,) = state.add_embeddings(record_set={"ids": ["0"], "embeddings": [[0.09765625, 0.430419921875]], "metadatas": [None], "documents": None}) # type: ignore
state.ann_accuracy()
# recall: 1.0, missing 0 out of 1, accuracy threshold 1e-06
state.count()
state.fields_match()
state.log_size_below_max()
state.no_duplicates()
embedding_ids_1, embedding_ids_2 = state.add_embeddings(record_set={"ids": ["1", "2"], "embeddings": [[0.20556640625, 0.08978271484375], [-0.1527099609375, 0.291748046875]], "metadatas": [None, None], "documents": None}) # type: ignore
state.ann_accuracy()
# recall: 1.0, missing 0 out of 3, accuracy threshold 1e-06
state.count()
state.fields_match()
state.log_size_below_max()
state.no_duplicates()
state.delete_by_ids(ids=[embedding_ids_2])
state.ann_accuracy()
state.teardown()


# Ideally this scenario would be exercised by Hypothesis, but most runs don't seem to trigger this particular state.
def test_delete_add_after_persist(settings: Settings) -> None:
client = chromadb.Client(settings)
Expand Down
2 changes: 1 addition & 1 deletion rust/worker/src/execution/operators/merge_knn_results.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ fn merge_results(
let mut brute_force_index = 0;

// TODO: This doesn't have to clone the user IDs, but it's easier for now
while (result_user_ids.len() <= k)
while (result_user_ids.len() < k)
&& (hnsw_index < hnsw_result_user_ids.len()
|| brute_force_index < brute_force_result_user_ids.len())
{
Expand Down

0 comments on commit d17472e

Please sign in to comment.