Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
254 changes: 254 additions & 0 deletions benchs/bench_fw/optimize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from copy import copy
import logging
import math
from dataclasses import dataclass

import faiss # @manual=//faiss/python:pyfaiss_gpu

from faiss.contrib.evaluation import ( # @manual=//faiss/contrib:faiss_contrib_gpu
OperatingPoints,
)

from .benchmark import Benchmark
from .descriptors import DatasetDescriptor, IndexDescriptor
from .index import IndexFromFactory

logger = logging.getLogger(__name__)


@dataclass
class Optimizer:
distance_metric: str = "L2"

def __post_init__(self):
self.cached_benchmark = None
if self.distance_metric == "IP":
self.distance_metric_type = faiss.METRIC_INNER_PRODUCT
elif self.distance_metric == "L2":
self.distance_metric_type = faiss.METRIC_L2
else:
raise ValueError

def set_io(self, benchmark_io):
self.io = benchmark_io
self.io.distance_metric = self.distance_metric
self.io.distance_metric_type = self.distance_metric_type

def benchmark_candidates(
self, index_descs, training_vectors, database_vectors, query_vectors, result_file
):
benchmark = Benchmark(
training_vectors=training_vectors,
database_vectors=database_vectors,
query_vectors=query_vectors,
index_descs=index_descs,
k=10,
distance_metric=self.distance_metric,
)
benchmark.set_io(self.io)
results = benchmark.benchmark(result_file)
# results = set(v['index'] for k, v in results['experiments'].items() if ".knn" in k and v['knn_intersection'] > 0.8)

op = OperatingPoints()
for k, v in results["experiments"].items():
if ".knn" in k:
op.add_operating_point(v, v["knn_intersection"], v["time"])
# op.add_operating_point(v, v['knn_intersection'], Cost((v['time'], results['indices'][v['codec']]['code_size'])))

# results = set(v["factory"] for v, _, _ in op.operating_points)
results = [
IndexDescriptor(
factory=v["factory"], construction_params=v["construction_params"], search_params=v["search_params"]
)
for v, _, _ in op.operating_points
]
# breakpoint()
return results

def get_fine_factory(
self, d: int, scale: int, hnsw: bool = True, refine: bool = False
):
options_pq_bits = {
4: [2, 4],
6: [2, 4, 6, 8],
8: [2, 4, 8, 16],
# 12: [2, 4, 8, 16],
}
options_code_size_log2 = range(3, 16)
options_with_opq = [False] # [True, False]

fs = [] # index descriptors
ps = [] # pretransforms
cs = [] # code sizes
for factory, code_size in [
("Flat", d * 4),
("PQ64x4fs,Refine(SQfp16)", d * 2),
# ("SQfp16", d * 2),
# ("SQ8", d),
]:
fs.append(IndexDescriptor(factory=factory))
ps.append(None)
cs.append(code_size)

if hnsw and scale < 1_000_000:
for factory, code_size in [
("HNSW32", d * 4 + 32 * 8),
# ("HNSW64", d * 4 + 64 * 8),
]:
fs.append(IndexDescriptor(factory=factory))
ps.append(None)
cs.append(code_size)
# if scale < 16384:
return fs, ps, cs
for code_size_log2 in options_code_size_log2:
code_size = 2**code_size_log2
if code_size >= d:
continue
for pq_bits, dimensions_per_pq_bits in options_pq_bits.items():
if code_size * 8 % pq_bits > 0:
continue
M = int(code_size * 8 / pq_bits)
for dppb in dimensions_per_pq_bits:
d_out = M * dppb
if d_out > d:
continue
for with_opq in options_with_opq:
if d_out != d and not with_opq:
continue
factory = ""
pt = None
if with_opq:
pt = f"OPQ{M}_{d_out}"
factory = pt + ","
ps.append(pt)
factory += f"PQ{M}x{pq_bits}"
if pq_bits == 4 and not refine:
factory += "fs"
# no refinement
fs.append(IndexDescriptor(factory=factory))
cs.append(code_size)
# adding refinement options on top of fastscan
if not refine and pq_bits == 4:
(
refiners,
_,
refiner_code_sizes,
) = self.get_fine_factory(
d, scale, hnsw=False, refine=True
)
for refiner, refiner_code_size in zip(
refiners, refiner_code_sizes
):
if refiner_code_size < code_size:
continue
fs.append(
IndexDescriptor(
factory=f"{factory},Refine({refiner.factory})"
)
)
cs.append(code_size + refiner_code_size)

return fs, ps, cs

def optimize(
self,
d: int,
scale: int,
training_vectors: DatasetDescriptor,
database_vectors: DatasetDescriptor,
query_vectors: DatasetDescriptor,
):
fine_descs, _, _ = [
IndexDescriptor(factory="Flat"),
IndexDescriptor(factory="PQ256x4fs,Refine(SQfp16)"),
IndexDescriptor(factory="HNSW32"),
], None, None # self.get_fine_factory(d, scale)
nlist_log2_min = max(math.ceil(math.log2(math.sqrt(scale))), 10)
nlist_log2_max = min(math.floor(math.log2(scale / 50)), 19) + 1
if nlist_log2_min >= nlist_log2_max:
return self.benchmark_candidates(
fine_descs, training_vectors, database_vectors, query_vectors, f"result_{d}_{scale}.json"
)
ivf_descs = []
# fine_ivf_descs, pretransforms, _ = [IndexDescriptor(factory="Flat")], [None], None
# if scale < 1_000_000:
fine_ivf_descs = [IndexDescriptor(factory="PQ256x4fs,Refine(SQfp16)")]
pretransforms = [None]
# self.get_fine_factory(
# d, scale, hnsw=False
#)
for nlist_log2 in range(nlist_log2_min, nlist_log2_max):
nlist = 2**nlist_log2
done = set()
for pt in pretransforms:
if pt in done:
continue
done.add(pt)

# pretransform
if pt is None:
transformed_training_vectors = training_vectors
transformed_query_vectors = query_vectors
else:
pretransform_index = IndexFromFactory(pt + ",Flat")
pretransform_index.set_io(self.io)
transformed_training_vectors = (
pretransform_index.transform(training_vectors)
)
transformed_query_vectors = pretransform_index.transform(
query_vectors
)

# cluster
centroids = transformed_training_vectors.k_means(
self.io, nlist
)
d = self.io.get_dataset(centroids).shape[1]

# optimize quantizer
quantizer_descs = self.optimize(
d=d,
scale=nlist,
training_vectors=centroids,
database_vectors=centroids,
query_vectors=transformed_query_vectors,
)

# build IVF index
for quantizer_desc in quantizer_descs:
for fine_ivf_desc, pretransform in zip(
fine_ivf_descs, pretransforms
):
if pretransform != pt:
continue
if pretransform is None:
pretransform = ""
else:
pretransform = pretransform + ","
if quantizer_desc.construction_params is None:
construction_params = [None, quantizer_desc.search_params]
else:
construction_params = [None] + copy(quantizer_desc.construction_params)
if construction_params[1] is None:
construction_params[1] = quantizer_desc.search_params
elif quantizer_desc.search_params is not None:
construction_params[1] += quantizer_desc.search_params
# breakpoint()
ivf_descs.append(
IndexDescriptor(
factory=f"{pretransform}IVF{nlist}({quantizer_desc.factory}),{fine_ivf_desc.factory}",
construction_params=construction_params,
)
)
return self.benchmark_candidates(
fine_descs + ivf_descs,
training_vectors,
database_vectors,
query_vectors,
f"result_{d}_{scale}.json",
)
37 changes: 37 additions & 0 deletions benchs/bench_fw_optimize_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging

from bench_fw.benchmark_io import BenchmarkIO
from bench_fw.descriptors import DatasetDescriptor
from bench_fw.optimize import Optimizer

logging.basicConfig(level=logging.INFO)

optimizer = Optimizer(
distance_metric="L2",
)
io = BenchmarkIO(
path="/checkpoint/gsz/bench_fw/ivf",
)
optimizer.set_io(io)
training_vectors = DatasetDescriptor(
namespace="std_t", tablename="bigann1M"
)
xt = io.get_dataset(training_vectors)
scale = 10_000_000
index_descs = optimizer.optimize(
d=xt.shape[1],
scale=scale,
training_vectors=training_vectors,
database_vectors=DatasetDescriptor(
namespace="std_d", tablename=f"bigann10M"
),
query_vectors=DatasetDescriptor(
namespace="std_q", tablename="bigann1M"
),
)
print(index_descs)