From 2e7c61cb519ec5365cd1a240c0e2135f6dcfd1d9 Mon Sep 17 00:00:00 2001 From: MaggieQi Date: Sat, 29 Dec 2018 20:17:21 +0800 Subject: [PATCH] add LICENSE (#1) * add LICENSE * new release * add url --- .gitignore | 1 + AnnService/CMakeLists.txt | 32 +- AnnService/CoreLibrary.vcxproj | 11 + AnnService/CoreLibrary.vcxproj.filters | 6 +- AnnService/IndexBuilder.vcxproj | 2 + AnnService/Server.vcxproj | 2 + AnnService/inc/Core/BKT/Index.h | 113 +-- AnnService/inc/Core/Common.h | 6 + AnnService/inc/Core/Common/DataUtils.h | 378 ++++++---- AnnService/inc/Core/Common/Dataset.h | 177 +---- AnnService/inc/Core/Common/FineGrainedLock.h | 48 ++ AnnService/inc/Core/Common/Heap.h | 2 +- AnnService/inc/Core/Common/QueryResultSet.h | 8 +- AnnService/inc/Core/Common/WorkSpace.h | 4 +- AnnService/inc/Core/KDT/Index.h | 103 +-- .../inc/Core/KDT/ParameterDefinitionList.h | 2 +- AnnService/inc/Core/MetadataSet.h | 48 +- AnnService/inc/Core/SearchQuery.h | 34 +- AnnService/inc/Core/VectorIndex.h | 48 +- AnnService/inc/Core/VectorSet.h | 16 +- AnnService/packages.config | 2 + AnnService/src/Client/main.cpp | 2 +- AnnService/src/Core/BKT/BKTIndex.cpp | 679 +++++++++++------- AnnService/src/Core/Common/WorkSpacePool.cpp | 3 + AnnService/src/Core/KDT/KDTIndex.cpp | 663 +++++++++++------ AnnService/src/Core/MetadataSet.cpp | 155 ++-- AnnService/src/Core/VectorIndex.cpp | 88 ++- AnnService/src/Core/VectorSet.cpp | 38 +- AnnService/src/Helper/ArgumentsParser.cpp | 5 +- AnnService/src/Helper/CommonHelper.cpp | 4 +- AnnService/src/IndexBuilder/Options.cpp | 2 +- .../VectorSetReaders/DefaultReader.cpp | 2 +- AnnService/src/IndexBuilder/main.cpp | 74 +- AnnService/src/Server/SearchExecutor.cpp | 4 +- AnnService/src/Server/SearchService.cpp | 2 +- AnnService/src/Socket/RemoteSearchQuery.cpp | 6 +- CMakeLists.txt | 31 + LICENSE | 21 + PythonWrapper/CMakeLists.txt | 48 +- PythonWrapper/PythonClient.vcxproj | 2 +- PythonWrapper/PythonCore.vcxproj | 4 +- PythonWrapper/inc/CoreInterface.h | 11 +- PythonWrapper/inc/PyByteArray.i | 6 +- PythonWrapper/src/CoreInterface.cpp | 88 ++- README.md | 110 ++- SPTAG.sln | 10 + Search/CMakeLists.txt | 14 + Search/Search.vcxproj | 170 +++++ Search/Search.vcxproj.filters | 25 + Search/main.cpp | 285 ++++++++ Search/packages.config | 12 + Test/CMakeLists.txt | 12 +- Test/Test.vcxproj | 14 +- Test/Test.vcxproj.filters | 17 +- Test/Test.vcxproj.user | 5 +- Test/packages.config | 2 + Test/src/AlgoTest.cpp | 131 ++++ Test/src/BKTTest.cpp | 25 - Test/src/Base64HelperTest.cpp | 40 ++ Test/src/CommonHelperTest.cpp | 90 +++ Test/src/DistanceTest.cpp | 4 +- Test/src/IniReaderTest.cpp | 37 + Test/src/StringConvertTest.cpp | 125 ++++ azure-pipelines.yml | 4 +- docs/GettingStart.md | 217 ++++++ 65 files changed, 3072 insertions(+), 1258 deletions(-) create mode 100644 AnnService/inc/Core/Common/FineGrainedLock.h create mode 100644 LICENSE create mode 100644 Search/CMakeLists.txt create mode 100644 Search/Search.vcxproj create mode 100644 Search/Search.vcxproj.filters create mode 100644 Search/main.cpp create mode 100644 Search/packages.config create mode 100644 Test/src/AlgoTest.cpp delete mode 100644 Test/src/BKTTest.cpp create mode 100644 Test/src/Base64HelperTest.cpp create mode 100644 Test/src/CommonHelperTest.cpp create mode 100644 Test/src/IniReaderTest.cpp create mode 100644 Test/src/StringConvertTest.cpp create mode 100644 docs/GettingStart.md diff --git a/.gitignore b/.gitignore index 2869b140..57fdf2db 100644 --- a/.gitignore +++ b/.gitignore @@ -71,3 +71,4 @@ dkms.conf /obj/x64_Debug /x64/Debug /packages +/Search/Search.vcxproj.user diff --git a/AnnService/CMakeLists.txt b/AnnService/CMakeLists.txt index d834e400..f5102115 100644 --- a/AnnService/CMakeLists.txt +++ b/AnnService/CMakeLists.txt @@ -1,25 +1,3 @@ -find_package(OpenMP) -if (OpenMP_FOUND) - set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") - set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") - set (CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}") - message (STATUS "Found openmp.") -else() - message (FATAL_ERROR "Could no find openmp!") -endif() - -find_package(Boost 1.67 COMPONENTS system thread serialization wserialization regex) -if (Boost_FOUND) - include_directories (${Boost_INCLUDE_DIR}) - link_directories (${Boost_LIBRARY_DIR} "/usr/lib") - message (STATUS "Found Boost.") - message (STATUS "Include Path: ${Boost_INCLUDE_DIRS}") - message (STATUS "Library Path: ${Boost_LIBRARY_DIRS}") - message (STATUS "Library: ${Boost_LIBRARIES}") -else() - message (FATAL_ERROR "Could not find Boost 1.67!") -endif() - file(GLOB HDR_FILES ${PROJECT_SOURCE_DIR}/AnnService/inc/Core/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Core/Common/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Core/BKT/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Core/KDT/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Helper/*.h) file(GLOB SRC_FILES ${PROJECT_SOURCE_DIR}/AnnService/src/Core/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Core/Common/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Core/BKT/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Core/KDT/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Helper/*.cpp) @@ -32,24 +10,24 @@ set_target_properties(SPTAGLibStatic PROPERTIES OUTPUT_NAME SPTAGLib) file(GLOB SERVER_HDR_FILES ${HDR_FILES} ${PROJECT_SOURCE_DIR}/AnnService/inc/Server/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Socket/*.h) file(GLOB SERVER_FILES ${SRC_FILES} ${PROJECT_SOURCE_DIR}/AnnService/src/Server/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Socket/*.cpp) add_executable (server ${SERVER_FILES} ${SERVER_HDR_FILES}) -target_link_libraries(server ${Boost_LIBRARIES}) +target_link_libraries(server ${Boost_LIBRARIES} ${TBB_LIBRARIES}) file(GLOB CLIENT_HDR_FILES ${HDR_FILES} ${PROJECT_SOURCE_DIR}/AnnService/inc/Client/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Socket/*.h) file(GLOB CLIENT_FILES ${SRC_FILES} ${PROJECT_SOURCE_DIR}/AnnService/src/Client/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Socket/*.cpp) add_executable (client ${CLIENT_FILES} ${CLIENT_HDR_FILES}) -target_link_libraries(client ${Boost_LIBRARIES}) +target_link_libraries(client ${Boost_LIBRARIES} ${TBB_LIBRARIES}) file(GLOB AGG_HDR_FILES ${HDR_FILES} ${PROJECT_SOURCE_DIR}/AnnService/inc/Aggregator/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Socket/*.h) file(GLOB AGG_FILES ${SRC_FILES} ${PROJECT_SOURCE_DIR}/AnnService/src/Aggregator/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Socket/*.cpp) add_executable (aggregator ${AGG_FILES} ${AGG_HDR_FILES}) -target_link_libraries(aggregator ${Boost_LIBRARIES}) +target_link_libraries(aggregator ${Boost_LIBRARIES} ${TBB_LIBRARIES}) file(GLOB BUILDER_HDR_FILES ${HDR_FILES} ${PROJECT_SOURCE_DIR}/AnnService/inc/IndexBuilder/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/IndexBuilder/VectorSetReaders/*.h) file(GLOB BUILDER_FILES ${SRC_FILES} ${PROJECT_SOURCE_DIR}/AnnService/src/IndexBuilder/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/IndexBuilder/VectorSetReaders/*.cpp) add_executable (indexbuilder ${BUILDER_FILES} ${BUILDER_HDR_FILES}) -target_link_libraries(indexbuilder ${Boost_LIBRARIES}) +target_link_libraries(indexbuilder ${Boost_LIBRARIES} ${TBB_LIBRARIES}) install(TARGETS SPTAGLib SPTAGLibStatic server client aggregator indexbuilder RUNTIME DESTINATION bin ARCHIVE DESTINATION lib - LIBRARY DESTINATION lib) \ No newline at end of file + LIBRARY DESTINATION lib) diff --git a/AnnService/CoreLibrary.vcxproj b/AnnService/CoreLibrary.vcxproj index adda90e4..9844c709 100644 --- a/AnnService/CoreLibrary.vcxproj +++ b/AnnService/CoreLibrary.vcxproj @@ -131,6 +131,7 @@ + @@ -171,7 +172,17 @@ + + + + + + + This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. + + + \ No newline at end of file diff --git a/AnnService/CoreLibrary.vcxproj.filters b/AnnService/CoreLibrary.vcxproj.filters index 9c44efaa..7d27224d 100644 --- a/AnnService/CoreLibrary.vcxproj.filters +++ b/AnnService/CoreLibrary.vcxproj.filters @@ -115,6 +115,9 @@ Header Files\Core\KDT + + Header Files\Core\Common + @@ -132,9 +135,6 @@ Source Files\Helper - - Source Files\Core - Source Files\Core diff --git a/AnnService/IndexBuilder.vcxproj b/AnnService/IndexBuilder.vcxproj index 50638b29..b5efd475 100644 --- a/AnnService/IndexBuilder.vcxproj +++ b/AnnService/IndexBuilder.vcxproj @@ -159,6 +159,7 @@ + @@ -171,5 +172,6 @@ + \ No newline at end of file diff --git a/AnnService/Server.vcxproj b/AnnService/Server.vcxproj index d830f3bc..c2336176 100644 --- a/AnnService/Server.vcxproj +++ b/AnnService/Server.vcxproj @@ -137,6 +137,7 @@ + @@ -149,5 +150,6 @@ + \ No newline at end of file diff --git a/AnnService/inc/Core/BKT/Index.h b/AnnService/inc/Core/BKT/Index.h index 8d8cc21f..c14aa815 100644 --- a/AnnService/inc/Core/BKT/Index.h +++ b/AnnService/inc/Core/BKT/Index.h @@ -12,11 +12,13 @@ #include "../Common/Dataset.h" #include "../Common/WorkSpace.h" #include "../Common/WorkSpacePool.h" +#include "../Common/FineGrainedLock.h" +#include "../Common/DataUtils.h" #include -#include #include #include +#include namespace SPTAG { @@ -36,7 +38,7 @@ namespace BKT int childStart; int childEnd; - BKTNode(int cid = -1) : centerid(cid), childStart(-1) {} + BKTNode(int cid = -1) : centerid(cid), childStart(-1), childEnd(-1) {} }; template @@ -119,8 +121,7 @@ namespace BKT int m_iDataSize; int m_iDataDimension; COMMON::Dataset m_pSamples; - std::shared_ptr m_pMetadata; - + // BKT structures. int m_iBKTNumber; std::vector m_pBKTStart; @@ -156,7 +157,6 @@ namespace BKT char* m_pGraphMemoryFile; char* m_pDataPointsMemoryFile; - int m_iNumberOfThreads; DistCalcMethod m_iDistCalcMethod; float(*m_fComputeDistance)(const T* pX, const T* pY, int length); @@ -167,8 +167,11 @@ namespace BKT int g_iNumberOfInitialDynamicPivots; int g_iNumberOfOtherDynamicPivots; + int m_iNumberOfThreads; + std::mutex m_dataAllocLock; + COMMON::FineGrainedLock m_dataUpdateLock; + tbb::concurrent_unordered_set m_deletedID; std::unique_ptr m_workSpacePool; - public: Index() : m_iBKTNumber(1), m_iBKTKmeansK(32), @@ -204,95 +207,29 @@ namespace BKT int GetFeatureDim() const { return m_pSamples.C(); } int GetNumThreads() const { return m_iNumberOfThreads; } int GetCurrMaxCheck() const { return m_iMaxCheck; } + DistCalcMethod GetDistCalcMethod() const { return m_iDistCalcMethod; } IndexAlgoType GetIndexAlgoType() const { return IndexAlgoType::BKT; } - VectorValueType AcceptableQueryValueType() const { return GetEnumValueType(); } - void SetMetadata(const std::string& metadataFilePath, const std::string& metadataIndexPath) { - m_pMetadata.reset(new FileMetadataSet(metadataFilePath, metadataIndexPath)); - } - ByteArray GetMetadata(IndexType p_vectorID) const { - if (nullptr != m_pMetadata) - { - return m_pMetadata->GetMetadata(p_vectorID); - } - return ByteArray::c_empty; - } + VectorValueType GetVectorValueType() const { return GetEnumValueType(); } - bool BuildIndex(); - bool BuildIndex(void* p_data, int p_vectorNum, int p_dimension); - ErrorCode BuildIndex(std::shared_ptr p_vectorSet, - std::shared_ptr p_metadataSet); + ErrorCode BuildIndex(const void* p_data, int p_vectorNum, int p_dimension); + + ErrorCode LoadIndex(const std::string& p_folderPath); + ErrorCode LoadIndexFromMemory(const std::vector& p_indexBlobs); - bool LoadIndex(); - ErrorCode LoadIndex(const std::string& p_folderPath, const Helper::IniReader& p_configReader); - - bool SaveIndex(); ErrorCode SaveIndex(const std::string& p_folderPath); - void SearchIndex(COMMON::QueryResultSet &query, COMMON::WorkSpace &space) const; - ErrorCode SearchIndex(QueryResult &query) const; - - void AddNodes(const T* pData, int num, COMMON::WorkSpace &space); - + void SearchIndex(COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space, const tbb::concurrent_unordered_set &p_deleted) const; + ErrorCode SearchIndex(QueryResult &p_query) const; + + ErrorCode AddIndex(const void* p_vectors, int p_vectorNum, int p_dimension); + ErrorCode DeleteIndex(const void* p_vectors, int p_vectorNum); + ErrorCode RefineIndex(const std::string& p_folderPath); + ErrorCode MergeIndex(const char* p_indexFilePath1, const char* p_indexFilePath2); + ErrorCode SetParameter(const char* p_param, const char* p_value); std::string GetParameter(const char* p_param) const; - // This can be used for both building model files or searching with model files loaded. - void SetParameters(std::string dataPointsFile, - std::string BKTFile, - std::string graphFile, - int numBKT, - int neighborhoodSize, - int kmeansK, - int BKTLeafSize, - int numSamplesBKT, - int numTPTrees, - int TPTLeafSize, - int maxCheckForRefineGraph, - int numThreads, - DistCalcMethod distCalcMethod, - int cacheSize = -1, - int numPoints = -1) - { - m_sDataPointsFilename = dataPointsFile; - m_sBKTFilename = BKTFile; - m_sGraphFilename = graphFile; - m_iBKTNumber = numBKT; - m_iNeighborhoodSize = neighborhoodSize; - m_iBKTKmeansK = kmeansK; - m_iBKTLeafSize = BKTLeafSize; - m_iSamples = numSamplesBKT; - m_iTptreeNumber = numTPTrees; - m_iTPTLeafSize = TPTLeafSize; - m_iMaxCheckForRefineGraph = maxCheckForRefineGraph; - m_iNumberOfThreads = numThreads; - m_iDistCalcMethod = distCalcMethod; - m_fComputeDistance = COMMON::DistanceCalcSelector(m_iDistCalcMethod); - - m_iCacheSize = cacheSize; - m_iDebugLoad = numPoints; - } - - // Only used for searching with memory mapped model files - void SetParameters(char* pDataPointsMemFile, - char* pBKTMemFile, - char* pGraphMemFile, - DistCalcMethod distCalcMethod, - int maxCheck, - int numBKT, - int neighborhoodSize) - { - m_pDataPointsMemoryFile = pDataPointsMemFile; - m_pBKTMemoryFile = pBKTMemFile; - m_pGraphMemoryFile = pGraphMemFile; - m_iMaxCheck = maxCheck; - m_iBKTNumber = numBKT; - m_iNeighborhoodSize = neighborhoodSize; - m_iNumberOfThreads = 1; - m_iDistCalcMethod = distCalcMethod; - m_fComputeDistance = COMMON::DistanceCalcSelector(m_iDistCalcMethod); - } - private: // Functions for loading models from files bool LoadDataPoints(std::string sDataPointsFileName); @@ -307,8 +244,8 @@ namespace BKT bool SaveDataPoints(std::string sDataPointsFileName); // Functions for building balanced kmeans tree - void BuildBKT(); - bool SaveBKT(std::string sBKTFilename) const; + void BuildBKT(std::vector& indices, std::vector& newStart, std::vector& newRoot); + bool SaveBKT(std::string sBKTFilename, std::vector& newStart, std::vector& newRoot) const; float KmeansAssign(std::vector& indices, const int first, const int last, KmeansArgs& args, bool updateCenters); int KmeansClustering(std::vector& indices, const int first, const int last, KmeansArgs& args); diff --git a/AnnService/inc/Core/Common.h b/AnnService/inc/Core/Common.h index 75279390..7d61675b 100644 --- a/AnnService/inc/Core/Common.h +++ b/AnnService/inc/Core/Common.h @@ -7,6 +7,7 @@ #include #include #include +#include #ifndef _MSC_VER #include @@ -29,6 +30,11 @@ template inline T max(T a, T b) { return a > b ? a : b; } + +#ifndef _rotl +#define _rotl(x, n) (((x) << (n)) | ((x) >> (32-(n)))) +#endif + #else #define WIN32_LEAN_AND_MEAN #include diff --git a/AnnService/inc/Core/Common/DataUtils.h b/AnnService/inc/Core/Common/DataUtils.h index 97474827..2ecd7faa 100644 --- a/AnnService/inc/Core/Common/DataUtils.h +++ b/AnnService/inc/Core/Common/DataUtils.h @@ -12,177 +12,275 @@ namespace SPTAG { const int bufsize = 1024 * 1024 * 1024; - template - void ProcessTSVData(int id, int threadbase, long long blocksize, - std::string filename, std::string outfile, std::string outmetafile, std::string outmetaindexfile, - std::atomic_int& numSamples, int& D, DistCalcMethod distCalcMethod) { - std::ifstream inputStream(filename); - if (!inputStream.is_open()) { - std::cerr << "unable to open file " + filename << std::endl; - throw MyException("unable to open file " + filename); - exit(1); - } - std::ofstream outputStream, metaStream_out, metaStream_index; - outputStream.open(outfile + std::to_string(id + threadbase), std::ofstream::binary); - metaStream_out.open(outmetafile + std::to_string(id + threadbase), std::ofstream::binary); - metaStream_index.open(outmetaindexfile + std::to_string(id + threadbase), std::ofstream::binary); - if (!outputStream.is_open() || !metaStream_out.is_open() || !metaStream_index.is_open()) { - std::cerr << "unable to open output file " << outfile << " " << outmetafile << " " << outmetaindexfile << std::endl; - throw MyException("unable to open output files"); - exit(1); - } + class DataUtils { + public: + template + static void ProcessTSVData(int id, int threadbase, std::uint64_t blocksize, + std::string filename, std::string outfile, std::string outmetafile, std::string outmetaindexfile, + std::atomic_int& numSamples, int& D, DistCalcMethod distCalcMethod) { + std::ifstream inputStream(filename); + if (!inputStream.is_open()) { + std::cerr << "unable to open file " + filename << std::endl; + throw MyException("unable to open file " + filename); + exit(1); + } + std::ofstream outputStream, metaStream_out, metaStream_index; + outputStream.open(outfile + std::to_string(id + threadbase), std::ofstream::binary); + metaStream_out.open(outmetafile + std::to_string(id + threadbase), std::ofstream::binary); + metaStream_index.open(outmetaindexfile + std::to_string(id + threadbase), std::ofstream::binary); + if (!outputStream.is_open() || !metaStream_out.is_open() || !metaStream_index.is_open()) { + std::cerr << "unable to open output file " << outfile << " " << outmetafile << " " << outmetaindexfile << std::endl; + throw MyException("unable to open output files"); + exit(1); + } - std::vector arr; - std::vector sample; + std::vector arr; + std::vector sample; - int base = 1; - if (distCalcMethod == DistCalcMethod::Cosine) { - base = Utils::GetBase(); - } - long long writepos = 0; - int sampleSize = 0; - long long totalread = 0; - std::streamoff startpos = id * blocksize; + int base = 1; + if (distCalcMethod == DistCalcMethod::Cosine) { + base = Utils::GetBase(); + } + std::uint64_t writepos = 0; + int sampleSize = 0; + std::uint64_t totalread = 0; + std::streamoff startpos = id * blocksize; #ifndef _MSC_VER - int enter_size = 1; + int enter_size = 1; #else - int enter_size = 1; + int enter_size = 1; #endif - std::string currentLine; - size_t index; - inputStream.seekg(startpos, std::ifstream::beg); - if (id != 0) { - std::getline(inputStream, currentLine); - totalread += currentLine.length() + enter_size; - } - std::cout << "Begin thread " << id << " begin at:" << (startpos + totalread) << std::endl; - while (!inputStream.eof() && totalread <= blocksize) { - std::getline(inputStream, currentLine); - if (currentLine.length() <= enter_size || (index = Utils::ProcessLine(currentLine, arr, D, base, distCalcMethod)) < 0) { + std::string currentLine; + size_t index; + inputStream.seekg(startpos, std::ifstream::beg); + if (id != 0) { + std::getline(inputStream, currentLine); totalread += currentLine.length() + enter_size; - continue; } - sample.resize(D); - for (int j = 0; j < D; j++) sample[j] = (T)arr[j]; + std::cout << "Begin thread " << id << " begin at:" << (startpos + totalread) << std::endl; + while (!inputStream.eof() && totalread <= blocksize) { + std::getline(inputStream, currentLine); + if (currentLine.length() <= enter_size || (index = Utils::ProcessLine(currentLine, arr, D, base, distCalcMethod)) < 0) { + totalread += currentLine.length() + enter_size; + continue; + } + sample.resize(D); + for (int j = 0; j < D; j++) sample[j] = (T)arr[j]; - outputStream.write((char *)(sample.data()), sizeof(T)*D); - metaStream_index.write((char *)&writepos, sizeof(long long)); - metaStream_out.write(currentLine.c_str(), index); + outputStream.write((char *)(sample.data()), sizeof(T)*D); + metaStream_index.write((char *)&writepos, sizeof(std::uint64_t)); + metaStream_out.write(currentLine.c_str(), index); - writepos += index; - sampleSize += 1; - totalread += currentLine.length() + enter_size; - } - metaStream_index.write((char *)&writepos, sizeof(long long)); - metaStream_index.write((char *)&sampleSize, sizeof(int)); - inputStream.close(); - outputStream.close(); - metaStream_out.close(); - metaStream_index.close(); - - numSamples.fetch_add(sampleSize); - - std::cout << "Finish Thread[" << id << ", " << sampleSize << "] at:" << (startpos + totalread) << std::endl; - } - - void MergeData(int threadbase, std::string outfile, std::string outmetafile, std::string outmetaindexfile, - std::atomic_int& numSamples, int D) { - std::ifstream inputStream; - std::ofstream outputStream; - char * buf = new char[bufsize]; - long long * offsets; - int partSamples; - int metaSamples = 0; - long long lastoff = 0; - - outputStream.open(outfile, std::ofstream::binary); - outputStream.write((char *)&numSamples, sizeof(int)); - outputStream.write((char *)&D, sizeof(int)); - for (int i = 0; i < threadbase; i++) { - std::string file = outfile + std::to_string(i); - inputStream.open(file, std::ifstream::binary); - while (!inputStream.eof()) { - inputStream.read(buf, bufsize); - outputStream.write(buf, inputStream.gcount()); + writepos += index; + sampleSize += 1; + totalread += currentLine.length() + enter_size; } + metaStream_index.write((char *)&writepos, sizeof(std::uint64_t)); + metaStream_index.write((char *)&sampleSize, sizeof(int)); inputStream.close(); - remove(file.c_str()); + outputStream.close(); + metaStream_out.close(); + metaStream_index.close(); + + numSamples.fetch_add(sampleSize); + + std::cout << "Finish Thread[" << id << ", " << sampleSize << "] at:" << (startpos + totalread) << std::endl; } - outputStream.close(); - - outputStream.open(outmetafile, std::ofstream::binary); - for (int i = 0; i < threadbase; i++) { - std::string file = outmetafile + std::to_string(i); - inputStream.open(file, std::ifstream::binary); - while (!inputStream.eof()) { - inputStream.read(buf, bufsize); - outputStream.write(buf, inputStream.gcount()); + + static void MergeData(int threadbase, std::string outfile, std::string outmetafile, std::string outmetaindexfile, + std::atomic_int& numSamples, int D) { + std::ifstream inputStream; + std::ofstream outputStream; + char * buf = new char[bufsize]; + std::uint64_t * offsets; + int partSamples; + int metaSamples = 0; + std::uint64_t lastoff = 0; + + outputStream.open(outfile, std::ofstream::binary); + outputStream.write((char *)&numSamples, sizeof(int)); + outputStream.write((char *)&D, sizeof(int)); + for (int i = 0; i < threadbase; i++) { + std::string file = outfile + std::to_string(i); + inputStream.open(file, std::ifstream::binary); + while (!inputStream.eof()) { + inputStream.read(buf, bufsize); + outputStream.write(buf, inputStream.gcount()); + } + inputStream.close(); + remove(file.c_str()); } - inputStream.close(); - remove(file.c_str()); - } - outputStream.close(); - delete[] buf; + outputStream.close(); + + outputStream.open(outmetafile, std::ofstream::binary); + for (int i = 0; i < threadbase; i++) { + std::string file = outmetafile + std::to_string(i); + inputStream.open(file, std::ifstream::binary); + while (!inputStream.eof()) { + inputStream.read(buf, bufsize); + outputStream.write(buf, inputStream.gcount()); + } + inputStream.close(); + remove(file.c_str()); + } + outputStream.close(); + delete[] buf; - outputStream.open(outmetaindexfile, std::ofstream::binary); - outputStream.write((char *)&numSamples, sizeof(int)); - for (int i = 0; i < threadbase; i++) { - std::string file = outmetaindexfile + std::to_string(i); - inputStream.open(file, std::ifstream::binary); + outputStream.open(outmetaindexfile, std::ofstream::binary); + outputStream.write((char *)&numSamples, sizeof(int)); + for (int i = 0; i < threadbase; i++) { + std::string file = outmetaindexfile + std::to_string(i); + inputStream.open(file, std::ifstream::binary); - inputStream.seekg(-((long long)sizeof(int)), inputStream.end); - inputStream.read((char *)&partSamples, sizeof(int)); - offsets = new long long[partSamples + 1]; + inputStream.seekg(-((long long)sizeof(int)), inputStream.end); + inputStream.read((char *)&partSamples, sizeof(int)); + offsets = new std::uint64_t[partSamples + 1]; - inputStream.seekg(0, inputStream.beg); - inputStream.read((char *)offsets, sizeof(long long)*(partSamples + 1)); - inputStream.close(); - remove(file.c_str()); + inputStream.seekg(0, inputStream.beg); + inputStream.read((char *)offsets, sizeof(std::uint64_t)*(partSamples + 1)); + inputStream.close(); + remove(file.c_str()); - for (int j = 0; j < partSamples + 1; j++) - offsets[j] += lastoff; - outputStream.write((char *)offsets, sizeof(long long)*partSamples); + for (int j = 0; j < partSamples + 1; j++) + offsets[j] += lastoff; + outputStream.write((char *)offsets, sizeof(std::uint64_t)*partSamples); + + lastoff = offsets[partSamples]; + metaSamples += partSamples; + delete[] offsets; + } + outputStream.write((char *)&lastoff, sizeof(std::uint64_t)); + outputStream.close(); - lastoff = offsets[partSamples]; - metaSamples += partSamples; - delete[] offsets; + std::cout << "numSamples:" << numSamples << " metaSamples:" << metaSamples << " D:" << D << std::endl; } - outputStream.write((char *)&lastoff, sizeof(long long)); - outputStream.close(); - std::cout << "numSamples:" << numSamples << " metaSamples:" << metaSamples << " D:" << D << std::endl; - } + static bool MergeIndex(const std::string& p_vectorfile1, const std::string& p_metafile1, const std::string& p_metaindexfile1, + const std::string& p_vectorfile2, const std::string& p_metafile2, const std::string& p_metaindexfile2) { + std::ifstream inputStream1, inputStream2; + std::ofstream outputStream; + char * buf = new char[bufsize]; + int R1, R2, C1, C2; + +#define MergeVector(inputStream, vectorFile, R, C) \ + inputStream.open(vectorFile, std::ifstream::binary); \ + if (!inputStream.is_open()) { \ + std::cout << "Cannot open vector file: " << vectorFile <<"!" << std::endl; \ + return false; \ + } \ + inputStream.read((char *)&(R), sizeof(int)); \ + inputStream.read((char *)&(C), sizeof(int)); \ + + MergeVector(inputStream1, p_vectorfile1, R1, C1) + MergeVector(inputStream2, p_vectorfile2, R2, C2) +#undef MergeVector + if (C1 != C2) { + inputStream1.close(); inputStream2.close(); + std::cout << "Vector dimensions are not the same!" << std::endl; + return false; + } + R1 += R2; + outputStream.open(p_vectorfile1 + "_tmp", std::ofstream::binary); + outputStream.write((char *)&R1, sizeof(int)); + outputStream.write((char *)&C1, sizeof(int)); + while (!inputStream1.eof()) { + inputStream1.read(buf, bufsize); + outputStream.write(buf, inputStream1.gcount()); + } + while (!inputStream2.eof()) { + inputStream2.read(buf, bufsize); + outputStream.write(buf, inputStream2.gcount()); + } + inputStream1.close(); inputStream2.close(); + outputStream.close(); + + if (p_metafile1 != "" && p_metafile2 != "") { + outputStream.open(p_metafile1 + "_tmp", std::ofstream::binary); +#define MergeMeta(inputStream, metaFile) \ + inputStream.open(metaFile, std::ifstream::binary); \ + if (!inputStream.is_open()) { \ + std::cout << "Cannot open meta file: " << metaFile << "!" << std::endl; \ + return false; \ + } \ + while (!inputStream.eof()) { \ + inputStream.read(buf, bufsize); \ + outputStream.write(buf, inputStream.gcount()); \ + } \ + inputStream.close(); \ + + MergeMeta(inputStream1, p_metafile1) + MergeMeta(inputStream2, p_metafile2) +#undef MergeMeta + outputStream.close(); + delete[] buf; + + + std::uint64_t * offsets; + int partSamples; + std::uint64_t lastoff = 0; + outputStream.open(p_metaindexfile1 + "_tmp", std::ofstream::binary); + outputStream.write((char *)&R1, sizeof(int)); +#define MergeMetaIndex(inputStream, metaIndexFile) \ + inputStream.open(metaIndexFile, std::ifstream::binary); \ + if (!inputStream.is_open()) { \ + std::cout << "Cannot open meta index file: " << metaIndexFile << "!" << std::endl; \ + return false; \ + } \ + inputStream.read((char *)&partSamples, sizeof(int)); \ + offsets = new std::uint64_t[partSamples + 1]; \ + inputStream.read((char *)offsets, sizeof(std::uint64_t)*(partSamples + 1)); \ + inputStream.close(); \ + for (int j = 0; j < partSamples + 1; j++) offsets[j] += lastoff; \ + outputStream.write((char *)offsets, sizeof(std::uint64_t)*partSamples); \ + lastoff = offsets[partSamples]; \ + delete[] offsets; \ + + MergeMetaIndex(inputStream1, p_metaindexfile1) + MergeMetaIndex(inputStream2, p_metaindexfile2) +#undef MergeMetaIndex + outputStream.write((char *)&lastoff, sizeof(std::uint64_t)); + outputStream.close(); + + rename((p_metafile1 + "_tmp").c_str(), p_metafile1.c_str()); + rename((p_metaindexfile1 + "_tmp").c_str(), p_metaindexfile1.c_str()); + } + rename((p_vectorfile1 + "_tmp").c_str(), p_vectorfile1.c_str()); + + std::cout << "Merged -> numSamples:" << R1 << " D:" << C1 << std::endl; + return true; + } - template - void ParseData(std::string filenames, std::string outfile, std::string outmetafile, std::string outmetaindexfile, - int threadnum, DistCalcMethod distCalcMethod) { - omp_set_num_threads(threadnum); + template + static void ParseData(std::string filenames, std::string outfile, std::string outmetafile, std::string outmetaindexfile, + int threadnum, DistCalcMethod distCalcMethod) { + omp_set_num_threads(threadnum); - std::atomic_int numSamples = { 0 }; - int D = -1; + std::atomic_int numSamples = { 0 }; + int D = -1; - int threadbase = 0; - std::vector inputFileNames = Helper::StrUtils::SplitString(filenames, ","); - for (std::string inputFileName : inputFileNames) - { + int threadbase = 0; + std::vector inputFileNames = Helper::StrUtils::SplitString(filenames, ","); + for (std::string inputFileName : inputFileNames) + { #ifndef _MSC_VER - struct stat stat_buf; - stat(inputFileName.c_str(), &stat_buf); + struct stat stat_buf; + stat(inputFileName.c_str(), &stat_buf); #else - struct _stat64 stat_buf; - int res = _stat64(inputFileName.c_str(), &stat_buf); + struct _stat64 stat_buf; + int res = _stat64(inputFileName.c_str(), &stat_buf); #endif - long long blocksize = (stat_buf.st_size + threadnum - 1) / threadnum; + std::uint64_t blocksize = (stat_buf.st_size + threadnum - 1) / threadnum; #pragma omp parallel for - for (int i = 0; i < threadnum; i++) { - ProcessTSVData(i, threadbase, blocksize, inputFileName, outfile, outmetafile, outmetaindexfile, numSamples, D, distCalcMethod); + for (int i = 0; i < threadnum; i++) { + ProcessTSVData(i, threadbase, blocksize, inputFileName, outfile, outmetafile, outmetaindexfile, numSamples, D, distCalcMethod); + } + threadbase += threadnum; } - threadbase += threadnum; + MergeData(threadbase, outfile, outmetafile, outmetaindexfile, numSamples, D); } - MergeData(threadbase, outfile, outmetafile, outmetaindexfile, numSamples, D); - } + }; } } diff --git a/AnnService/inc/Core/Common/Dataset.h b/AnnService/inc/Core/Common/Dataset.h index 516edda3..4753b088 100644 --- a/AnnService/inc/Core/Common/Dataset.h +++ b/AnnService/inc/Core/Common/Dataset.h @@ -12,108 +12,6 @@ namespace SPTAG { namespace COMMON { - template - class LRUCache - { - private: - struct Item - { - int idx; - T* data; - Item* next; - Item(int idx_, T* data_) : idx(idx_), data(data_), next(nullptr) {} - }; - - int rows; - int cols; - T* data; - std::ifstream fp; - std::unordered_map cache; - Item* head; - - public: - LRUCache(const char * filename_, int caches_ = 10000000) - { - fp.open(filename_, std::ifstream::binary); - fp.read((char *)&rows, sizeof(int)); - fp.read((char *)&cols, sizeof(int)); - if (caches_ > rows) caches_ = rows; - - data = (T*)aligned_malloc(sizeof(T) * caches_ * cols, ALIGN); - - int i = 0, batch = 10000; - while (i + batch < caches_) - { - fp.read((char *)(data + (size_t)i*cols), sizeof(T)*cols*batch); - i += batch; - } - fp.read((char *)(data + (size_t)i*cols), sizeof(T)*cols*(caches_ - i)); - - Item *p = head = new Item(-1, nullptr); - for (i = 0; i < caches_; i++) - { - p->next = cache[(i + 1) % caches_] = new Item(i, data + (size_t)i*cols); - p = p->next; - } - p->next = head->next; - delete head; - head = p; - - std::cout << "Use LRUCache (" << caches_ << ")" << std::endl; - } - ~LRUCache() - { - fp.close(); - - Item *p = head->next, *q; - while (p != head) - { - q = p; - p = p->next; - delete q; - } - delete head; - - aligned_free(data); - } - int R() { return rows; } - int C() { return cols; } - T* get(int index) - { - auto iter = cache.find(index); - if (iter == cache.end()) - { - Item *p = head->next; - cache[index] = cache[p->idx]; - cache.erase(p->idx); - p->idx = index; - fp.seekg(sizeof(int) + sizeof(int) + index * sizeof(T) * cols, std::ios_base::beg); - fp.read((char *)p->data, sizeof(T)*cols); - head = p; - return p->data; - } - else - { - Item *p = iter->second, *q = p->next; - if (q == head || q == head->next) - { - head = q; - return q->data; - } - p->next = q->next; - cache[q->next->idx] = p; - - q->next = head->next; - head->next = q; - cache[index] = head; - cache[q->next->idx] = q; - - head = q; - return q->data; - } - } - }; - // structure to save Data and Graph template class Dataset @@ -123,43 +21,44 @@ namespace SPTAG int cols; bool ownData = false; T* data = nullptr; - LRUCache* cache = nullptr; std::vector* dataIncremental = nullptr; public: Dataset() {} - Dataset(int rows_, int cols_, T* data_ = nullptr, const char * filename_ = nullptr, int cachesize_ = 0) { Initialize(rows_, cols_, data_, filename_, cachesize_); } + Dataset(int rows_, int cols_, T* data_ = nullptr) + { + Initialize(rows_, cols_, data_); + } ~Dataset() { - if (cache != nullptr) delete cache; if (ownData) aligned_free(data); if (dataIncremental) { dataIncremental->clear(); delete dataIncremental; } } - void Initialize(int rows_, int cols_, T* data_ = nullptr, const char * filename_ = nullptr, int cachesize_ = 0) + void Initialize(int rows_, int cols_, T* data_ = nullptr) { - if (filename_ != nullptr) + rows = rows_; + cols = cols_; + data = data_; + if (data == nullptr) { - cache = new LRUCache(filename_, cachesize_); - rows = cache->R(); - cols = cache->C(); + ownData = true; + data = (T*)aligned_malloc(sizeof(T) * rows * cols, ALIGN); } - else + dataIncremental = new std::vector(); + } + void SetR(int R_) + { + if (R_ >= rows) + dataIncremental->resize((R_ - rows) * cols); + else { - rows = rows_; - cols = cols_; - data = data_; - if (data == nullptr) - { - ownData = true; - data = (T*)aligned_malloc(sizeof(T) * rows * cols, ALIGN); - } + rows = R_; + dataIncremental->clear(); } - dataIncremental = new std::vector(); } - void SetR(int R_) { rows = R_; } int R() const { return (int)(rows + dataIncremental->size() / cols); } int C() const { return cols; } T* operator[](int index) @@ -167,39 +66,39 @@ namespace SPTAG if (index >= rows) { return dataIncremental->data() + (size_t)(index - rows)*cols; } - if (cache != nullptr) - { - return cache->get(index); - } - else - { - return data + (size_t)index*cols; - } + return data + (size_t)index*cols; } const T* operator[](int index) const { if (index >= rows) { return dataIncremental->data() + (size_t)(index - rows)*cols; } - if (cache != nullptr) - { - return cache->get(index); - } - else - { - return data + (size_t)index*cols; - } + return data + (size_t)index*cols; } - T* GetData() { + + T* GetData() + { return data; } + void reset() + { + if (ownData) { + aligned_free(data); + ownData = false; + } + if (dataIncremental) { + dataIncremental->clear(); + delete dataIncremental; + } + } + void AddBatch(const T* pData, int num) { dataIncremental->insert(dataIncremental->end(), pData, pData + num*cols); } - void AddReserved(int num) + void AddBatch(int num) { dataIncremental->insert(dataIncremental->end(), (size_t)num*cols, T(-1)); } diff --git a/AnnService/inc/Core/Common/FineGrainedLock.h b/AnnService/inc/Core/Common/FineGrainedLock.h new file mode 100644 index 00000000..e1d5dc39 --- /dev/null +++ b/AnnService/inc/Core/Common/FineGrainedLock.h @@ -0,0 +1,48 @@ +#ifndef _SPTAG_COMMON_FINEGRAINEDLOCK_H_ +#define _SPTAG_COMMON_FINEGRAINEDLOCK_H_ + +#include +#include +#include + +namespace SPTAG +{ + namespace COMMON + { + class FineGrainedLock { + public: + FineGrainedLock() {} + ~FineGrainedLock() { + for (int i = 0; i < locks.size(); i++) + locks[i].reset(); + locks.clear(); + } + + void resize(int n) { + int current = (int)locks.size(); + if (current <= n) { + locks.resize(n); + for (int i = current; i < n; i++) + locks[i].reset(new std::mutex); + } + else { + for (int i = n; i < current; i++) + locks[i].reset(); + locks.resize(n); + } + } + + std::mutex& operator[](int idx) { + return *locks[idx]; + } + + const std::mutex& operator[](int idx) const { + return *locks[idx]; + } + private: + std::vector> locks; + }; + } +} + +#endif // _SPTAG_COMMON_FINEGRAINEDLOCK_H_ \ No newline at end of file diff --git a/AnnService/inc/Core/Common/Heap.h b/AnnService/inc/Core/Common/Heap.h index f35ddc46..7d4dcc56 100644 --- a/AnnService/inc/Core/Common/Heap.h +++ b/AnnService/inc/Core/Common/Heap.h @@ -25,7 +25,7 @@ namespace SPTAG inline int size() { return count; } inline bool empty() { return count == 0; } inline void clear() { count = 0; } - inline T& Top() { return heap[1]; } + inline T& Top() { if (count == 0) return heap[0]; else return heap[1]; } // Insert a new element in the heap. void insert(T value) diff --git a/AnnService/inc/Core/Common/QueryResultSet.h b/AnnService/inc/Core/Common/QueryResultSet.h index 9808414b..33dcf5c7 100644 --- a/AnnService/inc/Core/Common/QueryResultSet.h +++ b/AnnService/inc/Core/Common/QueryResultSet.h @@ -10,13 +10,13 @@ namespace COMMON inline bool operator < (const BasicResult& lhs, const BasicResult& rhs) { - return ((lhs.Dist < rhs.Dist) || ((lhs.Dist == rhs.Dist) && (lhs.Key < rhs.Key))); + return ((lhs.Dist < rhs.Dist) || ((lhs.Dist == rhs.Dist) && (lhs.VID < rhs.VID))); } inline bool Compare(BasicResult& lhs, BasicResult& rhs) { - return ((lhs.Dist < rhs.Dist) || ((lhs.Dist == rhs.Dist) && (lhs.Key < rhs.Key))); + return ((lhs.Dist < rhs.Dist) || ((lhs.Dist == rhs.Dist) && (lhs.VID < rhs.VID))); } @@ -50,9 +50,9 @@ class QueryResultSet : public QueryResult bool AddPoint(const int index, float dist) { - if (dist < m_results[0].Dist || (dist == m_results[0].Dist && index < m_results[0].Key)) + if (dist < m_results[0].Dist || (dist == m_results[0].Dist && index < m_results[0].VID)) { - m_results[0].Key = index; + m_results[0].VID = index; m_results[0].Dist = dist; Heapify(m_resultNum); return true; diff --git a/AnnService/inc/Core/Common/WorkSpace.h b/AnnService/inc/Core/Common/WorkSpace.h index 57c1b609..f17ff0bb 100644 --- a/AnnService/inc/Core/Common/WorkSpace.h +++ b/AnnService/inc/Core/Common/WorkSpace.h @@ -14,7 +14,7 @@ namespace SPTAG int node; float distance; - HeapCell(int _node = -1, float _distance = 0) : node(_node), distance(_distance) {} + HeapCell(int _node = -1, float _distance = MaxDist) : node(_node), distance(_distance) {} inline bool operator < (const HeapCell& rhs) { @@ -203,7 +203,7 @@ namespace SPTAG return nodeCheckStatus.CheckAndSet(idx); } - CountVector nodeCheckStatus; + OptHashPosVector nodeCheckStatus; //OptHashPosVector nodeCheckStatus; // counter for dynamic pivoting diff --git a/AnnService/inc/Core/KDT/Index.h b/AnnService/inc/Core/KDT/Index.h index b84f5087..d21e76b0 100644 --- a/AnnService/inc/Core/KDT/Index.h +++ b/AnnService/inc/Core/KDT/Index.h @@ -12,11 +12,12 @@ #include "../Common/Dataset.h" #include "../Common/WorkSpace.h" #include "../Common/WorkSpacePool.h" +#include "../Common/FineGrainedLock.h" +#include "../Common/DataUtils.h" #include -#include #include -#include +#include namespace SPTAG { @@ -45,7 +46,6 @@ namespace SPTAG int m_iDataSize; int m_iDataDimension; COMMON::Dataset m_pSamples; - std::shared_ptr m_pMetadata; // KDT structures. int m_iKDTNumber; @@ -82,7 +82,6 @@ namespace SPTAG char* m_pGraphMemoryFile; char* m_pDataPointsMemoryFile; - int m_iNumberOfThreads; DistCalcMethod m_iDistCalcMethod; float(*m_fComputeDistance)(const T* pX, const T* pY, int length); @@ -93,8 +92,11 @@ namespace SPTAG int g_iNumberOfInitialDynamicPivots; int g_iNumberOfOtherDynamicPivots; + int m_iNumberOfThreads; + std::mutex m_dataAllocLock; + COMMON::FineGrainedLock m_dataUpdateLock; + tbb::concurrent_unordered_set m_deletedID; std::unique_ptr m_workSpacePool; - public: Index() : m_iKDTNumber(1), m_numTopDimensionKDTSplit(5), @@ -130,89 +132,28 @@ namespace SPTAG int GetFeatureDim() const { return m_pSamples.C(); } int GetNumThreads() const { return m_iNumberOfThreads; } int GetCurrMaxCheck() const { return m_iMaxCheck; } + DistCalcMethod GetDistCalcMethod() const { return m_iDistCalcMethod; } IndexAlgoType GetIndexAlgoType() const { return IndexAlgoType::KDT; } - VectorValueType AcceptableQueryValueType() const { return GetEnumValueType(); } - void SetMetadata(const std::string& metadataFilePath, const std::string& metadataIndexPath) { - m_pMetadata.reset(new FileMetadataSet(metadataFilePath, metadataIndexPath)); - } - ByteArray GetMetadata(IndexType p_vectorID) const { - if (nullptr != m_pMetadata) - { - return m_pMetadata->GetMetadata(p_vectorID); - } - return ByteArray::c_empty; - } - - bool BuildIndex(); - bool BuildIndex(void* p_data, int p_vectorNum, int p_dimension); - ErrorCode BuildIndex(std::shared_ptr p_vectorSet, - std::shared_ptr p_metadataSet); + VectorValueType GetVectorValueType() const { return GetEnumValueType(); } - bool LoadIndex(); - ErrorCode LoadIndex(const std::string& p_folderPath, const Helper::IniReader& p_configReader); + ErrorCode BuildIndex(const void* p_data, int p_vectorNum, int p_dimension); + ErrorCode LoadIndex(const std::string& p_folderPath); + ErrorCode LoadIndexFromMemory(const std::vector& p_indexBlobs); - bool SaveIndex(); ErrorCode SaveIndex(const std::string& p_folderPath); - void SearchIndex(COMMON::QueryResultSet &query, COMMON::WorkSpace &space) const; - ErrorCode SearchIndex(QueryResult &query) const; - - void AddNodes(const T* pData, int num, COMMON::WorkSpace &space); + void SearchIndex(COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space, const tbb::concurrent_unordered_set &p_deleted) const; + ErrorCode SearchIndex(QueryResult &p_query) const; + + ErrorCode AddIndex(const void* p_vectors, int p_vectorNum, int p_dimension); + ErrorCode DeleteIndex(const void* p_vectors, int p_vectorNum); + ErrorCode RefineIndex(const std::string& p_folderPath); + ErrorCode MergeIndex(const char* p_indexFilePath1, const char* p_indexFilePath2); ErrorCode SetParameter(const char* p_param, const char* p_value); std::string GetParameter(const char* p_param) const; - // This can be used for both building model files or searching with model files loaded. - void SetParameters(std::string dataPointsFile, - std::string KDTFile, - std::string graphFile, - int numKDT, - int neighborhoodSize, - int numTPTrees, - int TPTLeafSize, - int maxCheckForRefineGraph, - int numThreads, - DistCalcMethod distCalcMethod, - int cacheSize = -1, - int numPoints = -1) - { - m_sDataPointsFilename = dataPointsFile; - m_sKDTFilename = KDTFile; - m_sGraphFilename = graphFile; - m_iKDTNumber = numKDT; - m_iNeighborhoodSize = neighborhoodSize; - m_iTPTNumber = numTPTrees; - m_iTPTLeafSize = TPTLeafSize; - m_iMaxCheckForRefineGraph = maxCheckForRefineGraph; - m_iNumberOfThreads = numThreads; - m_iDistCalcMethod = distCalcMethod; - m_fComputeDistance = COMMON::DistanceCalcSelector(m_iDistCalcMethod); - - m_iCacheSize = cacheSize; - m_iDebugLoad = numPoints; - } - - // Only used for searching with memory mapped model files - void SetParameters(char* pDataPointsMemFile, - char* pKDTMemFile, - char* pGraphMemFile, - DistCalcMethod distCalcMethod, - int maxCheck, - int numKDT, - int neighborhoodSize) - { - m_pDataPointsMemoryFile = pDataPointsMemFile; - m_pKDTMemoryFile = pKDTMemFile; - m_pGraphMemoryFile = pGraphMemFile; - m_iMaxCheck = maxCheck; - m_iKDTNumber = numKDT; - m_iNeighborhoodSize = neighborhoodSize; - m_iNumberOfThreads = 1; - m_iDistCalcMethod = distCalcMethod; - m_fComputeDistance = COMMON::DistanceCalcSelector(m_iDistCalcMethod); - } - private: // Functions for loading models from files bool LoadDataPoints(std::string sDataPointsFileName); @@ -227,8 +168,8 @@ namespace SPTAG bool SaveDataPoints(std::string sDataPointsFileName); // Functions for building kdtree - void BuildKDT(); - bool SaveKDT(std::string sKDTFilename) const; + void BuildKDT(std::vector& indices, std::vector& newStart, std::vector& newRoot); + bool SaveKDT(std::string sKDTFilename, std::vector& newStart, std::vector& newRoot) const; void DivideTree(KDTNode* pTree, std::vector& indices,int first, int last, int index, int &iTreeSize); void ChooseDivision(KDTNode& node, const std::vector& indices, int first, int last); @@ -249,7 +190,7 @@ namespace SPTAG // Functions for hybrid search void KDTSearch(const int node, const bool isInit, const float distBound, - COMMON::WorkSpace& space, COMMON::QueryResultSet &query) const; + COMMON::WorkSpace& space, COMMON::QueryResultSet &query, const tbb::concurrent_unordered_set &deleted) const; }; } // namespace KDT } // namespace SPTAG diff --git a/AnnService/inc/Core/KDT/ParameterDefinitionList.h b/AnnService/inc/Core/KDT/ParameterDefinitionList.h index 4caa1ca5..932a525f 100644 --- a/AnnService/inc/Core/KDT/ParameterDefinitionList.h +++ b/AnnService/inc/Core/KDT/ParameterDefinitionList.h @@ -1,4 +1,4 @@ -#ifdef DefineParameter +#ifdef DefineKDTParameter // DefineKDTParameter(VarName, VarType, DefaultValue, RepresentStr) DefineKDTParameter(m_sKDTFilename, std::string, std::string("tree.bin"), "TreeFilePath") diff --git a/AnnService/inc/Core/MetadataSet.h b/AnnService/inc/Core/MetadataSet.h index c5449132..e9794893 100644 --- a/AnnService/inc/Core/MetadataSet.h +++ b/AnnService/inc/Core/MetadataSet.h @@ -22,7 +22,11 @@ class MetadataSet virtual bool Available() const = 0; - virtual ErrorCode SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile) const = 0; + virtual void AddBatch(MetadataSet& data) = 0; + + virtual ErrorCode SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile) = 0; + + static ErrorCode MetaCopy(const std::string& p_src, const std::string& p_dst); }; @@ -39,18 +43,22 @@ class FileMetadataSet : public MetadataSet bool Available() const; - virtual ErrorCode SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile) const; + void AddBatch(MetadataSet& data); + + ErrorCode SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile); private: std::ifstream* m_fp = nullptr; - const long long *m_pOffsets = nullptr; + std::vector m_pOffsets; - const int m_iCount = 0; + int m_count; std::string m_metaFile; std::string m_metaindexFile; + + std::vector m_newdata; }; @@ -67,38 +75,20 @@ class MemMetadataSet : public MetadataSet bool Available() const; - virtual ErrorCode SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile) const; - -private: - const std::uint64_t *m_offsets; + void AddBatch(MetadataSet& data); - ByteArray m_metadataHolder; + ErrorCode SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile); - ByteArray m_offsetHolder; +private: + std::vector m_offsets; SizeType m_count; -}; - -class MetadataSetFileTransfer : public MetadataSet -{ -public: - MetadataSetFileTransfer(const std::string& p_metaFile, const std::string& p_metaindexFile); - - virtual ~MetadataSetFileTransfer(); - - virtual ByteArray GetMetadata(IndexType p_vectorID) const; - - virtual SizeType Count() const; - - virtual bool Available() const; - - virtual ErrorCode SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile) const; + ByteArray m_metadataHolder; -private: - std::string m_metaFile; + ByteArray m_offsetHolder; - std::string m_metaindexFile; + std::vector m_newdata; }; diff --git a/AnnService/inc/Core/SearchQuery.h b/AnnService/inc/Core/SearchQuery.h index 7fd6806b..8b8c5f7b 100644 --- a/AnnService/inc/Core/SearchQuery.h +++ b/AnnService/inc/Core/SearchQuery.h @@ -8,19 +8,15 @@ namespace SPTAG { -struct BasicResult -{ - int Key; - float Dist; - - BasicResult() : Key(-1), Dist(MaxDist) + struct BasicResult { - } + int VID; + float Dist; - BasicResult(int p_key, float p_dist) : Key(p_key), Dist(p_dist) - { - } -}; + BasicResult() : VID(-1), Dist(MaxDist) {} + + BasicResult(int p_vid, float p_dist) : VID(p_vid), Dist(p_dist) {} + }; // Space to save temporary answer, similar with TopKCache @@ -46,6 +42,16 @@ class QueryResult Init(p_target, p_resultNum, p_withMeta); } + + QueryResult(const void* p_target, int p_resultNum, std::vector& p_results) + : m_target(p_target), + m_resultNum(p_resultNum), + m_withMeta(false) + { + p_results.resize(p_resultNum); + m_results.reset(p_results.data()); + } + QueryResult(const QueryResult& p_other) : m_target(p_other.m_target), @@ -130,11 +136,11 @@ class QueryResult } - inline void SetResult(int p_index, int p_key, float p_dist) + inline void SetResult(int p_index, int p_VID, float p_dist) { if (p_index < m_resultNum) { - m_results[p_index].Key = p_key; + m_results[p_index].VID = p_VID; m_results[p_index].Dist = p_dist; } } @@ -176,7 +182,7 @@ class QueryResult { for (int i = 0; i < m_resultNum; i++) { - m_results[i].Key = -1; + m_results[i].VID = -1; m_results[i].Dist = MaxDist; } diff --git a/AnnService/inc/Core/VectorIndex.h b/AnnService/inc/Core/VectorIndex.h index ababafd9..6f648d36 100644 --- a/AnnService/inc/Core/VectorIndex.h +++ b/AnnService/inc/Core/VectorIndex.h @@ -9,12 +9,6 @@ namespace SPTAG { -namespace Helper -{ -class IniReader; -} - - class VectorIndex { public: @@ -24,25 +18,48 @@ class VectorIndex virtual ErrorCode SaveIndex(const std::string& p_folderPath) = 0; - virtual ErrorCode LoadIndex(const std::string& p_folderPath, const Helper::IniReader& p_configReader) = 0; + virtual ErrorCode LoadIndex(const std::string& p_folderPath) = 0; - virtual ErrorCode BuildIndex(std::shared_ptr p_vectorSet, std::shared_ptr p_metadataSet) = 0; + virtual ErrorCode LoadIndexFromMemory(const std::vector& p_indexBlobs) = 0; + + virtual ErrorCode BuildIndex(const void* p_data, int p_vectorNum, int p_dimension) = 0; virtual ErrorCode SearchIndex(QueryResult& p_results) const = 0; - virtual std::string GetParameter(const char* p_param) const = 0; - virtual ErrorCode SetParameter(const char* p_param, const char* p_value) = 0; + virtual ErrorCode AddIndex(const void* p_vectors, int p_vectorNum, int p_dimension) = 0; - virtual std::string GetParameter(const std::string& p_param) const; - virtual ErrorCode SetParameter(const std::string& p_param, const std::string& p_value); + virtual ErrorCode DeleteIndex(const void* p_vectors, int p_vectorNum) = 0; + + virtual ErrorCode RefineIndex(const std::string& p_folderPath) = 0; - virtual ByteArray GetMetadata(IndexType p_vectorID) const = 0; + virtual ErrorCode MergeIndex(const char* p_indexFilePath1, const char* p_indexFilePath2) = 0; - virtual VectorValueType AcceptableQueryValueType() const = 0; + //virtual ErrorCode AddIndexWithID(const void* p_vector, const int& p_id) = 0; + + //virtual ErrorCode DeleteIndexWithID(const void* p_vector, const int& p_id) = 0; virtual int GetFeatureDim() const = 0; + virtual int GetNumSamples() const = 0; + virtual DistCalcMethod GetDistCalcMethod() const = 0; virtual IndexAlgoType GetIndexAlgoType() const = 0; + virtual VectorValueType GetVectorValueType() const = 0; + virtual int GetNumThreads() const = 0; + + virtual std::string GetParameter(const char* p_param) const = 0; + virtual ErrorCode SetParameter(const char* p_param, const char* p_value) = 0; + + virtual ErrorCode BuildIndex(std::shared_ptr p_vectorSet, std::shared_ptr p_metadataSet); + + virtual ErrorCode SearchIndex(const void* p_vector, int p_neighborCount, std::vector& p_results) const; + + virtual ErrorCode AddIndex(std::shared_ptr p_vectorSet, std::shared_ptr p_metadataSet); + + virtual std::string GetParameter(const std::string& p_param) const; + virtual ErrorCode SetParameter(const std::string& p_param, const std::string& p_value); + + virtual ByteArray GetMetadata(IndexType p_vectorID) const; + virtual void SetMetadata(const std::string& p_metadataFilePath, const std::string& p_metadataIndexPath); void SetIndexName(const std::string& p_indexName); @@ -52,8 +69,9 @@ class VectorIndex static ErrorCode LoadIndex(const std::string& p_loaderFilePath, std::shared_ptr& p_vectorIndex); -private: +protected: std::string m_indexName; + std::shared_ptr m_pMetadata; }; diff --git a/AnnService/inc/Core/VectorSet.h b/AnnService/inc/Core/VectorSet.h index a1c35e6b..09a6620a 100644 --- a/AnnService/inc/Core/VectorSet.h +++ b/AnnService/inc/Core/VectorSet.h @@ -13,15 +13,19 @@ class VectorSet virtual ~VectorSet(); + virtual VectorValueType GetValueType() const = 0; + virtual void* GetVector(IndexType p_vectorID) const = 0; virtual void* GetData() const = 0; - virtual VectorValueType ValueType() const = 0; - virtual SizeType Dimension() const = 0; virtual SizeType Count() const = 0; + + virtual bool Available() const = 0; + + virtual ErrorCode Save(const std::string& p_vectorFile) const = 0; }; @@ -35,16 +39,20 @@ class BasicVectorSet : public VectorSet virtual ~BasicVectorSet(); + virtual VectorValueType GetValueType() const; + virtual void* GetVector(IndexType p_vectorID) const; virtual void* GetData() const; - virtual VectorValueType ValueType() const; - virtual SizeType Dimension() const; virtual SizeType Count() const; + virtual bool Available() const; + + virtual ErrorCode Save(const std::string& p_vectorFile) const; + private: ByteArray m_data; diff --git a/AnnService/packages.config b/AnnService/packages.config index 2dbed9b5..424245f6 100644 --- a/AnnService/packages.config +++ b/AnnService/packages.config @@ -7,4 +7,6 @@ + + \ No newline at end of file diff --git a/AnnService/src/Client/main.cpp b/AnnService/src/Client/main.cpp index 2c35ab67..2b3a2f76 100644 --- a/AnnService/src/Client/main.cpp +++ b/AnnService/src/Client/main.cpp @@ -56,7 +56,7 @@ int main(int argc, char** argv) for (const auto& res : indexRes.m_results) { fprintf(stdout, "------------------\n"); - fprintf(stdout, "DocIndex: %d Distance: %f\n", res.Key, res.Dist); + fprintf(stdout, "DocIndex: %d Distance: %f\n", res.VID, res.Dist); if (indexRes.m_results.WithMeta()) { const auto& metadata = indexRes.m_results.GetMetadata(idx); diff --git a/AnnService/src/Core/BKT/BKTIndex.cpp b/AnnService/src/Core/BKT/BKTIndex.cpp index 1e2be638..8f9a1862 100644 --- a/AnnService/src/Core/BKT/BKTIndex.cpp +++ b/AnnService/src/Core/BKT/BKTIndex.cpp @@ -16,53 +16,96 @@ namespace SPTAG { #pragma region Load data points, kd-tree, neighborhood graph template - bool Index::LoadIndex() + ErrorCode Index::LoadIndexFromMemory(const std::vector& p_indexBlobs) { - bool loadedDataPoints = m_pDataPointsMemoryFile != NULL ? LoadDataPoints(m_pDataPointsMemoryFile) : LoadDataPoints(m_sDataPointsFilename); - if (!loadedDataPoints) return false; + if (!LoadDataPoints((char*)p_indexBlobs[0])) return ErrorCode::FailedParseValue; + if (!LoadBKT((char*)p_indexBlobs[1])) return ErrorCode::FailedParseValue; + if (!LoadGraph((char*)p_indexBlobs[2])) return ErrorCode::FailedParseValue; + return ErrorCode::Success; + } - m_iDataSize = m_pSamples.R(); - m_iDataDimension = m_pSamples.C(); + template + ErrorCode Index::LoadIndex(const std::string& p_folderPath) + { + std::string folderPath(p_folderPath); + if (!folderPath.empty() && *(folderPath.rbegin()) != FolderSep) + { + folderPath += FolderSep; + } + + Helper::IniReader p_configReader; + if (ErrorCode::Success != p_configReader.LoadIniFile(folderPath + "/indexloader.ini")) + { + return ErrorCode::FailedOpenFile; + } - bool loadedBKT = m_pBKTMemoryFile != NULL ? LoadBKT(m_pBKTMemoryFile) : LoadBKT(m_sBKTFilename); - if (!loadedBKT) return false; + std::string metadataSection("MetaData"); + if (p_configReader.DoesSectionExist(metadataSection)) + { + std::string metadataFilePath = p_configReader.GetParameter(metadataSection, + "MetaDataFilePath", + std::string()); + std::string metadataIndexFilePath = p_configReader.GetParameter(metadataSection, + "MetaDataIndexPath", + std::string()); - bool loadedGraph = m_pGraphMemoryFile != NULL ? LoadGraph(m_pGraphMemoryFile) : LoadGraph(m_sGraphFilename); - if (!loadedGraph) return false; + m_pMetadata.reset(new FileMetadataSet(folderPath + metadataFilePath, folderPath + metadataIndexFilePath)); - if (m_iRefineIter > 0) { - for (int i = 0; i < m_iRefineIter; i++) RefineRNG(); - SaveRNG(m_sGraphFilename); + if (!m_pMetadata->Available()) + { + std::cerr << "Error: Failed to load metadata." << std::endl; + return ErrorCode::Fail; + } } - return true; + +#define DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr) \ + SetParameter(RepresentStr, \ + p_configReader.GetParameter("Index", \ + RepresentStr, \ + std::string(#DefaultValue)).c_str()); \ + +#include "inc/Core/BKT/ParameterDefinitionList.h" +#undef DefineBKTParameter + + if (DistCalcMethod::Undefined == m_iDistCalcMethod) + { + return ErrorCode::Fail; + } + + if (!LoadDataPoints(folderPath + m_sDataPointsFilename)) return ErrorCode::Fail; + if (!LoadBKT(folderPath + m_sBKTFilename)) return ErrorCode::Fail; + if (!LoadGraph(folderPath + m_sGraphFilename)) return ErrorCode::Fail; + + m_iDataSize = m_pSamples.R(); + m_iDataDimension = m_pSamples.C(); + m_dataUpdateLock.resize(m_iDataSize); + + m_workSpacePool.reset(new COMMON::WorkSpacePool(m_iMaxCheck, GetNumSamples())); + m_workSpacePool->Init(m_iNumberOfThreads); + return ErrorCode::Success; } template bool Index::LoadDataPoints(std::string sDataPointsFileName) { std::cout << "Load Data Points From " << sDataPointsFileName << std::endl; - if (m_iCacheSize >= 0) - m_pSamples.Initialize(0, 0, nullptr, sDataPointsFileName.c_str(), m_iCacheSize); - else - { - FILE * fp = fopen(sDataPointsFileName.c_str(), "rb"); - if (fp == NULL) return false; + FILE * fp = fopen(sDataPointsFileName.c_str(), "rb"); + if (fp == NULL) return false; - int R, C; - fread(&R, sizeof(int), 1, fp); - fread(&C, sizeof(int), 1, fp); + int R, C; + fread(&R, sizeof(int), 1, fp); + fread(&C, sizeof(int), 1, fp); - if (m_iDebugLoad > 0 && R > m_iDebugLoad) R = m_iDebugLoad; + if (m_iDebugLoad > 0 && R > m_iDebugLoad) R = m_iDebugLoad; - m_pSamples.Initialize(R, C); - int i = 0, batch = 10000; - while (i + batch < R) { - fread((m_pSamples)[i], sizeof(T), C * batch, fp); - i += batch; - } - fread((m_pSamples)[i], sizeof(T), C * (R - i), fp); - fclose(fp); + m_pSamples.Initialize(R, C); + int i = 0, batch = 10000; + while (i + batch < R) { + fread((m_pSamples)[i], sizeof(T), C * batch, fp); + i += batch; } + fread((m_pSamples)[i], sizeof(T), C * (R - i), fp); + fclose(fp); std::cout << "Load Data Points (" << m_pSamples.R() << ", " << m_pSamples.C() << ") Finish!" << std::endl; return true; } @@ -181,64 +224,68 @@ namespace SPTAG #pragma region K-NN search template - void Index::SearchIndex(COMMON::QueryResultSet &query, COMMON::WorkSpace &space) const + void Index::SearchIndex(COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space, const tbb::concurrent_unordered_set &p_deleted) const { for (char i = 0; i < m_iBKTNumber; i++) { const BKTNode& node = m_pBKTRoots[m_pBKTStart[i]]; if (node.childStart < 0) { - space.m_SPTQueue.insert(COMMON::HeapCell(m_pBKTStart[i], m_fComputeDistance(query.GetTarget(), (m_pSamples)[node.centerid], m_iDataDimension))); + p_space.m_SPTQueue.insert(COMMON::HeapCell(m_pBKTStart[i], m_fComputeDistance(p_query.GetTarget(), (m_pSamples)[node.centerid], m_iDataDimension))); } else { for (int begin = node.childStart; begin < node.childEnd; begin++) { int index = m_pBKTRoots[begin].centerid; - space.m_SPTQueue.insert(COMMON::HeapCell(begin, m_fComputeDistance(query.GetTarget(), (m_pSamples)[index], m_iDataDimension))); + p_space.m_SPTQueue.insert(COMMON::HeapCell(begin, m_fComputeDistance(p_query.GetTarget(), (m_pSamples)[index], m_iDataDimension))); } } } int checkLimit = g_iNumberOfInitialDynamicPivots; const int checkPos = m_iNeighborhoodSize - 1; - while (!space.m_SPTQueue.empty()) { + while (!p_space.m_SPTQueue.empty()) { do { - COMMON::HeapCell bcell = space.m_SPTQueue.pop(); + COMMON::HeapCell bcell = p_space.m_SPTQueue.pop(); const BKTNode& tnode = m_pBKTRoots[bcell.node]; if (tnode.childStart < 0) { - if (!space.CheckAndSet(tnode.centerid)) { - space.m_iNumberOfCheckedLeaves++; - space.m_NGQueue.insert(COMMON::HeapCell(tnode.centerid, bcell.distance)); + if (!p_space.CheckAndSet(tnode.centerid)) { + p_space.m_iNumberOfCheckedLeaves++; + p_space.m_NGQueue.insert(COMMON::HeapCell(tnode.centerid, bcell.distance)); } - if (space.m_iNumberOfCheckedLeaves >= checkLimit) break; + if (p_space.m_iNumberOfCheckedLeaves >= checkLimit) break; } else { - if (!space.CheckAndSet(tnode.centerid)) { - space.m_NGQueue.insert(COMMON::HeapCell(tnode.centerid, bcell.distance)); + if (!p_space.CheckAndSet(tnode.centerid)) { + p_space.m_NGQueue.insert(COMMON::HeapCell(tnode.centerid, bcell.distance)); } for (int begin = tnode.childStart; begin < tnode.childEnd; begin++) { int index = m_pBKTRoots[begin].centerid; - space.m_SPTQueue.insert(COMMON::HeapCell(begin, m_fComputeDistance(query.GetTarget(), (m_pSamples)[index], m_iDataDimension))); + p_space.m_SPTQueue.insert(COMMON::HeapCell(begin, m_fComputeDistance(p_query.GetTarget(), (m_pSamples)[index], m_iDataDimension))); } } - } while (!space.m_SPTQueue.empty()); - while (!space.m_NGQueue.empty()) { - COMMON::HeapCell gnode = space.m_NGQueue.pop(); + } while (!p_space.m_SPTQueue.empty()); + while (!p_space.m_NGQueue.empty()) { + COMMON::HeapCell gnode = p_space.m_NGQueue.pop(); const int *node = (m_pNeighborhoodGraph)[gnode.node]; _mm_prefetch((const char *)node, _MM_HINT_T0); - if (query.AddPoint(gnode.node, gnode.distance)) { - space.m_iNumOfContinuousNoBetterPropagation = 0; - - int checkNode = node[checkPos]; - if (checkNode < -1) { - const BKTNode& tnode = m_pBKTRoots[-2 - checkNode]; - for (int i = -tnode.childStart; i < tnode.childEnd; i++) { - if (!query.AddPoint(m_pBKTRoots[i].centerid, gnode.distance)) break; + if (p_deleted.find(gnode.node) == p_deleted.end()) { + if (p_query.AddPoint(gnode.node, gnode.distance)) { + p_space.m_iNumOfContinuousNoBetterPropagation = 0; + + int checkNode = node[checkPos]; + if (checkNode < -1) { + const BKTNode& tnode = m_pBKTRoots[-2 - checkNode]; + for (int i = -tnode.childStart; i < tnode.childEnd; i++) { + if (p_deleted.find(m_pBKTRoots[i].centerid) == p_deleted.end()) { + if (!p_query.AddPoint(m_pBKTRoots[i].centerid, gnode.distance)) break; + } + } } } - } - else { - space.m_iNumOfContinuousNoBetterPropagation++; - if (space.m_iNumOfContinuousNoBetterPropagation > space.m_iContinuousLimit || space.m_iNumberOfCheckedLeaves > space.m_iMaxCheck) { - query.SortResult(); return; + else { + p_space.m_iNumOfContinuousNoBetterPropagation++; + if (p_space.m_iNumOfContinuousNoBetterPropagation > p_space.m_iContinuousLimit || p_space.m_iNumberOfCheckedLeaves > p_space.m_iMaxCheck) { + p_query.SortResult(); return; + } } } @@ -254,103 +301,70 @@ namespace SPTAG // do not check it if it has been checked if (nn_index < 0) break; - if (space.CheckAndSet(nn_index)) continue; + if (p_space.CheckAndSet(nn_index)) continue; // count the number of the computed nodes - float distance2leaf = m_fComputeDistance(query.GetTarget(), (m_pSamples)[nn_index], m_iDataDimension); - space.m_iNumberOfCheckedLeaves++; - space.m_NGQueue.insert(COMMON::HeapCell(nn_index, distance2leaf)); + float distance2leaf = m_fComputeDistance(p_query.GetTarget(), (m_pSamples)[nn_index], m_iDataDimension); + p_space.m_iNumberOfCheckedLeaves++; + p_space.m_NGQueue.insert(COMMON::HeapCell(nn_index, distance2leaf)); } - if (space.m_NGQueue.Top().distance >= space.m_SPTQueue.Top().distance) { - checkLimit = g_iNumberOfOtherDynamicPivots + space.m_iNumberOfCheckedLeaves; + if (p_space.m_NGQueue.Top().distance > p_space.m_SPTQueue.Top().distance) { + checkLimit = g_iNumberOfOtherDynamicPivots + p_space.m_iNumberOfCheckedLeaves; break; } } } + p_query.SortResult(); } template ErrorCode - Index::SearchIndex(QueryResult &query) const + Index::SearchIndex(QueryResult &p_query) const { auto workSpace = m_workSpacePool->Rent(); workSpace->Reset(m_iMaxCheck); - SearchIndex(*((COMMON::QueryResultSet*)&query), *workSpace); + SearchIndex(*((COMMON::QueryResultSet*)&p_query), *workSpace, m_deletedID); m_workSpacePool->Return(workSpace); - if (query.WithMeta() && nullptr != m_pMetadata) + if (p_query.WithMeta() && nullptr != m_pMetadata) { - for (int i = 0; i < query.GetResultNum(); ++i) + for (int i = 0; i < p_query.GetResultNum(); ++i) { - query.SetMetadata(i, m_pMetadata->GetMetadata(query.GetResult(i)->Key)); + for (int i = 0; i < p_query.GetResultNum(); ++i) + { + int result = p_query.GetResult(i)->VID; + p_query.SetMetadata(i, (result < 0) ? ByteArray::c_empty : m_pMetadata->GetMetadata(result)); + } } } return ErrorCode::Success; } - #pragma endregion #pragma region Build/Save kd-tree & neighborhood graphs template - bool Index::BuildIndex() - { - if (!LoadDataPoints(m_sDataPointsFilename.c_str())) return false; - m_iDataSize = m_pSamples.R(); - m_iDataDimension = m_pSamples.C(); - - BuildBKT(); - if (!SaveBKT(m_sBKTFilename)) return false; - - BuildRNG(); - if (!SaveRNG(m_sGraphFilename)) return false; - return true; - } - - template - bool Index::BuildIndex(void* p_data, int p_vectorNum, int p_dimension) + ErrorCode Index::BuildIndex(const void* p_data, int p_vectorNum, int p_dimension) { - m_pSamples.Initialize(p_vectorNum, p_dimension, reinterpret_cast(p_data)); + m_pSamples.Initialize(p_vectorNum, p_dimension); + std::memcpy(m_pSamples.GetData(), p_data, p_vectorNum * p_dimension * sizeof(T)); m_iDataSize = m_pSamples.R(); m_iDataDimension = m_pSamples.C(); - - BuildBKT(); - BuildRNG(); - return true; - } - - template - ErrorCode Index::BuildIndex(std::shared_ptr p_vectorSet, - std::shared_ptr p_metadataSet) - { - if (nullptr == p_vectorSet || p_vectorSet->Count() == 0 || p_vectorSet->Dimension() == 0 || p_vectorSet->ValueType() != GetEnumValueType()) - { - return ErrorCode::Fail; - } + m_dataUpdateLock.resize(m_iDataSize); if (DistCalcMethod::Cosine == m_iDistCalcMethod) { - m_pSamples.Initialize(p_vectorSet->Count(), p_vectorSet->Dimension()); - std::memcpy(m_pSamples.GetData(), p_vectorSet->GetData(), p_vectorSet->Count() * p_vectorSet->Dimension() * sizeof(T)); - int base = COMMON::Utils::GetBase(); - for (SPTAG::SizeType i = 0; i < p_vectorSet->Count(); i++) { - COMMON::Utils::Normalize(m_pSamples[i], m_pSamples.C(), base); + for (int i = 0; i < m_iDataSize; i++) { + COMMON::Utils::Normalize(m_pSamples[i], m_iDataDimension, base); } - p_vectorSet.reset(); - } else { - m_pSamples.Initialize(p_vectorSet->Count(), p_vectorSet->Dimension(), reinterpret_cast(p_vectorSet->GetData())); } - - m_iDataSize = m_pSamples.R(); - m_iDataDimension = m_pSamples.C(); - - BuildBKT(); + std::vector indices(m_iDataSize); + for (int i = 0; i < m_iDataSize; i++) indices[i] = i; + BuildBKT(indices, m_pBKTStart, m_pBKTRoots); BuildRNG(); - m_pMetadata = std::move(p_metadataSet); - m_workSpacePool.reset(new COMMON::WorkSpacePool(m_iMaxCheck, GetNumSamples())); m_workSpacePool->Init(m_iNumberOfThreads); return ErrorCode::Success; @@ -358,19 +372,19 @@ namespace SPTAG #pragma region Build/Save kd-tree template - bool Index::SaveBKT(std::string sBKTFilename) const + bool Index::SaveBKT(std::string sBKTFilename, std::vector& newStart, std::vector& newRoot) const { std::cout << "Save BKT to " << sBKTFilename << std::endl; FILE *fp = fopen(sBKTFilename.c_str(), "wb"); if(fp == NULL) return false; fwrite(&m_iBKTNumber, sizeof(int), 1, fp); - fwrite(m_pBKTStart.data(), sizeof(int), m_iBKTNumber, fp); - int treeNodeSize = (int)m_pBKTRoots.size(); + fwrite(newStart.data(), sizeof(int), m_iBKTNumber, fp); + int treeNodeSize = (int)newRoot.size(); fwrite(&treeNodeSize, sizeof(int), 1, fp); for (int i = 0; i < treeNodeSize; i++) { - fwrite(&(m_pBKTRoots[i].centerid), sizeof(int), 1, fp); - fwrite(&(m_pBKTRoots[i].childStart), sizeof(int), 1, fp); - fwrite(&(m_pBKTRoots[i].childEnd), sizeof(int), 1, fp); + fwrite(&(newRoot[i].centerid), sizeof(int), 1, fp); + fwrite(&(newRoot[i].childStart), sizeof(int), 1, fp); + fwrite(&(newRoot[i].childEnd), sizeof(int), 1, fp); } fclose(fp); std::cout << "Save BKT Finish!" << std::endl; @@ -378,69 +392,61 @@ namespace SPTAG } template - void Index::BuildBKT() + void Index::BuildBKT(std::vector& indices, std::vector& newStart, std::vector& newRoot) { omp_set_num_threads(m_iNumberOfThreads); struct BKTStackItem { int index, first, last; BKTStackItem(int index_, int first_, int last_) : index(index_), first(first_), last(last_) {} }; - std::stack ss; - KmeansArgs args(m_iBKTKmeansK, m_iDataDimension, m_iDataSize, m_iNumberOfThreads); - std::vector indices(m_iDataSize); - for (int i = 0; i < m_iDataSize; i++) indices[i] = i; + KmeansArgs args(m_iBKTKmeansK, m_iDataDimension, (int)indices.size(), m_iNumberOfThreads); + m_pSampleToCenter.clear(); for (char i = 0; i < m_iBKTNumber; i++) { std::random_shuffle(indices.begin(), indices.end()); - m_pBKTStart.push_back((int)m_pBKTRoots.size()); - m_pBKTRoots.push_back(BKTNode(m_iDataSize)); + newStart.push_back((int)newRoot.size()); + newRoot.push_back(BKTNode((int)indices.size())); std::cout << "Start to build tree " << i + 1 << std::endl; - ss.push(BKTStackItem(m_pBKTStart[i], 0, m_iDataSize)); + ss.push(BKTStackItem(newStart[i], 0, (int)indices.size())); while (!ss.empty()) { BKTStackItem item = ss.top(); ss.pop(); - int newBKTid = (int)m_pBKTRoots.size(); - m_pBKTRoots[item.index].childStart = newBKTid; + int newBKTid = (int)newRoot.size(); + newRoot[item.index].childStart = newBKTid; if (item.last - item.first <= m_iBKTLeafSize) { - if (item.last == item.first) { - m_pBKTRoots[item.index].childStart = -m_pBKTRoots[item.index].childStart; - } - else { - for (int j = item.first; j < item.last; j++) { - m_pBKTRoots.push_back(BKTNode(indices[j])); - } + for (int j = item.first; j < item.last; j++) { + newRoot.push_back(BKTNode(indices[j])); } } else { // clustering the data into BKTKmeansK clusters int numClusters = KmeansClustering(indices, item.first, item.last, args); if (numClusters <= 1) { - int end = min(item.last + 1, m_iDataSize); + int end = min(item.last + 1, (int)indices.size()); std::sort(indices.begin() + item.first, indices.begin() + end); - m_pBKTRoots[item.index].centerid = indices[item.first]; - m_pBKTRoots[item.index].childStart = -m_pBKTRoots[item.index].childStart; + newRoot[item.index].centerid = indices[item.first]; + newRoot[item.index].childStart = -newRoot[item.index].childStart; for (int j = item.first + 1; j < end; j++) { - m_pBKTRoots.push_back(BKTNode(indices[j])); - m_pSampleToCenter[indices[j]] = m_pBKTRoots[item.index].centerid; + newRoot.push_back(BKTNode(indices[j])); + m_pSampleToCenter[indices[j]] = newRoot[item.index].centerid; } - m_pSampleToCenter[-1 - m_pBKTRoots[item.index].centerid] = item.index; + m_pSampleToCenter[-1 - newRoot[item.index].centerid] = item.index; } else { for (int k = 0; k < m_iBKTKmeansK; k++) { - if (args.counts[k] > 0) { - m_pBKTRoots.push_back(BKTNode(indices[item.first + args.counts[k] - 1])); - ss.push(BKTStackItem(newBKTid++, item.first, item.first + args.counts[k] - 1)); - item.first += args.counts[k]; - } + if (args.counts[k] == 0) continue; + newRoot.push_back(BKTNode(indices[item.first + args.counts[k] - 1])); + if (args.counts[k] > 1) ss.push(BKTStackItem(newBKTid++, item.first, item.first + args.counts[k] - 1)); + item.first += args.counts[k]; } } } - m_pBKTRoots[item.index].childEnd = (int)m_pBKTRoots.size(); + newRoot[item.index].childEnd = (int)newRoot.size(); } - std::cout << i + 1 << " trees built, " << m_pBKTRoots.size() - m_pBKTStart[i] << " " << m_iDataSize << std::endl; + std::cout << i + 1 << " trees built, " << newRoot.size() - newStart[i] << " " << indices.size() << std::endl; } } @@ -678,7 +684,7 @@ namespace SPTAG float bestvariance = Variance[m_iDataDimension - 1].Dist; for (int i = 0; i < m_numTopDimensionTpTreeSplit; i++) { - index[i] = Variance[m_iDataDimension - 1 - i].Key; + index[i] = Variance[m_iDataDimension - 1 - i].VID; bestweight[i] = 0; } bestweight[0] = 1; @@ -788,6 +794,18 @@ namespace SPTAG m_iGraphSize = m_iDataSize; m_pNeighborhoodGraph.Initialize(m_iGraphSize, m_iNeighborhoodSize); + if (m_iGraphSize < 1000) { + std::memset(m_pNeighborhoodGraph.GetData(), -1, m_iGraphSize * m_iNeighborhoodSize * sizeof(int)); + m_iNeighborhoodSize /= graphScale; + RefineRNG(); + for (int i = 0; i < m_iGraphSize; i++) { + if (m_pSampleToCenter.find(-1 - i) != m_pSampleToCenter.end()) + m_pNeighborhoodGraph[i][m_iNeighborhoodSize - 1] = -2 - m_pSampleToCenter[-1 - i]; + } + std::cout << "Build RNG Graph end!" << std::endl; + return; + } + { COMMON::Dataset NeighborhoodDists(m_iGraphSize, m_iNeighborhoodSize); std::vector> TptreeDataIndices(m_iTptreeNumber, std::vector(m_iGraphSize)); @@ -894,8 +912,8 @@ namespace SPTAG #pragma omp parallel for schedule(dynamic) for (int i = 0; i < NSample; i++) { - //int x = Utils::rand_int(m_iGraphSize); - int x = i; + int x = COMMON::Utils::rand_int(m_iGraphSize); + //int x = i; COMMON::QueryResultSet query((m_pSamples)[x], m_iCEF); for (int y = 0; y < m_iGraphSize; y++) { @@ -910,7 +928,7 @@ namespace SPTAG } else { for (int j = 0; j < m_iNeighborhoodSize && j < m_iCEF; j++) { - exact_rng[j] = query.GetResult(j)->Key; + exact_rng[j] = query.GetResult(j)->VID; } for (int j = m_iCEF; j < m_iNeighborhoodSize; j++) exact_rng[j] = -1; } @@ -936,18 +954,209 @@ namespace SPTAG } template - void Index::AddNodes(const T* pData, int num, COMMON::WorkSpace &space) { - m_pSamples.AddBatch(pData, num); - m_pNeighborhoodGraph.AddReserved(num); + ErrorCode Index::RefineIndex(const std::string& p_folderPath) + { + std::string folderPath(p_folderPath); + if (!folderPath.empty() && *(folderPath.rbegin()) != FolderSep) + { + folderPath += FolderSep; + } + + if (!direxists(folderPath.c_str())) + { + mkdir(folderPath.c_str()); + } + tbb::concurrent_unordered_set deleted(m_deletedID.begin(), m_deletedID.end()); + std::vector indices; + std::unordered_map old2new; + int newR = m_iDataSize; + for (int i = 0; i < newR; i++) { + if (deleted.find(i) == deleted.end()) { + indices.push_back(i); + old2new[i] = i; + } + else { + while (deleted.find(newR - 1) != deleted.end() && newR > i) newR--; + if (newR == i) break; + indices.push_back(newR - 1); + old2new[newR - 1] = i; + newR--; + } + } + old2new[-1] = -1; + + std::cout << "Refine... from " << m_iDataSize << "->" << newR << std::endl; + std::ofstream vecOut(folderPath + m_sDataPointsFilename, std::ios::binary); + if (!vecOut.is_open()) return ErrorCode::FailedCreateFile; + vecOut.write((char*)&newR, sizeof(int)); + vecOut.write((char*)&m_iDataDimension, sizeof(int)); + for (int i = 0; i < newR; i++) { + vecOut.write((char*)m_pSamples[indices[i]], sizeof(T)*m_iDataDimension); + } + vecOut.close(); + + if (nullptr != m_pMetadata) + { + std::ofstream metaOut(folderPath + "metadata.bin_tmp", std::ios::binary); + std::ofstream metaIndexOut(folderPath + "metadataIndex.bin", std::ios::binary); + if (!metaOut.is_open() || !metaIndexOut.is_open()) return ErrorCode::FailedCreateFile; + metaIndexOut.write((char*)&newR, sizeof(int)); + std::uint64_t offset = 0; + for (int i = 0; i < newR; i++) { + metaIndexOut.write((char*)&offset, sizeof(std::uint64_t)); + ByteArray meta = m_pMetadata->GetMetadata(indices[i]); + metaOut.write((char*)meta.Data(), sizeof(uint8_t)*meta.Length()); + offset += meta.Length(); + } + metaOut.close(); + metaIndexOut.write((char*)&offset, sizeof(std::uint64_t)); + metaIndexOut.close(); + + SPTAG::MetadataSet::MetaCopy(folderPath + "metadata.bin_tmp", folderPath + "metadata.bin"); + } + + std::vector newRoot; + std::vector newStart; + std::vector tmpindices(indices.begin(), indices.end()); + BuildBKT(tmpindices, newStart, newRoot); +#pragma omp parallel for + for (int i = 0; i < newRoot.size(); i++) { + newRoot[i].centerid = old2new[newRoot[i].centerid]; + } + SaveBKT(folderPath + m_sBKTFilename, newStart, newRoot); + + std::ofstream graphOut(folderPath + m_sGraphFilename, std::ios::binary); + if (!graphOut.is_open()) return ErrorCode::FailedCreateFile; + graphOut.write((char*)&newR, sizeof(int)); + graphOut.write((char*)&m_iNeighborhoodSize, sizeof(int)); + + int *neighbors = new int[m_iNeighborhoodSize]; + COMMON::WorkSpace space; + space.Initialize(m_iMaxCheckForRefineGraph, m_iDataSize); + for (int i = 0; i < newR; i++) { + space.Reset(m_iMaxCheckForRefineGraph); + COMMON::QueryResultSet query((m_pSamples)[indices[i]], m_iCEF); + space.CheckAndSet(indices[i]); + for (int j = 0; j < m_iNeighborhoodSize; j++) { + int index = m_pNeighborhoodGraph[indices[i]][j]; + if (index < 0 || space.CheckAndSet(index)) continue; + space.m_NGQueue.insert(COMMON::HeapCell(index, m_fComputeDistance(query.GetTarget(), m_pSamples[index], m_iDataDimension))); + } + SearchIndex(query, space, deleted); + RebuildRNGNodeNeighbors(neighbors, query.GetResults(), m_iCEF); + for (int j = 0; j < m_iNeighborhoodSize; j++) + neighbors[j] = old2new[neighbors[j]]; + if (m_pSampleToCenter.find(-1 - indices[i]) != m_pSampleToCenter.end()) { + neighbors[m_iNeighborhoodSize - 1] = -2 - m_pSampleToCenter[-1 - indices[i]]; + } + graphOut.write((char*)neighbors, sizeof(int) * m_iNeighborhoodSize); + } + delete[]neighbors; + graphOut.close(); + + return ErrorCode::Success; + } + + template + ErrorCode Index::MergeIndex(const char* p_indexFilePath1, const char* p_indexFilePath2) { + std::string folderPath1(p_indexFilePath1), folderPath2(p_indexFilePath2); + if (!folderPath1.empty() && *(folderPath1.rbegin()) != FolderSep) folderPath1 += FolderSep; + if (!folderPath2.empty() && *(folderPath2.rbegin()) != FolderSep) folderPath2 += FolderSep; + + Helper::IniReader p_configReader1, p_configReader2; + if (ErrorCode::Success != p_configReader1.LoadIniFile(folderPath1 + "/indexloader.ini")) + return ErrorCode::FailedOpenFile; + + if (ErrorCode::Success != p_configReader2.LoadIniFile(folderPath2 + "/indexloader.ini")) + return ErrorCode::FailedOpenFile; + + std::string empty(""); + if (!COMMON::DataUtils::MergeIndex(folderPath1 + p_configReader1.GetParameter("Index", "VectorFilePath", empty), + folderPath1 + p_configReader1.GetParameter("MetaData", "MetaDataFilePath", empty), + folderPath1 + p_configReader1.GetParameter("MetaData", "MetaDataIndexPath", empty), + folderPath2 + p_configReader1.GetParameter("Index", "VectorFilePath", empty), + folderPath2 + p_configReader1.GetParameter("MetaData", "MetaDataFilePath", empty), + folderPath2 + p_configReader1.GetParameter("MetaData", "MetaDataIndexPath", empty))) + return ErrorCode::Fail; + +#define DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr) \ + SetParameter(RepresentStr, \ + p_configReader1.GetParameter("Index", \ + RepresentStr, \ + std::string(#DefaultValue)).c_str()); \ + +#include "inc/Core/BKT/ParameterDefinitionList.h" +#undef DefineBKTParameter + + if (!LoadDataPoints(folderPath1 + p_configReader1.GetParameter("Index", "VectorFilePath", empty))) return ErrorCode::FailedOpenFile; + std::vector indices(m_iDataSize); + for (int i = 0; i < m_iDataSize; i++) indices[i] = i; + BuildBKT(indices, m_pBKTStart, m_pBKTRoots); + BuildRNG(); + + SaveBKT(folderPath1 + p_configReader1.GetParameter("Index", "TreeFilePath", empty), m_pBKTStart, m_pBKTRoots); + SaveRNG(folderPath1 + p_configReader1.GetParameter("Index", "GraphFilePath", empty)); + return ErrorCode::Success; + } + + template + ErrorCode Index::DeleteIndex(const void* p_vectors, int p_vectorNum) { + const T* ptr_v = (const T*)p_vectors; +#pragma omp parallel for schedule(dynamic) + for (int i = 0; i < p_vectorNum; i++) { + COMMON::QueryResultSet query(ptr_v + i * m_iDataDimension, m_iCEF); + SearchIndex(query); + for (int i = 0; i < m_iCEF; i++) { + if (query.GetResult(i)->Dist < 1e-6) { + m_deletedID.insert(query.GetResult(i)->VID); + } + } + } + return ErrorCode::Success; + } + + template + ErrorCode Index::AddIndex(const void* p_vectors, int p_vectorNum, int p_dimension) { + if (m_pBKTRoots.size() == 0) { + return BuildIndex(p_vectors, p_vectorNum, p_dimension); + } + if (p_dimension != m_iDataDimension) return ErrorCode::FailedParseValue; - if (m_pSamples.R() != m_iDataSize + num || m_pNeighborhoodGraph.R() != m_iDataSize + num) - std::cout << "Error m_iDataSize" << std::endl; - m_iDataSize += num; + int begin, end; + { + std::lock_guard lock(m_dataAllocLock); + + m_pSamples.AddBatch((const T*)p_vectors, p_vectorNum); + m_pNeighborhoodGraph.AddBatch(p_vectorNum); + + end = m_iDataSize + p_vectorNum; + if (m_pSamples.R() != end || m_pNeighborhoodGraph.R() != end) { + std::cout << "Memory Error: Cannot alloc space for vectors" << std::endl; + m_pSamples.SetR(m_iDataSize); + m_pNeighborhoodGraph.SetR(m_iDataSize); + return ErrorCode::Fail; + } + begin = m_iDataSize; + m_iDataSize = end; + m_iGraphSize = end; + m_dataUpdateLock.resize(m_iDataSize); + } + if (DistCalcMethod::Cosine == m_iDistCalcMethod) + { + int base = COMMON::Utils::GetBase(); + for (int i = begin; i < end; i++) { + COMMON::Utils::Normalize((T*)m_pSamples[i], m_iDataDimension, base); + } + } - for (int node = m_iDataSize - num; node < m_iDataSize; node++) + auto space = m_workSpacePool->Rent(); + for (int node = begin; node < end; node++) { - RefineRNGNode(node, space, true); + RefineRNGNode(node, *(space.get()), true); } + m_workSpacePool->Return(space); + std::cout << "Add " << p_vectorNum << " vectors" << std::endl; + return ErrorCode::Success; } template @@ -960,7 +1169,7 @@ namespace SPTAG if (index < 0 || space.CheckAndSet(index)) continue; space.m_NGQueue.insert(COMMON::HeapCell(index, m_fComputeDistance(query.GetTarget(), m_pSamples[index], m_iDataDimension))); } - SearchIndex(query, space); + SearchIndex(query, space, m_deletedID); RebuildRNGNodeNeighbors(m_pNeighborhoodGraph[node], query.GetResults(), m_iCEF); if (updateNeighbors) { @@ -968,18 +1177,50 @@ namespace SPTAG for (int j = 0; j < m_iCEF; j++) { BasicResult* item = query.GetResult(j); - if (item->Key < 0) continue; - COMMON::QueryResultSet queryNbs(m_pSamples[item->Key], m_iNeighborhoodSize + 1); - queryNbs.AddPoint(node, item->Dist); + if (item->VID < 0) break; + + int insertID = node; + int* nodes = m_pNeighborhoodGraph[item->VID]; + std::lock_guard lock(m_dataUpdateLock[item->VID]); for (int k = 0; k < m_iNeighborhoodSize; k++) { - int tmpNode = m_pNeighborhoodGraph[item->Key][k]; - if (tmpNode < 0) break; - float distance = m_fComputeDistance(queryNbs.GetTarget(), m_pSamples[tmpNode], m_iDataDimension); - queryNbs.AddPoint(tmpNode, distance); + int tmpNode = nodes[k]; + if (tmpNode < -1) continue; + + if (tmpNode < 0) + { + bool good = true; + for (int t = 0; t < k; t++) { + if (m_fComputeDistance((m_pSamples)[insertID], (m_pSamples)[nodes[t]], m_iDataDimension) < item->Dist) { + good = false; + break; + } + } + if (good) { + nodes[k] = insertID; + } + break; + } + float tmpDist = m_fComputeDistance(m_pSamples[item->VID], m_pSamples[tmpNode], m_iDataDimension); + if (item->Dist < tmpDist || (item->Dist == tmpDist && insertID < tmpNode)) + { + bool good = true; + for (int t = 0; t < k; t++) { + if (m_fComputeDistance((m_pSamples)[insertID], (m_pSamples)[nodes[t]], m_iDataDimension) < item->Dist) { + good = false; + break; + } + } + if (good) { + nodes[k] = insertID; + insertID = tmpNode; + item->Dist = tmpDist; + } + else { + break; + } + } } - queryNbs.SortResult(); - RebuildRNGNodeNeighbors(m_pNeighborhoodGraph[item->Key], queryNbs.GetResults(), m_iNeighborhoodSize + 1); } } } @@ -989,16 +1230,16 @@ namespace SPTAG int count = 0; for (int j = 0; j < numResults && count < m_iNeighborhoodSize; j++) { const BasicResult& item = queryResults[j]; - if (item.Key < 0) continue; + if (item.VID < 0) continue; bool good = true; for (int k = 0; k < count; k++) { - if (m_fComputeDistance((m_pSamples)[nodes[k]], (m_pSamples)[item.Key], m_iDataDimension) < item.Dist) { + if (m_fComputeDistance((m_pSamples)[nodes[k]], (m_pSamples)[item.VID], m_iDataDimension) <= item.Dist) { good = false; break; } } - if (good) nodes[count++] = item.Key; + if (good) nodes[count++] = item.VID; } for (int j = count; j < m_iNeighborhoodSize; j++) nodes[j] = -1; } @@ -1025,15 +1266,6 @@ namespace SPTAG return true; } - template - bool Index::SaveIndex() - { - if (!SaveDataPoints(m_sDataPointsFilename)) return false; - if (!SaveBKT(m_sBKTFilename)) return false; - if (!SaveRNG(m_sGraphFilename)) return false; - return true; - } - template ErrorCode Index::SaveIndex(const std::string& p_folderPath) @@ -1060,13 +1292,10 @@ namespace SPTAG m_sDataPointsFilename = "vectors.bin"; m_sBKTFilename = "tree.bin"; m_sGraphFilename = "graph.bin"; - - if (!SaveDataPoints(folderPath + m_sDataPointsFilename)) return ErrorCode::Fail; - if (!SaveBKT(folderPath + m_sBKTFilename)) return ErrorCode::Fail; - if (!SaveRNG(folderPath + m_sGraphFilename)) return ErrorCode::Fail; - - loaderFile << "[Index]" << std::endl; + std::string metadataFile = "metadata.bin"; + std::string metadataIndexFile = "metadataIndex.bin"; + loaderFile << "[Index]" << std::endl; loaderFile << "IndexAlgoType=" << Helper::Convert::ConvertToString(IndexAlgoType::BKT) << std::endl; loaderFile << "ValueType=" << Helper::Convert::ConvertToString(GetEnumValueType()) << std::endl; loaderFile << std::endl; @@ -1081,11 +1310,6 @@ namespace SPTAG if (nullptr != m_pMetadata) { - std::string metadataFile = "metadata.bin"; - std::string metadataIndexFile = "metadataIndex.bin"; - - m_pMetadata->SaveMetadata(folderPath + metadataFile, folderPath + metadataIndexFile); - loaderFile << "[MetaData]" << std::endl; loaderFile << "MetaDataFilePath=" << metadataFile << std::endl; loaderFile << "MetaDataIndexPath=" << metadataIndexFile << std::endl; @@ -1093,61 +1317,18 @@ namespace SPTAG } loaderFile.close(); - return ErrorCode::Success; - } - - template - ErrorCode Index::LoadIndex(const std::string& p_folderPath, const Helper::IniReader& p_configReader) - { - std::string folderPath(p_folderPath); - if (!folderPath.empty() && *(folderPath.rbegin()) != FolderSep) - { - folderPath += FolderSep; + if (m_deletedID.size() > 0) { + RefineIndex(folderPath); } - - std::string metadataSection("MetaData"); - if (p_configReader.DoesSectionExist(metadataSection)) - { - std::string metadataFilePath = p_configReader.GetParameter(metadataSection, - "MetaDataFilePath", - std::string()); - std::string metadataIndexFilePath = p_configReader.GetParameter(metadataSection, - "MetaDataIndexPath", - std::string()); - - m_pMetadata.reset(new FileMetadataSet(folderPath + metadataFilePath, folderPath + metadataIndexFilePath)); - - if (!m_pMetadata->Available()) + else { + if (!SaveDataPoints(folderPath + m_sDataPointsFilename)) return ErrorCode::Fail; + if (!SaveBKT(folderPath + m_sBKTFilename, m_pBKTStart, m_pBKTRoots)) return ErrorCode::Fail; + if (!SaveRNG(folderPath + m_sGraphFilename)) return ErrorCode::Fail; + if (nullptr != m_pMetadata) { - std::cerr << "Error: Failed to load metadata." << std::endl; - return ErrorCode::Fail; + m_pMetadata->SaveMetadata(folderPath + metadataFile, folderPath + metadataIndexFile); } } - -#define DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr) \ - SetParameter(RepresentStr, \ - p_configReader.GetParameter("Index", \ - RepresentStr, \ - std::string(#DefaultValue)).c_str()); \ - -#include "inc/Core/BKT/ParameterDefinitionList.h" -#undef DefineBKTParameter - - if (DistCalcMethod::Undefined == m_iDistCalcMethod) - { - return ErrorCode::Fail; - } - - if (!LoadDataPoints(folderPath + m_sDataPointsFilename)) return ErrorCode::Fail; - if (!LoadBKT(folderPath + m_sBKTFilename)) return ErrorCode::Fail; - if (!LoadGraph(folderPath + m_sGraphFilename)) return ErrorCode::Fail; - - m_iDataSize = m_pSamples.R(); - m_iDataDimension = m_pSamples.C(); - - m_workSpacePool.reset(new COMMON::WorkSpacePool(m_iMaxCheck, GetNumSamples())); - m_workSpacePool->Init(m_iNumberOfThreads); - return ErrorCode::Success; } #pragma endregion diff --git a/AnnService/src/Core/Common/WorkSpacePool.cpp b/AnnService/src/Core/Common/WorkSpacePool.cpp index 5c9f6e1b..6ac6ee88 100644 --- a/AnnService/src/Core/Common/WorkSpacePool.cpp +++ b/AnnService/src/Core/Common/WorkSpacePool.cpp @@ -13,6 +13,9 @@ WorkSpacePool::WorkSpacePool(int p_maxCheck, int p_vectorCount) WorkSpacePool::~WorkSpacePool() { + for (auto& workSpace : m_workSpacePool) + workSpace.reset(); + m_workSpacePool.clear(); } diff --git a/AnnService/src/Core/KDT/KDTIndex.cpp b/AnnService/src/Core/KDT/KDTIndex.cpp index 8b120a2e..a6db67d1 100644 --- a/AnnService/src/Core/KDT/KDTIndex.cpp +++ b/AnnService/src/Core/KDT/KDTIndex.cpp @@ -16,53 +16,96 @@ namespace SPTAG { #pragma region Load data points, kd-tree, neighborhood graph template - bool Index::LoadIndex() + ErrorCode Index::LoadIndexFromMemory(const std::vector& p_indexBlobs) { - bool loadedDataPoints = m_pDataPointsMemoryFile != NULL ? LoadDataPoints(m_pDataPointsMemoryFile) : LoadDataPoints(m_sDataPointsFilename); - if (!loadedDataPoints) return false; + if (!LoadDataPoints((char*)p_indexBlobs[0])) return ErrorCode::FailedParseValue; + if (!LoadKDT((char*)p_indexBlobs[1])) return ErrorCode::FailedParseValue; + if (!LoadGraph((char*)p_indexBlobs[2])) return ErrorCode::FailedParseValue; + return ErrorCode::Success; + } - m_iDataSize = m_pSamples.R(); - m_iDataDimension = m_pSamples.C(); + template + ErrorCode Index::LoadIndex(const std::string& p_folderPath) + { + std::string folderPath(p_folderPath); + if (!folderPath.empty() && *(folderPath.rbegin()) != FolderSep) + { + folderPath += FolderSep; + } + + Helper::IniReader p_configReader; + if (ErrorCode::Success != p_configReader.LoadIniFile(folderPath + "/indexloader.ini")) + { + return ErrorCode::FailedOpenFile; + } - bool loadedKDT = m_pKDTMemoryFile != NULL ? LoadKDT(m_pKDTMemoryFile) : LoadKDT(m_sKDTFilename); - if (!loadedKDT) return false; + std::string metadataSection("MetaData"); + if (p_configReader.DoesSectionExist(metadataSection)) + { + std::string metadataFilePath = p_configReader.GetParameter(metadataSection, + "MetaDataFilePath", + std::string()); + std::string metadataIndexFilePath = p_configReader.GetParameter(metadataSection, + "MetaDataIndexPath", + std::string()); - bool loadedGraph = m_pGraphMemoryFile != NULL ? LoadGraph(m_pGraphMemoryFile) : LoadGraph(m_sGraphFilename); - if (!loadedGraph) return false; + m_pMetadata.reset(new FileMetadataSet(folderPath + metadataFilePath, folderPath + metadataIndexFilePath)); - if (m_iRefineIter > 0) { - for (int i = 0; i < m_iRefineIter; i++) RefineRNG(); - SaveRNG(m_sGraphFilename); + if (!m_pMetadata->Available()) + { + std::cerr << "Error: Failed to load metadata." << std::endl; + return ErrorCode::Fail; + } } - return true; + +#define DefineKDTParameter(VarName, VarType, DefaultValue, RepresentStr) \ + SetParameter(RepresentStr, \ + p_configReader.GetParameter("Index", \ + RepresentStr, \ + std::string(#DefaultValue)).c_str()); \ + +#include "inc/Core/KDT/ParameterDefinitionList.h" +#undef DefineKDTParameter + + if (DistCalcMethod::Undefined == m_iDistCalcMethod) + { + return ErrorCode::Fail; + } + + if (!LoadDataPoints(folderPath + m_sDataPointsFilename)) return ErrorCode::Fail; + if (!LoadKDT(folderPath + m_sKDTFilename)) return ErrorCode::Fail; + if (!LoadGraph(folderPath + m_sGraphFilename)) return ErrorCode::Fail; + + m_iDataSize = m_pSamples.R(); + m_iDataDimension = m_pSamples.C(); + m_dataUpdateLock.resize(m_iDataSize); + + m_workSpacePool.reset(new COMMON::WorkSpacePool(m_iMaxCheck, GetNumSamples())); + m_workSpacePool->Init(m_iNumberOfThreads); + return ErrorCode::Success; } template bool Index::LoadDataPoints(std::string sDataPointsFileName) { std::cout << "Load Data Points From " << sDataPointsFileName << std::endl; - if (m_iCacheSize >= 0) - m_pSamples.Initialize(0, 0, nullptr, sDataPointsFileName.c_str(), m_iCacheSize); - else - { - FILE * fp = fopen(sDataPointsFileName.c_str(), "rb"); - if (fp == NULL) return false; + FILE * fp = fopen(sDataPointsFileName.c_str(), "rb"); + if (fp == NULL) return false; - int R, C; - fread(&R, sizeof(int), 1, fp); - fread(&C, sizeof(int), 1, fp); + int R, C; + fread(&R, sizeof(int), 1, fp); + fread(&C, sizeof(int), 1, fp); - if (m_iDebugLoad > 0 && R > m_iDebugLoad) R = m_iDebugLoad; + if (m_iDebugLoad > 0 && R > m_iDebugLoad) R = m_iDebugLoad; - m_pSamples.Initialize(R, C); - int i = 0, batch = 10000; - while (i + batch < R) { - fread((m_pSamples)[i], sizeof(T), C * batch, fp); - i += batch; - } - fread((m_pSamples)[i], sizeof(T), C * (R - i), fp); - fclose(fp); + m_pSamples.Initialize(R, C); + int i = 0, batch = 10000; + while (i + batch < R) { + fread((m_pSamples)[i], sizeof(T), C * batch, fp); + i += batch; } + fread((m_pSamples)[i], sizeof(T), C * (R - i), fp); + fclose(fp); std::cout << "Load Data Points (" << m_pSamples.R() << ", " << m_pSamples.C() << ") Finish!" << std::endl; return true; } @@ -92,15 +135,17 @@ namespace SPTAG int realKDTNumber; fread(&realKDTNumber, sizeof(int), 1, fp); if (realKDTNumber < m_iKDTNumber) m_iKDTNumber = realKDTNumber; - m_pKDTStart.clear(); + m_pKDTStart.resize(m_iKDTNumber + 1, -1); for (int i = 0; i < m_iKDTNumber; i++) { - m_pKDTStart.push_back((int)(m_pKDTRoots.size())); int treeNodeSize; fread(&treeNodeSize, sizeof(int), 1, fp); - m_pKDTRoots.resize(m_pKDTRoots.size() + treeNodeSize); - fread(&(m_pKDTRoots[m_pKDTStart[i]]), sizeof(KDTNode), treeNodeSize, fp); + if (treeNodeSize > 0) { + m_pKDTStart[i] = (int)(m_pKDTRoots.size()); + m_pKDTRoots.resize(m_pKDTRoots.size() + treeNodeSize); + fread(&(m_pKDTRoots[m_pKDTStart[i]]), sizeof(KDTNode), treeNodeSize, fp); + } } - m_pKDTStart.push_back((int)(m_pKDTRoots.size())); + if (m_pKDTRoots.size() > 0) m_pKDTStart[m_iKDTNumber] = (int)(m_pKDTRoots.size()); fclose(fp); std::cout << "Load KDT (" << m_iKDTNumber << ", " << m_pKDTRoots.size() << ") Finish!" << std::endl; return true; @@ -176,29 +221,30 @@ namespace SPTAG #pragma region K-NN search template void Index::KDTSearch(const int node, const bool isInit, const float distBound, - COMMON::WorkSpace& space, COMMON::QueryResultSet &query) const { + COMMON::WorkSpace& p_space, COMMON::QueryResultSet &p_query, const tbb::concurrent_unordered_set &p_deleted) const { if (node < 0) { int index = -node - 1; + if (index >= m_iDataSize) return; #ifdef PREFETCH const char* data = (const char *)(m_pSamples[index]); _mm_prefetch(data, _MM_HINT_T0); _mm_prefetch(data + 64, _MM_HINT_T0); #endif - if (space.CheckAndSet(index)) return; + if (p_space.CheckAndSet(index)) return; - float distance = m_fComputeDistance(query.GetTarget(), (T*)data, m_iDataDimension); - query.AddPoint(index, distance); - ++space.m_iNumberOfTreeCheckedLeaves; - ++space.m_iNumberOfCheckedLeaves; - space.m_NGQueue.insert(COMMON::HeapCell(index, distance)); + float distance = m_fComputeDistance(p_query.GetTarget(), (T*)data, m_iDataDimension); + if (p_deleted.find(index) == p_deleted.end()) p_query.AddPoint(index, distance); + ++p_space.m_iNumberOfTreeCheckedLeaves; + ++p_space.m_iNumberOfCheckedLeaves; + p_space.m_NGQueue.insert(COMMON::HeapCell(index, distance)); return; } auto& tnode = m_pKDTRoots[node]; - float diff = (query.GetTarget())[tnode.split_dim] - tnode.split_value; + float diff = (p_query.GetTarget())[tnode.split_dim] - tnode.split_value; float distanceBound = distBound + diff * diff; int otherChild, bestChild; if (diff < 0) @@ -212,30 +258,30 @@ namespace SPTAG bestChild = tnode.right; } - if (!isInit || distanceBound < query.worstDist()) + if (!isInit || distanceBound < p_query.worstDist()) { - space.m_SPTQueue.insert(COMMON::HeapCell(otherChild, distanceBound)); + p_space.m_SPTQueue.insert(COMMON::HeapCell(otherChild, distanceBound)); } - KDTSearch(bestChild, isInit, distBound, space, query); + KDTSearch(bestChild, isInit, distBound, p_space, p_query, p_deleted); } template - void Index::SearchIndex(COMMON::QueryResultSet &query, COMMON::WorkSpace &space) const + void Index::SearchIndex(COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space, const tbb::concurrent_unordered_set &p_deleted) const { for (char i = 0; i < m_iKDTNumber; i++) { - KDTSearch(m_pKDTStart[i], true, 0, space, query); + KDTSearch(m_pKDTStart[i], true, 0, p_space, p_query, p_deleted); } - while (!space.m_SPTQueue.empty() && space.m_iNumberOfCheckedLeaves < g_iNumberOfInitialDynamicPivots) + while (!p_space.m_SPTQueue.empty() && p_space.m_iNumberOfCheckedLeaves < g_iNumberOfInitialDynamicPivots) { - auto& tcell = space.m_SPTQueue.pop(); - if (query.worstDist() < tcell.distance) break; - KDTSearch(tcell.node, true, tcell.distance, space, query); + auto& tcell = p_space.m_SPTQueue.pop(); + if (p_query.worstDist() < tcell.distance) break; + KDTSearch(tcell.node, true, tcell.distance, p_space, p_query, p_deleted); } - while (!space.m_NGQueue.empty()) { + while (!p_space.m_NGQueue.empty()) { bool bLocalOpt = true; - COMMON::HeapCell gnode = space.m_NGQueue.pop(); + COMMON::HeapCell gnode = p_space.m_NGQueue.pop(); const int *node = (m_pNeighborhoodGraph)[gnode.node]; #ifdef PREFETCH @@ -252,123 +298,85 @@ namespace SPTAG // do not check it if it has been checked if (nn_index < 0) break; - if (space.CheckAndSet(nn_index)) continue; + if (p_space.CheckAndSet(nn_index)) continue; // count the number of the computed nodes - float distance2leaf = m_fComputeDistance(query.GetTarget(), (m_pSamples)[nn_index], m_iDataDimension); + float distance2leaf = m_fComputeDistance(p_query.GetTarget(), (m_pSamples)[nn_index], m_iDataDimension); - bool inserted = query.AddPoint(nn_index, distance2leaf); - space.m_iNumberOfCheckedLeaves++; - space.m_NGQueue.insert(COMMON::HeapCell(nn_index, distance2leaf)); - if (inserted || distance2leaf < gnode.distance) bLocalOpt = false; + if (p_deleted.find(nn_index) == p_deleted.end()) p_query.AddPoint(nn_index, distance2leaf); + if (distance2leaf <= p_query.worstDist()|| distance2leaf < gnode.distance) bLocalOpt = false; + p_space.m_iNumberOfCheckedLeaves++; + p_space.m_NGQueue.insert(COMMON::HeapCell(nn_index, distance2leaf)); } - if (bLocalOpt) space.m_iNumOfContinuousNoBetterPropagation++; - else space.m_iNumOfContinuousNoBetterPropagation = 0; + if (bLocalOpt) p_space.m_iNumOfContinuousNoBetterPropagation++; + else p_space.m_iNumOfContinuousNoBetterPropagation = 0; - if (space.m_iNumOfContinuousNoBetterPropagation > g_iThresholdOfNumberOfContinuousNoBetterPropagation) + if (p_space.m_iNumOfContinuousNoBetterPropagation > g_iThresholdOfNumberOfContinuousNoBetterPropagation) { - if (space.m_iNumberOfTreeCheckedLeaves < space.m_iNumberOfCheckedLeaves / 10) + if (p_space.m_iNumberOfTreeCheckedLeaves < p_space.m_iNumberOfCheckedLeaves / 10) { - int nextNumberOfCheckedLeaves = g_iNumberOfOtherDynamicPivots + space.m_iNumberOfCheckedLeaves; - while (!space.m_SPTQueue.empty() && space.m_iNumberOfCheckedLeaves < nextNumberOfCheckedLeaves) + int nextNumberOfCheckedLeaves = g_iNumberOfOtherDynamicPivots + p_space.m_iNumberOfCheckedLeaves; + while (!p_space.m_SPTQueue.empty() && p_space.m_iNumberOfCheckedLeaves < nextNumberOfCheckedLeaves) { - auto& tcell = space.m_SPTQueue.pop(); - KDTSearch(tcell.node, false, tcell.distance, space, query); + auto& tcell = p_space.m_SPTQueue.pop(); + KDTSearch(tcell.node, false, tcell.distance, p_space, p_query, p_deleted); } } - else if (gnode.distance > query.worstDist()) { + else if (gnode.distance > p_query.worstDist()) { break; } } - if (space.m_iNumberOfCheckedLeaves >= space.m_iMaxCheck) break; + if (p_space.m_iNumberOfCheckedLeaves >= p_space.m_iMaxCheck) break; } - query.SortResult(); + p_query.SortResult(); } template ErrorCode - Index::SearchIndex(QueryResult &query) const + Index::SearchIndex(QueryResult &p_query) const { auto workSpace = m_workSpacePool->Rent(); workSpace->Reset(m_iMaxCheck); - SearchIndex(*((COMMON::QueryResultSet*)&query), *workSpace); + SearchIndex(*((COMMON::QueryResultSet*)&p_query), *workSpace, m_deletedID); m_workSpacePool->Return(workSpace); - if (query.WithMeta() && nullptr != m_pMetadata) + if (p_query.WithMeta() && nullptr != m_pMetadata) { - for (int i = 0; i < query.GetResultNum(); ++i) + for (int i = 0; i < p_query.GetResultNum(); ++i) { - query.SetMetadata(i, m_pMetadata->GetMetadata(query.GetResult(i)->Key)); + int result = p_query.GetResult(i)->VID; + p_query.SetMetadata(i, (result < 0) ? ByteArray::c_empty : m_pMetadata->GetMetadata(result)); } } return ErrorCode::Success; } - #pragma endregion #pragma region Build/Save kd-tree & neighborhood graphs template - bool Index::BuildIndex() - { - if (!LoadDataPoints(m_sDataPointsFilename.c_str())) return false; - m_iDataSize = m_pSamples.R(); - m_iDataDimension = m_pSamples.C(); - - BuildKDT(); - if (!SaveKDT(m_sKDTFilename)) return false; - - BuildRNG(); - if (!SaveRNG(m_sGraphFilename)) return false; - return true; - } - - template - bool Index::BuildIndex(void* p_data, int p_vectorNum, int p_dimension) + ErrorCode Index::BuildIndex(const void* p_data, int p_vectorNum, int p_dimension) { - m_pSamples.Initialize(p_vectorNum, p_dimension, reinterpret_cast(p_data)); + m_pSamples.Initialize(p_vectorNum, p_dimension); + std::memcpy(m_pSamples.GetData(), p_data, p_vectorNum * p_dimension * sizeof(T)); m_iDataSize = m_pSamples.R(); m_iDataDimension = m_pSamples.C(); - - BuildKDT(); - BuildRNG(); - return true; - } - - template - ErrorCode Index::BuildIndex(std::shared_ptr p_vectorSet, - std::shared_ptr p_metadataSet) - { - if (nullptr == p_vectorSet || p_vectorSet->Count() == 0 || p_vectorSet->Dimension() == 0 || p_vectorSet->ValueType() != GetEnumValueType()) - { - return ErrorCode::Fail; - } + m_dataUpdateLock.resize(m_iDataSize); if (DistCalcMethod::Cosine == m_iDistCalcMethod) { - m_pSamples.Initialize(p_vectorSet->Count(), p_vectorSet->Dimension()); - std::memcpy(m_pSamples.GetData(), p_vectorSet->GetData(), p_vectorSet->Count() * p_vectorSet->Dimension() * sizeof(T)); - int base = COMMON::Utils::GetBase(); - for (SPTAG::SizeType i = 0; i < p_vectorSet->Count(); i++) { - COMMON::Utils::Normalize(m_pSamples[i], m_pSamples.C(), base); + for (int i = 0; i < m_iDataSize; i++) { + COMMON::Utils::Normalize(m_pSamples[i], m_iDataDimension, base); } - p_vectorSet.reset(); - } - else { - m_pSamples.Initialize(p_vectorSet->Count(), p_vectorSet->Dimension(), reinterpret_cast(p_vectorSet->GetData())); } - m_iDataSize = m_pSamples.R(); - m_iDataDimension = m_pSamples.C(); - - BuildKDT(); + std::vector indices(m_iDataSize); + for (int j = 0; j < m_iDataSize; j++) indices[j] = j; + BuildKDT(indices, m_pKDTStart, m_pKDTRoots); BuildRNG(); - - m_pMetadata = std::move(p_metadataSet); - m_workSpacePool.reset(new COMMON::WorkSpacePool(m_iMaxCheck, GetNumSamples())); m_workSpacePool->Init(m_iNumberOfThreads); return ErrorCode::Success; @@ -376,7 +384,7 @@ namespace SPTAG #pragma region Build/Save kd-tree template - bool Index::SaveKDT(std::string sKDTFilename) const + bool Index::SaveKDT(std::string sKDTFilename, std::vector& newStart, std::vector& newRoot) const { std::cout << "Save KDT to " << sKDTFilename << std::endl; FILE *fp = fopen(sKDTFilename.c_str(), "wb"); @@ -384,9 +392,9 @@ namespace SPTAG fwrite(&m_iKDTNumber, sizeof(int), 1, fp); for (int i = 0; i < m_iKDTNumber; i++) { - int treeNodeSize = m_pKDTStart[i + 1] - m_pKDTStart[i]; + int treeNodeSize = newStart[i + 1] - newStart[i]; fwrite(&treeNodeSize, sizeof(int), 1, fp); - fwrite(&(m_pKDTRoots[m_pKDTStart[i]]), sizeof(KDTNode), treeNodeSize, fp); + if (treeNodeSize > 0) fwrite(&(newRoot[newStart[i]]), sizeof(KDTNode), treeNodeSize, fp); } fclose(fp); std::cout << "Save KDT Finish!" << std::endl; @@ -394,25 +402,30 @@ namespace SPTAG } template - void Index::BuildKDT() + void Index::BuildKDT(std::vector& indices, std::vector& newStart, std::vector& newRoot) { omp_set_num_threads(m_iNumberOfThreads); - m_pKDTRoots.resize(m_iKDTNumber * m_iDataSize); - m_pKDTStart.resize(m_iKDTNumber + 1, (int)(m_pKDTRoots.size())); - + newRoot.resize(m_iKDTNumber * indices.size()); + if (indices.size() > 0) + newStart.resize(m_iKDTNumber + 1, (int)(newRoot.size())); + else + { + newStart.resize(m_iKDTNumber + 1, -1); + return; + } #pragma omp parallel for for (int i = 0; i < m_iKDTNumber; i++) { Sleep(i * 100); std::srand(clock()); - m_pKDTStart[i] = i * m_iDataSize; - std::vector KDTDataIndices(m_iDataSize); - for (int j = 0; j < m_iDataSize; j++) KDTDataIndices[j] = j; - std::random_shuffle(KDTDataIndices.begin(), KDTDataIndices.end()); + + std::vector pindices(indices.begin(), indices.end()); + std::random_shuffle(pindices.begin(), pindices.end()); + newStart[i] = i * (int)pindices.size(); std::cout << "Start to build tree " << i + 1 << std::endl; - int iTreeSize = m_pKDTStart[i]; - DivideTree(m_pKDTRoots.data(), KDTDataIndices, 0, m_iDataSize - 1, m_pKDTStart[i], iTreeSize); - std::cout << i + 1 << " trees built, " << iTreeSize - m_pKDTStart[i] << " " << m_iDataSize << std::endl; + int iTreeSize = newStart[i]; + DivideTree(newRoot.data(), pindices, 0, (int)pindices.size() - 1, newStart[i], iTreeSize); + std::cout << i + 1 << " trees built, " << iTreeSize - newStart[i] << " " << pindices.size() << std::endl; } } @@ -421,8 +434,7 @@ namespace SPTAG int index, int &iTreeSize) { ChooseDivision(pTree[index], indices, first, last); int i = Subdivide(pTree[index], indices, first, last); - - if (i - 1 == first) + if (i - 1 <= first) { pTree[index].left = -indices[first] - 1; } @@ -612,7 +624,7 @@ namespace SPTAG float bestvariance = Variance[m_iDataDimension - 1].Dist; for (int i = 0; i < m_numTopDimensionTPTSplit; i++) { - index[i] = Variance[m_iDataDimension - 1 - i].Key; + index[i] = Variance[m_iDataDimension - 1 - i].VID; bestweight[i] = 0; } bestweight[0] = 1; @@ -722,6 +734,20 @@ namespace SPTAG m_iGraphSize = m_iDataSize; m_pNeighborhoodGraph.Initialize(m_iGraphSize, m_iNeighborhoodSize); + if (m_iGraphSize < 1000) { + std::memset(m_pNeighborhoodGraph.GetData(), -1, m_iGraphSize * m_iNeighborhoodSize * sizeof(int)); + m_iNeighborhoodSize /= graphScale; + + COMMON::WorkSpace space; + space.Initialize(m_iMaxCheckForRefineGraph, m_iGraphSize); + for (int i = 0; i < m_iGraphSize; i++) + { + RefineRNGNode(i, space, true); + } + std::cout << "Build RNG Graph end!" << std::endl; + return; + } + { COMMON::Dataset NeighborhoodDists(m_iGraphSize, m_iNeighborhoodSize); std::vector> TptreeDataIndices(m_iTPTNumber, std::vector(m_iGraphSize)); @@ -802,8 +828,8 @@ namespace SPTAG #pragma omp parallel for schedule(dynamic) for (int i = 0; i < NSample; i++) { - //int x = Utils::rand_int(m_iGraphSize); - int x = i; + int x = COMMON::Utils::rand_int(m_iGraphSize); + //int x = i; COMMON::QueryResultSet query((m_pSamples)[x], m_iCEF); for (int y = 0; y < m_iGraphSize; y++) { @@ -818,7 +844,7 @@ namespace SPTAG } else { for (int j = 0; j < m_iNeighborhoodSize && j < m_iCEF; j++) { - exact_rng[j] = query.GetResult(j)->Key; + exact_rng[j] = query.GetResult(j)->VID; } for (int j = m_iCEF; j < m_iNeighborhoodSize; j++) exact_rng[j] = -1; } @@ -844,18 +870,211 @@ namespace SPTAG } template - void Index::AddNodes(const T* pData, int num, COMMON::WorkSpace &space) { - m_pSamples.AddBatch(pData, num); - m_pNeighborhoodGraph.AddReserved(num); + ErrorCode Index::RefineIndex(const std::string& p_folderPath) + { + std::string folderPath(p_folderPath); + if (!folderPath.empty() && *(folderPath.rbegin()) != FolderSep) + { + folderPath += FolderSep; + } + + if (!direxists(folderPath.c_str())) + { + mkdir(folderPath.c_str()); + } + tbb::concurrent_unordered_set deleted(m_deletedID.begin(), m_deletedID.end()); + std::vector indices; + std::unordered_map old2new; + int newR = m_iDataSize; + for (int i = 0; i < newR; i++) { + if (deleted.find(i) == deleted.end()) { + indices.push_back(i); + old2new[i] = i; + } + else { + while (deleted.find(newR - 1) != deleted.end() && newR > i) newR--; + if (newR == i) break; + indices.push_back(newR - 1); + old2new[newR - 1] = i; + newR--; + } + } + old2new[-1] = -1; + + std::cout << "Refine... from " << m_iDataSize << "->" << newR << std::endl; + std::ofstream vecOut(folderPath + m_sDataPointsFilename, std::ios::binary); + if (!vecOut.is_open()) return ErrorCode::FailedCreateFile; + vecOut.write((char*)&newR, sizeof(int)); + vecOut.write((char*)&m_iDataDimension, sizeof(int)); + for (int i = 0; i < newR; i++) { + vecOut.write((char*)(m_pSamples[indices[i]]), sizeof(T)*m_iDataDimension); + } + vecOut.close(); + + if (nullptr != m_pMetadata) + { + std::ofstream metaOut(folderPath + "metadata.bin_tmp", std::ios::binary); + std::ofstream metaIndexOut(folderPath + "metadataIndex.bin", std::ios::binary); + if (!metaOut.is_open() || !metaIndexOut.is_open()) return ErrorCode::FailedCreateFile; + metaIndexOut.write((char*)&newR, sizeof(int)); + std::uint64_t offset = 0; + for (int i = 0; i < newR; i++) { + metaIndexOut.write((char*)&offset, sizeof(std::uint64_t)); + ByteArray meta = m_pMetadata->GetMetadata(indices[i]); + metaOut.write((char*)meta.Data(), sizeof(uint8_t)*meta.Length()); + offset += meta.Length(); + } + metaOut.close(); + metaIndexOut.write((char*)&offset, sizeof(std::uint64_t)); + metaIndexOut.close(); + + SPTAG::MetadataSet::MetaCopy(folderPath + "metadata.bin_tmp", folderPath + "metadata.bin"); + } + + std::ofstream graphOut(folderPath + m_sGraphFilename, std::ios::binary); + if (!graphOut.is_open()) return ErrorCode::FailedCreateFile; + graphOut.write((char*)&newR, sizeof(int)); + graphOut.write((char*)&m_iNeighborhoodSize, sizeof(int)); + + int *neighbors = new int[m_iNeighborhoodSize]; + COMMON::WorkSpace space; + space.Initialize(m_iMaxCheckForRefineGraph, m_iDataSize); + for (int i = 0; i < newR; i++) { + space.Reset(m_iMaxCheckForRefineGraph); + COMMON::QueryResultSet query((m_pSamples)[indices[i]], m_iCEF); + space.CheckAndSet(indices[i]); + for (int j = 0; j < m_iNeighborhoodSize; j++) { + int index = m_pNeighborhoodGraph[indices[i]][j]; + if (index < 0 || space.CheckAndSet(index)) continue; + space.m_NGQueue.insert(COMMON::HeapCell(index, m_fComputeDistance(query.GetTarget(), m_pSamples[index], m_iDataDimension))); + } + SearchIndex(query, space, deleted); + RebuildRNGNodeNeighbors(neighbors, query.GetResults(), m_iCEF); + for (int j = 0; j < m_iNeighborhoodSize; j++) + neighbors[j] = old2new[neighbors[j]]; + graphOut.write((char*)neighbors, sizeof(int) * m_iNeighborhoodSize); + } + delete[]neighbors; + graphOut.close(); + + std::vector newRoot; + std::vector newStart; + BuildKDT(indices, newStart, newRoot); + +#pragma omp parallel for + for (int i = 0; i < m_iKDTNumber; i++) + { + for (int j = newStart[i]; j < newStart[i+1]; j++) { + if (newRoot[j].left < 0) + newRoot[j].left = -old2new[-newRoot[j].left - 1] - 1; + if (newRoot[j].right < 0) + newRoot[j].right = -old2new[-newRoot[j].right - 1] - 1; + } + } + SaveKDT(folderPath + m_sKDTFilename, newStart, newRoot); + return ErrorCode::Success; + } + + template + ErrorCode Index::MergeIndex(const char* p_indexFilePath1, const char* p_indexFilePath2) { + std::string folderPath1(p_indexFilePath1), folderPath2(p_indexFilePath2); + if (!folderPath1.empty() && *(folderPath1.rbegin()) != FolderSep) folderPath1 += FolderSep; + if (!folderPath2.empty() && *(folderPath2.rbegin()) != FolderSep) folderPath2 += FolderSep; + + Helper::IniReader p_configReader1, p_configReader2; + if (ErrorCode::Success != p_configReader1.LoadIniFile(folderPath1 + "/indexloader.ini")) + return ErrorCode::FailedOpenFile; + + if (ErrorCode::Success != p_configReader2.LoadIniFile(folderPath2 + "/indexloader.ini")) + return ErrorCode::FailedOpenFile; + + std::string empty(""); + if (!COMMON::DataUtils::MergeIndex(folderPath1 + p_configReader1.GetParameter("Index", "VectorFilePath", empty), + folderPath1 + p_configReader1.GetParameter("MetaData", "MetaDataFilePath", empty), + folderPath1 + p_configReader1.GetParameter("MetaData", "MetaDataIndexPath", empty), + folderPath2 + p_configReader1.GetParameter("Index", "VectorFilePath", empty), + folderPath2 + p_configReader1.GetParameter("MetaData", "MetaDataFilePath", empty), + folderPath2 + p_configReader1.GetParameter("MetaData", "MetaDataIndexPath", empty))) + return ErrorCode::Fail; + +#define DefineKDTParameter(VarName, VarType, DefaultValue, RepresentStr) \ + SetParameter(RepresentStr, \ + p_configReader1.GetParameter("Index", \ + RepresentStr, \ + std::string(#DefaultValue)).c_str()); \ - if (m_pSamples.R() != m_iDataSize + num || m_pNeighborhoodGraph.R() != m_iDataSize + num) - std::cout << "Error m_iDataSize" << std::endl; - m_iDataSize += num; +#include "inc/Core/KDT/ParameterDefinitionList.h" +#undef DefineKDTParameter + + if (!LoadDataPoints(folderPath1 + p_configReader1.GetParameter("Index", "VectorFilePath", empty))) return ErrorCode::FailedOpenFile; + std::vector indices(m_iDataSize); + for (int j = 0; j < m_iDataSize; j++) indices[j] = j; + BuildKDT(indices, m_pKDTStart, m_pKDTRoots); + BuildRNG(); - for (int node = m_iDataSize - num; node < m_iDataSize; node++) + SaveKDT(folderPath1 + p_configReader1.GetParameter("Index", "TreeFilePath", empty), m_pKDTStart, m_pKDTRoots); + SaveRNG(folderPath1 + p_configReader1.GetParameter("Index", "GraphFilePath", empty)); + return ErrorCode::Success; + } + + template + ErrorCode Index::DeleteIndex(const void* p_vectors, int p_vectorNum) { + const T* ptr_v = (const T*)p_vectors; +#pragma omp parallel for schedule(dynamic) + for (int i = 0; i < p_vectorNum; i++) { + COMMON::QueryResultSet query(ptr_v + i * m_iDataDimension, m_iCEF); + SearchIndex(query); + for (int i = 0; i < m_iCEF; i++) { + if (query.GetResult(i)->Dist < 1e-6) { + m_deletedID.insert(query.GetResult(i)->VID); + } + } + } + return ErrorCode::Success; + } + + template + ErrorCode Index::AddIndex(const void* p_vectors, int p_vectorNum, int p_dimension) { + if (m_pKDTRoots.size() == 0) { + return BuildIndex(p_vectors, p_vectorNum, p_dimension); + } + if (p_dimension != m_iDataDimension) return ErrorCode::FailedParseValue; + + int begin, end; { - RefineRNGNode(node, space, true); + std::lock_guard lock(m_dataAllocLock); + + m_pSamples.AddBatch((const T*)p_vectors, p_vectorNum); + m_pNeighborhoodGraph.AddBatch(p_vectorNum); + + end = m_iDataSize + p_vectorNum; + if (m_pSamples.R() != end || m_pNeighborhoodGraph.R() != end) { + std::cout << "Memory Error: Cannot alloc space for vectors" << std::endl; + m_pSamples.SetR(m_iDataSize); + m_pNeighborhoodGraph.SetR(m_iDataSize); + return ErrorCode::Fail; + } + begin = m_iDataSize; + m_iDataSize = end; + m_iGraphSize = end; + m_dataUpdateLock.resize(m_iDataSize); + } + if (DistCalcMethod::Cosine == m_iDistCalcMethod) + { + int base = COMMON::Utils::GetBase(); + for (int i = begin; i < end; i++) { + COMMON::Utils::Normalize((T*)m_pSamples[i], m_iDataDimension, base); + } } + + auto space = m_workSpacePool->Rent(); + for (int node = begin; node < end; node++) + { + RefineRNGNode(node, *(space.get()), true); + } + m_workSpacePool->Return(space); + std::cout << "Add " << p_vectorNum << " vectors" << std::endl; + return ErrorCode::Success; } template @@ -868,7 +1087,7 @@ namespace SPTAG if (index < 0 || space.CheckAndSet(index)) continue; space.m_NGQueue.insert(COMMON::HeapCell(index, m_fComputeDistance(query.GetTarget(), m_pSamples[index], m_iDataDimension))); } - SearchIndex(query, space); + SearchIndex(query, space, m_deletedID); RebuildRNGNodeNeighbors(m_pNeighborhoodGraph[node], query.GetResults(), m_iCEF); if (updateNeighbors) { @@ -876,18 +1095,48 @@ namespace SPTAG for (int j = 0; j < m_iCEF; j++) { BasicResult* item = query.GetResult(j); - if (item->Key < 0) continue; - COMMON::QueryResultSet queryNbs(m_pSamples[item->Key], m_iNeighborhoodSize + 1); - queryNbs.AddPoint(node, item->Dist); + if (item->VID < 0) break; + + int insertID = node; + int* nodes = m_pNeighborhoodGraph[item->VID]; + std::lock_guard lock(m_dataUpdateLock[item->VID]); for (int k = 0; k < m_iNeighborhoodSize; k++) { - int tmpNode = m_pNeighborhoodGraph[item->Key][k]; - if (tmpNode < 0) break; - float distance = m_fComputeDistance(queryNbs.GetTarget(), m_pSamples[tmpNode], m_iDataDimension); - queryNbs.AddPoint(tmpNode, distance); + int tmpNode = nodes[k]; + if (tmpNode < 0) + { + bool good = true; + for (int t = 0; t < k; t++) { + if (m_fComputeDistance((m_pSamples)[insertID], (m_pSamples)[nodes[t]], m_iDataDimension) < item->Dist) { + good = false; + break; + } + } + if (good) { + nodes[k] = insertID; + } + break; + } + float tmpDist = m_fComputeDistance(m_pSamples[item->VID], m_pSamples[tmpNode], m_iDataDimension); + if (item->Dist < tmpDist || (item->Dist == tmpDist && insertID < tmpNode)) + { + bool good = true; + for (int t = 0; t < k; t++) { + if (m_fComputeDistance((m_pSamples)[insertID], (m_pSamples)[nodes[t]], m_iDataDimension) < item->Dist) { + good = false; + break; + } + } + if (good) { + nodes[k] = insertID; + insertID = tmpNode; + item->Dist = tmpDist; + } + else { + break; + } + } } - queryNbs.SortResult(); - RebuildRNGNodeNeighbors(m_pNeighborhoodGraph[item->Key], queryNbs.GetResults(), m_iNeighborhoodSize + 1); } } } @@ -897,16 +1146,16 @@ namespace SPTAG int count = 0; for (int j = 0; j < numResults && count < m_iNeighborhoodSize; j++) { const BasicResult& item = queryResults[j]; - if (item.Key < 0) continue; + if (item.VID < 0) continue; bool good = true; for (int k = 0; k < count; k++) { - if (m_fComputeDistance((m_pSamples)[nodes[k]], (m_pSamples)[item.Key], m_iDataDimension) < item.Dist) { + if (m_fComputeDistance((m_pSamples)[nodes[k]], (m_pSamples)[item.VID], m_iDataDimension) <= item.Dist) { good = false; break; } } - if (good) nodes[count++] = item.Key; + if (good) nodes[count++] = item.VID; } for (int j = count; j < m_iNeighborhoodSize; j++) nodes[j] = -1; } @@ -933,15 +1182,6 @@ namespace SPTAG return true; } - template - bool Index::SaveIndex() - { - if (!SaveDataPoints(m_sDataPointsFilename)) return false; - if (!SaveKDT(m_sKDTFilename)) return false; - if (!SaveRNG(m_sGraphFilename)) return false; - return true; - } - template ErrorCode Index::SaveIndex(const std::string& p_folderPath) @@ -968,13 +1208,10 @@ namespace SPTAG m_sDataPointsFilename = "vectors.bin"; m_sKDTFilename = "tree.bin"; m_sGraphFilename = "graph.bin"; - - if (!SaveDataPoints(folderPath + m_sDataPointsFilename)) return ErrorCode::Fail; - if (!SaveKDT(folderPath + m_sKDTFilename)) return ErrorCode::Fail; - if (!SaveRNG(folderPath + m_sGraphFilename)) return ErrorCode::Fail; + std::string metadataFile = "metadata.bin"; + std::string metadataIndexFile = "metadataIndex.bin"; loaderFile << "[Index]" << std::endl; - loaderFile << "IndexAlgoType=" << Helper::Convert::ConvertToString(IndexAlgoType::KDT) << std::endl; loaderFile << "ValueType=" << Helper::Convert::ConvertToString(GetEnumValueType()) << std::endl; loaderFile << std::endl; @@ -986,14 +1223,9 @@ namespace SPTAG #undef DefineKDTParameter loaderFile << std::endl; - + if (nullptr != m_pMetadata) { - std::string metadataFile = "metadata.bin"; - std::string metadataIndexFile = "metadataIndex.bin"; - - m_pMetadata->SaveMetadata(folderPath + metadataFile, folderPath + metadataIndexFile); - loaderFile << "[MetaData]" << std::endl; loaderFile << "MetaDataFilePath=" << metadataFile << std::endl; loaderFile << "MetaDataIndexPath=" << metadataIndexFile << std::endl; @@ -1001,61 +1233,18 @@ namespace SPTAG } loaderFile.close(); - return ErrorCode::Success; - } - - template - ErrorCode Index::LoadIndex(const std::string& p_folderPath, const Helper::IniReader& p_configReader) - { - std::string folderPath(p_folderPath); - if (!folderPath.empty() && *(folderPath.rbegin()) != FolderSep) - { - folderPath += FolderSep; + if (m_deletedID.size() > 0) { + RefineIndex(folderPath); } - - std::string metadataSection("MetaData"); - if (p_configReader.DoesSectionExist(metadataSection)) - { - std::string metadataFilePath = p_configReader.GetParameter(metadataSection, - "MetaDataFilePath", - std::string()); - std::string metadataIndexFilePath = p_configReader.GetParameter(metadataSection, - "MetaDataIndexPath", - std::string()); - - m_pMetadata.reset(new FileMetadataSet(folderPath + metadataFilePath, folderPath + metadataIndexFilePath)); - - if (!m_pMetadata->Available()) + else { + if (!SaveDataPoints(folderPath + m_sDataPointsFilename)) return ErrorCode::Fail; + if (!SaveKDT(folderPath + m_sKDTFilename, m_pKDTStart, m_pKDTRoots)) return ErrorCode::Fail; + if (!SaveRNG(folderPath + m_sGraphFilename)) return ErrorCode::Fail; + if (nullptr != m_pMetadata) { - std::cerr << "Error: Failed to load metadata." << std::endl; - return ErrorCode::Fail; + m_pMetadata->SaveMetadata(folderPath + metadataFile, folderPath + metadataIndexFile); } } - -#define DefineKDTParameter(VarName, VarType, DefaultValue, RepresentStr) \ - SetParameter(RepresentStr, \ - p_configReader.GetParameter("Index", \ - RepresentStr, \ - std::string(#DefaultValue)).c_str()); \ - -#include "inc/Core/KDT/ParameterDefinitionList.h" -#undef DefineKDTParameter - - if (DistCalcMethod::Undefined == m_iDistCalcMethod) - { - return ErrorCode::Fail; - } - - if (!LoadDataPoints(folderPath + m_sDataPointsFilename)) return ErrorCode::Fail; - if (!LoadKDT(folderPath + m_sKDTFilename)) return ErrorCode::Fail; - if (!LoadGraph(folderPath + m_sGraphFilename)) return ErrorCode::Fail; - - m_iDataSize = m_pSamples.R(); - m_iDataDimension = m_pSamples.C(); - - m_workSpacePool.reset(new COMMON::WorkSpacePool(m_iMaxCheck, GetNumSamples())); - m_workSpacePool->Init(m_iNumberOfThreads); - return ErrorCode::Success; } #pragma endregion diff --git a/AnnService/src/Core/MetadataSet.cpp b/AnnService/src/Core/MetadataSet.cpp index 16fafde1..405698ed 100644 --- a/AnnService/src/Core/MetadataSet.cpp +++ b/AnnService/src/Core/MetadataSet.cpp @@ -5,14 +5,11 @@ using namespace SPTAG; -namespace -{ -namespace Local -{ - ErrorCode -CopyFile(const std::string& p_src, const std::string& p_dst) +MetadataSet::MetaCopy(const std::string& p_src, const std::string& p_dst) { + if (p_src == p_dst) return ErrorCode::Success; + std::ifstream src(p_src, std::ios::binary); if (!src.is_open()) { @@ -35,17 +32,12 @@ CopyFile(const std::string& p_src, const std::string& p_dst) dst.write(buf, src.gcount()); } delete[] buf; - src.close(); dst.close(); return ErrorCode::Success; } -} // namespace Local -} // namespace - - MetadataSet::MetadataSet() { } @@ -68,9 +60,9 @@ FileMetadataSet::FileMetadataSet(const std::string& p_metafile, const std::strin return; } - fpidx.read((char *)&m_iCount, sizeof(int)); - m_pOffsets = new long long[m_iCount + 1]; - fpidx.read((char *)m_pOffsets, sizeof(long long) * (m_iCount + 1)); + fpidx.read((char *)&m_count, sizeof(int)); + m_pOffsets.resize(m_count + 1); + fpidx.read((char *)m_pOffsets.data(), sizeof(std::uint64_t) * (m_count + 1)); fpidx.close(); } @@ -82,56 +74,76 @@ FileMetadataSet::~FileMetadataSet() m_fp->close(); delete m_fp; } - delete[] m_pOffsets; } ByteArray FileMetadataSet::GetMetadata(IndexType p_vectorID) const { - long long startoff = m_pOffsets[p_vectorID]; - long long bytes = m_pOffsets[p_vectorID + 1] - startoff; - m_fp->seekg(startoff, std::ios_base::beg); - ByteArray b = ByteArray::Alloc((SizeType)bytes); - m_fp->read((char*)b.Data(), bytes); - return b; + std::uint64_t startoff = m_pOffsets[p_vectorID]; + std::uint64_t bytes = m_pOffsets[p_vectorID + 1] - startoff; + if (p_vectorID < m_count) { + m_fp->seekg(startoff, std::ios_base::beg); + ByteArray b = ByteArray::Alloc((SizeType)bytes); + m_fp->read((char*)b.Data(), bytes); + return b; + } + else { + startoff -= m_pOffsets[m_count]; + return ByteArray((std::uint8_t*)m_newdata.data() + startoff, static_cast(bytes), false); + } } SizeType FileMetadataSet::Count() const { - return m_iCount; + return static_cast(m_pOffsets.size() - 1); } bool FileMetadataSet::Available() const { - return m_fp && m_fp->is_open() && m_pOffsets; + return m_fp && m_fp->is_open() && m_pOffsets.size() > 1; +} + + +void +FileMetadataSet::AddBatch(MetadataSet& data) +{ + for (int i = 0; i < static_cast(data.Count()); i++) + { + ByteArray newdata = data.GetMetadata(i); + m_newdata.insert(m_newdata.end(), newdata.Data(), newdata.Data() + newdata.Length()); + m_pOffsets.push_back(m_pOffsets[m_pOffsets.size() - 1] + newdata.Length()); + } } ErrorCode -FileMetadataSet::SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile) const +FileMetadataSet::SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile) { ErrorCode ret = ErrorCode::Success; - if (p_metaFile != m_metaFile) + m_fp->close(); + ret = MetaCopy(m_metaFile, p_metaFile); + if (ErrorCode::Success != ret) { - m_fp->close(); - ret = Local::CopyFile(m_metaFile, p_metaFile); - if (ErrorCode::Success != ret) - { - return ret; - } - m_fp->open(p_metaFile, std::ifstream::binary); + return ret; } - - if (p_metaindexFile != m_metaindexFile) - { - ret = Local::CopyFile(m_metaindexFile, p_metaindexFile); + if (m_newdata.size() > 0) { + std::ofstream tmpout(p_metaFile, std::ofstream::app|std::ios::binary); + if (!tmpout.is_open()) return ErrorCode::FailedOpenFile; + tmpout.write((char*)m_newdata.data(), m_newdata.size()); + tmpout.close(); } - + m_fp->open(p_metaFile, std::ifstream::binary); + + std::ofstream dst(p_metaindexFile, std::ios::binary); + m_count = static_cast(m_pOffsets.size()) - 1; + m_newdata.clear(); + dst.write((char*)&m_count, sizeof(int)); + dst.write((char*)m_pOffsets.data(), sizeof(std::uint64_t) * m_pOffsets.size()); return ret; } @@ -141,7 +153,8 @@ MemMetadataSet::MemMetadataSet(ByteArray p_metadata, ByteArray p_offsets, SizeTy m_offsetHolder(std::move(p_offsets)), m_count(p_count) { - m_offsets = reinterpret_cast(m_offsetHolder.Data()); + const std::uint64_t* newdata = reinterpret_cast(m_offsetHolder.Data()); + m_offsets.insert(m_offsets.end(), newdata, newdata + p_count + 1); } @@ -159,6 +172,11 @@ MemMetadataSet::GetMetadata(IndexType p_vectorID) const static_cast(m_offsets[p_vectorID + 1] - m_offsets[p_vectorID]), m_metadataHolder.DataHolder()); } + else if (p_vectorID < m_offsets.size() - 1) { + return ByteArray((std::uint8_t*)m_newdata.data() + m_offsets[p_vectorID] - m_offsets[m_count], + static_cast(m_offsets[p_vectorID + 1] - m_offsets[p_vectorID]), + false); + } return ByteArray::c_empty; } @@ -177,9 +195,19 @@ MemMetadataSet::Available() const return m_metadataHolder.Length() > 0 && m_offsetHolder.Length() > 0; } +void +MemMetadataSet::AddBatch(MetadataSet& data) +{ + for (int i = 0; i < static_cast(data.Count()); i++) + { + ByteArray newdata = data.GetMetadata(i); + m_newdata.insert(m_newdata.end(), newdata.Data(), newdata.Data() + newdata.Length()); + m_offsets.push_back(m_offsets[m_offsets.size() - 1] + newdata.Length()); + } +} ErrorCode -MemMetadataSet::SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile) const +MemMetadataSet::SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile) { std::ofstream outputStream; outputStream.open(p_metaFile, std::ios::binary); @@ -190,6 +218,7 @@ MemMetadataSet::SaveMetadata(const std::string& p_metaFile, const std::string& p } outputStream.write(reinterpret_cast(m_metadataHolder.Data()), m_metadataHolder.Length()); + outputStream.write((const char*)m_newdata.data(), sizeof(std::uint8_t)*m_newdata.size()); outputStream.close(); outputStream.open(p_metaindexFile, std::ios::binary); @@ -199,55 +228,11 @@ MemMetadataSet::SaveMetadata(const std::string& p_metaFile, const std::string& p return ErrorCode::FailedCreateFile; } + m_count = static_cast(m_offsets.size()) - 1; outputStream.write(reinterpret_cast(&m_count), sizeof(m_count)); - outputStream.write(reinterpret_cast(m_offsetHolder.Data()), m_offsetHolder.Length()); + outputStream.write(reinterpret_cast(m_offsets.data()), sizeof(std::uint64_t)*m_offsets.size()); outputStream.close(); return ErrorCode::Success; } - -MetadataSetFileTransfer::MetadataSetFileTransfer(const std::string& p_metaFile, const std::string& p_metaindexFile) - : m_metaFile(p_metaFile), - m_metaindexFile(p_metaindexFile) -{ -} - - -MetadataSetFileTransfer::~MetadataSetFileTransfer() -{ -} - - -ByteArray -MetadataSetFileTransfer::GetMetadata(IndexType p_vectorID) const -{ - return ByteArray::c_empty; -} - - -SizeType -MetadataSetFileTransfer::Count() const -{ - return 0; -} - - -bool -MetadataSetFileTransfer::Available() const -{ - return false; -} - - -ErrorCode -MetadataSetFileTransfer::SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile) const -{ - auto ret = Local::CopyFile(m_metaFile, p_metaFile); - if (ErrorCode::Success != ret) - { - return ret; - } - - return Local::CopyFile(m_metaindexFile, p_metaindexFile); -} diff --git a/AnnService/src/Core/VectorIndex.cpp b/AnnService/src/Core/VectorIndex.cpp index aeacb594..341d30c7 100644 --- a/AnnService/src/Core/VectorIndex.cpp +++ b/AnnService/src/Core/VectorIndex.cpp @@ -4,7 +4,7 @@ #include "inc/Helper/SimpleIniReader.h" #include "inc/Core/BKT/Index.h" - +#include "inc/Core/KDT/Index.h" #include @@ -49,6 +49,62 @@ VectorIndex::SetParameter(const std::string& p_param, const std::string& p_value } +void +VectorIndex::SetMetadata(const std::string& p_metadataFilePath, const std::string& p_metadataIndexPath) { + m_pMetadata.reset(new FileMetadataSet(p_metadataFilePath, p_metadataIndexPath)); +} + + +ByteArray +VectorIndex::GetMetadata(IndexType p_vectorID) const { + if (nullptr != m_pMetadata) + { + return m_pMetadata->GetMetadata(p_vectorID); + } + return ByteArray::c_empty; +} + + +ErrorCode +VectorIndex::BuildIndex(std::shared_ptr p_vectorSet, + std::shared_ptr p_metadataSet) +{ + if (nullptr == p_vectorSet || p_vectorSet->Count() == 0 || p_vectorSet->Dimension() == 0 || p_vectorSet->GetValueType() != GetVectorValueType()) + { + return ErrorCode::Fail; + } + + BuildIndex(p_vectorSet->GetData(), p_vectorSet->Count(), p_vectorSet->Dimension()); + m_pMetadata = std::move(p_metadataSet); + return ErrorCode::Success; +} + + +ErrorCode +VectorIndex::SearchIndex(const void* p_vector, int p_neighborCount, std::vector& p_results) const { + QueryResult res(p_vector, p_neighborCount, p_results); + SearchIndex(res); + return ErrorCode::Success; +} + + +ErrorCode +VectorIndex::AddIndex(std::shared_ptr p_vectorSet, std::shared_ptr p_metadataSet) { + if (nullptr == p_vectorSet || p_vectorSet->Count() == 0 || p_vectorSet->Dimension() == 0 || p_vectorSet->GetValueType() != GetVectorValueType()) + { + return ErrorCode::Fail; + } + AddIndex(p_vectorSet->GetData(), p_vectorSet->Count(), p_vectorSet->Dimension()); + if (m_pMetadata == nullptr) { + m_pMetadata = std::move(p_metadataSet); + } + else { + m_pMetadata->AddBatch(*p_metadataSet); + } + return ErrorCode::Success; +} + + std::shared_ptr VectorIndex::CreateInstance(IndexAlgoType p_algo, VectorValueType p_valuetype) { @@ -64,6 +120,19 @@ VectorIndex::CreateInstance(IndexAlgoType p_algo, VectorValueType p_valuetype) case VectorValueType::Name: \ return std::shared_ptr(new BKT::Index); \ +#include "inc/Core/DefinitionList.h" +#undef DefineVectorValueType + + default: break; + } + } + else if (p_algo == IndexAlgoType::KDT) { + switch (p_valuetype) + { +#define DefineVectorValueType(Name, Type) \ + case VectorValueType::Name: \ + return std::shared_ptr(new KDT::Index); \ + #include "inc/Core/DefinitionList.h" #undef DefineVectorValueType @@ -97,7 +166,22 @@ VectorIndex::LoadIndex(const std::string& p_loaderFilePath, std::shared_ptr); \ - p_vectorIndex->LoadIndex(p_loaderFilePath, iniReader); \ + p_vectorIndex->LoadIndex(p_loaderFilePath); \ + break; \ + +#include "inc/Core/DefinitionList.h" +#undef DefineVectorValueType + + default: break; + } + } + else if (algoType == IndexAlgoType::KDT) { + switch (valueType) + { +#define DefineVectorValueType(Name, Type) \ + case VectorValueType::Name: \ + p_vectorIndex.reset(new KDT::Index); \ + p_vectorIndex->LoadIndex(p_loaderFilePath); \ break; \ #include "inc/Core/DefinitionList.h" diff --git a/AnnService/src/Core/VectorSet.cpp b/AnnService/src/Core/VectorSet.cpp index b7e778c6..99f801fe 100644 --- a/AnnService/src/Core/VectorSet.cpp +++ b/AnnService/src/Core/VectorSet.cpp @@ -2,6 +2,7 @@ using namespace SPTAG; +#pragma warning(disable:4996) // 'fopen': This function or variable may be unsafe. Consider using fopen_s instead. To disable deprecation, use _CRT_SECURE_NO_WARNINGS. See online help for details. VectorSet::VectorSet() { @@ -31,6 +32,13 @@ BasicVectorSet::~BasicVectorSet() } +VectorValueType +BasicVectorSet::GetValueType() const +{ + return m_valueType; +} + + void* BasicVectorSet::GetVector(IndexType p_vectorID) const { @@ -50,14 +58,6 @@ BasicVectorSet::GetData() const return reinterpret_cast(m_data.Data()); } - -VectorValueType -BasicVectorSet::ValueType() const -{ - return m_valueType; -} - - SizeType BasicVectorSet::Dimension() const { @@ -70,3 +70,25 @@ BasicVectorSet::Count() const { return m_vectorCount; } + + +bool +BasicVectorSet::Available() const +{ + return m_data.Data() != nullptr; +} + + +ErrorCode +BasicVectorSet::Save(const std::string& p_vectorFile) const +{ + FILE * fp = fopen(p_vectorFile.c_str(), "wb"); + if (fp == NULL) return ErrorCode::FailedOpenFile; + + fwrite(&m_vectorCount, sizeof(int), 1, fp); + fwrite(&m_dimension, sizeof(int), 1, fp); + + fwrite((const void*)(m_data.Data()), m_data.Length(), 1, fp); + fclose(fp); + return ErrorCode::Success; +} diff --git a/AnnService/src/Helper/ArgumentsParser.cpp b/AnnService/src/Helper/ArgumentsParser.cpp index 12c4149b..b20df09c 100644 --- a/AnnService/src/Helper/ArgumentsParser.cpp +++ b/AnnService/src/Helper/ArgumentsParser.cpp @@ -41,9 +41,8 @@ ArgumentsParser::Parse(int p_argc, char** p_args) if (last == p_argc) { - fprintf(stderr, "Unrecognized arg \"%s\"\n", *p_args); - PrintHelp(); - return false; + p_argc -= 1; + p_args += 1; } } diff --git a/AnnService/src/Helper/CommonHelper.cpp b/AnnService/src/Helper/CommonHelper.cpp index 29556b4d..4d472a95 100644 --- a/AnnService/src/Helper/CommonHelper.cpp +++ b/AnnService/src/Helper/CommonHelper.cpp @@ -89,12 +89,14 @@ StrUtils::StartsWith(const char* p_str, const char* p_prefix) return false; } - while ('\0' == (*p_prefix) && '\0' != (*p_str)) + while ('\0' != (*p_prefix) && '\0' != (*p_str)) { if (*p_prefix != *p_str) { return false; } + ++p_prefix; + ++p_str; } return '\0' == *p_prefix; diff --git a/AnnService/src/IndexBuilder/Options.cpp b/AnnService/src/IndexBuilder/Options.cpp index 815fe6f9..cdfe7307 100644 --- a/AnnService/src/IndexBuilder/Options.cpp +++ b/AnnService/src/IndexBuilder/Options.cpp @@ -8,7 +8,7 @@ using namespace SPTAG::IndexBuilder; BuilderOptions::BuilderOptions() - : m_threadNum(4), + : m_threadNum(32), m_inputValueType(VectorValueType::Float), m_vectorDelimiter("|") { diff --git a/AnnService/src/IndexBuilder/VectorSetReaders/DefaultReader.cpp b/AnnService/src/IndexBuilder/VectorSetReaders/DefaultReader.cpp index 553209f5..30639bcb 100644 --- a/AnnService/src/IndexBuilder/VectorSetReaders/DefaultReader.cpp +++ b/AnnService/src/IndexBuilder/VectorSetReaders/DefaultReader.cpp @@ -254,7 +254,7 @@ DefaultReader::GetVectorSet() const std::shared_ptr DefaultReader::GetMetadataSet() const { - return std::shared_ptr(new MetadataSetFileTransfer(m_metadataConentOutput, m_metadataIndexOutput)); + return std::shared_ptr(new FileMetadataSet(m_metadataConentOutput, m_metadataIndexOutput)); } diff --git a/AnnService/src/IndexBuilder/main.cpp b/AnnService/src/IndexBuilder/main.cpp index 7f686e6d..d370409d 100644 --- a/AnnService/src/IndexBuilder/main.cpp +++ b/AnnService/src/IndexBuilder/main.cpp @@ -2,6 +2,7 @@ #include "inc/IndexBuilder/Options.h" #include "inc/IndexBuilder/VectorSetReader.h" #include "inc/Core/VectorIndex.h" +#include "inc/Core/Common.h" #include "inc/Helper/SimpleIniReader.h" #include @@ -16,36 +17,77 @@ int main(int argc, char* argv[]) { exit(1); } - IndexBuilder::ThreadPool::Init(options->m_threadNum); + auto indexBuilder = VectorIndex::CreateInstance(options->m_indexAlgoType, options->m_inputValueType); - auto vectorReader = IndexBuilder::VectorSetReader::CreateInstance(options); - if (ErrorCode::Success != vectorReader->LoadFile(options->m_inputFiles)) + Helper::IniReader iniReader; + if (!options->m_builderConfigFile.empty()) { - fprintf(stderr, "Failed to read input file.\n"); - exit(1); + iniReader.LoadIniFile(options->m_builderConfigFile); } - auto indexBuilder = VectorIndex::CreateInstance(options->m_indexAlgoType, options->m_inputValueType); - if (!options->m_builderConfigFile.empty()) + for (int i = 1; i < argc; i++) { - Helper::IniReader iniReader; - iniReader.LoadIniFile(options->m_builderConfigFile); + std::string param(argv[i]); + int idx = (int)param.find("="); + if (idx < 0) continue; + + std::string paramName = param.substr(0, idx); + std::string paramVal = param.substr(idx + 1); + std::string sectionName; + idx = paramName.find("."); + if (idx >= 0) { + sectionName = paramName.substr(0, idx); + paramName = paramName.substr(idx + 1); + } + iniReader.SetParameter(sectionName, paramName, paramVal); + std::cout << "Set [" << sectionName << "]" << paramName << " = " << paramVal << std::endl; + } + + if (!iniReader.DoesParameterExist("Index", "NumberOfThreads")) { + iniReader.SetParameter("Index", "NumberOfThreads", std::to_string(options->m_threadNum)); + } + for (const auto& iter : iniReader.GetParameters("Index")) + { + indexBuilder->SetParameter(iter.first.c_str(), iter.second.c_str()); + } - for (const auto& iter : iniReader.GetParameters("Index")) + ErrorCode code; + if (options->m_inputFiles.find("BIN:") == 0) { + options->m_inputFiles = options->m_inputFiles.substr(4); + std::ifstream inputStream(options->m_inputFiles, std::ifstream::binary); + if (!inputStream.is_open()) { + fprintf(stderr, "Failed to read input file.\n"); + exit(1); + } + int row, col; + inputStream.read((char*)&row, sizeof(int)); + inputStream.read((char*)&col, sizeof(int)); + std::uint64_t totalRecordVectorBytes = ((std::uint64_t)GetValueTypeSize(options->m_inputValueType)) * row * col; + ByteArray vectorSet = ByteArray::Alloc(totalRecordVectorBytes); + char* vecBuf = reinterpret_cast(vectorSet.Data()); + inputStream.read(vecBuf, totalRecordVectorBytes); + inputStream.close(); + + std::shared_ptr p_vectorSet(new BasicVectorSet(vectorSet, options->m_inputValueType, col, row)); + code = indexBuilder->BuildIndex(p_vectorSet, nullptr); + indexBuilder->SaveIndex(options->m_outputFolder); + } + else { + auto vectorReader = IndexBuilder::VectorSetReader::CreateInstance(options); + if (ErrorCode::Success != vectorReader->LoadFile(options->m_inputFiles)) { - indexBuilder->SetParameter(iter.first.c_str(), iter.second.c_str()); + fprintf(stderr, "Failed to read input file.\n"); + exit(1); } + code = indexBuilder->BuildIndex(vectorReader->GetVectorSet(), vectorReader->GetMetadataSet()); + indexBuilder->SaveIndex(options->m_outputFolder); } - if (ErrorCode::Success != indexBuilder->BuildIndex(vectorReader->GetVectorSet(), - vectorReader->GetMetadataSet())) + if (ErrorCode::Success != code) { fprintf(stderr, "Failed to build index.\n"); exit(1); } - - indexBuilder->SaveIndex(options->m_outputFolder); - return 0; } diff --git a/AnnService/src/Server/SearchExecutor.cpp b/AnnService/src/Server/SearchExecutor.cpp index 6bbc0fd4..d89cc094 100644 --- a/AnnService/src/Server/SearchExecutor.cpp +++ b/AnnService/src/Server/SearchExecutor.cpp @@ -47,7 +47,7 @@ SearchExecutor::ExecuteInternal() const auto& firstIndex = m_selectedIndex.front(); - if (ErrorCode::Success != m_executionContext->ExtractVector(firstIndex->AcceptableQueryValueType())) + if (ErrorCode::Success != m_executionContext->ExtractVector(firstIndex->GetVectorValueType())) { return; } @@ -63,7 +63,7 @@ SearchExecutor::ExecuteInternal() for (const auto& vectorIndex : m_selectedIndex) { - if (vectorIndex->AcceptableQueryValueType() != firstIndex->AcceptableQueryValueType() + if (vectorIndex->GetVectorValueType() != firstIndex->GetVectorValueType() || vectorIndex->GetFeatureDim() != firstIndex->GetFeatureDim()) { continue; diff --git a/AnnService/src/Server/SearchService.cpp b/AnnService/src/Server/SearchService.cpp index 62ea440f..fd9349d3 100644 --- a/AnnService/src/Server/SearchService.cpp +++ b/AnnService/src/Server/SearchService.cpp @@ -181,7 +181,7 @@ SearchService::RunInteractiveMode() for (const auto& res : result.m_results) { fprintf(stdout, "------------------\n"); - fprintf(stdout, "DocIndex: %d Distance: %f", res.Key, res.Dist); + fprintf(stdout, "DocIndex: %d Distance: %f", res.VID, res.Dist); if (result.m_results.WithMeta()) { const auto& metadata = result.m_results.GetMetadata(idx); diff --git a/AnnService/src/Socket/RemoteSearchQuery.cpp b/AnnService/src/Socket/RemoteSearchQuery.cpp index 8002dcf4..ab1f7285 100644 --- a/AnnService/src/Socket/RemoteSearchQuery.cpp +++ b/AnnService/src/Socket/RemoteSearchQuery.cpp @@ -105,7 +105,7 @@ RemoteSearchResult::EstimateBufferSize() const for (const auto& res : indexRes.m_results) { - sum += SimpleSerialization::EstimateBufferSize(res.Key); + sum += SimpleSerialization::EstimateBufferSize(res.VID); sum += SimpleSerialization::EstimateBufferSize(res.Dist); } @@ -139,7 +139,7 @@ RemoteSearchResult::Write(std::uint8_t* p_buffer) const for (const auto& res : indexRes.m_results) { - p_buffer = SimpleSerialization::SimpleWriteBuffer(res.Key, p_buffer); + p_buffer = SimpleSerialization::SimpleWriteBuffer(res.VID, p_buffer); p_buffer = SimpleSerialization::SimpleWriteBuffer(res.Dist, p_buffer); } @@ -188,7 +188,7 @@ RemoteSearchResult::Read(const std::uint8_t* p_buffer) indexRes.m_results.Init(nullptr, resNum, withMeta); for (auto& res : indexRes.m_results) { - p_buffer = SimpleSerialization::SimpleReadBuffer(p_buffer, res.Key); + p_buffer = SimpleSerialization::SimpleReadBuffer(p_buffer, res.VID); p_buffer = SimpleSerialization::SimpleReadBuffer(p_buffer, res.Dist); } diff --git a/CMakeLists.txt b/CMakeLists.txt index a19f8424..617dd76f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -61,6 +61,37 @@ endif() set (CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}) +find_package(OpenMP) +if (OpenMP_FOUND) + set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") + set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") + set (CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}") + message (STATUS "Found openmp.") +else() + message (FATAL_ERROR "Could no find openmp!") +endif() + +find_package(Boost 1.67 COMPONENTS system thread serialization wserialization regex) +if (Boost_FOUND) + include_directories (${Boost_INCLUDE_DIR}) + link_directories (${Boost_LIBRARY_DIR} "/usr/lib") + message (STATUS "Found Boost.") + message (STATUS "Include Path: ${Boost_INCLUDE_DIRS}") + message (STATUS "Library Path: ${Boost_LIBRARY_DIRS}") + message (STATUS "Library: ${Boost_LIBRARIES}") +else() + message (FATAL_ERROR "Could not find Boost 1.67!") +endif() + +find_library(TBB_LIBRARIES libtbb${CMAKE_SHARED_LIBRARY_SUFFIX}) +if (TBB_LIBRARIES) + message (STATUS "Found TBB.") + message (STATUS "Library: ${TBB_LIBRARIES}") +else() + message (FATAL_ERROR "Could not find TBB!") +endif() + add_subdirectory (AnnService) add_subdirectory (PythonWrapper) add_subdirectory (Test) +add_subdirectory (Search) diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..d1ca00f2 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ + MIT License + + Copyright (c) Microsoft Corporation. All rights reserved. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE \ No newline at end of file diff --git a/PythonWrapper/CMakeLists.txt b/PythonWrapper/CMakeLists.txt index 54e2d437..261c4181 100644 --- a/PythonWrapper/CMakeLists.txt +++ b/PythonWrapper/CMakeLists.txt @@ -1,42 +1,32 @@ find_package(Python2 COMPONENTS Development) -if (Python2_FOUND) - if (WIN32) - set(PY_SUFFIX .pyd) - else() - set(PY_SUFFIX .so) - endif() +if (Python2_FOUND) include_directories (${Python2_INCLUDE_DIRS}) link_directories (${Python2_LIBRARY_DIRS}) message (STATUS "Found Python.") message (STATUS "Include Path: ${Python2_INCLUDE_DIRS}") message (STATUS "Library Path: ${Python2_LIBRARIES}") - message (STATUS "Suffix: ${PY_SUFFIX}") + set (Python_LIBRARIES ${Python2_LIBRARIES}) else() - message (FATAL_ERROR "Could not find Python 2.7!") + message (STATUS "Could not find Python 2.7!") + find_package(Python3 COMPONENTS Development) + if (Python3_FOUND) + include_directories (${Python3_INCLUDE_DIRS}) + link_directories (${Python3_LIBRARY_DIRS}) + message (STATUS "Found Python.") + message (STATUS "Include Path: ${Python3_INCLUDE_DIRS}") + message (STATUS "Library Path: ${Python3_LIBRARIES}") + set (Python_LIBRARIES ${Python3_LIBRARIES}) + else () + message (FATAL_ERROR "Could not find python2 or python3!") + endif() endif() -find_package(OpenMP) -if (OpenMP_FOUND) - set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") - set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") - set (CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}") - message (STATUS "Found openmp.") +if (WIN32) + set(PY_SUFFIX .pyd) else() - message (FATAL_ERROR "Could no find openmp!") + set(PY_SUFFIX .so) endif() -find_package(Boost 1.67 COMPONENTS system thread serialization wserialization regex) -if (Boost_FOUND) - include_directories (${Boost_INCLUDE_DIR}) - link_directories (${Boost_LIBRARY_DIR}) - message (STATUS "Found Boost.") - message (STATUS "Include Path: ${Boost_INCLUDE_DIRS}") - message (STATUS "Library Path: ${Boost_LIBRARY_DIRS}") - message (STATUS "Library: ${Boost_LIBRARIES}") -else() - message (FATAL_ERROR "Could not find Boost 1.67!") -endif() - execute_process(COMMAND swig -l${PROJECT_SOURCE_DIR}/PythonWrapper/inc/PyByteArray.i -python -c++ ${PROJECT_SOURCE_DIR}/PythonWrapper/inc/CoreInterface.h) execute_process(COMMAND swig -l${PROJECT_SOURCE_DIR}/PythonWrapper/inc/PyByteArray.i -python -c++ ${PROJECT_SOURCE_DIR}/PythonWrapper/inc/ClientInterface.h) @@ -47,14 +37,14 @@ file(GLOB CORE_HDR_FILES ${PROJECT_SOURCE_DIR}/PythonWrapper/inc/CoreInterface.h file(GLOB CORE_SRC_FILES ${PROJECT_SOURCE_DIR}/PythonWrapper/src/CoreInterface.cpp ${PROJECT_SOURCE_DIR}/PythonWrapper/inc/CoreInterface_wrap.cxx) add_library (_SPTAG SHARED ${CORE_SRC_FILES} ${CORE_HDR_FILES}) set_target_properties(_SPTAG PROPERTIES SUFFIX ${PY_SUFFIX}) -target_link_libraries(_SPTAG SPTAGLib ${Python2_LIBRARIES}) +target_link_libraries(_SPTAG SPTAGLib ${Python_LIBRARIES} ${TBB_LIBRARIES}) add_custom_command(TARGET _SPTAG POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${PROJECT_SOURCE_DIR}/PythonWrapper/inc/SPTAG.py ${EXECUTABLE_OUTPUT_PATH}) file(GLOB CLIENT_HDR_FILES ${PROJECT_SOURCE_DIR}/PythonWrapper/inc/ClientInterface.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Socket/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Client/*.h) file(GLOB CLIENT_SRC_FILES ${PROJECT_SOURCE_DIR}/PythonWrapper/src/ClientInterface.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Socket/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Client/*.cpp ${PROJECT_SOURCE_DIR}/PythonWrapper/inc/ClientInterface_wrap.cxx) add_library (_SPTAGClient SHARED ${CLIENT_SRC_FILES} ${CLIENT_HDR_FILES}) set_target_properties(_SPTAGClient PROPERTIES SUFFIX ${PY_SUFFIX}) -target_link_libraries(_SPTAGClient SPTAGLib ${Python2_LIBRARIES} ${Boost_LIBRARIES}) +target_link_libraries(_SPTAGClient SPTAGLib ${Python_LIBRARIES} ${Boost_LIBRARIES} ${TBB_LIBRARIES}) add_custom_command(TARGET _SPTAGClient POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${PROJECT_SOURCE_DIR}/PythonWrapper/inc/SPTAGClient.py ${EXECUTABLE_OUTPUT_PATH}) install(TARGETS _SPTAG _SPTAGClient diff --git a/PythonWrapper/PythonClient.vcxproj b/PythonWrapper/PythonClient.vcxproj index 30fad14a..fd745df4 100644 --- a/PythonWrapper/PythonClient.vcxproj +++ b/PythonWrapper/PythonClient.vcxproj @@ -168,7 +168,7 @@ - + diff --git a/PythonWrapper/PythonCore.vcxproj b/PythonWrapper/PythonCore.vcxproj index a895adaf..4e23ba73 100644 --- a/PythonWrapper/PythonCore.vcxproj +++ b/PythonWrapper/PythonCore.vcxproj @@ -113,9 +113,10 @@ + - + @@ -127,5 +128,6 @@ + \ No newline at end of file diff --git a/PythonWrapper/inc/CoreInterface.h b/PythonWrapper/inc/CoreInterface.h index f242f48d..2a37e3c9 100644 --- a/PythonWrapper/inc/CoreInterface.h +++ b/PythonWrapper/inc/CoreInterface.h @@ -17,7 +17,6 @@ %include %shared_ptr(AnnIndex) %shared_ptr(QueryResult) - %include "PyByteArray.i" %{ @@ -44,6 +43,8 @@ class AnnIndex bool Build(ByteArray p_data, SizeType p_num); + bool BuildWithMetaData(ByteArray p_data, ByteArray p_meta, SizeType p_num); + std::shared_ptr Search(ByteArray p_data, SizeType p_resultNum); std::shared_ptr SearchWithMetaData(ByteArray p_data, SizeType p_resultNum); @@ -52,6 +53,14 @@ class AnnIndex bool Save(const char* p_saveFile) const; + bool Add(ByteArray p_data, SizeType p_num); + + bool AddWithMetaData(ByteArray p_data, ByteArray p_meta, SizeType p_num); + + bool Delete(ByteArray p_data, SizeType p_num); + + bool Refine(const char* p_loaderFile); + static AnnIndex Load(const char* p_loaderFile); private: diff --git a/PythonWrapper/inc/PyByteArray.i b/PythonWrapper/inc/PyByteArray.i index a8e0040a..7c3b19b8 100644 --- a/PythonWrapper/inc/PyByteArray.i +++ b/PythonWrapper/inc/PyByteArray.i @@ -11,7 +11,7 @@ int i = 0; for (const auto& res : *($1)) { - PyList_SetItem(dstVecIDs, i, PyInt_FromLong(res.Key)); + PyList_SetItem(dstVecIDs, i, PyInt_FromLong(res.VID)); PyList_SetItem(dstVecDists, i, PyFloat_FromDouble(res.Dist)); i++; } @@ -43,7 +43,7 @@ { for (const auto& res : indexRes.m_results) { - PyList_Append(dstVecIDs, PyInt_FromLong(res.Key)); + PyList_Append(dstVecIDs, PyInt_FromLong(res.VID)); PyList_Append(dstVecDists, PyFloat_FromDouble(res.Dist)); } @@ -63,7 +63,7 @@ } %} -%typemap(in) ByteArray p_data +%typemap(in) ByteArray %{ $1 = SPTAG::ByteArray((std::uint8_t*)PyBytes_AsString($input), PyBytes_Size($input), false); %} diff --git a/PythonWrapper/src/CoreInterface.cpp b/PythonWrapper/src/CoreInterface.cpp index c0e658df..45dc6504 100644 --- a/PythonWrapper/src/CoreInterface.cpp +++ b/PythonWrapper/src/CoreInterface.cpp @@ -24,7 +24,7 @@ AnnIndex::AnnIndex(const char* p_algoType, const char* p_valueType, SizeType p_d AnnIndex::AnnIndex(const std::shared_ptr& p_index) : m_algoType(p_index->GetIndexAlgoType()), - m_inputValueType(p_index->AcceptableQueryValueType()), + m_inputValueType(p_index->GetVectorValueType()), m_dimension(p_index->GetFeatureDim()), m_index(p_index) { @@ -39,6 +39,21 @@ AnnIndex::~AnnIndex() bool AnnIndex::Build(ByteArray p_data, SizeType p_num) +{ + if (nullptr == m_index) + { + m_index = SPTAG::VectorIndex::CreateInstance(m_algoType, m_inputValueType); + } + if (nullptr == m_index || p_num == 0 || m_dimension == 0 || p_data.Length() != p_num * m_inputVectorSize) + { + return false; + } + return (SPTAG::ErrorCode::Success == m_index->BuildIndex(p_data.Data(), p_num, m_dimension)); +} + + +bool +AnnIndex::BuildWithMetaData(ByteArray p_data, ByteArray p_meta, SizeType p_num) { if (nullptr == m_index) { @@ -54,11 +69,14 @@ AnnIndex::Build(ByteArray p_data, SizeType p_num) static_cast(m_dimension), static_cast(p_num))); - if (SPTAG::ErrorCode::Success != m_index->BuildIndex(vectors, nullptr)) - { - return false; + std::uint64_t* offsets = new std::uint64_t[p_num + 1]{ 0 }; + int current = 1; + for (size_t i = 0; i < p_meta.Length(); i++) { + if (((char)p_meta.Data()[i]) == '\n') + offsets[current++] = (std::uint64_t)(i + 1); } - return true; + std::shared_ptr meta(new SPTAG::MemMetadataSet(p_meta, ByteArray((std::uint8_t*)offsets, (p_num + 1) * sizeof(std::uint64_t), true), p_num)); + return (SPTAG::ErrorCode::Success == m_index->BuildIndex(vectors, meta)); } @@ -136,3 +154,63 @@ AnnIndex::Load(const char* p_loaderFile) return AnnIndex(vecIndex); } + + +bool +AnnIndex::Add(ByteArray p_data, SizeType p_num) +{ + if (nullptr == m_index) + { + m_index = SPTAG::VectorIndex::CreateInstance(m_algoType, m_inputValueType); + } + if (nullptr == m_index || p_num == 0 || m_dimension == 0 || p_data.Length() != p_num * m_inputVectorSize) + { + return false; + } + return (SPTAG::ErrorCode::Success == m_index->AddIndex(p_data.Data(), p_num, m_dimension)); +} + + +bool +AnnIndex::AddWithMetaData(ByteArray p_data, ByteArray p_meta, SizeType p_num) +{ + if (nullptr == m_index) + { + m_index = SPTAG::VectorIndex::CreateInstance(m_algoType, m_inputValueType); + } + if (nullptr == m_index || p_num == 0 || m_dimension == 0 || p_data.Length() != p_num * m_inputVectorSize) + { + return false; + } + + std::shared_ptr vectors(new SPTAG::BasicVectorSet(p_data, + m_inputValueType, + static_cast(m_dimension), + static_cast(p_num))); + + std::uint64_t* offsets = new std::uint64_t[p_num + 1]{ 0 }; + int current = 1; + for (size_t i = 0; i < p_meta.Length(); i++) { + if (((char)p_meta.Data()[i]) == '\n') + offsets[current++] = (std::uint64_t)(i + 1); + } + std::shared_ptr meta(new SPTAG::MemMetadataSet(p_meta, ByteArray((std::uint8_t*)offsets, (p_num + 1) * sizeof(std::uint64_t), true), p_num)); + return (SPTAG::ErrorCode::Success == m_index->AddIndex(vectors, meta)); +} + + +bool +AnnIndex::Delete(ByteArray p_data, SizeType p_num) +{ + if (nullptr != m_index && p_num > 0) + { + return (SPTAG::ErrorCode::Success == m_index->DeleteIndex(p_data.Data(), p_num)); + } + return false; +} + +bool +AnnIndex::Refine(const char* p_loaderFile) +{ + return (SPTAG::ErrorCode::Success == m_index->RefineIndex(std::string(p_loaderFile))); +} \ No newline at end of file diff --git a/README.md b/README.md index 963513b3..e849b029 100644 --- a/README.md +++ b/README.md @@ -1,28 +1,50 @@ -# Space Partition Tree And Graph for Vector Search +# SPTAG: A library for fast approximate nearest neighbor search [![MIT licensed](https://img.shields.io/badge/license-MIT-yellow.svg)](https://github.com/Microsoft/nni/blob/master/LICENSE) -[![Build status](https://sysdnn.visualstudio.com/SPTAG/_apis/build/status/SPTAG-GITHUB)](https://sysdnn.visualstudio.com/SPTAG/_build/latest?definitionId=6) +[![Build status](https://sysdnn.visualstudio.com/SPTAG/_apis/build/status/SPTAG-CI)](https://sysdnn.visualstudio.com/SPTAG/_build/latest?definitionId=2) + +## **SPTAG** + SPTAG (Space Partition Tree And Graph) is a library for large scale vector approximate nearest neighbor search scenerio, which is written in C++ and wrapped by Python. - SPTAG (Space Partition Tree And Graph) is a toolkit which provides a high quality vector index build, search and distributed online serving toolkits for large scale vector search scenario.

architecture

- ## **Why to consider using SPTAG** - * Performance - * New features - 1. .. - 2. .. - ## **Install** -### **requirements** +## **Introduction** + +This library assumes that the samples are represented as vectors and that the vectors can be compared by L2 distances or cosine distances. +Vectors returned for a query vector are the vectors that have smallest L2 distance or cosine distances with the query vector. + +SPTAG provides two methods: kd-tree and relative neighborhood graph (SPTAG-KDT) +and balanced k-means tree and relatrive neighborhood graph (SPTAG-BKT). +SPTAG-KDT is advantageous in index building cost, and SPTAG-BKT is advantageous in search accuracy in very high-dimensional data. + + + +## **How it works** + +SPTAG is inspired by the NGS approach [[WangL12](#References)]. It contains two basic modules: index builder and searcher. +The RNG is built on the k-nearest neighborhood graph [[WangWZTG12](#References), [WangWJLZZH14](References)] +for boosting the conectivity. Balanced k-means trees are used to replace kd-trees to avoid the inaccurate distance bound estimation in kd-trees for very high-dimensional vectors. +The search begins with the search in the space partition trees for +finding several seeds to start the search in the RNG. +The searches in the trees and the graph are iteratively conducted. + + ## **Highlights** + * Fresh update: Support online vector deletion and insertion + * Distributed serving: Search over multiple machines + + ## **Build** + +### **Requirements** * swig >= 3.0 * cmake >= 3.12.0 * boost >= 1.67.0 -### **compile && install** +### **Install** > For Linux: ```bash @@ -36,16 +58,72 @@ It will generate a Release folder in the code directory which contains all the b mkdir build cd build && cmake -A x64 .. ``` -It will generate a SPTAGLib.sln in the build directory, open the solution in the Visual Studio (at least 2015) and compile the ALL_BUILD project, it will generate a Release directory which contains all the build targets. +It will generate a SPTAGLib.sln in the build directory. +Compiling the ALL_BUILD project in the Visual Studio (at least 2015) will generate a Release directory which contains all the build targets. -### **test** +### **Verify** Run the test (or Test.exe) in the Release folder to verify all the tests have passed. -## **Documentation** +### **Usage** -* Overview -* Get started +The detailed usage can be found in [Get started](docs/GettingStart.md). + +## **References** +Please cite SPTAG in your publications if it helps your research: +``` +@manual{ChenW18, + author = {Qi Chen and + Haidong Wang and + Mingqin Li and + Gang Ren and + Jeffery Zhu and + Jason Li and + Lintao Zhang and + Jingdong Wang}, + title = {SPTAG: A library for fast approximate nearest neighbor search}, + url = {https://github.com/Microsoft/SPTAG}, + year = {2018} +} + +@inproceedings{WangL12, + author = {Jingdong Wang and + Shipeng Li}, + title = {Query-driven iterated neighborhood graph search for large scale indexing}, + booktitle = {ACM Multimedia 2012}, + pages = {179--188}, + year = {2012} +} + +@inproceedings{WangWZTGL12, + author = {Jing Wang and + Jingdong Wang and + Gang Zeng and + Zhuowen Tu and + Rui Gan and + Shipeng Li}, + title = {Scalable k-NN graph construction for visual descriptors}, + booktitle = {CVPR 2012}, + pages = {1106--1113}, + year = {2012} +} + +@article{WangWJLZZH14, + author = {Jingdong Wang and + Naiyan Wang and + You Jia and + Jian Li and + Gang Zeng and + Hongbin Zha and + Xian{-}Sheng Hua}, + title = {Trinary-Projection Trees for Approximate Nearest Neighbor Search}, + journal = {{IEEE} Trans. Pattern Anal. Mach. Intell.}, + volume = {36}, + number = {2}, + pages = {388--403}, + year = {2014 +} +``` ## **Contribute** diff --git a/SPTAG.sln b/SPTAG.sln index 20c1ccf2..a63464d9 100644 --- a/SPTAG.sln +++ b/SPTAG.sln @@ -46,6 +46,8 @@ Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "IndexBuilder", "AnnService\ EndProject Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "Test", "Test\Test.vcxproj", "{29A25655-CCF2-47F8-8BC8-DFE1B5CF993C}" EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "Search", "Search\Search.vcxproj", "{97615D3B-9FA0-469E-B229-95A91A5087E0}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|x64 = Debug|x64 @@ -126,6 +128,14 @@ Global {29A25655-CCF2-47F8-8BC8-DFE1B5CF993C}.Release|x64.Build.0 = Release|x64 {29A25655-CCF2-47F8-8BC8-DFE1B5CF993C}.Release|x86.ActiveCfg = Release|Win32 {29A25655-CCF2-47F8-8BC8-DFE1B5CF993C}.Release|x86.Build.0 = Release|Win32 + {97615D3B-9FA0-469E-B229-95A91A5087E0}.Debug|x64.ActiveCfg = Debug|x64 + {97615D3B-9FA0-469E-B229-95A91A5087E0}.Debug|x64.Build.0 = Debug|x64 + {97615D3B-9FA0-469E-B229-95A91A5087E0}.Debug|x86.ActiveCfg = Debug|Win32 + {97615D3B-9FA0-469E-B229-95A91A5087E0}.Debug|x86.Build.0 = Debug|Win32 + {97615D3B-9FA0-469E-B229-95A91A5087E0}.Release|x64.ActiveCfg = Release|x64 + {97615D3B-9FA0-469E-B229-95A91A5087E0}.Release|x64.Build.0 = Release|x64 + {97615D3B-9FA0-469E-B229-95A91A5087E0}.Release|x86.ActiveCfg = Release|Win32 + {97615D3B-9FA0-469E-B229-95A91A5087E0}.Release|x86.Build.0 = Release|Win32 EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/Search/CMakeLists.txt b/Search/CMakeLists.txt new file mode 100644 index 00000000..acfa0b38 --- /dev/null +++ b/Search/CMakeLists.txt @@ -0,0 +1,14 @@ +file(GLOB HDR_FILES ${PROJECT_SOURCE_DIR}/AnnService/inc/Core/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Core/Common/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Core/BKT/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Core/KDT/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Helper/*.h) +file(GLOB SRC_FILES ${PROJECT_SOURCE_DIR}/AnnService/src/Core/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Core/Common/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Core/BKT/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Core/KDT/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Helper/*.cpp) + +include_directories(${PROJECT_SOURCE_DIR}/AnnService) + +file(GLOB SEARCH_FILES ${SRC_FILES} ${PROJECT_SOURCE_DIR}/Search/*.cpp) +add_executable (search ${SEARCH_FILES} ${HDR_FILES}) +target_link_libraries(search ${Boost_LIBRARIES} ${TBB_LIBRARIES}) + +install(TARGETS search + RUNTIME DESTINATION bin + ARCHIVE DESTINATION lib + LIBRARY DESTINATION lib) + diff --git a/Search/Search.vcxproj b/Search/Search.vcxproj new file mode 100644 index 00000000..8d2d1cec --- /dev/null +++ b/Search/Search.vcxproj @@ -0,0 +1,170 @@ + + + + + Debug + Win32 + + + Release + Win32 + + + Debug + x64 + + + Release + x64 + + + + 15.0 + {97615D3B-9FA0-469E-B229-95A91A5087E0} + Search + 8.1 + Search + + + + + Application + true + v140 + MultiByte + + + Application + false + v140 + true + MultiByte + + + Application + true + v140 + MultiByte + + + Application + false + v140 + true + MultiByte + + + + + + + + + + + + + + + + + + + + $(SolutionDir)obj\$(Platform)_$(Configuration)\$(ProjectName)\ + $(ProjectDir);$(SolutionDir)AnnService\;$(IncludePath) + $(OutLibDir);$(LibraryPath) + $(OutAppDir) + + + false + + + + CoreLibrary.lib;%(AdditionalDependencies) + + + + + Level3 + MaxSpeed + true + true + true + true + _MBCS;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + + + true + true + + + + + Level3 + Disabled + true + true + + + + + Level3 + Disabled + true + true + _MBCS;_SCL_SECURE_NO_WARNINGS;%(PreprocessorDefinitions) + Guard + ProgramDatabase + + + /guard:cf %(AdditionalOptions) + + + + + Level3 + MaxSpeed + true + true + true + true + + + true + true + + + + + + + + Designer + + + + + + + + + + + + + + + + This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. + + + + + + + + + + + \ No newline at end of file diff --git a/Search/Search.vcxproj.filters b/Search/Search.vcxproj.filters new file mode 100644 index 00000000..32a81453 --- /dev/null +++ b/Search/Search.vcxproj.filters @@ -0,0 +1,25 @@ + + + + + {4FC737F1-C7A5-4376-A066-2A32D752A2FF} + cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx + + + {93995380-89BD-4b04-88EB-625FBE52EBFB} + h;hh;hpp;hxx;hm;inl;inc;xsd + + + {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} + rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms + + + + + Source Files + + + + + + \ No newline at end of file diff --git a/Search/main.cpp b/Search/main.cpp new file mode 100644 index 00000000..7cd28b60 --- /dev/null +++ b/Search/main.cpp @@ -0,0 +1,285 @@ +#include "inc/Helper/SimpleIniReader.h" +#include "inc/Helper/CommonHelper.h" +#include "inc/Core/Common.h" +#include "inc/Core/MetadataSet.h" +#include "inc/Core/VectorIndex.h" +#include "inc/Core/SearchQuery.h" +#include "inc/Core/Common/WorkSpace.h" +#include "inc/Core/Common/DataUtils.h" +#include +#include + +using namespace SPTAG; + +template +float CalcRecall(std::vector &results, const std::vector> &truth, int NumQuerys, int K, std::ofstream& log) +{ + float meanrecall = 0, minrecall = MaxDist, maxrecall = 0, stdrecall = 0; + std::vector thisrecall(NumQuerys, 0); + for (int i = 0; i < NumQuerys; i++) + { + for (int id : truth[i]) + { + for (int j = 0; j < K; j++) + { + if (results[i].GetResult(j)->VID == id) + { + thisrecall[i] += 1; + break; + } + } + } + thisrecall[i] /= K; + meanrecall += thisrecall[i]; + if (thisrecall[i] < minrecall) minrecall = thisrecall[i]; + if (thisrecall[i] > maxrecall) maxrecall = thisrecall[i]; + } + meanrecall /= NumQuerys; + for (int i = 0; i < NumQuerys; i++) + { + stdrecall += (thisrecall[i] - meanrecall) * (thisrecall[i] - meanrecall); + } + stdrecall = std::sqrt(stdrecall / NumQuerys); + log << meanrecall << " " << stdrecall << " " << minrecall << " " << maxrecall << std::endl; + return meanrecall; +} + +void LoadTruth(std::ifstream& fp, std::vector>& truth, int NumQuerys, int K) +{ + int get; + std::string line; + for (int i = 0; i < NumQuerys; ++i) + { + truth[i].clear(); + for (int j = 0; j < K; ++j) + { + fp >> get; + truth[i].insert(get); + } + std::getline(fp, line); + } +} + +template +int Process(Helper::IniReader& reader, VectorIndex& index) +{ + std::string queryFile = reader.GetParameter("Index", "QueryFile", std::string("querys.bin")); + std::string truthFile = reader.GetParameter("Index", "TruthFile", std::string("truth.txt")); + std::string outputFile = reader.GetParameter("Index", "ResultFile", std::string("")); + + int numBatchQuerys = reader.GetParameter("Index", "NumBatchQuerys", 10000); + int numDebugQuerys = reader.GetParameter("Index", "NumDebugQuerys", -1); + int K = reader.GetParameter("Index", "K", 32); + + std::vector maxCheck = Helper::StrUtils::SplitString(reader.GetParameter("Index", "MaxCheck", std::string("2048")), "#"); + + std::ifstream inStream(queryFile); + std::ifstream ftruth(truthFile); + std::ofstream fp; + if (!inStream.is_open()) + { + std::cout << "ERROR: Cannot Load Query file " << queryFile << "!" << std::endl; + return -1; + } + if (outputFile != "") + { + fp.open(outputFile); + if (!fp.is_open()) + { + std::cout << "ERROR: Cannot open " << outputFile << " for write!" << std::endl; + } + } + + std::ofstream log(index.GetIndexName() + "_" + std::to_string(K) + ".txt"); + if (!log.is_open()) + { + std::cout << "ERROR: Cannot open logging file!" << std::endl; + return -1; + } + + int numQuerys = (numDebugQuerys >= 0) ? numDebugQuerys : numBatchQuerys; + + std::vector> Query(numQuerys, std::vector(index.GetFeatureDim(), 0)); + std::vector> truth(numQuerys); + std::vector results(numQuerys, QueryResult(NULL, K, 0)); + + int * latencies = new int[numQuerys + 1]; + + int base = 1; + if (index.GetDistCalcMethod() == DistCalcMethod::Cosine) { + base = COMMON::Utils::GetBase(); + } + int basesquare = base * base; + + int dims = index.GetFeatureDim(); + std::vector QStrings; + while (!inStream.eof()) + { + QStrings.clear(); + COMMON::Utils::PrepareQuerys(inStream, QStrings, Query, numQuerys, dims, index.GetDistCalcMethod(), base); + if (numQuerys == 0) break; + + for (int i = 0; i < numQuerys; i++) results[i].SetTarget(Query[i].data()); + if (ftruth.is_open()) LoadTruth(ftruth, truth, numQuerys, K); + + std::cout << " \t[avg] \t[99%] \t[95%] \t[recall] \t[mem]" << std::endl; + + int subSize = (numQuerys - 1) / index.GetNumThreads() + 1; + for (std::string& mc : maxCheck) + { + index.SetParameter("MaxCheck", mc.c_str()); + for (int i = 0; i < numQuerys; i++) results[i].Reset(); + + if (index.GetNumThreads() == 1) + { + for (int i = 0; i < numQuerys; i++) + { + latencies[i] = clock(); + index.SearchIndex(results[i]); + } + } + else + { +#pragma omp parallel for + for (int tid = 0; tid < index.GetNumThreads(); tid++) + { + int start = tid * subSize; + int end = min((tid + 1) * subSize, numQuerys); + for (int i = start; i < end; i++) + { + latencies[i] = clock(); + index.SearchIndex(results[i]); + } + } + } + latencies[numQuerys] = clock(); + + float timeMean = 0, timeMin = MaxDist, timeMax = 0, timeStd = 0; + for (int i = 0; i < numQuerys; i++) + { + if (latencies[i + 1] >= latencies[i]) + latencies[i] = latencies[i + 1] - latencies[i]; + else + latencies[i] = latencies[numQuerys] - latencies[i]; + timeMean += latencies[i]; + if (latencies[i] > timeMax) timeMax = (float)latencies[i]; + if (latencies[i] < timeMin) timeMin = (float)latencies[i]; + } + timeMean /= numQuerys; + for (int i = 0; i < numQuerys; i++) timeStd += ((float)latencies[i] - timeMean) * ((float)latencies[i] - timeMean); + timeStd = std::sqrt(timeStd / numQuerys); + log << timeMean << " " << timeStd << " " << timeMin << " " << timeMax << " "; + + std::sort(latencies, latencies + numQuerys, [](int x, int y) + { + return x < y; + }); + float l99 = float(latencies[int(numQuerys * 0.99)]) / CLOCKS_PER_SEC; + float l95 = float(latencies[int(numQuerys * 0.95)]) / CLOCKS_PER_SEC; + + float recall = 0; + if (ftruth.is_open()) + { + recall = CalcRecall(results, truth, numQuerys, K, log); + } + +#ifndef _MSC_VER + struct rusage rusage; + getrusage(RUSAGE_SELF, &rusage); + unsigned long long peakWSS = rusage.ru_maxrss * 1024 / 1000000000; +#else + PROCESS_MEMORY_COUNTERS pmc; + GetProcessMemoryInfo(GetCurrentProcess(), &pmc, sizeof(pmc)); + unsigned long long peakWSS = pmc.PeakWorkingSetSize / 1000000000; +#endif + std::cout << mc << "\t" << std::fixed << std::setprecision(6) << (timeMean / CLOCKS_PER_SEC) << "\t" << std::setprecision(4) << l99 << "\t" << l95 << "\t" << recall << "\t\t" << peakWSS << "GB" << std::endl; + + } + + if (fp.is_open()) + { + fp << std::setprecision(3) << std::fixed; + for (int i = 0; i < numQuerys; i++) + { + fp << QStrings[i] << ":"; + for (int j = 0; j < K; j++) + { + if (results[i].GetResult(j)->VID < 0) { + fp << results[i].GetResult(j)->Dist << "@" << results[i].GetResult(j)->VID << std::endl; + } + else { + ByteArray vm = index.GetMetadata(results[i].GetResult(j)->VID); + fp << (results[i].GetResult(j)->Dist / basesquare) << "@"; + fp.write((const char*)vm.Data(), vm.Length()); + } + fp << "|"; + } + fp << std::endl; + } + } + + if (numQuerys < numBatchQuerys || numDebugQuerys >= 0) break; + } + std::cout << "Output results finish!" << std::endl; + + inStream.close(); + fp.close(); + log.close(); + ftruth.close(); + delete[] latencies; + + QStrings.clear(); + results.clear(); + + return 0; +} + +int main(int argc, char** argv) +{ + if (argc < 2) + { + std::cerr << "Search.exe folder" << std::endl; + return -1; + } + + std::shared_ptr vecIndex; + auto ret = SPTAG::VectorIndex::LoadIndex(argv[1], vecIndex); + if (SPTAG::ErrorCode::Success != ret || nullptr == vecIndex) + { + std::cerr << "Cannot open configure file!" << std::endl; + return -1; + } + + Helper::IniReader iniReader; + for (int i = 1; i < argc; i++) + { + std::string param(argv[i]); + size_t idx = param.find("="); + if (idx < 0) continue; + + std::string paramName = param.substr(0, idx); + std::string paramVal = param.substr(idx + 1); + std::string sectionName; + idx = paramName.find("."); + if (idx >= 0) { + sectionName = paramName.substr(0, idx); + paramName = paramName.substr(idx + 1); + } + iniReader.SetParameter(sectionName, paramName, paramVal); + std::cout << "Set [" << sectionName << "]" << paramName << " = " << paramVal << std::endl; + } + + switch (vecIndex->GetVectorValueType()) + { +#define DefineVectorValueType(Name, Type) \ + case VectorValueType::Name: \ + Process(iniReader, *(vecIndex.get())); \ + break; \ + +#include "inc/Core/DefinitionList.h" +#undef DefineVectorValueType + + default: break; + } + return 0; +} diff --git a/Search/packages.config b/Search/packages.config new file mode 100644 index 00000000..424245f6 --- /dev/null +++ b/Search/packages.config @@ -0,0 +1,12 @@ + + + + + + + + + + + + \ No newline at end of file diff --git a/Test/CMakeLists.txt b/Test/CMakeLists.txt index 88384813..93a9f1b8 100644 --- a/Test/CMakeLists.txt +++ b/Test/CMakeLists.txt @@ -1,13 +1,3 @@ -find_package(OpenMP) -if (OpenMP_FOUND) - set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") - set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") - set (CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}") - message (STATUS "Found openmp.") -else() - message (FATAL_ERROR "Could no find openmp!") -endif() - if(NOT WIN32) ADD_DEFINITIONS(-DBOOST_TEST_DYN_LINK) message (STATUS "BOOST_TEST_DYN_LINK") @@ -30,7 +20,7 @@ include_directories(${PYTHON_INCLUDE_PATH} ${PROJECT_SOURCE_DIR}/AnnService ${PR file(GLOB TEST_HDR_FILES ${PROJECT_SOURCE_DIR}/Test/inc/Test.h) file(GLOB TEST_SRC_FILES ${PROJECT_SOURCE_DIR}/Test/src/*.cpp) add_executable (test ${TEST_SRC_FILES} ${TEST_HDR_FILES}) -target_link_libraries(test SPTAGLib ${Boost_LIBRARIES}) +target_link_libraries(test SPTAGLib ${Boost_LIBRARIES} ${TBB_LIBRARIES}) install(TARGETS test RUNTIME DESTINATION bin diff --git a/Test/Test.vcxproj b/Test/Test.vcxproj index bd509da9..da9f2274 100644 --- a/Test/Test.vcxproj +++ b/Test/Test.vcxproj @@ -140,15 +140,21 @@ - + + + + + - + + Designer + @@ -160,6 +166,8 @@ + + @@ -173,5 +181,7 @@ + + \ No newline at end of file diff --git a/Test/Test.vcxproj.filters b/Test/Test.vcxproj.filters index 71509659..a814c3ec 100644 --- a/Test/Test.vcxproj.filters +++ b/Test/Test.vcxproj.filters @@ -21,7 +21,19 @@ Source Files - + + Source Files + + + Source Files + + + Source Files + + + Source Files + + Source Files @@ -30,4 +42,7 @@ Header Files + + + \ No newline at end of file diff --git a/Test/Test.vcxproj.user b/Test/Test.vcxproj.user index abe8dd89..10f0fcf2 100644 --- a/Test/Test.vcxproj.user +++ b/Test/Test.vcxproj.user @@ -1,4 +1,7 @@  - + + $(OutLibDir) + WindowsLocalDebugger + \ No newline at end of file diff --git a/Test/packages.config b/Test/packages.config index 651c7547..ddc362df 100644 --- a/Test/packages.config +++ b/Test/packages.config @@ -2,4 +2,6 @@ + + \ No newline at end of file diff --git a/Test/src/AlgoTest.cpp b/Test/src/AlgoTest.cpp new file mode 100644 index 00000000..1999d349 --- /dev/null +++ b/Test/src/AlgoTest.cpp @@ -0,0 +1,131 @@ +#include "inc/Test.h" +#include "inc/Helper/SimpleIniReader.h" +#include "inc/Core/VectorIndex.h" + +template +void Build(SPTAG::IndexAlgoType algo, std::string distCalcMethod, T* vec, int n, int m) +{ + std::vector meta; + std::vector metaoffset; + for (int i = 0; i < n; i++) { + metaoffset.push_back(meta.size()); + std::string a = std::to_string(i); + for (int j = 0; j < a.length(); j++) + meta.push_back(a[j]); + } + metaoffset.push_back(meta.size()); + + std::shared_ptr vecset(new SPTAG::BasicVectorSet( + SPTAG::ByteArray((std::uint8_t*)vec, n * m * sizeof(T), false), + SPTAG::GetEnumValueType(), m, n)); + + std::shared_ptr metaset(new SPTAG::MemMetadataSet( + SPTAG::ByteArray((std::uint8_t*)meta.data(), meta.size() * sizeof(char), false), + SPTAG::ByteArray((std::uint8_t*)metaoffset.data(), metaoffset.size() * sizeof(long long), false), + n)); + std::shared_ptr vecIndex = SPTAG::VectorIndex::CreateInstance(algo, SPTAG::GetEnumValueType()); + vecIndex->SetParameter("DistCalcMethod", distCalcMethod); + BOOST_CHECK(nullptr != vecIndex); + BOOST_CHECK(SPTAG::ErrorCode::Success == vecIndex->BuildIndex(vecset, metaset)); + BOOST_CHECK(SPTAG::ErrorCode::Success == vecIndex->SaveIndex("origindices")); +} + +template +void Search(std::string folder, T* vec, int k) +{ + std::shared_ptr vecIndex; + BOOST_CHECK(SPTAG::ErrorCode::Success == SPTAG::VectorIndex::LoadIndex(folder, vecIndex)); + BOOST_CHECK(nullptr != vecIndex); + + SPTAG::QueryResult res(vec, k, true); + vecIndex->SearchIndex(res); + for (int i = 0; i < k; i++) { + std::cout << res.GetResult(i)->Dist << "@(" << res.GetResult(i)->VID << "," << std::string((char*)res.GetMetadata(i).Data(), res.GetMetadata(i).Length()) << ") "; + } + std::cout << std::endl; + vecIndex.reset(); +} + +template +void Add(T* vec, int n) +{ + std::shared_ptr vecIndex; + BOOST_CHECK(SPTAG::ErrorCode::Success == SPTAG::VectorIndex::LoadIndex("origindices", vecIndex)); + BOOST_CHECK(nullptr != vecIndex); + + std::vector meta; + std::vector metaoffset; + for (int i = 0; i < n; i++) { + metaoffset.push_back(meta.size()); + std::string a = std::to_string(vecIndex->GetNumSamples() + i); + for (int j = 0; j < a.length(); j++) + meta.push_back(a[j]); + } + metaoffset.push_back(meta.size()); + + int m = vecIndex->GetFeatureDim(); + std::shared_ptr vecset(new SPTAG::BasicVectorSet( + SPTAG::ByteArray((std::uint8_t*)vec, n * m * sizeof(T), false), + SPTAG::GetEnumValueType(), m, n)); + + std::shared_ptr metaset(new SPTAG::MemMetadataSet( + SPTAG::ByteArray((std::uint8_t*)meta.data(), meta.size() * sizeof(char), false), + SPTAG::ByteArray((std::uint8_t*)metaoffset.data(), metaoffset.size() * sizeof(long long), false), + n)); + + BOOST_CHECK(SPTAG::ErrorCode::Success == vecIndex->AddIndex(vecset, metaset)); + BOOST_CHECK(SPTAG::ErrorCode::Success == vecIndex->SaveIndex("addindices")); + vecIndex.reset(); +} + +template +void Delete(T* vec, int n) +{ + std::shared_ptr vecIndex; + BOOST_CHECK(SPTAG::ErrorCode::Success == SPTAG::VectorIndex::LoadIndex("addindices", vecIndex)); + BOOST_CHECK(nullptr != vecIndex); + + BOOST_CHECK(SPTAG::ErrorCode::Success == vecIndex->DeleteIndex((const void*)vec, n)); + BOOST_CHECK(SPTAG::ErrorCode::Success == vecIndex->SaveIndex("delindices")); + vecIndex.reset(); +} + +template +void Test(SPTAG::IndexAlgoType algo, std::string distCalcMethod) +{ + int n = 100, q = 3, m = 10, k = 3; + std::vector vec; + for (int i = 0; i < n; i++) { + for (int j = 0; j < m; j++) { + vec.push_back((T)i); + } + } + + std::vector query; + for (int i = 0; i < q; i++) { + for (int j = 0; j < m; j++) { + query.push_back((T)i*2); + } + } + + Build(algo, distCalcMethod, vec.data(), n, m); + Search("origindices", query.data(), k); + Add(query.data(), q); + Search("addindices", query.data(), k); + Delete(query.data(), q); + Search("delindices", query.data(), k); +} + +BOOST_AUTO_TEST_SUITE (AlgoTest) + +BOOST_AUTO_TEST_CASE(KDTTest) +{ + Test(SPTAG::IndexAlgoType::KDT, "L2"); +} + +BOOST_AUTO_TEST_CASE(BKTTest) +{ + Test(SPTAG::IndexAlgoType::BKT, "L2"); +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/Test/src/BKTTest.cpp b/Test/src/BKTTest.cpp deleted file mode 100644 index 0ec934ea..00000000 --- a/Test/src/BKTTest.cpp +++ /dev/null @@ -1,25 +0,0 @@ -#include "inc/Test.h" -#include "inc/Helper/SimpleIniReader.h" -#include "inc/Core/VectorIndex.h" - -BOOST_AUTO_TEST_SUITE (BKTTest) - -BOOST_AUTO_TEST_CASE(ParameterTest) -{ - SPTAG::Helper::IniReader reader; - reader.SetParameter("Index", "DistCalcMethod", "Cosine"); - reader.SetParameter("Index", "BKTNumber", "2"); - reader.SetParameter("Index", "NeighborhoodSize", "16"); - - auto index = SPTAG::VectorIndex::CreateInstance(SPTAG::IndexAlgoType::BKT, SPTAG::VectorValueType::Float); - - for (const auto& iter : reader.GetParameters("Index")) - { - index->SetParameter(iter.first.c_str(), iter.second.c_str()); - } - - BOOST_CHECK(index->GetParameter("BKTNumber") == "2"); - BOOST_CHECK(index->GetParameter("NeighborhoodSize") == "16"); -} - -BOOST_AUTO_TEST_SUITE_END() \ No newline at end of file diff --git a/Test/src/Base64HelperTest.cpp b/Test/src/Base64HelperTest.cpp new file mode 100644 index 00000000..57fec601 --- /dev/null +++ b/Test/src/Base64HelperTest.cpp @@ -0,0 +1,40 @@ +#include "inc/Test.h" +#include "inc/Helper/Base64Encode.h" + +#include + +BOOST_AUTO_TEST_SUITE(Base64Test) + +BOOST_AUTO_TEST_CASE(Base64EncDec) +{ + using namespace SPTAG::Helper::Base64; + + const size_t bufferSize = 1 << 10; + std::unique_ptr rawBuffer(new uint8_t[bufferSize]); + std::unique_ptr encBuffer(new char[bufferSize]); + std::unique_ptr rawBuffer2(new uint8_t[bufferSize]); + + for (size_t inputSize = 1; inputSize < 128; ++inputSize) + { + for (size_t i = 0; i < inputSize; ++i) + { + rawBuffer[i] = static_cast(i); + } + + size_t encBufLen = CapacityForEncode(inputSize); + BOOST_CHECK(encBufLen < bufferSize); + + size_t encOutLen = 0; + BOOST_CHECK(Encode(rawBuffer.get(), inputSize, encBuffer.get(), encOutLen)); + BOOST_CHECK(encBufLen >= encOutLen); + + size_t decBufLen = CapacityForDecode(encOutLen); + BOOST_CHECK(decBufLen < bufferSize); + + size_t decOutLen = 0; + BOOST_CHECK(Decode(encBuffer.get(), encOutLen, rawBuffer.get(), decOutLen)); + BOOST_CHECK(decBufLen >= decOutLen); + } +} + +BOOST_AUTO_TEST_SUITE_END() \ No newline at end of file diff --git a/Test/src/CommonHelperTest.cpp b/Test/src/CommonHelperTest.cpp new file mode 100644 index 00000000..581fa53b --- /dev/null +++ b/Test/src/CommonHelperTest.cpp @@ -0,0 +1,90 @@ +#include "inc/Test.h" +#include "inc/Helper/CommonHelper.h" + +#include + +BOOST_AUTO_TEST_SUITE(CommonHelperTest) + + +BOOST_AUTO_TEST_CASE(ToLowerInPlaceTest) +{ + auto runTestCase = [](std::string p_input, const std::string& p_expected) + { + SPTAG::Helper::StrUtils::ToLowerInPlace(p_input); + BOOST_CHECK(p_input == p_expected); + }; + + runTestCase("abc", "abc"); + runTestCase("ABC", "abc"); + runTestCase("abC", "abc"); + runTestCase("Upper-Case", "upper-case"); + runTestCase("123!-=aBc", "123!-=abc"); +} + + +BOOST_AUTO_TEST_CASE(SplitStringTest) +{ + std::string input("seg1 seg2 seg3 seg4"); + + const auto& segs = SPTAG::Helper::StrUtils::SplitString(input, " "); + BOOST_CHECK(segs.size() == 4); + BOOST_CHECK(segs[0] == "seg1"); + BOOST_CHECK(segs[1] == "seg2"); + BOOST_CHECK(segs[2] == "seg3"); + BOOST_CHECK(segs[3] == "seg4"); +} + + +BOOST_AUTO_TEST_CASE(FindTrimmedSegmentTest) +{ + using namespace SPTAG::Helper::StrUtils; + std::string input("\t Space End \r\n\t"); + + const auto& pos = FindTrimmedSegment(input.c_str(), + input.c_str() + input.size(), + [](char p_val)->bool + { + return std::isspace(p_val) > 0; + }); + + BOOST_CHECK(pos.first == input.c_str() + 2); + BOOST_CHECK(pos.second == input.c_str() + 13); +} + + +BOOST_AUTO_TEST_CASE(StartsWithTest) +{ + using namespace SPTAG::Helper::StrUtils; + + BOOST_CHECK(StartsWith("Abcd", "A")); + BOOST_CHECK(StartsWith("Abcd", "Ab")); + BOOST_CHECK(StartsWith("Abcd", "Abc")); + BOOST_CHECK(StartsWith("Abcd", "Abcd")); + + BOOST_CHECK(!StartsWith("Abcd", "a")); + BOOST_CHECK(!StartsWith("Abcd", "F")); + BOOST_CHECK(!StartsWith("Abcd", "AF")); + BOOST_CHECK(!StartsWith("Abcd", "AbF")); + BOOST_CHECK(!StartsWith("Abcd", "AbcF")); + BOOST_CHECK(!StartsWith("Abcd", "Abcde")); +} + + +BOOST_AUTO_TEST_CASE(StrEqualIgnoreCaseTest) +{ + using namespace SPTAG::Helper::StrUtils; + + BOOST_CHECK(StrEqualIgnoreCase("Abcd", "Abcd")); + BOOST_CHECK(StrEqualIgnoreCase("Abcd", "abcd")); + BOOST_CHECK(StrEqualIgnoreCase("Abcd", "abCD")); + BOOST_CHECK(StrEqualIgnoreCase("Abcd-123", "abcd-123")); + BOOST_CHECK(StrEqualIgnoreCase(" ZZZ", " zzz")); + + BOOST_CHECK(!StrEqualIgnoreCase("abcd", "abcd1")); + BOOST_CHECK(!StrEqualIgnoreCase("Abcd", " abcd")); + BOOST_CHECK(!StrEqualIgnoreCase("000", "OOO")); +} + + + +BOOST_AUTO_TEST_SUITE_END() \ No newline at end of file diff --git a/Test/src/DistanceTest.cpp b/Test/src/DistanceTest.cpp index 54ff51be..548dd32e 100644 --- a/Test/src/DistanceTest.cpp +++ b/Test/src/DistanceTest.cpp @@ -36,8 +36,8 @@ void test(int high) { X[i] = random(high, -high); Y[i] = random(high, -high); } - BOOST_CHECK_CLOSE_FRACTION(ComputeL2Distance(X, Y, dimension), SPTAG::COMMON::DistanceUtils::ComputeL2Distance(X, Y, dimension), 1e-6); - BOOST_CHECK_CLOSE_FRACTION(high*high - ComputeCosineDistance(X, Y, dimension), SPTAG::COMMON::DistanceUtils::ComputeCosineDistance(X, Y, dimension), 1e-6); + BOOST_CHECK_CLOSE_FRACTION(ComputeL2Distance(X, Y, dimension), SPTAG::COMMON::DistanceUtils::ComputeL2Distance(X, Y, dimension), 1e-5); + BOOST_CHECK_CLOSE_FRACTION(high*high - ComputeCosineDistance(X, Y, dimension), SPTAG::COMMON::DistanceUtils::ComputeCosineDistance(X, Y, dimension), 1e-5); delete[] X; delete[] Y; diff --git a/Test/src/IniReaderTest.cpp b/Test/src/IniReaderTest.cpp new file mode 100644 index 00000000..6f060f4d --- /dev/null +++ b/Test/src/IniReaderTest.cpp @@ -0,0 +1,37 @@ +#include "inc/Test.h" +#include "inc/Helper/SimpleIniReader.h" + +#include + +BOOST_AUTO_TEST_SUITE(IniReaderTest) + +BOOST_AUTO_TEST_CASE(IniReaderLoadTest) +{ + std::ofstream tmpIni("temp.ini"); + tmpIni << "[Common]" << std::endl; + tmpIni << "; Comment " << std::endl; + tmpIni << "Param1=1" << std::endl; + tmpIni << "Param2=Exp=2" << std::endl; + + tmpIni.close(); + + SPTAG::Helper::IniReader reader; + BOOST_CHECK(SPTAG::ErrorCode::Success == reader.LoadIniFile("temp.ini")); + + BOOST_CHECK(reader.DoesSectionExist("Common")); + BOOST_CHECK(reader.DoesParameterExist("Common", "Param1")); + BOOST_CHECK(reader.DoesParameterExist("Common", "Param2")); + + BOOST_CHECK(!reader.DoesSectionExist("NotExist")); + BOOST_CHECK(!reader.DoesParameterExist("NotExist", "Param1")); + BOOST_CHECK(!reader.DoesParameterExist("Common", "ParamNotExist")); + + BOOST_CHECK(1 == reader.GetParameter("Common", "Param1", 0)); + BOOST_CHECK(0 == reader.GetParameter("Common", "ParamNotExist", 0)); + + BOOST_CHECK(std::string("Exp=2") == reader.GetParameter("Common", "Param2", std::string())); + BOOST_CHECK(std::string("1") == reader.GetParameter("Common", "Param1", std::string())); + BOOST_CHECK(std::string() == reader.GetParameter("Common", "ParamNotExist", std::string())); +} + +BOOST_AUTO_TEST_SUITE_END() \ No newline at end of file diff --git a/Test/src/StringConvertTest.cpp b/Test/src/StringConvertTest.cpp new file mode 100644 index 00000000..82903e03 --- /dev/null +++ b/Test/src/StringConvertTest.cpp @@ -0,0 +1,125 @@ +#include "inc/Test.h" +#include "inc/Helper/StringConvert.h" + +namespace +{ + namespace Local + { + + template + void TestConvertSuccCase(ValueType p_val, const char* p_valStr) + { + using namespace SPTAG::Helper::Convert; + + std::string str = ConvertToString(p_val); + if (nullptr != p_valStr) + { + BOOST_CHECK(str == p_valStr); + } + + ValueType val; + BOOST_CHECK(ConvertStringTo(str.c_str(), val)); + BOOST_CHECK(val == p_val); + } + + } +} + +BOOST_AUTO_TEST_SUITE(StringConvertTest) + +BOOST_AUTO_TEST_CASE(ConvertInt8) +{ + Local::TestConvertSuccCase(static_cast(-1), "-1"); + Local::TestConvertSuccCase(static_cast(0), "0"); + Local::TestConvertSuccCase(static_cast(3), "3"); + Local::TestConvertSuccCase(static_cast(100), "100"); +} + +BOOST_AUTO_TEST_CASE(ConvertInt16) +{ + Local::TestConvertSuccCase(static_cast(-1), "-1"); + Local::TestConvertSuccCase(static_cast(0), "0"); + Local::TestConvertSuccCase(static_cast(3), "3"); + Local::TestConvertSuccCase(static_cast(100), "100"); +} + +BOOST_AUTO_TEST_CASE(ConvertInt32) +{ + Local::TestConvertSuccCase(static_cast(-1), "-1"); + Local::TestConvertSuccCase(static_cast(0), "0"); + Local::TestConvertSuccCase(static_cast(3), "3"); + Local::TestConvertSuccCase(static_cast(100), "100"); +} + +BOOST_AUTO_TEST_CASE(ConvertInt64) +{ + Local::TestConvertSuccCase(static_cast(-1), "-1"); + Local::TestConvertSuccCase(static_cast(0), "0"); + Local::TestConvertSuccCase(static_cast(3), "3"); + Local::TestConvertSuccCase(static_cast(100), "100"); +} + +BOOST_AUTO_TEST_CASE(ConvertUInt8) +{ + Local::TestConvertSuccCase(static_cast(0), "0"); + Local::TestConvertSuccCase(static_cast(3), "3"); + Local::TestConvertSuccCase(static_cast(100), "100"); +} + +BOOST_AUTO_TEST_CASE(ConvertUInt16) +{ + Local::TestConvertSuccCase(static_cast(0), "0"); + Local::TestConvertSuccCase(static_cast(3), "3"); + Local::TestConvertSuccCase(static_cast(100), "100"); +} + +BOOST_AUTO_TEST_CASE(ConvertUInt32) +{ + Local::TestConvertSuccCase(static_cast(0), "0"); + Local::TestConvertSuccCase(static_cast(3), "3"); + Local::TestConvertSuccCase(static_cast(100), "100"); +} + +BOOST_AUTO_TEST_CASE(ConvertUInt64) +{ + Local::TestConvertSuccCase(static_cast(0), "0"); + Local::TestConvertSuccCase(static_cast(3), "3"); + Local::TestConvertSuccCase(static_cast(100), "100"); +} + +BOOST_AUTO_TEST_CASE(ConvertFloat) +{ + Local::TestConvertSuccCase(static_cast(-1), nullptr); + Local::TestConvertSuccCase(static_cast(0), nullptr); + Local::TestConvertSuccCase(static_cast(3), nullptr); + Local::TestConvertSuccCase(static_cast(100), nullptr); +} + +BOOST_AUTO_TEST_CASE(ConvertDouble) +{ + Local::TestConvertSuccCase(static_cast(-1), nullptr); + Local::TestConvertSuccCase(static_cast(0), nullptr); + Local::TestConvertSuccCase(static_cast(3), nullptr); + Local::TestConvertSuccCase(static_cast(100), nullptr); +} + +BOOST_AUTO_TEST_CASE(ConvertIndexAlgoType) +{ + Local::TestConvertSuccCase(SPTAG::IndexAlgoType::BKT, "BKT"); + Local::TestConvertSuccCase(SPTAG::IndexAlgoType::KDT, "KDT"); +} + +BOOST_AUTO_TEST_CASE(ConvertVectorValueType) +{ + Local::TestConvertSuccCase(SPTAG::VectorValueType::Float, "Float"); + Local::TestConvertSuccCase(SPTAG::VectorValueType::Int8, "Int8"); + Local::TestConvertSuccCase(SPTAG::VectorValueType::Int16, "Int16"); +} + +BOOST_AUTO_TEST_CASE(ConvertDistCalcMethod) +{ + Local::TestConvertSuccCase(SPTAG::DistCalcMethod::Cosine, "Cosine"); + Local::TestConvertSuccCase(SPTAG::DistCalcMethod::L2, "L2"); +} + +BOOST_AUTO_TEST_SUITE_END() \ No newline at end of file diff --git a/azure-pipelines.yml b/azure-pipelines.yml index a62dbdc4..22f697bd 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -24,7 +24,7 @@ phases: cmake .. make cd ../Release - ./test + ./test displayName: 'Command Line Script' @@ -110,7 +110,7 @@ phases: - - script: '.\x64\Debug\Test.exe' + - script: '.\x64\Debug\Test.exe' displayName: 'Command Line Script' diff --git a/docs/GettingStart.md b/docs/GettingStart.md new file mode 100644 index 00000000..2f1e868f --- /dev/null +++ b/docs/GettingStart.md @@ -0,0 +1,217 @@ +## **Quick start** + +### **Index Build** + ```bash + Usage: + ./IndexBuiler [options] + Options: + -d, --dimension Dimension of vector, required. + -v, --vectortype Input vector data type (e.g. Float, Int8, Int16), required. + -i, --input Input raw data, required. + -o, --outputfolder Output folder, required. + -a, --algo Index Algorithm type, required. + + -t, --thread Thread Number, default is 32. + --delimiter Vector delimiter, default is |. + Index.= Set the algorithm parameter ArgName with value ArgValue. + ``` + + ### **Index Search** + ```bash + Usage: + ./Search [options] + Options + Index.QueryFile=XXX Input Query file + Index.ResultFile=XXX Output result file + Index.TruthFile=XXX Truth file that can help to calculate the recall + Index.K=XXX How many nearest neighbors return + Index.MaxCheck=XXX The maxcheck of the search + ``` + +### **Server** +```bash +Usage: +./Server [options] +Options: + -m, --mode Service mode, interactive or socket. + -c, --config Configure file of the index + +Write a server configuration file service.ini as follows: + +[Service] +ListenAddr=0.0.0.0 +ListenPort=8000 +ThreadNumber=8 +SocketThreadNumber=8 + +[QueryConfig] +DefaultMaxResultNumber=6 +DefaultSeparator=| + +[Index] +List=BKT + +[Index_BKT] +IndexFolder=BKT_gist +``` + +### **Client** +```bash +Usage: +./Client [options] +Options: +-s, --server Server address +-p, --port Server port +-t, Search timeout +-cth, Client Thread Number +-sth Socket Thread Number +``` + +### **Aggregator** +```bash +Usage: +./Aggregator + +Write Aggregator.ini as follows: + +[Service] +ListenAddr=0.0.0.0 +ListenPort=8100 +ThreadNumber=8 +SocketThreadNumber=8 + +[Servers] +Number=2 + +[Server_0] +Address=127.0.0.1 +Port=8000 + +[Server_1] +Address=127.0.0.1 +Port=8010 +``` + +### **Python Support** +> Singlebox PythonWrapper + ```python + +import SPTAG +import numpy as np + +n = 100 +k = 3 +r = 3 + +def testBuild(algo, distmethod, x, out): + i = SPTAG.AnnIndex(algo, 'Float', x.shape[1]) + i.SetBuildParam("NumberOfThreads", '4') + i.SetBuildParam("DistCalcMethod", distmethod) + ret = i.Build(x.tobytes(), x.shape[0]) + i.Save(out) + +def testBuildWithMetaData(algo, distmethod, x, s, out): + i = SPTAG.AnnIndex(algo, 'Float', x.shape[1]) + i.SetBuildParam("NumberOfThreads", '4') + i.SetBuildParam("DistCalcMethod", distmethod) + if i.BuildWithMetaData(x.tobytes(), s, x.shape[0]): + i.Save(out) + +def testSearch(index, q, k): + j = SPTAG.AnnIndex.Load(index) + for t in range(q.shape[0]): + result = j.Search(q[t].tobytes(), k) + print (result[0]) # ids + print (result[1]) # distances + +def testSearchWithMetaData(index, q, k): + j = SPTAG.AnnIndex.Load(index) + j.SetSearchParam("MaxCheck", '1024') + for t in range(q.shape[0]): + result = j.SearchWithMetaData(q[t].tobytes(), k) + print (result[0]) # ids + print (result[1]) # distances + print (result[2]) # metadata + +def testAdd(index, x, out, algo, distmethod): + if index != None: + i = SPTAG.AnnIndex.Load(index) + else: + i = SPTAG.AnnIndex(algo, 'Float', x.shape[1]) + i.SetBuildParam("NumberOfThreads", '4') + i.SetBuildParam("DistCalcMethod", distmethod) + if i.Add(x.tobytes(), x.shape[0]): + i.Save(out) + +def testAddWithMetaData(index, x, s, out, algo, distmethod): + if index != None: + i = SPTAG.AnnIndex.Load(index) + else: + i = SPTAG.AnnIndex(algo, 'Float', x.shape[1]) + i = SPTAG.AnnIndex(algo, 'Float', x.shape[1]) + i.SetBuildParam("NumberOfThreads", '4') + i.SetBuildParam("DistCalcMethod", distmethod) + if i.AddWithMetaData(x.tobytes(), s, x.shape[0]): + i.Save(out) + +def testDelete(index, x, out): + i = SPTAG.AnnIndex.Load(index) + ret = i.Delete(x.tobytes(), x.shape[0]) + print (ret) + i.Save(out) + +def Test(algo, distmethod): + x = np.ones((n, 10), dtype=np.float32) * np.reshape(np.arange(n, dtype=np.float32), (n, 1)) + q = np.ones((r, 10), dtype=np.float32) * np.reshape(np.arange(r, dtype=np.float32), (r, 1)) * 2 + m = '' + for i in range(n): + m += str(i) + '\n' + + print ("Build.............................") + testBuild(algo, distmethod, x, 'testindices') + testSearch('testindices', q, k) + print ("Add.............................") + testAdd('testindices', x, 'testindices', algo, distmethod) + testSearch('testindices', q, k) + print ("Delete.............................") + testDelete('testindices', q, 'testindices') + testSearch('testindices', q, k) + + print ("AddWithMetaData.............................") + testAddWithMetaData(None, x, m, 'testindices', algo, distmethod) + print ("Delete.............................") + testSearchWithMetaData('testindices', q, k) + testDelete('testindices', q, 'testindices') + testSearchWithMetaData('testindices', q, k) + +if __name__ == '__main__': + Test('BKT', 'L2') + Test('KDT', 'L2') + + ``` + + > Python Client Wrapper, Surpose there is a sever run at 127.0.0.1:8000 serving nytimes datasets: + ```python +import SPTAGClient +import numpy as np +import time + +def testSPTAGClient(): + index = SPTAGClient.AnnClient('127.0.0.1', '8100') + while not index.IsConnected(): + time.sleep(1) + index.SetTimeoutMilliseconds(18000) + + q = np.ones((10, 10), dtype=np.float32) + for t in range(q.shape[0]): + result = index.Search(q[t].tobytes(), 6, 'Float', False) + print (result[0]) + print (result[1]) + +if __name__ == '__main__': + testSPTAGClient() + + ``` + + + \ No newline at end of file