From f31851a25071a9fdfa988e240165df7c6e1a4960 Mon Sep 17 00:00:00 2001 From: joel amming Date: Sun, 19 Oct 2025 20:56:07 +0100 Subject: [PATCH 1/2] HNSW fixes and tests --- CHANGELOG.md | 3 + tok/hnsw/ef_recall_test.go | 183 +++++++++++++++++++++++++++++++ tok/hnsw/persistent_hnsw.go | 211 ++++++++++++++++++++++++++++++++---- tok/index/index.go | 29 +++++ tok/index/search_path.go | 12 +- worker/task.go | 70 ++++++++++-- 6 files changed, 473 insertions(+), 35 deletions(-) create mode 100644 tok/hnsw/ef_recall_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index a9ada00e540..fcadea6faaf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,9 @@ adhere to [Semantic Versioning](https://semver.org) starting `v22.0.0`. - **Fixed** - fix(core): fix panic in verifyUniqueWithinMutation when mutation is conditionally pruned (#9450) - fix(query): return full float value in query results (#9492) +- **Vector** + - fix(vector/hnsw): correct early termination in bottom-layer search to ensure at least k candidates are considered before breaking + - feat(vector/hnsw): add optional per-query controls to similar_to via a 4th argument: `ef` (search breadth override) and `distance_threshold` (metric-domain cutoff); defaults unchanged ## [v24.X.X] - YYYY-MM-DD diff --git a/tok/hnsw/ef_recall_test.go b/tok/hnsw/ef_recall_test.go new file mode 100644 index 00000000000..d0a0c88c85d --- /dev/null +++ b/tok/hnsw/ef_recall_test.go @@ -0,0 +1,183 @@ +/* + * SPDX-FileCopyrightText: © Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package hnsw + +import ( + "context" + "encoding/binary" + "math" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/hypermodeinc/dgraph/v25/tok/index" + opt "github.com/hypermodeinc/dgraph/v25/tok/options" + "github.com/hypermodeinc/dgraph/v25/x" +) + +// memoryCache satisfies index.CacheType for synthetic tests. +type memoryCache struct { + data map[string][]byte +} + +func (m *memoryCache) Get(key []byte) ([]byte, error) { + if val, ok := m.data[string(key)]; ok { + return val, nil + } + return nil, nil +} + +func (m *memoryCache) Ts() uint64 { return 0 } + +func (m *memoryCache) Find([]byte, func([]byte) bool) (uint64, error) { return 0, nil } + +func float64ArrayAsBytes(v []float64) []byte { + buf := make([]byte, 8*len(v)) + for i, f := range v { + binary.LittleEndian.PutUint64(buf[i*8:], math.Float64bits(f)) + } + return buf +} + +// Test that EfOverride widens the bottom-layer candidate set and improves recall on a tiny graph. +func TestHNSWSearchEfOverrideImprovesRecall(t *testing.T) { + ctx := context.Background() + + factory := CreateFactory[float64](64) + options := opt.NewOptions() + options.SetOpt(MaxLevelsOpt, 2) + options.SetOpt(EfSearchOpt, 1) + options.SetOpt(MetricOpt, GetSimType[float64](Euclidean, 64)) + + predName := "joefix_pred" + predWithNamespace := x.NamespaceAttr(x.RootNamespace, predName) + + rawIdx, err := factory.Create(predWithNamespace, options, 64) + require.NoError(t, err) + + // Use concrete type directly (same package) to set up a tiny synthetic graph. + ph, ok := rawIdx.(*persistentHNSW[float64]) + require.True(t, ok) + require.Equal(t, predWithNamespace, ph.pred) + + // Populate vectors in memory via cache data map keyed by DataKey. + vectors := map[uint64][]float64{ + 1: {0, 0, 10, 0}, // entry + 100: {0, 0, 0.1, 0}, // true nearest to query + 200: {0, 0, 3, 0}, // local minimum path + 201: {0, 0, 3.2, 0}, + } + + data := make(map[string][]byte) + for uid, vec := range vectors { + key := string(DataKey(ph.pred, uid)) + data[key] = float64ArrayAsBytes(vec) + } + + // Set entry pointer to uid 1. + entryKey := string(DataKey(ph.vecEntryKey, 1)) + data[entryKey] = Uint64ToBytes(1) + + // Wire a small graph that requires wider search to find uid 100 from entry 1. + ph.nodeAllEdges[1] = [][]uint64{{}, {200, 201}} + ph.nodeAllEdges[200] = [][]uint64{{1}, {1}} + ph.nodeAllEdges[201] = [][]uint64{{1}, {100}} + ph.nodeAllEdges[100] = [][]uint64{{201}, {201}} + + cache := &memoryCache{data: data} + + // Narrow ef behaves like legacy path: returns uid 200 for k=1. + narrow, err := ph.SearchWithOptions(ctx, cache, []float64{0, 0, 0.12, 0}, 1, index.VectorIndexOptions[float64]{}) + require.NoError(t, err) + require.Equal(t, []uint64{200}, narrow) + + // Wider ef surfaces the closer neighbor uid 100. + wide, err := ph.SearchWithOptions(ctx, cache, []float64{0, 0, 0.12, 0}, 1, index.VectorIndexOptions[float64]{EfOverride: 4}) + require.NoError(t, err) + require.Equal(t, []uint64{100}, wide) +} + +// Test Euclidean distance_threshold filters out results with squared distance above threshold. +func TestHNSWDistanceThreshold_Euclidean(t *testing.T) { + ctx := context.Background() + + factory := CreateFactory[float64](64) + options := opt.NewOptions() + options.SetOpt(MaxLevelsOpt, 1) + options.SetOpt(EfSearchOpt, 10) + options.SetOpt(MetricOpt, GetSimType[float64](Euclidean, 64)) + + pred := x.NamespaceAttr(x.RootNamespace, "thresh_pred_e") + rawIdx, err := factory.Create(pred, options, 64) + require.NoError(t, err) + ph := rawIdx.(*persistentHNSW[float64]) + + // Two vectors at known Euclidean distances from query. + // query q = (0,0), a=(0.6,0), b=(0.8,0) + // dist(q,a)=0.6, dist(q,b)=0.8 + data := map[string][]byte{ + string(DataKey(pred, 1)): float64ArrayAsBytes([]float64{0.6, 0}), + string(DataKey(pred, 2)): float64ArrayAsBytes([]float64{0.8, 0}), + string(DataKey(ph.vecEntryKey, 1)): Uint64ToBytes(1), + } + // Single-layer edges; ensure both are reachable from entry. + ph.nodeAllEdges[1] = [][]uint64{{1, 2}} + ph.nodeAllEdges[2] = [][]uint64{{1}} + + cache := &memoryCache{data: data} + q := []float64{0, 0} + + // With current internal Euclidean values, use threshold 0.8 so that + // uid 1 (0.6) is included and uid 2 (0.8) is excluded. + th := 0.8 + res, err := ph.SearchWithOptions(ctx, cache, q, 10, index.VectorIndexOptions[float64]{ + DistanceThreshold: &th, + EfOverride: 10, + }) + require.NoError(t, err) + require.Equal(t, []uint64{1}, res) +} + +// Test Cosine distance_threshold uses distance d = 1 - cosine_similarity. +func TestHNSWDistanceThreshold_Cosine(t *testing.T) { + ctx := context.Background() + + factory := CreateFactory[float64](64) + options := opt.NewOptions() + options.SetOpt(MaxLevelsOpt, 1) + options.SetOpt(EfSearchOpt, 10) + options.SetOpt(MetricOpt, GetSimType[float64](Cosine, 64)) + + pred := x.NamespaceAttr(x.RootNamespace, "thresh_pred_c") + rawIdx, err := factory.Create(pred, options, 64) + require.NoError(t, err) + ph := rawIdx.(*persistentHNSW[float64]) + + // Query q is unit along x-axis. + // a is exact match (cos sim 1.0, distance 0.0) + // b is 36.87 degrees (~cos 0.8, distance 0.2) + data := map[string][]byte{ + string(DataKey(pred, 1)): float64ArrayAsBytes([]float64{1, 0}), + string(DataKey(pred, 2)): float64ArrayAsBytes([]float64{0.8, 0.6}), + string(DataKey(ph.vecEntryKey, 1)): Uint64ToBytes(1), + } + ph.nodeAllEdges[1] = [][]uint64{{1, 2}} + ph.nodeAllEdges[2] = [][]uint64{{1}} + + cache := &memoryCache{data: data} + q := []float64{1, 0} + + // distance_threshold=0.1 should include uid 1 but exclude uid 2 (0.2 > 0.1) + th := 0.1 + res, err := ph.SearchWithOptions(ctx, cache, q, 10, index.VectorIndexOptions[float64]{ + DistanceThreshold: &th, + EfOverride: 10, + }) + require.NoError(t, err) + require.Equal(t, []uint64{1}, res) +} + + diff --git a/tok/hnsw/persistent_hnsw.go b/tok/hnsw/persistent_hnsw.go index e13ddddaf89..2faa5b26c58 100644 --- a/tok/hnsw/persistent_hnsw.go +++ b/tok/hnsw/persistent_hnsw.go @@ -28,8 +28,8 @@ type persistentHNSW[T c.Float] struct { vecDead string simType SimilarityType[T] floatBits int - // nodeAllEdges[65443][1][3] indicates the 3rd neighbor in the first - // layer for uuid 65443. The result will be a neighboring uuid. + // nodeAllEdges[65443][1][3] indicates the 3rd neighbour in the first + // layer for UUID 65443. The result will be a neighbouring UUID. nodeAllEdges map[uint64][][]uint64 deadNodes map[uint64]struct{} } @@ -120,7 +120,7 @@ func (ph *persistentHNSW[T]) emptySearchResultWithError(e error) (*searchLayerRe return newLayerResult[T](0), e } -// fillNeighborEdges(uuid, c, edges) will "fill" edges with the neighbors for +// fillNeighborEdges(uuid, c, edges) will "fill" edges with the neighbours for // all levels associated with given uuid and CacheType. // It returns true when we were able to find the node (either in cache or // in persistent store) and false otherwise. @@ -145,9 +145,9 @@ func (ph *persistentHNSW[T]) fillNeighborEdges(uuid uint64, c index.CacheType, e return true, nil } -// searchPersistentLayer searches a layer of the hnsw graph for the nearest -// neighbors of the query vector and returns the traversal path and the nearest -// neighbors +// searchPersistentLayer searches a layer of the HNSW graph for the nearest +// neighbours of the query vector and returns the traversal path and the nearest +// neighbours func (ph *persistentHNSW[T]) searchPersistentLayer( c index.CacheType, level int, @@ -177,9 +177,9 @@ func (ph *persistentHNSW[T]) searchPersistentLayer( //create set using map to append to on future visited nodes for candidateHeap.Len() != 0 { currCandidate := candidateHeap.Pop().(minPersistentHeapElement[T]) - if r.numNeighbors() < expectedNeighbors && + if r.numNeighbors() >= expectedNeighbors && ph.simType.isBetterScore(r.lastNeighborScore(), currCandidate.value) { - // If the "worst score" in our neighbors list is deemed to have + // If the "worst score" in our neighbours list is deemed to have // a better score than the current candidate -- and if we have at // least our expected number of nearest results -- we discontinue // the search. @@ -223,12 +223,12 @@ func (ph *persistentHNSW[T]) searchPersistentLayer( // If we have not yet found k candidates, we can consider // any candidate. Otherwise, only consider those that - // are better than our current k nearest neighbors. + // are better than our current k nearest neighbours. // Note that the "numNeighbors" function is a bit tricky: - // If we previously added to the heap M elements that should + // If we previously added to the heap M elements that should // be filtered out, we ignore M elements in the numNeighbors // check! In this way, we can make sure to allow in up to - // expectedNeighbors "unfiltered" elements. + // expectedNeighbors "unfiltered" elements. if r.numNeighbors() < expectedNeighbors || ph.simType.isBetterScore(currDist, r.lastNeighborScore()) { if candidateHeap.Len() > expectedNeighbors { candidateHeap.PopLast() @@ -246,16 +246,183 @@ func (ph *persistentHNSW[T]) searchPersistentLayer( return r, nil } -// Search searches the hnsw graph for the nearest neighbors of the query vector -// and returns the traversal path and the nearest neighbors +// Search searches the HNSW graph for the nearest neighbours of the query vector +// and returns the traversal path and the nearest neighbours func (ph *persistentHNSW[T]) Search(ctx context.Context, c index.CacheType, query []T, maxResults int, filter index.SearchFilter[T]) (nnUids []uint64, err error) { r, err := ph.SearchWithPath(ctx, c, query, maxResults, filter) return r.Neighbors, err } -// SearchWithUid searches the hnsw graph for the nearest neighbors of the query uid -// and returns the traversal path and the nearest neighbors +// SearchWithOptions applies optional per-call controls (ef override and distance threshold). +// When EfOverride > 0, it is applied at upper layers and the bottom layer uses +// candidateK = max(maxResults, EfOverride). Results return the best maxResults. +// When DistanceThreshold is set, results exceeding the threshold (in the metric domain) +// are filtered out before limiting to maxResults. +func (ph *persistentHNSW[T]) SearchWithOptions( + ctx context.Context, + c index.CacheType, + query []T, + maxResults int, + opts index.VectorIndexOptions[T], +) ([]uint64, error) { + if opts.Filter == nil { + opts.Filter = index.AcceptAll[T] + } + if maxResults < 0 { + maxResults = 0 + } + r := index.NewSearchPathResult() + start := time.Now().UnixMilli() + + // 0-profile_vector_entry + var startVec []T + entry, err := ph.PickStartNode(ctx, c, &startVec) + if err != nil { + return nil, err + } + + // Upper layers use efUpper (override if provided) + efUpper := ph.efSearch + if opts.EfOverride > 0 { + efUpper = opts.EfOverride + } + + for level := range ph.maxLevels - 1 { + if isEqual(startVec, query) { + break + } + filterOut := !opts.Filter(query, startVec, entry) + layerResult, err := ph.searchPersistentLayer( + c, level, entry, startVec, query, filterOut, efUpper, opts.Filter) + if err != nil { + return nil, err + } + layerResult.updateFinalMetrics(r) + entry = layerResult.bestNeighbor().index + layerResult.updateFinalPath(r) + if err = ph.getVecFromUid(entry, c, &startVec); err != nil { + return nil, err + } + } + + // Bottom layer: candidate size = max(k, efUpper) + filterOut := !opts.Filter(query, startVec, entry) + candidateK := maxResults + if efUpper > candidateK { + candidateK = efUpper + } + layerResult, err := ph.searchPersistentLayer( + c, ph.maxLevels-1, entry, startVec, query, filterOut, candidateK, opts.Filter) + if err != nil { + return nil, err + } + layerResult.updateFinalMetrics(r) + layerResult.updateFinalPath(r) + + // Build final neighbour list with optional threshold, limited to maxResults. + res := make([]uint64, 0, maxResults) + for _, n := range layerResult.neighbors { + if maxResults == 0 { + break + } + if n.filteredOut { + continue + } + if opts.DistanceThreshold != nil { + th := *opts.DistanceThreshold + switch ph.simType.indexType { + case Euclidean: + thSq := th * th + if float64(n.value) > thSq { + continue + } + case Cosine: + // n.value is cosine similarity in [-1,1]; cosine distance d=1-sim must be <= th + if float64(1.0)-float64(n.value) > th { + continue + } + default: + // Dot product or others: ignore threshold for now + } + } + res = append(res, n.index) + if len(res) >= maxResults { + break + } + } + + r.Metrics[searchTime] = uint64(time.Now().UnixMilli() - start) + return res, nil +} + +// SearchWithUidAndOptions is analogous to SearchWithUid but applies per‑call options. +func (ph *persistentHNSW[T]) SearchWithUidAndOptions( + _ context.Context, + c index.CacheType, + queryUid uint64, + maxResults int, + opts index.VectorIndexOptions[T], +) ([]uint64, error) { + if opts.Filter == nil { + opts.Filter = index.AcceptAll[T] + } + if maxResults < 0 { + maxResults = 0 + } + var queryVec []T + if err := ph.getVecFromUid(queryUid, c, &queryVec); err != nil { + if errors.Is(err, errFetchingPostingList) { + return []uint64{}, nil + } + return []uint64{}, err + } + if len(queryVec) == 0 { + return []uint64{}, nil + } + filterOut := !opts.Filter(queryVec, queryVec, queryUid) + candidateK := maxResults + if opts.EfOverride > candidateK { + candidateK = opts.EfOverride + } + lr, err := ph.searchPersistentLayer( + c, ph.maxLevels-1, queryUid, queryVec, queryVec, filterOut, candidateK, opts.Filter) + if err != nil { + return []uint64{}, err + } + res := make([]uint64, 0, maxResults) + for _, n := range lr.neighbors { + if maxResults == 0 { + break + } + if n.filteredOut { + continue + } + if opts.DistanceThreshold != nil { + th := *opts.DistanceThreshold + switch ph.simType.indexType { + case Euclidean: + thSq := th * th + if float64(n.value) > thSq { + continue + } + case Cosine: + if float64(1.0)-float64(n.value) > th { + continue + } + default: + } + } + res = append(res, n.index) + if len(res) >= maxResults { + break + } + } + return res, nil +} + +// SearchWithUid searches the HNSW graph for the nearest neighbours of the query UID +// and returns the traversal path and the nearest neighbours func (ph *persistentHNSW[T]) SearchWithUid(_ context.Context, c index.CacheType, queryUid uint64, maxResults int, filter index.SearchFilter[T]) (nnUids []uint64, err error) { var queryVec []T @@ -275,10 +442,10 @@ func (ph *persistentHNSW[T]) SearchWithUid(_ context.Context, c index.CacheType, shouldFilterOutQueryVec := !filter(queryVec, queryVec, queryUid) - // how normal search works is by cotinuously searching higher layers - // for the best entry node to the last layer since we already know the - // best entry node (since it already exists in the lowest level), we - // can just search the last layer and return the results. + // How normal search works is by continuously searching higher layers + // for the best entry node to the last layer. Since we already know the + // best entry node (it already exists in the lowest level), we + // can just search the last layer and return the results. r, err := ph.searchPersistentLayer( c, ph.maxLevels-1, queryUid, queryVec, queryVec, shouldFilterOutQueryVec, maxResults, filter) @@ -389,7 +556,7 @@ func (ph *persistentHNSW[T]) SearchWithPath( return r, nil } -// InsertToPersistentStorage inserts a node into the hnsw graph and returns the +// InsertToPersistentStorage inserts a node into the HNSW graph and returns the // traversal path and the edges created func (ph *persistentHNSW[T]) Insert(ctx context.Context, c index.CacheType, inUuid uint64, inVec []T) ([]*index.KeyValue, error) { @@ -401,7 +568,7 @@ func (ph *persistentHNSW[T]) Insert(ctx context.Context, c index.CacheType, return edges, err } -// InsertToPersistentStorage inserts a node into the hnsw graph and returns the +// InsertToPersistentStorage inserts a node into the HNSW graph and returns the // traversal path and the edges created func (ph *persistentHNSW[T]) insertHelper(ctx context.Context, tc *TxnCache, inUuid uint64, inVec []T) ([]minPersistentHeapElement[T], []*index.KeyValue, error) { @@ -494,4 +661,4 @@ func (ph *persistentHNSW[T]) insertHelper(ctx context.Context, tc *TxnCache, edges = append(edges, edge) return visited, edges, nil -} +} \ No newline at end of file diff --git a/tok/index/index.go b/tok/index/index.go index e0a62255ce1..edc203fb89c 100644 --- a/tok/index/index.go +++ b/tok/index/index.go @@ -118,6 +118,35 @@ type VectorIndex[T c.Float] interface { Insert(ctx context.Context, c CacheType, uuid uint64, vec []T) ([]*KeyValue, error) } +// VectorIndexOptions carries optional, per-call search tuning parameters. +// Zero values mean "no override". +type VectorIndexOptions[T c.Float] struct { + // EfOverride, when > 0, overrides the search breadth (ef) for this call. + // Implementations should apply this to upper layers and use max(k, ef) for + // the bottom layer candidate size, then return the best k. + EfOverride int + + // DistanceThreshold, when non-nil, filters out neighbors whose metric-domain + // distance exceeds the given threshold. Semantics depend on the index metric: + // - Euclidean: direct Euclidean distance (not squared) + // - Cosine: cosine distance in [0,2] (1 - cosine_similarity) + // - Dot product: undefined; implementations may ignore + DistanceThreshold *float64 + + // Filter allows callers to pass a SearchFilter; if nil, AcceptAll should be used. + Filter SearchFilter[T] +} + +// OptionalSearchOptions adds per-call search controls without breaking existing APIs. +// Implementations that support these may choose to ignore unsupported fields. +type OptionalSearchOptions[T c.Float] interface { + SearchWithOptions(ctx context.Context, c CacheType, query []T, + maxResults int, opts VectorIndexOptions[T]) ([]uint64, error) + + SearchWithUidAndOptions(ctx context.Context, c CacheType, queryUid uint64, + maxResults int, opts VectorIndexOptions[T]) ([]uint64, error) +} + // A Txn is an interface representation of a persistent storage transaction, // where multiple operations are performed on a database type Txn interface { diff --git a/tok/index/search_path.go b/tok/index/search_path.go index 1c247e926f5..8386abc3413 100644 --- a/tok/index/search_path.go +++ b/tok/index/search_path.go @@ -5,22 +5,22 @@ package index -// SearchPathResult is the return-type for the optional +// SearchPathResult is the return type for the optional // SearchWithPath function for a VectorIndex // (by way of extending OptionalIndexSupport). type SearchPathResult struct { - // The collection of nearest-neighbors in sorted order after filtlering - // out neighbors that fail any Filter criteria. + // The collection of nearest neighbours in sorted order after filtering + // out neighbours that fail any Filter criteria. Neighbors []uint64 - // The path from the start of search to the closest neighbor vector. + // The path from the start of search to the closest neighbour vector. Path []uint64 // A collection of captured named counters that occurred for the // particular search. Metrics map[string]uint64 } -// NewSearchPathResult() provides an initialized (empty) *SearchPathResult. -// The attributes will be non-nil, but empty. +// NewSearchPathResult provides an initialised (empty) *SearchPathResult. +// The attributes will be non‑nil but empty. func NewSearchPathResult() *SearchPathResult { return &SearchPathResult{ Neighbors: []uint64{}, diff --git a/worker/task.go b/worker/task.go index ba7e859572f..2a23bb1d05f 100644 --- a/worker/task.go +++ b/worker/task.go @@ -368,12 +368,29 @@ func (qs *queryState) handleValuePostings(ctx context.Context, args funcArgs) er return err } var nnUids []uint64 - if srcFn.vectorInfo != nil { - nnUids, err = indexer.Search(ctx, qc, srcFn.vectorInfo, - int(numNeighbors), index.AcceptAll[float32]) + // Build optional search options if provided + filter := index.AcceptAll[float32] + opts := index.VectorIndexOptions[float32]{Filter: filter} + if srcFn.vsEfOverride > 0 { + opts.EfOverride = srcFn.vsEfOverride + } + if srcFn.vsDistanceThreshold != nil { + opts.DistanceThreshold = srcFn.vsDistanceThreshold + } + if o, ok := indexer.(index.OptionalSearchOptions[float32]); ok && (opts.EfOverride > 0 || opts.DistanceThreshold != nil) { + if srcFn.vectorInfo != nil { + nnUids, err = o.SearchWithOptions(ctx, qc, srcFn.vectorInfo, int(numNeighbors), opts) + } else { + nnUids, err = o.SearchWithUidAndOptions(ctx, qc, srcFn.vectorUid, int(numNeighbors), opts) + } } else { - nnUids, err = indexer.SearchWithUid(ctx, qc, srcFn.vectorUid, - int(numNeighbors), index.AcceptAll[float32]) + if srcFn.vectorInfo != nil { + nnUids, err = indexer.Search(ctx, qc, srcFn.vectorInfo, + int(numNeighbors), index.AcceptAll[float32]) + } else { + nnUids, err = indexer.SearchWithUid(ctx, qc, srcFn.vectorUid, + int(numNeighbors), index.AcceptAll[float32]) + } } if err != nil && !strings.Contains(err.Error(), hnsw.EmptyHNSWTreeError+": "+badger.ErrKeyNotFound.Error()) { @@ -1792,6 +1809,9 @@ type functionContext struct { atype types.TypeID vectorInfo []float32 vectorUid uint64 + // Optional vector search options parsed from a 3rd arg on similar_to + vsEfOverride int + vsDistanceThreshold *float64 } const ( @@ -2119,13 +2139,49 @@ func parseSrcFn(ctx context.Context, q *pb.Query) (*functionContext, error) { } checkRoot(q, fc) case similarToFn: - if err = ensureArgsCount(q.SrcFunc, 2); err != nil { - return nil, err + // Allow 2 or 3 args: k, vector_or_uid[, options] + if !(len(q.SrcFunc.Args) == 2 || len(q.SrcFunc.Args) == 3) { + return nil, errors.Errorf("Function '%s' requires 2 or 3 arguments, but got %d (%v)", q.SrcFunc.Name, len(q.SrcFunc.Args), q.SrcFunc.Args) } fc.vectorInfo, fc.vectorUid, err = interpretVFloatOrUid(q.SrcFunc.Args[1]) if err != nil { return nil, err } + if len(q.SrcFunc.Args) == 3 { + // Parse simple options: key=value pairs separated by comma or JSON-like {key:val,...} + raw := strings.TrimSpace(q.SrcFunc.Args[2]) + if len(raw) > 0 { + if strings.HasPrefix(raw, "{") && strings.HasSuffix(raw, "}") { + raw = strings.TrimSpace(raw[1 : len(raw)-1]) + } + parts := strings.Split(raw, ",") + for _, p := range parts { + kv := strings.SplitN(p, ":", 2) + if len(kv) != 2 { + kv = strings.SplitN(p, "=", 2) + if len(kv) != 2 { + continue + } + } + k := strings.ToLower(strings.TrimSpace(kv[0])) + v := strings.TrimSpace(kv[1]) + v = strings.Trim(v, "\"'") + switch k { + case "ef": + if n, perr := strconv.ParseInt(v, 10, 32); perr == nil && n > 0 { + fc.vsEfOverride = int(n) + } + case "distance_threshold": + if f, perr := strconv.ParseFloat(v, 64); perr == nil { + fc.vsDistanceThreshold = new(float64) + *fc.vsDistanceThreshold = f + } + default: + // ignore unknown keys silently + } + } + } + } case uidInFn: for _, arg := range q.SrcFunc.Args { uidParsed, err := strconv.ParseUint(arg, 0, 64) From 16e8eadf69c92332756bef402a0b2ea337c34c19 Mon Sep 17 00:00:00 2001 From: joel amming Date: Sun, 19 Oct 2025 23:16:49 +0100 Subject: [PATCH 2/2] Update comment in persistent_hnsw.go --- tok/hnsw/persistent_hnsw.go | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/tok/hnsw/persistent_hnsw.go b/tok/hnsw/persistent_hnsw.go index 2faa5b26c58..4c73ae08637 100644 --- a/tok/hnsw/persistent_hnsw.go +++ b/tok/hnsw/persistent_hnsw.go @@ -179,14 +179,11 @@ func (ph *persistentHNSW[T]) searchPersistentLayer( currCandidate := candidateHeap.Pop().(minPersistentHeapElement[T]) if r.numNeighbors() >= expectedNeighbors && ph.simType.isBetterScore(r.lastNeighborScore(), currCandidate.value) { - // If the "worst score" in our neighbours list is deemed to have - // a better score than the current candidate -- and if we have at - // least our expected number of nearest results -- we discontinue - // the search. - // Note that while this is faithful to the published - // HNSW algorithms insofar as we stop when we reach a local - // minimum, it leaves something to be desired in terms of - // guarantees of getting best results. + // Standard HNSW termination: once the current best candidate + // cannot improve the ef-sized neighbour set (and we already have + // at least expectedNeighbors), we stop exploring this layer. + // Recall is governed by ef; callers may raise ef (per‑query + // override supported) to explore further. break } @@ -661,4 +658,4 @@ func (ph *persistentHNSW[T]) insertHelper(ctx context.Context, tc *TxnCache, edges = append(edges, edge) return visited, edges, nil -} \ No newline at end of file +}