Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions demos/demo_simd_levels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import time
import faiss
import numpy as np
import os
from collections import defaultdict
from faiss.contrib.datasets import SyntheticDataset


print("compile options", faiss.get_compile_options())
print("SIMD level: ", faiss.SIMDConfig.get_level_name())


ds = SyntheticDataset(32, 8000, 10000, 8000)


index = faiss.index_factory(ds.d, "PQ16x4fs")
# index = faiss.index_factory(ds.d, "IVF64,PQ16x4fs")
# index = faiss.index_factory(ds.d, "SQ8")

index.train(ds.get_train())
index.add(ds.get_database())


if False:
faiss.omp_set_num_threads(1)
print("PID=", os.getpid())
input("press enter to continue")
# for simd_level in faiss.NONE, faiss.AVX2, faiss.AVX512F:
for simd_level in faiss.AVX2, faiss.AVX512F:

faiss.SIMDConfig.set_level(simd_level)
print("simd_level=", faiss.SIMDConfig.get_level_name())
for run in range(1000):
D, I = index.search(ds.get_queries(), 10)

times = defaultdict(list)

for run in range(10):
for simd_level in faiss.SIMDLevel_NONE, faiss.SIMDLevel_AVX2, faiss.SIMDLevel_AVX512F:
faiss.SIMDConfig.set_level(simd_level)

t0 = time.time()
D, I = index.search(ds.get_queries(), 10)
t1 = time.time()

times[faiss.SIMDConfig.get_level_name()].append(t1 - t0)

for simd_level in times:
print(f"simd_level={simd_level} search time: {np.mean(times[simd_level])*1000:.3f} ms")
10 changes: 5 additions & 5 deletions faiss/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,11 @@ set(FAISS_HEADERS
impl/pq4_fast_scan.h
impl/residual_quantizer_encode_steps.h
impl/simd_result_handlers.h
impl/code_distance/code_distance.h
impl/code_distance/code_distance-generic.h
impl/code_distance/code_distance-avx2.h
impl/code_distance/code_distance-avx512.h
impl/code_distance/code_distance-sve.h
impl/pq_code_distance/code_distance.h
impl/pq_code_distance/code_distance-generic.h
impl/pq_code_distance/code_distance-avx2.h
impl/pq_code_distance/code_distance-avx512.h
impl/pq_code_distance/code_distance-sve.h
invlists/BlockInvertedLists.h
invlists/DirectMap.h
invlists/InvertedLists.h
Expand Down
11 changes: 3 additions & 8 deletions faiss/IndexAdditiveQuantizerFastScan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,10 @@

#include <faiss/IndexAdditiveQuantizerFastScan.h>

#include <cassert>
#include <memory>

#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/LocalSearchQuantizer.h>
#include <faiss/impl/LookupTableScaler.h>
#include <faiss/impl/ResidualQuantizer.h>
#include <faiss/impl/pq4_fast_scan.h>
#include <faiss/impl/pq_4bit/pq4_fast_scan.h>
#include <faiss/utils/quantize_lut.h>
#include <faiss/utils/utils.h>

Expand Down Expand Up @@ -199,11 +195,10 @@ void IndexAdditiveQuantizerFastScan::search(
return;
}

NormTableScaler scaler(norm_scale);
if (metric_type == METRIC_L2) {
search_dispatch_implem<true>(n, x, k, distances, labels, &scaler);
search_dispatch_implem<true>(n, x, k, distances, labels, norm_scale);
} else {
search_dispatch_implem<false>(n, x, k, distances, labels, &scaler);
search_dispatch_implem<false>(n, x, k, distances, labels, norm_scale);
}
}

Expand Down
138 changes: 65 additions & 73 deletions faiss/IndexFastScan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,16 @@

#include <faiss/IndexFastScan.h>

#include <cassert>
#include <climits>
#include <memory>

#include <omp.h>

#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/IDSelector.h>
#include <faiss/impl/LookupTableScaler.h>
#include <faiss/impl/ResultHandler.h>
#include <faiss/impl/pq_4bit/LookupTableScaler.h>
#include <faiss/utils/hamming.h>

#include <faiss/impl/pq4_fast_scan.h>
#include <faiss/impl/simd_result_handlers.h>
#include <faiss/impl/pq_4bit/pq4_fast_scan.h>
#include <faiss/impl/pq_4bit/simd_result_handlers.h>
#include <faiss/utils/quantize_lut.h>

namespace faiss {
Expand Down Expand Up @@ -163,7 +159,7 @@ void estimators_from_tables_generic(
size_t k,
typename C::T* heap_dis,
int64_t* heap_ids,
const NormTableScaler* scaler) {
const Scaler2x4bit* scaler) {
using accu_t = typename C::T;

for (size_t j = 0; j < ncodes; ++j) {
Expand Down Expand Up @@ -193,28 +189,6 @@ void estimators_from_tables_generic(
}
}

template <class C>
ResultHandlerCompare<C, false>* make_knn_handler(
int impl,
idx_t n,
idx_t k,
size_t ntotal,
float* distances,
idx_t* labels,
const IDSelector* sel = nullptr) {
using HeapHC = HeapHandler<C, false>;
using ReservoirHC = ReservoirHandler<C, false>;
using SingleResultHC = SingleResultHandler<C, false>;

if (k == 1) {
return new SingleResultHC(n, ntotal, distances, labels, sel);
} else if (impl % 2 == 0) {
return new HeapHC(n, ntotal, k, distances, labels, sel);
} else /* if (impl % 2 == 1) */ {
return new ReservoirHC(n, ntotal, k, 2 * k, distances, labels, sel);
}
}

} // anonymous namespace

using namespace quantize_lut;
Expand Down Expand Up @@ -264,9 +238,9 @@ void IndexFastScan::search(
FAISS_THROW_IF_NOT(k > 0);

if (metric_type == METRIC_L2) {
search_dispatch_implem<true>(n, x, k, distances, labels, nullptr);
search_dispatch_implem<true>(n, x, k, distances, labels, -1);
} else {
search_dispatch_implem<false>(n, x, k, distances, labels, nullptr);
search_dispatch_implem<false>(n, x, k, distances, labels, -1);
}
}

Expand All @@ -277,7 +251,7 @@ void IndexFastScan::search_dispatch_implem(
idx_t k,
float* distances,
idx_t* labels,
const NormTableScaler* scaler) const {
int norm_scale) const {
using Cfloat = typename std::conditional<
is_max,
CMax<float, int64_t>,
Expand Down Expand Up @@ -308,15 +282,17 @@ void IndexFastScan::search_dispatch_implem(
FAISS_THROW_MSG("not implemented");
} else if (implem == 2 || implem == 3 || implem == 4) {
FAISS_THROW_IF_NOT(orig_codes != nullptr);
search_implem_234<Cfloat>(n, x, k, distances, labels, scaler);
search_implem_234<Cfloat>(n, x, k, distances, labels, norm_scale);
} else if (impl >= 12 && impl <= 15) {
FAISS_THROW_IF_NOT(ntotal < INT_MAX);
int nt = std::min(omp_get_max_threads(), int(n));
if (nt < 2) {
if (impl == 12 || impl == 13) {
search_implem_12<C>(n, x, k, distances, labels, impl, scaler);
search_implem_12<C>(
n, x, k, distances, labels, impl, norm_scale);
} else {
search_implem_14<C>(n, x, k, distances, labels, impl, scaler);
search_implem_14<C>(
n, x, k, distances, labels, impl, norm_scale);
}
} else {
// explicitly slice over threads
Expand All @@ -328,10 +304,22 @@ void IndexFastScan::search_dispatch_implem(
idx_t* lab_i = labels + i0 * k;
if (impl == 12 || impl == 13) {
search_implem_12<C>(
i1 - i0, x + i0 * d, k, dis_i, lab_i, impl, scaler);
i1 - i0,
x + i0 * d,
k,
dis_i,
lab_i,
impl,
norm_scale);
} else {
search_implem_14<C>(
i1 - i0, x + i0 * d, k, dis_i, lab_i, impl, scaler);
i1 - i0,
x + i0 * d,
k,
dis_i,
lab_i,
impl,
norm_scale);
}
}
}
Expand All @@ -347,7 +335,7 @@ void IndexFastScan::search_implem_234(
idx_t k,
float* distances,
idx_t* labels,
const NormTableScaler* scaler) const {
int norm_scale) const {
FAISS_THROW_IF_NOT(implem == 2 || implem == 3 || implem == 4);

const size_t dim12 = ksub * M;
Expand All @@ -369,6 +357,11 @@ void IndexFastScan::search_implem_234(
}
}

std::unique_ptr<Scaler2x4bit> scaler;
if (norm_scale != -1) {
scaler.reset(new Scaler2x4bit(norm_scale));
}

#pragma omp parallel for if (n > 1000)
for (int64_t i = 0; i < n; i++) {
int64_t* heap_ids = labels + i * k;
Expand All @@ -384,7 +377,7 @@ void IndexFastScan::search_implem_234(
k,
heap_dis,
heap_ids,
scaler);
scaler.get());

heap_reorder<Cfloat>(k, heap_dis, heap_ids);

Expand All @@ -407,8 +400,8 @@ void IndexFastScan::search_implem_12(
float* distances,
idx_t* labels,
int impl,
const NormTableScaler* scaler) const {
using RH = ResultHandlerCompare<C, false>;
int norm_scale) const {
using RH = PQ4CodeScanner;
FAISS_THROW_IF_NOT(bbs == 32);

// handle qbs2 blocking by recursive call
Expand All @@ -423,7 +416,7 @@ void IndexFastScan::search_implem_12(
distances + i0 * k,
labels + i0 * k,
impl,
scaler);
norm_scale);
}
return;
}
Expand Down Expand Up @@ -454,22 +447,22 @@ void IndexFastScan::search_implem_12(
pq4_pack_LUT_qbs(qbs, M2, quantized_dis_tables.get(), LUT.get());
FAISS_THROW_IF_NOT(LUT_nq == n);

std::unique_ptr<RH> handler(
make_knn_handler<C>(impl, n, k, ntotal, distances, labels));
handler->disable = bool(skip & 2);
handler->normalizers = normalizers.get();
std::unique_ptr<PQ4CodeScanner> handler(pq4_make_flat_knn_handler(
metric_type,
impl % 2 == 1,
n,
k,
ntotal,
distances,
labels,
norm_scale,
normalizers.get(),
bool(skip & 2)));

if (skip & 4) {
// pass
} else {
pq4_accumulate_loop_qbs(
qbs,
ntotal2,
M2,
codes.get(),
LUT.get(),
*handler.get(),
scaler);
handler->accumulate_loop_qbs(qbs, ntotal2, M2, codes.get(), LUT.get());
}
if (!(skip & 8)) {
handler->end();
Expand All @@ -486,8 +479,8 @@ void IndexFastScan::search_implem_14(
float* distances,
idx_t* labels,
int impl,
const NormTableScaler* scaler) const {
using RH = ResultHandlerCompare<C, false>;
int norm_scale) const {
using RH = PQ4CodeScanner;
FAISS_THROW_IF_NOT(bbs % 32 == 0);

int qbs2 = qbs == 0 ? 4 : qbs;
Expand All @@ -503,7 +496,7 @@ void IndexFastScan::search_implem_14(
distances + i0 * k,
labels + i0 * k,
impl,
scaler);
norm_scale);
}
return;
}
Expand All @@ -522,23 +515,22 @@ void IndexFastScan::search_implem_14(
AlignedTable<uint8_t> LUT(n * dim12);
pq4_pack_LUT(n, M2, quantized_dis_tables.get(), LUT.get());

std::unique_ptr<RH> handler(
make_knn_handler<C>(impl, n, k, ntotal, distances, labels));
handler->disable = bool(skip & 2);
handler->normalizers = normalizers.get();
std::unique_ptr<PQ4CodeScanner> handler(pq4_make_flat_knn_handler(
metric_type,
impl % 2 == 1,
n,
k,
ntotal,
distances,
labels,
norm_scale,
normalizers.get(),
bool(skip & 2)));

if (skip & 4) {
// pass
} else {
pq4_accumulate_loop(
n,
ntotal2,
bbs,
M2,
codes.get(),
LUT.get(),
*handler.get(),
scaler);
handler->accumulate_loop(n, ntotal2, bbs, M2, codes.get(), LUT.get());
}
if (!(skip & 8)) {
handler->end();
Expand All @@ -551,15 +543,15 @@ template void IndexFastScan::search_dispatch_implem<true>(
idx_t k,
float* distances,
idx_t* labels,
const NormTableScaler* scaler) const;
int norm_scale) const;

template void IndexFastScan::search_dispatch_implem<false>(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels,
const NormTableScaler* scaler) const;
int norm_scale) const;

void IndexFastScan::reconstruct(idx_t key, float* recons) const {
std::vector<uint8_t> code(code_size, 0);
Expand Down
Loading
Loading