Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/knowhere/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,7 @@ class BaseConfig : public Config {
CFG_BOOL trace_visit;
CFG_BOOL enable_mmap;
CFG_BOOL for_tuning;
CFG_BOOL shuffle_build;
KNOHWERE_DECLARE_CONFIG(BaseConfig) {
KNOWHERE_CONFIG_DECLARE_FIELD(metric_type)
.set_default("L2")
Expand Down Expand Up @@ -567,6 +568,10 @@ class BaseConfig : public Config {
.for_deserialize()
.for_deserialize_from_file();
KNOWHERE_CONFIG_DECLARE_FIELD(for_tuning).set_default(false).description("for tuning").for_search();
KNOWHERE_CONFIG_DECLARE_FIELD(shuffle_build)
.set_default(true)
.description("shuffle ids before index building")
.for_train();
}
};
} // namespace knowhere
Expand Down
3 changes: 2 additions & 1 deletion src/index/diskann/diskann.cc
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,8 @@ DiskANNIndexNode<T>::Build(const DataSet& dataset, const Config& cfg) {
static_cast<uint32_t>(build_conf.disk_pq_dims.value()),
false,
build_conf.accelerate_build.value(),
static_cast<uint32_t>(num_nodes_to_cache)};
static_cast<uint32_t>(num_nodes_to_cache),
build_conf.shuffle_build.value()};
RETURN_IF_ERROR(TryDiskANNCall([&]() {
int res = diskann::build_disk_index<T>(diskann_internal_build_config);
if (res != 0)
Expand Down
56 changes: 42 additions & 14 deletions src/index/hnsw/hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "knowhere/feder/HNSW.h"

#include <new>
#include <numeric>

#include "common/range_util.h"
#include "hnswlib/hnswalg.h"
Expand Down Expand Up @@ -76,27 +77,54 @@ class HnswIndexNode : public IndexNode {

knowhere::TimeRecorder build_time("Building HNSW cost");
auto rows = dataset.GetRows();
if (rows <= 0) {
LOG_KNOWHERE_ERROR_ << "Can not add empty data to HNSW index.";
return Status::empty_index;
}
auto tensor = dataset.GetTensor();
auto hnsw_cfg = static_cast<const HnswConfig&>(cfg);
index_->addPoint(tensor, 0);
auto build_pool = ThreadPool::GetGlobalBuildThreadPool();
std::vector<folly::Future<folly::Unit>> futures;
futures.reserve(rows);
bool shuffle_build = hnsw_cfg.shuffle_build.value();

std::atomic<uint64_t> counter{0};
uint64_t one_tenth_row = rows / 10;
for (int i = 1; i < rows; ++i) {
futures.emplace_back(build_pool->push([&, idx = i]() {
index_->addPoint(((const char*)tensor + index_->data_size_ * idx), idx);
uint64_t added = counter.fetch_add(1);
if (added % one_tenth_row == 0) {
LOG_KNOWHERE_INFO_ << "HNSW build progress: " << (added / one_tenth_row) << "0%";
}
}));

std::vector<int> shuffle_batch_ids;
constexpr int64_t batch_size = 8192; // same with diskann
int64_t round_num = std::ceil(float(rows - 1) / batch_size);
auto build_pool = ThreadPool::GetGlobalBuildThreadPool();
std::vector<folly::Future<folly::Unit>> futures;

if (shuffle_build) {
shuffle_batch_ids.reserve(round_num);
for (int i = 0; i < round_num; ++i) {
shuffle_batch_ids.emplace_back(i);
}
std::random_device rng;
std::mt19937 urng(rng());
std::shuffle(shuffle_batch_ids.begin(), shuffle_batch_ids.end(), urng);
}
for (auto& future : futures) {
future.wait();
index_->addPoint(tensor, 0);

futures.reserve(batch_size);
for (int64_t round_id = 0; round_id < round_num; round_id++) {
int64_t start_id = (shuffle_build ? shuffle_batch_ids[round_id] : round_id) * batch_size;
int64_t end_id =
std::min(rows - 1, ((shuffle_build ? shuffle_batch_ids[round_id] : round_id) + 1) * batch_size);
for (int64_t i = start_id; i < end_id; ++i) {
futures.emplace_back(build_pool->push([&, idx = i + 1]() {
index_->addPoint(((const char*)tensor + index_->data_size_ * idx), idx);
uint64_t added = counter.fetch_add(1);
if (added % one_tenth_row == 0) {
LOG_KNOWHERE_INFO_ << "HNSW build progress: " << (added / one_tenth_row) << "0%";
}
}));
}
for (auto& future : futures) {
future.wait();
}
futures.clear();
}

build_time.RecordSection("");
LOG_KNOWHERE_INFO_ << "HNSW built with #points num:" << index_->max_elements_ << " #M:" << index_->M_
<< " #max level:" << index_->maxlevel_ << " #ef_construction:" << index_->ef_construction_
Expand Down
4 changes: 3 additions & 1 deletion thirdparty/DiskANN/include/diskann/aux_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ namespace diskann {
template<typename T>
std::unique_ptr<diskann::Index<T>> build_merged_vamana_index(
std::string base_file, diskann::Metric _compareMetric, unsigned L,
unsigned R, bool accelerate_build, double sampling_rate,
unsigned R, bool accelerate_build, bool shuffle_build, double sampling_rate,
double ram_budget, std::string mem_index_path, std::string medoids_file,
std::string centroids_file);

Expand Down Expand Up @@ -127,6 +127,8 @@ namespace diskann {
bool accelerate_build = false;
// the cached nodes number
uint32_t num_nodes_to_cache = 0;
// shuffle id to build index
bool shuffle_build = false;
};

template<typename T>
Expand Down
11 changes: 6 additions & 5 deletions thirdparty/DiskANN/src/aux_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ namespace diskann {
template<typename T>
std::unique_ptr<diskann::Index<T>> build_merged_vamana_index(
std::string base_file, bool ip_prepared, diskann::Metric compareMetric,
unsigned L, unsigned R, bool accelerate_build, double sampling_rate,
unsigned L, unsigned R, bool accelerate_build, bool shuffle_build, double sampling_rate,
double ram_budget, std::string mem_index_path, std::string medoids_file,
std::string centroids_file) {
size_t base_num, base_dim;
Expand All @@ -496,6 +496,7 @@ namespace diskann {
paras.Set<bool>("saturate_graph", 1);
paras.Set<std::string>("save_path", mem_index_path);
paras.Set<bool>("accelerate_build", accelerate_build);
paras.Set<bool>("shuffle_build", shuffle_build);

std::unique_ptr<diskann::Index<T>> _pvamanaIndex =
std::unique_ptr<diskann::Index<T>>(new diskann::Index<T>(
Expand Down Expand Up @@ -1242,7 +1243,7 @@ namespace diskann {
auto graph_s = std::chrono::high_resolution_clock::now();
auto vamana_index = diskann::build_merged_vamana_index<T>(
data_file_to_use.c_str(), ip_prepared, diskann::Metric::L2, L, R,
config.accelerate_build, p_val, indexing_ram_budget, mem_index_path,
config.accelerate_build, config.shuffle_build, p_val, indexing_ram_budget, mem_index_path,
medoids_path, centroids_path);
auto graph_e = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> graph_diff = graph_e - graph_s;
Expand Down Expand Up @@ -1347,23 +1348,23 @@ namespace diskann {
template std::unique_ptr<diskann::Index<int8_t>>
build_merged_vamana_index<int8_t>(std::string base_file, bool ip_prepared,
diskann::Metric compareMetric, unsigned L,
unsigned R, bool accelerate_build,
unsigned R, bool accelerate_build, bool shuffle_build,
double sampling_rate, double ram_budget,
std::string mem_index_path,
std::string medoids_path,
std::string centroids_file);
template std::unique_ptr<diskann::Index<float>>
build_merged_vamana_index<float>(std::string base_file, bool ip_prepared,
diskann::Metric compareMetric, unsigned L,
unsigned R, bool accelerate_build,
unsigned R, bool accelerate_build, bool shuffle_build,
double sampling_rate, double ram_budget,
std::string mem_index_path,
std::string medoids_path,
std::string centroids_file);
template std::unique_ptr<diskann::Index<uint8_t>>
build_merged_vamana_index<uint8_t>(std::string base_file, bool ip_prepared,
diskann::Metric compareMetric, unsigned L,
unsigned R, bool accelerate_build,
unsigned R, bool accelerate_build, bool shuffle_build,
double sampling_rate, double ram_budget,
std::string mem_index_path,
std::string medoids_path,
Expand Down
15 changes: 13 additions & 2 deletions thirdparty/DiskANN/src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1343,6 +1343,7 @@ namespace diskann {
_indexingRange = parameters.Get<unsigned>("R");
_indexingMaxC = parameters.Get<unsigned>("C");
const bool accelerate_build = parameters.Get<bool>("accelerate_build");
const bool shuffle_build = parameters.Get<bool>("shuffle_build");
const float last_round_alpha = parameters.Get<float>("alpha");
unsigned L = _indexingQueueSize;

Expand Down Expand Up @@ -1432,10 +1433,20 @@ namespace diskann {
}

futures.reserve(round_size);
std::vector<unsigned> shuffle_batch_ids;
if (shuffle_build) {
shuffle_batch_ids.reserve(round_num_syncs);
for (unsigned i = 0; i < (unsigned) round_num_syncs; i++) {
shuffle_batch_ids.emplace_back(i);
}
std::random_device rng;
std::mt19937 urng(rng());
std::shuffle(shuffle_batch_ids.begin(), shuffle_batch_ids.end(), urng);
}
for (uint32_t sync_num = 0; sync_num < round_num_syncs; sync_num++) {
size_t start_id = sync_num * round_size;
size_t start_id = (shuffle_build ? shuffle_batch_ids[sync_num] : sync_num) * round_size;
size_t end_id =
(std::min)(_nd + _num_frozen_pts, (sync_num + 1) * round_size);
(std::min)(_nd + _num_frozen_pts, ((shuffle_build ? shuffle_batch_ids[sync_num] : sync_num) + 1) * round_size);

auto s = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff;
Expand Down