diff --git a/firestore/integration_test.go b/firestore/integration_test.go index 044292d80700..ab01ad6b2a62 100644 --- a/firestore/integration_test.go +++ b/firestore/integration_test.go @@ -4117,511 +4117,3 @@ func TestIntegration_PipelineStages(t *testing.T) { } }) } - -func TestIntegration_PipelineFunctions(t *testing.T) { - if testParams[firestoreEditionKey].(firestoreEdition) != editionEnterprise { - t.Skip("Skipping pipeline queries tests since the firestore edition of", testParams[databaseIDKey].(string), "database is not enterprise") - } - t.Run("timestampFuncs", timestampFuncs) - t.Run("arithmeticFuncs", arithmeticFuncs) - t.Run("aggregateFuncs", aggregateFuncs) - t.Run("comparisonFuncs", comparisonFuncs) -} - -func timestampFuncs(t *testing.T) { - t.Parallel() - client := integrationClient(t) - coll := client.Collection(collectionIDs.New()) - h := testHelper{t} - now := time.Now() - docRef1 := coll.NewDoc() - h.mustCreate(docRef1, map[string]interface{}{ - "timestamp": now, - "unixMicros": now.UnixNano() / 1000, - "unixMillis": now.UnixNano() / 1e6, - "unixSeconds": now.Unix(), - }) - defer deleteDocuments([]*DocumentRef{docRef1}) - - tests := []struct { - name string - pipeline *Pipeline - want map[string]interface{} - }{ - { - name: "TimestampAdd day", - pipeline: client.Pipeline(). - Collection(coll.ID). - Select(TimestampAdd("timestamp", "day", 1).As("timestamp_plus_day")), - want: map[string]interface{}{"timestamp_plus_day": now.AddDate(0, 0, 1).Truncate(time.Microsecond)}, - }, - { - name: "TimestampAdd hour", - pipeline: client.Pipeline(). - Collection(coll.ID). - Select(TimestampAdd("timestamp", "hour", 1).As("timestamp_plus_hour")), - want: map[string]interface{}{"timestamp_plus_hour": now.Add(time.Hour).Truncate(time.Microsecond)}, - }, - { - name: "TimestampAdd minute", - pipeline: client.Pipeline(). - Collection(coll.ID). - Select(TimestampAdd("timestamp", "minute", 1).As("timestamp_plus_minute")), - want: map[string]interface{}{"timestamp_plus_minute": now.Add(time.Minute).Truncate(time.Microsecond)}, - }, - { - name: "TimestampAdd second", - pipeline: client.Pipeline(). - Collection(coll.ID). - Select(TimestampAdd("timestamp", "second", 1).As("timestamp_plus_second")), - want: map[string]interface{}{"timestamp_plus_second": now.Add(time.Second).Truncate(time.Microsecond)}, - }, - { - name: "TimestampSubtract", - pipeline: client.Pipeline(). - Collection(coll.ID). - Select(TimestampSubtract("timestamp", "hour", 1).As("timestamp_minus_hour")), - want: map[string]interface{}{"timestamp_minus_hour": now.Add(-time.Hour).Truncate(time.Microsecond)}, - }, - { - name: "TimestampToUnixMicros", - pipeline: client.Pipeline(). - Collection(coll.ID). - Select(FieldOf("timestamp").TimestampToUnixMicros().As("timestamp_micros")), - want: map[string]interface{}{"timestamp_micros": now.UnixNano() / 1000}, - }, - { - name: "TimestampToUnixMillis", - pipeline: client.Pipeline(). - Collection(coll.ID). - Select(FieldOf("timestamp").TimestampToUnixMillis().As("timestamp_millis")), - want: map[string]interface{}{"timestamp_millis": now.UnixNano() / 1e6}, - }, - { - name: "TimestampToUnixSeconds", - pipeline: client.Pipeline(). - Collection(coll.ID). - Select(FieldOf("timestamp").TimestampToUnixSeconds().As("timestamp_seconds")), - want: map[string]interface{}{"timestamp_seconds": now.Unix()}, - }, - { - name: "UnixMicrosToTimestamp - constant", - pipeline: client.Pipeline(). - Collection(coll.ID). - Select(UnixMicrosToTimestamp(ConstantOf(now.UnixNano() / 1000)).As("timestamp_from_micros")), - want: map[string]interface{}{"timestamp_from_micros": now.Truncate(time.Microsecond)}, - }, - { - name: "UnixMicrosToTimestamp - fieldname", - pipeline: client.Pipeline(). - Collection(coll.ID). - Select(UnixMicrosToTimestamp("unixMicros").As("timestamp_from_micros")), - want: map[string]interface{}{"timestamp_from_micros": now.Truncate(time.Microsecond)}, - }, - { - name: "UnixMillisToTimestamp", - pipeline: client.Pipeline(). - Collection(coll.ID). - Select(UnixMillisToTimestamp(ConstantOf(now.UnixNano() / 1e6)).As("timestamp_from_millis")), - want: map[string]interface{}{"timestamp_from_millis": now.Truncate(time.Millisecond)}, - }, - { - name: "UnixSecondsToTimestamp", - pipeline: client.Pipeline(). - Collection(coll.ID). - Select(UnixSecondsToTimestamp("unixSeconds").As("timestamp_from_seconds")), - want: map[string]interface{}{"timestamp_from_seconds": now.Truncate(time.Second)}, - }, - { - name: "CurrentTimestamp", - pipeline: client.Pipeline(). - Collection(coll.ID). - Select(CurrentTimestamp().As("current_timestamp")), - want: map[string]interface{}{"current_timestamp": time.Now().Truncate(time.Microsecond)}, - }, - } - - ctx := context.Background() - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - iter := test.pipeline.Execute(ctx) - defer iter.Stop() - - docs, err := iter.GetAll() - if err != nil { - t.Fatalf("GetAll: %v", err) - } - if len(docs) != 1 { - t.Fatalf("expected 1 doc, got %d", len(docs)) - } - got, err := docs[0].Data() - if err != nil { - t.Fatalf("Data: %v", err) - } - - margin := 0 * time.Microsecond - if test.name == "CurrentTimestamp" { - margin = 5 * time.Second - } - if diff := testutil.Diff(got, test.want, cmpopts.EquateApproxTime(margin)); diff != "" { - t.Errorf("got: %v, want: %v, diff: %s", got, test.want, diff) - } - }) - } -} - -func arithmeticFuncs(t *testing.T) { - t.Parallel() - h := testHelper{t} - client := integrationClient(t) - coll := client.Collection(collectionIDs.New()) - docRef1 := coll.NewDoc() - h.mustCreate(docRef1, map[string]interface{}{ - "a": int(1), - "b": int(2), - "c": -3, - "d": 4.5, - "e": -5.5, - }) - defer deleteDocuments([]*DocumentRef{docRef1}) - - tests := []struct { - name string - pipeline *Pipeline - want map[string]interface{} - }{ - { - name: "Add - left FieldOf, right FieldOf", - pipeline: client.Pipeline().Collection(coll.ID).Select(Add(FieldOf("a"), FieldOf("b")).As("add")), - want: map[string]interface{}{"add": int64(3)}, - }, - { - name: "Add - left FieldOf, right ConstantOf", - pipeline: client.Pipeline().Collection(coll.ID).Select(Add(FieldOf("a"), ConstantOf(2)).As("add")), - want: map[string]interface{}{"add": int64(3)}, - }, - { - name: "Add - left FieldOf, right constant", - pipeline: client.Pipeline().Collection(coll.ID).Select(Add(FieldOf("a"), 5).As("add")), - want: map[string]interface{}{"add": int64(6)}, - }, - { - name: "Add - left fieldname, right constant", - pipeline: client.Pipeline().Collection(coll.ID).Select(Add("a", 5).As("add")), - want: map[string]interface{}{"add": int64(6)}, - }, - { - name: "Add - left fieldpath, right constant", - pipeline: client.Pipeline().Collection(coll.ID).Select(Add(FieldPath([]string{"a"}), 5).As("add")), - want: map[string]interface{}{"add": int64(6)}, - }, - { - name: "Add - left fieldpath, right expression", - pipeline: client.Pipeline().Collection(coll.ID).Select(Add(FieldPath([]string{"a"}), Add(FieldOf("b"), FieldOf("d"))).As("add")), - want: map[string]interface{}{"add": float64(7.5)}, - }, - { - name: "Subtract", - pipeline: client.Pipeline().Collection(coll.ID).Select(Subtract("a", FieldOf("b")).As("subtract")), - want: map[string]interface{}{"subtract": int64(-1)}, - }, - { - name: "Multiply", - pipeline: client.Pipeline().Collection(coll.ID).Select(Multiply("a", 5).As("multiply")), - want: map[string]interface{}{"multiply": int64(5)}, - }, - { - name: "Divide", - pipeline: client.Pipeline().Collection(coll.ID).Select(Divide("a", FieldOf("d")).As("divide")), - want: map[string]interface{}{"divide": float64(1 / 4.5)}, - }, - { - name: "Mod", - pipeline: client.Pipeline().Collection(coll.ID).Select(Mod("a", FieldOf("b")).As("mod")), - want: map[string]interface{}{"mod": int64(1)}, - }, - { - name: "Pow", - pipeline: client.Pipeline().Collection(coll.ID).Select(Pow("a", FieldOf("b")).As("pow")), - want: map[string]interface{}{"pow": float64(1)}, - }, - { - name: "Abs - fieldname", - pipeline: client.Pipeline().Collection(coll.ID).Select(Abs("c").As("abs")), - want: map[string]interface{}{"abs": int64(3)}, - }, - { - name: "Abs - fieldPath", - pipeline: client.Pipeline().Collection(coll.ID).Select(Abs(FieldPath([]string{"c"})).As("abs")), - want: map[string]interface{}{"abs": int64(3)}, - }, - { - name: "Abs - Expr", - pipeline: client.Pipeline().Collection(coll.ID).Select(Abs(Add(FieldOf("b"), FieldOf("d"))).As("abs")), - want: map[string]interface{}{"abs": float64(6.5)}, - }, - { - name: "Ceil", - pipeline: client.Pipeline().Collection(coll.ID).Select(Ceil("d").As("ceil")), - want: map[string]interface{}{"ceil": float64(5)}, - }, - { - name: "Floor", - pipeline: client.Pipeline().Collection(coll.ID).Select(Floor("d").As("floor")), - want: map[string]interface{}{"floor": float64(4)}, - }, - { - name: "Round", - pipeline: client.Pipeline().Collection(coll.ID).Select(Round("d").As("round")), - want: map[string]interface{}{"round": float64(5)}, - }, - { - name: "Sqrt", - pipeline: client.Pipeline().Collection(coll.ID).Select(Sqrt("d").As("sqrt")), - want: map[string]interface{}{"sqrt": math.Sqrt(4.5)}, - }, - { - name: "Log", - pipeline: client.Pipeline().Collection(coll.ID).Select(Log("d", 2).As("log")), - want: map[string]interface{}{"log": math.Log2(4.5)}, - }, - { - name: "Log10", - pipeline: client.Pipeline().Collection(coll.ID).Select(Log10("d").As("log10")), - want: map[string]interface{}{"log10": math.Log10(4.5)}, - }, - { - name: "Ln", - pipeline: client.Pipeline().Collection(coll.ID).Select(Ln("d").As("ln")), - want: map[string]interface{}{"ln": math.Log(4.5)}, - }, - { - name: "Exp", - pipeline: client.Pipeline().Collection(coll.ID).Select(Exp("d").As("exp")), - want: map[string]interface{}{"exp": math.Exp(4.5)}, - }, - } - - ctx := context.Background() - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - iter := test.pipeline.Execute(ctx) - defer iter.Stop() - - docs, err := iter.GetAll() - if err != nil { - t.Fatalf("GetAll: %v", err) - } - if len(docs) != 1 { - t.Fatalf("expected 1 doc, got %d", len(docs)) - } - got, err := docs[0].Data() - if err != nil { - t.Fatalf("Data: %v", err) - } - if diff := testutil.Diff(got, test.want); diff != "" { - t.Errorf("got: %v, want: %v, diff +want -got: %s", got, test.want, diff) - } - }) - } -} - -func aggregateFuncs(t *testing.T) { - t.Parallel() - h := testHelper{t} - client := integrationClient(t) - coll := client.Collection(collectionIDs.New()) - docRef1 := coll.NewDoc() - h.mustCreate(docRef1, map[string]interface{}{ - "a": 1, - }) - docRef2 := coll.NewDoc() - h.mustCreate(docRef2, map[string]interface{}{ - "a": 2, - }) - docRef3 := coll.NewDoc() - h.mustCreate(docRef3, map[string]interface{}{ - "b": 2, - }) - defer deleteDocuments([]*DocumentRef{docRef1, docRef2, docRef3}) - - tests := []struct { - name string - pipeline *Pipeline - want map[string]interface{} - }{ - { - name: "Sum - fieldname arg", - pipeline: client.Pipeline(). - Collection(coll.ID). - Aggregate(Sum("a").As("sum_a")), - want: map[string]interface{}{"sum_a": int64(3)}, - }, - { - name: "Sum - fieldpath arg", - pipeline: client.Pipeline(). - Collection(coll.ID). - Aggregate(Sum(FieldPath([]string{"a"})).As("sum_a")), - want: map[string]interface{}{"sum_a": int64(3)}, - }, - { - name: "Sum - FieldOf Expr", - pipeline: client.Pipeline(). - Collection(coll.ID). - Aggregate(Sum(FieldOf("a")).As("sum_a")), - want: map[string]interface{}{"sum_a": int64(3)}, - }, - { - name: "Sum - FieldOfPath Expr", - pipeline: client.Pipeline(). - Collection(coll.ID). - Aggregate(Sum(FieldOfPath(FieldPath([]string{"a"}))).As("sum_a")), - want: map[string]interface{}{"sum_a": int64(3)}, - }, - { - name: "Avg", - pipeline: client.Pipeline(). - Collection(coll.ID). - Aggregate(Average("a").As("avg_a")), - want: map[string]interface{}{"avg_a": float64(1.5)}, - }, - { - name: "Count", - pipeline: client.Pipeline(). - Collection(coll.ID). - Aggregate(Count("a").As("count_a")), - want: map[string]interface{}{"count_a": int64(2)}, - }, - { - name: "CountAll", - pipeline: client.Pipeline(). - Collection(coll.ID). - Aggregate(CountAll().As("count_all")), - want: map[string]interface{}{"count_all": int64(3)}, - }, - } - - ctx := context.Background() - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - iter := test.pipeline.Execute(ctx) - defer iter.Stop() - - docs, err := iter.GetAll() - if err != nil { - t.Fatalf("GetAll: %v", err) - } - if len(docs) != 1 { - t.Fatalf("expected 1 doc, got %d", len(docs)) - } - got, err := docs[0].Data() - if err != nil { - t.Fatalf("Data: %v", err) - } - if diff := testutil.Diff(got, test.want); diff != "" { - t.Errorf("got: %v, want: %v, diff +want -got: %s", got, test.want, diff) - } - }) - } -} - -func comparisonFuncs(t *testing.T) { - t.Parallel() - ctx := context.Background() - client := integrationClient(t) - now := time.Now() - coll := client.Collection(collectionIDs.New()) - doc1data := map[string]interface{}{ - "timestamp": now, - "a": 1, - "b": 2, - "c": -3, - "d": 4.5, - "e": -5.5, - } - _, err := coll.Doc("doc1").Create(ctx, doc1data) - if err != nil { - t.Fatalf("Create: %v", err) - } - doc2data := map[string]interface{}{ - "timestamp": now, - "a": 2, - "b": 2, - "c": -3, - "d": 4.5, - "e": -5.5, - } - _, err = coll.Doc("doc2").Create(ctx, doc2data) - if err != nil { - t.Fatalf("Create: %v", err) - } - defer deleteDocuments([]*DocumentRef{coll.Doc("doc1"), coll.Doc("doc2")}) - - doc1want := map[string]interface{}{"a": int64(1), "b": int64(2), "c": int64(-3), "d": float64(4.5), "e": float64(-5.5), "timestamp": now.Truncate(time.Microsecond)} - - tests := []struct { - name string - pipeline *Pipeline - want []map[string]interface{} - }{ - { - name: "Equal", - pipeline: client.Pipeline(). - Collection(coll.ID). - Where(Equal("a", 1)), - want: []map[string]interface{}{doc1want}, - }, - { - name: "NotEqual", - pipeline: client.Pipeline(). - Collection(coll.ID). - Where(NotEqual("a", 2)), - want: []map[string]interface{}{doc1want}, - }, - { - name: "LessThan", - pipeline: client.Pipeline(). - Collection(coll.ID). - Where(LessThan("a", 2)), - want: []map[string]interface{}{doc1want}, - }, - { - name: "Equivalent", - pipeline: client.Pipeline(). - Collection(coll.ID). - Where(Equivalent("a", 1)), - want: []map[string]interface{}{doc1want}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - iter := test.pipeline.Execute(ctx) - defer iter.Stop() - - docs, err := iter.GetAll() - if err != nil { - t.Fatalf("GetAll: %v", err) - } - if len(docs) != len(test.want) { - t.Fatalf("expected %d doc(s), got %d", len(test.want), len(docs)) - } - - var gots []map[string]interface{} - for _, doc := range docs { - got, err := doc.Data() - if err != nil { - t.Fatalf("Data: %v", err) - } - if ts, ok := got["timestamp"].(time.Time); ok { - got["timestamp"] = ts.Truncate(time.Microsecond) - } - gots = append(gots, got) - } - - if diff := testutil.Diff(gots, test.want); diff != "" { - t.Errorf("got: %v, want: %v, diff +want -got: %s", gots, test.want, diff) - } - }) - } -} diff --git a/firestore/pipeline_filter_condition.go b/firestore/pipeline_filter_condition.go index 618f6b7a75b6..e9a1c25dedad 100644 --- a/firestore/pipeline_filter_condition.go +++ b/firestore/pipeline_filter_condition.go @@ -39,7 +39,7 @@ var _ BooleanExpr = (*baseBooleanExpr)(nil) // // Check if the 'tags' array contains "Go". // ArrayContains("tags", "Go") func ArrayContains(exprOrFieldPath any, value any) BooleanExpr { - return &baseBooleanExpr{baseFunction: newBaseFunction("array_contains", []Expr{toExprOrField(exprOrFieldPath), toExprOrConstant(value)})} + return &baseBooleanExpr{baseFunction: newBaseFunction("array_contains", []Expr{asFieldExpr(exprOrFieldPath), toExprOrConstant(value)})} } // ArrayContainsAll creates an expression that checks if an array contains all of the provided values. @@ -241,7 +241,7 @@ func Equivalent(left, right any) BooleanExpr { // // Check if the 'filename' field ends with ".go". // EndsWith("filename", ".go") func EndsWith(exprOrFieldPath any, suffix any) BooleanExpr { - return &baseBooleanExpr{baseFunction: newBaseFunction("ends_with", []Expr{toExprOrField(exprOrFieldPath), asStringExpr(suffix)})} + return &baseBooleanExpr{baseFunction: newBaseFunction("ends_with", []Expr{asFieldExpr(exprOrFieldPath), asStringExpr(suffix)})} } // Like creates an expression that performs a case-sensitive wildcard string comparison. @@ -253,7 +253,7 @@ func EndsWith(exprOrFieldPath any, suffix any) BooleanExpr { // // Check if the 'name' field starts with "G". // Like("name", "G%") func Like(exprOrFieldPath any, pattern any) BooleanExpr { - return &baseBooleanExpr{baseFunction: newBaseFunction("like", []Expr{toExprOrField(exprOrFieldPath), asStringExpr(pattern)})} + return &baseBooleanExpr{baseFunction: newBaseFunction("like", []Expr{asFieldExpr(exprOrFieldPath), asStringExpr(pattern)})} } // RegexContains creates an expression that checks if a string contains a match for a regular expression. @@ -265,7 +265,7 @@ func Like(exprOrFieldPath any, pattern any) BooleanExpr { // // Check if the 'email' field contains a gmail address. // RegexContains("email", "@gmail\\.com$") func RegexContains(exprOrFieldPath any, pattern any) BooleanExpr { - return &baseBooleanExpr{baseFunction: newBaseFunction("regex_contains", []Expr{toExprOrField(exprOrFieldPath), asStringExpr(pattern)})} + return &baseBooleanExpr{baseFunction: newBaseFunction("regex_contains", []Expr{asFieldExpr(exprOrFieldPath), asStringExpr(pattern)})} } // RegexMatch creates an expression that checks if a string matches a regular expression. @@ -277,7 +277,7 @@ func RegexContains(exprOrFieldPath any, pattern any) BooleanExpr { // // Check if the 'zip_code' field is a 5-digit number. // RegexMatch("zip_code", "^[0-9]{5}$") func RegexMatch(exprOrFieldPath any, pattern any) BooleanExpr { - return &baseBooleanExpr{baseFunction: newBaseFunction("regex_match", []Expr{toExprOrField(exprOrFieldPath), asStringExpr(pattern)})} + return &baseBooleanExpr{baseFunction: newBaseFunction("regex_match", []Expr{asFieldExpr(exprOrFieldPath), asStringExpr(pattern)})} } // StartsWith creates an expression that checks if a string field or expression starts with a given prefix. @@ -289,7 +289,7 @@ func RegexMatch(exprOrFieldPath any, pattern any) BooleanExpr { // // Check if the 'name' field starts with "Mr.". // StartsWith("name", "Mr.") func StartsWith(exprOrFieldPath any, prefix any) BooleanExpr { - return &baseBooleanExpr{baseFunction: newBaseFunction("starts_with", []Expr{toExprOrField(exprOrFieldPath), asStringExpr(prefix)})} + return &baseBooleanExpr{baseFunction: newBaseFunction("starts_with", []Expr{asFieldExpr(exprOrFieldPath), asStringExpr(prefix)})} } // StringContains creates an expression that checks if a string contains a specified substring. @@ -301,7 +301,7 @@ func StartsWith(exprOrFieldPath any, prefix any) BooleanExpr { // // Check if the 'description' field contains the word "Firestore". // StringContains("description", "Firestore") func StringContains(exprOrFieldPath any, substring any) BooleanExpr { - return &baseBooleanExpr{baseFunction: newBaseFunction("string_contains", []Expr{toExprOrField(exprOrFieldPath), asStringExpr(substring)})} + return &baseBooleanExpr{baseFunction: newBaseFunction("string_contains", []Expr{asFieldExpr(exprOrFieldPath), asStringExpr(substring)})} } // IsNaN creates a boolean expression that checks if a field or expression evaluates to NaN. @@ -313,7 +313,7 @@ func StringContains(exprOrFieldPath any, substring any) BooleanExpr { // // Check if the 'score' field is NaN. // IsNaN("score") func IsNaN(exprOrFieldPath any) BooleanExpr { - return &baseBooleanExpr{baseFunction: newBaseFunction("is_nan", []Expr{toExprOrField(exprOrFieldPath)})} + return &baseBooleanExpr{baseFunction: newBaseFunction("is_nan", []Expr{asFieldExpr(exprOrFieldPath)})} } // IsNotNaN creates a boolean expression that checks if a field or expression does not evaluate to NaN. @@ -325,7 +325,7 @@ func IsNaN(exprOrFieldPath any) BooleanExpr { // // Check if the 'score' field is not NaN. // IsNotNaN("score") func IsNotNaN(exprOrFieldPath any) BooleanExpr { - return &baseBooleanExpr{baseFunction: newBaseFunction("is_not_nan", []Expr{toExprOrField(exprOrFieldPath)})} + return &baseBooleanExpr{baseFunction: newBaseFunction("is_not_nan", []Expr{asFieldExpr(exprOrFieldPath)})} } // IsNull creates a boolean expression that checks if a field or expression evaluates to null. @@ -337,7 +337,7 @@ func IsNotNaN(exprOrFieldPath any) BooleanExpr { // // Check if the 'address' field is null. // IsNull("address") func IsNull(exprOrFieldPath any) BooleanExpr { - return &baseBooleanExpr{baseFunction: newBaseFunction("is_null", []Expr{toExprOrField(exprOrFieldPath)})} + return &baseBooleanExpr{baseFunction: newBaseFunction("is_null", []Expr{asFieldExpr(exprOrFieldPath)})} } // IsNotNull creates a boolean expression that checks if a field or expression does not evaluate to null. @@ -349,5 +349,5 @@ func IsNull(exprOrFieldPath any) BooleanExpr { // // Check if the 'address' field is not null. // IsNotNull("address") func IsNotNull(exprOrFieldPath any) BooleanExpr { - return &baseBooleanExpr{baseFunction: newBaseFunction("is_not_null", []Expr{toExprOrField(exprOrFieldPath)})} + return &baseBooleanExpr{baseFunction: newBaseFunction("is_not_null", []Expr{asFieldExpr(exprOrFieldPath)})} } diff --git a/firestore/pipeline_function.go b/firestore/pipeline_function.go index 3d095bf69a67..d55cf0ae682e 100644 --- a/firestore/pipeline_function.go +++ b/firestore/pipeline_function.go @@ -328,7 +328,7 @@ func CurrentTimestamp() Expr { // // Get the length of the 'tags' array field. // ArrayLength("tags") func ArrayLength(exprOrFieldPath any) Expr { - return newBaseFunction("array_length", []Expr{toExprOrField(exprOrFieldPath)}) + return newBaseFunction("array_length", []Expr{asFieldExpr(exprOrFieldPath)}) } // Array creates an expression that represents a Firestore array. @@ -359,7 +359,7 @@ func ArrayFromSlice[T any](elements []T) Expr { // // Get the first element of the 'tags' array field. // ArrayGet("tags", 0) func ArrayGet(exprOrFieldPath any, offset any) Expr { - return newBaseFunction("array_get", []Expr{toExprOrField(exprOrFieldPath), asInt64Expr(offset)}) + return newBaseFunction("array_get", []Expr{asFieldExpr(exprOrFieldPath), asInt64Expr(offset)}) } // ArrayReverse creates an expression that reverses the order of elements in an array. @@ -370,7 +370,7 @@ func ArrayGet(exprOrFieldPath any, offset any) Expr { // // Reverse the 'tags' array. // ArrayReverse("tags") func ArrayReverse(exprOrFieldPath any) Expr { - return newBaseFunction("array_reverse", []Expr{toExprOrField(exprOrFieldPath)}) + return newBaseFunction("array_reverse", []Expr{asFieldExpr(exprOrFieldPath)}) } // ArrayConcat creates an expression that concatenates multiple arrays into a single array. @@ -382,7 +382,7 @@ func ArrayReverse(exprOrFieldPath any) Expr { // // Concatenate the 'tags' and 'categories' array fields. // ArrayConcat("tags", FieldOf("categories")) func ArrayConcat(exprOrFieldPath any, otherArrays ...any) Expr { - return newBaseFunction("array_concat", append([]Expr{toExprOrField(exprOrFieldPath)}, toExprs(otherArrays)...)) + return newBaseFunction("array_concat", append([]Expr{asFieldExpr(exprOrFieldPath)}, toExprs(otherArrays)...)) } // ArraySum creates an expression that calculates the sum of all elements in a numeric array. @@ -393,7 +393,7 @@ func ArrayConcat(exprOrFieldPath any, otherArrays ...any) Expr { // // Calculate the sum of the 'scores' array. // ArraySum("scores") func ArraySum(exprOrFieldPath any) Expr { - return newBaseFunction("sum", []Expr{toExprOrField(exprOrFieldPath)}) + return newBaseFunction("sum", []Expr{asFieldExpr(exprOrFieldPath)}) } // ArrayMaximum creates an expression that finds the maximum element in a numeric array. @@ -404,7 +404,7 @@ func ArraySum(exprOrFieldPath any) Expr { // // Find the maximum value in the 'scores' array. // ArrayMaximum("scores") func ArrayMaximum(exprOrFieldPath any) Expr { - return newBaseFunction("maximum", []Expr{toExprOrField(exprOrFieldPath)}) + return newBaseFunction("maximum", []Expr{asFieldExpr(exprOrFieldPath)}) } // ArrayMinimum creates an expression that finds the minimum element in a numeric array. @@ -415,7 +415,7 @@ func ArrayMaximum(exprOrFieldPath any) Expr { // // Find the minimum value in the 'scores' array. // ArrayMinimum("scores") func ArrayMinimum(exprOrFieldPath any) Expr { - return newBaseFunction("minimum", []Expr{toExprOrField(exprOrFieldPath)}) + return newBaseFunction("minimum", []Expr{asFieldExpr(exprOrFieldPath)}) } // ByteLength creates an expression that calculates the length of a string represented by a field or [Expr] in UTF-8 @@ -427,7 +427,7 @@ func ArrayMinimum(exprOrFieldPath any) Expr { // // Get the byte length of the 'name' field. // ByteLength("name") func ByteLength(exprOrFieldPath any) Expr { - return newBaseFunction("byte_length", []Expr{toExprOrField(exprOrFieldPath)}) + return newBaseFunction("byte_length", []Expr{asFieldExpr(exprOrFieldPath)}) } // CharLength creates an expression that calculates the character length of a string field or expression in UTF8. @@ -438,7 +438,7 @@ func ByteLength(exprOrFieldPath any) Expr { // // Get the character length of the 'name' field. // CharLength("name") func CharLength(exprOrFieldPath any) Expr { - return newBaseFunction("char_length", []Expr{toExprOrField(exprOrFieldPath)}) + return newBaseFunction("char_length", []Expr{asFieldExpr(exprOrFieldPath)}) } // StringConcat creates an expression that concatenates multiple strings into a single string. @@ -450,7 +450,7 @@ func CharLength(exprOrFieldPath any) Expr { // // Concatenate first name and last name. // StringConcat(FieldOf("firstName"), " ", FieldOf("lastName")) func StringConcat(exprOrFieldPath any, otherStrings ...any) Expr { - return newBaseFunction("string_concat", append([]Expr{toExprOrField(exprOrFieldPath)}, toExprs(otherStrings)...)) + return newBaseFunction("string_concat", append([]Expr{asFieldExpr(exprOrFieldPath)}, toExprs(otherStrings)...)) } // StringReverse creates an expression that reverses a string. @@ -461,7 +461,7 @@ func StringConcat(exprOrFieldPath any, otherStrings ...any) Expr { // // Reverse the 'name' field. // StringReverse("name") func StringReverse(exprOrFieldPath any) Expr { - return newBaseFunction("string_reverse", []Expr{toExprOrField(exprOrFieldPath)}) + return newBaseFunction("string_reverse", []Expr{asFieldExpr(exprOrFieldPath)}) } // Join creates an expression that joins the elements of a string array into a single string. @@ -473,7 +473,7 @@ func StringReverse(exprOrFieldPath any) Expr { // // Join the 'tags' array with a comma and space. // Join("tags", ", ") func Join(exprOrFieldPath any, separator any) Expr { - return newBaseFunction("join", []Expr{toExprOrField(exprOrFieldPath), asStringExpr(separator)}) + return newBaseFunction("join", []Expr{asFieldExpr(exprOrFieldPath), asStringExpr(separator)}) } // Substring creates an expression that returns a substring of a string. @@ -486,7 +486,7 @@ func Join(exprOrFieldPath any, separator any) Expr { // // Get the first 5 characters of the 'description' field. // Substring("description", 0, 5) func Substring(exprOrFieldPath any, index any, offset any) Expr { - return newBaseFunction("substring", []Expr{toExprOrField(exprOrFieldPath), asInt64Expr(index), asInt64Expr(offset)}) + return newBaseFunction("substring", []Expr{asFieldExpr(exprOrFieldPath), asInt64Expr(index), asInt64Expr(offset)}) } // ToLower creates an expression that converts a string to lowercase. @@ -497,7 +497,7 @@ func Substring(exprOrFieldPath any, index any, offset any) Expr { // // Convert the 'username' to lowercase. // ToLower("username") func ToLower(exprOrFieldPath any) Expr { - return newBaseFunction("to_lower", []Expr{toExprOrField(exprOrFieldPath)}) + return newBaseFunction("to_lower", []Expr{asFieldExpr(exprOrFieldPath)}) } // ToUpper creates an expression that converts a string to uppercase. @@ -508,7 +508,7 @@ func ToLower(exprOrFieldPath any) Expr { // // Convert the 'product_code' to uppercase. // ToUpper("product_code") func ToUpper(exprOrFieldPath any) Expr { - return newBaseFunction("to_upper", []Expr{toExprOrField(exprOrFieldPath)}) + return newBaseFunction("to_upper", []Expr{asFieldExpr(exprOrFieldPath)}) } // Trim creates an expression that removes leading and trailing whitespace from a string. @@ -519,7 +519,7 @@ func ToUpper(exprOrFieldPath any) Expr { // // Trim the 'email' field. // Trim("email") func Trim(exprOrFieldPath any) Expr { - return newBaseFunction("trim", []Expr{toExprOrField(exprOrFieldPath)}) + return newBaseFunction("trim", []Expr{asFieldExpr(exprOrFieldPath)}) } // CosineDistance creates an expression that calculates the cosine distance between two vectors. @@ -531,7 +531,7 @@ func Trim(exprOrFieldPath any) Expr { // // Calculate the cosine distance between two vector fields. // CosineDistance("vector_field_1", FieldOf("vector_field_2")) func CosineDistance(vector1 any, vector2 any) Expr { - return newBaseFunction("cosine_distance", []Expr{toExprOrField(vector1), asVectorExpr(vector2)}) + return newBaseFunction("cosine_distance", []Expr{asFieldExpr(vector1), asVectorExpr(vector2)}) } // DotProduct creates an expression that calculates the dot product of two vectors. @@ -543,7 +543,7 @@ func CosineDistance(vector1 any, vector2 any) Expr { // // Calculate the dot product of two vector fields. // DotProduct("vector_field_1", FieldOf("vector_field_2")) func DotProduct(vector1 any, vector2 any) Expr { - return newBaseFunction("dot_product", []Expr{toExprOrField(vector1), asVectorExpr(vector2)}) + return newBaseFunction("dot_product", []Expr{asFieldExpr(vector1), asVectorExpr(vector2)}) } // EuclideanDistance creates an expression that calculates the euclidean distance between two vectors. @@ -555,7 +555,7 @@ func DotProduct(vector1 any, vector2 any) Expr { // // Calculate the euclidean distance between two vector fields. // EuclideanDistance("vector_field_1", FieldOf("vector_field_2")) func EuclideanDistance(vector1 any, vector2 any) Expr { - return newBaseFunction("euclidean_distance", []Expr{toExprOrField(vector1), asVectorExpr(vector2)}) + return newBaseFunction("euclidean_distance", []Expr{asFieldExpr(vector1), asVectorExpr(vector2)}) } // VectorLength creates an expression that calculates the length of a vector. @@ -566,5 +566,5 @@ func EuclideanDistance(vector1 any, vector2 any) Expr { // // Calculate the length of a vector field. // VectorLength("vector_field") func VectorLength(exprOrFieldPath any) Expr { - return newBaseFunction("vector_length", []Expr{toExprOrField(exprOrFieldPath)}) + return newBaseFunction("vector_length", []Expr{asFieldExpr(exprOrFieldPath)}) } diff --git a/firestore/pipeline_integration_test.go b/firestore/pipeline_integration_test.go index b8f9695a0235..5b6c98e6e57b 100644 --- a/firestore/pipeline_integration_test.go +++ b/firestore/pipeline_integration_test.go @@ -36,6 +36,11 @@ func TestIntegration_PipelineFunctions(t *testing.T) { t.Run("typeFuncs", typeFuncs) t.Run("vectorFuncs", vectorFuncs) + t.Run("timestampFuncs", timestampFuncs) + t.Run("arithmeticFuncs", arithmeticFuncs) + t.Run("aggregateFuncs", aggregateFuncs) + t.Run("comparisonFuncs", comparisonFuncs) + } func arrayFuncs(t *testing.T) { @@ -161,11 +166,7 @@ func arrayFuncs(t *testing.T) { r.Fatalf("expected 1 doc, got %d", len(docs)) return } - got, err := docs[0].Data() - if err != nil { - r.Fatalf("Data: %v", err) - return - } + got := docs[0].Data() if diff := testutil.Diff(got, test.want); diff != "" { r.Errorf("got: %v, want: %v, diff +want -got: %s", got, test.want, diff) return @@ -312,11 +313,7 @@ func stringFuncs(t *testing.T) { r.Fatalf("expected 1 doc, got %d", len(docs)) return } - got, err := docs[0].Data() - if err != nil { - r.Fatalf("Data: %v", err) - return - } + got := docs[0].Data() if diff := testutil.Diff(got, want); diff != "" { t.Errorf("got: %v, want: %v, diff +want -got: %s", got, want, diff) } @@ -332,11 +329,7 @@ func stringFuncs(t *testing.T) { } var gots []map[string]interface{} for _, doc := range docs { - got, err := doc.Data() - if err != nil { - r.Fatalf("Data: %v", err) - return - } + got := doc.Data() gots = append(gots, got) } if diff := testutil.Diff(gots, want); diff != "" { @@ -516,11 +509,7 @@ func vectorFuncs(t *testing.T) { r.Fatalf("expected 1 doc, got %d", len(docs)) return } - got, err := docs[0].Data() - if err != nil { - r.Fatalf("Data: %v", err) - return - } + got := docs[0].Data() if diff := testutil.Diff(got, test.want); diff != "" { r.Errorf("got: %v, want: %v, diff +want -got: %s", got, test.want, diff) } @@ -541,3 +530,488 @@ func isRetryablePipelineExecuteErr(err error) bool { strings.Contains(s.Message(), "Invalid request routing header") && strings.Contains(s.Message(), "Please fill in the request header with format") } + +func timestampFuncs(t *testing.T) { + t.Parallel() + client := integrationClient(t) + coll := client.Collection(collectionIDs.New()) + h := testHelper{t} + now := time.Now() + docRef1 := coll.NewDoc() + h.mustCreate(docRef1, map[string]interface{}{ + "timestamp": now, + "unixMicros": now.UnixNano() / 1000, + "unixMillis": now.UnixNano() / 1e6, + "unixSeconds": now.Unix(), + }) + defer deleteDocuments([]*DocumentRef{docRef1}) + + tests := []struct { + name string + pipeline *Pipeline + want map[string]interface{} + }{ + { + name: "TimestampAdd day", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(TimestampAdd("timestamp", "day", 1).As("timestamp_plus_day")), + want: map[string]interface{}{"timestamp_plus_day": now.AddDate(0, 0, 1).Truncate(time.Microsecond)}, + }, + { + name: "TimestampAdd hour", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(TimestampAdd("timestamp", "hour", 1).As("timestamp_plus_hour")), + want: map[string]interface{}{"timestamp_plus_hour": now.Add(time.Hour).Truncate(time.Microsecond)}, + }, + { + name: "TimestampAdd minute", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(TimestampAdd("timestamp", "minute", 1).As("timestamp_plus_minute")), + want: map[string]interface{}{"timestamp_plus_minute": now.Add(time.Minute).Truncate(time.Microsecond)}, + }, + { + name: "TimestampAdd second", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(TimestampAdd("timestamp", "second", 1).As("timestamp_plus_second")), + want: map[string]interface{}{"timestamp_plus_second": now.Add(time.Second).Truncate(time.Microsecond)}, + }, + { + name: "TimestampSubtract", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(TimestampSubtract("timestamp", "hour", 1).As("timestamp_minus_hour")), + want: map[string]interface{}{"timestamp_minus_hour": now.Add(-time.Hour).Truncate(time.Microsecond)}, + }, + { + name: "TimestampToUnixMicros", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(FieldOf("timestamp").TimestampToUnixMicros().As("timestamp_micros")), + want: map[string]interface{}{"timestamp_micros": now.UnixNano() / 1000}, + }, + { + name: "TimestampToUnixMillis", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(FieldOf("timestamp").TimestampToUnixMillis().As("timestamp_millis")), + want: map[string]interface{}{"timestamp_millis": now.UnixNano() / 1e6}, + }, + { + name: "TimestampToUnixSeconds", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(FieldOf("timestamp").TimestampToUnixSeconds().As("timestamp_seconds")), + want: map[string]interface{}{"timestamp_seconds": now.Unix()}, + }, + { + name: "UnixMicrosToTimestamp - constant", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(UnixMicrosToTimestamp(ConstantOf(now.UnixNano() / 1000)).As("timestamp_from_micros")), + want: map[string]interface{}{"timestamp_from_micros": now.Truncate(time.Microsecond)}, + }, + { + name: "UnixMicrosToTimestamp - fieldname", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(UnixMicrosToTimestamp("unixMicros").As("timestamp_from_micros")), + want: map[string]interface{}{"timestamp_from_micros": now.Truncate(time.Microsecond)}, + }, + { + name: "UnixMillisToTimestamp", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(UnixMillisToTimestamp(ConstantOf(now.UnixNano() / 1e6)).As("timestamp_from_millis")), + want: map[string]interface{}{"timestamp_from_millis": now.Truncate(time.Millisecond)}, + }, + { + name: "UnixSecondsToTimestamp", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(UnixSecondsToTimestamp("unixSeconds").As("timestamp_from_seconds")), + want: map[string]interface{}{"timestamp_from_seconds": now.Truncate(time.Second)}, + }, + { + name: "CurrentTimestamp", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(CurrentTimestamp().As("current_timestamp")), + want: map[string]interface{}{"current_timestamp": time.Now().Truncate(time.Microsecond)}, + }, + } + + ctx := context.Background() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + iter := test.pipeline.Execute(ctx) + defer iter.Stop() + + docs, err := iter.GetAll() + if err != nil { + t.Fatalf("GetAll: %v", err) + } + if len(docs) != 1 { + t.Fatalf("expected 1 doc, got %d", len(docs)) + } + got := docs[0].Data() + margin := 0 * time.Microsecond + if test.name == "CurrentTimestamp" { + margin = 5 * time.Second + } + if diff := testutil.Diff(got, test.want, cmpopts.EquateApproxTime(margin)); diff != "" { + t.Errorf("got: %v, want: %v, diff: %s", got, test.want, diff) + } + }) + } +} + +func arithmeticFuncs(t *testing.T) { + t.Parallel() + h := testHelper{t} + client := integrationClient(t) + coll := client.Collection(collectionIDs.New()) + docRef1 := coll.NewDoc() + h.mustCreate(docRef1, map[string]interface{}{ + "a": int(1), + "b": int(2), + "c": -3, + "d": 4.5, + "e": -5.5, + }) + defer deleteDocuments([]*DocumentRef{docRef1}) + + tests := []struct { + name string + pipeline *Pipeline + want map[string]interface{} + }{ + { + name: "Add - left FieldOf, right FieldOf", + pipeline: client.Pipeline().Collection(coll.ID).Select(Add(FieldOf("a"), FieldOf("b")).As("add")), + want: map[string]interface{}{"add": int64(3)}, + }, + { + name: "Add - left FieldOf, right ConstantOf", + pipeline: client.Pipeline().Collection(coll.ID).Select(Add(FieldOf("a"), ConstantOf(2)).As("add")), + want: map[string]interface{}{"add": int64(3)}, + }, + { + name: "Add - left FieldOf, right constant", + pipeline: client.Pipeline().Collection(coll.ID).Select(Add(FieldOf("a"), 5).As("add")), + want: map[string]interface{}{"add": int64(6)}, + }, + { + name: "Add - left fieldname, right constant", + pipeline: client.Pipeline().Collection(coll.ID).Select(Add("a", 5).As("add")), + want: map[string]interface{}{"add": int64(6)}, + }, + { + name: "Add - left fieldpath, right constant", + pipeline: client.Pipeline().Collection(coll.ID).Select(Add(FieldPath([]string{"a"}), 5).As("add")), + want: map[string]interface{}{"add": int64(6)}, + }, + { + name: "Add - left fieldpath, right expression", + pipeline: client.Pipeline().Collection(coll.ID).Select(Add(FieldPath([]string{"a"}), Add(FieldOf("b"), FieldOf("d"))).As("add")), + want: map[string]interface{}{"add": float64(7.5)}, + }, + { + name: "Subtract", + pipeline: client.Pipeline().Collection(coll.ID).Select(Subtract("a", FieldOf("b")).As("subtract")), + want: map[string]interface{}{"subtract": int64(-1)}, + }, + { + name: "Multiply", + pipeline: client.Pipeline().Collection(coll.ID).Select(Multiply("a", 5).As("multiply")), + want: map[string]interface{}{"multiply": int64(5)}, + }, + { + name: "Divide", + pipeline: client.Pipeline().Collection(coll.ID).Select(Divide("a", FieldOf("d")).As("divide")), + want: map[string]interface{}{"divide": float64(1 / 4.5)}, + }, + { + name: "Mod", + pipeline: client.Pipeline().Collection(coll.ID).Select(Mod("a", FieldOf("b")).As("mod")), + want: map[string]interface{}{"mod": int64(1)}, + }, + { + name: "Pow", + pipeline: client.Pipeline().Collection(coll.ID).Select(Pow("a", FieldOf("b")).As("pow")), + want: map[string]interface{}{"pow": float64(1)}, + }, + { + name: "Abs - fieldname", + pipeline: client.Pipeline().Collection(coll.ID).Select(Abs("c").As("abs")), + want: map[string]interface{}{"abs": int64(3)}, + }, + { + name: "Abs - fieldPath", + pipeline: client.Pipeline().Collection(coll.ID).Select(Abs(FieldPath([]string{"c"})).As("abs")), + want: map[string]interface{}{"abs": int64(3)}, + }, + { + name: "Abs - Expr", + pipeline: client.Pipeline().Collection(coll.ID).Select(Abs(Add(FieldOf("b"), FieldOf("d"))).As("abs")), + want: map[string]interface{}{"abs": float64(6.5)}, + }, + { + name: "Ceil", + pipeline: client.Pipeline().Collection(coll.ID).Select(Ceil("d").As("ceil")), + want: map[string]interface{}{"ceil": float64(5)}, + }, + { + name: "Floor", + pipeline: client.Pipeline().Collection(coll.ID).Select(Floor("d").As("floor")), + want: map[string]interface{}{"floor": float64(4)}, + }, + { + name: "Round", + pipeline: client.Pipeline().Collection(coll.ID).Select(Round("d").As("round")), + want: map[string]interface{}{"round": float64(5)}, + }, + { + name: "Sqrt", + pipeline: client.Pipeline().Collection(coll.ID).Select(Sqrt("d").As("sqrt")), + want: map[string]interface{}{"sqrt": math.Sqrt(4.5)}, + }, + { + name: "Log", + pipeline: client.Pipeline().Collection(coll.ID).Select(Log("d", 2).As("log")), + want: map[string]interface{}{"log": math.Log2(4.5)}, + }, + { + name: "Log10", + pipeline: client.Pipeline().Collection(coll.ID).Select(Log10("d").As("log10")), + want: map[string]interface{}{"log10": math.Log10(4.5)}, + }, + { + name: "Ln", + pipeline: client.Pipeline().Collection(coll.ID).Select(Ln("d").As("ln")), + want: map[string]interface{}{"ln": math.Log(4.5)}, + }, + { + name: "Exp", + pipeline: client.Pipeline().Collection(coll.ID).Select(Exp("d").As("exp")), + want: map[string]interface{}{"exp": math.Exp(4.5)}, + }, + } + + ctx := context.Background() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + iter := test.pipeline.Execute(ctx) + defer iter.Stop() + + docs, err := iter.GetAll() + if err != nil { + t.Fatalf("GetAll: %v", err) + } + if len(docs) != 1 { + t.Fatalf("expected 1 doc, got %d", len(docs)) + } + got := docs[0].Data() + if diff := testutil.Diff(got, test.want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", got, test.want, diff) + } + }) + } +} + +func aggregateFuncs(t *testing.T) { + t.Parallel() + h := testHelper{t} + client := integrationClient(t) + coll := client.Collection(collectionIDs.New()) + docRef1 := coll.NewDoc() + h.mustCreate(docRef1, map[string]interface{}{ + "a": 1, + }) + docRef2 := coll.NewDoc() + h.mustCreate(docRef2, map[string]interface{}{ + "a": 2, + }) + docRef3 := coll.NewDoc() + h.mustCreate(docRef3, map[string]interface{}{ + "b": 2, + }) + defer deleteDocuments([]*DocumentRef{docRef1, docRef2, docRef3}) + + tests := []struct { + name string + pipeline *Pipeline + want map[string]interface{} + }{ + { + name: "Sum - fieldname arg", + pipeline: client.Pipeline(). + Collection(coll.ID). + Aggregate(Sum("a").As("sum_a")), + want: map[string]interface{}{"sum_a": int64(3)}, + }, + { + name: "Sum - fieldpath arg", + pipeline: client.Pipeline(). + Collection(coll.ID). + Aggregate(Sum(FieldPath([]string{"a"})).As("sum_a")), + want: map[string]interface{}{"sum_a": int64(3)}, + }, + { + name: "Sum - FieldOf Expr", + pipeline: client.Pipeline(). + Collection(coll.ID). + Aggregate(Sum(FieldOf("a")).As("sum_a")), + want: map[string]interface{}{"sum_a": int64(3)}, + }, + { + name: "Sum - FieldOfPath Expr", + pipeline: client.Pipeline(). + Collection(coll.ID). + Aggregate(Sum(FieldOfPath(FieldPath([]string{"a"}))).As("sum_a")), + want: map[string]interface{}{"sum_a": int64(3)}, + }, + { + name: "Avg", + pipeline: client.Pipeline(). + Collection(coll.ID). + Aggregate(Average("a").As("avg_a")), + want: map[string]interface{}{"avg_a": float64(1.5)}, + }, + { + name: "Count", + pipeline: client.Pipeline(). + Collection(coll.ID). + Aggregate(Count("a").As("count_a")), + want: map[string]interface{}{"count_a": int64(2)}, + }, + { + name: "CountAll", + pipeline: client.Pipeline(). + Collection(coll.ID). + Aggregate(CountAll().As("count_all")), + want: map[string]interface{}{"count_all": int64(3)}, + }, + } + + ctx := context.Background() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + iter := test.pipeline.Execute(ctx) + defer iter.Stop() + + docs, err := iter.GetAll() + if err != nil { + t.Fatalf("GetAll: %v", err) + } + if len(docs) != 1 { + t.Fatalf("expected 1 doc, got %d", len(docs)) + } + got := docs[0].Data() + if diff := testutil.Diff(got, test.want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", got, test.want, diff) + } + }) + } +} + +func comparisonFuncs(t *testing.T) { + t.Parallel() + ctx := context.Background() + client := integrationClient(t) + now := time.Now() + coll := client.Collection(collectionIDs.New()) + doc1data := map[string]interface{}{ + "timestamp": now, + "a": 1, + "b": 2, + "c": -3, + "d": 4.5, + "e": -5.5, + } + _, err := coll.Doc("doc1").Create(ctx, doc1data) + if err != nil { + t.Fatalf("Create: %v", err) + } + doc2data := map[string]interface{}{ + "timestamp": now, + "a": 2, + "b": 2, + "c": -3, + "d": 4.5, + "e": -5.5, + } + _, err = coll.Doc("doc2").Create(ctx, doc2data) + if err != nil { + t.Fatalf("Create: %v", err) + } + defer deleteDocuments([]*DocumentRef{coll.Doc("doc1"), coll.Doc("doc2")}) + + doc1want := map[string]interface{}{"a": int64(1), "b": int64(2), "c": int64(-3), "d": float64(4.5), "e": float64(-5.5), "timestamp": now.Truncate(time.Microsecond)} + + tests := []struct { + name string + pipeline *Pipeline + want []map[string]interface{} + }{ + { + name: "Equal", + pipeline: client.Pipeline(). + Collection(coll.ID). + Where(Equal("a", 1)), + want: []map[string]interface{}{doc1want}, + }, + { + name: "NotEqual", + pipeline: client.Pipeline(). + Collection(coll.ID). + Where(NotEqual("a", 2)), + want: []map[string]interface{}{doc1want}, + }, + { + name: "LessThan", + pipeline: client.Pipeline(). + Collection(coll.ID). + Where(LessThan("a", 2)), + want: []map[string]interface{}{doc1want}, + }, + { + name: "Equivalent", + pipeline: client.Pipeline(). + Collection(coll.ID). + Where(Equivalent("a", 1)), + want: []map[string]interface{}{doc1want}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + iter := test.pipeline.Execute(ctx) + defer iter.Stop() + + docs, err := iter.GetAll() + if err != nil { + t.Fatalf("GetAll: %v", err) + } + if len(docs) != len(test.want) { + t.Fatalf("expected %d doc(s), got %d", len(test.want), len(docs)) + } + + var gots []map[string]interface{} + for _, doc := range docs { + got := doc.Data() + if ts, ok := got["timestamp"].(time.Time); ok { + got["timestamp"] = ts.Truncate(time.Microsecond) + } + gots = append(gots, got) + } + + if diff := testutil.Diff(gots, test.want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", gots, test.want, diff) + } + }) + } +} diff --git a/firestore/pipeline_utils.go b/firestore/pipeline_utils.go index 2ced52e8c79e..0c2a49b61018 100644 --- a/firestore/pipeline_utils.go +++ b/firestore/pipeline_utils.go @@ -24,7 +24,7 @@ import ( // newFieldAndArrayBooleanExpr creates a new BooleanExpr for functions that operate on a field/expression and an array of values. func newFieldAndArrayBooleanExpr(name string, exprOrFieldPath any, values any) BooleanExpr { - return &baseBooleanExpr{baseFunction: newBaseFunction(name, []Expr{toExprOrField(exprOrFieldPath), asArrayFunctionExpr(values)})} + return &baseBooleanExpr{baseFunction: newBaseFunction(name, []Expr{asFieldExpr(exprOrFieldPath), asArrayFunctionExpr(values)})} } // toExprs converts a plain Go value or an existing Expr into an Expr. @@ -125,28 +125,6 @@ func asFieldExpr(val any) Expr { } } -func asInt64Expr(val any) Expr { - switch v := val.(type) { - case Expr: - return v - case int, int32, int64: - return ConstantOf(v) - default: - return &baseExpr{err: fmt.Errorf("firestore: value must be a int, int32, int64 or Expr, but got %T", val)} - } -} - -func asStringExpr(val any) Expr { - switch v := val.(type) { - case Expr: - return v - case string: - return ConstantOf(v) - default: - return &baseExpr{err: fmt.Errorf("firestore: value must be a string or Expr, but got %T", val)} - } -} - // leftRightToBaseFunction is a helper for creating binary functions like Add or Eq. // It ensures the left operand is a field-like expression and the right is a constant-like expression. func leftRightToBaseFunction(name string, left, right any) *baseFunction { diff --git a/firestore/util_test.go b/firestore/util_test.go index aeb68badb8c9..004e0f3b4014 100644 --- a/firestore/util_test.go +++ b/firestore/util_test.go @@ -147,10 +147,7 @@ func refval(path string) *pb.Value { func docsToMaps(t *testing.T, docs []*PipelineResult) []map[string]interface{} { var maps []map[string]interface{} for _, doc := range docs { - data, err := doc.Data() - if err != nil { - t.Fatalf("Data: %v", err) - } + data := doc.Data() maps = append(maps, data) } return maps