diff --git a/.gitignore b/.gitignore
index 2869b140..57fdf2db 100644
--- a/.gitignore
+++ b/.gitignore
@@ -71,3 +71,4 @@ dkms.conf
/obj/x64_Debug
/x64/Debug
/packages
+/Search/Search.vcxproj.user
diff --git a/AnnService/CMakeLists.txt b/AnnService/CMakeLists.txt
index d834e400..f5102115 100644
--- a/AnnService/CMakeLists.txt
+++ b/AnnService/CMakeLists.txt
@@ -1,25 +1,3 @@
-find_package(OpenMP)
-if (OpenMP_FOUND)
- set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
- set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
- set (CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}")
- message (STATUS "Found openmp.")
-else()
- message (FATAL_ERROR "Could no find openmp!")
-endif()
-
-find_package(Boost 1.67 COMPONENTS system thread serialization wserialization regex)
-if (Boost_FOUND)
- include_directories (${Boost_INCLUDE_DIR})
- link_directories (${Boost_LIBRARY_DIR} "/usr/lib")
- message (STATUS "Found Boost.")
- message (STATUS "Include Path: ${Boost_INCLUDE_DIRS}")
- message (STATUS "Library Path: ${Boost_LIBRARY_DIRS}")
- message (STATUS "Library: ${Boost_LIBRARIES}")
-else()
- message (FATAL_ERROR "Could not find Boost 1.67!")
-endif()
-
file(GLOB HDR_FILES ${PROJECT_SOURCE_DIR}/AnnService/inc/Core/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Core/Common/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Core/BKT/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Core/KDT/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Helper/*.h)
file(GLOB SRC_FILES ${PROJECT_SOURCE_DIR}/AnnService/src/Core/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Core/Common/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Core/BKT/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Core/KDT/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Helper/*.cpp)
@@ -32,24 +10,24 @@ set_target_properties(SPTAGLibStatic PROPERTIES OUTPUT_NAME SPTAGLib)
file(GLOB SERVER_HDR_FILES ${HDR_FILES} ${PROJECT_SOURCE_DIR}/AnnService/inc/Server/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Socket/*.h)
file(GLOB SERVER_FILES ${SRC_FILES} ${PROJECT_SOURCE_DIR}/AnnService/src/Server/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Socket/*.cpp)
add_executable (server ${SERVER_FILES} ${SERVER_HDR_FILES})
-target_link_libraries(server ${Boost_LIBRARIES})
+target_link_libraries(server ${Boost_LIBRARIES} ${TBB_LIBRARIES})
file(GLOB CLIENT_HDR_FILES ${HDR_FILES} ${PROJECT_SOURCE_DIR}/AnnService/inc/Client/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Socket/*.h)
file(GLOB CLIENT_FILES ${SRC_FILES} ${PROJECT_SOURCE_DIR}/AnnService/src/Client/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Socket/*.cpp)
add_executable (client ${CLIENT_FILES} ${CLIENT_HDR_FILES})
-target_link_libraries(client ${Boost_LIBRARIES})
+target_link_libraries(client ${Boost_LIBRARIES} ${TBB_LIBRARIES})
file(GLOB AGG_HDR_FILES ${HDR_FILES} ${PROJECT_SOURCE_DIR}/AnnService/inc/Aggregator/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/Socket/*.h)
file(GLOB AGG_FILES ${SRC_FILES} ${PROJECT_SOURCE_DIR}/AnnService/src/Aggregator/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/Socket/*.cpp)
add_executable (aggregator ${AGG_FILES} ${AGG_HDR_FILES})
-target_link_libraries(aggregator ${Boost_LIBRARIES})
+target_link_libraries(aggregator ${Boost_LIBRARIES} ${TBB_LIBRARIES})
file(GLOB BUILDER_HDR_FILES ${HDR_FILES} ${PROJECT_SOURCE_DIR}/AnnService/inc/IndexBuilder/*.h ${PROJECT_SOURCE_DIR}/AnnService/inc/IndexBuilder/VectorSetReaders/*.h)
file(GLOB BUILDER_FILES ${SRC_FILES} ${PROJECT_SOURCE_DIR}/AnnService/src/IndexBuilder/*.cpp ${PROJECT_SOURCE_DIR}/AnnService/src/IndexBuilder/VectorSetReaders/*.cpp)
add_executable (indexbuilder ${BUILDER_FILES} ${BUILDER_HDR_FILES})
-target_link_libraries(indexbuilder ${Boost_LIBRARIES})
+target_link_libraries(indexbuilder ${Boost_LIBRARIES} ${TBB_LIBRARIES})
install(TARGETS SPTAGLib SPTAGLibStatic server client aggregator indexbuilder
RUNTIME DESTINATION bin
ARCHIVE DESTINATION lib
- LIBRARY DESTINATION lib)
\ No newline at end of file
+ LIBRARY DESTINATION lib)
diff --git a/AnnService/CoreLibrary.vcxproj b/AnnService/CoreLibrary.vcxproj
index adda90e4..9844c709 100644
--- a/AnnService/CoreLibrary.vcxproj
+++ b/AnnService/CoreLibrary.vcxproj
@@ -131,6 +131,7 @@
+
@@ -171,7 +172,17 @@
+
+
+
+
+
+
+ This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}.
+
+
+
\ No newline at end of file
diff --git a/AnnService/CoreLibrary.vcxproj.filters b/AnnService/CoreLibrary.vcxproj.filters
index 9c44efaa..7d27224d 100644
--- a/AnnService/CoreLibrary.vcxproj.filters
+++ b/AnnService/CoreLibrary.vcxproj.filters
@@ -115,6 +115,9 @@
Header Files\Core\KDT
+
+ Header Files\Core\Common
+
@@ -132,9 +135,6 @@
Source Files\Helper
-
- Source Files\Core
-
Source Files\Core
diff --git a/AnnService/IndexBuilder.vcxproj b/AnnService/IndexBuilder.vcxproj
index 50638b29..b5efd475 100644
--- a/AnnService/IndexBuilder.vcxproj
+++ b/AnnService/IndexBuilder.vcxproj
@@ -159,6 +159,7 @@
+
@@ -171,5 +172,6 @@
+
\ No newline at end of file
diff --git a/AnnService/Server.vcxproj b/AnnService/Server.vcxproj
index d830f3bc..c2336176 100644
--- a/AnnService/Server.vcxproj
+++ b/AnnService/Server.vcxproj
@@ -137,6 +137,7 @@
+
@@ -149,5 +150,6 @@
+
\ No newline at end of file
diff --git a/AnnService/inc/Core/BKT/Index.h b/AnnService/inc/Core/BKT/Index.h
index 8d8cc21f..c14aa815 100644
--- a/AnnService/inc/Core/BKT/Index.h
+++ b/AnnService/inc/Core/BKT/Index.h
@@ -12,11 +12,13 @@
#include "../Common/Dataset.h"
#include "../Common/WorkSpace.h"
#include "../Common/WorkSpacePool.h"
+#include "../Common/FineGrainedLock.h"
+#include "../Common/DataUtils.h"
#include
-#include
#include
#include
+#include
namespace SPTAG
{
@@ -36,7 +38,7 @@ namespace BKT
int childStart;
int childEnd;
- BKTNode(int cid = -1) : centerid(cid), childStart(-1) {}
+ BKTNode(int cid = -1) : centerid(cid), childStart(-1), childEnd(-1) {}
};
template
@@ -119,8 +121,7 @@ namespace BKT
int m_iDataSize;
int m_iDataDimension;
COMMON::Dataset m_pSamples;
- std::shared_ptr m_pMetadata;
-
+
// BKT structures.
int m_iBKTNumber;
std::vector m_pBKTStart;
@@ -156,7 +157,6 @@ namespace BKT
char* m_pGraphMemoryFile;
char* m_pDataPointsMemoryFile;
- int m_iNumberOfThreads;
DistCalcMethod m_iDistCalcMethod;
float(*m_fComputeDistance)(const T* pX, const T* pY, int length);
@@ -167,8 +167,11 @@ namespace BKT
int g_iNumberOfInitialDynamicPivots;
int g_iNumberOfOtherDynamicPivots;
+ int m_iNumberOfThreads;
+ std::mutex m_dataAllocLock;
+ COMMON::FineGrainedLock m_dataUpdateLock;
+ tbb::concurrent_unordered_set m_deletedID;
std::unique_ptr m_workSpacePool;
-
public:
Index() : m_iBKTNumber(1),
m_iBKTKmeansK(32),
@@ -204,95 +207,29 @@ namespace BKT
int GetFeatureDim() const { return m_pSamples.C(); }
int GetNumThreads() const { return m_iNumberOfThreads; }
int GetCurrMaxCheck() const { return m_iMaxCheck; }
+
DistCalcMethod GetDistCalcMethod() const { return m_iDistCalcMethod; }
IndexAlgoType GetIndexAlgoType() const { return IndexAlgoType::BKT; }
- VectorValueType AcceptableQueryValueType() const { return GetEnumValueType(); }
- void SetMetadata(const std::string& metadataFilePath, const std::string& metadataIndexPath) {
- m_pMetadata.reset(new FileMetadataSet(metadataFilePath, metadataIndexPath));
- }
- ByteArray GetMetadata(IndexType p_vectorID) const {
- if (nullptr != m_pMetadata)
- {
- return m_pMetadata->GetMetadata(p_vectorID);
- }
- return ByteArray::c_empty;
- }
+ VectorValueType GetVectorValueType() const { return GetEnumValueType(); }
- bool BuildIndex();
- bool BuildIndex(void* p_data, int p_vectorNum, int p_dimension);
- ErrorCode BuildIndex(std::shared_ptr p_vectorSet,
- std::shared_ptr p_metadataSet);
+ ErrorCode BuildIndex(const void* p_data, int p_vectorNum, int p_dimension);
+
+ ErrorCode LoadIndex(const std::string& p_folderPath);
+ ErrorCode LoadIndexFromMemory(const std::vector& p_indexBlobs);
- bool LoadIndex();
- ErrorCode LoadIndex(const std::string& p_folderPath, const Helper::IniReader& p_configReader);
-
- bool SaveIndex();
ErrorCode SaveIndex(const std::string& p_folderPath);
- void SearchIndex(COMMON::QueryResultSet &query, COMMON::WorkSpace &space) const;
- ErrorCode SearchIndex(QueryResult &query) const;
-
- void AddNodes(const T* pData, int num, COMMON::WorkSpace &space);
-
+ void SearchIndex(COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space, const tbb::concurrent_unordered_set &p_deleted) const;
+ ErrorCode SearchIndex(QueryResult &p_query) const;
+
+ ErrorCode AddIndex(const void* p_vectors, int p_vectorNum, int p_dimension);
+ ErrorCode DeleteIndex(const void* p_vectors, int p_vectorNum);
+ ErrorCode RefineIndex(const std::string& p_folderPath);
+ ErrorCode MergeIndex(const char* p_indexFilePath1, const char* p_indexFilePath2);
+
ErrorCode SetParameter(const char* p_param, const char* p_value);
std::string GetParameter(const char* p_param) const;
- // This can be used for both building model files or searching with model files loaded.
- void SetParameters(std::string dataPointsFile,
- std::string BKTFile,
- std::string graphFile,
- int numBKT,
- int neighborhoodSize,
- int kmeansK,
- int BKTLeafSize,
- int numSamplesBKT,
- int numTPTrees,
- int TPTLeafSize,
- int maxCheckForRefineGraph,
- int numThreads,
- DistCalcMethod distCalcMethod,
- int cacheSize = -1,
- int numPoints = -1)
- {
- m_sDataPointsFilename = dataPointsFile;
- m_sBKTFilename = BKTFile;
- m_sGraphFilename = graphFile;
- m_iBKTNumber = numBKT;
- m_iNeighborhoodSize = neighborhoodSize;
- m_iBKTKmeansK = kmeansK;
- m_iBKTLeafSize = BKTLeafSize;
- m_iSamples = numSamplesBKT;
- m_iTptreeNumber = numTPTrees;
- m_iTPTLeafSize = TPTLeafSize;
- m_iMaxCheckForRefineGraph = maxCheckForRefineGraph;
- m_iNumberOfThreads = numThreads;
- m_iDistCalcMethod = distCalcMethod;
- m_fComputeDistance = COMMON::DistanceCalcSelector(m_iDistCalcMethod);
-
- m_iCacheSize = cacheSize;
- m_iDebugLoad = numPoints;
- }
-
- // Only used for searching with memory mapped model files
- void SetParameters(char* pDataPointsMemFile,
- char* pBKTMemFile,
- char* pGraphMemFile,
- DistCalcMethod distCalcMethod,
- int maxCheck,
- int numBKT,
- int neighborhoodSize)
- {
- m_pDataPointsMemoryFile = pDataPointsMemFile;
- m_pBKTMemoryFile = pBKTMemFile;
- m_pGraphMemoryFile = pGraphMemFile;
- m_iMaxCheck = maxCheck;
- m_iBKTNumber = numBKT;
- m_iNeighborhoodSize = neighborhoodSize;
- m_iNumberOfThreads = 1;
- m_iDistCalcMethod = distCalcMethod;
- m_fComputeDistance = COMMON::DistanceCalcSelector(m_iDistCalcMethod);
- }
-
private:
// Functions for loading models from files
bool LoadDataPoints(std::string sDataPointsFileName);
@@ -307,8 +244,8 @@ namespace BKT
bool SaveDataPoints(std::string sDataPointsFileName);
// Functions for building balanced kmeans tree
- void BuildBKT();
- bool SaveBKT(std::string sBKTFilename) const;
+ void BuildBKT(std::vector& indices, std::vector& newStart, std::vector& newRoot);
+ bool SaveBKT(std::string sBKTFilename, std::vector& newStart, std::vector& newRoot) const;
float KmeansAssign(std::vector& indices, const int first, const int last, KmeansArgs& args, bool updateCenters);
int KmeansClustering(std::vector& indices, const int first, const int last, KmeansArgs& args);
diff --git a/AnnService/inc/Core/Common.h b/AnnService/inc/Core/Common.h
index 75279390..7d61675b 100644
--- a/AnnService/inc/Core/Common.h
+++ b/AnnService/inc/Core/Common.h
@@ -7,6 +7,7 @@
#include
#include
#include
+#include
#ifndef _MSC_VER
#include
@@ -29,6 +30,11 @@ template
inline T max(T a, T b) {
return a > b ? a : b;
}
+
+#ifndef _rotl
+#define _rotl(x, n) (((x) << (n)) | ((x) >> (32-(n))))
+#endif
+
#else
#define WIN32_LEAN_AND_MEAN
#include
diff --git a/AnnService/inc/Core/Common/DataUtils.h b/AnnService/inc/Core/Common/DataUtils.h
index 97474827..2ecd7faa 100644
--- a/AnnService/inc/Core/Common/DataUtils.h
+++ b/AnnService/inc/Core/Common/DataUtils.h
@@ -12,177 +12,275 @@ namespace SPTAG
{
const int bufsize = 1024 * 1024 * 1024;
- template
- void ProcessTSVData(int id, int threadbase, long long blocksize,
- std::string filename, std::string outfile, std::string outmetafile, std::string outmetaindexfile,
- std::atomic_int& numSamples, int& D, DistCalcMethod distCalcMethod) {
- std::ifstream inputStream(filename);
- if (!inputStream.is_open()) {
- std::cerr << "unable to open file " + filename << std::endl;
- throw MyException("unable to open file " + filename);
- exit(1);
- }
- std::ofstream outputStream, metaStream_out, metaStream_index;
- outputStream.open(outfile + std::to_string(id + threadbase), std::ofstream::binary);
- metaStream_out.open(outmetafile + std::to_string(id + threadbase), std::ofstream::binary);
- metaStream_index.open(outmetaindexfile + std::to_string(id + threadbase), std::ofstream::binary);
- if (!outputStream.is_open() || !metaStream_out.is_open() || !metaStream_index.is_open()) {
- std::cerr << "unable to open output file " << outfile << " " << outmetafile << " " << outmetaindexfile << std::endl;
- throw MyException("unable to open output files");
- exit(1);
- }
+ class DataUtils {
+ public:
+ template
+ static void ProcessTSVData(int id, int threadbase, std::uint64_t blocksize,
+ std::string filename, std::string outfile, std::string outmetafile, std::string outmetaindexfile,
+ std::atomic_int& numSamples, int& D, DistCalcMethod distCalcMethod) {
+ std::ifstream inputStream(filename);
+ if (!inputStream.is_open()) {
+ std::cerr << "unable to open file " + filename << std::endl;
+ throw MyException("unable to open file " + filename);
+ exit(1);
+ }
+ std::ofstream outputStream, metaStream_out, metaStream_index;
+ outputStream.open(outfile + std::to_string(id + threadbase), std::ofstream::binary);
+ metaStream_out.open(outmetafile + std::to_string(id + threadbase), std::ofstream::binary);
+ metaStream_index.open(outmetaindexfile + std::to_string(id + threadbase), std::ofstream::binary);
+ if (!outputStream.is_open() || !metaStream_out.is_open() || !metaStream_index.is_open()) {
+ std::cerr << "unable to open output file " << outfile << " " << outmetafile << " " << outmetaindexfile << std::endl;
+ throw MyException("unable to open output files");
+ exit(1);
+ }
- std::vector arr;
- std::vector sample;
+ std::vector arr;
+ std::vector sample;
- int base = 1;
- if (distCalcMethod == DistCalcMethod::Cosine) {
- base = Utils::GetBase();
- }
- long long writepos = 0;
- int sampleSize = 0;
- long long totalread = 0;
- std::streamoff startpos = id * blocksize;
+ int base = 1;
+ if (distCalcMethod == DistCalcMethod::Cosine) {
+ base = Utils::GetBase();
+ }
+ std::uint64_t writepos = 0;
+ int sampleSize = 0;
+ std::uint64_t totalread = 0;
+ std::streamoff startpos = id * blocksize;
#ifndef _MSC_VER
- int enter_size = 1;
+ int enter_size = 1;
#else
- int enter_size = 1;
+ int enter_size = 1;
#endif
- std::string currentLine;
- size_t index;
- inputStream.seekg(startpos, std::ifstream::beg);
- if (id != 0) {
- std::getline(inputStream, currentLine);
- totalread += currentLine.length() + enter_size;
- }
- std::cout << "Begin thread " << id << " begin at:" << (startpos + totalread) << std::endl;
- while (!inputStream.eof() && totalread <= blocksize) {
- std::getline(inputStream, currentLine);
- if (currentLine.length() <= enter_size || (index = Utils::ProcessLine(currentLine, arr, D, base, distCalcMethod)) < 0) {
+ std::string currentLine;
+ size_t index;
+ inputStream.seekg(startpos, std::ifstream::beg);
+ if (id != 0) {
+ std::getline(inputStream, currentLine);
totalread += currentLine.length() + enter_size;
- continue;
}
- sample.resize(D);
- for (int j = 0; j < D; j++) sample[j] = (T)arr[j];
+ std::cout << "Begin thread " << id << " begin at:" << (startpos + totalread) << std::endl;
+ while (!inputStream.eof() && totalread <= blocksize) {
+ std::getline(inputStream, currentLine);
+ if (currentLine.length() <= enter_size || (index = Utils::ProcessLine(currentLine, arr, D, base, distCalcMethod)) < 0) {
+ totalread += currentLine.length() + enter_size;
+ continue;
+ }
+ sample.resize(D);
+ for (int j = 0; j < D; j++) sample[j] = (T)arr[j];
- outputStream.write((char *)(sample.data()), sizeof(T)*D);
- metaStream_index.write((char *)&writepos, sizeof(long long));
- metaStream_out.write(currentLine.c_str(), index);
+ outputStream.write((char *)(sample.data()), sizeof(T)*D);
+ metaStream_index.write((char *)&writepos, sizeof(std::uint64_t));
+ metaStream_out.write(currentLine.c_str(), index);
- writepos += index;
- sampleSize += 1;
- totalread += currentLine.length() + enter_size;
- }
- metaStream_index.write((char *)&writepos, sizeof(long long));
- metaStream_index.write((char *)&sampleSize, sizeof(int));
- inputStream.close();
- outputStream.close();
- metaStream_out.close();
- metaStream_index.close();
-
- numSamples.fetch_add(sampleSize);
-
- std::cout << "Finish Thread[" << id << ", " << sampleSize << "] at:" << (startpos + totalread) << std::endl;
- }
-
- void MergeData(int threadbase, std::string outfile, std::string outmetafile, std::string outmetaindexfile,
- std::atomic_int& numSamples, int D) {
- std::ifstream inputStream;
- std::ofstream outputStream;
- char * buf = new char[bufsize];
- long long * offsets;
- int partSamples;
- int metaSamples = 0;
- long long lastoff = 0;
-
- outputStream.open(outfile, std::ofstream::binary);
- outputStream.write((char *)&numSamples, sizeof(int));
- outputStream.write((char *)&D, sizeof(int));
- for (int i = 0; i < threadbase; i++) {
- std::string file = outfile + std::to_string(i);
- inputStream.open(file, std::ifstream::binary);
- while (!inputStream.eof()) {
- inputStream.read(buf, bufsize);
- outputStream.write(buf, inputStream.gcount());
+ writepos += index;
+ sampleSize += 1;
+ totalread += currentLine.length() + enter_size;
}
+ metaStream_index.write((char *)&writepos, sizeof(std::uint64_t));
+ metaStream_index.write((char *)&sampleSize, sizeof(int));
inputStream.close();
- remove(file.c_str());
+ outputStream.close();
+ metaStream_out.close();
+ metaStream_index.close();
+
+ numSamples.fetch_add(sampleSize);
+
+ std::cout << "Finish Thread[" << id << ", " << sampleSize << "] at:" << (startpos + totalread) << std::endl;
}
- outputStream.close();
-
- outputStream.open(outmetafile, std::ofstream::binary);
- for (int i = 0; i < threadbase; i++) {
- std::string file = outmetafile + std::to_string(i);
- inputStream.open(file, std::ifstream::binary);
- while (!inputStream.eof()) {
- inputStream.read(buf, bufsize);
- outputStream.write(buf, inputStream.gcount());
+
+ static void MergeData(int threadbase, std::string outfile, std::string outmetafile, std::string outmetaindexfile,
+ std::atomic_int& numSamples, int D) {
+ std::ifstream inputStream;
+ std::ofstream outputStream;
+ char * buf = new char[bufsize];
+ std::uint64_t * offsets;
+ int partSamples;
+ int metaSamples = 0;
+ std::uint64_t lastoff = 0;
+
+ outputStream.open(outfile, std::ofstream::binary);
+ outputStream.write((char *)&numSamples, sizeof(int));
+ outputStream.write((char *)&D, sizeof(int));
+ for (int i = 0; i < threadbase; i++) {
+ std::string file = outfile + std::to_string(i);
+ inputStream.open(file, std::ifstream::binary);
+ while (!inputStream.eof()) {
+ inputStream.read(buf, bufsize);
+ outputStream.write(buf, inputStream.gcount());
+ }
+ inputStream.close();
+ remove(file.c_str());
}
- inputStream.close();
- remove(file.c_str());
- }
- outputStream.close();
- delete[] buf;
+ outputStream.close();
+
+ outputStream.open(outmetafile, std::ofstream::binary);
+ for (int i = 0; i < threadbase; i++) {
+ std::string file = outmetafile + std::to_string(i);
+ inputStream.open(file, std::ifstream::binary);
+ while (!inputStream.eof()) {
+ inputStream.read(buf, bufsize);
+ outputStream.write(buf, inputStream.gcount());
+ }
+ inputStream.close();
+ remove(file.c_str());
+ }
+ outputStream.close();
+ delete[] buf;
- outputStream.open(outmetaindexfile, std::ofstream::binary);
- outputStream.write((char *)&numSamples, sizeof(int));
- for (int i = 0; i < threadbase; i++) {
- std::string file = outmetaindexfile + std::to_string(i);
- inputStream.open(file, std::ifstream::binary);
+ outputStream.open(outmetaindexfile, std::ofstream::binary);
+ outputStream.write((char *)&numSamples, sizeof(int));
+ for (int i = 0; i < threadbase; i++) {
+ std::string file = outmetaindexfile + std::to_string(i);
+ inputStream.open(file, std::ifstream::binary);
- inputStream.seekg(-((long long)sizeof(int)), inputStream.end);
- inputStream.read((char *)&partSamples, sizeof(int));
- offsets = new long long[partSamples + 1];
+ inputStream.seekg(-((long long)sizeof(int)), inputStream.end);
+ inputStream.read((char *)&partSamples, sizeof(int));
+ offsets = new std::uint64_t[partSamples + 1];
- inputStream.seekg(0, inputStream.beg);
- inputStream.read((char *)offsets, sizeof(long long)*(partSamples + 1));
- inputStream.close();
- remove(file.c_str());
+ inputStream.seekg(0, inputStream.beg);
+ inputStream.read((char *)offsets, sizeof(std::uint64_t)*(partSamples + 1));
+ inputStream.close();
+ remove(file.c_str());
- for (int j = 0; j < partSamples + 1; j++)
- offsets[j] += lastoff;
- outputStream.write((char *)offsets, sizeof(long long)*partSamples);
+ for (int j = 0; j < partSamples + 1; j++)
+ offsets[j] += lastoff;
+ outputStream.write((char *)offsets, sizeof(std::uint64_t)*partSamples);
+
+ lastoff = offsets[partSamples];
+ metaSamples += partSamples;
+ delete[] offsets;
+ }
+ outputStream.write((char *)&lastoff, sizeof(std::uint64_t));
+ outputStream.close();
- lastoff = offsets[partSamples];
- metaSamples += partSamples;
- delete[] offsets;
+ std::cout << "numSamples:" << numSamples << " metaSamples:" << metaSamples << " D:" << D << std::endl;
}
- outputStream.write((char *)&lastoff, sizeof(long long));
- outputStream.close();
- std::cout << "numSamples:" << numSamples << " metaSamples:" << metaSamples << " D:" << D << std::endl;
- }
+ static bool MergeIndex(const std::string& p_vectorfile1, const std::string& p_metafile1, const std::string& p_metaindexfile1,
+ const std::string& p_vectorfile2, const std::string& p_metafile2, const std::string& p_metaindexfile2) {
+ std::ifstream inputStream1, inputStream2;
+ std::ofstream outputStream;
+ char * buf = new char[bufsize];
+ int R1, R2, C1, C2;
+
+#define MergeVector(inputStream, vectorFile, R, C) \
+ inputStream.open(vectorFile, std::ifstream::binary); \
+ if (!inputStream.is_open()) { \
+ std::cout << "Cannot open vector file: " << vectorFile <<"!" << std::endl; \
+ return false; \
+ } \
+ inputStream.read((char *)&(R), sizeof(int)); \
+ inputStream.read((char *)&(C), sizeof(int)); \
+
+ MergeVector(inputStream1, p_vectorfile1, R1, C1)
+ MergeVector(inputStream2, p_vectorfile2, R2, C2)
+#undef MergeVector
+ if (C1 != C2) {
+ inputStream1.close(); inputStream2.close();
+ std::cout << "Vector dimensions are not the same!" << std::endl;
+ return false;
+ }
+ R1 += R2;
+ outputStream.open(p_vectorfile1 + "_tmp", std::ofstream::binary);
+ outputStream.write((char *)&R1, sizeof(int));
+ outputStream.write((char *)&C1, sizeof(int));
+ while (!inputStream1.eof()) {
+ inputStream1.read(buf, bufsize);
+ outputStream.write(buf, inputStream1.gcount());
+ }
+ while (!inputStream2.eof()) {
+ inputStream2.read(buf, bufsize);
+ outputStream.write(buf, inputStream2.gcount());
+ }
+ inputStream1.close(); inputStream2.close();
+ outputStream.close();
+
+ if (p_metafile1 != "" && p_metafile2 != "") {
+ outputStream.open(p_metafile1 + "_tmp", std::ofstream::binary);
+#define MergeMeta(inputStream, metaFile) \
+ inputStream.open(metaFile, std::ifstream::binary); \
+ if (!inputStream.is_open()) { \
+ std::cout << "Cannot open meta file: " << metaFile << "!" << std::endl; \
+ return false; \
+ } \
+ while (!inputStream.eof()) { \
+ inputStream.read(buf, bufsize); \
+ outputStream.write(buf, inputStream.gcount()); \
+ } \
+ inputStream.close(); \
+
+ MergeMeta(inputStream1, p_metafile1)
+ MergeMeta(inputStream2, p_metafile2)
+#undef MergeMeta
+ outputStream.close();
+ delete[] buf;
+
+
+ std::uint64_t * offsets;
+ int partSamples;
+ std::uint64_t lastoff = 0;
+ outputStream.open(p_metaindexfile1 + "_tmp", std::ofstream::binary);
+ outputStream.write((char *)&R1, sizeof(int));
+#define MergeMetaIndex(inputStream, metaIndexFile) \
+ inputStream.open(metaIndexFile, std::ifstream::binary); \
+ if (!inputStream.is_open()) { \
+ std::cout << "Cannot open meta index file: " << metaIndexFile << "!" << std::endl; \
+ return false; \
+ } \
+ inputStream.read((char *)&partSamples, sizeof(int)); \
+ offsets = new std::uint64_t[partSamples + 1]; \
+ inputStream.read((char *)offsets, sizeof(std::uint64_t)*(partSamples + 1)); \
+ inputStream.close(); \
+ for (int j = 0; j < partSamples + 1; j++) offsets[j] += lastoff; \
+ outputStream.write((char *)offsets, sizeof(std::uint64_t)*partSamples); \
+ lastoff = offsets[partSamples]; \
+ delete[] offsets; \
+
+ MergeMetaIndex(inputStream1, p_metaindexfile1)
+ MergeMetaIndex(inputStream2, p_metaindexfile2)
+#undef MergeMetaIndex
+ outputStream.write((char *)&lastoff, sizeof(std::uint64_t));
+ outputStream.close();
+
+ rename((p_metafile1 + "_tmp").c_str(), p_metafile1.c_str());
+ rename((p_metaindexfile1 + "_tmp").c_str(), p_metaindexfile1.c_str());
+ }
+ rename((p_vectorfile1 + "_tmp").c_str(), p_vectorfile1.c_str());
+
+ std::cout << "Merged -> numSamples:" << R1 << " D:" << C1 << std::endl;
+ return true;
+ }
- template
- void ParseData(std::string filenames, std::string outfile, std::string outmetafile, std::string outmetaindexfile,
- int threadnum, DistCalcMethod distCalcMethod) {
- omp_set_num_threads(threadnum);
+ template
+ static void ParseData(std::string filenames, std::string outfile, std::string outmetafile, std::string outmetaindexfile,
+ int threadnum, DistCalcMethod distCalcMethod) {
+ omp_set_num_threads(threadnum);
- std::atomic_int numSamples = { 0 };
- int D = -1;
+ std::atomic_int numSamples = { 0 };
+ int D = -1;
- int threadbase = 0;
- std::vector inputFileNames = Helper::StrUtils::SplitString(filenames, ",");
- for (std::string inputFileName : inputFileNames)
- {
+ int threadbase = 0;
+ std::vector inputFileNames = Helper::StrUtils::SplitString(filenames, ",");
+ for (std::string inputFileName : inputFileNames)
+ {
#ifndef _MSC_VER
- struct stat stat_buf;
- stat(inputFileName.c_str(), &stat_buf);
+ struct stat stat_buf;
+ stat(inputFileName.c_str(), &stat_buf);
#else
- struct _stat64 stat_buf;
- int res = _stat64(inputFileName.c_str(), &stat_buf);
+ struct _stat64 stat_buf;
+ int res = _stat64(inputFileName.c_str(), &stat_buf);
#endif
- long long blocksize = (stat_buf.st_size + threadnum - 1) / threadnum;
+ std::uint64_t blocksize = (stat_buf.st_size + threadnum - 1) / threadnum;
#pragma omp parallel for
- for (int i = 0; i < threadnum; i++) {
- ProcessTSVData(i, threadbase, blocksize, inputFileName, outfile, outmetafile, outmetaindexfile, numSamples, D, distCalcMethod);
+ for (int i = 0; i < threadnum; i++) {
+ ProcessTSVData(i, threadbase, blocksize, inputFileName, outfile, outmetafile, outmetaindexfile, numSamples, D, distCalcMethod);
+ }
+ threadbase += threadnum;
}
- threadbase += threadnum;
+ MergeData(threadbase, outfile, outmetafile, outmetaindexfile, numSamples, D);
}
- MergeData(threadbase, outfile, outmetafile, outmetaindexfile, numSamples, D);
- }
+ };
}
}
diff --git a/AnnService/inc/Core/Common/Dataset.h b/AnnService/inc/Core/Common/Dataset.h
index 516edda3..4753b088 100644
--- a/AnnService/inc/Core/Common/Dataset.h
+++ b/AnnService/inc/Core/Common/Dataset.h
@@ -12,108 +12,6 @@ namespace SPTAG
{
namespace COMMON
{
- template
- class LRUCache
- {
- private:
- struct Item
- {
- int idx;
- T* data;
- Item* next;
- Item(int idx_, T* data_) : idx(idx_), data(data_), next(nullptr) {}
- };
-
- int rows;
- int cols;
- T* data;
- std::ifstream fp;
- std::unordered_map cache;
- Item* head;
-
- public:
- LRUCache(const char * filename_, int caches_ = 10000000)
- {
- fp.open(filename_, std::ifstream::binary);
- fp.read((char *)&rows, sizeof(int));
- fp.read((char *)&cols, sizeof(int));
- if (caches_ > rows) caches_ = rows;
-
- data = (T*)aligned_malloc(sizeof(T) * caches_ * cols, ALIGN);
-
- int i = 0, batch = 10000;
- while (i + batch < caches_)
- {
- fp.read((char *)(data + (size_t)i*cols), sizeof(T)*cols*batch);
- i += batch;
- }
- fp.read((char *)(data + (size_t)i*cols), sizeof(T)*cols*(caches_ - i));
-
- Item *p = head = new Item(-1, nullptr);
- for (i = 0; i < caches_; i++)
- {
- p->next = cache[(i + 1) % caches_] = new Item(i, data + (size_t)i*cols);
- p = p->next;
- }
- p->next = head->next;
- delete head;
- head = p;
-
- std::cout << "Use LRUCache (" << caches_ << ")" << std::endl;
- }
- ~LRUCache()
- {
- fp.close();
-
- Item *p = head->next, *q;
- while (p != head)
- {
- q = p;
- p = p->next;
- delete q;
- }
- delete head;
-
- aligned_free(data);
- }
- int R() { return rows; }
- int C() { return cols; }
- T* get(int index)
- {
- auto iter = cache.find(index);
- if (iter == cache.end())
- {
- Item *p = head->next;
- cache[index] = cache[p->idx];
- cache.erase(p->idx);
- p->idx = index;
- fp.seekg(sizeof(int) + sizeof(int) + index * sizeof(T) * cols, std::ios_base::beg);
- fp.read((char *)p->data, sizeof(T)*cols);
- head = p;
- return p->data;
- }
- else
- {
- Item *p = iter->second, *q = p->next;
- if (q == head || q == head->next)
- {
- head = q;
- return q->data;
- }
- p->next = q->next;
- cache[q->next->idx] = p;
-
- q->next = head->next;
- head->next = q;
- cache[index] = head;
- cache[q->next->idx] = q;
-
- head = q;
- return q->data;
- }
- }
- };
-
// structure to save Data and Graph
template
class Dataset
@@ -123,43 +21,44 @@ namespace SPTAG
int cols;
bool ownData = false;
T* data = nullptr;
- LRUCache* cache = nullptr;
std::vector* dataIncremental = nullptr;
public:
Dataset() {}
- Dataset(int rows_, int cols_, T* data_ = nullptr, const char * filename_ = nullptr, int cachesize_ = 0) { Initialize(rows_, cols_, data_, filename_, cachesize_); }
+ Dataset(int rows_, int cols_, T* data_ = nullptr)
+ {
+ Initialize(rows_, cols_, data_);
+ }
~Dataset()
{
- if (cache != nullptr) delete cache;
if (ownData) aligned_free(data);
if (dataIncremental) {
dataIncremental->clear();
delete dataIncremental;
}
}
- void Initialize(int rows_, int cols_, T* data_ = nullptr, const char * filename_ = nullptr, int cachesize_ = 0)
+ void Initialize(int rows_, int cols_, T* data_ = nullptr)
{
- if (filename_ != nullptr)
+ rows = rows_;
+ cols = cols_;
+ data = data_;
+ if (data == nullptr)
{
- cache = new LRUCache(filename_, cachesize_);
- rows = cache->R();
- cols = cache->C();
+ ownData = true;
+ data = (T*)aligned_malloc(sizeof(T) * rows * cols, ALIGN);
}
- else
+ dataIncremental = new std::vector();
+ }
+ void SetR(int R_)
+ {
+ if (R_ >= rows)
+ dataIncremental->resize((R_ - rows) * cols);
+ else
{
- rows = rows_;
- cols = cols_;
- data = data_;
- if (data == nullptr)
- {
- ownData = true;
- data = (T*)aligned_malloc(sizeof(T) * rows * cols, ALIGN);
- }
+ rows = R_;
+ dataIncremental->clear();
}
- dataIncremental = new std::vector();
}
- void SetR(int R_) { rows = R_; }
int R() const { return (int)(rows + dataIncremental->size() / cols); }
int C() const { return cols; }
T* operator[](int index)
@@ -167,39 +66,39 @@ namespace SPTAG
if (index >= rows) {
return dataIncremental->data() + (size_t)(index - rows)*cols;
}
- if (cache != nullptr)
- {
- return cache->get(index);
- }
- else
- {
- return data + (size_t)index*cols;
- }
+ return data + (size_t)index*cols;
}
const T* operator[](int index) const
{
if (index >= rows) {
return dataIncremental->data() + (size_t)(index - rows)*cols;
}
- if (cache != nullptr)
- {
- return cache->get(index);
- }
- else
- {
- return data + (size_t)index*cols;
- }
+ return data + (size_t)index*cols;
}
- T* GetData() {
+
+ T* GetData()
+ {
return data;
}
+ void reset()
+ {
+ if (ownData) {
+ aligned_free(data);
+ ownData = false;
+ }
+ if (dataIncremental) {
+ dataIncremental->clear();
+ delete dataIncremental;
+ }
+ }
+
void AddBatch(const T* pData, int num)
{
dataIncremental->insert(dataIncremental->end(), pData, pData + num*cols);
}
- void AddReserved(int num)
+ void AddBatch(int num)
{
dataIncremental->insert(dataIncremental->end(), (size_t)num*cols, T(-1));
}
diff --git a/AnnService/inc/Core/Common/FineGrainedLock.h b/AnnService/inc/Core/Common/FineGrainedLock.h
new file mode 100644
index 00000000..e1d5dc39
--- /dev/null
+++ b/AnnService/inc/Core/Common/FineGrainedLock.h
@@ -0,0 +1,48 @@
+#ifndef _SPTAG_COMMON_FINEGRAINEDLOCK_H_
+#define _SPTAG_COMMON_FINEGRAINEDLOCK_H_
+
+#include
+#include
+#include
+
+namespace SPTAG
+{
+ namespace COMMON
+ {
+ class FineGrainedLock {
+ public:
+ FineGrainedLock() {}
+ ~FineGrainedLock() {
+ for (int i = 0; i < locks.size(); i++)
+ locks[i].reset();
+ locks.clear();
+ }
+
+ void resize(int n) {
+ int current = (int)locks.size();
+ if (current <= n) {
+ locks.resize(n);
+ for (int i = current; i < n; i++)
+ locks[i].reset(new std::mutex);
+ }
+ else {
+ for (int i = n; i < current; i++)
+ locks[i].reset();
+ locks.resize(n);
+ }
+ }
+
+ std::mutex& operator[](int idx) {
+ return *locks[idx];
+ }
+
+ const std::mutex& operator[](int idx) const {
+ return *locks[idx];
+ }
+ private:
+ std::vector> locks;
+ };
+ }
+}
+
+#endif // _SPTAG_COMMON_FINEGRAINEDLOCK_H_
\ No newline at end of file
diff --git a/AnnService/inc/Core/Common/Heap.h b/AnnService/inc/Core/Common/Heap.h
index f35ddc46..7d4dcc56 100644
--- a/AnnService/inc/Core/Common/Heap.h
+++ b/AnnService/inc/Core/Common/Heap.h
@@ -25,7 +25,7 @@ namespace SPTAG
inline int size() { return count; }
inline bool empty() { return count == 0; }
inline void clear() { count = 0; }
- inline T& Top() { return heap[1]; }
+ inline T& Top() { if (count == 0) return heap[0]; else return heap[1]; }
// Insert a new element in the heap.
void insert(T value)
diff --git a/AnnService/inc/Core/Common/QueryResultSet.h b/AnnService/inc/Core/Common/QueryResultSet.h
index 9808414b..33dcf5c7 100644
--- a/AnnService/inc/Core/Common/QueryResultSet.h
+++ b/AnnService/inc/Core/Common/QueryResultSet.h
@@ -10,13 +10,13 @@ namespace COMMON
inline bool operator < (const BasicResult& lhs, const BasicResult& rhs)
{
- return ((lhs.Dist < rhs.Dist) || ((lhs.Dist == rhs.Dist) && (lhs.Key < rhs.Key)));
+ return ((lhs.Dist < rhs.Dist) || ((lhs.Dist == rhs.Dist) && (lhs.VID < rhs.VID)));
}
inline bool Compare(BasicResult& lhs, BasicResult& rhs)
{
- return ((lhs.Dist < rhs.Dist) || ((lhs.Dist == rhs.Dist) && (lhs.Key < rhs.Key)));
+ return ((lhs.Dist < rhs.Dist) || ((lhs.Dist == rhs.Dist) && (lhs.VID < rhs.VID)));
}
@@ -50,9 +50,9 @@ class QueryResultSet : public QueryResult
bool AddPoint(const int index, float dist)
{
- if (dist < m_results[0].Dist || (dist == m_results[0].Dist && index < m_results[0].Key))
+ if (dist < m_results[0].Dist || (dist == m_results[0].Dist && index < m_results[0].VID))
{
- m_results[0].Key = index;
+ m_results[0].VID = index;
m_results[0].Dist = dist;
Heapify(m_resultNum);
return true;
diff --git a/AnnService/inc/Core/Common/WorkSpace.h b/AnnService/inc/Core/Common/WorkSpace.h
index 57c1b609..f17ff0bb 100644
--- a/AnnService/inc/Core/Common/WorkSpace.h
+++ b/AnnService/inc/Core/Common/WorkSpace.h
@@ -14,7 +14,7 @@ namespace SPTAG
int node;
float distance;
- HeapCell(int _node = -1, float _distance = 0) : node(_node), distance(_distance) {}
+ HeapCell(int _node = -1, float _distance = MaxDist) : node(_node), distance(_distance) {}
inline bool operator < (const HeapCell& rhs)
{
@@ -203,7 +203,7 @@ namespace SPTAG
return nodeCheckStatus.CheckAndSet(idx);
}
- CountVector nodeCheckStatus;
+ OptHashPosVector nodeCheckStatus;
//OptHashPosVector nodeCheckStatus;
// counter for dynamic pivoting
diff --git a/AnnService/inc/Core/KDT/Index.h b/AnnService/inc/Core/KDT/Index.h
index b84f5087..d21e76b0 100644
--- a/AnnService/inc/Core/KDT/Index.h
+++ b/AnnService/inc/Core/KDT/Index.h
@@ -12,11 +12,12 @@
#include "../Common/Dataset.h"
#include "../Common/WorkSpace.h"
#include "../Common/WorkSpacePool.h"
+#include "../Common/FineGrainedLock.h"
+#include "../Common/DataUtils.h"
#include
-#include
#include
-#include
+#include
namespace SPTAG
{
@@ -45,7 +46,6 @@ namespace SPTAG
int m_iDataSize;
int m_iDataDimension;
COMMON::Dataset m_pSamples;
- std::shared_ptr m_pMetadata;
// KDT structures.
int m_iKDTNumber;
@@ -82,7 +82,6 @@ namespace SPTAG
char* m_pGraphMemoryFile;
char* m_pDataPointsMemoryFile;
- int m_iNumberOfThreads;
DistCalcMethod m_iDistCalcMethod;
float(*m_fComputeDistance)(const T* pX, const T* pY, int length);
@@ -93,8 +92,11 @@ namespace SPTAG
int g_iNumberOfInitialDynamicPivots;
int g_iNumberOfOtherDynamicPivots;
+ int m_iNumberOfThreads;
+ std::mutex m_dataAllocLock;
+ COMMON::FineGrainedLock m_dataUpdateLock;
+ tbb::concurrent_unordered_set m_deletedID;
std::unique_ptr m_workSpacePool;
-
public:
Index() : m_iKDTNumber(1),
m_numTopDimensionKDTSplit(5),
@@ -130,89 +132,28 @@ namespace SPTAG
int GetFeatureDim() const { return m_pSamples.C(); }
int GetNumThreads() const { return m_iNumberOfThreads; }
int GetCurrMaxCheck() const { return m_iMaxCheck; }
+
DistCalcMethod GetDistCalcMethod() const { return m_iDistCalcMethod; }
IndexAlgoType GetIndexAlgoType() const { return IndexAlgoType::KDT; }
- VectorValueType AcceptableQueryValueType() const { return GetEnumValueType(); }
- void SetMetadata(const std::string& metadataFilePath, const std::string& metadataIndexPath) {
- m_pMetadata.reset(new FileMetadataSet(metadataFilePath, metadataIndexPath));
- }
- ByteArray GetMetadata(IndexType p_vectorID) const {
- if (nullptr != m_pMetadata)
- {
- return m_pMetadata->GetMetadata(p_vectorID);
- }
- return ByteArray::c_empty;
- }
-
- bool BuildIndex();
- bool BuildIndex(void* p_data, int p_vectorNum, int p_dimension);
- ErrorCode BuildIndex(std::shared_ptr p_vectorSet,
- std::shared_ptr p_metadataSet);
+ VectorValueType GetVectorValueType() const { return GetEnumValueType(); }
- bool LoadIndex();
- ErrorCode LoadIndex(const std::string& p_folderPath, const Helper::IniReader& p_configReader);
+ ErrorCode BuildIndex(const void* p_data, int p_vectorNum, int p_dimension);
+ ErrorCode LoadIndex(const std::string& p_folderPath);
+ ErrorCode LoadIndexFromMemory(const std::vector& p_indexBlobs);
- bool SaveIndex();
ErrorCode SaveIndex(const std::string& p_folderPath);
- void SearchIndex(COMMON::QueryResultSet &query, COMMON::WorkSpace &space) const;
- ErrorCode SearchIndex(QueryResult &query) const;
-
- void AddNodes(const T* pData, int num, COMMON::WorkSpace &space);
+ void SearchIndex(COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space, const tbb::concurrent_unordered_set &p_deleted) const;
+ ErrorCode SearchIndex(QueryResult &p_query) const;
+
+ ErrorCode AddIndex(const void* p_vectors, int p_vectorNum, int p_dimension);
+ ErrorCode DeleteIndex(const void* p_vectors, int p_vectorNum);
+ ErrorCode RefineIndex(const std::string& p_folderPath);
+ ErrorCode MergeIndex(const char* p_indexFilePath1, const char* p_indexFilePath2);
ErrorCode SetParameter(const char* p_param, const char* p_value);
std::string GetParameter(const char* p_param) const;
- // This can be used for both building model files or searching with model files loaded.
- void SetParameters(std::string dataPointsFile,
- std::string KDTFile,
- std::string graphFile,
- int numKDT,
- int neighborhoodSize,
- int numTPTrees,
- int TPTLeafSize,
- int maxCheckForRefineGraph,
- int numThreads,
- DistCalcMethod distCalcMethod,
- int cacheSize = -1,
- int numPoints = -1)
- {
- m_sDataPointsFilename = dataPointsFile;
- m_sKDTFilename = KDTFile;
- m_sGraphFilename = graphFile;
- m_iKDTNumber = numKDT;
- m_iNeighborhoodSize = neighborhoodSize;
- m_iTPTNumber = numTPTrees;
- m_iTPTLeafSize = TPTLeafSize;
- m_iMaxCheckForRefineGraph = maxCheckForRefineGraph;
- m_iNumberOfThreads = numThreads;
- m_iDistCalcMethod = distCalcMethod;
- m_fComputeDistance = COMMON::DistanceCalcSelector(m_iDistCalcMethod);
-
- m_iCacheSize = cacheSize;
- m_iDebugLoad = numPoints;
- }
-
- // Only used for searching with memory mapped model files
- void SetParameters(char* pDataPointsMemFile,
- char* pKDTMemFile,
- char* pGraphMemFile,
- DistCalcMethod distCalcMethod,
- int maxCheck,
- int numKDT,
- int neighborhoodSize)
- {
- m_pDataPointsMemoryFile = pDataPointsMemFile;
- m_pKDTMemoryFile = pKDTMemFile;
- m_pGraphMemoryFile = pGraphMemFile;
- m_iMaxCheck = maxCheck;
- m_iKDTNumber = numKDT;
- m_iNeighborhoodSize = neighborhoodSize;
- m_iNumberOfThreads = 1;
- m_iDistCalcMethod = distCalcMethod;
- m_fComputeDistance = COMMON::DistanceCalcSelector(m_iDistCalcMethod);
- }
-
private:
// Functions for loading models from files
bool LoadDataPoints(std::string sDataPointsFileName);
@@ -227,8 +168,8 @@ namespace SPTAG
bool SaveDataPoints(std::string sDataPointsFileName);
// Functions for building kdtree
- void BuildKDT();
- bool SaveKDT(std::string sKDTFilename) const;
+ void BuildKDT(std::vector& indices, std::vector& newStart, std::vector& newRoot);
+ bool SaveKDT(std::string sKDTFilename, std::vector& newStart, std::vector& newRoot) const;
void DivideTree(KDTNode* pTree, std::vector& indices,int first, int last,
int index, int &iTreeSize);
void ChooseDivision(KDTNode& node, const std::vector& indices, int first, int last);
@@ -249,7 +190,7 @@ namespace SPTAG
// Functions for hybrid search
void KDTSearch(const int node, const bool isInit, const float distBound,
- COMMON::WorkSpace& space, COMMON::QueryResultSet &query) const;
+ COMMON::WorkSpace& space, COMMON::QueryResultSet &query, const tbb::concurrent_unordered_set &deleted) const;
};
} // namespace KDT
} // namespace SPTAG
diff --git a/AnnService/inc/Core/KDT/ParameterDefinitionList.h b/AnnService/inc/Core/KDT/ParameterDefinitionList.h
index 4caa1ca5..932a525f 100644
--- a/AnnService/inc/Core/KDT/ParameterDefinitionList.h
+++ b/AnnService/inc/Core/KDT/ParameterDefinitionList.h
@@ -1,4 +1,4 @@
-#ifdef DefineParameter
+#ifdef DefineKDTParameter
// DefineKDTParameter(VarName, VarType, DefaultValue, RepresentStr)
DefineKDTParameter(m_sKDTFilename, std::string, std::string("tree.bin"), "TreeFilePath")
diff --git a/AnnService/inc/Core/MetadataSet.h b/AnnService/inc/Core/MetadataSet.h
index c5449132..e9794893 100644
--- a/AnnService/inc/Core/MetadataSet.h
+++ b/AnnService/inc/Core/MetadataSet.h
@@ -22,7 +22,11 @@ class MetadataSet
virtual bool Available() const = 0;
- virtual ErrorCode SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile) const = 0;
+ virtual void AddBatch(MetadataSet& data) = 0;
+
+ virtual ErrorCode SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile) = 0;
+
+ static ErrorCode MetaCopy(const std::string& p_src, const std::string& p_dst);
};
@@ -39,18 +43,22 @@ class FileMetadataSet : public MetadataSet
bool Available() const;
- virtual ErrorCode SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile) const;
+ void AddBatch(MetadataSet& data);
+
+ ErrorCode SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile);
private:
std::ifstream* m_fp = nullptr;
- const long long *m_pOffsets = nullptr;
+ std::vector m_pOffsets;
- const int m_iCount = 0;
+ int m_count;
std::string m_metaFile;
std::string m_metaindexFile;
+
+ std::vector m_newdata;
};
@@ -67,38 +75,20 @@ class MemMetadataSet : public MetadataSet
bool Available() const;
- virtual ErrorCode SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile) const;
-
-private:
- const std::uint64_t *m_offsets;
+ void AddBatch(MetadataSet& data);
- ByteArray m_metadataHolder;
+ ErrorCode SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile);
- ByteArray m_offsetHolder;
+private:
+ std::vector m_offsets;
SizeType m_count;
-};
-
-class MetadataSetFileTransfer : public MetadataSet
-{
-public:
- MetadataSetFileTransfer(const std::string& p_metaFile, const std::string& p_metaindexFile);
-
- virtual ~MetadataSetFileTransfer();
-
- virtual ByteArray GetMetadata(IndexType p_vectorID) const;
-
- virtual SizeType Count() const;
-
- virtual bool Available() const;
-
- virtual ErrorCode SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile) const;
+ ByteArray m_metadataHolder;
-private:
- std::string m_metaFile;
+ ByteArray m_offsetHolder;
- std::string m_metaindexFile;
+ std::vector m_newdata;
};
diff --git a/AnnService/inc/Core/SearchQuery.h b/AnnService/inc/Core/SearchQuery.h
index 7fd6806b..8b8c5f7b 100644
--- a/AnnService/inc/Core/SearchQuery.h
+++ b/AnnService/inc/Core/SearchQuery.h
@@ -8,19 +8,15 @@
namespace SPTAG
{
-struct BasicResult
-{
- int Key;
- float Dist;
-
- BasicResult() : Key(-1), Dist(MaxDist)
+ struct BasicResult
{
- }
+ int VID;
+ float Dist;
- BasicResult(int p_key, float p_dist) : Key(p_key), Dist(p_dist)
- {
- }
-};
+ BasicResult() : VID(-1), Dist(MaxDist) {}
+
+ BasicResult(int p_vid, float p_dist) : VID(p_vid), Dist(p_dist) {}
+ };
// Space to save temporary answer, similar with TopKCache
@@ -46,6 +42,16 @@ class QueryResult
Init(p_target, p_resultNum, p_withMeta);
}
+
+ QueryResult(const void* p_target, int p_resultNum, std::vector& p_results)
+ : m_target(p_target),
+ m_resultNum(p_resultNum),
+ m_withMeta(false)
+ {
+ p_results.resize(p_resultNum);
+ m_results.reset(p_results.data());
+ }
+
QueryResult(const QueryResult& p_other)
: m_target(p_other.m_target),
@@ -130,11 +136,11 @@ class QueryResult
}
- inline void SetResult(int p_index, int p_key, float p_dist)
+ inline void SetResult(int p_index, int p_VID, float p_dist)
{
if (p_index < m_resultNum)
{
- m_results[p_index].Key = p_key;
+ m_results[p_index].VID = p_VID;
m_results[p_index].Dist = p_dist;
}
}
@@ -176,7 +182,7 @@ class QueryResult
{
for (int i = 0; i < m_resultNum; i++)
{
- m_results[i].Key = -1;
+ m_results[i].VID = -1;
m_results[i].Dist = MaxDist;
}
diff --git a/AnnService/inc/Core/VectorIndex.h b/AnnService/inc/Core/VectorIndex.h
index ababafd9..6f648d36 100644
--- a/AnnService/inc/Core/VectorIndex.h
+++ b/AnnService/inc/Core/VectorIndex.h
@@ -9,12 +9,6 @@
namespace SPTAG
{
-namespace Helper
-{
-class IniReader;
-}
-
-
class VectorIndex
{
public:
@@ -24,25 +18,48 @@ class VectorIndex
virtual ErrorCode SaveIndex(const std::string& p_folderPath) = 0;
- virtual ErrorCode LoadIndex(const std::string& p_folderPath, const Helper::IniReader& p_configReader) = 0;
+ virtual ErrorCode LoadIndex(const std::string& p_folderPath) = 0;
- virtual ErrorCode BuildIndex(std::shared_ptr p_vectorSet, std::shared_ptr p_metadataSet) = 0;
+ virtual ErrorCode LoadIndexFromMemory(const std::vector& p_indexBlobs) = 0;
+
+ virtual ErrorCode BuildIndex(const void* p_data, int p_vectorNum, int p_dimension) = 0;
virtual ErrorCode SearchIndex(QueryResult& p_results) const = 0;
- virtual std::string GetParameter(const char* p_param) const = 0;
- virtual ErrorCode SetParameter(const char* p_param, const char* p_value) = 0;
+ virtual ErrorCode AddIndex(const void* p_vectors, int p_vectorNum, int p_dimension) = 0;
- virtual std::string GetParameter(const std::string& p_param) const;
- virtual ErrorCode SetParameter(const std::string& p_param, const std::string& p_value);
+ virtual ErrorCode DeleteIndex(const void* p_vectors, int p_vectorNum) = 0;
+
+ virtual ErrorCode RefineIndex(const std::string& p_folderPath) = 0;
- virtual ByteArray GetMetadata(IndexType p_vectorID) const = 0;
+ virtual ErrorCode MergeIndex(const char* p_indexFilePath1, const char* p_indexFilePath2) = 0;
- virtual VectorValueType AcceptableQueryValueType() const = 0;
+ //virtual ErrorCode AddIndexWithID(const void* p_vector, const int& p_id) = 0;
+
+ //virtual ErrorCode DeleteIndexWithID(const void* p_vector, const int& p_id) = 0;
virtual int GetFeatureDim() const = 0;
+ virtual int GetNumSamples() const = 0;
+ virtual DistCalcMethod GetDistCalcMethod() const = 0;
virtual IndexAlgoType GetIndexAlgoType() const = 0;
+ virtual VectorValueType GetVectorValueType() const = 0;
+ virtual int GetNumThreads() const = 0;
+
+ virtual std::string GetParameter(const char* p_param) const = 0;
+ virtual ErrorCode SetParameter(const char* p_param, const char* p_value) = 0;
+
+ virtual ErrorCode BuildIndex(std::shared_ptr p_vectorSet, std::shared_ptr p_metadataSet);
+
+ virtual ErrorCode SearchIndex(const void* p_vector, int p_neighborCount, std::vector& p_results) const;
+
+ virtual ErrorCode AddIndex(std::shared_ptr p_vectorSet, std::shared_ptr p_metadataSet);
+
+ virtual std::string GetParameter(const std::string& p_param) const;
+ virtual ErrorCode SetParameter(const std::string& p_param, const std::string& p_value);
+
+ virtual ByteArray GetMetadata(IndexType p_vectorID) const;
+ virtual void SetMetadata(const std::string& p_metadataFilePath, const std::string& p_metadataIndexPath);
void SetIndexName(const std::string& p_indexName);
@@ -52,8 +69,9 @@ class VectorIndex
static ErrorCode LoadIndex(const std::string& p_loaderFilePath, std::shared_ptr& p_vectorIndex);
-private:
+protected:
std::string m_indexName;
+ std::shared_ptr m_pMetadata;
};
diff --git a/AnnService/inc/Core/VectorSet.h b/AnnService/inc/Core/VectorSet.h
index a1c35e6b..09a6620a 100644
--- a/AnnService/inc/Core/VectorSet.h
+++ b/AnnService/inc/Core/VectorSet.h
@@ -13,15 +13,19 @@ class VectorSet
virtual ~VectorSet();
+ virtual VectorValueType GetValueType() const = 0;
+
virtual void* GetVector(IndexType p_vectorID) const = 0;
virtual void* GetData() const = 0;
- virtual VectorValueType ValueType() const = 0;
-
virtual SizeType Dimension() const = 0;
virtual SizeType Count() const = 0;
+
+ virtual bool Available() const = 0;
+
+ virtual ErrorCode Save(const std::string& p_vectorFile) const = 0;
};
@@ -35,16 +39,20 @@ class BasicVectorSet : public VectorSet
virtual ~BasicVectorSet();
+ virtual VectorValueType GetValueType() const;
+
virtual void* GetVector(IndexType p_vectorID) const;
virtual void* GetData() const;
- virtual VectorValueType ValueType() const;
-
virtual SizeType Dimension() const;
virtual SizeType Count() const;
+ virtual bool Available() const;
+
+ virtual ErrorCode Save(const std::string& p_vectorFile) const;
+
private:
ByteArray m_data;
diff --git a/AnnService/packages.config b/AnnService/packages.config
index 2dbed9b5..424245f6 100644
--- a/AnnService/packages.config
+++ b/AnnService/packages.config
@@ -7,4 +7,6 @@
+
+
\ No newline at end of file
diff --git a/AnnService/src/Client/main.cpp b/AnnService/src/Client/main.cpp
index 2c35ab67..2b3a2f76 100644
--- a/AnnService/src/Client/main.cpp
+++ b/AnnService/src/Client/main.cpp
@@ -56,7 +56,7 @@ int main(int argc, char** argv)
for (const auto& res : indexRes.m_results)
{
fprintf(stdout, "------------------\n");
- fprintf(stdout, "DocIndex: %d Distance: %f\n", res.Key, res.Dist);
+ fprintf(stdout, "DocIndex: %d Distance: %f\n", res.VID, res.Dist);
if (indexRes.m_results.WithMeta())
{
const auto& metadata = indexRes.m_results.GetMetadata(idx);
diff --git a/AnnService/src/Core/BKT/BKTIndex.cpp b/AnnService/src/Core/BKT/BKTIndex.cpp
index 1e2be638..8f9a1862 100644
--- a/AnnService/src/Core/BKT/BKTIndex.cpp
+++ b/AnnService/src/Core/BKT/BKTIndex.cpp
@@ -16,53 +16,96 @@ namespace SPTAG
{
#pragma region Load data points, kd-tree, neighborhood graph
template
- bool Index::LoadIndex()
+ ErrorCode Index::LoadIndexFromMemory(const std::vector& p_indexBlobs)
{
- bool loadedDataPoints = m_pDataPointsMemoryFile != NULL ? LoadDataPoints(m_pDataPointsMemoryFile) : LoadDataPoints(m_sDataPointsFilename);
- if (!loadedDataPoints) return false;
+ if (!LoadDataPoints((char*)p_indexBlobs[0])) return ErrorCode::FailedParseValue;
+ if (!LoadBKT((char*)p_indexBlobs[1])) return ErrorCode::FailedParseValue;
+ if (!LoadGraph((char*)p_indexBlobs[2])) return ErrorCode::FailedParseValue;
+ return ErrorCode::Success;
+ }
- m_iDataSize = m_pSamples.R();
- m_iDataDimension = m_pSamples.C();
+ template
+ ErrorCode Index::LoadIndex(const std::string& p_folderPath)
+ {
+ std::string folderPath(p_folderPath);
+ if (!folderPath.empty() && *(folderPath.rbegin()) != FolderSep)
+ {
+ folderPath += FolderSep;
+ }
+
+ Helper::IniReader p_configReader;
+ if (ErrorCode::Success != p_configReader.LoadIniFile(folderPath + "/indexloader.ini"))
+ {
+ return ErrorCode::FailedOpenFile;
+ }
- bool loadedBKT = m_pBKTMemoryFile != NULL ? LoadBKT(m_pBKTMemoryFile) : LoadBKT(m_sBKTFilename);
- if (!loadedBKT) return false;
+ std::string metadataSection("MetaData");
+ if (p_configReader.DoesSectionExist(metadataSection))
+ {
+ std::string metadataFilePath = p_configReader.GetParameter(metadataSection,
+ "MetaDataFilePath",
+ std::string());
+ std::string metadataIndexFilePath = p_configReader.GetParameter(metadataSection,
+ "MetaDataIndexPath",
+ std::string());
- bool loadedGraph = m_pGraphMemoryFile != NULL ? LoadGraph(m_pGraphMemoryFile) : LoadGraph(m_sGraphFilename);
- if (!loadedGraph) return false;
+ m_pMetadata.reset(new FileMetadataSet(folderPath + metadataFilePath, folderPath + metadataIndexFilePath));
- if (m_iRefineIter > 0) {
- for (int i = 0; i < m_iRefineIter; i++) RefineRNG();
- SaveRNG(m_sGraphFilename);
+ if (!m_pMetadata->Available())
+ {
+ std::cerr << "Error: Failed to load metadata." << std::endl;
+ return ErrorCode::Fail;
+ }
}
- return true;
+
+#define DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr) \
+ SetParameter(RepresentStr, \
+ p_configReader.GetParameter("Index", \
+ RepresentStr, \
+ std::string(#DefaultValue)).c_str()); \
+
+#include "inc/Core/BKT/ParameterDefinitionList.h"
+#undef DefineBKTParameter
+
+ if (DistCalcMethod::Undefined == m_iDistCalcMethod)
+ {
+ return ErrorCode::Fail;
+ }
+
+ if (!LoadDataPoints(folderPath + m_sDataPointsFilename)) return ErrorCode::Fail;
+ if (!LoadBKT(folderPath + m_sBKTFilename)) return ErrorCode::Fail;
+ if (!LoadGraph(folderPath + m_sGraphFilename)) return ErrorCode::Fail;
+
+ m_iDataSize = m_pSamples.R();
+ m_iDataDimension = m_pSamples.C();
+ m_dataUpdateLock.resize(m_iDataSize);
+
+ m_workSpacePool.reset(new COMMON::WorkSpacePool(m_iMaxCheck, GetNumSamples()));
+ m_workSpacePool->Init(m_iNumberOfThreads);
+ return ErrorCode::Success;
}
template
bool Index::LoadDataPoints(std::string sDataPointsFileName)
{
std::cout << "Load Data Points From " << sDataPointsFileName << std::endl;
- if (m_iCacheSize >= 0)
- m_pSamples.Initialize(0, 0, nullptr, sDataPointsFileName.c_str(), m_iCacheSize);
- else
- {
- FILE * fp = fopen(sDataPointsFileName.c_str(), "rb");
- if (fp == NULL) return false;
+ FILE * fp = fopen(sDataPointsFileName.c_str(), "rb");
+ if (fp == NULL) return false;
- int R, C;
- fread(&R, sizeof(int), 1, fp);
- fread(&C, sizeof(int), 1, fp);
+ int R, C;
+ fread(&R, sizeof(int), 1, fp);
+ fread(&C, sizeof(int), 1, fp);
- if (m_iDebugLoad > 0 && R > m_iDebugLoad) R = m_iDebugLoad;
+ if (m_iDebugLoad > 0 && R > m_iDebugLoad) R = m_iDebugLoad;
- m_pSamples.Initialize(R, C);
- int i = 0, batch = 10000;
- while (i + batch < R) {
- fread((m_pSamples)[i], sizeof(T), C * batch, fp);
- i += batch;
- }
- fread((m_pSamples)[i], sizeof(T), C * (R - i), fp);
- fclose(fp);
+ m_pSamples.Initialize(R, C);
+ int i = 0, batch = 10000;
+ while (i + batch < R) {
+ fread((m_pSamples)[i], sizeof(T), C * batch, fp);
+ i += batch;
}
+ fread((m_pSamples)[i], sizeof(T), C * (R - i), fp);
+ fclose(fp);
std::cout << "Load Data Points (" << m_pSamples.R() << ", " << m_pSamples.C() << ") Finish!" << std::endl;
return true;
}
@@ -181,64 +224,68 @@ namespace SPTAG
#pragma region K-NN search
template
- void Index::SearchIndex(COMMON::QueryResultSet &query, COMMON::WorkSpace &space) const
+ void Index::SearchIndex(COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space, const tbb::concurrent_unordered_set &p_deleted) const
{
for (char i = 0; i < m_iBKTNumber; i++) {
const BKTNode& node = m_pBKTRoots[m_pBKTStart[i]];
if (node.childStart < 0) {
- space.m_SPTQueue.insert(COMMON::HeapCell(m_pBKTStart[i], m_fComputeDistance(query.GetTarget(), (m_pSamples)[node.centerid], m_iDataDimension)));
+ p_space.m_SPTQueue.insert(COMMON::HeapCell(m_pBKTStart[i], m_fComputeDistance(p_query.GetTarget(), (m_pSamples)[node.centerid], m_iDataDimension)));
}
else {
for (int begin = node.childStart; begin < node.childEnd; begin++) {
int index = m_pBKTRoots[begin].centerid;
- space.m_SPTQueue.insert(COMMON::HeapCell(begin, m_fComputeDistance(query.GetTarget(), (m_pSamples)[index], m_iDataDimension)));
+ p_space.m_SPTQueue.insert(COMMON::HeapCell(begin, m_fComputeDistance(p_query.GetTarget(), (m_pSamples)[index], m_iDataDimension)));
}
}
}
int checkLimit = g_iNumberOfInitialDynamicPivots;
const int checkPos = m_iNeighborhoodSize - 1;
- while (!space.m_SPTQueue.empty()) {
+ while (!p_space.m_SPTQueue.empty()) {
do
{
- COMMON::HeapCell bcell = space.m_SPTQueue.pop();
+ COMMON::HeapCell bcell = p_space.m_SPTQueue.pop();
const BKTNode& tnode = m_pBKTRoots[bcell.node];
if (tnode.childStart < 0) {
- if (!space.CheckAndSet(tnode.centerid)) {
- space.m_iNumberOfCheckedLeaves++;
- space.m_NGQueue.insert(COMMON::HeapCell(tnode.centerid, bcell.distance));
+ if (!p_space.CheckAndSet(tnode.centerid)) {
+ p_space.m_iNumberOfCheckedLeaves++;
+ p_space.m_NGQueue.insert(COMMON::HeapCell(tnode.centerid, bcell.distance));
}
- if (space.m_iNumberOfCheckedLeaves >= checkLimit) break;
+ if (p_space.m_iNumberOfCheckedLeaves >= checkLimit) break;
}
else {
- if (!space.CheckAndSet(tnode.centerid)) {
- space.m_NGQueue.insert(COMMON::HeapCell(tnode.centerid, bcell.distance));
+ if (!p_space.CheckAndSet(tnode.centerid)) {
+ p_space.m_NGQueue.insert(COMMON::HeapCell(tnode.centerid, bcell.distance));
}
for (int begin = tnode.childStart; begin < tnode.childEnd; begin++) {
int index = m_pBKTRoots[begin].centerid;
- space.m_SPTQueue.insert(COMMON::HeapCell(begin, m_fComputeDistance(query.GetTarget(), (m_pSamples)[index], m_iDataDimension)));
+ p_space.m_SPTQueue.insert(COMMON::HeapCell(begin, m_fComputeDistance(p_query.GetTarget(), (m_pSamples)[index], m_iDataDimension)));
}
}
- } while (!space.m_SPTQueue.empty());
- while (!space.m_NGQueue.empty()) {
- COMMON::HeapCell gnode = space.m_NGQueue.pop();
+ } while (!p_space.m_SPTQueue.empty());
+ while (!p_space.m_NGQueue.empty()) {
+ COMMON::HeapCell gnode = p_space.m_NGQueue.pop();
const int *node = (m_pNeighborhoodGraph)[gnode.node];
_mm_prefetch((const char *)node, _MM_HINT_T0);
- if (query.AddPoint(gnode.node, gnode.distance)) {
- space.m_iNumOfContinuousNoBetterPropagation = 0;
-
- int checkNode = node[checkPos];
- if (checkNode < -1) {
- const BKTNode& tnode = m_pBKTRoots[-2 - checkNode];
- for (int i = -tnode.childStart; i < tnode.childEnd; i++) {
- if (!query.AddPoint(m_pBKTRoots[i].centerid, gnode.distance)) break;
+ if (p_deleted.find(gnode.node) == p_deleted.end()) {
+ if (p_query.AddPoint(gnode.node, gnode.distance)) {
+ p_space.m_iNumOfContinuousNoBetterPropagation = 0;
+
+ int checkNode = node[checkPos];
+ if (checkNode < -1) {
+ const BKTNode& tnode = m_pBKTRoots[-2 - checkNode];
+ for (int i = -tnode.childStart; i < tnode.childEnd; i++) {
+ if (p_deleted.find(m_pBKTRoots[i].centerid) == p_deleted.end()) {
+ if (!p_query.AddPoint(m_pBKTRoots[i].centerid, gnode.distance)) break;
+ }
+ }
}
}
- }
- else {
- space.m_iNumOfContinuousNoBetterPropagation++;
- if (space.m_iNumOfContinuousNoBetterPropagation > space.m_iContinuousLimit || space.m_iNumberOfCheckedLeaves > space.m_iMaxCheck) {
- query.SortResult(); return;
+ else {
+ p_space.m_iNumOfContinuousNoBetterPropagation++;
+ if (p_space.m_iNumOfContinuousNoBetterPropagation > p_space.m_iContinuousLimit || p_space.m_iNumberOfCheckedLeaves > p_space.m_iMaxCheck) {
+ p_query.SortResult(); return;
+ }
}
}
@@ -254,103 +301,70 @@ namespace SPTAG
// do not check it if it has been checked
if (nn_index < 0) break;
- if (space.CheckAndSet(nn_index)) continue;
+ if (p_space.CheckAndSet(nn_index)) continue;
// count the number of the computed nodes
- float distance2leaf = m_fComputeDistance(query.GetTarget(), (m_pSamples)[nn_index], m_iDataDimension);
- space.m_iNumberOfCheckedLeaves++;
- space.m_NGQueue.insert(COMMON::HeapCell(nn_index, distance2leaf));
+ float distance2leaf = m_fComputeDistance(p_query.GetTarget(), (m_pSamples)[nn_index], m_iDataDimension);
+ p_space.m_iNumberOfCheckedLeaves++;
+ p_space.m_NGQueue.insert(COMMON::HeapCell(nn_index, distance2leaf));
}
- if (space.m_NGQueue.Top().distance >= space.m_SPTQueue.Top().distance) {
- checkLimit = g_iNumberOfOtherDynamicPivots + space.m_iNumberOfCheckedLeaves;
+ if (p_space.m_NGQueue.Top().distance > p_space.m_SPTQueue.Top().distance) {
+ checkLimit = g_iNumberOfOtherDynamicPivots + p_space.m_iNumberOfCheckedLeaves;
break;
}
}
}
+ p_query.SortResult();
}
template
ErrorCode
- Index::SearchIndex(QueryResult &query) const
+ Index::SearchIndex(QueryResult &p_query) const
{
auto workSpace = m_workSpacePool->Rent();
workSpace->Reset(m_iMaxCheck);
- SearchIndex(*((COMMON::QueryResultSet*)&query), *workSpace);
+ SearchIndex(*((COMMON::QueryResultSet*)&p_query), *workSpace, m_deletedID);
m_workSpacePool->Return(workSpace);
- if (query.WithMeta() && nullptr != m_pMetadata)
+ if (p_query.WithMeta() && nullptr != m_pMetadata)
{
- for (int i = 0; i < query.GetResultNum(); ++i)
+ for (int i = 0; i < p_query.GetResultNum(); ++i)
{
- query.SetMetadata(i, m_pMetadata->GetMetadata(query.GetResult(i)->Key));
+ for (int i = 0; i < p_query.GetResultNum(); ++i)
+ {
+ int result = p_query.GetResult(i)->VID;
+ p_query.SetMetadata(i, (result < 0) ? ByteArray::c_empty : m_pMetadata->GetMetadata(result));
+ }
}
}
return ErrorCode::Success;
}
-
#pragma endregion
#pragma region Build/Save kd-tree & neighborhood graphs
template
- bool Index::BuildIndex()
- {
- if (!LoadDataPoints(m_sDataPointsFilename.c_str())) return false;
- m_iDataSize = m_pSamples.R();
- m_iDataDimension = m_pSamples.C();
-
- BuildBKT();
- if (!SaveBKT(m_sBKTFilename)) return false;
-
- BuildRNG();
- if (!SaveRNG(m_sGraphFilename)) return false;
- return true;
- }
-
- template
- bool Index::BuildIndex(void* p_data, int p_vectorNum, int p_dimension)
+ ErrorCode Index::BuildIndex(const void* p_data, int p_vectorNum, int p_dimension)
{
- m_pSamples.Initialize(p_vectorNum, p_dimension, reinterpret_cast(p_data));
+ m_pSamples.Initialize(p_vectorNum, p_dimension);
+ std::memcpy(m_pSamples.GetData(), p_data, p_vectorNum * p_dimension * sizeof(T));
m_iDataSize = m_pSamples.R();
m_iDataDimension = m_pSamples.C();
-
- BuildBKT();
- BuildRNG();
- return true;
- }
-
- template
- ErrorCode Index::BuildIndex(std::shared_ptr p_vectorSet,
- std::shared_ptr p_metadataSet)
- {
- if (nullptr == p_vectorSet || p_vectorSet->Count() == 0 || p_vectorSet->Dimension() == 0 || p_vectorSet->ValueType() != GetEnumValueType())
- {
- return ErrorCode::Fail;
- }
+ m_dataUpdateLock.resize(m_iDataSize);
if (DistCalcMethod::Cosine == m_iDistCalcMethod)
{
- m_pSamples.Initialize(p_vectorSet->Count(), p_vectorSet->Dimension());
- std::memcpy(m_pSamples.GetData(), p_vectorSet->GetData(), p_vectorSet->Count() * p_vectorSet->Dimension() * sizeof(T));
-
int base = COMMON::Utils::GetBase();
- for (SPTAG::SizeType i = 0; i < p_vectorSet->Count(); i++) {
- COMMON::Utils::Normalize(m_pSamples[i], m_pSamples.C(), base);
+ for (int i = 0; i < m_iDataSize; i++) {
+ COMMON::Utils::Normalize(m_pSamples[i], m_iDataDimension, base);
}
- p_vectorSet.reset();
- } else {
- m_pSamples.Initialize(p_vectorSet->Count(), p_vectorSet->Dimension(), reinterpret_cast(p_vectorSet->GetData()));
}
-
- m_iDataSize = m_pSamples.R();
- m_iDataDimension = m_pSamples.C();
-
- BuildBKT();
+ std::vector indices(m_iDataSize);
+ for (int i = 0; i < m_iDataSize; i++) indices[i] = i;
+ BuildBKT(indices, m_pBKTStart, m_pBKTRoots);
BuildRNG();
- m_pMetadata = std::move(p_metadataSet);
-
m_workSpacePool.reset(new COMMON::WorkSpacePool(m_iMaxCheck, GetNumSamples()));
m_workSpacePool->Init(m_iNumberOfThreads);
return ErrorCode::Success;
@@ -358,19 +372,19 @@ namespace SPTAG
#pragma region Build/Save kd-tree
template
- bool Index::SaveBKT(std::string sBKTFilename) const
+ bool Index::SaveBKT(std::string sBKTFilename, std::vector& newStart, std::vector& newRoot) const
{
std::cout << "Save BKT to " << sBKTFilename << std::endl;
FILE *fp = fopen(sBKTFilename.c_str(), "wb");
if(fp == NULL) return false;
fwrite(&m_iBKTNumber, sizeof(int), 1, fp);
- fwrite(m_pBKTStart.data(), sizeof(int), m_iBKTNumber, fp);
- int treeNodeSize = (int)m_pBKTRoots.size();
+ fwrite(newStart.data(), sizeof(int), m_iBKTNumber, fp);
+ int treeNodeSize = (int)newRoot.size();
fwrite(&treeNodeSize, sizeof(int), 1, fp);
for (int i = 0; i < treeNodeSize; i++) {
- fwrite(&(m_pBKTRoots[i].centerid), sizeof(int), 1, fp);
- fwrite(&(m_pBKTRoots[i].childStart), sizeof(int), 1, fp);
- fwrite(&(m_pBKTRoots[i].childEnd), sizeof(int), 1, fp);
+ fwrite(&(newRoot[i].centerid), sizeof(int), 1, fp);
+ fwrite(&(newRoot[i].childStart), sizeof(int), 1, fp);
+ fwrite(&(newRoot[i].childEnd), sizeof(int), 1, fp);
}
fclose(fp);
std::cout << "Save BKT Finish!" << std::endl;
@@ -378,69 +392,61 @@ namespace SPTAG
}
template
- void Index::BuildBKT()
+ void Index::BuildBKT(std::vector& indices, std::vector& newStart, std::vector& newRoot)
{
omp_set_num_threads(m_iNumberOfThreads);
struct BKTStackItem {
int index, first, last;
BKTStackItem(int index_, int first_, int last_) : index(index_), first(first_), last(last_) {}
};
-
std::stack ss;
- KmeansArgs args(m_iBKTKmeansK, m_iDataDimension, m_iDataSize, m_iNumberOfThreads);
- std::vector indices(m_iDataSize);
- for (int i = 0; i < m_iDataSize; i++) indices[i] = i;
+ KmeansArgs args(m_iBKTKmeansK, m_iDataDimension, (int)indices.size(), m_iNumberOfThreads);
+ m_pSampleToCenter.clear();
for (char i = 0; i < m_iBKTNumber; i++)
{
std::random_shuffle(indices.begin(), indices.end());
- m_pBKTStart.push_back((int)m_pBKTRoots.size());
- m_pBKTRoots.push_back(BKTNode(m_iDataSize));
+ newStart.push_back((int)newRoot.size());
+ newRoot.push_back(BKTNode((int)indices.size()));
std::cout << "Start to build tree " << i + 1 << std::endl;
- ss.push(BKTStackItem(m_pBKTStart[i], 0, m_iDataSize));
+ ss.push(BKTStackItem(newStart[i], 0, (int)indices.size()));
while (!ss.empty()) {
BKTStackItem item = ss.top(); ss.pop();
- int newBKTid = (int)m_pBKTRoots.size();
- m_pBKTRoots[item.index].childStart = newBKTid;
+ int newBKTid = (int)newRoot.size();
+ newRoot[item.index].childStart = newBKTid;
if (item.last - item.first <= m_iBKTLeafSize) {
- if (item.last == item.first) {
- m_pBKTRoots[item.index].childStart = -m_pBKTRoots[item.index].childStart;
- }
- else {
- for (int j = item.first; j < item.last; j++) {
- m_pBKTRoots.push_back(BKTNode(indices[j]));
- }
+ for (int j = item.first; j < item.last; j++) {
+ newRoot.push_back(BKTNode(indices[j]));
}
}
else { // clustering the data into BKTKmeansK clusters
int numClusters = KmeansClustering(indices, item.first, item.last, args);
if (numClusters <= 1) {
- int end = min(item.last + 1, m_iDataSize);
+ int end = min(item.last + 1, (int)indices.size());
std::sort(indices.begin() + item.first, indices.begin() + end);
- m_pBKTRoots[item.index].centerid = indices[item.first];
- m_pBKTRoots[item.index].childStart = -m_pBKTRoots[item.index].childStart;
+ newRoot[item.index].centerid = indices[item.first];
+ newRoot[item.index].childStart = -newRoot[item.index].childStart;
for (int j = item.first + 1; j < end; j++) {
- m_pBKTRoots.push_back(BKTNode(indices[j]));
- m_pSampleToCenter[indices[j]] = m_pBKTRoots[item.index].centerid;
+ newRoot.push_back(BKTNode(indices[j]));
+ m_pSampleToCenter[indices[j]] = newRoot[item.index].centerid;
}
- m_pSampleToCenter[-1 - m_pBKTRoots[item.index].centerid] = item.index;
+ m_pSampleToCenter[-1 - newRoot[item.index].centerid] = item.index;
}
else {
for (int k = 0; k < m_iBKTKmeansK; k++) {
- if (args.counts[k] > 0) {
- m_pBKTRoots.push_back(BKTNode(indices[item.first + args.counts[k] - 1]));
- ss.push(BKTStackItem(newBKTid++, item.first, item.first + args.counts[k] - 1));
- item.first += args.counts[k];
- }
+ if (args.counts[k] == 0) continue;
+ newRoot.push_back(BKTNode(indices[item.first + args.counts[k] - 1]));
+ if (args.counts[k] > 1) ss.push(BKTStackItem(newBKTid++, item.first, item.first + args.counts[k] - 1));
+ item.first += args.counts[k];
}
}
}
- m_pBKTRoots[item.index].childEnd = (int)m_pBKTRoots.size();
+ newRoot[item.index].childEnd = (int)newRoot.size();
}
- std::cout << i + 1 << " trees built, " << m_pBKTRoots.size() - m_pBKTStart[i] << " " << m_iDataSize << std::endl;
+ std::cout << i + 1 << " trees built, " << newRoot.size() - newStart[i] << " " << indices.size() << std::endl;
}
}
@@ -678,7 +684,7 @@ namespace SPTAG
float bestvariance = Variance[m_iDataDimension - 1].Dist;
for (int i = 0; i < m_numTopDimensionTpTreeSplit; i++)
{
- index[i] = Variance[m_iDataDimension - 1 - i].Key;
+ index[i] = Variance[m_iDataDimension - 1 - i].VID;
bestweight[i] = 0;
}
bestweight[0] = 1;
@@ -788,6 +794,18 @@ namespace SPTAG
m_iGraphSize = m_iDataSize;
m_pNeighborhoodGraph.Initialize(m_iGraphSize, m_iNeighborhoodSize);
+ if (m_iGraphSize < 1000) {
+ std::memset(m_pNeighborhoodGraph.GetData(), -1, m_iGraphSize * m_iNeighborhoodSize * sizeof(int));
+ m_iNeighborhoodSize /= graphScale;
+ RefineRNG();
+ for (int i = 0; i < m_iGraphSize; i++) {
+ if (m_pSampleToCenter.find(-1 - i) != m_pSampleToCenter.end())
+ m_pNeighborhoodGraph[i][m_iNeighborhoodSize - 1] = -2 - m_pSampleToCenter[-1 - i];
+ }
+ std::cout << "Build RNG Graph end!" << std::endl;
+ return;
+ }
+
{
COMMON::Dataset NeighborhoodDists(m_iGraphSize, m_iNeighborhoodSize);
std::vector> TptreeDataIndices(m_iTptreeNumber, std::vector(m_iGraphSize));
@@ -894,8 +912,8 @@ namespace SPTAG
#pragma omp parallel for schedule(dynamic)
for (int i = 0; i < NSample; i++)
{
- //int x = Utils::rand_int(m_iGraphSize);
- int x = i;
+ int x = COMMON::Utils::rand_int(m_iGraphSize);
+ //int x = i;
COMMON::QueryResultSet query((m_pSamples)[x], m_iCEF);
for (int y = 0; y < m_iGraphSize; y++)
{
@@ -910,7 +928,7 @@ namespace SPTAG
}
else {
for (int j = 0; j < m_iNeighborhoodSize && j < m_iCEF; j++) {
- exact_rng[j] = query.GetResult(j)->Key;
+ exact_rng[j] = query.GetResult(j)->VID;
}
for (int j = m_iCEF; j < m_iNeighborhoodSize; j++) exact_rng[j] = -1;
}
@@ -936,18 +954,209 @@ namespace SPTAG
}
template
- void Index::AddNodes(const T* pData, int num, COMMON::WorkSpace &space) {
- m_pSamples.AddBatch(pData, num);
- m_pNeighborhoodGraph.AddReserved(num);
+ ErrorCode Index::RefineIndex(const std::string& p_folderPath)
+ {
+ std::string folderPath(p_folderPath);
+ if (!folderPath.empty() && *(folderPath.rbegin()) != FolderSep)
+ {
+ folderPath += FolderSep;
+ }
+
+ if (!direxists(folderPath.c_str()))
+ {
+ mkdir(folderPath.c_str());
+ }
+ tbb::concurrent_unordered_set deleted(m_deletedID.begin(), m_deletedID.end());
+ std::vector indices;
+ std::unordered_map old2new;
+ int newR = m_iDataSize;
+ for (int i = 0; i < newR; i++) {
+ if (deleted.find(i) == deleted.end()) {
+ indices.push_back(i);
+ old2new[i] = i;
+ }
+ else {
+ while (deleted.find(newR - 1) != deleted.end() && newR > i) newR--;
+ if (newR == i) break;
+ indices.push_back(newR - 1);
+ old2new[newR - 1] = i;
+ newR--;
+ }
+ }
+ old2new[-1] = -1;
+
+ std::cout << "Refine... from " << m_iDataSize << "->" << newR << std::endl;
+ std::ofstream vecOut(folderPath + m_sDataPointsFilename, std::ios::binary);
+ if (!vecOut.is_open()) return ErrorCode::FailedCreateFile;
+ vecOut.write((char*)&newR, sizeof(int));
+ vecOut.write((char*)&m_iDataDimension, sizeof(int));
+ for (int i = 0; i < newR; i++) {
+ vecOut.write((char*)m_pSamples[indices[i]], sizeof(T)*m_iDataDimension);
+ }
+ vecOut.close();
+
+ if (nullptr != m_pMetadata)
+ {
+ std::ofstream metaOut(folderPath + "metadata.bin_tmp", std::ios::binary);
+ std::ofstream metaIndexOut(folderPath + "metadataIndex.bin", std::ios::binary);
+ if (!metaOut.is_open() || !metaIndexOut.is_open()) return ErrorCode::FailedCreateFile;
+ metaIndexOut.write((char*)&newR, sizeof(int));
+ std::uint64_t offset = 0;
+ for (int i = 0; i < newR; i++) {
+ metaIndexOut.write((char*)&offset, sizeof(std::uint64_t));
+ ByteArray meta = m_pMetadata->GetMetadata(indices[i]);
+ metaOut.write((char*)meta.Data(), sizeof(uint8_t)*meta.Length());
+ offset += meta.Length();
+ }
+ metaOut.close();
+ metaIndexOut.write((char*)&offset, sizeof(std::uint64_t));
+ metaIndexOut.close();
+
+ SPTAG::MetadataSet::MetaCopy(folderPath + "metadata.bin_tmp", folderPath + "metadata.bin");
+ }
+
+ std::vector newRoot;
+ std::vector newStart;
+ std::vector tmpindices(indices.begin(), indices.end());
+ BuildBKT(tmpindices, newStart, newRoot);
+#pragma omp parallel for
+ for (int i = 0; i < newRoot.size(); i++) {
+ newRoot[i].centerid = old2new[newRoot[i].centerid];
+ }
+ SaveBKT(folderPath + m_sBKTFilename, newStart, newRoot);
+
+ std::ofstream graphOut(folderPath + m_sGraphFilename, std::ios::binary);
+ if (!graphOut.is_open()) return ErrorCode::FailedCreateFile;
+ graphOut.write((char*)&newR, sizeof(int));
+ graphOut.write((char*)&m_iNeighborhoodSize, sizeof(int));
+
+ int *neighbors = new int[m_iNeighborhoodSize];
+ COMMON::WorkSpace space;
+ space.Initialize(m_iMaxCheckForRefineGraph, m_iDataSize);
+ for (int i = 0; i < newR; i++) {
+ space.Reset(m_iMaxCheckForRefineGraph);
+ COMMON::QueryResultSet query((m_pSamples)[indices[i]], m_iCEF);
+ space.CheckAndSet(indices[i]);
+ for (int j = 0; j < m_iNeighborhoodSize; j++) {
+ int index = m_pNeighborhoodGraph[indices[i]][j];
+ if (index < 0 || space.CheckAndSet(index)) continue;
+ space.m_NGQueue.insert(COMMON::HeapCell(index, m_fComputeDistance(query.GetTarget(), m_pSamples[index], m_iDataDimension)));
+ }
+ SearchIndex(query, space, deleted);
+ RebuildRNGNodeNeighbors(neighbors, query.GetResults(), m_iCEF);
+ for (int j = 0; j < m_iNeighborhoodSize; j++)
+ neighbors[j] = old2new[neighbors[j]];
+ if (m_pSampleToCenter.find(-1 - indices[i]) != m_pSampleToCenter.end()) {
+ neighbors[m_iNeighborhoodSize - 1] = -2 - m_pSampleToCenter[-1 - indices[i]];
+ }
+ graphOut.write((char*)neighbors, sizeof(int) * m_iNeighborhoodSize);
+ }
+ delete[]neighbors;
+ graphOut.close();
+
+ return ErrorCode::Success;
+ }
+
+ template
+ ErrorCode Index::MergeIndex(const char* p_indexFilePath1, const char* p_indexFilePath2) {
+ std::string folderPath1(p_indexFilePath1), folderPath2(p_indexFilePath2);
+ if (!folderPath1.empty() && *(folderPath1.rbegin()) != FolderSep) folderPath1 += FolderSep;
+ if (!folderPath2.empty() && *(folderPath2.rbegin()) != FolderSep) folderPath2 += FolderSep;
+
+ Helper::IniReader p_configReader1, p_configReader2;
+ if (ErrorCode::Success != p_configReader1.LoadIniFile(folderPath1 + "/indexloader.ini"))
+ return ErrorCode::FailedOpenFile;
+
+ if (ErrorCode::Success != p_configReader2.LoadIniFile(folderPath2 + "/indexloader.ini"))
+ return ErrorCode::FailedOpenFile;
+
+ std::string empty("");
+ if (!COMMON::DataUtils::MergeIndex(folderPath1 + p_configReader1.GetParameter("Index", "VectorFilePath", empty),
+ folderPath1 + p_configReader1.GetParameter("MetaData", "MetaDataFilePath", empty),
+ folderPath1 + p_configReader1.GetParameter("MetaData", "MetaDataIndexPath", empty),
+ folderPath2 + p_configReader1.GetParameter("Index", "VectorFilePath", empty),
+ folderPath2 + p_configReader1.GetParameter("MetaData", "MetaDataFilePath", empty),
+ folderPath2 + p_configReader1.GetParameter("MetaData", "MetaDataIndexPath", empty)))
+ return ErrorCode::Fail;
+
+#define DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr) \
+ SetParameter(RepresentStr, \
+ p_configReader1.GetParameter("Index", \
+ RepresentStr, \
+ std::string(#DefaultValue)).c_str()); \
+
+#include "inc/Core/BKT/ParameterDefinitionList.h"
+#undef DefineBKTParameter
+
+ if (!LoadDataPoints(folderPath1 + p_configReader1.GetParameter("Index", "VectorFilePath", empty))) return ErrorCode::FailedOpenFile;
+ std::vector indices(m_iDataSize);
+ for (int i = 0; i < m_iDataSize; i++) indices[i] = i;
+ BuildBKT(indices, m_pBKTStart, m_pBKTRoots);
+ BuildRNG();
+
+ SaveBKT(folderPath1 + p_configReader1.GetParameter("Index", "TreeFilePath", empty), m_pBKTStart, m_pBKTRoots);
+ SaveRNG(folderPath1 + p_configReader1.GetParameter("Index", "GraphFilePath", empty));
+ return ErrorCode::Success;
+ }
+
+ template
+ ErrorCode Index::DeleteIndex(const void* p_vectors, int p_vectorNum) {
+ const T* ptr_v = (const T*)p_vectors;
+#pragma omp parallel for schedule(dynamic)
+ for (int i = 0; i < p_vectorNum; i++) {
+ COMMON::QueryResultSet query(ptr_v + i * m_iDataDimension, m_iCEF);
+ SearchIndex(query);
+ for (int i = 0; i < m_iCEF; i++) {
+ if (query.GetResult(i)->Dist < 1e-6) {
+ m_deletedID.insert(query.GetResult(i)->VID);
+ }
+ }
+ }
+ return ErrorCode::Success;
+ }
+
+ template
+ ErrorCode Index::AddIndex(const void* p_vectors, int p_vectorNum, int p_dimension) {
+ if (m_pBKTRoots.size() == 0) {
+ return BuildIndex(p_vectors, p_vectorNum, p_dimension);
+ }
+ if (p_dimension != m_iDataDimension) return ErrorCode::FailedParseValue;
- if (m_pSamples.R() != m_iDataSize + num || m_pNeighborhoodGraph.R() != m_iDataSize + num)
- std::cout << "Error m_iDataSize" << std::endl;
- m_iDataSize += num;
+ int begin, end;
+ {
+ std::lock_guard lock(m_dataAllocLock);
+
+ m_pSamples.AddBatch((const T*)p_vectors, p_vectorNum);
+ m_pNeighborhoodGraph.AddBatch(p_vectorNum);
+
+ end = m_iDataSize + p_vectorNum;
+ if (m_pSamples.R() != end || m_pNeighborhoodGraph.R() != end) {
+ std::cout << "Memory Error: Cannot alloc space for vectors" << std::endl;
+ m_pSamples.SetR(m_iDataSize);
+ m_pNeighborhoodGraph.SetR(m_iDataSize);
+ return ErrorCode::Fail;
+ }
+ begin = m_iDataSize;
+ m_iDataSize = end;
+ m_iGraphSize = end;
+ m_dataUpdateLock.resize(m_iDataSize);
+ }
+ if (DistCalcMethod::Cosine == m_iDistCalcMethod)
+ {
+ int base = COMMON::Utils::GetBase();
+ for (int i = begin; i < end; i++) {
+ COMMON::Utils::Normalize((T*)m_pSamples[i], m_iDataDimension, base);
+ }
+ }
- for (int node = m_iDataSize - num; node < m_iDataSize; node++)
+ auto space = m_workSpacePool->Rent();
+ for (int node = begin; node < end; node++)
{
- RefineRNGNode(node, space, true);
+ RefineRNGNode(node, *(space.get()), true);
}
+ m_workSpacePool->Return(space);
+ std::cout << "Add " << p_vectorNum << " vectors" << std::endl;
+ return ErrorCode::Success;
}
template
@@ -960,7 +1169,7 @@ namespace SPTAG
if (index < 0 || space.CheckAndSet(index)) continue;
space.m_NGQueue.insert(COMMON::HeapCell(index, m_fComputeDistance(query.GetTarget(), m_pSamples[index], m_iDataDimension)));
}
- SearchIndex(query, space);
+ SearchIndex(query, space, m_deletedID);
RebuildRNGNodeNeighbors(m_pNeighborhoodGraph[node], query.GetResults(), m_iCEF);
if (updateNeighbors) {
@@ -968,18 +1177,50 @@ namespace SPTAG
for (int j = 0; j < m_iCEF; j++)
{
BasicResult* item = query.GetResult(j);
- if (item->Key < 0) continue;
- COMMON::QueryResultSet queryNbs(m_pSamples[item->Key], m_iNeighborhoodSize + 1);
- queryNbs.AddPoint(node, item->Dist);
+ if (item->VID < 0) break;
+
+ int insertID = node;
+ int* nodes = m_pNeighborhoodGraph[item->VID];
+ std::lock_guard lock(m_dataUpdateLock[item->VID]);
for (int k = 0; k < m_iNeighborhoodSize; k++)
{
- int tmpNode = m_pNeighborhoodGraph[item->Key][k];
- if (tmpNode < 0) break;
- float distance = m_fComputeDistance(queryNbs.GetTarget(), m_pSamples[tmpNode], m_iDataDimension);
- queryNbs.AddPoint(tmpNode, distance);
+ int tmpNode = nodes[k];
+ if (tmpNode < -1) continue;
+
+ if (tmpNode < 0)
+ {
+ bool good = true;
+ for (int t = 0; t < k; t++) {
+ if (m_fComputeDistance((m_pSamples)[insertID], (m_pSamples)[nodes[t]], m_iDataDimension) < item->Dist) {
+ good = false;
+ break;
+ }
+ }
+ if (good) {
+ nodes[k] = insertID;
+ }
+ break;
+ }
+ float tmpDist = m_fComputeDistance(m_pSamples[item->VID], m_pSamples[tmpNode], m_iDataDimension);
+ if (item->Dist < tmpDist || (item->Dist == tmpDist && insertID < tmpNode))
+ {
+ bool good = true;
+ for (int t = 0; t < k; t++) {
+ if (m_fComputeDistance((m_pSamples)[insertID], (m_pSamples)[nodes[t]], m_iDataDimension) < item->Dist) {
+ good = false;
+ break;
+ }
+ }
+ if (good) {
+ nodes[k] = insertID;
+ insertID = tmpNode;
+ item->Dist = tmpDist;
+ }
+ else {
+ break;
+ }
+ }
}
- queryNbs.SortResult();
- RebuildRNGNodeNeighbors(m_pNeighborhoodGraph[item->Key], queryNbs.GetResults(), m_iNeighborhoodSize + 1);
}
}
}
@@ -989,16 +1230,16 @@ namespace SPTAG
int count = 0;
for (int j = 0; j < numResults && count < m_iNeighborhoodSize; j++) {
const BasicResult& item = queryResults[j];
- if (item.Key < 0) continue;
+ if (item.VID < 0) continue;
bool good = true;
for (int k = 0; k < count; k++) {
- if (m_fComputeDistance((m_pSamples)[nodes[k]], (m_pSamples)[item.Key], m_iDataDimension) < item.Dist) {
+ if (m_fComputeDistance((m_pSamples)[nodes[k]], (m_pSamples)[item.VID], m_iDataDimension) <= item.Dist) {
good = false;
break;
}
}
- if (good) nodes[count++] = item.Key;
+ if (good) nodes[count++] = item.VID;
}
for (int j = count; j < m_iNeighborhoodSize; j++) nodes[j] = -1;
}
@@ -1025,15 +1266,6 @@ namespace SPTAG
return true;
}
- template
- bool Index::SaveIndex()
- {
- if (!SaveDataPoints(m_sDataPointsFilename)) return false;
- if (!SaveBKT(m_sBKTFilename)) return false;
- if (!SaveRNG(m_sGraphFilename)) return false;
- return true;
- }
-
template
ErrorCode
Index::SaveIndex(const std::string& p_folderPath)
@@ -1060,13 +1292,10 @@ namespace SPTAG
m_sDataPointsFilename = "vectors.bin";
m_sBKTFilename = "tree.bin";
m_sGraphFilename = "graph.bin";
-
- if (!SaveDataPoints(folderPath + m_sDataPointsFilename)) return ErrorCode::Fail;
- if (!SaveBKT(folderPath + m_sBKTFilename)) return ErrorCode::Fail;
- if (!SaveRNG(folderPath + m_sGraphFilename)) return ErrorCode::Fail;
-
- loaderFile << "[Index]" << std::endl;
+ std::string metadataFile = "metadata.bin";
+ std::string metadataIndexFile = "metadataIndex.bin";
+ loaderFile << "[Index]" << std::endl;
loaderFile << "IndexAlgoType=" << Helper::Convert::ConvertToString(IndexAlgoType::BKT) << std::endl;
loaderFile << "ValueType=" << Helper::Convert::ConvertToString(GetEnumValueType()) << std::endl;
loaderFile << std::endl;
@@ -1081,11 +1310,6 @@ namespace SPTAG
if (nullptr != m_pMetadata)
{
- std::string metadataFile = "metadata.bin";
- std::string metadataIndexFile = "metadataIndex.bin";
-
- m_pMetadata->SaveMetadata(folderPath + metadataFile, folderPath + metadataIndexFile);
-
loaderFile << "[MetaData]" << std::endl;
loaderFile << "MetaDataFilePath=" << metadataFile << std::endl;
loaderFile << "MetaDataIndexPath=" << metadataIndexFile << std::endl;
@@ -1093,61 +1317,18 @@ namespace SPTAG
}
loaderFile.close();
- return ErrorCode::Success;
- }
-
- template
- ErrorCode Index::LoadIndex(const std::string& p_folderPath, const Helper::IniReader& p_configReader)
- {
- std::string folderPath(p_folderPath);
- if (!folderPath.empty() && *(folderPath.rbegin()) != FolderSep)
- {
- folderPath += FolderSep;
+ if (m_deletedID.size() > 0) {
+ RefineIndex(folderPath);
}
-
- std::string metadataSection("MetaData");
- if (p_configReader.DoesSectionExist(metadataSection))
- {
- std::string metadataFilePath = p_configReader.GetParameter(metadataSection,
- "MetaDataFilePath",
- std::string());
- std::string metadataIndexFilePath = p_configReader.GetParameter(metadataSection,
- "MetaDataIndexPath",
- std::string());
-
- m_pMetadata.reset(new FileMetadataSet(folderPath + metadataFilePath, folderPath + metadataIndexFilePath));
-
- if (!m_pMetadata->Available())
+ else {
+ if (!SaveDataPoints(folderPath + m_sDataPointsFilename)) return ErrorCode::Fail;
+ if (!SaveBKT(folderPath + m_sBKTFilename, m_pBKTStart, m_pBKTRoots)) return ErrorCode::Fail;
+ if (!SaveRNG(folderPath + m_sGraphFilename)) return ErrorCode::Fail;
+ if (nullptr != m_pMetadata)
{
- std::cerr << "Error: Failed to load metadata." << std::endl;
- return ErrorCode::Fail;
+ m_pMetadata->SaveMetadata(folderPath + metadataFile, folderPath + metadataIndexFile);
}
}
-
-#define DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr) \
- SetParameter(RepresentStr, \
- p_configReader.GetParameter("Index", \
- RepresentStr, \
- std::string(#DefaultValue)).c_str()); \
-
-#include "inc/Core/BKT/ParameterDefinitionList.h"
-#undef DefineBKTParameter
-
- if (DistCalcMethod::Undefined == m_iDistCalcMethod)
- {
- return ErrorCode::Fail;
- }
-
- if (!LoadDataPoints(folderPath + m_sDataPointsFilename)) return ErrorCode::Fail;
- if (!LoadBKT(folderPath + m_sBKTFilename)) return ErrorCode::Fail;
- if (!LoadGraph(folderPath + m_sGraphFilename)) return ErrorCode::Fail;
-
- m_iDataSize = m_pSamples.R();
- m_iDataDimension = m_pSamples.C();
-
- m_workSpacePool.reset(new COMMON::WorkSpacePool(m_iMaxCheck, GetNumSamples()));
- m_workSpacePool->Init(m_iNumberOfThreads);
-
return ErrorCode::Success;
}
#pragma endregion
diff --git a/AnnService/src/Core/Common/WorkSpacePool.cpp b/AnnService/src/Core/Common/WorkSpacePool.cpp
index 5c9f6e1b..6ac6ee88 100644
--- a/AnnService/src/Core/Common/WorkSpacePool.cpp
+++ b/AnnService/src/Core/Common/WorkSpacePool.cpp
@@ -13,6 +13,9 @@ WorkSpacePool::WorkSpacePool(int p_maxCheck, int p_vectorCount)
WorkSpacePool::~WorkSpacePool()
{
+ for (auto& workSpace : m_workSpacePool)
+ workSpace.reset();
+ m_workSpacePool.clear();
}
diff --git a/AnnService/src/Core/KDT/KDTIndex.cpp b/AnnService/src/Core/KDT/KDTIndex.cpp
index 8b120a2e..a6db67d1 100644
--- a/AnnService/src/Core/KDT/KDTIndex.cpp
+++ b/AnnService/src/Core/KDT/KDTIndex.cpp
@@ -16,53 +16,96 @@ namespace SPTAG
{
#pragma region Load data points, kd-tree, neighborhood graph
template
- bool Index::LoadIndex()
+ ErrorCode Index::LoadIndexFromMemory(const std::vector& p_indexBlobs)
{
- bool loadedDataPoints = m_pDataPointsMemoryFile != NULL ? LoadDataPoints(m_pDataPointsMemoryFile) : LoadDataPoints(m_sDataPointsFilename);
- if (!loadedDataPoints) return false;
+ if (!LoadDataPoints((char*)p_indexBlobs[0])) return ErrorCode::FailedParseValue;
+ if (!LoadKDT((char*)p_indexBlobs[1])) return ErrorCode::FailedParseValue;
+ if (!LoadGraph((char*)p_indexBlobs[2])) return ErrorCode::FailedParseValue;
+ return ErrorCode::Success;
+ }
- m_iDataSize = m_pSamples.R();
- m_iDataDimension = m_pSamples.C();
+ template
+ ErrorCode Index::LoadIndex(const std::string& p_folderPath)
+ {
+ std::string folderPath(p_folderPath);
+ if (!folderPath.empty() && *(folderPath.rbegin()) != FolderSep)
+ {
+ folderPath += FolderSep;
+ }
+
+ Helper::IniReader p_configReader;
+ if (ErrorCode::Success != p_configReader.LoadIniFile(folderPath + "/indexloader.ini"))
+ {
+ return ErrorCode::FailedOpenFile;
+ }
- bool loadedKDT = m_pKDTMemoryFile != NULL ? LoadKDT(m_pKDTMemoryFile) : LoadKDT(m_sKDTFilename);
- if (!loadedKDT) return false;
+ std::string metadataSection("MetaData");
+ if (p_configReader.DoesSectionExist(metadataSection))
+ {
+ std::string metadataFilePath = p_configReader.GetParameter(metadataSection,
+ "MetaDataFilePath",
+ std::string());
+ std::string metadataIndexFilePath = p_configReader.GetParameter(metadataSection,
+ "MetaDataIndexPath",
+ std::string());
- bool loadedGraph = m_pGraphMemoryFile != NULL ? LoadGraph(m_pGraphMemoryFile) : LoadGraph(m_sGraphFilename);
- if (!loadedGraph) return false;
+ m_pMetadata.reset(new FileMetadataSet(folderPath + metadataFilePath, folderPath + metadataIndexFilePath));
- if (m_iRefineIter > 0) {
- for (int i = 0; i < m_iRefineIter; i++) RefineRNG();
- SaveRNG(m_sGraphFilename);
+ if (!m_pMetadata->Available())
+ {
+ std::cerr << "Error: Failed to load metadata." << std::endl;
+ return ErrorCode::Fail;
+ }
}
- return true;
+
+#define DefineKDTParameter(VarName, VarType, DefaultValue, RepresentStr) \
+ SetParameter(RepresentStr, \
+ p_configReader.GetParameter("Index", \
+ RepresentStr, \
+ std::string(#DefaultValue)).c_str()); \
+
+#include "inc/Core/KDT/ParameterDefinitionList.h"
+#undef DefineKDTParameter
+
+ if (DistCalcMethod::Undefined == m_iDistCalcMethod)
+ {
+ return ErrorCode::Fail;
+ }
+
+ if (!LoadDataPoints(folderPath + m_sDataPointsFilename)) return ErrorCode::Fail;
+ if (!LoadKDT(folderPath + m_sKDTFilename)) return ErrorCode::Fail;
+ if (!LoadGraph(folderPath + m_sGraphFilename)) return ErrorCode::Fail;
+
+ m_iDataSize = m_pSamples.R();
+ m_iDataDimension = m_pSamples.C();
+ m_dataUpdateLock.resize(m_iDataSize);
+
+ m_workSpacePool.reset(new COMMON::WorkSpacePool(m_iMaxCheck, GetNumSamples()));
+ m_workSpacePool->Init(m_iNumberOfThreads);
+ return ErrorCode::Success;
}
template
bool Index::LoadDataPoints(std::string sDataPointsFileName)
{
std::cout << "Load Data Points From " << sDataPointsFileName << std::endl;
- if (m_iCacheSize >= 0)
- m_pSamples.Initialize(0, 0, nullptr, sDataPointsFileName.c_str(), m_iCacheSize);
- else
- {
- FILE * fp = fopen(sDataPointsFileName.c_str(), "rb");
- if (fp == NULL) return false;
+ FILE * fp = fopen(sDataPointsFileName.c_str(), "rb");
+ if (fp == NULL) return false;
- int R, C;
- fread(&R, sizeof(int), 1, fp);
- fread(&C, sizeof(int), 1, fp);
+ int R, C;
+ fread(&R, sizeof(int), 1, fp);
+ fread(&C, sizeof(int), 1, fp);
- if (m_iDebugLoad > 0 && R > m_iDebugLoad) R = m_iDebugLoad;
+ if (m_iDebugLoad > 0 && R > m_iDebugLoad) R = m_iDebugLoad;
- m_pSamples.Initialize(R, C);
- int i = 0, batch = 10000;
- while (i + batch < R) {
- fread((m_pSamples)[i], sizeof(T), C * batch, fp);
- i += batch;
- }
- fread((m_pSamples)[i], sizeof(T), C * (R - i), fp);
- fclose(fp);
+ m_pSamples.Initialize(R, C);
+ int i = 0, batch = 10000;
+ while (i + batch < R) {
+ fread((m_pSamples)[i], sizeof(T), C * batch, fp);
+ i += batch;
}
+ fread((m_pSamples)[i], sizeof(T), C * (R - i), fp);
+ fclose(fp);
std::cout << "Load Data Points (" << m_pSamples.R() << ", " << m_pSamples.C() << ") Finish!" << std::endl;
return true;
}
@@ -92,15 +135,17 @@ namespace SPTAG
int realKDTNumber;
fread(&realKDTNumber, sizeof(int), 1, fp);
if (realKDTNumber < m_iKDTNumber) m_iKDTNumber = realKDTNumber;
- m_pKDTStart.clear();
+ m_pKDTStart.resize(m_iKDTNumber + 1, -1);
for (int i = 0; i < m_iKDTNumber; i++) {
- m_pKDTStart.push_back((int)(m_pKDTRoots.size()));
int treeNodeSize;
fread(&treeNodeSize, sizeof(int), 1, fp);
- m_pKDTRoots.resize(m_pKDTRoots.size() + treeNodeSize);
- fread(&(m_pKDTRoots[m_pKDTStart[i]]), sizeof(KDTNode), treeNodeSize, fp);
+ if (treeNodeSize > 0) {
+ m_pKDTStart[i] = (int)(m_pKDTRoots.size());
+ m_pKDTRoots.resize(m_pKDTRoots.size() + treeNodeSize);
+ fread(&(m_pKDTRoots[m_pKDTStart[i]]), sizeof(KDTNode), treeNodeSize, fp);
+ }
}
- m_pKDTStart.push_back((int)(m_pKDTRoots.size()));
+ if (m_pKDTRoots.size() > 0) m_pKDTStart[m_iKDTNumber] = (int)(m_pKDTRoots.size());
fclose(fp);
std::cout << "Load KDT (" << m_iKDTNumber << ", " << m_pKDTRoots.size() << ") Finish!" << std::endl;
return true;
@@ -176,29 +221,30 @@ namespace SPTAG
#pragma region K-NN search
template
void Index::KDTSearch(const int node, const bool isInit, const float distBound,
- COMMON::WorkSpace& space, COMMON::QueryResultSet &query) const {
+ COMMON::WorkSpace& p_space, COMMON::QueryResultSet &p_query, const tbb::concurrent_unordered_set &p_deleted) const {
if (node < 0)
{
int index = -node - 1;
+ if (index >= m_iDataSize) return;
#ifdef PREFETCH
const char* data = (const char *)(m_pSamples[index]);
_mm_prefetch(data, _MM_HINT_T0);
_mm_prefetch(data + 64, _MM_HINT_T0);
#endif
- if (space.CheckAndSet(index)) return;
+ if (p_space.CheckAndSet(index)) return;
- float distance = m_fComputeDistance(query.GetTarget(), (T*)data, m_iDataDimension);
- query.AddPoint(index, distance);
- ++space.m_iNumberOfTreeCheckedLeaves;
- ++space.m_iNumberOfCheckedLeaves;
- space.m_NGQueue.insert(COMMON::HeapCell(index, distance));
+ float distance = m_fComputeDistance(p_query.GetTarget(), (T*)data, m_iDataDimension);
+ if (p_deleted.find(index) == p_deleted.end()) p_query.AddPoint(index, distance);
+ ++p_space.m_iNumberOfTreeCheckedLeaves;
+ ++p_space.m_iNumberOfCheckedLeaves;
+ p_space.m_NGQueue.insert(COMMON::HeapCell(index, distance));
return;
}
auto& tnode = m_pKDTRoots[node];
- float diff = (query.GetTarget())[tnode.split_dim] - tnode.split_value;
+ float diff = (p_query.GetTarget())[tnode.split_dim] - tnode.split_value;
float distanceBound = distBound + diff * diff;
int otherChild, bestChild;
if (diff < 0)
@@ -212,30 +258,30 @@ namespace SPTAG
bestChild = tnode.right;
}
- if (!isInit || distanceBound < query.worstDist())
+ if (!isInit || distanceBound < p_query.worstDist())
{
- space.m_SPTQueue.insert(COMMON::HeapCell(otherChild, distanceBound));
+ p_space.m_SPTQueue.insert(COMMON::HeapCell(otherChild, distanceBound));
}
- KDTSearch(bestChild, isInit, distBound, space, query);
+ KDTSearch(bestChild, isInit, distBound, p_space, p_query, p_deleted);
}
template
- void Index::SearchIndex(COMMON::QueryResultSet &query, COMMON::WorkSpace &space) const
+ void Index::SearchIndex(COMMON::QueryResultSet &p_query, COMMON::WorkSpace &p_space, const tbb::concurrent_unordered_set &p_deleted) const
{
for (char i = 0; i < m_iKDTNumber; i++) {
- KDTSearch(m_pKDTStart[i], true, 0, space, query);
+ KDTSearch(m_pKDTStart[i], true, 0, p_space, p_query, p_deleted);
}
- while (!space.m_SPTQueue.empty() && space.m_iNumberOfCheckedLeaves < g_iNumberOfInitialDynamicPivots)
+ while (!p_space.m_SPTQueue.empty() && p_space.m_iNumberOfCheckedLeaves < g_iNumberOfInitialDynamicPivots)
{
- auto& tcell = space.m_SPTQueue.pop();
- if (query.worstDist() < tcell.distance) break;
- KDTSearch(tcell.node, true, tcell.distance, space, query);
+ auto& tcell = p_space.m_SPTQueue.pop();
+ if (p_query.worstDist() < tcell.distance) break;
+ KDTSearch(tcell.node, true, tcell.distance, p_space, p_query, p_deleted);
}
- while (!space.m_NGQueue.empty()) {
+ while (!p_space.m_NGQueue.empty()) {
bool bLocalOpt = true;
- COMMON::HeapCell gnode = space.m_NGQueue.pop();
+ COMMON::HeapCell gnode = p_space.m_NGQueue.pop();
const int *node = (m_pNeighborhoodGraph)[gnode.node];
#ifdef PREFETCH
@@ -252,123 +298,85 @@ namespace SPTAG
// do not check it if it has been checked
if (nn_index < 0) break;
- if (space.CheckAndSet(nn_index)) continue;
+ if (p_space.CheckAndSet(nn_index)) continue;
// count the number of the computed nodes
- float distance2leaf = m_fComputeDistance(query.GetTarget(), (m_pSamples)[nn_index], m_iDataDimension);
+ float distance2leaf = m_fComputeDistance(p_query.GetTarget(), (m_pSamples)[nn_index], m_iDataDimension);
- bool inserted = query.AddPoint(nn_index, distance2leaf);
- space.m_iNumberOfCheckedLeaves++;
- space.m_NGQueue.insert(COMMON::HeapCell(nn_index, distance2leaf));
- if (inserted || distance2leaf < gnode.distance) bLocalOpt = false;
+ if (p_deleted.find(nn_index) == p_deleted.end()) p_query.AddPoint(nn_index, distance2leaf);
+ if (distance2leaf <= p_query.worstDist()|| distance2leaf < gnode.distance) bLocalOpt = false;
+ p_space.m_iNumberOfCheckedLeaves++;
+ p_space.m_NGQueue.insert(COMMON::HeapCell(nn_index, distance2leaf));
}
- if (bLocalOpt) space.m_iNumOfContinuousNoBetterPropagation++;
- else space.m_iNumOfContinuousNoBetterPropagation = 0;
+ if (bLocalOpt) p_space.m_iNumOfContinuousNoBetterPropagation++;
+ else p_space.m_iNumOfContinuousNoBetterPropagation = 0;
- if (space.m_iNumOfContinuousNoBetterPropagation > g_iThresholdOfNumberOfContinuousNoBetterPropagation)
+ if (p_space.m_iNumOfContinuousNoBetterPropagation > g_iThresholdOfNumberOfContinuousNoBetterPropagation)
{
- if (space.m_iNumberOfTreeCheckedLeaves < space.m_iNumberOfCheckedLeaves / 10)
+ if (p_space.m_iNumberOfTreeCheckedLeaves < p_space.m_iNumberOfCheckedLeaves / 10)
{
- int nextNumberOfCheckedLeaves = g_iNumberOfOtherDynamicPivots + space.m_iNumberOfCheckedLeaves;
- while (!space.m_SPTQueue.empty() && space.m_iNumberOfCheckedLeaves < nextNumberOfCheckedLeaves)
+ int nextNumberOfCheckedLeaves = g_iNumberOfOtherDynamicPivots + p_space.m_iNumberOfCheckedLeaves;
+ while (!p_space.m_SPTQueue.empty() && p_space.m_iNumberOfCheckedLeaves < nextNumberOfCheckedLeaves)
{
- auto& tcell = space.m_SPTQueue.pop();
- KDTSearch(tcell.node, false, tcell.distance, space, query);
+ auto& tcell = p_space.m_SPTQueue.pop();
+ KDTSearch(tcell.node, false, tcell.distance, p_space, p_query, p_deleted);
}
}
- else if (gnode.distance > query.worstDist()) {
+ else if (gnode.distance > p_query.worstDist()) {
break;
}
}
- if (space.m_iNumberOfCheckedLeaves >= space.m_iMaxCheck) break;
+ if (p_space.m_iNumberOfCheckedLeaves >= p_space.m_iMaxCheck) break;
}
- query.SortResult();
+ p_query.SortResult();
}
template
ErrorCode
- Index::SearchIndex(QueryResult &query) const
+ Index::SearchIndex(QueryResult &p_query) const
{
auto workSpace = m_workSpacePool->Rent();
workSpace->Reset(m_iMaxCheck);
- SearchIndex(*((COMMON::QueryResultSet*)&query), *workSpace);
+ SearchIndex(*((COMMON::QueryResultSet*)&p_query), *workSpace, m_deletedID);
m_workSpacePool->Return(workSpace);
- if (query.WithMeta() && nullptr != m_pMetadata)
+ if (p_query.WithMeta() && nullptr != m_pMetadata)
{
- for (int i = 0; i < query.GetResultNum(); ++i)
+ for (int i = 0; i < p_query.GetResultNum(); ++i)
{
- query.SetMetadata(i, m_pMetadata->GetMetadata(query.GetResult(i)->Key));
+ int result = p_query.GetResult(i)->VID;
+ p_query.SetMetadata(i, (result < 0) ? ByteArray::c_empty : m_pMetadata->GetMetadata(result));
}
}
return ErrorCode::Success;
}
-
#pragma endregion
#pragma region Build/Save kd-tree & neighborhood graphs
template
- bool Index::BuildIndex()
- {
- if (!LoadDataPoints(m_sDataPointsFilename.c_str())) return false;
- m_iDataSize = m_pSamples.R();
- m_iDataDimension = m_pSamples.C();
-
- BuildKDT();
- if (!SaveKDT(m_sKDTFilename)) return false;
-
- BuildRNG();
- if (!SaveRNG(m_sGraphFilename)) return false;
- return true;
- }
-
- template
- bool Index::BuildIndex(void* p_data, int p_vectorNum, int p_dimension)
+ ErrorCode Index::BuildIndex(const void* p_data, int p_vectorNum, int p_dimension)
{
- m_pSamples.Initialize(p_vectorNum, p_dimension, reinterpret_cast(p_data));
+ m_pSamples.Initialize(p_vectorNum, p_dimension);
+ std::memcpy(m_pSamples.GetData(), p_data, p_vectorNum * p_dimension * sizeof(T));
m_iDataSize = m_pSamples.R();
m_iDataDimension = m_pSamples.C();
-
- BuildKDT();
- BuildRNG();
- return true;
- }
-
- template
- ErrorCode Index::BuildIndex(std::shared_ptr p_vectorSet,
- std::shared_ptr p_metadataSet)
- {
- if (nullptr == p_vectorSet || p_vectorSet->Count() == 0 || p_vectorSet->Dimension() == 0 || p_vectorSet->ValueType() != GetEnumValueType())
- {
- return ErrorCode::Fail;
- }
+ m_dataUpdateLock.resize(m_iDataSize);
if (DistCalcMethod::Cosine == m_iDistCalcMethod)
{
- m_pSamples.Initialize(p_vectorSet->Count(), p_vectorSet->Dimension());
- std::memcpy(m_pSamples.GetData(), p_vectorSet->GetData(), p_vectorSet->Count() * p_vectorSet->Dimension() * sizeof(T));
-
int base = COMMON::Utils::GetBase();
- for (SPTAG::SizeType i = 0; i < p_vectorSet->Count(); i++) {
- COMMON::Utils::Normalize(m_pSamples[i], m_pSamples.C(), base);
+ for (int i = 0; i < m_iDataSize; i++) {
+ COMMON::Utils::Normalize(m_pSamples[i], m_iDataDimension, base);
}
- p_vectorSet.reset();
- }
- else {
- m_pSamples.Initialize(p_vectorSet->Count(), p_vectorSet->Dimension(), reinterpret_cast(p_vectorSet->GetData()));
}
- m_iDataSize = m_pSamples.R();
- m_iDataDimension = m_pSamples.C();
-
- BuildKDT();
+ std::vector indices(m_iDataSize);
+ for (int j = 0; j < m_iDataSize; j++) indices[j] = j;
+ BuildKDT(indices, m_pKDTStart, m_pKDTRoots);
BuildRNG();
-
- m_pMetadata = std::move(p_metadataSet);
-
m_workSpacePool.reset(new COMMON::WorkSpacePool(m_iMaxCheck, GetNumSamples()));
m_workSpacePool->Init(m_iNumberOfThreads);
return ErrorCode::Success;
@@ -376,7 +384,7 @@ namespace SPTAG
#pragma region Build/Save kd-tree
template
- bool Index::SaveKDT(std::string sKDTFilename) const
+ bool Index::SaveKDT(std::string sKDTFilename, std::vector& newStart, std::vector& newRoot) const
{
std::cout << "Save KDT to " << sKDTFilename << std::endl;
FILE *fp = fopen(sKDTFilename.c_str(), "wb");
@@ -384,9 +392,9 @@ namespace SPTAG
fwrite(&m_iKDTNumber, sizeof(int), 1, fp);
for (int i = 0; i < m_iKDTNumber; i++)
{
- int treeNodeSize = m_pKDTStart[i + 1] - m_pKDTStart[i];
+ int treeNodeSize = newStart[i + 1] - newStart[i];
fwrite(&treeNodeSize, sizeof(int), 1, fp);
- fwrite(&(m_pKDTRoots[m_pKDTStart[i]]), sizeof(KDTNode), treeNodeSize, fp);
+ if (treeNodeSize > 0) fwrite(&(newRoot[newStart[i]]), sizeof(KDTNode), treeNodeSize, fp);
}
fclose(fp);
std::cout << "Save KDT Finish!" << std::endl;
@@ -394,25 +402,30 @@ namespace SPTAG
}
template
- void Index::BuildKDT()
+ void Index::BuildKDT(std::vector& indices, std::vector& newStart, std::vector& newRoot)
{
omp_set_num_threads(m_iNumberOfThreads);
- m_pKDTRoots.resize(m_iKDTNumber * m_iDataSize);
- m_pKDTStart.resize(m_iKDTNumber + 1, (int)(m_pKDTRoots.size()));
-
+ newRoot.resize(m_iKDTNumber * indices.size());
+ if (indices.size() > 0)
+ newStart.resize(m_iKDTNumber + 1, (int)(newRoot.size()));
+ else
+ {
+ newStart.resize(m_iKDTNumber + 1, -1);
+ return;
+ }
#pragma omp parallel for
for (int i = 0; i < m_iKDTNumber; i++)
{
Sleep(i * 100); std::srand(clock());
- m_pKDTStart[i] = i * m_iDataSize;
- std::vector KDTDataIndices(m_iDataSize);
- for (int j = 0; j < m_iDataSize; j++) KDTDataIndices[j] = j;
- std::random_shuffle(KDTDataIndices.begin(), KDTDataIndices.end());
+
+ std::vector pindices(indices.begin(), indices.end());
+ std::random_shuffle(pindices.begin(), pindices.end());
+ newStart[i] = i * (int)pindices.size();
std::cout << "Start to build tree " << i + 1 << std::endl;
- int iTreeSize = m_pKDTStart[i];
- DivideTree(m_pKDTRoots.data(), KDTDataIndices, 0, m_iDataSize - 1, m_pKDTStart[i], iTreeSize);
- std::cout << i + 1 << " trees built, " << iTreeSize - m_pKDTStart[i] << " " << m_iDataSize << std::endl;
+ int iTreeSize = newStart[i];
+ DivideTree(newRoot.data(), pindices, 0, (int)pindices.size() - 1, newStart[i], iTreeSize);
+ std::cout << i + 1 << " trees built, " << iTreeSize - newStart[i] << " " << pindices.size() << std::endl;
}
}
@@ -421,8 +434,7 @@ namespace SPTAG
int index, int &iTreeSize) {
ChooseDivision(pTree[index], indices, first, last);
int i = Subdivide(pTree[index], indices, first, last);
-
- if (i - 1 == first)
+ if (i - 1 <= first)
{
pTree[index].left = -indices[first] - 1;
}
@@ -612,7 +624,7 @@ namespace SPTAG
float bestvariance = Variance[m_iDataDimension - 1].Dist;
for (int i = 0; i < m_numTopDimensionTPTSplit; i++)
{
- index[i] = Variance[m_iDataDimension - 1 - i].Key;
+ index[i] = Variance[m_iDataDimension - 1 - i].VID;
bestweight[i] = 0;
}
bestweight[0] = 1;
@@ -722,6 +734,20 @@ namespace SPTAG
m_iGraphSize = m_iDataSize;
m_pNeighborhoodGraph.Initialize(m_iGraphSize, m_iNeighborhoodSize);
+ if (m_iGraphSize < 1000) {
+ std::memset(m_pNeighborhoodGraph.GetData(), -1, m_iGraphSize * m_iNeighborhoodSize * sizeof(int));
+ m_iNeighborhoodSize /= graphScale;
+
+ COMMON::WorkSpace space;
+ space.Initialize(m_iMaxCheckForRefineGraph, m_iGraphSize);
+ for (int i = 0; i < m_iGraphSize; i++)
+ {
+ RefineRNGNode(i, space, true);
+ }
+ std::cout << "Build RNG Graph end!" << std::endl;
+ return;
+ }
+
{
COMMON::Dataset NeighborhoodDists(m_iGraphSize, m_iNeighborhoodSize);
std::vector> TptreeDataIndices(m_iTPTNumber, std::vector(m_iGraphSize));
@@ -802,8 +828,8 @@ namespace SPTAG
#pragma omp parallel for schedule(dynamic)
for (int i = 0; i < NSample; i++)
{
- //int x = Utils::rand_int(m_iGraphSize);
- int x = i;
+ int x = COMMON::Utils::rand_int(m_iGraphSize);
+ //int x = i;
COMMON::QueryResultSet query((m_pSamples)[x], m_iCEF);
for (int y = 0; y < m_iGraphSize; y++)
{
@@ -818,7 +844,7 @@ namespace SPTAG
}
else {
for (int j = 0; j < m_iNeighborhoodSize && j < m_iCEF; j++) {
- exact_rng[j] = query.GetResult(j)->Key;
+ exact_rng[j] = query.GetResult(j)->VID;
}
for (int j = m_iCEF; j < m_iNeighborhoodSize; j++) exact_rng[j] = -1;
}
@@ -844,18 +870,211 @@ namespace SPTAG
}
template
- void Index::AddNodes(const T* pData, int num, COMMON::WorkSpace &space) {
- m_pSamples.AddBatch(pData, num);
- m_pNeighborhoodGraph.AddReserved(num);
+ ErrorCode Index::RefineIndex(const std::string& p_folderPath)
+ {
+ std::string folderPath(p_folderPath);
+ if (!folderPath.empty() && *(folderPath.rbegin()) != FolderSep)
+ {
+ folderPath += FolderSep;
+ }
+
+ if (!direxists(folderPath.c_str()))
+ {
+ mkdir(folderPath.c_str());
+ }
+ tbb::concurrent_unordered_set deleted(m_deletedID.begin(), m_deletedID.end());
+ std::vector indices;
+ std::unordered_map old2new;
+ int newR = m_iDataSize;
+ for (int i = 0; i < newR; i++) {
+ if (deleted.find(i) == deleted.end()) {
+ indices.push_back(i);
+ old2new[i] = i;
+ }
+ else {
+ while (deleted.find(newR - 1) != deleted.end() && newR > i) newR--;
+ if (newR == i) break;
+ indices.push_back(newR - 1);
+ old2new[newR - 1] = i;
+ newR--;
+ }
+ }
+ old2new[-1] = -1;
+
+ std::cout << "Refine... from " << m_iDataSize << "->" << newR << std::endl;
+ std::ofstream vecOut(folderPath + m_sDataPointsFilename, std::ios::binary);
+ if (!vecOut.is_open()) return ErrorCode::FailedCreateFile;
+ vecOut.write((char*)&newR, sizeof(int));
+ vecOut.write((char*)&m_iDataDimension, sizeof(int));
+ for (int i = 0; i < newR; i++) {
+ vecOut.write((char*)(m_pSamples[indices[i]]), sizeof(T)*m_iDataDimension);
+ }
+ vecOut.close();
+
+ if (nullptr != m_pMetadata)
+ {
+ std::ofstream metaOut(folderPath + "metadata.bin_tmp", std::ios::binary);
+ std::ofstream metaIndexOut(folderPath + "metadataIndex.bin", std::ios::binary);
+ if (!metaOut.is_open() || !metaIndexOut.is_open()) return ErrorCode::FailedCreateFile;
+ metaIndexOut.write((char*)&newR, sizeof(int));
+ std::uint64_t offset = 0;
+ for (int i = 0; i < newR; i++) {
+ metaIndexOut.write((char*)&offset, sizeof(std::uint64_t));
+ ByteArray meta = m_pMetadata->GetMetadata(indices[i]);
+ metaOut.write((char*)meta.Data(), sizeof(uint8_t)*meta.Length());
+ offset += meta.Length();
+ }
+ metaOut.close();
+ metaIndexOut.write((char*)&offset, sizeof(std::uint64_t));
+ metaIndexOut.close();
+
+ SPTAG::MetadataSet::MetaCopy(folderPath + "metadata.bin_tmp", folderPath + "metadata.bin");
+ }
+
+ std::ofstream graphOut(folderPath + m_sGraphFilename, std::ios::binary);
+ if (!graphOut.is_open()) return ErrorCode::FailedCreateFile;
+ graphOut.write((char*)&newR, sizeof(int));
+ graphOut.write((char*)&m_iNeighborhoodSize, sizeof(int));
+
+ int *neighbors = new int[m_iNeighborhoodSize];
+ COMMON::WorkSpace space;
+ space.Initialize(m_iMaxCheckForRefineGraph, m_iDataSize);
+ for (int i = 0; i < newR; i++) {
+ space.Reset(m_iMaxCheckForRefineGraph);
+ COMMON::QueryResultSet query((m_pSamples)[indices[i]], m_iCEF);
+ space.CheckAndSet(indices[i]);
+ for (int j = 0; j < m_iNeighborhoodSize; j++) {
+ int index = m_pNeighborhoodGraph[indices[i]][j];
+ if (index < 0 || space.CheckAndSet(index)) continue;
+ space.m_NGQueue.insert(COMMON::HeapCell(index, m_fComputeDistance(query.GetTarget(), m_pSamples[index], m_iDataDimension)));
+ }
+ SearchIndex(query, space, deleted);
+ RebuildRNGNodeNeighbors(neighbors, query.GetResults(), m_iCEF);
+ for (int j = 0; j < m_iNeighborhoodSize; j++)
+ neighbors[j] = old2new[neighbors[j]];
+ graphOut.write((char*)neighbors, sizeof(int) * m_iNeighborhoodSize);
+ }
+ delete[]neighbors;
+ graphOut.close();
+
+ std::vector