Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 91 additions & 33 deletions faiss/impl/ScalarQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,11 +275,14 @@ struct Codec6bit {
* through a codec
*******************************************************************/

template <class Codec, bool uniform, int SIMD>
enum class QuantizerTemplateScaling { UNIFORM = 0, NON_UNIFORM = 1 };

template <class Codec, QuantizerTemplateScaling SCALING, int SIMD>
struct QuantizerTemplate {};

template <class Codec>
struct QuantizerTemplate<Codec, true, 1> : ScalarQuantizer::SQuantizer {
struct QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1>
: ScalarQuantizer::SQuantizer {
const size_t d;
const float vmin, vdiff;

Expand Down Expand Up @@ -319,9 +322,12 @@ struct QuantizerTemplate<Codec, true, 1> : ScalarQuantizer::SQuantizer {
#ifdef __AVX2__

template <class Codec>
struct QuantizerTemplate<Codec, true, 8> : QuantizerTemplate<Codec, true, 1> {
struct QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 8>
: QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1> {
QuantizerTemplate(size_t d, const std::vector<float>& trained)
: QuantizerTemplate<Codec, true, 1>(d, trained) {}
: QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1>(
d,
trained) {}

FAISS_ALWAYS_INLINE __m256
reconstruct_8_components(const uint8_t* code, int i) const {
Expand All @@ -336,9 +342,12 @@ struct QuantizerTemplate<Codec, true, 8> : QuantizerTemplate<Codec, true, 1> {
#ifdef __aarch64__

template <class Codec>
struct QuantizerTemplate<Codec, true, 8> : QuantizerTemplate<Codec, true, 1> {
struct QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 8>
: QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1> {
QuantizerTemplate(size_t d, const std::vector<float>& trained)
: QuantizerTemplate<Codec, true, 1>(d, trained) {}
: QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1>(
d,
trained) {}

FAISS_ALWAYS_INLINE float32x4x2_t
reconstruct_8_components(const uint8_t* code, int i) const {
Expand All @@ -357,7 +366,8 @@ struct QuantizerTemplate<Codec, true, 8> : QuantizerTemplate<Codec, true, 1> {
#endif

template <class Codec>
struct QuantizerTemplate<Codec, false, 1> : ScalarQuantizer::SQuantizer {
struct QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 1>
: ScalarQuantizer::SQuantizer {
const size_t d;
const float *vmin, *vdiff;

Expand Down Expand Up @@ -397,9 +407,13 @@ struct QuantizerTemplate<Codec, false, 1> : ScalarQuantizer::SQuantizer {
#ifdef __AVX2__

template <class Codec>
struct QuantizerTemplate<Codec, false, 8> : QuantizerTemplate<Codec, false, 1> {
struct QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 8>
: QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 1> {
QuantizerTemplate(size_t d, const std::vector<float>& trained)
: QuantizerTemplate<Codec, false, 1>(d, trained) {}
: QuantizerTemplate<
Codec,
QuantizerTemplateScaling::NON_UNIFORM,
1>(d, trained) {}

FAISS_ALWAYS_INLINE __m256
reconstruct_8_components(const uint8_t* code, int i) const {
Expand All @@ -416,9 +430,13 @@ struct QuantizerTemplate<Codec, false, 8> : QuantizerTemplate<Codec, false, 1> {
#ifdef __aarch64__

template <class Codec>
struct QuantizerTemplate<Codec, false, 8> : QuantizerTemplate<Codec, false, 1> {
struct QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 8>
: QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 1> {
QuantizerTemplate(size_t d, const std::vector<float>& trained)
: QuantizerTemplate<Codec, false, 1>(d, trained) {}
: QuantizerTemplate<
Codec,
QuantizerTemplateScaling::NON_UNIFORM,
1>(d, trained) {}

FAISS_ALWAYS_INLINE float32x4x2_t
reconstruct_8_components(const uint8_t* code, int i) const {
Expand Down Expand Up @@ -717,20 +735,30 @@ ScalarQuantizer::SQuantizer* select_quantizer_1(
const std::vector<float>& trained) {
switch (qtype) {
case ScalarQuantizer::QT_8bit:
return new QuantizerTemplate<Codec8bit, false, SIMDWIDTH>(
d, trained);
return new QuantizerTemplate<
Codec8bit,
QuantizerTemplateScaling::NON_UNIFORM,
SIMDWIDTH>(d, trained);
case ScalarQuantizer::QT_6bit:
return new QuantizerTemplate<Codec6bit, false, SIMDWIDTH>(
d, trained);
return new QuantizerTemplate<
Codec6bit,
QuantizerTemplateScaling::NON_UNIFORM,
SIMDWIDTH>(d, trained);
case ScalarQuantizer::QT_4bit:
return new QuantizerTemplate<Codec4bit, false, SIMDWIDTH>(
d, trained);
return new QuantizerTemplate<
Codec4bit,
QuantizerTemplateScaling::NON_UNIFORM,
SIMDWIDTH>(d, trained);
case ScalarQuantizer::QT_8bit_uniform:
return new QuantizerTemplate<Codec8bit, true, SIMDWIDTH>(
d, trained);
return new QuantizerTemplate<
Codec8bit,
QuantizerTemplateScaling::UNIFORM,
SIMDWIDTH>(d, trained);
case ScalarQuantizer::QT_4bit_uniform:
return new QuantizerTemplate<Codec4bit, true, SIMDWIDTH>(
d, trained);
return new QuantizerTemplate<
Codec4bit,
QuantizerTemplateScaling::UNIFORM,
SIMDWIDTH>(d, trained);
case ScalarQuantizer::QT_fp16:
return new QuantizerFP16<SIMDWIDTH>(d, trained);
case ScalarQuantizer::QT_bf16:
Expand Down Expand Up @@ -1494,31 +1522,46 @@ SQDistanceComputer* select_distance_computer(
switch (qtype) {
case ScalarQuantizer::QT_8bit_uniform:
return new DCTemplate<
QuantizerTemplate<Codec8bit, true, SIMDWIDTH>,
QuantizerTemplate<
Codec8bit,
QuantizerTemplateScaling::UNIFORM,
SIMDWIDTH>,
Sim,
SIMDWIDTH>(d, trained);

case ScalarQuantizer::QT_4bit_uniform:
return new DCTemplate<
QuantizerTemplate<Codec4bit, true, SIMDWIDTH>,
QuantizerTemplate<
Codec4bit,
QuantizerTemplateScaling::UNIFORM,
SIMDWIDTH>,
Sim,
SIMDWIDTH>(d, trained);

case ScalarQuantizer::QT_8bit:
return new DCTemplate<
QuantizerTemplate<Codec8bit, false, SIMDWIDTH>,
QuantizerTemplate<
Codec8bit,
QuantizerTemplateScaling::NON_UNIFORM,
SIMDWIDTH>,
Sim,
SIMDWIDTH>(d, trained);

case ScalarQuantizer::QT_6bit:
return new DCTemplate<
QuantizerTemplate<Codec6bit, false, SIMDWIDTH>,
QuantizerTemplate<
Codec6bit,
QuantizerTemplateScaling::NON_UNIFORM,
SIMDWIDTH>,
Sim,
SIMDWIDTH>(d, trained);

case ScalarQuantizer::QT_4bit:
return new DCTemplate<
QuantizerTemplate<Codec4bit, false, SIMDWIDTH>,
QuantizerTemplate<
Codec4bit,
QuantizerTemplateScaling::NON_UNIFORM,
SIMDWIDTH>,
Sim,
SIMDWIDTH>(d, trained);

Expand Down Expand Up @@ -1912,15 +1955,15 @@ InvertedListScanner* sel2_InvertedListScanner(
}
}

template <class Similarity, class Codec, bool uniform>
template <class Similarity, class Codec, QuantizerTemplateScaling SCALING>
InvertedListScanner* sel12_InvertedListScanner(
const ScalarQuantizer* sq,
const Index* quantizer,
bool store_pairs,
const IDSelector* sel,
bool r) {
constexpr int SIMDWIDTH = Similarity::simdwidth;
using QuantizerClass = QuantizerTemplate<Codec, uniform, SIMDWIDTH>;
using QuantizerClass = QuantizerTemplate<Codec, SCALING, SIMDWIDTH>;
using DCClass = DCTemplate<QuantizerClass, Similarity, SIMDWIDTH>;
return sel2_InvertedListScanner<DCClass>(
sq, quantizer, store_pairs, sel, r);
Expand All @@ -1936,19 +1979,34 @@ InvertedListScanner* sel1_InvertedListScanner(
constexpr int SIMDWIDTH = Similarity::simdwidth;
switch (sq->qtype) {
case ScalarQuantizer::QT_8bit_uniform:
return sel12_InvertedListScanner<Similarity, Codec8bit, true>(
return sel12_InvertedListScanner<
Similarity,
Codec8bit,
QuantizerTemplateScaling::UNIFORM>(
sq, quantizer, store_pairs, sel, r);
case ScalarQuantizer::QT_4bit_uniform:
return sel12_InvertedListScanner<Similarity, Codec4bit, true>(
return sel12_InvertedListScanner<
Similarity,
Codec4bit,
QuantizerTemplateScaling::UNIFORM>(
sq, quantizer, store_pairs, sel, r);
case ScalarQuantizer::QT_8bit:
return sel12_InvertedListScanner<Similarity, Codec8bit, false>(
return sel12_InvertedListScanner<
Similarity,
Codec8bit,
QuantizerTemplateScaling::NON_UNIFORM>(
sq, quantizer, store_pairs, sel, r);
case ScalarQuantizer::QT_4bit:
return sel12_InvertedListScanner<Similarity, Codec4bit, false>(
return sel12_InvertedListScanner<
Similarity,
Codec4bit,
QuantizerTemplateScaling::NON_UNIFORM>(
sq, quantizer, store_pairs, sel, r);
case ScalarQuantizer::QT_6bit:
return sel12_InvertedListScanner<Similarity, Codec6bit, false>(
return sel12_InvertedListScanner<
Similarity,
Codec6bit,
QuantizerTemplateScaling::NON_UNIFORM>(
sq, quantizer, store_pairs, sel, r);
case ScalarQuantizer::QT_fp16:
return sel2_InvertedListScanner<DCTemplate<
Expand Down