diff --git a/include/knowhere/config.h b/include/knowhere/config.h index 8236d9dea..a9684ef8a 100644 --- a/include/knowhere/config.h +++ b/include/knowhere/config.h @@ -519,6 +519,7 @@ class BaseConfig : public Config { CFG_BOOL trace_visit; CFG_BOOL enable_mmap; CFG_BOOL for_tuning; + CFG_BOOL shuffle_build; KNOHWERE_DECLARE_CONFIG(BaseConfig) { KNOWHERE_CONFIG_DECLARE_FIELD(metric_type) .set_default("L2") @@ -567,6 +568,10 @@ class BaseConfig : public Config { .for_deserialize() .for_deserialize_from_file(); KNOWHERE_CONFIG_DECLARE_FIELD(for_tuning).set_default(false).description("for tuning").for_search(); + KNOWHERE_CONFIG_DECLARE_FIELD(shuffle_build) + .set_default(true) + .description("shuffle ids before index building") + .for_train(); } }; } // namespace knowhere diff --git a/src/index/diskann/diskann.cc b/src/index/diskann/diskann.cc index ad272e1e4..15fa34a8a 100644 --- a/src/index/diskann/diskann.cc +++ b/src/index/diskann/diskann.cc @@ -305,7 +305,8 @@ DiskANNIndexNode::Build(const DataSet& dataset, const Config& cfg) { static_cast(build_conf.disk_pq_dims.value()), false, build_conf.accelerate_build.value(), - static_cast(num_nodes_to_cache)}; + static_cast(num_nodes_to_cache), + build_conf.shuffle_build.value()}; RETURN_IF_ERROR(TryDiskANNCall([&]() { int res = diskann::build_disk_index(diskann_internal_build_config); if (res != 0) diff --git a/src/index/hnsw/hnsw.cc b/src/index/hnsw/hnsw.cc index d82dcd581..febf02173 100644 --- a/src/index/hnsw/hnsw.cc +++ b/src/index/hnsw/hnsw.cc @@ -12,6 +12,7 @@ #include "knowhere/feder/HNSW.h" #include +#include #include "common/range_util.h" #include "hnswlib/hnswalg.h" @@ -76,27 +77,54 @@ class HnswIndexNode : public IndexNode { knowhere::TimeRecorder build_time("Building HNSW cost"); auto rows = dataset.GetRows(); + if (rows <= 0) { + LOG_KNOWHERE_ERROR_ << "Can not add empty data to HNSW index."; + return Status::empty_index; + } auto tensor = dataset.GetTensor(); auto hnsw_cfg = static_cast(cfg); - index_->addPoint(tensor, 0); - auto build_pool = ThreadPool::GetGlobalBuildThreadPool(); - std::vector> futures; - futures.reserve(rows); + bool shuffle_build = hnsw_cfg.shuffle_build.value(); std::atomic counter{0}; uint64_t one_tenth_row = rows / 10; - for (int i = 1; i < rows; ++i) { - futures.emplace_back(build_pool->push([&, idx = i]() { - index_->addPoint(((const char*)tensor + index_->data_size_ * idx), idx); - uint64_t added = counter.fetch_add(1); - if (added % one_tenth_row == 0) { - LOG_KNOWHERE_INFO_ << "HNSW build progress: " << (added / one_tenth_row) << "0%"; - } - })); + + std::vector shuffle_batch_ids; + constexpr int64_t batch_size = 8192; // same with diskann + int64_t round_num = std::ceil(float(rows - 1) / batch_size); + auto build_pool = ThreadPool::GetGlobalBuildThreadPool(); + std::vector> futures; + + if (shuffle_build) { + shuffle_batch_ids.reserve(round_num); + for (int i = 0; i < round_num; ++i) { + shuffle_batch_ids.emplace_back(i); + } + std::random_device rng; + std::mt19937 urng(rng()); + std::shuffle(shuffle_batch_ids.begin(), shuffle_batch_ids.end(), urng); } - for (auto& future : futures) { - future.wait(); + index_->addPoint(tensor, 0); + + futures.reserve(batch_size); + for (int64_t round_id = 0; round_id < round_num; round_id++) { + int64_t start_id = (shuffle_build ? shuffle_batch_ids[round_id] : round_id) * batch_size; + int64_t end_id = + std::min(rows - 1, ((shuffle_build ? shuffle_batch_ids[round_id] : round_id) + 1) * batch_size); + for (int64_t i = start_id; i < end_id; ++i) { + futures.emplace_back(build_pool->push([&, idx = i + 1]() { + index_->addPoint(((const char*)tensor + index_->data_size_ * idx), idx); + uint64_t added = counter.fetch_add(1); + if (added % one_tenth_row == 0) { + LOG_KNOWHERE_INFO_ << "HNSW build progress: " << (added / one_tenth_row) << "0%"; + } + })); + } + for (auto& future : futures) { + future.wait(); + } + futures.clear(); } + build_time.RecordSection(""); LOG_KNOWHERE_INFO_ << "HNSW built with #points num:" << index_->max_elements_ << " #M:" << index_->M_ << " #max level:" << index_->maxlevel_ << " #ef_construction:" << index_->ef_construction_ diff --git a/thirdparty/DiskANN/include/diskann/aux_utils.h b/thirdparty/DiskANN/include/diskann/aux_utils.h index f1595fbcf..262da0851 100644 --- a/thirdparty/DiskANN/include/diskann/aux_utils.h +++ b/thirdparty/DiskANN/include/diskann/aux_utils.h @@ -87,7 +87,7 @@ namespace diskann { template std::unique_ptr> build_merged_vamana_index( std::string base_file, diskann::Metric _compareMetric, unsigned L, - unsigned R, bool accelerate_build, double sampling_rate, + unsigned R, bool accelerate_build, bool shuffle_build, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_file, std::string centroids_file); @@ -127,6 +127,8 @@ namespace diskann { bool accelerate_build = false; // the cached nodes number uint32_t num_nodes_to_cache = 0; + // shuffle id to build index + bool shuffle_build = false; }; template diff --git a/thirdparty/DiskANN/src/aux_utils.cpp b/thirdparty/DiskANN/src/aux_utils.cpp index e99cb8a35..8f4099254 100644 --- a/thirdparty/DiskANN/src/aux_utils.cpp +++ b/thirdparty/DiskANN/src/aux_utils.cpp @@ -474,7 +474,7 @@ namespace diskann { template std::unique_ptr> build_merged_vamana_index( std::string base_file, bool ip_prepared, diskann::Metric compareMetric, - unsigned L, unsigned R, bool accelerate_build, double sampling_rate, + unsigned L, unsigned R, bool accelerate_build, bool shuffle_build, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_file, std::string centroids_file) { size_t base_num, base_dim; @@ -496,6 +496,7 @@ namespace diskann { paras.Set("saturate_graph", 1); paras.Set("save_path", mem_index_path); paras.Set("accelerate_build", accelerate_build); + paras.Set("shuffle_build", shuffle_build); std::unique_ptr> _pvamanaIndex = std::unique_ptr>(new diskann::Index( @@ -1242,7 +1243,7 @@ namespace diskann { auto graph_s = std::chrono::high_resolution_clock::now(); auto vamana_index = diskann::build_merged_vamana_index( data_file_to_use.c_str(), ip_prepared, diskann::Metric::L2, L, R, - config.accelerate_build, p_val, indexing_ram_budget, mem_index_path, + config.accelerate_build, config.shuffle_build, p_val, indexing_ram_budget, mem_index_path, medoids_path, centroids_path); auto graph_e = std::chrono::high_resolution_clock::now(); std::chrono::duration graph_diff = graph_e - graph_s; @@ -1347,7 +1348,7 @@ namespace diskann { template std::unique_ptr> build_merged_vamana_index(std::string base_file, bool ip_prepared, diskann::Metric compareMetric, unsigned L, - unsigned R, bool accelerate_build, + unsigned R, bool accelerate_build, bool shuffle_build, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, @@ -1355,7 +1356,7 @@ namespace diskann { template std::unique_ptr> build_merged_vamana_index(std::string base_file, bool ip_prepared, diskann::Metric compareMetric, unsigned L, - unsigned R, bool accelerate_build, + unsigned R, bool accelerate_build, bool shuffle_build, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, @@ -1363,7 +1364,7 @@ namespace diskann { template std::unique_ptr> build_merged_vamana_index(std::string base_file, bool ip_prepared, diskann::Metric compareMetric, unsigned L, - unsigned R, bool accelerate_build, + unsigned R, bool accelerate_build, bool shuffle_build, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, diff --git a/thirdparty/DiskANN/src/index.cpp b/thirdparty/DiskANN/src/index.cpp index ded3f265c..e631bbfc9 100644 --- a/thirdparty/DiskANN/src/index.cpp +++ b/thirdparty/DiskANN/src/index.cpp @@ -1343,6 +1343,7 @@ namespace diskann { _indexingRange = parameters.Get("R"); _indexingMaxC = parameters.Get("C"); const bool accelerate_build = parameters.Get("accelerate_build"); + const bool shuffle_build = parameters.Get("shuffle_build"); const float last_round_alpha = parameters.Get("alpha"); unsigned L = _indexingQueueSize; @@ -1432,10 +1433,20 @@ namespace diskann { } futures.reserve(round_size); + std::vector shuffle_batch_ids; + if (shuffle_build) { + shuffle_batch_ids.reserve(round_num_syncs); + for (unsigned i = 0; i < (unsigned) round_num_syncs; i++) { + shuffle_batch_ids.emplace_back(i); + } + std::random_device rng; + std::mt19937 urng(rng()); + std::shuffle(shuffle_batch_ids.begin(), shuffle_batch_ids.end(), urng); + } for (uint32_t sync_num = 0; sync_num < round_num_syncs; sync_num++) { - size_t start_id = sync_num * round_size; + size_t start_id = (shuffle_build ? shuffle_batch_ids[sync_num] : sync_num) * round_size; size_t end_id = - (std::min)(_nd + _num_frozen_pts, (sync_num + 1) * round_size); + (std::min)(_nd + _num_frozen_pts, ((shuffle_build ? shuffle_batch_ids[sync_num] : sync_num) + 1) * round_size); auto s = std::chrono::high_resolution_clock::now(); std::chrono::duration diff;