Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions so_vector/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
TRUE_KNN_FILENAME_1K: str = "queries-recall-1k.json.bz2"


async def extract_exact_neighbors(
query_vector: List[float], index: str, max_size: int, vector_field: str, request_timeout: Optional[float], filter, client
) -> List[str]:
async def extract_exact_neighbors(query_vector: List[float], index: str, max_size: int, vector_field: str, filter, client) -> List[str]:
if filter is None:
raise ValueError("Filter must be provided for exact neighbors extraction.")
script_query = {
Expand All @@ -29,13 +27,11 @@ async def extract_exact_neighbors(
"_source": False,
"docvalue_fields": ["questionId"],
}
es_kwargs = {"request_timeout": request_timeout} if request_timeout else {}
script_result = await client.search(
body=script_query,
index=index,
request_cache=True,
size=max_size,
**es_kwargs,
)
return [hit["fields"]["questionId"][0] for hit in script_result["hits"]["hits"]]

Expand Down Expand Up @@ -135,9 +131,7 @@ def __init__(self):
def get_query_vectors(self) -> List[List[float]]:
return self._queries

async def get_neighbors_for_query(
self, index: str, query_id: int, size: int, request_timeout: Optional[float], filter, client
) -> List[str]:
async def get_neighbors_for_query(self, index: str, query_id: int, size: int, filter, client) -> List[str]:
# For now, we must calculate the exact neighbors, maybe we should cache this?
# it would have to be cached per query and filter
if filter is not None:
Expand All @@ -147,7 +141,6 @@ async def get_neighbors_for_query(
index=index,
max_size=size,
vector_field="titleVector",
request_timeout=request_timeout,
filter=filter,
client=client,
)
Expand Down Expand Up @@ -176,6 +169,8 @@ def partition(self, partition_index, total_partitions):
return self

def params(self):
request_timeout = self._params.get("request-timeout", None)
optional_params = {"request-timeout": request_timeout} if request_timeout else {}
return {
"index": self._index_name,
"cache": self._params.get("cache", False),
Expand All @@ -184,6 +179,7 @@ def params(self):
"oversample": self._params.get("oversample", -1),
"knn_vector_store": KnnVectorStore(),
"filter": self._params.get("filter", None),
**optional_params,
}


Expand Down Expand Up @@ -215,19 +211,20 @@ async def __call__(self, es, params):
min_recall = k
max_recall = 0

if request_timeout:
es = es.options(request_timeout=request_timeout)

knn_vector_store: KnnVectorStore = params["knn_vector_store"]
for query_id, query_vector in enumerate(knn_vector_store.get_query_vectors()):
knn_body = self.get_knn_query(query_vector, k, num_candidates, filter, params["oversample"])
es_kwargs = {"request_timeout": request_timeout} if request_timeout else {}
knn_result = await es.search(
body=knn_body,
index=index,
request_cache=request_cache,
size=k,
**es_kwargs,
)
knn_hits = [hit["fields"]["questionId"][0] for hit in knn_result["hits"]["hits"]]
true_neighbors = await knn_vector_store.get_neighbors_for_query(index, query_id, k, request_timeout, filter, es)
true_neighbors = await knn_vector_store.get_neighbors_for_query(index, query_id, k, filter, es)
current_recall = len(set(knn_hits).intersection(set(true_neighbors)))
recall_total += current_recall
exact_total += len(true_neighbors)
Expand Down
Loading