Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions faiss/IndexAdditiveQuantizerFastScan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,30 +32,31 @@ IndexAdditiveQuantizerFastScan::IndexAdditiveQuantizerFastScan(
}

void IndexAdditiveQuantizerFastScan::init(
AdditiveQuantizer* aq_2,
AdditiveQuantizer* aq_init,
MetricType metric,
int bbs) {
FAISS_THROW_IF_NOT(aq_2 != nullptr);
FAISS_THROW_IF_NOT(!aq_2->nbits.empty());
FAISS_THROW_IF_NOT(aq_2->nbits[0] == 4);
FAISS_THROW_IF_NOT(aq_init != nullptr);
FAISS_THROW_IF_NOT(!aq_init->nbits.empty());
FAISS_THROW_IF_NOT(aq_init->nbits[0] == 4);
if (metric == METRIC_INNER_PRODUCT) {
FAISS_THROW_IF_NOT_MSG(
aq_2->search_type == AdditiveQuantizer::ST_LUT_nonorm,
aq_init->search_type == AdditiveQuantizer::ST_LUT_nonorm,
"Search type must be ST_LUT_nonorm for IP metric");
} else {
FAISS_THROW_IF_NOT_MSG(
aq_2->search_type == AdditiveQuantizer::ST_norm_lsq2x4 ||
aq_2->search_type == AdditiveQuantizer::ST_norm_rq2x4,
aq_init->search_type == AdditiveQuantizer::ST_norm_lsq2x4 ||
aq_init->search_type ==
AdditiveQuantizer::ST_norm_rq2x4,
"Search type must be lsq2x4 or rq2x4 for L2 metric");
}

this->aq = aq_2;
this->aq = aq_init;
if (metric == METRIC_L2) {
M = aq_2->M + 2; // 2x4 bits AQ
M = aq_init->M + 2; // 2x4 bits AQ
} else {
M = aq_2->M;
M = aq_init->M;
}
init_fastscan(aq_2->d, M, 4, metric, bbs);
init_fastscan(aq_init->d, M, 4, metric, bbs);

max_train_points = 1024 * ksub * M;
}
Expand Down
34 changes: 17 additions & 17 deletions faiss/IndexPQ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,16 +159,16 @@ void IndexPQ::search(
FAISS_THROW_IF_NOT(is_trained);

const SearchParametersPQ* params = nullptr;
Search_type_t search_type_2 = this->search_type;
Search_type_t param_search_type = this->search_type;

if (iparams) {
params = dynamic_cast<const SearchParametersPQ*>(iparams);
FAISS_THROW_IF_NOT_MSG(params, "invalid search params");
FAISS_THROW_IF_NOT_MSG(!params->sel, "selector not supported");
search_type_2 = params->search_type;
param_search_type = params->search_type;
}

if (search_type_2 == ST_PQ) { // Simple PQ search
if (param_search_type == ST_PQ) { // Simple PQ search

if (metric_type == METRIC_L2) {
float_maxheap_array_t res = {
Expand All @@ -183,19 +183,19 @@ void IndexPQ::search(
indexPQ_stats.ncode += n * ntotal;

} else if (
search_type_2 == ST_polysemous ||
search_type_2 == ST_polysemous_generalize) {
param_search_type == ST_polysemous ||
param_search_type == ST_polysemous_generalize) {
FAISS_THROW_IF_NOT(metric_type == METRIC_L2);
int polysemous_ht_2 =
int param_polysemous_ht =
params ? params->polysemous_ht : this->polysemous_ht;
search_core_polysemous(
n,
x,
k,
distances,
labels,
polysemous_ht_2,
search_type_2 == ST_polysemous_generalize);
param_polysemous_ht,
param_search_type == ST_polysemous_generalize);

} else { // code-to-code distances

Expand All @@ -215,7 +215,7 @@ void IndexPQ::search(
}
}

if (search_type_2 == ST_SDC) {
if (param_search_type == ST_SDC) {
float_maxheap_array_t res = {
size_t(n), size_t(k), labels, distances};

Expand All @@ -227,7 +227,7 @@ void IndexPQ::search(
int_maxheap_array_t res = {
size_t(n), size_t(k), labels, idistances.get()};

if (search_type_2 == ST_HE) {
if (param_search_type == ST_HE) {
hammings_knn_hc(
&res,
q_codes.get(),
Expand All @@ -236,7 +236,7 @@ void IndexPQ::search(
pq.code_size,
true);

} else if (search_type_2 == ST_generalized_HE) {
} else if (param_search_type == ST_generalized_HE) {
generalized_hammings_knn_hc(
&res,
q_codes.get(),
Expand Down Expand Up @@ -322,13 +322,13 @@ void IndexPQ::search_core_polysemous(
idx_t k,
float* distances,
idx_t* labels,
int polysemous_ht_2,
int param_polysemous_ht,
bool generalized_hamming) const {
FAISS_THROW_IF_NOT(k > 0);
FAISS_THROW_IF_NOT(pq.nbits == 8);

if (polysemous_ht_2 == 0) {
polysemous_ht_2 = pq.nbits * pq.M + 1;
if (param_polysemous_ht == 0) {
param_polysemous_ht = pq.nbits * pq.M + 1;
}

// PQ distance tables
Expand Down Expand Up @@ -374,7 +374,7 @@ void IndexPQ::search_core_polysemous(
k,
heap_dis,
heap_ids,
polysemous_ht_2);
param_polysemous_ht);

} else { // generalized hamming
switch (pq.code_size) {
Expand All @@ -387,7 +387,7 @@ void IndexPQ::search_core_polysemous(
k, \
heap_dis, \
heap_ids, \
polysemous_ht_2); \
param_polysemous_ht); \
break;
DISPATCH(8)
DISPATCH(16)
Expand All @@ -401,7 +401,7 @@ void IndexPQ::search_core_polysemous(
k,
heap_dis,
heap_ids,
polysemous_ht_2);
param_polysemous_ht);
} else {
bad_code_size++;
}
Expand Down
10 changes: 5 additions & 5 deletions faiss/utils/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -589,9 +589,9 @@ int64_t count_gt(int64_t n, const T* row, T threshold) {
} // namespace

template <typename T>
void CombinerRangeKNN<T>::compute_sizes(int64_t* L_res_2) {
this->L_res = L_res_2;
L_res_2[0] = 0;
void CombinerRangeKNN<T>::compute_sizes(int64_t* L_res_init) {
this->L_res = L_res_init;
L_res_init[0] = 0;
int64_t j = 0;
for (int64_t i = 0; i < nq; i++) {
int64_t n_in;
Expand All @@ -602,11 +602,11 @@ void CombinerRangeKNN<T>::compute_sizes(int64_t* L_res_2) {
n_in = lim_remain[j + 1] - lim_remain[j];
j++;
}
L_res_2[i + 1] = n_in; // L_res_2[i] + n_in;
L_res_init[i + 1] = n_in; // L_res_init[i] + n_in;
}
// cumsum
for (int64_t i = 0; i < nq; i++) {
L_res_2[i + 1] += L_res_2[i];
L_res_init[i + 1] += L_res_init[i];
}
}

Expand Down