Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions faiss/IVFlib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
#include <faiss/IndexPreTransform.h>
#include <faiss/IndexRefine.h>
#include <faiss/MetaIndexes.h>
#include <faiss/clone_index.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/index_io.h>
#include <faiss/utils/distances.h>
#include <faiss/utils/hamming.h>
#include <faiss/utils/utils.h>
Expand Down Expand Up @@ -519,5 +521,146 @@ void ivf_residual_add_from_flat_codes(
index->ntotal += nb;
}

int64_t DefaultShardingFunction::operator()(int64_t i, int64_t shard_count) {
return i % shard_count;
}

void handle_ivf(
faiss::IndexIVF* index,
int64_t shard_count,
const std::string& filename_template,
ShardingFunction* sharding_function) {
std::vector<faiss::IndexIVF*> sharded_indexes(shard_count);
auto clone = static_cast<faiss::IndexIVF*>(faiss::clone_index(index));
clone->quantizer->reset();
for (int64_t i = 0; i < shard_count; i++) {
sharded_indexes[i] =
static_cast<faiss::IndexIVF*>(faiss::clone_index(clone));
}

// assign centroids to each sharded Index based on sharding_function, and
// add them to the quantizer of each sharded index
std::vector<std::vector<float>> sharded_centroids(shard_count);
for (int64_t i = 0; i < index->quantizer->ntotal; i++) {
int64_t shard_id = (*sharding_function)(i, shard_count);
float* reconstructed = new float[index->quantizer->d];
index->quantizer->reconstruct(i, reconstructed);
sharded_centroids[shard_id].insert(
sharded_centroids[shard_id].end(),
&reconstructed[0],
&reconstructed[index->quantizer->d]);
delete[] reconstructed;
}
for (int64_t i = 0; i < shard_count; i++) {
sharded_indexes[i]->quantizer->add(
sharded_centroids[i].size() / index->quantizer->d,
sharded_centroids[i].data());
}

for (int64_t i = 0; i < shard_count; i++) {
char fname[256];
snprintf(fname, 256, filename_template.c_str(), i);
faiss::write_index(sharded_indexes[i], fname);
}

for (int64_t i = 0; i < shard_count; i++) {
delete sharded_indexes[i];
}
}

void handle_binary_ivf(
faiss::IndexBinaryIVF* index,
int64_t shard_count,
const std::string& filename_template,
ShardingFunction* sharding_function) {
std::vector<faiss::IndexBinaryIVF*> sharded_indexes(shard_count);

auto clone = static_cast<faiss::IndexBinaryIVF*>(
faiss::clone_binary_index(index));
clone->quantizer->reset();

for (int64_t i = 0; i < shard_count; i++) {
sharded_indexes[i] = static_cast<faiss::IndexBinaryIVF*>(
faiss::clone_binary_index(clone));
}

// assign centroids to each sharded Index based on sharding_function, and
// add them to the quantizer of each sharded index
int64_t reconstruction_size = index->quantizer->d / 8;
std::vector<std::vector<uint8_t>> sharded_centroids(shard_count);
for (int64_t i = 0; i < index->quantizer->ntotal; i++) {
int64_t shard_id = (*sharding_function)(i, shard_count);
uint8_t* reconstructed = new uint8_t[reconstruction_size];
index->quantizer->reconstruct(i, reconstructed);
sharded_centroids[shard_id].insert(
sharded_centroids[shard_id].end(),
&reconstructed[0],
&reconstructed[reconstruction_size]);
delete[] reconstructed;
}
for (int64_t i = 0; i < shard_count; i++) {
sharded_indexes[i]->quantizer->add(
sharded_centroids[i].size() / reconstruction_size,
sharded_centroids[i].data());
}

for (int64_t i = 0; i < shard_count; i++) {
char fname[256];
snprintf(fname, 256, filename_template.c_str(), i);
faiss::write_index_binary(sharded_indexes[i], fname);
}

for (int64_t i = 0; i < shard_count; i++) {
delete sharded_indexes[i];
}
}

template <typename IndexType>
void sharding_helper(
IndexType* index,
int64_t shard_count,
const std::string& filename_template,
ShardingFunction* sharding_function) {
FAISS_THROW_IF_MSG(index->quantizer->ntotal == 0, "No centroids to shard.");
FAISS_THROW_IF_MSG(
filename_template.find("%d") == std::string::npos,
"Invalid filename_template. Must contain format specifier for shard count.");

DefaultShardingFunction default_sharding_function;
if (sharding_function == nullptr) {
sharding_function = &default_sharding_function;
}

if (typeid(IndexType) == typeid(faiss::IndexIVF)) {
handle_ivf(
dynamic_cast<faiss::IndexIVF*>(index),
shard_count,
filename_template,
sharding_function);
} else if (typeid(IndexType) == typeid(faiss::IndexBinaryIVF)) {
handle_binary_ivf(
dynamic_cast<faiss::IndexBinaryIVF*>(index),
shard_count,
filename_template,
sharding_function);
}
}

void shard_ivf_index_centroids(
faiss::IndexIVF* index,
int64_t shard_count,
const std::string& filename_template,
ShardingFunction* sharding_function) {
sharding_helper(index, shard_count, filename_template, sharding_function);
}

void shard_binary_ivf_index_centroids(
faiss::IndexBinaryIVF* index,
int64_t shard_count,
const std::string& filename_template,
ShardingFunction* sharding_function) {
sharding_helper(index, shard_count, filename_template, sharding_function);
}

} // namespace ivflib
} // namespace faiss
38 changes: 38 additions & 0 deletions faiss/IVFlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* IndexIVFs embedded within an IndexPreTransform.
*/

#include <faiss/IndexBinaryIVF.h>
#include <faiss/IndexIVF.h>
#include <vector>

Expand Down Expand Up @@ -167,6 +168,43 @@ void ivf_residual_add_from_flat_codes(
const uint8_t* codes,
int64_t code_size = -1);

struct ShardingFunction {
virtual int64_t operator()(int64_t i, int64_t shard_count) = 0;
virtual ~ShardingFunction() = default;
ShardingFunction() {}
ShardingFunction(const ShardingFunction&) = default;
ShardingFunction(ShardingFunction&&) = default;
ShardingFunction& operator=(const ShardingFunction&) = default;
ShardingFunction& operator=(ShardingFunction&&) = default;
};
struct DefaultShardingFunction : ShardingFunction {
int64_t operator()(int64_t i, int64_t shard_count) override;
};

/**
* Shards an IVF index centroids by the given sharding function, and writes
* the index to the path given by filename_generator. The centroids must already
* be added to the index quantizer.
*
* @param index The IVF index containing centroids to shard.
* @param shard_count Number of shards.
* @param filename_template Template for shard filenames.
* @param sharding_function The function to shard by. The default is ith vector
* mod shard_count.
* @return The number of shards written.
*/
void shard_ivf_index_centroids(
IndexIVF* index,
int64_t shard_count = 20,
const std::string& filename_template = "shard.%d.index",
ShardingFunction* sharding_function = nullptr);

void shard_binary_ivf_index_centroids(
faiss::IndexBinaryIVF* index,
int64_t shard_count = 20,
const std::string& filename_template = "shard.%d.index",
ShardingFunction* sharding_function = nullptr);

} // namespace ivflib
} // namespace faiss

Expand Down
34 changes: 34 additions & 0 deletions faiss/clone_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include <faiss/IndexAdditiveQuantizerFastScan.h>
#include <faiss/IndexBinary.h>
#include <faiss/IndexBinaryFlat.h>
#include <faiss/IndexBinaryHNSW.h>
#include <faiss/IndexBinaryIVF.h>
#include <faiss/IndexFlat.h>
#include <faiss/IndexHNSW.h>
#include <faiss/IndexIVF.h>
Expand Down Expand Up @@ -107,6 +109,11 @@ IndexIVF* Cloner::clone_IndexIVF(const IndexIVF* ivf) {
return nullptr;
}

IndexBinaryIVF* clone_IndexBinaryIVF(const IndexBinaryIVF* ivf) {
TRYCLONE(IndexBinaryIVF, ivf)
return nullptr;
}

IndexRefine* clone_IndexRefine(const IndexRefine* ir) {
TRYCLONE(IndexRefineFlat, ir)
TRYCLONE(IndexRefine, ir) {
Expand All @@ -131,6 +138,11 @@ IndexHNSW* clone_IndexHNSW(const IndexHNSW* ihnsw) {
}
}

IndexBinaryHNSW* clone_IndexBinaryHNSW(const IndexBinaryHNSW* ihnsw) {
TRYCLONE(IndexBinaryHNSW, ihnsw)
return nullptr;
}

IndexNNDescent* clone_IndexNNDescent(const IndexNNDescent* innd) {
TRYCLONE(IndexNNDescentFlat, innd)
TRYCLONE(IndexNNDescent, innd) {
Expand Down Expand Up @@ -385,6 +397,28 @@ Quantizer* clone_Quantizer(const Quantizer* quant) {
IndexBinary* clone_binary_index(const IndexBinary* index) {
if (auto ii = dynamic_cast<const IndexBinaryFlat*>(index)) {
return new IndexBinaryFlat(*ii);
} else if (
const IndexBinaryIVF* ivf =
dynamic_cast<const IndexBinaryIVF*>(index)) {
IndexBinaryIVF* res = clone_IndexBinaryIVF(ivf);
if (ivf->invlists == nullptr) {
res->invlists = nullptr;
} else {
res->invlists = clone_InvertedLists(ivf->invlists);
res->own_invlists = true;
}

res->own_fields = true;
res->quantizer = clone_binary_index(ivf->quantizer);

return res;
} else if (
const IndexBinaryHNSW* ihnsw =
dynamic_cast<const IndexBinaryHNSW*>(index)) {
IndexBinaryHNSW* res = clone_IndexBinaryHNSW(ihnsw);
res->own_fields = true;
res->storage = clone_binary_index(ihnsw->storage);
return res;
} else {
FAISS_THROW_MSG("cannot clone this type of index");
}
Expand Down
3 changes: 2 additions & 1 deletion faiss/python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
class_wrappers.handle_Linear(Linear)
class_wrappers.handle_QINCo(QINCo)
class_wrappers.handle_QINCoStep(QINCoStep)
shard_ivf_index_centroids = class_wrappers.handle_shard_ivf_index_centroids(shard_ivf_index_centroids)


this_module = sys.modules[__name__]
Expand Down Expand Up @@ -170,7 +171,7 @@ def replacement_function(*args):
add_ref_in_constructor(GpuIndexIVFPQ, 1)
add_ref_in_constructor(GpuIndexIVFScalarQuantizer, 1)
except NameError as e:
logger.info("Failed to load GPU Faiss: %s. Will not load constructor refs for GPU indexes." % e.args[0])
logger.info("Failed to load GPU Faiss: %s. Will not load constructor refs for GPU indexes. This is only an error if you're trying to use GPU Faiss." % e.args[0])

add_ref_in_constructor(IndexIVFFlat, 0)
add_ref_in_constructor(IndexIVFFlatDedup, 0)
Expand Down
9 changes: 9 additions & 0 deletions faiss/python/class_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1395,3 +1395,12 @@ def from_torch(self, qinco):

the_class.__init__ = replacement_init
the_class.from_torch = from_torch


def handle_shard_ivf_index_centroids(func):
def wrapper(*args, **kwargs):
args = list(args)
if len(args) > 3 and args[3] is not None:
args[3] = faiss.PyCallbackShardingFunction(args[3])
return func(*args, **kwargs)
return wrapper
24 changes: 24 additions & 0 deletions faiss/python/python_callbacks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,27 @@ PyCallbackIDSelector::~PyCallbackIDSelector() {
PyThreadLock gil;
Py_DECREF(callback);
}

/***********************************************************
* Callbacks for IVF index sharding
***********************************************************/

PyCallbackShardingFunction::PyCallbackShardingFunction(PyObject* callback)
: callback(callback) {
PyThreadLock gil;
Py_INCREF(callback);
}

int64_t PyCallbackShardingFunction::operator()(int64_t i, int64_t shard_count) {
PyThreadLock gil;
PyObject* shard_id = PyObject_CallFunction(callback, "LL", i, shard_count);
if (shard_id == nullptr) {
FAISS_THROW_MSG("propagate py error");
}
return PyLong_AsLongLong(shard_id);
}

PyCallbackShardingFunction::~PyCallbackShardingFunction() {
PyThreadLock gil;
Py_DECREF(callback);
}
22 changes: 22 additions & 0 deletions faiss/python/python_callbacks.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#pragma once

#include <faiss/IVFlib.h>
#include <faiss/impl/IDSelector.h>
#include <faiss/impl/io.h>
#include <faiss/invlists/InvertedLists.h>
Expand Down Expand Up @@ -58,3 +59,24 @@ struct PyCallbackIDSelector : faiss::IDSelector {

~PyCallbackIDSelector() override;
};

/***********************************************************
* Callbacks for IVF index sharding
***********************************************************/

struct PyCallbackShardingFunction : faiss::ivflib::ShardingFunction {
PyObject* callback;

explicit PyCallbackShardingFunction(PyObject* callback);

int64_t operator()(int64_t i, int64_t shard_count) override;

~PyCallbackShardingFunction() override;

PyCallbackShardingFunction(const PyCallbackShardingFunction&) = delete;
PyCallbackShardingFunction(PyCallbackShardingFunction&&) noexcept = default;
PyCallbackShardingFunction& operator=(const PyCallbackShardingFunction&) =
default;
PyCallbackShardingFunction& operator=(PyCallbackShardingFunction&&) =
default;
};
Loading
Loading