diff --git a/redisearch/aggregate.go b/redisearch/aggregate.go index 270e5eb..191be4e 100644 --- a/redisearch/aggregate.go +++ b/redisearch/aggregate.go @@ -123,8 +123,8 @@ type AggregateQuery struct { Max int WithSchema bool Verbatim bool - WithCursor bool - Cursor *Cursor + WithCursor bool + Cursor *Cursor // TODO: add load fields } @@ -187,12 +187,12 @@ func (a *AggregateQuery) Limit(offset int, num int) *AggregateQuery { } //Load document fields from the document HASH objects (if they are not in the sortables) -func (a *AggregateQuery) Load( Properties []string) *AggregateQuery { +func (a *AggregateQuery) Load(Properties []string) *AggregateQuery { nproperties := len(Properties) if nproperties > 0 { a.AggregatePlan = a.AggregatePlan.Add("LOAD", nproperties) for _, property := range Properties { - a.AggregatePlan = a.AggregatePlan.Add(fmt.Sprintf( "@%s", property )) + a.AggregatePlan = a.AggregatePlan.Add(fmt.Sprintf("@%s", property)) } } return a @@ -260,6 +260,7 @@ func (q AggregateQuery) Serialize() redis.Args { return args } +// Deprecated: Please use processAggReply() instead func ProcessAggResponse(res []interface{}) [][]string { aggregateReply := make([][]string, len(res), len(res)) for i := 0; i < len(res); i++ { @@ -273,6 +274,25 @@ func ProcessAggResponse(res []interface{}) [][]string { return aggregateReply } +func processAggReply(res []interface{}) (total int, aggregateReply [][]string, err error) { + aggregateReply = [][]string{} + total = 0 + aggregate_results := len(res) - 1 + if aggregate_results > 0 { + total = aggregate_results + aggregateReply = make([][]string, aggregate_results, aggregate_results) + for i := 0; i < aggregate_results; i++ { + if d, e := redis.Strings(res[i+1], nil); e == nil { + aggregateReply[i] = d + } else { + err = fmt.Errorf("Error parsing Aggregate Reply: %v on reply position %d", e, i) + aggregateReply[i] = nil + } + } + } + return +} + func ProcessAggResponseSS(res []interface{}) [][]string { var lout = len(res) aggregateReply := make([][]string, lout, lout) diff --git a/redisearch/aggregate_test.go b/redisearch/aggregate_test.go index 0ece9f9..b485ba8 100644 --- a/redisearch/aggregate_test.go +++ b/redisearch/aggregate_test.go @@ -1,11 +1,10 @@ -package redisearch_test +package redisearch import ( "bufio" "compress/bzip2" "encoding/json" "fmt" - "github.com/RediSearch/redisearch-go/redisearch" "github.com/gomodule/redigo/redis" "github.com/stretchr/testify/assert" "log" @@ -17,6 +16,15 @@ import ( "testing" ) +func createClient(indexName string) *Client { + value, exists := os.LookupEnv("REDISEARCH_TEST_HOST") + host := "localhost:6379" + if exists && value != "" { + host = value + } + return NewClient(host, indexName) +} + // Game struct which contains a Asin, a Description, a Title, a Price, and a list of categories // a type and a list of social links @@ -30,7 +38,7 @@ type Game struct { Categories []string `json:"categories"` } -func AddValues(c *redisearch.Client) { +func AddValues(c *Client) { // Open our jsonFile bzipfile := "../tests/games.json.bz2" @@ -47,7 +55,7 @@ func AddValues(c *redisearch.Client) { d := bufio.NewReader(cr) // create a scanner scanner := bufio.NewScanner(d) - docs := make([]redisearch.Document, 0) + docs := make([]Document, 0) docPos := 1 for scanner.Scan() { // we initialize our Users array @@ -57,7 +65,7 @@ func AddValues(c *redisearch.Client) { if err != nil { fmt.Println("error:", err) } - docs = append(docs, redisearch.NewDocument(fmt.Sprintf("docs-games-%d", docPos), 1). + docs = append(docs, NewDocument(fmt.Sprintf("docs-games-%d", docPos), 1). Set("title", game.Title). Set("brand", game.Brand). Set("description", game.Description). @@ -66,7 +74,7 @@ func AddValues(c *redisearch.Client) { docPos = docPos + 1 } - if err := c.IndexOptions(redisearch.DefaultIndexingOptions, docs...); err != nil { + if err := c.IndexOptions(DefaultIndexingOptions, docs...); err != nil { log.Fatal(err) } @@ -75,12 +83,12 @@ func init() { /* load test data */ c := createClient("docs-games-idx1") - sc := redisearch.NewSchema(redisearch.DefaultOptions). - AddField(redisearch.NewTextFieldOptions("title", redisearch.TextFieldOptions{Sortable: true})). - AddField(redisearch.NewTextFieldOptions("brand", redisearch.TextFieldOptions{Sortable: true, NoStem: true})). - AddField(redisearch.NewTextField("description")). - AddField(redisearch.NewSortableNumericField("price")). - AddField(redisearch.NewTagField("categories")) + sc := NewSchema(DefaultOptions). + AddField(NewTextFieldOptions("title", TextFieldOptions{Sortable: true})). + AddField(NewTextFieldOptions("brand", TextFieldOptions{Sortable: true, NoStem: true})). + AddField(NewTextField("description")). + AddField(NewSortableNumericField("price")). + AddField(NewTagField("categories")) c.Drop() c.CreateIndex(sc) @@ -91,10 +99,10 @@ func TestAggregateGroupBy(t *testing.T) { c := createClient("docs-games-idx1") - q1 := redisearch.NewAggregateQuery(). - GroupBy(*redisearch.NewGroupBy().AddFields("@brand"). - Reduce(*redisearch.NewReducerAlias(redisearch.GroupByReducerCount, []string{}, "count"))). - SortBy([]redisearch.SortingKey{*redisearch.NewSortingKeyDir("@count", false)}). + q1 := NewAggregateQuery(). + GroupBy(*NewGroupBy().AddFields("@brand"). + Reduce(*NewReducerAlias(GroupByReducerCount, []string{}, "count"))). + SortBy([]SortingKey{*NewSortingKeyDir("@count", false)}). Limit(0, 5) _, count, err := c.Aggregate(q1) @@ -106,24 +114,25 @@ func TestAggregateMinMax(t *testing.T) { c := createClient("docs-games-idx1") - q1 := redisearch.NewAggregateQuery().SetQuery(redisearch.NewQuery("sony")). - GroupBy(*redisearch.NewGroupBy().AddFields("@brand"). - Reduce(*redisearch.NewReducer(redisearch.GroupByReducerCount, []string{})). - Reduce(*redisearch.NewReducerAlias(redisearch.GroupByReducerMin, []string{"@price"}, "minPrice"))). - SortBy([]redisearch.SortingKey{*redisearch.NewSortingKeyDir("@minPrice", false)}) + q1 := NewAggregateQuery().SetQuery(NewQuery("sony")). + GroupBy(*NewGroupBy().AddFields("@brand"). + Reduce(*NewReducer(GroupByReducerCount, []string{})). + Reduce(*NewReducerAlias(GroupByReducerMin, []string{"@price"}, "minPrice"))). + SortBy([]SortingKey{*NewSortingKeyDir("@minPrice", false)}) res, _, err := c.Aggregate(q1) assert.Nil(t, err) row := res[0] + fmt.Println(row) f, _ := strconv.ParseFloat(row[5], 64) assert.GreaterOrEqual(t, f, 88.0) assert.Less(t, f, 89.0) - q2 := redisearch.NewAggregateQuery().SetQuery(redisearch.NewQuery("sony")). - GroupBy(*redisearch.NewGroupBy().AddFields("@brand"). - Reduce(*redisearch.NewReducer(redisearch.GroupByReducerCount, []string{})). - Reduce(*redisearch.NewReducerAlias(redisearch.GroupByReducerMax, []string{"@price"}, "maxPrice"))). - SortBy([]redisearch.SortingKey{*redisearch.NewSortingKeyDir("@maxPrice", false)}) + q2 := NewAggregateQuery().SetQuery(NewQuery("sony")). + GroupBy(*NewGroupBy().AddFields("@brand"). + Reduce(*NewReducer(GroupByReducerCount, []string{})). + Reduce(*NewReducerAlias(GroupByReducerMax, []string{"@price"}, "maxPrice"))). + SortBy([]SortingKey{*NewSortingKeyDir("@maxPrice", false)}) res, _, err = c.Aggregate(q2) assert.Nil(t, err) @@ -137,10 +146,10 @@ func TestAggregateCountDistinct(t *testing.T) { c := createClient("docs-games-idx1") - q1 := redisearch.NewAggregateQuery(). - GroupBy(*redisearch.NewGroupBy().AddFields("@brand"). - Reduce(*redisearch.NewReducer(redisearch.GroupByReducerCountDistinct, []string{"@title"}).SetAlias("count_distinct(title)")). - Reduce(*redisearch.NewReducer(redisearch.GroupByReducerCount, []string{}))) + q1 := NewAggregateQuery(). + GroupBy(*NewGroupBy().AddFields("@brand"). + Reduce(*NewReducer(GroupByReducerCountDistinct, []string{"@title"}).SetAlias("count_distinct(title)")). + Reduce(*NewReducer(GroupByReducerCount, []string{}))) res, _, err := c.Aggregate(q1) assert.Nil(t, err) @@ -152,9 +161,9 @@ func TestAggregateFilter(t *testing.T) { c := createClient("docs-games-idx1") - q1 := redisearch.NewAggregateQuery(). - GroupBy(*redisearch.NewGroupBy().AddFields("@brand"). - Reduce(*redisearch.NewReducerAlias(redisearch.GroupByReducerCount, []string{}, "count"))). + q1 := NewAggregateQuery(). + GroupBy(*NewGroupBy().AddFields("@brand"). + Reduce(*NewReducerAlias(GroupByReducerCount, []string{}, "count"))). Filter("@count > 5") res, _, err := c.Aggregate(q1) @@ -184,13 +193,13 @@ func makeAggResponseInterface(seed int64, nElements int, responseSizes []int) (r func benchmarkProcessAggResponseSS(res []interface{}, total int, b *testing.B) { for n := 0; n < b.N; n++ { - redisearch.ProcessAggResponseSS(res) + ProcessAggResponseSS(res) } } func benchmarkProcessAggResponse(res []interface{}, total int, b *testing.B) { for n := 0; n < b.N; n++ { - redisearch.ProcessAggResponse(res) + ProcessAggResponse(res) } } @@ -232,7 +241,7 @@ func TestProjection_Serialize(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - p := redisearch.Projection{ + p := Projection{ Expression: tt.fields.Expression, Alias: tt.fields.Alias, } @@ -260,7 +269,7 @@ func TestCursor_Serialize(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := redisearch.Cursor{ + c := Cursor{ Id: tt.fields.Id, Count: tt.fields.Count, MaxIdle: tt.fields.MaxIdle, @@ -275,8 +284,8 @@ func TestCursor_Serialize(t *testing.T) { func TestGroupBy_AddFields(t *testing.T) { type fields struct { Fields []string - Reducers []redisearch.Reducer - Paging *redisearch.Paging + Reducers []Reducer + Paging *Paging } type args struct { fields interface{} @@ -285,17 +294,17 @@ func TestGroupBy_AddFields(t *testing.T) { name string fields fields args args - want *redisearch.GroupBy + want *GroupBy }{ {"TestGroupBy_AddFields_1", fields{[]string{}, nil, nil}, args{"a",}, - &redisearch.GroupBy{[]string{"a"}, nil, nil}, + &GroupBy{[]string{"a"}, nil, nil}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - g := &redisearch.GroupBy{ + g := &GroupBy{ Fields: tt.fields.Fields, Reducers: tt.fields.Reducers, Paging: tt.fields.Paging, @@ -310,8 +319,8 @@ func TestGroupBy_AddFields(t *testing.T) { func TestGroupBy_Limit(t *testing.T) { type fields struct { Fields []string - Reducers []redisearch.Reducer - Paging *redisearch.Paging + Reducers []Reducer + Paging *Paging } type args struct { offset int @@ -321,17 +330,17 @@ func TestGroupBy_Limit(t *testing.T) { name string fields fields args args - want *redisearch.GroupBy + want *GroupBy }{ {"TestGroupBy_Limit_1", fields{[]string{}, nil, nil}, args{10, 20}, - &redisearch.GroupBy{[]string{}, nil, &redisearch.Paging{10, 20}}, + &GroupBy{[]string{}, nil, &Paging{10, 20}}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - g := &redisearch.GroupBy{ + g := &GroupBy{ Fields: tt.fields.Fields, Reducers: tt.fields.Reducers, Paging: tt.fields.Paging, @@ -347,14 +356,14 @@ func TestGroupBy_Limit(t *testing.T) { func TestAggregateQuery_SetMax(t *testing.T) { type fields struct { - Query *redisearch.Query + Query *Query AggregatePlan redis.Args - Paging *redisearch.Paging + Paging *Paging Max int WithSchema bool Verbatim bool WithCursor bool - Cursor *redisearch.Cursor + Cursor *Cursor } type args struct { value int @@ -363,17 +372,17 @@ func TestAggregateQuery_SetMax(t *testing.T) { name string fields fields args args - want *redisearch.AggregateQuery + want *AggregateQuery }{ {"TestAggregateQuery_SetMax_1", fields{nil, redis.Args{}, nil, 0, false, false, false, nil}, args{10}, - &redisearch.AggregateQuery{nil, redis.Args{}, nil, 10, false, false, false, nil}, + &AggregateQuery{nil, redis.Args{}, nil, 10, false, false, false, nil}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - a := &redisearch.AggregateQuery{ + a := &AggregateQuery{ Query: tt.fields.Query, AggregatePlan: tt.fields.AggregatePlan, Paging: tt.fields.Paging, @@ -392,14 +401,14 @@ func TestAggregateQuery_SetMax(t *testing.T) { func TestAggregateQuery_SetVerbatim(t *testing.T) { type fields struct { - Query *redisearch.Query + Query *Query AggregatePlan redis.Args - Paging *redisearch.Paging + Paging *Paging Max int WithSchema bool Verbatim bool WithCursor bool - Cursor *redisearch.Cursor + Cursor *Cursor } type args struct { value bool @@ -408,17 +417,17 @@ func TestAggregateQuery_SetVerbatim(t *testing.T) { name string fields fields args args - want *redisearch.AggregateQuery + want *AggregateQuery }{ {"TestAggregateQuery_SetVerbatim_1", fields{nil, redis.Args{}, nil, 0, false, false, false, nil}, args{true}, - &redisearch.AggregateQuery{nil, redis.Args{}, nil, 0, false, true, false, nil}, + &AggregateQuery{nil, redis.Args{}, nil, 0, false, true, false, nil}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - a := &redisearch.AggregateQuery{ + a := &AggregateQuery{ Query: tt.fields.Query, AggregatePlan: tt.fields.AggregatePlan, Paging: tt.fields.Paging, @@ -437,14 +446,14 @@ func TestAggregateQuery_SetVerbatim(t *testing.T) { func TestAggregateQuery_SetWithSchema(t *testing.T) { type fields struct { - Query *redisearch.Query + Query *Query AggregatePlan redis.Args - Paging *redisearch.Paging + Paging *Paging Max int WithSchema bool Verbatim bool WithCursor bool - Cursor *redisearch.Cursor + Cursor *Cursor } type args struct { value bool @@ -453,17 +462,17 @@ func TestAggregateQuery_SetWithSchema(t *testing.T) { name string fields fields args args - want *redisearch.AggregateQuery + want *AggregateQuery }{ {"TestAggregateQuery_SetWithSchema_1", fields{nil, redis.Args{}, nil, 0, false, false, false, nil}, args{true}, - &redisearch.AggregateQuery{nil, redis.Args{}, nil, 0, true, false, false, nil}, + &AggregateQuery{nil, redis.Args{}, nil, 0, true, false, false, nil}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - a := &redisearch.AggregateQuery{ + a := &AggregateQuery{ Query: tt.fields.Query, AggregatePlan: tt.fields.AggregatePlan, Paging: tt.fields.Paging, @@ -482,14 +491,14 @@ func TestAggregateQuery_SetWithSchema(t *testing.T) { func TestAggregateQuery_CursorHasResults(t *testing.T) { type fields struct { - Query *redisearch.Query + Query *Query AggregatePlan redis.Args - Paging *redisearch.Paging + Paging *Paging Max int WithSchema bool Verbatim bool WithCursor bool - Cursor *redisearch.Cursor + Cursor *Cursor } tests := []struct { name string @@ -501,13 +510,13 @@ func TestAggregateQuery_CursorHasResults(t *testing.T) { false, }, {"TestAggregateQuery_CursorHasResults_1_true", - fields{nil, redis.Args{}, nil, 0, false, false, false, redisearch.NewCursor().SetId(10)}, + fields{nil, redis.Args{}, nil, 0, false, false, false, NewCursor().SetId(10)}, true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - a := &redisearch.AggregateQuery{ + a := &AggregateQuery{ Query: tt.fields.Query, AggregatePlan: tt.fields.AggregatePlan, Paging: tt.fields.Paging, @@ -526,14 +535,14 @@ func TestAggregateQuery_CursorHasResults(t *testing.T) { func TestAggregateQuery_Load(t *testing.T) { type fields struct { - Query *redisearch.Query + Query *Query AggregatePlan redis.Args - Paging *redisearch.Paging + Paging *Paging Max int WithSchema bool Verbatim bool WithCursor bool - Cursor *redisearch.Cursor + Cursor *Cursor } type args struct { Properties []string @@ -562,7 +571,7 @@ func TestAggregateQuery_Load(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - a := &redisearch.AggregateQuery{ + a := &AggregateQuery{ Query: tt.fields.Query, AggregatePlan: tt.fields.AggregatePlan, Paging: tt.fields.Paging, @@ -578,3 +587,56 @@ func TestAggregateQuery_Load(t *testing.T) { }) } } + +func TestProcessAggResponse(t *testing.T) { + type args struct { + res []interface{} + } + tests := []struct { + name string + args args + want [][]string + }{ + {"empty-reply", args{[]interface{}{}}, [][]string{},}, + {"1-element-reply", args{[]interface{}{[]interface{}{"userFullName", "berge, julius", "count", "2783"}}}, [][]string{{"userFullName", "berge, julius", "count", "2783"}},}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ProcessAggResponse(tt.args.res); !reflect.DeepEqual(got, tt.want) { + t.Errorf("ProcessAggResponse() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_processAggReply(t *testing.T) { + type args struct { + res []interface{} + } + tests := []struct { + name string + args args + wantTotal int + wantAggregateReply [][]string + wantErr bool + }{ + {"empty-reply", args{[]interface{}{}}, 0, [][]string{}, false}, + {"1-element-reply", args{[]interface{}{1, []interface{}{"userFullName", "j", "count", "2"}}}, 1, [][]string{{"userFullName", "j", "count", "2"}}, false}, + {"multi-element-reply", args{[]interface{}{2, []interface{}{"userFullName", "j"}, []interface{}{"userFullName", "a"}}}, 2, [][]string{{"userFullName", "j"}, {"userFullName", "a"}}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotTotal, gotAggregateReply, err := processAggReply(tt.args.res) + if (err != nil) != tt.wantErr { + t.Errorf("processAggReply() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotTotal != tt.wantTotal { + t.Errorf("processAggReply() gotTotal = %v, want %v", gotTotal, tt.wantTotal) + } + if !reflect.DeepEqual(gotAggregateReply, tt.wantAggregateReply) { + t.Errorf("processAggReply() gotAggregateReply = %v, want %v", gotAggregateReply, tt.wantAggregateReply) + } + }) + } +} diff --git a/redisearch/client.go b/redisearch/client.go index e5bbf9a..f7a0465 100644 --- a/redisearch/client.go +++ b/redisearch/client.go @@ -412,13 +412,8 @@ func (i *Client) Aggregate(q *AggregateQuery) (aggregateReply [][]string, total } // has no cursor if !hasCursor { - total = len(res) - 1 - // there is a case when only 1 data from aggregate, it returns nothing - // then set total > 0 so the data will be return - if total > 0 { - aggregateReply = ProcessAggResponse(res[1:]) - } - // has cursor + total, aggregateReply,err = processAggReply(res) + // has cursor } else { var partialResults, err = redis.Values(res[0], nil) if err != nil { @@ -428,12 +423,7 @@ func (i *Client) Aggregate(q *AggregateQuery) (aggregateReply [][]string, total if err != nil { return aggregateReply, total, err } - total = len(partialResults) - 1 - // there is a case when only 1 data from aggregate, it returns nothing - // then set total > 0 so the data will be return - if total > 0 { - aggregateReply = ProcessAggResponse(partialResults[1:]) - } + total, aggregateReply,err = processAggReply(partialResults) } return diff --git a/redisearch/document_test.go b/redisearch/document_test.go index 512cf3d..94d2d37 100644 --- a/redisearch/document_test.go +++ b/redisearch/document_test.go @@ -2,6 +2,7 @@ package redisearch_test import ( "github.com/RediSearch/redisearch-go/redisearch" + "reflect" "testing" ) @@ -29,3 +30,76 @@ func TestEscapeTextFileString(t *testing.T) { }) } } + +func TestDocument_EstimateSize(t *testing.T) { + type fields struct { + Id string + Score float32 + Payload []byte + Properties map[string]interface{} + } + tests := []struct { + name string + fields fields + wantSz int + }{ + { + "only-id", fields{"doc1", 1.0, []byte{}, map[string]interface{}{},}, len("doc1"), + }, + { + "id-payload", fields{"doc1", 1.0, []byte("payload"), map[string]interface{}{},}, len("doc1") + len([]byte("payload")), + }, + { + "id-payload-fields", fields{"doc1", 1.0, []byte("payload"), map[string]interface{}{"text1": []byte("text1")},}, len("doc1") + len([]byte("payload")) + 2*len([]byte("text1")), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := &redisearch.Document{ + Id: tt.fields.Id, + Score: tt.fields.Score, + Payload: tt.fields.Payload, + Properties: tt.fields.Properties, + } + if gotSz := d.EstimateSize(); !reflect.DeepEqual(gotSz, tt.wantSz) { + t.Errorf("EstimateSize() = %v, want %v", gotSz, tt.wantSz) + } + }) + } +} + +func TestDocument_SetPayload(t *testing.T) { + type fields struct { + Id string + Score float32 + Payload []byte + Properties map[string]interface{} + } + type args struct { + payload []byte + } + tests := []struct { + name string + fields fields + args args + wantPayload []byte + }{ + {"empty-payload", fields{"doc1", 1.0, []byte{}, map[string]interface{}{},}, args{[]byte{}}, []byte{}}, + {"simple-set", fields{"doc1", 1.0, []byte{}, map[string]interface{}{},}, args{[]byte("payload")},[]byte("payload")}, + {"set-with-previous-payload", fields{"doc1", 1.0, []byte("previous_payload"), map[string]interface{}{},}, args{[]byte("payload")},[]byte("payload")}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := &redisearch.Document{ + Id: tt.fields.Id, + Score: tt.fields.Score, + Payload: tt.fields.Payload, + Properties: tt.fields.Properties, + } + d.SetPayload(tt.args.payload) + if !reflect.DeepEqual(d.Payload, tt.wantPayload) { + t.Errorf("SetPayload() = %v, want %v", d.Payload, tt.wantPayload) + } + }) + } +} diff --git a/redisearch/query_test.go b/redisearch/query_test.go index b6d30b5..2d8cc32 100644 --- a/redisearch/query_test.go +++ b/redisearch/query_test.go @@ -18,6 +18,8 @@ func TestPaging_serialize(t *testing.T) { }{ {"default", fields{0, 10}, redis.Args{}}, {"0-1000", fields{0, 1000}, redis.Args{"LIMIT", 0, 1000}}, + {"0-2", fields{0, 2}, redis.Args{"LIMIT", 0, 2}}, + {"100-10", fields{100, 10}, redis.Args{"LIMIT", 100, 10}}, {"100-200", fields{100, 200}, redis.Args{"LIMIT", 100, 200}}, } for _, tt := range tests { diff --git a/redisearch/spellcheck.go b/redisearch/spellcheck.go index 8528616..560f3dd 100644 --- a/redisearch/spellcheck.go +++ b/redisearch/spellcheck.go @@ -95,7 +95,10 @@ func NewMisspelledTerm(term string) MisspelledTerm { func (l MisspelledTerm) Len() int { return len(l.MisspelledSuggestionList) } func (l MisspelledTerm) Swap(i, j int) { - l.MisspelledSuggestionList[i], l.MisspelledSuggestionList[j] = l.MisspelledSuggestionList[j], l.MisspelledSuggestionList[i] + maxLen := len(l.MisspelledSuggestionList) + if i < maxLen && j < maxLen { + l.MisspelledSuggestionList[i], l.MisspelledSuggestionList[j] = l.MisspelledSuggestionList[j], l.MisspelledSuggestionList[i] + } } func (l MisspelledTerm) Less(i, j int) bool { return l.MisspelledSuggestionList[i].Score > l.MisspelledSuggestionList[j].Score @@ -108,11 +111,20 @@ func (l MisspelledTerm) Sort() { // convert the result from a redis spelling correction on a query to a proper MisspelledTerm object func loadMisspelledTerm(arr []interface{}, termIdx, suggIdx int) (missT MisspelledTerm, err error) { + if len(arr) == 0 { + return MisspelledTerm{}, nil + } + if termIdx >= len(arr) { + return MisspelledTerm{}, fmt.Errorf("term index: (%d) is larger than reply size: %d", termIdx, len(arr)) + } term, err := redis.String(arr[termIdx], err) if err != nil { return MisspelledTerm{}, fmt.Errorf("Could not parse term: %s", err) } missT = NewMisspelledTerm(term) + if suggIdx >= len(arr) { + return MisspelledTerm{}, fmt.Errorf("suggestion index: (%d) is larger than reply size: %d", suggIdx, len(arr)) + } lst, err := redis.Values(arr[suggIdx], err) if err != nil { return MisspelledTerm{}, fmt.Errorf("Could not get the array of suggestions for spelling corrections on term %s. Error: %s", term, err) @@ -122,6 +134,9 @@ func loadMisspelledTerm(arr []interface{}, termIdx, suggIdx int) (missT Misspell if err != nil { return MisspelledTerm{}, fmt.Errorf("Could not get the inner array of suggestions for spelling corrections on term %s. Error: %s", term, err) } + if len(innerLst) != 2 { + return MisspelledTerm{}, fmt.Errorf("expects 2 elements per inner-array") + } score, err := redis.Float64(innerLst[0], err) if err != nil { return MisspelledTerm{}, fmt.Errorf("Could not parse score: %s", err) diff --git a/redisearch/spellcheck_test.go b/redisearch/spellcheck_test.go new file mode 100644 index 0000000..1ad402c --- /dev/null +++ b/redisearch/spellcheck_test.go @@ -0,0 +1,392 @@ +package redisearch + +import ( + "github.com/gomodule/redigo/redis" + "reflect" + "testing" +) + +func TestMisspelledTerm_Len(t *testing.T) { + type fields struct { + Term string + MisspelledSuggestionList []MisspelledSuggestion + } + tests := []struct { + name string + fields fields + want int + }{ + {"empty", fields{"empty", []MisspelledSuggestion{},}, 0,}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l := MisspelledTerm{ + Term: tt.fields.Term, + MisspelledSuggestionList: tt.fields.MisspelledSuggestionList, + } + if got := l.Len(); got != tt.want { + t.Errorf("Len() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMisspelledTerm_Less(t *testing.T) { + type fields struct { + Term string + MisspelledSuggestionList []MisspelledSuggestion + } + type args struct { + i int + j int + } + tests := []struct { + name string + fields fields + args args + want bool + }{ + {"double-value-list-true", fields{"double", []MisspelledSuggestion{NewMisspelledSuggestion("double", 0), NewMisspelledSuggestion("doublee", 0.1)},}, args{1, 0}, true,}, + {"double-value-list-false", fields{"double", []MisspelledSuggestion{NewMisspelledSuggestion("double", 0), NewMisspelledSuggestion("doublee", 0.1)},}, args{0, 1}, false,}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l := MisspelledTerm{ + Term: tt.fields.Term, + MisspelledSuggestionList: tt.fields.MisspelledSuggestionList, + } + if got := l.Less(tt.args.i, tt.args.j); got != tt.want { + t.Errorf("Less() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMisspelledTerm_Sort(t *testing.T) { + type fields struct { + Term string + MisspelledSuggestionList []MisspelledSuggestion + } + tests := []struct { + name string + fields fields + want []MisspelledSuggestion + }{ + {"empty", fields{"empty", []MisspelledSuggestion{},}, []MisspelledSuggestion{},}, + {"double-value-list", fields{"double", []MisspelledSuggestion{NewMisspelledSuggestion("double", 0), NewMisspelledSuggestion("doublee", 0.1)},}, []MisspelledSuggestion{NewMisspelledSuggestion("doublee", 0.1), NewMisspelledSuggestion("double", 0)},}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l := MisspelledTerm{ + Term: tt.fields.Term, + MisspelledSuggestionList: tt.fields.MisspelledSuggestionList, + } + l.Sort() + if !reflect.DeepEqual(l.MisspelledSuggestionList, tt.want) { + t.Errorf("Sort() = %v, want %v", l.MisspelledSuggestionList, tt.want) + } + }) + } +} + +func TestMisspelledTerm_Swap(t *testing.T) { + type fields struct { + Term string + MisspelledSuggestionList []MisspelledSuggestion + } + type args struct { + i int + j int + } + tests := []struct { + name string + fields fields + args args + want []MisspelledSuggestion + }{ + {"empty-list", fields{"empty", []MisspelledSuggestion{},}, args{0, 1}, []MisspelledSuggestion{},}, + {"single-value-list", fields{"single", []MisspelledSuggestion{NewMisspelledSuggestion("first", 1)},}, args{0, 1}, []MisspelledSuggestion{NewMisspelledSuggestion("first", 1)},}, + {"double-value-list", fields{"doubl", []MisspelledSuggestion{NewMisspelledSuggestion("double", 1), NewMisspelledSuggestion("doublee", 0.1)},}, args{0, 1}, []MisspelledSuggestion{NewMisspelledSuggestion("doublee", 0.1), NewMisspelledSuggestion("double", 1)},}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l := MisspelledTerm{ + Term: tt.fields.Term, + MisspelledSuggestionList: tt.fields.MisspelledSuggestionList, + } + l.Swap(tt.args.i, tt.args.j) + if !reflect.DeepEqual(l.MisspelledSuggestionList, tt.want) { + t.Errorf("Sort() = %v, want %v", l.MisspelledSuggestionList, tt.want) + } + }) + } +} + +func TestNewMisspelledSuggestion(t *testing.T) { + type args struct { + term string + score float32 + } + tests := []struct { + name string + args args + want MisspelledSuggestion + }{ + {"simple", args{"term", 1}, MisspelledSuggestion{"term", 1},}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := NewMisspelledSuggestion(tt.args.term, tt.args.score); !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewMisspelledSuggestion() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNewMisspelledTerm(t *testing.T) { + type args struct { + term string + } + tests := []struct { + name string + args args + want MisspelledTerm + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := NewMisspelledTerm(tt.args.term); !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewMisspelledTerm() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNewSpellCheckOptions(t *testing.T) { + type args struct { + distance int + } + tests := []struct { + name string + args args + wantD int + wantExclusion []string + wantInclusion []string + }{ + {"1", args{1}, 1, []string{}, []string{}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NewSpellCheckOptions(tt.args.distance) + if !reflect.DeepEqual(got.Distance, tt.wantD) { + t.Errorf("NewSpellCheckOptions() = %v, want %v", got.Distance, tt.wantD) + } + if !reflect.DeepEqual(got.ExclusionDicts, tt.wantExclusion) { + t.Errorf("NewSpellCheckOptions() = %v, want %v", got.ExclusionDicts, tt.wantExclusion) + } + if !reflect.DeepEqual(got.InclusionDicts, tt.wantInclusion) { + t.Errorf("NewSpellCheckOptions() = %v, want %v", got.InclusionDicts, tt.wantInclusion) + } + }) + } +} + +func TestNewSpellCheckOptionsDefaults(t *testing.T) { + tests := []struct { + name string + wantD int + wantExclusion []string + wantInclusion []string + }{ + {"1", 1, []string{}, []string{}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NewSpellCheckOptionsDefaults() + if !reflect.DeepEqual(got.Distance, tt.wantD) { + t.Errorf("TestNewSpellCheckOptionsDefaults() = %v, want %v", got.Distance, tt.wantD) + } + if !reflect.DeepEqual(got.ExclusionDicts, tt.wantExclusion) { + t.Errorf("TestNewSpellCheckOptionsDefaults() = %v, want %v", got.ExclusionDicts, tt.wantExclusion) + } + if !reflect.DeepEqual(got.InclusionDicts, tt.wantInclusion) { + t.Errorf("TestNewSpellCheckOptionsDefaults() = %v, want %v", got.InclusionDicts, tt.wantInclusion) + } + }) + } +} + +func TestSpellCheckOptions_AddExclusionDict(t *testing.T) { + type fields struct { + Distance int + ExclusionDicts []string + InclusionDicts []string + } + type args struct { + dictname string + } + tests := []struct { + name string + fields fields + args args + want []string + }{ + {"empty", fields{1, []string{}, []string{}}, args{"dict1"}, []string{"dict1"},}, + {"one-prior", fields{1, []string{"dict1"}, []string{}}, args{"dict2"}, []string{"dict1", "dict2"},}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &SpellCheckOptions{ + Distance: tt.fields.Distance, + ExclusionDicts: tt.fields.ExclusionDicts, + InclusionDicts: tt.fields.InclusionDicts, + } + if got := s.AddExclusionDict(tt.args.dictname).ExclusionDicts; !reflect.DeepEqual(got, tt.want) { + t.Errorf("AddExclusionDict() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSpellCheckOptions_AddInclusionDict(t *testing.T) { + type fields struct { + Distance int + ExclusionDicts []string + InclusionDicts []string + } + type args struct { + dictname string + } + tests := []struct { + name string + fields fields + args args + want []string + }{ + {"empty", fields{1, []string{}, []string{}}, args{"dict1"}, []string{"dict1"},}, + {"one-prior", fields{1, []string{}, []string{"dict1"}}, args{"dict2"}, []string{"dict1", "dict2"},}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &SpellCheckOptions{ + Distance: tt.fields.Distance, + ExclusionDicts: tt.fields.ExclusionDicts, + InclusionDicts: tt.fields.InclusionDicts, + } + if got := s.AddInclusionDict(tt.args.dictname).InclusionDicts; !reflect.DeepEqual(got, tt.want) { + t.Errorf("AddInclusionDict() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSpellCheckOptions_SetDistance(t *testing.T) { + type fields struct { + Distance int + ExclusionDicts []string + InclusionDicts []string + } + type args struct { + distance int + } + tests := []struct { + name string + fields fields + args args + want int + wantErr bool + }{ + {"error-lower", fields{1, []string{}, []string{}}, args{0}, 1, true}, + {"error-upper", fields{1, []string{}, []string{}}, args{5}, 1, true}, + {"distance-4", fields{1, []string{}, []string{}}, args{4}, 4, false}, + {"distance-1", fields{4, []string{}, []string{}}, args{1}, 1, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &SpellCheckOptions{ + Distance: tt.fields.Distance, + ExclusionDicts: tt.fields.ExclusionDicts, + InclusionDicts: tt.fields.InclusionDicts, + } + got, err := s.SetDistance(tt.args.distance) + if (err != nil) != tt.wantErr { + t.Errorf("SetDistance() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got.Distance, tt.want) { + t.Errorf("SetDistance() got = %v, want %v", got.Distance, tt.want) + } + }) + } +} + +func TestSpellCheckOptions_serialize(t *testing.T) { + type fields struct { + Distance int + ExclusionDicts []string + InclusionDicts []string + } + tests := []struct { + name string + fields fields + want redis.Args + }{ + // TODO: Add test cases. + {"empty", fields{1, []string{}, []string{}}, redis.Args{}}, + {"exclude", fields{1, []string{"dict1"}, []string{}}, redis.Args{"TERMS", "EXCLUDE", "dict1"}}, + {"include", fields{1, []string{}, []string{"dict1"}}, redis.Args{"TERMS", "INCLUDE", "dict1"}}, + {"all", fields{2, []string{"dict1"}, []string{"dict2"}}, redis.Args{"DISTANCE", 2, "TERMS", "EXCLUDE", "dict1", "TERMS", "INCLUDE", "dict2"}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := SpellCheckOptions{ + Distance: tt.fields.Distance, + ExclusionDicts: tt.fields.ExclusionDicts, + InclusionDicts: tt.fields.InclusionDicts, + } + if got := s.serialize(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("serialize() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_loadMisspelledTerm(t *testing.T) { + type args struct { + arr []interface{} + termIdx int + suggIdx int + } + // Each misspelled term, in turn, is a 3-element array consisting of + // - the constant string "TERM" ( 3-element position 0 -- we dont use it ) + // - the term itself ( 3-element position 1 ) + // - an array of suggestions for spelling corrections ( 3-element position 2 ) + //termIdx := 1 + //suggIdx := 2 + // + tests := []struct { + name string + args args + wantMissT MisspelledTerm + wantErr bool + }{ + {"empty", args{[]interface{}{}, 1, 2,}, MisspelledTerm{}, false}, + {"missing term", args{[]interface{}{"TERM",}, 1, 2,}, MisspelledTerm{}, true}, + {"missing sugestion array", args{[]interface{}{"TERM", "hockye"}, 1, 2,}, MisspelledTerm{}, true}, + {"incorrect float", args{[]interface{}{"TERM", "hockye", []interface{}{[]interface{}{[]byte("INCORRECT"), []byte("hockey")}}}, 1, 2,}, MisspelledTerm{}, true}, + {"correct1", args{[]interface{}{"TERM", "hockye", []interface{}{[]interface{}{[]byte("1"), []byte("hockey")}}}, 1, 2,}, MisspelledTerm{"hockye", []MisspelledSuggestion{NewMisspelledSuggestion("hockey", 1.0)}}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotMissT, err := loadMisspelledTerm(tt.args.arr, tt.args.termIdx, tt.args.suggIdx) + if (err != nil) != tt.wantErr { + t.Errorf("loadMisspelledTerm() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(gotMissT, tt.wantMissT) { + t.Errorf("loadMisspelledTerm() gotMissT = %v, want %v", gotMissT, tt.wantMissT) + } + }) + } +}