diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index 9fc201ea39..bef961353c 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -9,6 +9,7 @@ #include +#include #include #include @@ -542,12 +543,11 @@ int search_from_candidates( for (int i = 0; i < candidates.size(); i++) { idx_t v1 = candidates.ids[i]; float d = candidates.dis[i]; - FAISS_ASSERT(v1 >= 0); + assert(v1 >= 0); if (!sel || sel->is_member(v1)) { - if (nres < k) { - faiss::maxheap_push(++nres, D, I, d, v1); - } else if (d < D[0]) { - faiss::maxheap_replace_top(nres, D, I, d, v1); + if (d < D[0]) { + faiss::maxheap_replace_top(k, D, I, d, v1); + nres++; } } vt.set(v1); @@ -612,10 +612,9 @@ int search_from_candidates( auto add_to_heap = [&](const size_t idx, const float dis) { if (!sel || sel->is_member(idx)) { - if (nres < k) { - faiss::maxheap_push(++nres, D, I, dis, idx); - } else if (dis < D[0]) { - faiss::maxheap_replace_top(nres, D, I, dis, idx); + if (dis < D[0]) { + faiss::maxheap_replace_top(k, D, I, dis, idx); + nres++; } } candidates.push(idx, dis); @@ -668,7 +667,7 @@ int search_from_candidates( stats.n3 += ndis; } - return nres; + return std::min(nres, k); } std::priority_queue search_from_candidate_unbounded( @@ -816,6 +815,11 @@ HNSWStats HNSW::search( // greedy search on upper levels storage_idx_t nearest = entry_point; float d_nearest = qdis(nearest); + if (!std::isfinite(d_nearest)) { + // means either the query or the entry point are NaN: in + // both cases we can only return -1 as a result + return stats; + } for (int level = max_level; level >= 1; level--) { greedy_update_nearest(*this, qdis, level, nearest, d_nearest); @@ -826,7 +830,6 @@ HNSWStats HNSW::search( MinimaxHeap candidates(ef); candidates.push(nearest, d_nearest); - search_from_candidates( *this, qdis, k, I, D, candidates, vt, stats, 0, 0, params); } else { diff --git a/faiss/impl/ResultHandler.h b/faiss/impl/ResultHandler.h index d096fbcfa3..945f68cf93 100644 --- a/faiss/impl/ResultHandler.h +++ b/faiss/impl/ResultHandler.h @@ -445,8 +445,8 @@ struct SingleBestResultHandler { /// begin results for query # i void begin(const size_t current_idx) { this->current_idx = current_idx; - min_dis = HUGE_VALF; - min_idx = 0; + min_dis = C::neutral(); + min_idx = -1; } /// add one result for query i @@ -472,7 +472,8 @@ struct SingleBestResultHandler { this->i1 = i1; for (size_t i = i0; i < i1; i++) { - this->dis_tab[i] = HUGE_VALF; + this->dis_tab[i] = C::neutral(); + this->ids_tab[i] = -1; } } diff --git a/faiss/impl/ScalarQuantizer.cpp b/faiss/impl/ScalarQuantizer.cpp index 680a3bc059..b1da370e6f 100644 --- a/faiss/impl/ScalarQuantizer.cpp +++ b/faiss/impl/ScalarQuantizer.cpp @@ -1075,6 +1075,11 @@ void ScalarQuantizer::set_derived_sizes() { } void ScalarQuantizer::train(size_t n, const float* x) { + for (size_t i = 0; i < n * d; i++) { + FAISS_THROW_IF_NOT_MSG( + std::isfinite(x[i]), "training data contains NaN or Inf"); + } + int bit_per_dim = qtype == QT_4bit_uniform ? 4 : qtype == QT_4bit ? 4 : qtype == QT_6bit ? 6 diff --git a/tests/test_error_reporting.py b/tests/test_error_reporting.py new file mode 100644 index 0000000000..851e8575b3 --- /dev/null +++ b/tests/test_error_reporting.py @@ -0,0 +1,181 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""this is a basic test script for simple indices work""" + +import numpy as np +import unittest +import faiss + +from common_faiss_tests import get_dataset_2 +from faiss.contrib.datasets import SyntheticDataset + + +class TestValidIndexParams(unittest.TestCase): + + def test_IndexIVFPQ(self): + d = 32 + nb = 1000 + nt = 1500 + nq = 200 + + (xt, xb, xq) = get_dataset_2(d, nt, nb, nq) + + coarse_quantizer = faiss.IndexFlatL2(d) + index = faiss.IndexIVFPQ(coarse_quantizer, d, 32, 8, 8) + index.cp.min_points_per_centroid = 5 # quiet warning + index.train(xt) + index.add(xb) + + # invalid nprobe + index.nprobe = 0 + k = 10 + self.assertRaises(RuntimeError, index.search, xq, k) + + # invalid k + index.nprobe = 4 + k = -10 + self.assertRaises(AssertionError, index.search, xq, k) + + # valid params + index.nprobe = 4 + k = 10 + D, nns = index.search(xq, k) + + self.assertEqual(D.shape[0], nq) + self.assertEqual(D.shape[1], k) + + def test_IndexFlat(self): + d = 32 + nb = 1000 + nt = 0 + nq = 200 + + (xt, xb, xq) = get_dataset_2(d, nt, nb, nq) + index = faiss.IndexFlat(d, faiss.METRIC_L2) + + index.add(xb) + + # invalid k + k = -5 + self.assertRaises(AssertionError, index.search, xq, k) + + # valid k + k = 5 + D, I = index.search(xq, k) + + self.assertEqual(D.shape[0], nq) + self.assertEqual(D.shape[1], k) + + +class TestReconsException(unittest.TestCase): + + def test_recons_exception(self): + + d = 64 # dimension + nb = 1000 + rs = np.random.RandomState(1234) + xb = rs.rand(nb, d).astype('float32') + nlist = 10 + quantizer = faiss.IndexFlatL2(d) # the other index + index = faiss.IndexIVFFlat(quantizer, d, nlist) + index.train(xb) + index.add(xb) + index.make_direct_map() + + index.reconstruct(9) + + self.assertRaises( + RuntimeError, + index.reconstruct, 100001 + ) + + def test_reconstuct_after_add(self): + index = faiss.index_factory(10, 'IVF5,SQfp16') + index.train(faiss.randn((100, 10), 123)) + index.add(faiss.randn((100, 10), 345)) + index.make_direct_map() + index.add(faiss.randn((100, 10), 678)) + + # should not raise an exception + index.reconstruct(5) + print(index.ntotal) + index.reconstruct(150) + + +class TestNaN(unittest.TestCase): + """ NaN values handling is transparent: they don't produce results + but should not crash. The tests below cover a few common index types. + """ + + def do_test_train(self, factory_string): + """ NaN and Inf should raise an exception at train time """ + ds = SyntheticDataset(32, 200, 20, 10) + index = faiss.index_factory(ds.d, factory_string) + # try to train with NaNs + xt = ds.get_train().copy() + xt[:, ::4] = np.nan + self.assertRaises(RuntimeError, index.train, xt) + + def test_train_IVFSQ(self): + self.do_test_train("IVF10,SQ8") + + def test_train_IVFPQ(self): + self.do_test_train("IVF10,PQ4np") + + def test_train_SQ(self): + self.do_test_train("SQ8") + + def do_test_add(self, factory_string): + """ stored NaNs should not be returned at search time """ + ds = SyntheticDataset(32, 200, 20, 10) + index = faiss.index_factory(ds.d, factory_string) + if not index.is_trained: + index.train(ds.get_train()) + xb = ds.get_database() + xb[12, 3] = np.nan + index.add(xb) + D, I = index.search(ds.get_queries(), 20) + self.assertTrue(np.where(I == 12)[0].size == 0) + + def test_add_Flat(self): + self.do_test_add("Flat") + + def test_add_HNSW(self): + self.do_test_add("HNSW32,Flat") + + def xx_test_add_SQ8(self): + # this is expected to fail because: + # in ASAN mode, the float NaN -> int conversion crashes + # in opt mode it works but there is no way to encode the NaN, + # so the value cannot be ignored. + self.do_test_add("SQ8") + + def test_add_IVFFlat(self): + self.do_test_add("IVF10,Flat") + + def do_test_search(self, factory_string): + """ NaN query vectors should return -1 """ + ds = SyntheticDataset(32, 200, 20, 10) + index = faiss.index_factory(ds.d, factory_string) + if not index.is_trained: + index.train(ds.get_train()) + index.add(ds.get_database()) + xq = ds.get_queries() + xq[7, 3] = np.nan + D, I = index.search(ds.get_queries(), 20) + self.assertTrue(np.all(I[7] == -1)) + + def test_search_Flat(self): + self.do_test_search("Flat") + + def test_search_HNSW(self): + self.do_test_search("HNSW32,Flat") + + def test_search_IVFFlat(self): + self.do_test_search("IVF10,Flat") + + def test_search_SQ(self): + self.do_test_search("SQ8") diff --git a/tests/test_index.py b/tests/test_index.py index 0e828e08c1..47bb46b6f3 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -4,8 +4,6 @@ # LICENSE file in the root directory of this source tree. """this is a basic test script for simple indices work""" -from __future__ import absolute_import, division, print_function -# no unicode_literals because it messes up in py2 import numpy as np import unittest @@ -13,7 +11,6 @@ import tempfile import os import re -import warnings from common_faiss_tests import get_dataset, get_dataset_2 @@ -1007,41 +1004,6 @@ def test_replica_flag_propagation(self): index.remove_replica(index1) self.assertEqual(index.ntotal, 0) -class TestReconsException(unittest.TestCase): - - def test_recons_exception(self): - - d = 64 # dimension - nb = 1000 - rs = np.random.RandomState(1234) - xb = rs.rand(nb, d).astype('float32') - nlist = 10 - quantizer = faiss.IndexFlatL2(d) # the other index - index = faiss.IndexIVFFlat(quantizer, d, nlist) - index.train(xb) - index.add(xb) - index.make_direct_map() - - index.reconstruct(9) - - self.assertRaises( - RuntimeError, - index.reconstruct, 100001 - ) - - def test_reconstuct_after_add(self): - index = faiss.index_factory(10, 'IVF5,SQfp16') - index.train(faiss.randn((100, 10), 123)) - index.add(faiss.randn((100, 10), 345)) - index.make_direct_map() - index.add(faiss.randn((100, 10), 678)) - - # should not raise an exception - index.reconstruct(5) - print(index.ntotal) - index.reconstruct(150) - - class TestReconsHash(unittest.TestCase): def do_test(self, index_key): @@ -1113,62 +1075,6 @@ def test_IVFPQ(self): self.do_test("IVF5,PQ4x4np") -class TestValidIndexParams(unittest.TestCase): - - def test_IndexIVFPQ(self): - d = 32 - nb = 1000 - nt = 1500 - nq = 200 - - (xt, xb, xq) = get_dataset_2(d, nt, nb, nq) - - coarse_quantizer = faiss.IndexFlatL2(d) - index = faiss.IndexIVFPQ(coarse_quantizer, d, 32, 8, 8) - index.cp.min_points_per_centroid = 5 # quiet warning - index.train(xt) - index.add(xb) - - # invalid nprobe - index.nprobe = 0 - k = 10 - self.assertRaises(RuntimeError, index.search, xq, k) - - # invalid k - index.nprobe = 4 - k = -10 - self.assertRaises(AssertionError, index.search, xq, k) - - # valid params - index.nprobe = 4 - k = 10 - D, nns = index.search(xq, k) - - self.assertEqual(D.shape[0], nq) - self.assertEqual(D.shape[1], k) - - def test_IndexFlat(self): - d = 32 - nb = 1000 - nt = 0 - nq = 200 - - (xt, xb, xq) = get_dataset_2(d, nt, nb, nq) - index = faiss.IndexFlat(d, faiss.METRIC_L2) - - index.add(xb) - - # invalid k - k = -5 - self.assertRaises(AssertionError, index.search, xq, k) - - # valid k - k = 5 - D, I = index.search(xq, k) - - self.assertEqual(D.shape[0], nq) - self.assertEqual(D.shape[1], k) - class TestLargeRangeSearch(unittest.TestCase):