From 2c4a99cc9a49d173b2d56ad26e8d7c2aea773f08 Mon Sep 17 00:00:00 2001 From: Baha Aiman Date: Thu, 30 Oct 2025 21:44:55 +0000 Subject: [PATCH 1/4] test(firestore): [PQ] move tests and resolve build failures --- firestore/integration_test.go | 508 ------------------------ firestore/pipeline_filter_condition.go | 22 +- firestore/pipeline_function.go | 40 +- firestore/pipeline_integration_test.go | 514 ++++++++++++++++++++++++- firestore/pipeline_utils.go | 24 +- firestore/util_test.go | 5 +- 6 files changed, 527 insertions(+), 586 deletions(-) diff --git a/firestore/integration_test.go b/firestore/integration_test.go index 044292d80700..ab01ad6b2a62 100644 --- a/firestore/integration_test.go +++ b/firestore/integration_test.go @@ -4117,511 +4117,3 @@ func TestIntegration_PipelineStages(t *testing.T) { } }) } - -func TestIntegration_PipelineFunctions(t *testing.T) { - if testParams[firestoreEditionKey].(firestoreEdition) != editionEnterprise { - t.Skip("Skipping pipeline queries tests since the firestore edition of", testParams[databaseIDKey].(string), "database is not enterprise") - } - t.Run("timestampFuncs", timestampFuncs) - t.Run("arithmeticFuncs", arithmeticFuncs) - t.Run("aggregateFuncs", aggregateFuncs) - t.Run("comparisonFuncs", comparisonFuncs) -} - -func timestampFuncs(t *testing.T) { - t.Parallel() - client := integrationClient(t) - coll := client.Collection(collectionIDs.New()) - h := testHelper{t} - now := time.Now() - docRef1 := coll.NewDoc() - h.mustCreate(docRef1, map[string]interface{}{ - "timestamp": now, - "unixMicros": now.UnixNano() / 1000, - "unixMillis": now.UnixNano() / 1e6, - "unixSeconds": now.Unix(), - }) - defer deleteDocuments([]*DocumentRef{docRef1}) - - tests := []struct { - name string - pipeline *Pipeline - want map[string]interface{} - }{ - { - name: "TimestampAdd day", - pipeline: client.Pipeline(). - Collection(coll.ID). - Select(TimestampAdd("timestamp", "day", 1).As("timestamp_plus_day")), - want: map[string]interface{}{"timestamp_plus_day": now.AddDate(0, 0, 1).Truncate(time.Microsecond)}, - }, - { - name: "TimestampAdd hour", - pipeline: client.Pipeline(). - Collection(coll.ID). - Select(TimestampAdd("timestamp", "hour", 1).As("timestamp_plus_hour")), - want: map[string]interface{}{"timestamp_plus_hour": now.Add(time.Hour).Truncate(time.Microsecond)}, - }, - { - name: "TimestampAdd minute", - pipeline: client.Pipeline(). - Collection(coll.ID). - Select(TimestampAdd("timestamp", "minute", 1).As("timestamp_plus_minute")), - want: map[string]interface{}{"timestamp_plus_minute": now.Add(time.Minute).Truncate(time.Microsecond)}, - }, - { - name: "TimestampAdd second", - pipeline: client.Pipeline(). - Collection(coll.ID). - Select(TimestampAdd("timestamp", "second", 1).As("timestamp_plus_second")), - want: map[string]interface{}{"timestamp_plus_second": now.Add(time.Second).Truncate(time.Microsecond)}, - }, - { - name: "TimestampSubtract", - pipeline: client.Pipeline(). - Collection(coll.ID). - Select(TimestampSubtract("timestamp", "hour", 1).As("timestamp_minus_hour")), - want: map[string]interface{}{"timestamp_minus_hour": now.Add(-time.Hour).Truncate(time.Microsecond)}, - }, - { - name: "TimestampToUnixMicros", - pipeline: client.Pipeline(). - Collection(coll.ID). - Select(FieldOf("timestamp").TimestampToUnixMicros().As("timestamp_micros")), - want: map[string]interface{}{"timestamp_micros": now.UnixNano() / 1000}, - }, - { - name: "TimestampToUnixMillis", - pipeline: client.Pipeline(). - Collection(coll.ID). - Select(FieldOf("timestamp").TimestampToUnixMillis().As("timestamp_millis")), - want: map[string]interface{}{"timestamp_millis": now.UnixNano() / 1e6}, - }, - { - name: "TimestampToUnixSeconds", - pipeline: client.Pipeline(). - Collection(coll.ID). - Select(FieldOf("timestamp").TimestampToUnixSeconds().As("timestamp_seconds")), - want: map[string]interface{}{"timestamp_seconds": now.Unix()}, - }, - { - name: "UnixMicrosToTimestamp - constant", - pipeline: client.Pipeline(). - Collection(coll.ID). - Select(UnixMicrosToTimestamp(ConstantOf(now.UnixNano() / 1000)).As("timestamp_from_micros")), - want: map[string]interface{}{"timestamp_from_micros": now.Truncate(time.Microsecond)}, - }, - { - name: "UnixMicrosToTimestamp - fieldname", - pipeline: client.Pipeline(). - Collection(coll.ID). - Select(UnixMicrosToTimestamp("unixMicros").As("timestamp_from_micros")), - want: map[string]interface{}{"timestamp_from_micros": now.Truncate(time.Microsecond)}, - }, - { - name: "UnixMillisToTimestamp", - pipeline: client.Pipeline(). - Collection(coll.ID). - Select(UnixMillisToTimestamp(ConstantOf(now.UnixNano() / 1e6)).As("timestamp_from_millis")), - want: map[string]interface{}{"timestamp_from_millis": now.Truncate(time.Millisecond)}, - }, - { - name: "UnixSecondsToTimestamp", - pipeline: client.Pipeline(). - Collection(coll.ID). - Select(UnixSecondsToTimestamp("unixSeconds").As("timestamp_from_seconds")), - want: map[string]interface{}{"timestamp_from_seconds": now.Truncate(time.Second)}, - }, - { - name: "CurrentTimestamp", - pipeline: client.Pipeline(). - Collection(coll.ID). - Select(CurrentTimestamp().As("current_timestamp")), - want: map[string]interface{}{"current_timestamp": time.Now().Truncate(time.Microsecond)}, - }, - } - - ctx := context.Background() - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - iter := test.pipeline.Execute(ctx) - defer iter.Stop() - - docs, err := iter.GetAll() - if err != nil { - t.Fatalf("GetAll: %v", err) - } - if len(docs) != 1 { - t.Fatalf("expected 1 doc, got %d", len(docs)) - } - got, err := docs[0].Data() - if err != nil { - t.Fatalf("Data: %v", err) - } - - margin := 0 * time.Microsecond - if test.name == "CurrentTimestamp" { - margin = 5 * time.Second - } - if diff := testutil.Diff(got, test.want, cmpopts.EquateApproxTime(margin)); diff != "" { - t.Errorf("got: %v, want: %v, diff: %s", got, test.want, diff) - } - }) - } -} - -func arithmeticFuncs(t *testing.T) { - t.Parallel() - h := testHelper{t} - client := integrationClient(t) - coll := client.Collection(collectionIDs.New()) - docRef1 := coll.NewDoc() - h.mustCreate(docRef1, map[string]interface{}{ - "a": int(1), - "b": int(2), - "c": -3, - "d": 4.5, - "e": -5.5, - }) - defer deleteDocuments([]*DocumentRef{docRef1}) - - tests := []struct { - name string - pipeline *Pipeline - want map[string]interface{} - }{ - { - name: "Add - left FieldOf, right FieldOf", - pipeline: client.Pipeline().Collection(coll.ID).Select(Add(FieldOf("a"), FieldOf("b")).As("add")), - want: map[string]interface{}{"add": int64(3)}, - }, - { - name: "Add - left FieldOf, right ConstantOf", - pipeline: client.Pipeline().Collection(coll.ID).Select(Add(FieldOf("a"), ConstantOf(2)).As("add")), - want: map[string]interface{}{"add": int64(3)}, - }, - { - name: "Add - left FieldOf, right constant", - pipeline: client.Pipeline().Collection(coll.ID).Select(Add(FieldOf("a"), 5).As("add")), - want: map[string]interface{}{"add": int64(6)}, - }, - { - name: "Add - left fieldname, right constant", - pipeline: client.Pipeline().Collection(coll.ID).Select(Add("a", 5).As("add")), - want: map[string]interface{}{"add": int64(6)}, - }, - { - name: "Add - left fieldpath, right constant", - pipeline: client.Pipeline().Collection(coll.ID).Select(Add(FieldPath([]string{"a"}), 5).As("add")), - want: map[string]interface{}{"add": int64(6)}, - }, - { - name: "Add - left fieldpath, right expression", - pipeline: client.Pipeline().Collection(coll.ID).Select(Add(FieldPath([]string{"a"}), Add(FieldOf("b"), FieldOf("d"))).As("add")), - want: map[string]interface{}{"add": float64(7.5)}, - }, - { - name: "Subtract", - pipeline: client.Pipeline().Collection(coll.ID).Select(Subtract("a", FieldOf("b")).As("subtract")), - want: map[string]interface{}{"subtract": int64(-1)}, - }, - { - name: "Multiply", - pipeline: client.Pipeline().Collection(coll.ID).Select(Multiply("a", 5).As("multiply")), - want: map[string]interface{}{"multiply": int64(5)}, - }, - { - name: "Divide", - pipeline: client.Pipeline().Collection(coll.ID).Select(Divide("a", FieldOf("d")).As("divide")), - want: map[string]interface{}{"divide": float64(1 / 4.5)}, - }, - { - name: "Mod", - pipeline: client.Pipeline().Collection(coll.ID).Select(Mod("a", FieldOf("b")).As("mod")), - want: map[string]interface{}{"mod": int64(1)}, - }, - { - name: "Pow", - pipeline: client.Pipeline().Collection(coll.ID).Select(Pow("a", FieldOf("b")).As("pow")), - want: map[string]interface{}{"pow": float64(1)}, - }, - { - name: "Abs - fieldname", - pipeline: client.Pipeline().Collection(coll.ID).Select(Abs("c").As("abs")), - want: map[string]interface{}{"abs": int64(3)}, - }, - { - name: "Abs - fieldPath", - pipeline: client.Pipeline().Collection(coll.ID).Select(Abs(FieldPath([]string{"c"})).As("abs")), - want: map[string]interface{}{"abs": int64(3)}, - }, - { - name: "Abs - Expr", - pipeline: client.Pipeline().Collection(coll.ID).Select(Abs(Add(FieldOf("b"), FieldOf("d"))).As("abs")), - want: map[string]interface{}{"abs": float64(6.5)}, - }, - { - name: "Ceil", - pipeline: client.Pipeline().Collection(coll.ID).Select(Ceil("d").As("ceil")), - want: map[string]interface{}{"ceil": float64(5)}, - }, - { - name: "Floor", - pipeline: client.Pipeline().Collection(coll.ID).Select(Floor("d").As("floor")), - want: map[string]interface{}{"floor": float64(4)}, - }, - { - name: "Round", - pipeline: client.Pipeline().Collection(coll.ID).Select(Round("d").As("round")), - want: map[string]interface{}{"round": float64(5)}, - }, - { - name: "Sqrt", - pipeline: client.Pipeline().Collection(coll.ID).Select(Sqrt("d").As("sqrt")), - want: map[string]interface{}{"sqrt": math.Sqrt(4.5)}, - }, - { - name: "Log", - pipeline: client.Pipeline().Collection(coll.ID).Select(Log("d", 2).As("log")), - want: map[string]interface{}{"log": math.Log2(4.5)}, - }, - { - name: "Log10", - pipeline: client.Pipeline().Collection(coll.ID).Select(Log10("d").As("log10")), - want: map[string]interface{}{"log10": math.Log10(4.5)}, - }, - { - name: "Ln", - pipeline: client.Pipeline().Collection(coll.ID).Select(Ln("d").As("ln")), - want: map[string]interface{}{"ln": math.Log(4.5)}, - }, - { - name: "Exp", - pipeline: client.Pipeline().Collection(coll.ID).Select(Exp("d").As("exp")), - want: map[string]interface{}{"exp": math.Exp(4.5)}, - }, - } - - ctx := context.Background() - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - iter := test.pipeline.Execute(ctx) - defer iter.Stop() - - docs, err := iter.GetAll() - if err != nil { - t.Fatalf("GetAll: %v", err) - } - if len(docs) != 1 { - t.Fatalf("expected 1 doc, got %d", len(docs)) - } - got, err := docs[0].Data() - if err != nil { - t.Fatalf("Data: %v", err) - } - if diff := testutil.Diff(got, test.want); diff != "" { - t.Errorf("got: %v, want: %v, diff +want -got: %s", got, test.want, diff) - } - }) - } -} - -func aggregateFuncs(t *testing.T) { - t.Parallel() - h := testHelper{t} - client := integrationClient(t) - coll := client.Collection(collectionIDs.New()) - docRef1 := coll.NewDoc() - h.mustCreate(docRef1, map[string]interface{}{ - "a": 1, - }) - docRef2 := coll.NewDoc() - h.mustCreate(docRef2, map[string]interface{}{ - "a": 2, - }) - docRef3 := coll.NewDoc() - h.mustCreate(docRef3, map[string]interface{}{ - "b": 2, - }) - defer deleteDocuments([]*DocumentRef{docRef1, docRef2, docRef3}) - - tests := []struct { - name string - pipeline *Pipeline - want map[string]interface{} - }{ - { - name: "Sum - fieldname arg", - pipeline: client.Pipeline(). - Collection(coll.ID). - Aggregate(Sum("a").As("sum_a")), - want: map[string]interface{}{"sum_a": int64(3)}, - }, - { - name: "Sum - fieldpath arg", - pipeline: client.Pipeline(). - Collection(coll.ID). - Aggregate(Sum(FieldPath([]string{"a"})).As("sum_a")), - want: map[string]interface{}{"sum_a": int64(3)}, - }, - { - name: "Sum - FieldOf Expr", - pipeline: client.Pipeline(). - Collection(coll.ID). - Aggregate(Sum(FieldOf("a")).As("sum_a")), - want: map[string]interface{}{"sum_a": int64(3)}, - }, - { - name: "Sum - FieldOfPath Expr", - pipeline: client.Pipeline(). - Collection(coll.ID). - Aggregate(Sum(FieldOfPath(FieldPath([]string{"a"}))).As("sum_a")), - want: map[string]interface{}{"sum_a": int64(3)}, - }, - { - name: "Avg", - pipeline: client.Pipeline(). - Collection(coll.ID). - Aggregate(Average("a").As("avg_a")), - want: map[string]interface{}{"avg_a": float64(1.5)}, - }, - { - name: "Count", - pipeline: client.Pipeline(). - Collection(coll.ID). - Aggregate(Count("a").As("count_a")), - want: map[string]interface{}{"count_a": int64(2)}, - }, - { - name: "CountAll", - pipeline: client.Pipeline(). - Collection(coll.ID). - Aggregate(CountAll().As("count_all")), - want: map[string]interface{}{"count_all": int64(3)}, - }, - } - - ctx := context.Background() - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - iter := test.pipeline.Execute(ctx) - defer iter.Stop() - - docs, err := iter.GetAll() - if err != nil { - t.Fatalf("GetAll: %v", err) - } - if len(docs) != 1 { - t.Fatalf("expected 1 doc, got %d", len(docs)) - } - got, err := docs[0].Data() - if err != nil { - t.Fatalf("Data: %v", err) - } - if diff := testutil.Diff(got, test.want); diff != "" { - t.Errorf("got: %v, want: %v, diff +want -got: %s", got, test.want, diff) - } - }) - } -} - -func comparisonFuncs(t *testing.T) { - t.Parallel() - ctx := context.Background() - client := integrationClient(t) - now := time.Now() - coll := client.Collection(collectionIDs.New()) - doc1data := map[string]interface{}{ - "timestamp": now, - "a": 1, - "b": 2, - "c": -3, - "d": 4.5, - "e": -5.5, - } - _, err := coll.Doc("doc1").Create(ctx, doc1data) - if err != nil { - t.Fatalf("Create: %v", err) - } - doc2data := map[string]interface{}{ - "timestamp": now, - "a": 2, - "b": 2, - "c": -3, - "d": 4.5, - "e": -5.5, - } - _, err = coll.Doc("doc2").Create(ctx, doc2data) - if err != nil { - t.Fatalf("Create: %v", err) - } - defer deleteDocuments([]*DocumentRef{coll.Doc("doc1"), coll.Doc("doc2")}) - - doc1want := map[string]interface{}{"a": int64(1), "b": int64(2), "c": int64(-3), "d": float64(4.5), "e": float64(-5.5), "timestamp": now.Truncate(time.Microsecond)} - - tests := []struct { - name string - pipeline *Pipeline - want []map[string]interface{} - }{ - { - name: "Equal", - pipeline: client.Pipeline(). - Collection(coll.ID). - Where(Equal("a", 1)), - want: []map[string]interface{}{doc1want}, - }, - { - name: "NotEqual", - pipeline: client.Pipeline(). - Collection(coll.ID). - Where(NotEqual("a", 2)), - want: []map[string]interface{}{doc1want}, - }, - { - name: "LessThan", - pipeline: client.Pipeline(). - Collection(coll.ID). - Where(LessThan("a", 2)), - want: []map[string]interface{}{doc1want}, - }, - { - name: "Equivalent", - pipeline: client.Pipeline(). - Collection(coll.ID). - Where(Equivalent("a", 1)), - want: []map[string]interface{}{doc1want}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - iter := test.pipeline.Execute(ctx) - defer iter.Stop() - - docs, err := iter.GetAll() - if err != nil { - t.Fatalf("GetAll: %v", err) - } - if len(docs) != len(test.want) { - t.Fatalf("expected %d doc(s), got %d", len(test.want), len(docs)) - } - - var gots []map[string]interface{} - for _, doc := range docs { - got, err := doc.Data() - if err != nil { - t.Fatalf("Data: %v", err) - } - if ts, ok := got["timestamp"].(time.Time); ok { - got["timestamp"] = ts.Truncate(time.Microsecond) - } - gots = append(gots, got) - } - - if diff := testutil.Diff(gots, test.want); diff != "" { - t.Errorf("got: %v, want: %v, diff +want -got: %s", gots, test.want, diff) - } - }) - } -} diff --git a/firestore/pipeline_filter_condition.go b/firestore/pipeline_filter_condition.go index 618f6b7a75b6..e9a1c25dedad 100644 --- a/firestore/pipeline_filter_condition.go +++ b/firestore/pipeline_filter_condition.go @@ -39,7 +39,7 @@ var _ BooleanExpr = (*baseBooleanExpr)(nil) // // Check if the 'tags' array contains "Go". // ArrayContains("tags", "Go") func ArrayContains(exprOrFieldPath any, value any) BooleanExpr { - return &baseBooleanExpr{baseFunction: newBaseFunction("array_contains", []Expr{toExprOrField(exprOrFieldPath), toExprOrConstant(value)})} + return &baseBooleanExpr{baseFunction: newBaseFunction("array_contains", []Expr{asFieldExpr(exprOrFieldPath), toExprOrConstant(value)})} } // ArrayContainsAll creates an expression that checks if an array contains all of the provided values. @@ -241,7 +241,7 @@ func Equivalent(left, right any) BooleanExpr { // // Check if the 'filename' field ends with ".go". // EndsWith("filename", ".go") func EndsWith(exprOrFieldPath any, suffix any) BooleanExpr { - return &baseBooleanExpr{baseFunction: newBaseFunction("ends_with", []Expr{toExprOrField(exprOrFieldPath), asStringExpr(suffix)})} + return &baseBooleanExpr{baseFunction: newBaseFunction("ends_with", []Expr{asFieldExpr(exprOrFieldPath), asStringExpr(suffix)})} } // Like creates an expression that performs a case-sensitive wildcard string comparison. @@ -253,7 +253,7 @@ func EndsWith(exprOrFieldPath any, suffix any) BooleanExpr { // // Check if the 'name' field starts with "G". // Like("name", "G%") func Like(exprOrFieldPath any, pattern any) BooleanExpr { - return &baseBooleanExpr{baseFunction: newBaseFunction("like", []Expr{toExprOrField(exprOrFieldPath), asStringExpr(pattern)})} + return &baseBooleanExpr{baseFunction: newBaseFunction("like", []Expr{asFieldExpr(exprOrFieldPath), asStringExpr(pattern)})} } // RegexContains creates an expression that checks if a string contains a match for a regular expression. @@ -265,7 +265,7 @@ func Like(exprOrFieldPath any, pattern any) BooleanExpr { // // Check if the 'email' field contains a gmail address. // RegexContains("email", "@gmail\\.com$") func RegexContains(exprOrFieldPath any, pattern any) BooleanExpr { - return &baseBooleanExpr{baseFunction: newBaseFunction("regex_contains", []Expr{toExprOrField(exprOrFieldPath), asStringExpr(pattern)})} + return &baseBooleanExpr{baseFunction: newBaseFunction("regex_contains", []Expr{asFieldExpr(exprOrFieldPath), asStringExpr(pattern)})} } // RegexMatch creates an expression that checks if a string matches a regular expression. @@ -277,7 +277,7 @@ func RegexContains(exprOrFieldPath any, pattern any) BooleanExpr { // // Check if the 'zip_code' field is a 5-digit number. // RegexMatch("zip_code", "^[0-9]{5}$") func RegexMatch(exprOrFieldPath any, pattern any) BooleanExpr { - return &baseBooleanExpr{baseFunction: newBaseFunction("regex_match", []Expr{toExprOrField(exprOrFieldPath), asStringExpr(pattern)})} + return &baseBooleanExpr{baseFunction: newBaseFunction("regex_match", []Expr{asFieldExpr(exprOrFieldPath), asStringExpr(pattern)})} } // StartsWith creates an expression that checks if a string field or expression starts with a given prefix. @@ -289,7 +289,7 @@ func RegexMatch(exprOrFieldPath any, pattern any) BooleanExpr { // // Check if the 'name' field starts with "Mr.". // StartsWith("name", "Mr.") func StartsWith(exprOrFieldPath any, prefix any) BooleanExpr { - return &baseBooleanExpr{baseFunction: newBaseFunction("starts_with", []Expr{toExprOrField(exprOrFieldPath), asStringExpr(prefix)})} + return &baseBooleanExpr{baseFunction: newBaseFunction("starts_with", []Expr{asFieldExpr(exprOrFieldPath), asStringExpr(prefix)})} } // StringContains creates an expression that checks if a string contains a specified substring. @@ -301,7 +301,7 @@ func StartsWith(exprOrFieldPath any, prefix any) BooleanExpr { // // Check if the 'description' field contains the word "Firestore". // StringContains("description", "Firestore") func StringContains(exprOrFieldPath any, substring any) BooleanExpr { - return &baseBooleanExpr{baseFunction: newBaseFunction("string_contains", []Expr{toExprOrField(exprOrFieldPath), asStringExpr(substring)})} + return &baseBooleanExpr{baseFunction: newBaseFunction("string_contains", []Expr{asFieldExpr(exprOrFieldPath), asStringExpr(substring)})} } // IsNaN creates a boolean expression that checks if a field or expression evaluates to NaN. @@ -313,7 +313,7 @@ func StringContains(exprOrFieldPath any, substring any) BooleanExpr { // // Check if the 'score' field is NaN. // IsNaN("score") func IsNaN(exprOrFieldPath any) BooleanExpr { - return &baseBooleanExpr{baseFunction: newBaseFunction("is_nan", []Expr{toExprOrField(exprOrFieldPath)})} + return &baseBooleanExpr{baseFunction: newBaseFunction("is_nan", []Expr{asFieldExpr(exprOrFieldPath)})} } // IsNotNaN creates a boolean expression that checks if a field or expression does not evaluate to NaN. @@ -325,7 +325,7 @@ func IsNaN(exprOrFieldPath any) BooleanExpr { // // Check if the 'score' field is not NaN. // IsNotNaN("score") func IsNotNaN(exprOrFieldPath any) BooleanExpr { - return &baseBooleanExpr{baseFunction: newBaseFunction("is_not_nan", []Expr{toExprOrField(exprOrFieldPath)})} + return &baseBooleanExpr{baseFunction: newBaseFunction("is_not_nan", []Expr{asFieldExpr(exprOrFieldPath)})} } // IsNull creates a boolean expression that checks if a field or expression evaluates to null. @@ -337,7 +337,7 @@ func IsNotNaN(exprOrFieldPath any) BooleanExpr { // // Check if the 'address' field is null. // IsNull("address") func IsNull(exprOrFieldPath any) BooleanExpr { - return &baseBooleanExpr{baseFunction: newBaseFunction("is_null", []Expr{toExprOrField(exprOrFieldPath)})} + return &baseBooleanExpr{baseFunction: newBaseFunction("is_null", []Expr{asFieldExpr(exprOrFieldPath)})} } // IsNotNull creates a boolean expression that checks if a field or expression does not evaluate to null. @@ -349,5 +349,5 @@ func IsNull(exprOrFieldPath any) BooleanExpr { // // Check if the 'address' field is not null. // IsNotNull("address") func IsNotNull(exprOrFieldPath any) BooleanExpr { - return &baseBooleanExpr{baseFunction: newBaseFunction("is_not_null", []Expr{toExprOrField(exprOrFieldPath)})} + return &baseBooleanExpr{baseFunction: newBaseFunction("is_not_null", []Expr{asFieldExpr(exprOrFieldPath)})} } diff --git a/firestore/pipeline_function.go b/firestore/pipeline_function.go index 3d095bf69a67..d55cf0ae682e 100644 --- a/firestore/pipeline_function.go +++ b/firestore/pipeline_function.go @@ -328,7 +328,7 @@ func CurrentTimestamp() Expr { // // Get the length of the 'tags' array field. // ArrayLength("tags") func ArrayLength(exprOrFieldPath any) Expr { - return newBaseFunction("array_length", []Expr{toExprOrField(exprOrFieldPath)}) + return newBaseFunction("array_length", []Expr{asFieldExpr(exprOrFieldPath)}) } // Array creates an expression that represents a Firestore array. @@ -359,7 +359,7 @@ func ArrayFromSlice[T any](elements []T) Expr { // // Get the first element of the 'tags' array field. // ArrayGet("tags", 0) func ArrayGet(exprOrFieldPath any, offset any) Expr { - return newBaseFunction("array_get", []Expr{toExprOrField(exprOrFieldPath), asInt64Expr(offset)}) + return newBaseFunction("array_get", []Expr{asFieldExpr(exprOrFieldPath), asInt64Expr(offset)}) } // ArrayReverse creates an expression that reverses the order of elements in an array. @@ -370,7 +370,7 @@ func ArrayGet(exprOrFieldPath any, offset any) Expr { // // Reverse the 'tags' array. // ArrayReverse("tags") func ArrayReverse(exprOrFieldPath any) Expr { - return newBaseFunction("array_reverse", []Expr{toExprOrField(exprOrFieldPath)}) + return newBaseFunction("array_reverse", []Expr{asFieldExpr(exprOrFieldPath)}) } // ArrayConcat creates an expression that concatenates multiple arrays into a single array. @@ -382,7 +382,7 @@ func ArrayReverse(exprOrFieldPath any) Expr { // // Concatenate the 'tags' and 'categories' array fields. // ArrayConcat("tags", FieldOf("categories")) func ArrayConcat(exprOrFieldPath any, otherArrays ...any) Expr { - return newBaseFunction("array_concat", append([]Expr{toExprOrField(exprOrFieldPath)}, toExprs(otherArrays)...)) + return newBaseFunction("array_concat", append([]Expr{asFieldExpr(exprOrFieldPath)}, toExprs(otherArrays)...)) } // ArraySum creates an expression that calculates the sum of all elements in a numeric array. @@ -393,7 +393,7 @@ func ArrayConcat(exprOrFieldPath any, otherArrays ...any) Expr { // // Calculate the sum of the 'scores' array. // ArraySum("scores") func ArraySum(exprOrFieldPath any) Expr { - return newBaseFunction("sum", []Expr{toExprOrField(exprOrFieldPath)}) + return newBaseFunction("sum", []Expr{asFieldExpr(exprOrFieldPath)}) } // ArrayMaximum creates an expression that finds the maximum element in a numeric array. @@ -404,7 +404,7 @@ func ArraySum(exprOrFieldPath any) Expr { // // Find the maximum value in the 'scores' array. // ArrayMaximum("scores") func ArrayMaximum(exprOrFieldPath any) Expr { - return newBaseFunction("maximum", []Expr{toExprOrField(exprOrFieldPath)}) + return newBaseFunction("maximum", []Expr{asFieldExpr(exprOrFieldPath)}) } // ArrayMinimum creates an expression that finds the minimum element in a numeric array. @@ -415,7 +415,7 @@ func ArrayMaximum(exprOrFieldPath any) Expr { // // Find the minimum value in the 'scores' array. // ArrayMinimum("scores") func ArrayMinimum(exprOrFieldPath any) Expr { - return newBaseFunction("minimum", []Expr{toExprOrField(exprOrFieldPath)}) + return newBaseFunction("minimum", []Expr{asFieldExpr(exprOrFieldPath)}) } // ByteLength creates an expression that calculates the length of a string represented by a field or [Expr] in UTF-8 @@ -427,7 +427,7 @@ func ArrayMinimum(exprOrFieldPath any) Expr { // // Get the byte length of the 'name' field. // ByteLength("name") func ByteLength(exprOrFieldPath any) Expr { - return newBaseFunction("byte_length", []Expr{toExprOrField(exprOrFieldPath)}) + return newBaseFunction("byte_length", []Expr{asFieldExpr(exprOrFieldPath)}) } // CharLength creates an expression that calculates the character length of a string field or expression in UTF8. @@ -438,7 +438,7 @@ func ByteLength(exprOrFieldPath any) Expr { // // Get the character length of the 'name' field. // CharLength("name") func CharLength(exprOrFieldPath any) Expr { - return newBaseFunction("char_length", []Expr{toExprOrField(exprOrFieldPath)}) + return newBaseFunction("char_length", []Expr{asFieldExpr(exprOrFieldPath)}) } // StringConcat creates an expression that concatenates multiple strings into a single string. @@ -450,7 +450,7 @@ func CharLength(exprOrFieldPath any) Expr { // // Concatenate first name and last name. // StringConcat(FieldOf("firstName"), " ", FieldOf("lastName")) func StringConcat(exprOrFieldPath any, otherStrings ...any) Expr { - return newBaseFunction("string_concat", append([]Expr{toExprOrField(exprOrFieldPath)}, toExprs(otherStrings)...)) + return newBaseFunction("string_concat", append([]Expr{asFieldExpr(exprOrFieldPath)}, toExprs(otherStrings)...)) } // StringReverse creates an expression that reverses a string. @@ -461,7 +461,7 @@ func StringConcat(exprOrFieldPath any, otherStrings ...any) Expr { // // Reverse the 'name' field. // StringReverse("name") func StringReverse(exprOrFieldPath any) Expr { - return newBaseFunction("string_reverse", []Expr{toExprOrField(exprOrFieldPath)}) + return newBaseFunction("string_reverse", []Expr{asFieldExpr(exprOrFieldPath)}) } // Join creates an expression that joins the elements of a string array into a single string. @@ -473,7 +473,7 @@ func StringReverse(exprOrFieldPath any) Expr { // // Join the 'tags' array with a comma and space. // Join("tags", ", ") func Join(exprOrFieldPath any, separator any) Expr { - return newBaseFunction("join", []Expr{toExprOrField(exprOrFieldPath), asStringExpr(separator)}) + return newBaseFunction("join", []Expr{asFieldExpr(exprOrFieldPath), asStringExpr(separator)}) } // Substring creates an expression that returns a substring of a string. @@ -486,7 +486,7 @@ func Join(exprOrFieldPath any, separator any) Expr { // // Get the first 5 characters of the 'description' field. // Substring("description", 0, 5) func Substring(exprOrFieldPath any, index any, offset any) Expr { - return newBaseFunction("substring", []Expr{toExprOrField(exprOrFieldPath), asInt64Expr(index), asInt64Expr(offset)}) + return newBaseFunction("substring", []Expr{asFieldExpr(exprOrFieldPath), asInt64Expr(index), asInt64Expr(offset)}) } // ToLower creates an expression that converts a string to lowercase. @@ -497,7 +497,7 @@ func Substring(exprOrFieldPath any, index any, offset any) Expr { // // Convert the 'username' to lowercase. // ToLower("username") func ToLower(exprOrFieldPath any) Expr { - return newBaseFunction("to_lower", []Expr{toExprOrField(exprOrFieldPath)}) + return newBaseFunction("to_lower", []Expr{asFieldExpr(exprOrFieldPath)}) } // ToUpper creates an expression that converts a string to uppercase. @@ -508,7 +508,7 @@ func ToLower(exprOrFieldPath any) Expr { // // Convert the 'product_code' to uppercase. // ToUpper("product_code") func ToUpper(exprOrFieldPath any) Expr { - return newBaseFunction("to_upper", []Expr{toExprOrField(exprOrFieldPath)}) + return newBaseFunction("to_upper", []Expr{asFieldExpr(exprOrFieldPath)}) } // Trim creates an expression that removes leading and trailing whitespace from a string. @@ -519,7 +519,7 @@ func ToUpper(exprOrFieldPath any) Expr { // // Trim the 'email' field. // Trim("email") func Trim(exprOrFieldPath any) Expr { - return newBaseFunction("trim", []Expr{toExprOrField(exprOrFieldPath)}) + return newBaseFunction("trim", []Expr{asFieldExpr(exprOrFieldPath)}) } // CosineDistance creates an expression that calculates the cosine distance between two vectors. @@ -531,7 +531,7 @@ func Trim(exprOrFieldPath any) Expr { // // Calculate the cosine distance between two vector fields. // CosineDistance("vector_field_1", FieldOf("vector_field_2")) func CosineDistance(vector1 any, vector2 any) Expr { - return newBaseFunction("cosine_distance", []Expr{toExprOrField(vector1), asVectorExpr(vector2)}) + return newBaseFunction("cosine_distance", []Expr{asFieldExpr(vector1), asVectorExpr(vector2)}) } // DotProduct creates an expression that calculates the dot product of two vectors. @@ -543,7 +543,7 @@ func CosineDistance(vector1 any, vector2 any) Expr { // // Calculate the dot product of two vector fields. // DotProduct("vector_field_1", FieldOf("vector_field_2")) func DotProduct(vector1 any, vector2 any) Expr { - return newBaseFunction("dot_product", []Expr{toExprOrField(vector1), asVectorExpr(vector2)}) + return newBaseFunction("dot_product", []Expr{asFieldExpr(vector1), asVectorExpr(vector2)}) } // EuclideanDistance creates an expression that calculates the euclidean distance between two vectors. @@ -555,7 +555,7 @@ func DotProduct(vector1 any, vector2 any) Expr { // // Calculate the euclidean distance between two vector fields. // EuclideanDistance("vector_field_1", FieldOf("vector_field_2")) func EuclideanDistance(vector1 any, vector2 any) Expr { - return newBaseFunction("euclidean_distance", []Expr{toExprOrField(vector1), asVectorExpr(vector2)}) + return newBaseFunction("euclidean_distance", []Expr{asFieldExpr(vector1), asVectorExpr(vector2)}) } // VectorLength creates an expression that calculates the length of a vector. @@ -566,5 +566,5 @@ func EuclideanDistance(vector1 any, vector2 any) Expr { // // Calculate the length of a vector field. // VectorLength("vector_field") func VectorLength(exprOrFieldPath any) Expr { - return newBaseFunction("vector_length", []Expr{toExprOrField(exprOrFieldPath)}) + return newBaseFunction("vector_length", []Expr{asFieldExpr(exprOrFieldPath)}) } diff --git a/firestore/pipeline_integration_test.go b/firestore/pipeline_integration_test.go index b8f9695a0235..5b6c98e6e57b 100644 --- a/firestore/pipeline_integration_test.go +++ b/firestore/pipeline_integration_test.go @@ -36,6 +36,11 @@ func TestIntegration_PipelineFunctions(t *testing.T) { t.Run("typeFuncs", typeFuncs) t.Run("vectorFuncs", vectorFuncs) + t.Run("timestampFuncs", timestampFuncs) + t.Run("arithmeticFuncs", arithmeticFuncs) + t.Run("aggregateFuncs", aggregateFuncs) + t.Run("comparisonFuncs", comparisonFuncs) + } func arrayFuncs(t *testing.T) { @@ -161,11 +166,7 @@ func arrayFuncs(t *testing.T) { r.Fatalf("expected 1 doc, got %d", len(docs)) return } - got, err := docs[0].Data() - if err != nil { - r.Fatalf("Data: %v", err) - return - } + got := docs[0].Data() if diff := testutil.Diff(got, test.want); diff != "" { r.Errorf("got: %v, want: %v, diff +want -got: %s", got, test.want, diff) return @@ -312,11 +313,7 @@ func stringFuncs(t *testing.T) { r.Fatalf("expected 1 doc, got %d", len(docs)) return } - got, err := docs[0].Data() - if err != nil { - r.Fatalf("Data: %v", err) - return - } + got := docs[0].Data() if diff := testutil.Diff(got, want); diff != "" { t.Errorf("got: %v, want: %v, diff +want -got: %s", got, want, diff) } @@ -332,11 +329,7 @@ func stringFuncs(t *testing.T) { } var gots []map[string]interface{} for _, doc := range docs { - got, err := doc.Data() - if err != nil { - r.Fatalf("Data: %v", err) - return - } + got := doc.Data() gots = append(gots, got) } if diff := testutil.Diff(gots, want); diff != "" { @@ -516,11 +509,7 @@ func vectorFuncs(t *testing.T) { r.Fatalf("expected 1 doc, got %d", len(docs)) return } - got, err := docs[0].Data() - if err != nil { - r.Fatalf("Data: %v", err) - return - } + got := docs[0].Data() if diff := testutil.Diff(got, test.want); diff != "" { r.Errorf("got: %v, want: %v, diff +want -got: %s", got, test.want, diff) } @@ -541,3 +530,488 @@ func isRetryablePipelineExecuteErr(err error) bool { strings.Contains(s.Message(), "Invalid request routing header") && strings.Contains(s.Message(), "Please fill in the request header with format") } + +func timestampFuncs(t *testing.T) { + t.Parallel() + client := integrationClient(t) + coll := client.Collection(collectionIDs.New()) + h := testHelper{t} + now := time.Now() + docRef1 := coll.NewDoc() + h.mustCreate(docRef1, map[string]interface{}{ + "timestamp": now, + "unixMicros": now.UnixNano() / 1000, + "unixMillis": now.UnixNano() / 1e6, + "unixSeconds": now.Unix(), + }) + defer deleteDocuments([]*DocumentRef{docRef1}) + + tests := []struct { + name string + pipeline *Pipeline + want map[string]interface{} + }{ + { + name: "TimestampAdd day", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(TimestampAdd("timestamp", "day", 1).As("timestamp_plus_day")), + want: map[string]interface{}{"timestamp_plus_day": now.AddDate(0, 0, 1).Truncate(time.Microsecond)}, + }, + { + name: "TimestampAdd hour", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(TimestampAdd("timestamp", "hour", 1).As("timestamp_plus_hour")), + want: map[string]interface{}{"timestamp_plus_hour": now.Add(time.Hour).Truncate(time.Microsecond)}, + }, + { + name: "TimestampAdd minute", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(TimestampAdd("timestamp", "minute", 1).As("timestamp_plus_minute")), + want: map[string]interface{}{"timestamp_plus_minute": now.Add(time.Minute).Truncate(time.Microsecond)}, + }, + { + name: "TimestampAdd second", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(TimestampAdd("timestamp", "second", 1).As("timestamp_plus_second")), + want: map[string]interface{}{"timestamp_plus_second": now.Add(time.Second).Truncate(time.Microsecond)}, + }, + { + name: "TimestampSubtract", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(TimestampSubtract("timestamp", "hour", 1).As("timestamp_minus_hour")), + want: map[string]interface{}{"timestamp_minus_hour": now.Add(-time.Hour).Truncate(time.Microsecond)}, + }, + { + name: "TimestampToUnixMicros", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(FieldOf("timestamp").TimestampToUnixMicros().As("timestamp_micros")), + want: map[string]interface{}{"timestamp_micros": now.UnixNano() / 1000}, + }, + { + name: "TimestampToUnixMillis", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(FieldOf("timestamp").TimestampToUnixMillis().As("timestamp_millis")), + want: map[string]interface{}{"timestamp_millis": now.UnixNano() / 1e6}, + }, + { + name: "TimestampToUnixSeconds", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(FieldOf("timestamp").TimestampToUnixSeconds().As("timestamp_seconds")), + want: map[string]interface{}{"timestamp_seconds": now.Unix()}, + }, + { + name: "UnixMicrosToTimestamp - constant", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(UnixMicrosToTimestamp(ConstantOf(now.UnixNano() / 1000)).As("timestamp_from_micros")), + want: map[string]interface{}{"timestamp_from_micros": now.Truncate(time.Microsecond)}, + }, + { + name: "UnixMicrosToTimestamp - fieldname", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(UnixMicrosToTimestamp("unixMicros").As("timestamp_from_micros")), + want: map[string]interface{}{"timestamp_from_micros": now.Truncate(time.Microsecond)}, + }, + { + name: "UnixMillisToTimestamp", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(UnixMillisToTimestamp(ConstantOf(now.UnixNano() / 1e6)).As("timestamp_from_millis")), + want: map[string]interface{}{"timestamp_from_millis": now.Truncate(time.Millisecond)}, + }, + { + name: "UnixSecondsToTimestamp", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(UnixSecondsToTimestamp("unixSeconds").As("timestamp_from_seconds")), + want: map[string]interface{}{"timestamp_from_seconds": now.Truncate(time.Second)}, + }, + { + name: "CurrentTimestamp", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(CurrentTimestamp().As("current_timestamp")), + want: map[string]interface{}{"current_timestamp": time.Now().Truncate(time.Microsecond)}, + }, + } + + ctx := context.Background() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + iter := test.pipeline.Execute(ctx) + defer iter.Stop() + + docs, err := iter.GetAll() + if err != nil { + t.Fatalf("GetAll: %v", err) + } + if len(docs) != 1 { + t.Fatalf("expected 1 doc, got %d", len(docs)) + } + got := docs[0].Data() + margin := 0 * time.Microsecond + if test.name == "CurrentTimestamp" { + margin = 5 * time.Second + } + if diff := testutil.Diff(got, test.want, cmpopts.EquateApproxTime(margin)); diff != "" { + t.Errorf("got: %v, want: %v, diff: %s", got, test.want, diff) + } + }) + } +} + +func arithmeticFuncs(t *testing.T) { + t.Parallel() + h := testHelper{t} + client := integrationClient(t) + coll := client.Collection(collectionIDs.New()) + docRef1 := coll.NewDoc() + h.mustCreate(docRef1, map[string]interface{}{ + "a": int(1), + "b": int(2), + "c": -3, + "d": 4.5, + "e": -5.5, + }) + defer deleteDocuments([]*DocumentRef{docRef1}) + + tests := []struct { + name string + pipeline *Pipeline + want map[string]interface{} + }{ + { + name: "Add - left FieldOf, right FieldOf", + pipeline: client.Pipeline().Collection(coll.ID).Select(Add(FieldOf("a"), FieldOf("b")).As("add")), + want: map[string]interface{}{"add": int64(3)}, + }, + { + name: "Add - left FieldOf, right ConstantOf", + pipeline: client.Pipeline().Collection(coll.ID).Select(Add(FieldOf("a"), ConstantOf(2)).As("add")), + want: map[string]interface{}{"add": int64(3)}, + }, + { + name: "Add - left FieldOf, right constant", + pipeline: client.Pipeline().Collection(coll.ID).Select(Add(FieldOf("a"), 5).As("add")), + want: map[string]interface{}{"add": int64(6)}, + }, + { + name: "Add - left fieldname, right constant", + pipeline: client.Pipeline().Collection(coll.ID).Select(Add("a", 5).As("add")), + want: map[string]interface{}{"add": int64(6)}, + }, + { + name: "Add - left fieldpath, right constant", + pipeline: client.Pipeline().Collection(coll.ID).Select(Add(FieldPath([]string{"a"}), 5).As("add")), + want: map[string]interface{}{"add": int64(6)}, + }, + { + name: "Add - left fieldpath, right expression", + pipeline: client.Pipeline().Collection(coll.ID).Select(Add(FieldPath([]string{"a"}), Add(FieldOf("b"), FieldOf("d"))).As("add")), + want: map[string]interface{}{"add": float64(7.5)}, + }, + { + name: "Subtract", + pipeline: client.Pipeline().Collection(coll.ID).Select(Subtract("a", FieldOf("b")).As("subtract")), + want: map[string]interface{}{"subtract": int64(-1)}, + }, + { + name: "Multiply", + pipeline: client.Pipeline().Collection(coll.ID).Select(Multiply("a", 5).As("multiply")), + want: map[string]interface{}{"multiply": int64(5)}, + }, + { + name: "Divide", + pipeline: client.Pipeline().Collection(coll.ID).Select(Divide("a", FieldOf("d")).As("divide")), + want: map[string]interface{}{"divide": float64(1 / 4.5)}, + }, + { + name: "Mod", + pipeline: client.Pipeline().Collection(coll.ID).Select(Mod("a", FieldOf("b")).As("mod")), + want: map[string]interface{}{"mod": int64(1)}, + }, + { + name: "Pow", + pipeline: client.Pipeline().Collection(coll.ID).Select(Pow("a", FieldOf("b")).As("pow")), + want: map[string]interface{}{"pow": float64(1)}, + }, + { + name: "Abs - fieldname", + pipeline: client.Pipeline().Collection(coll.ID).Select(Abs("c").As("abs")), + want: map[string]interface{}{"abs": int64(3)}, + }, + { + name: "Abs - fieldPath", + pipeline: client.Pipeline().Collection(coll.ID).Select(Abs(FieldPath([]string{"c"})).As("abs")), + want: map[string]interface{}{"abs": int64(3)}, + }, + { + name: "Abs - Expr", + pipeline: client.Pipeline().Collection(coll.ID).Select(Abs(Add(FieldOf("b"), FieldOf("d"))).As("abs")), + want: map[string]interface{}{"abs": float64(6.5)}, + }, + { + name: "Ceil", + pipeline: client.Pipeline().Collection(coll.ID).Select(Ceil("d").As("ceil")), + want: map[string]interface{}{"ceil": float64(5)}, + }, + { + name: "Floor", + pipeline: client.Pipeline().Collection(coll.ID).Select(Floor("d").As("floor")), + want: map[string]interface{}{"floor": float64(4)}, + }, + { + name: "Round", + pipeline: client.Pipeline().Collection(coll.ID).Select(Round("d").As("round")), + want: map[string]interface{}{"round": float64(5)}, + }, + { + name: "Sqrt", + pipeline: client.Pipeline().Collection(coll.ID).Select(Sqrt("d").As("sqrt")), + want: map[string]interface{}{"sqrt": math.Sqrt(4.5)}, + }, + { + name: "Log", + pipeline: client.Pipeline().Collection(coll.ID).Select(Log("d", 2).As("log")), + want: map[string]interface{}{"log": math.Log2(4.5)}, + }, + { + name: "Log10", + pipeline: client.Pipeline().Collection(coll.ID).Select(Log10("d").As("log10")), + want: map[string]interface{}{"log10": math.Log10(4.5)}, + }, + { + name: "Ln", + pipeline: client.Pipeline().Collection(coll.ID).Select(Ln("d").As("ln")), + want: map[string]interface{}{"ln": math.Log(4.5)}, + }, + { + name: "Exp", + pipeline: client.Pipeline().Collection(coll.ID).Select(Exp("d").As("exp")), + want: map[string]interface{}{"exp": math.Exp(4.5)}, + }, + } + + ctx := context.Background() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + iter := test.pipeline.Execute(ctx) + defer iter.Stop() + + docs, err := iter.GetAll() + if err != nil { + t.Fatalf("GetAll: %v", err) + } + if len(docs) != 1 { + t.Fatalf("expected 1 doc, got %d", len(docs)) + } + got := docs[0].Data() + if diff := testutil.Diff(got, test.want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", got, test.want, diff) + } + }) + } +} + +func aggregateFuncs(t *testing.T) { + t.Parallel() + h := testHelper{t} + client := integrationClient(t) + coll := client.Collection(collectionIDs.New()) + docRef1 := coll.NewDoc() + h.mustCreate(docRef1, map[string]interface{}{ + "a": 1, + }) + docRef2 := coll.NewDoc() + h.mustCreate(docRef2, map[string]interface{}{ + "a": 2, + }) + docRef3 := coll.NewDoc() + h.mustCreate(docRef3, map[string]interface{}{ + "b": 2, + }) + defer deleteDocuments([]*DocumentRef{docRef1, docRef2, docRef3}) + + tests := []struct { + name string + pipeline *Pipeline + want map[string]interface{} + }{ + { + name: "Sum - fieldname arg", + pipeline: client.Pipeline(). + Collection(coll.ID). + Aggregate(Sum("a").As("sum_a")), + want: map[string]interface{}{"sum_a": int64(3)}, + }, + { + name: "Sum - fieldpath arg", + pipeline: client.Pipeline(). + Collection(coll.ID). + Aggregate(Sum(FieldPath([]string{"a"})).As("sum_a")), + want: map[string]interface{}{"sum_a": int64(3)}, + }, + { + name: "Sum - FieldOf Expr", + pipeline: client.Pipeline(). + Collection(coll.ID). + Aggregate(Sum(FieldOf("a")).As("sum_a")), + want: map[string]interface{}{"sum_a": int64(3)}, + }, + { + name: "Sum - FieldOfPath Expr", + pipeline: client.Pipeline(). + Collection(coll.ID). + Aggregate(Sum(FieldOfPath(FieldPath([]string{"a"}))).As("sum_a")), + want: map[string]interface{}{"sum_a": int64(3)}, + }, + { + name: "Avg", + pipeline: client.Pipeline(). + Collection(coll.ID). + Aggregate(Average("a").As("avg_a")), + want: map[string]interface{}{"avg_a": float64(1.5)}, + }, + { + name: "Count", + pipeline: client.Pipeline(). + Collection(coll.ID). + Aggregate(Count("a").As("count_a")), + want: map[string]interface{}{"count_a": int64(2)}, + }, + { + name: "CountAll", + pipeline: client.Pipeline(). + Collection(coll.ID). + Aggregate(CountAll().As("count_all")), + want: map[string]interface{}{"count_all": int64(3)}, + }, + } + + ctx := context.Background() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + iter := test.pipeline.Execute(ctx) + defer iter.Stop() + + docs, err := iter.GetAll() + if err != nil { + t.Fatalf("GetAll: %v", err) + } + if len(docs) != 1 { + t.Fatalf("expected 1 doc, got %d", len(docs)) + } + got := docs[0].Data() + if diff := testutil.Diff(got, test.want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", got, test.want, diff) + } + }) + } +} + +func comparisonFuncs(t *testing.T) { + t.Parallel() + ctx := context.Background() + client := integrationClient(t) + now := time.Now() + coll := client.Collection(collectionIDs.New()) + doc1data := map[string]interface{}{ + "timestamp": now, + "a": 1, + "b": 2, + "c": -3, + "d": 4.5, + "e": -5.5, + } + _, err := coll.Doc("doc1").Create(ctx, doc1data) + if err != nil { + t.Fatalf("Create: %v", err) + } + doc2data := map[string]interface{}{ + "timestamp": now, + "a": 2, + "b": 2, + "c": -3, + "d": 4.5, + "e": -5.5, + } + _, err = coll.Doc("doc2").Create(ctx, doc2data) + if err != nil { + t.Fatalf("Create: %v", err) + } + defer deleteDocuments([]*DocumentRef{coll.Doc("doc1"), coll.Doc("doc2")}) + + doc1want := map[string]interface{}{"a": int64(1), "b": int64(2), "c": int64(-3), "d": float64(4.5), "e": float64(-5.5), "timestamp": now.Truncate(time.Microsecond)} + + tests := []struct { + name string + pipeline *Pipeline + want []map[string]interface{} + }{ + { + name: "Equal", + pipeline: client.Pipeline(). + Collection(coll.ID). + Where(Equal("a", 1)), + want: []map[string]interface{}{doc1want}, + }, + { + name: "NotEqual", + pipeline: client.Pipeline(). + Collection(coll.ID). + Where(NotEqual("a", 2)), + want: []map[string]interface{}{doc1want}, + }, + { + name: "LessThan", + pipeline: client.Pipeline(). + Collection(coll.ID). + Where(LessThan("a", 2)), + want: []map[string]interface{}{doc1want}, + }, + { + name: "Equivalent", + pipeline: client.Pipeline(). + Collection(coll.ID). + Where(Equivalent("a", 1)), + want: []map[string]interface{}{doc1want}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + iter := test.pipeline.Execute(ctx) + defer iter.Stop() + + docs, err := iter.GetAll() + if err != nil { + t.Fatalf("GetAll: %v", err) + } + if len(docs) != len(test.want) { + t.Fatalf("expected %d doc(s), got %d", len(test.want), len(docs)) + } + + var gots []map[string]interface{} + for _, doc := range docs { + got := doc.Data() + if ts, ok := got["timestamp"].(time.Time); ok { + got["timestamp"] = ts.Truncate(time.Microsecond) + } + gots = append(gots, got) + } + + if diff := testutil.Diff(gots, test.want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", gots, test.want, diff) + } + }) + } +} diff --git a/firestore/pipeline_utils.go b/firestore/pipeline_utils.go index 2ced52e8c79e..0c2a49b61018 100644 --- a/firestore/pipeline_utils.go +++ b/firestore/pipeline_utils.go @@ -24,7 +24,7 @@ import ( // newFieldAndArrayBooleanExpr creates a new BooleanExpr for functions that operate on a field/expression and an array of values. func newFieldAndArrayBooleanExpr(name string, exprOrFieldPath any, values any) BooleanExpr { - return &baseBooleanExpr{baseFunction: newBaseFunction(name, []Expr{toExprOrField(exprOrFieldPath), asArrayFunctionExpr(values)})} + return &baseBooleanExpr{baseFunction: newBaseFunction(name, []Expr{asFieldExpr(exprOrFieldPath), asArrayFunctionExpr(values)})} } // toExprs converts a plain Go value or an existing Expr into an Expr. @@ -125,28 +125,6 @@ func asFieldExpr(val any) Expr { } } -func asInt64Expr(val any) Expr { - switch v := val.(type) { - case Expr: - return v - case int, int32, int64: - return ConstantOf(v) - default: - return &baseExpr{err: fmt.Errorf("firestore: value must be a int, int32, int64 or Expr, but got %T", val)} - } -} - -func asStringExpr(val any) Expr { - switch v := val.(type) { - case Expr: - return v - case string: - return ConstantOf(v) - default: - return &baseExpr{err: fmt.Errorf("firestore: value must be a string or Expr, but got %T", val)} - } -} - // leftRightToBaseFunction is a helper for creating binary functions like Add or Eq. // It ensures the left operand is a field-like expression and the right is a constant-like expression. func leftRightToBaseFunction(name string, left, right any) *baseFunction { diff --git a/firestore/util_test.go b/firestore/util_test.go index aeb68badb8c9..004e0f3b4014 100644 --- a/firestore/util_test.go +++ b/firestore/util_test.go @@ -147,10 +147,7 @@ func refval(path string) *pb.Value { func docsToMaps(t *testing.T, docs []*PipelineResult) []map[string]interface{} { var maps []map[string]interface{} for _, doc := range docs { - data, err := doc.Data() - if err != nil { - t.Fatalf("Data: %v", err) - } + data := doc.Data() maps = append(maps, data) } return maps From 27ddeee6229aea6153ce2112dc8a40089f3c513b Mon Sep 17 00:00:00 2001 From: Baha Aiman Date: Thu, 30 Oct 2025 22:06:34 +0000 Subject: [PATCH 2/4] refactor(firestore): Combine FieldOf and FieldOfPath --- firestore/pipeline.go | 6 +++--- firestore/pipeline_aggregate.go | 2 +- firestore/pipeline_field.go | 20 +++++++++++--------- firestore/pipeline_integration_test.go | 4 ++-- firestore/pipeline_stage.go | 8 ++++---- firestore/pipeline_utils.go | 4 ++-- 6 files changed, 23 insertions(+), 21 deletions(-) diff --git a/firestore/pipeline.go b/firestore/pipeline.go index a7048d496c30..830eff99a611 100644 --- a/firestore/pipeline.go +++ b/firestore/pipeline.go @@ -166,8 +166,8 @@ func (p *Pipeline) Offset(offset int) *Pipeline { // // client.Pipeline().Collection("users").Select("info.email") // client.Pipeline().Collection("users").Select(FieldOf("info.email")) -// client.Pipeline().Collection("users").Select(FieldOfPath([]string{"info", "email"})) -// client.Pipeline().Collection("users").Select(FieldOfPath([]string{"info", "email"})) +// client.Pipeline().Collection("users").Select(FieldOf([]string{"info", "email"})) +// client.Pipeline().Collection("users").Select(FieldOf([]string{"info", "email"})) // client.Pipeline().Collection("users").Select(Add("age", 5).As("agePlus5")) func (p *Pipeline) Select(fieldpathsOrSelectables ...any) *Pipeline { if p.err != nil { @@ -341,7 +341,7 @@ func (p *Pipeline) UnnestWithAlias(fieldpath any, alias string, opts *UnnestOpti case string: fieldExpr = FieldOf(v) case FieldPath: - fieldExpr = FieldOfPath(v) + fieldExpr = FieldOf(v) default: p.err = errInvalidArg(fieldpath, "string", "FieldPath") return p diff --git a/firestore/pipeline_aggregate.go b/firestore/pipeline_aggregate.go index 0143b39c2a55..f9cc8a68ed74 100644 --- a/firestore/pipeline_aggregate.go +++ b/firestore/pipeline_aggregate.go @@ -44,7 +44,7 @@ func newBaseAggregateFunction(name string, fieldOrExpr any) *baseAggregateFuncti case string: valueExpr = FieldOf(value) case FieldPath: - valueExpr = FieldOfPath(value) + valueExpr = FieldOf(value) case Expr: valueExpr = value default: diff --git a/firestore/pipeline_field.go b/firestore/pipeline_field.go index fa1edb6f7740..9c1d03c0d3c1 100644 --- a/firestore/pipeline_field.go +++ b/firestore/pipeline_field.go @@ -29,20 +29,22 @@ type field struct { } // FieldOf creates a new field [Expr] from a field path string. -func FieldOf(path string) Expr { - fieldPath, err := parseDotSeparatedString(path) - if err != nil { - return &field{baseExpr: &baseExpr{err: err}} +func FieldOf[T string | FieldPath](path T) Expr { + anyPath := any(path) + var fieldPath FieldPath + var err error + if v, ok := anyPath.(string); ok { + fieldPath, err = parseDotSeparatedString(v) + if err != nil { + return &field{baseExpr: &baseExpr{err: err}} + } + } else { + fieldPath = anyPath.(FieldPath) } - return FieldOfPath(fieldPath) -} -// FieldOfPath creates a new field [Expr] for the given [FieldPath]. -func FieldOfPath(fieldPath FieldPath) Expr { if err := fieldPath.validate(); err != nil { return &field{baseExpr: &baseExpr{err: err}} } - pbVal := &pb.Value{ ValueType: &pb.Value_FieldReferenceValue{ FieldReferenceValue: fieldPath.toServiceFieldPath(), diff --git a/firestore/pipeline_integration_test.go b/firestore/pipeline_integration_test.go index 5b6c98e6e57b..f3936349ef0c 100644 --- a/firestore/pipeline_integration_test.go +++ b/firestore/pipeline_integration_test.go @@ -868,10 +868,10 @@ func aggregateFuncs(t *testing.T) { want: map[string]interface{}{"sum_a": int64(3)}, }, { - name: "Sum - FieldOfPath Expr", + name: "Sum - FieldOf Path Expr", pipeline: client.Pipeline(). Collection(coll.ID). - Aggregate(Sum(FieldOfPath(FieldPath([]string{"a"}))).As("sum_a")), + Aggregate(Sum(FieldOf(FieldPath([]string{"a"}))).As("sum_a")), want: map[string]interface{}{"sum_a": int64(3)}, }, { diff --git a/firestore/pipeline_stage.go b/firestore/pipeline_stage.go index 0df166aae45f..e05aac5cfd56 100644 --- a/firestore/pipeline_stage.go +++ b/firestore/pipeline_stage.go @@ -213,7 +213,7 @@ func newFindNearestStage(vectorField any, queryVector any, measure PipelineDista case string: propertyExpr = FieldOf(v) case FieldPath: - propertyExpr = FieldOfPath(v) + propertyExpr = FieldOf(v) case Expr: propertyExpr = v default: @@ -300,7 +300,7 @@ func newRemoveFieldsStage(fieldpaths ...any) (*removeFieldsStage, error) { case string: fields[i] = FieldOf(v) case FieldPath: - fields[i] = FieldOfPath(v) + fields[i] = FieldOf(v) default: return nil, errInvalidArg(fp, "string", "FieldPath") } @@ -332,7 +332,7 @@ func newReplaceStage(fieldOrSelectable any) (*replaceStage, error) { case string: expr = FieldOf(v) case FieldPath: - expr = FieldOfPath(v) + expr = FieldOf(v) case Selectable: _, expr = v.getSelectionDetails() default: @@ -463,7 +463,7 @@ func newUnnestStage(fieldExpr Expr, alias string, opts *UnnestOptions) (*unnestS var indexFieldExpr Expr switch v := opts.IndexField.(type) { case FieldPath: - indexFieldExpr = FieldOfPath(v) + indexFieldExpr = FieldOf(v) case string: indexFieldExpr = FieldOf(v) default: diff --git a/firestore/pipeline_utils.go b/firestore/pipeline_utils.go index 0c2a49b61018..3cf0d0982eac 100644 --- a/firestore/pipeline_utils.go +++ b/firestore/pipeline_utils.go @@ -117,7 +117,7 @@ func asFieldExpr(val any) Expr { case Expr: return v case FieldPath: - return FieldOfPath(v) + return FieldOf(v) case string: return FieldOf(v) default: @@ -191,7 +191,7 @@ func fieldsOrSelectablesToSelectables(fieldsOrSelectables ...any) ([]Selectable, } s = FieldOf(v).(*field) case FieldPath: - s = FieldOfPath(v).(*field) + s = FieldOf(v).(*field) case Selectable: s = v default: From 1a52f88cc0e1e9607967bce8594e628c53e6644a Mon Sep 17 00:00:00 2001 From: Baha Aiman Date: Thu, 30 Oct 2025 22:15:29 +0000 Subject: [PATCH 3/4] correct comment --- firestore/pipeline_field.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/firestore/pipeline_field.go b/firestore/pipeline_field.go index 9c1d03c0d3c1..d66627642c31 100644 --- a/firestore/pipeline_field.go +++ b/firestore/pipeline_field.go @@ -28,7 +28,7 @@ type field struct { fieldPath FieldPath } -// FieldOf creates a new field [Expr] from a field path string. +// FieldOf creates a new field [Expr] from a dot separated field path string or [FieldPath]. func FieldOf[T string | FieldPath](path T) Expr { anyPath := any(path) var fieldPath FieldPath From 4da1686c528f7028092c300d1ddf211d7e5d42bc Mon Sep 17 00:00:00 2001 From: Baha Aiman Date: Thu, 30 Oct 2025 15:16:26 -0700 Subject: [PATCH 4/4] use type switch Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- firestore/pipeline_field.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/firestore/pipeline_field.go b/firestore/pipeline_field.go index d66627642c31..578e7ede37e2 100644 --- a/firestore/pipeline_field.go +++ b/firestore/pipeline_field.go @@ -30,16 +30,16 @@ type field struct { // FieldOf creates a new field [Expr] from a dot separated field path string or [FieldPath]. func FieldOf[T string | FieldPath](path T) Expr { - anyPath := any(path) var fieldPath FieldPath - var err error - if v, ok := anyPath.(string); ok { - fieldPath, err = parseDotSeparatedString(v) + switch p := any(path).(type) { + case string: + fp, err := parseDotSeparatedString(p) if err != nil { return &field{baseExpr: &baseExpr{err: err}} } - } else { - fieldPath = anyPath.(FieldPath) + fieldPath = fp + case FieldPath: + fieldPath = p } if err := fieldPath.validate(); err != nil {