Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions so_vector/challenges/default.json
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@
"iterations": 100,
"clients": 1
},
{
"name": "knn-recall-10-50-match-all",
"operation": "knn-recall-10-50-match-all",
"warmup-iterations": 1,
"iterations": 1,
"clients": 1
},
{
"name": "script-score-query-match-all",
"operation": "script-score-query-match-all",
Expand Down Expand Up @@ -132,6 +139,13 @@
"iterations": 100,
"clients": 1
},
{
"name": "knn-recall-10-50-match-all-force-merge",
"operation": "knn-recall-10-50-match-all",
"warmup-iterations": 1,
"iterations": 1,
"clients": 1
},
{
"name": "knn-search-10-50-acceptedAnswerId-force-merge",
"operation": "knn-search-10-50-acceptedAnswerId",
Expand Down
2 changes: 1 addition & 1 deletion so_vector/index.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"type": "keyword"
},
"questionId": {
"type": "keyword"
"type": "long"
},
"creationDate": {
"type": "date"
Expand Down
8 changes: 8 additions & 0 deletions so_vector/operations/default.json
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@
"k": 10,
"num_candidates": 50
},
{
"name": "knn-recall-10-50-match-all",
"operation-type": "knn-recall",
"param-source": "knn-recall-param-source",
"k": 10,
"num_candidates": 50,
"include-in-reporting": false
},
{
"name": "script-score-query-css",
"operation-type": "search",
Expand Down
Binary file added so_vector/queries-1k.json.bz2
Binary file not shown.
Binary file added so_vector/queries-recall-1k.json.bz2
Binary file not shown.
Binary file added so_vector/queries-recall.json.bz2
Binary file not shown.
1 change: 0 additions & 1 deletion so_vector/queries.json

This file was deleted.

Binary file added so_vector/queries.json.bz2
Binary file not shown.
152 changes: 146 additions & 6 deletions so_vector/track.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
import bz2
import json
import logging
import os
from typing import Any, List

logger = logging.getLogger(__name__)
QUERIES_FILENAME: str = "queries.json.bz2"
TRUE_KNN_FILENAME: str = "queries-recall.json.bz2"
QUERIES_FILENAME_1K: str = "queries-1k.json.bz2"
TRUE_KNN_FILENAME_1K: str = "queries-recall-1k.json.bz2"


def compute_percentile(data: List[Any], percentile):
size = len(data)
if size <= 0:
return None
sorted_data = sorted(data)
index = int(round(percentile * size / 100)) - 1
return sorted_data[max(min(index, size - 1), 0)]


class KnnParamSource:
Expand All @@ -15,18 +33,30 @@ def __init__(self, track, params, **kwargs):
self._cache = params.get("cache", False)
self._exact_scan = params.get("exact", False)
self._params = params
self._queries = []

cwd = os.path.dirname(__file__)
with open(os.path.join(cwd, "queries.json"), "r") as file:
lines = file.readlines()
self._queries = [json.loads(line) for line in lines]
with bz2.open(os.path.join(cwd, QUERIES_FILENAME), "r") as queries_file:
for vector_query in queries_file:
self._queries.append(json.loads(vector_query))
self.infinite = True
self._iters = 0
self._maxIters = len(self._queries)

def partition(self, partition_index, total_partitions):
return self

def params(self):
result = {"index": self._index_name, "cache": self._params.get("cache", False), "size": self._params.get("k", 10)}
num_candidates = self._params.get("num_candidates", 50)
# if -1, then its unset. If set, just set it.
oversample = self._params.get("oversample", -1)
if oversample > -1 and self._exact_scan:
raise ValueError("Oversampling is not supported for exact scan queries.")
query_vec = self._queries[self._iters]
self._iters += 1
if self._iters >= self._maxIters:
self._iters = 0

if self._exact_scan:
result["body"] = {
Expand All @@ -35,7 +65,7 @@ def params(self):
"query": {"match_all": {}},
"script": {
"source": "dotProduct(params.query, 'titleVector') + 1.0",
"params": {"query": self._queries[0]},
"params": {"query": query_vec},
},
}
},
Expand All @@ -47,16 +77,126 @@ def params(self):
result["body"] = {
"knn": {
"field": "titleVector",
"query_vector": self._queries[0],
"query_vector": query_vec,
"k": self._params.get("k", 10),
"num_candidates": self._params.get("num-candidates", 50),
"num_candidates": self._params.get("num_candidates", 50),
},
"_source": False,
}
if "filter" in self._params:
result["body"]["knn"]["filter"] = self._params["filter"]
if oversample > -1:
result["body"]["knn"]["rescore_vector"] = {"oversample": oversample}

return result


class KnnVectorStore:
def __init__(self):
cwd = os.path.dirname(__file__)
self._query_nearest_neighbor_docids = []
self._queries = []
with bz2.open(os.path.join(cwd, TRUE_KNN_FILENAME), "r") as queries_file:
for docids in queries_file:
self._query_nearest_neighbor_docids.append(json.loads(docids))
with bz2.open(os.path.join(cwd, QUERIES_FILENAME), "r") as queries_file:
for vector_query in queries_file:
self._queries.append(json.loads(vector_query))

def get_query_vectors(self) -> List[List[float]]:
return self._queries

def get_neighbors_for_query(self, query_id: int, size: int) -> List[str]:
if (query_id < 0) or (query_id >= len(self._query_nearest_neighbor_docids)):
raise ValueError(f"Unknown query with id: '{query_id}' provided")
if (size < 0) or (size > len(self._query_nearest_neighbor_docids[query_id])):
raise ValueError(f"Invalid size: '{size}' provided for query with id: '{query_id}'")
return self._query_nearest_neighbor_docids[query_id][:size]


class KnnRecallParamSource:
def __init__(self, track, params, **kwargs):
if len(track.indices) == 1:
default_index = track.indices[0].name
else:
default_index = "_all"

self._index_name = params.get("index", default_index)
self._cache = params.get("cache", False)
self._params = params
self.infinite = True
cwd = os.path.dirname(__file__)

def partition(self, partition_index, total_partitions):
return self

def params(self):
return {
"index": self._index_name,
"cache": self._params.get("cache", False),
"size": self._params.get("k", 10),
"num_candidates": self._params.get("num_candidates", 50),
"oversample": self._params.get("oversample", -1),
"knn_vector_store": KnnVectorStore(),
}


# Used in tandem with the KnnRecallParamSource.
# reads the queries, executes knn search and compares the results with the true nearest neighbors
class KnnRecallRunner:
def get_knn_query(self, query_vec, k, num_candidates, oversample):
knn = {
"field": "titleVector",
"query_vector": query_vec,
"k": k,
"num_candidates": num_candidates,
}
if oversample > -1:
knn["rescore_vector"] = {"oversample": oversample}
return {"knn": knn, "_source": False, "docvalue_fields": ["questionId"]}

async def __call__(self, es, params):
k = params["size"]
num_candidates = params["num_candidates"]
index = params["index"]
request_cache = params["cache"]
recall_total = 0
exact_total = 0
min_recall = k
max_recall = 0

knn_vector_store: KnnVectorStore = params["knn_vector_store"]
for query_id, query_vector in enumerate(knn_vector_store.get_query_vectors()):
knn_body = self.get_knn_query(query_vector, k, num_candidates, params["oversample"])
knn_result = await es.search(
body=knn_body,
index=index,
request_cache=request_cache,
size=k,
)
knn_hits = [hit["fields"]["questionId"][0] for hit in knn_result["hits"]["hits"]]
true_neighbors = knn_vector_store.get_neighbors_for_query(query_id, k)[:k]
current_recall = len(set(knn_hits).intersection(set(true_neighbors)))
recall_total += current_recall
exact_total += len(true_neighbors)
min_recall = min(min_recall, current_recall)
max_recall = max(max_recall, current_recall)
to_return = {
"avg_recall": recall_total / exact_total,
"min_recall": min_recall,
"max_recall": max_recall,
"k": k,
"num_candidates": num_candidates,
"oversample": params["oversample"],
}
logger.info(f"Recall results: {to_return}")
return to_return

def __repr__(self, *args, **kwargs):
return "knn-recall"


def register(registry):
registry.register_param_source("knn-param-source", KnnParamSource)
registry.register_param_source("knn-recall-param-source", KnnRecallParamSource)
registry.register_runner("knn-recall", KnnRecallRunner(), async_runner=True)
Loading