diff --git a/CHANGELOG.md b/CHANGELOG.md index fd19108bf2a..9fde3691d5b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -88,6 +88,11 @@ as a guide. - **Query** - fix(query): return full float value in query results (#9492) +- **Vector** + - fix(vector/hnsw): correct early termination in bottom-layer search to ensure at least k + candidates are considered before breaking + - feat(vector/hnsw): add optional per-query controls to similar_to via named parameters: `ef` + (search breadth override) and `distance_threshold` (metric-domain cutoff); defaults unchanged - **Changed** diff --git a/dql/parser.go b/dql/parser.go index 618d6c0ca4d..52a69388668 100644 --- a/dql/parser.go +++ b/dql/parser.go @@ -1745,6 +1745,10 @@ L: name := collectName(it, item.Val) function.Name = strings.ToLower(name) + var similarToOptSeen map[string]struct{} + if function.Name == similarToFn { + similarToOptSeen = make(map[string]struct{}) + } if _, ok := tryParseItemType(it, itemLeftRound); !ok { return nil, it.Errorf("Expected ( after func name [%s]", function.Name) } @@ -1874,7 +1878,10 @@ L: case IsInequalityFn(function.Name): err = parseFuncArgs(it, function) - case function.Name == "uid_in" || function.Name == "similar_to": + case function.Name == "uid_in": + err = parseFuncArgs(it, function) + + case function.Name == "similar_to": err = parseFuncArgs(it, function) default: @@ -1892,7 +1899,87 @@ L: } expectArg = false continue + case itemLeftCurl: + return nil, itemInFunc.Errorf("Unrecognized character inside a func: U+007B '{'") + case itemRightCurl: + // Right curly braces are never valid in function arguments outside of + // the (unsupported) object literal syntax. Always error on stray '}'. + return nil, itemInFunc.Errorf("Unrecognized character inside a func: U+007D '}'") default: + // similar_to supports named optional parameters after the 3rd positional argument: + // similar_to(pred, k, vec, ef: 64, distance_threshold: 0.5) + // + // Internally we represent each option as two args appended after k and vec: + // ["ef", "64", "distance_threshold", "0.5", ...] + if itemInFunc.Typ == itemName && function.Name == similarToFn && + function.Attr != "" && len(function.Args) >= 2 { + next, ok := it.PeekOne() + if ok && next.Typ == itemColon { + key := strings.ToLower(collectName(it, itemInFunc.Val)) + switch key { + case "ef", "distance_threshold": + default: + return nil, itemInFunc.Errorf("Unknown option %q in similar_to", key) + } + if _, exists := similarToOptSeen[key]; exists { + return nil, itemInFunc.Errorf("Duplicate key %q in similar_to options", key) + } + similarToOptSeen[key] = struct{}{} + + if ok := trySkipItemTyp(it, itemColon); !ok { + return nil, it.Errorf("Expected colon(:) after %s", key) + } + if !it.Next() { + return nil, it.Errorf("Expected value for %s", key) + } + valItem := it.Item() + switch valItem.Typ { + case itemDollar: + varName, err := parseVarName(it) + if err != nil { + return nil, err + } + function.Args = append(function.Args, Arg{Value: key}) + function.Args = append(function.Args, Arg{Value: varName, IsDQLVar: true}) + case itemMathOp: + // Allow signed numeric literals, e.g. distance_threshold: -0.5 + prefix := valItem.Val + if !it.Next() { + return nil, it.Errorf("Expected value after %s for %s", prefix, key) + } + valItem = it.Item() + if valItem.Typ != itemName { + return nil, valItem.Errorf("Expected value for %s", key) + } + v := collectName(it, valItem.Val) + v = strings.Trim(v, " \t") + uq, err := unquoteIfQuoted(v) + if err != nil { + return nil, err + } + function.Args = append(function.Args, Arg{Value: key}) + function.Args = append(function.Args, Arg{Value: prefix + uq}) + default: + if valItem.Typ != itemName { + return nil, valItem.Errorf("Expected value for %s", key) + } + v := collectName(it, valItem.Val) + v = strings.Trim(v, " \t") + uq, err := unquoteIfQuoted(v) + if err != nil { + return nil, err + } + function.Args = append(function.Args, Arg{Value: key}) + function.Args = append(function.Args, Arg{Value: uq}) + } + + expectArg = false + continue + } + + // Disallow extra positional args after (k, vec). Options must be named. + return nil, itemInFunc.Errorf("Expected named parameter in similar_to options (e.g. ef: 64)") + } if itemInFunc.Typ != itemName { return nil, itemInFunc.Errorf("Expected arg after func [%s], but got item %v", function.Name, itemInFunc) @@ -2408,6 +2495,10 @@ loop: // The parentheses are balanced out. Let's break. break loop } + case item.Typ == itemLeftCurl: + return nil, item.Errorf("Unrecognized character inside a func: U+007B '{'") + case item.Typ == itemRightCurl: + return nil, item.Errorf("Unrecognized character inside a func: U+007D '}'") default: return nil, item.Errorf("Unexpected item while parsing @filter: %v", item) } diff --git a/dql/parser_test.go b/dql/parser_test.go index 6d49c7f909c..9c6f6bddced 100644 --- a/dql/parser_test.go +++ b/dql/parser_test.go @@ -2518,6 +2518,12 @@ func TestParseFilter_brac(t *testing.T) { } // Test if unbalanced brac will lead to errors. +// Note: This query has two errors: missing ')' after '()' AND a stray '{'. +// After changes to support similar_to's JSON args the lexer now emits brace tokens +// instead of erroring immediately. This causes the query to fail on the structural +// error (unclosed brackets) rather than the character-specific error. This is an +// acceptable trade-off because queries with multiple syntax errors may report a different +// (but equally fatal) error first. func TestParseFilter_unbalancedbrac(t *testing.T) { query := ` query { @@ -2532,8 +2538,119 @@ func TestParseFilter_unbalancedbrac(t *testing.T) { ` _, err := Parse(Request{Str: query}) require.Error(t, err) - require.Contains(t, err.Error(), - "Unrecognized character inside a func: U+007B '{'") + require.Contains(t, err.Error(), "Unclosed Brackets") +} + +func TestParseSimilarToNamedParams(t *testing.T) { + query := `{ + q(func: similar_to(voptions, 4, "[0,0]", distance_threshold: 1.5, ef: 12)) { + uid + } + }` + res, err := Parse(Request{Str: query}) + require.NoError(t, err) + require.Len(t, res.Query, 1) + require.NotNil(t, res.Query[0]) + require.NotNil(t, res.Query[0].Func) + require.Equal(t, "similar_to", res.Query[0].Func.Name) + require.Equal(t, "voptions", res.Query[0].Func.Attr) + require.Equal(t, "4", res.Query[0].Func.Args[0].Value) + require.Equal(t, "[0,0]", res.Query[0].Func.Args[1].Value) + + // Options are appended as (key, value) pairs after k and vec. + require.Len(t, res.Query[0].Func.Args, 6) + require.Equal(t, "distance_threshold", res.Query[0].Func.Args[2].Value) + require.Equal(t, "1.5", res.Query[0].Func.Args[3].Value) + require.Equal(t, "ef", res.Query[0].Func.Args[4].Value) + require.Equal(t, "12", res.Query[0].Func.Args[5].Value) +} + +func TestParseSimilarToThreeArgs(t *testing.T) { + // Test three-arg form (no options) + query := `{ + q(func: similar_to(voptions, 4, "[0,0]")) { + uid + } + }` + res, err := Parse(Request{Str: query}) + require.NoError(t, err) + require.Equal(t, "similar_to", res.Query[0].Func.Name) + require.Len(t, res.Query[0].Func.Args, 2) +} + +func TestParseSimilarToRejectsObjectLiteralSyntax(t *testing.T) { + query := `{ + q(func: similar_to(voptions, 4, "[0,0]", {ef: 12})) { + uid + } + }` + _, err := Parse(Request{Str: query}) + require.Error(t, err) + require.Contains(t, err.Error(), "Unrecognized character inside a func: U+007B '{'") +} + +func TestParseSimilarToWithQueryVariable(t *testing.T) { + query := `query test($eff: int) { + q(func: similar_to(voptions, 4, "[0,0]", ef: $eff)) { + uid + } + }` + res, err := Parse(Request{ + Str: query, + Variables: map[string]string{"$eff": "64"}, + }) + require.NoError(t, err) + require.Equal(t, "similar_to", res.Query[0].Func.Name) + require.Len(t, res.Query[0].Func.Args, 4) + require.Equal(t, "ef", res.Query[0].Func.Args[2].Value) + require.Equal(t, "64", res.Query[0].Func.Args[3].Value) +} + +func TestParseSimilarToRejectsLegacyStringOptionsSyntax(t *testing.T) { + query := `{ + q(func: similar_to(voptions, 4, "[0,0]", "ef=64,distance_threshold=0.45")) { + uid + } + }` + _, err := Parse(Request{Str: query}) + require.Error(t, err) + require.Contains(t, err.Error(), "Expected named parameter in similar_to options") +} + +func TestParseSimilarToUnknownOption(t *testing.T) { + query := `{ + q(func: similar_to(voptions, 4, "[0,0]", foo: 5)) { + uid + } + }` + _, err := Parse(Request{Str: query}) + require.Error(t, err) + require.Contains(t, err.Error(), "Unknown option") + require.Contains(t, err.Error(), "foo") +} + +func TestParseSimilarToDuplicateOption(t *testing.T) { + query := `{ + q(func: similar_to(voptions, 4, "[0,0]", ef: 10, ef: 20)) { + uid + } + }` + _, err := Parse(Request{Str: query}) + require.Error(t, err) + require.Contains(t, err.Error(), "Duplicate key") + require.Contains(t, err.Error(), "ef") +} + +func TestParseNonSimilarToWithBrace(t *testing.T) { + // Braces in non-similar_to functions should be rejected + query := `{ + q(func: eq(name, {value: "test"})) { + uid + } + }` + _, err := Parse(Request{Str: query}) + require.Error(t, err) + require.Contains(t, err.Error(), "Unrecognized character inside a func: U+007B '{'") } func TestParseFilter_Geo1(t *testing.T) { @@ -2768,6 +2885,10 @@ func TestParseCountAsFunc(t *testing.T) { } +// Note: This query has two errors: missing ')' after 'friends' AND a stray '}'. +// After changes to support similar_to's JSON args the lexer emits brace tokens instead +// of erroring immediately -- causing this to fail on unclosed brackets rather than the +// specific character error. See TestParseFilter_unbalancedbrac for full explanation. func TestParseCountError1(t *testing.T) { query := `{ me(func: uid(1)) { @@ -2779,10 +2900,11 @@ func TestParseCountError1(t *testing.T) { ` _, err := Parse(Request{Str: query}) require.Error(t, err) - require.Contains(t, err.Error(), - "Unrecognized character inside a func: U+007D '}'") + require.Contains(t, err.Error(), "Unclosed Brackets") } +// Note: Similar to TestParseCountError1, this has missing ')' and stray '}', +// now reports structural error instead of character-specific error. func TestParseCountError2(t *testing.T) { query := `{ me(func: uid(1)) { @@ -2794,8 +2916,7 @@ func TestParseCountError2(t *testing.T) { ` _, err := Parse(Request{Str: query}) require.Error(t, err) - require.Contains(t, err.Error(), - "Unrecognized character inside a func: U+007D '}'") + require.Contains(t, err.Error(), "Unclosed Brackets") } func TestParseCheckPwd(t *testing.T) { diff --git a/dql/state.go b/dql/state.go index 507bf9f64f3..48260990914 100644 --- a/dql/state.go +++ b/dql/state.go @@ -306,6 +306,18 @@ func lexFuncOrArg(l *lex.Lexer) lex.StateFn { l.Emit(itemLeftSquare) case r == rightSquare: l.Emit(itemRightSquare) + case r == leftCurl: + empty = false + l.Emit(itemLeftCurl) + // Design decision: Emit brace tokens without affecting ArgDepth tracking. + // The parser validates whether braces are legal in context. + // Trade-off: Queries with multiple syntax errors (e.g., missing ')' AND stray '}') + // will report structural errors (Unclosed Brackets) rather than character-specific + // errors. This is acceptable as the query is still rejected with a clear error. + case r == rightCurl: + l.Emit(itemRightCurl) + // Don't decrement ArgDepth for braces; let parser validate context. + // See leftCurl case above for full rationale. case r == '#': return lexComment case r == '.': diff --git a/query/vector/vector_test.go b/query/vector/vector_test.go index 14d02c16d82..e5b687ac758 100644 --- a/query/vector/vector_test.go +++ b/query/vector/vector_test.go @@ -417,6 +417,74 @@ func TestVectorIndexRebuildWhenChange(t *testing.T) { require.Greater(t, dur, time.Second*4) } +func TestSimilarToOptionsIntegration(t *testing.T) { + const pred = "voptions" + dropPredicate(pred) + t.Cleanup(func() { dropPredicate(pred) }) + + setSchema(fmt.Sprintf(vectorSchemaWithIndex, pred, "4", "euclidean")) + + rdf := `<0x1> "[0,0]" . + <0x2> "[1,0]" . + <0x3> "[2,0]" . + <0x4> "[5,0]" .` + require.NoError(t, addTriplesToCluster(rdf)) + + t.Run("ef_override_named_param", func(t *testing.T) { + query := `{ + results(func: similar_to(voptions, 3, "[0,0]", ef: 2)) { + uid + } + }` + resp := processQueryNoErr(t, query) + + var result struct { + Data struct { + Results []struct { + UID string `json:"uid"` + } `json:"results"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal([]byte(resp), &result)) + require.Len(t, result.Data.Results, 3) + + expected := map[string]struct{}{"0x1": {}, "0x2": {}, "0x3": {}} + for _, r := range result.Data.Results { + _, ok := expected[r.UID] + require.Truef(t, ok, "unexpected uid %s", r.UID) + delete(expected, r.UID) + } + require.Empty(t, expected) + }) + + t.Run("distance_threshold_named_param", func(t *testing.T) { + query := `{ + results(func: similar_to(voptions, 4, "[0,0]", distance_threshold: 1.5)) { + uid + } + }` + resp := processQueryNoErr(t, query) + + var result struct { + Data struct { + Results []struct { + UID string `json:"uid"` + } `json:"results"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal([]byte(resp), &result)) + require.Len(t, result.Data.Results, 2) + + expected := map[string]struct{}{"0x1": {}, "0x2": {}} + for _, r := range result.Data.Results { + _, ok := expected[r.UID] + require.Truef(t, ok, "unexpected uid %s", r.UID) + delete(expected, r.UID) + } + require.Empty(t, expected) + }) +} + func TestVectorInQueryArgument(t *testing.T) { dropPredicate("vtest") setSchema(fmt.Sprintf(vectorSchemaWithIndex, "vtest", "4", "euclidean")) diff --git a/tok/hnsw/ef_recall_test.go b/tok/hnsw/ef_recall_test.go new file mode 100644 index 00000000000..5ae22282222 --- /dev/null +++ b/tok/hnsw/ef_recall_test.go @@ -0,0 +1,211 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package hnsw + +import ( + "context" + "encoding/binary" + "math" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/dgraph-io/dgraph/v25/tok/index" + opt "github.com/dgraph-io/dgraph/v25/tok/options" + "github.com/dgraph-io/dgraph/v25/x" +) + +// memoryCache satisfies index.CacheType for synthetic tests. +type memoryCache struct { + data map[string][]byte +} + +func (m *memoryCache) Get(key []byte) ([]byte, error) { + if val, ok := m.data[string(key)]; ok { + return val, nil + } + return nil, nil +} + +func (m *memoryCache) Ts() uint64 { return 0 } + +func (m *memoryCache) Find([]byte, func([]byte) bool) (uint64, error) { return 0, nil } + +func float64ArrayAsBytes(v []float64) []byte { + buf := make([]byte, 8*len(v)) + for i, f := range v { + binary.LittleEndian.PutUint64(buf[i*8:], math.Float64bits(f)) + } + return buf +} + +// Test that EfOverride widens the bottom-layer candidate set and improves recall on a tiny graph. +func TestHNSWSearchEfOverrideImprovesRecall(t *testing.T) { + ctx := context.Background() + + factory := CreateFactory[float64](64) + options := opt.NewOptions() + options.SetOpt(MaxLevelsOpt, 2) + options.SetOpt(EfSearchOpt, 1) + options.SetOpt(MetricOpt, GetSimType[float64](Euclidean, 64)) + + predName := "ef_override_test_pred" + predWithNamespace := x.NamespaceAttr(x.RootNamespace, predName) + + rawIdx, err := factory.Create(predWithNamespace, options, 64) + require.NoError(t, err) + + // Use concrete type directly (same package) to set up a tiny synthetic graph. + ph, ok := rawIdx.(*persistentHNSW[float64]) + require.True(t, ok) + require.Equal(t, predWithNamespace, ph.pred) + + // Populate vectors in memory via cache data map keyed by DataKey. + vectors := map[uint64][]float64{ + 1: {0, 0, 10, 0}, // entry + 100: {0, 0, 0.1, 0}, // true nearest to query + 200: {0, 0, 3, 0}, // local minimum path + 201: {0, 0, 3.2, 0}, + } + + data := make(map[string][]byte) + for uid, vec := range vectors { + key := string(DataKey(ph.pred, uid)) + data[key] = float64ArrayAsBytes(vec) + } + + // Set entry pointer to uid 1. + entryKey := string(DataKey(ph.vecEntryKey, 1)) + data[entryKey] = Uint64ToBytes(1) + + // Wire a small graph that requires wider search to find uid 100 from entry 1. + ph.nodeAllEdges[1] = [][]uint64{{}, {200, 201}} + ph.nodeAllEdges[200] = [][]uint64{{1}, {1}} + ph.nodeAllEdges[201] = [][]uint64{{1}, {100}} + ph.nodeAllEdges[100] = [][]uint64{{201}, {201}} + + cache := &memoryCache{data: data} + query := []float64{0, 0, 0.12, 0} + + // Narrow ef behaves like legacy path: returns uid 200 for k=1. + narrowOpts := index.VectorIndexOptions[float64]{} + narrow, err := ph.SearchWithOptions(ctx, cache, query, 1, narrowOpts) + require.NoError(t, err) + require.Equal(t, []uint64{200}, narrow) + + // Wider ef surfaces the closer neighbor uid 100. + wideOpts := index.VectorIndexOptions[float64]{EfOverride: 4} + wide, err := ph.SearchWithOptions(ctx, cache, query, 1, wideOpts) + require.NoError(t, err) + require.Equal(t, []uint64{100}, wide) +} + +// Test Euclidean distance_threshold filters out results with squared distance above threshold. +func TestHNSWDistanceThreshold_Euclidean(t *testing.T) { + ctx := context.Background() + + factory := CreateFactory[float64](64) + options := opt.NewOptions() + options.SetOpt(MaxLevelsOpt, 1) + options.SetOpt(EfSearchOpt, 10) + options.SetOpt(MetricOpt, GetSimType[float64](Euclidean, 64)) + + pred := x.NamespaceAttr(x.RootNamespace, "thresh_pred_e") + rawIdx, err := factory.Create(pred, options, 64) + require.NoError(t, err) + ph := rawIdx.(*persistentHNSW[float64]) + + // Two vectors at known Euclidean distances from query. + // query q = (0,0), a=(0.6,0), b=(0.8,0) + // dist(q,a)=0.6, dist(q,b)=0.8 + data := map[string][]byte{ + string(DataKey(pred, 1)): float64ArrayAsBytes([]float64{0.6, 0}), + string(DataKey(pred, 2)): float64ArrayAsBytes([]float64{0.8, 0}), + string(DataKey(ph.vecEntryKey, 1)): Uint64ToBytes(1), + } + // Single-layer edges; ensure both are reachable from entry. + ph.nodeAllEdges[1] = [][]uint64{{1, 2}} + ph.nodeAllEdges[2] = [][]uint64{{1}} + + cache := &memoryCache{data: data} + q := []float64{0, 0} + + // With current internal Euclidean values, use thresholds directly in the metric domain. + // threshold 0.75: include uid 1 (0.6) and exclude uid 2 (0.8). + th := 0.75 + res, err := ph.SearchWithOptions(ctx, cache, q, 10, index.VectorIndexOptions[float64]{ + DistanceThreshold: &th, + EfOverride: 10, + }) + require.NoError(t, err) + require.Equal(t, []uint64{1}, res) + + // threshold 1.5 (>1) should include both neighbors and demonstrate we keep working in raw distance. + thHigh := 1.5 + resHigh, err := ph.SearchWithOptions(ctx, cache, q, 10, index.VectorIndexOptions[float64]{ + DistanceThreshold: &thHigh, + EfOverride: 10, + }) + require.NoError(t, err) + require.ElementsMatch(t, []uint64{1, 2}, resHigh) + + // threshold 0.6 includes uid 1 (distance 0.6) but excludes the rest (inclusive comparison). + thExact := 0.6 + resExact, err := ph.SearchWithOptions(ctx, cache, q, 10, index.VectorIndexOptions[float64]{ + DistanceThreshold: &thExact, + EfOverride: 10, + }) + require.NoError(t, err) + require.Equal(t, []uint64{1}, resExact) + + // threshold 0.5 should filter everything. + thLow := 0.5 + resLow, err := ph.SearchWithOptions(ctx, cache, q, 10, index.VectorIndexOptions[float64]{ + DistanceThreshold: &thLow, + EfOverride: 10, + }) + require.NoError(t, err) + require.Empty(t, resLow) +} + +// Test Cosine distance_threshold uses distance d = 1 - cosine_similarity. +func TestHNSWDistanceThreshold_Cosine(t *testing.T) { + ctx := context.Background() + + factory := CreateFactory[float64](64) + options := opt.NewOptions() + options.SetOpt(MaxLevelsOpt, 1) + options.SetOpt(EfSearchOpt, 10) + options.SetOpt(MetricOpt, GetSimType[float64](Cosine, 64)) + + pred := x.NamespaceAttr(x.RootNamespace, "thresh_pred_c") + rawIdx, err := factory.Create(pred, options, 64) + require.NoError(t, err) + ph := rawIdx.(*persistentHNSW[float64]) + + // Query q is unit along x-axis. + // a is exact match (cos sim 1.0, distance 0.0) + // b is 36.87 degrees (~cos 0.8, distance 0.2) + data := map[string][]byte{ + string(DataKey(pred, 1)): float64ArrayAsBytes([]float64{1, 0}), + string(DataKey(pred, 2)): float64ArrayAsBytes([]float64{0.8, 0.6}), + string(DataKey(ph.vecEntryKey, 1)): Uint64ToBytes(1), + } + ph.nodeAllEdges[1] = [][]uint64{{1, 2}} + ph.nodeAllEdges[2] = [][]uint64{{1}} + + cache := &memoryCache{data: data} + q := []float64{1, 0} + + // distance_threshold=0.1 should include uid 1 but exclude uid 2 (0.2 > 0.1) + th := 0.1 + res, err := ph.SearchWithOptions(ctx, cache, q, 10, index.VectorIndexOptions[float64]{ + DistanceThreshold: &th, + EfOverride: 10, + }) + require.NoError(t, err) + require.Equal(t, []uint64{1}, res) +} diff --git a/tok/hnsw/persistent_hnsw.go b/tok/hnsw/persistent_hnsw.go index ec4c6a892a8..e516f0aef2e 100644 --- a/tok/hnsw/persistent_hnsw.go +++ b/tok/hnsw/persistent_hnsw.go @@ -29,7 +29,7 @@ type persistentHNSW[T c.Float] struct { simType SimilarityType[T] floatBits int // nodeAllEdges[65443][1][3] indicates the 3rd neighbor in the first - // layer for uuid 65443. The result will be a neighboring uuid. + // layer for UUID 65443. The result will be a neighboring UUID. nodeAllEdges map[uint64][][]uint64 deadNodes map[uint64]struct{} } @@ -145,7 +145,7 @@ func (ph *persistentHNSW[T]) fillNeighborEdges(uuid uint64, c index.CacheType, e return true, nil } -// searchPersistentLayer searches a layer of the hnsw graph for the nearest +// searchPersistentLayer searches a layer of the HNSW graph for the nearest // neighbors of the query vector and returns the traversal path and the nearest // neighbors func (ph *persistentHNSW[T]) searchPersistentLayer( @@ -177,16 +177,13 @@ func (ph *persistentHNSW[T]) searchPersistentLayer( //create set using map to append to on future visited nodes for candidateHeap.Len() != 0 { currCandidate := candidateHeap.Pop().(minPersistentHeapElement[T]) - if r.numNeighbors() < expectedNeighbors && + if r.numNeighbors() >= expectedNeighbors && ph.simType.isBetterScore(r.lastNeighborScore(), currCandidate.value) { - // If the "worst score" in our neighbors list is deemed to have - // a better score than the current candidate -- and if we have at - // least our expected number of nearest results -- we discontinue - // the search. - // Note that while this is faithful to the published - // HNSW algorithms insofar as we stop when we reach a local - // minimum, it leaves something to be desired in terms of - // guarantees of getting best results. + // Standard HNSW termination: once the current best candidate + // cannot improve the ef-sized neighbor set (and we already have + // at least expectedNeighbors), we stop exploring this layer. + // Recall is governed by ef; callers may raise ef (per‑query + // override supported) to explore further. break } @@ -246,7 +243,7 @@ func (ph *persistentHNSW[T]) searchPersistentLayer( return r, nil } -// Search searches the hnsw graph for the nearest neighbors of the query vector +// Search searches the HNSW graph for the nearest neighbors of the query vector // and returns the traversal path and the nearest neighbors func (ph *persistentHNSW[T]) Search(ctx context.Context, c index.CacheType, query []T, maxResults int, filter index.SearchFilter[T]) (nnUids []uint64, err error) { @@ -254,7 +251,173 @@ func (ph *persistentHNSW[T]) Search(ctx context.Context, c index.CacheType, quer return r.Neighbors, err } -// SearchWithUid searches the hnsw graph for the nearest neighbors of the query uid +// SearchWithOptions applies optional per-call controls (ef override and distance threshold). +// When EfOverride > 0, it is applied at upper layers and the bottom layer uses +// candidateK = max(maxResults, EfOverride). Results return the best maxResults. +// When DistanceThreshold is set, results exceeding the threshold (in the metric domain) +// are filtered out before limiting to maxResults. +func (ph *persistentHNSW[T]) SearchWithOptions( + ctx context.Context, + c index.CacheType, + query []T, + maxResults int, + opts index.VectorIndexOptions[T], +) ([]uint64, error) { + if opts.Filter == nil { + opts.Filter = index.AcceptAll[T] + } + if maxResults < 0 { + maxResults = 0 + } + r := index.NewSearchPathResult() + start := time.Now().UnixMilli() + + // 0-profile_vector_entry + var startVec []T + entry, err := ph.PickStartNode(ctx, c, &startVec) + if err != nil { + return nil, err + } + + // Upper layers use efUpper (override if provided) + efUpper := ph.efSearch + if opts.EfOverride > 0 { + efUpper = opts.EfOverride + } + + for level := range ph.maxLevels - 1 { + if isEqual(startVec, query) { + break + } + filterOut := !opts.Filter(query, startVec, entry) + layerResult, err := ph.searchPersistentLayer( + c, level, entry, startVec, query, filterOut, efUpper, opts.Filter) + if err != nil { + return nil, err + } + layerResult.updateFinalMetrics(r) + entry = layerResult.bestNeighbor().index + layerResult.updateFinalPath(r) + if err = ph.getVecFromUid(entry, c, &startVec); err != nil { + return nil, err + } + } + + // Bottom layer: candidate size = max(k, efUpper) + filterOut := !opts.Filter(query, startVec, entry) + candidateK := maxResults + if efUpper > candidateK { + candidateK = efUpper + } + layerResult, err := ph.searchPersistentLayer( + c, ph.maxLevels-1, entry, startVec, query, filterOut, candidateK, opts.Filter) + if err != nil { + return nil, err + } + layerResult.updateFinalMetrics(r) + layerResult.updateFinalPath(r) + + // Build final neighbor list with optional threshold, limited to maxResults. + res := make([]uint64, 0, maxResults) + for _, n := range layerResult.neighbors { + if maxResults == 0 { + break + } + if n.filteredOut { + continue + } + if opts.DistanceThreshold != nil { + th := *opts.DistanceThreshold + switch ph.simType.indexType { + case Euclidean: + // n.value stores the metric-domain distance (not squared). + if float64(n.value) > th { + continue + } + case Cosine: + // n.value is cosine similarity in [-1,1]; cosine distance d = 1 - sim must be <= th. + if float64(1.0)-float64(n.value) > th { + continue + } + default: + // Dot product or others: ignore threshold for now. + } + } + res = append(res, n.index) + if len(res) >= maxResults { + break + } + } + + r.Metrics[searchTime] = uint64(time.Now().UnixMilli() - start) + return res, nil +} + +// SearchWithUidAndOptions is analogous to SearchWithUid but applies per‑call options. +func (ph *persistentHNSW[T]) SearchWithUidAndOptions( + _ context.Context, + c index.CacheType, + queryUid uint64, + maxResults int, + opts index.VectorIndexOptions[T], +) ([]uint64, error) { + if opts.Filter == nil { + opts.Filter = index.AcceptAll[T] + } + if maxResults < 0 { + maxResults = 0 + } + var queryVec []T + if err := ph.getVecFromUid(queryUid, c, &queryVec); err != nil { + if errors.Is(err, errFetchingPostingList) { + return []uint64{}, nil + } + return []uint64{}, err + } + if len(queryVec) == 0 { + return []uint64{}, nil + } + filterOut := !opts.Filter(queryVec, queryVec, queryUid) + candidateK := maxResults + if opts.EfOverride > candidateK { + candidateK = opts.EfOverride + } + lr, err := ph.searchPersistentLayer( + c, ph.maxLevels-1, queryUid, queryVec, queryVec, filterOut, candidateK, opts.Filter) + if err != nil { + return []uint64{}, err + } + res := make([]uint64, 0, maxResults) + for _, n := range lr.neighbors { + if maxResults == 0 { + break + } + if n.filteredOut { + continue + } + if opts.DistanceThreshold != nil { + th := *opts.DistanceThreshold + switch ph.simType.indexType { + case Euclidean: + if float64(n.value) > th { + continue + } + case Cosine: + if float64(1.0)-float64(n.value) > th { + continue + } + default: + } + } + res = append(res, n.index) + if len(res) >= maxResults { + break + } + } + return res, nil +} + +// SearchWithUid searches the HNSW graph for the nearest neighbors of the query UID // and returns the traversal path and the nearest neighbors func (ph *persistentHNSW[T]) SearchWithUid(_ context.Context, c index.CacheType, queryUid uint64, maxResults int, filter index.SearchFilter[T]) (nnUids []uint64, err error) { @@ -275,9 +438,9 @@ func (ph *persistentHNSW[T]) SearchWithUid(_ context.Context, c index.CacheType, shouldFilterOutQueryVec := !filter(queryVec, queryVec, queryUid) - // how normal search works is by cotinuously searching higher layers - // for the best entry node to the last layer since we already know the - // best entry node (since it already exists in the lowest level), we + // How normal search works is by continuously searching higher layers + // for the best entry node to the last layer. Since we already know the + // best entry node (it already exists in the lowest level), we // can just search the last layer and return the results. r, err := ph.searchPersistentLayer( c, ph.maxLevels-1, queryUid, queryVec, queryVec, @@ -389,7 +552,7 @@ func (ph *persistentHNSW[T]) SearchWithPath( return r, nil } -// InsertToPersistentStorage inserts a node into the hnsw graph and returns the +// InsertToPersistentStorage inserts a node into the HNSW graph and returns the // traversal path and the edges created func (ph *persistentHNSW[T]) Insert(ctx context.Context, c index.CacheType, inUuid uint64, inVec []T) ([]*index.KeyValue, error) { @@ -401,7 +564,7 @@ func (ph *persistentHNSW[T]) Insert(ctx context.Context, c index.CacheType, return edges, err } -// InsertToPersistentStorage inserts a node into the hnsw graph and returns the +// InsertToPersistentStorage inserts a node into the HNSW graph and returns the // traversal path and the edges created func (ph *persistentHNSW[T]) insertHelper(ctx context.Context, tc *TxnCache, inUuid uint64, inVec []T) ([]minPersistentHeapElement[T], []*index.KeyValue, error) { diff --git a/tok/index/index.go b/tok/index/index.go index 1ce56c02e49..1e981ef189e 100644 --- a/tok/index/index.go +++ b/tok/index/index.go @@ -118,6 +118,35 @@ type VectorIndex[T c.Float] interface { Insert(ctx context.Context, c CacheType, uuid uint64, vec []T) ([]*KeyValue, error) } +// VectorIndexOptions carries optional, per-call search tuning parameters. +// Zero values mean "no override". +type VectorIndexOptions[T c.Float] struct { + // EfOverride, when > 0, overrides the search breadth (ef) for this call. + // Implementations should apply this to upper layers and use max(k, ef) for + // the bottom layer candidate size, then return the best k. + EfOverride int + + // DistanceThreshold, when non-nil, filters out neighbors whose metric-domain + // distance exceeds the given threshold. Semantics depend on the index metric: + // - Euclidean: direct Euclidean distance (not squared) + // - Cosine: cosine distance in [0,2] (1 - cosine_similarity) + // - Dot product: undefined; implementations may ignore + DistanceThreshold *float64 + + // Filter allows callers to pass a SearchFilter; if nil, AcceptAll should be used. + Filter SearchFilter[T] +} + +// OptionalSearchOptions adds per-call search controls without breaking existing APIs. +// Implementations that support these may choose to ignore unsupported fields. +type OptionalSearchOptions[T c.Float] interface { + SearchWithOptions(ctx context.Context, c CacheType, query []T, + maxResults int, opts VectorIndexOptions[T]) ([]uint64, error) + + SearchWithUidAndOptions(ctx context.Context, c CacheType, queryUid uint64, + maxResults int, opts VectorIndexOptions[T]) ([]uint64, error) +} + // A Txn is an interface representation of a persistent storage transaction, // where multiple operations are performed on a database type Txn interface { diff --git a/tok/index/search_path.go b/tok/index/search_path.go index e60e1e318d4..7e24b7d068c 100644 --- a/tok/index/search_path.go +++ b/tok/index/search_path.go @@ -5,11 +5,11 @@ package index -// SearchPathResult is the return-type for the optional +// SearchPathResult is the return type for the optional // SearchWithPath function for a VectorIndex // (by way of extending OptionalIndexSupport). type SearchPathResult struct { - // The collection of nearest-neighbors in sorted order after filtlering + // The collection of nearest neighbors in sorted order after filtering // out neighbors that fail any Filter criteria. Neighbors []uint64 // The path from the start of search to the closest neighbor vector. @@ -19,8 +19,8 @@ type SearchPathResult struct { Metrics map[string]uint64 } -// NewSearchPathResult() provides an initialized (empty) *SearchPathResult. -// The attributes will be non-nil, but empty. +// NewSearchPathResult provides an initialised (empty) *SearchPathResult. +// The attributes will be non‑nil but empty. func NewSearchPathResult() *SearchPathResult { return &SearchPathResult{ Neighbors: []uint64{}, diff --git a/worker/similar_to_options_test.go b/worker/similar_to_options_test.go new file mode 100644 index 00000000000..e6534d78d13 --- /dev/null +++ b/worker/similar_to_options_test.go @@ -0,0 +1,129 @@ +/* + * SPDX-FileCopyrightText: © 2017-2025 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package worker + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseSimilarToOptions_ValidEf(t *testing.T) { + fc := &functionContext{} + err := parseSimilarToOptions([]string{"ef", "64"}, fc) + require.NoError(t, err) + require.Equal(t, 64, fc.vsEfOverride) + require.Nil(t, fc.vsDistanceThreshold) +} + +func TestParseSimilarToOptions_ValidDistanceThreshold(t *testing.T) { + fc := &functionContext{} + err := parseSimilarToOptions([]string{"distance_threshold", "0.45"}, fc) + require.NoError(t, err) + require.Equal(t, 0, fc.vsEfOverride) + require.NotNil(t, fc.vsDistanceThreshold) + require.Equal(t, 0.45, *fc.vsDistanceThreshold) +} + +func TestParseSimilarToOptions_BothOptions(t *testing.T) { + fc := &functionContext{} + err := parseSimilarToOptions([]string{"ef", "32", "distance_threshold", "1.5"}, fc) + require.NoError(t, err) + require.Equal(t, 32, fc.vsEfOverride) + require.NotNil(t, fc.vsDistanceThreshold) + require.Equal(t, 1.5, *fc.vsDistanceThreshold) +} + +func TestParseSimilarToOptions_EmptyArgs(t *testing.T) { + fc := &functionContext{} + err := parseSimilarToOptions(nil, fc) + require.NoError(t, err) + require.Equal(t, 0, fc.vsEfOverride) + require.Nil(t, fc.vsDistanceThreshold) +} + +func TestParseSimilarToOptions_DuplicateKey(t *testing.T) { + fc := &functionContext{} + err := parseSimilarToOptions([]string{"ef", "10", "ef", "20"}, fc) + require.Error(t, err) + require.Contains(t, err.Error(), "Duplicate key in similar_to options") + require.Contains(t, err.Error(), "ef") +} + +func TestParseSimilarToOptions_UnknownKey(t *testing.T) { + fc := &functionContext{} + err := parseSimilarToOptions([]string{"unknown_key", "123"}, fc) + require.Error(t, err) + require.Contains(t, err.Error(), "Unknown option in similar_to") + require.Contains(t, err.Error(), "unknown_key") +} + +func TestParseSimilarToOptions_MalformedOption_OddArgs(t *testing.T) { + fc := &functionContext{} + err := parseSimilarToOptions([]string{"ef"}, fc) + require.Error(t, err) + require.Contains(t, err.Error(), "Malformed option in similar_to") +} + +func TestParseSimilarToOptions_InvalidEfValue(t *testing.T) { + fc := &functionContext{} + err := parseSimilarToOptions([]string{"ef", "abc"}, fc) + require.Error(t, err) + require.Contains(t, err.Error(), "Invalid value for 'ef'") +} + +func TestParseSimilarToOptions_NegativeEf(t *testing.T) { + fc := &functionContext{} + err := parseSimilarToOptions([]string{"ef", "-5"}, fc) + require.Error(t, err) + require.Contains(t, err.Error(), "Value for 'ef' must be positive") +} + +func TestParseSimilarToOptions_ZeroEf(t *testing.T) { + fc := &functionContext{} + err := parseSimilarToOptions([]string{"ef", "0"}, fc) + require.Error(t, err) + require.Contains(t, err.Error(), "Value for 'ef' must be positive") +} + +func TestParseSimilarToOptions_InvalidDistanceThreshold(t *testing.T) { + fc := &functionContext{} + err := parseSimilarToOptions([]string{"distance_threshold", "notanumber"}, fc) + require.Error(t, err) + require.Contains(t, err.Error(), "Invalid value for 'distance_threshold'") +} + +func TestParseSimilarToOptions_NegativeDistanceThreshold(t *testing.T) { + fc := &functionContext{} + err := parseSimilarToOptions([]string{"distance_threshold", "-0.5"}, fc) + require.Error(t, err) + require.Contains(t, err.Error(), "Value for 'distance_threshold' must be non-negative") +} + +func TestParseSimilarToOptions_WhitespaceHandling(t *testing.T) { + fc := &functionContext{} + err := parseSimilarToOptions([]string{" ef ", " 64 ", " distance_threshold ", " 0.5 "}, fc) + require.NoError(t, err) + require.Equal(t, 64, fc.vsEfOverride) + require.NotNil(t, fc.vsDistanceThreshold) + require.Equal(t, 0.5, *fc.vsDistanceThreshold) +} + +func TestParseSimilarToOptions_QuotedValues(t *testing.T) { + fc := &functionContext{} + err := parseSimilarToOptions([]string{"ef", `"64"`}, fc) + require.NoError(t, err) + require.Equal(t, 64, fc.vsEfOverride) +} + +func TestParseSimilarToOptions_MisspelledKey(t *testing.T) { + // Ensure typos like "distanc_threshold" are caught + fc := &functionContext{} + err := parseSimilarToOptions([]string{"distanc_threshold", "0.5"}, fc) + require.Error(t, err) + require.Contains(t, err.Error(), "Unknown option in similar_to") + require.Contains(t, err.Error(), "distanc_threshold") +} diff --git a/worker/task.go b/worker/task.go index 22000483c36..94e54c4ffd3 100644 --- a/worker/task.go +++ b/worker/task.go @@ -368,12 +368,30 @@ func (qs *queryState) handleValuePostings(ctx context.Context, args funcArgs) er return err } var nnUids []uint64 - if srcFn.vectorInfo != nil { - nnUids, err = indexer.Search(ctx, qc, srcFn.vectorInfo, - int(numNeighbors), index.AcceptAll[float32]) + // Build optional search options if provided + filter := index.AcceptAll[float32] + opts := index.VectorIndexOptions[float32]{Filter: filter} + if srcFn.vsEfOverride > 0 { + opts.EfOverride = srcFn.vsEfOverride + } + if srcFn.vsDistanceThreshold != nil { + opts.DistanceThreshold = srcFn.vsDistanceThreshold + } + hasOptions := opts.EfOverride > 0 || opts.DistanceThreshold != nil + if o, ok := indexer.(index.OptionalSearchOptions[float32]); ok && hasOptions { + if srcFn.vectorInfo != nil { + nnUids, err = o.SearchWithOptions(ctx, qc, srcFn.vectorInfo, int(numNeighbors), opts) + } else { + nnUids, err = o.SearchWithUidAndOptions(ctx, qc, srcFn.vectorUid, int(numNeighbors), opts) + } } else { - nnUids, err = indexer.SearchWithUid(ctx, qc, srcFn.vectorUid, - int(numNeighbors), index.AcceptAll[float32]) + if srcFn.vectorInfo != nil { + nnUids, err = indexer.Search(ctx, qc, srcFn.vectorInfo, + int(numNeighbors), index.AcceptAll[float32]) + } else { + nnUids, err = indexer.SearchWithUid(ctx, qc, srcFn.vectorUid, + int(numNeighbors), index.AcceptAll[float32]) + } } if err != nil && !strings.Contains(err.Error(), hnsw.EmptyHNSWTreeError+": "+badger.ErrKeyNotFound.Error()) { @@ -1792,6 +1810,9 @@ type functionContext struct { atype types.TypeID vectorInfo []float32 vectorUid uint64 + // Optional vector search options parsed from a 3rd arg on similar_to + vsEfOverride int + vsDistanceThreshold *float64 } const ( @@ -2125,13 +2146,21 @@ func parseSrcFn(ctx context.Context, q *pb.Query) (*functionContext, error) { } checkRoot(q, fc) case similarToFn: - if err = ensureArgsCount(q.SrcFunc, 2); err != nil { - return nil, err + // similar_to accepts 2 mandatory args: k, vector_or_uid followed by optional key:value pairs + // Example: similar_to(vpred, 3, $vec, ef: 64, distance_threshold: 0.5) + if len(q.SrcFunc.Args) < 2 || (len(q.SrcFunc.Args) > 2 && (len(q.SrcFunc.Args)-2)%2 != 0) { + return nil, errors.Errorf("Function '%s' requires 2 arguments plus optional key:value pairs, but got %d (%v)", + q.SrcFunc.Name, len(q.SrcFunc.Args), q.SrcFunc.Args) } fc.vectorInfo, fc.vectorUid, err = interpretVFloatOrUid(q.SrcFunc.Args[1]) if err != nil { return nil, err } + if len(q.SrcFunc.Args) > 2 { + if err := parseSimilarToOptions(q.SrcFunc.Args[2:], fc); err != nil { + return nil, err + } + } case uidInFn: for _, arg := range q.SrcFunc.Args { uidParsed, err := strconv.ParseUint(arg, 0, 64) @@ -2710,3 +2739,55 @@ func (qs *queryState) handleHasFunction(ctx context.Context, q *pb.Query, out *p out.UidMatrix = append(out.UidMatrix, result) return nil } + +// parseSimilarToOptions parses named options passed after similar_to 2 mandatory args (k, vecOrUid) +// The parser encodes these as key/value pairs: ["ef", "64", "distance_threshold", "0.5", ...] +func parseSimilarToOptions(args []string, fc *functionContext) error { + if len(args) == 0 { + return nil + } + if len(args)%2 != 0 { + return errors.Errorf("Malformed option in similar_to: expected key:value pairs, got %v", args) + } + seen := make(map[string]struct{}, len(args)/2) + for i := 0; i < len(args); i += 2 { + k := strings.ToLower(strings.TrimSpace(args[i])) + v := strings.TrimSpace(args[i+1]) + if strings.HasSuffix(k, ":") { + k = strings.TrimSuffix(k, ":") + } + if len(k) == 0 { + return errors.Errorf("Malformed option in similar_to: empty key") + } + if _, dup := seen[k]; dup { + return errors.Errorf("Duplicate key in similar_to options: %q", k) + } + seen[k] = struct{}{} + + v = strings.Trim(v, "\"'") + switch k { + case "ef": + n, perr := strconv.ParseInt(v, 10, 32) + if perr != nil { + return errors.Errorf("Invalid value for 'ef' in similar_to: %q", v) + } + if n <= 0 { + return errors.Errorf("Value for 'ef' must be positive, got: %d", n) + } + fc.vsEfOverride = int(n) + case "distance_threshold": + f, perr := strconv.ParseFloat(v, 64) + if perr != nil { + return errors.Errorf("Invalid value for 'distance_threshold' in similar_to: %q", v) + } + if f < 0 { + return errors.Errorf("Value for 'distance_threshold' must be non-negative, got: %v", f) + } + fc.vsDistanceThreshold = new(float64) + *fc.vsDistanceThreshold = f + default: + return errors.Errorf("Unknown option in similar_to: %q", k) + } + } + return nil +}