From b216ae6ae3a889c6e004d0d9bf8a19c2b1c613db Mon Sep 17 00:00:00 2001 From: matthijs <> Date: Mon, 10 Mar 2025 00:05:18 -0700 Subject: [PATCH 1/2] Generated from a GitHub Pull Request. Run 'jf sync' on this diff to load the correct commit data. Differential Revision: D69972250 --- faiss/CMakeLists.txt | 4 +- faiss/IndexFlatCodes.cpp | 2 +- faiss/IndexFlatCodes.h | 6 +- faiss/impl/HNSW.cpp | 2 +- faiss/impl/HNSW.h | 3 +- faiss/impl/index_read.cpp | 205 ++++++++++++++++++++-- faiss/impl/io.h | 4 +- faiss/impl/io_macros.h | 34 ++-- faiss/impl/mapped_io.cpp | 281 ++++++++++++++++++++++++++++++ faiss/impl/mapped_io.h | 44 +++++ faiss/impl/maybe_owned_vector.h | 232 ++++++++++++++++++++++++ faiss/impl/zerocopy_io.cpp | 56 ++++++ faiss/impl/zerocopy_io.h | 25 +++ faiss/index_io.h | 4 + faiss/python/array_conversions.py | 18 ++ faiss/python/swigfaiss.swig | 17 ++ tests/CMakeLists.txt | 2 + tests/test_mmap.cpp | 146 ++++++++++++++++ tests/test_zerocopy.cpp | 132 ++++++++++++++ 19 files changed, 1177 insertions(+), 40 deletions(-) create mode 100644 faiss/impl/mapped_io.cpp create mode 100644 faiss/impl/mapped_io.h create mode 100644 faiss/impl/maybe_owned_vector.h create mode 100644 faiss/impl/zerocopy_io.cpp create mode 100644 faiss/impl/zerocopy_io.h create mode 100644 tests/test_mmap.cpp create mode 100644 tests/test_zerocopy.cpp diff --git a/faiss/CMakeLists.txt b/faiss/CMakeLists.txt index 61a2b6d4c2..ca889ba795 100644 --- a/faiss/CMakeLists.txt +++ b/faiss/CMakeLists.txt @@ -69,12 +69,12 @@ set(FAISS_SRC impl/io.cpp impl/kmeans1d.cpp impl/lattice_Zn.cpp + impl/mapped_io.cpp impl/pq4_fast_scan.cpp impl/pq4_fast_scan_search_1.cpp impl/pq4_fast_scan_search_qbs.cpp impl/residual_quantizer_encode_steps.cpp - impl/io.cpp - impl/lattice_Zn.cpp + impl/zerocopy_io.cpp impl/NNDescent.cpp invlists/BlockInvertedLists.cpp invlists/DirectMap.cpp diff --git a/faiss/IndexFlatCodes.cpp b/faiss/IndexFlatCodes.cpp index 54c86c6a9b..d5b86b385e 100644 --- a/faiss/IndexFlatCodes.cpp +++ b/faiss/IndexFlatCodes.cpp @@ -110,7 +110,7 @@ CodePacker* IndexFlatCodes::get_CodePacker() const { } void IndexFlatCodes::permute_entries(const idx_t* perm) { - std::vector new_codes(codes.size()); + MaybeOwnedVector new_codes(codes.size()); for (idx_t i = 0; i < ntotal; i++) { memcpy(new_codes.data() + i * code_size, diff --git a/faiss/IndexFlatCodes.h b/faiss/IndexFlatCodes.h index 809862f1e2..56a11df795 100644 --- a/faiss/IndexFlatCodes.h +++ b/faiss/IndexFlatCodes.h @@ -7,9 +7,11 @@ #pragma once +#include + #include #include -#include +#include namespace faiss { @@ -21,7 +23,7 @@ struct IndexFlatCodes : Index { size_t code_size; /// encoded dataset, size ntotal * code_size - std::vector codes; + MaybeOwnedVector codes; IndexFlatCodes(); diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index 61afa00231..4595db7e32 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -1085,7 +1085,7 @@ void HNSW::permute_entries(const idx_t* map) { // swap everyone std::swap(levels, new_levels); std::swap(offsets, new_offsets); - std::swap(neighbors, new_neighbors); + neighbors = std::move(new_neighbors); } /************************************************************** diff --git a/faiss/impl/HNSW.h b/faiss/impl/HNSW.h index f79f1c1199..5a3b7f93f2 100644 --- a/faiss/impl/HNSW.h +++ b/faiss/impl/HNSW.h @@ -15,6 +15,7 @@ #include #include +#include #include #include #include @@ -121,7 +122,7 @@ struct HNSW { /// neighbors[offsets[i]:offsets[i+1]] is the list of neighbors of vector i /// for all levels. this is where all storage goes. - std::vector neighbors; + MaybeOwnedVector neighbors; /// entry point in the search structure (one of the points with maximum /// level diff --git a/faiss/impl/index_read.cpp b/faiss/impl/index_read.cpp index c7887a4f53..58ed848a26 100644 --- a/faiss/impl/index_read.cpp +++ b/faiss/impl/index_read.cpp @@ -53,8 +53,167 @@ #include #include +// mmap-ing and viewing facilities +#include + +#include +#include + namespace faiss { +/************************************************************* + * Mmap-ing and viewing facilities + **************************************************************/ + +template +void read_vector_with_size(VectorT& target, IOReader* f, size_t size) { + ZeroCopyIOReader* zr = dynamic_cast(f); + if (zr != nullptr) { + if constexpr (is_maybe_owned_vector_v) { + // create a view + char* address = nullptr; + size_t nread = zr->get_data_view( + (void**)&address, + sizeof(typename VectorT::value_type), + size); + + FAISS_THROW_IF_NOT_FMT( + nread == (size), + "read error in %s: %zd != %zd (%s)", + f->name.c_str(), + nread, + size_t(size), + strerror(errno)); + + VectorT view = VectorT::create_view(address, nread); + target = std::move(view); + + return; + } + } + + target.resize(size); + READANDCHECK(target.data(), size); +} + +template +void read_vector(VectorT& target, IOReader* f) { + // is it a mmap-enabled reader? + MappedFileIOReader* mf = dynamic_cast(f); + if (mf != nullptr) { + // check if the use case is right + if constexpr (is_maybe_owned_vector_v) { + // read the size + size_t size = 0; + READANDCHECK(&size, 1); + // ok, mmap and check + char* address = nullptr; + const size_t nread = mf->mmap( + (void**)&address, + sizeof(typename VectorT::value_type), + size); + + FAISS_THROW_IF_NOT_FMT( + nread == (size), + "read error in %s: %zd != %zd (%s)", + f->name.c_str(), + nread, + size, + strerror(errno)); + + VectorT mmapped_view = + VectorT::create_view(address, nread, mf->mmap_owner); + target = std::move(mmapped_view); + + return; + } + } + + // is it a zero-copy reader? + ZeroCopyIOReader* zr = dynamic_cast(f); + if (zr != nullptr) { + if constexpr (is_maybe_owned_vector_v) { + // read the size first + size_t size = target.size(); + READANDCHECK(&size, 1); + + // create a view + char* address = nullptr; + size_t nread = zr->get_data_view( + (void**)&address, + sizeof(typename VectorT::value_type), + size); + VectorT view = VectorT::create_view(address, nread, nullptr); + target = std::move(view); + + return; + } + } + + // the default case + READVECTOR(target); +} + +template +void read_xb_vector(VectorT& target, IOReader* f) { + // is it a mmap-enabled reader? + MappedFileIOReader* mf = dynamic_cast(f); + if (mf != nullptr) { + // check if the use case is right + if constexpr (is_maybe_owned_vector_v) { + // read the size + size_t size = 0; + READANDCHECK(&size, 1); + + size *= 4; + + // ok, mmap and check + char* address = nullptr; + const size_t nread = mf->mmap( + (void**)&address, + sizeof(typename VectorT::value_type), + size); + + FAISS_THROW_IF_NOT_FMT( + nread == (size), + "read error in %s: %zd != %zd (%s)", + f->name.c_str(), + nread, + size, + strerror(errno)); + + VectorT mmapped_view = + VectorT::create_view(address, nread, mf->mmap_owner); + target = std::move(mmapped_view); + + return; + } + } + + ZeroCopyIOReader* zr = dynamic_cast(f); + if (zr != nullptr) { + if constexpr (std::is_same_v>) { + // read the size first + size_t size = target.size(); + READANDCHECK(&size, 1); + + size *= 4; + + char* address = nullptr; + size_t nread = zr->get_data_view( + (void**)&address, + sizeof(typename VectorT::value_type), + size); + VectorT view = VectorT::create_view(address, nread, nullptr); + target = std::move(view); + return; + } + } + + // the default case + READXBVECTOR(target); +} + /************************************************************* * Read **************************************************************/ @@ -275,7 +434,7 @@ static void read_AdditiveQuantizer(AdditiveQuantizer* aq, IOReader* f) { aq->search_type == AdditiveQuantizer::ST_norm_cqint4 || aq->search_type == AdditiveQuantizer::ST_norm_lsq2x4 || aq->search_type == AdditiveQuantizer::ST_norm_rq2x4) { - READXBVECTOR(aq->qnorm.codes); + read_xb_vector(aq->qnorm.codes, f); aq->qnorm.ntotal = aq->qnorm.codes.size() / 4; aq->qnorm.update_permutation(); } @@ -365,7 +524,7 @@ static void read_HNSW(HNSW* hnsw, IOReader* f) { READVECTOR(hnsw->cum_nneighbor_per_level); READVECTOR(hnsw->levels); READVECTOR(hnsw->offsets); - READVECTOR(hnsw->neighbors); + read_vector(hnsw->neighbors, f); READ1(hnsw->entry_point); READ1(hnsw->max_level); @@ -545,7 +704,7 @@ Index* read_index(IOReader* f, int io_flags) { } read_index_header(idxf, f); idxf->code_size = idxf->d * sizeof(float); - READXBVECTOR(idxf->codes); + read_xb_vector(idxf->codes, f); FAISS_THROW_IF_NOT( idxf->codes.size() == idxf->ntotal * idxf->code_size); // leak! @@ -576,7 +735,7 @@ Index* read_index(IOReader* f, int io_flags) { idxl->rrot = *rrot; delete rrot; } - READVECTOR(idxl->codes); + read_vector(idxl->codes, f); FAISS_THROW_IF_NOT( idxl->rrot.d_in == idxl->d && idxl->rrot.d_out == idxl->nbits); FAISS_THROW_IF_NOT( @@ -589,7 +748,7 @@ Index* read_index(IOReader* f, int io_flags) { read_index_header(idxp, f); read_ProductQuantizer(&idxp->pq, f); idxp->code_size = idxp->pq.code_size; - READVECTOR(idxp->codes); + read_vector(idxp->codes, f); if (h == fourcc("IxPo") || h == fourcc("IxPq")) { READ1(idxp->search_type); READ1(idxp->encode_signs); @@ -611,28 +770,28 @@ Index* read_index(IOReader* f, int io_flags) { read_ResidualQuantizer(&idxr->rq, f, io_flags); } READ1(idxr->code_size); - READVECTOR(idxr->codes); + read_vector(idxr->codes, f); idx = idxr; } else if (h == fourcc("IxLS")) { auto idxr = new IndexLocalSearchQuantizer(); read_index_header(idxr, f); read_LocalSearchQuantizer(&idxr->lsq, f); READ1(idxr->code_size); - READVECTOR(idxr->codes); + read_vector(idxr->codes, f); idx = idxr; } else if (h == fourcc("IxPR")) { auto idxpr = new IndexProductResidualQuantizer(); read_index_header(idxpr, f); read_ProductResidualQuantizer(&idxpr->prq, f, io_flags); READ1(idxpr->code_size); - READVECTOR(idxpr->codes); + read_vector(idxpr->codes, f); idx = idxpr; } else if (h == fourcc("IxPL")) { auto idxpl = new IndexProductLocalSearchQuantizer(); read_index_header(idxpl, f); read_ProductLocalSearchQuantizer(&idxpl->plsq, f); READ1(idxpl->code_size); - READVECTOR(idxpl->codes); + read_vector(idxpl->codes, f); idx = idxpl; } else if (h == fourcc("ImRQ")) { ResidualCoarseQuantizer* idxr = new ResidualCoarseQuantizer(); @@ -789,7 +948,7 @@ Index* read_index(IOReader* f, int io_flags) { IndexScalarQuantizer* idxs = new IndexScalarQuantizer(); read_index_header(idxs, f); read_ScalarQuantizer(&idxs->sq, f); - READVECTOR(idxs->codes); + read_vector(idxs->codes, f); idxs->code_size = idxs->sq.code_size; idx = idxs; } else if (h == fourcc("IxLa")) { @@ -947,7 +1106,7 @@ Index* read_index(IOReader* f, int io_flags) { READ1(idxp->code_size_1); READ1(idxp->code_size_2); READ1(idxp->code_size); - READVECTOR(idxp->codes); + read_vector(idxp->codes, f); idx = idxp; } else if ( h == fourcc("IHNf") || h == fourcc("IHNp") || h == fourcc("IHNs") || @@ -1071,14 +1230,28 @@ Index* read_index(IOReader* f, int io_flags) { } Index* read_index(FILE* f, int io_flags) { - FileIOReader reader(f); - return read_index(&reader, io_flags); + if ((io_flags & IO_FLAG_MMAP_IFC) == IO_FLAG_MMAP_IFC) { + // enable mmap-supporting IOReader + auto owner = std::make_shared(f); + MappedFileIOReader reader(owner); + return read_index(&reader, io_flags); + } else { + FileIOReader reader(f); + return read_index(&reader, io_flags); + } } Index* read_index(const char* fname, int io_flags) { - FileIOReader reader(fname); - Index* idx = read_index(&reader, io_flags); - return idx; + if ((io_flags & IO_FLAG_MMAP_IFC) == IO_FLAG_MMAP_IFC) { + // enable mmap-supporting IOReader + auto owner = std::make_shared(fname); + MappedFileIOReader reader(owner); + return read_index(&reader, io_flags); + } else { + FileIOReader reader(fname); + Index* idx = read_index(&reader, io_flags); + return idx; + } } VectorTransform* read_VectorTransform(const char* fname) { diff --git a/faiss/impl/io.h b/faiss/impl/io.h index 56a074fec1..ebd640fef5 100644 --- a/faiss/impl/io.h +++ b/faiss/impl/io.h @@ -16,12 +16,12 @@ #pragma once +#include +#include #include #include #include -#include - namespace faiss { struct IOReader { diff --git a/faiss/impl/io_macros.h b/faiss/impl/io_macros.h index c874ccf35c..b1eb88e1d5 100644 --- a/faiss/impl/io_macros.h +++ b/faiss/impl/io_macros.h @@ -7,6 +7,8 @@ #pragma once +#include + /************************************************************* * I/O macros * @@ -36,13 +38,14 @@ } // will fail if we write 256G of data at once... -#define READVECTOR(vec) \ - { \ - size_t size; \ - READANDCHECK(&size, 1); \ - FAISS_THROW_IF_NOT(size >= 0 && size < (uint64_t{1} << 40)); \ - (vec).resize(size); \ - READANDCHECK((vec).data(), size); \ +#define READVECTOR(vec) \ + { \ + static_assert(!faiss::is_maybe_owned_vector_v); \ + size_t size; \ + READANDCHECK(&size, 1); \ + FAISS_THROW_IF_NOT(size >= 0 && size < (uint64_t{1} << 40)); \ + (vec).resize(size); \ + READANDCHECK((vec).data(), size); \ } #define WRITEANDCHECK(ptr, n) \ @@ -76,12 +79,13 @@ WRITEANDCHECK((vec).data(), size * 4); \ } -#define READXBVECTOR(vec) \ - { \ - size_t size; \ - READANDCHECK(&size, 1); \ - FAISS_THROW_IF_NOT(size >= 0 && size < (uint64_t{1} << 40)); \ - size *= 4; \ - (vec).resize(size); \ - READANDCHECK((vec).data(), size); \ +#define READXBVECTOR(vec) \ + { \ + size_t size; \ + static_assert(!faiss::is_maybe_owned_vector_v); \ + READANDCHECK(&size, 1); \ + FAISS_THROW_IF_NOT(size >= 0 && size < (uint64_t{1} << 40)); \ + size *= 4; \ + (vec).resize(size); \ + READANDCHECK((vec).data(), size); \ } diff --git a/faiss/impl/mapped_io.cpp b/faiss/impl/mapped_io.cpp new file mode 100644 index 0000000000..2e23a8a47e --- /dev/null +++ b/faiss/impl/mapped_io.cpp @@ -0,0 +1,281 @@ +#include +#include + +#ifdef __linux__ + +#include +#include +#include +#include +#include + +#elif defined(_WIN32) + +#include +#include + +#endif + +#include + +#include +#include + +namespace faiss { + +#ifdef __linux__ + +struct MmappedFileMappingOwner::PImpl { + void* ptr = nullptr; + size_t ptr_size = 0; + + PImpl(const std::string& filename) { + auto f = std::unique_ptr( + fopen(filename.c_str(), "r"), &fclose); + FAISS_THROW_IF_NOT_FMT( + f.get(), + "could not open %s for reading: %s", + filename.c_str(), + strerror(errno)); + + // get the size + struct stat s; + int status = fstat(fileno(f.get()), &s); + FAISS_THROW_IF_NOT_FMT( + status >= 0, "fstat() failed: %s", strerror(errno)); + + const size_t filesize = s.st_size; + + void* address = mmap( + nullptr, filesize, PROT_READ, MAP_SHARED, fileno(f.get()), 0); + FAISS_THROW_IF_NOT_FMT( + address != nullptr, "could not mmap(): %s", strerror(errno)); + + // btw, fd can be closed here + + madvise(address, filesize, MADV_RANDOM); + + // save it + ptr = address; + ptr_size = filesize; + } + + PImpl(FILE* f) { + // get the size + struct stat s; + int status = fstat(fileno(f), &s); + FAISS_THROW_IF_NOT_FMT( + status >= 0, "fstat() failed: %s", strerror(errno)); + + const size_t filesize = s.st_size; + + void* address = + mmap(nullptr, filesize, PROT_READ, MAP_SHARED, fileno(f), 0); + FAISS_THROW_IF_NOT_FMT( + address != nullptr, "could not mmap(): %s", strerror(errno)); + + // btw, fd can be closed here + + madvise(address, filesize, MADV_RANDOM); + + // save it + ptr = address; + ptr_size = filesize; + } + + ~PImpl() { + // todo: check for an error + munmap(ptr, ptr_size); + } +}; + +#elif defined(_WIN32) + +struct MmappedFileMappingOwner::PImpl { + void* ptr = nullptr; + size_t ptr_size = 0; + + PImpl(const std::string& filename) { + HANDLE file_handle = CreateFile( + filename.c_str(), + GENERIC_READ, + FILE_SHARE_READ, + nullptr, + OPEN_EXISTING, + 0, + nullptr); + if (file_handle == INVALID_HANDLE_VALUE) { + const auto error = GetLastError(); + FAISS_THROW_FMT( + "could not open the file, %s (error %d)", + filename.c_str(), + error); + } + + // get the size of the file + LARGE_INTEGER len_li; + if (GetFileSizeEx(file_handle, &len_li) == 0) { + const auto error = GetLastError(); + FAISS_THROW_FMT( + "could not get the file size, %s (error %d)", + filename.c_str(), + error); + } + + // create a mapping + HANDLE mapping_handle = CreateFileMapping( + file_handle, nullptr, PAGE_READONLY, 0, 0, nullptr); + if (mapping_handle == 0) { + const auto error = GetLastError(); + FAISS_THROW_FMT( + "could not create a file mapping, %s (error %d)", + filename.c_str(), + error); + } + CloseHandle(file_handle); + + char* data = + (char*)MapViewOfFile(mapping_handle, FILE_MAP_READ, 0, 0, 0); + if (data == nullptr) { + const auto error = GetLastError(); + FAISS_THROW_FMT( + "could not get map the file, %s (error %d)", + filename.c_str(), + error); + } + + ptr = data; + ptr_size = len_li.QuadPart; + } + + PImpl(FILE* f) { + // obtain a HANDLE from a FILE + const int fd = _fileno(f); + if (fd == -1) { + // no good + FAISS_THROW_FMT("could not get a HANDLE"); + } + + HANDLE file_handle = (HANDLE)_get_osfhandle(fd); + if (file_handle == INVALID_HANDLE_VALUE) { + FAISS_THROW_FMT("could not get an OS HANDLE"); + } + + // get the size of the file + LARGE_INTEGER len_li; + if (GetFileSizeEx(file_handle, &len_li) == 0) { + const auto error = GetLastError(); + FAISS_THROW_FMT("could not get the file size (error %d)", error); + } + + // create a mapping + HANDLE mapping_handle = CreateFileMapping( + file_handle, nullptr, PAGE_READONLY, 0, 0, nullptr); + if (mapping_handle == 0) { + const auto error = GetLastError(); + FAISS_THROW_FMT( + "could not create a file mapping, (error %d)", error); + } + CloseHandle(file_handle); + + char* data = + (char*)MapViewOfFile(mapping_handle, FILE_MAP_READ, 0, 0, 0); + if (data == nullptr) { + const auto error = GetLastError(); + FAISS_THROW_FMT("could not get map the file, (error %d)", error); + } + + ptr = data; + ptr_size = len_li.QuadPart; + } + + ~PImpl() { + if (ptr != nullptr) { + UnmapViewOfFile(ptr); + CloseHandle(ptr); + + ptr = nullptr; + } + } +}; + +#else + +struct MmappedFileMappingOwner::PImpl { + PImpl(FILE* f) { + FAISS_THROW_FMT("Not implemented"); + } + + ~PImpl() { + FAISS_THROW_FMT("Not implemented"); + } +}; + +#endif + +MmappedFileMappingOwner::MmappedFileMappingOwner(const std::string& filename) { + p_impl = std::make_unique(filename); +} + +MmappedFileMappingOwner::MmappedFileMappingOwner(FILE* f) { + p_impl = std::make_unique(f); +} + +MmappedFileMappingOwner::~MmappedFileMappingOwner() = default; + +// +void* MmappedFileMappingOwner::data() const { + return p_impl->ptr; +} + +size_t MmappedFileMappingOwner::size() const { + return p_impl->ptr_size; +} + +MappedFileIOReader::MappedFileIOReader( + const std::shared_ptr& owner) + : mmap_owner(owner) {} + +// this operation performs a copy +size_t MappedFileIOReader::operator()(void* ptr, size_t size, size_t nitems) { + char* ptr_c = nullptr; + + const size_t actual_nitems = this->mmap((void**)&ptr_c, size, nitems); + if (actual_nitems > 0) { + memcpy(ptr, ptr_c, size * actual_nitems); + } + + return actual_nitems; +} + +// this operation returns a mmapped address, owned by mmap_owner +size_t MappedFileIOReader::mmap(void** ptr, size_t size, size_t nitems) { + if (size == 0) { + return nitems; + } + + size_t actual_size = size * nitems; + if (pos + size * nitems > mmap_owner->size()) { + actual_size = mmap_owner->size() - pos; + } + + size_t actual_nitems = (actual_size + size - 1) / size; + if (actual_nitems == 0) { + return 0; + } + + // get an address + *ptr = (void*)(reinterpret_cast(mmap_owner->data()) + pos); + + // alter pos + pos += size * actual_nitems; + + return actual_nitems; +} + +int MappedFileIOReader::filedescriptor() { + // todo + return -1; +} + +} // namespace faiss \ No newline at end of file diff --git a/faiss/impl/mapped_io.h b/faiss/impl/mapped_io.h new file mode 100644 index 0000000000..efbde1e2e7 --- /dev/null +++ b/faiss/impl/mapped_io.h @@ -0,0 +1,44 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace faiss { + +// holds a memory-mapped region over a file +struct MmappedFileMappingOwner : public MaybeOwnedVectorOwner { + MmappedFileMappingOwner(const std::string& filename); + MmappedFileMappingOwner(FILE* f); + ~MmappedFileMappingOwner(); + + void* data() const; + size_t size() const; + + struct PImpl; + std::unique_ptr p_impl; +}; + +// A deserializer that supports memory-mapped files. +// All de-allocations should happen as soon as the index gets destroyed, +// after all underlying the MaybeOwnerVector objects are destroyed. +struct MappedFileIOReader : IOReader { + std::shared_ptr mmap_owner; + + size_t pos = 0; + + MappedFileIOReader(const std::shared_ptr& owner); + + // perform a copy + size_t operator()(void* ptr, size_t size, size_t nitems) override; + // perform a quasi-read that returns a mmapped address, owned by mmap_owner, + // and updates the position + size_t mmap(void** ptr, size_t size, size_t nitems); + + int filedescriptor() override; +}; + +} // namespace faiss \ No newline at end of file diff --git a/faiss/impl/maybe_owned_vector.h b/faiss/impl/maybe_owned_vector.h new file mode 100644 index 0000000000..bca98d19c2 --- /dev/null +++ b/faiss/impl/maybe_owned_vector.h @@ -0,0 +1,232 @@ +#pragma once + +#include +#include +#include +#include + +#include + +namespace faiss { + +// An interface for an owner of a MaybeOwnedVector. +struct MaybeOwnedVectorOwner { + virtual ~MaybeOwnedVectorOwner() = default; +}; + +// a container that either works as std::vector that owns its own memory, +// or as a view of a memory buffer, with a known size +template +struct MaybeOwnedVector { + using value_type = T; + using self_type = MaybeOwnedVector; + using vec_iterator = typename std::vector::const_iterator; + + bool is_owned = true; + + // this one is used if is_owned == true + std::vector owned_data; + + // these three are used if is_owned == false + T* view_data = nullptr; + // the number of T elements + size_t view_size = 0; + // who owns the data. + // This field can be nullptr, and it is present ONLY in order + // to avoid possible tricky memory / resource leaks. + std::shared_ptr owner; + + // points either to view_data, or to owned.data() + T* c_ptr = nullptr; + // uses either view_size, or owned.size(); + size_t c_size = 0; + + MaybeOwnedVector() = default; + MaybeOwnedVector(const size_t initial_size) { + is_owned = true; + + owned_data.resize(initial_size); + c_ptr = owned_data.data(); + c_size = owned_data.size(); + } + + MaybeOwnedVector(const MaybeOwnedVector& other) { + is_owned = other.is_owned; + owned_data = other.owned_data; + + view_data = other.view_data; + view_size = other.view_size; + owner = other.owner; + + if (is_owned) { + c_ptr = owned_data.data(); + c_size = owned_data.size(); + } else { + c_ptr = view_data; + c_size = view_size; + } + } + + MaybeOwnedVector(MaybeOwnedVector&& other) { + is_owned = other.is_owned; + owned_data = std::move(other.owned_data); + + view_data = other.view_data; + view_size = other.view_size; + owner = std::move(other.owner); + other.owner = nullptr; + + if (is_owned) { + c_ptr = owned_data.data(); + c_size = owned_data.size(); + } else { + c_ptr = view_data; + c_size = view_size; + } + } + + MaybeOwnedVector& operator=(const MaybeOwnedVector& other) { + if (this == &other) { + return *this; + } + + // create a copy + MaybeOwnedVector cloned(other); + // swap + swap(*this, cloned); + + return *this; + } + + MaybeOwnedVector& operator=(MaybeOwnedVector&& other) { + if (this == &other) { + return *this; + } + + // moved + MaybeOwnedVector moved(std::move(other)); + // swap + swap(*this, moved); + + return *this; + } + + MaybeOwnedVector(std::vector&& other) { + is_owned = true; + + owned_data = std::move(other); + c_ptr = owned_data.data(); + c_size = owned_data.size(); + } + + static MaybeOwnedVector create_view( + void* address, + const size_t n_elements, + const std::shared_ptr& owner) { + MaybeOwnedVector vec; + vec.is_owned = false; + vec.view_data = reinterpret_cast(address); + vec.view_size = n_elements; + vec.owner = owner; + + vec.c_ptr = vec.view_data; + vec.c_size = vec.view_size; + + return vec; + } + + const T* data() const { + return c_ptr; + } + + T* data() { + return c_ptr; + } + + size_t size() const { + return c_size; + } + + T& operator[](const size_t idx) { + return c_ptr[idx]; + } + + const T& operator[](const size_t idx) const { + return c_ptr[idx]; + } + + vec_iterator begin() const { + FAISS_ASSERT_MSG( + is_owned, + "This operation cannot be performed on a viewed vector"); + + return owned_data.begin(); + } + + vec_iterator end() const { + FAISS_ASSERT_MSG( + is_owned, + "This operation cannot be performed on a viewed vector"); + + return owned_data.end(); + } + + vec_iterator erase(vec_iterator begin, vec_iterator end) { + FAISS_ASSERT_MSG( + is_owned, + "This operation cannot be performed on a viewed vector"); + + return owned_data.erase(begin, end); + } + + void clear() { + FAISS_ASSERT_MSG( + is_owned, + "This operation cannot be performed on a viewed vector"); + + owned_data.clear(); + c_ptr = owned_data.data(); + c_size = owned_data.size(); + } + + void resize(const size_t new_size) { + FAISS_ASSERT_MSG( + is_owned, + "This operation cannot be performed on a viewed vector"); + + owned_data.resize(new_size); + c_ptr = owned_data.data(); + c_size = owned_data.size(); + } + + void resize(const size_t new_size, const value_type v) { + FAISS_ASSERT_MSG( + is_owned, + "This operation cannot be performed on a viewed vector"); + + owned_data.resize(new_size, v); + c_ptr = owned_data.data(); + c_size = owned_data.size(); + } + + friend void swap(self_type& a, self_type& b) { + std::swap(a.is_owned, b.is_owned); + std::swap(a.owned_data, b.owned_data); + std::swap(a.view_data, b.view_data); + std::swap(a.view_size, b.view_size); + std::swap(a.owner, b.owner); + std::swap(a.c_ptr, b.c_ptr); + std::swap(a.c_size, b.c_size); + } +}; + +template +struct is_maybe_owned_vector : std::false_type {}; + +template +struct is_maybe_owned_vector> : std::true_type {}; + +template +inline constexpr bool is_maybe_owned_vector_v = is_maybe_owned_vector::value; + +} // namespace faiss \ No newline at end of file diff --git a/faiss/impl/zerocopy_io.cpp b/faiss/impl/zerocopy_io.cpp new file mode 100644 index 0000000000..31214ce3b2 --- /dev/null +++ b/faiss/impl/zerocopy_io.cpp @@ -0,0 +1,56 @@ +#include +#include + +namespace faiss { + +ZeroCopyIOReader::ZeroCopyIOReader(uint8_t* data, size_t size) + : data_(data), rp_(0), total_(size) {} + +ZeroCopyIOReader::~ZeroCopyIOReader() {} + +size_t ZeroCopyIOReader::get_data_view(void** ptr, size_t size, size_t nitems) { + if (size == 0) { + return nitems; + } + + size_t actual_size = size * nitems; + if (rp_ + size * nitems > total_) { + actual_size = total_ - rp_; + } + + size_t actual_nitems = (actual_size + size - 1) / size; + if (actual_nitems == 0) { + return 0; + } + + // get an address + *ptr = (void*)(reinterpret_cast(data_ + rp_)); + + // alter pos + rp_ += size * actual_nitems; + + return actual_nitems; +} + +void ZeroCopyIOReader::reset() { + rp_ = 0; +} + +size_t ZeroCopyIOReader::operator()(void* ptr, size_t size, size_t nitems) { + if (rp_ >= total_) { + return 0; + } + size_t nremain = (total_ - rp_) / size; + if (nremain < nitems) { + nitems = nremain; + } + memcpy(ptr, (data_ + rp_), size * nitems); + rp_ += size * nitems; + return nitems; +} + +int ZeroCopyIOReader::filedescriptor() { + return -1; // Indicating no file descriptor available for memory buffer +} + +} // namespace faiss \ No newline at end of file diff --git a/faiss/impl/zerocopy_io.h b/faiss/impl/zerocopy_io.h new file mode 100644 index 0000000000..7033b5fc51 --- /dev/null +++ b/faiss/impl/zerocopy_io.h @@ -0,0 +1,25 @@ +#pragma once + +#include + +#include + +namespace faiss { + +// ZeroCopyIOReader just maps the data from a given pointer. +struct ZeroCopyIOReader : public faiss::IOReader { + uint8_t* data_; + size_t rp_ = 0; + size_t total_ = 0; + + ZeroCopyIOReader(uint8_t* data, size_t size); + ~ZeroCopyIOReader(); + + void reset(); + size_t get_data_view(void** ptr, size_t size, size_t nitems); + size_t operator()(void* ptr, size_t size, size_t nitems) override; + + int filedescriptor() override; +}; + +} // namespace faiss \ No newline at end of file diff --git a/faiss/index_io.h b/faiss/index_io.h index b266712af7..f73109fe54 100644 --- a/faiss/index_io.h +++ b/faiss/index_io.h @@ -62,6 +62,10 @@ const int IO_FLAG_PQ_SKIP_SDC_TABLE = 32; // try to memmap data (useful to load an ArrayInvertedLists as an // OnDiskInvertedLists) const int IO_FLAG_MMAP = IO_FLAG_SKIP_IVF_DATA | 0x646f0000; +// mmap that handles codes for IndexFlatCodes-derived indices and HNSW. +// this is a temporary solution, it is expected to be merged with IO_FLAG_MMAP +// after OnDiskInvertedLists get properly updated. +const int IO_FLAG_MMAP_IFC = 1 << 9; Index* read_index(const char* fname, int io_flags = 0); Index* read_index(FILE* f, int io_flags = 0); diff --git a/faiss/python/array_conversions.py b/faiss/python/array_conversions.py index 0c57defe1e..b62c59e4ce 100644 --- a/faiss/python/array_conversions.py +++ b/faiss/python/array_conversions.py @@ -106,6 +106,13 @@ def vector_to_array(v): classname = v.__class__.__name__ if classname.startswith('AlignedTable'): return AlignedTable_to_array(v) + if classname.startswith('MaybeOwnedVector'): + dtype = np.dtype(vector_name_map[classname[16:]]) + a = np.empty(v.size(), dtype=dtype) + if v.size() > 0: + memcpy(swig_ptr(a), v.data(), a.nbytes) + return a + assert classname.endswith('Vector') dtype = np.dtype(vector_name_map[classname[:-6]]) a = np.empty(v.size(), dtype=dtype) @@ -122,6 +129,17 @@ def copy_array_to_vector(a, v): """ copy a numpy array to a vector """ n, = a.shape classname = v.__class__.__name__ + if classname.startswith('MaybeOwnedVector'): + assert v.is_owned, 'cannot copy to an non-owned MaybeOwnedVector' + dtype = np.dtype(vector_name_map[classname[16:]]) + assert dtype == a.dtype, ( + 'cannot copy a %s array to a %s (should be %s)' % ( + a.dtype, classname, dtype)) + v.resize(n) + if n > 0: + memcpy(v.data(), swig_ptr(a), a.nbytes) + return + assert classname.endswith('Vector') dtype = np.dtype(vector_name_map[classname[:-6]]) assert dtype == a.dtype, ( diff --git a/faiss/python/swigfaiss.swig b/faiss/python/swigfaiss.swig index 493e42ef0e..fc0e7c0df4 100644 --- a/faiss/python/swigfaiss.swig +++ b/faiss/python/swigfaiss.swig @@ -81,6 +81,11 @@ typedef uint64_t size_t; #endif +#include + +#include +#include +#include #include #include @@ -506,6 +511,14 @@ void gpu_sync_all_devices() %include +%include + +%ignore faiss::MmappedFileMappingOwner::p_impl; + +%include +%include +%include + %newobject *::get_FlatCodesDistanceComputer() const; %include %include @@ -992,6 +1005,10 @@ faiss::Quantizer * downcast_Quantizer (faiss::Quantizer *aq) %template(AlignedTableUint16) faiss::AlignedTable; %template(AlignedTableFloat32) faiss::AlignedTable; +%template(MaybeOwnedVectorUInt8) faiss::MaybeOwnedVector; +%template(MaybeOwnedVectorInt32) faiss::MaybeOwnedVector; +%template(MaybeOwnedVectorFloat32) faiss::MaybeOwnedVector; + // SWIG seems to have some trouble resolving function template types here, so // declare explicitly diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index dfab76e024..285b9090ed 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -36,6 +36,8 @@ set(FAISS_TEST_SRC test_callback.cpp test_utils.cpp test_hamming.cpp + test_mmap.cpp + test_zerocopy.cpp ) add_executable(faiss_test ${FAISS_TEST_SRC}) diff --git a/tests/test_mmap.cpp b/tests/test_mmap.cpp new file mode 100644 index 0000000000..c549499142 --- /dev/null +++ b/tests/test_mmap.cpp @@ -0,0 +1,146 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace { + +std::vector make_data(const size_t n, const size_t d, size_t seed) { + std::vector database(n * d); + std::mt19937 rng(seed); + std::uniform_real_distribution distrib; + + for (size_t i = 0; i < n * d; i++) { + database[i] = distrib(rng); + } + return database; +} + +} // namespace + +TEST(TestMmap, mmap_flatcodes) { + // the logic is the following: + // 1. generate two flatcodes-based indices, Index1 and Index2 + // 2. serialize both indices into std::vector<> buffers, Buf1 and Buf2 + // 3. save Buf1 into a temporary file, File1 + // 4. deserialize Index1 using mmap feature on File1 into Index1MM + // 5. ensure that Index1MM acts as Index2 if we write the data from Buf2 + // on top of the existing File1 + // 6. ensure that Index1MM acts as Index1 if we write the data from Buf1 + // on top of the existing File1 again + + // generate data + const size_t nt = 1000; + const size_t nq = 10; + const size_t d = 32; + const size_t k = 25; + + std::vector xt1 = make_data(nt, d, 123); + std::vector xt2 = make_data(nt, d, 456); + std::vector xq = make_data(nq, d, 789); + + // ensure that the data is different + ASSERT_NE(xt1, xt2); + + // make index1 and create reference results + faiss::IndexFlatL2 index1(d); + index1.train(nt, xt1.data()); + index1.add(nt, xt1.data()); + + std::vector ref_dis_1(k * nq); + std::vector ref_ids_1(k * nq); + index1.search(nq, xq.data(), k, ref_dis_1.data(), ref_ids_1.data()); + + // make index2 and create reference results + faiss::IndexFlatL2 index2(d); + index2.train(nt, xt2.data()); + index2.add(nt, xt2.data()); + + std::vector ref_dis_2(k * nq); + std::vector ref_ids_2(k * nq); + index2.search(nq, xq.data(), k, ref_dis_2.data(), ref_ids_2.data()); + + // ensure that the results are different + ASSERT_NE(ref_dis_1, ref_dis_2); + ASSERT_NE(ref_ids_1, ref_ids_2); + + // serialize both in a form of vectors + faiss::VectorIOWriter wr1; + faiss::write_index(&index1, &wr1); + + faiss::VectorIOWriter wr2; + faiss::write_index(&index2, &wr2); + + // generate a temporary file and write index1 into it + std::string tmpname = std::tmpnam(nullptr); + + { + std::ofstream ofs(tmpname); + ofs.write((const char*)wr1.data.data(), wr1.data.size()); + } + + // create a mmap index + std::unique_ptr index1mm( + faiss::read_index(tmpname.c_str(), faiss::IO_FLAG_MMAP_IFC)); + + ASSERT_NE(index1mm, nullptr); + + // perform a search + std::vector cand_dis_1(k * nq); + std::vector cand_ids_1(k * nq); + index1mm->search(nq, xq.data(), k, cand_dis_1.data(), cand_ids_1.data()); + + // match vs ref1 + ASSERT_EQ(ref_ids_1, cand_ids_1); + ASSERT_EQ(ref_dis_1, cand_dis_1); + + // ok now, overwrite the internals of the file without recreating it + { + std::ofstream ofs(tmpname); + ofs.seekp(0, std::ios::beg); + + ofs.write((const char*)wr2.data.data(), wr2.data.size()); + } + + // perform a search + std::vector cand_dis_2(k * nq); + std::vector cand_ids_2(k * nq); + index1mm->search(nq, xq.data(), k, cand_dis_2.data(), cand_ids_2.data()); + + // match vs ref1 + ASSERT_EQ(ref_ids_2, cand_ids_2); + ASSERT_EQ(ref_dis_2, cand_dis_2); + + // write back data1 + { + std::ofstream ofs(tmpname); + ofs.seekp(0, std::ios::beg); + + ofs.write((const char*)wr1.data.data(), wr1.data.size()); + } + + // perform a search + std::vector cand_dis_3(k * nq); + std::vector cand_ids_3(k * nq); + index1mm->search(nq, xq.data(), k, cand_dis_3.data(), cand_ids_3.data()); + + // match vs ref1 + ASSERT_EQ(ref_ids_1, cand_ids_3); + ASSERT_EQ(ref_dis_1, cand_dis_3); +} diff --git a/tests/test_zerocopy.cpp b/tests/test_zerocopy.cpp new file mode 100644 index 0000000000..33c4baefff --- /dev/null +++ b/tests/test_zerocopy.cpp @@ -0,0 +1,132 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace { + +std::vector make_data(const size_t n, const size_t d, size_t seed) { + std::vector database(n * d); + std::mt19937 rng(seed); + std::uniform_real_distribution distrib; + + for (size_t i = 0; i < n * d; i++) { + database[i] = distrib(rng); + } + return database; +} + +} // namespace + +TEST(TestZeroCopy, zerocopy_flatcodes) { + // the logic is the following: + // 1. generate two flatcodes-based indices, Index1 and Index2 + // 2. serialize both indices into std::vector<> buffers, Buf1 and Buf2 + // 3. deserialize Index1 using zero-copy feature on Buf1 into Index1ZC + // 4. ensure that Index1ZC acts as Index2 if we write the data from Buf2 + // on top of the existing Buf1 + + // generate data + const size_t nt = 1000; + const size_t nq = 10; + const size_t d = 32; + const size_t k = 25; + + std::vector xt1 = make_data(nt, d, 123); + std::vector xt2 = make_data(nt, d, 456); + std::vector xq = make_data(nq, d, 789); + + // ensure that the data is different + ASSERT_NE(xt1, xt2); + + // make index1 and create reference results + faiss::IndexFlatL2 index1(d); + index1.train(nt, xt1.data()); + index1.add(nt, xt1.data()); + + std::vector ref_dis_1(k * nq); + std::vector ref_ids_1(k * nq); + index1.search(nq, xq.data(), k, ref_dis_1.data(), ref_ids_1.data()); + + // make index2 and create reference results + faiss::IndexFlatL2 index2(d); + index2.train(nt, xt2.data()); + index2.add(nt, xt2.data()); + + std::vector ref_dis_2(k * nq); + std::vector ref_ids_2(k * nq); + index2.search(nq, xq.data(), k, ref_dis_2.data(), ref_ids_2.data()); + + // ensure that the results are different + ASSERT_NE(ref_dis_1, ref_dis_2); + ASSERT_NE(ref_ids_1, ref_ids_2); + + // serialize both in a form of vectors + faiss::VectorIOWriter wr1; + faiss::write_index(&index1, &wr1); + + faiss::VectorIOWriter wr2; + faiss::write_index(&index2, &wr2); + + ASSERT_EQ(wr1.data.size(), wr2.data.size()); + + // clone a buffer + std::vector buffer = wr1.data; + + // create a zero-copy index + faiss::ZeroCopyIOReader reader(buffer.data(), buffer.size()); + std::unique_ptr index1zc(faiss::read_index(&reader)); + + ASSERT_NE(index1zc, nullptr); + + // perform a search + std::vector cand_dis_1(k * nq); + std::vector cand_ids_1(k * nq); + index1zc->search(nq, xq.data(), k, cand_dis_1.data(), cand_ids_1.data()); + + // match vs ref1 + ASSERT_EQ(ref_ids_1, cand_ids_1); + ASSERT_EQ(ref_dis_1, cand_dis_1); + + // overwrite buffer without moving it + for (size_t i = 0; i < buffer.size(); i++) { + buffer[i] = wr2.data[i]; + } + + // perform a search + std::vector cand_dis_2(k * nq); + std::vector cand_ids_2(k * nq); + index1zc->search(nq, xq.data(), k, cand_dis_2.data(), cand_ids_2.data()); + + // match vs ref2 + ASSERT_EQ(ref_ids_2, cand_ids_2); + ASSERT_EQ(ref_dis_2, cand_dis_2); + + // overwrite again + for (size_t i = 0; i < buffer.size(); i++) { + buffer[i] = wr1.data[i]; + } + + // perform a search + std::vector cand_dis_3(k * nq); + std::vector cand_ids_3(k * nq); + index1zc->search(nq, xq.data(), k, cand_dis_3.data(), cand_ids_3.data()); + + // match vs ref1 + ASSERT_EQ(ref_ids_1, cand_ids_3); + ASSERT_EQ(ref_dis_1, cand_dis_3); +} From 10fd7e0a93c3077135d0708c878015d03229b815 Mon Sep 17 00:00:00 2001 From: Matthijs Douze Date: Mon, 10 Mar 2025 10:49:47 -0700 Subject: [PATCH 2/2] mem mapping and zero-copy python fixes (#4212) Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/4212 Add files to TARGETS fix python Reviewed By: mengdilin Differential Revision: D69984379 --- faiss/impl/mapped_io.cpp | 13 +++++++++--- faiss/impl/mapped_io.h | 9 +++++++- faiss/impl/maybe_owned_vector.h | 16 +++++++++++++- faiss/impl/zerocopy_io.cpp | 12 ++++++++++- faiss/impl/zerocopy_io.h | 9 +++++++- faiss/python/swigfaiss.swig | 1 + tests/test_fast_scan_ivf.py | 3 ++- tests/test_io.py | 37 +++++++++++++++++++++++++++++++++ 8 files changed, 92 insertions(+), 8 deletions(-) diff --git a/faiss/impl/mapped_io.cpp b/faiss/impl/mapped_io.cpp index 2e23a8a47e..a62e22dee1 100644 --- a/faiss/impl/mapped_io.cpp +++ b/faiss/impl/mapped_io.cpp @@ -1,3 +1,10 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + #include #include @@ -11,8 +18,8 @@ #elif defined(_WIN32) -#include -#include +#include // @manual +#include // @manual #endif @@ -278,4 +285,4 @@ int MappedFileIOReader::filedescriptor() { return -1; } -} // namespace faiss \ No newline at end of file +} // namespace faiss diff --git a/faiss/impl/mapped_io.h b/faiss/impl/mapped_io.h index efbde1e2e7..0e32df23d8 100644 --- a/faiss/impl/mapped_io.h +++ b/faiss/impl/mapped_io.h @@ -1,3 +1,10 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + #pragma once #include @@ -41,4 +48,4 @@ struct MappedFileIOReader : IOReader { int filedescriptor() override; }; -} // namespace faiss \ No newline at end of file +} // namespace faiss diff --git a/faiss/impl/maybe_owned_vector.h b/faiss/impl/maybe_owned_vector.h index bca98d19c2..2369c4ab4e 100644 --- a/faiss/impl/maybe_owned_vector.h +++ b/faiss/impl/maybe_owned_vector.h @@ -1,3 +1,10 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + #pragma once #include @@ -50,6 +57,13 @@ struct MaybeOwnedVector { c_size = owned_data.size(); } + explicit MaybeOwnedVector(const std::vector& vec) + : faiss::MaybeOwnedVector(vec.size()) { + if (vec.size() > 0) { + memcpy(owned_data.data(), vec.data(), sizeof(T) * vec.size()); + } + } + MaybeOwnedVector(const MaybeOwnedVector& other) { is_owned = other.is_owned; owned_data = other.owned_data; @@ -229,4 +243,4 @@ struct is_maybe_owned_vector> : std::true_type {}; template inline constexpr bool is_maybe_owned_vector_v = is_maybe_owned_vector::value; -} // namespace faiss \ No newline at end of file +} // namespace faiss diff --git a/faiss/impl/zerocopy_io.cpp b/faiss/impl/zerocopy_io.cpp index 31214ce3b2..c754f1f07b 100644 --- a/faiss/impl/zerocopy_io.cpp +++ b/faiss/impl/zerocopy_io.cpp @@ -1,3 +1,10 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + #include #include @@ -37,6 +44,9 @@ void ZeroCopyIOReader::reset() { } size_t ZeroCopyIOReader::operator()(void* ptr, size_t size, size_t nitems) { + if (size * nitems == 0) { + return 0; + } if (rp_ >= total_) { return 0; } @@ -53,4 +63,4 @@ int ZeroCopyIOReader::filedescriptor() { return -1; // Indicating no file descriptor available for memory buffer } -} // namespace faiss \ No newline at end of file +} // namespace faiss diff --git a/faiss/impl/zerocopy_io.h b/faiss/impl/zerocopy_io.h index 7033b5fc51..488b5d1e80 100644 --- a/faiss/impl/zerocopy_io.h +++ b/faiss/impl/zerocopy_io.h @@ -1,3 +1,10 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + #pragma once #include @@ -22,4 +29,4 @@ struct ZeroCopyIOReader : public faiss::IOReader { int filedescriptor() override; }; -} // namespace faiss \ No newline at end of file +} // namespace faiss diff --git a/faiss/python/swigfaiss.swig b/faiss/python/swigfaiss.swig index fc0e7c0df4..9cf81523d8 100644 --- a/faiss/python/swigfaiss.swig +++ b/faiss/python/swigfaiss.swig @@ -32,6 +32,7 @@ #pragma SWIG nowarn=341 #pragma SWIG nowarn=512 #pragma SWIG nowarn=362 +#pragma SWIG nowarn=509 // we need explict control of these typedefs... // %include diff --git a/tests/test_fast_scan_ivf.py b/tests/test_fast_scan_ivf.py index 75c9500f82..a1d6a21440 100644 --- a/tests/test_fast_scan_ivf.py +++ b/tests/test_fast_scan_ivf.py @@ -270,8 +270,9 @@ def test_equiv_pq(self): index_pq = faiss.index_factory(32, "PQ16x4np") index_pq.pq = index.pq index_pq.is_trained = True - index_pq.codes = faiss. downcast_InvertedLists( + codevec = faiss.downcast_InvertedLists( index.invlists).codes.at(0) + index_pq.codes = faiss.MaybeOwnedVectorUInt8(codevec) index_pq.ntotal = index.ntotal Dnew, Inew = index_pq.search(xq, 4) diff --git a/tests/test_io.py b/tests/test_io.py index 3cbd0a6e10..49dc3bf489 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -481,3 +481,40 @@ def test_reader(self): finally: if os.path.exists(fname): os.unlink(fname) + + +class TestIOFlatMMap(unittest.TestCase): + + def test_mmap(self): + xt, xb, xq = get_dataset_2(32, 0, 100, 50) + index = faiss.index_factory(32, "SQfp16", faiss.METRIC_L2) + # does not need training + index.add(xb) + Dref, Iref = index.search(xq, 10) + + fd, fname = tempfile.mkstemp() + os.close(fd) + try: + faiss.write_index(index, fname) + index2 = faiss.read_index(fname, faiss.IO_FLAG_MMAP_IFC) + Dnew, Inew = index2.search(xq, 10) + np.testing.assert_array_equal(Iref, Inew) + np.testing.assert_array_equal(Dref, Dnew) + finally: + if os.path.exists(fname): + os.unlink(fname) + + def test_zerocopy(self): + xt, xb, xq = get_dataset_2(32, 0, 100, 50) + index = faiss.index_factory(32, "SQfp16", faiss.METRIC_L2) + # does not need training + index.add(xb) + Dref, Iref = index.search(xq, 10) + + serialized_index = faiss.serialize_index(index) + reader = faiss.ZeroCopyIOReader( + faiss.swig_ptr(serialized_index), serialized_index.size) + index2 = faiss.read_index(reader) + Dnew, Inew = index2.search(xq, 10) + np.testing.assert_array_equal(Iref, Inew) + np.testing.assert_array_equal(Dref, Dnew)