diff --git a/firestore/pipeline_integration_test.go b/firestore/pipeline_integration_test.go index a6a2194a8f07..bca501857144 100644 --- a/firestore/pipeline_integration_test.go +++ b/firestore/pipeline_integration_test.go @@ -19,7 +19,6 @@ import ( "fmt" "math" "sort" - "strings" "testing" "time" @@ -27,14 +26,45 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" "google.golang.org/api/iterator" "google.golang.org/genproto/googleapis/type/latlng" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" ) -func TestIntegration_PipelineExecute(t *testing.T) { +func skipIfNotEnterprise(t *testing.T) { if testParams[firestoreEditionKey].(firestoreEdition) != editionEnterprise { - t.Skip("Skipping pipeline queries tests since the firestore edition of", testParams[databaseIDKey].(string), "database is not enterprise") + t.Skip("Skipping test in non-enterprise environment") } +} + +type Author struct { + Name string `firestore:"name"` + Country string `firestore:"country"` +} + +type Book struct { + Title string `firestore:"title"` + Author Author `firestore:"author"` + Genre string `firestore:"genre"` + Published int `firestore:"published"` + Rating float64 `firestore:"rating"` + Tags []string `firestore:"tags"` +} + +func testBooks() []Book { + return []Book{ + {Title: "The Hitchhiker's Guide to the Galaxy", Author: Author{Name: "Douglas Adams", Country: "UK"}, Genre: "Science Fiction", Published: 1979, Rating: 4.2, Tags: []string{"comedy", "space", "adventure"}}, + {Title: "Pride and Prejudice", Author: Author{Name: "Jane Austen", Country: "UK"}, Genre: "Romance", Published: 1813, Rating: 4.5, Tags: []string{"classic", "social commentary", "love"}}, + {Title: "One Hundred Years of Solitude", Author: Author{Name: "Gabriel García Márquez", Country: "Colombia"}, Genre: "Magical Realism", Published: 1967, Rating: 4.3, Tags: []string{"family", "history", "fantasy"}}, + {Title: "The Lord of the Rings", Author: Author{Name: "J.R.R. Tolkien", Country: "UK"}, Genre: "Fantasy", Published: 1954, Rating: 4.7, Tags: []string{"adventure", "magic", "epic"}}, + {Title: "The Handmaid's Tale", Author: Author{Name: "Margaret Atwood", Country: "Canada"}, Genre: "Dystopian", Published: 1985, Rating: 4.1, Tags: []string{"feminism", "totalitarianism", "resistance"}}, + {Title: "Crime and Punishment", Author: Author{Name: "Fyodor Dostoevsky", Country: "Russia"}, Genre: "Psychological Thriller", Published: 1866, Rating: 4.3, Tags: []string{"philosophy", "crime", "redemption"}}, + {Title: "To Kill a Mockingbird", Author: Author{Name: "Harper Lee", Country: "USA"}, Genre: "Southern Gothic", Published: 1960, Rating: 4.2, Tags: []string{"racism", "injustice", "coming-of-age"}}, + {Title: "1984", Author: Author{Name: "George Orwell", Country: "UK"}, Genre: "Dystopian", Published: 1949, Rating: 4.2, Tags: []string{"surveillance", "totalitarianism", "propaganda"}}, + {Title: "The Great Gatsby", Author: Author{Name: "F. Scott Fitzgerald", Country: "USA"}, Genre: "Modernist", Published: 1925, Rating: 4.0, Tags: []string{"wealth", "american dream", "love"}}, + {Title: "Dune", Author: Author{Name: "Frank Herbert", Country: "USA"}, Genre: "Science Fiction", Published: 1965, Rating: 4.6, Tags: []string{"politics", "desert", "ecology"}}, + } +} + +func TestIntegration_PipelineExecute(t *testing.T) { + skipIfNotEnterprise(t) ctx := context.Background() client := integrationClient(t) coll := integrationColl(t) @@ -67,82 +97,10 @@ func TestIntegration_PipelineExecute(t *testing.T) { if len(res) != 0 { t.Errorf("got %d documents, want 0", len(res)) } - - stats := iter.ExplainStats() - if stats != nil { - t.Fatal("ExplainStats should be nil when WithExplainMode is not used") - } - }) - - t.Run("ExplainModeAnalyze and recommended index", func(t *testing.T) { - doc := coll.NewDoc() - _, err := doc.Create(ctx, map[string]interface{}{"a": 1}) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { - deleteDocuments([]*DocumentRef{doc}) - }) - - iter := client.Pipeline().Collection(coll.ID).WithExecuteOptions(WithExplainMode(ExplainModeAnalyze), WithIndexMode("recommended")).Execute(ctx) - defer iter.Stop() - _, err = iter.GetAll() - if err != nil { - t.Fatalf("Failed to execute pipeline with explain mode: %v", err) - } - stats := iter.ExplainStats() - if stats == nil { - t.Fatal("ExplainStats should not be nil when WithExplainMode is used") - } - - text, err := stats.GetText() - if err != nil { - t.Fatalf("GetText() error: %v", err) - } - if text == "" { - t.Error("GetText() should not be empty") - } - - rawData, err := stats.GetRawData() - if err != nil { - t.Fatalf("GetRawData() error: %v", err) - } - if rawData == nil { - t.Error("GetRawData() should not be nil") - } }) t.Run("WithTransaction", func(t *testing.T) { h := testHelper{t} - type Author struct { - Name string `firestore:"name"` - Country string `firestore:"country"` - } - type Book struct { - Title string `firestore:"title"` - Author `firestore:"author"` - Genre string `firestore:"genre"` - Published int `firestore:"published"` - Rating float64 `firestore:"rating"` - Tags []string `firestore:"tags"` - } - books := []Book{ - { - Title: "The Hitchhiker's Guide to the Galaxy", - Author: Author{Name: "Douglas Adams", Country: "UK"}, - Genre: "Science Fiction", - Published: 1979, - Rating: 4.2, - Tags: []string{"comedy", "space", "adventure"}, - }, - { - Title: "Pride and Prejudice", - Author: Author{Name: "Jane Austen", Country: "UK"}, - Genre: "Romance", - Published: 1813, - Rating: 4.5, - Tags: []string{"classic", "social commentary", "love"}, - }, - } + books := testBooks()[:2] var docRefs []*DocumentRef for _, b := range books { docRef := coll.NewDoc() @@ -171,65 +129,34 @@ func TestIntegration_PipelineExecute(t *testing.T) { } func TestIntegration_PipelineStages(t *testing.T) { - if testParams[firestoreEditionKey].(firestoreEdition) != editionEnterprise { - t.Skip("Skipping pipeline queries tests since the firestore edition of", testParams[databaseIDKey].(string), "database is not enterprise") - } + skipIfNotEnterprise(t) ctx := context.Background() client := integrationClient(t) coll := integrationColl(t) h := testHelper{t} + type Author struct { + Name string `firestore:"name"` + Country string `firestore:"country"` + } type Book struct { - Title string `firestore:"title"` - Author struct { - Name string `firestore:"name"` - Country string `firestore:"country"` - } `firestore:"author"` + Title string `firestore:"title"` + Author Author `firestore:"author"` Genre string `firestore:"genre"` Published int `firestore:"published"` Rating float64 `firestore:"rating"` Tags []string `firestore:"tags"` } books := []Book{ - {Title: "The Hitchhiker's Guide to the Galaxy", Author: struct { - Name string `firestore:"name"` - Country string `firestore:"country"` - }{Name: "Douglas Adams", Country: "UK"}, Genre: "Science Fiction", Published: 1979, Rating: 4.2, Tags: []string{"comedy", "space", "adventure"}}, - {Title: "Pride and Prejudice", Author: struct { - Name string `firestore:"name"` - Country string `firestore:"country"` - }{Name: "Jane Austen", Country: "UK"}, Genre: "Romance", Published: 1813, Rating: 4.5, Tags: []string{"classic", "social commentary", "love"}}, - {Title: "One Hundred Years of Solitude", Author: struct { - Name string `firestore:"name"` - Country string `firestore:"country"` - }{Name: "Gabriel García Márquez", Country: "Colombia"}, Genre: "Magical Realism", Published: 1967, Rating: 4.3, Tags: []string{"family", "history", "fantasy"}}, - {Title: "The Lord of the Rings", Author: struct { - Name string `firestore:"name"` - Country string `firestore:"country"` - }{Name: "J.R.R. Tolkien", Country: "UK"}, Genre: "Fantasy", Published: 1954, Rating: 4.7, Tags: []string{"adventure", "magic", "epic"}}, - {Title: "The Handmaid's Tale", Author: struct { - Name string `firestore:"name"` - Country string `firestore:"country"` - }{Name: "Margaret Atwood", Country: "Canada"}, Genre: "Dystopian", Published: 1985, Rating: 4.1, Tags: []string{"feminism", "totalitarianism", "resistance"}}, - {Title: "Crime and Punishment", Author: struct { - Name string `firestore:"name"` - Country string `firestore:"country"` - }{Name: "Fyodor Dostoevsky", Country: "Russia"}, Genre: "Psychological Thriller", Published: 1866, Rating: 4.3, Tags: []string{"philosophy", "crime", "redemption"}}, - {Title: "To Kill a Mockingbird", Author: struct { - Name string `firestore:"name"` - Country string `firestore:"country"` - }{Name: "Harper Lee", Country: "USA"}, Genre: "Southern Gothic", Published: 1960, Rating: 4.2, Tags: []string{"racism", "injustice", "coming-of-age"}}, - {Title: "1984", Author: struct { - Name string `firestore:"name"` - Country string `firestore:"country"` - }{Name: "George Orwell", Country: "UK"}, Genre: "Dystopian", Published: 1949, Rating: 4.2, Tags: []string{"surveillance", "totalitarianism", "propaganda"}}, - {Title: "The Great Gatsby", Author: struct { - Name string `firestore:"name"` - Country string `firestore:"country"` - }{Name: "F. Scott Fitzgerald", Country: "USA"}, Genre: "Modernist", Published: 1925, Rating: 4.0, Tags: []string{"wealth", "american dream", "love"}}, - {Title: "Dune", Author: struct { - Name string `firestore:"name"` - Country string `firestore:"country"` - }{Name: "Frank Herbert", Country: "USA"}, Genre: "Science Fiction", Published: 1965, Rating: 4.6, Tags: []string{"politics", "desert", "ecology"}}, + {Title: "The Hitchhiker's Guide to the Galaxy", Author: Author{Name: "Douglas Adams", Country: "UK"}, Genre: "Science Fiction", Published: 1979, Rating: 4.2, Tags: []string{"comedy", "space", "adventure"}}, + {Title: "Pride and Prejudice", Author: Author{Name: "Jane Austen", Country: "UK"}, Genre: "Romance", Published: 1813, Rating: 4.5, Tags: []string{"classic", "social commentary", "love"}}, + {Title: "One Hundred Years of Solitude", Author: Author{Name: "Gabriel García Márquez", Country: "Colombia"}, Genre: "Magical Realism", Published: 1967, Rating: 4.3, Tags: []string{"family", "history", "fantasy"}}, + {Title: "The Lord of the Rings", Author: Author{Name: "J.R.R. Tolkien", Country: "UK"}, Genre: "Fantasy", Published: 1954, Rating: 4.7, Tags: []string{"adventure", "magic", "epic"}}, + {Title: "The Handmaid's Tale", Author: Author{Name: "Margaret Atwood", Country: "Canada"}, Genre: "Dystopian", Published: 1985, Rating: 4.1, Tags: []string{"feminism", "totalitarianism", "resistance"}}, + {Title: "Crime and Punishment", Author: Author{Name: "Fyodor Dostoevsky", Country: "Russia"}, Genre: "Psychological Thriller", Published: 1866, Rating: 4.3, Tags: []string{"philosophy", "crime", "redemption"}}, + {Title: "To Kill a Mockingbird", Author: Author{Name: "Harper Lee", Country: "USA"}, Genre: "Southern Gothic", Published: 1960, Rating: 4.2, Tags: []string{"racism", "injustice", "coming-of-age"}}, + {Title: "1984", Author: Author{Name: "George Orwell", Country: "UK"}, Genre: "Dystopian", Published: 1949, Rating: 4.2, Tags: []string{"surveillance", "totalitarianism", "propaganda"}}, + {Title: "The Great Gatsby", Author: Author{Name: "F. Scott Fitzgerald", Country: "USA"}, Genre: "Modernist", Published: 1925, Rating: 4.0, Tags: []string{"wealth", "american dream", "love"}}, + {Title: "Dune", Author: Author{Name: "Frank Herbert", Country: "USA"}, Genre: "Science Fiction", Published: 1965, Rating: 4.6, Tags: []string{"politics", "desert", "ecology"}}, } var docRefs []*DocumentRef for _, b := range books { @@ -328,15 +255,6 @@ func TestIntegration_PipelineStages(t *testing.T) { t.Errorf("got %d documents, want 2", len(results)) } }) - t.Run("CollectionWithOptions", func(t *testing.T) { - hints := CollectionHints{}.WithForceIndex("title") - iter := client.Pipeline().Collection(coll.ID, WithCollectionHints(hints)).Execute(ctx) - defer iter.Stop() - _, err := iter.Next() - if s, ok := status.FromError(err); !ok || s.Code() != codes.InvalidArgument { - t.Errorf("got err %v, want InvalidArgument", err) - } - }) t.Run("Database", func(t *testing.T) { dbDoc1 := coll.Doc("db_doc1") otherColl := client.Collection(collectionIDs.New()) @@ -762,9 +680,7 @@ func TestIntegration_PipelineStages(t *testing.T) { } func TestIntegration_PipelineFunctions(t *testing.T) { - if testParams[firestoreEditionKey].(firestoreEdition) != editionEnterprise { - t.Skip("Skipping pipeline queries tests since the firestore edition of", testParams[databaseIDKey].(string), "database is not enterprise") - } + skipIfNotEnterprise(t) t.Run("arrayFuncs", arrayFuncs) t.Run("stringFuncs", stringFuncs) t.Run("vectorFuncs", vectorFuncs) @@ -870,10 +786,7 @@ func typeFuncs(t *testing.T) { defer iter.Stop() docs, err := iter.GetAll() - if isRetryablePipelineExecuteErr(err) { - t.Errorf("GetAll: %v. Retrying....", err) - return - } else if err != nil { + if err != nil { t.Fatalf("GetAll: %v", err) return } @@ -888,6 +801,198 @@ func typeFuncs(t *testing.T) { } } +func TestIntegration_Query_Pipeline(t *testing.T) { + skipIfNotEnterprise(t) + ctx := context.Background() + coll := integrationColl(t) + h := testHelper{t} + type Book struct { + Title string `firestore:"title"` + Genre string `firestore:"genre"` + Published int `firestore:"published"` + Rating float64 `firestore:"rating"` + } + books := []Book{ + {Title: "The Hitchhiker's Guide to the Galaxy", Genre: "Science Fiction", Published: 1979, Rating: 4.2}, + {Title: "Pride and Prejudice", Genre: "Romance", Published: 1813, Rating: 4.5}, + {Title: "One Hundred Years of Solitude", Genre: "Magical Realism", Published: 1967, Rating: 4.3}, + } + var docRefs []*DocumentRef + for _, b := range books { + docRef := coll.NewDoc() + h.mustCreate(docRef, b) + docRefs = append(docRefs, docRef) + } + t.Cleanup(func() { + deleteDocuments(docRefs) + }) + + t.Run("Where", func(t *testing.T) { + q := coll.Where("published", ">", 1900) + p := q.Pipeline() + iter := p.Execute(ctx) + defer iter.Stop() + res, err := iter.GetAll() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if len(res) != 2 { + t.Errorf("got %d documents, want 2", len(res)) + } + }) + + t.Run("OrderBy", func(t *testing.T) { + q := coll.OrderBy("published", Asc) + p := q.Pipeline() + iter := p.Execute(ctx) + defer iter.Stop() + res, err := iter.GetAll() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if len(res) != 3 { + t.Errorf("got %d documents, want 3", len(res)) + } + var publishedYears []int64 + for _, r := range res { + publishedYears = append(publishedYears, r.Data()["published"].(int64)) + } + if !sort.SliceIsSorted(publishedYears, func(i, j int) bool { return publishedYears[i] < publishedYears[j] }) { + t.Errorf("results not sorted by published year: %v", publishedYears) + } + }) + + t.Run("Limit", func(t *testing.T) { + q := coll.Limit(2) + p := q.Pipeline() + iter := p.Execute(ctx) + defer iter.Stop() + res, err := iter.GetAll() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if len(res) != 2 { + t.Errorf("got %d documents, want 2", len(res)) + } + }) + + t.Run("Offset", func(t *testing.T) { + q := coll.OrderBy("published", Asc).Offset(1) + p := q.Pipeline() + iter := p.Execute(ctx) + defer iter.Stop() + res, err := iter.GetAll() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if len(res) != 2 { + t.Errorf("got %d documents, want 2", len(res)) + } + }) + + t.Run("Select", func(t *testing.T) { + q := coll.Select("title") + p := q.Pipeline() + iter := p.Execute(ctx) + defer iter.Stop() + doc, err := iter.Next() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + data := doc.Data() + if _, ok := data["title"]; !ok { + t.Error("missing 'title' field") + } + if _, ok := data["genre"]; ok { + t.Error("unexpected 'genre' field") + } + }) +} + +func TestIntegration_AggregationQuery_Pipeline(t *testing.T) { + skipIfNotEnterprise(t) + ctx := context.Background() + coll := integrationColl(t) + h := testHelper{t} + type Book struct { + Title string `firestore:"title"` + Genre string `firestore:"genre"` + Published int `firestore:"published"` + Rating float64 `firestore:"rating"` + } + books := []Book{ + {Title: "The Hitchhiker's Guide to the Galaxy", Genre: "Science Fiction", Published: 1979, Rating: 4.2}, + {Title: "Pride and Prejudice", Genre: "Romance", Published: 1813, Rating: 4.5}, + {Title: "One Hundred Years of Solitude", Genre: "Magical Realism", Published: 1967, Rating: 4.3}, + } + var docRefs []*DocumentRef + for _, b := range books { + docRef := coll.NewDoc() + h.mustCreate(docRef, b) + docRefs = append(docRefs, docRef) + } + t.Cleanup(func() { + deleteDocuments(docRefs) + }) + + t.Run("Count", func(t *testing.T) { + ag := coll.NewAggregationQuery().WithCount("count") + p := ag.Pipeline() + iter := p.Execute(ctx) + defer iter.Stop() + doc, err := iter.Next() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + + if !doc.Exists() { + t.Fatalf("Exists: got: false, want: true") + } + data := doc.Data() + if data["count"] != int64(3) { + t.Errorf("got %d count, want 3", data["count"]) + } + }) + + t.Run("Sum", func(t *testing.T) { + ag := coll.NewAggregationQuery().WithSum("published", "total_published") + p := ag.Pipeline() + iter := p.Execute(ctx) + defer iter.Stop() + doc, err := iter.Next() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + + if !doc.Exists() { + t.Fatalf("Exists: got: false, want: true") + } + data := doc.Data() + if data["total_published"] != int64(1979+1813+1967) { + t.Errorf("got %d total_published, want %d", data["total_published"], int64(1979+1813+1967)) + } + }) + + t.Run("Average", func(t *testing.T) { + ag := coll.NewAggregationQuery().WithAvg("rating", "avg_rating") + p := ag.Pipeline() + iter := p.Execute(ctx) + defer iter.Stop() + doc, err := iter.Next() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + + if !doc.Exists() { + t.Fatalf("Exists: got: false, want: true") + } + data := doc.Data() + if data["avg_rating"] != (4.2+4.5+4.3)/3 { + t.Errorf("got %f avg_rating, want %f", data["avg_rating"], (4.2+4.5+4.3)/3) + } + }) +} + func objectFuncs(t *testing.T) { t.Parallel() h := testHelper{t} @@ -934,10 +1039,7 @@ func objectFuncs(t *testing.T) { defer iter.Stop() docs, err := iter.GetAll() - if isRetryablePipelineExecuteErr(err) { - t.Errorf("GetAll: %v. Retrying....", err) - return - } else if err != nil { + if err != nil { t.Fatalf("GetAll: %v", err) return } @@ -1064,10 +1166,7 @@ func arrayFuncs(t *testing.T) { defer iter.Stop() docs, err := iter.GetAll() - if isRetryablePipelineExecuteErr(err) { - t.Errorf("GetAll: %v. Retrying....", err) - return - } else if err != nil { + if err != nil { t.Fatalf("GetAll: %v", err) return } @@ -1208,10 +1307,7 @@ func stringFuncs(t *testing.T) { defer iter.Stop() docs, err := iter.GetAll() - if isRetryablePipelineExecuteErr(err) { - t.Errorf("GetAll: %v. Retrying....", err) - return - } else if err != nil { + if err != nil { t.Fatalf("GetAll: %v", err) return } @@ -1320,10 +1416,7 @@ func vectorFuncs(t *testing.T) { defer iter.Stop() docs, err := iter.GetAll() - if isRetryablePipelineExecuteErr(err) { - t.Errorf("GetAll: %v. Retrying....", err) - return - } else if err != nil { + if err != nil { t.Fatalf("GetAll: %v", err) return } @@ -1339,19 +1432,6 @@ func vectorFuncs(t *testing.T) { } } -func isRetryablePipelineExecuteErr(err error) bool { - if err == nil { - return false - } - s, ok := status.FromError(err) - if !ok { - return false - } - return s.Code() == codes.InvalidArgument && - strings.Contains(s.Message(), "Invalid request routing header") && - strings.Contains(s.Message(), "Please fill in the request header with format") -} - func timestampFuncs(t *testing.T) { t.Parallel() client := integrationClient(t) @@ -1905,10 +1985,7 @@ func keyFuncs(t *testing.T) { defer iter.Stop() docs, err := iter.GetAll() - if isRetryablePipelineExecuteErr(err) { - t.Errorf("GetAll: %v. Retrying....", err) - return - } else if err != nil { + if err != nil { t.Fatalf("GetAll: %v", err) return } @@ -1999,10 +2076,7 @@ func generalFuncs(t *testing.T) { defer iter.Stop() docs, err := iter.GetAll() - if isRetryablePipelineExecuteErr(err) { - t.Errorf("GetAll: %v. Retrying....", err) - return - } else if err != nil { + if err != nil { t.Fatalf("GetAll: %v", err) return } @@ -2022,80 +2096,156 @@ func logicalFuncs(t *testing.T) { h := testHelper{t} client := integrationClient(t) coll := client.Collection(collectionIDs.New()) - docRef1 := coll.NewDoc() - h.mustCreate(docRef1, map[string]interface{}{ + docRef1 := coll.Doc("doc1") + doc1Data := map[string]interface{}{ "a": 1, "b": 2, "c": nil, "d": true, "e": false, - }) - defer deleteDocuments([]*DocumentRef{docRef1}) + } + h.mustCreate(docRef1, doc1Data) + + docRef2 := coll.Doc("doc2") + doc2Data := map[string]interface{}{ + "a": 1, + "b": 1, + "d": true, + "e": true, + } + h.mustCreate(docRef2, doc2Data) + defer deleteDocuments([]*DocumentRef{docRef1, docRef2}) + + doc1Want := map[string]interface{}{ + "a": int64(1), + "b": int64(2), + "c": nil, + "d": true, + "e": false, + } + doc2Want := map[string]interface{}{ + "a": int64(1), + "b": int64(1), + "d": true, + "e": true, + } tests := []struct { name string pipeline *Pipeline - want map[string]interface{} + want interface{} }{ { name: "Conditional - true", pipeline: client.Pipeline().Collection(coll.ID).Select(Conditional(Equal(ConstantOf(1), ConstantOf(1)), FieldOf("a"), FieldOf("b")).As("result")), - want: map[string]interface{}{"result": int64(1)}, + want: []map[string]interface{}{{"result": int64(1)}, {"result": int64(1)}}, }, { name: "Conditional - false", pipeline: client.Pipeline().Collection(coll.ID).Select(Conditional(Equal(ConstantOf(1), ConstantOf(0)), FieldOf("a"), FieldOf("b")).As("result")), - want: map[string]interface{}{"result": int64(2)}, + want: []map[string]interface{}{{"result": int64(2)}, {"result": int64(1)}}, }, { name: "Conditional - field true", pipeline: client.Pipeline().Collection(coll.ID).Select(Conditional(Equal(FieldOf("d"), ConstantOf(true)), FieldOf("a"), FieldOf("b")).As("result")), - want: map[string]interface{}{"result": int64(1)}, + want: []map[string]interface{}{{"result": int64(1)}, {"result": int64(1)}}, }, { name: "Conditional - field false", pipeline: client.Pipeline().Collection(coll.ID).Select(Conditional(Equal(FieldOf("e"), ConstantOf(true)), FieldOf("a"), FieldOf("b")).As("result")), - want: map[string]interface{}{"result": int64(2)}, + want: []map[string]interface{}{{"result": int64(2)}, {"result": int64(1)}}, }, { name: "LogicalMax", pipeline: client.Pipeline().Collection(coll.ID).Select(LogicalMaximum(FieldOf("a"), FieldOf("b")).As("max")), - want: map[string]interface{}{"max": int64(2)}, + want: []map[string]interface{}{{"max": int64(2)}, {"max": int64(1)}}, }, { name: "LogicalMin", pipeline: client.Pipeline().Collection(coll.ID).Select(LogicalMinimum(FieldOf("a"), FieldOf("b")).As("min")), - want: map[string]interface{}{"min": int64(1)}, + want: []map[string]interface{}{{"min": int64(1)}, {"min": int64(1)}}, }, { name: "IfError - no error", pipeline: client.Pipeline().Collection(coll.ID).Select(IfError(FieldOf("a"), ConstantOf(100)).As("result")), - want: map[string]interface{}{"result": int64(1)}, + want: []map[string]interface{}{{"result": int64(1)}, {"result": int64(1)}}, }, { name: "IfError - error", pipeline: client.Pipeline().Collection(coll.ID).Select(Divide("a", 0).IfError(ConstantOf("was error")).As("ifError")), - want: map[string]interface{}{"ifError": "was error"}, + want: []map[string]interface{}{{"ifError": "was error"}, {"ifError": "was error"}}, }, { name: "IfErrorBoolean - no error", pipeline: client.Pipeline().Collection(coll.ID).Select(IfErrorBoolean(Equal(FieldOf("d"), ConstantOf(true)), Equal(ConstantOf(1), ConstantOf(0))).As("result")), - want: map[string]interface{}{"result": true}, + want: []map[string]interface{}{{"result": true}, {"result": true}}, }, { name: "IfErrorBoolean - error", pipeline: client.Pipeline().Collection(coll.ID).Select(IfErrorBoolean(Equal(FieldOf("x"), ConstantOf(true)), Equal(ConstantOf(1), ConstantOf(0))).As("result")), - want: map[string]interface{}{"result": false}, + want: []map[string]interface{}{{"result": false}, {"result": false}}, }, { name: "IfAbsent - not absent", pipeline: client.Pipeline().Collection(coll.ID).Select(IfAbsent(FieldOf("a"), ConstantOf(100)).As("result")), - want: map[string]interface{}{"result": int64(1)}, + want: []map[string]interface{}{{"result": int64(1)}, {"result": int64(1)}}, }, { name: "IfAbsent - absent", pipeline: client.Pipeline().Collection(coll.ID).Select(IfAbsent(FieldOf("x"), ConstantOf(100)).As("result")), - want: map[string]interface{}{"result": int64(100)}, + want: []map[string]interface{}{{"result": int64(100)}, {"result": int64(100)}}, + }, + { + name: "And", + pipeline: client.Pipeline().Collection(coll.ID).Where( + And( + Equal(FieldOf("a"), 1), + Equal(FieldOf("b"), 2), + ), + ), + want: []map[string]interface{}{doc1Want}, + }, + { + name: "Or", + pipeline: client.Pipeline().Collection(coll.ID).Where( + Or( + Equal(FieldOf("b"), 2), + Equal(FieldOf("e"), true), + ), + ), + want: []map[string]interface{}{doc1Want, doc2Want}, + }, + { + name: "Not", + pipeline: client.Pipeline().Collection(coll.ID).Where( + Not(Equal(FieldOf("b"), 1)), + ), + want: []map[string]interface{}{doc1Want}, + }, + { + name: "Xor", + pipeline: client.Pipeline().Collection(coll.ID).Where( + Xor( + Equal(FieldOf("d"), true), + Equal(FieldOf("e"), true), + ), + ), + want: []map[string]interface{}{doc1Want}, + }, + { + name: "FieldExists", + pipeline: client.Pipeline().Collection(coll.ID).Where(FieldExists("c")), + want: []map[string]interface{}{doc1Want}, + }, + { + name: "IsError", + pipeline: client.Pipeline().Collection(coll.ID).Where(IsError(Divide("a", 0))), + want: []map[string]interface{}{doc1Want, doc2Want}, + }, + { + name: "IsAbsent", + pipeline: client.Pipeline().Collection(coll.ID).Where(IsAbsent("c")), + want: []map[string]interface{}{doc2Want}, }, } @@ -2106,20 +2256,154 @@ func logicalFuncs(t *testing.T) { defer iter.Stop() docs, err := iter.GetAll() - if isRetryablePipelineExecuteErr(err) { - t.Errorf("GetAll: %v. Retrying....", err) - return - } else if err != nil { + if err != nil { t.Fatalf("GetAll: %v", err) return } - if len(docs) != 1 { - t.Fatalf("expected 1 doc, got %d", len(docs)) - } - got := docs[0].Data() - if diff := testutil.Diff(got, test.want); diff != "" { - t.Errorf("got: %v, want: %v, diff +want -got: %s", got, test.want, diff) + + lastStage := test.pipeline.stages[len(test.pipeline.stages)-1] + lastStageName := lastStage.name() + + if lastStageName == stageNameSelect { // This is a select query + want, ok := test.want.([]map[string]interface{}) + if !ok { + t.Fatalf("invalid test.want type for select query: %T", test.want) + return + } + if len(docs) != len(want) { + t.Fatalf("expected %d doc(s), got %d", len(want), len(docs)) + return + } + var gots []map[string]interface{} + for _, doc := range docs { + gots = append(gots, doc.Data()) + } + if diff := testutil.Diff(gots, want, cmpopts.SortSlices(func(a, b map[string]interface{}) bool { + // A stable sort for the results. + // Try to sort by "result", "max", "min", "ifError" + if v1, ok := a["result"]; ok { + v2 := b["result"] + switch v1 := v1.(type) { + case int64: + return v1 < v2.(int64) + case bool: + return !v1 && v2.(bool) + } + } + if v1, ok := a["max"]; ok { + return v1.(int64) < b["max"].(int64) + } + if v1, ok := a["min"]; ok { + return v1.(int64) < b["min"].(int64) + } + if v1, ok := a["ifError"]; ok { + return v1.(string) < b["ifError"].(string) + } + return false + })); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", gots, want, diff) + } + } else if lastStageName == stageNameWhere { // This is a where query (filter condition) + want, ok := test.want.([]map[string]interface{}) + if !ok { + t.Fatalf("invalid test.want type for where query: %T", test.want) + return + } + if len(docs) != len(want) { + t.Fatalf("expected %d doc(s), got %d", len(want), len(docs)) + return + } + var gots []map[string]interface{} + for _, doc := range docs { + got := doc.Data() + gots = append(gots, got) + } + // Sort slices before comparing for consistent test results + sort.Slice(gots, func(i, j int) bool { + if gots[i]["a"].(int64) == gots[j]["a"].(int64) { + return gots[i]["b"].(int64) < gots[j]["b"].(int64) + } + return gots[i]["a"].(int64) < gots[j]["a"].(int64) + }) + sort.Slice(want, func(i, j int) bool { + if want[i]["a"].(int64) == want[j]["a"].(int64) { + return want[i]["b"].(int64) < want[j]["b"].(int64) + } + return want[i]["a"].(int64) < want[j]["a"].(int64) + }) + if diff := testutil.Diff(gots, want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", gots, want, diff) + } + } else { + t.Fatalf("unknown pipeline stage: %s", lastStageName) + return } }) } } + +func TestIntegration_CreateFromQuery(t *testing.T) { + skipIfNotEnterprise(t) + ctx := context.Background() + client := integrationClient(t) + coll := integrationColl(t) + h := testHelper{t} + + books := testBooks()[:3] + var docRefs []*DocumentRef + for _, b := range books { + docRef := coll.NewDoc() + h.mustCreate(docRef, b) + docRefs = append(docRefs, docRef) + } + t.Cleanup(func() { + deleteDocuments(docRefs) + }) + + q := coll.Where("rating", ">", 4.2) + p := client.Pipeline().CreateFromQuery(q) + iter := p.Execute(ctx) + defer iter.Stop() + results, err := iter.GetAll() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if len(results) != 2 { + t.Errorf("got %d documents, want 2", len(results)) + } +} + +func TestIntegration_CreateFromAggregationQuery(t *testing.T) { + skipIfNotEnterprise(t) + ctx := context.Background() + client := integrationClient(t) + coll := integrationColl(t) + h := testHelper{t} + + books := testBooks()[:3] + var docRefs []*DocumentRef + for _, b := range books { + docRef := coll.NewDoc() + h.mustCreate(docRef, b) + docRefs = append(docRefs, docRef) + } + t.Cleanup(func() { + deleteDocuments(docRefs) + }) + + ag := coll.NewAggregationQuery().WithCount("count") + p := client.Pipeline().CreateFromAggregationQuery(ag) + iter := p.Execute(ctx) + defer iter.Stop() + doc, err := iter.Next() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if !doc.Exists() { + t.Fatalf("Exists: got: false, want: true") + } + data := doc.Data() + if data["count"] != int64(3) { + t.Errorf("got count %d, want 3", data["count"]) + } +} diff --git a/firestore/pipeline_source.go b/firestore/pipeline_source.go index 153eb60804a5..ce0ac5519b57 100644 --- a/firestore/pipeline_source.go +++ b/firestore/pipeline_source.go @@ -162,3 +162,15 @@ func (ps *PipelineSource) Database() *Pipeline { func (ps *PipelineSource) Documents(refs ...*DocumentRef) *Pipeline { return newPipeline(ps.client, newInputStageDocuments(refs...)) } + +// CreateFromQuery creates a new [Pipeline] from the given [Query]. Under the hood, this will +// translate the query semantics (order by document ID, etc.) to an equivalent pipeline. +func (ps *PipelineSource) CreateFromQuery(query Query) *Pipeline { + return query.Pipeline() +} + +// CreateFromAggregationQuery creates a new [Pipeline] from the given [AggregationQuery]. Under the hood, this will +// translate the query semantics (order by document ID, etc.) to an equivalent pipeline. +func (ps *PipelineSource) CreateFromAggregationQuery(query *AggregationQuery) *Pipeline { + return query.Pipeline() +} diff --git a/firestore/query.go b/firestore/query.go index 493e56cba9c8..729b47eb7c84 100644 --- a/firestore/query.go +++ b/firestore/query.go @@ -1768,3 +1768,378 @@ type AggregationResponse struct { // Query explain metrics. This is only present when ExplainOptions is provided. ExplainMetrics *ExplainMetrics } + +func (q *Query) toPipeline() *Pipeline { + var p *Pipeline + if q.allDescendants { + p = q.c.Pipeline().CollectionGroup(q.collectionID) + } else { + p = q.c.Pipeline().Collection(q.collectionID) + } + + if q.err != nil { + p.err = q.err + return p + } + + var allFilters []BooleanExpression + + // Original filters + for _, f := range q.filters { + var filterExpr BooleanExpression + var err error + if fieldFilter := f.GetFieldFilter(); fieldFilter != nil { + filterExpr, err = newQueryFilter(q, fieldFilter) + if err != nil { + p.err = err + return p + } + } else if unaryFilter := f.GetUnaryFilter(); unaryFilter != nil { + filterExpr, err = newQueryUnaryFilter(unaryFilter) + if err != nil { + p.err = err + return p + } + } + allFilters = append(allFilters, filterExpr) + } + + // Start at + if q.startCursorSpecified() { + var startFilter BooleanExpression + var err error + if q.startDoc != nil { + startFilter, err = newCursorFilter(q.orders, q.startDoc, q.startBefore, true) + } else { + startFilter, err = newCursorFilterWithValues(q.orders, q.startVals, q.startBefore, true) + } + if err != nil { + p.err = err + return p + } + allFilters = append(allFilters, startFilter) + } + + // End at + if q.endCursorSpecified() { + var endFilter BooleanExpression + var err error + if q.endDoc != nil { + endFilter, err = newCursorFilter(q.orders, q.endDoc, q.endBefore, false) + } else { + endFilter, err = newCursorFilterWithValues(q.orders, q.endVals, q.endBefore, false) + } + if err != nil { + p.err = err + return p + } + allFilters = append(allFilters, endFilter) + } + + // Order by + if len(q.orders) > 0 { + var orders []Ordering + for _, o := range q.orders { + var fp FieldPath + if o.fieldReference != nil { + var err error + fp, err = fieldPathFromFieldRef(o.fieldReference) + if err != nil { + p.err = err + return p + } + } else { + fp = o.fieldPath + } + field := FieldOf(fp) + var direction OrderingDirection + if o.dir == Asc { + direction = OrderingAsc + } else { + direction = OrderingDesc + } + orders = append(orders, Ordering{Expr: field, Direction: direction}) + } + p = p.Sort(orders...) + } + // Combine all filters + if len(allFilters) == 1 { + p = p.Where(allFilters[0]) + } else if len(allFilters) > 1 { + p = p.Where(And(allFilters[0], allFilters[1:]...)) + } + + // Offset + if q.offset > 0 { + p = p.Offset(int(q.offset)) + } + + // Limit + if q.limit != nil { + p = p.Limit(int(q.limit.Value)) + } + + // Select + if len(q.selection) > 0 { + var fields []interface{} + for _, s := range q.selection { + fp, err := fieldPathFromFieldRef(s) + if err != nil { + p.err = err + return p + } + fields = append(fields, fp) + } + p = p.Select(fields...) + } + + // FindNearest + if q.findNearest != nil { + var measure PipelineDistanceMeasure + switch q.findNearest.DistanceMeasure { + case pb.StructuredQuery_FindNearest_EUCLIDEAN: + measure = PipelineDistanceMeasureEuclidean + case pb.StructuredQuery_FindNearest_COSINE: + measure = PipelineDistanceMeasureCosine + case pb.StructuredQuery_FindNearest_DOT_PRODUCT: + measure = PipelineDistanceMeasureDotProduct + } + + vectorField, err := fieldPathFromFieldRef(q.findNearest.VectorField) + if err != nil { + p.err = err + return p + } + + queryVector, err := createFromProtoValue(q.findNearest.QueryVector, q.c) + if err != nil { + p.err = err + return p + } + var limit *int + if q.findNearest.Limit != nil { + val := int(q.findNearest.Limit.Value) + limit = &val + } + + var distanceField *string + if q.findNearest.DistanceResultField != "" { + distanceField = &q.findNearest.DistanceResultField + } + + p = p.FindNearest(vectorField, queryVector, measure, &PipelineFindNearestOptions{ + Limit: limit, + DistanceField: distanceField, + }) + } + + return p +} + +func fieldPathFromFieldRef(ref *pb.StructuredQuery_FieldReference) (FieldPath, error) { + return parseDotSeparatedString(ref.FieldPath) +} + +func newQueryFilter(q *Query, f *pb.StructuredQuery_FieldFilter) (BooleanExpression, error) { + fp, err := fieldPathFromFieldRef(f.GetField()) + if err != nil { + return nil, err + } + v, err := createFromProtoValue(f.GetValue(), q.c) + if err != nil { + return nil, err + } + + switch f.Op { + case pb.StructuredQuery_FieldFilter_EQUAL: + return Equal(fp, v), nil + case pb.StructuredQuery_FieldFilter_NOT_EQUAL: + return NotEqual(fp, v), nil + case pb.StructuredQuery_FieldFilter_LESS_THAN: + return LessThan(fp, v), nil + case pb.StructuredQuery_FieldFilter_LESS_THAN_OR_EQUAL: + return LessThanOrEqual(fp, v), nil + case pb.StructuredQuery_FieldFilter_GREATER_THAN: + return GreaterThan(fp, v), nil + case pb.StructuredQuery_FieldFilter_GREATER_THAN_OR_EQUAL: + return GreaterThanOrEqual(fp, v), nil + case pb.StructuredQuery_FieldFilter_IN: + return EqualAny(fp, v), nil + case pb.StructuredQuery_FieldFilter_NOT_IN: + return NotEqualAny(fp, v), nil + case pb.StructuredQuery_FieldFilter_ARRAY_CONTAINS: + return ArrayContains(fp, v), nil + case pb.StructuredQuery_FieldFilter_ARRAY_CONTAINS_ANY: + return ArrayContainsAny(fp, v), nil + default: + return nil, fmt.Errorf("firestore: unsupported query filter operator: %v", f.Op) + } +} + +func newQueryUnaryFilter(f *pb.StructuredQuery_UnaryFilter) (BooleanExpression, error) { + fp, err := fieldPathFromFieldRef(f.GetField()) + if err != nil { + return nil, err + } + switch f.Op { + case pb.StructuredQuery_UnaryFilter_IS_NULL: + return Equal(fp, nil), nil + case pb.StructuredQuery_UnaryFilter_IS_NOT_NULL: + return NotEqual(fp, nil), nil + case pb.StructuredQuery_UnaryFilter_IS_NAN: + return Equal(fp, math.NaN()), nil + case pb.StructuredQuery_UnaryFilter_IS_NOT_NAN: + return NotEqual(fp, math.NaN()), nil + default: + return nil, fmt.Errorf("firestore: unsupported unary filter operator: %v", f.Op) + } +} + +// newCursorFilter creates a pipeline filter expression from a document snapshot cursor. +func newCursorFilter(orders []order, doc *DocumentSnapshot, before, isStart bool) (BooleanExpression, error) { + values := make([]interface{}, len(orders)) + for i, o := range orders { + var err error + if o.isDocumentID() { + values[i] = doc.Ref.ID + } else { + values[i], err = doc.DataAt(o.fieldPath.toServiceFieldPath()) + if err != nil { + return nil, err + } + } + } + return newCursorFilterWithValues(orders, values, before, isStart) +} + +// newCursorFilterWithValues creates a pipeline filter expression from a list of values. +func newCursorFilterWithValues(orders []order, values []interface{}, before, isStart bool) (BooleanExpression, error) { + if len(orders) != len(values) { + return nil, errors.New("firestore: number of cursor values does not match number of OrderBy fields") + } + + var orTerms []BooleanExpression + for i := 1; i <= len(orders); i++ { + prefixOrders := orders[:i] + prefixValues := values[:i] + var andTerms []BooleanExpression + for j, o := range prefixOrders { + fp := o.fieldPath + val := prefixValues[j] + + var op string + if j < len(prefixOrders)-1 { + op = "==" + } else { + if isStart { + if before { // StartAt + if o.dir == Asc { + op = ">=" + } else { + op = "<=" + } + } else { // StartAfter + if o.dir == Asc { + op = ">" + } else { + op = "<" + } + } + } else { // End + if before { // EndBefore + if o.dir == Asc { + op = "<" + } else { + op = ">" + } + } else { // EndAt + if o.dir == Asc { + op = "<=" + } else { + op = ">=" + } + } + } + } + + switch op { + case "==": + andTerms = append(andTerms, Equal(fp, val)) + case ">": + andTerms = append(andTerms, GreaterThan(fp, val)) + case ">=": + andTerms = append(andTerms, GreaterThanOrEqual(fp, val)) + case "<": + andTerms = append(andTerms, LessThan(fp, val)) + case "<=": + andTerms = append(andTerms, LessThanOrEqual(fp, val)) + } + } + if len(andTerms) == 1 { + orTerms = append(orTerms, andTerms[0]) + } else if len(andTerms) > 1 { + orTerms = append(orTerms, And(andTerms[0], andTerms[1:]...)) + } + } + + if len(orTerms) == 1 { + return orTerms[0], nil + } + return Or(orTerms[0], orTerms[1:]...), nil +} + +// Pipeline creates a new [Pipeline] from the query. +// All of the operations of the query will be converted to pipeline stages. +// For example, `query.Where("f", "==", 1).Limit(10).OrderBy("f", Asc).Pipeline()` is equivalent to +// `client.Pipeline().Collection("C").Where(Equal("f", 1)).Limit(10).Sort(Ascending("f"))`. +func (q Query) Pipeline() *Pipeline { + return q.toPipeline() +} + +// Pipeline creates a new [Pipeline] from the aggregation query. +// All of the operations of the underlying query will be converted to pipeline stages, +// and an aggregate stage will be added for the aggregations. +func (aq *AggregationQuery) Pipeline() *Pipeline { + p := aq.query.toPipeline() + if p.err != nil { + return p + } + + if len(aq.aggregateQueries) == 0 { + return p + } + + var aggregations []*AliasedAggregate + for _, aggQuery := range aq.aggregateQueries { + alias := aggQuery.GetAlias() + + var agg *AliasedAggregate + if _, ok := aggQuery.Operator.(*pb.StructuredAggregationQuery_Aggregation_Count_); ok { + // AggregationQuery's Count is a count of all documents. We can achieve + // this in a pipeline by counting the document ID, which is always present. + agg = Count(DocumentID).As(alias) + } else if sum := aggQuery.GetSum(); sum != nil { + fp, err := fieldPathFromFieldRef(sum.GetField()) + if err != nil { + p.err = err + return p + } + agg = Sum(fp).As(alias) + } else if avg := aggQuery.GetAvg(); avg != nil { + fp, err := fieldPathFromFieldRef(avg.GetField()) + if err != nil { + p.err = err + return p + } + agg = Average(fp).As(alias) + } else { + // This case should not be reachable with the current AggregationQuery API. + p.err = fmt.Errorf("firestore: unsupported aggregation operator in Pipeline conversion") + return p + } + aggregations = append(aggregations, agg) + } + + p = p.Aggregate(aggregations...) + return p +} diff --git a/firestore/query_test.go b/firestore/query_test.go index 54c0337d8734..f8d147335f62 100644 --- a/firestore/query_test.go +++ b/firestore/query_test.go @@ -1751,6 +1751,141 @@ func TestQueryRunOptionsAndGetAllWithOptions(t *testing.T) { t.Fatal(err) } } + +func TestQuery_Pipeline(t *testing.T) { + t.Parallel() + client, _, cleanup := newMock(t) + defer cleanup() + + coll := client.Collection("C") + + testCases := []struct { + name string + query Query + expPipe *Pipeline + }{ + { + name: "simple query", + query: coll.Where("f", "==", 1).Limit(10), + expPipe: client.Pipeline().Collection("C").Where(Equal("f", 1)).Limit(10), + }, + { + name: "query with all clauses", + query: coll.Where("f", ">", 1).OrderBy("f", Asc).Select("f").Offset(1), + expPipe: client.Pipeline().Collection("C").Sort(Ordering{Expr: FieldOf("f"), Direction: OrderingAsc}).Where(GreaterThan("f", 1)).Offset(1).Select("f"), + }, + { + name: "query with collection group", + query: client.CollectionGroup("C").Where("f", "==", 1).Limit(10), + expPipe: client.Pipeline().CollectionGroup("C").Where(Equal("f", 1)).Limit(10), + }, + { + name: "query with cursor", + query: coll.OrderBy("f", Asc).StartAt(1), + expPipe: client.Pipeline().Collection("C").Sort(Ordering{Expr: FieldOf("f"), Direction: OrderingAsc}).Where(GreaterThanOrEqual(FieldPath{"f"}, 1)), + }, + { + name: "query with findNearest", + query: coll.FindNearest("f", []float32{1, 2, 3}, 5, DistanceMeasureEuclidean, &FindNearestOptions{DistanceResultField: "dist"}).q, + expPipe: client.Pipeline().Collection("C").FindNearest("f", []float32{1, 2, 3}, PipelineDistanceMeasureEuclidean, &PipelineFindNearestOptions{Limit: intptr(5), DistanceField: stringptr("dist")}), + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + p := tc.query.Pipeline() + if p.err != nil { + t.Fatalf("Pipeline() got error: %v", p.err) + } + gotProto, err := p.toProto() + if err != nil { + t.Fatalf("p.toProto() got error: %v", err) + } + expProto, err := tc.expPipe.toProto() + if err != nil { + t.Fatalf("expPipe.toProto() got error: %v", err) + } + + if diff := cmp.Diff(expProto, gotProto, protocmp.Transform()); diff != "" { + t.Errorf("mismatch (-want, +got)\n: %s", diff) + } + }) + } +} + +func intptr(i int) *int { + return &i +} + +func stringptr(s string) *string { + return &s +} + +func TestAggregationQuery_Pipeline(t *testing.T) { + t.Parallel() + client, _, cleanup := newMock(t) + defer cleanup() + + coll := client.Collection("C") + queryWithWhere := coll.Where("f", "==", 1) + + testCases := []struct { + name string + query *AggregationQuery + expPipe *Pipeline + }{ + { + name: "simple aggregation query", + query: coll.NewAggregationQuery().WithCount("total"), + expPipe: client.Pipeline().Collection("C").Aggregate(Count(DocumentID).As("total")), + }, + { + name: "aggregation query with where", + query: queryWithWhere.NewAggregationQuery().WithCount("total"), + expPipe: client.Pipeline().Collection("C").Where(Equal("f", 1)).Aggregate(Count(DocumentID).As("total")), + }, + { + name: "aggregation query with sum", + query: coll.NewAggregationQuery().WithSum("f", "sum_f"), + expPipe: client.Pipeline().Collection("C").Aggregate(Sum("f").As("sum_f")), + }, + { + name: "aggregation query with avg", + query: coll.NewAggregationQuery().WithAvg("f", "avg_f"), + expPipe: client.Pipeline().Collection("C").Aggregate(Average("f").As("avg_f")), + }, + { + name: "aggregation query with multiple aggregations", + query: coll.NewAggregationQuery().WithCount("total").WithSum("f", "sum_f"), + expPipe: client.Pipeline().Collection("C").Aggregate(Count(DocumentID).As("total"), Sum("f").As("sum_f")), + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + p := tc.query.Pipeline() + if p.err != nil { + t.Fatalf("Pipeline() got error: %v", p.err) + } + gotProto, err := p.toProto() + if err != nil { + t.Fatalf("p.toProto() got error: %v", err) + } + expProto, err := tc.expPipe.toProto() + if err != nil { + t.Fatalf("expPipe.toProto() got error: %v", err) + } + + if diff := cmp.Diff(expProto, gotProto, protocmp.Transform()); diff != "" { + t.Errorf("mismatch (-want, +got)\n: %s", diff) + } + }) + } +} func TestFindNearest(t *testing.T) { ctx := context.Background() c, srv, cleanup := newMock(t)