diff --git a/gql/parser.go b/gql/parser.go index 23ac00306e4..dba4ee79eb9 100644 --- a/gql/parser.go +++ b/gql/parser.go @@ -1349,7 +1349,7 @@ func validFuncName(name string) bool { switch name { case "regexp", "anyofterms", "allofterms", "alloftext", "anyoftext", - "has", "uid", "uid_in", "anyof", "allof", "type": + "has", "uid", "uid_in", "anyof", "allof", "type", "match": return true } return false diff --git a/query/query.go b/query/query.go index 39c63cd4f46..8a4ef43edcf 100644 --- a/query/query.go +++ b/query/query.go @@ -2423,7 +2423,7 @@ func isValidArg(a string) bool { func isValidFuncName(f string) bool { switch f { case "anyofterms", "allofterms", "val", "regexp", "anyoftext", "alloftext", - "has", "uid", "uid_in", "anyof", "allof", "type": + "has", "uid", "uid_in", "anyof", "allof", "type", "match": return true } return isInequalityFn(f) || types.IsGeoFunc(f) diff --git a/systest/queries_test.go b/systest/queries_test.go index c884963f8ed..5722bbed8f7 100644 --- a/systest/queries_test.go +++ b/systest/queries_test.go @@ -50,6 +50,7 @@ func TestQuery(t *testing.T) { t.Run("multiple block eval", wrap(MultipleBlockEval)) t.Run("unmatched var assignment eval", wrap(UnmatchedVarEval)) t.Run("hash index queries", wrap(QueryHashIndex)) + t.Run("fuzzy matching", wrap(FuzzyMatch)) t.Run("regexp with toggled trigram index", wrap(RegexpToggleTrigramIndex)) t.Run("groupby uid that works", wrap(GroupByUidWorks)) t.Run("cleanup", wrap(SchemaQueryCleanup)) @@ -539,6 +540,147 @@ func SchemaQueryTestHTTP(t *testing.T, c *dgo.Dgraph) { CompareJSON(t, js, string(m["data"])) } +func FuzzyMatch(t *testing.T, c *dgo.Dgraph) { + ctx := context.Background() + + require.NoError(t, c.Alter(ctx, &api.Operation{ + Schema: ` + term: string @index(trigram) . + name: string . + `, + })) + + txn := c.NewTxn() + _, err := txn.Mutate(ctx, &api.Mutation{ + SetNquads: []byte(` + _:t0 "" . + _:t1 "road" . + _:t2 "avenue" . + _:t3 "street" . + _:t4 "boulevard" . + _:t5 "drive" . + _:t6 "route" . + _:t7 "pass" . + _:t8 "pathway" . + _:t9 "lane" . + _:ta "highway" . + _:tb "parkway" . + _:tc "motorway" . + _:td "high road" . + _:te "side street" . + _:tf "dual carriageway" . + _:n0 "srfrog" . + `), + }) + require.NoError(t, err) + require.NoError(t, txn.Commit(ctx)) + + tests := []struct { + in, out, failure string + }{ + { + in: `{q(func:match(term, drive, 8)) {term}}`, + out: `{"q":[{"term":"drive"}]}`, + }, + { + in: `{q(func:match(term, "plano", 1)) {term}}`, + out: `{"q":[]}`, + }, + { + in: `{q(func:match(term, "plano", 2)) {term}}`, + out: `{"q":[{"term":"lane"}]}`, + }, + { + in: `{q(func:match(term, "plano", 8)) {term}}`, + out: `{"q":[{"term":"lane"}]}`, + }, + { + in: `{q(func:match(term, way, 8)) {term}}`, + out: `{"q":[ + {"term": "highway"}, + {"term": "pathway"}, + {"term": "parkway"}, + {"term": "motorway"} + ]}`, + }, + { + in: `{q(func:match(term, pway, 8)) {term}}`, + out: `{"q":[ + {"term": "highway"}, + {"term": "pathway"}, + {"term": "parkway"}, + {"term": "motorway"} + ]}`, + }, + { + in: `{q(func:match(term, high, 8)) {term}}`, + out: `{"q":[ + {"term": "highway"}, + {"term": "high road"} + ]}`, + }, + { + in: `{q(func:match(term, str, 8)) {term}}`, + out: `{"q":[ + {"term": "street"}, + {"term": "side street"} + ]}`, + }, + { + in: `{q(func:match(term, strip, 8)) {term}}`, + out: `{"q":[ + {"term": "street"}, + {"term": "side street"} + ]}`, + }, + { + in: `{q(func:match(term, strip, 3)) {term}}`, + out: `{"q":[{"term": "street"}]}`, + }, + { + in: `{q(func:match(term, "carigeway", 8)) {term}}`, + out: `{"q":[ + {"term": "dual carriageway"} + ]}`, + }, + { + in: `{q(func:match(term, "carigeway", 4)) {term}}`, + out: `{"q":[]}`, + }, + { + in: `{q(func:match(term, "dualway", 8)) {term}}`, + out: `{"q":[ + {"term": "highway"}, + {"term": "pathway"}, + {"term": "parkway"}, + {"term": "motorway"} + ]}`, + }, + { + in: `{q(func:match(term, "dualway", 2)) {term}}`, + out: `{"q":[]}`, + }, + { + in: `{q(func:match(term, "", 8)) {term}}`, + failure: `Empty argument received`, + }, + { + in: `{q(func:match(name, "someone", 8)) {name}}`, + failure: `Attribute name is not indexed with type trigram`, + }, + } + for _, tc := range tests { + resp, err := c.NewTxn().Query(ctx, tc.in) + if tc.failure != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tc.failure) + continue + } + require.NoError(t, err) + CompareJSON(t, tc.out, string(resp.Json)) + } +} + func QueryHashIndex(t *testing.T, c *dgo.Dgraph) { ctx := context.Background() diff --git a/tok/tok.go b/tok/tok.go index d0c5b02c2d9..62e62c004ad 100644 --- a/tok/tok.go +++ b/tok/tok.go @@ -130,6 +130,17 @@ func LoadCustomTokenizer(soFile string) { registerTokenizer(CustomTokenizer{PluginTokenizer: tokenizer}) } +// GetTokenizerByID tries to find a tokenizer by id in the registered list. +// Returns the tokenizer and true if found, otherwise nil and false. +func GetTokenizerByID(id byte) (Tokenizer, bool) { + for _, t := range tokenizers { + if id == t.Identifier() { + return t, true + } + } + return nil, false +} + // GetTokenizer returns tokenizer given unique name. func GetTokenizer(name string) (Tokenizer, bool) { t, found := tokenizers[name] @@ -332,6 +343,12 @@ func EncodeRegexTokens(tokens []string) { } } +func EncodeTokens(id byte, tokens []string) { + for i := 0; i < len(tokens); i++ { + tokens[i] = encodeToken(tokens[i], id) + } +} + type BoolTokenizer struct{} func (t BoolTokenizer) Name() string { return "bool" } diff --git a/tok/tokens.go b/tok/tokens.go index 14f41e82a42..9a16c731a7d 100644 --- a/tok/tokens.go +++ b/tok/tokens.go @@ -33,11 +33,19 @@ func GetLangTokenizer(t Tokenizer, lang string) Tokenizer { return t } -func GetTermTokens(funcArgs []string) ([]string, error) { +func GetTokens(id byte, funcArgs ...string) ([]string, error) { if l := len(funcArgs); l != 1 { return nil, x.Errorf("Function requires 1 arguments, but got %d", l) } - return BuildTokens(funcArgs[0], TermTokenizer{}) + tokenizer, ok := GetTokenizerByID(id) + if !ok { + return nil, x.Errorf("No tokenizer was found with id %v", id) + } + return BuildTokens(funcArgs[0], tokenizer) +} + +func GetTermTokens(funcArgs []string) ([]string, error) { + return GetTokens(IdentTerm, funcArgs...) } func GetFullTextTokens(funcArgs []string, lang string) ([]string, error) { diff --git a/wiki/content/query-language/index.md b/wiki/content/query-language/index.md index 20ad1f0bb44..ea1d9f244a7 100644 --- a/wiki/content/query-language/index.md +++ b/wiki/content/query-language/index.md @@ -353,6 +353,39 @@ Keep the following in mind when designing regular expression queries. - If the partial result (for subset of trigrams) exceeds 1000000 uids during index scan, the query is stopped to prohibit expensive queries. +### Fuzzy matching + + +Syntax: `match(predicate, string, distance)` + +Schema Types: `string` + +Index Required: `trigram` + +Matches predicate values by calculating the [Levenshtein distance](https://en.wikipedia.org/wiki/Levenshtein_distance) to the string, +also known as _fuzzy matching_. The distance parameter must be greater than zero (0). Using a greater distance value can yield more but less accurate results. + +Query Example: At root, fuzzy match nodes similar to `Stephen`, with a distance value of 8. + +{{< runnable >}} +{ + directors(func: match(name@en, Stephen, 8)) { + name@en + } +} +{{< /runnable >}} + +Same query with a Levenshtein distance of 3. + +{{< runnable >}} +{ + directors(func: match(name@en, Stephen, 3)) { + name@en + } +} +{{< /runnable >}} + + ### Full Text Search Syntax Examples: `alloftext(predicate, "space-separated text")` and `anyoftext(predicate, "space-separated text")` diff --git a/worker/match.go b/worker/match.go new file mode 100644 index 00000000000..30cf16c2664 --- /dev/null +++ b/worker/match.go @@ -0,0 +1,114 @@ +/* + * Copyright 2019 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package worker + +import ( + "github.com/dgraph-io/dgraph/algo" + "github.com/dgraph-io/dgraph/posting" + "github.com/dgraph-io/dgraph/protos/pb" + "github.com/dgraph-io/dgraph/tok" + "github.com/dgraph-io/dgraph/x" +) + +// LevenshteinDistance measures the difference between two strings. +// The Levenshtein distance between two words is the minimum number of +// single-character edits (i.e. insertions, deletions or substitutions) +// required to change one word into the other. +// +// This implemention is optimized to use O(min(m,n)) space and is based on the +// optimized C version found here: +// http://en.wikibooks.org/wiki/Algorithm_implementation/Strings/Levenshtein_distance#C +func levenshteinDistance(s, t string, max int) int { + if len(s) > len(t) { + s, t = t, s + } + r1, r2 := []rune(s), []rune(t) // len(s) <= len(t) => len(r1) <= len(r2) + column := make([]int, len(r1)+1) + + for y := 1; y <= len(r1); y++ { + column[y] = y + } + + var minIdx int + for x := 1; x <= len(r2); x++ { + column[0] = x + + for y, lastDiag := 1, x-1; y <= len(r1); y++ { + oldDiag := column[y] + cost := 0 + if r1[y-1] != r2[x-1] { + cost = 1 + } + column[y] = min(column[y]+1, column[y-1]+1, lastDiag+cost) + lastDiag = oldDiag + } + if minIdx < len(r1) && column[minIdx] > column[minIdx+1] { + minIdx++ + } + if column[minIdx] > max { + return column[minIdx] + } + } + return column[len(r1)] +} + +func min(a, b, c int) int { + if a < b && a < c { + return a + } else if b < c { + return b + } + return c +} + +// matchFuzzy takes in a value (from posting) and compares it to our list of ngram tokens. +// Returns true if value matches fuzzy tokens, false otherwise. +func matchFuzzy(query, val string, max int) bool { + if val == "" { + return false + } + return levenshteinDistance(val, query, max) <= max +} + +// uidsForMatch collects a list of uids that "might" match a fuzzy term based on the ngram +// index. matchFuzzy does the actual fuzzy match. +// Returns the list of uids even if empty, or an error otherwise. +func uidsForMatch(attr string, arg funcArgs) (*pb.List, error) { + opts := posting.ListOptions{ReadTs: arg.q.ReadTs} + uidsForNgram := func(ngram string) (*pb.List, error) { + key := x.IndexKey(attr, ngram) + pl, err := posting.GetNoStore(key) + if err != nil { + return nil, err + } + return pl.Uids(opts) + } + + tokens, err := tok.GetTokens(tok.IdentTrigram, arg.srcFn.tokens...) + if err != nil { + return nil, err + } + + uidMatrix := make([]*pb.List, len(tokens)) + for i, t := range tokens { + uidMatrix[i], err = uidsForNgram(t) + if err != nil { + return nil, err + } + } + return algo.MergeSorted(uidMatrix), nil +} diff --git a/worker/match_test.go b/worker/match_test.go new file mode 100644 index 00000000000..e97b41a672c --- /dev/null +++ b/worker/match_test.go @@ -0,0 +1,18 @@ +package worker + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDistance(t *testing.T) { + require.Equal(t, 0, levenshteinDistance("detour", "detour", 2)) + require.Equal(t, 1, levenshteinDistance("detour", "det.our", 2)) + require.Equal(t, 2, levenshteinDistance("detour", "det..our", 2)) + require.Equal(t, 3, levenshteinDistance("detour", "..det..our", 2)) + require.Equal(t, 2, levenshteinDistance("detour", "detour..", 2)) + require.Equal(t, 3, levenshteinDistance("detour", "detour...", 2)) + require.Equal(t, 3, levenshteinDistance("detour", "...detour", 2)) + require.Equal(t, 3, levenshteinDistance("detour", "..detour.", 2)) +} diff --git a/worker/task.go b/worker/task.go index 45c43765011..e3ad8389d66 100644 --- a/worker/task.go +++ b/worker/task.go @@ -220,6 +220,7 @@ const ( HasFn UidInFn CustomIndexFn + MatchFn StandardFn = 100 ) @@ -260,6 +261,8 @@ func parseFuncTypeHelper(name string) (FuncType, string) { return UidInFn, f case "anyof", "allof": return CustomIndexFn, f + case "match": + return MatchFn, f default: if types.IsGeoFunc(f) { return GeoFn, f @@ -270,11 +273,18 @@ func parseFuncTypeHelper(name string) (FuncType, string) { func needsIndex(fnType FuncType) bool { switch fnType { - case CompareAttrFn, GeoFn, FullTextSearchFn, StandardFn: + case CompareAttrFn, GeoFn, FullTextSearchFn, StandardFn, MatchFn: return true - default: - return false } + return false +} + +// needsIntersect checks if the function type needs algo.IntersectSorted() after the results +// are collected. This is needed for functions that require all values to match, like +// "allofterms", "alloftext", and custom functions with "allof". +// Returns true if function results need intersect, false otherwise. +func needsIntersect(fnName string) bool { + return strings.HasPrefix(fnName, "allof") || strings.HasSuffix(fnName, "allof") } type funcArgs struct { @@ -294,8 +304,8 @@ func (srcFn *functionContext) needsValuePostings(typ types.TypeID) (bool, error) return false, nil } return true, nil - case GeoFn, RegexFn, FullTextSearchFn, StandardFn, HasFn, CustomIndexFn: - // All of these require index, hence would require fetching uid postings. + case GeoFn, RegexFn, FullTextSearchFn, StandardFn, HasFn, CustomIndexFn, MatchFn: + // All of these require an index, hence would require fetching uid postings. return false, nil case UidInFn, CompareScalarFn: // Operate on uid postings @@ -558,7 +568,7 @@ func (qs *queryState) handleUidPostings( } else { key = x.DataKey(q.Attr, q.UidList.Uids[i]) } - case GeoFn, RegexFn, FullTextSearchFn, StandardFn, CustomIndexFn: + case GeoFn, RegexFn, FullTextSearchFn, StandardFn, CustomIndexFn, MatchFn: key = x.IndexKey(q.Attr, srcFn.tokens[i]) case CompareAttrFn: key = x.IndexKey(q.Attr, srcFn.tokens[i]) @@ -819,6 +829,13 @@ func (qs *queryState) helpProcessTask( } } + if srcFn.fnType == MatchFn { + span.Annotate(nil, "handleMatchFunction") + if err := qs.handleMatchFunction(ctx, funcArgs{q, gid, srcFn, out}); err != nil { + return nil, err + } + } + // We fetch the actual value for the uids, compare them to the value in the // request and filter the uids only if the tokenizer IsLossy. if srcFn.fnType == CompareAttrFn && len(srcFn.tokens) > 0 { @@ -1103,6 +1120,97 @@ func (qs *queryState) handleCompareFunction(ctx context.Context, arg funcArgs) e return nil } +func (qs *queryState) handleMatchFunction(ctx context.Context, arg funcArgs) error { + span := otrace.FromContext(ctx) + stop := x.SpanTimer(span, "handleMatchFunction") + defer stop() + if span != nil { + span.Annotatef(nil, "Number of uids: %d. args.srcFn: %+v", arg.srcFn.n, arg.srcFn) + } + + attr := arg.q.Attr + typ := arg.srcFn.atype + span.Annotatef(nil, "Attr: %s. Type: %s", attr, typ.Name()) + uids := &pb.List{} + switch { + case !typ.IsScalar(): + return x.Errorf("Attribute not scalar: %s %v", attr, typ) + + case typ != types.StringID: + return x.Errorf("Got non-string type. Fuzzy match is allowed only on string type.") + + case arg.q.UidList != nil && len(arg.q.UidList.Uids) != 0: + uids = arg.q.UidList + + case schema.State().HasTokenizer(tok.IdentTrigram, attr): + var err error + uids, err = uidsForMatch(attr, arg) + if err != nil { + return err + } + + default: + return x.Errorf( + "Attribute %v does not have trigram index for fuzzy matching. "+ + "Please add a trigram index or use has/uid function with match() as filter.", + attr) + } + + isList := schema.State().IsList(attr) + lang := langForFunc(arg.q.Langs) + span.Annotatef(nil, "Total uids: %d, list: %t lang: %v", len(uids.Uids), isList, lang) + arg.out.UidMatrix = append(arg.out.UidMatrix, uids) + + matchQuery := strings.Join(arg.srcFn.tokens, "") + filtered := &pb.List{} + for _, uid := range uids.Uids { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + pl, err := qs.cache.Get(x.DataKey(attr, uid)) + if err != nil { + return err + } + + vals := make([]types.Val, 1) + switch { + case lang != "": + vals[0], err = pl.ValueForTag(arg.q.ReadTs, lang) + + case isList: + vals, err = pl.AllUntaggedValues(arg.q.ReadTs) + + default: + vals[0], err = pl.Value(arg.q.ReadTs) + } + if err != nil { + if err == posting.ErrNoValue { + continue + } + return err + } + + max := int(arg.srcFn.threshold) + for _, val := range vals { + // convert data from binary to appropriate format + strVal, err := types.Convert(val, types.StringID) + if err == nil && matchFuzzy(matchQuery, strVal.Value.(string), max) { + filtered.Uids = append(filtered.Uids, uid) + // NOTE: We only add the uid once. + break + } + } + } + + for i := 0; i < len(arg.out.UidMatrix); i++ { + algo.IntersectWith(arg.out.UidMatrix[i], filtered, arg.out.UidMatrix[i]) + } + + return nil +} + func (qs *queryState) filterGeoFunction(arg funcArgs) error { attr := arg.q.Attr uids := algo.MergeSorted(arg.out.UidMatrix) @@ -1395,8 +1503,29 @@ func parseSrcFn(q *pb.Query) (*functionContext, error) { if fc.tokens, err = getStringTokens(q.SrcFunc.Args, langForFunc(q.Langs), fnType); err != nil { return nil, err } - fnName := strings.ToLower(q.SrcFunc.Name) - fc.intersectDest = strings.HasPrefix(fnName, "allof") // allofterms and alloftext + fc.intersectDest = needsIntersect(f) + fc.n = len(fc.tokens) + case MatchFn: + if err = ensureArgsCount(q.SrcFunc, 2); err != nil { + return nil, err + } + required, found := verifyStringIndex(attr, fnType) + if !found { + return nil, x.Errorf("Attribute %s is not indexed with type %s", attr, required) + } + fc.intersectDest = needsIntersect(f) + // Max Levenshtein distance + var s string + s, q.SrcFunc.Args = q.SrcFunc.Args[1], q.SrcFunc.Args[:1] + max, err := strconv.ParseInt(s, 10, 32) + if err != nil { + return nil, x.Errorf("Levenshtein distance value must be an int, got %v", s) + } + if max < 0 { + return nil, x.Errorf("Levenshtein distance value must be greater than 0, got %v", s) + } + fc.threshold = int64(max) + fc.tokens = q.SrcFunc.Args fc.n = len(fc.tokens) case CustomIndexFn: if err = ensureArgsCount(q.SrcFunc, 2); err != nil { @@ -1417,9 +1546,7 @@ func parseSrcFn(q *pb.Query) (*functionContext, error) { } fc.tokens, _ = tok.BuildTokens(valToTok.Value, tok.GetLangTokenizer(tokenizer, langForFunc(q.Langs))) - fnName := strings.ToLower(q.SrcFunc.Name) - x.AssertTrue(fnName == "allof" || fnName == "anyof") - fc.intersectDest = strings.HasSuffix(fnName, "allof") + fc.intersectDest = needsIntersect(f) fc.n = len(fc.tokens) case RegexFn: if err = ensureArgsCount(q.SrcFunc, 2); err != nil { diff --git a/worker/tokens.go b/worker/tokens.go index c787503384e..4c7852a8759 100644 --- a/worker/tokens.go +++ b/worker/tokens.go @@ -29,9 +29,12 @@ import ( func verifyStringIndex(attr string, funcType FuncType) (string, bool) { var requiredTokenizer tok.Tokenizer - if funcType == FullTextSearchFn { + switch funcType { + case FullTextSearchFn: requiredTokenizer = tok.FullTextTokenizer{} - } else { + case MatchFn: + requiredTokenizer = tok.TrigramTokenizer{} + default: requiredTokenizer = tok.TermTokenizer{} }