Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add max_distance filter to query method #3268

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions chromadb/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ def _get(
page: Optional[int] = None,
page_size: Optional[int] = None,
where_document: Optional[WhereDocument] = None,
include: Include = IncludeMetadataDocumentsEmbeddings,
include: Include = IncludeMetadataDocumentsEmbeddings, # type: ignore
max_distance: Optional[float] = None,
) -> GetResult:
"""[Internal] Returns entries from a collection specified by UUID.

Expand Down Expand Up @@ -277,7 +278,8 @@ def _query(
n_results: int = 10,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
include: Include = IncludeMetadataDocumentsEmbeddingsDistances,
include: Include = IncludeMetadataDocumentsEmbeddingsDistances, # type: ignore
max_distance: Optional[float] = None,
) -> QueryResult:
"""[Internal] Performs a nearest neighbors query on a collection specified by UUID.

Expand Down Expand Up @@ -648,6 +650,7 @@ def _get(
page_size: Optional[int] = None,
where_document: Optional[WhereDocument] = None,
include: Include = ["metadatas", "documents"], # type: ignore[list-item]
max_distance: Optional[float] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> GetResult:
Expand Down Expand Up @@ -708,6 +711,7 @@ def _query(
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
include: Include = ["metadatas", "documents", "distances"], # type: ignore[list-item]
max_distance: Optional[float] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> QueryResult:
Expand Down
1 change: 1 addition & 0 deletions chromadb/api/async_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,7 @@ async def _query(
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
include: Include = ["metadatas", "documents", "distances"], # type: ignore[list-item]
max_distance: Optional[float] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> QueryResult:
Expand Down
2 changes: 2 additions & 0 deletions chromadb/api/async_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,7 @@ async def _query(
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
include: Include = ["metadatas", "documents", "distances"], # type: ignore[list-item]
max_distance: Optional[float] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> QueryResult:
Expand All @@ -557,6 +558,7 @@ async def _query(
"where": where,
"where_document": where_document,
"include": include,
"max_distance": max_distance,
},
)

Expand Down
4 changes: 4 additions & 0 deletions chromadb/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ def _get(
page_size: Optional[int] = None,
where_document: Optional[WhereDocument] = None,
include: Include = ["embeddings", "metadatas", "documents"], # type: ignore[list-item]
max_distance: Optional[float] = None,
) -> GetResult:
return self._server._get(
collection_id=collection_id,
Expand All @@ -342,6 +343,7 @@ def _get(
page_size=page_size,
where_document=where_document,
include=include,
max_distance=max_distance,
)

def _delete(
Expand Down Expand Up @@ -369,6 +371,7 @@ def _query(
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
include: Include = ["embeddings", "metadatas", "documents", "distances"], # type: ignore[list-item]
max_distance: Optional[float] = None,
) -> QueryResult:
return self._server._query(
collection_id=collection_id,
Expand All @@ -379,6 +382,7 @@ def _query(
where=where,
where_document=where_document,
include=include,
max_distance=max_distance,
)

@override
Expand Down
4 changes: 4 additions & 0 deletions chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ def _get(
page_size: Optional[int] = None,
where_document: Optional[WhereDocument] = None,
include: Include = ["metadatas", "documents"], # type: ignore[list-item]
max_distance: Optional[float] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> GetResult:
Expand All @@ -344,6 +345,7 @@ def _get(
"offset": offset,
"where_document": where_document,
"include": include,
"max_distance": max_distance,
},
)

Expand Down Expand Up @@ -512,6 +514,7 @@ def _query(
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
include: Include = ["metadatas", "documents", "distances"], # type: ignore[list-item]
max_distance: Optional[float] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> QueryResult:
Expand All @@ -527,6 +530,7 @@ def _query(
"where": where,
"where_document": where_document,
"include": include,
"max_distance": max_distance,
},
)

Expand Down
3 changes: 3 additions & 0 deletions chromadb/api/models/AsyncCollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ async def query(
IncludeEnum.documents,
IncludeEnum.distances,
],
max_distance: Optional[float] = None,
) -> QueryResult:
"""Get the n_results nearest neighbor embeddings for provided query_embeddings or query_texts.

Expand Down Expand Up @@ -210,6 +211,7 @@ async def query(
where=where,
where_document=where_document,
include=include,
max_distance=max_distance,
)

query_results = await self._client._query(
Expand All @@ -219,6 +221,7 @@ async def query(
where=query_request["where"],
where_document=query_request["where_document"],
include=query_request["include"],
max_distance=query_request["max_distance"],
tenant=self.tenant,
database=self.database,
)
Expand Down
5 changes: 4 additions & 1 deletion chromadb/api/models/Collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
ID,
OneOrMany,
WhereDocument,
IncludeEnum,
)

import logging
Expand Down Expand Up @@ -183,6 +182,7 @@ def query(
IncludeEnum.documents,
IncludeEnum.distances,
],
max_distance: Optional[float] = None,
) -> QueryResult:
"""Get the n_results nearest neighbor embeddings for provided query_embeddings or query_texts.

Expand All @@ -195,6 +195,7 @@ def query(
where: A Where type dict used to filter results by. E.g. `{"$and": [{"color" : "red"}, {"price": {"$gte": 4.20}}]}`. Optional.
where_document: A WhereDocument type dict used to filter by the documents. E.g. `{$contains: {"text": "hello"}}`. Optional.
include: A list of what to include in the results. Can contain `"embeddings"`, `"metadatas"`, `"documents"`, `"distances"`. Ids are always included. Defaults to `["metadatas", "documents", "distances"]`. Optional.
max_distance: A float to filter results by distance ≤ this value. Applies to the collection's distance metric. If `None`, no filtering is applied.

Returns:
QueryResult: A QueryResult object containing the results.
Expand All @@ -216,6 +217,7 @@ def query(
where=where,
where_document=where_document,
include=include,
max_distance=max_distance,
)

query_results = self._client._query(
Expand All @@ -225,6 +227,7 @@ def query(
where=query_request["where"],
where_document=query_request["where_document"],
include=query_request["include"],
max_distance=query_request["max_distance"],
tenant=self.tenant,
database=self.database,
)
Expand Down
6 changes: 6 additions & 0 deletions chromadb/api/models/CollectionCommon.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
validate_record_set_contains_any,
validate_record_set_for_embedding,
validate_filter_set,
validate_max_distance,
)

# TODO: We should rename the types in chromadb.types to be Models where
Expand Down Expand Up @@ -229,6 +230,7 @@ def _validate_and_prepare_get_request(
where: Optional[Where],
where_document: Optional[WhereDocument],
include: Include,
max_distance: Optional[float],
) -> GetRequest:
# Unpack
unpacked_ids: Optional[IDs] = maybe_cast_one_to_many(target=ids)
Expand Down Expand Up @@ -257,6 +259,7 @@ def _validate_and_prepare_get_request(
where=filters["where"],
where_document=filters["where_document"],
include=request_include,
max_distance=max_distance,
)

@validation_context("query")
Expand All @@ -275,6 +278,7 @@ def _validate_and_prepare_query_request(
where: Optional[Where],
where_document: Optional[WhereDocument],
include: Include,
max_distance: Optional[float],
) -> QueryRequest:
# Unpack
query_records = normalize_base_record_set(
Expand All @@ -294,6 +298,7 @@ def _validate_and_prepare_query_request(
validate_filter_set(filter_set=filters)
validate_include(include=include)
validate_n_results(n_results=n_results)
validate_max_distance(max_distance=max_distance)

# Prepare
if query_records["embeddings"] is None:
Expand All @@ -315,6 +320,7 @@ def _validate_and_prepare_query_request(
where=request_where,
where_document=request_where_document,
include=request_include,
max_distance=max_distance,
n_results=n_results,
)

Expand Down
6 changes: 5 additions & 1 deletion chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def create_collection(
id=model.id,
name=model.name,
configuration=model.get_configuration(),
segments=[], # Passing empty till backend changes are deployed.
segments=[], # Passing empty till backend changes are deployed.
metadata=model.metadata,
dimension=None, # This is lazily populated on the first add
get_or_create=get_or_create,
Expand Down Expand Up @@ -570,6 +570,7 @@ def _get(
page_size: Optional[int] = None,
where_document: Optional[WhereDocument] = None,
include: Include = ["embeddings", "metadatas", "documents"], # type: ignore[list-item]
max_distance: Optional[float] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> GetResult:
Expand Down Expand Up @@ -629,6 +630,7 @@ def _get(
False,
IncludeEnum.uris in include,
),
max_distance,
)
)

Expand Down Expand Up @@ -753,6 +755,7 @@ def _query(
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
include: Include = ["documents", "metadatas", "distances"], # type: ignore[list-item]
max_distance: Optional[float] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> QueryResult:
Expand Down Expand Up @@ -810,6 +813,7 @@ def _query(
IncludeEnum.distances in include,
IncludeEnum.uris in include,
),
max_distance=max_distance,
)
)

Expand Down
37 changes: 35 additions & 2 deletions chromadb/api/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
from typing import Optional, Set, Union, TypeVar, List, Dict, Any, Tuple, cast
from typing import (
Optional,
Set,
Union,
TypeVar,
List,
Dict,
Any,
Tuple,
cast,
TypeAlias,
)
from numpy.typing import NDArray
import numpy as np
from typing_extensions import TypedDict, Protocol, runtime_checkable
Expand Down Expand Up @@ -50,6 +61,17 @@ def maybe_cast_one_to_many(target: Optional[OneOrMany[T]]) -> Optional[List[T]]:
Embedding = Vector
Embeddings = List[Embedding]

EmbeddingsType: TypeAlias = Optional[
Union[
List[Embeddings],
List[PyEmbeddings],
List[NDArray[Union[np.int32, np.float32]]],
]
]
EmbeddingType: TypeAlias = Optional[
Union[Embeddings, PyEmbeddings, NDArray[Union[np.int32, np.float32]]]
]


def normalize_embeddings(
target: Optional[Union[OneOrMany[Embedding], OneOrMany[PyEmbedding]]]
Expand Down Expand Up @@ -371,6 +393,7 @@ class GetRequest(TypedDict):
where: Optional[Where]
where_document: Optional[WhereDocument]
include: Include
max_distance: Optional[float]


class GetResult(TypedDict):
Expand All @@ -391,6 +414,7 @@ class QueryRequest(TypedDict):
where_document: Optional[WhereDocument]
include: Include
n_results: int
max_distance: Optional[float]


class QueryResult(TypedDict):
Expand Down Expand Up @@ -750,6 +774,15 @@ def validate_n_results(n_results: int) -> int:
return n_results


def validate_max_distance(max_distance: Optional[float] = None) -> None:
"""Validates max_distance to ensure it is a float"""
if max_distance is not None:
if not isinstance(max_distance, float):
raise ValueError(f"Expected max_distance to be a float, got {max_distance}")
if max_distance < 0:
raise ValueError(f"Max distance must be non-negative, got {max_distance}")


def validate_embeddings(embeddings: Embeddings) -> Embeddings:
"""Validates embeddings to ensure it is a list of numpy arrays of ints, or floats"""
if not isinstance(embeddings, (list, np.ndarray)):
Expand Down Expand Up @@ -836,7 +869,7 @@ def validate_batch(


def convert_np_embeddings_to_list(embeddings: Embeddings) -> PyEmbeddings:
return [embedding.tolist() for embedding in embeddings]
return cast(PyEmbeddings, [embedding.tolist() for embedding in embeddings])


def convert_list_embeddings_to_np(embeddings: PyEmbeddings) -> Embeddings:
Expand Down
6 changes: 6 additions & 0 deletions chromadb/execution/executor/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ def knn(self, plan: KNNPlan) -> QueryResult:
)
knns = self._vector_segment(plan.scan.collection).query_vectors(query)

if plan.max_distance is not None:
knns = [
[r for r in result if r["distance"] <= plan.max_distance]
for result in knns
]

ids = [[r["id"] for r in result] for result in knns]
embeddings = None
documents = None
Expand Down
4 changes: 4 additions & 0 deletions chromadb/execution/expression/plan.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from dataclasses import dataclass, field

from typing import Optional

from chromadb.execution.expression.operator import KNN, Filter, Limit, Projection, Scan


Expand All @@ -14,6 +16,7 @@ class GetPlan:
filter: Filter = field(default_factory=Filter)
limit: Limit = field(default_factory=Limit)
projection: Projection = field(default_factory=Projection)
max_distance: Optional[float] = None


@dataclass
Expand All @@ -22,3 +25,4 @@ class KNNPlan:
knn: KNN
filter: Filter = field(default_factory=Filter)
projection: Projection = field(default_factory=Projection)
max_distance: Optional[float] = None
Loading
Loading