diff --git a/benchs/bench_fw/benchmark.py b/benchs/bench_fw/benchmark.py index 1053f99388..8ca68c4cd8 100644 --- a/benchs/bench_fw/benchmark.py +++ b/benchs/bench_fw/benchmark.py @@ -208,9 +208,11 @@ def set_io(self, benchmark_io): self.io.distance_metric = self.distance_metric self.io.distance_metric_type = self.distance_metric_type - def get_index_desc(self, factory: str) -> Optional[IndexDescriptor]: + def get_index_desc(self, factory_or_codec: str) -> Optional[IndexDescriptor]: for desc in self.index_descs: - if desc.factory == factory: + if desc.factory == factory_or_codec: + return desc + if desc.codec_alias == factory_or_codec: return desc return None @@ -232,7 +234,7 @@ def range_search_reference(self, index, parameters, range_metric): parameters, radius=m_radius, ) - flat = index.factory == "Flat" + flat = index.is_flat_index() ( gt_radius, range_search_metric_function, @@ -650,6 +652,7 @@ def benchmark( f"Range index {index_desc.factory} has no radius_score" ) results["metrics"] = {} + self.build_index_wrapper(index_desc) for metric_key, range_metric in index_desc.range_metrics.items(): ( gt_radius, diff --git a/benchs/bench_fw/descriptors.py b/benchs/bench_fw/descriptors.py index f1dd7354c2..173b07ce16 100644 --- a/benchs/bench_fw/descriptors.py +++ b/benchs/bench_fw/descriptors.py @@ -20,6 +20,7 @@ class IndexDescriptor: # but not both at the same time. path: Optional[str] = None factory: Optional[str] = None + codec_alias: Optional[str] = None construction_params: Optional[List[Dict[str, int]]] = None search_params: Optional[Dict[str, int]] = None # range metric definitions diff --git a/benchs/bench_fw/index.py b/benchs/bench_fw/index.py index 14f2158e64..3deaa4afcf 100644 --- a/benchs/bench_fw/index.py +++ b/benchs/bench_fw/index.py @@ -495,7 +495,7 @@ def range_search( radius: Optional[float] = None, ): logger.info("range_search: begin") - if search_parameters is not None and search_parameters["snap"] == 1: + if search_parameters is not None and search_parameters.get("snap") == 1: query_vectors = self.snap(query_vectors) filename = ( self.get_range_search_name( @@ -776,6 +776,9 @@ def add_range_or_val(name, range): ) return op + def is_flat_index(self): + return self.get_index_name().startswith("Flat") + # IndexFromCodec, IndexFromQuantizer and IndexFromPreTransform # are used to wrap pre-trained Faiss indices (codecs) @@ -807,6 +810,9 @@ def get_codec_name(self): name += Index.param_dict_list_to_name(self.construction_params) return name + def fetch_meta(self, dry_run=False): + return None, None + def fetch_codec(self): codec = self.io.read_index( os.path.basename(self.path), @@ -911,7 +917,7 @@ def fetch_codec(self, dry_run=False): assert codec_size is not None meta = { "training_time": training_time, - "training_size": self.training_vectors.num_vectors, + "training_size": self.training_vectors.num_vectors if self.training_vectors else 0, "codec_size": codec_size, "sa_code_size": self.get_sa_code_size(codec), "code_size": self.get_code_size(codec),