From df3c7fe019ce97268af3aba2475b11e88011d016 Mon Sep 17 00:00:00 2001 From: Justin Gibbs Date: Thu, 19 Feb 2026 13:47:56 -0800 Subject: [PATCH] Additional hardening of index load path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Address input validation gaps in the load path: #: 1 FPE (divide-by-zero) Location: ProductQuantizer::set_derived_values() Root Cause: M=0 causes d % M to divide by zero Fix: Validate M > 0 in read_ProductQuantizer and set_derived_values() ──────────────────────────────────────── #: 2 OOB vector access Location: AdditiveQuantizer::set_derived_values() Root Cause: nbits.size() != M causes access to nbits[i] beyond bounds Fix: Validate nbits.size() == M in read_AdditiveQuantizer ──────────────────────────────────────── #: 3 OOB vector access Location: ResidualQuantizer old format Root Cause: Same as #2 but via read_ResidualQuantizer_old path Fix: Validate nbits.size() == M in read_ResidualQuantizer_old ──────────────────────────────────────── #: 4 Integer overflow Location: ProductQuantizer::set_derived_values() Root Cause: d * ksub overflows when nbits is large Fix: Use mul_no_overflow() for centroids.resize(), moved nbits > 24 check before code_size computation ──────────────────────────────────────── #: 5 FPE (divide-by-zero) + stack overflow Location: IndexLattice constructor Root Cause: nsq=0 causes d/nsq divide-by-zero; garbage nsq causes stack overflow in ZnSphereCodec Fix: Validate nsq > 0, d > 0, d % nsq == 0, r2 >= 0 before construction Files Modified Reviewed By: mdouze Differential Revision: D93486058 --- faiss/impl/ProductQuantizer.cpp | 5 +- faiss/impl/index_read.cpp | 21 +++ tests/test_read_index_deserialize.cpp | 223 ++++++++++++++++++++++++++ 3 files changed, 247 insertions(+), 2 deletions(-) create mode 100644 tests/test_read_index_deserialize.cpp diff --git a/faiss/impl/ProductQuantizer.cpp b/faiss/impl/ProductQuantizer.cpp index 9fbde1101c..f3e8ae7b6a 100644 --- a/faiss/impl/ProductQuantizer.cpp +++ b/faiss/impl/ProductQuantizer.cpp @@ -57,14 +57,15 @@ ProductQuantizer::ProductQuantizer() : ProductQuantizer(0, 1, 0) {} void ProductQuantizer::set_derived_values() { // quite a few derived values + FAISS_THROW_IF_NOT_MSG(M > 0, "M must be > 0"); FAISS_THROW_IF_NOT_MSG( d % M == 0, "The dimension of the vector (d) should be a multiple of the number of subquantizers (M)"); dsub = d / M; - code_size = (nbits * M + 7) / 8; FAISS_THROW_IF_MSG(nbits > 24, "nbits larger than 24 is not practical."); + code_size = (nbits * M + 7) / 8; ksub = 1 << nbits; - centroids.resize(d * ksub); + centroids.resize(mul_no_overflow(d, (size_t)ksub, "PQ centroids")); verbose = false; train_type = Train_default; } diff --git a/faiss/impl/index_read.cpp b/faiss/impl/index_read.cpp index 49a51e5bb7..c33ebf0199 100644 --- a/faiss/impl/index_read.cpp +++ b/faiss/impl/index_read.cpp @@ -447,6 +447,8 @@ void read_ProductQuantizer(ProductQuantizer* pq, IOReader* f) { READ1(pq->d); READ1(pq->M); READ1(pq->nbits); + FAISS_THROW_IF_NOT_FMT( + pq->M > 0, "invalid ProductQuantizer M=%zd (must be > 0)", pq->M); pq->set_derived_values(); READVECTOR(pq->centroids); } @@ -455,6 +457,11 @@ static void read_ResidualQuantizer_old(ResidualQuantizer& rq, IOReader* f) { READ1(rq.d); READ1(rq.M); READVECTOR(rq.nbits); + FAISS_THROW_IF_NOT_FMT( + rq.nbits.size() == rq.M, + "ResidualQuantizer nbits size %zd != M %zd", + rq.nbits.size(), + rq.M); READ1(rq.is_trained); READ1(rq.train_type); READ1(rq.max_beam_size); @@ -471,6 +478,11 @@ static void read_AdditiveQuantizer(AdditiveQuantizer& aq, IOReader* f) { READVECTOR(aq.nbits); READ1(aq.is_trained); READVECTOR(aq.codebooks); + FAISS_THROW_IF_NOT_FMT( + aq.nbits.size() == aq.M, + "AdditiveQuantizer nbits size %zd != M %zd", + aq.nbits.size(), + aq.M); READ1(aq.search_type); READ1(aq.norm_min); READ1(aq.norm_max); @@ -1190,6 +1202,15 @@ std::unique_ptr read_index_up(IOReader* f, int io_flags) { READ1(nsq); READ1(scale_nbit); READ1(r2); + FAISS_THROW_IF_NOT_FMT( + nsq > 0, "invalid IndexLattice nsq %d (must be > 0)", nsq); + FAISS_THROW_IF_NOT_FMT( + d > 0 && d % nsq == 0, + "invalid IndexLattice d=%d, nsq=%d (d must be > 0 and divisible by nsq)", + d, + nsq); + FAISS_THROW_IF_NOT_FMT( + r2 >= 0, "invalid IndexLattice r2 %d (must be >= 0)", r2); auto idxl = std::make_unique(d, nsq, scale_nbit, r2); read_index_header(*idxl, f); READVECTOR(idxl->trained); diff --git a/tests/test_read_index_deserialize.cpp b/tests/test_read_index_deserialize.cpp new file mode 100644 index 0000000000..5f24a26895 --- /dev/null +++ b/tests/test_read_index_deserialize.cpp @@ -0,0 +1,223 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +#include +#include +#include +#include + +using namespace faiss; + +/// Helper: append a scalar value to the buffer in little-endian format, +/// matching WRITE1. +template +static void push_val(std::vector& buf, T val) { + const auto* p = reinterpret_cast(&val); + buf.insert(buf.end(), p, p + sizeof(T)); +} + +/// Helper: append a WRITEVECTOR-formatted vector (size_t length prefix +/// followed by raw element data). +template +static void push_vector(std::vector& buf, const std::vector& vec) { + push_val(buf, vec.size()); + const auto* p = reinterpret_cast(vec.data()); + buf.insert(buf.end(), p, p + vec.size() * sizeof(T)); +} + +/// Helper: append a fourcc string as a uint32_t. +static void push_fourcc(std::vector& buf, const char s[4]) { + const auto* x = reinterpret_cast(s); + uint32_t h = x[0] | (x[1] << 8) | (x[2] << 16) | (x[3] << 24); + push_val(buf, h); +} + +/// Helper: append write_index_header fields. +static void push_index_header( + std::vector& buf, + int d, + int64_t ntotal, + bool is_trained = true, + int metric_type = 1 /* L2 */) { + push_val(buf, d); + push_val(buf, ntotal); + int64_t dummy = 1 << 20; + push_val(buf, dummy); + push_val(buf, dummy); + push_val(buf, is_trained); + push_val(buf, metric_type); +} + +/// Helper: append write_ProductQuantizer fields (d, M, nbits, centroids vec). +static void push_pq( + std::vector& buf, + size_t d, + size_t M, + size_t nbits, + const std::vector& centroids = {}) { + push_val(buf, d); + push_val(buf, M); + push_val(buf, nbits); + push_vector(buf, centroids); +} + +/// Try to read a float index from the given buffer and expect a FaissException. +static void expect_read_throws(const std::vector& data) { + VectorIOReader reader; + reader.data = data; + EXPECT_THROW(read_index_up(&reader), FaissException); +} + +/// Try to read a float index and expect a FaissException whose message +/// contains the given substring. +static void expect_read_throws_with( + const std::vector& data, + const std::string& expected_substr) { + VectorIOReader reader; + reader.data = data; + try { + read_index_up(&reader); + FAIL() << "expected FaissException"; + } catch (const FaissException& e) { + EXPECT_NE( + std::string(e.what()).find(expected_substr), std::string::npos) + << "expected '" << expected_substr << "' in: " << e.what(); + } +} + +// ----------------------------------------------------------------------- +// Test: ProductQuantizer with M=0 causes divide-by-zero in +// set_derived_values(). The fix validates M > 0 in read_ProductQuantizer +// before calling set_derived_values(). +// ----------------------------------------------------------------------- +TEST(ReadIndexDeserialize, PQWithMZeroDivideByZero) { + // Build a minimal MultiIndexQuantizer ("Imiq") payload with M=0. + // Format: fourcc("Imiq") + index_header + PQ(d=4, M=0, nbits=8, []) + std::vector buf; + push_fourcc(buf, "Imiq"); + push_index_header(buf, /*d=*/4, /*ntotal=*/0); + push_pq(buf, /*d=*/4, /*M=*/0, /*nbits=*/8); + + expect_read_throws(buf); +} + +// ----------------------------------------------------------------------- +// Test: AdditiveQuantizer with nbits.size() != M causes out-of-bounds +// access on the nbits vector in set_derived_values(). The fix validates +// nbits.size() == M in read_AdditiveQuantizer. +// ----------------------------------------------------------------------- +TEST(ReadIndexDeserialize, AdditiveQuantizerNbitsSizeMismatch) { + // Build a minimal IndexProductResidualQuantizerFastScan ("IPRf") payload + // whose AdditiveQuantizer has M=10 but nbits vector has only 1 element. + // + // "IPRf" format: + // fourcc + index_header + read_AdditiveQuantizer(d, M, nbits, + // is_trained, codebooks, search_type, norm_min, norm_max) + // + set_derived_values() + // + // We must provide enough data to reach set_derived_values() so the + // OOB access on nbits[i] actually triggers (rather than an earlier + // read-error throwing first). + std::vector buf; + push_fourcc(buf, "IPRf"); + push_index_header(buf, /*d=*/4, /*ntotal=*/0); + // read_AdditiveQuantizer fields: + push_val(buf, 4); // d + push_val(buf, 10); // M = 10 + // nbits vector with only 1 element (should be 10 to match M) + push_vector(buf, {8}); + // is_trained + push_val(buf, true); + // codebooks (empty vector is fine) + push_vector(buf, {}); + // search_type (ST_decompress = 0) + push_val(buf, 0); + // norm_min, norm_max + push_val(buf, 0.0f); + push_val(buf, 1.0f); + // After these reads, set_derived_values() will access nbits[1..9] + // which are out of bounds. + + expect_read_throws_with(buf, "nbits size"); +} + +// ----------------------------------------------------------------------- +// Test: ResidualQuantizer (old format) with nbits.size() != M also causes +// out-of-bounds access. Uses the "IxRQ" fourcc path. +// ----------------------------------------------------------------------- +TEST(ReadIndexDeserialize, ResidualQuantizerOldNbitsSizeMismatch) { + // "IxRQ" format: fourcc + index_header + read_ResidualQuantizer_old(...) + // read_ResidualQuantizer_old reads: d, M, nbits_vec, is_trained, ... + std::vector buf; + push_fourcc(buf, "IxRQ"); + push_index_header(buf, /*d=*/4, /*ntotal=*/0); + // ResidualQuantizer_old fields: + push_val(buf, 4); // d + push_val(buf, 5); // M = 5 + // nbits vector with only 2 elements (should be 5 to match M) + push_vector(buf, {8, 8}); + + expect_read_throws_with(buf, "nbits size"); +} + +// ----------------------------------------------------------------------- +// Test: ProductQuantizer with d * ksub overflow in centroids allocation. +// The fix uses mul_no_overflow to detect the overflow. +// ----------------------------------------------------------------------- +TEST(ReadIndexDeserialize, PQCentroidsOverflow) { + // Build a minimal "Imiq" with d very large and nbits=24 so that + // d * (1 << 24) overflows size_t. + std::vector buf; + push_fourcc(buf, "Imiq"); + // Use a huge d that, when multiplied by ksub=2^24, overflows + size_t huge_d = (size_t)1 << 48; + push_index_header(buf, /*d=*/(int)huge_d, /*ntotal=*/0); + // M must divide d; set M=1 so d % M == 0 + push_pq(buf, /*d=*/huge_d, /*M=*/1, /*nbits=*/24); + + expect_read_throws(buf); +} + +// ----------------------------------------------------------------------- +// Test: IndexLattice with nsq=0 causes divide-by-zero in the +// constructor's member initializer list (dsq = d / nsq). The fix +// validates nsq > 0 in the deserialization path. +// ----------------------------------------------------------------------- +TEST(ReadIndexDeserialize, IndexLatticeNsqZeroDivideByZero) { + // "IxLa" format: fourcc + d(int) + nsq(int) + scale_nbit(int) + // + r2(int) + index_header + READVECTOR(trained) + std::vector buf; + push_fourcc(buf, "IxLa"); + push_val(buf, 16); // d + push_val(buf, 0); // nsq = 0 -> divide by zero + push_val(buf, 4); // scale_nbit + push_val(buf, 14); // r2 + + expect_read_throws(buf); +} + +// ----------------------------------------------------------------------- +// Test: IndexLattice with d not divisible by nsq causes undefined +// behavior in the constructor. The fix validates d % nsq == 0 before +// construction. +// ----------------------------------------------------------------------- +TEST(ReadIndexDeserialize, IndexLatticeDNotDivisibleByNsq) { + std::vector buf; + push_fourcc(buf, "IxLa"); + push_val(buf, 17); // d = 17 (not divisible by nsq=4) + push_val(buf, 4); // nsq + push_val(buf, 4); // scale_nbit + push_val(buf, 14); // r2 + + expect_read_throws_with(buf, "divisible by nsq"); +}