Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions faiss/IndexAdditiveQuantizerFastScan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,30 +32,31 @@ IndexAdditiveQuantizerFastScan::IndexAdditiveQuantizerFastScan(
}

void IndexAdditiveQuantizerFastScan::init(
AdditiveQuantizer* aq_2,
AdditiveQuantizer* aq_init,
MetricType metric,
int bbs) {
FAISS_THROW_IF_NOT(aq_2 != nullptr);
FAISS_THROW_IF_NOT(!aq_2->nbits.empty());
FAISS_THROW_IF_NOT(aq_2->nbits[0] == 4);
FAISS_THROW_IF_NOT(aq_init != nullptr);
FAISS_THROW_IF_NOT(!aq_init->nbits.empty());
FAISS_THROW_IF_NOT(aq_init->nbits[0] == 4);
if (metric == METRIC_INNER_PRODUCT) {
FAISS_THROW_IF_NOT_MSG(
aq_2->search_type == AdditiveQuantizer::ST_LUT_nonorm,
aq_init->search_type == AdditiveQuantizer::ST_LUT_nonorm,
"Search type must be ST_LUT_nonorm for IP metric");
} else {
FAISS_THROW_IF_NOT_MSG(
aq_2->search_type == AdditiveQuantizer::ST_norm_lsq2x4 ||
aq_2->search_type == AdditiveQuantizer::ST_norm_rq2x4,
aq_init->search_type == AdditiveQuantizer::ST_norm_lsq2x4 ||
aq_init->search_type ==
AdditiveQuantizer::ST_norm_rq2x4,
"Search type must be lsq2x4 or rq2x4 for L2 metric");
}

this->aq = aq_2;
this->aq = aq_init;
if (metric == METRIC_L2) {
M = aq_2->M + 2; // 2x4 bits AQ
M = aq_init->M + 2; // 2x4 bits AQ
} else {
M = aq_2->M;
M = aq_init->M;
}
init_fastscan(aq_2->d, M, 4, metric, bbs);
init_fastscan(aq_init->d, M, 4, metric, bbs);

max_train_points = 1024 * ksub * M;
}
Expand Down
46 changes: 25 additions & 21 deletions faiss/utils/NeuralNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,19 +212,23 @@ nn::Int32Tensor2D QINCoStep::encode(
// repeated codebook
Tensor2D zqs_r(n * K, d); // size n, K, d
Tensor2D cc(n * K, d * 2); // size n, K, d * 2
size_t d_2 = this->d;

auto copy_row = [d_2](Tensor2D& t, size_t i, size_t j, const float* data) {
assert(i <= t.shape[0] && j <= t.shape[1]);
memcpy(t.data() + i * t.shape[1] + j, data, sizeof(float) * d_2);
};
size_t local_d = this->d;

auto copy_row =
[local_d](Tensor2D& t, size_t i, size_t j, const float* data) {
assert(i <= t.shape[0] && j <= t.shape[1]);
memcpy(t.data() + i * t.shape[1] + j,
data,
sizeof(float) * local_d);
};

// manual broadcasting
for (size_t i = 0; i < n; i++) {
for (size_t j = 0; j < K; j++) {
copy_row(zqs_r, i * K + j, 0, codebook.data() + j * d_2);
copy_row(cc, i * K + j, 0, codebook.data() + j * d_2);
copy_row(cc, i * K + j, d_2, xhat.data() + i * d_2);
copy_row(zqs_r, i * K + j, 0, codebook.data() + j * d);
copy_row(cc, i * K + j, 0, codebook.data() + j * d);
copy_row(cc, i * K + j, d, xhat.data() + i * d);
}
}

Expand All @@ -237,13 +241,13 @@ nn::Int32Tensor2D QINCoStep::encode(

// add the xhat
for (size_t i = 0; i < n; i++) {
float* zqs_r_row = zqs_r.data() + i * K * d_2;
const float* xhat_row = xhat.data() + i * d_2;
float* zqs_r_row = zqs_r.data() + i * K * d;
const float* xhat_row = xhat.data() + i * d;
for (size_t l = 0; l < K; l++) {
for (size_t j = 0; j < d_2; j++) {
for (size_t j = 0; j < d; j++) {
zqs_r_row[j] += xhat_row[j];
}
zqs_r_row += d_2;
zqs_r_row += d;
}
}

Expand All @@ -252,31 +256,31 @@ nn::Int32Tensor2D QINCoStep::encode(
float* res = nullptr;
if (residuals) {
FAISS_THROW_IF_NOT(
residuals->shape[0] == n && residuals->shape[1] == d_2);
residuals->shape[0] == n && residuals->shape[1] == d);
res = residuals->data();
}

for (size_t i = 0; i < n; i++) {
const float* q = x.data() + i * d_2;
const float* db = zqs_r.data() + i * K * d_2;
const float* q = x.data() + i * d;
const float* db = zqs_r.data() + i * K * d;
float dis_min = HUGE_VALF;
int64_t idx = -1;
for (size_t j = 0; j < K; j++) {
float dis = fvec_L2sqr(q, db, d_2);
float dis = fvec_L2sqr(q, db, d);
if (dis < dis_min) {
dis_min = dis;
idx = j;
}
db += d_2;
db += d;
}
codes.v[i] = idx;
if (res) {
const float* xhat_row = xhat.data() + i * d_2;
const float* xhat_next_row = zqs_r.data() + (i * K + idx) * d_2;
for (size_t j = 0; j < d_2; j++) {
const float* xhat_row = xhat.data() + i * d;
const float* xhat_next_row = zqs_r.data() + (i * K + idx) * d;
for (size_t j = 0; j < d; j++) {
res[j] = xhat_next_row[j] - xhat_row[j];
}
res += d_2;
res += d;
}
}
return codes;
Expand Down
10 changes: 5 additions & 5 deletions faiss/utils/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -589,9 +589,9 @@ int64_t count_gt(int64_t n, const T* row, T threshold) {
} // namespace

template <typename T>
void CombinerRangeKNN<T>::compute_sizes(int64_t* L_res_2) {
this->L_res = L_res_2;
L_res_2[0] = 0;
void CombinerRangeKNN<T>::compute_sizes(int64_t* L_res_init) {
this->L_res = L_res_init;
L_res_init[0] = 0;
int64_t j = 0;
for (int64_t i = 0; i < nq; i++) {
int64_t n_in;
Expand All @@ -602,11 +602,11 @@ void CombinerRangeKNN<T>::compute_sizes(int64_t* L_res_2) {
n_in = lim_remain[j + 1] - lim_remain[j];
j++;
}
L_res_2[i + 1] = n_in; // L_res_2[i] + n_in;
L_res_init[i + 1] = n_in; // L_res_init[i] + n_in;
}
// cumsum
for (int64_t i = 0; i < nq; i++) {
L_res_2[i + 1] += L_res_2[i];
L_res_init[i + 1] += L_res_init[i];
}
}

Expand Down