diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp index ed7824ebbb..5983e9d831 100644 --- a/faiss/IndexHNSW.cpp +++ b/faiss/IndexHNSW.cpp @@ -210,7 +210,9 @@ IndexHNSW::IndexHNSW(int d, int M, MetricType metric) : Index(d, metric), hnsw(M) {} IndexHNSW::IndexHNSW(Index* storage, int M) - : Index(storage->d, storage->metric_type), hnsw(M), storage(storage) {} + : Index(storage->d, storage->metric_type), hnsw(M), storage(storage) { + metric_arg = storage->metric_arg; +} IndexHNSW::~IndexHNSW() { if (own_fields) { diff --git a/tests/test_hnsw.cpp b/tests/test_hnsw.cpp index 9424bd3499..9c33c08a9e 100644 --- a/tests/test_hnsw.cpp +++ b/tests/test_hnsw.cpp @@ -193,6 +193,27 @@ TEST(HNSW, Test_popmin_infinite_distances) { } } +TEST(HNSW, Test_IndexHNSW_METRIC_Lp) { + // Create an HNSW index with METRIC_Lp and metric_arg = 3 + faiss::IndexFlat storage_index(1, faiss::METRIC_Lp); + storage_index.metric_arg = 3; + faiss::IndexHNSW index(&storage_index, 32); + + // Add a single data point + float data[1] = {0.0}; + index.add(1, data); + + // Prepare a query + float query[1] = {2.0}; + float distance; + faiss::idx_t label; + + index.search(1, query, 1, &distance, &label); + + EXPECT_NEAR(distance, 8.0, 1e-5); // Distance should be 8.0 (2^3) + EXPECT_EQ(label, 0); // Label should be 0 +} + class HNSWTest : public testing::Test { protected: HNSWTest() {