diff --git a/firestore/client.go b/firestore/client.go index beb78797197d..e96d57592804 100644 --- a/firestore/client.go +++ b/firestore/client.go @@ -47,7 +47,7 @@ const resourcePrefixHeader = "google-cloud-resource-prefix" // requestParamsHeader is routing header required to access named databases const reqParamsHeader = "x-goog-request-params" -// reqParamsHeaderVal constructs header from dbPath +// reqParamsHeaderVal constructs header from dbPath. // dbPath is of the form projects/{project_id}/databases/{database_id} func reqParamsHeaderVal(dbPath string) string { splitPath := strings.Split(dbPath, "/") @@ -181,6 +181,11 @@ func withRequestParamsHeader(ctx context.Context, requestParams string) context. return metadata.NewOutgoingContext(ctx, md) } +// Pipeline creates a PipelineSource to start building a Firestore pipeline. +func (c *Client) Pipeline() *PipelineSource { + return &PipelineSource{client: c} +} + // Collection creates a reference to a collection with the given path. // A path is a sequence of IDs separated by slashes. // diff --git a/firestore/integration_test.go b/firestore/integration_test.go index 52591e01f45f..974e06c732a5 100644 --- a/firestore/integration_test.go +++ b/firestore/integration_test.go @@ -47,16 +47,29 @@ import ( "google.golang.org/protobuf/types/known/structpb" ) -func TestMain(m *testing.M) { - databaseIDs := []string{DefaultDatabaseID} - databasesStr, ok := os.LookupEnv(envDatabases) - if ok { - databaseIDs = append(databaseIDs, strings.Split(databasesStr, ",")...) - } +type firestoreEdition int + +const ( + editionStandard firestoreEdition = iota // 0 + editionEnterprise // 1 +) + +const ( + envProjID = "GCLOUD_TESTS_GOLANG_FIRESTORE_PROJECT_ID" + envPrivateKey = "GCLOUD_TESTS_GOLANG_FIRESTORE_KEY" + envDatabases = "GCLOUD_TESTS_GOLANG_FIRESTORE_DATABASES" + envEnterpriseDatabases = "GCLOUD_TESTS_GOLANG_FIRESTORE_ENTERPRISE_DATABASES" + envEmulator = "FIRESTORE_EMULATOR_HOST" + indexBuilding = "index is currently building" + databaseIDKey = "databaseID" + firestoreEditionKey = "edition" +) +func TestMain(m *testing.M) { testParams = make(map[string]interface{}) - for _, databaseID := range databaseIDs { - testParams["databaseID"] = databaseID + for databaseID, edition := range parseDatabases() { + testParams[databaseIDKey] = databaseID + testParams[firestoreEditionKey] = edition initIntegrationTest() status := m.Run() if status != 0 { @@ -68,13 +81,26 @@ func TestMain(m *testing.M) { os.Exit(0) } -const ( - envProjID = "GCLOUD_TESTS_GOLANG_FIRESTORE_PROJECT_ID" - envPrivateKey = "GCLOUD_TESTS_GOLANG_FIRESTORE_KEY" - envDatabases = "GCLOUD_TESTS_GOLANG_FIRESTORE_DATABASES" - envEmulator = "FIRESTORE_EMULATOR_HOST" - indexBuilding = "index is currently building" -) +func parseDatabases() map[string]firestoreEdition { + databases := map[string]firestoreEdition{ + DefaultDatabaseID: editionStandard, + } + + databasesStr, ok := os.LookupEnv(envDatabases) + if ok { + for _, databaseID := range strings.Split(databasesStr, ",") { + databases[databaseID] = editionStandard + } + } + + databasesStr, ok = os.LookupEnv(envEnterpriseDatabases) + if ok { + for _, databaseID := range strings.Split(databasesStr, ",") { + databases[databaseID] = editionEnterprise + } + } + return databases +} var ( iClient *Client @@ -88,7 +114,7 @@ var ( ) func initIntegrationTest() { - databaseID := testParams["databaseID"].(string) + databaseID := testParams[databaseIDKey].(string) log.Printf("Setting up tests to run on databaseID: %q\n", databaseID) flag.Parse() // needed for testing.Short() if testing.Short() { @@ -2730,12 +2756,12 @@ func TestIntegration_NewClientWithDatabase(t *testing.T) { }{ { desc: "Success", - dbName: testParams["databaseID"].(string), + dbName: testParams[databaseIDKey].(string), wantErr: false, }, { desc: "Error from NewClient bubbled to NewClientWithDatabase", - dbName: testParams["databaseID"].(string), + dbName: testParams[databaseIDKey].(string), wantErr: true, opt: []option.ClientOption{option.WithCredentialsFile("non existent filepath")}, }, diff --git a/firestore/pipeline.go b/firestore/pipeline.go new file mode 100644 index 000000000000..ad9a4e05ecb4 --- /dev/null +++ b/firestore/pipeline.go @@ -0,0 +1,616 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package firestore + +import ( + "context" + "fmt" + + pb "cloud.google.com/go/firestore/apiv1/firestorepb" +) + +// Pipeline class provides a flexible and expressive framework for building complex data +// transformation and query pipelines for Firestore. +// +// A pipeline takes data sources, such as Firestore collections or collection groups, and applies +// a series of stages that are chained together. Each stage takes the output from the previous stage +// (or the data source) and produces an output for the next stage (or as the final output of the +// pipeline). +// +// Expressions can be used within +// each stages to filter and transform data through the stage. +// +// NOTE: The chained stages do not prescribe exactly how Firestore will execute the pipeline. +// Instead, Firestore only guarantees that the result is the same as if the chained stages were +// executed in order. +type Pipeline struct { + c *Client + stages []pipelineStage + readSettings *readSettings + executeSettings *executeSettings + tx *Transaction + err error +} + +func newPipeline(client *Client, initialStage pipelineStage) *Pipeline { + return &Pipeline{ + c: client, + stages: []pipelineStage{initialStage}, + readSettings: &readSettings{}, + executeSettings: &executeSettings{}, + } +} + +// executeSettings holds the options for executing a pipeline. +type executeSettings struct { + ExplainOptions *executeExplainOptions + IndexMode string +} + +// ExecuteOption is an option for executing a pipeline query. +type ExecuteOption interface { + apply(*executeSettings) +} + +type funcExecuteOption struct { + f func(*executeSettings) +} + +func (fdo *funcExecuteOption) apply(do *executeSettings) { + fdo.f(do) +} + +func newFuncExecuteOption(f func(*executeSettings)) *funcExecuteOption { + return &funcExecuteOption{ + f: f, + } +} + +// ExplainMode is the execution mode for pipeline explain. +type ExplainMode string + +const ( + // ExplainModeAnalyze both plans and executes the query. + ExplainModeAnalyze ExplainMode = "analyze" +) + +// executeExplainOptions are options for explaining a pipeline execution. +type executeExplainOptions struct { + Mode ExplainMode +} + +// WithExplainMode sets the execution mode for pipeline explain. +func WithExplainMode(mode ExplainMode) ExecuteOption { + return newFuncExecuteOption(func(eo *executeSettings) { + eo.ExplainOptions = &executeExplainOptions{Mode: mode} + }) +} + +// Execute executes the pipeline and returns a snapshot of the results. +func (p *Pipeline) Execute(ctx context.Context) *PipelineSnapshot { + ctx = withResourceHeader(ctx, p.c.path()) + ctx = withRequestParamsHeader(ctx, reqParamsHeaderVal(p.c.path())) + + return &PipelineSnapshot{ + iter: &PipelineResultIterator{ + iter: newStreamPipelineResultIterator(ctx, p), + }, + } +} + +func (p *Pipeline) toExecutePipelineRequest() (*pb.ExecutePipelineRequest, error) { + pipelinePb, err := p.toProto() + if err != nil { + return nil, err + } + + options := make(map[string]*pb.Value) + if p.executeSettings.ExplainOptions != nil { + options["explain_options"] = &pb.Value{ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + "mode": {ValueType: &pb.Value_StringValue{StringValue: string(p.executeSettings.ExplainOptions.Mode)}}, + }, + }}} + } + if p.executeSettings.IndexMode != "" { + options["index_mode"] = &pb.Value{ValueType: &pb.Value_StringValue{StringValue: p.executeSettings.IndexMode}} + } + + req := &pb.ExecutePipelineRequest{ + Database: p.c.path(), + PipelineType: &pb.ExecutePipelineRequest_StructuredPipeline{ + StructuredPipeline: &pb.StructuredPipeline{ + Pipeline: pipelinePb, + Options: options, + }, + }, + } + + // Note that transaction ID and other consistency selectors are mutually exclusive. + // We respect the transaction first, any read options passed by the caller second, + // and any read options stored in the client third. + if p.tx != nil { + req.ConsistencySelector = &pb.ExecutePipelineRequest_Transaction{Transaction: p.tx.id} + } else if rt, hasOpts := parseReadTime(p.c, p.readSettings); hasOpts { + req.ConsistencySelector = &pb.ExecutePipelineRequest_ReadTime{ReadTime: rt} + } + return req, nil +} + +func (p *Pipeline) toProto() (*pb.Pipeline, error) { + if p.err != nil { + return nil, p.err + } + protoStages := make([]*pb.Pipeline_Stage, len(p.stages)) + for i, s := range p.stages { + ps, err := s.toProto() + if err != nil { + return nil, fmt.Errorf("firestore: error converting stage %q to proto: %w", s.name(), err) + } + protoStages[i] = ps + } + return &pb.Pipeline{Stages: protoStages}, nil +} + +func (p *Pipeline) copy() *Pipeline { + newP := &Pipeline{ + c: p.c, + stages: make([]pipelineStage, len(p.stages)), + readSettings: &readSettings{}, + executeSettings: &executeSettings{}, + tx: p.tx, + err: p.err, + } + copy(newP.stages, p.stages) + *newP.readSettings = *p.readSettings + *newP.executeSettings = *p.executeSettings + return newP +} + +// WithReadOptions specifies constraints for accessing documents from the database, +// such as ReadTime. +func (p *Pipeline) WithReadOptions(opts ...ReadOption) *Pipeline { + newP := p.copy() + for _, opt := range opts { + if opt != nil { + opt.apply(newP.readSettings) + } + } + return newP +} + +// WithExecuteOptions specifies options for executing a pipeline. +func (p *Pipeline) WithExecuteOptions(opts ...ExecuteOption) *Pipeline { + newP := p.copy() + for _, opt := range opts { + if opt != nil { + opt.apply(newP.executeSettings) + } + } + return newP +} + +// append creates a new Pipeline by adding a stage to the current one. +func (p *Pipeline) append(s pipelineStage) *Pipeline { + if p.err != nil { + return p + } + newP := p.copy() + newP.stages = append(newP.stages, s) + return newP +} + +// Limit limits the maximum number of documents returned by previous stages. +func (p *Pipeline) Limit(limit int) *Pipeline { + return p.append(newLimitStage(limit)) +} + +// OrderingDirection is the sort direction for pipeline result ordering. +type OrderingDirection string + +const ( + // OrderingAsc sorts results from smallest to largest. + OrderingAsc OrderingDirection = OrderingDirection("ascending") + + // OrderingDesc sorts results from largest to smallest. + OrderingDesc OrderingDirection = OrderingDirection("descending") +) + +// Ordering specifies the field and direction for sorting. +type Ordering struct { + Expr Expression + Direction OrderingDirection +} + +// Ascending creates an Ordering for ascending sort direction. +func Ascending(expr Expression) Ordering { + return Ordering{Expr: expr, Direction: OrderingAsc} +} + +// Descending creates an Ordering for descending sort direction. +func Descending(expr Expression) Ordering { + return Ordering{Expr: expr, Direction: OrderingDesc} +} + +// Sort sorts the documents by the given fields and directions. +func (p *Pipeline) Sort(orders ...Ordering) *Pipeline { + return p.append(newSortStage(orders...)) +} + +// Offset skips the first `offset` number of documents from the results of previous stages. +// +// This stage is useful for implementing pagination in your pipelines, allowing you to retrieve +// results in chunks. It is typically used in conjunction with [*Pipeline.Limit] to control the +// size of each page. +// +// Example: +// Retrieve the second page of 20 results +// +// client.Pipeline().Collection("books"). +// .Offset(20) // Skip the first 20 results +// .Limit(20) // Take the next 20 results +func (p *Pipeline) Offset(offset int) *Pipeline { + return p.append(newOffsetStage(offset)) +} + +// Select selects or creates a set of fields from the outputs of previous stages. +// The selected fields are defined using field path string, [FieldPath] or [Selectable] expressions. +// [Selectable] expressions can be: +// - Field: References an existing field. +// - Function: Represents the result of a function with an assigned alias name using [FunctionExpression.As]. +// +// Example: +// +// client.Pipeline().Collection("users").Select("info.email") +// client.Pipeline().Collection("users").Select(FieldOf("info.email")) +// client.Pipeline().Collection("users").Select(FieldOf([]string{"info", "email"})) +// client.Pipeline().Collection("users").Select(FieldOf([]string{"info", "email"})) +// client.Pipeline().Collection("users").Select(Add("age", 5).As("agePlus5")) +func (p *Pipeline) Select(fieldpathsOrSelectables ...any) *Pipeline { + if p.err != nil { + return p + } + stage, err := newSelectStage(fieldpathsOrSelectables...) + if err != nil { + p.err = err + return p + } + return p.append(stage) +} + +// Distinct removes duplicate documents from the outputs of previous stages. +// +// You can optionally specify fields or [Selectable] expressions to determine distinctness. +// If no fields are specified, the entire document is used to determine distinctness. +func (p *Pipeline) Distinct(fieldpathsOrSelectables ...any) *Pipeline { + if p.err != nil { + return p + } + stage, err := newDistinctStage(fieldpathsOrSelectables...) + if err != nil { + p.err = err + return p + } + return p.append(stage) +} + +// AddFields adds new fields to outputs from previous stages. +// +// This stage allows you to compute values on-the-fly based on existing data from previous +// stages or constants. You can use this to create new fields or overwrite existing ones (if there +// is name overlaps). +// +// The added fields are defined using [Selectable]s +func (p *Pipeline) AddFields(selectables ...Selectable) *Pipeline { + if p.err != nil { + return p + } + stage, err := newAddFieldsStage(selectables...) + if err != nil { + p.err = err + return p + } + return p.append(stage) +} + +// RemoveFields removes fields from outputs from previous stages. +func (p *Pipeline) RemoveFields(fieldpaths ...any) *Pipeline { + if p.err != nil { + return p + } + stage, err := newRemoveFieldsStage(fieldpaths...) + if err != nil { + p.err = err + return p + } + return p.append(stage) +} + +// Where filters the documents from previous stages to only include those matching the specified [BooleanExpression]. +// +// This stage allows you to apply conditions to the data, similar to a "WHERE" clause in SQL. +func (p *Pipeline) Where(condition BooleanExpression) *Pipeline { + if p.err != nil { + return p + } + stage, err := newWhereStage(condition) + if err != nil { + p.err = err + return p + } + return p.append(stage) +} + +// AggregateSpec is used to perform aggregation operations. +type AggregateSpec struct { + groups []Selectable + accTargets []*AliasedAggregate + err error +} + +// NewAggregateSpec creates a new AggregateSpec with the given accumulator targets. +func NewAggregateSpec(accumulators ...*AliasedAggregate) *AggregateSpec { + return &AggregateSpec{accTargets: accumulators} +} + +// WithGroups sets the grouping keys for the aggregation. +func (a *AggregateSpec) WithGroups(fieldpathsOrSelectables ...any) *AggregateSpec { + a.groups, a.err = fieldsOrSelectablesToSelectables(fieldpathsOrSelectables...) + return a +} + +// Aggregate performs aggregation operations on the documents from previous stages. +// This stage allows you to calculate aggregate values over a set of documents. You define the +// aggregations to perform using [AliasedAggregate] expressions which are typically results of +// calling [AggregateFunction.As] on [AggregateFunction] instances. +// Example: +// +// client.Pipeline().Collection("users"). +// Aggregate(Sum("age").As("age_sum")) +func (p *Pipeline) Aggregate(accumulators ...*AliasedAggregate) *Pipeline { + a := NewAggregateSpec(accumulators...) + aggStage, err := newAggregateStage(a) + if err != nil { + p.err = err + return p + } + return p.append(aggStage) +} + +// AggregateWithSpec performs optionally grouped aggregation operations on the documents from previous stages. +// This stage allows you to calculate aggregate values over a set of documents, optionally +// grouped by one or more fields or functions. You can specify: +// - Grouping Fields or Functions: One or more fields or functions to group the documents +// by. For each distinct combination of values in these fields, a separate group is created. +// If no grouping fields are provided, a single group containing all documents is used. Not +// specifying groups is the same as putting the entire inputs into one group. +// - Accumulator targets: One or more accumulation operations to perform within each group. These +// are defined using [AliasedAggregate] expressions which are typically results of calling +// [AggregateFunction.As] on [AggregateFunction] instances. Each aggregation +// calculates a value (e.g., sum, average, count) based on the documents within its group. +// +// Example: +// +// // Calculate the average rating for each genre. +// client.Pipeline().Collection("books"). +// AggregateWithSpec(NewAggregateSpec(Average("rating").As("avg_rating")).WithGroups("genre")) +func (p *Pipeline) AggregateWithSpec(spec *AggregateSpec) *Pipeline { + aggStage, err := newAggregateStage(spec) + if err != nil { + p.err = err + return p + } + return p.append(aggStage) +} + +// UnnestOptions holds the configuration for the Unnest stage. +type UnnestOptions struct { + // IndexField specifies the name of the field to store the array index of the unnested element. + IndexField any +} + +// Unnest produces a document for each element in an array field. +// For each input document, this stage outputs zero or more documents. +// Each output document is a copy of the input document, but the array field is replaced by an element from the array. +// The `field` parameter specifies the array field to unnest. It can be a string representing the field path or a [Selectable] expression. +// The alias of the selectable will be used as the new field name. +func (p *Pipeline) Unnest(field Selectable, opts *UnnestOptions) *Pipeline { + if p.err != nil { + return p + } + stage, err := newUnnestStageFromSelectable(field, opts) + if err != nil { + p.err = err + return p + } + return p.append(stage) +} + +// UnnestWithAlias produces a document for each element in an array field, with a specified alias for the unnested field. +// It can optionally take UnnestOptions. +func (p *Pipeline) UnnestWithAlias(fieldpath any, alias string, opts *UnnestOptions) *Pipeline { + if p.err != nil { + return p + } + + var fieldExpr Expression + switch v := fieldpath.(type) { + case string: + fieldExpr = FieldOf(v) + case FieldPath: + fieldExpr = FieldOf(v) + default: + p.err = errInvalidArg(fieldpath, "string", "FieldPath") + return p + } + + stage, err := newUnnestStage(fieldExpr, alias, opts) + if err != nil { + p.err = err + return p + } + return p.append(stage) +} + +// Union performs union of all documents from two pipelines, including duplicates. +// +// This stage will pass through documents from previous stage, and also pass through documents +// from previous stage of the other [*Pipeline] given in parameter. The order of documents +// emitted from this stage is undefined. +// +// Example: +// +// // Emit documents from books collection and magazines collection. +// client.Pipeline().Collection("books"). +// Union(client.Pipeline().Collection("magazines")) +func (p *Pipeline) Union(other *Pipeline) *Pipeline { + if p.err != nil { + return p + } + stage, err := newUnionStage(other) + if err != nil { + p.err = err + return p + } + return p.append(stage) +} + +// SampleMode defines the mode for the sample stage. +type SampleMode string + +const ( + // SampleModeDocuments samples a fixed number of documents. + SampleModeDocuments SampleMode = "documents" + // SampleModePercent samples a percentage of documents. + SampleModePercent SampleMode = "percent" +) + +// SampleSpec is used to define a sample operation. +type SampleSpec struct { + Size any + Mode SampleMode +} + +// SampleByDocuments creates a SampleSpec for sampling a fixed number of documents. +func SampleByDocuments(limit int) *SampleSpec { + return &SampleSpec{Size: limit, Mode: SampleModeDocuments} +} + +// Sample performs a pseudo-random sampling of the documents from the previous stage. +// +// This stage will filter documents pseudo-randomly. The behavior is defined by the SampleSpec. +// Use SampleByDocuments or SampleByPercentage to create a SampleSpec. +// +// Example: +// +// // Sample 10 books, if available. +// client.Pipeline().Collection("books").Sample(SampleByDocuments(10)) +// +// // Sample 50% of books. +// client.Pipeline().Collection("books").Sample(&SampleSpec{Size: 0.5, Mode: SampleModePercent}) +func (p *Pipeline) Sample(spec *SampleSpec) *Pipeline { + if p.err != nil { + return p + } + stage, err := newSampleStage(spec) + if err != nil { + p.err = err + return p + } + return p.append(stage) +} + +// ReplaceWith fully overwrites all fields in a document with those coming from a nested map. +// +// This stage allows you to emit a map value as a document. Each key of the map becomes a field +// on the document that contains the corresponding value. +// +// Example: +// +// // Input: { "name": "John Doe Jr.", "parents": { "father": "John Doe Sr.", "mother": "Jane Doe" } } +// // Emit parents as document. +// client.Pipeline().Collection("people").ReplaceWith("parents") +// // Output: { "father": "John Doe Sr.", "mother": "Jane Doe" } +func (p *Pipeline) ReplaceWith(fieldpathOrExpr any) *Pipeline { + if p.err != nil { + return p + } + stage, err := newReplaceWithStage(fieldpathOrExpr) + if err != nil { + p.err = err + return p + } + return p.append(stage) +} + +// PipelineDistanceMeasure is the distance measure for find_nearest pipeline stage. +type PipelineDistanceMeasure string + +const ( + // PipelineDistanceMeasureEuclidean is used to measures the Euclidean distance between the vectors. + PipelineDistanceMeasureEuclidean PipelineDistanceMeasure = "euclidean" + // PipelineDistanceMeasureCosine compares vectors based on the angle between them. + PipelineDistanceMeasureCosine PipelineDistanceMeasure = "cosine" + // PipelineDistanceMeasureDotProduct is similar to cosine but is affected by the magnitude of the vectors. + PipelineDistanceMeasureDotProduct PipelineDistanceMeasure = "dot_product" +) + +// PipelineFindNearestOptions are options for a FindNearest pipeline stage. +type PipelineFindNearestOptions struct { + Limit *int + DistanceField *string +} + +// FindNearest performs vector distance (similarity) search with given parameters to the stage inputs. +// +// This stage adds a "nearest neighbor search" capability to your pipelines. Given a field that +// stores vectors and a target vector, this stage will identify and return the inputs whose vector +// field is closest to the target vector. +// +// The vectorField can be a string, a FieldPath or an Expr. +// The queryVector can be Vector32, Vector64, []float32, or []float64. +func (p *Pipeline) FindNearest(vectorField any, queryVector any, measure PipelineDistanceMeasure, options *PipelineFindNearestOptions) *Pipeline { + if p.err != nil { + return p + } + + stage, err := newFindNearestStage(vectorField, queryVector, measure, options) + if err != nil { + p.err = err + return p + } + return p.append(stage) +} + +// RawStage adds a generic stage to the pipeline. +// This method provides a flexible way to extend the pipeline's functionality by adding custom stages. +// +// Example: +// +// // Assume we don't have a built-in "where" stage +// client.Pipeline().Collection("books"). +// RawStage( +// NewRawStage("where"). +// WithArguments( +// LessThan(FieldOf("published"), 1900), +// ), +// ). +// Select("title", "author") +func (p *Pipeline) RawStage(stage *RawStage) *Pipeline { + if p.err != nil { + return p + } + return p.append(stage) +} diff --git a/firestore/pipeline_aggregate.go b/firestore/pipeline_aggregate.go new file mode 100644 index 000000000000..bf2fda84aa3d --- /dev/null +++ b/firestore/pipeline_aggregate.go @@ -0,0 +1,192 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package firestore + +import ( + "fmt" + + pb "cloud.google.com/go/firestore/apiv1/firestorepb" +) + +// AggregateFunction represents an aggregation function in a pipeline. +type AggregateFunction interface { + toProto() (*pb.Value, error) + getBaseAggregateFunction() *baseAggregateFunction + isAggregateFunction() + As(alias string) *AliasedAggregate +} + +// baseAggregateFunction provides common methods for all AggregateFunction implementations. +type baseAggregateFunction struct { + pbVal *pb.Value + err error +} + +func newBaseAggregateFunction(name string, fieldOrExpr any) *baseAggregateFunction { + var argsPbVals []*pb.Value + var err error + + if fieldOrExpr != nil { + var valueExpr Expression + switch value := fieldOrExpr.(type) { + case string: + valueExpr = FieldOf(value) + case FieldPath: + valueExpr = FieldOf(value) + case Expression: + valueExpr = value + default: + err = fmt.Errorf("firestore: invalid type for parameter 'value' for %s: expected string, FieldPath, or Expr, but got %T", name, value) + } + + if err == nil { + var pbVal *pb.Value + pbVal, err = valueExpr.toProto() + if err == nil { + argsPbVals = append(argsPbVals, pbVal) + } + } + } + + if err != nil { + return &baseAggregateFunction{err: err} + } + + pbVal := &pb.Value{ValueType: &pb.Value_FunctionValue{ + FunctionValue: &pb.Function{ + Name: name, + Args: argsPbVals, + }, + }} + return &baseAggregateFunction{pbVal: pbVal} +} + +func (b *baseAggregateFunction) toProto() (*pb.Value, error) { + if b.err != nil { + return nil, b.err + } + return b.pbVal, nil +} + +func (b *baseAggregateFunction) getBaseAggregateFunction() *baseAggregateFunction { return b } +func (b *baseAggregateFunction) isAggregateFunction() {} +func (b *baseAggregateFunction) As(alias string) *AliasedAggregate { + return &AliasedAggregate{baseAggregateFunction: b, alias: alias} +} + +// Ensure that baseAggregateFunction implements the AggregateFunction interface. +var _ AggregateFunction = (*baseAggregateFunction)(nil) + +// AliasedAggregate is an aliased [AggregateFunction]. +// It's used to give a name to the result of an aggregation. +type AliasedAggregate struct { + *baseAggregateFunction + alias string +} + +// Sum creates an aggregation that calculates the sum of values from an expression or a field's values +// across multiple stage inputs. +// +// Example: +// +// // Calculate the total revenue from a set of orders +// Sum(FieldOf("orderAmount")).As("totalRevenue") // FieldOf returns Expr +// Sum("orderAmount").As("totalRevenue") // String implicitly becomes FieldOf(...).As(...) +func Sum(fieldOrExpr any) AggregateFunction { + return newBaseAggregateFunction("sum", fieldOrExpr) +} + +// Average creates an aggregation that calculates the average (mean) of values from an expression or a field's values +// across multiple stage inputs. +// fieldOrExpr can be a field path string, [FieldPath] or [Expression] +// Example: +// +// // Calculate the average age of users +// Average(FieldOf("info.age")).As("averageAge") // FieldOf returns Expr +// Average(FieldOfPath("info.age")).As("averageAge") // FieldOfPath returns Expr +// Average("info.age").As("averageAge") // String implicitly becomes FieldOf(...).As(...) +// Average(FieldPath([]string{"info", "age"})).As("averageAge") +func Average(fieldOrExpr any) AggregateFunction { + return newBaseAggregateFunction("average", fieldOrExpr) +} + +// Count creates an aggregation that counts the number of stage inputs with valid evaluations of the +// provided field or expression. +// fieldOrExpr can be a field path string, [FieldPath] or [Expression] +// Example: +// +// // Count the number of items where the price is greater than 10 +// Count(FieldOf("price").Gt(10)).As("expensiveItemCount") // FieldOf("price").Gt(10) is a BooleanExpr +// // Count the total number of products +// Count("productId").As("totalProducts") // String implicitly becomes FieldOf(...).As(...) +func Count(fieldOrExpr any) AggregateFunction { + return newBaseAggregateFunction("count", fieldOrExpr) +} + +// CountAll creates an aggregation that counts the total number of stage inputs. +// +// Example: +// +// // Count the total number of users +// CountAll().As("totalUsers") +func CountAll() AggregateFunction { + return newBaseAggregateFunction("count", nil) +} + +// CountDistinct creates an aggregation that counts the number of distinct values of the +// provided field or expression. +// fieldOrExpr can be a field path string, [FieldPath] or [Expression] +// Example: +// +// // CountDistinct the number of distinct items where the price is greater than 10 +// CountDistinct(FieldOf("price").Gt(10)).As("expensiveItemCount") // FieldOf("price").Gt(10) is a BooleanExpr +// // CountDistinct the total number of distinct products +// CountDistinct("productId").As("totalProducts") // String implicitly becomes FieldOf(...).As(...) +func CountDistinct(fieldOrExpr any) AggregateFunction { + return newBaseAggregateFunction("count_distinct", fieldOrExpr) +} + +// CountIf creates an aggregation that counts the number of stage inputs where the provided boolean +// expression evaluates to true. +// Example: +// +// CountIf(FieldOf("published").Equal(true)).As("publishedCount") +func CountIf(condition BooleanExpression) AggregateFunction { + return newBaseAggregateFunction("count_if", condition) +} + +// Maximum creates an aggregation that calculates the maximum of values from an expression or a field's values +// across multiple stage inputs. +// +// Example: +// +// // Find the highest order amount +// Maximum(FieldOf("orderAmount")).As("maxOrderAmount") // FieldOf returns Expr +// Maximum("orderAmount").As("maxOrderAmount") // String implicitly becomes FieldOf(...).As(...) +func Maximum(fieldOrExpr any) AggregateFunction { + return newBaseAggregateFunction("maximum", fieldOrExpr) +} + +// Minimum creates an aggregation that calculates the minimum of values from an expression or a field's values +// across multiple stage inputs. +// +// Example: +// +// // Find the lowest order amount +// Minimum(FieldOf("orderAmount")).As("minOrderAmount") // FieldOf returns Expr +// Minimum("orderAmount").As("minOrderAmount") // String implicitly becomes FieldOf(...).As(...) +func Minimum(fieldOrExpr any) AggregateFunction { + return newBaseAggregateFunction("minimum", fieldOrExpr) +} diff --git a/firestore/pipeline_constant.go b/firestore/pipeline_constant.go new file mode 100644 index 000000000000..583799b3266f --- /dev/null +++ b/firestore/pipeline_constant.go @@ -0,0 +1,75 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package firestore + +import ( + "fmt" + "reflect" + "time" + + "google.golang.org/genproto/googleapis/type/latlng" + ts "google.golang.org/protobuf/types/known/timestamppb" +) + +// constant represents a constant value that can be used in a Firestore pipeline expression. +// It implements the [Expression] interface. +type constant struct { + *baseExpression +} + +// ConstantOf creates a new constant [Expression] from a Go value. +func ConstantOf(value any) Expression { + if value == nil { + return ConstantOfNull() + } + + switch value := value.(type) { + case *constant: // If it's already our private constant type + return value + case Expression: + // If it's already an Expr that isn't *constant, we create a new constant from it if possible. + // This path is primarily for if a user passes, e.g., a function result to ConstantOf. + // if it's not *constant, we fall through to scalar type checking. + break + } + + // Handle known scalar types + switch value.(type) { + case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, float32, float64, time.Time, *ts.Timestamp, []byte, Vector32, Vector64, bool, *latlng.LatLng, *DocumentRef: + pbVal, _, err := toProtoValue(reflect.ValueOf(value)) + if err != nil { + return &constant{baseExpression: &baseExpression{err: err}} + } + return &constant{baseExpression: &baseExpression{pbVal: pbVal}} + default: + return &constant{baseExpression: &baseExpression{err: fmt.Errorf("firestore: unknown constant type: %T", value)}} + } +} + +// ConstantOfNull creates a new constant [Expression] representing a null value. +func ConstantOfNull() Expression { + pbVal, _, err := toProtoValue(reflect.ValueOf(nil)) + return &constant{baseExpression: &baseExpression{pbVal: pbVal, err: err}} +} + +// ConstantOfVector32 creates a new [Vector32] constant [Expression] from a slice of float32s. +func ConstantOfVector32(value []float32) Expression { + return ConstantOf(Vector32(value)) +} + +// ConstantOfVector64 creates a new [Vector64] constant [Expression] from a slice of float64s. +func ConstantOfVector64(value []float64) Expression { + return ConstantOf(Vector64(value)) +} diff --git a/firestore/pipeline_expression.go b/firestore/pipeline_expression.go new file mode 100644 index 000000000000..46a758df8880 --- /dev/null +++ b/firestore/pipeline_expression.go @@ -0,0 +1,521 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package firestore + +import ( + pb "cloud.google.com/go/firestore/apiv1/firestorepb" +) + +// Selectable is an interface for expressions that can be selected in a pipeline. +type Selectable interface { + // getSelectionDetails returns the output alias and the underlying expression. + getSelectionDetails() (alias string, expr Expression) + + isSelectable() +} + +// Expression represents an expression that can be evaluated to a value within the execution of a +// [Pipeline]. +// +// Expressions are the building blocks for creating complex queries and transformations in +// Firestore pipelines. They can represent: +// +// - Field references: Access values from document fields. +// - Literals: Represent constant values (strings, numbers, booleans). +// - Function calls: Apply functions to one or more expressions. +// - Aggregations: Calculate aggregate values (e.g., sum, average) using [AggregateFunction] instances. +// +// The [Expression] interface provides a fluent API for building expressions. You can chain together +// method calls to create complex expressions. +type Expression interface { + isExpr() + toProto() (*pb.Value, error) + getBaseExpr() *baseExpression + + // Aritmetic operations + + // Add creates an expression that adds two expressions together, returning it as an Expr. + // + // The parameter 'other' can be a numeric constant or a numeric [Expression]. + Add(other any) Expression + // Subtract creates an expression that subtracts the right expression from the left expression, returning it as an Expr. + // + // The parameter 'other' can be a numeric constant or a numeric [Expression]. + Subtract(other any) Expression + // Multiply creates an expression that multiplies the left and right expressions, returning it as an Expr. + // + // The parameter 'other' can be a numeric constant or a numeric [Expression]. + Multiply(other any) Expression + // Divide creates an expression that divides the left expression by the right expression, returning it as an Expr. + // + // The parameter 'other' can be a numeric constant or a numeric [Expression]. + Divide(other any) Expression + // Abs creates an expression that is the absolute value of the input field or expression. + Abs() Expression + // Floor creates an expression that is the largest integer that isn't less than the input field or expression. + Floor() Expression + // Ceil creates an expression that is the smallest integer that isn't less than the input field or expression. + Ceil() Expression + // Exp creates an expression that is the Euler's number e raised to the power of the input field or expression. + Exp() Expression + // Log creates an expression that is logarithm of the left expression to base as the right expression, returning it as an Expr. + // + // The parameter 'other' can be a numeric constant or a numeric [Expression]. + Log(other any) Expression + // Log10 creates an expression that is the base 10 logarithm of the input field or expression. + Log10() Expression + // Ln creates an expression that is the natural logarithm (base e) of the input field or expression. + Ln() Expression + // Mod creates an expression that computes the modulo of the left expression by the right expression, returning it as an Expr. + // + // The parameter 'other' can be a numeric constant or a numeric [Expression]. + Mod(other any) Expression + // Pow creates an expression that computes the left expression raised to the power of the right expression, returning it as an Expr. + // + // The parameter 'other' can be a numeric constant or a numeric [Expression]. + Pow(other any) Expression + // Round creates an expression that rounds the input field or expression to nearest integer. + Round() Expression + // Sqrt creates an expression that is the square root of the input field or expression. + Sqrt() Expression + + // Array operations + // ArrayContains creates a boolean expression that checks if an array contains a specific value. + // + // The parameter 'value' can be a constant (e.g., string, int, bool) or an [Expression]. + ArrayContains(value any) BooleanExpression + // ArrayContainsAll creates a boolean expression that checks if an array contains all the specified values. + // + // The parameter 'values' can be a slice of constants (e.g., []string, []int) or an [Expression] that evaluates to an array. + ArrayContainsAll(values any) BooleanExpression + // ArrayContainsAny creates a boolean expression that checks if an array contains any of the specified values. + // + // The parameter 'values' can be a slice of constants (e.g., []string, []int) or an [Expression] that evaluates to an array. + ArrayContainsAny(values any) BooleanExpression + // ArrayLength creates an expression that calculates the length of an array. + ArrayLength() Expression + // EqualAny creates a boolean expression that checks if the expression is equal to any of the specified values. + // + // The parameter 'values' can be a slice of constants (e.g., []string, []int) or an [Expression] that evaluates to an array. + EqualAny(values any) BooleanExpression + // NotEqualAny creates a boolean expression that checks if the expression is not equal to any of the specified values. + // + // The parameter 'values' can be a slice of constants (e.g., []string, []int) or an [Expression] that evaluates to an array. + NotEqualAny(values any) BooleanExpression + // ArrayGet creates an expression that retrieves an element from an array at a specified index. + // + // The parameter 'offset' is the 0-based index of the element to retrieve. + // It can be an integer constant or an [Expression] that evaluates to an integer. + ArrayGet(offset any) Expression + // ArrayReverse creates an expression that reverses the order of elements in an array. + ArrayReverse() Expression + // ArrayConcat creates an expression that concatenates multiple arrays into a single array. + // + // The parameter 'otherArrays' can be a mix of array constants (e.g., []string, []int) or [Expression]s that evaluate to arrays. + ArrayConcat(otherArrays ...any) Expression + // ArraySum creates an expression that calculates the sum of all elements in a numeric array. + ArraySum() Expression + // ArrayMaximum creates an expression that finds the maximum element in a numeric array. + ArrayMaximum() Expression + // ArrayMinimum creates an expression that finds the minimum element in a numeric array. + ArrayMinimum() Expression + + // Timestamp operations + // TimestampAdd creates an expression that adds a specified amount of time to a timestamp. + // + // The parameter 'unit' can be a string constant (e.g., "day") or an [Expression] that evaluates to a valid unit string. + // Valid units include "microsecond", "millisecond", "second", "minute", "hour" and "day". + // The parameter 'amount' can be an integer constant or an [Expression] that evaluates to an integer. + TimestampAdd(unit, amount any) Expression + // TimestampSubtract creates an expression that subtracts a specified amount of time from a timestamp. + // + // The parameter 'unit' can be a string constant (e.g., "hour") or an [Expression] that evaluates to a valid unit string. + // Valid units include "microsecond", "millisecond", "second", "minute", "hour" and "day". + // The parameter 'amount' can be an integer constant or an [Expression] that evaluates to an integer. + TimestampSubtract(unit, amount any) Expression + // TimestampTruncate creates an expression that truncates a timestamp to a specified granularity. + // + // The parameter 'granularity' can be a string constant (e.g., "month") or an [Expression] that evaluates to a valid granularity string. + // Valid values are "microsecond", "millisecond", "second", "minute", "hour", "day", "week", "week(monday)", "week(tuesday)", + // "week(wednesday)", "week(thursday)", "week(friday)", "week(saturday)", "week(sunday)", "isoweek", "month", "quarter", "year", and "isoyear". + TimestampTruncate(granularity any) Expression + // TimestampTruncateWithTimezone creates an expression that truncates a timestamp to a specified granularity in a given timezone. + // + // The parameter 'granularity' can be a string constant (e.g., "week") or an [Expression] that evaluates to a valid granularity string. + // Valid values are "microsecond", "millisecond", "second", "minute", "hour", "day", "week", "week(monday)", "week(tuesday)", + // "week(wednesday)", "week(thursday)", "week(friday)", "week(saturday)", "week(sunday)", "isoweek", "month", "quarter", "year", and "isoyear". + // The parameter 'timezone' can be a string constant (e.g., "America/Los_Angeles") or an [Expression] that evaluates to a valid timezone string. + // Valid values are from the TZ database or in the format "Etc/GMT-1". + TimestampTruncateWithTimezone(granularity any, timezone string) Expression + // TimestampToUnixMicros creates an expression that converts a timestamp expression to the number of microseconds since + // the Unix epoch (1970-01-01 00:00:00 UTC). + TimestampToUnixMicros() Expression + // TimestampToUnixMillis creates an expression that converts a timestamp expression to the number of milliseconds since + // the Unix epoch (1970-01-01 00:00:00 UTC). + TimestampToUnixMillis() Expression + // TimestampToUnixSeconds creates an expression that converts a timestamp expression to the number of seconds since + // the Unix epoch (1970-01-01 00:00:00 UTC). + TimestampToUnixSeconds() Expression + // UnixMicrosToTimestamp creates an expression that converts a Unix timestamp in microseconds to a Firestore timestamp. + UnixMicrosToTimestamp() Expression + // UnixMillisToTimestamp creates an expression that converts a Unix timestamp in milliseconds to a Firestore timestamp. + UnixMillisToTimestamp() Expression + // UnixSecondsToTimestamp creates an expression that converts a Unix timestamp in seconds to a Firestore timestamp. + UnixSecondsToTimestamp() Expression + + // Comparison operations + // Equal creates a boolean expression that checks if the expression is equal to the other value. + // + // The parameter 'other' can be a constant (e.g., string, int, bool) or an [Expression]. + Equal(other any) BooleanExpression + // NotEqual creates a boolean expression that checks if the expression is not equal to the other value. + // + // The parameter 'other' can be a constant (e.g., string, int, bool) or an [Expression]. + NotEqual(other any) BooleanExpression + // GreaterThan creates a boolean expression that checks if the expression is greater than the other value. + // + // The parameter 'other' can be a constant (e.g., string, int, bool) or an [Expression]. + GreaterThan(other any) BooleanExpression + // GreaterThanOrEqual creates a boolean expression that checks if the expression is greater than or equal to the other value. + // + // The parameter 'other' can be a constant (e.g., string, int, bool) or an [Expression]. + GreaterThanOrEqual(other any) BooleanExpression + // LessThan creates a boolean expression that checks if the expression is less than the other value. + // + // The parameter 'other' can be a constant (e.g., string, int, bool) or an [Expression]. + LessThan(other any) BooleanExpression + // LessThanOrEqual creates a boolean expression that checks if the expression is less than or equal to the other value. + // + // The parameter 'other' can be a constant (e.g., string, int, bool) or an [Expression]. + LessThanOrEqual(other any) BooleanExpression + + // General functions + // Length creates an expression that calculates the length of string, array, map or vector. + Length() Expression + // Reverse creates an expression that reverses a string, or array. + Reverse() Expression + // Concat creates an expression that concatenates expressions together. + // + // The parameter 'others' can be a list of constants (e.g., string, int) or [Expression]. + Concat(others ...any) Expression + + // Key functions + // GetCollectionID creates an expression that returns the ID of the collection that contains the document. + GetCollectionID() Expression + // GetDocumentID creates an expression that returns the ID of the document. + GetDocumentID() Expression + + // Logical functions + // IfError creates an expression that evaluates and returns the receiver expression if it does not produce an error; + // otherwise, it evaluates and returns `catchExprOrValue`. + // + // The parameter 'catchExprOrValue' is the expression or value to return if the receiver expression errors. + IfError(catchExprOrValue any) Expression + // IfAbsent creates an expression that returns a default value if an expression evaluates to an absent value. + // + // The parameter 'catchExprOrValue' is the value to return if the expression is absent. + // It can be a constant or an [Expression]. + IfAbsent(catchExprOrValue any) Expression + + // Object functions + // MapGet creates an expression that accesses a value from a map (object) field using the provided key. + // + // The parameter 'strOrExprkey' is the key to access in the map. + // It can be a string constant or an [Expression] that evaluates to a string. + MapGet(strOrExprkey any) Expression + // MapMerge creates an expression that merges multiple maps into a single map. + // If multiple maps have the same key, the later value is used. + // + // The parameter 'secondMap' is an [Expression] representing the second map. + // The parameter 'otherMaps' is a list of additional [Expression]s representing maps to merge. + MapMerge(secondMap Expression, otherMaps ...Expression) Expression + // MapRemove creates an expression that removes a key from a map. + // + // The parameter 'strOrExprkey' is the key to remove from the map. + // It can be a string constant or an [Expression] that evaluates to a string. + MapRemove(strOrExprkey any) Expression + + // Aggregators + // Sum creates an aggregate function that calculates the sum of the expression. + Sum() AggregateFunction + // Average creates an aggregate function that calculates the average of the expression. + Average() AggregateFunction + // Count creates an aggregate function that counts the number of documents. + Count() AggregateFunction + + // String functions + // ByteLength creates an expression that calculates the length of a string represented by a field or [Expression] in UTF-8 + // bytes. + ByteLength() Expression + // CharLength creates an expression that calculates the character length of a string field or expression in UTF8. + CharLength() Expression + // EndsWith creates a boolean expression that checks if the string expression ends with the specified suffix. + // + // The parameter 'suffix' can be a string constant or an [Expression] that evaluates to a string. + EndsWith(suffix any) BooleanExpression + // Like creates a boolean expression that checks if the string expression matches the specified pattern. + // + // The parameter 'suffix' can be a string constant or an [Expression] that evaluates to a string. + Like(suffix any) BooleanExpression + // RegexContains creates a boolean expression that checks if the string expression contains a match for the specified regex pattern. + // + // The parameter 'pattern' can be a string constant or an [Expression] that evaluates to a string. + RegexContains(pattern any) BooleanExpression + // RegexMatch creates a boolean expression that checks if the string expression matches the specified regex pattern. + // + // The parameter 'pattern' can be a string constant or an [Expression] that evaluates to a string. + RegexMatch(pattern any) BooleanExpression + // StartsWith creates a boolean expression that checks if the string expression starts with the specified prefix. + // + // The parameter 'prefix' can be a string constant or an [Expression] that evaluates to a string. + StartsWith(prefix any) BooleanExpression + // StringConcat creates an expression that concatenates multiple strings into a single string. + // + // The parameter 'otherStrings' can be a mix of string constants or [Expression]s that evaluate to strings. + StringConcat(otherStrings ...any) Expression + // StringContains creates a boolean expression that checks if the string expression contains the specified substring. + // + // The parameter 'substring' can be a string constant or an [Expression] that evaluates to a string. + StringContains(substring any) BooleanExpression + // StringReverse creates an expression that reverses a string. + StringReverse() Expression + // Join creates an expression that joins the elements of a string array into a single string. + // + // The parameter 'delimiter' can be a string constant or an [Expression] that evaluates to a string. + Join(delimiter any) Expression + // Substring creates an expression that returns a substring of a string. + // + // The parameter 'index' is the starting index of the substring. + // It can be an integer constant or an [Expression] that evaluates to an integer. + // The parameter 'offset' is the length of the substring. + // It can be an integer constant or an [Expression] that evaluates to an integer. + Substring(index, offset any) Expression + // ToLower creates an expression that converts a string to lowercase. + ToLower() Expression + // ToUpper creates an expression that converts a string to uppercase. + ToUpper() Expression + // Trim creates an expression that removes leading and trailing whitespace from a string. + Trim() Expression + // Split creates an expression that splits a string by a delimiter. + // + // The parameter 'delimiter' can be a string constant or an [Expression] that evaluates to a string. + Split(delimiter any) Expression + + // Type creates an expression that returns the type of the expression. + Type() Expression + + // Vector functions + // CosineDistance creates an expression that calculates the cosine distance between two vectors. + // + // The parameter 'other' can be [Vector32], [Vector64], []float32, []float64 or an [Expression] that evaluates to a vector. + CosineDistance(other any) Expression + // DotProduct creates an expression that calculates the dot product of two vectors. + // + // The parameter 'other' can be [Vector32], [Vector64], []float32, []float64 or an [Expression] that evaluates to a vector. + DotProduct(other any) Expression + // EuclideanDistance creates an expression that calculates the euclidean distance between two vectors. + // + // The parameter 'other' can be [Vector32], [Vector64], []float32, []float64 or an [Expression] that evaluates to a vector. + EuclideanDistance(other any) Expression + // VectorLength creates an expression that calculates the length of a vector. + VectorLength() Expression + + // Ordering + // Ascending creates an ordering expression for ascending order. + Ascending() Ordering + // Descending creates an ordering expression for descending order. + Descending() Ordering + + // As assigns an alias to an expression. + // Aliases are useful for renaming fields in the output of a stage. + As(alias string) Selectable +} + +// baseExpression provides common methods for all Expr implementations, allowing for method chaining. +type baseExpression struct { + pbVal *pb.Value + err error +} + +func (b *baseExpression) isExpr() {} +func (b *baseExpression) toProto() (*pb.Value, error) { return b.pbVal, b.err } +func (b *baseExpression) getBaseExpr() *baseExpression { return b } + +// Aritmetic functions +func (b *baseExpression) Add(other any) Expression { return Add(b, other) } +func (b *baseExpression) Subtract(other any) Expression { return Subtract(b, other) } +func (b *baseExpression) Multiply(other any) Expression { return Multiply(b, other) } +func (b *baseExpression) Divide(other any) Expression { return Divide(b, other) } +func (b *baseExpression) Abs() Expression { return Abs(b) } +func (b *baseExpression) Floor() Expression { return Floor(b) } +func (b *baseExpression) Ceil() Expression { return Ceil(b) } +func (b *baseExpression) Exp() Expression { return Exp(b) } +func (b *baseExpression) Log(other any) Expression { return Log(b, other) } +func (b *baseExpression) Log10() Expression { return Log10(b) } +func (b *baseExpression) Ln() Expression { return Ln(b) } +func (b *baseExpression) Mod(other any) Expression { return Mod(b, other) } +func (b *baseExpression) Pow(other any) Expression { return Pow(b, other) } +func (b *baseExpression) Round() Expression { return Round(b) } +func (b *baseExpression) Sqrt() Expression { return Sqrt(b) } + +// Array functions +func (b *baseExpression) ArrayContains(value any) BooleanExpression { return ArrayContains(b, value) } +func (b *baseExpression) ArrayContainsAll(values any) BooleanExpression { + return ArrayContainsAll(b, values) +} +func (b *baseExpression) ArrayContainsAny(values any) BooleanExpression { + return ArrayContainsAny(b, values) +} +func (b *baseExpression) ArrayLength() Expression { return ArrayLength(b) } +func (b *baseExpression) EqualAny(values any) BooleanExpression { return EqualAny(b, values) } +func (b *baseExpression) NotEqualAny(values any) BooleanExpression { return NotEqualAny(b, values) } +func (b *baseExpression) ArrayGet(offset any) Expression { return ArrayGet(b, offset) } +func (b *baseExpression) ArrayReverse() Expression { return ArrayReverse(b) } +func (b *baseExpression) ArrayConcat(otherArrays ...any) Expression { + return ArrayConcat(b, otherArrays...) +} +func (b *baseExpression) ArraySum() Expression { return ArraySum(b) } +func (b *baseExpression) ArrayMaximum() Expression { return ArrayMaximum(b) } +func (b *baseExpression) ArrayMinimum() Expression { return ArrayMinimum(b) } + +// Timestamp functions +func (b *baseExpression) TimestampAdd(unit, amount any) Expression { + return TimestampAdd(b, unit, amount) +} +func (b *baseExpression) TimestampSubtract(unit, amount any) Expression { + return TimestampSubtract(b, unit, amount) +} +func (b *baseExpression) TimestampTruncate(granularity any) Expression { + return TimestampTruncate(b, granularity) +} +func (b *baseExpression) TimestampTruncateWithTimezone(granularity any, timezone string) Expression { + return TimestampTruncateWithTimezone(b, granularity, timezone) +} +func (b *baseExpression) TimestampToUnixMicros() Expression { return TimestampToUnixMicros(b) } +func (b *baseExpression) TimestampToUnixMillis() Expression { return TimestampToUnixMillis(b) } +func (b *baseExpression) TimestampToUnixSeconds() Expression { return TimestampToUnixSeconds(b) } +func (b *baseExpression) UnixMicrosToTimestamp() Expression { return UnixMicrosToTimestamp(b) } +func (b *baseExpression) UnixMillisToTimestamp() Expression { return UnixMillisToTimestamp(b) } +func (b *baseExpression) UnixSecondsToTimestamp() Expression { return UnixSecondsToTimestamp(b) } + +// Comparison functions +func (b *baseExpression) Equal(other any) BooleanExpression { return Equal(b, other) } +func (b *baseExpression) NotEqual(other any) BooleanExpression { return NotEqual(b, other) } +func (b *baseExpression) GreaterThan(other any) BooleanExpression { return GreaterThan(b, other) } +func (b *baseExpression) GreaterThanOrEqual(other any) BooleanExpression { + return GreaterThanOrEqual(b, other) +} +func (b *baseExpression) LessThan(other any) BooleanExpression { return LessThan(b, other) } +func (b *baseExpression) LessThanOrEqual(other any) BooleanExpression { + return LessThanOrEqual(b, other) +} + +// General functions +func (b *baseExpression) Length() Expression { return Length(b) } +func (b *baseExpression) Reverse() Expression { return Reverse(b) } +func (b *baseExpression) Concat(others ...any) Expression { return Concat(b, others...) } + +// Key functions +func (b *baseExpression) GetCollectionID() Expression { return GetCollectionID(b) } +func (b *baseExpression) GetDocumentID() Expression { return GetDocumentID(b) } + +// Logical functions +func (b *baseExpression) IfError(catchExprOrValue any) Expression { + return IfError(b, catchExprOrValue) +} +func (b *baseExpression) IfAbsent(catchExprOrValue any) Expression { + return IfAbsent(b, catchExprOrValue) +} + +// Object functions +func (b *baseExpression) MapGet(strOrExprkey any) Expression { return MapGet(b, strOrExprkey) } +func (b *baseExpression) MapMerge(secondMap Expression, otherMaps ...Expression) Expression { + return MapMerge(b, secondMap, otherMaps...) +} +func (b *baseExpression) MapRemove(strOrExprkey any) Expression { return MapRemove(b, strOrExprkey) } + +// Aggregation operations +func (b *baseExpression) Sum() AggregateFunction { return Sum(b) } +func (b *baseExpression) Average() AggregateFunction { return Average(b) } +func (b *baseExpression) Count() AggregateFunction { return Count(b) } +func (b *baseExpression) CountDistinct() AggregateFunction { return CountDistinct(b) } +func (b *baseExpression) Maximum() AggregateFunction { return Maximum(b) } +func (b *baseExpression) Minimum() AggregateFunction { return Minimum(b) } + +// String functions +func (b *baseExpression) ByteLength() Expression { return ByteLength(b) } +func (b *baseExpression) CharLength() Expression { return CharLength(b) } +func (b *baseExpression) EndsWith(suffix any) BooleanExpression { return EndsWith(b, suffix) } +func (b *baseExpression) Like(suffix any) BooleanExpression { return Like(b, suffix) } +func (b *baseExpression) RegexContains(pattern any) BooleanExpression { + return RegexContains(b, pattern) +} +func (b *baseExpression) RegexMatch(pattern any) BooleanExpression { return RegexMatch(b, pattern) } +func (b *baseExpression) StartsWith(prefix any) BooleanExpression { return StartsWith(b, prefix) } +func (b *baseExpression) StringConcat(otherStrings ...any) Expression { + return StringConcat(b, otherStrings...) +} +func (b *baseExpression) StringContains(substring any) BooleanExpression { + return StringContains(b, substring) +} +func (b *baseExpression) StringReverse() Expression { return StringReverse(b) } +func (b *baseExpression) Join(delimiter any) Expression { return Join(b, delimiter) } +func (b *baseExpression) Substring(index, offset any) Expression { return Substring(b, index, offset) } +func (b *baseExpression) ToLower() Expression { return ToLower(b) } +func (b *baseExpression) ToUpper() Expression { return ToUpper(b) } +func (b *baseExpression) Trim() Expression { return Trim(b) } +func (b *baseExpression) Split(delimiter any) Expression { return Split(b, delimiter) } + +// Type functions +func (b *baseExpression) Type() Expression { return Type(b) } + +// Vector functions +func (b *baseExpression) CosineDistance(other any) Expression { return CosineDistance(b, other) } +func (b *baseExpression) DotProduct(other any) Expression { return DotProduct(b, other) } +func (b *baseExpression) EuclideanDistance(other any) Expression { return EuclideanDistance(b, other) } +func (b *baseExpression) VectorLength() Expression { return VectorLength(b) } + +// Ordering +func (b *baseExpression) Ascending() Ordering { return Ascending(b) } +func (b *baseExpression) Descending() Ordering { return Descending(b) } + +func (b *baseExpression) As(alias string) Selectable { + return newAliasedExpr(b, alias) +} + +// Ensure that baseExpr implements the Expr interface. +var _ Expression = (*baseExpression)(nil) + +// AliasedExpression represents an expression with an alias. +// It implements the [Selectable] interface, allowing it to be used in projection stages like `Select` and `AddFields`. +type AliasedExpression struct { + *baseExpression + alias string +} + +func newAliasedExpr(expr Expression, alias string) *AliasedExpression { + return &AliasedExpression{baseExpression: expr.getBaseExpr(), alias: alias} +} + +// getSelectionDetails returns the alias and the underlying expression for this AliasedExpr. +// This method allows AliasedExpr to satisfy the Selectable interface. +func (e *AliasedExpression) getSelectionDetails() (string, Expression) { + return e.alias, e.baseExpression +} + +func (e *AliasedExpression) isSelectable() {} + +// Ensure that AliasedExpr implements the Selectable interface. +var _ Selectable = (*AliasedExpression)(nil) diff --git a/firestore/pipeline_field.go b/firestore/pipeline_field.go new file mode 100644 index 000000000000..e75137810be4 --- /dev/null +++ b/firestore/pipeline_field.go @@ -0,0 +1,65 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package firestore + +import ( + pb "cloud.google.com/go/firestore/apiv1/firestorepb" +) + +// field represents a reference to a field in a Firestore document, or outputs of a [Pipeline] stage. +// It implements the [Expression] and [Selectable] interfaces. +// +// Field references are used to access document field values in expressions and to specify fields +// for sorting, filtering, and projecting data in Firestore pipelines. +type field struct { + *baseExpression + fieldPath FieldPath +} + +// FieldOf creates a new field [Expression] from a dot separated field path string or [FieldPath]. +func FieldOf[T string | FieldPath](path T) Expression { + var fieldPath FieldPath + switch p := any(path).(type) { + case string: + fp, err := parseDotSeparatedString(p) + if err != nil { + return &field{baseExpression: &baseExpression{err: err}} + } + fieldPath = fp + case FieldPath: + fieldPath = p + } + + if err := fieldPath.validate(); err != nil { + return &field{baseExpression: &baseExpression{err: err}} + } + pbVal := &pb.Value{ + ValueType: &pb.Value_FieldReferenceValue{ + FieldReferenceValue: fieldPath.toServiceFieldPath(), + }, + } + return &field{fieldPath: fieldPath, baseExpression: &baseExpression{pbVal: pbVal}} +} + +// getSelectionDetails returns the field path string as the default alias and the field expression itself. +// This allows a field [Expression] to satisfy the [Selectable] interface, making it directly usable +// in `Select` or `AddFields` stages without explicit aliasing if the original field name is desired. +func (f *field) getSelectionDetails() (string, Expression) { + // For Selectable, the alias is the field path itself if not otherwise aliased by `As`. + // This makes `FieldOf("name")` selectable as "name". + return f.fieldPath.toServiceFieldPath(), f +} + +func (f *field) isSelectable() {} diff --git a/firestore/pipeline_filter_condition.go b/firestore/pipeline_filter_condition.go new file mode 100644 index 000000000000..f126b15eae51 --- /dev/null +++ b/firestore/pipeline_filter_condition.go @@ -0,0 +1,346 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package firestore + +// BooleanExpression is an interface that represents a boolean expression in a pipeline. +type BooleanExpression interface { + Expression // Embed Expr interface + isBooleanExpr() + + // Conditional creates an expression that evaluates a condition and returns one of two expressions. + // + // The parameter 'thenVal' is the expression to return if the condition is true. + // The parameter 'elseVal' is the expression to return if the condition is false. + Conditional(thenVal, elseVal any) Expression + // IfErrorBoolean creates a boolean expression that evaluates and returns the receiver expression if it does not produce an error; + // otherwise, it evaluates and returns `catchExpr`. + // + // The parameter 'catchExpr' is the boolean expression to return if the receiver expression errors. + IfErrorBoolean(catchExpr BooleanExpression) BooleanExpression + // Not creates an expression that negates a boolean expression. + Not() BooleanExpression + // CountIf creates an aggregation that counts the number of stage inputs where the this boolean expression + // evaluates to true. + CountIf() AggregateFunction +} + +// baseBooleanExpression provides common methods for all BooleanExpr implementations. +type baseBooleanExpression struct { + *baseFunction // Embed Function to get Expr methods and toProto +} + +func (b *baseBooleanExpression) isBooleanExpr() {} +func (b *baseBooleanExpression) Conditional(thenVal, elseVal any) Expression { + return Conditional(b, thenVal, elseVal) +} +func (b *baseBooleanExpression) IfErrorBoolean(catchExpr BooleanExpression) BooleanExpression { + return IfErrorBoolean(b, catchExpr) +} +func (b *baseBooleanExpression) Not() BooleanExpression { + return Not(b) +} +func (b *baseBooleanExpression) CountIf() AggregateFunction { + return CountIf(b) +} + +// Ensure that baseBooleanExpr implements the BooleanExpr interface. +var _ BooleanExpression = (*baseBooleanExpression)(nil) + +// ArrayContains creates an expression that checks if an array contains a specified element. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expression] that evaluates to an array. +// - value is the element to check for. +// +// Example: +// +// // Check if the 'tags' array contains "Go". +// ArrayContains("tags", "Go") +func ArrayContains(exprOrFieldPath any, value any) BooleanExpression { + return &baseBooleanExpression{baseFunction: newBaseFunction("array_contains", []Expression{asFieldExpr(exprOrFieldPath), toExprOrConstant(value)})} +} + +// ArrayContainsAll creates an expression that checks if an array contains all of the provided values. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expression] that evaluates to an array. +// - values can be an array of values or an expression that evaluates to an array. +// +// Example: +// +// // Check if the 'tags' array contains both "Go" and "Firestore". +// ArrayContainsAll("tags", []string{"Go", "Firestore"}) +func ArrayContainsAll(exprOrFieldPath any, values any) BooleanExpression { + return newFieldAndArrayBooleanExpr("array_contains_all", exprOrFieldPath, values) +} + +// ArrayContainsAny creates an expression that checks if an array contains any of the provided values. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expression] that evaluates to an array. +// - values can be an array of values or an expression that evaluates to an array. +// +// Example: +// +// // Check if the 'tags' array contains either "Go" or "Firestore". +// ArrayContainsAny("tags", []string{"Go", "Firestore"}) +func ArrayContainsAny(exprOrFieldPath any, values any) BooleanExpression { + return newFieldAndArrayBooleanExpr("array_contains_any", exprOrFieldPath, values) +} + +// EqualAny creates an expression that checks if a field or expression is equal to any of the provided values. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expression]. +// - values can be an array of values or an expression that evaluates to an array. +// +// Example: +// +// // Check if the 'status' field is either "active" or "pending". +// EqualAny("status", []string{"active", "pending"}) +func EqualAny(exprOrFieldPath any, values any) BooleanExpression { + return newFieldAndArrayBooleanExpr("equal_any", exprOrFieldPath, values) +} + +// NotEqualAny creates an expression that checks if a field or expression is not equal to any of the provided values. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expression]. +// - values can be an array of values or an expression that evaluates to an array. +// +// Example: +// +// // Check if the 'status' field is not "archived" or "deleted". +// NotEqualAny("status", []string{"archived", "deleted"}) +func NotEqualAny(exprOrFieldPath any, values any) BooleanExpression { + return newFieldAndArrayBooleanExpr("not_equal_any", exprOrFieldPath, values) +} + +// Equal creates an expression that checks if field's value or an expression is equal to an expression or a constant value, +// returning it as a BooleanExpr. +// - left: The field path string, [FieldPath] or [Expression] to compare. +// - right: The constant value or [Expression] to compare to. +// +// Example: +// +// // Check if the 'age' field is equal to 21 +// Equal(FieldOf("age"), 21) +// +// // Check if the 'age' field is equal to an expression +// Equal(FieldOf("age"), FieldOf("minAge").Add(10)) +// +// // Check if the 'age' field is equal to the 'limit' field +// Equal("age", FieldOf("limit")) +// +// // Check if the 'city' field is equal to string constant "London" +// Equal("city", "London") +func Equal(left, right any) BooleanExpression { + return &baseBooleanExpression{baseFunction: leftRightToBaseFunction("equal", left, right)} +} + +// NotEqual creates an expression that checks if field's value or an expression is not equal to an expression or a constant value, +// returning it as a BooleanExpr. +// - left: The field path string, [FieldPath] or [Expression] to compare. +// - right: The constant value or [Expression] to compare to. +// +// Example: +// +// // Check if the 'age' field is not equal to 21 +// NotEqual(FieldOf("age"), 21) +// +// // Check if the 'age' field is not equal to an expression +// NotEqual(FieldOf("age"), FieldOf("minAge").Add(10)) +// +// // Check if the 'age' field is not equal to the 'limit' field +// NotEqual("age", FieldOf("limit")) +// +// // Check if the 'city' field is not equal to string constant "London" +// NotEqual("city", "London") +func NotEqual(left, right any) BooleanExpression { + return &baseBooleanExpression{baseFunction: leftRightToBaseFunction("not_equal", left, right)} +} + +// GreaterThan creates an expression that checks if field's value or an expression is greater than an expression or a constant value, +// returning it as a BooleanExpr. +// - left: The field path string, [FieldPath] or [Expression] to compare. +// - right: The constant value or [Expression] to compare to. +// +// Example: +// +// // Check if the 'age' field is greater than 21 +// GreaterThan(FieldOf("age"), 21) +// +// // Check if the 'age' field is greater than an expression +// GreaterThan(FieldOf("age"), FieldOf("minAge").Add(10)) +// +// // Check if the 'age' field is greater than the 'limit' field +// GreaterThan("age", FieldOf("limit")) +func GreaterThan(left, right any) BooleanExpression { + return &baseBooleanExpression{baseFunction: leftRightToBaseFunction("greater_than", left, right)} +} + +// GreaterThanOrEqual creates an expression that checks if field's value or an expression is greater than or equal to an expression or a constant value, +// returning it as a BooleanExpr. +// - left: The field path string, [FieldPath] or [Expression] to compare. +// - right: The constant value or [Expression] to compare to. +// +// Example: +// +// // Check if the 'age' field is greater than or equal to 21 +// GreaterThanOrEqual(FieldOf("age"), 21) +// +// // Check if the 'age' field is greater than or equal to an expression +// GreaterThanOrEqual(FieldOf("age"), FieldOf("minAge").Add(10)) +// +// // Check if the 'age' field is greater than or equal to the 'limit' field +// GreaterThanOrEqual("age", FieldOf("limit")) +func GreaterThanOrEqual(left, right any) BooleanExpression { + return &baseBooleanExpression{baseFunction: leftRightToBaseFunction("greater_than_or_equal", left, right)} +} + +// LessThan creates an expression that checks if field's value or an expression is less than an expression or a constant value, +// returning it as a BooleanExpr. +// - left: The field path string, [FieldPath] or [Expression] to compare. +// - right: The constant value or [Expression] to compare to. +// +// Example: +// +// // Check if the 'age' field is less than 21 +// LessThan(FieldOf("age"), 21) +// +// // Check if the 'age' field is less than an expression +// LessThan(FieldOf("age"), FieldOf("minAge").Add(10)) +// +// // Check if the 'age' field is less than the 'limit' field +// LessThan("age", FieldOf("limit")) +func LessThan(left, right any) BooleanExpression { + return &baseBooleanExpression{baseFunction: leftRightToBaseFunction("less_than", left, right)} +} + +// LessThanOrEqual creates an expression that checks if field's value or an expression is less than or equal to an expression or a constant value, +// returning it as a BooleanExpr. +// - left: The field path string, [FieldPath] or [Expression] to compare. +// - right: The constant value or [Expression] to compare to. +// +// Example: +// +// // Check if the 'age' field is less than or equal to 21 +// LessThanOrEqual(FieldOf("age"), 21) +// +// // Check if the 'age' field is less than or equal to an expression +// LessThanOrEqual(FieldOf("age"), FieldOf("minAge").Add(10)) +// +// // Check if the 'age' field is less than or equal to the 'limit' field +// LessThanOrEqual("age", FieldOf("limit")) +func LessThanOrEqual(left, right any) BooleanExpression { + return &baseBooleanExpression{baseFunction: leftRightToBaseFunction("less_than_or_equal", left, right)} +} + +// EndsWith creates an expression that checks if a string field or expression ends with a given suffix. +// - exprOrFieldPath can be a field path string, [FieldPath] or [Expression]. +// - suffix string or [Expression] to check for. +// +// Example: +// +// // Check if the 'filename' field ends with ".go". +// EndsWith("filename", ".go") +func EndsWith(exprOrFieldPath any, suffix any) BooleanExpression { + return &baseBooleanExpression{baseFunction: newBaseFunction("ends_with", []Expression{asFieldExpr(exprOrFieldPath), asStringExpr(suffix)})} +} + +// Like creates an expression that performs a case-sensitive wildcard string comparison. +// - exprOrFieldPath can be a field path string, [FieldPath] or [Expression]. +// - pattern string or [Expression] to search for. You can use "%" as a wildcard character. +// +// Example: +// +// // Check if the 'name' field starts with "G". +// Like("name", "G%") +func Like(exprOrFieldPath any, pattern any) BooleanExpression { + return &baseBooleanExpression{baseFunction: newBaseFunction("like", []Expression{asFieldExpr(exprOrFieldPath), asStringExpr(pattern)})} +} + +// RegexContains creates an expression that checks if a string contains a match for a regular expression. +// - exprOrFieldPath can be a field path string, [FieldPath] or [Expression]. +// - pattern is the regular expression to search for. +// +// Example: +// +// // Check if the 'email' field contains a gmail address. +// RegexContains("email", "@gmail\\.com$") +func RegexContains(exprOrFieldPath any, pattern any) BooleanExpression { + return &baseBooleanExpression{baseFunction: newBaseFunction("regex_contains", []Expression{asFieldExpr(exprOrFieldPath), asStringExpr(pattern)})} +} + +// RegexMatch creates an expression that checks if a string matches a regular expression. +// - exprOrFieldPath can be a field path string, [FieldPath] or [Expression]. +// - pattern is the regular expression to match against. +// +// Example: +// +// // Check if the 'zip_code' field is a 5-digit number. +// RegexMatch("zip_code", "^[0-9]{5}$") +func RegexMatch(exprOrFieldPath any, pattern any) BooleanExpression { + return &baseBooleanExpression{baseFunction: newBaseFunction("regex_match", []Expression{asFieldExpr(exprOrFieldPath), asStringExpr(pattern)})} +} + +// StartsWith creates an expression that checks if a string field or expression starts with a given prefix. +// - exprOrFieldPath can be a field path string, [FieldPath] or [Expression]. +// - prefix string or [Expression] to check for. +// +// Example: +// +// // Check if the 'name' field starts with "Mr.". +// StartsWith("name", "Mr.") +func StartsWith(exprOrFieldPath any, prefix any) BooleanExpression { + return &baseBooleanExpression{baseFunction: newBaseFunction("starts_with", []Expression{asFieldExpr(exprOrFieldPath), asStringExpr(prefix)})} +} + +// StringContains creates an expression that checks if a string contains a specified substring. +// - exprOrFieldPath can be a field path string, [FieldPath] or [Expression]. +// - substring is the string to search for. +// +// Example: +// +// // Check if the 'description' field contains the word "Firestore". +// StringContains("description", "Firestore") +func StringContains(exprOrFieldPath any, substring any) BooleanExpression { + return &baseBooleanExpression{baseFunction: newBaseFunction("string_contains", []Expression{asFieldExpr(exprOrFieldPath), asStringExpr(substring)})} +} + +// And creates an expression that performs a logical 'AND' operation. +func And(condition BooleanExpression, right ...BooleanExpression) BooleanExpression { + return &baseBooleanExpression{baseFunction: newBaseFunctionFromBooleans("and", append([]BooleanExpression{condition}, right...))} +} + +// FieldExists creates an expression that checks if a field exists. +func FieldExists(exprOrField any) BooleanExpression { + return &baseBooleanExpression{baseFunction: newBaseFunction("exists", []Expression{asFieldExpr(exprOrField)})} +} + +// Not creates an expression that negates a boolean expression. +func Not(condition BooleanExpression) BooleanExpression { + return &baseBooleanExpression{baseFunction: newBaseFunction("not", []Expression{condition})} +} + +// Or creates an expression that performs a logical 'OR' operation. +func Or(condition BooleanExpression, right ...BooleanExpression) BooleanExpression { + return &baseBooleanExpression{baseFunction: newBaseFunctionFromBooleans("or", append([]BooleanExpression{condition}, right...))} +} + +// Xor creates an expression that performs a logical 'XOR' operation. +func Xor(condition BooleanExpression, right ...BooleanExpression) BooleanExpression { + return &baseBooleanExpression{baseFunction: newBaseFunctionFromBooleans("xor", append([]BooleanExpression{condition}, right...))} +} + +// IsError creates an expression that checks if an expression evaluates to an error. +func IsError(expr Expression) BooleanExpression { + return &baseBooleanExpression{baseFunction: newBaseFunction("is_error", []Expression{expr})} +} + +// IsAbsent creates an expression that checks if an expression evaluates to an absent value. +func IsAbsent(exprOrField any) BooleanExpression { + return &baseBooleanExpression{baseFunction: newBaseFunction("is_absent", []Expression{asFieldExpr(exprOrField)})} +} diff --git a/firestore/pipeline_function.go b/firestore/pipeline_function.go new file mode 100644 index 000000000000..6fca3f134e98 --- /dev/null +++ b/firestore/pipeline_function.go @@ -0,0 +1,540 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package firestore + +import ( + "fmt" + + pb "cloud.google.com/go/firestore/apiv1/firestorepb" +) + +// FunctionExpression represents Firestore [Pipeline] functions, which can be evaluated within pipeline +// execution. +type FunctionExpression interface { + Expression + isFunction() +} + +type baseFunction struct { + *baseExpression +} + +func (b *baseFunction) isFunction() {} + +// Ensure that *baseFunction implements the FunctionExpression interface. +var _ FunctionExpression = (*baseFunction)(nil) + +func newBaseFunction(name string, params []Expression) *baseFunction { + argsPbVals := make([]*pb.Value, 0, len(params)) + for i, param := range params { + paramExpr := asFieldExpr(param) + pbVal, err := paramExpr.toProto() + if err != nil { + return &baseFunction{baseExpression: &baseExpression{err: fmt.Errorf("firestore: error converting arg %d for function %q: %w", i, name, err)}} + } + argsPbVals = append(argsPbVals, pbVal) + } + pbVal := &pb.Value{ValueType: &pb.Value_FunctionValue{ + FunctionValue: &pb.Function{ + Name: name, + Args: argsPbVals, + }, + }} + + return &baseFunction{baseExpression: &baseExpression{pbVal: pbVal}} +} + +func newBaseFunctionFromBooleans(name string, params []BooleanExpression) *baseFunction { + exprs := make([]Expression, len(params)) + for i, p := range params { + exprs[i] = p + } + return newBaseFunction(name, exprs) +} + +// Add creates an expression that adds two expressions together, returning it as an Expr. +// - left can be a field path string, [FieldPath] or [Expression]. +// - right can be a numeric constant or a numeric [Expression]. +func Add(left, right any) Expression { + return leftRightToBaseFunction("add", left, right) +} + +// Subtract creates an expression that subtracts the right expression from the left expression, returning it as an Expr. +// - left can be a field path string, [FieldPath] or [Expression]. +// - right can be a constant or an [Expression]. +func Subtract(left, right any) Expression { + return leftRightToBaseFunction("subtract", left, right) +} + +// Multiply creates an expression that multiplies the left and right expressions, returning it as an Expr. +// - left can be a field path string, [FieldPath] or [Expression]. +// - right can be a constant or an [Expression]. +func Multiply(left, right any) Expression { + return leftRightToBaseFunction("multiply", left, right) +} + +// Divide creates an expression that divides the left expression by the right expression, returning it as an Expr. +// - left can be a field path string, [FieldPath] or [Expression]. +// - right can be a constant or an [Expression]. +func Divide(left, right any) Expression { + return leftRightToBaseFunction("divide", left, right) +} + +// Abs creates an expression that is the absolute value of the input field or expression. +// - numericExprOrFieldPath can be a field path string, [FieldPath] or an [Expression] that returns a number when evaluated. +func Abs(numericExprOrFieldPath any) Expression { + return newBaseFunction("abs", []Expression{asFieldExpr(numericExprOrFieldPath)}) +} + +// Floor creates an expression that is the largest integer that isn't less than the input field or expression. +// - numericExprOrFieldPath can be a field path string, [FieldPath] or an [Expression] that returns a number when evaluated. +func Floor(numericExprOrFieldPath any) Expression { + return newBaseFunction("floor", []Expression{asFieldExpr(numericExprOrFieldPath)}) +} + +// Ceil creates an expression that is the smallest integer that isn't less than the input field or expression. +// - numericExprOrFieldPath can be a field path string, [FieldPath] or an [Expression] that returns a number when evaluated. +func Ceil(numericExprOrFieldPath any) Expression { + return newBaseFunction("ceil", []Expression{asFieldExpr(numericExprOrFieldPath)}) +} + +// Exp creates an expression that is the Euler's number e raised to the power of the input field or expression. +// - numericExprOrFieldPath can be a field path string, [FieldPath] or an [Expression] that returns a number when evaluated. +func Exp(numericExprOrFieldPath any) Expression { + return newBaseFunction("exp", []Expression{asFieldExpr(numericExprOrFieldPath)}) +} + +// Log creates an expression that is logarithm of the left expression to base as the right expression, returning it as an Expr. +// - left can be a field path string, [FieldPath] or [Expression]. +// - right can be a constant or an [Expression]. +func Log(left, right any) Expression { + return leftRightToBaseFunction("log", left, right) +} + +// Log10 creates an expression that is the base 10 logarithm of the input field or expression. +// - numericExprOrFieldPath can be a field path string, [FieldPath] or an [Expression] that returns a number when evaluated. +func Log10(numericExprOrFieldPath any) Expression { + return newBaseFunction("log10", []Expression{asFieldExpr(numericExprOrFieldPath)}) +} + +// Ln creates an expression that is the natural logarithm (base e) of the input field or expression. +// - numericExprOrFieldPath can be a field path string, [FieldPath] or an [Expression] that returns a number when evaluated. +func Ln(numericExprOrFieldPath any) Expression { + return newBaseFunction("ln", []Expression{asFieldExpr(numericExprOrFieldPath)}) +} + +// Mod creates an expression that computes the modulo of the left expression by the right expression, returning it as an Expr. +// - left can be a field path string, [FieldPath] or [Expression]. +// - right can be a constant or an [Expression]. +func Mod(left, right any) Expression { + return leftRightToBaseFunction("mod", left, right) +} + +// Pow creates an expression that computes the left expression raised to the power of the right expression, returning it as an Expr. +// - left can be a field path string, [FieldPath] or [Expression]. +// - right can be a constant or an [Expression]. +func Pow(left, right any) Expression { + return leftRightToBaseFunction("pow", left, right) +} + +// Round creates an expression that rounds the input field or expression to nearest integer. +// - numericExprOrFieldPath can be a field path string, [FieldPath] or an [Expression] that returns a number when evaluated. +func Round(numericExprOrFieldPath any) Expression { + return newBaseFunction("round", []Expression{asFieldExpr(numericExprOrFieldPath)}) +} + +// Sqrt creates an expression that is the square root of the input field or expression. +// - numericExprOrFieldPath can be a field path string, [FieldPath] or an [Expression] that returns a number when evaluated. +func Sqrt(numericExprOrFieldPath any) Expression { + return newBaseFunction("sqrt", []Expression{asFieldExpr(numericExprOrFieldPath)}) +} + +// TimestampAdd creates an expression that adds a specified amount of time to a timestamp. +// - timestamp can be a field path string, [FieldPath] or [Expression]. +// - unit can be a string or an [Expression]. Valid units include "microsecond", "millisecond", "second", "minute", "hour" and "day". +// - amount can be an int, int32, int64 or [Expression]. +func TimestampAdd(timestamp, unit, amount any) Expression { + return newBaseFunction("timestamp_add", []Expression{asFieldExpr(timestamp), asStringExpr(unit), asInt64Expr(amount)}) +} + +// TimestampSubtract creates an expression that subtracts a specified amount of time from a timestamp. +// - timestamp can be a field path string, [FieldPath] or [Expression]. +// - unit can be a string or an [Expression]. Valid units include "microsecond", "millisecond", "second", "minute", "hour" and "day". +// - amount can be an int, int32, int64 or [Expression]. +func TimestampSubtract(timestamp, unit, amount any) Expression { + return newBaseFunction("timestamp_subtract", []Expression{asFieldExpr(timestamp), asStringExpr(unit), asInt64Expr(amount)}) +} + +// TimestampTruncate creates an expression that truncates a timestamp to a specified granularity. +// - timestamp can be a field path string, [FieldPath] or [Expression]. +// - granularity can be a string or an [Expression]. Valid values are "microsecond", +// "millisecond", "second", "minute", "hour", "day", "week", "week(monday)", "week(tuesday)", +// "week(wednesday)", "week(thursday)", "week(friday)", "week(saturday)", "week(sunday)", +// "isoweek", "month", "quarter", "year", and "isoyear". +func TimestampTruncate(timestamp, granularity any) Expression { + return newBaseFunction("timestamp_trunc", []Expression{asFieldExpr(timestamp), asStringExpr(granularity)}) +} + +// TimestampTruncateWithTimezone creates an expression that truncates a timestamp to a specified granularity in a given timezone. +// - timestamp can be a field path string, [FieldPath] or [Expression]. +// - granularity can be a string or an [Expression]. Valid values are "microsecond", +// "millisecond", "second", "minute", "hour", "day", "week", "week(monday)", "week(tuesday)", +// "week(wednesday)", "week(thursday)", "week(friday)", "week(saturday)", "week(sunday)", +// "isoweek", "month", "quarter", "year", and "isoyear". +// - timezone can be a string or an [Expression]. Valid values are from the TZ database +// (e.g., "America/Los_Angeles") or in the format "Etc/GMT-1". +func TimestampTruncateWithTimezone(timestamp, granularity any, timezone string) Expression { + return newBaseFunction("timestamp_trunc", []Expression{asFieldExpr(timestamp), asStringExpr(granularity), asStringExpr(timezone)}) +} + +// TimestampToUnixMicros creates an expression that converts a timestamp expression to the number of microseconds since +// the Unix epoch (1970-01-01 00:00:00 UTC). +// - timestamp can be a field path string, [FieldPath] or [Expression]. +func TimestampToUnixMicros(timestamp any) Expression { + return newBaseFunction("timestamp_to_unix_micros", []Expression{asFieldExpr(timestamp)}) +} + +// TimestampToUnixMillis creates an expression that converts a timestamp expression to the number of milliseconds since +// the Unix epoch (1970-01-01 00:00:00 UTC). +// - timestamp can be a field path string, [FieldPath] or [Expression]. +func TimestampToUnixMillis(timestamp any) Expression { + return newBaseFunction("timestamp_to_unix_millis", []Expression{asFieldExpr(timestamp)}) +} + +// TimestampToUnixSeconds creates an expression that converts a timestamp expression to the number of seconds since +// the Unix epoch (1970-01-01 00:00:00 UTC). +// - timestamp can be a field path string, [FieldPath] or [Expression]. +func TimestampToUnixSeconds(timestamp any) Expression { + return newBaseFunction("timestamp_to_unix_seconds", []Expression{asFieldExpr(timestamp)}) +} + +// UnixMicrosToTimestamp creates an expression that converts a Unix timestamp in microseconds to a Firestore timestamp. +// - micros can be a field path string, [FieldPath] or [Expression]. +func UnixMicrosToTimestamp(micros any) Expression { + return newBaseFunction("unix_micros_to_timestamp", []Expression{asFieldExpr(micros)}) +} + +// UnixMillisToTimestamp creates an expression that converts a Unix timestamp in milliseconds to a Firestore timestamp. +// - millis can be a field path string, [FieldPath] or [Expression]. +func UnixMillisToTimestamp(millis any) Expression { + return newBaseFunction("unix_millis_to_timestamp", []Expression{asFieldExpr(millis)}) +} + +// UnixSecondsToTimestamp creates an expression that converts a Unix timestamp in seconds to a Firestore timestamp. +// - seconds can be a field path string, [FieldPath] or [Expression]. +func UnixSecondsToTimestamp(seconds any) Expression { + return newBaseFunction("unix_seconds_to_timestamp", []Expression{asFieldExpr(seconds)}) +} + +// CurrentTimestamp creates an expression that returns the current timestamp. +func CurrentTimestamp() Expression { + return newBaseFunction("current_timestamp", []Expression{}) +} + +// ArrayLength creates an expression that calculates the length of an array. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expression] that evaluates to an array. +func ArrayLength(exprOrFieldPath any) Expression { + return newBaseFunction("array_length", []Expression{asFieldExpr(exprOrFieldPath)}) +} + +// Array creates an expression that represents a Firestore array. +// - elements can be any number of values or expressions that will form the elements of the array. +func Array(elements ...any) Expression { + return newBaseFunction("array", toExprs(elements)) +} + +// ArrayFromSlice creates a new array expression from a slice of elements. +// This function is necessary for creating an array from an existing typed slice (e.g., []int), +// as the [Array] function (which takes variadic arguments) cannot directly accept a typed slice +// using the spread operator (...). It handles the conversion of each element to `any` internally. +func ArrayFromSlice[T any](elements []T) Expression { + return newBaseFunction("array", toExprsFromSlice(elements)) +} + +// ArrayGet creates an expression that retrieves an element from an array at a specified index. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expression] that evaluates to an array. +// - offset is the 0-based index of the element to retrieve. +func ArrayGet(exprOrFieldPath any, offset any) Expression { + return newBaseFunction("array_get", []Expression{asFieldExpr(exprOrFieldPath), asInt64Expr(offset)}) +} + +// ArrayReverse creates an expression that reverses the order of elements in an array. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expression] that evaluates to an array. +func ArrayReverse(exprOrFieldPath any) Expression { + return newBaseFunction("array_reverse", []Expression{asFieldExpr(exprOrFieldPath)}) +} + +// ArrayConcat creates an expression that concatenates multiple arrays into a single array. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expression] that evaluates to an array. +// - otherArrays are the other arrays to concatenate. +func ArrayConcat(exprOrFieldPath any, otherArrays ...any) Expression { + return newBaseFunction("array_concat", append([]Expression{asFieldExpr(exprOrFieldPath)}, toExprs(otherArrays)...)) +} + +// ArraySum creates an expression that calculates the sum of all elements in a numeric array. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expression] that evaluates to a numeric array. +func ArraySum(exprOrFieldPath any) Expression { + return newBaseFunction("sum", []Expression{asFieldExpr(exprOrFieldPath)}) +} + +// ArrayMaximum creates an expression that finds the maximum element in a numeric array. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expression] that evaluates to a numeric array. +func ArrayMaximum(exprOrFieldPath any) Expression { + return newBaseFunction("maximum", []Expression{asFieldExpr(exprOrFieldPath)}) +} + +// ArrayMinimum creates an expression that finds the minimum element in a numeric array. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expression] that evaluates to a numeric array. +func ArrayMinimum(exprOrFieldPath any) Expression { + return newBaseFunction("minimum", []Expression{asFieldExpr(exprOrFieldPath)}) +} + +// ByteLength creates an expression that calculates the length of a string represented by a field or [Expression] in UTF-8 +// bytes. +// - exprOrFieldPath can be a field path string, [FieldPath] or [Expression]. +func ByteLength(exprOrFieldPath any) Expression { + return newBaseFunction("byte_length", []Expression{asFieldExpr(exprOrFieldPath)}) +} + +// CharLength creates an expression that calculates the character length of a string field or expression in UTF8. +// - exprOrFieldPath can be a field path string, [FieldPath] or [Expression]. +func CharLength(exprOrFieldPath any) Expression { + return newBaseFunction("char_length", []Expression{asFieldExpr(exprOrFieldPath)}) +} + +// StringConcat creates an expression that concatenates multiple strings into a single string. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expression] that evaluates to a string. +// - otherStrings are the other strings to concatenate. +func StringConcat(exprOrFieldPath any, otherStrings ...any) Expression { + return newBaseFunction("string_concat", append([]Expression{asFieldExpr(exprOrFieldPath)}, toExprs(otherStrings)...)) +} + +// StringReverse creates an expression that reverses a string. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expression] that evaluates to a string. +func StringReverse(exprOrFieldPath any) Expression { + return newBaseFunction("string_reverse", []Expression{asFieldExpr(exprOrFieldPath)}) +} + +// Join creates an expression that joins the elements of a string array into a single string. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expression] that evaluates to a string array. +// - delimiter is the string to use as a separator between elements. +func Join(exprOrFieldPath any, delimiter any) Expression { + return newBaseFunction("join", []Expression{asFieldExpr(exprOrFieldPath), asStringExpr(delimiter)}) +} + +// Substring creates an expression that returns a substring of a string. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expression] that evaluates to a string. +// - index is the starting index of the substring. +// - offset is the length of the substring. +func Substring(exprOrFieldPath any, index any, offset any) Expression { + return newBaseFunction("substring", []Expression{asFieldExpr(exprOrFieldPath), asInt64Expr(index), asInt64Expr(offset)}) +} + +// ToLower creates an expression that converts a string to lowercase. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expression] that evaluates to a string. +func ToLower(exprOrFieldPath any) Expression { + return newBaseFunction("to_lower", []Expression{asFieldExpr(exprOrFieldPath)}) +} + +// ToUpper creates an expression that converts a string to uppercase. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expression] that evaluates to a string. +func ToUpper(exprOrFieldPath any) Expression { + return newBaseFunction("to_upper", []Expression{asFieldExpr(exprOrFieldPath)}) +} + +// Trim creates an expression that removes leading and trailing whitespace from a string. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expression] that evaluates to a string. +func Trim(exprOrFieldPath any) Expression { + return newBaseFunction("trim", []Expression{asFieldExpr(exprOrFieldPath)}) +} + +// Split creates an expression that splits a string by a delimiter. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expression] that evaluates to a string. +// - delimiter is the string to use to split by. +func Split(exprOrFieldPath any, delimiter any) Expression { + return newBaseFunction("split", []Expression{asFieldExpr(exprOrFieldPath), asStringExpr(delimiter)}) +} + +// Type creates an expression that returns the type of the expression. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expression]. +func Type(exprOrFieldPath any) Expression { + return newBaseFunction("type", []Expression{asFieldExpr(exprOrFieldPath)}) +} + +// CosineDistance creates an expression that calculates the cosine distance between two vectors. +// - vector1 can be a field path string, [FieldPath] or [Expression]. +// - vector2 can be [Vector32], [Vector64], []float32, []float64 or [Expression]. +func CosineDistance(vector1 any, vector2 any) Expression { + return newBaseFunction("cosine_distance", []Expression{asFieldExpr(vector1), asVectorExpr(vector2)}) +} + +// DotProduct creates an expression that calculates the dot product of two vectors. +// - vector1 can be a field path string, [FieldPath] or [Expression]. +// - vector2 can be [Vector32], [Vector64], []float32, []float64 or [Expression]. +func DotProduct(vector1 any, vector2 any) Expression { + return newBaseFunction("dot_product", []Expression{asFieldExpr(vector1), asVectorExpr(vector2)}) +} + +// EuclideanDistance creates an expression that calculates the euclidean distance between two vectors. +// - vector1 can be a field path string, [FieldPath] or [Expression]. +// - vector2 can be [Vector32], [Vector64], []float32, []float64 or [Expression]. +func EuclideanDistance(vector1 any, vector2 any) Expression { + return newBaseFunction("euclidean_distance", []Expression{asFieldExpr(vector1), asVectorExpr(vector2)}) +} + +// VectorLength creates an expression that calculates the length of a vector. +// - exprOrFieldPath can be a field path string, [FieldPath] or [Expression]. +func VectorLength(exprOrFieldPath any) Expression { + return newBaseFunction("vector_length", []Expression{asFieldExpr(exprOrFieldPath)}) +} + +// Length creates an expression that calculates the length of string, array, map or vector. +// - exprOrField can be a field path string, [FieldPath] or an [Expression] that returns a string, array, map or vector when evaluated. +// +// Example: +// +// // Length of the 'name' field. +// Length("name") +func Length(exprOrField any) Expression { + return newBaseFunction("length", []Expression{asFieldExpr(exprOrField)}) +} + +// Reverse creates an expression that reverses a string, or array. +// - exprOrField can be a field path string, [FieldPath] or an [Expression] that returns a string, or array when evaluated. +// +// Example: +// +// // Reverse the 'name' field. +// +// Reverse("name") +func Reverse(exprOrField any) Expression { + return newBaseFunction("reverse", []Expression{asFieldExpr(exprOrField)}) +} + +// Concat creates an expression that concatenates expressions together. +// - exprOrField can be a field path string, [FieldPath] or an [Expression]. +// - others can be a list of constants or [Expression]. +// +// Example: +// +// // Concat the 'name' field with a constant string. +// Concat("name", "-suffix") +func Concat(exprOrField any, others ...any) Expression { + return newBaseFunction("concat", append([]Expression{asFieldExpr(exprOrField)}, toArrayOfExprOrConstant(others)...)) +} + +// GetCollectionID creates an expression that returns the ID of the collection that contains the document. +// - exprOrField can be a field path string, [FieldPath] or an [Expression] that evaluates to a field path. +func GetCollectionID(exprOrField any) Expression { + return newBaseFunction("collection_id", []Expression{asFieldExpr(exprOrField)}) +} + +// GetDocumentID creates an expression that returns the ID of the document. +// - exprStringOrDocRef can be a string, a [DocumentRef], or an [Expression] that evaluates to a document reference. +func GetDocumentID(exprStringOrDocRef any) Expression { + var expr Expression + switch v := exprStringOrDocRef.(type) { + case string: + expr = ConstantOf(v) + case *DocumentRef: + expr = ConstantOf(v) + case Expression: + expr = v + default: + return &baseFunction{baseExpression: &baseExpression{err: fmt.Errorf("firestore: value must be a string, DocumentRef, or Expr, but got %T", exprStringOrDocRef)}} + } + + return newBaseFunction("document_id", []Expression{expr}) +} + +// Conditional creates an expression that evaluates a condition and returns one of two expressions. +// - condition is the boolean expression to evaluate. +// - thenVal is the expression to return if the condition is true. +// - elseVal is the expression to return if the condition is false. +func Conditional(condition BooleanExpression, thenVal, elseVal any) Expression { + return newBaseFunction("conditional", []Expression{condition, toExprOrConstant(thenVal), toExprOrConstant(elseVal)}) +} + +// LogicalMaximum creates an expression that evaluates to the maximum value in a list of expressions. +// - exprOrField can be a field path string, [FieldPath] or an [Expression]. +// - others can be a list of constants or [Expression]. +func LogicalMaximum(exprOrField any, others ...any) Expression { + return newBaseFunction("maximum", append([]Expression{asFieldExpr(exprOrField)}, toArrayOfExprOrConstant(others)...)) +} + +// LogicalMinimum creates an expression that evaluates to the minimum value in a list of expressions. +// - exprOrField can be a field path string, [FieldPath] or an [Expression]. +// - others can be a list of constants or [Expression]. +func LogicalMinimum(exprOrField any, others ...any) Expression { + return newBaseFunction("minimum", append([]Expression{asFieldExpr(exprOrField)}, toArrayOfExprOrConstant(others)...)) +} + +// IfError creates an expression that evaluates and returns `tryExpr` if it does not produce an error; +// otherwise, it evaluates and returns `catchExprOrValue`. It returns a new [Expression] representing +// the if_error operation. +// - tryExpr is the expression to try. +// - catchExprOrValue is the expression or value to return if `tryExpr` errors. +func IfError(tryExpr Expression, catchExprOrValue any) Expression { + return newBaseFunction("if_error", []Expression{tryExpr, toExprOrConstant(catchExprOrValue)}) +} + +// IfErrorBoolean creates a boolean expression that evaluates and returns `tryExpr` if it does not produce an error; +// otherwise, it evaluates and returns `catchExpr`. It returns a new [BooleanExpression] representing +// the if_error operation. +// - tryExpr is the boolean expression to try. +// - catchExpr is the boolean expression to return if `tryExpr` errors. +func IfErrorBoolean(tryExpr BooleanExpression, catchExpr BooleanExpression) BooleanExpression { + return &baseBooleanExpression{baseFunction: newBaseFunction("if_error", []Expression{tryExpr, catchExpr})} +} + +// IfAbsent creates an expression that returns a default value if an expression evaluates to an absent value. +// - exprOrField can be a field path string, [FieldPath] or an [Expression]. +// - elseValue is the value to return if the expression is absent. +func IfAbsent(exprOrField any, elseValue any) Expression { + return newBaseFunction("if_absent", []Expression{asFieldExpr(exprOrField), toExprOrConstant(elseValue)}) +} + +// Map creates an expression that creates a Firestore map value from an input object. +// - elements: The input map to evaluate in the expression. +func Map(elements map[string]any) Expression { + exprs := make([]Expression, 0, len(elements)*2) + for k, v := range elements { + exprs = append(exprs, ConstantOf(k), toExprOrConstant(v)) + } + return newBaseFunction("map", exprs) +} + +// MapGet creates an expression that accesses a value from a map (object) field using the provided key. +// - exprOrField: The expression representing the map. +// - strOrExprkey: The key to access in the map. +func MapGet(exprOrField any, strOrExprkey any) Expression { + return newBaseFunction("map_get", []Expression{asFieldExpr(exprOrField), asStringExpr(strOrExprkey)}) +} + +// MapMerge creates an expression that merges multiple maps into a single map. +// If multiple maps have the same key, the later value is used. +// - exprOrField: First map expression that will be merged. +// - secondMap: Second map expression that will be merged. +// - otherMaps: Additional maps to merge. +func MapMerge(exprOrField any, secondMap Expression, otherMaps ...Expression) Expression { + return newBaseFunction("map_merge", append([]Expression{asFieldExpr(exprOrField), secondMap}, otherMaps...)) +} + +// MapRemove creates an expression that removes a key from a map. +// - exprOrField: The expression representing the map. +// - strOrExprkey: The key to remove from the map. +func MapRemove(exprOrField any, strOrExprkey any) Expression { + return newBaseFunction("map_remove", []Expression{asFieldExpr(exprOrField), asStringExpr(strOrExprkey)}) +} diff --git a/firestore/pipeline_integration_test.go b/firestore/pipeline_integration_test.go new file mode 100644 index 000000000000..0d1550ff8d27 --- /dev/null +++ b/firestore/pipeline_integration_test.go @@ -0,0 +1,2409 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package firestore + +import ( + "context" + "fmt" + "math" + "sort" + "testing" + "time" + + "cloud.google.com/go/internal/testutil" + "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/api/iterator" + "google.golang.org/genproto/googleapis/type/latlng" +) + +func skipIfNotEnterprise(t *testing.T) { + if testParams[firestoreEditionKey].(firestoreEdition) != editionEnterprise { + t.Skip("Skipping test in non-enterprise environment") + } +} + +type Author struct { + Name string `firestore:"name"` + Country string `firestore:"country"` +} + +type Book struct { + Title string `firestore:"title"` + Author Author `firestore:"author"` + Genre string `firestore:"genre"` + Published int `firestore:"published"` + Rating float64 `firestore:"rating"` + Tags []string `firestore:"tags"` +} + +func testBooks() []Book { + return []Book{ + {Title: "The Hitchhiker's Guide to the Galaxy", Author: Author{Name: "Douglas Adams", Country: "UK"}, Genre: "Science Fiction", Published: 1979, Rating: 4.2, Tags: []string{"comedy", "space", "adventure"}}, + {Title: "Pride and Prejudice", Author: Author{Name: "Jane Austen", Country: "UK"}, Genre: "Romance", Published: 1813, Rating: 4.5, Tags: []string{"classic", "social commentary", "love"}}, + {Title: "One Hundred Years of Solitude", Author: Author{Name: "Gabriel García Márquez", Country: "Colombia"}, Genre: "Magical Realism", Published: 1967, Rating: 4.3, Tags: []string{"family", "history", "fantasy"}}, + {Title: "The Lord of the Rings", Author: Author{Name: "J.R.R. Tolkien", Country: "UK"}, Genre: "Fantasy", Published: 1954, Rating: 4.7, Tags: []string{"adventure", "magic", "epic"}}, + {Title: "The Handmaid's Tale", Author: Author{Name: "Margaret Atwood", Country: "Canada"}, Genre: "Dystopian", Published: 1985, Rating: 4.1, Tags: []string{"feminism", "totalitarianism", "resistance"}}, + {Title: "Crime and Punishment", Author: Author{Name: "Fyodor Dostoevsky", Country: "Russia"}, Genre: "Psychological Thriller", Published: 1866, Rating: 4.3, Tags: []string{"philosophy", "crime", "redemption"}}, + {Title: "To Kill a Mockingbird", Author: Author{Name: "Harper Lee", Country: "USA"}, Genre: "Southern Gothic", Published: 1960, Rating: 4.2, Tags: []string{"racism", "injustice", "coming-of-age"}}, + {Title: "1984", Author: Author{Name: "George Orwell", Country: "UK"}, Genre: "Dystopian", Published: 1949, Rating: 4.2, Tags: []string{"surveillance", "totalitarianism", "propaganda"}}, + {Title: "The Great Gatsby", Author: Author{Name: "F. Scott Fitzgerald", Country: "USA"}, Genre: "Modernist", Published: 1925, Rating: 4.0, Tags: []string{"wealth", "american dream", "love"}}, + {Title: "Dune", Author: Author{Name: "Frank Herbert", Country: "USA"}, Genre: "Science Fiction", Published: 1965, Rating: 4.6, Tags: []string{"politics", "desert", "ecology"}}, + } +} + +func TestIntegration_PipelineExecute(t *testing.T) { + skipIfNotEnterprise(t) + ctx := context.Background() + client := integrationClient(t) + coll := integrationColl(t) + + t.Run("WithReadOptions", func(t *testing.T) { + timeBeforeCreate := time.Now() + doc1 := coll.NewDoc() + _, err := doc1.Create(ctx, map[string]interface{}{"a": 1}) + if err != nil { + t.Fatal(err) + } + + // Let a little time pass to ensure the next write has a later timestamp. + time.Sleep(1 * time.Millisecond) + + doc2 := coll.NewDoc() + _, err = doc2.Create(ctx, map[string]interface{}{"a": 2}) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + deleteDocuments([]*DocumentRef{doc1, doc2}) + }) + + iter := client.Pipeline().Collection(coll.ID).WithReadOptions(ReadTime(timeBeforeCreate)).Execute(ctx).Results() + res, err := iter.GetAll() + if err != nil { + t.Fatal(err) + } + if len(res) != 0 { + t.Errorf("got %d documents, want 0", len(res)) + } + }) + t.Run("WithTransaction", func(t *testing.T) { + h := testHelper{t} + books := testBooks()[:2] + var docRefs []*DocumentRef + for _, b := range books { + docRef := coll.NewDoc() + h.mustCreate(docRef, b) + docRefs = append(docRefs, docRef) + } + t.Cleanup(func() { + deleteDocuments(docRefs) + }) + p := client.Pipeline().Collection(coll.ID) + err := client.RunTransaction(ctx, func(ctx context.Context, txn *Transaction) error { + iter := txn.Execute(p).Results() + res, err := iter.GetAll() + if err != nil { + return err + } + if len(res) != len(books) { + return fmt.Errorf("got %d documents, want %d", len(res), len(books)) + } + return nil + }) + if err != nil { + t.Fatal(err) + } + }) +} + +func TestIntegration_PipelineStages(t *testing.T) { + skipIfNotEnterprise(t) + ctx := context.Background() + client := integrationClient(t) + coll := integrationColl(t) + h := testHelper{t} + type Author struct { + Name string `firestore:"name"` + Country string `firestore:"country"` + } + type Book struct { + Title string `firestore:"title"` + Author Author `firestore:"author"` + Genre string `firestore:"genre"` + Published int `firestore:"published"` + Rating float64 `firestore:"rating"` + Tags []string `firestore:"tags"` + } + books := []Book{ + {Title: "The Hitchhiker's Guide to the Galaxy", Author: Author{Name: "Douglas Adams", Country: "UK"}, Genre: "Science Fiction", Published: 1979, Rating: 4.2, Tags: []string{"comedy", "space", "adventure"}}, + {Title: "Pride and Prejudice", Author: Author{Name: "Jane Austen", Country: "UK"}, Genre: "Romance", Published: 1813, Rating: 4.5, Tags: []string{"classic", "social commentary", "love"}}, + {Title: "One Hundred Years of Solitude", Author: Author{Name: "Gabriel García Márquez", Country: "Colombia"}, Genre: "Magical Realism", Published: 1967, Rating: 4.3, Tags: []string{"family", "history", "fantasy"}}, + {Title: "The Lord of the Rings", Author: Author{Name: "J.R.R. Tolkien", Country: "UK"}, Genre: "Fantasy", Published: 1954, Rating: 4.7, Tags: []string{"adventure", "magic", "epic"}}, + {Title: "The Handmaid's Tale", Author: Author{Name: "Margaret Atwood", Country: "Canada"}, Genre: "Dystopian", Published: 1985, Rating: 4.1, Tags: []string{"feminism", "totalitarianism", "resistance"}}, + {Title: "Crime and Punishment", Author: Author{Name: "Fyodor Dostoevsky", Country: "Russia"}, Genre: "Psychological Thriller", Published: 1866, Rating: 4.3, Tags: []string{"philosophy", "crime", "redemption"}}, + {Title: "To Kill a Mockingbird", Author: Author{Name: "Harper Lee", Country: "USA"}, Genre: "Southern Gothic", Published: 1960, Rating: 4.2, Tags: []string{"racism", "injustice", "coming-of-age"}}, + {Title: "1984", Author: Author{Name: "George Orwell", Country: "UK"}, Genre: "Dystopian", Published: 1949, Rating: 4.2, Tags: []string{"surveillance", "totalitarianism", "propaganda"}}, + {Title: "The Great Gatsby", Author: Author{Name: "F. Scott Fitzgerald", Country: "USA"}, Genre: "Modernist", Published: 1925, Rating: 4.0, Tags: []string{"wealth", "american dream", "love"}}, + {Title: "Dune", Author: Author{Name: "Frank Herbert", Country: "USA"}, Genre: "Science Fiction", Published: 1965, Rating: 4.6, Tags: []string{"politics", "desert", "ecology"}}, + } + var docRefs []*DocumentRef + for _, b := range books { + docRef := coll.NewDoc() + h.mustCreate(docRef, b) + docRefs = append(docRefs, docRef) + } + t.Cleanup(func() { + deleteDocuments(docRefs) + }) + t.Run("AddFields", func(t *testing.T) { + iter := client.Pipeline().Collection(coll.ID).AddFields(Multiply(FieldOf("rating"), 2).As("doubled_rating")).Limit(1).Execute(ctx).Results() + defer iter.Stop() + doc, err := iter.Next() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if !doc.Exists() { + t.Fatalf("Exists: got: false, want: true") + } + data := doc.Data() + if dr, ok := data["doubled_rating"]; !ok || dr.(float64) != data["rating"].(float64)*2 { + t.Errorf("got doubled_rating %v, want %v", dr, data["rating"].(float64)*2) + } + }) + t.Run("Aggregate", func(t *testing.T) { + iter := client.Pipeline().Collection(coll.ID).Aggregate(Count("rating").As("total_books")).Execute(ctx).Results() + defer iter.Stop() + doc, err := iter.Next() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + + if !doc.Exists() { + t.Fatalf("Exists: got: false, want: true") + } + data := doc.Data() + if data["total_books"] != int64(10) { + t.Errorf("got %d total_books, want 10", data["total_books"]) + } + }) + t.Run("AggregateWithSpec", func(t *testing.T) { + spec := NewAggregateSpec(Average("rating").As("avg_rating")).WithGroups("genre") + iter := client.Pipeline().Collection(coll.ID).AggregateWithSpec(spec).Execute(ctx).Results() + defer iter.Stop() + results, err := iter.GetAll() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if len(results) != 8 { + t.Errorf("got %d groups, want 8", len(results)) + } + }) + t.Run("Distinct", func(t *testing.T) { + iter := client.Pipeline().Collection(coll.ID).Distinct("genre").Execute(ctx).Results() + defer iter.Stop() + results, err := iter.GetAll() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if len(results) != 8 { + t.Errorf("got %d distinct genres, want 8", len(results)) + } + }) + t.Run("Documents", func(t *testing.T) { + iter := client.Pipeline().Documents(docRefs[0], docRefs[1]).Execute(ctx).Results() + defer iter.Stop() + results, err := iter.GetAll() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if len(results) != 2 { + t.Errorf("got %d documents, want 2", len(results)) + } + }) + t.Run("CollectionGroup", func(t *testing.T) { + cgCollID := collectionIDs.New() + doc1 := coll.Doc("cg_doc1") + doc2 := coll.Doc("cg_doc2") + cgColl1 := doc1.Collection(cgCollID) + cgColl2 := doc2.Collection(cgCollID) + cgDoc1 := cgColl1.NewDoc() + cgDoc2 := cgColl2.NewDoc() + h.mustCreate(cgDoc1, map[string]string{"val": "a"}) + h.mustCreate(cgDoc2, map[string]string{"val": "b"}) + t.Cleanup(func() { + deleteDocuments([]*DocumentRef{cgDoc1, cgDoc2, doc1, doc2}) + }) + iter := client.Pipeline().CollectionGroup(cgCollID).Execute(ctx).Results() + defer iter.Stop() + results, err := iter.GetAll() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if len(results) != 2 { + t.Errorf("got %d documents, want 2", len(results)) + } + }) + t.Run("Database", func(t *testing.T) { + dbDoc1 := coll.Doc("db_doc1") + otherColl := client.Collection(collectionIDs.New()) + dbDoc2 := otherColl.Doc("db_doc2") + h.mustCreate(dbDoc1, map[string]string{"val": "a"}) + h.mustCreate(dbDoc2, map[string]string{"val": "b"}) + t.Cleanup(func() { + deleteDocuments([]*DocumentRef{dbDoc1, dbDoc2}) + }) + iter := client.Pipeline().Database().Limit(2).Execute(ctx).Results() + defer iter.Stop() + results, err := iter.GetAll() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if len(results) != 2 { + t.Errorf("got %d documents, want 2", len(results)) + } + }) + t.Run("FindNearest", func(t *testing.T) { + type DocWithVector struct { + ID string `firestore:"id"` + Vector Vector32 `firestore:"vector"` + } + docsWithVector := []DocWithVector{ + {ID: "doc1", Vector: Vector32{1.0, 2.0, 3.0}}, + {ID: "doc2", Vector: Vector32{4.0, 5.0, 6.0}}, + {ID: "doc3", Vector: Vector32{7.0, 8.0, 9.0}}, + } + var vectorDocRefs []*DocumentRef + for _, d := range docsWithVector { + docRef := coll.NewDoc() + h.mustCreate(docRef, d) + vectorDocRefs = append(vectorDocRefs, docRef) + } + t.Cleanup(func() { + deleteDocuments(vectorDocRefs) + }) + queryVector := Vector32{1.1, 2.1, 3.1} + limit := 2 + distanceField := "distance" + options := &PipelineFindNearestOptions{ + Limit: &limit, + DistanceField: &distanceField, + } + iter := client.Pipeline().Collection(coll.ID). + FindNearest("vector", queryVector, PipelineDistanceMeasureEuclidean, options). + Execute(ctx).Results() + defer iter.Stop() + results, err := iter.GetAll() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if len(results) != 2 { + t.Errorf("got %d documents, want 2", len(results)) + } + // Check if the results are sorted by distance + + if !results[0].Exists() { + t.Fatalf("results[0] Exists: got: false, want: true") + } + dist1 := results[0].Data() + + if !results[1].Exists() { + t.Fatalf("results[1] Exists: got: false, want: true") + } + dist2 := results[1].Data() + if dist1[distanceField].(float64) > dist2[distanceField].(float64) { + t.Errorf("documents are not sorted by distance") + } + // Check if the correct documents are returned + if dist1["id"] != "doc1" { + t.Errorf("got doc id %q, want 'doc1'", dist1["id"]) + } + }) + t.Run("Limit", func(t *testing.T) { + iter := client.Pipeline().Collection(coll.ID).Limit(3).Execute(ctx).Results() + defer iter.Stop() + results, err := iter.GetAll() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if len(results) != 3 { + t.Errorf("got %d documents, want 3", len(results)) + } + }) + t.Run("Offset", func(t *testing.T) { + iter := client.Pipeline().Collection(coll.ID).Sort(Ascending(FieldOf("published"))).Offset(2).Limit(1).Execute(ctx).Results() + defer iter.Stop() + doc, err := iter.Next() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if !doc.Exists() { + t.Fatalf("Exists: got: false, want: true") + } + data := doc.Data() + if data["title"] != "The Great Gatsby" { + t.Errorf("got title %q, want 'The Great Gatsby'", data["title"]) + } + }) + t.Run("RawStage", func(t *testing.T) { + // Using RawStage to perform a Limit operation + iter := client.Pipeline().Collection(coll.ID).RawStage(NewRawStage("limit").WithArguments(3)).Execute(ctx).Results() + defer iter.Stop() + results, err := iter.GetAll() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if len(results) != 3 { + t.Errorf("got %d documents, want 3", len(results)) + } + + // Using RawStage to perform a Select operation with options + iter = client.Pipeline().Collection(coll.ID).RawStage(NewRawStage("select").WithArguments(map[string]interface{}{"title": FieldOf("title")})).Limit(1).Execute(ctx).Results() + defer iter.Stop() + doc, err := iter.Next() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if !doc.Exists() { + t.Fatalf("Exists: got: false, want: true") + } + data := doc.Data() + if _, ok := data["title"]; !ok { + t.Error("missing 'title' field") + } + if _, ok := data["genre"]; ok { + t.Error("unexpected 'genre' field") + } + }) + t.Run("RemoveFields", func(t *testing.T) { + iter := client.Pipeline().Collection(coll.ID). + Limit(1). + RemoveFields("genre", "rating"). + Execute(ctx).Results() + defer iter.Stop() + doc, err := iter.Next() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if !doc.Exists() { + t.Fatalf("Exists: got: false, want: true") + } + data := doc.Data() + if _, ok := data["genre"]; ok { + t.Error("unexpected 'genre' field") + } + if _, ok := data["rating"]; ok { + t.Error("unexpected 'rating' field") + } + if _, ok := data["title"]; !ok { + t.Error("missing 'title' field") + } + }) + t.Run("Replace", func(t *testing.T) { + type DocWithMap struct { + ID string `firestore:"id"` + Data map[string]int `firestore:"data"` + } + docWithMap := DocWithMap{ID: "docWithMap", Data: map[string]int{"a": 1, "b": 2}} + docRef := coll.NewDoc() + h.mustCreate(docRef, docWithMap) + t.Cleanup(func() { + deleteDocuments([]*DocumentRef{docRef}) + }) + iter := client.Pipeline().Collection(coll.ID). + Where(Equal(FieldOf("id"), "docWithMap")). + ReplaceWith("data"). + Execute(ctx).Results() + defer iter.Stop() + doc, err := iter.Next() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if !doc.Exists() { + t.Fatalf("Exists: got: false, want: true") + } + data := doc.Data() + want := map[string]interface{}{"a": int64(1), "b": int64(2)} + if diff := testutil.Diff(data, want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", data, want, diff) + } + }) + t.Run("Sample", func(t *testing.T) { + t.Run("SampleByDocuments", func(t *testing.T) { + iter := client.Pipeline().Collection(coll.ID).Sample(SampleByDocuments(5)).Execute(ctx).Results() + defer iter.Stop() + var got []map[string]interface{} + for { + doc, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if !doc.Exists() { + t.Fatalf("Exists: got: false, want: true") + } + data := doc.Data() + got = append(got, data) + } + if len(got) != 5 { + t.Errorf("got %d documents, want 5", len(got)) + } + }) + t.Run("SampleByPercentage", func(t *testing.T) { + iter := client.Pipeline().Collection(coll.ID).Sample(&SampleSpec{Size: 0.6, Mode: SampleModePercent}).Execute(ctx).Results() + defer iter.Stop() + var got []map[string]interface{} + for { + doc, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if !doc.Exists() { + t.Fatalf("Exists: got: false, want: true") + } + data := doc.Data() + got = append(got, data) + } + if len(got) >= 10 { + t.Errorf("Sampled documents count should be less than total. got %d, total 10", len(got)) + } + if len(got) == 0 { + t.Errorf("Sampled documents count should be greater than 0. got %d", len(got)) + } + }) + }) + t.Run("Select", func(t *testing.T) { + iter := client.Pipeline().Collection(coll.ID).Select("title", "author.name").Limit(1).Execute(ctx).Results() + defer iter.Stop() + doc, err := iter.Next() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if !doc.Exists() { + t.Fatalf("Exists: got: false, want: true") + } + data := doc.Data() + if _, ok := data["title"]; !ok { + t.Error("missing 'title' field") + } + if _, ok := data["author.name"]; !ok { + t.Error("missing 'author.name' field") + } + if _, ok := data["author"]; ok { + t.Error("unexpected 'author' field") + } + if _, ok := data["genre"]; ok { + t.Error("unexpected 'genre' field") + } + }) + t.Run("Sort", func(t *testing.T) { + iter := client.Pipeline().Collection(coll.ID).Sort(Descending(FieldOf("rating"))).Limit(1).Execute(ctx).Results() + defer iter.Stop() + doc, err := iter.Next() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if !doc.Exists() { + t.Fatalf("Exists: got: false, want: true") + } + data := doc.Data() + if data["title"] != "The Lord of the Rings" { + t.Errorf("got title %q, want 'The Lord of the Rings'", data["title"]) + } + }) + t.Run("Union", func(t *testing.T) { + type Employee struct { + Name string `firestore:"name"` + Age int `firestore:"age"` + } + type Customer struct { + Name string `firestore:"name"` + Address string `firestore:"address"` + } + employeeColl := client.Collection(collectionIDs.New()) + customerColl := client.Collection(collectionIDs.New()) + employees := []Employee{ + {Name: "John Doe", Age: 42}, + {Name: "Jane Smith", Age: 35}, + } + customers := []Customer{ + {Name: "Alice", Address: "123 Main St"}, + {Name: "Bob", Address: "456 Oak Ave"}, + } + var unionDocRefs []*DocumentRef + for _, e := range employees { + docRef := employeeColl.NewDoc() + h.mustCreate(docRef, e) + unionDocRefs = append(unionDocRefs, docRef) + } + for _, c := range customers { + docRef := customerColl.NewDoc() + h.mustCreate(docRef, c) + unionDocRefs = append(unionDocRefs, docRef) + } + t.Cleanup(func() { + deleteDocuments(unionDocRefs) + }) + employeePipeline := client.Pipeline().Collection(employeeColl.ID) + customerPipeline := client.Pipeline().Collection(customerColl.ID) + iter := employeePipeline.Union(customerPipeline).Execute(context.Background()).Results() + defer iter.Stop() + var got []map[string]interface{} + for { + doc, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if !doc.Exists() { + t.Fatalf("Exists: got: false, want: true") + } + data := doc.Data() + got = append(got, data) + } + want := []map[string]interface{}{ + {"name": "John Doe", "age": int64(42)}, + {"name": "Jane Smith", "age": int64(35)}, + {"name": "Alice", "address": "123 Main St"}, + {"name": "Bob", "address": "456 Oak Ave"}, + } + sort.Slice(got, func(i, j int) bool { + return got[i]["name"].(string) < got[j]["name"].(string) + }) + sort.Slice(want, func(i, j int) bool { + return want[i]["name"].(string) < want[j]["name"].(string) + }) + if diff := testutil.Diff(got, want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", got, want, diff) + } + }) + t.Run("Unnest", func(t *testing.T) { + iter := client.Pipeline().Collection(coll.ID). + Where(Equal(FieldOf("title"), "The Hitchhiker's Guide to the Galaxy")). + UnnestWithAlias("tags", "tag", nil). + Select("title", "tag"). + Execute(ctx).Results() + defer iter.Stop() + var got []map[string]interface{} + for { + doc, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if !doc.Exists() { + t.Fatalf("Exists: got: false, want: true") + } + data := doc.Data() + got = append(got, data) + } + want := []map[string]interface{}{ + {"title": "The Hitchhiker's Guide to the Galaxy", "tag": "comedy"}, + {"title": "The Hitchhiker's Guide to the Galaxy", "tag": "space"}, + {"title": "The Hitchhiker's Guide to the Galaxy", "tag": "adventure"}, + } + sort.Slice(got, func(i, j int) bool { + return got[i]["tag"].(string) < got[j]["tag"].(string) + }) + sort.Slice(want, func(i, j int) bool { + return want[i]["tag"].(string) < want[j]["tag"].(string) + }) + if diff := testutil.Diff(got, want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", got, want, diff) + } + }) + t.Run("UnnestWithIndexField", func(t *testing.T) { + iter := client.Pipeline().Collection(coll.ID). + Where(Equal(FieldOf("title"), "The Hitchhiker's Guide to the Galaxy")). + UnnestWithAlias("tags", "tag", &UnnestOptions{IndexField: "tagIndex"}). + Select("title", "tag", "tagIndex"). + Execute(ctx).Results() + defer iter.Stop() + var got []map[string]interface{} + for { + doc, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if !doc.Exists() { + t.Fatalf("Exists: got: false, want: true") + } + data := doc.Data() + got = append(got, data) + } + want := []map[string]interface{}{ + {"title": "The Hitchhiker's Guide to the Galaxy", "tag": "comedy", "tagIndex": int64(0)}, + {"title": "The Hitchhiker's Guide to the Galaxy", "tag": "space", "tagIndex": int64(1)}, + {"title": "The Hitchhiker's Guide to the Galaxy", "tag": "adventure", "tagIndex": int64(2)}, + } + sort.Slice(got, func(i, j int) bool { + return got[i]["tagIndex"].(int64) < got[j]["tagIndex"].(int64) + }) + if diff := testutil.Diff(got, want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", got, want, diff) + } + }) + t.Run("Where", func(t *testing.T) { + iter := client.Pipeline().Collection(coll.ID).Where(Equal(FieldOf("author.country"), "UK")).Execute(ctx).Results() + defer iter.Stop() + results, err := iter.GetAll() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if len(results) != 4 { + t.Errorf("got %d documents, want 4", len(results)) + } + }) +} + +func TestIntegration_PipelineFunctions(t *testing.T) { + skipIfNotEnterprise(t) + t.Run("arrayFuncs", arrayFuncs) + t.Run("stringFuncs", stringFuncs) + t.Run("vectorFuncs", vectorFuncs) + + t.Run("timestampFuncs", timestampFuncs) + t.Run("arithmeticFuncs", arithmeticFuncs) + t.Run("aggregateFuncs", aggregateFuncs) + t.Run("comparisonFuncs", comparisonFuncs) + t.Run("generalFuncs", generalFuncs) + t.Run("keyFuncs", keyFuncs) + t.Run("objectFuncs", objectFuncs) + t.Run("logicalFuncs", logicalFuncs) + t.Run("typeFuncs", typeFuncs) +} + +func typeFuncs(t *testing.T) { + t.Parallel() + h := testHelper{t} + client := integrationClient(t) + coll := client.Collection(collectionIDs.New()) + docRef1 := coll.NewDoc() + h.mustCreate(docRef1, map[string]interface{}{ + "a": nil, + "b": true, + "c": 1, + "d": "hello", + "e": []byte("world"), + "f": time.Now(), + "g": &latlng.LatLng{Latitude: 32.1, Longitude: -4.5}, + "h": []interface{}{1, 2, 3}, + "i": map[string]interface{}{"j": 1}, + "k": Vector64{1, 2, 3}, + "l": docRef1, + }) + defer deleteDocuments([]*DocumentRef{docRef1}) + + tests := []struct { + name string + pipeline *Pipeline + want map[string]interface{} + }{ + { + name: "Type of null", + pipeline: client.Pipeline().Collection(coll.ID).Select(Type("a").As("type")), + want: map[string]interface{}{"type": "null"}, + }, + { + name: "Type of boolean", + pipeline: client.Pipeline().Collection(coll.ID).Select(Type("b").As("type")), + want: map[string]interface{}{"type": "boolean"}, + }, + { + name: "Type of int64", + pipeline: client.Pipeline().Collection(coll.ID).Select(Type("c").As("type")), + want: map[string]interface{}{"type": "int64"}, + }, + { + name: "Type of string", + pipeline: client.Pipeline().Collection(coll.ID).Select(Type("d").As("type")), + want: map[string]interface{}{"type": "string"}, + }, + { + name: "Type of bytes", + pipeline: client.Pipeline().Collection(coll.ID).Select(Type("e").As("type")), + want: map[string]interface{}{"type": "bytes"}, + }, + { + name: "Type of timestamp", + pipeline: client.Pipeline().Collection(coll.ID).Select(Type("f").As("type")), + want: map[string]interface{}{"type": "timestamp"}, + }, + { + name: "Type of geopoint", + pipeline: client.Pipeline().Collection(coll.ID).Select(Type("g").As("type")), + want: map[string]interface{}{"type": "geo_point"}, + }, + { + name: "Type of array", + pipeline: client.Pipeline().Collection(coll.ID).Select(Type("h").As("type")), + want: map[string]interface{}{"type": "array"}, + }, + { + name: "Type of map", + pipeline: client.Pipeline().Collection(coll.ID).Select(Type("i").As("type")), + want: map[string]interface{}{"type": "map"}, + }, + { + name: "Type of vector", + pipeline: client.Pipeline().Collection(coll.ID).Select(Type("k").As("type")), + want: map[string]interface{}{"type": "vector"}, + }, + { + name: "Type of reference", + pipeline: client.Pipeline().Collection(coll.ID).Select(Type("l").As("type")), + want: map[string]interface{}{"type": "reference"}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := context.Background() + iter := test.pipeline.Execute(ctx).Results() + defer iter.Stop() + + docs, err := iter.GetAll() + if err != nil { + t.Fatalf("GetAll: %v", err) + return + } + if len(docs) != 1 { + t.Fatalf("expected 1 doc, got %d", len(docs)) + } + got := docs[0].Data() + if diff := testutil.Diff(got, test.want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", got, test.want, diff) + } + }) + } +} + +func TestIntegration_Query_Pipeline(t *testing.T) { + skipIfNotEnterprise(t) + ctx := context.Background() + coll := integrationColl(t) + h := testHelper{t} + type Book struct { + Title string `firestore:"title"` + Genre string `firestore:"genre"` + Published int `firestore:"published"` + Rating float64 `firestore:"rating"` + } + books := []Book{ + {Title: "The Hitchhiker's Guide to the Galaxy", Genre: "Science Fiction", Published: 1979, Rating: 4.2}, + {Title: "Pride and Prejudice", Genre: "Romance", Published: 1813, Rating: 4.5}, + {Title: "One Hundred Years of Solitude", Genre: "Magical Realism", Published: 1967, Rating: 4.3}, + } + var docRefs []*DocumentRef + for _, b := range books { + docRef := coll.NewDoc() + h.mustCreate(docRef, b) + docRefs = append(docRefs, docRef) + } + t.Cleanup(func() { + deleteDocuments(docRefs) + }) + + t.Run("Where", func(t *testing.T) { + q := coll.Where("published", ">", 1900) + p := q.Pipeline() + iter := p.Execute(ctx).Results() + defer iter.Stop() + res, err := iter.GetAll() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if len(res) != 2 { + t.Errorf("got %d documents, want 2", len(res)) + } + }) + + t.Run("OrderBy", func(t *testing.T) { + q := coll.OrderBy("published", Asc) + p := q.Pipeline() + iter := p.Execute(ctx).Results() + defer iter.Stop() + res, err := iter.GetAll() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if len(res) != 3 { + t.Errorf("got %d documents, want 3", len(res)) + } + var publishedYears []int64 + for _, r := range res { + publishedYears = append(publishedYears, r.Data()["published"].(int64)) + } + if !sort.SliceIsSorted(publishedYears, func(i, j int) bool { return publishedYears[i] < publishedYears[j] }) { + t.Errorf("results not sorted by published year: %v", publishedYears) + } + }) + + t.Run("Limit", func(t *testing.T) { + q := coll.Limit(2) + p := q.Pipeline() + iter := p.Execute(ctx).Results() + defer iter.Stop() + res, err := iter.GetAll() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if len(res) != 2 { + t.Errorf("got %d documents, want 2", len(res)) + } + }) + + t.Run("Offset", func(t *testing.T) { + q := coll.OrderBy("published", Asc).Offset(1) + p := q.Pipeline() + iter := p.Execute(ctx).Results() + defer iter.Stop() + res, err := iter.GetAll() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if len(res) != 2 { + t.Errorf("got %d documents, want 2", len(res)) + } + }) + + t.Run("Select", func(t *testing.T) { + q := coll.Select("title") + p := q.Pipeline() + iter := p.Execute(ctx).Results() + defer iter.Stop() + doc, err := iter.Next() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + data := doc.Data() + if _, ok := data["title"]; !ok { + t.Error("missing 'title' field") + } + if _, ok := data["genre"]; ok { + t.Error("unexpected 'genre' field") + } + }) +} + +func TestIntegration_AggregationQuery_Pipeline(t *testing.T) { + skipIfNotEnterprise(t) + ctx := context.Background() + coll := integrationColl(t) + h := testHelper{t} + type Book struct { + Title string `firestore:"title"` + Genre string `firestore:"genre"` + Published int `firestore:"published"` + Rating float64 `firestore:"rating"` + } + books := []Book{ + {Title: "The Hitchhiker's Guide to the Galaxy", Genre: "Science Fiction", Published: 1979, Rating: 4.2}, + {Title: "Pride and Prejudice", Genre: "Romance", Published: 1813, Rating: 4.5}, + {Title: "One Hundred Years of Solitude", Genre: "Magical Realism", Published: 1967, Rating: 4.3}, + } + var docRefs []*DocumentRef + for _, b := range books { + docRef := coll.NewDoc() + h.mustCreate(docRef, b) + docRefs = append(docRefs, docRef) + } + t.Cleanup(func() { + deleteDocuments(docRefs) + }) + + t.Run("Count", func(t *testing.T) { + ag := coll.NewAggregationQuery().WithCount("count") + p := ag.Pipeline() + iter := p.Execute(ctx).Results() + defer iter.Stop() + doc, err := iter.Next() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + + if !doc.Exists() { + t.Fatalf("Exists: got: false, want: true") + } + data := doc.Data() + if data["count"] != int64(3) { + t.Errorf("got %d count, want 3", data["count"]) + } + }) + + t.Run("Sum", func(t *testing.T) { + ag := coll.NewAggregationQuery().WithSum("published", "total_published") + p := ag.Pipeline() + iter := p.Execute(ctx).Results() + defer iter.Stop() + doc, err := iter.Next() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + + if !doc.Exists() { + t.Fatalf("Exists: got: false, want: true") + } + data := doc.Data() + if data["total_published"] != int64(1979+1813+1967) { + t.Errorf("got %d total_published, want %d", data["total_published"], int64(1979+1813+1967)) + } + }) + + t.Run("Average", func(t *testing.T) { + ag := coll.NewAggregationQuery().WithAvg("rating", "avg_rating") + p := ag.Pipeline() + iter := p.Execute(ctx).Results() + defer iter.Stop() + doc, err := iter.Next() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + + if !doc.Exists() { + t.Fatalf("Exists: got: false, want: true") + } + data := doc.Data() + if data["avg_rating"] != (4.2+4.5+4.3)/3 { + t.Errorf("got %f avg_rating, want %f", data["avg_rating"], (4.2+4.5+4.3)/3) + } + }) +} + +func objectFuncs(t *testing.T) { + t.Parallel() + h := testHelper{t} + client := integrationClient(t) + coll := client.Collection(collectionIDs.New()) + docRef1 := coll.NewDoc() + h.mustCreate(docRef1, map[string]interface{}{ + "m1": map[string]interface{}{"a": 1, "b": 2}, + "m2": map[string]interface{}{"c": 3, "d": 4}, + }) + defer deleteDocuments([]*DocumentRef{docRef1}) + + tests := []struct { + name string + pipeline *Pipeline + want map[string]interface{} + }{ + { + name: "Map", + pipeline: client.Pipeline().Collection(coll.ID).Select(Map(map[string]any{"a": 1, "b": 2}).As("map")), + want: map[string]interface{}{"map": map[string]interface{}{"a": int64(1), "b": int64(2)}}, + }, + { + name: "MapGet", + pipeline: client.Pipeline().Collection(coll.ID).Select(MapGet("m1", "a").As("value")), + want: map[string]interface{}{"value": int64(1)}, + }, + { + name: "MapMerge", + pipeline: client.Pipeline().Collection(coll.ID).Select(MapMerge("m1", FieldOf("m2")).As("merged")), + want: map[string]interface{}{"merged": map[string]interface{}{"a": int64(1), "b": int64(2), "c": int64(3), "d": int64(4)}}, + }, + { + name: "MapRemove", + pipeline: client.Pipeline().Collection(coll.ID).Select(MapRemove("m1", "a").As("removed")), + want: map[string]interface{}{"removed": map[string]interface{}{"b": int64(2)}}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := context.Background() + iter := test.pipeline.Execute(ctx).Results() + defer iter.Stop() + + docs, err := iter.GetAll() + if err != nil { + t.Fatalf("GetAll: %v", err) + return + } + if len(docs) != 1 { + t.Fatalf("expected 1 doc, got %d", len(docs)) + } + got := docs[0].Data() + if diff := testutil.Diff(got, test.want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", got, test.want, diff) + } + }) + } +} + +func arrayFuncs(t *testing.T) { + t.Parallel() + h := testHelper{t} + client := integrationClient(t) + coll := client.Collection(collectionIDs.New()) + docRef1 := coll.NewDoc() + h.mustCreate(docRef1, map[string]interface{}{ + "a": []interface{}{1, 2, 3}, + "b": []interface{}{4, 5, 6}, + "tags": []string{"Go", "Firestore", "GCP"}, + "tags2": []string{"Go", "Firestore"}, + "lang": "Go", + "status": "active", + }) + defer deleteDocuments([]*DocumentRef{docRef1}) + + tests := []struct { + name string + pipeline *Pipeline + want map[string]interface{} + }{ + { + name: "ArrayLength", + pipeline: client.Pipeline().Collection(coll.ID).Select(ArrayLength("a").As("length")), + want: map[string]interface{}{"length": int64(3)}, + }, + { + name: "Array", + pipeline: client.Pipeline().Collection(coll.ID).Select(Array(1, 2, 3).As("array")), + want: map[string]interface{}{"array": []interface{}{int64(1), int64(2), int64(3)}}, + }, + { + name: "ArrayFromSlice", + pipeline: client.Pipeline().Collection(coll.ID).Select(ArrayFromSlice([]int{1, 2, 3}).As("array")), + want: map[string]interface{}{"array": []interface{}{int64(1), int64(2), int64(3)}}, + }, + { + name: "ArrayGet", + pipeline: client.Pipeline().Collection(coll.ID).Select(ArrayGet("a", 1).As("element")), + want: map[string]interface{}{"element": int64(2)}, + }, + { + name: "ArrayReverse", + pipeline: client.Pipeline().Collection(coll.ID).Select(ArrayReverse("a").As("reversed")), + want: map[string]interface{}{"reversed": []interface{}{int64(3), int64(2), int64(1)}}, + }, + { + name: "ArrayConcat", + pipeline: client.Pipeline().Collection(coll.ID).Select(ArrayConcat("a", FieldOf("b")).As("concatenated")), + want: map[string]interface{}{"concatenated": []interface{}{int64(1), int64(2), int64(3), int64(4), int64(5), int64(6)}}, + }, + { + name: "ArraySum", + pipeline: client.Pipeline().Collection(coll.ID).Select(ArraySum("a").As("sum")), + want: map[string]interface{}{"sum": int64(6)}, + }, + { + name: "ArrayMaximum", + pipeline: client.Pipeline().Collection(coll.ID).Select(ArrayMaximum("a").As("max")), + want: map[string]interface{}{"max": int64(3)}, + }, + { + name: "ArrayMinimum", + pipeline: client.Pipeline().Collection(coll.ID).Select(ArrayMinimum("a").As("min")), + want: map[string]interface{}{"min": int64(1)}, + }, + // Array filter conditions + { + name: "ArrayContains", + pipeline: client.Pipeline().Collection(coll.ID).Where(ArrayContains("tags", "Go")), + want: map[string]interface{}{"lang": "Go", "tags": []interface{}{"Go", "Firestore", "GCP"}, "tags2": []interface{}{"Go", "Firestore"}, "status": "active", "a": []interface{}{int64(1), int64(2), int64(3)}, "b": []interface{}{int64(4), int64(5), int64(6)}}, + }, + { + name: "ArrayContainsAll - array of mixed types", + pipeline: client.Pipeline().Collection(coll.ID).Where(ArrayContainsAll("tags", []any{FieldOf("lang"), "Firestore"})), + want: map[string]interface{}{"lang": "Go", "tags": []interface{}{"Go", "Firestore", "GCP"}, "tags2": []interface{}{"Go", "Firestore"}, "status": "active", "a": []interface{}{int64(1), int64(2), int64(3)}, "b": []interface{}{int64(4), int64(5), int64(6)}}, + }, + { + name: "ArrayContainsAll - array of constants", + pipeline: client.Pipeline().Collection(coll.ID).Where(ArrayContainsAll("tags", []string{"Go", "Firestore"})), + want: map[string]interface{}{"lang": "Go", "tags": []interface{}{"Go", "Firestore", "GCP"}, "tags2": []interface{}{"Go", "Firestore"}, "status": "active", "a": []interface{}{int64(1), int64(2), int64(3)}, "b": []interface{}{int64(4), int64(5), int64(6)}}, + }, + { + name: "ArrayContainsAll - Expr", + pipeline: client.Pipeline().Collection(coll.ID).Where(ArrayContainsAll("tags", FieldOf("tags2"))), + want: map[string]interface{}{"lang": "Go", "tags": []interface{}{"Go", "Firestore", "GCP"}, "tags2": []interface{}{"Go", "Firestore"}, "status": "active", "a": []interface{}{int64(1), int64(2), int64(3)}, "b": []interface{}{int64(4), int64(5), int64(6)}}, + }, + { + name: "ArrayContainsAny", + pipeline: client.Pipeline().Collection(coll.ID).Where(ArrayContainsAny("tags", []string{"Go", "Java"})), + want: map[string]interface{}{"lang": "Go", "tags": []interface{}{"Go", "Firestore", "GCP"}, "tags2": []interface{}{"Go", "Firestore"}, "status": "active", "a": []interface{}{int64(1), int64(2), int64(3)}, "b": []interface{}{int64(4), int64(5), int64(6)}}, + }, + { + name: "EqualAny", + pipeline: client.Pipeline().Collection(coll.ID).Where(EqualAny("status", []string{"active", "pending"})), + want: map[string]interface{}{"lang": "Go", "tags": []interface{}{"Go", "Firestore", "GCP"}, "tags2": []interface{}{"Go", "Firestore"}, "status": "active", "a": []interface{}{int64(1), int64(2), int64(3)}, "b": []interface{}{int64(4), int64(5), int64(6)}}, + }, + { + name: "NotEqualAny", + pipeline: client.Pipeline().Collection(coll.ID).Where(NotEqualAny("status", []string{"archived", "deleted"})), + want: map[string]interface{}{"lang": "Go", "tags": []interface{}{"Go", "Firestore", "GCP"}, "tags2": []interface{}{"Go", "Firestore"}, "status": "active", "a": []interface{}{int64(1), int64(2), int64(3)}, "b": []interface{}{int64(4), int64(5), int64(6)}}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + testutil.Retry(t, 3, time.Second, func(r *testutil.R) { + ctx := context.Background() + iter := test.pipeline.Execute(ctx).Results() + defer iter.Stop() + + docs, err := iter.GetAll() + if err != nil { + t.Fatalf("GetAll: %v", err) + return + } + if len(docs) != 1 { + t.Fatalf("expected 1 doc, got %d", len(docs)) + return + } + got := docs[0].Data() + if diff := testutil.Diff(got, test.want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", got, test.want, diff) + return + } + }) + }) + } +} + +func stringFuncs(t *testing.T) { + t.Parallel() + h := testHelper{t} + client := integrationClient(t) + coll := client.Collection(collectionIDs.New()) + docRef1 := coll.NewDoc() + h.mustCreate(docRef1, map[string]interface{}{ + "name": " John Doe ", + "description": "This is a Firestore document.", + "productCode": "abc-123", + "tags": []string{"tag1", "tag2", "tag3"}, + "email": "john.doe@example.com", + "zipCode": "12345", + "csv": "a,b,c", + }) + defer deleteDocuments([]*DocumentRef{docRef1}) + + doc1want := map[string]interface{}{ + "name": " John Doe ", + "description": "This is a Firestore document.", + "productCode": "abc-123", + "tags": []interface{}{"tag1", "tag2", "tag3"}, + "email": "john.doe@example.com", + "zipCode": "12345", + "csv": "a,b,c", + } + + tests := []struct { + name string + pipeline *Pipeline + want interface{} + }{ + { + name: "ByteLength", + pipeline: client.Pipeline().Collection(coll.ID).Select(ByteLength("name").As("byte_length")), + want: map[string]interface{}{"byte_length": int64(12)}, + }, + { + name: "CharLength", + pipeline: client.Pipeline().Collection(coll.ID).Select(CharLength("name").As("char_length")), + want: map[string]interface{}{"char_length": int64(12)}, + }, + { + name: "StringConcat", + pipeline: client.Pipeline().Collection(coll.ID).Select(StringConcat(FieldOf("name"), " - ", FieldOf("productCode")).As("concatenated_string")), + want: map[string]interface{}{"concatenated_string": " John Doe - abc-123"}, + }, + { + name: "StringReverse", + pipeline: client.Pipeline().Collection(coll.ID).Select(StringReverse("name").As("reversed_string")), + want: map[string]interface{}{"reversed_string": " eoD nhoJ "}, + }, + { + name: "Join", + pipeline: client.Pipeline().Collection(coll.ID).Select(Join("tags", ", ").As("joined_string")), + want: map[string]interface{}{"joined_string": "tag1, tag2, tag3"}, + }, + { + name: "Substring", + pipeline: client.Pipeline().Collection(coll.ID).Select(Substring("description", 0, 4).As("substring")), + want: map[string]interface{}{"substring": "This"}, + }, + { + name: "ToLower", + pipeline: client.Pipeline().Collection(coll.ID).Select(ToLower("name").As("lowercase_name")), + want: map[string]interface{}{"lowercase_name": " john doe "}, + }, + { + name: "ToUpper", + pipeline: client.Pipeline().Collection(coll.ID).Select(ToUpper("name").As("uppercase_name")), + want: map[string]interface{}{"uppercase_name": " JOHN DOE "}, + }, + { + name: "Trim", + pipeline: client.Pipeline().Collection(coll.ID).Select(Trim("name").As("trimmed_name")), + want: map[string]interface{}{"trimmed_name": "John Doe"}, + }, + { + name: "Split", + pipeline: client.Pipeline().Collection(coll.ID).Select(Split("csv", ",").As("split_string")), + want: map[string]interface{}{"split_string": []interface{}{"a", "b", "c"}}, + }, + // String filter conditions + { + name: "Like", + pipeline: client.Pipeline().Collection(coll.ID).Where(Like("name", "%John%")), + want: []map[string]interface{}{doc1want}, + }, + { + name: "StartsWith", + pipeline: client.Pipeline().Collection(coll.ID).Where(StartsWith("name", " John")), + want: []map[string]interface{}{doc1want}, + }, + { + name: "EndsWith", + pipeline: client.Pipeline().Collection(coll.ID).Where(EndsWith("name", "Doe ")), + want: []map[string]interface{}{doc1want}, + }, + { + name: "RegexContains", + pipeline: client.Pipeline().Collection(coll.ID).Where(RegexContains("email", "@example\\.com")), + want: []map[string]interface{}{doc1want}, + }, + { + name: "RegexMatch", + pipeline: client.Pipeline().Collection(coll.ID).Where(RegexMatch("zipCode", "^[0-9]{5}$")), + want: []map[string]interface{}{doc1want}, + }, + { + name: "StringContains", + pipeline: client.Pipeline().Collection(coll.ID).Where(StringContains("description", "Firestore")), + want: []map[string]interface{}{doc1want}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := context.Background() + + iter := test.pipeline.Execute(ctx).Results() + defer iter.Stop() + + docs, err := iter.GetAll() + if err != nil { + t.Fatalf("GetAll: %v", err) + return + } + lastStage := test.pipeline.stages[len(test.pipeline.stages)-1] + lastStageName := lastStage.name() + + if lastStageName == stageNameSelect { // This is a select query + want, ok := test.want.(map[string]interface{}) + if !ok { + t.Fatalf("invalid test.want type for select query: %T", test.want) + return + } + if len(docs) != 1 { + t.Fatalf("expected 1 doc, got %d", len(docs)) + return + } + got := docs[0].Data() + if diff := testutil.Diff(got, want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", got, want, diff) + } + } else if lastStageName == stageNameWhere { // This is a where query (filter condition) + want, ok := test.want.([]map[string]interface{}) + if !ok { + t.Fatalf("invalid test.want type for where query: %T", test.want) + return + } + if len(docs) != len(want) { + t.Fatalf("expected %d doc(s), got %d", len(want), len(docs)) + return + } + var gots []map[string]interface{} + for _, doc := range docs { + got := doc.Data() + gots = append(gots, got) + } + if diff := testutil.Diff(gots, want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", gots, want, diff) + } + } else { + t.Fatalf("unknown pipeline stage: %s", lastStageName) + return + } + }) + } + +} + +func vectorFuncs(t *testing.T) { + t.Parallel() + h := testHelper{t} + client := integrationClient(t) + coll := client.Collection(collectionIDs.New()) + docRef1 := coll.NewDoc() + h.mustCreate(docRef1, map[string]interface{}{ + "v1": Vector64{1.0, 2.0, 3.0}, + "v2": Vector64{4.0, 5.0, 6.0}, + }) + defer deleteDocuments([]*DocumentRef{docRef1}) + + tests := []struct { + name string + pipeline *Pipeline + want map[string]interface{} + }{ + { + name: "VectorLength", + pipeline: client.Pipeline().Collection(coll.ID).Select(VectorLength("v1").As("length")), + want: map[string]interface{}{"length": int64(3)}, + }, + { + name: "DotProduct - field and field", + pipeline: client.Pipeline().Collection(coll.ID).Select(DotProduct("v1", FieldOf("v2")).As("dot_product")), + want: map[string]interface{}{"dot_product": float64(1*4 + 2*5 + 3*6)}, + }, + { + name: "DotProduct - field and constant", + pipeline: client.Pipeline().Collection(coll.ID).Select(DotProduct("v1", Vector64{4.0, 5.0, 6.0}).As("dot_product")), + want: map[string]interface{}{"dot_product": float64(1*4 + 2*5 + 3*6)}, + }, + { + name: "EuclideanDistance - field and field", + pipeline: client.Pipeline().Collection(coll.ID).Select(EuclideanDistance("v1", FieldOf("v2")).As("euclidean")), + want: map[string]interface{}{"euclidean": math.Sqrt(math.Pow(4-1, 2) + math.Pow(5-2, 2) + math.Pow(6-3, 2))}, + }, + { + name: "EuclideanDistance - field and constant", + pipeline: client.Pipeline().Collection(coll.ID).Select(EuclideanDistance("v1", Vector64{4.0, 5.0, 6.0}).As("euclidean")), + want: map[string]interface{}{"euclidean": math.Sqrt(math.Pow(4-1, 2) + math.Pow(5-2, 2) + math.Pow(6-3, 2))}, + }, + { + name: "CosineDistance - field and field", + pipeline: client.Pipeline().Collection(coll.ID).Select(CosineDistance("v1", FieldOf("v2")).As("cosine")), + want: map[string]interface{}{"cosine": 1 - (32 / (math.Sqrt(14) * math.Sqrt(77)))}, + }, + { + name: "CosineDistance - field and constant", + pipeline: client.Pipeline().Collection(coll.ID).Select(CosineDistance("v1", Vector64{4.0, 5.0, 6.0}).As("cosine")), + want: map[string]interface{}{"cosine": 1 - (32 / (math.Sqrt(14) * math.Sqrt(77)))}, + }, + } + + ctx := context.Background() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + iter := test.pipeline.Execute(ctx).Results() + defer iter.Stop() + + docs, err := iter.GetAll() + if err != nil { + t.Fatalf("GetAll: %v", err) + return + } + if len(docs) != 1 { + t.Fatalf("expected 1 doc, got %d", len(docs)) + return + } + got := docs[0].Data() + if diff := testutil.Diff(got, test.want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", got, test.want, diff) + } + }) + } +} + +func timestampFuncs(t *testing.T) { + t.Parallel() + client := integrationClient(t) + coll := client.Collection(collectionIDs.New()) + h := testHelper{t} + now := time.Now() + docRef1 := coll.NewDoc() + h.mustCreate(docRef1, map[string]interface{}{ + "timestamp": now, + "unixMicros": now.UnixNano() / 1000, + "unixMillis": now.UnixNano() / 1e6, + "unixSeconds": now.Unix(), + }) + defer deleteDocuments([]*DocumentRef{docRef1}) + + tests := []struct { + name string + pipeline *Pipeline + want map[string]interface{} + }{ + { + name: "TimestampAdd day", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(TimestampAdd("timestamp", "day", 1).As("timestamp_plus_day")), + want: map[string]interface{}{"timestamp_plus_day": now.AddDate(0, 0, 1).Truncate(time.Microsecond)}, + }, + { + name: "TimestampAdd hour", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(TimestampAdd("timestamp", "hour", 1).As("timestamp_plus_hour")), + want: map[string]interface{}{"timestamp_plus_hour": now.Add(time.Hour).Truncate(time.Microsecond)}, + }, + { + name: "TimestampAdd minute", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(TimestampAdd("timestamp", "minute", 1).As("timestamp_plus_minute")), + want: map[string]interface{}{"timestamp_plus_minute": now.Add(time.Minute).Truncate(time.Microsecond)}, + }, + { + name: "TimestampAdd second", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(TimestampAdd("timestamp", "second", 1).As("timestamp_plus_second")), + want: map[string]interface{}{"timestamp_plus_second": now.Add(time.Second).Truncate(time.Microsecond)}, + }, + { + name: "TimestampSubtract", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(TimestampSubtract("timestamp", "hour", 1).As("timestamp_minus_hour")), + want: map[string]interface{}{"timestamp_minus_hour": now.Add(-time.Hour).Truncate(time.Microsecond)}, + }, + { + name: "TimestampToUnixMicros", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(FieldOf("timestamp").TimestampToUnixMicros().As("timestamp_micros")), + want: map[string]interface{}{"timestamp_micros": now.UnixNano() / 1000}, + }, + { + name: "TimestampToUnixMillis", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(FieldOf("timestamp").TimestampToUnixMillis().As("timestamp_millis")), + want: map[string]interface{}{"timestamp_millis": now.UnixNano() / 1e6}, + }, + { + name: "TimestampToUnixSeconds", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(FieldOf("timestamp").TimestampToUnixSeconds().As("timestamp_seconds")), + want: map[string]interface{}{"timestamp_seconds": now.Unix()}, + }, + { + name: "UnixMicrosToTimestamp - constant", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(UnixMicrosToTimestamp(ConstantOf(now.UnixNano() / 1000)).As("timestamp_from_micros")), + want: map[string]interface{}{"timestamp_from_micros": now.Truncate(time.Microsecond)}, + }, + { + name: "UnixMicrosToTimestamp - fieldname", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(UnixMicrosToTimestamp("unixMicros").As("timestamp_from_micros")), + want: map[string]interface{}{"timestamp_from_micros": now.Truncate(time.Microsecond)}, + }, + { + name: "UnixMillisToTimestamp", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(UnixMillisToTimestamp(ConstantOf(now.UnixNano() / 1e6)).As("timestamp_from_millis")), + want: map[string]interface{}{"timestamp_from_millis": now.Truncate(time.Millisecond)}, + }, + { + name: "UnixSecondsToTimestamp", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(UnixSecondsToTimestamp("unixSeconds").As("timestamp_from_seconds")), + want: map[string]interface{}{"timestamp_from_seconds": now.Truncate(time.Second)}, + }, + { + name: "CurrentTimestamp", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(CurrentTimestamp().As("current_timestamp")), + want: map[string]interface{}{"current_timestamp": time.Now().Truncate(time.Microsecond)}, + }, + { + name: "TimestampTruncate day", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(TimestampTruncate("timestamp", "day").As("timestamp_trunc_day")), + want: map[string]interface{}{"timestamp_trunc_day": time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location()).Truncate(time.Microsecond)}, + }, + { + name: "TimestampTruncate hour", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(TimestampTruncate("timestamp", "hour").As("timestamp_trunc_hour")), + want: map[string]interface{}{"timestamp_trunc_hour": time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), 0, 0, 0, now.Location()).Truncate(time.Microsecond)}, + }, + { + name: "TimestampTruncate minute", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(TimestampTruncate("timestamp", "minute").As("timestamp_trunc_minute")), + want: map[string]interface{}{"timestamp_trunc_minute": time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), now.Minute(), 0, 0, now.Location()).Truncate(time.Microsecond)}, + }, + { + name: "TimestampTruncate second", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(TimestampTruncate("timestamp", "second").As("timestamp_trunc_second")), + want: map[string]interface{}{"timestamp_trunc_second": time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), now.Minute(), now.Second(), 0, now.Location()).Truncate(time.Microsecond)}, + }, + { + name: "TimestampTruncateWithTimezone day", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(TimestampTruncateWithTimezone("timestamp", "day", "America/New_York").As("timestamp_trunc_day_ny")), + want: map[string]interface{}{"timestamp_trunc_day_ny": func() time.Time { + loc, _ := time.LoadLocation("America/New_York") + nowInLoc := now.In(loc) + return time.Date(nowInLoc.Year(), nowInLoc.Month(), nowInLoc.Day(), 0, 0, 0, 0, loc).Truncate(time.Microsecond) + }()}, + }, + } + + ctx := context.Background() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + iter := test.pipeline.Execute(ctx).Results() + defer iter.Stop() + + docs, err := iter.GetAll() + if err != nil { + t.Fatalf("GetAll: %v", err) + } + if len(docs) != 1 { + t.Fatalf("expected 1 doc, got %d", len(docs)) + } + got := docs[0].Data() + margin := 0 * time.Microsecond + if test.name == "CurrentTimestamp" { + margin = 5 * time.Second + } + if diff := testutil.Diff(got, test.want, cmpopts.EquateApproxTime(margin)); diff != "" { + t.Errorf("got: %v, want: %v, diff: %s", got, test.want, diff) + } + }) + } +} + +func arithmeticFuncs(t *testing.T) { + t.Parallel() + h := testHelper{t} + client := integrationClient(t) + coll := client.Collection(collectionIDs.New()) + docRef1 := coll.NewDoc() + h.mustCreate(docRef1, map[string]interface{}{ + "a": int(1), + "b": int(2), + "c": -3, + "d": 4.5, + "e": -5.5, + }) + defer deleteDocuments([]*DocumentRef{docRef1}) + + tests := []struct { + name string + pipeline *Pipeline + want map[string]interface{} + }{ + { + name: "Add - left FieldOf, right FieldOf", + pipeline: client.Pipeline().Collection(coll.ID).Select(Add(FieldOf("a"), FieldOf("b")).As("add")), + want: map[string]interface{}{"add": int64(3)}, + }, + { + name: "Add - left FieldOf, right ConstantOf", + pipeline: client.Pipeline().Collection(coll.ID).Select(Add(FieldOf("a"), ConstantOf(2)).As("add")), + want: map[string]interface{}{"add": int64(3)}, + }, + { + name: "Add - left FieldOf, right constant", + pipeline: client.Pipeline().Collection(coll.ID).Select(Add(FieldOf("a"), 5).As("add")), + want: map[string]interface{}{"add": int64(6)}, + }, + { + name: "Add - left fieldname, right constant", + pipeline: client.Pipeline().Collection(coll.ID).Select(Add("a", 5).As("add")), + want: map[string]interface{}{"add": int64(6)}, + }, + { + name: "Add - left fieldpath, right constant", + pipeline: client.Pipeline().Collection(coll.ID).Select(Add(FieldPath([]string{"a"}), 5).As("add")), + want: map[string]interface{}{"add": int64(6)}, + }, + { + name: "Add - left fieldpath, right expression", + pipeline: client.Pipeline().Collection(coll.ID).Select(Add(FieldPath([]string{"a"}), Add(FieldOf("b"), FieldOf("d"))).As("add")), + want: map[string]interface{}{"add": float64(7.5)}, + }, + { + name: "Subtract", + pipeline: client.Pipeline().Collection(coll.ID).Select(Subtract("a", FieldOf("b")).As("subtract")), + want: map[string]interface{}{"subtract": int64(-1)}, + }, + { + name: "Multiply", + pipeline: client.Pipeline().Collection(coll.ID).Select(Multiply("a", 5).As("multiply")), + want: map[string]interface{}{"multiply": int64(5)}, + }, + { + name: "Divide", + pipeline: client.Pipeline().Collection(coll.ID).Select(Divide("a", FieldOf("d")).As("divide")), + want: map[string]interface{}{"divide": float64(1 / 4.5)}, + }, + { + name: "Mod", + pipeline: client.Pipeline().Collection(coll.ID).Select(Mod("a", FieldOf("b")).As("mod")), + want: map[string]interface{}{"mod": int64(1)}, + }, + { + name: "Pow", + pipeline: client.Pipeline().Collection(coll.ID).Select(Pow("a", FieldOf("b")).As("pow")), + want: map[string]interface{}{"pow": float64(1)}, + }, + { + name: "Abs - fieldname", + pipeline: client.Pipeline().Collection(coll.ID).Select(Abs("c").As("abs")), + want: map[string]interface{}{"abs": int64(3)}, + }, + { + name: "Abs - fieldPath", + pipeline: client.Pipeline().Collection(coll.ID).Select(Abs(FieldPath([]string{"c"})).As("abs")), + want: map[string]interface{}{"abs": int64(3)}, + }, + { + name: "Abs - Expr", + pipeline: client.Pipeline().Collection(coll.ID).Select(Abs(Add(FieldOf("b"), FieldOf("d"))).As("abs")), + want: map[string]interface{}{"abs": float64(6.5)}, + }, + { + name: "Ceil", + pipeline: client.Pipeline().Collection(coll.ID).Select(Ceil("d").As("ceil")), + want: map[string]interface{}{"ceil": float64(5)}, + }, + { + name: "Floor", + pipeline: client.Pipeline().Collection(coll.ID).Select(Floor("d").As("floor")), + want: map[string]interface{}{"floor": float64(4)}, + }, + { + name: "Round", + pipeline: client.Pipeline().Collection(coll.ID).Select(Round("d").As("round")), + want: map[string]interface{}{"round": float64(5)}, + }, + { + name: "Sqrt", + pipeline: client.Pipeline().Collection(coll.ID).Select(Sqrt("d").As("sqrt")), + want: map[string]interface{}{"sqrt": math.Sqrt(4.5)}, + }, + { + name: "Log", + pipeline: client.Pipeline().Collection(coll.ID).Select(Log("d", 2).As("log")), + want: map[string]interface{}{"log": math.Log2(4.5)}, + }, + { + name: "Log10", + pipeline: client.Pipeline().Collection(coll.ID).Select(Log10("d").As("log10")), + want: map[string]interface{}{"log10": math.Log10(4.5)}, + }, + { + name: "Ln", + pipeline: client.Pipeline().Collection(coll.ID).Select(Ln("d").As("ln")), + want: map[string]interface{}{"ln": math.Log(4.5)}, + }, + { + name: "Exp", + pipeline: client.Pipeline().Collection(coll.ID).Select(Exp("d").As("exp")), + want: map[string]interface{}{"exp": math.Exp(4.5)}, + }, + } + + ctx := context.Background() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + iter := test.pipeline.Execute(ctx).Results() + defer iter.Stop() + + docs, err := iter.GetAll() + if err != nil { + t.Fatalf("GetAll: %v", err) + } + if len(docs) != 1 { + t.Fatalf("expected 1 doc, got %d", len(docs)) + } + got := docs[0].Data() + if diff := testutil.Diff(got, test.want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", got, test.want, diff) + } + }) + } +} + +func aggregateFuncs(t *testing.T) { + t.Parallel() + h := testHelper{t} + client := integrationClient(t) + coll := client.Collection(collectionIDs.New()) + docRef1 := coll.NewDoc() + h.mustCreate(docRef1, map[string]interface{}{ + "a": 1, + }) + docRef2 := coll.NewDoc() + h.mustCreate(docRef2, map[string]interface{}{ + "a": 2, + }) + docRef3 := coll.NewDoc() + h.mustCreate(docRef3, map[string]interface{}{ + "b": 2, + }) + defer deleteDocuments([]*DocumentRef{docRef1, docRef2, docRef3}) + + tests := []struct { + name string + pipeline *Pipeline + want map[string]interface{} + }{ + { + name: "Sum - fieldname arg", + pipeline: client.Pipeline(). + Collection(coll.ID). + Aggregate(Sum("a").As("sum_a")), + want: map[string]interface{}{"sum_a": int64(3)}, + }, + { + name: "Sum - fieldpath arg", + pipeline: client.Pipeline(). + Collection(coll.ID). + Aggregate(Sum(FieldPath([]string{"a"})).As("sum_a")), + want: map[string]interface{}{"sum_a": int64(3)}, + }, + { + name: "Sum - FieldOf Expr", + pipeline: client.Pipeline(). + Collection(coll.ID). + Aggregate(Sum(FieldOf("a")).As("sum_a")), + want: map[string]interface{}{"sum_a": int64(3)}, + }, + { + name: "Sum - FieldOf Path Expr", + pipeline: client.Pipeline(). + Collection(coll.ID). + Aggregate(Sum(FieldOf(FieldPath([]string{"a"}))).As("sum_a")), + want: map[string]interface{}{"sum_a": int64(3)}, + }, + { + name: "Avg", + pipeline: client.Pipeline(). + Collection(coll.ID). + Aggregate(Average("a").As("avg_a")), + want: map[string]interface{}{"avg_a": float64(1.5)}, + }, + { + name: "Count", + pipeline: client.Pipeline(). + Collection(coll.ID). + Aggregate(Count("a").As("count_a")), + want: map[string]interface{}{"count_a": int64(2)}, + }, + { + name: "CountAll", + pipeline: client.Pipeline(). + Collection(coll.ID). + Aggregate(CountAll().As("count_all")), + want: map[string]interface{}{"count_all": int64(3)}, + }, + } + + ctx := context.Background() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + iter := test.pipeline.Execute(ctx).Results() + defer iter.Stop() + + docs, err := iter.GetAll() + if err != nil { + t.Fatalf("GetAll: %v", err) + } + if len(docs) != 1 { + t.Fatalf("expected 1 doc, got %d", len(docs)) + } + got := docs[0].Data() + if diff := testutil.Diff(got, test.want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", got, test.want, diff) + } + }) + } +} + +func comparisonFuncs(t *testing.T) { + t.Parallel() + ctx := context.Background() + client := integrationClient(t) + now := time.Now() + coll := client.Collection(collectionIDs.New()) + doc1data := map[string]interface{}{ + "timestamp": now, + "a": 1, + "b": 2, + "c": -3, + "d": 4.5, + "e": -5.5, + } + _, err := coll.Doc("doc1").Create(ctx, doc1data) + if err != nil { + t.Fatalf("Create: %v", err) + } + doc2data := map[string]interface{}{ + "timestamp": now, + "a": 2, + "b": 2, + "c": -3, + "d": 4.5, + "e": -5.5, + } + _, err = coll.Doc("doc2").Create(ctx, doc2data) + if err != nil { + t.Fatalf("Create: %v", err) + } + defer deleteDocuments([]*DocumentRef{coll.Doc("doc1"), coll.Doc("doc2")}) + + doc1want := map[string]interface{}{"a": int64(1), "b": int64(2), "c": int64(-3), "d": float64(4.5), "e": float64(-5.5), "timestamp": now.Truncate(time.Microsecond)} + + tests := []struct { + name string + pipeline *Pipeline + want []map[string]interface{} + }{ + { + name: "Equal", + pipeline: client.Pipeline(). + Collection(coll.ID). + Where(Equal("a", 1)), + want: []map[string]interface{}{doc1want}, + }, + { + name: "NotEqual", + pipeline: client.Pipeline(). + Collection(coll.ID). + Where(NotEqual("a", 2)), + want: []map[string]interface{}{doc1want}, + }, + { + name: "LessThan", + pipeline: client.Pipeline(). + Collection(coll.ID). + Where(LessThan("a", 2)), + want: []map[string]interface{}{doc1want}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + iter := test.pipeline.Execute(ctx).Results() + defer iter.Stop() + + docs, err := iter.GetAll() + if err != nil { + t.Fatalf("GetAll: %v", err) + } + if len(docs) != len(test.want) { + t.Fatalf("expected %d doc(s), got %d", len(test.want), len(docs)) + } + + var gots []map[string]interface{} + for _, doc := range docs { + got := doc.Data() + if ts, ok := got["timestamp"].(time.Time); ok { + got["timestamp"] = ts.Truncate(time.Microsecond) + } + gots = append(gots, got) + } + + if diff := testutil.Diff(gots, test.want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", gots, test.want, diff) + } + }) + } +} + +func keyFuncs(t *testing.T) { + t.Parallel() + h := testHelper{t} + client := integrationClient(t) + coll := client.Collection(collectionIDs.New()) + docRef1 := coll.Doc("doc1") + h.mustCreate(docRef1, map[string]interface{}{ + "a": "hello", + "b": "world", + }) + defer deleteDocuments([]*DocumentRef{docRef1}) + + tests := []struct { + name string + pipeline *Pipeline + want map[string]interface{} + }{ + { + name: "CollectionId", + pipeline: client.Pipeline().Collection(coll.ID).Select(GetCollectionID("__name__").As("collectionId")), + want: map[string]interface{}{"collectionId": coll.ID}, + }, + { + name: "DocumentId", + pipeline: client.Pipeline().Collection(coll.ID).Select(GetDocumentID(docRef1).As("documentId")), + want: map[string]interface{}{"documentId": "doc1"}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := context.Background() + iter := test.pipeline.Execute(ctx).Results() + defer iter.Stop() + + docs, err := iter.GetAll() + if err != nil { + t.Fatalf("GetAll: %v", err) + return + } + if len(docs) != 1 { + t.Fatalf("expected 1 doc, got %d", len(docs)) + } + got := docs[0].Data() + if diff := testutil.Diff(got, test.want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", got, test.want, diff) + } + }) + } +} + +func generalFuncs(t *testing.T) { + t.Parallel() + h := testHelper{t} + client := integrationClient(t) + coll := client.Collection(collectionIDs.New()) + docRef1 := coll.NewDoc() + h.mustCreate(docRef1, map[string]interface{}{ + "a": "hello", + "b": "world", + }) + defer deleteDocuments([]*DocumentRef{docRef1}) + + tests := []struct { + name string + pipeline *Pipeline + want map[string]interface{} + }{ + { + name: "Length - string literal", + pipeline: client.Pipeline().Collection(coll.ID).Select(Length(ConstantOf("hello")).As("len")), + want: map[string]interface{}{"len": int64(5)}, + }, + { + name: "Length - field", + pipeline: client.Pipeline().Collection(coll.ID).Select(Length("a").As("len")), + want: map[string]interface{}{"len": int64(5)}, + }, + { + name: "Length - field path", + pipeline: client.Pipeline().Collection(coll.ID).Select(Length(FieldPath{"a"}).As("len")), + want: map[string]interface{}{"len": int64(5)}, + }, + { + name: "Reverse - string literal", + pipeline: client.Pipeline().Collection(coll.ID).Select(Reverse(ConstantOf("hello")).As("reverse")), + want: map[string]interface{}{"reverse": "olleh"}, + }, + { + name: "Reverse - field", + pipeline: client.Pipeline().Collection(coll.ID).Select(Reverse("a").As("reverse")), + want: map[string]interface{}{"reverse": "olleh"}, + }, + { + name: "Reverse - field path", + pipeline: client.Pipeline().Collection(coll.ID).Select(Reverse(FieldPath{"a"}).As("reverse")), + want: map[string]interface{}{"reverse": "olleh"}, + }, + { + name: "Concat - two literals", + pipeline: client.Pipeline().Collection(coll.ID).Select(Concat(ConstantOf("hello"), ConstantOf("world")).As("concat")), + want: map[string]interface{}{"concat": "helloworld"}, + }, + { + name: "Concat - literal and field", + pipeline: client.Pipeline().Collection(coll.ID).Select(Concat(ConstantOf("hello"), FieldOf("b")).As("concat")), + want: map[string]interface{}{"concat": "helloworld"}, + }, + { + name: "Concat - two fields", + pipeline: client.Pipeline().Collection(coll.ID).Select(Concat(FieldOf("a"), FieldOf("b")).As("concat")), + want: map[string]interface{}{"concat": "helloworld"}, + }, + { + name: "Concat - field and literal", + pipeline: client.Pipeline().Collection(coll.ID).Select(Concat(FieldOf("a"), ConstantOf("world")).As("concat")), + want: map[string]interface{}{"concat": "helloworld"}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := context.Background() + iter := test.pipeline.Execute(ctx).Results() + defer iter.Stop() + + docs, err := iter.GetAll() + if err != nil { + t.Fatalf("GetAll: %v", err) + return + } + if len(docs) != 1 { + t.Fatalf("expected 1 doc, got %d", len(docs)) + } + got := docs[0].Data() + if diff := testutil.Diff(got, test.want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", got, test.want, diff) + } + }) + } +} + +func logicalFuncs(t *testing.T) { + t.Parallel() + h := testHelper{t} + client := integrationClient(t) + coll := client.Collection(collectionIDs.New()) + docRef1 := coll.Doc("doc1") + doc1Data := map[string]interface{}{ + "a": 1, + "b": 2, + "c": nil, + "d": true, + "e": false, + } + h.mustCreate(docRef1, doc1Data) + + docRef2 := coll.Doc("doc2") + doc2Data := map[string]interface{}{ + "a": 1, + "b": 1, + "d": true, + "e": true, + } + h.mustCreate(docRef2, doc2Data) + defer deleteDocuments([]*DocumentRef{docRef1, docRef2}) + + doc1Want := map[string]interface{}{ + "a": int64(1), + "b": int64(2), + "c": nil, + "d": true, + "e": false, + } + doc2Want := map[string]interface{}{ + "a": int64(1), + "b": int64(1), + "d": true, + "e": true, + } + + tests := []struct { + name string + pipeline *Pipeline + want interface{} + }{ + { + name: "Conditional - true", + pipeline: client.Pipeline().Collection(coll.ID).Select(Conditional(Equal(ConstantOf(1), ConstantOf(1)), FieldOf("a"), FieldOf("b")).As("result")), + want: []map[string]interface{}{{"result": int64(1)}, {"result": int64(1)}}, + }, + { + name: "Conditional - false", + pipeline: client.Pipeline().Collection(coll.ID).Select(Conditional(Equal(ConstantOf(1), ConstantOf(0)), FieldOf("a"), FieldOf("b")).As("result")), + want: []map[string]interface{}{{"result": int64(2)}, {"result": int64(1)}}, + }, + { + name: "Conditional - field true", + pipeline: client.Pipeline().Collection(coll.ID).Select(Conditional(Equal(FieldOf("d"), ConstantOf(true)), FieldOf("a"), FieldOf("b")).As("result")), + want: []map[string]interface{}{{"result": int64(1)}, {"result": int64(1)}}, + }, + { + name: "Conditional - field false", + pipeline: client.Pipeline().Collection(coll.ID).Select(Conditional(Equal(FieldOf("e"), ConstantOf(true)), FieldOf("a"), FieldOf("b")).As("result")), + want: []map[string]interface{}{{"result": int64(2)}, {"result": int64(1)}}, + }, + { + name: "LogicalMax", + pipeline: client.Pipeline().Collection(coll.ID).Select(LogicalMaximum(FieldOf("a"), FieldOf("b")).As("max")), + want: []map[string]interface{}{{"max": int64(2)}, {"max": int64(1)}}, + }, + { + name: "LogicalMin", + pipeline: client.Pipeline().Collection(coll.ID).Select(LogicalMinimum(FieldOf("a"), FieldOf("b")).As("min")), + want: []map[string]interface{}{{"min": int64(1)}, {"min": int64(1)}}, + }, + { + name: "IfError - no error", + pipeline: client.Pipeline().Collection(coll.ID).Select(IfError(FieldOf("a"), ConstantOf(100)).As("result")), + want: []map[string]interface{}{{"result": int64(1)}, {"result": int64(1)}}, + }, + { + name: "IfError - error", + pipeline: client.Pipeline().Collection(coll.ID).Select(Divide("a", 0).IfError(ConstantOf("was error")).As("ifError")), + want: []map[string]interface{}{{"ifError": "was error"}, {"ifError": "was error"}}, + }, + { + name: "IfErrorBoolean - no error", + pipeline: client.Pipeline().Collection(coll.ID).Select(IfErrorBoolean(Equal(FieldOf("d"), ConstantOf(true)), Equal(ConstantOf(1), ConstantOf(0))).As("result")), + want: []map[string]interface{}{{"result": true}, {"result": true}}, + }, + { + name: "IfErrorBoolean - error", + pipeline: client.Pipeline().Collection(coll.ID).Select(IfErrorBoolean(Equal(FieldOf("x"), ConstantOf(true)), Equal(ConstantOf(1), ConstantOf(0))).As("result")), + want: []map[string]interface{}{{"result": false}, {"result": false}}, + }, + { + name: "IfAbsent - not absent", + pipeline: client.Pipeline().Collection(coll.ID).Select(IfAbsent(FieldOf("a"), ConstantOf(100)).As("result")), + want: []map[string]interface{}{{"result": int64(1)}, {"result": int64(1)}}, + }, + { + name: "IfAbsent - absent", + pipeline: client.Pipeline().Collection(coll.ID).Select(IfAbsent(FieldOf("x"), ConstantOf(100)).As("result")), + want: []map[string]interface{}{{"result": int64(100)}, {"result": int64(100)}}, + }, + { + name: "And", + pipeline: client.Pipeline().Collection(coll.ID).Where( + And( + Equal(FieldOf("a"), 1), + Equal(FieldOf("b"), 2), + ), + ), + want: []map[string]interface{}{doc1Want}, + }, + { + name: "Or", + pipeline: client.Pipeline().Collection(coll.ID).Where( + Or( + Equal(FieldOf("b"), 2), + Equal(FieldOf("e"), true), + ), + ), + want: []map[string]interface{}{doc1Want, doc2Want}, + }, + { + name: "Not", + pipeline: client.Pipeline().Collection(coll.ID).Where( + Not(Equal(FieldOf("b"), 1)), + ), + want: []map[string]interface{}{doc1Want}, + }, + { + name: "Xor", + pipeline: client.Pipeline().Collection(coll.ID).Where( + Xor( + Equal(FieldOf("d"), true), + Equal(FieldOf("e"), true), + ), + ), + want: []map[string]interface{}{doc1Want}, + }, + { + name: "FieldExists", + pipeline: client.Pipeline().Collection(coll.ID).Where(FieldExists("c")), + want: []map[string]interface{}{doc1Want}, + }, + { + name: "IsError", + pipeline: client.Pipeline().Collection(coll.ID).Where(IsError(Divide("a", 0))), + want: []map[string]interface{}{doc1Want, doc2Want}, + }, + { + name: "IsAbsent", + pipeline: client.Pipeline().Collection(coll.ID).Where(IsAbsent("c")), + want: []map[string]interface{}{doc2Want}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := context.Background() + iter := test.pipeline.Execute(ctx).Results() + defer iter.Stop() + + docs, err := iter.GetAll() + if err != nil { + t.Fatalf("GetAll: %v", err) + return + } + + lastStage := test.pipeline.stages[len(test.pipeline.stages)-1] + lastStageName := lastStage.name() + + if lastStageName == stageNameSelect { // This is a select query + want, ok := test.want.([]map[string]interface{}) + if !ok { + t.Fatalf("invalid test.want type for select query: %T", test.want) + return + } + if len(docs) != len(want) { + t.Fatalf("expected %d doc(s), got %d", len(want), len(docs)) + return + } + var gots []map[string]interface{} + for _, doc := range docs { + gots = append(gots, doc.Data()) + } + if diff := testutil.Diff(gots, want, cmpopts.SortSlices(func(a, b map[string]interface{}) bool { + // A stable sort for the results. + // Try to sort by "result", "max", "min", "ifError" + if v1, ok := a["result"]; ok { + v2 := b["result"] + switch v1 := v1.(type) { + case int64: + return v1 < v2.(int64) + case bool: + return !v1 && v2.(bool) + } + } + if v1, ok := a["max"]; ok { + return v1.(int64) < b["max"].(int64) + } + if v1, ok := a["min"]; ok { + return v1.(int64) < b["min"].(int64) + } + if v1, ok := a["ifError"]; ok { + return v1.(string) < b["ifError"].(string) + } + return false + })); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", gots, want, diff) + } + } else if lastStageName == stageNameWhere { // This is a where query (filter condition) + want, ok := test.want.([]map[string]interface{}) + if !ok { + t.Fatalf("invalid test.want type for where query: %T", test.want) + return + } + if len(docs) != len(want) { + t.Fatalf("expected %d doc(s), got %d", len(want), len(docs)) + return + } + var gots []map[string]interface{} + for _, doc := range docs { + got := doc.Data() + gots = append(gots, got) + } + // Sort slices before comparing for consistent test results + sort.Slice(gots, func(i, j int) bool { + if gots[i]["a"].(int64) == gots[j]["a"].(int64) { + return gots[i]["b"].(int64) < gots[j]["b"].(int64) + } + return gots[i]["a"].(int64) < gots[j]["a"].(int64) + }) + sort.Slice(want, func(i, j int) bool { + if want[i]["a"].(int64) == want[j]["a"].(int64) { + return want[i]["b"].(int64) < want[j]["b"].(int64) + } + return want[i]["a"].(int64) < want[j]["a"].(int64) + }) + if diff := testutil.Diff(gots, want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", gots, want, diff) + } + } else { + t.Fatalf("unknown pipeline stage: %s", lastStageName) + return + } + }) + } +} + +func TestIntegration_CreateFromQuery(t *testing.T) { + skipIfNotEnterprise(t) + ctx := context.Background() + client := integrationClient(t) + coll := integrationColl(t) + h := testHelper{t} + + books := testBooks()[:3] + var docRefs []*DocumentRef + for _, b := range books { + docRef := coll.NewDoc() + h.mustCreate(docRef, b) + docRefs = append(docRefs, docRef) + } + t.Cleanup(func() { + deleteDocuments(docRefs) + }) + + q := coll.Where("rating", ">", 4.2) + p := client.Pipeline().CreateFromQuery(q) + iter := p.Execute(ctx).Results() + defer iter.Stop() + results, err := iter.GetAll() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if len(results) != 2 { + t.Errorf("got %d documents, want 2", len(results)) + } +} + +func TestIntegration_CreateFromAggregationQuery(t *testing.T) { + skipIfNotEnterprise(t) + ctx := context.Background() + client := integrationClient(t) + coll := integrationColl(t) + h := testHelper{t} + + books := testBooks()[:3] + var docRefs []*DocumentRef + for _, b := range books { + docRef := coll.NewDoc() + h.mustCreate(docRef, b) + docRefs = append(docRefs, docRef) + } + t.Cleanup(func() { + deleteDocuments(docRefs) + }) + + ag := coll.NewAggregationQuery().WithCount("count") + p := client.Pipeline().CreateFromAggregationQuery(ag) + iter := p.Execute(ctx).Results() + defer iter.Stop() + doc, err := iter.Next() + if err != nil { + t.Fatalf("Failed to iterate: %v", err) + } + if !doc.Exists() { + t.Fatalf("Exists: got: false, want: true") + } + data := doc.Data() + if data["count"] != int64(3) { + t.Errorf("got count %d, want 3", data["count"]) + } +} diff --git a/firestore/pipeline_result.go b/firestore/pipeline_result.go new file mode 100644 index 000000000000..c60389655033 --- /dev/null +++ b/firestore/pipeline_result.go @@ -0,0 +1,287 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package firestore + +import ( + "context" + "errors" + "fmt" + "io" + "time" + + pb "cloud.google.com/go/firestore/apiv1/firestorepb" + "cloud.google.com/go/internal/trace" + "google.golang.org/api/iterator" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/timestamppb" +) + +// PipelineResult is a result returned from executing a pipeline. +type PipelineResult struct { + // Ref is the DocumentRef for this result. It may be nil if the result + // does not correspond to a specific Firestore document (e.g., an aggregation result + // without grouping, or a synthetic document from a stage). + Ref *DocumentRef + + // CreateTime is the time at which the document was created. + // It may be nil if the result does not correspond to a specific Firestore document + CreateTime *time.Time + + // UpdateTime is the time at which the document was last changed. + // It may be nil if the result does not correspond to a specific Firestore document + UpdateTime *time.Time + + // ExecutionTime is the time at which the document(s) were read. + ExecutionTime *time.Time + + c *Client + proto *pb.Document +} + +func newPipelineResult(ref *DocumentRef, proto *pb.Document, c *Client, executionTime *timestamppb.Timestamp) (*PipelineResult, error) { + pr := &PipelineResult{ + Ref: ref, + c: c, + proto: proto, + } + if proto != nil { + if proto.GetCreateTime() != nil { + if err := proto.GetCreateTime().CheckValid(); err != nil { + return nil, err + } + createTime := proto.GetCreateTime().AsTime() + pr.CreateTime = &createTime + } + if proto.GetUpdateTime() != nil { + if err := proto.GetUpdateTime().CheckValid(); err != nil { + return nil, err + } + updateTime := proto.GetUpdateTime().AsTime() + pr.UpdateTime = &updateTime + } + } + if executionTime != nil { + if err := executionTime.CheckValid(); err != nil { + return nil, err + } + execTime := executionTime.AsTime() + pr.ExecutionTime = &execTime + } + return pr, nil +} + +// Exists reports whether the PipelineResult represents an document. +// Even if Exists returns false, the rest of the fields are valid. +func (p *PipelineResult) Exists() bool { + return p.proto != nil +} + +// Data returns the PipelineResult's fields as a map. +// It is equivalent to +// +// var m map[string]any +// p.DataTo(&m) +func (p *PipelineResult) Data() map[string]any { + if p == nil || !p.Exists() { + return nil + } + m, err := createMapFromValueMap(p.proto.Fields, p.c) + + // Any error here is a bug in the client. + if err != nil { + panic(fmt.Sprintf("firestore: %v", err)) + } + return m +} + +// DataTo uses the PipelineResult's fields to populate v, which can be a pointer to a +// map[string]any or a pointer to a struct. +// This is similar to [DocumentSnapshot.DataTo] +func (p *PipelineResult) DataTo(v any) error { + if p == nil || !p.Exists() { + return status.Errorf(codes.NotFound, "document does not exist") + } + return setFromProtoValue(v, &pb.Value{ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{Fields: p.proto.Fields}}}, p.c) +} + +// PipelineResultIterator is an iterator over PipelineResults from a pipeline execution. +type PipelineResultIterator struct { + iter pipelineResultIteratorInternal + err error // Stores sticky error from Next() or construction +} + +// Next returns the next result. Its second return value is iterator.Done if there +// are no more results. Once Next returns Done, all subsequent calls will return +// Done. +func (it *PipelineResultIterator) Next() (*PipelineResult, error) { + if it.err != nil { + return nil, it.err + } + if it.iter == nil { // Iterator was stopped or not initialized + return nil, iterator.Done + } + + pr, err := it.iter.next() + if err != nil { + it.err = err // Store sticky error + } + return pr, err +} + +// Stop stops the iterator, freeing its resources. +// Always call Stop when you are done with a DocumentIterator. +// It is not safe to call Stop concurrently with Next. +func (it *PipelineResultIterator) Stop() { + if it.iter != nil { + it.iter.stop() + } + // Set a sticky error indicating the iterator is now done if not already errored. + if it.err == nil { + it.err = iterator.Done + } +} + +// GetAll returns all the documents remaining from the iterator. +// It is not necessary to call Stop on the iterator after calling GetAll. +func (it *PipelineResultIterator) GetAll() ([]*PipelineResult, error) { + if it.err != nil { + return nil, it.err + } + defer it.Stop() + + var results []*PipelineResult + for { + pr, err := it.Next() + if err == iterator.Done { + break + } + if err != nil { + return results, err + } + results = append(results, pr) + } + return results, nil +} + +// pipelineResultIteratorInternal is an unexported interface defining the core iteration logic. +type pipelineResultIteratorInternal interface { + next() (*PipelineResult, error) + stop() + getExplainStats() (*pb.ExplainStats, error) +} + +// streamPipelineResultIterator is the concrete implementation for gRPC streaming of pipeline results. +type streamPipelineResultIterator struct { + ctx context.Context + cancel func() + p *Pipeline + streamClient pb.Firestore_ExecutePipelineClient + currResp *pb.ExecutePipelineResponse + currRespResultsIdx int + statsPb *pb.ExplainStats +} + +// Ensure that streamPipelineResultIterator implements the pipelineResultIteratorInternal interface. +var _ pipelineResultIteratorInternal = (*streamPipelineResultIterator)(nil) + +func newStreamPipelineResultIterator(ctx context.Context, p *Pipeline) *streamPipelineResultIterator { + ctx, cancel := context.WithCancel(ctx) + return &streamPipelineResultIterator{ + ctx: ctx, + cancel: cancel, + p: p, + } +} + +// Each ExecutePipelineResponse received from Firestore service contains a list of Documents +// On each next() call, return a single document. +func (it *streamPipelineResultIterator) next() (_ *PipelineResult, err error) { + client := it.p.c + + // streamClient is initialized on first next call + if it.streamClient == nil { + it.ctx = trace.StartSpan(it.ctx, "cloud.google.com/go/firestore.ExecutePipeline") + defer func() { + if errors.Is(err, iterator.Done) { + trace.EndSpan(it.ctx, nil) + } else { + trace.EndSpan(it.ctx, err) + } + }() + req, err := it.p.toExecutePipelineRequest() + if err != nil { + return nil, err + } + + ctx := withRequestParamsHeader(it.ctx, reqParamsHeaderVal(client.path())) + it.streamClient, err = client.c.ExecutePipeline(ctx, req) + if err != nil { + return nil, err + } + } + + // If the current response is nil or all its results have been processed, + // receive the next response from the stream. + if it.currResp == nil || it.currRespResultsIdx >= len(it.currResp.GetResults()) { + var res *pb.ExecutePipelineResponse + for { + res, err = it.streamClient.Recv() + if err == io.EOF { + return nil, iterator.Done + } + if err != nil { + return nil, err + } + if res.GetResults() != nil { + it.currResp = res + it.currRespResultsIdx = 0 + it.statsPb = res.GetExplainStats() + break + } + // No results => partial progress; keep receiving + } + } + + // Get the next document proto from the current response. + docProto := it.currResp.GetResults()[it.currRespResultsIdx] + it.currRespResultsIdx++ + + var docRef *DocumentRef + if len(docProto.GetName()) != 0 { + var pathErr error + docRef, pathErr = pathToDoc(docProto.GetName(), client) + if pathErr != nil { + return nil, pathErr + } + } + + pr, err := newPipelineResult(docRef, docProto, client, it.currResp.GetExecutionTime()) + if err != nil { + return nil, err + } + return pr, nil +} + +func (it *streamPipelineResultIterator) stop() { + it.cancel() +} + +func (it *streamPipelineResultIterator) getExplainStats() (*pb.ExplainStats, error) { + if it == nil { + return nil, fmt.Errorf("firestore: iterator is nil") + } + return it.statsPb, nil +} diff --git a/firestore/pipeline_result_test.go b/firestore/pipeline_result_test.go new file mode 100644 index 000000000000..37482d0b0a8c --- /dev/null +++ b/firestore/pipeline_result_test.go @@ -0,0 +1,504 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package firestore + +import ( + "context" + "errors" + "io" + "testing" + "time" + + pb "cloud.google.com/go/firestore/apiv1/firestorepb" + "cloud.google.com/go/internal/testutil" + "github.com/google/go-cmp/cmp" + "google.golang.org/api/iterator" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/timestamppb" + "google.golang.org/protobuf/types/known/wrapperspb" +) + +func TestStreamPipelineResultIterator_Next(t *testing.T) { + ctx := context.Background() + client := newTestClient() // For PipelineResult construction + p := &Pipeline{c: client} // Dummy pipeline for iterator context + + now := time.Now() + tsNow := timestamppb.New(now) + ts2MinLater := timestamppb.New(now.Add(-2 * time.Minute)) + + mockResponses := []*pb.ExecutePipelineResponse{ + { // First response with two results + Results: []*pb.Document{ + {Name: "projects/test-project/databases/test-db/documents/col/doc1", Fields: map[string]*pb.Value{"foo": {ValueType: &pb.Value_StringValue{StringValue: "bar1"}}}, CreateTime: tsNow, UpdateTime: tsNow}, + {Name: "projects/test-project/databases/test-db/documents/col/doc2", Fields: map[string]*pb.Value{"foo": {ValueType: &pb.Value_StringValue{StringValue: "bar2"}}}, CreateTime: tsNow, UpdateTime: tsNow}, + }, + ExecutionTime: tsNow, + }, + { // Second response with one result + Results: []*pb.Document{ + {Name: "projects/test-project/databases/test-db/documents/col/doc3", Fields: map[string]*pb.Value{"foo": {ValueType: &pb.Value_StringValue{StringValue: "bar3"}}}, CreateTime: tsNow, UpdateTime: tsNow}, + }, + ExecutionTime: ts2MinLater, + }, + } + + tests := []struct { + name string + responses []*pb.ExecutePipelineResponse + errors []error + gotCount int + wantErr error + wantData []map[string]any + }{ + { + name: "successful iteration", + responses: mockResponses, + errors: []error{nil, nil, io.EOF}, // EOF after 2 responses (containing 3 docs) + gotCount: 3, + wantErr: iterator.Done, + wantData: []map[string]any{ + {"foo": "bar1"}, + {"foo": "bar2"}, + {"foo": "bar3"}, + }, + }, + { + name: "iteration with gRPC error", + responses: []*pb.ExecutePipelineResponse{mockResponses[0]}, // Only first response + errors: []error{nil, status.Error(codes.Unavailable, "service unavailable")}, + gotCount: 2, // Expect results from the first response before error + wantErr: status.Error(codes.Unavailable, "service unavailable"), + wantData: []map[string]any{ + {"foo": "bar1"}, + {"foo": "bar2"}, + }, + }, + { + name: "no results", + responses: []*pb.ExecutePipelineResponse{{Results: []*pb.Document{}}}, + errors: []error{io.EOF}, + gotCount: 0, + wantErr: iterator.Done, + }, + { + name: "partial progress then results", + responses: []*pb.ExecutePipelineResponse{{ExecutionTime: tsNow /* No results */}, mockResponses[0]}, + errors: []error{nil, nil, io.EOF}, + gotCount: 2, + wantErr: iterator.Done, + wantData: []map[string]any{ + {"foo": "bar1"}, + {"foo": "bar2"}, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mockStreamClient := &mockExecutePipelineClient{ + RecvResponses: tc.responses, + RecvErrors: tc.errors, + ContextVal: ctx, + } + + iter := &streamPipelineResultIterator{ + ctx: ctx, + cancel: func() {}, + p: p, + streamClient: mockStreamClient, + } + defer iter.stop() + var results []*PipelineResult + var gotErr error + var pr *PipelineResult + for { + pr, gotErr = iter.next() + if gotErr != nil { + break + } + results = append(results, pr) + } + + if len(results) != tc.gotCount { + t.Errorf("results got %d, want %d", len(results), tc.gotCount) + } + + if tc.wantErr != nil { + if gotErr == nil { + t.Fatalf("error %v, got nil", tc.wantErr) + } + if !errors.Is(gotErr, tc.wantErr) && gotErr.Error() != tc.wantErr.Error() { + t.Errorf("error got %v, want %v", gotErr, tc.wantErr) + } + } else if gotErr != nil { + t.Errorf("error got %v, want %v", gotErr, nil) + } + + if tc.wantData != nil { + if len(results) != len(tc.wantData) { + t.Fatalf("Result count mismatch for data check: expected %d, got %d", len(tc.wantData), len(results)) + } + for i, pr := range results { + data := pr.Data() + if diff := cmp.Diff(tc.wantData[i], data); diff != "" { + t.Errorf("Data mismatch for result %d (-want +got):\n%s", i, diff) + } + } + } + }) + } +} + +func TestPipelineResultIterator_Stop(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + client := newTestClient() + p := &Pipeline{c: client} + + mockStreamClient := &mockExecutePipelineClient{ + ContextVal: ctx, // Iterator will use this context + } + + // Create the public iterator which wraps the stream iterator + publicIter := &PipelineResultIterator{ + iter: &streamPipelineResultIterator{ + ctx: ctx, // This context is passed to the stream client + cancel: cancel, // This cancel func should be called by Stop + p: p, + streamClient: mockStreamClient, + }, + } + + publicIter.Stop() + + // Check if the context was cancelled + select { + case <-ctx.Done(): + // Expected: context is cancelled + default: + t.Errorf("Expected context to be cancelled after Stop(), but it was not") + } + + // Calling Stop again should be a no-op + publicIter.Stop() // Should not panic or error + + // Check that Next after Stop returns iterator.Done + _, err := publicIter.Next() + if !errors.Is(err, iterator.Done) { + t.Errorf("Next after Stop(): got %v, want %v", err, iterator.Done) + } +} + +func TestPipelineResultIterator_GetAll(t *testing.T) { + ctx := context.Background() + client := newTestClient() + p := &Pipeline{c: client} + + mockStreamClient := &mockExecutePipelineClient{ + RecvResponses: []*pb.ExecutePipelineResponse{ + {Results: []*pb.Document{ + {Name: "projects/p/databases/d/documents/c/doc1", Fields: map[string]*pb.Value{"id": {ValueType: &pb.Value_IntegerValue{IntegerValue: 1}}}}, + }}, + {Results: []*pb.Document{ + {Name: "projects/p/databases/d/documents/c/doc2", Fields: map[string]*pb.Value{"id": {ValueType: &pb.Value_IntegerValue{IntegerValue: 2}}}}, + }}, + }, + RecvErrors: []error{nil, nil, io.EOF}, // EOF after two responses + ContextVal: ctx, + } + + publicIter := &PipelineResultIterator{ + iter: &streamPipelineResultIterator{ + ctx: ctx, + cancel: func() {}, + p: p, + streamClient: mockStreamClient, + }, + } + + allResults, err := publicIter.GetAll() + if err != nil { + t.Fatalf("GetAll: %v", err) + } + if len(allResults) != 2 { + t.Errorf("results from GetAll(): got %d, want: 2", len(allResults)) + } + + data := allResults[0].Data() + if data["id"].(int64) != 1 { + t.Errorf("first result id: got %v, want: 1", data["id"]) + } + + data = allResults[1].Data() + if data["id"].(int64) != 2 { + t.Errorf("second result id: got %v, want: 2", data["id"]) + } + + // After GetAll, Next should return iterator.Done + _, nextErr := publicIter.Next() + if !errors.Is(nextErr, iterator.Done) { + t.Errorf("Next after GetAll(): got %v, want: %v", nextErr, iterator.Done) + } +} + +func TestPipelineResult_DataExtraction(t *testing.T) { + client := newTestClient() + now := time.Now() + tsNowProto := timestamppb.New(now) + + docProto := &pb.Document{ + Name: "projects/test/databases/d/documents/mycoll/mydoc", + CreateTime: tsNowProto, + UpdateTime: tsNowProto, + Fields: map[string]*pb.Value{ + "stringProp": {ValueType: &pb.Value_StringValue{StringValue: "hello"}}, + "intProp": {ValueType: &pb.Value_IntegerValue{IntegerValue: 123}}, + "boolProp": {ValueType: &pb.Value_BooleanValue{BooleanValue: true}}, + "mapProp": { + ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + "nestedString": {ValueType: &pb.Value_StringValue{StringValue: "world"}}, + }, + }}}, + }, + } + execTimeProto := timestamppb.New(now.Add(time.Second)) + + docRef, _ := pathToDoc(docProto.Name, client) + pr, err := newPipelineResult(docRef, docProto, client, execTimeProto) + if err != nil { + t.Fatalf("newPipelineResult: %v", err) + } + + // Test Data() + dataMap := pr.Data() + if dataMap["stringProp"].(string) != "hello" { + t.Errorf("stringProp: got %v, want 'hello'", dataMap["stringProp"]) + } + + if dataMap["intProp"].(int64) != 123 { + t.Errorf("intProp: got %v, want 123", dataMap["intProp"]) + } + nestedMap, ok := dataMap["mapProp"].(map[string]any) + if !ok { + t.Fatalf("mapProp is not a map[string]any") + } + if nestedMap["nestedString"].(string) != "world" { + t.Errorf("nestedString: got %v, want 'world'", nestedMap["nestedString"]) + } + + // Test DataTo() with a struct + type MyStruct struct { + StringProp string `firestore:"stringProp"` + IntProp int `firestore:"intProp"` + BoolProp bool `firestore:"boolProp"` + MapProp map[string]any `firestore:"mapProp"` + NonExistent float64 `firestore:"nonExistent"` + } + gotDst := MyStruct{ + StringProp: "world", + IntProp: 456, + BoolProp: false, + MapProp: map[string]any{"nestedString": "hello"}, + NonExistent: 456.789, + } + + wantDst := MyStruct{ + StringProp: "hello", + IntProp: 123, + BoolProp: true, + MapProp: map[string]any{"nestedString": "world"}, + NonExistent: 456.789, + } + + if err := pr.DataTo(&gotDst); err != nil { + t.Fatalf("pr.DataTo(&gotDst): %v", err) + } + + if diff := testutil.Diff(wantDst, gotDst); diff != "" { + t.Errorf("dst mismatch (-want +got):\n%s", diff) + } + + // Test Timestamps + if pr.CreateTime == nil || !pr.CreateTime.Equal(now) { + t.Errorf("CreateTime: got %v, want %v", pr.CreateTime, now) + } + if pr.ExecutionTime == nil || !pr.ExecutionTime.Equal(now.Add(time.Second)) { + t.Errorf("ExecutionTime: got %v, want %v", pr.ExecutionTime, now.Add(time.Second)) + } +} + +func TestPipelineResult_NoResults(t *testing.T) { + client := newTestClient() + execTime := time.Now() + execTimeProto := timestamppb.New(execTime) + + pr, err := newPipelineResult(nil, nil, client, execTimeProto) // No proto document + if err != nil { + t.Fatalf("newPipelineResult: %v", err) + } + + data := pr.Data() + if data != nil { + t.Errorf("pr.Data() for non-existent result: got non-nil empty map, want nil") + } + if len(data) != 0 { + t.Errorf("pr.Data() for non-existent result: got map with %d elements, want empty map", len(data)) + } + + type MyStruct struct{ Foo string } + var s MyStruct + err = pr.DataTo(&s) + if err == nil || status.Code(err) != codes.NotFound { + t.Fatalf("pr.DataTo(&s) on non-existent result failed: %v status.Code(err): %v", err, status.Code(err)) + } + if s.Foo != "" { + t.Errorf("Struct Foo for non-existent result: got %q, want \"\"", s.Foo) + } + + if pr.ExecutionTime == nil || !pr.ExecutionTime.Equal(execTime) { + t.Errorf("ExecutionTime for non-existent result: got %v, want %v", pr.ExecutionTime, execTime) + } +} + +func TestPipelineResultIterator_ExplainStats(t *testing.T) { + ctx := context.Background() + client := newTestClient() + p := &Pipeline{c: client} + + // Prepare mock stats data + explainText := "Executed in 10ms" + stringValue := &wrapperspb.StringValue{Value: explainText} + anyText, err := anypb.New(stringValue) + if err != nil { + t.Fatalf("anypb.New(stringValue): %v", err) + } + statsTextPb := &pb.ExplainStats{Data: anyText} + + // For raw data test + boolValue := &wrapperspb.BoolValue{Value: true} + anyBool, err := anypb.New(boolValue) + if err != nil { + t.Fatalf("anypb.New(boolValue): %v", err) + } + statsRawPb := &pb.ExplainStats{Data: anyBool} + + t.Run("successful case with text data", func(t *testing.T) { + mockIter := &streamPipelineResultIterator{ + ctx: ctx, + cancel: func() {}, + p: p, + statsPb: statsTextPb, + } + ps := &PipelineSnapshot{&PipelineResultIterator{iter: mockIter, err: iterator.Done}} // Pre-set to done + + stats := ps.ExplainStats() + if stats.err != nil { + t.Fatalf("ExplainStats() error: %v", stats.err) + } + + text, err := stats.Text() + if err != nil { + t.Fatalf("GetText() error: %v", err) + } + if text != explainText { + t.Errorf("GetText(): got %q, want %q", text, explainText) + } + }) + + t.Run("successful case with raw data", func(t *testing.T) { + mockIter := &streamPipelineResultIterator{ + ctx: ctx, + cancel: func() {}, + p: p, + statsPb: statsRawPb, + } + ps := &PipelineSnapshot{&PipelineResultIterator{iter: mockIter, err: iterator.Done}} + + stats := ps.ExplainStats() + if stats.err != nil { + t.Fatalf("ExplainStats() error: %v", stats.err) + } + + rawData, err := stats.RawData() + if err != nil { + t.Fatalf("GetRawData() error: %v", err) + } + if !proto.Equal(rawData, anyBool) { + t.Errorf("GetRawData(): got %v, want %v", rawData, anyBool) + } + }) + + t.Run("error case - iterator not done", func(t *testing.T) { + mockIter := &streamPipelineResultIterator{} + ps := &PipelineSnapshot{&PipelineResultIterator{iter: mockIter}} // err is nil + + stats := ps.ExplainStats() + if stats.err == nil { + t.Fatal("ExplainStats() expected error, got nil") + } + if !errors.Is(stats.err, errStatsBeforeEnd) { + t.Errorf("ExplainStats() error: got %v, want %v", stats.err, errStatsBeforeEnd) + } + }) + + t.Run("error case - iterator is nil", func(t *testing.T) { + var ps *PipelineSnapshot + stats := ps.ExplainStats() + if stats.err == nil { + t.Fatal("ExplainStats() on nil iterator expected error, got nil") + } + }) + + t.Run("error case - GetText with wrong data type", func(t *testing.T) { + mockIter := &streamPipelineResultIterator{statsPb: statsRawPb} + ps := &PipelineSnapshot{&PipelineResultIterator{iter: mockIter, err: iterator.Done}} + + stats := ps.ExplainStats() + _, err := stats.Text() + if err == nil { + t.Fatal("GetText() with wrong data type expected error, got nil") + } + }) + + t.Run("no stats available", func(t *testing.T) { + mockIter := &streamPipelineResultIterator{statsPb: nil} // No stats + ps := &PipelineSnapshot{&PipelineResultIterator{iter: mockIter, err: iterator.Done}} + + stats := ps.ExplainStats() + if stats.err != nil { + t.Fatalf("ExplainStats() error: %v", stats.err) + } + + text, err := stats.Text() + if err != nil { + t.Fatalf("GetText() error: %v", err) + } + if text != "" { + t.Errorf("GetText(): got %q, want empty string", text) + } + + rawData, err := stats.RawData() + if err != nil { + t.Fatalf("GetRawData() error: %v", err) + } + if rawData != nil { + t.Errorf("GetRawData(): got %v, want nil", rawData) + } + }) +} diff --git a/firestore/pipeline_snapshot.go b/firestore/pipeline_snapshot.go new file mode 100644 index 000000000000..18d5f64a5625 --- /dev/null +++ b/firestore/pipeline_snapshot.go @@ -0,0 +1,95 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package firestore + +import ( + "errors" + "fmt" + + pb "cloud.google.com/go/firestore/apiv1/firestorepb" + "google.golang.org/api/iterator" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/wrapperspb" +) + +// PipelineSnapshot contains zero or more [PipelineResult] objects +// representing the documents returned by a pipeline query. It provides methods +// to iterate over the documents and access metadata about the query results. +type PipelineSnapshot struct { + iter *PipelineResultIterator +} + +// Results returns an iterator over the query results. +func (ps *PipelineSnapshot) Results() *PipelineResultIterator { + return ps.iter +} + +// ExplainStats returns stats from query explain. +// If [WithExplainMode] was set to [ExplainModeExplain] or left unset, then no stats will be available. +func (ps *PipelineSnapshot) ExplainStats() *ExplainStats { + if ps == nil { + return &ExplainStats{err: errors.New("firestore: PipelineSnapshot is nil")} + } + if ps.iter == nil { + return &ExplainStats{err: errors.New("firestore: PipelineResultIterator is nil")} + } + if ps.iter.err == nil || ps.iter.err != iterator.Done { + return &ExplainStats{err: errStatsBeforeEnd} + } + statsPb, statsErr := ps.iter.iter.getExplainStats() + return &ExplainStats{statsPb: statsPb, err: statsErr} +} + +// ExplainStats is query explain stats. +// +// Contains all metadata related to pipeline planning and execution, specific +// contents depend on the supplied pipeline options. +type ExplainStats struct { + statsPb *pb.ExplainStats + err error +} + +// RawData returns the explain stats in an encoded proto format, as returned from the Firestore backend. +// The caller is responsible for unpacking this proto message. +func (es *ExplainStats) RawData() (*anypb.Any, error) { + if es.err != nil { + return nil, es.err + } + if es.statsPb == nil { + return nil, nil + } + + return es.statsPb.GetData(), nil +} + +// Text returns the explain stats as a string from the Firestore backend. +// If explain stats were requested with `outputFormat = 'text'`, the string is +// returned verbatim. If `outputFormat = 'json'`, this returns the explain stats +// as stringified JSON. +func (es *ExplainStats) Text() (string, error) { + if es.err != nil { + return "", es.err + } + if es.statsPb == nil || es.statsPb.GetData() == nil { + return "", nil + } + + var data wrapperspb.StringValue + if err := es.statsPb.GetData().UnmarshalTo(&data); err != nil { + return "", fmt.Errorf("firestore: failed to unmarshal Any to wrapperspb.StringValue: %w", err) + } + + return data.GetValue(), nil +} diff --git a/firestore/pipeline_source.go b/firestore/pipeline_source.go new file mode 100644 index 000000000000..43c3aff91b02 --- /dev/null +++ b/firestore/pipeline_source.go @@ -0,0 +1,176 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package firestore + +import ( + "fmt" + "reflect" + + pb "cloud.google.com/go/firestore/apiv1/firestorepb" +) + +// PipelineSource is a factory for creating Pipeline instances. +// It is obtained by calling [Client.Pipeline()]. +type PipelineSource struct { + client *Client +} + +// CollectionHints provides hints to the query planner. +type CollectionHints map[string]any + +// WithForceIndex specifies an index to force the query to use. +func (ch CollectionHints) WithForceIndex(index string) CollectionHints { + newCH := make(CollectionHints, len(ch)+1) + for k, v := range ch { + newCH[k] = v + } + newCH["force_index"] = index + return newCH +} + +// WithIgnoreIndexFields specifies fields to ignore when selecting an index. +func (ch CollectionHints) WithIgnoreIndexFields(fields ...string) CollectionHints { + newCH := make(CollectionHints, len(ch)+1) + for k, v := range ch { + newCH[k] = v + } + newCH["ignore_index_fields"] = fields + return newCH +} + +func (ch CollectionHints) toProto() (map[string]*pb.Value, error) { + if ch == nil { + return nil, nil + } + optsMap := make(map[string]*pb.Value) + for key, val := range ch { + valPb, _, err := toProtoValue(reflect.ValueOf(val)) + if err != nil { + return nil, fmt.Errorf("firestore: error converting option %q: %w", key, err) + } + optsMap[key] = valPb + } + return optsMap, nil +} + +// collectionStageSettings provides settings for Collection and CollectionGroup pipeline stages. +type collectionStageSettings struct { + Hints CollectionHints +} + +func (cs *collectionStageSettings) toProto() (map[string]*pb.Value, error) { + if cs == nil { + return nil, nil + } + return cs.Hints.toProto() +} + +// CollectionOption is an option for a Collection pipeline stage. +type CollectionOption interface { + apply(co *collectionStageSettings) + isCollectionOption() +} + +// CollectionGroupOption is an option for a CollectionGroup pipeline stage. +type CollectionGroupOption interface { + apply(co *collectionStageSettings) + isCollectionGroupOption() +} + +// funcOption wraps a function that modifies collectionStageSettings +// into an implementation of the CollectionOption and CollectionGroupOption interfaces. +type funcOption struct { + f func(*collectionStageSettings) +} + +func (fo *funcOption) apply(cs *collectionStageSettings) { + fo.f(cs) +} + +func (*funcOption) isCollectionOption() {} + +func (*funcOption) isCollectionGroupOption() {} + +func newFuncOption(f func(*collectionStageSettings)) *funcOption { + return &funcOption{ + f: f, + } +} + +// WithCollectionHints specifies hints for the query planner. +func WithCollectionHints(hints CollectionHints) CollectionOption { + return newFuncOption(func(cs *collectionStageSettings) { + cs.Hints = hints + }) +} + +// WithCollectionGroupHints specifies hints for the query planner. +func WithCollectionGroupHints(hints CollectionHints) CollectionGroupOption { + return newFuncOption(func(cs *collectionStageSettings) { + cs.Hints = hints + }) +} + +// Collection creates a new [Pipeline] that operates on the specified Firestore collection. +func (ps *PipelineSource) Collection(path string, opts ...CollectionOption) *Pipeline { + cs := &collectionStageSettings{} + for _, opt := range opts { + if opt != nil { + opt.apply(cs) + } + } + return newPipeline(ps.client, newInputStageCollection(path, cs)) +} + +// CollectionGroup creates a new [Pipeline] that operates on all documents in a group +// of collections that include the given ID, regardless of parent document. +// +// For example, consider: +// Countries/France/Cities/Paris = {population: 100} +// Countries/Canada/Cities/Montreal = {population: 90} +// +// CollectionGroup can be used to query across all "Cities" regardless of +// its parent "Countries". +func (ps *PipelineSource) CollectionGroup(collectionID string, opts ...CollectionGroupOption) *Pipeline { + cgs := &collectionStageSettings{} + for _, opt := range opts { + if opt != nil { + opt.apply(cgs) + } + } + return newPipeline(ps.client, newInputStageCollectionGroup("", collectionID, cgs)) +} + +// Database creates a new [Pipeline] that operates on all documents in the Firestore database. +func (ps *PipelineSource) Database() *Pipeline { + return newPipeline(ps.client, newInputStageDatabase()) +} + +// Documents creates a new [Pipeline] that operates on a specific set of Firestore documents. +func (ps *PipelineSource) Documents(refs ...*DocumentRef) *Pipeline { + return newPipeline(ps.client, newInputStageDocuments(refs...)) +} + +// CreateFromQuery creates a new [Pipeline] from the given [Query]. Under the hood, this will +// translate the query semantics (order by document ID, etc.) to an equivalent pipeline. +func (ps *PipelineSource) CreateFromQuery(query Query) *Pipeline { + return query.Pipeline() +} + +// CreateFromAggregationQuery creates a new [Pipeline] from the given [AggregationQuery]. Under the hood, this will +// translate the query semantics (order by document ID, etc.) to an equivalent pipeline. +func (ps *PipelineSource) CreateFromAggregationQuery(query *AggregationQuery) *Pipeline { + return query.Pipeline() +} diff --git a/firestore/pipeline_source_test.go b/firestore/pipeline_source_test.go new file mode 100644 index 000000000000..34b9f41abc71 --- /dev/null +++ b/firestore/pipeline_source_test.go @@ -0,0 +1,115 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package firestore + +import ( + "testing" + + pb "cloud.google.com/go/firestore/apiv1/firestorepb" + "cloud.google.com/go/internal/testutil" +) + +func TestPipelineSource_Collection(t *testing.T) { + client := newTestClient() + ps := &PipelineSource{client: client} + p := ps.Collection("users") + + if p.err != nil { + t.Fatalf("Collection: %v", p.err) + } + if len(p.stages) != 1 { + t.Fatalf("initial stages: got %d, want 1", len(p.stages)) + } + + req, err := p.toExecutePipelineRequest() + if err != nil { + t.Fatalf("toExecutePipelineRequest: %v", err) + } + + wantStage := &pb.Pipeline_Stage{ + Name: "collection", + Args: []*pb.Value{{ValueType: &pb.Value_ReferenceValue{ReferenceValue: "/users"}}}, + } + + if len(req.GetStructuredPipeline().GetPipeline().GetStages()) != 1 { + t.Fatalf("stage in proto: got %d, want 1", len(req.GetStructuredPipeline().GetPipeline().GetStages())) + } + if diff := testutil.Diff(wantStage, req.GetStructuredPipeline().GetPipeline().GetStages()[0]); diff != "" { + t.Errorf("toExecutePipelineRequest mismatch for collection stage (-want +got):\n%s", diff) + } +} + +func TestPipelineSource_CollectionGroup(t *testing.T) { + client := newTestClient() + ps := &PipelineSource{client: client} + p := ps.CollectionGroup("cities") + + if p.err != nil { + t.Fatalf("CollectionGroup: %v", p.err) + } + if len(p.stages) != 1 { + t.Fatalf("initial stages: got %d, want 1", len(p.stages)) + } + + req, err := p.toExecutePipelineRequest() + if err != nil { + t.Fatalf("toExecutePipelineRequest: %v", err) + } + + wantStage := &pb.Pipeline_Stage{ + Name: "collection_group", + Args: []*pb.Value{ + {ValueType: &pb.Value_ReferenceValue{ReferenceValue: ""}}, + {ValueType: &pb.Value_StringValue{StringValue: "cities"}}, + }, + } + + if len(req.GetStructuredPipeline().GetPipeline().GetStages()) != 1 { + t.Fatalf("stage in proto: got %d, want 1", len(req.GetStructuredPipeline().GetPipeline().GetStages())) + } + if diff := testutil.Diff(wantStage, req.GetStructuredPipeline().GetPipeline().GetStages()[0]); diff != "" { + t.Errorf("toExecutePipelineRequest mismatch for collectionGroup stage (-want +got):\n%s", diff) + } +} + +func TestPipelineSource_Database(t *testing.T) { + client := newTestClient() + ps := &PipelineSource{client: client} + p := ps.Database() + + if p.err != nil { + t.Fatalf("Database: %v", p.err) + } + if len(p.stages) != 1 { + t.Fatalf("initial stages: got %d, want 1", len(p.stages)) + } + + req, err := p.toExecutePipelineRequest() + if err != nil { + t.Fatalf("toExecutePipelineRequest: %v", err) + } + + wantStage := &pb.Pipeline_Stage{ + Name: "database", + Args: nil, + } + + if len(req.GetStructuredPipeline().GetPipeline().GetStages()) != 1 { + t.Fatalf("stage in proto: got %d, want 1", len(req.GetStructuredPipeline().GetPipeline().GetStages())) + } + if diff := testutil.Diff(wantStage, req.GetStructuredPipeline().GetPipeline().GetStages()[0]); diff != "" { + t.Errorf("toExecutePipelineRequest mismatch for database stage (-want +got):\n%s", diff) + } +} diff --git a/firestore/pipeline_stage.go b/firestore/pipeline_stage.go new file mode 100644 index 000000000000..c9602004fcfd --- /dev/null +++ b/firestore/pipeline_stage.go @@ -0,0 +1,587 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package firestore + +import ( + "fmt" + "reflect" + "strings" + + pb "cloud.google.com/go/firestore/apiv1/firestorepb" +) + +// baseStage is an internal helper to reduce repetition in pipelineStage +// implementations. +type baseStage struct { + stageName string + stagePb *pb.Pipeline_Stage +} + +func (s *baseStage) name() string { return s.stageName } +func (s *baseStage) toProto() (*pb.Pipeline_Stage, error) { return s.stagePb, nil } + +func errInvalidArg(v any, expected ...string) error { + return fmt.Errorf("firestore: invalid argument type: %T, expected one of: [%s]", v, strings.Join(expected, ", ")) +} + +const ( + stageNameAddFields = "add_fields" + stageNameAggregate = "aggregate" + stageNameCollection = "collection" + stageNameCollectionGroup = "collection_group" + stageNameDatabase = "database" + stageNameDistinct = "distinct" + stageNameDocuments = "documents" + stageNameFindNearest = "find_nearest" + stageNameRemoveFields = "remove_fields" + stageNameReplaceWith = "replace_with" + stageNameSample = "sample" + stageNameSelect = "select" + stageNameUnion = "union" + stageNameUnnest = "unnest" + stageNameWhere = "where" +) + +// internal interface for pipeline stages. +type pipelineStage interface { + toProto() (*pb.Pipeline_Stage, error) + name() string // For identification, logging, and potential validation +} + +// inputStageCollection returns all documents from the entire collection. +type inputStageCollection struct { + path string + options *collectionStageSettings +} + +func newInputStageCollection(path string, options *collectionStageSettings) *inputStageCollection { + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + return &inputStageCollection{path: path, options: options} +} +func (s *inputStageCollection) name() string { return stageNameCollection } +func (s *inputStageCollection) toProto() (*pb.Pipeline_Stage, error) { + optionsPb, err := s.options.toProto() + if err != nil { + return nil, err + } + return &pb.Pipeline_Stage{ + Name: s.name(), + Args: []*pb.Value{{ValueType: &pb.Value_ReferenceValue{ReferenceValue: s.path}}}, + Options: optionsPb, + }, nil +} + +// inputStageCollection returns all documents from the entire collection. +type inputStageCollectionGroup struct { + collectionID string + ancestor string + options *collectionStageSettings +} + +func newInputStageCollectionGroup(ancestor, collectionID string, options *collectionStageSettings) *inputStageCollectionGroup { + return &inputStageCollectionGroup{ancestor: ancestor, collectionID: collectionID, options: options} +} +func (s *inputStageCollectionGroup) name() string { return stageNameCollectionGroup } +func (s *inputStageCollectionGroup) toProto() (*pb.Pipeline_Stage, error) { + optionsPb, err := s.options.toProto() + if err != nil { + return nil, err + } + return &pb.Pipeline_Stage{ + Name: s.name(), + Args: []*pb.Value{ + {ValueType: &pb.Value_ReferenceValue{ReferenceValue: s.ancestor}}, + {ValueType: &pb.Value_StringValue{StringValue: s.collectionID}}, + }, + Options: optionsPb, + }, nil +} + +// inputStageDatabase returns all documents from the entire database. +type inputStageDatabase struct{} + +func newInputStageDatabase() *inputStageDatabase { + return &inputStageDatabase{} +} +func (s *inputStageDatabase) name() string { return stageNameDatabase } +func (s *inputStageDatabase) toProto() (*pb.Pipeline_Stage, error) { + return &pb.Pipeline_Stage{ + Name: s.name(), + }, nil +} + +type inputStageDocuments struct { + baseStage +} + +func newInputStageDocuments(refs ...*DocumentRef) *inputStageDocuments { + args := make([]*pb.Value, len(refs)) + for i, ref := range refs { + args[i] = &pb.Value{ValueType: &pb.Value_ReferenceValue{ReferenceValue: "/" + ref.shortPath}} + } + return &inputStageDocuments{baseStage{ + stageName: stageNameDocuments, + stagePb: &pb.Pipeline_Stage{ + Name: stageNameDocuments, + Args: args, + }, + }} +} + +// addFieldsStage is the internal representation of a AddFields stage. +type addFieldsStage struct { + baseStage +} + +func newAddFieldsStage(selectables ...Selectable) (*addFieldsStage, error) { + mapVal, err := projectionsToMapValue(selectables) + if err != nil { + return nil, err + } + stagePb := newUnaryStage(stageNameAddFields, mapVal) + return &addFieldsStage{baseStage{ + stageName: stageNameAddFields, + stagePb: stagePb, + }}, nil +} + +type aggregateStage struct { + baseStage +} + +func newAggregateStage(a *AggregateSpec) (*aggregateStage, error) { + if a.err != nil { + return nil, a.err + } + targetsPb, err := aliasedAggregatesToMapValue(a.accTargets) + if err != nil { + return nil, err + } + groupsPb, err := projectionsToMapValue(a.groups) + if err != nil { + return nil, err + } + return &aggregateStage{baseStage{ + stageName: stageNameAggregate, + stagePb: &pb.Pipeline_Stage{ + Name: stageNameAggregate, + Args: []*pb.Value{ + targetsPb, + groupsPb, + }, + }, + }}, nil +} + +type distinctStage struct { + baseStage +} + +// newProjectionStage is a helper for creating pipeline stages that take a +// projection as an argument. +func newProjectionStage(name string, fieldsOrSelectables ...any) (*pb.Pipeline_Stage, error) { + selectables, err := fieldsOrSelectablesToSelectables(fieldsOrSelectables...) + if err != nil { + return nil, err + } + mapVal, err := projectionsToMapValue(selectables) + if err != nil { + return nil, err + } + return &pb.Pipeline_Stage{ + Name: name, + Args: []*pb.Value{mapVal}, + }, nil +} + +func newDistinctStage(fieldsOrSelectables ...any) (*distinctStage, error) { + stagePb, err := newProjectionStage(stageNameDistinct, fieldsOrSelectables...) + if err != nil { + return nil, err + } + return &distinctStage{baseStage{stageName: stageNameDistinct, stagePb: stagePb}}, nil +} + +type findNearestStage struct { + baseStage +} + +func newFindNearestStage(vectorField any, queryVector any, measure PipelineDistanceMeasure, options *PipelineFindNearestOptions) (*findNearestStage, error) { + var propertyExpr Expression + switch v := vectorField.(type) { + case string: + propertyExpr = FieldOf(v) + case FieldPath: + propertyExpr = FieldOf(v) + case Expression: + propertyExpr = v + default: + return nil, errInvalidArg(vectorField, "string", "FieldPath", "Expression") + } + propPb, err := propertyExpr.toProto() + if err != nil { + return nil, err + } + var vectorPb *pb.Value + switch v := queryVector.(type) { + case Vector32: + vectorPb = vectorToProtoValue([]float32(v)) + case []float32: + vectorPb = vectorToProtoValue(v) + case Vector64: + vectorPb = vectorToProtoValue([]float64(v)) + case []float64: + vectorPb = vectorToProtoValue(v) + default: + return nil, errInvalidVector + } + measurePb := &pb.Value{ValueType: &pb.Value_StringValue{StringValue: string(measure)}} + var optionsPb map[string]*pb.Value + if options != nil { + optionsPb = make(map[string]*pb.Value) + if options.Limit != nil { + optionsPb["limit"] = &pb.Value{ValueType: &pb.Value_IntegerValue{IntegerValue: int64(*options.Limit)}} + } + if options.DistanceField != nil { + optionsPb["distance_field"] = &pb.Value{ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: *options.DistanceField}} + } + } + return &findNearestStage{baseStage{ + stageName: stageNameFindNearest, + stagePb: &pb.Pipeline_Stage{ + Name: stageNameFindNearest, + Args: []*pb.Value{propPb, vectorPb, measurePb}, + Options: optionsPb, + }, + }}, nil +} + +type limitStage struct { + limit int +} + +func newLimitStage(limit int) *limitStage { + return &limitStage{limit: limit} +} +func (s *limitStage) name() string { return "limit" } +func (s *limitStage) toProto() (*pb.Pipeline_Stage, error) { + arg := &pb.Value{ValueType: &pb.Value_IntegerValue{IntegerValue: int64(s.limit)}} + return &pb.Pipeline_Stage{ + Name: s.name(), + Args: []*pb.Value{arg}, + }, nil +} + +type offsetStage struct { + offset int +} + +func newOffsetStage(offset int) *offsetStage { + return &offsetStage{offset: offset} +} +func (s *offsetStage) name() string { return "offset" } +func (s *offsetStage) toProto() (*pb.Pipeline_Stage, error) { + arg := &pb.Value{ValueType: &pb.Value_IntegerValue{IntegerValue: int64(s.offset)}} + return &pb.Pipeline_Stage{ + Name: s.name(), + Args: []*pb.Value{arg}, + }, nil +} + +type removeFieldsStage struct { + baseStage +} + +func newRemoveFieldsStage(fieldpaths ...any) (*removeFieldsStage, error) { + fields := make([]Expression, len(fieldpaths)) + for i, fp := range fieldpaths { + switch v := fp.(type) { + case string: + fields[i] = FieldOf(v) + case FieldPath: + fields[i] = FieldOf(v) + default: + return nil, errInvalidArg(fp, "string", "FieldPath") + } + } + args := make([]*pb.Value, len(fields)) + for i, f := range fields { + pb, err := f.toProto() + if err != nil { + return nil, err + } + args[i] = pb + } + return &removeFieldsStage{baseStage{ + stageName: stageNameRemoveFields, + stagePb: &pb.Pipeline_Stage{ + Name: stageNameRemoveFields, + Args: args, + }, + }}, nil +} + +type replaceWithStage struct { + baseStage +} + +func newReplaceWithStage(fieldpathOrExpr any) (*replaceWithStage, error) { + var expr Expression + switch v := fieldpathOrExpr.(type) { + case string: + expr = FieldOf(v) + case FieldPath: + expr = FieldOf(v) + case Expression: + expr = v + default: + return nil, errInvalidArg(fieldpathOrExpr, "string", "FieldPath", "Expression") + } + exprPb, err := expr.toProto() + if err != nil { + return nil, err + } + return &replaceWithStage{baseStage{ + stageName: stageNameReplaceWith, + stagePb: &pb.Pipeline_Stage{ + Name: stageNameReplaceWith, + Args: []*pb.Value{exprPb, {ValueType: &pb.Value_StringValue{StringValue: "full_replace"}}}, + }, + }}, nil +} + +type sampleStage struct { + baseStage +} + +func newSampleStage(spec *SampleSpec) (*sampleStage, error) { + var sizePb *pb.Value + switch v := spec.Size.(type) { + case int: + sizePb = &pb.Value{ValueType: &pb.Value_IntegerValue{IntegerValue: int64(v)}} + case int64: + sizePb = &pb.Value{ValueType: &pb.Value_IntegerValue{IntegerValue: v}} + case float64: + sizePb = &pb.Value{ValueType: &pb.Value_DoubleValue{DoubleValue: v}} + default: + return nil, fmt.Errorf("firestore: invalid type for sample size: %T", spec.Size) + } + modePb := &pb.Value{ValueType: &pb.Value_StringValue{StringValue: string(spec.Mode)}} + return &sampleStage{baseStage{ + stageName: stageNameSample, + stagePb: &pb.Pipeline_Stage{ + Name: stageNameSample, + Args: []*pb.Value{sizePb, modePb}, + }, + }}, nil +} + +type selectStage struct { + baseStage +} + +func newSelectStage(fieldsOrSelectables ...any) (*selectStage, error) { + stagePb, err := newProjectionStage(stageNameSelect, fieldsOrSelectables...) + if err != nil { + return nil, err + } + return &selectStage{baseStage{stageName: stageNameSelect, stagePb: stagePb}}, nil +} + +type sortStage struct { + orders []Ordering +} + +func newSortStage(orders ...Ordering) *sortStage { + return &sortStage{orders: orders} +} +func (s *sortStage) name() string { return "sort" } +func (s *sortStage) toProto() (*pb.Pipeline_Stage, error) { + sortOrders := make([]*pb.Value, len(s.orders)) + for i, so := range s.orders { + fieldPb, err := so.Expr.toProto() + if err != nil { + return nil, err + } + sortOrders[i] = &pb.Value{ + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + "direction": { + ValueType: &pb.Value_StringValue{ + StringValue: string(so.Direction), + }, + }, + "expression": fieldPb, + }, + }, + }, + } + } + return &pb.Pipeline_Stage{ + Name: s.name(), + Args: sortOrders, + }, nil +} + +type unionStage struct { + baseStage +} + +func newUnionStage(other *Pipeline) (*unionStage, error) { + otherPb, err := other.toProto() + if err != nil { + return nil, err + } + return &unionStage{baseStage{ + stageName: stageNameUnion, + stagePb: &pb.Pipeline_Stage{ + Name: stageNameUnion, + Args: []*pb.Value{ + {ValueType: &pb.Value_PipelineValue{PipelineValue: otherPb}}, + }, + }, + }}, nil +} + +type unnestStage struct { + baseStage +} + +func newUnnestStage(fieldExpr Expression, alias string, opts *UnnestOptions) (*unnestStage, error) { + exprPb, err := fieldExpr.toProto() + if err != nil { + return nil, err + } + aliasPb, err := FieldOf(alias).toProto() + if err != nil { + return nil, err + } + var optionsPb map[string]*pb.Value + if opts != nil && opts.IndexField != nil { + var indexFieldExpr Expression + switch v := opts.IndexField.(type) { + case FieldPath: + indexFieldExpr = FieldOf(v) + case string: + indexFieldExpr = FieldOf(v) + default: + return nil, errInvalidArg(opts.IndexField, "string", "FieldPath") + } + indexPb, err := indexFieldExpr.toProto() + if err != nil { + return nil, err + } + optionsPb = make(map[string]*pb.Value) + optionsPb["index_field"] = indexPb + } + return &unnestStage{baseStage{ + stageName: stageNameUnnest, + stagePb: &pb.Pipeline_Stage{ + Name: stageNameUnnest, + Args: []*pb.Value{exprPb, aliasPb}, + Options: optionsPb, + }, + }}, nil +} + +func newUnnestStageFromSelectable(field Selectable, opts *UnnestOptions) (*unnestStage, error) { + alias, expr := field.getSelectionDetails() + return newUnnestStage(expr, alias, opts) +} + +type whereStage struct { + baseStage +} + +// newUnaryStage is a helper for creating pipeline stages that take a single +// proto as an argument. +func newUnaryStage(name string, val *pb.Value) *pb.Pipeline_Stage { + return &pb.Pipeline_Stage{ + Name: name, + Args: []*pb.Value{val}, + } +} + +func newWhereStage(condition BooleanExpression) (*whereStage, error) { + argsPb, err := condition.toProto() + if err != nil { + return nil, err + } + return &whereStage{baseStage{ + stageName: stageNameWhere, + stagePb: newUnaryStage(stageNameWhere, argsPb), + }}, nil +} + +// RawStageOptions holds the options for a RawStage. +type RawStageOptions map[string]any + +// RawStage is a generic stage in the pipeline. +// It provides a flexible way to extend the pipeline's functionality by adding custom +// stages. It also allows the users to call the stages that are supported by the Firestore backend +// but not yet available in the current SDK version. +type RawStage struct { + stageName string + args []any + options RawStageOptions +} + +// NewRawStage creates a new RawStage with the given name. +func NewRawStage(name string) *RawStage { + return &RawStage{stageName: name} +} + +// WithArguments sets the arguments for the RawStage. +func (s *RawStage) WithArguments(args ...any) *RawStage { + s.args = args + return s +} + +// WithOptions sets the options for the RawStage. +func (s *RawStage) WithOptions(options RawStageOptions) *RawStage { + s.options = options + return s +} + +func (s *RawStage) name() string { return s.stageName } + +func (s *RawStage) toProto() (*pb.Pipeline_Stage, error) { + argsPb := make([]*pb.Value, len(s.args)) + for i, arg := range s.args { + val, _, err := toProtoValue(reflect.ValueOf(arg)) + if err != nil { + return nil, fmt.Errorf("firestore: error converting raw stage argument %d: %w", i, err) + } + argsPb[i] = val + } + + optionsPb := make(map[string]*pb.Value, len(s.options)) + for key, val := range s.options { + valPb, _, err := toProtoValue(reflect.ValueOf(val)) + if err != nil { + return nil, fmt.Errorf("firestore: error converting raw stage option %q: %w", key, err) + } + optionsPb[key] = valPb + } + + return &pb.Pipeline_Stage{ + Name: s.name(), + Args: argsPb, + Options: optionsPb, + }, nil +} diff --git a/firestore/pipeline_stage_test.go b/firestore/pipeline_stage_test.go new file mode 100644 index 000000000000..8f6a78f7e357 --- /dev/null +++ b/firestore/pipeline_stage_test.go @@ -0,0 +1,431 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package firestore + +import ( + "context" + "testing" + + pb "cloud.google.com/go/firestore/apiv1/firestorepb" + "cloud.google.com/go/internal/testutil" +) + +func TestPipelineStages(t *testing.T) { + docRef1 := &DocumentRef{ + Path: "projects/projectID/databases/(default)/documents/collection/doc1", + shortPath: "collection/doc1", + } + docRef2 := &DocumentRef{ + Path: "projects/projectID/databases/(default)/documents/collection/doc2", + shortPath: "collection/doc2", + } + + testcases := []struct { + desc string + stage pipelineStage + want *pb.Pipeline_Stage + }{ + { + desc: "inputStageCollection", + stage: newInputStageCollection("my-collection", nil), + want: &pb.Pipeline_Stage{ + Name: "collection", + Args: []*pb.Value{{ValueType: &pb.Value_ReferenceValue{ReferenceValue: "/my-collection"}}}, + }, + }, + { + desc: "inputStageCollectionGroup", + stage: newInputStageCollectionGroup("ancestor/path", "my-collection-group", nil), + want: &pb.Pipeline_Stage{ + Name: "collection_group", + Args: []*pb.Value{ + {ValueType: &pb.Value_ReferenceValue{ReferenceValue: "ancestor/path"}}, + {ValueType: &pb.Value_StringValue{StringValue: "my-collection-group"}}, + }, + }, + }, + { + desc: "inputStageDatabase", + stage: newInputStageDatabase(), + want: &pb.Pipeline_Stage{Name: "database"}, + }, + { + desc: "inputStageDocuments", + stage: newInputStageDocuments(docRef1, docRef2), + want: &pb.Pipeline_Stage{ + Name: "documents", + Args: []*pb.Value{ + {ValueType: &pb.Value_ReferenceValue{ReferenceValue: "/collection/doc1"}}, + {ValueType: &pb.Value_ReferenceValue{ReferenceValue: "/collection/doc2"}}, + }, + }, + }, + { + desc: "limitStage", + stage: newLimitStage(10), + want: &pb.Pipeline_Stage{ + Name: "limit", + Args: []*pb.Value{{ValueType: &pb.Value_IntegerValue{IntegerValue: 10}}}, + }, + }, + { + desc: "offsetStage", + stage: newOffsetStage(5), + want: &pb.Pipeline_Stage{ + Name: "offset", + Args: []*pb.Value{{ValueType: &pb.Value_IntegerValue{IntegerValue: 5}}}, + }, + }, + { + desc: "sortStage", + stage: newSortStage(Ascending(FieldOf("name")), Descending(FieldOf("age"))), + want: &pb.Pipeline_Stage{ + Name: "sort", + Args: []*pb.Value{ + {ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{Fields: map[string]*pb.Value{ + "direction": {ValueType: &pb.Value_StringValue{StringValue: "ascending"}}, + "expression": {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "name"}}, + }}}}, + {ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{Fields: map[string]*pb.Value{ + "direction": {ValueType: &pb.Value_StringValue{StringValue: "descending"}}, + "expression": {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "age"}}, + }}}}, + }, + }, + }, + } + + for _, tc := range testcases { + t.Run(tc.desc, func(t *testing.T) { + got, err := tc.stage.toProto() + if err != nil { + t.Fatalf("toProto() failed: %v", err) + } + if diff := testutil.Diff(got, tc.want); diff != "" { + t.Errorf("toProto() returned diff (-got +want): %s", diff) + } + }) + } +} + +func TestSelectStage(t *testing.T) { + stage, err := newSelectStage("name", FieldOf("age"), Add(FieldOf("score"), 10).As("new_score")) + if err != nil { + t.Fatalf("newSelectStage() failed: %v", err) + } + + want := &pb.Pipeline_Stage{ + Name: "select", + Args: []*pb.Value{ + {ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{Fields: map[string]*pb.Value{ + "name": {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "name"}}, + "age": {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "age"}}, + "new_score": {ValueType: &pb.Value_FunctionValue{FunctionValue: &pb.Function{ + Name: "add", + Args: []*pb.Value{ + {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "score"}}, + {ValueType: &pb.Value_IntegerValue{IntegerValue: 10}}, + }, + }}}, + }}}}, + }, + } + + got, err := stage.toProto() + if err != nil { + t.Fatalf("toProto() failed: %v", err) + } + if diff := testutil.Diff(got, want); diff != "" { + t.Errorf("toProto() returned diff (-got +want): %s", diff) + } +} + +func TestWhereStage(t *testing.T) { + condition := Equal(FieldOf("genre"), "Sci-Fi") + stage, err := newWhereStage(condition) + if err != nil { + t.Fatalf("newWhereStage() failed: %v", err) + } + + want := &pb.Pipeline_Stage{ + Name: "where", + Args: []*pb.Value{ + {ValueType: &pb.Value_FunctionValue{FunctionValue: &pb.Function{ + Name: "equal", + Args: []*pb.Value{ + {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "genre"}}, + {ValueType: &pb.Value_StringValue{StringValue: "Sci-Fi"}}, + }, + }}}, + }, + } + + got, err := stage.toProto() + if err != nil { + t.Fatalf("toProto() failed: %v", err) + } + if diff := testutil.Diff(got, want); diff != "" { + t.Errorf("toProto() returned diff (-got +want): %s", diff) + } +} + +func TestAddFieldsStage(t *testing.T) { + stage, err := newAddFieldsStage(FieldOf("name").As("name"), Add(FieldOf("score"), 10).As("new_score")) + if err != nil { + t.Fatalf("newAddFieldsStage() failed: %v", err) + } + + want := &pb.Pipeline_Stage{ + Name: "add_fields", + Args: []*pb.Value{ + {ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{Fields: map[string]*pb.Value{ + "name": {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "name"}}, + "new_score": {ValueType: &pb.Value_FunctionValue{FunctionValue: &pb.Function{ + Name: "add", + Args: []*pb.Value{ + {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "score"}}, + {ValueType: &pb.Value_IntegerValue{IntegerValue: 10}}, + }, + }}}, + }}}}, + }, + } + + got, err := stage.toProto() + if err != nil { + t.Fatalf("toProto() failed: %v", err) + } + if diff := testutil.Diff(got, want); diff != "" { + t.Errorf("toProto() returned diff (-got +want): %s", diff) + } +} + +func TestAggregateStage(t *testing.T) { + spec := NewAggregateSpec(Sum("score").As("total_score")).WithGroups("category") + stage, err := newAggregateStage(spec) + if err != nil { + t.Fatalf("newAggregateStage() failed: %v", err) + } + + want := &pb.Pipeline_Stage{ + Name: "aggregate", + Args: []*pb.Value{ + {ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{Fields: map[string]*pb.Value{ + "total_score": {ValueType: &pb.Value_FunctionValue{FunctionValue: &pb.Function{ + Name: "sum", + Args: []*pb.Value{ + {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "score"}}, + }, + }}}, + }}}}, + {ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{Fields: map[string]*pb.Value{ + "category": {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "category"}}, + }}}}, + }, + } + + got, err := stage.toProto() + if err != nil { + t.Fatalf("toProto() failed: %v", err) + } + if diff := testutil.Diff(got, want); diff != "" { + t.Errorf("toProto() returned diff (-got +want): %s", diff) + } +} + +func TestDistinctStage(t *testing.T) { + stage, err := newDistinctStage("category", FieldOf("author")) + if err != nil { + t.Fatalf("newDistinctStage() failed: %v", err) + } + + want := &pb.Pipeline_Stage{ + Name: "distinct", + Args: []*pb.Value{ + {ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{Fields: map[string]*pb.Value{ + "category": {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "category"}}, + "author": {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "author"}}, + }}}}, + }, + } + + got, err := stage.toProto() + if err != nil { + t.Fatalf("toProto() failed: %v", err) + } + if diff := testutil.Diff(got, want); diff != "" { + t.Errorf("toProto() returned diff (-got +want): %s", diff) + } +} + +func TestFindNearestStage(t *testing.T) { + limit := 10 + distanceField := "distance" + stage, err := newFindNearestStage("embedding", []float64{1, 2, 3}, PipelineDistanceMeasureEuclidean, &PipelineFindNearestOptions{Limit: &limit, DistanceField: &distanceField}) + if err != nil { + t.Fatalf("newFindNearestStage() failed: %v", err) + } + + want := &pb.Pipeline_Stage{ + Name: "find_nearest", + Args: []*pb.Value{ + {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "embedding"}}, + vectorToProtoValue([]float64{1, 2, 3}), + {ValueType: &pb.Value_StringValue{StringValue: "euclidean"}}, + }, + Options: map[string]*pb.Value{ + "limit": {ValueType: &pb.Value_IntegerValue{IntegerValue: 10}}, + "distance_field": {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "distance"}}, + }, + } + + got, err := stage.toProto() + if err != nil { + t.Fatalf("toProto() failed: %v", err) + } + if diff := testutil.Diff(got, want); diff != "" { + t.Errorf("toProto() returned diff (-got +want): %s", diff) + } +} + +func TestRemoveFieldsStage(t *testing.T) { + stage, err := newRemoveFieldsStage("price", FieldPath{"author", "name"}) + if err != nil { + t.Fatalf("newRemoveFieldsStage() failed: %v", err) + } + + want := &pb.Pipeline_Stage{ + Name: "remove_fields", + Args: []*pb.Value{ + {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "price"}}, + {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "author.name"}}, + }, + } + + got, err := stage.toProto() + if err != nil { + t.Fatalf("toProto() failed: %v", err) + } + if diff := testutil.Diff(got, want); diff != "" { + t.Errorf("toProto() returned diff (-got +want): %s", diff) + } +} + +func TestReplaceStage(t *testing.T) { + stage, err := newReplaceWithStage("metadata") + if err != nil { + t.Fatalf("newReplaceStage() failed: %v", err) + } + + want := &pb.Pipeline_Stage{ + Name: "replace_with", + Args: []*pb.Value{ + {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "metadata"}}, + {ValueType: &pb.Value_StringValue{StringValue: "full_replace"}}, + }, + } + + got, err := stage.toProto() + if err != nil { + t.Fatalf("toProto() returned diff (-got +want): %v", err) + } + if diff := testutil.Diff(got, want); diff != "" { + t.Errorf("toProto() returned diff (-got +want): %s", diff) + } +} + +func TestSampleStage(t *testing.T) { + spec := SampleByDocuments(100) + stage, err := newSampleStage(spec) + if err != nil { + t.Fatalf("newSampleStage() failed: %v", err) + } + + want := &pb.Pipeline_Stage{ + Name: "sample", + Args: []*pb.Value{ + {ValueType: &pb.Value_IntegerValue{IntegerValue: 100}}, + {ValueType: &pb.Value_StringValue{StringValue: "documents"}}, + }, + } + + got, err := stage.toProto() + if err != nil { + t.Fatalf("toProto() failed: %v", err) + } + if diff := testutil.Diff(got, want); diff != "" { + t.Errorf("toProto() returned diff (-got +want):%s", diff) + } +} + +func TestUnionStage(t *testing.T) { + client, err := NewClient(context.Background(), "projectID") + if err != nil { + t.Fatalf("NewClient: %v", err) + } + otherPipeline := newPipeline(client, newInputStageCollection("other_collection", nil)) + stage, err := newUnionStage(otherPipeline) + if err != nil { + t.Fatalf("newUnionStage() failed: %v", err) + } + + want := &pb.Pipeline_Stage{ + Name: "union", + Args: []*pb.Value{ + {ValueType: &pb.Value_PipelineValue{PipelineValue: &pb.Pipeline{ + Stages: []*pb.Pipeline_Stage{ + { + Name: "collection", + Args: []*pb.Value{{ValueType: &pb.Value_ReferenceValue{ReferenceValue: "/other_collection"}}}, + }, + }, + }}}, + }, + } + + got, err := stage.toProto() + if err != nil { + t.Fatalf("toProto() failed: %v", err) + } + if diff := testutil.Diff(got, want); diff != "" { + t.Errorf("toProto() returned diff (-got +want): %s", diff) + } +} + +func TestUnnestStage(t *testing.T) { + stage, err := newUnnestStage(FieldOf("tags"), "tag", &UnnestOptions{IndexField: "index"}) + if err != nil { + t.Fatalf("newUnnestStage() failed: %v", err) + } + + want := &pb.Pipeline_Stage{ + Name: "unnest", + Args: []*pb.Value{ + {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "tags"}}, + {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "tag"}}, + }, + Options: map[string]*pb.Value{ + "index_field": {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "index"}}, + }, + } + + got, err := stage.toProto() + if err != nil { + t.Fatalf("toProto() failed: %v", err) + } + if diff := testutil.Diff(got, want); diff != "" { + t.Errorf("toProto() returned diff (-got +want): %s", diff) + } +} diff --git a/firestore/pipeline_test.go b/firestore/pipeline_test.go new file mode 100644 index 000000000000..e3509c836bb4 --- /dev/null +++ b/firestore/pipeline_test.go @@ -0,0 +1,404 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package firestore + +import ( + "context" + "io" + "testing" + + pb "cloud.google.com/go/firestore/apiv1/firestorepb" + "github.com/google/go-cmp/cmp" + "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/testing/protocmp" +) + +// mockExecutePipelineClient is a mock implementation of pb.Firestore_ExecutePipelineClient. +type mockExecutePipelineClient struct { + pb.Firestore_ExecutePipelineClient // Embed for forward compatibility + RecvResponses []*pb.ExecutePipelineResponse + RecvErrors []error + RecvIdx int + CloseSendErr error + HeaderVal metadata.MD + TrailerVal metadata.MD + ContextVal context.Context + SendHeaderVal metadata.MD +} + +func (m *mockExecutePipelineClient) Recv() (*pb.ExecutePipelineResponse, error) { + if m.ContextVal != nil && m.ContextVal.Err() != nil { + return nil, m.ContextVal.Err() + } + if m.RecvIdx < len(m.RecvResponses) || m.RecvIdx < len(m.RecvErrors) { + var resp *pb.ExecutePipelineResponse + var err error + if m.RecvIdx < len(m.RecvResponses) { + resp = m.RecvResponses[m.RecvIdx] + } + if m.RecvIdx < len(m.RecvErrors) { + err = m.RecvErrors[m.RecvIdx] + } + m.RecvIdx++ + return resp, err + } + return nil, io.EOF +} +func (m *mockExecutePipelineClient) CloseSend() error { return m.CloseSendErr } +func (m *mockExecutePipelineClient) Header() (metadata.MD, error) { return m.HeaderVal, nil } +func (m *mockExecutePipelineClient) Trailer() metadata.MD { return m.TrailerVal } +func (m *mockExecutePipelineClient) Context() context.Context { return m.ContextVal } +func (m *mockExecutePipelineClient) SendHeader(md metadata.MD) error { + m.SendHeaderVal = md + return nil +} +func (m *mockExecutePipelineClient) SetHeader(md metadata.MD) error { return nil } +func (m *mockExecutePipelineClient) SetTrailer(md metadata.MD) {} +func (m *mockExecutePipelineClient) SendMsg(i any) error { return nil } +func (m *mockExecutePipelineClient) RecvMsg(i any) error { return nil } + +// Test helper to create a minimal Client for non-RPC tests +func newTestClient() *Client { + return &Client{ + projectID: "test-project", + databaseID: "test-db", + } +} + +func TestPipeline_Limit(t *testing.T) { + client := newTestClient() + ps := &PipelineSource{client: client} + p := ps.Collection("users").Limit(10) + + if p.err != nil { + t.Fatalf("Pipeline.Limit() returned error: %v", p.err) + } + if len(p.stages) != 2 { + t.Fatalf("Expected 2 stages, got %d", len(p.stages)) + } + + req, err := p.toExecutePipelineRequest() + if err != nil { + t.Fatalf("p.toExecutePipelineRequest() failed: %v", err) + } + + stages := req.GetStructuredPipeline().GetPipeline().GetStages() + if len(stages) != 2 { + t.Fatalf("Expected 2 stages in proto, got %d", len(stages)) + } + + wantLimitStage := &pb.Pipeline_Stage{ + Name: "limit", + Args: []*pb.Value{{ValueType: &pb.Value_IntegerValue{IntegerValue: 10}}}, + } + if diff := cmp.Diff(wantLimitStage, stages[1], protocmp.Transform()); diff != "" { + t.Errorf("toExecutePipelineRequest() mismatch for limit stage (-want +got):\n%s", diff) + } +} + +func TestPipeline_ToExecutePipelineRequest(t *testing.T) { + client := newTestClient() + ps := &PipelineSource{client: client} + p := ps.Collection("items").Limit(5) + + req, err := p.toExecutePipelineRequest() + if err != nil { + t.Fatalf("toExecutePipelineRequest: %v", err) + } + + if req.GetDatabase() != "projects/test-project/databases/test-db" { + t.Errorf("req.GetDatabase: got %s, want %s", req.GetDatabase(), "projects/test-project/databases/test-db") + } + + pipelineProto := req.GetStructuredPipeline().GetPipeline() + if pipelineProto == nil { + t.Fatal("StructuredPipeline.Pipeline is nil") + } + + stagesProto := pipelineProto.GetStages() + if len(stagesProto) != 2 { + t.Fatalf("stages: got %d want 2", len(stagesProto)) + } + + // Check collection stage + wantCollStage := &pb.Pipeline_Stage{ + Name: "collection", + Args: []*pb.Value{{ValueType: &pb.Value_ReferenceValue{ReferenceValue: "/items"}}}, + } + if diff := cmp.Diff(wantCollStage, stagesProto[0], protocmp.Transform()); diff != "" { + t.Errorf("Collection stage mismatch (-want +got):\n%s", diff) + } + + // Check limit stage + wantLimitStage := &pb.Pipeline_Stage{ + Name: "limit", + Args: []*pb.Value{{ValueType: &pb.Value_IntegerValue{IntegerValue: 5}}}, + } + if diff := cmp.Diff(wantLimitStage, stagesProto[1], protocmp.Transform()); diff != "" { + t.Errorf("Limit stage mismatch (-want +got):\n%s", diff) + } +} + +func TestPipeline_Sort(t *testing.T) { + client := newTestClient() + ps := &PipelineSource{client: client} + p := ps.Collection("users").Sort(Ordering{Expr: FieldOf("age"), Direction: OrderingDesc}) + + req, err := p.toExecutePipelineRequest() + if err != nil { + t.Fatalf("p.toExecutePipelineRequest() failed: %v", err) + } + + stages := req.GetStructuredPipeline().GetPipeline().GetStages() + if len(stages) != 2 { + t.Fatalf("Expected 2 stages in proto, got %d", len(stages)) + } + + wantSortStage := &pb.Pipeline_Stage{ + Name: "sort", + Args: []*pb.Value{ + { + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + "expression": {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "age"}}, + "direction": {ValueType: &pb.Value_StringValue{StringValue: "descending"}}, + }, + }, + }, + }, + }, + } + if diff := cmp.Diff(wantSortStage, stages[1], protocmp.Transform()); diff != "" { + t.Errorf("toExecutePipelineRequest() mismatch for sort stage (-want +got):\n%s", diff) + } +} + +func TestPipeline_Offset(t *testing.T) { + client := newTestClient() + ps := &PipelineSource{client: client} + p := ps.Collection("users").Offset(20) + + req, err := p.toExecutePipelineRequest() + if err != nil { + t.Fatalf("p.toExecutePipelineRequest() failed: %v", err) + } + + stages := req.GetStructuredPipeline().GetPipeline().GetStages() + if len(stages) != 2 { + t.Fatalf("Expected 2 stages in proto, got %d", len(stages)) + } + + wantOffsetStage := &pb.Pipeline_Stage{ + Name: "offset", + Args: []*pb.Value{{ValueType: &pb.Value_IntegerValue{IntegerValue: 20}}}, + } + if diff := cmp.Diff(wantOffsetStage, stages[1], protocmp.Transform()); diff != "" { + t.Errorf("toExecutePipelineRequest() mismatch for offset stage (-want +got):\n%s", diff) + } +} + +func TestPipeline_Select(t *testing.T) { + client := newTestClient() + ps := &PipelineSource{client: client} + p := ps.Collection("users").Select("name", FieldOf("age"), Add(FieldOf("score"), 10).As("new_score")) + + req, err := p.toExecutePipelineRequest() + if err != nil { + t.Fatalf("p.toExecutePipelineRequest() failed: %v", err) + } + + stages := req.GetStructuredPipeline().GetPipeline().GetStages() + if len(stages) != 2 { + t.Fatalf("Expected 2 stages in proto, got %d", len(stages)) + } + + wantSelectStage := &pb.Pipeline_Stage{ + Name: "select", + Args: []*pb.Value{ + {ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + "name": {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "name"}}, + "age": {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "age"}}, + "new_score": {ValueType: &pb.Value_FunctionValue{FunctionValue: &pb.Function{ + Name: "add", + Args: []*pb.Value{ + {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "score"}}, + {ValueType: &pb.Value_IntegerValue{IntegerValue: 10}}, + }, + }}}, + }, + }, + }}, + }, + } + if diff := cmp.Diff(wantSelectStage, stages[1], protocmp.Transform()); diff != "" { + t.Errorf("toExecutePipelineRequest() mismatch for select stage (-want +got):\n%s", diff) + } +} + +func TestPipeline_AddFields(t *testing.T) { + client := newTestClient() + ps := &PipelineSource{client: client} + p := ps.Collection("users").AddFields(Add(FieldOf("score"), 10).As("new_score")) + + req, err := p.toExecutePipelineRequest() + if err != nil { + t.Fatalf("p.toExecutePipelineRequest() failed: %v", err) + } + + stages := req.GetStructuredPipeline().GetPipeline().GetStages() + if len(stages) != 2 { + t.Fatalf("Expected 2 stages in proto, got %d", len(stages)) + } + + wantAddFieldsStage := &pb.Pipeline_Stage{ + Name: "add_fields", + Args: []*pb.Value{ + {ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + "new_score": {ValueType: &pb.Value_FunctionValue{FunctionValue: &pb.Function{ + Name: "add", + Args: []*pb.Value{ + {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "score"}}, + {ValueType: &pb.Value_IntegerValue{IntegerValue: 10}}, + }, + }}}, + }, + }, + }}, + }, + } + if diff := cmp.Diff(wantAddFieldsStage, stages[1], protocmp.Transform()); diff != "" { + t.Errorf("toExecutePipelineRequest() mismatch for addFields stage (-want +got):\n%s", diff) + } +} + +func TestPipeline_Where(t *testing.T) { + client := newTestClient() + ps := &PipelineSource{client: client} + p := ps.Collection("users").Where(Equal(FieldOf("age"), 30)) + + req, err := p.toExecutePipelineRequest() + if err != nil { + t.Fatalf("p.toExecutePipelineRequest() failed: %v", err) + } + + stages := req.GetStructuredPipeline().GetPipeline().GetStages() + if len(stages) != 2 { + t.Fatalf("Expected 2 stages in proto, got %d", len(stages)) + } + + wantWhereStage := &pb.Pipeline_Stage{ + Name: "where", + Args: []*pb.Value{ + {ValueType: &pb.Value_FunctionValue{FunctionValue: &pb.Function{ + Name: "equal", + Args: []*pb.Value{ + {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "age"}}, + {ValueType: &pb.Value_IntegerValue{IntegerValue: 30}}, + }, + }}}, + }, + } + if diff := cmp.Diff(wantWhereStage, stages[1], protocmp.Transform()); diff != "" { + t.Errorf("toExecutePipelineRequest() mismatch for where stage (-want +got):\n%s", diff) + } +} + +func TestPipeline_Aggregate(t *testing.T) { + client := newTestClient() + ps := &PipelineSource{client: client} + p := ps.Collection("users").Aggregate(Sum("age").As("total_age")) + + req, err := p.toExecutePipelineRequest() + if err != nil { + t.Fatalf("p.toExecutePipelineRequest() failed: %v", err) + } + + stages := req.GetStructuredPipeline().GetPipeline().GetStages() + if len(stages) != 2 { + t.Fatalf("Expected 2 stages in proto, got %d", len(stages)) + } + + wantAggregateStage := &pb.Pipeline_Stage{ + Name: "aggregate", + Args: []*pb.Value{ + {ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + "total_age": {ValueType: &pb.Value_FunctionValue{FunctionValue: &pb.Function{ + Name: "sum", + Args: []*pb.Value{ + {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "age"}}, + }, + }}}, + }, + }, + }}, + {ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{}}}, + }, + } + if diff := cmp.Diff(wantAggregateStage, stages[1], protocmp.Transform()); diff != "" { + t.Errorf("toExecutePipelineRequest() mismatch for aggregate stage (-want +got):\n%s", diff) + } +} + +func TestPipeline_AggregateWithSpec(t *testing.T) { + client := newTestClient() + ps := &PipelineSource{client: client} + spec := NewAggregateSpec(Average("rating").As("avg_rating")).WithGroups("genre") + p := ps.Collection("books").AggregateWithSpec(spec) + + req, err := p.toExecutePipelineRequest() + if err != nil { + t.Fatalf("p.toExecutePipelineRequest() failed: %v", err) + } + + stages := req.GetStructuredPipeline().GetPipeline().GetStages() + if len(stages) != 2 { + t.Fatalf("Expected 2 stages in proto, got %d", len(stages)) + } + + wantAggregateStage := &pb.Pipeline_Stage{ + Name: "aggregate", + Args: []*pb.Value{ + {ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + "avg_rating": {ValueType: &pb.Value_FunctionValue{FunctionValue: &pb.Function{ + Name: "average", + Args: []*pb.Value{ + {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "rating"}}, + }, + }}}, + }, + }, + }}, + {ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + "genre": {ValueType: &pb.Value_FieldReferenceValue{FieldReferenceValue: "genre"}}, + }, + }, + }}, + }, + } + if diff := cmp.Diff(wantAggregateStage, stages[1], protocmp.Transform()); diff != "" { + t.Errorf("toExecutePipelineRequest() mismatch for aggregate stage (-want +got):\n%s", diff) + } +} diff --git a/firestore/pipeline_utils.go b/firestore/pipeline_utils.go new file mode 100644 index 000000000000..f6e0a0748013 --- /dev/null +++ b/firestore/pipeline_utils.go @@ -0,0 +1,220 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package firestore + +import ( + "errors" + "fmt" + "reflect" + + pb "cloud.google.com/go/firestore/apiv1/firestorepb" +) + +func toArrayOfExprOrConstant(val []any) []Expression { + exprs := make([]Expression, 0, len(val)) + for _, v := range val { + exprs = append(exprs, toExprOrConstant(v)) + } + return exprs +} + +// newFieldAndArrayBooleanExpr creates a new BooleanExpr for functions that operate on a field/expression and an array of values. +func newFieldAndArrayBooleanExpr(name string, exprOrFieldPath any, values any) BooleanExpression { + return &baseBooleanExpression{baseFunction: newBaseFunction(name, []Expression{asFieldExpr(exprOrFieldPath), asArrayFunctionExpr(values)})} +} + +// toExprs converts a plain Go value or an existing Expr into an Expr. +// Plain values are wrapped in a Constant. +func toExprs(val []any) []Expression { + exprs := make([]Expression, len(val)) + for i, v := range val { + exprs[i] = toExprOrConstant(v) + } + return exprs +} + +// toExprsFromSlice converts a slice of any type into a slice of Expr, wrapping plain values in Constants. +func toExprsFromSlice[T any](val []T) []Expression { + exprs := make([]Expression, len(val)) + for i, v := range val { + exprs[i] = toExprOrConstant(v) + } + return exprs +} + +// val should be single Expr or array of Expr/constants +func asArrayFunctionExpr(val any) Expression { + if expr, ok := val.(Expression); ok { + return expr + } + + arrayVal := reflect.ValueOf(val) + if arrayVal.Kind() != reflect.Slice { + return &baseExpression{err: fmt.Errorf("firestore: value must be a slice or Expr, but got %T", val)} + } + + // Convert the slice of any to []Expression + var exprs []Expression + for i := 0; i < arrayVal.Len(); i++ { + exprs = append(exprs, toExprOrConstant(arrayVal.Index(i).Interface())) + } + return newBaseFunction("array", exprs) +} + +// asInt64Expr converts a value to an Expr that evaluates to an int64, or returns an error Expr if conversion is not possible. +func asInt64Expr(val any) Expression { + switch v := val.(type) { + case Expression: + return v + case int, int8, int16, int32, int64, uint8, uint16, uint32: + return ConstantOf(v) + default: + return &baseExpression{err: fmt.Errorf("firestore: value must be a int, int8, int16, int32, int64, uint8, uint16, uint32 or Expr, but got %T", val)} + } +} + +// asStringExpr converts a value to an Expr that evaluates to a string, or returns an error Expr if conversion is not possible. +func asStringExpr(val any) Expression { + switch v := val.(type) { + case Expression: + return v + case string: + return ConstantOf(v) + default: + return &baseExpression{err: fmt.Errorf("firestore: value must be a string or Expr, but got %T", val)} + } +} + +// asVectorExpr converts a value to an Expr that evaluates to a vector type (Vector32, Vector64, []float32, []float64), or returns an error Expr if conversion is not possible. +func asVectorExpr(val any) Expression { + switch v := val.(type) { + case Expression: + return v + case Vector32, Vector64, []float32, []float64: + return ConstantOf(v) + default: + return &baseExpression{err: fmt.Errorf("firestore: value must be a []float32, []float64, Vector32, Vector64 or Expr, but got %T", val)} + } +} + +// toExprOrConstant converts a plain Go value or an existing Expr into an Expr. +// Plain values are wrapped in a Constant. +func toExprOrConstant(val any) Expression { + if expr, ok := val.(Expression); ok { + return expr + } + return ConstantOf(val) +} + +// asFieldExpr converts a plain Go string or FieldPath into a field expression. +// If the value is already an Expr, it's returned directly. +func asFieldExpr(val any) Expression { + switch v := val.(type) { + case Expression: + return v + case FieldPath: + return FieldOf(v) + case string: + return FieldOf(v) + default: + return &baseExpression{err: fmt.Errorf("firestore: value must be a string, FieldPath, or Expr, but got %T", val)} + } +} + +// leftRightToBaseFunction is a helper for creating binary functions like Add or Eq. +// It ensures the left operand is a field-like expression and the right is a constant-like expression. +func leftRightToBaseFunction(name string, left, right any) *baseFunction { + return newBaseFunction(name, []Expression{asFieldExpr(left), toExprOrConstant(right)}) +} + +// projectionsToMapValue converts a slice of Selectable items into a single +// protobuf MapValue. +func projectionsToMapValue(selectables []Selectable) (*pb.Value, error) { + if selectables == nil { + return &pb.Value{ValueType: &pb.Value_MapValue{}}, nil + } + fieldsProto := make(map[string]*pb.Value, len(selectables)) + for _, s := range selectables { + alias, expr := s.getSelectionDetails() + if _, exists := fieldsProto[alias]; exists { + return nil, fmt.Errorf("firestore: duplicate alias or field name %q in selectables", alias) + } + + protoVal, err := expr.toProto() + if err != nil { + return nil, fmt.Errorf("firestore: error processing expression for alias %q: %w", alias, err) + } + fieldsProto[alias] = protoVal + } + return &pb.Value{ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{Fields: fieldsProto}}}, nil +} + +// aliasedAggregatesToMapValue converts a slice of AliasedAggregate items into a single +// protobuf MapValue. +func aliasedAggregatesToMapValue(aggregates []*AliasedAggregate) (*pb.Value, error) { + if aggregates == nil { + return &pb.Value{ValueType: &pb.Value_MapValue{}}, nil + } + fieldsProto := make(map[string]*pb.Value, len(aggregates)) + for _, agg := range aggregates { + if _, exists := fieldsProto[agg.alias]; exists { + return nil, fmt.Errorf("firestore: duplicate alias %q in aggregations", agg.alias) + } + + base := agg.getBaseAggregateFunction() + if base.err != nil { + return nil, fmt.Errorf("firestore: error in aggregate expression for alias %q: %w", agg.alias, base.err) + } + protoVal, err := base.toProto() + if err != nil { + return nil, fmt.Errorf("firestore: error converting aggregate for alias %q to proto: %w", agg.alias, err) + } + fieldsProto[agg.alias] = protoVal + } + return &pb.Value{ValueType: &pb.Value_MapValue{MapValue: &pb.MapValue{Fields: fieldsProto}}}, nil +} + +// fieldsOrSelectablesToSelectables converts a user-provided list of mixed types +// (string, FieldPath, Selectable) into a uniform []Selectable slice. +func fieldsOrSelectablesToSelectables(fieldsOrSelectables ...any) ([]Selectable, error) { + selectables := make([]Selectable, 0, len(fieldsOrSelectables)) + for _, f := range fieldsOrSelectables { + var s Selectable + switch v := f.(type) { + case string: + if v == "" { + return nil, errors.New("firestore: path cannot be empty") + } + s = FieldOf(v).(*field) + case FieldPath: + s = FieldOf(v).(*field) + case Selectable: + s = v + default: + return nil, fmt.Errorf("firestore: value must be a string, FieldPath, or Selectable, but got %T", v) + } + selectables = append(selectables, s) + } + return selectables, nil +} + +// exprToProtoValue converts an Expr to a protobuf Value. +// If the expression is nil, it returns a Null Value. +func exprToProtoValue(expr Expression) (*pb.Value, error) { + if expr == nil { + return ConstantOfNull().getBaseExpr().pbVal, nil + } + return expr.toProto() +} diff --git a/firestore/query.go b/firestore/query.go index bb43ec93a888..729b47eb7c84 100644 --- a/firestore/query.go +++ b/firestore/query.go @@ -34,6 +34,7 @@ import ( var ( errMetricsBeforeEnd = errors.New("firestore: ExplainMetrics are available only after the iterator reaches the end") + errStatsBeforeEnd = errors.New("firestore: ExplainStats are available only after the iterator reaches the end") errInvalidVector = errors.New("firestore: queryVector must be Vector32 or Vector64") errMalformedVectorQuery = errors.New("firestore: Malformed VectorQuery. Use FindNearest or FindNearestPath to create VectorQuery") ) @@ -1767,3 +1768,378 @@ type AggregationResponse struct { // Query explain metrics. This is only present when ExplainOptions is provided. ExplainMetrics *ExplainMetrics } + +func (q *Query) toPipeline() *Pipeline { + var p *Pipeline + if q.allDescendants { + p = q.c.Pipeline().CollectionGroup(q.collectionID) + } else { + p = q.c.Pipeline().Collection(q.collectionID) + } + + if q.err != nil { + p.err = q.err + return p + } + + var allFilters []BooleanExpression + + // Original filters + for _, f := range q.filters { + var filterExpr BooleanExpression + var err error + if fieldFilter := f.GetFieldFilter(); fieldFilter != nil { + filterExpr, err = newQueryFilter(q, fieldFilter) + if err != nil { + p.err = err + return p + } + } else if unaryFilter := f.GetUnaryFilter(); unaryFilter != nil { + filterExpr, err = newQueryUnaryFilter(unaryFilter) + if err != nil { + p.err = err + return p + } + } + allFilters = append(allFilters, filterExpr) + } + + // Start at + if q.startCursorSpecified() { + var startFilter BooleanExpression + var err error + if q.startDoc != nil { + startFilter, err = newCursorFilter(q.orders, q.startDoc, q.startBefore, true) + } else { + startFilter, err = newCursorFilterWithValues(q.orders, q.startVals, q.startBefore, true) + } + if err != nil { + p.err = err + return p + } + allFilters = append(allFilters, startFilter) + } + + // End at + if q.endCursorSpecified() { + var endFilter BooleanExpression + var err error + if q.endDoc != nil { + endFilter, err = newCursorFilter(q.orders, q.endDoc, q.endBefore, false) + } else { + endFilter, err = newCursorFilterWithValues(q.orders, q.endVals, q.endBefore, false) + } + if err != nil { + p.err = err + return p + } + allFilters = append(allFilters, endFilter) + } + + // Order by + if len(q.orders) > 0 { + var orders []Ordering + for _, o := range q.orders { + var fp FieldPath + if o.fieldReference != nil { + var err error + fp, err = fieldPathFromFieldRef(o.fieldReference) + if err != nil { + p.err = err + return p + } + } else { + fp = o.fieldPath + } + field := FieldOf(fp) + var direction OrderingDirection + if o.dir == Asc { + direction = OrderingAsc + } else { + direction = OrderingDesc + } + orders = append(orders, Ordering{Expr: field, Direction: direction}) + } + p = p.Sort(orders...) + } + // Combine all filters + if len(allFilters) == 1 { + p = p.Where(allFilters[0]) + } else if len(allFilters) > 1 { + p = p.Where(And(allFilters[0], allFilters[1:]...)) + } + + // Offset + if q.offset > 0 { + p = p.Offset(int(q.offset)) + } + + // Limit + if q.limit != nil { + p = p.Limit(int(q.limit.Value)) + } + + // Select + if len(q.selection) > 0 { + var fields []interface{} + for _, s := range q.selection { + fp, err := fieldPathFromFieldRef(s) + if err != nil { + p.err = err + return p + } + fields = append(fields, fp) + } + p = p.Select(fields...) + } + + // FindNearest + if q.findNearest != nil { + var measure PipelineDistanceMeasure + switch q.findNearest.DistanceMeasure { + case pb.StructuredQuery_FindNearest_EUCLIDEAN: + measure = PipelineDistanceMeasureEuclidean + case pb.StructuredQuery_FindNearest_COSINE: + measure = PipelineDistanceMeasureCosine + case pb.StructuredQuery_FindNearest_DOT_PRODUCT: + measure = PipelineDistanceMeasureDotProduct + } + + vectorField, err := fieldPathFromFieldRef(q.findNearest.VectorField) + if err != nil { + p.err = err + return p + } + + queryVector, err := createFromProtoValue(q.findNearest.QueryVector, q.c) + if err != nil { + p.err = err + return p + } + var limit *int + if q.findNearest.Limit != nil { + val := int(q.findNearest.Limit.Value) + limit = &val + } + + var distanceField *string + if q.findNearest.DistanceResultField != "" { + distanceField = &q.findNearest.DistanceResultField + } + + p = p.FindNearest(vectorField, queryVector, measure, &PipelineFindNearestOptions{ + Limit: limit, + DistanceField: distanceField, + }) + } + + return p +} + +func fieldPathFromFieldRef(ref *pb.StructuredQuery_FieldReference) (FieldPath, error) { + return parseDotSeparatedString(ref.FieldPath) +} + +func newQueryFilter(q *Query, f *pb.StructuredQuery_FieldFilter) (BooleanExpression, error) { + fp, err := fieldPathFromFieldRef(f.GetField()) + if err != nil { + return nil, err + } + v, err := createFromProtoValue(f.GetValue(), q.c) + if err != nil { + return nil, err + } + + switch f.Op { + case pb.StructuredQuery_FieldFilter_EQUAL: + return Equal(fp, v), nil + case pb.StructuredQuery_FieldFilter_NOT_EQUAL: + return NotEqual(fp, v), nil + case pb.StructuredQuery_FieldFilter_LESS_THAN: + return LessThan(fp, v), nil + case pb.StructuredQuery_FieldFilter_LESS_THAN_OR_EQUAL: + return LessThanOrEqual(fp, v), nil + case pb.StructuredQuery_FieldFilter_GREATER_THAN: + return GreaterThan(fp, v), nil + case pb.StructuredQuery_FieldFilter_GREATER_THAN_OR_EQUAL: + return GreaterThanOrEqual(fp, v), nil + case pb.StructuredQuery_FieldFilter_IN: + return EqualAny(fp, v), nil + case pb.StructuredQuery_FieldFilter_NOT_IN: + return NotEqualAny(fp, v), nil + case pb.StructuredQuery_FieldFilter_ARRAY_CONTAINS: + return ArrayContains(fp, v), nil + case pb.StructuredQuery_FieldFilter_ARRAY_CONTAINS_ANY: + return ArrayContainsAny(fp, v), nil + default: + return nil, fmt.Errorf("firestore: unsupported query filter operator: %v", f.Op) + } +} + +func newQueryUnaryFilter(f *pb.StructuredQuery_UnaryFilter) (BooleanExpression, error) { + fp, err := fieldPathFromFieldRef(f.GetField()) + if err != nil { + return nil, err + } + switch f.Op { + case pb.StructuredQuery_UnaryFilter_IS_NULL: + return Equal(fp, nil), nil + case pb.StructuredQuery_UnaryFilter_IS_NOT_NULL: + return NotEqual(fp, nil), nil + case pb.StructuredQuery_UnaryFilter_IS_NAN: + return Equal(fp, math.NaN()), nil + case pb.StructuredQuery_UnaryFilter_IS_NOT_NAN: + return NotEqual(fp, math.NaN()), nil + default: + return nil, fmt.Errorf("firestore: unsupported unary filter operator: %v", f.Op) + } +} + +// newCursorFilter creates a pipeline filter expression from a document snapshot cursor. +func newCursorFilter(orders []order, doc *DocumentSnapshot, before, isStart bool) (BooleanExpression, error) { + values := make([]interface{}, len(orders)) + for i, o := range orders { + var err error + if o.isDocumentID() { + values[i] = doc.Ref.ID + } else { + values[i], err = doc.DataAt(o.fieldPath.toServiceFieldPath()) + if err != nil { + return nil, err + } + } + } + return newCursorFilterWithValues(orders, values, before, isStart) +} + +// newCursorFilterWithValues creates a pipeline filter expression from a list of values. +func newCursorFilterWithValues(orders []order, values []interface{}, before, isStart bool) (BooleanExpression, error) { + if len(orders) != len(values) { + return nil, errors.New("firestore: number of cursor values does not match number of OrderBy fields") + } + + var orTerms []BooleanExpression + for i := 1; i <= len(orders); i++ { + prefixOrders := orders[:i] + prefixValues := values[:i] + var andTerms []BooleanExpression + for j, o := range prefixOrders { + fp := o.fieldPath + val := prefixValues[j] + + var op string + if j < len(prefixOrders)-1 { + op = "==" + } else { + if isStart { + if before { // StartAt + if o.dir == Asc { + op = ">=" + } else { + op = "<=" + } + } else { // StartAfter + if o.dir == Asc { + op = ">" + } else { + op = "<" + } + } + } else { // End + if before { // EndBefore + if o.dir == Asc { + op = "<" + } else { + op = ">" + } + } else { // EndAt + if o.dir == Asc { + op = "<=" + } else { + op = ">=" + } + } + } + } + + switch op { + case "==": + andTerms = append(andTerms, Equal(fp, val)) + case ">": + andTerms = append(andTerms, GreaterThan(fp, val)) + case ">=": + andTerms = append(andTerms, GreaterThanOrEqual(fp, val)) + case "<": + andTerms = append(andTerms, LessThan(fp, val)) + case "<=": + andTerms = append(andTerms, LessThanOrEqual(fp, val)) + } + } + if len(andTerms) == 1 { + orTerms = append(orTerms, andTerms[0]) + } else if len(andTerms) > 1 { + orTerms = append(orTerms, And(andTerms[0], andTerms[1:]...)) + } + } + + if len(orTerms) == 1 { + return orTerms[0], nil + } + return Or(orTerms[0], orTerms[1:]...), nil +} + +// Pipeline creates a new [Pipeline] from the query. +// All of the operations of the query will be converted to pipeline stages. +// For example, `query.Where("f", "==", 1).Limit(10).OrderBy("f", Asc).Pipeline()` is equivalent to +// `client.Pipeline().Collection("C").Where(Equal("f", 1)).Limit(10).Sort(Ascending("f"))`. +func (q Query) Pipeline() *Pipeline { + return q.toPipeline() +} + +// Pipeline creates a new [Pipeline] from the aggregation query. +// All of the operations of the underlying query will be converted to pipeline stages, +// and an aggregate stage will be added for the aggregations. +func (aq *AggregationQuery) Pipeline() *Pipeline { + p := aq.query.toPipeline() + if p.err != nil { + return p + } + + if len(aq.aggregateQueries) == 0 { + return p + } + + var aggregations []*AliasedAggregate + for _, aggQuery := range aq.aggregateQueries { + alias := aggQuery.GetAlias() + + var agg *AliasedAggregate + if _, ok := aggQuery.Operator.(*pb.StructuredAggregationQuery_Aggregation_Count_); ok { + // AggregationQuery's Count is a count of all documents. We can achieve + // this in a pipeline by counting the document ID, which is always present. + agg = Count(DocumentID).As(alias) + } else if sum := aggQuery.GetSum(); sum != nil { + fp, err := fieldPathFromFieldRef(sum.GetField()) + if err != nil { + p.err = err + return p + } + agg = Sum(fp).As(alias) + } else if avg := aggQuery.GetAvg(); avg != nil { + fp, err := fieldPathFromFieldRef(avg.GetField()) + if err != nil { + p.err = err + return p + } + agg = Average(fp).As(alias) + } else { + // This case should not be reachable with the current AggregationQuery API. + p.err = fmt.Errorf("firestore: unsupported aggregation operator in Pipeline conversion") + return p + } + aggregations = append(aggregations, agg) + } + + p = p.Aggregate(aggregations...) + return p +} diff --git a/firestore/query_test.go b/firestore/query_test.go index 54c0337d8734..f8d147335f62 100644 --- a/firestore/query_test.go +++ b/firestore/query_test.go @@ -1751,6 +1751,141 @@ func TestQueryRunOptionsAndGetAllWithOptions(t *testing.T) { t.Fatal(err) } } + +func TestQuery_Pipeline(t *testing.T) { + t.Parallel() + client, _, cleanup := newMock(t) + defer cleanup() + + coll := client.Collection("C") + + testCases := []struct { + name string + query Query + expPipe *Pipeline + }{ + { + name: "simple query", + query: coll.Where("f", "==", 1).Limit(10), + expPipe: client.Pipeline().Collection("C").Where(Equal("f", 1)).Limit(10), + }, + { + name: "query with all clauses", + query: coll.Where("f", ">", 1).OrderBy("f", Asc).Select("f").Offset(1), + expPipe: client.Pipeline().Collection("C").Sort(Ordering{Expr: FieldOf("f"), Direction: OrderingAsc}).Where(GreaterThan("f", 1)).Offset(1).Select("f"), + }, + { + name: "query with collection group", + query: client.CollectionGroup("C").Where("f", "==", 1).Limit(10), + expPipe: client.Pipeline().CollectionGroup("C").Where(Equal("f", 1)).Limit(10), + }, + { + name: "query with cursor", + query: coll.OrderBy("f", Asc).StartAt(1), + expPipe: client.Pipeline().Collection("C").Sort(Ordering{Expr: FieldOf("f"), Direction: OrderingAsc}).Where(GreaterThanOrEqual(FieldPath{"f"}, 1)), + }, + { + name: "query with findNearest", + query: coll.FindNearest("f", []float32{1, 2, 3}, 5, DistanceMeasureEuclidean, &FindNearestOptions{DistanceResultField: "dist"}).q, + expPipe: client.Pipeline().Collection("C").FindNearest("f", []float32{1, 2, 3}, PipelineDistanceMeasureEuclidean, &PipelineFindNearestOptions{Limit: intptr(5), DistanceField: stringptr("dist")}), + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + p := tc.query.Pipeline() + if p.err != nil { + t.Fatalf("Pipeline() got error: %v", p.err) + } + gotProto, err := p.toProto() + if err != nil { + t.Fatalf("p.toProto() got error: %v", err) + } + expProto, err := tc.expPipe.toProto() + if err != nil { + t.Fatalf("expPipe.toProto() got error: %v", err) + } + + if diff := cmp.Diff(expProto, gotProto, protocmp.Transform()); diff != "" { + t.Errorf("mismatch (-want, +got)\n: %s", diff) + } + }) + } +} + +func intptr(i int) *int { + return &i +} + +func stringptr(s string) *string { + return &s +} + +func TestAggregationQuery_Pipeline(t *testing.T) { + t.Parallel() + client, _, cleanup := newMock(t) + defer cleanup() + + coll := client.Collection("C") + queryWithWhere := coll.Where("f", "==", 1) + + testCases := []struct { + name string + query *AggregationQuery + expPipe *Pipeline + }{ + { + name: "simple aggregation query", + query: coll.NewAggregationQuery().WithCount("total"), + expPipe: client.Pipeline().Collection("C").Aggregate(Count(DocumentID).As("total")), + }, + { + name: "aggregation query with where", + query: queryWithWhere.NewAggregationQuery().WithCount("total"), + expPipe: client.Pipeline().Collection("C").Where(Equal("f", 1)).Aggregate(Count(DocumentID).As("total")), + }, + { + name: "aggregation query with sum", + query: coll.NewAggregationQuery().WithSum("f", "sum_f"), + expPipe: client.Pipeline().Collection("C").Aggregate(Sum("f").As("sum_f")), + }, + { + name: "aggregation query with avg", + query: coll.NewAggregationQuery().WithAvg("f", "avg_f"), + expPipe: client.Pipeline().Collection("C").Aggregate(Average("f").As("avg_f")), + }, + { + name: "aggregation query with multiple aggregations", + query: coll.NewAggregationQuery().WithCount("total").WithSum("f", "sum_f"), + expPipe: client.Pipeline().Collection("C").Aggregate(Count(DocumentID).As("total"), Sum("f").As("sum_f")), + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + p := tc.query.Pipeline() + if p.err != nil { + t.Fatalf("Pipeline() got error: %v", p.err) + } + gotProto, err := p.toProto() + if err != nil { + t.Fatalf("p.toProto() got error: %v", err) + } + expProto, err := tc.expPipe.toProto() + if err != nil { + t.Fatalf("expPipe.toProto() got error: %v", err) + } + + if diff := cmp.Diff(expProto, gotProto, protocmp.Transform()); diff != "" { + t.Errorf("mismatch (-want, +got)\n: %s", diff) + } + }) + } +} func TestFindNearest(t *testing.T) { ctx := context.Background() c, srv, cleanup := newMock(t) diff --git a/firestore/to_value.go b/firestore/to_value.go index 409018430f96..71fb9db9ca49 100644 --- a/firestore/to_value.go +++ b/firestore/to_value.go @@ -78,6 +78,9 @@ func toProtoValue(v reflect.Value) (pbv *pb.Value, sawTransform bool, err error) return &pb.Value{ValueType: &pb.Value_TimestampValue{TimestampValue: x}}, false, nil case Vector32: return vectorToProtoValue(x), false, nil + case Expression: + pbVal, err := exprToProtoValue(x) + return pbVal, false, err case Vector64: return vectorToProtoValue(x), false, nil case *latlng.LatLng: diff --git a/firestore/transaction.go b/firestore/transaction.go index 00a59c9fa327..7577adea7e4e 100644 --- a/firestore/transaction.go +++ b/firestore/transaction.go @@ -353,3 +353,16 @@ func (t *Transaction) WithReadOptions(opts ...ReadOption) *Transaction { } return t } + +// Execute runs the given pipeline in the context of the transaction. +func (t *Transaction) Execute(p *Pipeline) *PipelineSnapshot { + if len(t.writes) > 0 { + t.readAfterWrite = true + return &PipelineSnapshot{ + iter: &PipelineResultIterator{err: errReadAfterWrite}, + } + } + p2 := p.copy() + p2.tx = t + return p2.Execute(t.ctx) +} diff --git a/firestore/transaction_test.go b/firestore/transaction_test.go index f044e19c0c86..e98357fe9932 100644 --- a/firestore/transaction_test.go +++ b/firestore/transaction_test.go @@ -339,6 +339,27 @@ func TestTransactionErrors(t *testing.T) { } }) + t.Run("Read after write, with pipeline", func(t *testing.T) { + srv.reset() + srv.addRPC(beginReq, beginRes) + srv.addRPC(rollbackReq, &emptypb.Empty{}) + err := c.RunTransaction(ctx, func(_ context.Context, tx *Transaction) error { + if err := tx.Delete(c.Doc("C/a")); err != nil { + return err + } + p := c.Pipeline().Collection("C").Select("x") + it := tx.Execute(p).Results() + it.Stop() + return it.err + }) + if err != errReadAfterWrite { + t.Errorf("got <%v>, want <%v>", err, errReadAfterWrite) + } + if !srv.isEmpty() { + t.Errorf("Expected %+v requests but not received. srv.reqItems: %+v", len(srv.reqItems), srv.reqItems) + } + }) + t.Run("Read after write fails even if the user ignores the read's error", func(t *testing.T) { srv.reset() srv.addRPC(beginReq, beginRes) diff --git a/firestore/util_test.go b/firestore/util_test.go index 825b8db92440..004e0f3b4014 100644 --- a/firestore/util_test.go +++ b/firestore/util_test.go @@ -143,3 +143,12 @@ func mapval(m map[string]*pb.Value) *pb.Value { func refval(path string) *pb.Value { return &pb.Value{ValueType: &pb.Value_ReferenceValue{ReferenceValue: path}} } + +func docsToMaps(t *testing.T, docs []*PipelineResult) []map[string]interface{} { + var maps []map[string]interface{} + for _, doc := range docs { + data := doc.Data() + maps = append(maps, data) + } + return maps +}