Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions faiss/index_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,19 @@ VectorTransform* parse_VectorTransform(const std::string& description, int d) {
* Parse IndexIVF
*/

size_t parse_nlist(std::string s) {
size_t multiplier = 1;
if (s.back() == 'k') {
s.pop_back();
multiplier = 1024;
}
if (s.back() == 'M') {
s.pop_back();
multiplier = 1024 * 1024;
}
return std::stoi(s) * multiplier;
}

// parsing guard + function
Index* parse_coarse_quantizer(
const std::string& description,
Expand All @@ -240,8 +253,8 @@ Index* parse_coarse_quantizer(
};
use_2layer = false;

if (match("IVF([0-9]+)")) {
nlist = std::stoi(sm[1].str());
if (match("IVF([0-9]+[kM]?)")) {
nlist = parse_nlist(sm[1].str());
return new IndexFlat(d, mt);
}
if (match("IMI2x([0-9]+)")) {
Expand All @@ -252,18 +265,18 @@ Index* parse_coarse_quantizer(
nlist = (size_t)1 << (2 * nbit);
return new MultiIndexQuantizer(d, 2, nbit);
}
if (match("IVF([0-9]+)_HNSW([0-9]*)")) {
nlist = std::stoi(sm[1].str());
if (match("IVF([0-9]+[kM]?)_HNSW([0-9]*)")) {
nlist = parse_nlist(sm[1].str());
int hnsw_M = sm[2].length() > 0 ? std::stoi(sm[2]) : 32;
return new IndexHNSWFlat(d, hnsw_M, mt);
}
if (match("IVF([0-9]+)_NSG([0-9]+)")) {
nlist = std::stoi(sm[1].str());
if (match("IVF([0-9]+[kM]?)_NSG([0-9]+)")) {
nlist = parse_nlist(sm[1].str());
int R = std::stoi(sm[2]);
return new IndexNSGFlat(d, R, mt);
}
if (match("IVF([0-9]+)\\(Index([0-9])\\)")) {
nlist = std::stoi(sm[1].str());
if (match("IVF([0-9]+[kM]?)\\(Index([0-9])\\)")) {
nlist = parse_nlist(sm[1].str());
int no = std::stoi(sm[2].str());
FAISS_ASSERT(no >= 0 && no < parenthesis_indexes.size());
return parenthesis_indexes[no].release();
Expand Down
12 changes: 12 additions & 0 deletions tests/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,18 @@ def test_ivf(self):
index = faiss.index_factory(123, "IVF456,Flat")
self.assertEqual(index.__class__, faiss.IndexIVFFlat)

def test_ivf_suffix_k(self):
index = faiss.index_factory(123, "IVF3k,Flat")
self.assertEqual(index.nlist, 3072)

def test_ivf_suffix_M(self):
index = faiss.index_factory(123, "IVF1M,Flat")
self.assertEqual(index.nlist, 1024 * 1024)

def test_ivf_suffix_HNSW_M(self):
index = faiss.index_factory(123, "IVF1M_HNSW,Flat")
self.assertEqual(index.nlist, 1024 * 1024)

def test_idmap(self):
index = faiss.index_factory(123, "Flat,IDMap")
self.assertEqual(index.__class__, faiss.IndexIDMap)
Expand Down