diff --git a/firestore/client.go b/firestore/client.go index a89a1ed46cb3..e96d57592804 100644 --- a/firestore/client.go +++ b/firestore/client.go @@ -47,7 +47,7 @@ const resourcePrefixHeader = "google-cloud-resource-prefix" // requestParamsHeader is routing header required to access named databases const reqParamsHeader = "x-goog-request-params" -// reqParamsHeaderVal constructs header from dbPath +// reqParamsHeaderVal constructs header from dbPath. // dbPath is of the form projects/{project_id}/databases/{database_id} func reqParamsHeaderVal(dbPath string) string { splitPath := strings.Split(dbPath, "/") diff --git a/firestore/integration_test.go b/firestore/integration_test.go index 5e4039258a8c..044292d80700 100644 --- a/firestore/integration_test.go +++ b/firestore/integration_test.go @@ -3566,6 +3566,558 @@ func TestIntegration_FindNearest(t *testing.T) { } } +func TestIntegration_PipelineStages(t *testing.T) { + if testParams[firestoreEditionKey].(firestoreEdition) != editionEnterprise { + t.Skip("Skipping pipeline queries tests since the firestore edition of", testParams[databaseIDKey].(string), "database is not enterprise") + } + ctx := context.Background() + client := integrationClient(t) + coll := integrationColl(t) + h := testHelper{t} + type Book struct { + Title string `firestore:"title"` + Author struct { + Name string `firestore:"name"` + Country string `firestore:"country"` + } `firestore:"author"` + Genre string `firestore:"genre"` + Published int `firestore:"published"` + Rating float64 `firestore:"rating"` + Tags []string `firestore:"tags"` + } + books := []Book{ + {Title: "The Hitchhiker's Guide to the Galaxy", Author: struct { + Name string `firestore:"name"` + Country string `firestore:"country"` + }{Name: "Douglas Adams", Country: "UK"}, Genre: "Science Fiction", Published: 1979, Rating: 4.2, Tags: []string{"comedy", "space", "adventure"}}, + {Title: "Pride and Prejudice", Author: struct { + Name string `firestore:"name"` + Country string `firestore:"country"` + }{Name: "Jane Austen", Country: "UK"}, Genre: "Romance", Published: 1813, Rating: 4.5, Tags: []string{"classic", "social commentary", "love"}}, + {Title: "One Hundred Years of Solitude", Author: struct { + Name string `firestore:"name"` + Country string `firestore:"country"` + }{Name: "Gabriel García Márquez", Country: "Colombia"}, Genre: "Magical Realism", Published: 1967, Rating: 4.3, Tags: []string{"family", "history", "fantasy"}}, + {Title: "The Lord of the Rings", Author: struct { + Name string `firestore:"name"` + Country string `firestore:"country"` + }{Name: "J.R.R. Tolkien", Country: "UK"}, Genre: "Fantasy", Published: 1954, Rating: 4.7, Tags: []string{"adventure", "magic", "epic"}}, + {Title: "The Handmaid's Tale", Author: struct { + Name string `firestore:"name"` + Country string `firestore:"country"` + }{Name: "Margaret Atwood", Country: "Canada"}, Genre: "Dystopian", Published: 1985, Rating: 4.1, Tags: []string{"feminism", "totalitarianism", "resistance"}}, + {Title: "Crime and Punishment", Author: struct { + Name string `firestore:"name"` + Country string `firestore:"country"` + }{Name: "Fyodor Dostoevsky", Country: "Russia"}, Genre: "Psychological Thriller", Published: 1866, Rating: 4.3, Tags: []string{"philosophy", "crime", "redemption"}}, + {Title: "To Kill a Mockingbird", Author: struct { + Name string `firestore:"name"` + Country string `firestore:"country"` + }{Name: "Harper Lee", Country: "USA"}, Genre: "Southern Gothic", Published: 1960, Rating: 4.2, Tags: []string{"racism", "injustice", "coming-of-age"}}, + {Title: "1984", Author: struct { + Name string `firestore:"name"` + Country string `firestore:"country"` + }{Name: "George Orwell", Country: "UK"}, Genre: "Dystopian", Published: 1949, Rating: 4.2, Tags: []string{"surveillance", "totalitarianism", "propaganda"}}, + {Title: "The Great Gatsby", Author: struct { + Name string `firestore:"name"` + Country string `firestore:"country"` + }{Name: "F. Scott Fitzgerald", Country: "USA"}, Genre: "Modernist", Published: 1925, Rating: 4.0, Tags: []string{"wealth", "american dream", "love"}}, + {Title: "Dune", Author: struct { + Name string `firestore:"name"` + Country string `firestore:"country"` + }{Name: "Frank Herbert", Country: "USA"}, Genre: "Science Fiction", Published: 1965, Rating: 4.6, Tags: []string{"politics", "desert", "ecology"}}, + } + var docRefs []*DocumentRef + for _, b := range books { + docRef := coll.NewDoc() + h.mustCreate(docRef, b) + docRefs = append(docRefs, docRef) + } + t.Cleanup(func() { + deleteDocuments(docRefs) + }) + t.Run("AddFields", func(t *testing.T) { + iter := client.Pipeline().Collection(coll.ID).AddFields(Multiply(FieldOf("rating"), 2).As("doubled_rating")).Limit(1).Execute(ctx) + defer iter.Stop() + doc, err := iter.Next() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if !doc.Exists() { + t.Fatalf("Exists: got: false, want: true") + } + data := doc.Data() + if dr, ok := data["doubled_rating"]; !ok || dr.(float64) != data["rating"].(float64)*2 { + t.Errorf("got doubled_rating %v, want %v", dr, data["rating"].(float64)*2) + } + }) + t.Run("Aggregate", func(t *testing.T) { + iter := client.Pipeline().Collection(coll.ID).Aggregate(Count("rating").As("total_books")).Execute(ctx) + defer iter.Stop() + doc, err := iter.Next() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + + if !doc.Exists() { + t.Fatalf("Exists: got: false, want: true") + } + data := doc.Data() + if data["total_books"] != int64(10) { + t.Errorf("got %d total_books, want 10", data["total_books"]) + } + }) + t.Run("AggregateWithSpec", func(t *testing.T) { + spec := NewAggregateSpec(Average("rating").As("avg_rating")).WithGroups("genre") + iter := client.Pipeline().Collection(coll.ID).AggregateWithSpec(spec).Execute(ctx) + defer iter.Stop() + results, err := iter.GetAll() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if len(results) != 8 { + t.Errorf("got %d groups, want 8", len(results)) + } + }) + t.Run("Distinct", func(t *testing.T) { + iter := client.Pipeline().Collection(coll.ID).Distinct("genre").Execute(ctx) + defer iter.Stop() + results, err := iter.GetAll() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if len(results) != 8 { + t.Errorf("got %d distinct genres, want 8", len(results)) + } + }) + t.Run("Documents", func(t *testing.T) { + iter := client.Pipeline().Documents(docRefs[0], docRefs[1]).Execute(ctx) + defer iter.Stop() + results, err := iter.GetAll() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if len(results) != 2 { + t.Errorf("got %d documents, want 2", len(results)) + } + }) + t.Run("CollectionGroup", func(t *testing.T) { + cgCollID := collectionIDs.New() + doc1 := coll.Doc("cg_doc1") + doc2 := coll.Doc("cg_doc2") + cgColl1 := doc1.Collection(cgCollID) + cgColl2 := doc2.Collection(cgCollID) + cgDoc1 := cgColl1.NewDoc() + cgDoc2 := cgColl2.NewDoc() + h.mustCreate(cgDoc1, map[string]string{"val": "a"}) + h.mustCreate(cgDoc2, map[string]string{"val": "b"}) + t.Cleanup(func() { + deleteDocuments([]*DocumentRef{cgDoc1, cgDoc2, doc1, doc2}) + }) + iter := client.Pipeline().CollectionGroup(cgCollID).Execute(ctx) + defer iter.Stop() + results, err := iter.GetAll() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if len(results) != 2 { + t.Errorf("got %d documents, want 2", len(results)) + } + }) + t.Run("Database", func(t *testing.T) { + dbDoc1 := coll.Doc("db_doc1") + otherColl := client.Collection(collectionIDs.New()) + dbDoc2 := otherColl.Doc("db_doc2") + h.mustCreate(dbDoc1, map[string]string{"val": "a"}) + h.mustCreate(dbDoc2, map[string]string{"val": "b"}) + t.Cleanup(func() { + deleteDocuments([]*DocumentRef{dbDoc1, dbDoc2}) + }) + iter := client.Pipeline().Database().Limit(2).Execute(ctx) + defer iter.Stop() + results, err := iter.GetAll() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if len(results) != 2 { + t.Errorf("got %d documents, want 2", len(results)) + } + }) + t.Run("FindNearest", func(t *testing.T) { + type DocWithVector struct { + ID string `firestore:"id"` + Vector Vector32 `firestore:"vector"` + } + docsWithVector := []DocWithVector{ + {ID: "doc1", Vector: Vector32{1.0, 2.0, 3.0}}, + {ID: "doc2", Vector: Vector32{4.0, 5.0, 6.0}}, + {ID: "doc3", Vector: Vector32{7.0, 8.0, 9.0}}, + } + var vectorDocRefs []*DocumentRef + for _, d := range docsWithVector { + docRef := coll.NewDoc() + h.mustCreate(docRef, d) + vectorDocRefs = append(vectorDocRefs, docRef) + } + t.Cleanup(func() { + deleteDocuments(vectorDocRefs) + }) + queryVector := Vector32{1.1, 2.1, 3.1} + limit := 2 + distanceField := "distance" + options := &PipelineFindNearestOptions{ + Limit: &limit, + DistanceField: &distanceField, + } + iter := client.Pipeline().Collection(coll.ID). + FindNearest("vector", queryVector, PipelineDistanceMeasureEuclidean, options). + Execute(ctx) + defer iter.Stop() + results, err := iter.GetAll() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if len(results) != 2 { + t.Errorf("got %d documents, want 2", len(results)) + } + // Check if the results are sorted by distance + + if !results[0].Exists() { + t.Fatalf("results[0] Exists: got: false, want: true") + } + dist1 := results[0].Data() + + if !results[1].Exists() { + t.Fatalf("results[1] Exists: got: false, want: true") + } + dist2 := results[1].Data() + if dist1[distanceField].(float64) > dist2[distanceField].(float64) { + t.Errorf("documents are not sorted by distance") + } + // Check if the correct documents are returned + if dist1["id"] != "doc1" { + t.Errorf("got doc id %q, want 'doc1'", dist1["id"]) + } + }) + t.Run("Limit", func(t *testing.T) { + iter := client.Pipeline().Collection(coll.ID).Limit(3).Execute(ctx) + defer iter.Stop() + results, err := iter.GetAll() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if len(results) != 3 { + t.Errorf("got %d documents, want 3", len(results)) + } + }) + t.Run("Offset", func(t *testing.T) { + iter := client.Pipeline().Collection(coll.ID).Sort(Ascending(FieldOf("published"))).Offset(2).Limit(1).Execute(ctx) + defer iter.Stop() + doc, err := iter.Next() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if !doc.Exists() { + t.Fatalf("Exists: got: false, want: true") + } + data := doc.Data() + if data["title"] != "The Great Gatsby" { + t.Errorf("got title %q, want 'The Great Gatsby'", data["title"]) + } + }) + t.Run("RemoveFields", func(t *testing.T) { + iter := client.Pipeline().Collection(coll.ID). + Limit(1). + RemoveFields("genre", "rating"). + Execute(ctx) + defer iter.Stop() + doc, err := iter.Next() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if !doc.Exists() { + t.Fatalf("Exists: got: false, want: true") + } + data := doc.Data() + if _, ok := data["genre"]; ok { + t.Error("unexpected 'genre' field") + } + if _, ok := data["rating"]; ok { + t.Error("unexpected 'rating' field") + } + if _, ok := data["title"]; !ok { + t.Error("missing 'title' field") + } + }) + t.Run("Replace", func(t *testing.T) { + type DocWithMap struct { + ID string `firestore:"id"` + Data map[string]int `firestore:"data"` + } + docWithMap := DocWithMap{ID: "docWithMap", Data: map[string]int{"a": 1, "b": 2}} + docRef := coll.NewDoc() + h.mustCreate(docRef, docWithMap) + t.Cleanup(func() { + deleteDocuments([]*DocumentRef{docRef}) + }) + iter := client.Pipeline().Collection(coll.ID). + Where(Equal(FieldOf("id"), "docWithMap")). + Replace("data"). + Execute(ctx) + defer iter.Stop() + doc, err := iter.Next() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if !doc.Exists() { + t.Fatalf("Exists: got: false, want: true") + } + data := doc.Data() + want := map[string]interface{}{"a": int64(1), "b": int64(2)} + if diff := testutil.Diff(data, want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", data, want, diff) + } + }) + t.Run("Sample", func(t *testing.T) { + t.Run("SampleByDocuments", func(t *testing.T) { + iter := client.Pipeline().Collection(coll.ID).Sample(SampleByDocuments(5)).Execute(ctx) + defer iter.Stop() + var got []map[string]interface{} + for { + doc, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if !doc.Exists() { + t.Fatalf("Exists: got: false, want: true") + } + data := doc.Data() + got = append(got, data) + } + if len(got) != 5 { + t.Errorf("got %d documents, want 5", len(got)) + } + }) + t.Run("SampleByPercentage", func(t *testing.T) { + iter := client.Pipeline().Collection(coll.ID).Sample(&SampleSpec{Size: 0.6, Mode: SampleModePercent}).Execute(ctx) + defer iter.Stop() + var got []map[string]interface{} + for { + doc, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if !doc.Exists() { + t.Fatalf("Exists: got: false, want: true") + } + data := doc.Data() + got = append(got, data) + } + if len(got) >= 10 { + t.Errorf("Sampled documents count should be less than total. got %d, total 10", len(got)) + } + if len(got) == 0 { + t.Errorf("Sampled documents count should be greater than 0. got %d", len(got)) + } + }) + }) + t.Run("Select", func(t *testing.T) { + iter := client.Pipeline().Collection(coll.ID).Select("title", "author.name").Limit(1).Execute(ctx) + defer iter.Stop() + doc, err := iter.Next() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if !doc.Exists() { + t.Fatalf("Exists: got: false, want: true") + } + data := doc.Data() + if _, ok := data["title"]; !ok { + t.Error("missing 'title' field") + } + if _, ok := data["author.name"]; !ok { + t.Error("missing 'author.name' field") + } + if _, ok := data["author"]; ok { + t.Error("unexpected 'author' field") + } + if _, ok := data["genre"]; ok { + t.Error("unexpected 'genre' field") + } + }) + t.Run("Sort", func(t *testing.T) { + iter := client.Pipeline().Collection(coll.ID).Sort(Descending(FieldOf("rating"))).Limit(1).Execute(ctx) + defer iter.Stop() + doc, err := iter.Next() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if !doc.Exists() { + t.Fatalf("Exists: got: false, want: true") + } + data := doc.Data() + if data["title"] != "The Lord of the Rings" { + t.Errorf("got title %q, want 'The Lord of the Rings'", data["title"]) + } + }) + t.Run("Union", func(t *testing.T) { + type Employee struct { + Name string `firestore:"name"` + Age int `firestore:"age"` + } + type Customer struct { + Name string `firestore:"name"` + Address string `firestore:"address"` + } + employeeColl := client.Collection(collectionIDs.New()) + customerColl := client.Collection(collectionIDs.New()) + employees := []Employee{ + {Name: "John Doe", Age: 42}, + {Name: "Jane Smith", Age: 35}, + } + customers := []Customer{ + {Name: "Alice", Address: "123 Main St"}, + {Name: "Bob", Address: "456 Oak Ave"}, + } + var unionDocRefs []*DocumentRef + for _, e := range employees { + docRef := employeeColl.NewDoc() + h.mustCreate(docRef, e) + unionDocRefs = append(unionDocRefs, docRef) + } + for _, c := range customers { + docRef := customerColl.NewDoc() + h.mustCreate(docRef, c) + unionDocRefs = append(unionDocRefs, docRef) + } + t.Cleanup(func() { + deleteDocuments(unionDocRefs) + }) + employeePipeline := client.Pipeline().Collection(employeeColl.ID) + customerPipeline := client.Pipeline().Collection(customerColl.ID) + iter := employeePipeline.Union(customerPipeline).Execute(context.Background()) + defer iter.Stop() + var got []map[string]interface{} + for { + doc, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if !doc.Exists() { + t.Fatalf("Exists: got: false, want: true") + } + data := doc.Data() + got = append(got, data) + } + want := []map[string]interface{}{ + {"name": "John Doe", "age": int64(42)}, + {"name": "Jane Smith", "age": int64(35)}, + {"name": "Alice", "address": "123 Main St"}, + {"name": "Bob", "address": "456 Oak Ave"}, + } + sort.Slice(got, func(i, j int) bool { + return got[i]["name"].(string) < got[j]["name"].(string) + }) + sort.Slice(want, func(i, j int) bool { + return want[i]["name"].(string) < want[j]["name"].(string) + }) + if diff := testutil.Diff(got, want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", got, want, diff) + } + }) + t.Run("Unnest", func(t *testing.T) { + iter := client.Pipeline().Collection(coll.ID). + Where(Equal(FieldOf("title"), "The Hitchhiker's Guide to the Galaxy")). + UnnestWithAlias("tags", "tag", nil). + Select("title", "tag"). + Execute(ctx) + defer iter.Stop() + var got []map[string]interface{} + for { + doc, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if !doc.Exists() { + t.Fatalf("Exists: got: false, want: true") + } + data := doc.Data() + got = append(got, data) + } + want := []map[string]interface{}{ + {"title": "The Hitchhiker's Guide to the Galaxy", "tag": "comedy"}, + {"title": "The Hitchhiker's Guide to the Galaxy", "tag": "space"}, + {"title": "The Hitchhiker's Guide to the Galaxy", "tag": "adventure"}, + } + sort.Slice(got, func(i, j int) bool { + return got[i]["tag"].(string) < got[j]["tag"].(string) + }) + sort.Slice(want, func(i, j int) bool { + return want[i]["tag"].(string) < want[j]["tag"].(string) + }) + if diff := testutil.Diff(got, want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", got, want, diff) + } + }) + t.Run("UnnestWithIndexField", func(t *testing.T) { + iter := client.Pipeline().Collection(coll.ID). + Where(Equal(FieldOf("title"), "The Hitchhiker's Guide to the Galaxy")). + UnnestWithAlias("tags", "tag", &UnnestOptions{IndexField: "tagIndex"}). + Select("title", "tag", "tagIndex"). + Execute(ctx) + defer iter.Stop() + var got []map[string]interface{} + for { + doc, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if !doc.Exists() { + t.Fatalf("Exists: got: false, want: true") + } + data := doc.Data() + got = append(got, data) + } + want := []map[string]interface{}{ + {"title": "The Hitchhiker's Guide to the Galaxy", "tag": "comedy", "tagIndex": int64(0)}, + {"title": "The Hitchhiker's Guide to the Galaxy", "tag": "space", "tagIndex": int64(1)}, + {"title": "The Hitchhiker's Guide to the Galaxy", "tag": "adventure", "tagIndex": int64(2)}, + } + sort.Slice(got, func(i, j int) bool { + return got[i]["tagIndex"].(int64) < got[j]["tagIndex"].(int64) + }) + if diff := testutil.Diff(got, want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", got, want, diff) + } + }) + t.Run("Where", func(t *testing.T) { + iter := client.Pipeline().Collection(coll.ID).Where(Equal(FieldOf("author.country"), "UK")).Execute(ctx) + defer iter.Stop() + results, err := iter.GetAll() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if len(results) != 4 { + t.Errorf("got %d documents, want 4", len(results)) + } + }) +} + func TestIntegration_PipelineFunctions(t *testing.T) { if testParams[firestoreEditionKey].(firestoreEdition) != editionEnterprise { t.Skip("Skipping pipeline queries tests since the firestore edition of", testParams[databaseIDKey].(string), "database is not enterprise") diff --git a/firestore/pipeline.go b/firestore/pipeline.go index f2a86c952767..a7048d496c30 100644 --- a/firestore/pipeline.go +++ b/firestore/pipeline.go @@ -51,12 +51,30 @@ func newPipeline(client *Client, initialStage pipelineStage) *Pipeline { // Execute executes the pipeline and returns an iterator for streaming the results. // TODO: Accept PipelineOptions func (p *Pipeline) Execute(ctx context.Context) *PipelineResultIterator { + ctx = withResourceHeader(ctx, p.c.path()) + ctx = withRequestParamsHeader(ctx, reqParamsHeaderVal(p.c.path())) return &PipelineResultIterator{ - iter: newStreamPipelineResultIterator(withResourceHeader(ctx, p.c.path()), p), + iter: newStreamPipelineResultIterator(ctx, p), } } func (p *Pipeline) toExecutePipelineRequest() (*pb.ExecutePipelineRequest, error) { + pipelinePb, err := p.toProto() + if err != nil { + return nil, err + } + + req := &pb.ExecutePipelineRequest{ + Database: p.c.path(), + PipelineType: &pb.ExecutePipelineRequest_StructuredPipeline{ + StructuredPipeline: &pb.StructuredPipeline{Pipeline: pipelinePb}, + }, + // TODO: Add consistencyselector + } + return req, nil +} + +func (p *Pipeline) toProto() (*pb.Pipeline, error) { if p.err != nil { return nil, p.err } @@ -68,20 +86,7 @@ func (p *Pipeline) toExecutePipelineRequest() (*pb.ExecutePipelineRequest, error } protoStages[i] = ps } - - req := &pb.ExecutePipelineRequest{ - Database: p.c.path(), - PipelineType: &pb.ExecutePipelineRequest_StructuredPipeline{ - StructuredPipeline: &pb.StructuredPipeline{ - Pipeline: &pb.Pipeline{ - Stages: protoStages, - }, - }, - }, - // TODO: Add consistencyselector - } - - return req, nil + return &pb.Pipeline{Stages: protoStages}, nil } // append creates a new Pipeline by adding a stage to the current one. @@ -103,6 +108,54 @@ func (p *Pipeline) Limit(limit int) *Pipeline { return p.append(newLimitStage(limit)) } +// OrderingDirection is the sort direction for pipeline result ordering. +type OrderingDirection string + +const ( + // OrderingAsc sorts results from smallest to largest. + OrderingAsc OrderingDirection = OrderingDirection("ascending") + + // OrderingDesc sorts results from largest to smallest. + OrderingDesc OrderingDirection = OrderingDirection("descending") +) + +// Ordering specifies the field and direction for sorting. +type Ordering struct { + Expr Expr + Direction OrderingDirection +} + +// Ascending creates an Ordering for ascending sort direction. +func Ascending(expr Expr) Ordering { + return Ordering{Expr: expr, Direction: OrderingAsc} +} + +// Descending creates an Ordering for descending sort direction. +func Descending(expr Expr) Ordering { + return Ordering{Expr: expr, Direction: OrderingDesc} +} + +// Sort sorts the documents by the given fields and directions. +func (p *Pipeline) Sort(orders ...Ordering) *Pipeline { + return p.append(newSortStage(orders...)) +} + +// Offset skips the first `offset` number of documents from the results of previous stages. +// +// This stage is useful for implementing pagination in your pipelines, allowing you to retrieve +// results in chunks. It is typically used in conjunction with [*Pipeline.Limit] to control the +// size of each page. +// +// Example: +// Retrieve the second page of 20 results +// +// client.Pipeline().Collection("books"). +// .Offset(20) // Skip the first 20 results +// .Limit(20) // Take the next 20 results +func (p *Pipeline) Offset(offset int) *Pipeline { + return p.append(newOffsetStage(offset)) +} + // Select selects or creates a set of fields from the outputs of previous stages. // The selected fields are defined using field path string, [FieldPath] or [Selectable] expressions. // [Selectable] expressions can be: @@ -116,11 +169,27 @@ func (p *Pipeline) Limit(limit int) *Pipeline { // client.Pipeline().Collection("users").Select(FieldOfPath([]string{"info", "email"})) // client.Pipeline().Collection("users").Select(FieldOfPath([]string{"info", "email"})) // client.Pipeline().Collection("users").Select(Add("age", 5).As("agePlus5")) -func (p *Pipeline) Select(fieldsOrSelectables ...any) *Pipeline { +func (p *Pipeline) Select(fieldpathsOrSelectables ...any) *Pipeline { if p.err != nil { return p } - stage, err := newSelectStage(fieldsOrSelectables...) + stage, err := newSelectStage(fieldpathsOrSelectables...) + if err != nil { + p.err = err + return p + } + return p.append(stage) +} + +// Distinct removes duplicate documents from the outputs of previous stages. +// +// You can optionally specify fields or [Selectable] expressions to determine distinctness. +// If no fields are specified, the entire document is used to determine distinctness. +func (p *Pipeline) Distinct(fieldpathsOrSelectables ...any) *Pipeline { + if p.err != nil { + return p + } + stage, err := newDistinctStage(fieldpathsOrSelectables...) if err != nil { p.err = err return p @@ -147,6 +216,19 @@ func (p *Pipeline) AddFields(selectables ...Selectable) *Pipeline { return p.append(stage) } +// RemoveFields removes fields from outputs from previous stages. +func (p *Pipeline) RemoveFields(fieldpaths ...any) *Pipeline { + if p.err != nil { + return p + } + stage, err := newRemoveFieldsStage(fieldpaths...) + if err != nil { + p.err = err + return p + } + return p.append(stage) +} + // Where filters the documents from previous stages to only include those matching the specified [BooleanExpr]. // // This stage allows you to apply conditions to the data, similar to a "WHERE" clause in SQL. @@ -175,8 +257,8 @@ func NewAggregateSpec(accumulators ...*AliasedAggregate) *AggregateSpec { } // WithGroups sets the grouping keys for the aggregation. -func (a *AggregateSpec) WithGroups(fieldsOrSelectables ...any) *AggregateSpec { - a.groups, a.err = fieldsOrSelectablesToSelectables(fieldsOrSelectables...) +func (a *AggregateSpec) WithGroups(fieldpathsOrSelectables ...any) *AggregateSpec { + a.groups, a.err = fieldsOrSelectablesToSelectables(fieldpathsOrSelectables...) return a } @@ -214,7 +296,7 @@ func (p *Pipeline) Aggregate(accumulators ...*AliasedAggregate) *Pipeline { // // // Calculate the average rating for each genre. // client.Pipeline().Collection("books"). -// AggregateWithSpec(NewAggregateSpec(Avg("rating").As("avg_rating")).WithGroups("genre")) +// AggregateWithSpec(NewAggregateSpec(Average("rating").As("avg_rating")).WithGroups("genre")) func (p *Pipeline) AggregateWithSpec(spec *AggregateSpec) *Pipeline { aggStage, err := newAggregateStage(spec) if err != nil { @@ -223,3 +305,182 @@ func (p *Pipeline) AggregateWithSpec(spec *AggregateSpec) *Pipeline { } return p.append(aggStage) } + +// UnnestOptions holds the configuration for the Unnest stage. +type UnnestOptions struct { + // IndexField specifies the name of the field to store the array index of the unnested element. + IndexField any +} + +// Unnest produces a document for each element in an array field. +// For each input document, this stage outputs zero or more documents. +// Each output document is a copy of the input document, but the array field is replaced by an element from the array. +// The `fieldOrSelectable` parameter specifies the array field to unnest. It can be a string representing the field path or a [Selectable] expression. +// If a [Selectable] is provided, the alias of the selectable will be used as the new field name. +func (p *Pipeline) Unnest(fieldpathsOrSelectable any) *Pipeline { + if p.err != nil { + return p + } + stage, err := newUnnestStageFromAny(fieldpathsOrSelectable) + if err != nil { + p.err = err + return p + } + return p.append(stage) +} + +// UnnestWithAlias produces a document for each element in an array field, with a specified alias for the unnested field. +// It can optionally take UnnestOptions. +func (p *Pipeline) UnnestWithAlias(fieldpath any, alias string, opts *UnnestOptions) *Pipeline { + if p.err != nil { + return p + } + + var fieldExpr Expr + switch v := fieldpath.(type) { + case string: + fieldExpr = FieldOf(v) + case FieldPath: + fieldExpr = FieldOfPath(v) + default: + p.err = errInvalidArg(fieldpath, "string", "FieldPath") + return p + } + + stage, err := newUnnestStage(fieldExpr, alias, opts) + if err != nil { + p.err = err + return p + } + return p.append(stage) +} + +// Union performs union of all documents from two pipelines, including duplicates. +// +// This stage will pass through documents from previous stage, and also pass through documents +// from previous stage of the other [*Pipeline] given in parameter. The order of documents +// emitted from this stage is undefined. +// +// Example: +// +// // Emit documents from books collection and magazines collection. +// client.Pipeline().Collection("books"). +// Union(client.Pipeline().Collection("magazines")) +func (p *Pipeline) Union(other *Pipeline) *Pipeline { + if p.err != nil { + return p + } + stage, err := newUnionStage(other) + if err != nil { + p.err = err + return p + } + return p.append(stage) +} + +// SampleMode defines the mode for the sample stage. +type SampleMode string + +const ( + // SampleModeDocuments samples a fixed number of documents. + SampleModeDocuments SampleMode = "documents" + // SampleModePercent samples a percentage of documents. + SampleModePercent SampleMode = "percent" +) + +// SampleSpec is used to define a sample operation. +type SampleSpec struct { + Size any + Mode SampleMode +} + +// SampleByDocuments creates a SampleSpec for sampling a fixed number of documents. +func SampleByDocuments(limit int) *SampleSpec { + return &SampleSpec{Size: limit, Mode: SampleModeDocuments} +} + +// Sample performs a pseudo-random sampling of the documents from the previous stage. +// +// This stage will filter documents pseudo-randomly. The behavior is defined by the SampleSpec. +// Use SampleByDocuments or SampleByPercentage to create a SampleSpec. +// +// Example: +// +// // Sample 10 books, if available. +// client.Pipeline().Collection("books").Sample(SampleByDocuments(10)) +// +// // Sample 50% of books. +// client.Pipeline().Collection("books").Sample(&SampleSpec{Size: 0.5, Mode: SampleModePercent}) +func (p *Pipeline) Sample(spec *SampleSpec) *Pipeline { + if p.err != nil { + return p + } + stage, err := newSampleStage(spec) + if err != nil { + p.err = err + return p + } + return p.append(stage) +} + +// Replace fully overwrites all fields in a document with those coming from a nested map. +// +// This stage allows you to emit a map value as a document. Each key of the map becomes a field +// on the document that contains the corresponding value. +// +// Example: +// +// // Input: { "name": "John Doe Jr.", "parents": { "father": "John Doe Sr.", "mother": "Jane Doe" } } +// // Emit parents as document. +// client.Pipeline().Collection("people").Replace("parents") +// // Output: { "father": "John Doe Sr.", "mother": "Jane Doe" } +func (p *Pipeline) Replace(fieldpathOrSelectable any) *Pipeline { + if p.err != nil { + return p + } + stage, err := newReplaceStage(fieldpathOrSelectable) + if err != nil { + p.err = err + return p + } + return p.append(stage) +} + +// PipelineDistanceMeasure is the distance measure for find_nearest pipeline stage. +type PipelineDistanceMeasure string + +const ( + // PipelineDistanceMeasureEuclidean is used to measures the Euclidean distance between the vectors. + PipelineDistanceMeasureEuclidean PipelineDistanceMeasure = "euclidean" + // PipelineDistanceMeasureCosine compares vectors based on the angle between them. + PipelineDistanceMeasureCosine PipelineDistanceMeasure = "cosine" + // PipelineDistanceMeasureDotProduct is similar to cosine but is affected by the magnitude of the vectors. + PipelineDistanceMeasureDotProduct PipelineDistanceMeasure = "dot_product" +) + +// PipelineFindNearestOptions are options for a FindNearest pipeline stage. +type PipelineFindNearestOptions struct { + Limit *int + DistanceField *string +} + +// FindNearest performs vector distance (similarity) search with given parameters to the stage inputs. +// +// This stage adds a "nearest neighbor search" capability to your pipelines. Given a field that +// stores vectors and a target vector, this stage will identify and return the inputs whose vector +// field is closest to the target vector. +// +// The vectorField can be a string, a FieldPath or an Expr. +// The queryVector can be Vector32, Vector64, []float32, or []float64. +func (p *Pipeline) FindNearest(vectorField any, queryVector any, measure PipelineDistanceMeasure, options *PipelineFindNearestOptions) *Pipeline { + if p.err != nil { + return p + } + + stage, err := newFindNearestStage(vectorField, queryVector, measure, options) + if err != nil { + p.err = err + return p + } + return p.append(stage) +} diff --git a/firestore/pipeline_expr.go b/firestore/pipeline_expr.go index af36273d822a..be50d31ad7b2 100644 --- a/firestore/pipeline_expr.go +++ b/firestore/pipeline_expr.go @@ -85,6 +85,10 @@ type Expr interface { Average() AggregateFunction Count() AggregateFunction + // Ordering + Ascending() Ordering + Descending() Ordering + // As assigns an alias to an expression. // Aliases are useful for renaming fields in the output of a stage. As(alias string) Selectable @@ -147,6 +151,9 @@ func (b *baseExpr) CountIf() AggregateFunction { return CountIf(b) } func (b *baseExpr) Maximum() AggregateFunction { return Maximum(b) } func (b *baseExpr) Minimum() AggregateFunction { return Minimum(b) } +// Ordering +func (b *baseExpr) Ascending() Ordering { return Ascending(b) } +func (b *baseExpr) Descending() Ordering { return Descending(b) } func (b *baseExpr) As(alias string) Selectable { return newAliasedExpr(b, alias) } diff --git a/firestore/pipeline_result.go b/firestore/pipeline_result.go index c06da71bd82f..6e2a3be7706b 100644 --- a/firestore/pipeline_result.go +++ b/firestore/pipeline_result.go @@ -83,39 +83,38 @@ func newPipelineResult(ref *DocumentRef, proto *pb.Document, c *Client, executio return pr, nil } +// Exists reports whether the PipelineResult represents an document. +// Even if Exists returns false, the rest of the fields are valid. +func (p *PipelineResult) Exists() bool { + return p.proto != nil +} + // Data returns the PipelineResult's fields as a map. // It is equivalent to // // var m map[string]any // p.DataTo(&m) -func (p *PipelineResult) Data() (map[string]any, error) { - if p == nil { - return nil, status.Errorf(codes.NotFound, "result does not exist") - } - var fields map[string]*pb.Value - if p.proto != nil { - fields = p.proto.Fields +func (p *PipelineResult) Data() map[string]any { + if p == nil || !p.Exists() { + return nil } - m, err := createMapFromValueMap(fields, p.c) + m, err := createMapFromValueMap(p.proto.Fields, p.c) + // Any error here is a bug in the client. if err != nil { panic(fmt.Sprintf("firestore: %v", err)) } - return m, nil + return m } // DataTo uses the PipelineResult's fields to populate v, which can be a pointer to a // map[string]any or a pointer to a struct. // This is similar to [DocumentSnapshot.DataTo] func (p *PipelineResult) DataTo(v any) error { - if p == nil { + if p == nil || !p.Exists() { return status.Errorf(codes.NotFound, "document does not exist") } - var fields map[string]*pb.Value - if p.proto != nil { - fields = p.proto.Fields - } - return setFromProtoValue(v, &pb.Value{ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{Fields: fields}}}, p.c) + return setFromProtoValue(v, &pb.Value{ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{Fields: p.proto.Fields}}}, p.c) } // PipelineResultIterator is an iterator over PipelineResults from a pipeline execution. @@ -223,6 +222,7 @@ func (it *streamPipelineResultIterator) next() (_ *PipelineResult, err error) { } ctx := withRequestParamsHeader(it.ctx, reqParamsHeaderVal(client.path())) + it.streamClient, err = client.c.ExecutePipeline(ctx, req) if err != nil { return nil, err diff --git a/firestore/pipeline_result_test.go b/firestore/pipeline_result_test.go index fb8c3ba5358c..ecaba4fb2419 100644 --- a/firestore/pipeline_result_test.go +++ b/firestore/pipeline_result_test.go @@ -113,6 +113,7 @@ func TestStreamPipelineResultIterator_Next(t *testing.T) { RecvErrors: tc.errors, ContextVal: ctx, } + iter := &streamPipelineResultIterator{ ctx: ctx, cancel: func() {}, @@ -151,10 +152,7 @@ func TestStreamPipelineResultIterator_Next(t *testing.T) { t.Fatalf("Result count mismatch for data check: expected %d, got %d", len(tc.wantData), len(results)) } for i, pr := range results { - data, err := pr.Data() - if err != nil { - t.Fatalf("Data: %v", err) - } + data := pr.Data() if diff := cmp.Diff(tc.wantData[i], data); diff != "" { t.Errorf("Data mismatch for result %d (-want +got):\n%s", i, diff) } @@ -238,18 +236,12 @@ func TestPipelineResultIterator_GetAll(t *testing.T) { t.Errorf("results from GetAll(): got %d, want: 2", len(allResults)) } - data, err := allResults[0].Data() - if err != nil { - t.Fatalf("Data: %v", err) - } + data := allResults[0].Data() if data["id"].(int64) != 1 { t.Errorf("first result id: got %v, want: 1", data["id"]) } - data, err = allResults[1].Data() - if err != nil { - t.Fatalf("Data: %v", err) - } + data = allResults[1].Data() if data["id"].(int64) != 2 { t.Errorf("second result id: got %v, want: 2", data["id"]) } @@ -274,11 +266,12 @@ func TestPipelineResult_DataExtraction(t *testing.T) { "stringProp": {ValueType: &pb.Value_StringValue{StringValue: "hello"}}, "intProp": {ValueType: &pb.Value_IntegerValue{IntegerValue: 123}}, "boolProp": {ValueType: &pb.Value_BooleanValue{BooleanValue: true}}, - "mapProp": {ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{ - Fields: map[string]*pb.Value{ - "nestedString": {ValueType: &pb.Value_StringValue{StringValue: "world"}}, - }, - }}}, + "mapProp": { + ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + "nestedString": {ValueType: &pb.Value_StringValue{StringValue: "world"}}, + }, + }}}, }, } execTimeProto := timestamppb.New(now.Add(time.Second)) @@ -290,11 +283,7 @@ func TestPipelineResult_DataExtraction(t *testing.T) { } // Test Data() - dataMap, err := pr.Data() - if err != nil { - t.Fatalf("Data: %+v", err) - } - + dataMap := pr.Data() if dataMap["stringProp"].(string) != "hello" { t.Errorf("stringProp: got %v, want 'hello'", dataMap["stringProp"]) } @@ -361,12 +350,9 @@ func TestPipelineResult_NoResults(t *testing.T) { t.Fatalf("newPipelineResult: %v", err) } - data, err := pr.Data() - if err != nil { - t.Errorf("pr.Data() for non-existent result err: got %v, want nil", err) - } - if data == nil { - t.Errorf("pr.Data() for non-existent result: got nil, want non-nil empty map") + data := pr.Data() + if data != nil { + t.Errorf("pr.Data() for non-existent result: got non-nil empty map, want nil") } if len(data) != 0 { t.Errorf("pr.Data() for non-existent result: got map with %d elements, want empty map", len(data)) @@ -374,10 +360,9 @@ func TestPipelineResult_NoResults(t *testing.T) { type MyStruct struct{ Foo string } var s MyStruct - err = pr.DataTo(&s) // Should behave like populating from an empty map - if err != nil { - // DataTo on a non-existent PipelineResult should not error out but result in a zero-valued struct. - t.Fatalf("pr.DataTo(&s) on non-existent result failed: %v", err) + err = pr.DataTo(&s) + if err == nil || status.Code(err) != codes.NotFound { + t.Fatalf("pr.DataTo(&s) on non-existent result failed: %v status.Code(err): %v", err, status.Code(err)) } if s.Foo != "" { t.Errorf("Struct Foo for non-existent result: got %q, want \"\"", s.Foo) diff --git a/firestore/pipeline_source.go b/firestore/pipeline_source.go index c0aa7240fe7c..92cf76f5b29d 100644 --- a/firestore/pipeline_source.go +++ b/firestore/pipeline_source.go @@ -38,20 +38,12 @@ func (ps *PipelineSource) CollectionGroup(collectionID string) *Pipeline { return newPipeline(ps.client, newInputStageCollectionGroup("", collectionID)) } -// CollectionGroupWithAncestor creates a new [Pipeline] that operates on all documents in a group -// of collections that include the given ID, that are underneath a given document. -// -// For example, consider: -// /continents/Europe/Countries/Germany/Cities/Paris = {population: 100} -// /continents/Europe/Countries/France/Cities/Paris = {population: 100} -// /continents/NorthAmerica/Countries/Canada/Cities/Montreal = {population: 90} -// -// CollectionGroupWithAncestor can be used to query across all "Cities" in "/continents/Europe". -func (ps *PipelineSource) CollectionGroupWithAncestor(ancestor, collectionID string) *Pipeline { - return newPipeline(ps.client, newInputStageCollectionGroup(ancestor, collectionID)) -} - // Database creates a new [Pipeline] that operates on all documents in the Firestore database. func (ps *PipelineSource) Database() *Pipeline { return newPipeline(ps.client, newInputStageDatabase()) } + +// Documents creates a new [Pipeline] that operates on a specific set of Firestore documents. +func (ps *PipelineSource) Documents(refs ...*DocumentRef) *Pipeline { + return newPipeline(ps.client, newInputStageDocuments(refs...)) +} diff --git a/firestore/pipeline_source_test.go b/firestore/pipeline_source_test.go index c3d979721f54..c275ccc3fdc7 100644 --- a/firestore/pipeline_source_test.go +++ b/firestore/pipeline_source_test.go @@ -84,39 +84,6 @@ func TestPipelineSource_CollectionGroup(t *testing.T) { } } -func TestPipelineSource_CollectionGroupWithAncestor(t *testing.T) { - client := newTestClient() - ps := &PipelineSource{client: client} - p := ps.CollectionGroupWithAncestor("ancestor/path", "items") - - if p.err != nil { - t.Fatalf("CollectionGroupWithAncestor: %v", p.err) - } - if len(p.stages) != 1 { - t.Fatalf("initial stages: got %d, want 1", len(p.stages)) - } - - req, err := p.toExecutePipelineRequest() - if err != nil { - t.Fatalf("toExecutePipelineRequest: %v", err) - } - - wantStage := &pb.Pipeline_Stage{ - Name: "collection_group", - Args: []*pb.Value{ - {ValueType: &pb.Value_ReferenceValue{ReferenceValue: "ancestor/path"}}, - {ValueType: &pb.Value_StringValue{StringValue: "items"}}, - }, - } - - if len(req.GetStructuredPipeline().GetPipeline().GetStages()) != 1 { - t.Fatalf("stage in proto: got %d, want 1", len(req.GetStructuredPipeline().GetPipeline().GetStages())) - } - if diff := testutil.Diff(wantStage, req.GetStructuredPipeline().GetPipeline().GetStages()[0]); diff != "" { - t.Errorf("toExecutePipelineRequest mismatch for collectionGroupWithAncestor stage (-want +got):\n%s", diff) - } -} - func TestPipelineSource_Database(t *testing.T) { client := newTestClient() ps := &PipelineSource{client: client} diff --git a/firestore/pipeline_stage.go b/firestore/pipeline_stage.go index 9546f2afce61..0df166aae45f 100644 --- a/firestore/pipeline_stage.go +++ b/firestore/pipeline_stage.go @@ -15,16 +15,42 @@ package firestore import ( + "fmt" "strings" pb "cloud.google.com/go/firestore/apiv1/firestorepb" ) +// baseStage is an internal helper to reduce repetition in pipelineStage +// implementations. +type baseStage struct { + stageName string + stagePb *pb.Pipeline_Stage +} + +func (s *baseStage) name() string { return s.stageName } +func (s *baseStage) toProto() (*pb.Pipeline_Stage, error) { return s.stagePb, nil } + +func errInvalidArg(v any, expected ...string) error { + return fmt.Errorf("firestore: invalid argument type: %T, expected one of: [%s]", v, strings.Join(expected, ", ")) +} + const ( - stageNameAddFields = "add_fields" - stageNameSelect = "select" - stageNameWhere = "where" - stageNameAggregate = "aggregate" + stageNameAddFields = "add_fields" + stageNameAggregate = "aggregate" + stageNameCollection = "collection" + stageNameCollectionGroup = "collection_group" + stageNameDatabase = "database" + stageNameDistinct = "distinct" + stageNameDocuments = "documents" + stageNameFindNearest = "find_nearest" + stageNameRemoveFields = "remove_fields" + stageNameReplace = "replace_with" + stageNameSample = "sample" + stageNameSelect = "select" + stageNameUnion = "union" + stageNameUnnest = "unnest" + stageNameWhere = "where" ) // internal interface for pipeline stages. @@ -44,7 +70,7 @@ func newInputStageCollection(path string) *inputStageCollection { } return &inputStageCollection{path: path} } -func (s *inputStageCollection) name() string { return "collection" } +func (s *inputStageCollection) name() string { return stageNameCollection } func (s *inputStageCollection) toProto() (*pb.Pipeline_Stage, error) { arg := &pb.Value{ValueType: &pb.Value_ReferenceValue{ReferenceValue: s.path}} return &pb.Pipeline_Stage{ @@ -62,7 +88,7 @@ type inputStageCollectionGroup struct { func newInputStageCollectionGroup(ancestor, collectionID string) *inputStageCollectionGroup { return &inputStageCollectionGroup{ancestor: ancestor, collectionID: collectionID} } -func (s *inputStageCollectionGroup) name() string { return "collection_group" } +func (s *inputStageCollectionGroup) name() string { return stageNameCollectionGroup } func (s *inputStageCollectionGroup) toProto() (*pb.Pipeline_Stage, error) { ancestor := &pb.Value{ValueType: &pb.Value_ReferenceValue{ReferenceValue: s.ancestor}} collectionID := &pb.Value{ValueType: &pb.Value_StringValue{StringValue: s.collectionID}} @@ -78,13 +104,159 @@ type inputStageDatabase struct{} func newInputStageDatabase() *inputStageDatabase { return &inputStageDatabase{} } -func (s *inputStageDatabase) name() string { return "database" } +func (s *inputStageDatabase) name() string { return stageNameDatabase } func (s *inputStageDatabase) toProto() (*pb.Pipeline_Stage, error) { return &pb.Pipeline_Stage{ Name: s.name(), }, nil } +type inputStageDocuments struct { + baseStage +} + +func newInputStageDocuments(refs ...*DocumentRef) *inputStageDocuments { + args := make([]*pb.Value, len(refs)) + for i, ref := range refs { + args[i] = &pb.Value{ValueType: &pb.Value_ReferenceValue{ReferenceValue: "/" + ref.shortPath}} + } + return &inputStageDocuments{baseStage{ + stageName: stageNameDocuments, + stagePb: &pb.Pipeline_Stage{ + Name: stageNameDocuments, + Args: args, + }, + }} +} + +// addFieldsStage is the internal representation of a AddFields stage. +type addFieldsStage struct { + baseStage +} + +func newAddFieldsStage(selectables ...Selectable) (*addFieldsStage, error) { + mapVal, err := projectionsToMapValue(selectables) + if err != nil { + return nil, err + } + stagePb := newUnaryStage(stageNameAddFields, mapVal) + return &addFieldsStage{baseStage{ + stageName: stageNameAddFields, + stagePb: stagePb, + }}, nil +} + +type aggregateStage struct { + baseStage +} + +func newAggregateStage(a *AggregateSpec) (*aggregateStage, error) { + if a.err != nil { + return nil, a.err + } + targetsPb, err := aliasedAggregatesToMapValue(a.accTargets) + if err != nil { + return nil, err + } + groupsPb, err := projectionsToMapValue(a.groups) + if err != nil { + return nil, err + } + return &aggregateStage{baseStage{ + stageName: stageNameAggregate, + stagePb: &pb.Pipeline_Stage{ + Name: stageNameAggregate, + Args: []*pb.Value{ + targetsPb, + groupsPb, + }, + }, + }}, nil +} + +type distinctStage struct { + baseStage +} + +// newProjectionStage is a helper for creating pipeline stages that take a +// projection as an argument. +func newProjectionStage(name string, fieldsOrSelectables ...any) (*pb.Pipeline_Stage, error) { + selectables, err := fieldsOrSelectablesToSelectables(fieldsOrSelectables...) + if err != nil { + return nil, err + } + mapVal, err := projectionsToMapValue(selectables) + if err != nil { + return nil, err + } + return &pb.Pipeline_Stage{ + Name: name, + Args: []*pb.Value{mapVal}, + }, nil +} + +func newDistinctStage(fieldsOrSelectables ...any) (*distinctStage, error) { + stagePb, err := newProjectionStage(stageNameDistinct, fieldsOrSelectables...) + if err != nil { + return nil, err + } + return &distinctStage{baseStage{stageName: stageNameDistinct, stagePb: stagePb}}, nil +} + +type findNearestStage struct { + baseStage +} + +func newFindNearestStage(vectorField any, queryVector any, measure PipelineDistanceMeasure, options *PipelineFindNearestOptions) (*findNearestStage, error) { + var propertyExpr Expr + switch v := vectorField.(type) { + case string: + propertyExpr = FieldOf(v) + case FieldPath: + propertyExpr = FieldOfPath(v) + case Expr: + propertyExpr = v + default: + return nil, errInvalidArg(vectorField, "string", "FieldPath", "Expr") + } + propPb, err := propertyExpr.toProto() + if err != nil { + return nil, err + } + var vectorPb *pb.Value + switch v := queryVector.(type) { + case Vector32: + vectorPb = vectorToProtoValue([]float32(v)) + case []float32: + vectorPb = vectorToProtoValue(v) + case Vector64: + vectorPb = vectorToProtoValue([]float64(v)) + case []float64: + vectorPb = vectorToProtoValue(v) + default: + return nil, errInvalidVector + } + measurePb := &pb.Value{ValueType: &pb.Value_StringValue{StringValue: string(measure)}} + var optionsPb map[string]*pb.Value + if options != nil { + optionsPb = make(map[string]*pb.Value) + if options.Limit != nil { + optionsPb["limit"] = &pb.Value{ValueType: &pb.Value_IntegerValue{IntegerValue: int64(*options.Limit)}} + } + if options.DistanceField != nil { + optionsPb["distance_field"] = &pb.Value{ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: *options.DistanceField}} + } + } + return &findNearestStage{baseStage{ + stageName: stageNameFindNearest, + stagePb: &pb.Pipeline_Stage{ + Name: stageNameFindNearest, + Args: []*pb.Value{propPb, vectorPb, measurePb}, + Options: optionsPb, + }, + }}, nil +} + type limitStage struct { limit int } @@ -101,99 +273,254 @@ func (s *limitStage) toProto() (*pb.Pipeline_Stage, error) { }, nil } -type selectStage struct { - stagePb *pb.Pipeline_Stage +type offsetStage struct { + offset int } -func newSelectStage(fieldsOrSelectables ...any) (*selectStage, error) { - selectables, err := fieldsOrSelectablesToSelectables(fieldsOrSelectables...) - if err != nil { - return nil, err +func newOffsetStage(offset int) *offsetStage { + return &offsetStage{offset: offset} +} +func (s *offsetStage) name() string { return "offset" } +func (s *offsetStage) toProto() (*pb.Pipeline_Stage, error) { + arg := &pb.Value{ValueType: &pb.Value_IntegerValue{IntegerValue: int64(s.offset)}} + return &pb.Pipeline_Stage{ + Name: s.name(), + Args: []*pb.Value{arg}, + }, nil +} + +type removeFieldsStage struct { + baseStage +} + +func newRemoveFieldsStage(fieldpaths ...any) (*removeFieldsStage, error) { + fields := make([]Expr, len(fieldpaths)) + for i, fp := range fieldpaths { + switch v := fp.(type) { + case string: + fields[i] = FieldOf(v) + case FieldPath: + fields[i] = FieldOfPath(v) + default: + return nil, errInvalidArg(fp, "string", "FieldPath") + } + } + args := make([]*pb.Value, len(fields)) + for i, f := range fields { + pb, err := f.toProto() + if err != nil { + return nil, err + } + args[i] = pb } + return &removeFieldsStage{baseStage{ + stageName: stageNameRemoveFields, + stagePb: &pb.Pipeline_Stage{ + Name: stageNameRemoveFields, + Args: args, + }, + }}, nil +} - mapVal, err := projectionsToMapValue(selectables) +type replaceStage struct { + baseStage +} + +func newReplaceStage(fieldOrSelectable any) (*replaceStage, error) { + var expr Expr + switch v := fieldOrSelectable.(type) { + case string: + expr = FieldOf(v) + case FieldPath: + expr = FieldOfPath(v) + case Selectable: + _, expr = v.getSelectionDetails() + default: + return nil, errInvalidArg(fieldOrSelectable, "string", "FieldPath", "Selectable") + } + exprPb, err := expr.toProto() if err != nil { return nil, err } + return &replaceStage{baseStage{ + stageName: stageNameReplace, + stagePb: &pb.Pipeline_Stage{ + Name: stageNameReplace, + Args: []*pb.Value{exprPb, &pb.Value{ValueType: &pb.Value_StringValue{StringValue: "full_replace"}}}, + }, + }}, nil +} - return &selectStage{ +type sampleStage struct { + baseStage +} + +func newSampleStage(spec *SampleSpec) (*sampleStage, error) { + var sizePb *pb.Value + switch v := spec.Size.(type) { + case int: + sizePb = &pb.Value{ValueType: &pb.Value_IntegerValue{IntegerValue: int64(v)}} + case int64: + sizePb = &pb.Value{ValueType: &pb.Value_IntegerValue{IntegerValue: v}} + case float64: + sizePb = &pb.Value{ValueType: &pb.Value_DoubleValue{DoubleValue: v}} + default: + return nil, fmt.Errorf("firestore: invalid type for sample size: %T", spec.Size) + } + modePb := &pb.Value{ValueType: &pb.Value_StringValue{StringValue: string(spec.Mode)}} + return &sampleStage{baseStage{ + stageName: stageNameSample, stagePb: &pb.Pipeline_Stage{ - Name: stageNameSelect, - Args: []*pb.Value{mapVal}, + Name: stageNameSample, + Args: []*pb.Value{sizePb, modePb}, }, - }, nil + }}, nil } -func (s *selectStage) name() string { return "select" } -func (s *selectStage) toProto() (*pb.Pipeline_Stage, error) { return s.stagePb, nil } -// addFieldsStage is the internal representation of a AddFields stage. -type addFieldsStage struct { - stagePb *pb.Pipeline_Stage +type selectStage struct { + baseStage } -func newAddFieldsStage(selectables ...Selectable) (*addFieldsStage, error) { - mapVal, err := projectionsToMapValue(selectables) +func newSelectStage(fieldsOrSelectables ...any) (*selectStage, error) { + stagePb, err := newProjectionStage(stageNameSelect, fieldsOrSelectables...) if err != nil { return nil, err } + return &selectStage{baseStage{stageName: stageNameSelect, stagePb: stagePb}}, nil +} - return &addFieldsStage{ - stagePb: &pb.Pipeline_Stage{ - Name: stageNameAddFields, - Args: []*pb.Value{mapVal}, - }, +type sortStage struct { + orders []Ordering +} + +func newSortStage(orders ...Ordering) *sortStage { + return &sortStage{orders: orders} +} +func (s *sortStage) name() string { return "sort" } +func (s *sortStage) toProto() (*pb.Pipeline_Stage, error) { + sortOrders := make([]*pb.Value, len(s.orders)) + for i, so := range s.orders { + fieldPb, err := so.Expr.toProto() + if err != nil { + return nil, err + } + sortOrders[i] = &pb.Value{ + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + "direction": { + ValueType: &pb.Value_StringValue{ + StringValue: string(so.Direction), + }, + }, + "expression": fieldPb, + }, + }, + }, + } + } + return &pb.Pipeline_Stage{ + Name: s.name(), + Args: sortOrders, }, nil } -func (s *addFieldsStage) name() string { return stageNameAddFields } -func (s *addFieldsStage) toProto() (*pb.Pipeline_Stage, error) { return s.stagePb, nil } -type whereStage struct { - stagePb *pb.Pipeline_Stage +type unionStage struct { + baseStage } -func newWhereStage(condition BooleanExpr) (*whereStage, error) { - argsPb, err := condition.toProto() +func newUnionStage(other *Pipeline) (*unionStage, error) { + otherPb, err := other.toProto() if err != nil { return nil, err } - return &whereStage{ + return &unionStage{baseStage{ + stageName: stageNameUnion, stagePb: &pb.Pipeline_Stage{ - Name: stageNameWhere, - Args: []*pb.Value{argsPb}, + Name: stageNameUnion, + Args: []*pb.Value{ + {ValueType: &pb.Value_PipelineValue{PipelineValue: otherPb}}, + }, }, - }, nil + }}, nil } -func (s *whereStage) name() string { return stageNameWhere } -func (s *whereStage) toProto() (*pb.Pipeline_Stage, error) { return s.stagePb, nil } - -type aggregateStage struct { - stagePb *pb.Pipeline_Stage +type unnestStage struct { + baseStage } -func newAggregateStage(a *AggregateSpec) (*aggregateStage, error) { - if a.err != nil { - return nil, a.err - } - targetsPb, err := aliasedAggregatesToMapValue(a.accTargets) +func newUnnestStage(fieldExpr Expr, alias string, opts *UnnestOptions) (*unnestStage, error) { + exprPb, err := fieldExpr.toProto() if err != nil { return nil, err } - - groupsPb, err := projectionsToMapValue(a.groups) + aliasPb, err := FieldOf(alias).toProto() if err != nil { return nil, err } - - return &aggregateStage{ + var optionsPb map[string]*pb.Value + if opts != nil && opts.IndexField != nil { + var indexFieldExpr Expr + switch v := opts.IndexField.(type) { + case FieldPath: + indexFieldExpr = FieldOfPath(v) + case string: + indexFieldExpr = FieldOf(v) + default: + return nil, errInvalidArg(opts.IndexField, "string", "FieldPath") + } + indexPb, err := indexFieldExpr.toProto() + if err != nil { + return nil, err + } + optionsPb = make(map[string]*pb.Value) + optionsPb["index_field"] = indexPb + } + return &unnestStage{baseStage{ + stageName: stageNameUnnest, stagePb: &pb.Pipeline_Stage{ - Name: stageNameAggregate, - Args: []*pb.Value{ - targetsPb, - groupsPb, - }, + Name: stageNameUnnest, + Args: []*pb.Value{exprPb, aliasPb}, + Options: optionsPb, }, - }, nil + }}, nil +} + +func newUnnestStageFromAny(fieldOrSelectable any) (*unnestStage, error) { + var expr Expr + var alias string + switch v := fieldOrSelectable.(type) { + case string: + expr = FieldOf(v) + alias = v + case Selectable: + alias, expr = v.getSelectionDetails() + default: + return nil, errInvalidArg(fieldOrSelectable, "string", "Selectable") + } + return newUnnestStage(expr, alias, nil) +} + +type whereStage struct { + baseStage +} + +// newUnaryStage is a helper for creating pipeline stages that take a single +// proto as an argument. +func newUnaryStage(name string, val *pb.Value) *pb.Pipeline_Stage { + return &pb.Pipeline_Stage{ + Name: name, + Args: []*pb.Value{val}, + } +} + +func newWhereStage(condition BooleanExpr) (*whereStage, error) { + argsPb, err := condition.toProto() + if err != nil { + return nil, err + } + return &whereStage{baseStage{ + stageName: stageNameWhere, + stagePb: newUnaryStage(stageNameWhere, argsPb), + }}, nil } -func (s *aggregateStage) name() string { return stageNameAggregate } -func (s *aggregateStage) toProto() (*pb.Pipeline_Stage, error) { return s.stagePb, nil } diff --git a/firestore/pipeline_stage_test.go b/firestore/pipeline_stage_test.go new file mode 100644 index 000000000000..86e577238fb8 --- /dev/null +++ b/firestore/pipeline_stage_test.go @@ -0,0 +1,431 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package firestore + +import ( + "context" + "testing" + + pb "cloud.google.com/go/firestore/apiv1/firestorepb" + "cloud.google.com/go/internal/testutil" +) + +func TestPipelineStages(t *testing.T) { + docRef1 := &DocumentRef{ + Path: "projects/projectID/databases/(default)/documents/collection/doc1", + shortPath: "collection/doc1", + } + docRef2 := &DocumentRef{ + Path: "projects/projectID/databases/(default)/documents/collection/doc2", + shortPath: "collection/doc2", + } + + testcases := []struct { + desc string + stage pipelineStage + want *pb.Pipeline_Stage + }{ + { + desc: "inputStageCollection", + stage: newInputStageCollection("my-collection"), + want: &pb.Pipeline_Stage{ + Name: "collection", + Args: []*pb.Value{{ValueType: &pb.Value_ReferenceValue{ReferenceValue: "/my-collection"}}}, + }, + }, + { + desc: "inputStageCollectionGroup", + stage: newInputStageCollectionGroup("ancestor/path", "my-collection-group"), + want: &pb.Pipeline_Stage{ + Name: "collection_group", + Args: []*pb.Value{ + {ValueType: &pb.Value_ReferenceValue{ReferenceValue: "ancestor/path"}}, + {ValueType: &pb.Value_StringValue{StringValue: "my-collection-group"}}, + }, + }, + }, + { + desc: "inputStageDatabase", + stage: newInputStageDatabase(), + want: &pb.Pipeline_Stage{Name: "database"}, + }, + { + desc: "inputStageDocuments", + stage: newInputStageDocuments(docRef1, docRef2), + want: &pb.Pipeline_Stage{ + Name: "documents", + Args: []*pb.Value{ + {ValueType: &pb.Value_ReferenceValue{ReferenceValue: "/collection/doc1"}}, + {ValueType: &pb.Value_ReferenceValue{ReferenceValue: "/collection/doc2"}}, + }, + }, + }, + { + desc: "limitStage", + stage: newLimitStage(10), + want: &pb.Pipeline_Stage{ + Name: "limit", + Args: []*pb.Value{{ValueType: &pb.Value_IntegerValue{IntegerValue: 10}}}, + }, + }, + { + desc: "offsetStage", + stage: newOffsetStage(5), + want: &pb.Pipeline_Stage{ + Name: "offset", + Args: []*pb.Value{{ValueType: &pb.Value_IntegerValue{IntegerValue: 5}}}, + }, + }, + { + desc: "sortStage", + stage: newSortStage(Ascending(FieldOf("name")), Descending(FieldOf("age"))), + want: &pb.Pipeline_Stage{ + Name: "sort", + Args: []*pb.Value{ + {ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{Fields: map[string]*pb.Value{ + "direction": {ValueType: &pb.Value_StringValue{StringValue: "ascending"}}, + "expression": {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "name"}}, + }}}}, + {ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{Fields: map[string]*pb.Value{ + "direction": {ValueType: &pb.Value_StringValue{StringValue: "descending"}}, + "expression": {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "age"}}, + }}}}, + }, + }, + }, + } + + for _, tc := range testcases { + t.Run(tc.desc, func(t *testing.T) { + got, err := tc.stage.toProto() + if err != nil { + t.Fatalf("toProto() failed: %v", err) + } + if diff := testutil.Diff(got, tc.want); diff != "" { + t.Errorf("toProto() returned diff (-got +want): %s", diff) + } + }) + } +} + +func TestSelectStage(t *testing.T) { + stage, err := newSelectStage("name", FieldOf("age"), Add(FieldOf("score"), 10).As("new_score")) + if err != nil { + t.Fatalf("newSelectStage() failed: %v", err) + } + + want := &pb.Pipeline_Stage{ + Name: "select", + Args: []*pb.Value{ + {ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{Fields: map[string]*pb.Value{ + "name": {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "name"}}, + "age": {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "age"}}, + "new_score": {ValueType: &pb.Value_FunctionValue{FunctionValue: &pb.Function{ + Name: "add", + Args: []*pb.Value{ + {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "score"}}, + {ValueType: &pb.Value_IntegerValue{IntegerValue: 10}}, + }, + }}}, + }}}}, + }, + } + + got, err := stage.toProto() + if err != nil { + t.Fatalf("toProto() failed: %v", err) + } + if diff := testutil.Diff(got, want); diff != "" { + t.Errorf("toProto() returned diff (-got +want): %s", diff) + } +} + +func TestWhereStage(t *testing.T) { + condition := Equal(FieldOf("genre"), "Sci-Fi") + stage, err := newWhereStage(condition) + if err != nil { + t.Fatalf("newWhereStage() failed: %v", err) + } + + want := &pb.Pipeline_Stage{ + Name: "where", + Args: []*pb.Value{ + {ValueType: &pb.Value_FunctionValue{FunctionValue: &pb.Function{ + Name: "equal", + Args: []*pb.Value{ + {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "genre"}}, + {ValueType: &pb.Value_StringValue{StringValue: "Sci-Fi"}}, + }, + }}}, + }, + } + + got, err := stage.toProto() + if err != nil { + t.Fatalf("toProto() failed: %v", err) + } + if diff := testutil.Diff(got, want); diff != "" { + t.Errorf("toProto() returned diff (-got +want): %s", diff) + } +} + +func TestAddFieldsStage(t *testing.T) { + stage, err := newAddFieldsStage(FieldOf("name").As("name"), Add(FieldOf("score"), 10).As("new_score")) + if err != nil { + t.Fatalf("newAddFieldsStage() failed: %v", err) + } + + want := &pb.Pipeline_Stage{ + Name: "add_fields", + Args: []*pb.Value{ + {ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{Fields: map[string]*pb.Value{ + "name": {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "name"}}, + "new_score": {ValueType: &pb.Value_FunctionValue{FunctionValue: &pb.Function{ + Name: "add", + Args: []*pb.Value{ + {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "score"}}, + {ValueType: &pb.Value_IntegerValue{IntegerValue: 10}}, + }, + }}}, + }}}}, + }, + } + + got, err := stage.toProto() + if err != nil { + t.Fatalf("toProto() failed: %v", err) + } + if diff := testutil.Diff(got, want); diff != "" { + t.Errorf("toProto() returned diff (-got +want): %s", diff) + } +} + +func TestAggregateStage(t *testing.T) { + spec := NewAggregateSpec(Sum("score").As("total_score")).WithGroups("category") + stage, err := newAggregateStage(spec) + if err != nil { + t.Fatalf("newAggregateStage() failed: %v", err) + } + + want := &pb.Pipeline_Stage{ + Name: "aggregate", + Args: []*pb.Value{ + {ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{Fields: map[string]*pb.Value{ + "total_score": {ValueType: &pb.Value_FunctionValue{FunctionValue: &pb.Function{ + Name: "sum", + Args: []*pb.Value{ + {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "score"}}, + }, + }}}, + }}}}, + {ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{Fields: map[string]*pb.Value{ + "category": {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "category"}}, + }}}}, + }, + } + + got, err := stage.toProto() + if err != nil { + t.Fatalf("toProto() failed: %v", err) + } + if diff := testutil.Diff(got, want); diff != "" { + t.Errorf("toProto() returned diff (-got +want): %s", diff) + } +} + +func TestDistinctStage(t *testing.T) { + stage, err := newDistinctStage("category", FieldOf("author")) + if err != nil { + t.Fatalf("newDistinctStage() failed: %v", err) + } + + want := &pb.Pipeline_Stage{ + Name: "distinct", + Args: []*pb.Value{ + {ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{Fields: map[string]*pb.Value{ + "category": {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "category"}}, + "author": {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "author"}}, + }}}}, + }, + } + + got, err := stage.toProto() + if err != nil { + t.Fatalf("toProto() failed: %v", err) + } + if diff := testutil.Diff(got, want); diff != "" { + t.Errorf("toProto() returned diff (-got +want): %s", diff) + } +} + +func TestFindNearestStage(t *testing.T) { + limit := 10 + distanceField := "distance" + stage, err := newFindNearestStage("embedding", []float64{1, 2, 3}, PipelineDistanceMeasureEuclidean, &PipelineFindNearestOptions{Limit: &limit, DistanceField: &distanceField}) + if err != nil { + t.Fatalf("newFindNearestStage() failed: %v", err) + } + + want := &pb.Pipeline_Stage{ + Name: "find_nearest", + Args: []*pb.Value{ + {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "embedding"}}, + vectorToProtoValue([]float64{1, 2, 3}), + {ValueType: &pb.Value_StringValue{StringValue: "euclidean"}}, + }, + Options: map[string]*pb.Value{ + "limit": {ValueType: &pb.Value_IntegerValue{IntegerValue: 10}}, + "distance_field": {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "distance"}}, + }, + } + + got, err := stage.toProto() + if err != nil { + t.Fatalf("toProto() failed: %v", err) + } + if diff := testutil.Diff(got, want); diff != "" { + t.Errorf("toProto() returned diff (-got +want): %s", diff) + } +} + +func TestRemoveFieldsStage(t *testing.T) { + stage, err := newRemoveFieldsStage("price", FieldPath{"author", "name"}) + if err != nil { + t.Fatalf("newRemoveFieldsStage() failed: %v", err) + } + + want := &pb.Pipeline_Stage{ + Name: "remove_fields", + Args: []*pb.Value{ + {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "price"}}, + {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "author.name"}}, + }, + } + + got, err := stage.toProto() + if err != nil { + t.Fatalf("toProto() failed: %v", err) + } + if diff := testutil.Diff(got, want); diff != "" { + t.Errorf("toProto() returned diff (-got +want): %s", diff) + } +} + +func TestReplaceStage(t *testing.T) { + stage, err := newReplaceStage("metadata") + if err != nil { + t.Fatalf("newReplaceStage() failed: %v", err) + } + + want := &pb.Pipeline_Stage{ + Name: "replace_with", + Args: []*pb.Value{ + {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "metadata"}}, + {ValueType: &pb.Value_StringValue{StringValue: "full_replace"}}, + }, + } + + got, err := stage.toProto() + if err != nil { + t.Fatalf("toProto() returned diff (-got +want): %v", err) + } + if diff := testutil.Diff(got, want); diff != "" { + t.Errorf("toProto() returned diff (-got +want): %s", diff) + } +} + +func TestSampleStage(t *testing.T) { + spec := SampleByDocuments(100) + stage, err := newSampleStage(spec) + if err != nil { + t.Fatalf("newSampleStage() failed: %v", err) + } + + want := &pb.Pipeline_Stage{ + Name: "sample", + Args: []*pb.Value{ + {ValueType: &pb.Value_IntegerValue{IntegerValue: 100}}, + {ValueType: &pb.Value_StringValue{StringValue: "documents"}}, + }, + } + + got, err := stage.toProto() + if err != nil { + t.Fatalf("toProto() failed: %v", err) + } + if diff := testutil.Diff(got, want); diff != "" { + t.Errorf("toProto() returned diff (-got +want):%s", diff) + } +} + +func TestUnionStage(t *testing.T) { + client, err := NewClient(context.Background(), "projectID") + if err != nil { + t.Fatalf("NewClient: %v", err) + } + otherPipeline := newPipeline(client, newInputStageCollection("other_collection")) + stage, err := newUnionStage(otherPipeline) + if err != nil { + t.Fatalf("newUnionStage() failed: %v", err) + } + + want := &pb.Pipeline_Stage{ + Name: "union", + Args: []*pb.Value{ + {ValueType: &pb.Value_PipelineValue{PipelineValue: &pb.Pipeline{ + Stages: []*pb.Pipeline_Stage{ + { + Name: "collection", + Args: []*pb.Value{{ValueType: &pb.Value_ReferenceValue{ReferenceValue: "/other_collection"}}}, + }, + }, + }}}, + }, + } + + got, err := stage.toProto() + if err != nil { + t.Fatalf("toProto() failed: %v", err) + } + if diff := testutil.Diff(got, want); diff != "" { + t.Errorf("toProto() returned diff (-got +want): %s", diff) + } +} + +func TestUnnestStage(t *testing.T) { + stage, err := newUnnestStage(FieldOf("tags"), "tag", &UnnestOptions{IndexField: "index"}) + if err != nil { + t.Fatalf("newUnnestStage() failed: %v", err) + } + + want := &pb.Pipeline_Stage{ + Name: "unnest", + Args: []*pb.Value{ + {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "tags"}}, + {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "tag"}}, + }, + Options: map[string]*pb.Value{ + "index_field": {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "index"}}, + }, + } + + got, err := stage.toProto() + if err != nil { + t.Fatalf("toProto() failed: %v", err) + } + if diff := testutil.Diff(got, want); diff != "" { + t.Errorf("toProto() returned diff (-got +want): %s", diff) + } +} diff --git a/firestore/pipeline_test.go b/firestore/pipeline_test.go index ee81322e8163..dbf9de74686e 100644 --- a/firestore/pipeline_test.go +++ b/firestore/pipeline_test.go @@ -150,3 +150,255 @@ func TestPipeline_ToExecutePipelineRequest(t *testing.T) { t.Errorf("Limit stage mismatch (-want +got):\n%s", diff) } } + +func TestPipeline_Sort(t *testing.T) { + client := newTestClient() + ps := &PipelineSource{client: client} + p := ps.Collection("users").Sort(Ordering{Expr: FieldOf("age"), Direction: OrderingDesc}) + + req, err := p.toExecutePipelineRequest() + if err != nil { + t.Fatalf("p.toExecutePipelineRequest() failed: %v", err) + } + + stages := req.GetStructuredPipeline().GetPipeline().GetStages() + if len(stages) != 2 { + t.Fatalf("Expected 2 stages in proto, got %d", len(stages)) + } + + wantSortStage := &pb.Pipeline_Stage{ + Name: "sort", + Args: []*pb.Value{ + { + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + "expression": {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "age"}}, + "direction": {ValueType: &pb.Value_StringValue{StringValue: "descending"}}, + }, + }, + }, + }, + }, + } + if diff := cmp.Diff(wantSortStage, stages[1], protocmp.Transform()); diff != "" { + t.Errorf("toExecutePipelineRequest() mismatch for sort stage (-want +got):\n%s", diff) + } +} + +func TestPipeline_Offset(t *testing.T) { + client := newTestClient() + ps := &PipelineSource{client: client} + p := ps.Collection("users").Offset(20) + + req, err := p.toExecutePipelineRequest() + if err != nil { + t.Fatalf("p.toExecutePipelineRequest() failed: %v", err) + } + + stages := req.GetStructuredPipeline().GetPipeline().GetStages() + if len(stages) != 2 { + t.Fatalf("Expected 2 stages in proto, got %d", len(stages)) + } + + wantOffsetStage := &pb.Pipeline_Stage{ + Name: "offset", + Args: []*pb.Value{{ValueType: &pb.Value_IntegerValue{IntegerValue: 20}}}, + } + if diff := cmp.Diff(wantOffsetStage, stages[1], protocmp.Transform()); diff != "" { + t.Errorf("toExecutePipelineRequest() mismatch for offset stage (-want +got):\n%s", diff) + } +} + +func TestPipeline_Select(t *testing.T) { + client := newTestClient() + ps := &PipelineSource{client: client} + p := ps.Collection("users").Select("name", FieldOf("age"), Add(FieldOf("score"), 10).As("new_score")) + + req, err := p.toExecutePipelineRequest() + if err != nil { + t.Fatalf("p.toExecutePipelineRequest() failed: %v", err) + } + + stages := req.GetStructuredPipeline().GetPipeline().GetStages() + if len(stages) != 2 { + t.Fatalf("Expected 2 stages in proto, got %d", len(stages)) + } + + wantSelectStage := &pb.Pipeline_Stage{ + Name: "select", + Args: []*pb.Value{ + {ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + "name": {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "name"}}, + "age": {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "age"}}, + "new_score": {ValueType: &pb.Value_FunctionValue{FunctionValue: &pb.Function{ + Name: "add", + Args: []*pb.Value{ + {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "score"}}, + {ValueType: &pb.Value_IntegerValue{IntegerValue: 10}}, + }, + }}}, + }, + }, + }}, + }, + } + if diff := cmp.Diff(wantSelectStage, stages[1], protocmp.Transform()); diff != "" { + t.Errorf("toExecutePipelineRequest() mismatch for select stage (-want +got):\n%s", diff) + } +} + +func TestPipeline_AddFields(t *testing.T) { + client := newTestClient() + ps := &PipelineSource{client: client} + p := ps.Collection("users").AddFields(Add(FieldOf("score"), 10).As("new_score")) + + req, err := p.toExecutePipelineRequest() + if err != nil { + t.Fatalf("p.toExecutePipelineRequest() failed: %v", err) + } + + stages := req.GetStructuredPipeline().GetPipeline().GetStages() + if len(stages) != 2 { + t.Fatalf("Expected 2 stages in proto, got %d", len(stages)) + } + + wantAddFieldsStage := &pb.Pipeline_Stage{ + Name: "add_fields", + Args: []*pb.Value{ + {ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + "new_score": {ValueType: &pb.Value_FunctionValue{FunctionValue: &pb.Function{ + Name: "add", + Args: []*pb.Value{ + {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "score"}}, + {ValueType: &pb.Value_IntegerValue{IntegerValue: 10}}, + }, + }}}, + }, + }, + }}, + }, + } + if diff := cmp.Diff(wantAddFieldsStage, stages[1], protocmp.Transform()); diff != "" { + t.Errorf("toExecutePipelineRequest() mismatch for addFields stage (-want +got):\n%s", diff) + } +} + +func TestPipeline_Where(t *testing.T) { + client := newTestClient() + ps := &PipelineSource{client: client} + p := ps.Collection("users").Where(Equal(FieldOf("age"), 30)) + + req, err := p.toExecutePipelineRequest() + if err != nil { + t.Fatalf("p.toExecutePipelineRequest() failed: %v", err) + } + + stages := req.GetStructuredPipeline().GetPipeline().GetStages() + if len(stages) != 2 { + t.Fatalf("Expected 2 stages in proto, got %d", len(stages)) + } + + wantWhereStage := &pb.Pipeline_Stage{ + Name: "where", + Args: []*pb.Value{ + {ValueType: &pb.Value_FunctionValue{FunctionValue: &pb.Function{ + Name: "equal", + Args: []*pb.Value{ + {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "age"}}, + {ValueType: &pb.Value_IntegerValue{IntegerValue: 30}}, + }, + }}}, + }, + } + if diff := cmp.Diff(wantWhereStage, stages[1], protocmp.Transform()); diff != "" { + t.Errorf("toExecutePipelineRequest() mismatch for where stage (-want +got):\n%s", diff) + } +} + +func TestPipeline_Aggregate(t *testing.T) { + client := newTestClient() + ps := &PipelineSource{client: client} + p := ps.Collection("users").Aggregate(Sum("age").As("total_age")) + + req, err := p.toExecutePipelineRequest() + if err != nil { + t.Fatalf("p.toExecutePipelineRequest() failed: %v", err) + } + + stages := req.GetStructuredPipeline().GetPipeline().GetStages() + if len(stages) != 2 { + t.Fatalf("Expected 2 stages in proto, got %d", len(stages)) + } + + wantAggregateStage := &pb.Pipeline_Stage{ + Name: "aggregate", + Args: []*pb.Value{ + {ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + "total_age": {ValueType: &pb.Value_FunctionValue{FunctionValue: &pb.Function{ + Name: "sum", + Args: []*pb.Value{ + {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "age"}}, + }, + }}}, + }, + }, + }}, + {ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{}}}, + }, + } + if diff := cmp.Diff(wantAggregateStage, stages[1], protocmp.Transform()); diff != "" { + t.Errorf("toExecutePipelineRequest() mismatch for aggregate stage (-want +got):\n%s", diff) + } +} + +func TestPipeline_AggregateWithSpec(t *testing.T) { + client := newTestClient() + ps := &PipelineSource{client: client} + spec := NewAggregateSpec(Average("rating").As("avg_rating")).WithGroups("genre") + p := ps.Collection("books").AggregateWithSpec(spec) + + req, err := p.toExecutePipelineRequest() + if err != nil { + t.Fatalf("p.toExecutePipelineRequest() failed: %v", err) + } + + stages := req.GetStructuredPipeline().GetPipeline().GetStages() + if len(stages) != 2 { + t.Fatalf("Expected 2 stages in proto, got %d", len(stages)) + } + + wantAggregateStage := &pb.Pipeline_Stage{ + Name: "aggregate", + Args: []*pb.Value{ + {ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + "avg_rating": {ValueType: &pb.Value_FunctionValue{FunctionValue: &pb.Function{ + Name: "average", + Args: []*pb.Value{ + {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "rating"}}, + }, + }}}, + }, + }, + }}, + {ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + "genre": {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "genre"}}, + }, + }, + }}, + }, + } + if diff := cmp.Diff(wantAggregateStage, stages[1], protocmp.Transform()); diff != "" { + t.Errorf("toExecutePipelineRequest() mismatch for aggregate stage (-want +got):\n%s", diff) + } +}