Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions faiss/Index.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#define FAISS_INDEX_H

#include <faiss/MetricType.h>
#include <faiss/impl/FaissAssert.h>

#include <cstdio>
#include <sstream>
#include <string>
Expand Down Expand Up @@ -56,6 +58,23 @@ struct IDSelector;
struct RangeSearchResult;
struct DistanceComputer;

enum NumericType {
Float32,
Float16,
};

inline size_t get_numeric_type_size(NumericType numeric_type) {
switch (numeric_type) {
case NumericType::Float32:
return 4;
case NumericType::Float16:
return 2;
default:
FAISS_THROW_MSG(
"Unknown Numeric Type. Only supports Float32, Float16");
}
}

/** Parent class for the optional search paramenters.
*
* Sub-classes with additional search parameters should inherit this class.
Expand Down Expand Up @@ -107,6 +126,14 @@ struct Index {
*/
virtual void train(idx_t n, const float* x);

virtual void train(idx_t n, const void* x, NumericType numeric_type) {
if (numeric_type == NumericType::Float32) {
train(n, static_cast<const float*>(x));
} else {
FAISS_THROW_MSG("Index::train: unsupported numeric type");
}
}

/** Add n vectors of dimension d to the index.
*
* Vectors are implicitly assigned labels ntotal .. ntotal + n - 1
Expand All @@ -117,6 +144,14 @@ struct Index {
*/
virtual void add(idx_t n, const float* x) = 0;

virtual void add(idx_t n, const void* x, NumericType numeric_type) {
if (numeric_type == NumericType::Float32) {
add(n, static_cast<const float*>(x));
} else {
FAISS_THROW_MSG("Index::add: unsupported numeric type");
}
}

/** Same as add, but stores xids instead of sequential ids.
*
* The default implementation fails with an assertion, as it is
Expand All @@ -127,6 +162,17 @@ struct Index {
* @param xids if non-null, ids to store for the vectors (size n)
*/
virtual void add_with_ids(idx_t n, const float* x, const idx_t* xids);
virtual void add_with_ids(
idx_t n,
const void* x,
NumericType numeric_type,
const idx_t* xids) {
if (numeric_type == NumericType::Float32) {
add_with_ids(n, static_cast<const float*>(x), xids);
} else {
FAISS_THROW_MSG("Index::add_with_ids: unsupported numeric type");
}
}

/** query n vectors of dimension d to the index.
*
Expand All @@ -147,6 +193,26 @@ struct Index {
idx_t* labels,
const SearchParameters* params = nullptr) const = 0;

virtual void search(
idx_t n,
const void* x,
NumericType numeric_type,
idx_t k,
float* distances,
idx_t* labels,
const SearchParameters* params = nullptr) const {
if (numeric_type == NumericType::Float32) {
search(n,
static_cast<const float*>(x),
k,
distances,
labels,
params);
} else {
FAISS_THROW_MSG("Index::search: unsupported numeric type");
}
}

/** query n vectors of dimension d to the index.
*
* return all vectors with distance < radius. Note that many
Expand Down
37 changes: 31 additions & 6 deletions faiss/IndexHNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <random>

#include <cstdint>
#include "faiss/Index.h"

#include <faiss/Index2Layer.h>
#include <faiss/IndexFlat.h>
Expand Down Expand Up @@ -893,15 +894,31 @@ IndexHNSWCagra::IndexHNSWCagra() {
is_trained = true;
}

IndexHNSWCagra::IndexHNSWCagra(int d, int M, MetricType metric)
: IndexHNSW(
(metric == METRIC_L2)
? static_cast<IndexFlat*>(new IndexFlatL2(d))
: static_cast<IndexFlat*>(new IndexFlatIP(d)),
M) {
IndexHNSWCagra::IndexHNSWCagra(
int d,
int M,
MetricType metric,
NumericType numeric_type)
: IndexHNSW(d, M, metric) {
FAISS_THROW_IF_NOT_MSG(
((metric == METRIC_L2) || (metric == METRIC_INNER_PRODUCT)),
"unsupported metric type for IndexHNSWCagra");
numeric_type_ = numeric_type;
if (numeric_type == NumericType::Float32) {
// Use flat storage with full precision for fp32
storage = (metric == METRIC_L2)
? static_cast<Index*>(new IndexFlatL2(d))
: static_cast<Index*>(new IndexFlatIP(d));
} else if (numeric_type == NumericType::Float16) {
auto qtype = ScalarQuantizer::QT_fp16;
storage = new IndexScalarQuantizer(d, qtype, metric);
} else {
FAISS_THROW_MSG(
"Unsupported numeric_type: only F16 and F32 are supported for IndexHNSWCagra");
}

metric_arg = storage->metric_arg;

own_fields = true;
is_trained = true;
init_level0 = true;
Expand Down Expand Up @@ -967,4 +984,12 @@ void IndexHNSWCagra::search(
}
}

faiss::NumericType IndexHNSWCagra::get_numeric_type() const {
return numeric_type_;
}

void IndexHNSWCagra::set_numeric_type(faiss::NumericType numeric_type) {
numeric_type_ = numeric_type;
}

} // namespace faiss
11 changes: 10 additions & 1 deletion faiss/IndexHNSW.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#pragma once

#include <vector>
#include "faiss/Index.h"

#include <faiss/IndexFlat.h>
#include <faiss/IndexPQ.h>
Expand Down Expand Up @@ -170,7 +171,11 @@ struct IndexHNSW2Level : IndexHNSW {

struct IndexHNSWCagra : IndexHNSW {
IndexHNSWCagra();
IndexHNSWCagra(int d, int M, MetricType metric = METRIC_L2);
IndexHNSWCagra(
int d,
int M,
MetricType metric = METRIC_L2,
NumericType numeric_type = NumericType::Float32);

/// When set to true, the index is immutable.
/// This option is used to copy the knn graph from GpuIndexCagra
Expand All @@ -195,6 +200,10 @@ struct IndexHNSWCagra : IndexHNSW {
float* distances,
idx_t* labels,
const SearchParameters* params = nullptr) const override;

faiss::NumericType get_numeric_type() const;
void set_numeric_type(faiss::NumericType numeric_type);
NumericType numeric_type_;
};

} // namespace faiss
5 changes: 4 additions & 1 deletion faiss/gpu/GpuCloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ Index* ToCPUCloner::clone_Index(const Index* index) {
#if defined USE_NVIDIA_CUVS
else if (auto icg = dynamic_cast<const GpuIndexCagra*>(index)) {
IndexHNSWCagra* res = new IndexHNSWCagra();
if (icg->get_numeric_type() == faiss::NumericType::Float16) {
res->base_level_only = true;
}
icg->copyTo(res);
return res;
}
Expand Down Expand Up @@ -236,7 +239,7 @@ Index* ToGpuCloner::clone_Index(const Index* index) {
config.device = device;
GpuIndexCagra* res =
new GpuIndexCagra(provider, icg->d, icg->metric_type, config);
res->copyFrom(icg);
res->copyFrom(icg, icg->get_numeric_type());
return res;
}
#endif
Expand Down
Loading
Loading