Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions faiss/utils/NeuralNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,19 +212,19 @@ nn::Int32Tensor2D QINCoStep::encode(
// repeated codebook
Tensor2D zqs_r(n * K, d); // size n, K, d
Tensor2D cc(n * K, d * 2); // size n, K, d * 2
size_t d = this->d;
size_t d_2 = this->d;

auto copy_row = [d](Tensor2D& t, size_t i, size_t j, const float* data) {
auto copy_row = [d_2](Tensor2D& t, size_t i, size_t j, const float* data) {
assert(i <= t.shape[0] && j <= t.shape[1]);
memcpy(t.data() + i * t.shape[1] + j, data, sizeof(float) * d);
memcpy(t.data() + i * t.shape[1] + j, data, sizeof(float) * d_2);
};

// manual broadcasting
for (size_t i = 0; i < n; i++) {
for (size_t j = 0; j < K; j++) {
copy_row(zqs_r, i * K + j, 0, codebook.data() + j * d);
copy_row(cc, i * K + j, 0, codebook.data() + j * d);
copy_row(cc, i * K + j, d, xhat.data() + i * d);
copy_row(zqs_r, i * K + j, 0, codebook.data() + j * d_2);
copy_row(cc, i * K + j, 0, codebook.data() + j * d_2);
copy_row(cc, i * K + j, d_2, xhat.data() + i * d_2);
}
}

Expand All @@ -237,13 +237,13 @@ nn::Int32Tensor2D QINCoStep::encode(

// add the xhat
for (size_t i = 0; i < n; i++) {
float* zqs_r_row = zqs_r.data() + i * K * d;
const float* xhat_row = xhat.data() + i * d;
float* zqs_r_row = zqs_r.data() + i * K * d_2;
const float* xhat_row = xhat.data() + i * d_2;
for (size_t l = 0; l < K; l++) {
for (size_t j = 0; j < d; j++) {
for (size_t j = 0; j < d_2; j++) {
zqs_r_row[j] += xhat_row[j];
}
zqs_r_row += d;
zqs_r_row += d_2;
}
}

Expand All @@ -252,31 +252,31 @@ nn::Int32Tensor2D QINCoStep::encode(
float* res = nullptr;
if (residuals) {
FAISS_THROW_IF_NOT(
residuals->shape[0] == n && residuals->shape[1] == d);
residuals->shape[0] == n && residuals->shape[1] == d_2);
res = residuals->data();
}

for (size_t i = 0; i < n; i++) {
const float* q = x.data() + i * d;
const float* db = zqs_r.data() + i * K * d;
const float* q = x.data() + i * d_2;
const float* db = zqs_r.data() + i * K * d_2;
float dis_min = HUGE_VALF;
int64_t idx = -1;
for (size_t j = 0; j < K; j++) {
float dis = fvec_L2sqr(q, db, d);
float dis = fvec_L2sqr(q, db, d_2);
if (dis < dis_min) {
dis_min = dis;
idx = j;
}
db += d;
db += d_2;
}
codes.v[i] = idx;
if (res) {
const float* xhat_row = xhat.data() + i * d;
const float* xhat_next_row = zqs_r.data() + (i * K + idx) * d;
for (size_t j = 0; j < d; j++) {
const float* xhat_row = xhat.data() + i * d_2;
const float* xhat_next_row = zqs_r.data() + (i * K + idx) * d_2;
for (size_t j = 0; j < d_2; j++) {
res[j] = xhat_next_row[j] - xhat_row[j];
}
res += d;
res += d_2;
}
}
return codes;
Expand Down