diff --git a/src/search/hnsw_indexer.cc b/src/search/hnsw_indexer.cc new file mode 100644 index 00000000000..f03e4c9580c --- /dev/null +++ b/src/search/hnsw_indexer.cc @@ -0,0 +1,552 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include "hnsw_indexer.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "db_util.h" + +namespace redis { + +HnswNode::HnswNode(NodeKey key, uint16_t level) : key(std::move(key)), level(level) {} + +StatusOr HnswNode::DecodeMetadata(const SearchKey& search_key, engine::Storage* storage) const { + auto node_index_key = search_key.ConstructHnswNode(level, key); + rocksdb::PinnableSlice value; + auto s = storage->Get(rocksdb::ReadOptions(), storage->GetCFHandle(ColumnFamilyID::Search), node_index_key, &value); + if (!s.ok()) return {Status::NotOK, s.ToString()}; + + HnswNodeFieldMetadata metadata; + s = metadata.Decode(&value); + if (!s.ok()) return {Status::NotOK, s.ToString()}; + return metadata; +} + +void HnswNode::PutMetadata(HnswNodeFieldMetadata* node_meta, const SearchKey& search_key, engine::Storage* storage, + rocksdb::WriteBatchBase* batch) const { + std::string updated_metadata; + node_meta->Encode(&updated_metadata); + batch->Put(storage->GetCFHandle(ColumnFamilyID::Search), search_key.ConstructHnswNode(level, key), updated_metadata); +} + +void HnswNode::DecodeNeighbours(const SearchKey& search_key, engine::Storage* storage) { + neighbours.clear(); + auto edge_prefix = search_key.ConstructHnswEdgeWithSingleEnd(level, key); + util::UniqueIterator iter(storage, storage->DefaultScanOptions(), ColumnFamilyID::Search); + for (iter->Seek(edge_prefix); iter->Valid(); iter->Next()) { + if (!iter->key().starts_with(edge_prefix)) { + break; + } + auto neighbour_edge = iter->key(); + neighbour_edge.remove_prefix(edge_prefix.size()); + Slice neighbour; + GetSizedString(&neighbour_edge, &neighbour); + neighbours.push_back(neighbour.ToString()); + } +} + +Status HnswNode::AddNeighbour(const NodeKey& neighbour_key, const SearchKey& search_key, engine::Storage* storage, + rocksdb::WriteBatchBase* batch) const { + auto edge_index_key = search_key.ConstructHnswEdge(level, key, neighbour_key); + batch->Put(storage->GetCFHandle(ColumnFamilyID::Search), edge_index_key, Slice()); + + HnswNodeFieldMetadata node_metadata = GET_OR_RET(DecodeMetadata(search_key, storage)); + node_metadata.num_neighbours++; + PutMetadata(&node_metadata, search_key, storage, batch); + return Status::OK(); +} + +Status HnswNode::RemoveNeighbour(const NodeKey& neighbour_key, const SearchKey& search_key, engine::Storage* storage, + rocksdb::WriteBatchBase* batch) const { + auto edge_index_key = search_key.ConstructHnswEdge(level, key, neighbour_key); + auto s = batch->Delete(storage->GetCFHandle(ColumnFamilyID::Search), edge_index_key); + if (!s.ok()) { + return {Status::NotOK, fmt::format("failed to delete edge, {}", s.ToString())}; + } + + HnswNodeFieldMetadata node_metadata = GET_OR_RET(DecodeMetadata(search_key, storage)); + node_metadata.num_neighbours--; + PutMetadata(&node_metadata, search_key, storage, batch); + return Status::OK(); +} + +Status VectorItem::Create(NodeKey key, const kqir::NumericArray& vector, const HnswVectorFieldMetadata* metadata, + VectorItem* out) { + if (metadata->dim != vector.size()) { + return {Status::InvalidArgument, "VectorItem's metadata dimension must be consistent with the vector itself."}; + } + + *out = VectorItem(std::move(key), vector, metadata); + return Status::OK(); +} + +Status VectorItem::Create(NodeKey key, kqir::NumericArray&& vector, const HnswVectorFieldMetadata* metadata, + VectorItem* out) { + if (metadata->dim != vector.size()) { + return {Status::InvalidArgument, "VectorItem's metadata dimension must be consistent with the vector itself."}; + } + + *out = VectorItem(std::move(key), std::move(vector), metadata); + return Status::OK(); +} + +bool VectorItem::operator==(const VectorItem& other) const { return key == other.key; } + +bool VectorItem::operator<(const VectorItem& other) const { return key < other.key; } + +VectorItem::VectorItem(NodeKey&& key, const kqir::NumericArray& vector, const HnswVectorFieldMetadata* metadata) + : key(std::move(key)), vector(vector), metadata(metadata) {} + +VectorItem::VectorItem(NodeKey&& key, kqir::NumericArray&& vector, const HnswVectorFieldMetadata* metadata) + : key(std::move(key)), vector(std::move(vector)), metadata(metadata) {} + +StatusOr ComputeSimilarity(const VectorItem& left, const VectorItem& right) { + if (left.metadata->distance_metric != right.metadata->distance_metric || left.metadata->dim != right.metadata->dim) + return {Status::InvalidArgument, "Vectors must be of the same metric and dimension to compute distance."}; + + auto metric = left.metadata->distance_metric; + auto dim = left.metadata->dim; + + switch (metric) { + case DistanceMetric::L2: { + double dist = 0.0; + for (auto i = 0; i < dim; i++) { + double diff = left.vector[i] - right.vector[i]; + dist += diff * diff; + } + return std::sqrt(dist); + } + case DistanceMetric::IP: { + double dist = 0.0; + for (auto i = 0; i < dim; i++) { + dist += left.vector[i] * right.vector[i]; + } + return -dist; + } + case DistanceMetric::COSINE: { + double dist = 0.0; + double norm_left = 0.0; + double norm_right = 0.0; + for (auto i = 0; i < dim; i++) { + dist += left.vector[i] * right.vector[i]; + norm_left += left.vector[i] * left.vector[i]; + norm_right += right.vector[i] * right.vector[i]; + } + auto similarity = dist / std::sqrt(norm_left * norm_right); + return 1.0 - similarity; + } + default: + __builtin_unreachable(); + } +} + +HnswIndex::HnswIndex(const SearchKey& search_key, HnswVectorFieldMetadata* vector, engine::Storage* storage) + : search_key(search_key), + metadata(vector), + storage(storage), + m_level_normalization_factor(1.0 / std::log(metadata->m)) { + std::random_device rand_dev; + generator = std::mt19937(rand_dev()); +} + +uint16_t HnswIndex::RandomizeLayer() { + std::uniform_real_distribution level_dist(0.0, 1.0); + double r = level_dist(generator); + double log_val = -std::log(r); + double layer_val = log_val * m_level_normalization_factor; + return static_cast(std::floor(layer_val)); +} + +StatusOr HnswIndex::DefaultEntryPoint(uint16_t level) const { + auto prefix = search_key.ConstructHnswLevelNodePrefix(level); + util::UniqueIterator it(storage, storage->DefaultScanOptions(), ColumnFamilyID::Search); + it->Seek(prefix); + + Slice node_key; + Slice node_key_dst; + if (it->Valid() && it->key().starts_with(prefix)) { + std::string node_key_str = it->key().ToString().substr(prefix.size()); + node_key = Slice(node_key_str); + if (!GetSizedString(&node_key, &node_key_dst)) { + return {Status::NotOK, fmt::format("fail to decode the default node key layer {}", level)}; + } + return node_key_dst.ToString(); + } + return {Status::NotFound, fmt::format("No node found in layer {}", level)}; +} + +StatusOr> HnswIndex::DecodeNodesToVectorItems(const std::vector& node_keys, + uint16_t level, const SearchKey& search_key, + engine::Storage* storage, + const HnswVectorFieldMetadata* metadata) { + std::vector vector_items; + vector_items.reserve(node_keys.size()); + + for (const auto& neighbour_key : node_keys) { + HnswNode neighbour_node(neighbour_key, level); + auto neighbour_metadata_status = neighbour_node.DecodeMetadata(search_key, storage); + if (!neighbour_metadata_status.IsOK()) { + continue; // Skip this neighbour if metadata can't be decoded + } + auto neighbour_metadata = neighbour_metadata_status.GetValue(); + VectorItem item; + GET_OR_RET(VectorItem::Create(neighbour_key, std::move(neighbour_metadata.vector), metadata, &item)); + vector_items.emplace_back(std::move(item)); + } + return vector_items; +} + +Status HnswIndex::AddEdge(const NodeKey& node_key1, const NodeKey& node_key2, uint16_t layer, + ObserverOrUniquePtr& batch) const { + auto edge_index_key1 = search_key.ConstructHnswEdge(layer, node_key1, node_key2); + auto s = batch->Put(storage->GetCFHandle(ColumnFamilyID::Search), edge_index_key1, Slice()); + if (!s.ok()) { + return {Status::NotOK, fmt::format("failed to add edge, {}", s.ToString())}; + } + + auto edge_index_key2 = search_key.ConstructHnswEdge(layer, node_key2, node_key1); + s = batch->Put(storage->GetCFHandle(ColumnFamilyID::Search), edge_index_key2, Slice()); + if (!s.ok()) { + return {Status::NotOK, fmt::format("failed to add edge, {}", s.ToString())}; + } + return Status::OK(); +} + +Status HnswIndex::RemoveEdge(const NodeKey& node_key1, const NodeKey& node_key2, uint16_t layer, + ObserverOrUniquePtr& batch) const { + auto edge_index_key1 = search_key.ConstructHnswEdge(layer, node_key1, node_key2); + auto s = batch->Delete(storage->GetCFHandle(ColumnFamilyID::Search), edge_index_key1); + if (!s.ok()) { + return {Status::NotOK, fmt::format("failed to delete edge, {}", s.ToString())}; + } + + auto edge_index_key2 = search_key.ConstructHnswEdge(layer, node_key2, node_key1); + s = batch->Delete(storage->GetCFHandle(ColumnFamilyID::Search), edge_index_key2); + if (!s.ok()) { + return {Status::NotOK, fmt::format("failed to delete edge, {}", s.ToString())}; + } + return Status::OK(); +} + +StatusOr> HnswIndex::SelectNeighbors(const VectorItem& vec, + const std::vector& vertors, + uint16_t layer) const { + std::vector> distances; + distances.reserve(vertors.size()); + for (const auto& candidate : vertors) { + auto dist = GET_OR_RET(ComputeSimilarity(vec, candidate)); + distances.emplace_back(dist, candidate); + } + + std::sort(distances.begin(), distances.end()); + std::vector selected_vs; + + selected_vs.reserve(vertors.size()); + uint16_t m_max = layer != 0 ? metadata->m : 2 * metadata->m; + for (auto i = 0; i < std::min(m_max, (uint16_t)distances.size()); i++) { + selected_vs.push_back(distances[i].second); + } + return selected_vs; +} + +StatusOr> HnswIndex::SearchLayer(uint16_t level, const VectorItem& target_vector, + uint32_t ef_runtime, + const std::vector& entry_points) const { + std::vector candidates; + std::unordered_set visited; + std::priority_queue, std::vector>, std::greater<>> + explore_heap; + std::priority_queue> result_heap; + + for (const auto& entry_point_key : entry_points) { + HnswNode entry_node = HnswNode(entry_point_key, level); + auto entry_node_metadata = GET_OR_RET(entry_node.DecodeMetadata(search_key, storage)); + + VectorItem entry_point_vector; + GET_OR_RET( + VectorItem::Create(entry_point_key, std::move(entry_node_metadata.vector), metadata, &entry_point_vector)); + auto dist = GET_OR_RET(ComputeSimilarity(target_vector, entry_point_vector)); + + explore_heap.push(std::make_pair(dist, entry_point_vector)); + result_heap.push(std::make_pair(dist, std::move(entry_point_vector))); + visited.insert(entry_point_key); + } + + while (!explore_heap.empty()) { + auto [dist, current_vector] = explore_heap.top(); + explore_heap.pop(); + if (dist > result_heap.top().first) { + break; + } + + auto current_node = HnswNode(current_vector.key, level); + current_node.DecodeNeighbours(search_key, storage); + + for (const auto& neighbour_key : current_node.neighbours) { + if (visited.find(neighbour_key) != visited.end()) { + continue; + } + visited.insert(neighbour_key); + + auto neighbour_node = HnswNode(neighbour_key, level); + auto neighbour_node_metadata = GET_OR_RET(neighbour_node.DecodeMetadata(search_key, storage)); + + VectorItem neighbour_node_vector; + GET_OR_RET(VectorItem::Create(neighbour_key, std::move(neighbour_node_metadata.vector), metadata, + &neighbour_node_vector)); + + auto dist = GET_OR_RET(ComputeSimilarity(target_vector, neighbour_node_vector)); + explore_heap.push(std::make_pair(dist, neighbour_node_vector)); + result_heap.push(std::make_pair(dist, neighbour_node_vector)); + while (result_heap.size() > ef_runtime) { + result_heap.pop(); + } + } + } + + while (!result_heap.empty()) { + candidates.push_back(result_heap.top().second); + result_heap.pop(); + } + + std::reverse(candidates.begin(), candidates.end()); + return candidates; +} + +Status HnswIndex::InsertVectorEntryInternal(std::string_view key, const kqir::NumericArray& vector, + ObserverOrUniquePtr& batch, + uint16_t target_level) const { + auto cf_handle = storage->GetCFHandle(ColumnFamilyID::Search); + VectorItem inserted_vector_item; + GET_OR_RET(VectorItem::Create(std::string(key), vector, metadata, &inserted_vector_item)); + std::vector nearest_vec_items; + + if (metadata->num_levels != 0) { + auto level = metadata->num_levels - 1; + + auto default_entry_node = GET_OR_RET(DefaultEntryPoint(level)); + std::vector entry_points{default_entry_node}; + + for (; level > target_level; level--) { + nearest_vec_items = GET_OR_RET(SearchLayer(level, inserted_vector_item, metadata->ef_runtime, entry_points)); + entry_points = {nearest_vec_items[0].key}; + } + + for (; level >= 0; level--) { + nearest_vec_items = GET_OR_RET(SearchLayer(level, inserted_vector_item, metadata->ef_construction, entry_points)); + auto candidate_vec_items = GET_OR_RET(SelectNeighbors(inserted_vector_item, nearest_vec_items, level)); + auto node = HnswNode(std::string(key), level); + auto m_max = level == 0 ? 2 * metadata->m : metadata->m; + + std::unordered_set connected_edges_set; + std::unordered_map> deleted_edges_map; + + // Check if candidate node has room for more outgoing edges + auto has_room_for_more_edges = [&](uint16_t candidate_node_num_neighbours) { + return candidate_node_num_neighbours < m_max; + }; + + // Check if candidate node has room after some other nodes' are pruned in current batch + auto has_room_after_deletions = [&](const HnswNode& candidate_node, uint16_t candidate_node_num_neighbours) { + auto it = deleted_edges_map.find(candidate_node.key); + if (it != deleted_edges_map.end()) { + auto num_deleted_edges = static_cast(it->second.size()); + return (candidate_node_num_neighbours - num_deleted_edges) < m_max; + } + return false; + }; + + for (const auto& candidate_vec : candidate_vec_items) { + auto candidate_node = HnswNode(candidate_vec.key, level); + auto candidate_node_metadata = GET_OR_RET(candidate_node.DecodeMetadata(search_key, storage)); + uint16_t candidate_node_num_neighbours = candidate_node_metadata.num_neighbours; + + if (has_room_for_more_edges(candidate_node_num_neighbours) || + has_room_after_deletions(candidate_node, candidate_node_num_neighbours)) { + GET_OR_RET(AddEdge(inserted_vector_item.key, candidate_node.key, level, batch)); + connected_edges_set.insert(candidate_node.key); + continue; + } + + // Re-evaluate the neighbours for the candidate node + candidate_node.DecodeNeighbours(search_key, storage); + auto candidate_node_neighbour_vec_items = + GET_OR_RET(DecodeNodesToVectorItems(candidate_node.neighbours, level, search_key, storage, metadata)); + candidate_node_neighbour_vec_items.push_back(inserted_vector_item); + auto sorted_neighbours_by_distance = + GET_OR_RET(SelectNeighbors(candidate_vec, candidate_node_neighbour_vec_items, level)); + + bool inserted_node_is_selected = + std::find(sorted_neighbours_by_distance.begin(), sorted_neighbours_by_distance.end(), + inserted_vector_item) != sorted_neighbours_by_distance.end(); + + if (inserted_node_is_selected) { + // Add the edge between candidate and inserted node + GET_OR_RET(AddEdge(inserted_vector_item.key, candidate_node.key, level, batch)); + connected_edges_set.insert(candidate_node.key); + + auto find_deleted_item = [&](const std::vector& candidate_neighbours, + const std::vector& selected_neighbours) -> VectorItem { + auto it = + std::find_if(candidate_neighbours.begin(), candidate_neighbours.end(), [&](const VectorItem& item) { + return std::find(selected_neighbours.begin(), selected_neighbours.end(), item) == + selected_neighbours.end(); + }); + return *it; + }; + + // Remove the edge for candidate and the pruned node + auto deleted_node = find_deleted_item(candidate_node_neighbour_vec_items, sorted_neighbours_by_distance); + GET_OR_RET(RemoveEdge(deleted_node.key, candidate_node.key, level, batch)); + deleted_edges_map[candidate_node.key].insert(deleted_node.key); + deleted_edges_map[deleted_node.key].insert(candidate_node.key); + } + } + + // Update inserted node metadata + HnswNodeFieldMetadata node_metadata(static_cast(connected_edges_set.size()), vector); + node.PutMetadata(&node_metadata, search_key, storage, batch.Get()); + + // Update modified nodes metadata + for (const auto& node_edges : deleted_edges_map) { + auto& current_node_key = node_edges.first; + auto current_node = HnswNode(current_node_key, level); + auto current_node_metadata = GET_OR_RET(current_node.DecodeMetadata(search_key, storage)); + auto new_num_neighbours = current_node_metadata.num_neighbours - node_edges.second.size(); + if (connected_edges_set.count(current_node_key) != 0) { + new_num_neighbours++; + connected_edges_set.erase(current_node_key); + } + current_node_metadata.num_neighbours = new_num_neighbours; + current_node.PutMetadata(¤t_node_metadata, search_key, storage, batch.Get()); + } + + for (const auto& current_node_key : connected_edges_set) { + auto current_node = HnswNode(current_node_key, level); + HnswNodeFieldMetadata current_node_metadata = GET_OR_RET(current_node.DecodeMetadata(search_key, storage)); + current_node_metadata.num_neighbours++; + current_node.PutMetadata(¤t_node_metadata, search_key, storage, batch.Get()); + } + + entry_points.clear(); + for (const auto& new_entry_point : nearest_vec_items) { + entry_points.push_back(new_entry_point.key); + } + } + } else { + auto node = HnswNode(std::string(key), 0); + HnswNodeFieldMetadata node_metadata(0, vector); + node.PutMetadata(&node_metadata, search_key, storage, batch.Get()); + metadata->num_levels = 1; + } + + while (target_level > metadata->num_levels - 1) { + auto node = HnswNode(std::string(key), metadata->num_levels); + HnswNodeFieldMetadata node_metadata(0, vector); + node.PutMetadata(&node_metadata, search_key, storage, batch.Get()); + metadata->num_levels++; + } + + std::string encoded_index_metadata; + metadata->Encode(&encoded_index_metadata); + auto index_meta_key = search_key.ConstructFieldMeta(); + batch->Put(cf_handle, index_meta_key, encoded_index_metadata); + + return Status::OK(); +} + +Status HnswIndex::InsertVectorEntry(std::string_view key, const kqir::NumericArray& vector, + ObserverOrUniquePtr& batch) { + auto target_level = RandomizeLayer(); + return InsertVectorEntryInternal(key, vector, batch, target_level); +} + +Status HnswIndex::DeleteVectorEntry(std::string_view key, ObserverOrUniquePtr& batch) const { + std::string node_key(key); + for (uint16_t level = 0; level < metadata->num_levels; level++) { + auto node = HnswNode(node_key, level); + auto node_metadata_status = node.DecodeMetadata(search_key, storage); + if (!node_metadata_status.IsOK()) { + break; + } + + auto node_metadata = std::move(node_metadata_status).GetValue(); + auto node_index_key = search_key.ConstructHnswNode(level, key); + auto s = batch->Delete(storage->GetCFHandle(ColumnFamilyID::Search), node_index_key); + if (!s.ok()) { + return {Status::NotOK, s.ToString()}; + } + + node.DecodeNeighbours(search_key, storage); + for (const auto& neighbour_key : node.neighbours) { + GET_OR_RET(RemoveEdge(node_key, neighbour_key, level, batch)); + auto neighbour_node = HnswNode(neighbour_key, level); + HnswNodeFieldMetadata neighbour_node_metadata = GET_OR_RET(neighbour_node.DecodeMetadata(search_key, storage)); + neighbour_node_metadata.num_neighbours--; + neighbour_node.PutMetadata(&neighbour_node_metadata, search_key, storage, batch.Get()); + } + } + + auto has_other_nodes_at_level = [&](uint16_t level, std::string_view skip_key) -> bool { + auto prefix = search_key.ConstructHnswLevelNodePrefix(level); + util::UniqueIterator it(storage, storage->DefaultScanOptions(), ColumnFamilyID::Search); + it->Seek(prefix); + + Slice node_key; + Slice node_key_dst; + while (it->Valid() && it->key().starts_with(prefix)) { + std::string node_key_str = it->key().ToString().substr(prefix.size()); + node_key = Slice(node_key_str); + if (!GetSizedString(&node_key, &node_key_dst)) { + continue; + } + if (node_key_dst.ToString() != skip_key) { + return true; + } + it->Next(); + } + return false; + }; + + while (metadata->num_levels > 0) { + if (has_other_nodes_at_level(metadata->num_levels - 1, key)) { + break; + } + metadata->num_levels--; + } + + std::string encoded_index_metadata; + metadata->Encode(&encoded_index_metadata); + auto index_meta_key = search_key.ConstructFieldMeta(); + batch->Put(storage->GetCFHandle(ColumnFamilyID::Search), index_meta_key, encoded_index_metadata); + + return Status::OK(); +} + +} // namespace redis diff --git a/src/search/hnsw_indexer.h b/src/search/hnsw_indexer.h new file mode 100644 index 00000000000..30bdf94ac46 --- /dev/null +++ b/src/search/hnsw_indexer.h @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#pragma once + +#include +#include +#include + +#include "search/indexer.h" +#include "search/search_encoding.h" +#include "search/value.h" +#include "storage/storage.h" + +namespace redis { + +class HnswIndex; + +struct HnswNode { + using NodeKey = std::string; + NodeKey key; + uint16_t level{}; + std::vector neighbours; + + HnswNode(NodeKey key, uint16_t level); + + StatusOr DecodeMetadata(const SearchKey& search_key, engine::Storage* storage) const; + void PutMetadata(HnswNodeFieldMetadata* node_meta, const SearchKey& search_key, engine::Storage* storage, + rocksdb::WriteBatchBase* batch) const; + void DecodeNeighbours(const SearchKey& search_key, engine::Storage* storage); + + // For testing purpose + Status AddNeighbour(const NodeKey& neighbour_key, const SearchKey& search_key, engine::Storage* storage, + rocksdb::WriteBatchBase* batch) const; + Status RemoveNeighbour(const NodeKey& neighbour_key, const SearchKey& search_key, engine::Storage* storage, + rocksdb::WriteBatchBase* batch) const; + friend class HnswIndex; +}; + +struct VectorItem { + using NodeKey = HnswNode::NodeKey; + + NodeKey key; + kqir::NumericArray vector; + const HnswVectorFieldMetadata* metadata; + + VectorItem() : metadata(nullptr) {} + + static Status Create(NodeKey key, const kqir::NumericArray& vector, const HnswVectorFieldMetadata* metadata, + VectorItem* out); + static Status Create(NodeKey key, kqir::NumericArray&& vector, const HnswVectorFieldMetadata* metadata, + VectorItem* out); + + bool operator==(const VectorItem& other) const; + bool operator<(const VectorItem& other) const; + + private: + VectorItem(NodeKey&& key, const kqir::NumericArray& vector, const HnswVectorFieldMetadata* metadata); + VectorItem(NodeKey&& key, kqir::NumericArray&& vector, const HnswVectorFieldMetadata* metadata); +}; + +StatusOr ComputeSimilarity(const VectorItem& left, const VectorItem& right); + +struct HnswIndex { + using NodeKey = HnswNode::NodeKey; + + SearchKey search_key; + HnswVectorFieldMetadata* metadata; + engine::Storage* storage = nullptr; + + std::mt19937 generator; + double m_level_normalization_factor; + + HnswIndex(const SearchKey& search_key, HnswVectorFieldMetadata* vector, engine::Storage* storage); + + static StatusOr> DecodeNodesToVectorItems(const std::vector& node_key, + uint16_t level, const SearchKey& search_key, + engine::Storage* storage, + const HnswVectorFieldMetadata* metadata); + uint16_t RandomizeLayer(); + StatusOr DefaultEntryPoint(uint16_t level) const; + Status AddEdge(const NodeKey& node_key1, const NodeKey& node_key2, uint16_t layer, + ObserverOrUniquePtr& batch) const; + Status RemoveEdge(const NodeKey& node_key1, const NodeKey& node_key2, uint16_t layer, + ObserverOrUniquePtr& batch) const; + + StatusOr> SelectNeighbors(const VectorItem& vec, const std::vector& vectors, + uint16_t layer) const; + StatusOr> SearchLayer(uint16_t level, const VectorItem& target_vector, uint32_t ef_runtime, + const std::vector& entry_points) const; + Status InsertVectorEntryInternal(std::string_view key, const kqir::NumericArray& vector, + ObserverOrUniquePtr& batch, uint16_t layer) const; + Status InsertVectorEntry(std::string_view key, const kqir::NumericArray& vector, + ObserverOrUniquePtr& batch); + Status DeleteVectorEntry(std::string_view key, ObserverOrUniquePtr& batch) const; +}; + +} // namespace redis diff --git a/src/search/indexer.cc b/src/search/indexer.cc index 1212dd2f0d8..576de07343a 100644 --- a/src/search/indexer.cc +++ b/src/search/indexer.cc @@ -25,6 +25,7 @@ #include "db_util.h" #include "parse_util.h" +#include "search/hnsw_indexer.h" #include "search/search_encoding.h" #include "search/value.h" #include "storage/redis_metadata.h" @@ -57,10 +58,6 @@ StatusOr FieldValueRetriever::Create(IndexOnDataType type, } } -// placeholders, remove them after vector indexing is implemented -static bool IsVectorType(const redis::IndexFieldMetadata *) { return false; } -static size_t GetVectorDim(const redis::IndexFieldMetadata *) { return 1; } - StatusOr FieldValueRetriever::ParseFromJson(const jsoncons::json &val, const redis::IndexFieldMetadata *type) { if (auto numeric [[maybe_unused]] = dynamic_cast(type)) { @@ -82,8 +79,8 @@ StatusOr FieldValueRetriever::ParseFromJson(const jsoncons::json &v } else { return {Status::NotOK, "json value should be string or array of strings for tag fields"}; } - } else if (IsVectorType(type)) { - size_t dim = GetVectorDim(type); + } else if (auto vector = dynamic_cast(type)) { + const auto dim = vector->dim; if (!val.is_array()) return {Status::NotOK, "json value should be array of numbers for vector fields"}; if (dim != val.size()) return {Status::NotOK, "the size of the json array is not equal to the dim of the vector"}; std::vector nums; @@ -107,8 +104,8 @@ StatusOr FieldValueRetriever::ParseFromHash(const std::string &valu const char delim[] = {tag->separator, '\0'}; auto vec = util::Split(value, delim); return kqir::MakeValue(vec); - } else if (IsVectorType(type)) { - const size_t dim = GetVectorDim(type); + } else if (auto vector = dynamic_cast(type)) { + const auto dim = vector->dim; if (value.size() != dim * sizeof(double)) { return {Status::NotOK, "field value is too short or too long to be parsed as a vector"}; } @@ -246,7 +243,7 @@ Status IndexUpdater::UpdateTagIndex(std::string_view key, const kqir::Value &ori Status IndexUpdater::UpdateNumericIndex(std::string_view key, const kqir::Value &original, const kqir::Value ¤t, const SearchKey &search_key, const NumericFieldMetadata *num) const { CHECK(original.IsNull() || original.Is()); - CHECK(original.IsNull() || original.Is()); + CHECK(current.IsNull() || current.Is()); auto *storage = indexer->storage; auto batch = storage->GetWriteBatchBase(); @@ -269,6 +266,32 @@ Status IndexUpdater::UpdateNumericIndex(std::string_view key, const kqir::Value return Status::OK(); } +Status IndexUpdater::UpdateHnswVectorIndex(std::string_view key, const kqir::Value &original, + const kqir::Value ¤t, const SearchKey &search_key, + HnswVectorFieldMetadata *vector) const { + CHECK(original.IsNull() || original.Is()); + CHECK(current.IsNull() || current.Is()); + + auto storage = indexer->storage; + auto hnsw = HnswIndex(search_key, vector, storage); + + if (!original.IsNull()) { + auto batch = storage->GetWriteBatchBase(); + GET_OR_RET(hnsw.DeleteVectorEntry(key, batch)); + auto s = storage->Write(storage->DefaultWriteOptions(), batch->GetWriteBatch()); + if (!s.ok()) return {Status::NotOK, s.ToString()}; + } + + if (!current.IsNull()) { + auto batch = storage->GetWriteBatchBase(); + GET_OR_RET(hnsw.InsertVectorEntry(key, current.Get(), batch)); + auto s = storage->Write(storage->DefaultWriteOptions(), batch->GetWriteBatch()); + if (!s.ok()) return {Status::NotOK, s.ToString()}; + } + + return Status::OK(); +} + Status IndexUpdater::UpdateIndex(const std::string &field, std::string_view key, const kqir::Value &original, const kqir::Value ¤t) const { if (original == current) { @@ -287,6 +310,8 @@ Status IndexUpdater::UpdateIndex(const std::string &field, std::string_view key, GET_OR_RET(UpdateTagIndex(key, original, current, search_key, tag)); } else if (auto numeric [[maybe_unused]] = dynamic_cast(metadata)) { GET_OR_RET(UpdateNumericIndex(key, original, current, search_key, numeric)); + } else if (auto vector = dynamic_cast(metadata)) { + GET_OR_RET(UpdateHnswVectorIndex(key, original, current, search_key, vector)); } else { return {Status::NotOK, "Unexpected field type"}; } diff --git a/src/search/indexer.h b/src/search/indexer.h index 8ffd503b6ba..e5e0aa4fb50 100644 --- a/src/search/indexer.h +++ b/src/search/indexer.h @@ -89,6 +89,8 @@ struct IndexUpdater { const SearchKey &search_key, const TagFieldMetadata *tag) const; Status UpdateNumericIndex(std::string_view key, const kqir::Value &original, const kqir::Value ¤t, const SearchKey &search_key, const NumericFieldMetadata *num) const; + Status UpdateHnswVectorIndex(std::string_view key, const kqir::Value &original, const kqir::Value ¤t, + const SearchKey &search_key, HnswVectorFieldMetadata *vector) const; }; struct GlobalIndexer { diff --git a/src/search/search_encoding.h b/src/search/search_encoding.h index 68e248bb40b..2fbbde8c21e 100644 --- a/src/search/search_encoding.h +++ b/src/search/search_encoding.h @@ -33,6 +33,7 @@ enum class IndexOnDataType : uint8_t { }; inline constexpr auto kErrorInsufficientLength = "insufficient length while decoding metadata"; +inline constexpr auto kErrorIncorrectLength = "length is too short or too long to be parsed as a vector"; class IndexMetadata { public: @@ -76,6 +77,23 @@ enum class IndexFieldType : uint8_t { TAG = 1, NUMERIC = 2, + + VECTOR = 3, +}; + +enum class VectorType : uint8_t { + FLOAT64 = 1, +}; + +enum class DistanceMetric : uint8_t { + L2 = 0, + IP = 1, + COSINE = 2, +}; + +enum class HnswLevelType : uint8_t { + NODE = 1, + EDGE = 2, }; struct SearchKey { @@ -95,6 +113,26 @@ struct SearchKey { void PutIndex(std::string *dst) const { PutSizedString(dst, index); } + static void PutHnswLevelType(std::string *dst, HnswLevelType type) { PutFixed8(dst, uint8_t(type)); } + + void PutHnswLevelPrefix(std::string *dst, uint16_t level) const { + PutNamespace(dst); + PutType(dst, SearchSubkeyType::FIELD); + PutIndex(dst); + PutSizedString(dst, field); + PutFixed16(dst, level); + } + + void PutHnswLevelNodePrefix(std::string *dst, uint16_t level) const { + PutHnswLevelPrefix(dst, level); + PutHnswLevelType(dst, HnswLevelType::NODE); + } + + void PutHnswLevelEdgePrefix(std::string *dst, uint16_t level) const { + PutHnswLevelPrefix(dst, level); + PutHnswLevelType(dst, HnswLevelType::EDGE); + } + std::string ConstructIndexMeta() const { std::string dst; PutNamespace(&dst); @@ -177,6 +215,34 @@ struct SearchKey { PutSizedString(&dst, key); return dst; } + + std::string ConstructHnswLevelNodePrefix(uint16_t level) const { + std::string dst; + PutHnswLevelNodePrefix(&dst, level); + return dst; + } + + std::string ConstructHnswNode(uint16_t level, std::string_view key) const { + std::string dst; + PutHnswLevelNodePrefix(&dst, level); + PutSizedString(&dst, key); + return dst; + } + + std::string ConstructHnswEdgeWithSingleEnd(uint16_t level, std::string_view key) const { + std::string dst; + PutHnswLevelEdgePrefix(&dst, level); + PutSizedString(&dst, key); + return dst; + } + + std::string ConstructHnswEdge(uint16_t level, std::string_view key1, std::string_view key2) const { + std::string dst; + PutHnswLevelEdgePrefix(&dst, level); + PutSizedString(&dst, key1); + PutSizedString(&dst, key2); + return dst; + } }; struct IndexPrefixes { @@ -236,6 +302,8 @@ struct IndexFieldMetadata { return "tag"; case IndexFieldType::NUMERIC: return "numeric"; + case IndexFieldType::VECTOR: + return "vector"; default: return "unknown"; } @@ -291,6 +359,96 @@ struct NumericFieldMetadata : IndexFieldMetadata { bool IsSortable() const override { return true; } }; +struct HnswVectorFieldMetadata : IndexFieldMetadata { + VectorType vector_type; + uint16_t dim; + DistanceMetric distance_metric; + + uint32_t initial_cap = 500000; // Initial vector capacity + uint16_t m = 16; // Max allowed outgoing edges per node + uint32_t ef_construction = 200; // Max potential outgoing edge candidates during construction + uint32_t ef_runtime = 10; // Max top candidates held during KNN search + double epsilon = 0.01; // Relative factor setting search boundaries in range queries + uint16_t num_levels = 0; // Number of levels in the HNSW graph + + HnswVectorFieldMetadata() : IndexFieldMetadata(IndexFieldType::VECTOR) {} + + void Encode(std::string *dst) const override { + IndexFieldMetadata::Encode(dst); + PutFixed8(dst, uint8_t(vector_type)); + PutFixed16(dst, dim); + PutFixed8(dst, uint8_t(distance_metric)); + PutFixed32(dst, initial_cap); + PutFixed16(dst, m); + PutFixed32(dst, ef_construction); + PutFixed32(dst, ef_runtime); + PutDouble(dst, epsilon); + PutFixed16(dst, num_levels); + } + + rocksdb::Status Decode(Slice *input) override { + if (auto s = IndexFieldMetadata::Decode(input); !s.ok()) { + return s; + } + + constexpr size_t required_size = sizeof(uint8_t) + sizeof(uint16_t) + sizeof(uint8_t) + sizeof(uint32_t) + + sizeof(uint16_t) + sizeof(uint32_t) + sizeof(uint32_t) + sizeof(uint64_t) + + sizeof(uint16_t); + + if (input->size() < required_size) { + return rocksdb::Status::Corruption(kErrorInsufficientLength); + } + + GetFixed8(input, (uint8_t *)(&vector_type)); + GetFixed16(input, &dim); + GetFixed8(input, (uint8_t *)(&distance_metric)); + GetFixed32(input, &initial_cap); + GetFixed16(input, &m); + GetFixed32(input, &ef_construction); + GetFixed32(input, &ef_runtime); + GetDouble(input, &epsilon); + GetFixed16(input, &num_levels); + return rocksdb::Status::OK(); + } +}; + +struct HnswNodeFieldMetadata { + uint16_t num_neighbours; + std::vector vector; + + HnswNodeFieldMetadata() = default; + HnswNodeFieldMetadata(uint16_t num_neighbours, std::vector vector) + : num_neighbours(num_neighbours), vector(std::move(vector)) {} + + void Encode(std::string *dst) const { + PutFixed16(dst, num_neighbours); + PutFixed16(dst, static_cast(vector.size())); + for (double element : vector) { + PutDouble(dst, element); + } + } + + rocksdb::Status Decode(Slice *input) { + if (input->size() < 2 + 2) { + return rocksdb::Status::Corruption(kErrorInsufficientLength); + } + GetFixed16(input, (uint16_t *)(&num_neighbours)); + + uint16_t dim = 0; + GetFixed16(input, (uint16_t *)(&dim)); + + if (input->size() != dim * sizeof(double)) { + return rocksdb::Status::Corruption(kErrorIncorrectLength); + } + vector.resize(dim); + + for (auto i = 0; i < dim; ++i) { + GetDouble(input, &vector[i]); + } + return rocksdb::Status::OK(); + } +}; + inline rocksdb::Status IndexFieldMetadata::Decode(Slice *input, std::unique_ptr &ptr) { if (input->size() < 1) { return rocksdb::Status::Corruption(kErrorInsufficientLength); @@ -303,6 +461,9 @@ inline rocksdb::Status IndexFieldMetadata::Decode(Slice *input, std::unique_ptr< case IndexFieldType::NUMERIC: ptr = std::make_unique(); break; + case IndexFieldType::VECTOR: + ptr = std::make_unique(); + break; default: return rocksdb::Status::Corruption("encountered unknown field type"); } diff --git a/tests/cppunit/hnsw_index_test.cc b/tests/cppunit/hnsw_index_test.cc new file mode 100644 index 00000000000..e09e9830262 --- /dev/null +++ b/tests/cppunit/hnsw_index_test.cc @@ -0,0 +1,664 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include +#include + +#include +#include +#include + +#include "search/hnsw_indexer.h" +#include "search/indexer.h" +#include "search/search_encoding.h" +#include "search/value.h" +#include "storage/storage.h" + +struct HnswIndexTest : TestBase { + redis::HnswVectorFieldMetadata metadata; + std::string ns = "hnsw_test_ns"; + std::string idx_name = "hnsw_test_idx"; + std::string key = "vector"; + std::unique_ptr hnsw_index; + + HnswIndexTest() { + metadata.vector_type = redis::VectorType::FLOAT64; + metadata.dim = 3; + metadata.m = 3; + metadata.distance_metric = redis::DistanceMetric::L2; + auto search_key = redis::SearchKey(ns, idx_name, key); + hnsw_index = std::make_unique(search_key, &metadata, storage_.get()); + } + + void TearDown() override { hnsw_index.reset(); } +}; + +TEST_F(HnswIndexTest, ComputeSimilarity) { + redis::VectorItem vec1; + auto status1 = redis::VectorItem::Create("1", {1.0, 1.2, 1.4}, hnsw_index->metadata, &vec1); + ASSERT_TRUE(status1.IsOK()); + redis::VectorItem vec2; + auto status2 = redis::VectorItem::Create("2", {3.0, 3.2, 3.4}, hnsw_index->metadata, &vec2); + ASSERT_TRUE(status2.IsOK()); + redis::VectorItem vec3; // identical to vec1 + auto status3 = redis::VectorItem::Create("3", {1.0, 1.2, 1.4}, hnsw_index->metadata, &vec3); + ASSERT_TRUE(status3.IsOK()); + + auto s1 = redis::ComputeSimilarity(vec1, vec3); + ASSERT_TRUE(s1.IsOK()); + double similarity = s1.GetValue(); + EXPECT_EQ(similarity, 0.0); + + auto s2 = redis::ComputeSimilarity(vec1, vec2); + ASSERT_TRUE(s2.IsOK()); + similarity = s2.GetValue(); + EXPECT_NEAR(similarity, std::sqrt(12), 1e-5); + + hnsw_index->metadata->distance_metric = redis::DistanceMetric::IP; + auto s3 = redis::ComputeSimilarity(vec1, vec2); + ASSERT_TRUE(s3.IsOK()); + similarity = s3.GetValue(); + EXPECT_NEAR(similarity, -(1.0 * 3.0 + 1.2 * 3.2 + 1.4 * 3.4), 1e-5); + + hnsw_index->metadata->distance_metric = redis::DistanceMetric::COSINE; + double expected_res = (1.0 * 3.0 + 1.2 * 3.2 + 1.4 * 3.4) / + std::sqrt((1.0 * 1.0 + 1.2 * 1.2 + 1.4 * 1.4) * (3.0 * 3.0 + 3.2 * 3.2 + 3.4 * 3.4)); + auto s4 = redis::ComputeSimilarity(vec1, vec2); + ASSERT_TRUE(s4.IsOK()); + similarity = s4.GetValue(); + EXPECT_NEAR(similarity, 1 - expected_res, 1e-5); + + hnsw_index->metadata->distance_metric = redis::DistanceMetric::L2; +} + +TEST_F(HnswIndexTest, RandomizeLayer) { + constexpr size_t kSampleSize = 50000; + + std::vector layers; + layers.reserve(kSampleSize); + + for (size_t i = 0; i < kSampleSize; ++i) { + layers.push_back(hnsw_index->RandomizeLayer()); + EXPECT_GE(layers.back(), 0); + } + + std::map layer_frequency; + for (const auto& layer : layers) { + layer_frequency[layer]++; + } + + uint16_t max_observed_layer = 0; + for (const auto& [layer, freq] : layer_frequency) { + // std::cout << "Layer: " << layer << " Frequency: " << freq << std::endl; + if (layer > max_observed_layer) { + max_observed_layer = layer; + } + } + + // Calculate expected frequencies for each layer based on the theoretical distribution + std::vector expected_frequencies(max_observed_layer + 1, 0); + double normalization_factor = 1.0 / std::log(hnsw_index->metadata->m); + double total_probability = 0.0; + + for (uint16_t i = 0; i <= max_observed_layer; ++i) { + total_probability += std::exp(-i / normalization_factor); + } + + for (uint16_t i = 0; i <= max_observed_layer; ++i) { + double probability = std::exp(-i / normalization_factor) / total_probability; + expected_frequencies[i] = kSampleSize * probability; + } + + for (const auto& [layer, freq] : layer_frequency) { + if (layer < expected_frequencies.size() / 3) { + double expected_freq = expected_frequencies[layer]; + double deviation = std::abs(static_cast(freq) - expected_freq) / expected_freq; + EXPECT_LE(deviation, 0.1) << "Layer: " << layer << " Frequency: " << freq << " Expected: " << expected_freq; + } + } +} + +TEST_F(HnswIndexTest, DefaultEntryPointNotFound) { + auto initial_result = hnsw_index->DefaultEntryPoint(0); + ASSERT_EQ(initial_result.GetCode(), Status::NotFound); +} + +TEST_F(HnswIndexTest, DecodeNodesToVectorItems) { + uint16_t layer = 1; + std::string node_key1 = "node1"; + std::string node_key2 = "node2"; + std::string node_key3 = "node3"; + + redis::HnswNode node1(node_key1, layer); + redis::HnswNode node2(node_key2, layer); + redis::HnswNode node3(node_key3, layer); + + redis::HnswNodeFieldMetadata metadata1(0, {1, 2, 3}); + redis::HnswNodeFieldMetadata metadata2(0, {4, 5, 6}); + redis::HnswNodeFieldMetadata metadata3(0, {7, 8, 9}); + + auto batch = storage_->GetWriteBatchBase(); + node1.PutMetadata(&metadata1, hnsw_index->search_key, hnsw_index->storage, batch.Get()); + node2.PutMetadata(&metadata2, hnsw_index->search_key, hnsw_index->storage, batch.Get()); + node3.PutMetadata(&metadata3, hnsw_index->search_key, hnsw_index->storage, batch.Get()); + auto s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); + ASSERT_TRUE(s.ok()); + + std::vector keys = {node_key1, node_key2, node_key3}; + + auto s1 = hnsw_index->DecodeNodesToVectorItems(keys, layer, hnsw_index->search_key, hnsw_index->storage, + hnsw_index->metadata); + ASSERT_TRUE(s1.IsOK()); + auto vector_items = s1.GetValue(); + ASSERT_EQ(vector_items.size(), 3); + EXPECT_EQ(vector_items[0].key, node_key1); + EXPECT_EQ(vector_items[1].key, node_key2); + EXPECT_EQ(vector_items[2].key, node_key3); + EXPECT_TRUE(vector_items[0].vector == std::vector({1, 2, 3})); + EXPECT_TRUE(vector_items[1].vector == std::vector({4, 5, 6})); + EXPECT_TRUE(vector_items[2].vector == std::vector({7, 8, 9})); +} + +TEST_F(HnswIndexTest, SelectNeighbors) { + redis::VectorItem vec1; + auto status1 = redis::VectorItem::Create("1", {1.0, 1.0, 1.0}, hnsw_index->metadata, &vec1); + ASSERT_TRUE(status1.IsOK()); + + redis::VectorItem vec2; + auto status2 = redis::VectorItem::Create("2", {2.0, 2.0, 2.0}, hnsw_index->metadata, &vec2); + ASSERT_TRUE(status2.IsOK()); + + redis::VectorItem vec3; + auto status3 = redis::VectorItem::Create("3", {3.0, 3.0, 3.0}, hnsw_index->metadata, &vec3); + ASSERT_TRUE(status3.IsOK()); + + redis::VectorItem vec4; + auto status4 = redis::VectorItem::Create("4", {4.0, 4.0, 4.0}, hnsw_index->metadata, &vec4); + ASSERT_TRUE(status4.IsOK()); + + redis::VectorItem vec5; + auto status5 = redis::VectorItem::Create("5", {5.0, 5.0, 5.0}, hnsw_index->metadata, &vec5); + ASSERT_TRUE(status5.IsOK()); + + redis::VectorItem vec6; + auto status6 = redis::VectorItem::Create("6", {6.0, 6.0, 6.0}, hnsw_index->metadata, &vec6); + ASSERT_TRUE(status6.IsOK()); + + redis::VectorItem vec7; + auto status7 = redis::VectorItem::Create("7", {7.0, 7.0, 7.0}, hnsw_index->metadata, &vec7); + ASSERT_TRUE(status7.IsOK()); + + std::vector candidates = {vec3, vec2}; + auto s1 = hnsw_index->SelectNeighbors(vec1, candidates, 1); + ASSERT_TRUE(s1.IsOK()); + auto selected = s1.GetValue(); + EXPECT_EQ(selected.size(), candidates.size()); + + EXPECT_EQ(selected[0].key, vec2.key); + EXPECT_EQ(selected[1].key, vec3.key); + + candidates = {vec4, vec2, vec5, vec7, vec3, vec6}; + auto s2 = hnsw_index->SelectNeighbors(vec1, candidates, 1); + ASSERT_TRUE(s2.IsOK()); + selected = s2.GetValue(); + EXPECT_EQ(selected.size(), 3); + + EXPECT_EQ(selected[0].key, vec2.key); + EXPECT_EQ(selected[1].key, vec3.key); + EXPECT_EQ(selected[2].key, vec4.key); + + candidates = {vec4, vec2, vec5, vec7, vec3, vec6}; + auto s3 = hnsw_index->SelectNeighbors(vec1, candidates, 0); + ASSERT_TRUE(s3.IsOK()); + selected = s3.GetValue(); + EXPECT_EQ(selected.size(), 6); + + EXPECT_EQ(selected[0].key, vec2.key); + EXPECT_EQ(selected[1].key, vec3.key); + EXPECT_EQ(selected[2].key, vec4.key); + EXPECT_EQ(selected[3].key, vec5.key); + EXPECT_EQ(selected[4].key, vec6.key); + EXPECT_EQ(selected[5].key, vec7.key); +} + +TEST_F(HnswIndexTest, SearchLayer) { + uint16_t layer = 3; + std::string node_key1 = "node1"; + std::string node_key2 = "node2"; + std::string node_key3 = "node3"; + std::string node_key4 = "node4"; + std::string node_key5 = "node5"; + + redis::HnswNode node1(node_key1, layer); + redis::HnswNode node2(node_key2, layer); + redis::HnswNode node3(node_key3, layer); + redis::HnswNode node4(node_key4, layer); + redis::HnswNode node5(node_key5, layer); + + redis::HnswNodeFieldMetadata metadata1(0, {1.0, 2.0, 3.0}); + redis::HnswNodeFieldMetadata metadata2(0, {4.0, 5.0, 6.0}); + redis::HnswNodeFieldMetadata metadata3(0, {7.0, 8.0, 9.0}); + redis::HnswNodeFieldMetadata metadata4(0, {2.0, 3.0, 4.0}); + redis::HnswNodeFieldMetadata metadata5(0, {6.0, 6.0, 7.0}); + + // Add Nodes + auto batch = storage_->GetWriteBatchBase(); + node1.PutMetadata(&metadata1, hnsw_index->search_key, hnsw_index->storage, batch.Get()); + node2.PutMetadata(&metadata2, hnsw_index->search_key, hnsw_index->storage, batch.Get()); + node3.PutMetadata(&metadata3, hnsw_index->search_key, hnsw_index->storage, batch.Get()); + node4.PutMetadata(&metadata4, hnsw_index->search_key, hnsw_index->storage, batch.Get()); + node5.PutMetadata(&metadata5, hnsw_index->search_key, hnsw_index->storage, batch.Get()); + auto s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); + ASSERT_TRUE(s.ok()); + + // Add Neighbours + batch = storage_->GetWriteBatchBase(); + auto s1 = node1.AddNeighbour("node2", hnsw_index->search_key, hnsw_index->storage, batch.Get()); + ASSERT_TRUE(s1.IsOK()); + auto s2 = node1.AddNeighbour("node4", hnsw_index->search_key, hnsw_index->storage, batch.Get()); + ASSERT_TRUE(s2.IsOK()); + auto s3 = node2.AddNeighbour("node1", hnsw_index->search_key, hnsw_index->storage, batch.Get()); + ASSERT_TRUE(s3.IsOK()); + auto s4 = node2.AddNeighbour("node3", hnsw_index->search_key, hnsw_index->storage, batch.Get()); + ASSERT_TRUE(s1.IsOK()); + auto s5 = node3.AddNeighbour("node2", hnsw_index->search_key, hnsw_index->storage, batch.Get()); + ASSERT_TRUE(s5.IsOK()); + auto s6 = node3.AddNeighbour("node5", hnsw_index->search_key, hnsw_index->storage, batch.Get()); + ASSERT_TRUE(s6.IsOK()); + auto s7 = node4.AddNeighbour("node1", hnsw_index->search_key, hnsw_index->storage, batch.Get()); + ASSERT_TRUE(s7.IsOK()); + auto s8 = node5.AddNeighbour("node3", hnsw_index->search_key, hnsw_index->storage, batch.Get()); + ASSERT_TRUE(s8.IsOK()); + s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); + ASSERT_TRUE(s.ok()); + + redis::VectorItem target_vector; + auto status = redis::VectorItem::Create("target", {2.0, 3.0, 4.0}, hnsw_index->metadata, &target_vector); + ASSERT_TRUE(status.IsOK()); + + // Test with multiple entry points + std::vector entry_points = {"node3", "node2"}; + uint32_t ef_runtime = 3; + + auto s9 = hnsw_index->SearchLayer(layer, target_vector, ef_runtime, entry_points); + ASSERT_TRUE(s9.IsOK()); + auto candidates = s9.GetValue(); + + ASSERT_EQ(candidates.size(), ef_runtime); + EXPECT_EQ(candidates[0].key, "node4"); + EXPECT_EQ(candidates[1].key, "node1"); + EXPECT_EQ(candidates[2].key, "node2"); + + // Test with a single entry point + entry_points = {"node5"}; + auto s10 = hnsw_index->SearchLayer(layer, target_vector, ef_runtime, entry_points); + ASSERT_TRUE(s10.IsOK()); + candidates = s10.GetValue(); + + ASSERT_EQ(candidates.size(), ef_runtime); + EXPECT_EQ(candidates[0].key, "node4"); + EXPECT_EQ(candidates[1].key, "node1"); + EXPECT_EQ(candidates[2].key, "node2"); + + // Test with different ef_runtime + ef_runtime = 10; + auto s11 = hnsw_index->SearchLayer(layer, target_vector, ef_runtime, entry_points); + ASSERT_TRUE(s11.IsOK()); + candidates = s11.GetValue(); + + ASSERT_EQ(candidates.size(), 5); + EXPECT_EQ(candidates[0].key, "node4"); + EXPECT_EQ(candidates[1].key, "node1"); + EXPECT_EQ(candidates[2].key, "node2"); + EXPECT_EQ(candidates[3].key, "node5"); + EXPECT_EQ(candidates[4].key, "node3"); +} + +TEST_F(HnswIndexTest, InsertAndDeleteVectorEntry) { + std::vector vec1 = {11.0, 12.0, 13.0}; + std::vector vec2 = {14.0, 15.0, 16.0}; + std::vector vec3 = {17.0, 18.0, 19.0}; + std::vector vec4 = {12.0, 13.0, 14.0}; + std::vector vec5 = {15.0, 16.0, 17.0}; + + std::string key1 = "n1"; + std::string key2 = "n2"; + std::string key3 = "n3"; + std::string key4 = "n4"; + std::string key5 = "n5"; + + // Insert n1 into layer 1 + uint16_t target_level = 1; + auto batch = storage_->GetWriteBatchBase(); + auto s1 = hnsw_index->InsertVectorEntryInternal(key1, vec1, batch, target_level); + ASSERT_TRUE(s1.IsOK()); + auto s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); + ASSERT_TRUE(s.ok()); + + rocksdb::PinnableSlice value; + auto index_meta_key = hnsw_index->search_key.ConstructFieldMeta(); + s = storage_->Get(rocksdb::ReadOptions(), hnsw_index->storage->GetCFHandle(ColumnFamilyID::Search), index_meta_key, + &value); + ASSERT_TRUE(s.ok()); + redis::HnswVectorFieldMetadata decoded_metadata; + decoded_metadata.Decode(&value); + ASSERT_TRUE(decoded_metadata.num_levels == 2); + + redis::HnswNode node1_layer0(key1, 0); + auto s2 = node1_layer0.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); + ASSERT_TRUE(s2.IsOK()); + redis::HnswNodeFieldMetadata node1_layer0_meta = s2.GetValue(); + EXPECT_EQ(node1_layer0_meta.num_neighbours, 0); + + redis::HnswNode node1_layer1(key1, 1); + auto s3 = node1_layer1.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); + ASSERT_TRUE(s3.IsOK()); + redis::HnswNodeFieldMetadata node1_layer1_meta = s2.GetValue(); + EXPECT_EQ(node1_layer1_meta.num_neighbours, 0); + + // Insert n2 into layer 3 + batch = storage_->GetWriteBatchBase(); + target_level = 3; + auto s4 = hnsw_index->InsertVectorEntryInternal(key2, vec2, batch, target_level); + ASSERT_TRUE(s4.IsOK()); + s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); + ASSERT_TRUE(s.ok()); + + index_meta_key = hnsw_index->search_key.ConstructFieldMeta(); + s = storage_->Get(rocksdb::ReadOptions(), hnsw_index->storage->GetCFHandle(ColumnFamilyID::Search), index_meta_key, + &value); + ASSERT_TRUE(s.ok()); + decoded_metadata.Decode(&value); + ASSERT_TRUE(decoded_metadata.num_levels == 4); + + node1_layer0.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage); + EXPECT_EQ(node1_layer0.neighbours.size(), 1); + EXPECT_EQ(node1_layer0.neighbours[0], "n2"); + + node1_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage); + EXPECT_EQ(node1_layer1.neighbours.size(), 1); + EXPECT_EQ(node1_layer1.neighbours[0], "n2"); + + redis::HnswNode node2_layer0(key2, 0); + node2_layer0.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage); + EXPECT_EQ(node2_layer0.neighbours.size(), 1); + EXPECT_EQ(node2_layer0.neighbours[0], "n1"); + + redis::HnswNode node2_layer1(key2, 1); + node2_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage); + EXPECT_EQ(node2_layer1.neighbours.size(), 1); + EXPECT_EQ(node2_layer1.neighbours[0], "n1"); + + redis::HnswNode node2_layer2(key2, 2); + auto s5 = node2_layer2.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); + ASSERT_TRUE(s5.IsOK()); + redis::HnswNodeFieldMetadata node2_layer2_meta = s5.GetValue(); + EXPECT_EQ(node2_layer2_meta.num_neighbours, 0); + + redis::HnswNode node2_layer3(key2, 3); + auto s6 = node2_layer3.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); + ASSERT_TRUE(s6.IsOK()); + redis::HnswNodeFieldMetadata node2_layer3_meta = s6.GetValue(); + EXPECT_EQ(node2_layer3_meta.num_neighbours, 0); + + // Insert n3 into layer 2 + batch = storage_->GetWriteBatchBase(); + target_level = 2; + auto s7 = hnsw_index->InsertVectorEntryInternal(key3, vec3, batch, target_level); + ASSERT_TRUE(s7.IsOK()); + s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); + ASSERT_TRUE(s.ok()); + + index_meta_key = hnsw_index->search_key.ConstructFieldMeta(); + s = storage_->Get(rocksdb::ReadOptions(), hnsw_index->storage->GetCFHandle(ColumnFamilyID::Search), index_meta_key, + &value); + ASSERT_TRUE(s.ok()); + decoded_metadata.Decode(&value); + ASSERT_TRUE(decoded_metadata.num_levels == 4); + + redis::HnswNode node3_layer2(key3, target_level); + auto s8 = node3_layer2.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); + ASSERT_TRUE(s8.IsOK()); + redis::HnswNodeFieldMetadata node3_layer2_meta = s8.GetValue(); + EXPECT_EQ(node3_layer2_meta.num_neighbours, 1); + node3_layer2.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage); + EXPECT_EQ(node3_layer2.neighbours.size(), 1); + EXPECT_EQ(node3_layer2.neighbours[0], "n2"); + + redis::HnswNode node3_layer1(key3, 1); + auto s9 = node3_layer1.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); + ASSERT_TRUE(s9.IsOK()); + redis::HnswNodeFieldMetadata node3_layer1_meta = s9.GetValue(); + EXPECT_EQ(node3_layer1_meta.num_neighbours, 2); + node3_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage); + EXPECT_EQ(node3_layer1.neighbours.size(), 2); + std::unordered_set expected_set = {"n1", "n2"}; + std::unordered_set actual_set{node3_layer1.neighbours.begin(), node3_layer1.neighbours.end()}; + EXPECT_EQ(actual_set, expected_set); + + // Insert n4 into layer 1 + batch = storage_->GetWriteBatchBase(); + target_level = 1; + auto s10 = hnsw_index->InsertVectorEntryInternal(key4, vec4, batch, target_level); + ASSERT_TRUE(s10.IsOK()); + s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); + ASSERT_TRUE(s.ok()); + + redis::HnswNode node4_layer0(key4, 0); + auto s11 = node4_layer0.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); + ASSERT_TRUE(s11.IsOK()); + redis::HnswNodeFieldMetadata node4_layer0_meta = s11.GetValue(); + EXPECT_EQ(node4_layer0_meta.num_neighbours, 3); + + auto s12 = node1_layer1.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); + ASSERT_TRUE(s12.IsOK()); + node1_layer1_meta = s12.GetValue(); + EXPECT_EQ(node1_layer1_meta.num_neighbours, 3); + node1_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage); + expected_set = {"n2", "n3", "n4"}; + actual_set = {node1_layer1.neighbours.begin(), node1_layer1.neighbours.end()}; + EXPECT_EQ(actual_set, expected_set); + + auto s13 = node2_layer1.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); + ASSERT_TRUE(s13.IsOK()); + auto node2_layer1_meta = s13.GetValue(); + EXPECT_EQ(node2_layer1_meta.num_neighbours, 3); + node2_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage); + expected_set = {"n1", "n3", "n4"}; + actual_set = {node2_layer1.neighbours.begin(), node2_layer1.neighbours.end()}; + EXPECT_EQ(actual_set, expected_set); + + auto s14 = node3_layer1.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); + ASSERT_TRUE(s14.IsOK()); + node3_layer1_meta = s14.GetValue(); + EXPECT_EQ(node3_layer1_meta.num_neighbours, 3); + node3_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage); + expected_set = {"n1", "n2", "n4"}; + actual_set = {node3_layer1.neighbours.begin(), node3_layer1.neighbours.end()}; + EXPECT_EQ(actual_set, expected_set); + + // Insert n5 into layer 1 + batch = storage_->GetWriteBatchBase(); + auto s15 = hnsw_index->InsertVectorEntryInternal(key5, vec5, batch, target_level); + ASSERT_TRUE(s15.IsOK()); + s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); + ASSERT_TRUE(s.ok()); + + auto s16 = node2_layer1.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); + ASSERT_TRUE(s16.IsOK()); + node2_layer1_meta = s16.GetValue(); + EXPECT_EQ(node2_layer1_meta.num_neighbours, 3); + node2_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage); + expected_set = {"n1", "n4", "n5"}; + actual_set = {node2_layer1.neighbours.begin(), node2_layer1.neighbours.end()}; + EXPECT_EQ(actual_set, expected_set); + + auto s17 = node3_layer1.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); + ASSERT_TRUE(s17.IsOK()); + node3_layer1_meta = s17.GetValue(); + EXPECT_EQ(node3_layer1_meta.num_neighbours, 2); + node3_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage); + expected_set = {"n1", "n5"}; + actual_set = {node3_layer1.neighbours.begin(), node3_layer1.neighbours.end()}; + EXPECT_EQ(actual_set, expected_set); + + redis::HnswNode node4_layer1(key4, 1); + auto s18 = node4_layer1.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); + ASSERT_TRUE(s18.IsOK()); + auto node4_layer1_meta = s18.GetValue(); + EXPECT_EQ(node4_layer1_meta.num_neighbours, 3); + node4_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage); + expected_set = {"n1", "n2", "n5"}; + actual_set = {node4_layer1.neighbours.begin(), node4_layer1.neighbours.end()}; + EXPECT_EQ(actual_set, expected_set); + + redis::HnswNode node5_layer1(key5, 1); + auto s19 = node5_layer1.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); + ASSERT_TRUE(s19.IsOK()); + auto node5_layer1_meta = s19.GetValue(); + EXPECT_EQ(node5_layer1_meta.num_neighbours, 3); + node5_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage); + expected_set = {"n2", "n3", "n4"}; + actual_set = {node5_layer1.neighbours.begin(), node5_layer1.neighbours.end()}; + EXPECT_EQ(actual_set, expected_set); + + auto s20 = node1_layer0.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); + ASSERT_TRUE(s20.IsOK()); + node1_layer0_meta = s20.GetValue(); + EXPECT_EQ(node1_layer0_meta.num_neighbours, 4); + node1_layer0.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage); + expected_set = {"n2", "n3", "n4", "n5"}; + actual_set = {node1_layer0.neighbours.begin(), node1_layer0.neighbours.end()}; + EXPECT_EQ(actual_set, expected_set); + + redis::HnswNode node5_layer0(key5, 0); + auto s21 = node5_layer0.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); + ASSERT_TRUE(s21.IsOK()); + auto node5_layer0_meta = s21.GetValue(); + EXPECT_EQ(node5_layer0_meta.num_neighbours, 4); + node5_layer0.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage); + expected_set = {"n1", "n2", "n3", "n4"}; + actual_set = {node5_layer0.neighbours.begin(), node5_layer0.neighbours.end()}; + EXPECT_EQ(actual_set, expected_set); + + // Delete n2 + batch = storage_->GetWriteBatchBase(); + auto s22 = hnsw_index->DeleteVectorEntry(key2, batch); + ASSERT_TRUE(s22.IsOK()); + s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); + ASSERT_TRUE(s.ok()); + + index_meta_key = hnsw_index->search_key.ConstructFieldMeta(); + s = storage_->Get(rocksdb::ReadOptions(), hnsw_index->storage->GetCFHandle(ColumnFamilyID::Search), index_meta_key, + &value); + ASSERT_TRUE(s.ok()); + decoded_metadata.Decode(&value); + ASSERT_TRUE(decoded_metadata.num_levels == 3); + + auto s23 = node2_layer3.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); + EXPECT_TRUE(!s23.IsOK()); + + auto s24 = node2_layer2.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); + EXPECT_TRUE(!s24.IsOK()); + + auto s25 = node2_layer1.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); + EXPECT_TRUE(!s25.IsOK()); + + auto s26 = node2_layer0.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); + EXPECT_TRUE(!s26.IsOK()); + + auto s27 = node3_layer2.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); + ASSERT_TRUE(s27.IsOK()); + node3_layer2_meta = s27.GetValue(); + EXPECT_EQ(node3_layer2_meta.num_neighbours, 0); + + auto s28 = node1_layer1.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); + ASSERT_TRUE(s28.IsOK()); + node1_layer1_meta = s28.GetValue(); + EXPECT_EQ(node1_layer1_meta.num_neighbours, 2); + node1_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage); + expected_set = {"n3", "n4"}; + actual_set = {node1_layer1.neighbours.begin(), node1_layer1.neighbours.end()}; + EXPECT_EQ(actual_set, expected_set); + + auto s29 = node3_layer1.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); + ASSERT_TRUE(s29.IsOK()); + node3_layer1_meta = s29.GetValue(); + EXPECT_EQ(node3_layer1_meta.num_neighbours, 2); + node3_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage); + expected_set = {"n1", "n5"}; + actual_set = {node3_layer1.neighbours.begin(), node3_layer1.neighbours.end()}; + EXPECT_EQ(actual_set, expected_set); + + auto s30 = node4_layer1.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); + ASSERT_TRUE(s30.IsOK()); + node4_layer1_meta = s30.GetValue(); + EXPECT_EQ(node4_layer1_meta.num_neighbours, 2); + node4_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage); + expected_set = {"n1", "n5"}; + actual_set = {node4_layer1.neighbours.begin(), node4_layer1.neighbours.end()}; + EXPECT_EQ(actual_set, expected_set); + + auto s31 = node5_layer1.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); + ASSERT_TRUE(s31.IsOK()); + node5_layer1_meta = s31.GetValue(); + EXPECT_EQ(node5_layer1_meta.num_neighbours, 2); + node5_layer1.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage); + expected_set = {"n3", "n4"}; + actual_set = {node5_layer1.neighbours.begin(), node5_layer1.neighbours.end()}; + EXPECT_EQ(actual_set, expected_set); + + auto s32 = node1_layer0.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); + ASSERT_TRUE(s32.IsOK()); + node1_layer0_meta = s32.GetValue(); + EXPECT_EQ(node1_layer0_meta.num_neighbours, 3); + node1_layer0.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage); + expected_set = {"n3", "n4", "n5"}; + actual_set = {node1_layer0.neighbours.begin(), node1_layer0.neighbours.end()}; + EXPECT_EQ(actual_set, expected_set); + + redis::HnswNode node3_layer0(key3, 0); + auto s33 = node3_layer0.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); + ASSERT_TRUE(s33.IsOK()); + auto node3_layer0_meta = s33.GetValue(); + EXPECT_EQ(node3_layer0_meta.num_neighbours, 3); + node3_layer0.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage); + expected_set = {"n1", "n4", "n5"}; + actual_set = {node3_layer0.neighbours.begin(), node3_layer0.neighbours.end()}; + EXPECT_EQ(actual_set, expected_set); + + auto s34 = node4_layer0.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); + ASSERT_TRUE(s34.IsOK()); + node4_layer0_meta = s34.GetValue(); + EXPECT_EQ(node4_layer0_meta.num_neighbours, 3); + node4_layer0.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage); + expected_set = {"n1", "n3", "n5"}; + actual_set = {node4_layer0.neighbours.begin(), node4_layer0.neighbours.end()}; + EXPECT_EQ(actual_set, expected_set); + + auto s35 = node5_layer0.DecodeMetadata(hnsw_index->search_key, hnsw_index->storage); + ASSERT_TRUE(s35.IsOK()); + node5_layer0_meta = s35.GetValue(); + EXPECT_EQ(node5_layer0_meta.num_neighbours, 3); + node5_layer0.DecodeNeighbours(hnsw_index->search_key, hnsw_index->storage); + expected_set = {"n1", "n3", "n4"}; + actual_set = {node5_layer0.neighbours.begin(), node5_layer0.neighbours.end()}; + EXPECT_EQ(actual_set, expected_set); +} diff --git a/tests/cppunit/hnsw_node_test.cc b/tests/cppunit/hnsw_node_test.cc new file mode 100644 index 00000000000..5fadf9927a5 --- /dev/null +++ b/tests/cppunit/hnsw_node_test.cc @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + * + */ + +#include +#include +#include + +#include +#include +#include + +#include "search/hnsw_indexer.h" +#include "search/indexer.h" +#include "search/search_encoding.h" +#include "storage/storage.h" + +struct NodeTest : public TestBase { + std::string ns = "hnsw_node_test_ns"; + std::string idx_name = "hnsw_node_test_idx"; + std::string key = "vector"; + redis::SearchKey search_key; + + NodeTest() : search_key(ns, idx_name, key) {} + + void TearDown() override {} +}; + +TEST_F(NodeTest, PutAndDecodeMetadata) { + uint16_t layer = 0; + redis::HnswNode node1("node1", layer); + redis::HnswNode node2("node2", layer); + redis::HnswNode node3("node3", layer); + + redis::HnswNodeFieldMetadata metadata1(0, {1, 2, 3}); + redis::HnswNodeFieldMetadata metadata2(0, {4, 5, 6}); + redis::HnswNodeFieldMetadata metadata3(0, {7, 8, 9}); + + auto batch = storage_->GetWriteBatchBase(); + node1.PutMetadata(&metadata1, search_key, storage_.get(), batch.Get()); + node2.PutMetadata(&metadata2, search_key, storage_.get(), batch.Get()); + node3.PutMetadata(&metadata3, search_key, storage_.get(), batch.Get()); + auto s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); + ASSERT_TRUE(s.ok()); + + auto decoded_metadata1 = node1.DecodeMetadata(search_key, storage_.get()); + ASSERT_TRUE(decoded_metadata1.IsOK()); + ASSERT_EQ(decoded_metadata1.GetValue().num_neighbours, 0); + ASSERT_EQ(decoded_metadata1.GetValue().vector, std::vector({1, 2, 3})); + + auto decoded_metadata2 = node2.DecodeMetadata(search_key, storage_.get()); + ASSERT_TRUE(decoded_metadata2.IsOK()); + ASSERT_EQ(decoded_metadata2.GetValue().num_neighbours, 0); + ASSERT_EQ(decoded_metadata2.GetValue().vector, std::vector({4, 5, 6})); + + auto decoded_metadata3 = node3.DecodeMetadata(search_key, storage_.get()); + ASSERT_TRUE(decoded_metadata3.IsOK()); + ASSERT_EQ(decoded_metadata3.GetValue().num_neighbours, 0); + ASSERT_EQ(decoded_metadata3.GetValue().vector, std::vector({7, 8, 9})); + + // Prepare edges between node1 and node2 + batch = storage_->GetWriteBatchBase(); + auto edge1 = search_key.ConstructHnswEdge(layer, "node1", "node2"); + auto edge2 = search_key.ConstructHnswEdge(layer, "node2", "node1"); + auto edge3 = search_key.ConstructHnswEdge(layer, "node2", "node3"); + auto edge4 = search_key.ConstructHnswEdge(layer, "node3", "node2"); + + batch->Put(storage_->GetCFHandle(ColumnFamilyID::Search), edge1, Slice()); + batch->Put(storage_->GetCFHandle(ColumnFamilyID::Search), edge2, Slice()); + batch->Put(storage_->GetCFHandle(ColumnFamilyID::Search), edge3, Slice()); + batch->Put(storage_->GetCFHandle(ColumnFamilyID::Search), edge4, Slice()); + s = storage_->Write(storage_->DefaultWriteOptions(), batch->GetWriteBatch()); + ASSERT_TRUE(s.ok()); + + node1.DecodeNeighbours(search_key, storage_.get()); + EXPECT_EQ(node1.neighbours.size(), 1); + EXPECT_EQ(node1.neighbours[0], "node2"); + + node2.DecodeNeighbours(search_key, storage_.get()); + EXPECT_EQ(node2.neighbours.size(), 2); + std::unordered_set expected_neighbours = {"node1", "node3"}; + std::unordered_set actual_neighbours(node2.neighbours.begin(), node2.neighbours.end()); + EXPECT_EQ(actual_neighbours, expected_neighbours); + + node3.DecodeNeighbours(search_key, storage_.get()); + EXPECT_EQ(node3.neighbours.size(), 1); + EXPECT_EQ(node3.neighbours[0], "node2"); +} + +TEST_F(NodeTest, ModifyNeighbours) { + uint16_t layer = 1; + redis::HnswNode node1("node1", layer); + redis::HnswNode node2("node2", layer); + redis::HnswNode node3("node3", layer); + redis::HnswNode node4("node4", layer); + + redis::HnswNodeFieldMetadata metadata1(0, {1, 2, 3}); + redis::HnswNodeFieldMetadata metadata2(0, {4, 5, 6}); + redis::HnswNodeFieldMetadata metadata3(0, {7, 8, 9}); + redis::HnswNodeFieldMetadata metadata4(0, {10, 11, 12}); + + // Add Nodes + auto batch1 = storage_->GetWriteBatchBase(); + node1.PutMetadata(&metadata1, search_key, storage_.get(), batch1.Get()); + node2.PutMetadata(&metadata2, search_key, storage_.get(), batch1.Get()); + node3.PutMetadata(&metadata3, search_key, storage_.get(), batch1.Get()); + node4.PutMetadata(&metadata4, search_key, storage_.get(), batch1.Get()); + auto s = storage_->Write(storage_->DefaultWriteOptions(), batch1->GetWriteBatch()); + ASSERT_TRUE(s.ok()); + + // Add Edges + auto batch2 = storage_->GetWriteBatchBase(); + auto s1 = node1.AddNeighbour("node2", search_key, storage_.get(), batch2.Get()); + ASSERT_TRUE(s1.IsOK()); + auto s2 = node2.AddNeighbour("node1", search_key, storage_.get(), batch2.Get()); + ASSERT_TRUE(s2.IsOK()); + auto s3 = node2.AddNeighbour("node3", search_key, storage_.get(), batch2.Get()); + ASSERT_TRUE(s3.IsOK()); + auto s4 = node3.AddNeighbour("node2", search_key, storage_.get(), batch2.Get()); + ASSERT_TRUE(s4.IsOK()); + s = storage_->Write(storage_->DefaultWriteOptions(), batch2->GetWriteBatch()); + ASSERT_TRUE(s.ok()); + + node1.DecodeNeighbours(search_key, storage_.get()); + EXPECT_EQ(node1.neighbours.size(), 1); + EXPECT_EQ(node1.neighbours[0], "node2"); + + node2.DecodeNeighbours(search_key, storage_.get()); + EXPECT_EQ(node2.neighbours.size(), 2); + std::unordered_set expected_neighbours = {"node1", "node3"}; + std::unordered_set actual_neighbours(node2.neighbours.begin(), node2.neighbours.end()); + EXPECT_EQ(actual_neighbours, expected_neighbours); + + node3.DecodeNeighbours(search_key, storage_.get()); + EXPECT_EQ(node3.neighbours.size(), 1); + EXPECT_EQ(node3.neighbours[0], "node2"); + + // Remove Edges + auto batch3 = storage_->GetWriteBatchBase(); + auto s5 = node2.RemoveNeighbour("node3", search_key, storage_.get(), batch3.Get()); + ASSERT_TRUE(s5.IsOK()); + + s = storage_->Write(storage_->DefaultWriteOptions(), batch3->GetWriteBatch()); + ASSERT_TRUE(s.ok()); + + node2.DecodeNeighbours(search_key, storage_.get()); + EXPECT_EQ(node2.neighbours.size(), 1); + EXPECT_EQ(node2.neighbours[0], "node1"); +}