diff --git a/so_vector/operations/default.json b/so_vector/operations/default.json index 93b7591bf..a33f0fb29 100644 --- a/so_vector/operations/default.json +++ b/so_vector/operations/default.json @@ -13,6 +13,7 @@ { "name": "knn-recall-10-50-acceptedAnswerId", "operation-type": "knn-recall", + "request-timeout": 600, "param-source": "knn-recall-param-source", "k": 10, "num_candidates": 50, @@ -40,6 +41,7 @@ { "name": "knn-recall-10-50-java", "operation-type": "knn-recall", + "request-timeout": 600, "param-source": "knn-recall-param-source", "k": 10, "num_candidates": 50, @@ -69,6 +71,7 @@ { "name": "knn-recall-10-50-css", "operation-type": "knn-recall", + "request-timeout": 600, "param-source": "knn-recall-param-source", "k": 10, "num_candidates": 50, @@ -98,6 +101,7 @@ { "name": "knn-recall-10-50-concurrency", "operation-type": "knn-recall", + "request-timeout": 600, "param-source": "knn-recall-param-source", "k": 10, "num_candidates": 50, @@ -127,6 +131,7 @@ { "name": "knn-recall-10-50-random-10-percent", "operation-type": "knn-recall", + "request-timeout": 600, "param-source": "knn-recall-param-source", "k": 10, "num_candidates": 50, @@ -156,6 +161,7 @@ { "name": "knn-recall-100-300-random-10-percent", "operation-type": "knn-recall", + "request-timeout": 600, "param-source": "knn-recall-param-source", "k": 100, "num_candidates": 300, @@ -185,6 +191,7 @@ { "name": "knn-recall-10-50-random-20-percent", "operation-type": "knn-recall", + "request-timeout": 600, "param-source": "knn-recall-param-source", "k": 10, "num_candidates": 50, @@ -214,6 +221,7 @@ { "name": "knn-recall-100-300-random-20-percent", "operation-type": "knn-recall", + "request-timeout": 600, "param-source": "knn-recall-param-source", "k": 100, "num_candidates": 300, @@ -234,6 +242,7 @@ { "name": "knn-recall-default-match-all", "operation-type": "knn-recall", + "request-timeout": 600, "param-source": "knn-recall-param-source", "include-in-reporting": false }, @@ -253,6 +262,7 @@ { "name": "knn-recall-default-random-20-percent", "operation-type": "knn-recall", + "request-timeout": 600, "param-source": "knn-recall-param-source", "include-in-reporting": false, "filter": { @@ -273,6 +283,7 @@ { "name": "knn-recall-10-50-match-all", "operation-type": "knn-recall", + "request-timeout": 600, "param-source": "knn-recall-param-source", "k": 10, "num_candidates": 50, @@ -288,6 +299,7 @@ { "name": "knn-recall-100-300-match-all", "operation-type": "knn-recall", + "request-timeout": 600, "param-source": "knn-recall-param-source", "k": 100, "num_candidates": 300, diff --git a/so_vector/track.py b/so_vector/track.py index d95156aa7..aa6e13927 100644 --- a/so_vector/track.py +++ b/so_vector/track.py @@ -2,7 +2,7 @@ import json import logging import os -from typing import Any, List +from typing import Any, List, Optional logger = logging.getLogger(__name__) QUERIES_FILENAME: str = "queries.json.bz2" @@ -11,7 +11,9 @@ TRUE_KNN_FILENAME_1K: str = "queries-recall-1k.json.bz2" -async def extract_exact_neighbors(query_vector: List[float], index: str, max_size: int, vector_field: str, filter, client) -> List[str]: +async def extract_exact_neighbors( + query_vector: List[float], index: str, max_size: int, vector_field: str, request_timeout: Optional[float], filter, client +) -> List[str]: if filter is None: raise ValueError("Filter must be provided for exact neighbors extraction.") script_query = { @@ -27,11 +29,13 @@ async def extract_exact_neighbors(query_vector: List[float], index: str, max_siz "_source": False, "docvalue_fields": ["questionId"], } + es_kwargs = {"request_timeout": request_timeout} if request_timeout else {} script_result = await client.search( body=script_query, index=index, request_cache=True, size=max_size, + **es_kwargs, ) return [hit["fields"]["questionId"][0] for hit in script_result["hits"]["hits"]] @@ -131,7 +135,9 @@ def __init__(self): def get_query_vectors(self) -> List[List[float]]: return self._queries - async def get_neighbors_for_query(self, index: str, query_id: int, size: int, filter, client) -> List[str]: + async def get_neighbors_for_query( + self, index: str, query_id: int, size: int, request_timeout: Optional[float], filter, client + ) -> List[str]: # For now, we must calculate the exact neighbors, maybe we should cache this? # it would have to be cached per query and filter if filter is not None: @@ -141,6 +147,7 @@ async def get_neighbors_for_query(self, index: str, query_id: int, size: int, fi index=index, max_size=size, vector_field="titleVector", + request_timeout=request_timeout, filter=filter, client=client, ) @@ -200,6 +207,7 @@ async def __call__(self, es, params): k = params["size"] num_candidates = params["num_candidates"] index = params["index"] + request_timeout = params.get("request-timeout", None) request_cache = params["cache"] filter = params["filter"] recall_total = 0 @@ -210,14 +218,16 @@ async def __call__(self, es, params): knn_vector_store: KnnVectorStore = params["knn_vector_store"] for query_id, query_vector in enumerate(knn_vector_store.get_query_vectors()): knn_body = self.get_knn_query(query_vector, k, num_candidates, filter, params["oversample"]) + es_kwargs = {"request_timeout": request_timeout} if request_timeout else {} knn_result = await es.search( body=knn_body, index=index, request_cache=request_cache, size=k, + **es_kwargs, ) knn_hits = [hit["fields"]["questionId"][0] for hit in knn_result["hits"]["hits"]] - true_neighbors = await knn_vector_store.get_neighbors_for_query(index, query_id, k, filter, es) + true_neighbors = await knn_vector_store.get_neighbors_for_query(index, query_id, k, request_timeout, filter, es) current_recall = len(set(knn_hits).intersection(set(true_neighbors))) recall_total += current_recall exact_total += len(true_neighbors)