Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions include/knowhere/comp/thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

#include "folly/executors/CPUThreadPoolExecutor.h"
#include "folly/futures/Future.h"
#include "knowhere/expected.h"
#include "knowhere/log.h"

namespace knowhere {
Expand Down Expand Up @@ -211,4 +212,23 @@ class ThreadPool {

constexpr static size_t kTaskQueueFactor = 16;
};

// T is either folly::Unit or Status
template <typename T>
inline Status
WaitAllSuccess(std::vector<folly::Future<T>>& futures) {
static_assert(std::is_same<T, folly::Unit>::value || std::is_same<T, Status>::value,
"WaitAllSuccess can only be used with folly::Unit or knowhere::Status");
auto allFuts = folly::collectAll(futures.begin(), futures.end()).get();
for (const auto& result : allFuts) {
result.throwUnlessValue();
if constexpr (!std::is_same_v<T, folly::Unit>) {
if (result.value() != Status::success) {
return result.value();
}
}
}
return Status::success;
}

} // namespace knowhere
36 changes: 13 additions & 23 deletions src/common/comp/brute_force.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,17 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
bool is_cosine = IsMetricType(metric_str, metric::COSINE);

int topk = cfg.k.value();
auto labels = new int64_t[nq * topk];
auto distances = new float[nq * topk];
auto labels = std::make_unique<int64_t[]>(nq * topk);
auto distances = std::make_unique<float[]>(nq * topk);

auto pool = ThreadPool::GetGlobalSearchThreadPool();
std::vector<folly::Future<Status>> futs;
futs.reserve(nq);
for (int i = 0; i < nq; ++i) {
futs.emplace_back(pool->push([&, index = i] {
futs.emplace_back(pool->push([&, index = i, labels_ptr = labels.get(), distances_ptr = distances.get()] {
ThreadPool::ScopedOmpSetter setter(1);
auto cur_labels = labels + topk * index;
auto cur_distances = distances + topk * index;
auto cur_labels = labels_ptr + topk * index;
auto cur_distances = distances_ptr + topk * index;

BitsetViewIDSelector bw_idselector(bitset);
faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector;
Expand Down Expand Up @@ -128,14 +128,11 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
return Status::success;
}));
}
for (auto& fut : futs) {
fut.wait();
auto ret = fut.result().value();
if (ret != Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
}
auto ret = WaitAllSuccess(futs);
if (ret != Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
}
return GenResultDataSet(nq, cfg.k.value(), labels, distances);
return GenResultDataSet(nq, cfg.k.value(), labels.release(), distances.release());
}

template <typename DataType>
Expand Down Expand Up @@ -233,11 +230,7 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
return Status::success;
}));
}
for (auto& fut : futs) {
fut.wait();
auto ret = fut.result().value();
RETURN_IF_ERROR(ret);
}
RETURN_IF_ERROR(WaitAllSuccess(futs));
return Status::success;
}

Expand Down Expand Up @@ -348,12 +341,9 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
return Status::success;
}));
}
for (auto& fut : futs) {
fut.wait();
auto ret = fut.result().value();
if (ret != Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
}
auto ret = WaitAllSuccess(futs);
if (ret != Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
}

int64_t* ids = nullptr;
Expand Down
19 changes: 3 additions & 16 deletions src/common/thread/thread.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <utility>

#include "knowhere/comp/thread_pool.h"

namespace knowhere {

void
Expand All @@ -33,14 +34,7 @@ ExecOverSearchThreadPool(std::vector<std::function<void()>>& tasks) {
}));
}
std::this_thread::yield();
// check for exceptions. value() is {}, so either
// a call does nothing, or it throws an inner exception.
for (auto& f : futures) {
f.wait();
}
for (auto& f : futures) {
f.result().value();
}
WaitAllSuccess(futures);
}

void
Expand All @@ -55,14 +49,7 @@ ExecOverBuildThreadPool(std::vector<std::function<void()>>& tasks) {
}));
}
std::this_thread::yield();
// check for exceptions. value() is {}, so either
// a call does nothing, or it throws an inner exception.
for (auto& f : futures) {
f.wait();
}
for (auto& f : futures) {
f.result().value();
}
WaitAllSuccess(futures);
}

void
Expand Down
42 changes: 13 additions & 29 deletions src/index/diskann/diskann.cc
Original file line number Diff line number Diff line change
Expand Up @@ -479,8 +479,6 @@ DiskANNIndexNode<DataType>::Deserialize(const BinarySet& binset, const Config& c
std::vector<int64_t> warmup_result_ids_64(warmup_num, 0);
std::vector<DistType> warmup_result_dists(warmup_num, 0);

bool all_searches_are_good = true;

std::vector<folly::Future<folly::Unit>> futures;
futures.reserve(warmup_num);
for (_s64 i = 0; i < (int64_t)warmup_num; ++i) {
Expand All @@ -490,16 +488,14 @@ DiskANNIndexNode<DataType>::Deserialize(const BinarySet& binset, const Config& c
warmup_result_dists.data() + (index * 1), 4);
}));
}
for (auto& future : futures) {
if (TryDiskANNCall([&]() { future.wait(); }) != Status::success) {
all_searches_are_good = false;
}
}

bool failed = TryDiskANNCall([&]() { WaitAllSuccess(futures); }) != Status::success;

if (warmup != nullptr) {
diskann::aligned_free(warmup);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a memory leak in case of exception

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TryDiskANNCall catches all exceptions and returns non success Status, thus memory leak won't happen here.

}

if (!all_searches_are_good) {
if (failed) {
LOG_KNOWHERE_ERROR_ << "Failed to do search on warmup file for DiskANN.";
return Status::diskann_inner_error;
}
Expand Down Expand Up @@ -542,34 +538,28 @@ DiskANNIndexNode<DataType>::Search(const DataSet& dataset, const Config& cfg, co
search_conf.search_list_size.value());
}

auto p_id = new int64_t[k * nq];
auto p_dist = new DistType[k * nq];
auto p_id = std::make_unique<int64_t[]>(k * nq);
auto p_dist = std::make_unique<DistType[]>(k * nq);

bool all_searches_are_good = true;
std::vector<folly::Future<folly::Unit>> futures;
futures.reserve(nq);
for (int64_t row = 0; row < nq; ++row) {
futures.emplace_back(search_pool_->push([&, index = row]() {
futures.emplace_back(search_pool_->push([&, index = row, p_id_ptr = p_id.get(), p_dist_ptr = p_dist.get()]() {
diskann::QueryStats stats;
pq_flash_index_->cached_beam_search(xq + (index * dim), k, lsearch, p_id + (index * k),
p_dist + (index * k), beamwidth, false, &stats, feder_result, bitset,
filter_ratio, for_tuning);
pq_flash_index_->cached_beam_search(xq + (index * dim), k, lsearch, p_id_ptr + (index * k),
p_dist_ptr + (index * k), beamwidth, false, &stats, feder_result,
bitset, filter_ratio, for_tuning);
#ifdef NOT_COMPILE_FOR_SWIG
knowhere_diskann_search_hops.Observe(stats.n_hops);
#endif
}));
}
for (auto& future : futures) {
if (TryDiskANNCall([&]() { future.wait(); }) != Status::success) {
all_searches_are_good = false;
}
}

if (!all_searches_are_good) {
if (TryDiskANNCall([&]() { WaitAllSuccess(futures); }) != Status::success) {
return expected<DataSetPtr>::Err(Status::diskann_inner_error, "some search failed");
}

auto res = GenResultDataSet(nq, k, p_id, p_dist);
auto res = GenResultDataSet(nq, k, p_id.release(), p_dist.release());

// set visit_info json string into result dataset
if (feder_result != nullptr) {
Expand Down Expand Up @@ -621,7 +611,6 @@ DiskANNIndexNode<DataType>::RangeSearch(const DataSet& dataset, const Config& cf

std::vector<folly::Future<folly::Unit>> futures;
futures.reserve(nq);
bool all_searches_are_good = true;
for (int64_t row = 0; row < nq; ++row) {
futures.emplace_back(search_pool_->push([&, index = row]() {
std::vector<int64_t> indices;
Expand All @@ -639,12 +628,7 @@ DiskANNIndexNode<DataType>::RangeSearch(const DataSet& dataset, const Config& cf
}
}));
}
for (auto& future : futures) {
if (TryDiskANNCall([&]() { future.wait(); }) != Status::success) {
all_searches_are_good = false;
}
}
if (!all_searches_are_good) {
if (TryDiskANNCall([&]() { WaitAllSuccess(futures); }) != Status::success) {
return expected<DataSetPtr>::Err(Status::diskann_inner_error, "some search failed");
}

Expand Down
18 changes: 2 additions & 16 deletions src/index/flat/flat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,7 @@ class FlatIndexNode : public IndexNode {
}));
}
// wait for the completion
for (auto& fut : futs) {
fut.wait();
}
// check for exceptions. value() is {}, so either
// a call does nothing, or it throws an inner exception.
for (auto& fut : futs) {
fut.result().value();
}
WaitAllSuccess(futs);
} catch (const std::exception& e) {
std::unique_ptr<int64_t[]> auto_delete_ids(ids);
std::unique_ptr<float[]> auto_delete_dis(distances);
Expand Down Expand Up @@ -216,14 +209,7 @@ class FlatIndexNode : public IndexNode {
}));
}
// wait for the completion
for (auto& fut : futs) {
fut.wait();
}
// check for exceptions. value() is {}, so either
// a call does nothing, or it throws an inner exception.
for (auto& fut : futs) {
fut.result().value();
}
WaitAllSuccess(futs);
GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius, range_filter, distances, ids,
lims);
} catch (const std::exception& e) {
Expand Down
48 changes: 15 additions & 33 deletions src/index/hnsw/hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,7 @@ class HnswIndexNode : public IndexNode {
}
}));
}
for (auto& future : futures) {
future.wait();
}
// check for exceptions
for (auto& future : futures) {
future.result().value();
}
WaitAllSuccess(futures);
futures.clear();
}

Expand All @@ -146,13 +140,7 @@ class HnswIndexNode : public IndexNode {
futures.emplace_back(
build_pool->push([&, idx = i]() { index_->repairGraphConnectivity(unreached[idx]); }));
}
for (auto& future : futures) {
future.wait();
}
// check for exceptions
for (auto& future : futures) {
future.result().value();
}
WaitAllSuccess(futures);
}
build_time.RecordSection("graph repair");
LOG_KNOWHERE_INFO_ << "HNSW built with #points num:" << index_->max_elements_ << " #M:" << index_->M_
Expand Down Expand Up @@ -186,8 +174,8 @@ class HnswIndexNode : public IndexNode {
feder_result = std::make_unique<feder::hnsw::FederResult>();
}

auto p_id = new int64_t[k * nq];
auto p_dist = new DistType[k * nq];
auto p_id = std::make_unique<int64_t[]>(k * nq);
auto p_dist = std::make_unique<DistType[]>(k * nq);

hnswlib::SearchParam param{(size_t)hnsw_cfg.ef.value(), hnsw_cfg.for_tuning.value()};
bool transform =
Expand All @@ -196,12 +184,12 @@ class HnswIndexNode : public IndexNode {
std::vector<folly::Future<folly::Unit>> futs;
futs.reserve(nq);
for (int i = 0; i < nq; ++i) {
futs.emplace_back(search_pool_->push([&, idx = i]() {
futs.emplace_back(search_pool_->push([&, idx = i, p_id_ptr = p_id.get(), p_dist_ptr = p_dist.get()]() {
auto single_query = (const char*)xq + idx * index_->data_size_;
auto rst = index_->searchKnn(single_query, k, bitset, &param, feder_result);
size_t rst_size = rst.size();
auto p_single_dis = p_dist + idx * k;
auto p_single_id = p_id + idx * k;
auto p_single_dis = p_dist_ptr + idx * k;
auto p_single_id = p_id_ptr + idx * k;
for (size_t idx = 0; idx < rst_size; ++idx) {
const auto& [dist, id] = rst[idx];
p_single_dis[idx] = transform ? (-dist) : dist;
Expand All @@ -213,11 +201,9 @@ class HnswIndexNode : public IndexNode {
}
}));
}
for (auto& fut : futs) {
fut.wait();
}
WaitAllSuccess(futs);

auto res = GenResultDataSet(nq, k, p_id, p_dist);
auto res = GenResultDataSet(nq, k, p_id.release(), p_dist.release());

// set visit_info json string into result dataset
if (feder_result != nullptr) {
Expand Down Expand Up @@ -300,9 +286,7 @@ class HnswIndexNode : public IndexNode {
}));
}
// wait for initial search(in top layers and search for seed_ef in base layer) to finish
for (auto& fut : futs) {
fut.wait();
}
WaitAllSuccess(futs);

return vec;
}
Expand Down Expand Up @@ -335,10 +319,6 @@ class HnswIndexNode : public IndexNode {

hnswlib::SearchParam param{(size_t)hnsw_cfg.ef.value()};

int64_t* ids = nullptr;
DistType* dis = nullptr;
size_t* lims = nullptr;

std::vector<std::vector<int64_t>> result_id_array(nq);
std::vector<std::vector<DistType>> result_dist_array(nq);
std::vector<size_t> result_size(nq);
Expand All @@ -365,9 +345,11 @@ class HnswIndexNode : public IndexNode {
}
}));
}
for (auto& fut : futs) {
fut.wait();
}
WaitAllSuccess(futs);

int64_t* ids = nullptr;
DistType* dis = nullptr;
size_t* lims = nullptr;

// filter range search result
GetRangeSearchResult(result_dist_array, result_id_array, is_ip, nq, radius_for_filter, range_filter, dis, ids,
Expand Down
Loading