diff --git a/so_vector/track.py b/so_vector/track.py index aa6e13927..38d7e7d04 100644 --- a/so_vector/track.py +++ b/so_vector/track.py @@ -11,9 +11,7 @@ TRUE_KNN_FILENAME_1K: str = "queries-recall-1k.json.bz2" -async def extract_exact_neighbors( - query_vector: List[float], index: str, max_size: int, vector_field: str, request_timeout: Optional[float], filter, client -) -> List[str]: +async def extract_exact_neighbors(query_vector: List[float], index: str, max_size: int, vector_field: str, filter, client) -> List[str]: if filter is None: raise ValueError("Filter must be provided for exact neighbors extraction.") script_query = { @@ -29,13 +27,11 @@ async def extract_exact_neighbors( "_source": False, "docvalue_fields": ["questionId"], } - es_kwargs = {"request_timeout": request_timeout} if request_timeout else {} script_result = await client.search( body=script_query, index=index, request_cache=True, size=max_size, - **es_kwargs, ) return [hit["fields"]["questionId"][0] for hit in script_result["hits"]["hits"]] @@ -135,9 +131,7 @@ def __init__(self): def get_query_vectors(self) -> List[List[float]]: return self._queries - async def get_neighbors_for_query( - self, index: str, query_id: int, size: int, request_timeout: Optional[float], filter, client - ) -> List[str]: + async def get_neighbors_for_query(self, index: str, query_id: int, size: int, filter, client) -> List[str]: # For now, we must calculate the exact neighbors, maybe we should cache this? # it would have to be cached per query and filter if filter is not None: @@ -147,7 +141,6 @@ async def get_neighbors_for_query( index=index, max_size=size, vector_field="titleVector", - request_timeout=request_timeout, filter=filter, client=client, ) @@ -176,6 +169,8 @@ def partition(self, partition_index, total_partitions): return self def params(self): + request_timeout = self._params.get("request-timeout", None) + optional_params = {"request-timeout": request_timeout} if request_timeout else {} return { "index": self._index_name, "cache": self._params.get("cache", False), @@ -184,6 +179,7 @@ def params(self): "oversample": self._params.get("oversample", -1), "knn_vector_store": KnnVectorStore(), "filter": self._params.get("filter", None), + **optional_params, } @@ -215,19 +211,20 @@ async def __call__(self, es, params): min_recall = k max_recall = 0 + if request_timeout: + es = es.options(request_timeout=request_timeout) + knn_vector_store: KnnVectorStore = params["knn_vector_store"] for query_id, query_vector in enumerate(knn_vector_store.get_query_vectors()): knn_body = self.get_knn_query(query_vector, k, num_candidates, filter, params["oversample"]) - es_kwargs = {"request_timeout": request_timeout} if request_timeout else {} knn_result = await es.search( body=knn_body, index=index, request_cache=request_cache, size=k, - **es_kwargs, ) knn_hits = [hit["fields"]["questionId"][0] for hit in knn_result["hits"]["hits"]] - true_neighbors = await knn_vector_store.get_neighbors_for_query(index, query_id, k, request_timeout, filter, es) + true_neighbors = await knn_vector_store.get_neighbors_for_query(index, query_id, k, filter, es) current_recall = len(set(knn_hits).intersection(set(true_neighbors))) recall_total += current_recall exact_total += len(true_neighbors)