From 9b67878767b3e258f8b93186a5ea8d3c220b678b Mon Sep 17 00:00:00 2001 From: Baha Aiman Date: Mon, 27 Oct 2025 05:25:39 +0000 Subject: [PATCH] feat(firestore): [PQ] Add all Aggregation and Timestamp functions --- firestore/integration_test.go | 508 ++++++++++++++++++++++++++++++++ firestore/pipeline_aggregate.go | 62 +++- firestore/pipeline_expr.go | 34 ++- firestore/pipeline_function.go | 128 ++++++-- firestore/pipeline_utils.go | 28 +- 5 files changed, 714 insertions(+), 46 deletions(-) diff --git a/firestore/integration_test.go b/firestore/integration_test.go index 974e06c732a5..5e4039258a8c 100644 --- a/firestore/integration_test.go +++ b/firestore/integration_test.go @@ -3565,3 +3565,511 @@ func TestIntegration_FindNearest(t *testing.T) { }) } } + +func TestIntegration_PipelineFunctions(t *testing.T) { + if testParams[firestoreEditionKey].(firestoreEdition) != editionEnterprise { + t.Skip("Skipping pipeline queries tests since the firestore edition of", testParams[databaseIDKey].(string), "database is not enterprise") + } + t.Run("timestampFuncs", timestampFuncs) + t.Run("arithmeticFuncs", arithmeticFuncs) + t.Run("aggregateFuncs", aggregateFuncs) + t.Run("comparisonFuncs", comparisonFuncs) +} + +func timestampFuncs(t *testing.T) { + t.Parallel() + client := integrationClient(t) + coll := client.Collection(collectionIDs.New()) + h := testHelper{t} + now := time.Now() + docRef1 := coll.NewDoc() + h.mustCreate(docRef1, map[string]interface{}{ + "timestamp": now, + "unixMicros": now.UnixNano() / 1000, + "unixMillis": now.UnixNano() / 1e6, + "unixSeconds": now.Unix(), + }) + defer deleteDocuments([]*DocumentRef{docRef1}) + + tests := []struct { + name string + pipeline *Pipeline + want map[string]interface{} + }{ + { + name: "TimestampAdd day", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(TimestampAdd("timestamp", "day", 1).As("timestamp_plus_day")), + want: map[string]interface{}{"timestamp_plus_day": now.AddDate(0, 0, 1).Truncate(time.Microsecond)}, + }, + { + name: "TimestampAdd hour", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(TimestampAdd("timestamp", "hour", 1).As("timestamp_plus_hour")), + want: map[string]interface{}{"timestamp_plus_hour": now.Add(time.Hour).Truncate(time.Microsecond)}, + }, + { + name: "TimestampAdd minute", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(TimestampAdd("timestamp", "minute", 1).As("timestamp_plus_minute")), + want: map[string]interface{}{"timestamp_plus_minute": now.Add(time.Minute).Truncate(time.Microsecond)}, + }, + { + name: "TimestampAdd second", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(TimestampAdd("timestamp", "second", 1).As("timestamp_plus_second")), + want: map[string]interface{}{"timestamp_plus_second": now.Add(time.Second).Truncate(time.Microsecond)}, + }, + { + name: "TimestampSubtract", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(TimestampSubtract("timestamp", "hour", 1).As("timestamp_minus_hour")), + want: map[string]interface{}{"timestamp_minus_hour": now.Add(-time.Hour).Truncate(time.Microsecond)}, + }, + { + name: "TimestampToUnixMicros", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(FieldOf("timestamp").TimestampToUnixMicros().As("timestamp_micros")), + want: map[string]interface{}{"timestamp_micros": now.UnixNano() / 1000}, + }, + { + name: "TimestampToUnixMillis", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(FieldOf("timestamp").TimestampToUnixMillis().As("timestamp_millis")), + want: map[string]interface{}{"timestamp_millis": now.UnixNano() / 1e6}, + }, + { + name: "TimestampToUnixSeconds", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(FieldOf("timestamp").TimestampToUnixSeconds().As("timestamp_seconds")), + want: map[string]interface{}{"timestamp_seconds": now.Unix()}, + }, + { + name: "UnixMicrosToTimestamp - constant", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(UnixMicrosToTimestamp(ConstantOf(now.UnixNano() / 1000)).As("timestamp_from_micros")), + want: map[string]interface{}{"timestamp_from_micros": now.Truncate(time.Microsecond)}, + }, + { + name: "UnixMicrosToTimestamp - fieldname", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(UnixMicrosToTimestamp("unixMicros").As("timestamp_from_micros")), + want: map[string]interface{}{"timestamp_from_micros": now.Truncate(time.Microsecond)}, + }, + { + name: "UnixMillisToTimestamp", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(UnixMillisToTimestamp(ConstantOf(now.UnixNano() / 1e6)).As("timestamp_from_millis")), + want: map[string]interface{}{"timestamp_from_millis": now.Truncate(time.Millisecond)}, + }, + { + name: "UnixSecondsToTimestamp", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(UnixSecondsToTimestamp("unixSeconds").As("timestamp_from_seconds")), + want: map[string]interface{}{"timestamp_from_seconds": now.Truncate(time.Second)}, + }, + { + name: "CurrentTimestamp", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(CurrentTimestamp().As("current_timestamp")), + want: map[string]interface{}{"current_timestamp": time.Now().Truncate(time.Microsecond)}, + }, + } + + ctx := context.Background() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + iter := test.pipeline.Execute(ctx) + defer iter.Stop() + + docs, err := iter.GetAll() + if err != nil { + t.Fatalf("GetAll: %v", err) + } + if len(docs) != 1 { + t.Fatalf("expected 1 doc, got %d", len(docs)) + } + got, err := docs[0].Data() + if err != nil { + t.Fatalf("Data: %v", err) + } + + margin := 0 * time.Microsecond + if test.name == "CurrentTimestamp" { + margin = 5 * time.Second + } + if diff := testutil.Diff(got, test.want, cmpopts.EquateApproxTime(margin)); diff != "" { + t.Errorf("got: %v, want: %v, diff: %s", got, test.want, diff) + } + }) + } +} + +func arithmeticFuncs(t *testing.T) { + t.Parallel() + h := testHelper{t} + client := integrationClient(t) + coll := client.Collection(collectionIDs.New()) + docRef1 := coll.NewDoc() + h.mustCreate(docRef1, map[string]interface{}{ + "a": int(1), + "b": int(2), + "c": -3, + "d": 4.5, + "e": -5.5, + }) + defer deleteDocuments([]*DocumentRef{docRef1}) + + tests := []struct { + name string + pipeline *Pipeline + want map[string]interface{} + }{ + { + name: "Add - left FieldOf, right FieldOf", + pipeline: client.Pipeline().Collection(coll.ID).Select(Add(FieldOf("a"), FieldOf("b")).As("add")), + want: map[string]interface{}{"add": int64(3)}, + }, + { + name: "Add - left FieldOf, right ConstantOf", + pipeline: client.Pipeline().Collection(coll.ID).Select(Add(FieldOf("a"), ConstantOf(2)).As("add")), + want: map[string]interface{}{"add": int64(3)}, + }, + { + name: "Add - left FieldOf, right constant", + pipeline: client.Pipeline().Collection(coll.ID).Select(Add(FieldOf("a"), 5).As("add")), + want: map[string]interface{}{"add": int64(6)}, + }, + { + name: "Add - left fieldname, right constant", + pipeline: client.Pipeline().Collection(coll.ID).Select(Add("a", 5).As("add")), + want: map[string]interface{}{"add": int64(6)}, + }, + { + name: "Add - left fieldpath, right constant", + pipeline: client.Pipeline().Collection(coll.ID).Select(Add(FieldPath([]string{"a"}), 5).As("add")), + want: map[string]interface{}{"add": int64(6)}, + }, + { + name: "Add - left fieldpath, right expression", + pipeline: client.Pipeline().Collection(coll.ID).Select(Add(FieldPath([]string{"a"}), Add(FieldOf("b"), FieldOf("d"))).As("add")), + want: map[string]interface{}{"add": float64(7.5)}, + }, + { + name: "Subtract", + pipeline: client.Pipeline().Collection(coll.ID).Select(Subtract("a", FieldOf("b")).As("subtract")), + want: map[string]interface{}{"subtract": int64(-1)}, + }, + { + name: "Multiply", + pipeline: client.Pipeline().Collection(coll.ID).Select(Multiply("a", 5).As("multiply")), + want: map[string]interface{}{"multiply": int64(5)}, + }, + { + name: "Divide", + pipeline: client.Pipeline().Collection(coll.ID).Select(Divide("a", FieldOf("d")).As("divide")), + want: map[string]interface{}{"divide": float64(1 / 4.5)}, + }, + { + name: "Mod", + pipeline: client.Pipeline().Collection(coll.ID).Select(Mod("a", FieldOf("b")).As("mod")), + want: map[string]interface{}{"mod": int64(1)}, + }, + { + name: "Pow", + pipeline: client.Pipeline().Collection(coll.ID).Select(Pow("a", FieldOf("b")).As("pow")), + want: map[string]interface{}{"pow": float64(1)}, + }, + { + name: "Abs - fieldname", + pipeline: client.Pipeline().Collection(coll.ID).Select(Abs("c").As("abs")), + want: map[string]interface{}{"abs": int64(3)}, + }, + { + name: "Abs - fieldPath", + pipeline: client.Pipeline().Collection(coll.ID).Select(Abs(FieldPath([]string{"c"})).As("abs")), + want: map[string]interface{}{"abs": int64(3)}, + }, + { + name: "Abs - Expr", + pipeline: client.Pipeline().Collection(coll.ID).Select(Abs(Add(FieldOf("b"), FieldOf("d"))).As("abs")), + want: map[string]interface{}{"abs": float64(6.5)}, + }, + { + name: "Ceil", + pipeline: client.Pipeline().Collection(coll.ID).Select(Ceil("d").As("ceil")), + want: map[string]interface{}{"ceil": float64(5)}, + }, + { + name: "Floor", + pipeline: client.Pipeline().Collection(coll.ID).Select(Floor("d").As("floor")), + want: map[string]interface{}{"floor": float64(4)}, + }, + { + name: "Round", + pipeline: client.Pipeline().Collection(coll.ID).Select(Round("d").As("round")), + want: map[string]interface{}{"round": float64(5)}, + }, + { + name: "Sqrt", + pipeline: client.Pipeline().Collection(coll.ID).Select(Sqrt("d").As("sqrt")), + want: map[string]interface{}{"sqrt": math.Sqrt(4.5)}, + }, + { + name: "Log", + pipeline: client.Pipeline().Collection(coll.ID).Select(Log("d", 2).As("log")), + want: map[string]interface{}{"log": math.Log2(4.5)}, + }, + { + name: "Log10", + pipeline: client.Pipeline().Collection(coll.ID).Select(Log10("d").As("log10")), + want: map[string]interface{}{"log10": math.Log10(4.5)}, + }, + { + name: "Ln", + pipeline: client.Pipeline().Collection(coll.ID).Select(Ln("d").As("ln")), + want: map[string]interface{}{"ln": math.Log(4.5)}, + }, + { + name: "Exp", + pipeline: client.Pipeline().Collection(coll.ID).Select(Exp("d").As("exp")), + want: map[string]interface{}{"exp": math.Exp(4.5)}, + }, + } + + ctx := context.Background() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + iter := test.pipeline.Execute(ctx) + defer iter.Stop() + + docs, err := iter.GetAll() + if err != nil { + t.Fatalf("GetAll: %v", err) + } + if len(docs) != 1 { + t.Fatalf("expected 1 doc, got %d", len(docs)) + } + got, err := docs[0].Data() + if err != nil { + t.Fatalf("Data: %v", err) + } + if diff := testutil.Diff(got, test.want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", got, test.want, diff) + } + }) + } +} + +func aggregateFuncs(t *testing.T) { + t.Parallel() + h := testHelper{t} + client := integrationClient(t) + coll := client.Collection(collectionIDs.New()) + docRef1 := coll.NewDoc() + h.mustCreate(docRef1, map[string]interface{}{ + "a": 1, + }) + docRef2 := coll.NewDoc() + h.mustCreate(docRef2, map[string]interface{}{ + "a": 2, + }) + docRef3 := coll.NewDoc() + h.mustCreate(docRef3, map[string]interface{}{ + "b": 2, + }) + defer deleteDocuments([]*DocumentRef{docRef1, docRef2, docRef3}) + + tests := []struct { + name string + pipeline *Pipeline + want map[string]interface{} + }{ + { + name: "Sum - fieldname arg", + pipeline: client.Pipeline(). + Collection(coll.ID). + Aggregate(Sum("a").As("sum_a")), + want: map[string]interface{}{"sum_a": int64(3)}, + }, + { + name: "Sum - fieldpath arg", + pipeline: client.Pipeline(). + Collection(coll.ID). + Aggregate(Sum(FieldPath([]string{"a"})).As("sum_a")), + want: map[string]interface{}{"sum_a": int64(3)}, + }, + { + name: "Sum - FieldOf Expr", + pipeline: client.Pipeline(). + Collection(coll.ID). + Aggregate(Sum(FieldOf("a")).As("sum_a")), + want: map[string]interface{}{"sum_a": int64(3)}, + }, + { + name: "Sum - FieldOfPath Expr", + pipeline: client.Pipeline(). + Collection(coll.ID). + Aggregate(Sum(FieldOfPath(FieldPath([]string{"a"}))).As("sum_a")), + want: map[string]interface{}{"sum_a": int64(3)}, + }, + { + name: "Avg", + pipeline: client.Pipeline(). + Collection(coll.ID). + Aggregate(Average("a").As("avg_a")), + want: map[string]interface{}{"avg_a": float64(1.5)}, + }, + { + name: "Count", + pipeline: client.Pipeline(). + Collection(coll.ID). + Aggregate(Count("a").As("count_a")), + want: map[string]interface{}{"count_a": int64(2)}, + }, + { + name: "CountAll", + pipeline: client.Pipeline(). + Collection(coll.ID). + Aggregate(CountAll().As("count_all")), + want: map[string]interface{}{"count_all": int64(3)}, + }, + } + + ctx := context.Background() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + iter := test.pipeline.Execute(ctx) + defer iter.Stop() + + docs, err := iter.GetAll() + if err != nil { + t.Fatalf("GetAll: %v", err) + } + if len(docs) != 1 { + t.Fatalf("expected 1 doc, got %d", len(docs)) + } + got, err := docs[0].Data() + if err != nil { + t.Fatalf("Data: %v", err) + } + if diff := testutil.Diff(got, test.want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", got, test.want, diff) + } + }) + } +} + +func comparisonFuncs(t *testing.T) { + t.Parallel() + ctx := context.Background() + client := integrationClient(t) + now := time.Now() + coll := client.Collection(collectionIDs.New()) + doc1data := map[string]interface{}{ + "timestamp": now, + "a": 1, + "b": 2, + "c": -3, + "d": 4.5, + "e": -5.5, + } + _, err := coll.Doc("doc1").Create(ctx, doc1data) + if err != nil { + t.Fatalf("Create: %v", err) + } + doc2data := map[string]interface{}{ + "timestamp": now, + "a": 2, + "b": 2, + "c": -3, + "d": 4.5, + "e": -5.5, + } + _, err = coll.Doc("doc2").Create(ctx, doc2data) + if err != nil { + t.Fatalf("Create: %v", err) + } + defer deleteDocuments([]*DocumentRef{coll.Doc("doc1"), coll.Doc("doc2")}) + + doc1want := map[string]interface{}{"a": int64(1), "b": int64(2), "c": int64(-3), "d": float64(4.5), "e": float64(-5.5), "timestamp": now.Truncate(time.Microsecond)} + + tests := []struct { + name string + pipeline *Pipeline + want []map[string]interface{} + }{ + { + name: "Equal", + pipeline: client.Pipeline(). + Collection(coll.ID). + Where(Equal("a", 1)), + want: []map[string]interface{}{doc1want}, + }, + { + name: "NotEqual", + pipeline: client.Pipeline(). + Collection(coll.ID). + Where(NotEqual("a", 2)), + want: []map[string]interface{}{doc1want}, + }, + { + name: "LessThan", + pipeline: client.Pipeline(). + Collection(coll.ID). + Where(LessThan("a", 2)), + want: []map[string]interface{}{doc1want}, + }, + { + name: "Equivalent", + pipeline: client.Pipeline(). + Collection(coll.ID). + Where(Equivalent("a", 1)), + want: []map[string]interface{}{doc1want}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + iter := test.pipeline.Execute(ctx) + defer iter.Stop() + + docs, err := iter.GetAll() + if err != nil { + t.Fatalf("GetAll: %v", err) + } + if len(docs) != len(test.want) { + t.Fatalf("expected %d doc(s), got %d", len(test.want), len(docs)) + } + + var gots []map[string]interface{} + for _, doc := range docs { + got, err := doc.Data() + if err != nil { + t.Fatalf("Data: %v", err) + } + if ts, ok := got["timestamp"].(time.Time); ok { + got["timestamp"] = ts.Truncate(time.Microsecond) + } + gots = append(gots, got) + } + + if diff := testutil.Diff(gots, test.want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", gots, test.want, diff) + } + }) + } +} diff --git a/firestore/pipeline_aggregate.go b/firestore/pipeline_aggregate.go index 157ac3bf22db..0143b39c2a55 100644 --- a/firestore/pipeline_aggregate.go +++ b/firestore/pipeline_aggregate.go @@ -108,18 +108,18 @@ func Sum(fieldOrExpr any) AggregateFunction { return newBaseAggregateFunction("sum", fieldOrExpr) } -// Avg creates an aggregation that calculates the average (mean) of values from an expression or a field's values +// Average creates an aggregation that calculates the average (mean) of values from an expression or a field's values // across multiple stage inputs. // fieldOrExpr can be a field path string, [FieldPath] or [Expr] // Example: // // // Calculate the average age of users -// Avg(FieldOf("info.age")).As("averageAge") // FieldOf returns Expr -// Avg(FieldOfPath("info.age")).As("averageAge") // FieldOfPath returns Expr -// Avg("info.age").As("averageAge") // String implicitly becomes FieldOf(...).As(...) -// Avg(FieldPath([]string{"info", "age"})).As("averageAge") -func Avg(fieldOrExpr any) AggregateFunction { - return newBaseAggregateFunction("avg", fieldOrExpr) +// Average(FieldOf("info.age")).As("averageAge") // FieldOf returns Expr +// Average(FieldOfPath("info.age")).As("averageAge") // FieldOfPath returns Expr +// Average("info.age").As("averageAge") // String implicitly becomes FieldOf(...).As(...) +// Average(FieldPath([]string{"info", "age"})).As("averageAge") +func Average(fieldOrExpr any) AggregateFunction { + return newBaseAggregateFunction("average", fieldOrExpr) } // Count creates an aggregation that counts the number of stage inputs with valid evaluations of the @@ -144,3 +144,51 @@ func Count(fieldOrExpr any) AggregateFunction { func CountAll() AggregateFunction { return newBaseAggregateFunction("count", nil) } + +// CountDistinct creates an aggregation that counts the number of distinct values of the +// provided field or expression. +// fieldOrExpr can be a field path string, [FieldPath] or [Expr] +// Example: +// +// // CountDistinct the number of distinct items where the price is greater than 10 +// CountDistinct(FieldOf("price").Gt(10)).As("expensiveItemCount") // FieldOf("price").Gt(10) is a BooleanExpr +// // CountDistinct the total number of distinct products +// CountDistinct("productId").As("totalProducts") // String implicitly becomes FieldOf(...).As(...) +func CountDistinct(fieldOrExpr any) AggregateFunction { + return newBaseAggregateFunction("count_distinct", fieldOrExpr) +} + +// CountIf creates an aggregation that counts the number of values of the +// provided field or expression evaluates to TRUE. +// fieldOrExpr can be a field path string, [FieldPath] or [Expr] +// Example: +// +// CountIf(FieldOf("published")).As("publishedCount") +// CountIf("published").As("publishedCount") +func CountIf(fieldOrExpr any) AggregateFunction { + return newBaseAggregateFunction("count_if", fieldOrExpr) +} + +// Maximum creates an aggregation that calculates the maximum of values from an expression or a field's values +// across multiple stage inputs. +// +// Example: +// +// // Find the highest order amount +// Maximum(FieldOf("orderAmount")).As("maxOrderAmount") // FieldOf returns Expr +// Maximum("orderAmount").As("maxOrderAmount") // String implicitly becomes FieldOf(...).As(...) +func Maximum(fieldOrExpr any) AggregateFunction { + return newBaseAggregateFunction("maximum", fieldOrExpr) +} + +// Minimum creates an aggregation that calculates the minimum of values from an expression or a field's values +// across multiple stage inputs. +// +// Example: +// +// // Find the lowest order amount +// Minimum(FieldOf("orderAmount")).As("minOrderAmount") // FieldOf returns Expr +// Minimum("orderAmount").As("minOrderAmount") // String implicitly becomes FieldOf(...).As(...) +func Minimum(fieldOrExpr any) AggregateFunction { + return newBaseAggregateFunction("minimum", fieldOrExpr) +} diff --git a/firestore/pipeline_expr.go b/firestore/pipeline_expr.go index 4f5c81ea37f0..af36273d822a 100644 --- a/firestore/pipeline_expr.go +++ b/firestore/pipeline_expr.go @@ -61,6 +61,16 @@ type Expr interface { Round() Expr Sqrt() Expr + // Timestamp operations + TimestampAdd(unit, amount any) Expr + TimestampSubtract(unit, amount any) Expr + TimestampToUnixMicros() Expr + TimestampToUnixMillis() Expr + TimestampToUnixSeconds() Expr + UnixMicrosToTimestamp() Expr + UnixMillisToTimestamp() Expr + UnixSecondsToTimestamp() Expr + // Comparison operations Equal(other any) BooleanExpr NotEqual(other any) BooleanExpr @@ -72,7 +82,7 @@ type Expr interface { // Aggregators Sum() AggregateFunction - Avg() AggregateFunction + Average() AggregateFunction Count() AggregateFunction // As assigns an alias to an expression. @@ -107,6 +117,18 @@ func (b *baseExpr) Pow(other any) Expr { return Pow(b, other) } func (b *baseExpr) Round() Expr { return Round(b) } func (b *baseExpr) Sqrt() Expr { return Sqrt(b) } +// Timestamp functions +func (b *baseExpr) TimestampAdd(unit, amount any) Expr { return TimestampAdd(b, unit, amount) } +func (b *baseExpr) TimestampSubtract(unit, amount any) Expr { + return TimestampSubtract(b, unit, amount) +} +func (b *baseExpr) TimestampToUnixMicros() Expr { return TimestampToUnixMicros(b) } +func (b *baseExpr) TimestampToUnixMillis() Expr { return TimestampToUnixMillis(b) } +func (b *baseExpr) TimestampToUnixSeconds() Expr { return TimestampToUnixSeconds(b) } +func (b *baseExpr) UnixMicrosToTimestamp() Expr { return UnixMicrosToTimestamp(b) } +func (b *baseExpr) UnixMillisToTimestamp() Expr { return UnixMillisToTimestamp(b) } +func (b *baseExpr) UnixSecondsToTimestamp() Expr { return UnixSecondsToTimestamp(b) } + // Comparison functions func (b *baseExpr) Equal(other any) BooleanExpr { return Equal(b, other) } func (b *baseExpr) NotEqual(other any) BooleanExpr { return NotEqual(b, other) } @@ -117,9 +139,13 @@ func (b *baseExpr) LessThanOrEqual(other any) BooleanExpr { return LessThanOr func (b *baseExpr) Equivalent(other any) BooleanExpr { return Equivalent(b, other) } // Aggregation operations -func (b *baseExpr) Sum() AggregateFunction { return Sum(b) } -func (b *baseExpr) Avg() AggregateFunction { return Avg(b) } -func (b *baseExpr) Count() AggregateFunction { return Count(b) } +func (b *baseExpr) Sum() AggregateFunction { return Sum(b) } +func (b *baseExpr) Average() AggregateFunction { return Average(b) } +func (b *baseExpr) Count() AggregateFunction { return Count(b) } +func (b *baseExpr) CountDistinct() AggregateFunction { return CountDistinct(b) } +func (b *baseExpr) CountIf() AggregateFunction { return CountIf(b) } +func (b *baseExpr) Maximum() AggregateFunction { return Maximum(b) } +func (b *baseExpr) Minimum() AggregateFunction { return Minimum(b) } func (b *baseExpr) As(alias string) Selectable { return newAliasedExpr(b, alias) diff --git a/firestore/pipeline_function.go b/firestore/pipeline_function.go index f4b91b51a037..ae9e725ce99e 100644 --- a/firestore/pipeline_function.go +++ b/firestore/pipeline_function.go @@ -40,7 +40,7 @@ func newBaseFunction(name string, params []Expr) *baseFunction { argsPbVals := make([]*pb.Value, 0, len(params)) for i, param := range params { - paramExpr := toExprOrField(param) + paramExpr := asFieldExpr(param) pbVal, err := paramExpr.toProto() if err != nil { return &baseFunction{baseExpr: &baseExpr{err: fmt.Errorf("error converting arg %d for function %q: %w", i, name, err)}} @@ -59,7 +59,7 @@ func newBaseFunction(name string, params []Expr) *baseFunction { // Add creates an expression that adds two expressions together, returning it as an Expr. // - left can be a field path string, [FieldPath] or [Expr]. -// - right can be a constant or an [Expr]. +// - right can be a numeric constant or a numeric [Expr]. // // Example: // @@ -118,47 +118,47 @@ func Divide(left, right any) Expr { } // Abs creates an expression that is the absolute value of the input field or expression. -// - numericExprOrField can be a field path string, [FieldPath] or an [Expr] that returns a number when evaluated. +// - numericExprOrFieldPath can be a field path string, [FieldPath] or an [Expr] that returns a number when evaluated. // // Example: // // // Absolute value of the 'age' field. // Abs("age") -func Abs(numericExprOrField any) Expr { - return newBaseFunction("abs", []Expr{toExprOrField(numericExprOrField)}) +func Abs(numericExprOrFieldPath any) Expr { + return newBaseFunction("abs", []Expr{asFieldExpr(numericExprOrFieldPath)}) } // Floor creates an expression that is the largest integer that isn't less than the input field or expression. -// - numericExprOrField can be a field path string, [FieldPath] or an [Expr] that returns a number when evaluated. +// - numericExprOrFieldPath can be a field path string, [FieldPath] or an [Expr] that returns a number when evaluated. // // Example: // // // Floor value of the 'age' field. // Floor("age") -func Floor(numericExprOrField any) Expr { - return newBaseFunction("floor", []Expr{toExprOrField(numericExprOrField)}) +func Floor(numericExprOrFieldPath any) Expr { + return newBaseFunction("floor", []Expr{asFieldExpr(numericExprOrFieldPath)}) } // Ceil creates an expression that is the smallest integer that isn't less than the input field or expression. -// - numericExprOrField can be a field path string, [FieldPath] or an [Expr] that returns a number when evaluated. +// - numericExprOrFieldPath can be a field path string, [FieldPath] or an [Expr] that returns a number when evaluated. // // Example: // // // Ceiling value of the 'age' field. // Ceil("age") -func Ceil(numericExprOrField any) Expr { - return newBaseFunction("ceil", []Expr{toExprOrField(numericExprOrField)}) +func Ceil(numericExprOrFieldPath any) Expr { + return newBaseFunction("ceil", []Expr{asFieldExpr(numericExprOrFieldPath)}) } // Exp creates an expression that is the Euler's number e raised to the power of the input field or expression. -// - numericExprOrField can be a field path string, [FieldPath] or an [Expr] that returns a number when evaluated. +// - numericExprOrFieldPath can be a field path string, [FieldPath] or an [Expr] that returns a number when evaluated. // // Example: // // // e to the power of the value of the 'age' field. // Exp("age") -func Exp(numericExprOrField any) Expr { - return newBaseFunction("exp", []Expr{toExprOrField(numericExprOrField)}) +func Exp(numericExprOrFieldPath any) Expr { + return newBaseFunction("exp", []Expr{asFieldExpr(numericExprOrFieldPath)}) } // Log creates an expression that is logarithm of the left expression to base as the right expression, returning it as an Expr. @@ -177,25 +177,25 @@ func Log(left, right any) Expr { } // Log10 creates an expression that is the base 10 logarithm of the input field or expression. -// - numericExprOrField can be a field path string, [FieldPath] or an [Expr] that returns a number when evaluated. +// - numericExprOrFieldPath can be a field path string, [FieldPath] or an [Expr] that returns a number when evaluated. // // Example: // // // Base 10 logarithmic value of the 'age' field. // Log10("age") -func Log10(numericExprOrField any) Expr { - return newBaseFunction("log10", []Expr{toExprOrField(numericExprOrField)}) +func Log10(numericExprOrFieldPath any) Expr { + return newBaseFunction("log10", []Expr{asFieldExpr(numericExprOrFieldPath)}) } // Ln creates an expression that is the natural logarithm (base e) of the input field or expression. -// - numericExprOrField can be a field path string, [FieldPath] or an [Expr] that returns a number when evaluated. +// - numericExprOrFieldPath can be a field path string, [FieldPath] or an [Expr] that returns a number when evaluated. // // Example: // // // Natural logarithmic value of the 'age' field. // Ln("age") -func Ln(numericExprOrField any) Expr { - return newBaseFunction("ln", []Expr{toExprOrField(numericExprOrField)}) +func Ln(numericExprOrFieldPath any) Expr { + return newBaseFunction("ln", []Expr{asFieldExpr(numericExprOrFieldPath)}) } // Mod creates an expression that computes the modulo of the left expression by the right expression, returning it as an Expr. @@ -228,30 +228,94 @@ func Pow(left, right any) Expr { return leftRightToBaseFunction("pow", left, right) } -// Rand creates an expression that return a pseudo-random number of type double in the range of [0, 1), -// inclusive of 0 and exclusive of 1. -func Rand() Expr { - return newBaseFunction("rand", []Expr{}) -} - // Round creates an expression that rounds the input field or expression to nearest integer. -// - numericExprOrField can be a field path string, [FieldPath] or an [Expr] that returns a number when evaluated. +// - numericExprOrFieldPath can be a field path string, [FieldPath] or an [Expr] that returns a number when evaluated. // // Example: // // // Round the value of the 'age' field. // Round("age") -func Round(numericExprOrField any) Expr { - return newBaseFunction("round", []Expr{toExprOrField(numericExprOrField)}) +func Round(numericExprOrFieldPath any) Expr { + return newBaseFunction("round", []Expr{asFieldExpr(numericExprOrFieldPath)}) } // Sqrt creates an expression that is the square root of the input field or expression. -// - numericExprOrField can be a field path string, [FieldPath] or an [Expr] that returns a number when evaluated. +// - numericExprOrFieldPath can be a field path string, [FieldPath] or an [Expr] that returns a number when evaluated. // // Example: // // // Square root of the value of the 'age' field. // Sqrt("age") -func Sqrt(numericExprOrField any) Expr { - return newBaseFunction("sqrt", []Expr{toExprOrField(numericExprOrField)}) +func Sqrt(numericExprOrFieldPath any) Expr { + return newBaseFunction("sqrt", []Expr{asFieldExpr(numericExprOrFieldPath)}) +} + +// TimestampAdd creates an expression that adds a specified amount of time to a timestamp. +// - timestamp can be a field path string, [FieldPath] or [Expr]. +// - unit can be a string or an [Expr]. Valid units include "microsecond", "millisecond", "second", "minute", "hour" and "day". +// - amount can be an int, int32, int64 or [Expr]. +// +// Example: +// +// // Add 5 hours to the value of the 'last_updated' field. +// TimestampAdd("last_updated", "hour", 5) +func TimestampAdd(timestamp, unit, amount any) Expr { + return newBaseFunction("timestamp_add", []Expr{asFieldExpr(timestamp), asStringExpr(unit), asInt64Expr(amount)}) +} + +// TimestampSubtract creates an expression that subtracts a specified amount of time from a timestamp. +// - timestamp can be a field path string, [FieldPath] or [Expr]. +// - unit can be a string or an [Expr]. Valid units include "microsecond", "millisecond", "second", "minute", "hour" and "day". +// - amount can be an int, int32, int64 or [Expr]. +// +// Example: +// +// // Subtract 10 days from the value of the 'last_updated' field. +// TimestampSubtract("last_updated", "day", 10) +func TimestampSubtract(timestamp, unit, amount any) Expr { + return newBaseFunction("timestamp_subtract", []Expr{asFieldExpr(timestamp), asStringExpr(unit), asInt64Expr(amount)}) +} + +// TimestampToUnixMicros creates an expression that converts a timestamp expression to the number of microseconds since +// the Unix epoch (1970-01-01 00:00:00 UTC). +// - timestamp can be a field path string, [FieldPath] or [Expr]. +func TimestampToUnixMicros(timestamp any) Expr { + return newBaseFunction("timestamp_to_unix_micros", []Expr{asFieldExpr(timestamp)}) +} + +// TimestampToUnixMillis creates an expression that converts a timestamp expression to the number of milliseconds since +// the Unix epoch (1970-01-01 00:00:00 UTC). +// - timestamp can be a field path string, [FieldPath] or [Expr]. +func TimestampToUnixMillis(timestamp any) Expr { + return newBaseFunction("timestamp_to_unix_millis", []Expr{asFieldExpr(timestamp)}) +} + +// TimestampToUnixSeconds creates an expression that converts a timestamp expression to the number of seconds since +// the Unix epoch (1970-01-01 00:00:00 UTC). +// - timestamp can be a field path string, [FieldPath] or [Expr]. +func TimestampToUnixSeconds(timestamp any) Expr { + return newBaseFunction("timestamp_to_unix_seconds", []Expr{asFieldExpr(timestamp)}) +} + +// UnixMicrosToTimestamp creates an expression that converts a Unix timestamp in microseconds to a Firestore timestamp. +// - micros can be a field path string, [FieldPath] or [Expr]. +func UnixMicrosToTimestamp(micros any) Expr { + return newBaseFunction("unix_micros_to_timestamp", []Expr{asFieldExpr(micros)}) +} + +// UnixMillisToTimestamp creates an expression that converts a Unix timestamp in milliseconds to a Firestore timestamp. +// - millis can be a field path string, [FieldPath] or [Expr]. +func UnixMillisToTimestamp(millis any) Expr { + return newBaseFunction("unix_millis_to_timestamp", []Expr{asFieldExpr(millis)}) +} + +// UnixSecondsToTimestamp creates an expression that converts a Unix timestamp in seconds to a Firestore timestamp. +// - seconds can be a field path string, [FieldPath] or [Expr]. +func UnixSecondsToTimestamp(seconds any) Expr { + return newBaseFunction("unix_seconds_to_timestamp", []Expr{asFieldExpr(seconds)}) +} + +// CurrentTimestamp creates an expression that returns the current timestamp. +func CurrentTimestamp() Expr { + return newBaseFunction("current_timestamp", []Expr{}) } diff --git a/firestore/pipeline_utils.go b/firestore/pipeline_utils.go index 7925aad2491a..9999d4829047 100644 --- a/firestore/pipeline_utils.go +++ b/firestore/pipeline_utils.go @@ -30,9 +30,9 @@ func toExprOrConstant(val any) Expr { return ConstantOf(val) } -// toExprOrField converts a plain Go string or FieldPath into a field expression. +// asFieldExpr converts a plain Go string or FieldPath into a field expression. // If the value is already an Expr, it's returned directly. -func toExprOrField(val any) Expr { +func asFieldExpr(val any) Expr { switch v := val.(type) { case Expr: return v @@ -45,10 +45,32 @@ func toExprOrField(val any) Expr { } } +func asInt64Expr(val any) Expr { + switch v := val.(type) { + case Expr: + return v + case int, int32, int64: + return ConstantOf(v) + default: + return &baseExpr{err: fmt.Errorf("firestore: value must be a int, int32, int64 or Expr, but got %T", val)} + } +} + +func asStringExpr(val any) Expr { + switch v := val.(type) { + case Expr: + return v + case string: + return ConstantOf(v) + default: + return &baseExpr{err: fmt.Errorf("firestore: value must be a string or Expr, but got %T", val)} + } +} + // leftRightToBaseFunction is a helper for creating binary functions like Add or Eq. // It ensures the left operand is a field-like expression and the right is a constant-like expression. func leftRightToBaseFunction(name string, left, right any) *baseFunction { - return newBaseFunction(name, []Expr{toExprOrField(left), toExprOrConstant(right)}) + return newBaseFunction(name, []Expr{asFieldExpr(left), toExprOrConstant(right)}) } // projectionsToMapValue converts a slice of Selectable items into a single