Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 48 additions & 16 deletions benchs/bench_fw/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,8 @@ def build_index_wrapper(self, knn_desc: KnnDescriptor):
if hasattr(knn_desc, "index"):
return

if knn_desc.index_desc.index is not None:
assert knn_desc.index_desc is not None
if hasattr(knn_desc.index_desc, "index"):
knn_desc.index = knn_desc.index_desc.index
knn_desc.index.knn_name = knn_desc.get_name()
knn_desc.index.search_params = knn_desc.search_params
Expand All @@ -359,6 +360,7 @@ def build_index_wrapper(self, knn_desc: KnnDescriptor):
metric=self.distance_metric,
bucket=knn_desc.index_desc.bucket,
index_path=knn_desc.index_desc.path,
index_name=knn_desc.index_desc.get_name(),
# knn_name=knn_desc.get_name(),
search_params=knn_desc.search_params,
)
Expand Down Expand Up @@ -544,15 +546,40 @@ def experiment(parameters, cost_metric, perf_metric):
def knn_search_benchmark(
self, dry_run, results: Dict[str, Any], knn_desc: KnnDescriptor
):
gt_knn_D = None
gt_knn_I = None
if hasattr(self, "gt_knn_D"):
gt_knn_D = self.gt_knn_D
gt_knn_I = self.gt_knn_I

assert hasattr(knn_desc, "index")
if not knn_desc.index.is_flat_index() and gt_knn_I is None:
key = knn_desc.index.get_knn_search_name(
search_parameters=knn_desc.search_params,
query_vectors=knn_desc.query_dataset,
k=knn_desc.k,
reconstruct=False,
)
metrics, requires = knn_desc.index.knn_search(
dry_run,
knn_desc.search_params,
knn_desc.query_dataset,
knn_desc.k,
)[3:]
if requires is not None:
return results, requires
results["experiments"][key] = metrics
return results, requires

return self.search_benchmark(
name="knn_search",
search_func=lambda parameters: knn_desc.index.knn_search(
dry_run,
parameters,
knn_desc.query_dataset,
knn_desc.k,
self.gt_knn_I,
self.gt_knn_D,
gt_knn_I,
gt_knn_D,
)[3:],
key_func=lambda parameters: knn_desc.index.get_knn_search_name(
search_parameters=parameters,
Expand Down Expand Up @@ -634,6 +661,7 @@ class ExecutionOperator:
train_op: Optional[TrainOperator] = None
build_op: Optional[BuildOperator] = None
search_op: Optional[SearchOperator] = None
compute_gt: bool = True

def __post_init__(self):
if self.distance_metric == "IP":
Expand Down Expand Up @@ -698,16 +726,11 @@ def search_one(
faiss.omp_set_num_threads(self.num_threads)
assert self.search_op is not None

if not dry_run:
if not dry_run and self.compute_gt:
self.create_gt_knn(knn_desc)
self.create_range_ref_knn(knn_desc)

self.search_op.build_index_wrapper(knn_desc)
meta, requires = knn_desc.index.fetch_meta(dry_run=dry_run)
if requires is not None:
# return results, (requires if train else None)
return results, requires
results["indices"][knn_desc.index.get_codec_name()] = meta

# results, requires = self.reconstruct_benchmark(
# dry_run=True,
Expand Down Expand Up @@ -766,9 +789,11 @@ def search_one(
ref_index_desc.search_params,
range_metric,
)
gt_rsm = self.search_op.range_ground_truth(
gt_radius, range_search_metric_function
)
gt_rsm = None
if self.compute_gt:
gt_rsm = self.search_op.range_ground_truth(
gt_radius, range_search_metric_function
)
results, requires = self.search_op.range_search_benchmark(
dry_run=True,
results=results,
Expand Down Expand Up @@ -847,9 +872,13 @@ def create_gt_knn(self, knn_desc, search=True) -> Optional[KnnDescriptor]:
if self.search_op:
gt_knn_desc = self.search_op.get_flat_desc(knn_desc.flat_name())
if gt_knn_desc is None:
gt_index_desc = self.build_op.get_flat_desc(
knn_desc.index_desc.flat_name()
)
if knn_desc.index_desc is not None:
gt_index_desc = knn_desc.gt_index_desc
else:
gt_index_desc = self.build_op.get_flat_desc(
knn_desc.index_desc.flat_name()
)
knn_desc.gt_index_desc = gt_index_desc
assert gt_index_desc is not None
gt_knn_desc = KnnDescriptor(
d=knn_desc.d,
Expand Down Expand Up @@ -933,7 +962,10 @@ def execute(self, results: Dict[str, Any], dry_run: False):
if self.search_op is not None:
for desc in self.search_op.knn_descs:
results, requires = self.search_one(
knn_desc=desc, results=results, dry_run=dry_run, range=self.search_op.range
knn_desc=desc,
results=results,
dry_run=dry_run,
range=self.search_op.range,
)
if dry_run:
if requires is None:
Expand Down
64 changes: 46 additions & 18 deletions benchs/bench_fw/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,10 @@ def knn_search(
D_gt=None,
):
logger.info("knn_search: begin")
if search_parameters is not None and search_parameters["snap"] == 1:
if (
search_parameters is not None and
search_parameters.get("snap", 0) == 1
):
query_vectors = self.snap(query_vectors)
filename = (
self.get_knn_search_name(search_parameters, query_vectors, k)
Expand Down Expand Up @@ -322,7 +325,11 @@ def knn_search(
else:
xq = self.io.get_dataset(query_vectors)
(D, I), t, _ = timer("knn_search", lambda: index.search(xq, k))
if self.is_flat() or not hasattr(self, "database_vectors"): # TODO
if (
self.is_flat() or
not hasattr(self, "database_vectors") or
(self.database_vectors is None)
): # TODO
R = D
else:
xq = self.io.get_dataset(query_vectors)
Expand Down Expand Up @@ -352,20 +359,24 @@ def knn_search(
"factory": self.get_model_name(),
"construction_params": self.get_construction_params(),
"search_params": search_parameters,
"knn_intersection": knn_intersection_measure(
I,
I_gt,
)
if I_gt is not None
else None,
"distance_ratio": distance_ratio_measure(
I,
R,
D_gt,
self.metric_type,
)
if D_gt is not None
else None,
"knn_intersection": (
knn_intersection_measure(
I,
I_gt,
)
if I_gt is not None
else None
),
"distance_ratio": (
distance_ratio_measure(
I,
R,
D_gt,
self.metric_type,
)
if D_gt is not None
else None
),
}
logger.info("knn_search: end")
return D, I, R, P, None
Expand Down Expand Up @@ -467,7 +478,10 @@ def range_search(
radius: Optional[float] = None,
):
logger.info("range_search: begin")
if search_parameters is not None and search_parameters.get("snap") == 1:
if (
search_parameters is not None and
search_parameters.get("snap", 0) == 1
):
query_vectors = self.snap(query_vectors)
filename = (
self.get_range_search_name(
Expand Down Expand Up @@ -607,6 +621,12 @@ def get_codec(self):
Index.cached_codec.popitem(last=False)
return Index.cached_codec[codec_name]

def get_model(self):
return self.get_index()

def get_model_name(self):
return self.get_index_name()

def get_codec_name(self) -> Optional[str]:
return self.codec_name

Expand Down Expand Up @@ -709,6 +729,11 @@ def get_operating_points(self):
def add_range_or_val(name, range):
op.add_range(
name,
(
[self.search_params[name]]
if self.search_params and name in self.search_params
else range
),
[self.search_params[name]]
if self.search_params and name in self.search_params
else range,
Expand Down Expand Up @@ -808,7 +833,10 @@ def get_pretransform(self):
return quantizer

def get_model_name(self):
return os.path.basename(self.path)
if self.path is not None:
return os.path.basename(self.path)
else:
return self.get_codec_name()

def fetch_meta(self, dry_run=False):
return None, None
Expand Down