Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions c_api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ set(FAISS_C_SRC
index_factory_c.cpp
index_io_c.cpp
impl/AuxIndexStructures_c.cpp
impl/io_c.cpp
utils/distances_c.cpp
utils/utils_c.cpp
)
Expand Down
76 changes: 76 additions & 0 deletions c_api/impl/io_c.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

// -*- c++ -*-

#include "io_c.h"
#include <faiss/impl/io.h>
#include "../macros_impl.h"

using faiss::IOReader;
using faiss::IOWriter;

struct CustomIOReader : IOReader {
size_t (*func)(void* ptr, size_t size, size_t nitems) = nullptr;

CustomIOReader(size_t (*func_in)(void* ptr, size_t size, size_t nitems));

size_t operator()(void* ptr, size_t size, size_t nitems) override;
};

CustomIOReader::CustomIOReader(
size_t (*func_in)(void* ptr, size_t size, size_t nitems))
: func(func_in) {}

size_t CustomIOReader::operator()(void* ptr, size_t size, size_t nitems) {
return func(ptr, size, nitems);
}

int faiss_CustomIOReader_new(
FaissCustomIOReader** p_out,
size_t (*func_in)(void* ptr, size_t size, size_t nitems)) {
try {
*p_out = reinterpret_cast<FaissCustomIOReader*>(
new CustomIOReader(func_in));
}
CATCH_AND_HANDLE
}

void faiss_CustomIOReader_free(FaissCustomIOReader* obj) {
delete reinterpret_cast<CustomIOReader*>(obj);
}

struct CustomIOWriter : IOWriter {
size_t (*func)(const void* ptr, size_t size, size_t nitems) = nullptr;

CustomIOWriter(
size_t (*func_in)(const void* ptr, size_t size, size_t nitems));

size_t operator()(const void* ptr, size_t size, size_t nitems) override;
};

CustomIOWriter::CustomIOWriter(
size_t (*func_in)(const void* ptr, size_t size, size_t nitems))
: func(func_in) {}

size_t CustomIOWriter::operator()(const void* ptr, size_t size, size_t nitems) {
return func(ptr, size, nitems);
}

int faiss_CustomIOWriter_new(
FaissCustomIOWriter** p_out,
size_t (*func_in)(const void* ptr, size_t size, size_t nitems)) {
try {
*p_out = reinterpret_cast<FaissCustomIOWriter*>(
new CustomIOWriter(func_in));
}
CATCH_AND_HANDLE
}

void faiss_CustomIOWriter_free(FaissCustomIOWriter* obj) {
delete reinterpret_cast<CustomIOWriter*>(obj);
}
50 changes: 50 additions & 0 deletions c_api/impl/io_c.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

// -*- c -*-

#ifndef FAISS_IO_C_H
#define FAISS_IO_C_H

#include <stddef.h>
#include "../faiss_c.h"

#ifdef __cplusplus
extern "C" {
#endif

FAISS_DECLARE_CLASS(IOReader)
FAISS_DECLARE_DESTRUCTOR(IOReader)

FAISS_DECLARE_CLASS(IOWriter)
FAISS_DECLARE_DESTRUCTOR(IOWriter)

/*******************************************************
* Custom reader + writer
*
* Reader and writer which wraps a function pointer,
* primarily for FFI use.
*******************************************************/

FAISS_DECLARE_CLASS(CustomIOReader)
FAISS_DECLARE_DESTRUCTOR(CustomIOReader)

int faiss_CustomIOReader_new(
FaissCustomIOReader** p_out,
size_t (*func_in)(void* ptr, size_t size, size_t nitems));

FAISS_DECLARE_CLASS(CustomIOWriter)
FAISS_DECLARE_DESTRUCTOR(CustomIOWriter)

int faiss_CustomIOWriter_new(
FaissCustomIOWriter** p_out,
size_t (*func_in)(const void* ptr, size_t size, size_t nitems));

#ifdef __cplusplus
}
#endif
#endif
50 changes: 50 additions & 0 deletions c_api/index_io_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

using faiss::Index;
using faiss::IndexBinary;
using faiss::IOReader;
using faiss::IOWriter;
using faiss::VectorTransform;

int faiss_write_index(const FaissIndex* idx, FILE* f) {
Expand All @@ -30,6 +32,19 @@ int faiss_write_index_fname(const FaissIndex* idx, const char* fname) {
CATCH_AND_HANDLE
}

int faiss_write_index_custom(
const FaissIndex* idx,
FaissIOWriter* io_writer,
int io_flags) {
try {
faiss::write_index(
reinterpret_cast<const Index*>(idx),
reinterpret_cast<IOWriter*>(io_writer),
io_flags);
}
CATCH_AND_HANDLE
}

int faiss_read_index(FILE* f, int io_flags, FaissIndex** p_out) {
try {
auto out = faiss::read_index(f, io_flags);
Expand All @@ -49,6 +64,18 @@ int faiss_read_index_fname(
CATCH_AND_HANDLE
}

int faiss_read_index_custom(
FaissIOReader* io_reader,
int io_flags,
FaissIndex** p_out) {
try {
auto out = faiss::read_index(
reinterpret_cast<IOReader*>(io_reader), io_flags);
*p_out = reinterpret_cast<FaissIndex*>(out);
}
CATCH_AND_HANDLE
}

int faiss_write_index_binary(const FaissIndexBinary* idx, FILE* f) {
try {
faiss::write_index_binary(reinterpret_cast<const IndexBinary*>(idx), f);
Expand All @@ -66,6 +93,17 @@ int faiss_write_index_binary_fname(
CATCH_AND_HANDLE
}

int faiss_write_index_binary_custom(
const FaissIndexBinary* idx,
FaissIOWriter* io_writer) {
try {
faiss::write_index_binary(
reinterpret_cast<const IndexBinary*>(idx),
reinterpret_cast<IOWriter*>(io_writer));
}
CATCH_AND_HANDLE
}

int faiss_read_index_binary(FILE* f, int io_flags, FaissIndexBinary** p_out) {
try {
auto out = faiss::read_index_binary(f, io_flags);
Expand All @@ -85,6 +123,18 @@ int faiss_read_index_binary_fname(
CATCH_AND_HANDLE
}

int faiss_read_index_binary_custom(
FaissIOReader* io_reader,
int io_flags,
FaissIndexBinary** p_out) {
try {
auto out = faiss::read_index_binary(
reinterpret_cast<IOReader*>(io_reader), io_flags);
*p_out = reinterpret_cast<FaissIndexBinary*>(out);
}
CATCH_AND_HANDLE
}

int faiss_read_VectorTransform_fname(
const char* fname,
FaissVectorTransform** p_out) {
Expand Down
28 changes: 28 additions & 0 deletions c_api/index_io_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "Index_c.h"
#include "VectorTransform_c.h"
#include "faiss_c.h"
#include "impl/io_c.h"

#ifdef __cplusplus
extern "C" {
Expand All @@ -32,6 +33,13 @@ int faiss_write_index(const FaissIndex* idx, FILE* f);
*/
int faiss_write_index_fname(const FaissIndex* idx, const char* fname);

/** Write index to a custom writer.
*/
int faiss_write_index_custom(
const FaissIndex* idx,
FaissIOWriter* io_writer,
int io_flags);

#define FAISS_IO_FLAG_MMAP 1
#define FAISS_IO_FLAG_READ_ONLY 2

Expand All @@ -45,6 +53,13 @@ int faiss_read_index(FILE* f, int io_flags, FaissIndex** p_out);
*/
int faiss_read_index_fname(const char* fname, int io_flags, FaissIndex** p_out);

/** Read index from a custom reader.
*/
int faiss_read_index_custom(
FaissIOReader* io_reader,
int io_flags,
FaissIndex** p_out);

/** Write index to a file.
* This is equivalent to `faiss::write_index_binary` when a file descriptor is
* provided.
Expand All @@ -59,6 +74,12 @@ int faiss_write_index_binary_fname(
const FaissIndexBinary* idx,
const char* fname);

/** Write binary index to a custom writer.
*/
int faiss_write_index_binary_custom(
const FaissIndexBinary* idx,
FaissIOWriter* io_writer);

/** Read index from a file.
* This is equivalent to `faiss:read_index_binary` when a file descriptor is
* given.
Expand All @@ -73,6 +94,13 @@ int faiss_read_index_binary_fname(
int io_flags,
FaissIndexBinary** p_out);

/** Read binary index from a custom reader.
*/
int faiss_read_index_binary_custom(
FaissIOReader* io_reader,
int io_flags,
FaissIndexBinary** p_out);

/** Read vector transform from a file.
* This is equivalent to `faiss:read_VectorTransform` when a file path is given.
*/
Expand Down
Loading