Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
fddbd3e
Add support for search params to IndexBinaryFlat
Dec 4, 2024
bf90fc1
update tests
Dec 4, 2024
c5c9cab
revert default param changes
Dec 4, 2024
d70e64b
add missing sel to hamming.h, add no heap test case, simplify valid_c…
Dec 12, 2024
cd8b5d3
fix import, no heap test, linting
Dec 16, 2024
b7e73e1
Merge branch 'main' into gustavz/search_params_support_for_index_bina…
Dec 16, 2024
e96237e
lint
Dec 19, 2024
b3db31b
Merge branch 'main' into gustavz/search_params_support_for_index_bina…
Dec 19, 2024
d92184e
remove default from definition
Dec 24, 2024
d40594f
Merge branch 'main' into gustavz/search_params_support_for_index_bina…
Dec 24, 2024
16027c6
add #include <faiss/impl/IDSelector.h> to hamming.cpp
Jan 14, 2025
54d8256
Merge branch 'main' into gustavz/search_params_support_for_index_bina…
Jan 14, 2025
983ae3d
add faiss namespace to IDSelector in hamming.cpp
Jan 21, 2025
44d7977
Merge branch 'main' into gustavz/search_params_support_for_index_bina…
Jan 21, 2025
fc1f675
Merge branch 'main' into gustavz/search_params_support_for_index_bina…
Jan 22, 2025
dfccc69
add faiss:: to IDSelector in hamming.h
Feb 13, 2025
a5ec86c
Merge branch 'main' into gustavz/search_params_support_for_index_bina…
Feb 13, 2025
56f13ea
add params to IndexBinary replacement_search and replacement_range_se…
Feb 14, 2025
95af571
Merge branch 'main' into gustavz/search_params_support_for_index_bina…
Feb 14, 2025
75df1b5
Merge branch 'main' into gustavz/search_params_support_for_index_bina…
mnorris11 Feb 20, 2025
5b424cb
update tests
Feb 28, 2025
3b56fc2
Merge branch 'main' into gustavz/search_params_support_for_index_bina…
Feb 28, 2025
8be85c8
small test optimization
Feb 28, 2025
9d16d52
lint
Mar 11, 2025
1c52c9c
Merge branch 'main' into gustavz/search_params_support_for_index_bina…
Mar 11, 2025
9f851fc
Merge branch 'main' into gustavz/search_params_support_for_index_bina…
gtwang01 Mar 11, 2025
565a9c9
Merge branch 'main' into gustavz/search_params_support_for_index_bina…
gtwang01 Mar 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions faiss/IndexBinaryFlat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ void IndexBinaryFlat::search(
int32_t* distances,
idx_t* labels,
const SearchParameters* params) const {
FAISS_THROW_IF_NOT_MSG(
!params, "search params not supported for this index");
// Extract IDSelector from params if present
const IDSelector* sel = params ? params->sel : nullptr;
FAISS_THROW_IF_NOT(k > 0);

const idx_t block_size = query_batch_size;
Expand All @@ -60,7 +60,8 @@ void IndexBinaryFlat::search(
ntotal,
code_size,
/* ordered = */ true,
approx_topk_mode);
approx_topk_mode,
sel);
} else {
hammings_knn_mc(
x + s * code_size,
Expand All @@ -70,7 +71,8 @@ void IndexBinaryFlat::search(
k,
code_size,
distances + s * k,
labels + s * k);
labels + s * k,
sel);
}
}
}
Expand Down Expand Up @@ -107,9 +109,9 @@ void IndexBinaryFlat::range_search(
int radius,
RangeSearchResult* result,
const SearchParameters* params) const {
FAISS_THROW_IF_NOT_MSG(
!params, "search params not supported for this index");
hamming_range_search(x, xb.data(), n, ntotal, radius, code_size, result);
const IDSelector* sel = params ? params->sel : nullptr;
hamming_range_search(
x, xb.data(), n, ntotal, radius, code_size, result, sel);
}

} // namespace faiss
9 changes: 5 additions & 4 deletions faiss/python/class_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,7 @@ def replacement_reconstruct_n(self, n0=0, ni=-1, x=None):
self.reconstruct_n_c(n0, ni, swig_ptr(x))
return x

def replacement_search(self, x, k):
def replacement_search(self, x, k, *, params=None):
x = _check_dtype_uint8(x)
n, d = x.shape
assert d == self.code_size
Expand All @@ -878,7 +878,8 @@ def replacement_search(self, x, k):
labels = np.empty((n, k), dtype=np.int64)
self.search_c(n, swig_ptr(x),
k, swig_ptr(distances),
swig_ptr(labels))
swig_ptr(labels),
params=params)
return distances, labels

def replacement_search_preassigned(self, x, k, Iq, Dq):
Expand Down Expand Up @@ -906,12 +907,12 @@ def replacement_search_preassigned(self, x, k, Iq, Dq):
)
return D, I

def replacement_range_search(self, x, thresh):
def replacement_range_search(self, x, thresh, *, params=None):
n, d = x.shape
x = _check_dtype_uint8(x)
assert d == self.code_size
res = RangeSearchResult(n)
self.range_search_c(n, swig_ptr(x), thresh, res)
self.range_search_c(n, swig_ptr(x), thresh, res, params=params)
# get pointers and copy them
lims = rev_swig_ptr(res.lims, n + 1).copy()
nd = int(lims[-1])
Expand Down
42 changes: 30 additions & 12 deletions faiss/utils/approx_topk_hamming/approx_topk_hamming.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,11 @@ struct HeapWithBucketsForHamming32<
// output distances
int* const __restrict bh_val,
// output indices, each being within [0, n) range
int64_t* const __restrict bh_ids) {
int64_t* const __restrict bh_ids,
// optional id selector for filtering
const IDSelector* sel = nullptr) {
// forward a call to bs_addn with 1 beam
bs_addn(1, n, hc, binaryVectors, k, bh_val, bh_ids);
bs_addn(1, n, hc, binaryVectors, k, bh_val, bh_ids, sel);
}

static void bs_addn(
Expand All @@ -66,7 +68,9 @@ struct HeapWithBucketsForHamming32<
int* const __restrict bh_val,
// output indices, each being within [0, n_per_beam * beam_size)
// range
int64_t* const __restrict bh_ids) {
int64_t* const __restrict bh_ids,
// optional id selector for filtering
const IDSelector* sel = nullptr) {
//
using C = CMax<int, int64_t>;

Expand Down Expand Up @@ -95,11 +99,22 @@ struct HeapWithBucketsForHamming32<
for (uint32_t ip = 0; ip < nb; ip += NBUCKETS) {
for (uint32_t j = 0; j < NBUCKETS_8; j++) {
uint32_t hamming_distances[8];
uint8_t valid_counter = 0;
for (size_t j8 = 0; j8 < 8; j8++) {
hamming_distances[j8] = hc.hamming(
binary_vectors +
(j8 + j * 8 + ip + n_per_beam * beam_index) *
code_size);
const uint32_t idx =
j8 + j * 8 + ip + n_per_beam * beam_index;
if (!sel || sel->is_member(idx)) {
hamming_distances[j8] = hc.hamming(
binary_vectors + idx * code_size);
valid_counter++;
} else {
hamming_distances[j8] =
std::numeric_limits<int32_t>::max();
}
}

if (valid_counter == 8) {
continue; // Skip if all vectors are filtered out
}

// loop. Compiler should get rid of unneeded ops
Expand Down Expand Up @@ -157,7 +172,8 @@ struct HeapWithBucketsForHamming32<
const auto value = min_distances_scalar[j8];
const auto index = min_indices_scalar[j8];

if (C::cmp2(bh_val[0], value, bh_ids[0], index)) {
if (value < std::numeric_limits<int32_t>::max() &&
C::cmp2(bh_val[0], value, bh_ids[0], index)) {
heap_replace_top<C>(
k, bh_val, bh_ids, value, index);
}
Expand All @@ -168,11 +184,13 @@ struct HeapWithBucketsForHamming32<
// process leftovers
for (uint32_t ip = nb; ip < n_per_beam; ip++) {
const auto index = ip + n_per_beam * beam_index;
const auto value =
hc.hamming(binary_vectors + (index)*code_size);
if (!sel || sel->is_member(index)) {
const auto value =
hc.hamming(binary_vectors + (index)*code_size);

if (C::cmp(bh_val[0], value)) {
heap_replace_top<C>(k, bh_val, bh_ids, value, index);
if (C::cmp(bh_val[0], value)) {
heap_replace_top<C>(k, bh_val, bh_ids, value, index);
}
}
}
}
Expand Down
52 changes: 38 additions & 14 deletions faiss/utils/hamming.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/IDSelector.h>
#include <faiss/utils/Heap.h>
#include <faiss/utils/approx_topk_hamming/approx_topk_hamming.h>
#include <faiss/utils/utils.h>
Expand Down Expand Up @@ -171,7 +172,8 @@ void hammings_knn_hc(
size_t n2,
bool order = true,
bool init_heap = true,
ApproxTopK_mode_t approx_topk_mode = ApproxTopK_mode_t::EXACT_TOPK) {
ApproxTopK_mode_t approx_topk_mode = ApproxTopK_mode_t::EXACT_TOPK,
const faiss::IDSelector* sel = nullptr) {
size_t k = ha->k;
if (init_heap)
ha->heapify();
Expand Down Expand Up @@ -204,7 +206,7 @@ void hammings_knn_hc(
NB, \
BD, \
HammingComputer>:: \
addn(j1 - j0, hc, bs2_, k, bh_val_, bh_ids_); \
addn(j1 - j0, hc, bs2_, k, bh_val_, bh_ids_, sel); \
break;

switch (approx_topk_mode) {
Expand All @@ -214,6 +216,9 @@ void hammings_knn_hc(
HANDLE_APPROX(32, 2)
default: {
for (size_t j = j0; j < j1; j++, bs2_ += bytes_per_code) {
if (sel && !sel->is_member(j)) {
continue;
}
dis = hc.hamming(bs2_);
if (dis < bh_val_[0]) {
faiss::maxheap_replace_top<hamdis_t>(
Expand All @@ -238,7 +243,8 @@ void hammings_knn_mc(
size_t nb,
size_t k,
int32_t* __restrict distances,
int64_t* __restrict labels) {
int64_t* __restrict labels,
const faiss::IDSelector* sel) {
const int nBuckets = bytes_per_code * 8 + 1;
std::vector<int> all_counters(na * nBuckets, 0);
std::unique_ptr<int64_t[]> all_ids_per_dis(new int64_t[na * nBuckets * k]);
Expand All @@ -259,7 +265,9 @@ void hammings_knn_mc(
#pragma omp parallel for
for (int64_t i = 0; i < na; ++i) {
for (size_t j = j0; j < j1; ++j) {
cs[i].update_counter(b + j * bytes_per_code, j);
if (!sel || sel->is_member(j)) {
cs[i].update_counter(b + j * bytes_per_code, j);
}
}
}
}
Expand Down Expand Up @@ -291,7 +299,8 @@ void hamming_range_search(
size_t nb,
int radius,
size_t code_size,
RangeSearchResult* res) {
RangeSearchResult* res,
const faiss::IDSelector* sel) {
#pragma omp parallel
{
RangeSearchPartialResult pres(res);
Expand All @@ -303,9 +312,11 @@ void hamming_range_search(
RangeQueryResult& qres = pres.new_result(i);

for (size_t j = 0; j < nb; j++) {
int dis = hc.hamming(yi);
if (dis < radius) {
qres.add(dis, j);
if (!sel || sel->is_member(j)) {
int dis = hc.hamming(yi);
if (dis < radius) {
qres.add(dis, j);
}
}
yi += code_size;
}
Expand Down Expand Up @@ -489,10 +500,21 @@ void hammings_knn_hc(
size_t nb,
size_t ncodes,
int order,
ApproxTopK_mode_t approx_topk_mode) {
ApproxTopK_mode_t approx_topk_mode,
const faiss::IDSelector* sel) {
Run_hammings_knn_hc r;
dispatch_HammingComputer(
ncodes, r, ncodes, ha, a, b, nb, order, true, approx_topk_mode);
ncodes,
r,
ncodes,
ha,
a,
b,
nb,
order,
true,
approx_topk_mode,
sel);
}

void hammings_knn_mc(
Expand All @@ -503,10 +525,11 @@ void hammings_knn_mc(
size_t k,
size_t ncodes,
int32_t* __restrict distances,
int64_t* __restrict labels) {
int64_t* __restrict labels,
const faiss::IDSelector* sel) {
Run_hammings_knn_mc r;
dispatch_HammingComputer(
ncodes, r, ncodes, a, b, na, nb, k, distances, labels);
ncodes, r, ncodes, a, b, na, nb, k, distances, labels, sel);
}

void hamming_range_search(
Expand All @@ -516,10 +539,11 @@ void hamming_range_search(
size_t nb,
int radius,
size_t code_size,
RangeSearchResult* result) {
RangeSearchResult* result,
const faiss::IDSelector* sel) {
Run_hamming_range_search r;
dispatch_HammingComputer(
code_size, r, a, b, na, nb, radius, code_size, result);
code_size, r, a, b, na, nb, radius, code_size, result, sel);
}

/* Count number of matches given a max threshold */
Expand Down
10 changes: 7 additions & 3 deletions faiss/utils/hamming.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <stdint.h>

#include <faiss/impl/IDSelector.h>
#include <faiss/impl/platform_macros.h>
#include <faiss/utils/Heap.h>

Expand Down Expand Up @@ -135,7 +136,8 @@ void hammings_knn_hc(
size_t nb,
size_t ncodes,
int ordered,
ApproxTopK_mode_t approx_topk_mode = ApproxTopK_mode_t::EXACT_TOPK);
ApproxTopK_mode_t approx_topk_mode = ApproxTopK_mode_t::EXACT_TOPK,
const faiss::IDSelector* sel = nullptr);

/* Legacy alias to hammings_knn_hc. */
void hammings_knn(
Expand Down Expand Up @@ -166,7 +168,8 @@ void hammings_knn_mc(
size_t k,
size_t ncodes,
int32_t* distances,
int64_t* labels);
int64_t* labels,
const faiss::IDSelector* sel = nullptr);

/** same as hammings_knn except we are doing a range search with radius */
void hamming_range_search(
Expand All @@ -176,7 +179,8 @@ void hamming_range_search(
size_t nb,
int radius,
size_t ncodes,
RangeSearchResult* result);
RangeSearchResult* result,
const faiss::IDSelector* sel = nullptr);

/* Counting the number of matches or of cross-matches (without returning them)
For use with function that assume pre-allocated memory */
Expand Down
Loading
Loading