diff --git a/faiss/python/class_wrappers.py b/faiss/python/class_wrappers.py index d9031904cf..c66d1a507c 100644 --- a/faiss/python/class_wrappers.py +++ b/faiss/python/class_wrappers.py @@ -1066,6 +1066,28 @@ def add_to_referenced_objects(self, ref): else: self.referenced_objects.append(ref) +class RememberSwigOwnership(object): + """ + SWIG's seattr transfers ownership of SWIG wrapped objects to the class + (btw this seems to contradict https://www.swig.org/Doc1.3/Python.html#Python_nn22 + 31.4.2) + This interferes with how we manage ownership: with the referenced_objects + table. Therefore, we reset the thisown field in this context manager. + """ + + def __init__(self, obj): + self.obj = obj + + def __enter__(self): + if hasattr(self.obj, "thisown"): + self.old_thisown = self.obj.thisown + else: + self.old_thisown = None + + def __exit__(self, *ignored): + if self.old_thisown is not None: + self.obj.thisown = self.old_thisown + def handle_SearchParameters(the_class): """ this wrapper is to enable initializations of the form @@ -1080,8 +1102,9 @@ def replacement_init(self, **args): self.original_init() for k, v in args.items(): assert hasattr(self, k) - setattr(self, k, v) - if inspect.isclass(v): + with RememberSwigOwnership(v): + setattr(self, k, v) + if type(v) not in (int, float, bool, str): add_to_referenced_objects(self, v) the_class.__init__ = replacement_init diff --git a/tests/test_search_params.py b/tests/test_search_params.py index 7b83bbfcc5..d832a07cf8 100644 --- a/tests/test_search_params.py +++ b/tests/test_search_params.py @@ -7,6 +7,8 @@ import faiss import unittest +import sys +import gc from faiss.contrib import datasets from faiss.contrib.evaluation import sort_range_res_2, check_ref_range_results @@ -346,6 +348,28 @@ def test_max_codes(self): if stats.ndis < target_ndis: np.testing.assert_equal(I0[q], Iq[0]) + def test_ownership(self): + # see https://github.com/facebookresearch/faiss/issues/2996 + subset = np.arange(0, 50) + sel = faiss.IDSelectorBatch(subset) + self.assertTrue(sel.this.own()) + params = faiss.SearchParameters(sel=sel) + self.assertTrue(sel.this.own()) # otherwise mem leak! + # this is a somewhat fragile test because it assumes the + # gc decreases refcounts immediately. + prev_count = sys.getrefcount(sel) + del params + new_count = sys.getrefcount(sel) + self.assertEqual(new_count, prev_count - 1) + + # check for other objects as well + sel1 = faiss.IDSelectorBatch([1, 2, 3]) + sel2 = faiss.IDSelectorBatch([4, 5, 6]) + sel = faiss.IDSelectorAnd(sel1, sel2) + # make storage is still managed by python + self.assertTrue(sel1.this.own()) + self.assertTrue(sel2.this.own()) + class TestSelectorCallback(unittest.TestCase): @@ -417,6 +441,7 @@ def test_12_92(self): print(j01) assert j01[0] >= j01[1] + class TestPrecomputed(unittest.TestCase): def test_knn_and_range(self):