Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 227 additions & 13 deletions demos/index_pq_flat_separate_codes_from_codebook.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
#!/usr/bin/env -S grimaldi --kernel faiss_binary_local
#!/usr/bin/env -S grimaldi --kernel bento_kernel_faiss
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# fmt: off
# flake8: noqa


""":md
# IndexPQ: separate codes from codebook
# Serializing codes separately, with IndexLSH and IndexPQ

Let's say, for example, you have a few vector embeddings per user
and want to shard a flat index by user so you can re-use the same LSH or PQ method
for all users but store each user's codes independently.

This notebook demonstrates how to separate serializing and deserializing the PQ codebook
(via faiss.write_index for IndexPQ) independently of the vector codes. For example, in the case
where you have a few vector embeddings per user and want to shard the flat index by user you
can re-use the same PQ method for all users but store each user's codes independently.

"""

Expand All @@ -24,11 +23,9 @@

""":py"""
d = 768
n = 10000
n = 1_000
ids = np.arange(n).astype('int64')
training_data = np.random.rand(n, d).astype('float32')
M = d//8
nbits = 8

""":py"""
def read_ids_codes():
Expand All @@ -50,9 +47,76 @@ def write_template_index(template_index):
def read_template_index_instance():
return faiss.read_index("/tmp/template.index")

""":md
## IndexLSH: separate codes

The first half of this notebook demonstrates how to store LSH codes. Unlike PQ, LSH does not require training. In fact, it's compression method, a random projections matrix, is deterministic on construction based on a random seed value that's [hardcoded](https://github.com/facebookresearch/faiss/blob/2c961cc308ade8a85b3aa10a550728ce3387f625/faiss/IndexLSH.cpp#L35).
"""

""":py"""
# at train time
nbits = 1536

""":py"""
# demonstrating encoding is deterministic

codes = []
database_vector_float32 = np.random.rand(1, d).astype(np.float32)
for i in range(10):
index = faiss.IndexIDMap2(faiss.IndexLSH(d, nbits))
code = index.index.sa_encode(database_vector_float32)
codes.append(code)

for i in range(1, 10):
assert np.array_equal(codes[0], codes[i])

""":py"""
# new database vector

ids, codes = read_ids_codes()
database_vector_id, database_vector_float32 = max(ids) + 1 if ids is not None else 1, np.random.rand(1, d).astype(np.float32)
index = faiss.IndexIDMap2(faiss.IndexLSH(d, nbits))

code = index.index.sa_encode(database_vector_float32)

if ids is not None and codes is not None:
ids = np.concatenate((ids, [database_vector_id]))
codes = np.vstack((codes, code))
else:
ids = np.array([database_vector_id])
codes = np.array([code])

write_ids_codes(ids, codes)

""":py '2840581589434841'"""
# then at query time

query_vector_float32 = np.random.rand(1, d).astype(np.float32)
index = faiss.IndexIDMap2(faiss.IndexLSH(d, nbits))
ids, codes = read_ids_codes()

index.add_sa_codes(codes, ids)

index.search(query_vector_float32, k=5)

""":py"""
!rm /tmp/ids.npy /tmp/codes.npy

""":md
## IndexPQ: separate codes from codebook

The second half of this notebook demonstrates how to separate serializing and deserializing the PQ codebook
(via faiss.write_index for IndexPQ) independently of the vector codes. For example, in the case
where you have a few vector embeddings per user and want to shard the flat index by user you
can re-use the same PQ method for all users but store each user's codes independently.

"""

""":py"""
M = d//8
nbits = 8

""":py"""
# at train time
template_index = faiss.index_factory(d, f"IDMap2,PQ{M}x{nbits}")
template_index.train(training_data)
write_template_index(template_index)
Expand All @@ -61,8 +125,8 @@ def read_template_index_instance():
# New database vector

index = read_template_index_instance()
database_vector_id, database_vector_float32 = np.random.randint(10000), np.random.rand(1, d).astype(np.float32)
ids, codes = read_ids_codes()
database_vector_id, database_vector_float32 = max(ids) + 1 if ids is not None else 1, np.random.rand(1, d).astype(np.float32)

code = index.index.sa_encode(database_vector_float32)

Expand All @@ -75,7 +139,7 @@ def read_template_index_instance():

write_ids_codes(ids, codes)

""":py '331546060044009'"""
""":py '1858280061369209'"""
# then at query time
query_vector_float32 = np.random.rand(1, d).astype(np.float32)
id_wrapper_index = read_template_index_instance()
Expand All @@ -87,3 +151,153 @@ def read_template_index_instance():

""":py"""
!rm /tmp/ids.npy /tmp/codes.npy /tmp/template.index

""":md
## Comparing these methods

- methods: Flat, LSH, PQ
- vary cost: nbits, M for 1x, 2x, 4x, 8x, 16x, 32x compression
- measure: recall@1

We don't measure latency as the number of vectors per user shard is insignificant.

"""

""":py '2898032417027201'"""
n, d

""":py"""
database_vector_ids, database_vector_float32s = np.arange(n), np.random.rand(n, d).astype(np.float32)
query_vector_float32s = np.random.rand(n, d).astype(np.float32)

""":py"""
index = faiss.index_factory(d, "IDMap2,Flat")
index.add_with_ids(database_vector_float32s, database_vector_ids)
_, ground_truth_result_ids= index.search(query_vector_float32s, k=1)

""":py '857475336204238'"""
from dataclasses import dataclass

pq_m_nbits = (
# 96 bytes
(96, 8),
(192, 4),
# 192 bytes
(192, 8),
(384, 4),
# 384 bytes
(384, 8),
(768, 4),
)
lsh_nbits = (768, 1536, 3072, 6144, 12288, 24576)


@dataclass
class Record:
type_: str
index: faiss.Index
args: tuple
recall: float


results = []

for m, nbits in pq_m_nbits:
print("pq", m, nbits)
index = faiss.index_factory(d, f"IDMap2,PQ{m}x{nbits}")
index.train(training_data)
index.add_with_ids(database_vector_float32s, database_vector_ids)
_, result_ids = index.search(query_vector_float32s, k=1)
recall = sum(result_ids == ground_truth_result_ids)
results.append(Record("pq", index, (m, nbits), recall))

for nbits in lsh_nbits:
print("lsh", nbits)
index = faiss.IndexIDMap2(faiss.IndexLSH(d, nbits))
index.add_with_ids(database_vector_float32s, database_vector_ids)
_, result_ids = index.search(query_vector_float32s, k=1)
recall = sum(result_ids == ground_truth_result_ids)
results.append(Record("lsh", index, (nbits,), recall))

""":py '556918346720794'"""
import matplotlib.pyplot as plt
import numpy as np

def create_grouped_bar_chart(x_values, y_values_list, labels_list, xlabel, ylabel, title):
num_bars_per_group = len(x_values)

plt.figure(figsize=(12, 6))

for x, y_values, labels in zip(x_values, y_values_list, labels_list):
num_bars = len(y_values)
bar_width = 0.08 * x
bar_positions = np.arange(num_bars) * bar_width - (num_bars - 1) * bar_width / 2 + x

bars = plt.bar(bar_positions, y_values, width=bar_width)

for bar, label in zip(bars, labels):
height = bar.get_height()
plt.annotate(
label,
xy=(bar.get_x() + bar.get_width() / 2, height),
xytext=(0, 3),
textcoords="offset points",
ha='center', va='bottom'
)

plt.xscale('log')
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.title(title)
plt.xticks(x_values, labels=[str(x) for x in x_values])
plt.tight_layout()
plt.show()

# # Example usage:
# x_values = [1, 2, 4, 8, 16, 32]
# y_values_list = [
# [2.5, 3.6, 1.8],
# [3.0, 2.8],
# [2.5, 3.5, 4.0, 1.0],
# [4.2],
# [3.0, 5.5, 2.2],
# [6.0, 4.5]
# ]
# labels_list = [
# ['A1', 'B1', 'C1'],
# ['A2', 'B2'],
# ['A3', 'B3', 'C3', 'D3'],
# ['A4'],
# ['A5', 'B5', 'C5'],
# ['A6', 'B6']
# ]

# create_grouped_bar_chart(x_values, y_values_list, labels_list, "x axis", "y axis", "title")

""":py '1630106834206134'"""
# x-axis: compression ratio
# y-axis: recall@1

from collections import defaultdict

x = defaultdict(list)
x[1].append(("flat", 1.00))
for r in results:
y_value = r.recall[0] / n
x_value = int(d * 4 / r.index.sa_code_size())
label = None
if r.type_ == "pq":
label = f"PQ{r.args[0]}x{r.args[1]}"
if r.type_ == "lsh":
label = f"LSH{r.args[0]}"
x[x_value].append((label, y_value))

x_values = sorted(list(x.keys()))
create_grouped_bar_chart(
x_values,
[[e[1] for e in x[x_value]] for x_value in x_values],
[[e[0] for e in x[x_value]] for x_value in x_values],
"compression ratio",
"recall@1 q=1,000 queries",
"recall@1 for a database of n=1,000 d=768 vectors",
)