diff --git a/faiss/clone_index.cpp b/faiss/clone_index.cpp index bc08283740..5a1e5cfad2 100644 --- a/faiss/clone_index.cpp +++ b/faiss/clone_index.cpp @@ -335,9 +335,10 @@ Index* Cloner::clone_Index(const Index* index) { IndexNSG* res = clone_IndexNSG(insg); // copy the dynamic allocated graph - auto& new_graph = res->nsg.final_graph; - auto& old_graph = insg->nsg.final_graph; - new_graph = std::make_shared>(*old_graph); + if (auto& old_graph = insg->nsg.final_graph) { + auto& new_graph = res->nsg.final_graph; + new_graph = std::make_shared>(*old_graph); + } res->own_fields = true; res->storage = clone_Index(insg->storage); diff --git a/faiss/cppcontrib/factory_tools.cpp b/faiss/cppcontrib/factory_tools.cpp index d1f283b8ff..46ffada3e8 100644 --- a/faiss/cppcontrib/factory_tools.cpp +++ b/faiss/cppcontrib/factory_tools.cpp @@ -8,8 +8,22 @@ // -*- c++ -*- #include + #include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + namespace faiss { namespace { @@ -122,6 +136,11 @@ std::string reverse_index_factory(const faiss::Index* index) { const faiss::IndexHNSW* hnsw_index = dynamic_cast(index)) { return "HNSW" + std::to_string(get_hnsw_M(hnsw_index)); + } else if ( + const faiss::IndexNSG* nsg_index = + dynamic_cast(index)) { + return "NSG" + std::to_string(nsg_index->nsg.R) + "," + + reverse_index_factory(nsg_index->storage); } else if ( const faiss::IndexRefine* refine_index = dynamic_cast(index)) { diff --git a/faiss/cppcontrib/factory_tools.h b/faiss/cppcontrib/factory_tools.h index f83a6db4ad..20b9237254 100644 --- a/faiss/cppcontrib/factory_tools.h +++ b/faiss/cppcontrib/factory_tools.h @@ -9,20 +9,13 @@ #pragma once -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include namespace faiss { +struct Index; +struct IndexBinary; + std::string reverse_index_factory(const faiss::Index* index); std::string reverse_index_factory(const faiss::IndexBinary* index); diff --git a/tests/test_factory_tools.cpp b/tests/test_factory_tools.cpp index 2e77645e80..f5dda2ad68 100644 --- a/tests/test_factory_tools.cpp +++ b/tests/test_factory_tools.cpp @@ -24,6 +24,8 @@ TEST(TestFactoryTools, TestReverseIndexFactory) { "HNSW32", "SQ8", "SQfp16", + "NSG24,Flat", + "NSG16,SQ8", }) { std::unique_ptr index{index_factory(64, factory)}; ASSERT_TRUE(index); @@ -32,6 +34,8 @@ TEST(TestFactoryTools, TestReverseIndexFactory) { using Case = std::pair; for (auto [src, dst] : { Case{"SQ8,RFlat", "SQ8,Refine(Flat)"}, + Case{"NSG", "NSG32,Flat"}, + Case{"NSG,PQ8", "NSG32,PQ8x8"}, }) { std::unique_ptr index{index_factory(64, src)}; ASSERT_TRUE(index);