diff --git a/faiss/impl/ScalarQuantizer.cpp b/faiss/impl/ScalarQuantizer.cpp index 528843f606..0fb957ebb9 100644 --- a/faiss/impl/ScalarQuantizer.cpp +++ b/faiss/impl/ScalarQuantizer.cpp @@ -275,11 +275,14 @@ struct Codec6bit { * through a codec *******************************************************************/ -template +enum class QuantizerTemplateScaling { UNIFORM = 0, NON_UNIFORM = 1 }; + +template struct QuantizerTemplate {}; template -struct QuantizerTemplate : ScalarQuantizer::SQuantizer { +struct QuantizerTemplate + : ScalarQuantizer::SQuantizer { const size_t d; const float vmin, vdiff; @@ -319,9 +322,12 @@ struct QuantizerTemplate : ScalarQuantizer::SQuantizer { #ifdef __AVX2__ template -struct QuantizerTemplate : QuantizerTemplate { +struct QuantizerTemplate + : QuantizerTemplate { QuantizerTemplate(size_t d, const std::vector& trained) - : QuantizerTemplate(d, trained) {} + : QuantizerTemplate( + d, + trained) {} FAISS_ALWAYS_INLINE __m256 reconstruct_8_components(const uint8_t* code, int i) const { @@ -336,9 +342,12 @@ struct QuantizerTemplate : QuantizerTemplate { #ifdef __aarch64__ template -struct QuantizerTemplate : QuantizerTemplate { +struct QuantizerTemplate + : QuantizerTemplate { QuantizerTemplate(size_t d, const std::vector& trained) - : QuantizerTemplate(d, trained) {} + : QuantizerTemplate( + d, + trained) {} FAISS_ALWAYS_INLINE float32x4x2_t reconstruct_8_components(const uint8_t* code, int i) const { @@ -357,7 +366,8 @@ struct QuantizerTemplate : QuantizerTemplate { #endif template -struct QuantizerTemplate : ScalarQuantizer::SQuantizer { +struct QuantizerTemplate + : ScalarQuantizer::SQuantizer { const size_t d; const float *vmin, *vdiff; @@ -397,9 +407,13 @@ struct QuantizerTemplate : ScalarQuantizer::SQuantizer { #ifdef __AVX2__ template -struct QuantizerTemplate : QuantizerTemplate { +struct QuantizerTemplate + : QuantizerTemplate { QuantizerTemplate(size_t d, const std::vector& trained) - : QuantizerTemplate(d, trained) {} + : QuantizerTemplate< + Codec, + QuantizerTemplateScaling::NON_UNIFORM, + 1>(d, trained) {} FAISS_ALWAYS_INLINE __m256 reconstruct_8_components(const uint8_t* code, int i) const { @@ -416,9 +430,13 @@ struct QuantizerTemplate : QuantizerTemplate { #ifdef __aarch64__ template -struct QuantizerTemplate : QuantizerTemplate { +struct QuantizerTemplate + : QuantizerTemplate { QuantizerTemplate(size_t d, const std::vector& trained) - : QuantizerTemplate(d, trained) {} + : QuantizerTemplate< + Codec, + QuantizerTemplateScaling::NON_UNIFORM, + 1>(d, trained) {} FAISS_ALWAYS_INLINE float32x4x2_t reconstruct_8_components(const uint8_t* code, int i) const { @@ -717,20 +735,30 @@ ScalarQuantizer::SQuantizer* select_quantizer_1( const std::vector& trained) { switch (qtype) { case ScalarQuantizer::QT_8bit: - return new QuantizerTemplate( - d, trained); + return new QuantizerTemplate< + Codec8bit, + QuantizerTemplateScaling::NON_UNIFORM, + SIMDWIDTH>(d, trained); case ScalarQuantizer::QT_6bit: - return new QuantizerTemplate( - d, trained); + return new QuantizerTemplate< + Codec6bit, + QuantizerTemplateScaling::NON_UNIFORM, + SIMDWIDTH>(d, trained); case ScalarQuantizer::QT_4bit: - return new QuantizerTemplate( - d, trained); + return new QuantizerTemplate< + Codec4bit, + QuantizerTemplateScaling::NON_UNIFORM, + SIMDWIDTH>(d, trained); case ScalarQuantizer::QT_8bit_uniform: - return new QuantizerTemplate( - d, trained); + return new QuantizerTemplate< + Codec8bit, + QuantizerTemplateScaling::UNIFORM, + SIMDWIDTH>(d, trained); case ScalarQuantizer::QT_4bit_uniform: - return new QuantizerTemplate( - d, trained); + return new QuantizerTemplate< + Codec4bit, + QuantizerTemplateScaling::UNIFORM, + SIMDWIDTH>(d, trained); case ScalarQuantizer::QT_fp16: return new QuantizerFP16(d, trained); case ScalarQuantizer::QT_bf16: @@ -1494,31 +1522,46 @@ SQDistanceComputer* select_distance_computer( switch (qtype) { case ScalarQuantizer::QT_8bit_uniform: return new DCTemplate< - QuantizerTemplate, + QuantizerTemplate< + Codec8bit, + QuantizerTemplateScaling::UNIFORM, + SIMDWIDTH>, Sim, SIMDWIDTH>(d, trained); case ScalarQuantizer::QT_4bit_uniform: return new DCTemplate< - QuantizerTemplate, + QuantizerTemplate< + Codec4bit, + QuantizerTemplateScaling::UNIFORM, + SIMDWIDTH>, Sim, SIMDWIDTH>(d, trained); case ScalarQuantizer::QT_8bit: return new DCTemplate< - QuantizerTemplate, + QuantizerTemplate< + Codec8bit, + QuantizerTemplateScaling::NON_UNIFORM, + SIMDWIDTH>, Sim, SIMDWIDTH>(d, trained); case ScalarQuantizer::QT_6bit: return new DCTemplate< - QuantizerTemplate, + QuantizerTemplate< + Codec6bit, + QuantizerTemplateScaling::NON_UNIFORM, + SIMDWIDTH>, Sim, SIMDWIDTH>(d, trained); case ScalarQuantizer::QT_4bit: return new DCTemplate< - QuantizerTemplate, + QuantizerTemplate< + Codec4bit, + QuantizerTemplateScaling::NON_UNIFORM, + SIMDWIDTH>, Sim, SIMDWIDTH>(d, trained); @@ -1912,7 +1955,7 @@ InvertedListScanner* sel2_InvertedListScanner( } } -template +template InvertedListScanner* sel12_InvertedListScanner( const ScalarQuantizer* sq, const Index* quantizer, @@ -1920,7 +1963,7 @@ InvertedListScanner* sel12_InvertedListScanner( const IDSelector* sel, bool r) { constexpr int SIMDWIDTH = Similarity::simdwidth; - using QuantizerClass = QuantizerTemplate; + using QuantizerClass = QuantizerTemplate; using DCClass = DCTemplate; return sel2_InvertedListScanner( sq, quantizer, store_pairs, sel, r); @@ -1936,19 +1979,34 @@ InvertedListScanner* sel1_InvertedListScanner( constexpr int SIMDWIDTH = Similarity::simdwidth; switch (sq->qtype) { case ScalarQuantizer::QT_8bit_uniform: - return sel12_InvertedListScanner( + return sel12_InvertedListScanner< + Similarity, + Codec8bit, + QuantizerTemplateScaling::UNIFORM>( sq, quantizer, store_pairs, sel, r); case ScalarQuantizer::QT_4bit_uniform: - return sel12_InvertedListScanner( + return sel12_InvertedListScanner< + Similarity, + Codec4bit, + QuantizerTemplateScaling::UNIFORM>( sq, quantizer, store_pairs, sel, r); case ScalarQuantizer::QT_8bit: - return sel12_InvertedListScanner( + return sel12_InvertedListScanner< + Similarity, + Codec8bit, + QuantizerTemplateScaling::NON_UNIFORM>( sq, quantizer, store_pairs, sel, r); case ScalarQuantizer::QT_4bit: - return sel12_InvertedListScanner( + return sel12_InvertedListScanner< + Similarity, + Codec4bit, + QuantizerTemplateScaling::NON_UNIFORM>( sq, quantizer, store_pairs, sel, r); case ScalarQuantizer::QT_6bit: - return sel12_InvertedListScanner( + return sel12_InvertedListScanner< + Similarity, + Codec6bit, + QuantizerTemplateScaling::NON_UNIFORM>( sq, quantizer, store_pairs, sel, r); case ScalarQuantizer::QT_fp16: return sel2_InvertedListScanner