Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions include/knowhere/comp/thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

#include "folly/executors/CPUThreadPoolExecutor.h"
#include "folly/futures/Future.h"
#include "knowhere/expected.h"
#include "knowhere/log.h"

namespace knowhere {
Expand Down Expand Up @@ -182,4 +183,23 @@ class ThreadPool {
inline static std::mutex global_thread_pool_mutex_;
constexpr static size_t kTaskQueueFactor = 16;
};

// T is either folly::Unit or Status
template <typename T>
inline Status
WaitAllSuccess(std::vector<folly::Future<T>>& futures) {
static_assert(std::is_same<T, folly::Unit>::value || std::is_same<T, Status>::value,
"WaitAllSuccess can only be used with folly::Unit or knowhere::Status");
auto allFuts = folly::collectAll(futures.begin(), futures.end()).get();
for (const auto& result : allFuts) {
result.throwUnlessValue();
if constexpr (!std::is_same_v<T, folly::Unit>) {
if (result.value() != Status::success) {
return result.value();
}
}
}
return Status::success;
}

} // namespace knowhere
36 changes: 13 additions & 23 deletions src/common/comp/brute_force.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,17 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
bool is_cosine = IsMetricType(metric_str, metric::COSINE);

int topk = cfg.k.value();
auto labels = new int64_t[nq * topk];
auto distances = new float[nq * topk];
auto labels = std::make_unique<int64_t[]>(nq * topk);
auto distances = std::make_unique<float[]>(nq * topk);

auto pool = ThreadPool::GetGlobalSearchThreadPool();
std::vector<folly::Future<Status>> futs;
futs.reserve(nq);
for (int i = 0; i < nq; ++i) {
futs.emplace_back(pool->push([&, index = i] {
futs.emplace_back(pool->push([&, index = i, labels_ptr = labels.get(), distances_ptr = distances.get()] {
ThreadPool::ScopedOmpSetter setter(1);
auto cur_labels = labels + topk * index;
auto cur_distances = distances + topk * index;
auto cur_labels = labels_ptr + topk * index;
auto cur_distances = distances_ptr + topk * index;
switch (faiss_metric_type) {
case faiss::METRIC_L2: {
auto cur_query = (const float*)xq + dim * index;
Expand Down Expand Up @@ -118,14 +118,11 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
return Status::success;
}));
}
for (auto& fut : futs) {
fut.wait();
auto ret = fut.result().value();
if (ret != Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
}
auto ret = WaitAllSuccess(futs);
if (ret != Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
}
return GenResultDataSet(nq, cfg.k.value(), labels, distances);
return GenResultDataSet(nq, cfg.k.value(), labels.release(), distances.release());
}

Status
Expand Down Expand Up @@ -212,11 +209,7 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
return Status::success;
}));
}
for (auto& fut : futs) {
fut.wait();
auto ret = fut.result().value();
RETURN_IF_ERROR(ret);
}
RETURN_IF_ERROR(WaitAllSuccess(futs));
return Status::success;
}

Expand Down Expand Up @@ -315,12 +308,9 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
return Status::success;
}));
}
for (auto& fut : futs) {
fut.wait();
auto ret = fut.result().value();
if (ret != Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
}
auto ret = WaitAllSuccess(futs);
if (ret != Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
}

int64_t* ids = nullptr;
Expand Down
42 changes: 13 additions & 29 deletions src/index/diskann/diskann.cc
Original file line number Diff line number Diff line change
Expand Up @@ -478,8 +478,6 @@ DiskANNIndexNode<T>::Deserialize(const BinarySet& binset, const Config& cfg) {
std::vector<int64_t> warmup_result_ids_64(warmup_num, 0);
std::vector<float> warmup_result_dists(warmup_num, 0);

bool all_searches_are_good = true;

std::vector<folly::Future<folly::Unit>> futures;
futures.reserve(warmup_num);
for (_s64 i = 0; i < (int64_t)warmup_num; ++i) {
Expand All @@ -489,16 +487,14 @@ DiskANNIndexNode<T>::Deserialize(const BinarySet& binset, const Config& cfg) {
warmup_result_dists.data() + (index * 1), 4);
}));
}
for (auto& future : futures) {
if (TryDiskANNCall([&]() { future.wait(); }) != Status::success) {
all_searches_are_good = false;
}
}

bool failed = TryDiskANNCall([&]() { WaitAllSuccess(futures); }) != Status::success;

if (warmup != nullptr) {
diskann::aligned_free(warmup);
}

if (!all_searches_are_good) {
if (failed) {
LOG_KNOWHERE_ERROR_ << "Failed to do search on warmup file for DiskANN.";
return Status::diskann_inner_error;
}
Expand Down Expand Up @@ -550,30 +546,24 @@ DiskANNIndexNode<T>::Search(const DataSet& dataset, const Config& cfg, const Bit
search_conf.search_list_size.value());
}

auto p_id = new int64_t[k * nq];
auto p_dist = new float[k * nq];
auto p_id = std::make_unique<int64_t[]>(k * nq);
auto p_dist = std::make_unique<float[]>(k * nq);

bool all_searches_are_good = true;
std::vector<folly::Future<folly::Unit>> futures;
futures.reserve(nq);
for (int64_t row = 0; row < nq; ++row) {
futures.emplace_back(search_pool_->push([&, index = row]() {
pq_flash_index_->cached_beam_search(xq + (index * dim), k, lsearch, p_id + (index * k),
p_dist + (index * k), beamwidth, false, nullptr, feder_result, bitset,
filter_ratio, for_tuning);
futures.emplace_back(search_pool_->push([&, index = row, p_id_ptr = p_id.get(), p_dist_ptr = p_dist.get()]() {
pq_flash_index_->cached_beam_search(xq + (index * dim), k, lsearch, p_id_ptr + (index * k),
p_dist_ptr + (index * k), beamwidth, false, nullptr, feder_result,
bitset, filter_ratio, for_tuning);
}));
}
for (auto& future : futures) {
if (TryDiskANNCall([&]() { future.wait(); }) != Status::success) {
all_searches_are_good = false;
}
}

if (!all_searches_are_good) {
if (TryDiskANNCall([&]() { WaitAllSuccess(futures); }) != Status::success) {
return expected<DataSetPtr>::Err(Status::diskann_inner_error, "some search failed");
}

auto res = GenResultDataSet(nq, k, p_id, p_dist);
auto res = GenResultDataSet(nq, k, p_id.release(), p_dist.release());

// set visit_info json string into result dataset
if (feder_result != nullptr) {
Expand Down Expand Up @@ -625,7 +615,6 @@ DiskANNIndexNode<T>::RangeSearch(const DataSet& dataset, const Config& cfg, cons

std::vector<folly::Future<folly::Unit>> futures;
futures.reserve(nq);
bool all_searches_are_good = true;
for (int64_t row = 0; row < nq; ++row) {
futures.emplace_back(search_pool_->push([&, index = row]() {
std::vector<int64_t> indices;
Expand All @@ -639,12 +628,7 @@ DiskANNIndexNode<T>::RangeSearch(const DataSet& dataset, const Config& cfg, cons
}
}));
}
for (auto& future : futures) {
if (TryDiskANNCall([&]() { future.wait(); }) != Status::success) {
all_searches_are_good = false;
}
}
if (!all_searches_are_good) {
if (TryDiskANNCall([&]() { WaitAllSuccess(futures); }) != Status::success) {
return expected<DataSetPtr>::Err(Status::diskann_inner_error, "some search failed");
}

Expand Down
18 changes: 2 additions & 16 deletions src/index/flat/flat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,7 @@ class FlatIndexNode : public IndexNode {
}));
}
// wait for the completion
for (auto& fut : futs) {
fut.wait();
}
// check for exceptions. value() is {}, so either
// a call does nothing, or it throws an inner exception.
for (auto& fut : futs) {
fut.result().value();
}
WaitAllSuccess(futs);
} catch (const std::exception& e) {
std::unique_ptr<int64_t[]> auto_delete_ids(ids);
std::unique_ptr<float[]> auto_delete_dis(distances);
Expand Down Expand Up @@ -193,14 +186,7 @@ class FlatIndexNode : public IndexNode {
}));
}
// wait for the completion
for (auto& fut : futs) {
fut.wait();
}
// check for exceptions. value() is {}, so either
// a call does nothing, or it throws an inner exception.
for (auto& fut : futs) {
fut.result().value();
}
WaitAllSuccess(futs);
GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius, range_filter, distances, ids,
lims);
} catch (const std::exception& e) {
Expand Down
28 changes: 10 additions & 18 deletions src/index/hnsw/hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,7 @@ class HnswIndexNode : public IndexNode {
}
}));
}
for (auto& future : futures) {
future.wait();
}
knowhere::WaitAllSuccess(futures);
build_time.RecordSection("");
LOG_KNOWHERE_INFO_ << "HNSW built with #points num:" << index_->max_elements_ << " #M:" << index_->M_
<< " #max level:" << index_->maxlevel_ << " #ef_construction:" << index_->ef_construction_
Expand Down Expand Up @@ -124,8 +122,8 @@ class HnswIndexNode : public IndexNode {
feder_result = std::make_unique<feder::hnsw::FederResult>();
}

auto p_id = new int64_t[k * nq];
auto p_dist = new float[k * nq];
auto p_id = std::make_unique<int64_t[]>(k * nq);
auto p_dist = std::make_unique<float[]>(k * nq);

hnswlib::SearchParam param{(size_t)hnsw_cfg.ef.value(), hnsw_cfg.for_tuning.value()};
bool transform =
Expand All @@ -134,12 +132,12 @@ class HnswIndexNode : public IndexNode {
std::vector<folly::Future<folly::Unit>> futs;
futs.reserve(nq);
for (int i = 0; i < nq; ++i) {
futs.emplace_back(search_pool_->push([&, idx = i]() {
futs.emplace_back(search_pool_->push([&, idx = i, p_id_ptr = p_id.get(), p_dist_ptr = p_dist.get()]() {
auto single_query = (const char*)xq + idx * index_->data_size_;
auto rst = index_->searchKnn(single_query, k, bitset, &param, feder_result);
size_t rst_size = rst.size();
auto p_single_dis = p_dist + idx * k;
auto p_single_id = p_id + idx * k;
auto p_single_dis = p_dist_ptr + idx * k;
auto p_single_id = p_id_ptr + idx * k;
for (size_t idx = 0; idx < rst_size; ++idx) {
const auto& [dist, id] = rst[idx];
p_single_dis[idx] = transform ? (-dist) : dist;
Expand All @@ -151,11 +149,9 @@ class HnswIndexNode : public IndexNode {
}
}));
}
for (auto& fut : futs) {
fut.wait();
}
WaitAllSuccess(futs);

auto res = GenResultDataSet(nq, k, p_id, p_dist);
auto res = GenResultDataSet(nq, k, p_id.release(), p_dist.release());

// set visit_info json string into result dataset
if (feder_result != nullptr) {
Expand Down Expand Up @@ -238,9 +234,7 @@ class HnswIndexNode : public IndexNode {
}));
}
// wait for initial search(in top layers and search for seed_ef in base layer) to finish
for (auto& fut : futs) {
fut.wait();
}
WaitAllSuccess(futs);

return vec;
}
Expand Down Expand Up @@ -303,9 +297,7 @@ class HnswIndexNode : public IndexNode {
}
}));
}
for (auto& fut : futs) {
fut.wait();
}
WaitAllSuccess(futs);

// filter range search result
GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius_for_filter, range_filter, dis, ids,
Expand Down
18 changes: 2 additions & 16 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -442,14 +442,7 @@ IvfIndexNode<T>::Search(const DataSet& dataset, const Config& cfg, const BitsetV
}));
}
// wait for the completion
for (auto& fut : futs) {
fut.wait();
}
// check for exceptions. value() is {}, so either
// a call does nothing, or it throws an inner exception.
for (auto& fut : futs) {
fut.result().value();
}
WaitAllSuccess(futs);
} catch (const std::exception& e) {
delete[] ids;
delete[] distances;
Expand Down Expand Up @@ -541,14 +534,7 @@ IvfIndexNode<T>::RangeSearch(const DataSet& dataset, const Config& cfg, const Bi
}));
}
// wait for the completion
for (auto& fut : futs) {
fut.wait();
}
// check for exceptions. value() is {}, so either
// a call does nothing, or it throws an inner exception.
for (auto& fut : futs) {
fut.result().value();
}
WaitAllSuccess(futs);
GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius, range_filter, distances, ids, lims);
} catch (const std::exception& e) {
LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what();
Expand Down
Loading