Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions faiss/gpu/test/test_gpu_index_refs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import print_function
import unittest
import numpy as np
import faiss
from enum import Enum
from faiss.contrib.datasets import SyntheticDataset


class DeletionSite(Enum):
BEFORE_TRAIN = 1
BEFORE_ADD = 2
BEFORE_SEARCH = 3


def do_test(idx, index_to_delete, db, deletion_site: DeletionSite):
if deletion_site == DeletionSite.BEFORE_TRAIN:
del index_to_delete

idx.train(db)

if deletion_site == DeletionSite.BEFORE_ADD:
del index_to_delete

idx.add(db)

if deletion_site == DeletionSite.BEFORE_SEARCH:
del index_to_delete

idx.search(db, 1)


def do_multi_test(idx, index_to_delete, db):
for site in DeletionSite:
do_test(idx, index_to_delete, db, site)


#
# Test
#


class TestRefs(unittest.TestCase):
d = 32
nv = 1000
nlist = 10
res = faiss.StandardGpuResources() # pyre-ignore
db = np.random.rand(nv, d)

# These GPU classes reference another index.
# This tests to make sure the deletion of the other index
# does not cause a crash.

def test_GpuIndexIVFFlat(self):
index_to_delete = faiss.IndexIVFFlat(
faiss.IndexFlat(self.d), self.d, self.nlist
)
idx = faiss.GpuIndexIVFFlat(
self.res, index_to_delete, faiss.GpuIndexIVFFlatConfig()
)
do_multi_test(idx, index_to_delete, self.db)

def test_GpuIndexBinaryFlat(self):
ds = SyntheticDataset(64, 1000, 1000, 200)
index_to_delete = faiss.IndexBinaryFlat(ds.d)
idx = faiss.GpuIndexBinaryFlat(self.res, index_to_delete)
tobinary = faiss.index_factory(ds.d, "LSHrt")
tobinary.train(ds.get_train())
xb = tobinary.sa_encode(ds.get_database())
do_multi_test(idx, index_to_delete, xb)

def test_GpuIndexFlat(self):
index_to_delete = faiss.IndexFlat(self.d, faiss.METRIC_L2)
idx = faiss.GpuIndexFlat(self.res, index_to_delete)
do_multi_test(idx, index_to_delete, self.db)

def test_GpuIndexIVFPQ(self):
index_to_delete = faiss.IndexIVFPQ(
faiss.IndexFlatL2(self.d),
self.d, self.nlist, 2, 8)
idx = faiss.GpuIndexIVFPQ(self.res, index_to_delete)
do_multi_test(idx, index_to_delete, self.db)

def test_GpuIndexIVFScalarQuantizer(self):
index_to_delete = faiss.IndexIVFScalarQuantizer(
faiss.IndexFlat(self.d, faiss.METRIC_L2),
self.d,
self.nlist,
faiss.ScalarQuantizer.QT_8bit_direct,
faiss.METRIC_L2,
False
)
idx = faiss.GpuIndexIVFScalarQuantizer(self.res, index_to_delete)
do_multi_test(idx, index_to_delete, self.db)
10 changes: 10 additions & 0 deletions faiss/python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,16 @@ def replacement_function(*args):
setattr(this_module, function_name, replacement_function)


try:
from swigfaiss_gpu import GpuIndexIVFFlat, GpuIndexBinaryFlat, GpuIndexFlat, GpuIndexIVFPQ, GpuIndexIVFScalarQuantizer
add_ref_in_constructor(GpuIndexIVFFlat, 1)
add_ref_in_constructor(GpuIndexBinaryFlat, 1)
add_ref_in_constructor(GpuIndexFlat, 1)
add_ref_in_constructor(GpuIndexIVFPQ, 1)
add_ref_in_constructor(GpuIndexIVFScalarQuantizer, 1)
except ImportError as e:
print("Failed to load GPU Faiss: %s. Will not load constructor refs for GPU indexes." % e.args[0])

add_ref_in_constructor(IndexIVFFlat, 0)
add_ref_in_constructor(IndexIVFFlatDedup, 0)
add_ref_in_constructor(IndexPreTransform, {2: [0, 1], 1: [0]})
Expand Down