diff --git a/firestore/apiv1/firestore_client.go b/firestore/apiv1/firestore_client.go index 2493c6ed5811..931d439edbd2 100644 --- a/firestore/apiv1/firestore_client.go +++ b/firestore/apiv1/firestore_client.go @@ -1086,19 +1086,21 @@ func (c *gRPCClient) RunQuery(ctx context.Context, req *firestorepb.RunQueryRequ } func (c *gRPCClient) ExecutePipeline(ctx context.Context, req *firestorepb.ExecutePipelineRequest, opts ...gax.CallOption) (firestorepb.Firestore_ExecutePipelineClient, error) { - routingHeaders := "" - routingHeadersMap := make(map[string]string) + var routingHeaders []string + seen := make(map[string]bool) if reg := regexp.MustCompile("projects/(?P[^/]+)(?:/.*)?"); reg.MatchString(req.GetDatabase()) && len(url.QueryEscape(reg.FindStringSubmatch(req.GetDatabase())[1])) > 0 { - routingHeadersMap["project_id"] = url.QueryEscape(reg.FindStringSubmatch(req.GetDatabase())[1]) + if !seen["project_id"] { + routingHeaders = append(routingHeaders, fmt.Sprintf("%s=%s", "project_id", url.QueryEscape(reg.FindStringSubmatch(req.GetDatabase())[1]))) + seen["project_id"] = true + } } if reg := regexp.MustCompile("projects/[^/]+/databases/(?P[^/]+)(?:/.*)?"); reg.MatchString(req.GetDatabase()) && len(url.QueryEscape(reg.FindStringSubmatch(req.GetDatabase())[1])) > 0 { - routingHeadersMap["database_id"] = url.QueryEscape(reg.FindStringSubmatch(req.GetDatabase())[1]) - } - for headerName, headerValue := range routingHeadersMap { - routingHeaders = fmt.Sprintf("%s%s=%s&", routingHeaders, headerName, headerValue) + if !seen["database_id"] { + routingHeaders = append(routingHeaders, fmt.Sprintf("%s=%s", "database_id", url.QueryEscape(reg.FindStringSubmatch(req.GetDatabase())[1]))) + seen["database_id"] = true + } } - routingHeaders = strings.TrimSuffix(routingHeaders, "&") - hds := []string{"x-goog-request-params", routingHeaders} + hds := []string{"x-goog-request-params", strings.Join(routingHeaders, "&")} hds = append(c.xGoogHeaders, hds...) ctx = gax.InsertMetadataIntoOutgoingContext(ctx, hds...) diff --git a/firestore/pipeline_expr.go b/firestore/pipeline_expr.go index 0d79020512f8..662b5b7a3ff0 100644 --- a/firestore/pipeline_expr.go +++ b/firestore/pipeline_expr.go @@ -45,84 +45,297 @@ type Expr interface { getBaseExpr() *baseExpr // Aritmetic operations + + // Add creates an expression that adds two expressions together, returning it as an Expr. + // + // The parameter 'other' can be a numeric constant or a numeric [Expr]. Add(other any) Expr + // Subtract creates an expression that subtracts the right expression from the left expression, returning it as an Expr. + // + // The parameter 'other' can be a numeric constant or a numeric [Expr]. Subtract(other any) Expr + // Multiply creates an expression that multiplies the left and right expressions, returning it as an Expr. + // + // The parameter 'other' can be a numeric constant or a numeric [Expr]. Multiply(other any) Expr + // Divide creates an expression that divides the left expression by the right expression, returning it as an Expr. + // + // The parameter 'other' can be a numeric constant or a numeric [Expr]. Divide(other any) Expr + // Abs creates an expression that is the absolute value of the input field or expression. Abs() Expr + // Floor creates an expression that is the largest integer that isn't less than the input field or expression. Floor() Expr + // Ceil creates an expression that is the smallest integer that isn't less than the input field or expression. Ceil() Expr + // Exp creates an expression that is the Euler's number e raised to the power of the input field or expression. Exp() Expr + // Log creates an expression that is logarithm of the left expression to base as the right expression, returning it as an Expr. + // + // The parameter 'other' can be a numeric constant or a numeric [Expr]. Log(other any) Expr + // Log10 creates an expression that is the base 10 logarithm of the input field or expression. Log10() Expr + // Ln creates an expression that is the natural logarithm (base e) of the input field or expression. Ln() Expr + // Mod creates an expression that computes the modulo of the left expression by the right expression, returning it as an Expr. + // + // The parameter 'other' can be a numeric constant or a numeric [Expr]. Mod(other any) Expr + // Pow creates an expression that computes the left expression raised to the power of the right expression, returning it as an Expr. + // + // The parameter 'other' can be a numeric constant or a numeric [Expr]. Pow(other any) Expr + // Round creates an expression that rounds the input field or expression to nearest integer. Round() Expr + // Sqrt creates an expression that is the square root of the input field or expression. Sqrt() Expr // Array operations + // ArrayContains creates a boolean expression that checks if an array contains a specific value. + // + // The parameter 'value' can be a constant (e.g., string, int, bool) or an [Expr]. ArrayContains(value any) BooleanExpr + // ArrayContainsAll creates a boolean expression that checks if an array contains all the specified values. + // + // The parameter 'values' can be a slice of constants (e.g., []string, []int) or an [Expr] that evaluates to an array. ArrayContainsAll(values any) BooleanExpr + // ArrayContainsAny creates a boolean expression that checks if an array contains any of the specified values. + // + // The parameter 'values' can be a slice of constants (e.g., []string, []int) or an [Expr] that evaluates to an array. ArrayContainsAny(values any) BooleanExpr + // ArrayLength creates an expression that calculates the length of an array. ArrayLength() Expr + // EqualAny creates a boolean expression that checks if the expression is equal to any of the specified values. + // + // The parameter 'values' can be a slice of constants (e.g., []string, []int) or an [Expr] that evaluates to an array. EqualAny(values any) BooleanExpr + // NotEqualAny creates a boolean expression that checks if the expression is not equal to any of the specified values. + // + // The parameter 'values' can be a slice of constants (e.g., []string, []int) or an [Expr] that evaluates to an array. NotEqualAny(values any) BooleanExpr + // ArrayGet creates an expression that retrieves an element from an array at a specified index. + // + // The parameter 'offset' is the 0-based index of the element to retrieve. + // It can be an integer constant or an [Expr] that evaluates to an integer. ArrayGet(offset any) Expr + // ArrayReverse creates an expression that reverses the order of elements in an array. ArrayReverse() Expr + // ArrayConcat creates an expression that concatenates multiple arrays into a single array. + // + // The parameter 'otherArrays' can be a mix of array constants (e.g., []string, []int) or [Expr]s that evaluate to arrays. ArrayConcat(otherArrays ...any) Expr + // ArraySum creates an expression that calculates the sum of all elements in a numeric array. ArraySum() Expr + // ArrayMaximum creates an expression that finds the maximum element in a numeric array. ArrayMaximum() Expr + // ArrayMinimum creates an expression that finds the minimum element in a numeric array. ArrayMinimum() Expr // Timestamp operations + // TimestampAdd creates an expression that adds a specified amount of time to a timestamp. + // + // The parameter 'unit' can be a string constant (e.g., "day") or an [Expr] that evaluates to a valid unit string. + // Valid units include "microsecond", "millisecond", "second", "minute", "hour" and "day". + // The parameter 'amount' can be an integer constant or an [Expr] that evaluates to an integer. TimestampAdd(unit, amount any) Expr + // TimestampSubtract creates an expression that subtracts a specified amount of time from a timestamp. + // + // The parameter 'unit' can be a string constant (e.g., "hour") or an [Expr] that evaluates to a valid unit string. + // Valid units include "microsecond", "millisecond", "second", "minute", "hour" and "day". + // The parameter 'amount' can be an integer constant or an [Expr] that evaluates to an integer. TimestampSubtract(unit, amount any) Expr + // TimestampTruncate creates an expression that truncates a timestamp to a specified granularity. + // + // The parameter 'granularity' can be a string constant (e.g., "month") or an [Expr] that evaluates to a valid granularity string. + // Valid values are "microsecond", "millisecond", "second", "minute", "hour", "day", "week", "week(monday)", "week(tuesday)", + // "week(wednesday)", "week(thursday)", "week(friday)", "week(saturday)", "week(sunday)", "isoweek", "month", "quarter", "year", and "isoyear". + TimestampTruncate(granularity any) Expr + // TimestampTruncateWithTimezone creates an expression that truncates a timestamp to a specified granularity in a given timezone. + // + // The parameter 'granularity' can be a string constant (e.g., "week") or an [Expr] that evaluates to a valid granularity string. + // Valid values are "microsecond", "millisecond", "second", "minute", "hour", "day", "week", "week(monday)", "week(tuesday)", + // "week(wednesday)", "week(thursday)", "week(friday)", "week(saturday)", "week(sunday)", "isoweek", "month", "quarter", "year", and "isoyear". + // The parameter 'timezone' can be a string constant (e.g., "America/Los_Angeles") or an [Expr] that evaluates to a valid timezone string. + // Valid values are from the TZ database or in the format "Etc/GMT-1". + TimestampTruncateWithTimezone(granularity any, timezone string) Expr + // TimestampToUnixMicros creates an expression that converts a timestamp expression to the number of microseconds since + // the Unix epoch (1970-01-01 00:00:00 UTC). TimestampToUnixMicros() Expr + // TimestampToUnixMillis creates an expression that converts a timestamp expression to the number of milliseconds since + // the Unix epoch (1970-01-01 00:00:00 UTC). TimestampToUnixMillis() Expr + // TimestampToUnixSeconds creates an expression that converts a timestamp expression to the number of seconds since + // the Unix epoch (1970-01-01 00:00:00 UTC). TimestampToUnixSeconds() Expr + // UnixMicrosToTimestamp creates an expression that converts a Unix timestamp in microseconds to a Firestore timestamp. UnixMicrosToTimestamp() Expr + // UnixMillisToTimestamp creates an expression that converts a Unix timestamp in milliseconds to a Firestore timestamp. UnixMillisToTimestamp() Expr + // UnixSecondsToTimestamp creates an expression that converts a Unix timestamp in seconds to a Firestore timestamp. UnixSecondsToTimestamp() Expr // Comparison operations + // Equal creates a boolean expression that checks if the expression is equal to the other value. + // + // The parameter 'other' can be a constant (e.g., string, int, bool) or an [Expr]. Equal(other any) BooleanExpr + // NotEqual creates a boolean expression that checks if the expression is not equal to the other value. + // + // The parameter 'other' can be a constant (e.g., string, int, bool) or an [Expr]. NotEqual(other any) BooleanExpr + // GreaterThan creates a boolean expression that checks if the expression is greater than the other value. + // + // The parameter 'other' can be a constant (e.g., string, int, bool) or an [Expr]. GreaterThan(other any) BooleanExpr + // GreaterThanOrEqual creates a boolean expression that checks if the expression is greater than or equal to the other value. + // + // The parameter 'other' can be a constant (e.g., string, int, bool) or an [Expr]. GreaterThanOrEqual(other any) BooleanExpr + // LessThan creates a boolean expression that checks if the expression is less than the other value. + // + // The parameter 'other' can be a constant (e.g., string, int, bool) or an [Expr]. LessThan(other any) BooleanExpr + // LessThanOrEqual creates a boolean expression that checks if the expression is less than or equal to the other value. + // + // The parameter 'other' can be a constant (e.g., string, int, bool) or an [Expr]. LessThanOrEqual(other any) BooleanExpr + // General functions + // Length creates an expression that calculates the length of string, array, map or vector. + Length() Expr + // Reverse creates an expression that reverses a string, or array. + Reverse() Expr + // Concat creates an expression that concatenates expressions together. + // + // The parameter 'others' can be a list of constants (e.g., string, int) or [Expr]. + Concat(others ...any) Expr + + // Key functions + // GetCollectionID creates an expression that returns the ID of the collection that contains the document. + GetCollectionID() Expr + // GetDocumentID creates an expression that returns the ID of the document. + GetDocumentID() Expr + + // Logical functions + // IfError creates an expression that evaluates and returns the receiver expression if it does not produce an error; + // otherwise, it evaluates and returns `catchExprOrValue`. + // + // The parameter 'catchExprOrValue' is the expression or value to return if the receiver expression errors. + IfError(catchExprOrValue any) Expr + // IfAbsent creates an expression that returns a default value if an expression evaluates to an absent value. + // + // The parameter 'catchExprOrValue' is the value to return if the expression is absent. + // It can be a constant or an [Expr]. + IfAbsent(catchExprOrValue any) Expr + + // Object functions + // MapGet creates an expression that accesses a value from a map (object) field using the provided key. + // + // The parameter 'strOrExprkey' is the key to access in the map. + // It can be a string constant or an [Expr] that evaluates to a string. + MapGet(strOrExprkey any) Expr + // MapMerge creates an expression that merges multiple maps into a single map. + // If multiple maps have the same key, the later value is used. + // + // The parameter 'secondMap' is an [Expr] representing the second map. + // The parameter 'otherMaps' is a list of additional [Expr]s representing maps to merge. + MapMerge(secondMap Expr, otherMaps ...Expr) Expr + // MapRemove creates an expression that removes a key from a map. + // + // The parameter 'strOrExprkey' is the key to remove from the map. + // It can be a string constant or an [Expr] that evaluates to a string. + MapRemove(strOrExprkey any) Expr + // Aggregators + // Sum creates an aggregate function that calculates the sum of the expression. Sum() AggregateFunction + // Average creates an aggregate function that calculates the average of the expression. Average() AggregateFunction + // Count creates an aggregate function that counts the number of documents. Count() AggregateFunction // String functions + // ByteLength creates an expression that calculates the length of a string represented by a field or [Expr] in UTF-8 + // bytes. ByteLength() Expr + // CharLength creates an expression that calculates the character length of a string field or expression in UTF8. CharLength() Expr + // EndsWith creates a boolean expression that checks if the string expression ends with the specified suffix. + // + // The parameter 'suffix' can be a string constant or an [Expr] that evaluates to a string. EndsWith(suffix any) BooleanExpr + // Like creates a boolean expression that checks if the string expression matches the specified pattern. + // + // The parameter 'suffix' can be a string constant or an [Expr] that evaluates to a string. Like(suffix any) BooleanExpr + // RegexContains creates a boolean expression that checks if the string expression contains a match for the specified regex pattern. + // + // The parameter 'pattern' can be a string constant or an [Expr] that evaluates to a string. RegexContains(pattern any) BooleanExpr + // RegexMatch creates a boolean expression that checks if the string expression matches the specified regex pattern. + // + // The parameter 'pattern' can be a string constant or an [Expr] that evaluates to a string. RegexMatch(pattern any) BooleanExpr + // StartsWith creates a boolean expression that checks if the string expression starts with the specified prefix. + // + // The parameter 'prefix' can be a string constant or an [Expr] that evaluates to a string. StartsWith(prefix any) BooleanExpr + // StringConcat creates an expression that concatenates multiple strings into a single string. + // + // The parameter 'otherStrings' can be a mix of string constants or [Expr]s that evaluate to strings. StringConcat(otherStrings ...any) Expr + // StringContains creates a boolean expression that checks if the string expression contains the specified substring. + // + // The parameter 'substring' can be a string constant or an [Expr] that evaluates to a string. StringContains(substring any) BooleanExpr + // StringReverse creates an expression that reverses a string. StringReverse() Expr - Join(separator any) Expr + // Join creates an expression that joins the elements of a string array into a single string. + // + // The parameter 'delimiter' can be a string constant or an [Expr] that evaluates to a string. + Join(delimiter any) Expr + // Substring creates an expression that returns a substring of a string. + // + // The parameter 'index' is the starting index of the substring. + // It can be an integer constant or an [Expr] that evaluates to an integer. + // The parameter 'offset' is the length of the substring. + // It can be an integer constant or an [Expr] that evaluates to an integer. Substring(index, offset any) Expr + // ToLower creates an expression that converts a string to lowercase. ToLower() Expr + // ToUpper creates an expression that converts a string to uppercase. ToUpper() Expr + // Trim creates an expression that removes leading and trailing whitespace from a string. Trim() Expr + // Split creates an expression that splits a string by a delimiter. + // + // The parameter 'delimiter' can be a string constant or an [Expr] that evaluates to a string. + Split(delimiter any) Expr + + // Type creates an expression that returns the type of the expression. + Type() Expr // Vector functions + // CosineDistance creates an expression that calculates the cosine distance between two vectors. + // + // The parameter 'other' can be [Vector32], [Vector64], []float32, []float64 or an [Expr] that evaluates to a vector. CosineDistance(other any) Expr + // DotProduct creates an expression that calculates the dot product of two vectors. + // + // The parameter 'other' can be [Vector32], [Vector64], []float32, []float64 or an [Expr] that evaluates to a vector. DotProduct(other any) Expr + // EuclideanDistance creates an expression that calculates the euclidean distance between two vectors. + // + // The parameter 'other' can be [Vector32], [Vector64], []float32, []float64 or an [Expr] that evaluates to a vector. EuclideanDistance(other any) Expr + // VectorLength creates an expression that calculates the length of a vector. VectorLength() Expr // Ordering + // Ascending creates an ordering expression for ascending order. Ascending() Ordering + // Descending creates an ordering expression for descending order. Descending() Ordering // As assigns an alias to an expression. @@ -176,6 +389,12 @@ func (b *baseExpr) TimestampAdd(unit, amount any) Expr { return TimestampAdd(b, func (b *baseExpr) TimestampSubtract(unit, amount any) Expr { return TimestampSubtract(b, unit, amount) } +func (b *baseExpr) TimestampTruncate(granularity any) Expr { + return TimestampTruncate(b, granularity) +} +func (b *baseExpr) TimestampTruncateWithTimezone(granularity any, timezone string) Expr { + return TimestampTruncateWithTimezone(b, granularity, timezone) +} func (b *baseExpr) TimestampToUnixMicros() Expr { return TimestampToUnixMicros(b) } func (b *baseExpr) TimestampToUnixMillis() Expr { return TimestampToUnixMillis(b) } func (b *baseExpr) TimestampToUnixSeconds() Expr { return TimestampToUnixSeconds(b) } @@ -191,6 +410,26 @@ func (b *baseExpr) GreaterThanOrEqual(other any) BooleanExpr { return GreaterTha func (b *baseExpr) LessThan(other any) BooleanExpr { return LessThan(b, other) } func (b *baseExpr) LessThanOrEqual(other any) BooleanExpr { return LessThanOrEqual(b, other) } +// General functions +func (b *baseExpr) Length() Expr { return Length(b) } +func (b *baseExpr) Reverse() Expr { return Reverse(b) } +func (b *baseExpr) Concat(others ...any) Expr { return Concat(b, others...) } + +// Key functions +func (b *baseExpr) GetCollectionID() Expr { return GetCollectionID(b) } +func (b *baseExpr) GetDocumentID() Expr { return GetDocumentID(b) } + +// Logical functions +func (b *baseExpr) IfError(catchExprOrValue any) Expr { return IfError(b, catchExprOrValue) } +func (b *baseExpr) IfAbsent(catchExprOrValue any) Expr { return IfAbsent(b, catchExprOrValue) } + +// Object functions +func (b *baseExpr) MapGet(strOrExprkey any) Expr { return MapGet(b, strOrExprkey) } +func (b *baseExpr) MapMerge(secondMap Expr, otherMaps ...Expr) Expr { + return MapMerge(b, secondMap, otherMaps...) +} +func (b *baseExpr) MapRemove(strOrExprkey any) Expr { return MapRemove(b, strOrExprkey) } + // Aggregation operations func (b *baseExpr) Sum() AggregateFunction { return Sum(b) } func (b *baseExpr) Average() AggregateFunction { return Average(b) } @@ -211,11 +450,15 @@ func (b *baseExpr) StartsWith(prefix any) BooleanExpr { return StartsWith func (b *baseExpr) StringConcat(otherStrings ...any) Expr { return StringConcat(b, otherStrings...) } func (b *baseExpr) StringContains(substring any) BooleanExpr { return StringContains(b, substring) } func (b *baseExpr) StringReverse() Expr { return StringReverse(b) } -func (b *baseExpr) Join(separator any) Expr { return Join(b, separator) } +func (b *baseExpr) Join(delimiter any) Expr { return Join(b, delimiter) } func (b *baseExpr) Substring(index, offset any) Expr { return Substring(b, index, offset) } func (b *baseExpr) ToLower() Expr { return ToLower(b) } func (b *baseExpr) ToUpper() Expr { return ToUpper(b) } func (b *baseExpr) Trim() Expr { return Trim(b) } +func (b *baseExpr) Split(delimiter any) Expr { return Split(b, delimiter) } + +// Type functions +func (b *baseExpr) Type() Expr { return Type(b) } // Vector functions func (b *baseExpr) CosineDistance(other any) Expr { return CosineDistance(b, other) } diff --git a/firestore/pipeline_filter_condition.go b/firestore/pipeline_filter_condition.go index e571e7cbdd04..33a9fa3f9aad 100644 --- a/firestore/pipeline_filter_condition.go +++ b/firestore/pipeline_filter_condition.go @@ -18,6 +18,19 @@ package firestore type BooleanExpr interface { Expr // Embed Expr interface isBooleanExpr() + + // Conditional creates an expression that evaluates a condition and returns one of two expressions. + // + // The parameter 'thenVal' is the expression to return if the condition is true. + // The parameter 'elseVal' is the expression to return if the condition is false. + Conditional(thenVal, elseVal any) Expr + // IfErrorBoolean creates a boolean expression that evaluates and returns the receiver expression if it does not produce an error; + // otherwise, it evaluates and returns `catchExpr`. + // + // The parameter 'catchExpr' is the boolean expression to return if the receiver expression errors. + IfErrorBoolean(catchExpr BooleanExpr) BooleanExpr + // Not creates an expression that negates a boolean expression. + Not() BooleanExpr } // baseBooleanExpr provides common methods for all BooleanExpr implementations. @@ -26,6 +39,15 @@ type baseBooleanExpr struct { } func (b *baseBooleanExpr) isBooleanExpr() {} +func (b *baseBooleanExpr) Conditional(thenVal, elseVal any) Expr { + return Conditional(b, thenVal, elseVal) +} +func (b *baseBooleanExpr) IfErrorBoolean(catchExpr BooleanExpr) BooleanExpr { + return IfErrorBoolean(b, catchExpr) +} +func (b *baseBooleanExpr) Not() BooleanExpr { + return Not(b) +} // Ensure that baseBooleanExpr implements the BooleanExpr interface. var _ BooleanExpr = (*baseBooleanExpr)(nil) @@ -210,28 +232,6 @@ func LessThanOrEqual(left, right any) BooleanExpr { return &baseBooleanExpr{baseFunction: leftRightToBaseFunction("less_than_or_equal", left, right)} } -// Equivalent creates an expression that checks if field's value or an expression is equal to an expression or a constant value, -// returning it as a BooleanExpr. This is an alias for Equal. -// - left: The field path string, [FieldPath] or [Expr] to compare. -// - right: The constant value or [Expr] to compare to. -// -// Example: -// -// // Check if the 'age' field is equal to 21 -// Equivalent(FieldOf("age"), 21) -// -// // Check if the 'age' field is equal to an expression -// Equivalent(FieldOf("age"), FieldOf("minAge").Add(10)) -// -// // Check if the 'age' field is equal to the 'limit' field -// Equivalent("age", FieldOf("limit")) -// -// // Check if the 'city' field is equal to string constant "London" -// Equivalent("city", "London") -func Equivalent(left, right any) BooleanExpr { - return &baseBooleanExpr{baseFunction: leftRightToBaseFunction("equivalent", left, right)} -} - // EndsWith creates an expression that checks if a string field or expression ends with a given suffix. // - exprOrFieldPath can be a field path string, [FieldPath] or [Expr]. // - suffix string or [Expr] to check for. @@ -303,3 +303,38 @@ func StartsWith(exprOrFieldPath any, prefix any) BooleanExpr { func StringContains(exprOrFieldPath any, substring any) BooleanExpr { return &baseBooleanExpr{baseFunction: newBaseFunction("string_contains", []Expr{asFieldExpr(exprOrFieldPath), asStringExpr(substring)})} } + +// And creates an expression that performs a logical 'AND' operation. +func And(condition BooleanExpr, right ...BooleanExpr) BooleanExpr { + return &baseBooleanExpr{baseFunction: newBaseFunctionFromBooleans("and", append([]BooleanExpr{condition}, right...))} +} + +// FieldExists creates an expression that checks if a field exists. +func FieldExists(exprOrField any) BooleanExpr { + return &baseBooleanExpr{baseFunction: newBaseFunction("exists", []Expr{asFieldExpr(exprOrField)})} +} + +// Not creates an expression that negates a boolean expression. +func Not(condition BooleanExpr) BooleanExpr { + return &baseBooleanExpr{baseFunction: newBaseFunction("not", []Expr{condition})} +} + +// Or creates an expression that performs a logical 'OR' operation. +func Or(condition BooleanExpr, right ...BooleanExpr) BooleanExpr { + return &baseBooleanExpr{baseFunction: newBaseFunctionFromBooleans("or", append([]BooleanExpr{condition}, right...))} +} + +// Xor creates an expression that performs a logical 'XOR' operation. +func Xor(condition BooleanExpr, right ...BooleanExpr) BooleanExpr { + return &baseBooleanExpr{baseFunction: newBaseFunctionFromBooleans("xor", append([]BooleanExpr{condition}, right...))} +} + +// IsError creates an expression that checks if an expression evaluates to an error. +func IsError(expr Expr) BooleanExpr { + return &baseBooleanExpr{baseFunction: newBaseFunction("is_error", []Expr{expr})} +} + +// IsAbsent creates an expression that checks if an expression evaluates to an absent value. +func IsAbsent(exprOrField any) BooleanExpr { + return &baseBooleanExpr{baseFunction: newBaseFunction("is_absent", []Expr{asFieldExpr(exprOrField)})} +} diff --git a/firestore/pipeline_function.go b/firestore/pipeline_function.go index 280be1225996..8a5e4654a7b4 100644 --- a/firestore/pipeline_function.go +++ b/firestore/pipeline_function.go @@ -39,7 +39,6 @@ var _ Function = (*baseFunction)(nil) func newBaseFunction(name string, params []Expr) *baseFunction { argsPbVals := make([]*pb.Value, 0, len(params)) for i, param := range params { - paramExpr := asFieldExpr(param) pbVal, err := paramExpr.toProto() if err != nil { @@ -57,6 +56,14 @@ func newBaseFunction(name string, params []Expr) *baseFunction { return &baseFunction{baseExpr: &baseExpr{pbVal: pbVal}} } +func newBaseFunctionFromBooleans(name string, params []BooleanExpr) *baseFunction { + exprs := make([]Expr, len(params)) + for i, p := range params { + exprs[i] = p + } + return newBaseFunction(name, exprs) +} + // Add creates an expression that adds two expressions together, returning it as an Expr. // - left can be a field path string, [FieldPath] or [Expr]. // - right can be a numeric constant or a numeric [Expr]. @@ -170,6 +177,28 @@ func TimestampSubtract(timestamp, unit, amount any) Expr { return newBaseFunction("timestamp_subtract", []Expr{asFieldExpr(timestamp), asStringExpr(unit), asInt64Expr(amount)}) } +// TimestampTruncate creates an expression that truncates a timestamp to a specified granularity. +// - timestamp can be a field path string, [FieldPath] or [Expr]. +// - granularity can be a string or an [Expr]. Valid values are "microsecond", +// "millisecond", "second", "minute", "hour", "day", "week", "week(monday)", "week(tuesday)", +// "week(wednesday)", "week(thursday)", "week(friday)", "week(saturday)", "week(sunday)", +// "isoweek", "month", "quarter", "year", and "isoyear". +func TimestampTruncate(timestamp, granularity any) Expr { + return newBaseFunction("timestamp_trunc", []Expr{asFieldExpr(timestamp), asStringExpr(granularity)}) +} + +// TimestampTruncateWithTimezone creates an expression that truncates a timestamp to a specified granularity in a given timezone. +// - timestamp can be a field path string, [FieldPath] or [Expr]. +// - granularity can be a string or an [Expr]. Valid values are "microsecond", +// "millisecond", "second", "minute", "hour", "day", "week", "week(monday)", "week(tuesday)", +// "week(wednesday)", "week(thursday)", "week(friday)", "week(saturday)", "week(sunday)", +// "isoweek", "month", "quarter", "year", and "isoyear". +// - timezone can be a string or an [Expr]. Valid values are from the TZ database +// (e.g., "America/Los_Angeles") or in the format "Etc/GMT-1". +func TimestampTruncateWithTimezone(timestamp, granularity any, timezone string) Expr { + return newBaseFunction("timestamp_trunc", []Expr{asFieldExpr(timestamp), asStringExpr(granularity), asStringExpr(timezone)}) +} + // TimestampToUnixMicros creates an expression that converts a timestamp expression to the number of microseconds since // the Unix epoch (1970-01-01 00:00:00 UTC). // - timestamp can be a field path string, [FieldPath] or [Expr]. @@ -273,7 +302,7 @@ func ArrayMinimum(exprOrFieldPath any) Expr { } // ByteLength creates an expression that calculates the length of a string represented by a field or [Expr] in UTF-8 -// bytes, or just the length of a Blob. +// bytes. // - exprOrFieldPath can be a field path string, [FieldPath] or [Expr]. func ByteLength(exprOrFieldPath any) Expr { return newBaseFunction("byte_length", []Expr{asFieldExpr(exprOrFieldPath)}) @@ -300,9 +329,9 @@ func StringReverse(exprOrFieldPath any) Expr { // Join creates an expression that joins the elements of a string array into a single string. // - exprOrFieldPath can be a field path string, [FieldPath] or an [Expr] that evaluates to a string array. -// - separator is the string to use as a separator between elements. -func Join(exprOrFieldPath any, separator any) Expr { - return newBaseFunction("join", []Expr{asFieldExpr(exprOrFieldPath), asStringExpr(separator)}) +// - delimiter is the string to use as a separator between elements. +func Join(exprOrFieldPath any, delimiter any) Expr { + return newBaseFunction("join", []Expr{asFieldExpr(exprOrFieldPath), asStringExpr(delimiter)}) } // Substring creates an expression that returns a substring of a string. @@ -331,6 +360,19 @@ func Trim(exprOrFieldPath any) Expr { return newBaseFunction("trim", []Expr{asFieldExpr(exprOrFieldPath)}) } +// Split creates an expression that splits a string by a delimiter. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expr] that evaluates to a string. +// - delimiter is the string to use to split by. +func Split(exprOrFieldPath any, delimiter any) Expr { + return newBaseFunction("split", []Expr{asFieldExpr(exprOrFieldPath), asStringExpr(delimiter)}) +} + +// Type creates an expression that returns the type of the expression. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expr]. +func Type(exprOrFieldPath any) Expr { + return newBaseFunction("type", []Expr{asFieldExpr(exprOrFieldPath)}) +} + // CosineDistance creates an expression that calculates the cosine distance between two vectors. // - vector1 can be a field path string, [FieldPath] or [Expr]. // - vector2 can be [Vector32], [Vector64], []float32, []float64 or [Expr]. @@ -357,3 +399,142 @@ func EuclideanDistance(vector1 any, vector2 any) Expr { func VectorLength(exprOrFieldPath any) Expr { return newBaseFunction("vector_length", []Expr{asFieldExpr(exprOrFieldPath)}) } + +// Length creates an expression that calculates the length of string, array, map or vector. +// - exprOrField can be a field path string, [FieldPath] or an [Expr] that returns a string, array, map or vector when evaluated. +// +// Example: +// +// // Length of the 'name' field. +// Length("name") +func Length(exprOrField any) Expr { + return newBaseFunction("length", []Expr{asFieldExpr(exprOrField)}) +} + +// Reverse creates an expression that reverses a string, or array. +// - exprOrField can be a field path string, [FieldPath] or an [Expr] that returns a string, or array when evaluated. +// +// Example: +// +// // Reverse the 'name' field. +// +// Reverse("name") +func Reverse(exprOrField any) Expr { + return newBaseFunction("reverse", []Expr{asFieldExpr(exprOrField)}) +} + +// Concat creates an expression that concatenates expressions together. +// - exprOrField can be a field path string, [FieldPath] or an [Expr]. +// - others can be a list of constants or [Expr]. +// +// Example: +// +// // Concat the 'name' field with a constant string. +// Concat("name", "-suffix") +func Concat(exprOrField any, others ...any) Expr { + return newBaseFunction("concat", append([]Expr{asFieldExpr(exprOrField)}, toArrayOfExprOrConstant(others)...)) +} + +// GetCollectionID creates an expression that returns the ID of the collection that contains the document. +// - exprOrField can be a field path string, [FieldPath] or an [Expr] that evaluates to a field path. +func GetCollectionID(exprOrField any) Expr { + return newBaseFunction("collection_id", []Expr{asFieldExpr(exprOrField)}) +} + +// GetDocumentID creates an expression that returns the ID of the document. +// - exprStringOrDocRef can be a string, a [DocumentRef], or an [Expr] that evaluates to a document reference. +func GetDocumentID(exprStringOrDocRef any) Expr { + var expr Expr + switch v := exprStringOrDocRef.(type) { + case string: + expr = ConstantOf(v) + case *DocumentRef: + expr = ConstantOf(v) + case Expr: + expr = v + default: + return &baseFunction{baseExpr: &baseExpr{err: fmt.Errorf("firestore: value must be a string, DocumentRef, or Expr, but got %T", exprStringOrDocRef)}} + } + + return newBaseFunction("document_id", []Expr{expr}) +} + +// Conditional creates an expression that evaluates a condition and returns one of two expressions. +// - condition is the boolean expression to evaluate. +// - thenVal is the expression to return if the condition is true. +// - elseVal is the expression to return if the condition is false. +func Conditional(condition BooleanExpr, thenVal, elseVal any) Expr { + return newBaseFunction("conditional", []Expr{condition, toExprOrConstant(thenVal), toExprOrConstant(elseVal)}) +} + +// LogicalMaximum creates an expression that evaluates to the maximum value in a list of expressions. +// - exprOrField can be a field path string, [FieldPath] or an [Expr]. +// - others can be a list of constants or [Expr]. +func LogicalMaximum(exprOrField any, others ...any) Expr { + return newBaseFunction("maximum", append([]Expr{asFieldExpr(exprOrField)}, toArrayOfExprOrConstant(others)...)) +} + +// LogicalMinimum creates an expression that evaluates to the minimum value in a list of expressions. +// - exprOrField can be a field path string, [FieldPath] or an [Expr]. +// - others can be a list of constants or [Expr]. +func LogicalMinimum(exprOrField any, others ...any) Expr { + return newBaseFunction("minimum", append([]Expr{asFieldExpr(exprOrField)}, toArrayOfExprOrConstant(others)...)) +} + +// IfError creates an expression that evaluates and returns `tryExpr` if it does not produce an error; +// otherwise, it evaluates and returns `catchExprOrValue`. It returns a new [Expr] representing +// the if_error operation. +// - tryExpr is the expression to try. +// - catchExprOrValue is the expression or value to return if `tryExpr` errors. +func IfError(tryExpr Expr, catchExprOrValue any) Expr { + return newBaseFunction("if_error", []Expr{tryExpr, toExprOrConstant(catchExprOrValue)}) +} + +// IfErrorBoolean creates a boolean expression that evaluates and returns `tryExpr` if it does not produce an error; +// otherwise, it evaluates and returns `catchExpr`. It returns a new [BooleanExpr] representing +// the if_error operation. +// - tryExpr is the boolean expression to try. +// - catchExpr is the boolean expression to return if `tryExpr` errors. +func IfErrorBoolean(tryExpr BooleanExpr, catchExpr BooleanExpr) BooleanExpr { + return &baseBooleanExpr{baseFunction: newBaseFunction("if_error", []Expr{tryExpr, catchExpr})} +} + +// IfAbsent creates an expression that returns a default value if an expression evaluates to an absent value. +// - exprOrField can be a field path string, [FieldPath] or an [Expr]. +// - elseValue is the value to return if the expression is absent. +func IfAbsent(exprOrField any, elseValue any) Expr { + return newBaseFunction("if_absent", []Expr{asFieldExpr(exprOrField), toExprOrConstant(elseValue)}) +} + +// Map creates an expression that creates a Firestore map value from an input object. +// - elements: The input map to evaluate in the expression. +func Map(elements map[string]any) Expr { + exprs := make([]Expr, 0, len(elements)*2) + for k, v := range elements { + exprs = append(exprs, ConstantOf(k), toExprOrConstant(v)) + } + return newBaseFunction("map", exprs) +} + +// MapGet creates an expression that accesses a value from a map (object) field using the provided key. +// - exprOrField: The expression representing the map. +// - strOrExprkey: The key to access in the map. +func MapGet(exprOrField any, strOrExprkey any) Expr { + return newBaseFunction("map_get", []Expr{asFieldExpr(exprOrField), asStringExpr(strOrExprkey)}) +} + +// MapMerge creates an expression that merges multiple maps into a single map. +// If multiple maps have the same key, the later value is used. +// - exprOrField: First map expression that will be merged. +// - secondMap: Second map expression that will be merged. +// - otherMaps: Additional maps to merge. +func MapMerge(exprOrField any, secondMap Expr, otherMaps ...Expr) Expr { + return newBaseFunction("map_merge", append([]Expr{asFieldExpr(exprOrField), secondMap}, otherMaps...)) +} + +// MapRemove creates an expression that removes a key from a map. +// - exprOrField: The expression representing the map. +// - strOrExprkey: The key to remove from the map. +func MapRemove(exprOrField any, strOrExprkey any) Expr { + return newBaseFunction("map_remove", []Expr{asFieldExpr(exprOrField), asStringExpr(strOrExprkey)}) +} diff --git a/firestore/pipeline_integration_test.go b/firestore/pipeline_integration_test.go index 0d6a1ab38c95..e0f1cee1914b 100644 --- a/firestore/pipeline_integration_test.go +++ b/firestore/pipeline_integration_test.go @@ -26,6 +26,7 @@ import ( "cloud.google.com/go/internal/testutil" "github.com/google/go-cmp/cmp/cmpopts" "google.golang.org/api/iterator" + "google.golang.org/genproto/googleapis/type/latlng" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -720,7 +721,183 @@ func TestIntegration_PipelineFunctions(t *testing.T) { t.Run("arithmeticFuncs", arithmeticFuncs) t.Run("aggregateFuncs", aggregateFuncs) t.Run("comparisonFuncs", comparisonFuncs) + t.Run("generalFuncs", generalFuncs) + t.Run("keyFuncs", keyFuncs) + t.Run("objectFuncs", objectFuncs) + t.Run("logicalFuncs", logicalFuncs) + t.Run("typeFuncs", typeFuncs) +} + +func typeFuncs(t *testing.T) { + t.Parallel() + h := testHelper{t} + client := integrationClient(t) + coll := client.Collection(collectionIDs.New()) + docRef1 := coll.NewDoc() + h.mustCreate(docRef1, map[string]interface{}{ + "a": nil, + "b": true, + "c": 1, + "d": "hello", + "e": []byte("world"), + "f": time.Now(), + "g": &latlng.LatLng{Latitude: 32.1, Longitude: -4.5}, + "h": []interface{}{1, 2, 3}, + "i": map[string]interface{}{"j": 1}, + "k": Vector64{1, 2, 3}, + "l": docRef1, + }) + defer deleteDocuments([]*DocumentRef{docRef1}) + + tests := []struct { + name string + pipeline *Pipeline + want map[string]interface{} + }{ + { + name: "Type of null", + pipeline: client.Pipeline().Collection(coll.ID).Select(Type("a").As("type")), + want: map[string]interface{}{"type": "null"}, + }, + { + name: "Type of boolean", + pipeline: client.Pipeline().Collection(coll.ID).Select(Type("b").As("type")), + want: map[string]interface{}{"type": "boolean"}, + }, + { + name: "Type of int64", + pipeline: client.Pipeline().Collection(coll.ID).Select(Type("c").As("type")), + want: map[string]interface{}{"type": "int64"}, + }, + { + name: "Type of string", + pipeline: client.Pipeline().Collection(coll.ID).Select(Type("d").As("type")), + want: map[string]interface{}{"type": "string"}, + }, + { + name: "Type of bytes", + pipeline: client.Pipeline().Collection(coll.ID).Select(Type("e").As("type")), + want: map[string]interface{}{"type": "bytes"}, + }, + { + name: "Type of timestamp", + pipeline: client.Pipeline().Collection(coll.ID).Select(Type("f").As("type")), + want: map[string]interface{}{"type": "timestamp"}, + }, + { + name: "Type of geopoint", + pipeline: client.Pipeline().Collection(coll.ID).Select(Type("g").As("type")), + want: map[string]interface{}{"type": "geo_point"}, + }, + { + name: "Type of array", + pipeline: client.Pipeline().Collection(coll.ID).Select(Type("h").As("type")), + want: map[string]interface{}{"type": "array"}, + }, + { + name: "Type of map", + pipeline: client.Pipeline().Collection(coll.ID).Select(Type("i").As("type")), + want: map[string]interface{}{"type": "map"}, + }, + { + name: "Type of vector", + pipeline: client.Pipeline().Collection(coll.ID).Select(Type("k").As("type")), + want: map[string]interface{}{"type": "vector"}, + }, + { + name: "Type of reference", + pipeline: client.Pipeline().Collection(coll.ID).Select(Type("l").As("type")), + want: map[string]interface{}{"type": "reference"}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := context.Background() + iter := test.pipeline.Execute(ctx) + defer iter.Stop() + + docs, err := iter.GetAll() + if isRetryablePipelineExecuteErr(err) { + t.Errorf("GetAll: %v. Retrying....", err) + return + } else if err != nil { + t.Fatalf("GetAll: %v", err) + return + } + if len(docs) != 1 { + t.Fatalf("expected 1 doc, got %d", len(docs)) + } + got := docs[0].Data() + if diff := testutil.Diff(got, test.want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", got, test.want, diff) + } + }) + } +} + +func objectFuncs(t *testing.T) { + t.Parallel() + h := testHelper{t} + client := integrationClient(t) + coll := client.Collection(collectionIDs.New()) + docRef1 := coll.NewDoc() + h.mustCreate(docRef1, map[string]interface{}{ + "m1": map[string]interface{}{"a": 1, "b": 2}, + "m2": map[string]interface{}{"c": 3, "d": 4}, + }) + defer deleteDocuments([]*DocumentRef{docRef1}) + + tests := []struct { + name string + pipeline *Pipeline + want map[string]interface{} + }{ + { + name: "Map", + pipeline: client.Pipeline().Collection(coll.ID).Select(Map(map[string]any{"a": 1, "b": 2}).As("map")), + want: map[string]interface{}{"map": map[string]interface{}{"a": int64(1), "b": int64(2)}}, + }, + { + name: "MapGet", + pipeline: client.Pipeline().Collection(coll.ID).Select(MapGet("m1", "a").As("value")), + want: map[string]interface{}{"value": int64(1)}, + }, + { + name: "MapMerge", + pipeline: client.Pipeline().Collection(coll.ID).Select(MapMerge("m1", FieldOf("m2")).As("merged")), + want: map[string]interface{}{"merged": map[string]interface{}{"a": int64(1), "b": int64(2), "c": int64(3), "d": int64(4)}}, + }, + { + name: "MapRemove", + pipeline: client.Pipeline().Collection(coll.ID).Select(MapRemove("m1", "a").As("removed")), + want: map[string]interface{}{"removed": map[string]interface{}{"b": int64(2)}}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := context.Background() + iter := test.pipeline.Execute(ctx) + defer iter.Stop() + docs, err := iter.GetAll() + if isRetryablePipelineExecuteErr(err) { + t.Errorf("GetAll: %v. Retrying....", err) + return + } else if err != nil { + t.Fatalf("GetAll: %v", err) + return + } + if len(docs) != 1 { + t.Fatalf("expected 1 doc, got %d", len(docs)) + } + got := docs[0].Data() + if diff := testutil.Diff(got, test.want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", got, test.want, diff) + } + }) + } } func arrayFuncs(t *testing.T) { @@ -836,19 +1013,19 @@ func arrayFuncs(t *testing.T) { docs, err := iter.GetAll() if isRetryablePipelineExecuteErr(err) { - r.Errorf("GetAll: %v. Retrying....", err) + t.Errorf("GetAll: %v. Retrying....", err) return } else if err != nil { - r.Fatalf("GetAll: %v", err) + t.Fatalf("GetAll: %v", err) return } if len(docs) != 1 { - r.Fatalf("expected 1 doc, got %d", len(docs)) + t.Fatalf("expected 1 doc, got %d", len(docs)) return } got := docs[0].Data() if diff := testutil.Diff(got, test.want); diff != "" { - r.Errorf("got: %v, want: %v, diff +want -got: %s", got, test.want, diff) + t.Errorf("got: %v, want: %v, diff +want -got: %s", got, test.want, diff) return } }) @@ -869,6 +1046,7 @@ func stringFuncs(t *testing.T) { "tags": []string{"tag1", "tag2", "tag3"}, "email": "john.doe@example.com", "zipCode": "12345", + "csv": "a,b,c", }) defer deleteDocuments([]*DocumentRef{docRef1}) @@ -879,6 +1057,7 @@ func stringFuncs(t *testing.T) { "tags": []interface{}{"tag1", "tag2", "tag3"}, "email": "john.doe@example.com", "zipCode": "12345", + "csv": "a,b,c", } tests := []struct { @@ -931,6 +1110,11 @@ func stringFuncs(t *testing.T) { pipeline: client.Pipeline().Collection(coll.ID).Select(Trim("name").As("trimmed_name")), want: map[string]interface{}{"trimmed_name": "John Doe"}, }, + { + name: "Split", + pipeline: client.Pipeline().Collection(coll.ID).Select(Split("csv", ",").As("split_string")), + want: map[string]interface{}{"split_string": []interface{}{"a", "b", "c"}}, + }, // String filter conditions { name: "Like", @@ -966,60 +1150,58 @@ func stringFuncs(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - testutil.Retry(t, 3, time.Second, func(r *testutil.R) { - ctx := context.Background() + ctx := context.Background() - iter := test.pipeline.Execute(ctx) - defer iter.Stop() + iter := test.pipeline.Execute(ctx) + defer iter.Stop() - docs, err := iter.GetAll() - if isRetryablePipelineExecuteErr(err) { - r.Errorf("GetAll: %v. Retrying....", err) + docs, err := iter.GetAll() + if isRetryablePipelineExecuteErr(err) { + t.Errorf("GetAll: %v. Retrying....", err) + return + } else if err != nil { + t.Fatalf("GetAll: %v", err) + return + } + lastStage := test.pipeline.stages[len(test.pipeline.stages)-1] + lastStageName := lastStage.name() + + if lastStageName == stageNameSelect { // This is a select query + want, ok := test.want.(map[string]interface{}) + if !ok { + t.Fatalf("invalid test.want type for select query: %T", test.want) return - } else if err != nil { - r.Fatalf("GetAll: %v", err) + } + if len(docs) != 1 { + t.Fatalf("expected 1 doc, got %d", len(docs)) return } - lastStage := test.pipeline.stages[len(test.pipeline.stages)-1] - lastStageName := lastStage.name() - - if lastStageName == stageNameSelect { // This is a select query - want, ok := test.want.(map[string]interface{}) - if !ok { - r.Fatalf("invalid test.want type for select query: %T", test.want) - return - } - if len(docs) != 1 { - r.Fatalf("expected 1 doc, got %d", len(docs)) - return - } - got := docs[0].Data() - if diff := testutil.Diff(got, want); diff != "" { - t.Errorf("got: %v, want: %v, diff +want -got: %s", got, want, diff) - } - } else if lastStageName == stageNameWhere { // This is a where query (filter condition) - want, ok := test.want.([]map[string]interface{}) - if !ok { - r.Fatalf("invalid test.want type for where query: %T", test.want) - return - } - if len(docs) != len(want) { - r.Fatalf("expected %d doc(s), got %d", len(want), len(docs)) - return - } - var gots []map[string]interface{} - for _, doc := range docs { - got := doc.Data() - gots = append(gots, got) - } - if diff := testutil.Diff(gots, want); diff != "" { - t.Errorf("got: %v, want: %v, diff +want -got: %s", gots, want, diff) - } - } else { - r.Fatalf("unknown pipeline stage: %s", lastStageName) + got := docs[0].Data() + if diff := testutil.Diff(got, want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", got, want, diff) + } + } else if lastStageName == stageNameWhere { // This is a where query (filter condition) + want, ok := test.want.([]map[string]interface{}) + if !ok { + t.Fatalf("invalid test.want type for where query: %T", test.want) return } - }) + if len(docs) != len(want) { + t.Fatalf("expected %d doc(s), got %d", len(want), len(docs)) + return + } + var gots []map[string]interface{} + for _, doc := range docs { + got := doc.Data() + gots = append(gots, got) + } + if diff := testutil.Diff(gots, want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", gots, want, diff) + } + } else { + t.Fatalf("unknown pipeline stage: %s", lastStageName) + return + } }) } @@ -1082,27 +1264,25 @@ func vectorFuncs(t *testing.T) { ctx := context.Background() for _, test := range tests { t.Run(test.name, func(t *testing.T) { - testutil.Retry(t, 3, time.Second, func(r *testutil.R) { - iter := test.pipeline.Execute(ctx) - defer iter.Stop() + iter := test.pipeline.Execute(ctx) + defer iter.Stop() - docs, err := iter.GetAll() - if isRetryablePipelineExecuteErr(err) { - r.Errorf("GetAll: %v. Retrying....", err) - return - } else if err != nil { - r.Fatalf("GetAll: %v", err) - return - } - if len(docs) != 1 { - r.Fatalf("expected 1 doc, got %d", len(docs)) - return - } - got := docs[0].Data() - if diff := testutil.Diff(got, test.want); diff != "" { - r.Errorf("got: %v, want: %v, diff +want -got: %s", got, test.want, diff) - } - }) + docs, err := iter.GetAll() + if isRetryablePipelineExecuteErr(err) { + t.Errorf("GetAll: %v. Retrying....", err) + return + } else if err != nil { + t.Fatalf("GetAll: %v", err) + return + } + if len(docs) != 1 { + t.Fatalf("expected 1 doc, got %d", len(docs)) + return + } + got := docs[0].Data() + if diff := testutil.Diff(got, test.want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", got, test.want, diff) + } }) } } @@ -1231,6 +1411,45 @@ func timestampFuncs(t *testing.T) { Select(CurrentTimestamp().As("current_timestamp")), want: map[string]interface{}{"current_timestamp": time.Now().Truncate(time.Microsecond)}, }, + { + name: "TimestampTruncate day", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(TimestampTruncate("timestamp", "day").As("timestamp_trunc_day")), + want: map[string]interface{}{"timestamp_trunc_day": time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location()).Truncate(time.Microsecond)}, + }, + { + name: "TimestampTruncate hour", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(TimestampTruncate("timestamp", "hour").As("timestamp_trunc_hour")), + want: map[string]interface{}{"timestamp_trunc_hour": time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), 0, 0, 0, now.Location()).Truncate(time.Microsecond)}, + }, + { + name: "TimestampTruncate minute", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(TimestampTruncate("timestamp", "minute").As("timestamp_trunc_minute")), + want: map[string]interface{}{"timestamp_trunc_minute": time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), now.Minute(), 0, 0, now.Location()).Truncate(time.Microsecond)}, + }, + { + name: "TimestampTruncate second", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(TimestampTruncate("timestamp", "second").As("timestamp_trunc_second")), + want: map[string]interface{}{"timestamp_trunc_second": time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), now.Minute(), now.Second(), 0, now.Location()).Truncate(time.Microsecond)}, + }, + { + name: "TimestampTruncateWithTimezone day", + pipeline: client.Pipeline(). + Collection(coll.ID). + Select(TimestampTruncateWithTimezone("timestamp", "day", "America/New_York").As("timestamp_trunc_day_ny")), + want: map[string]interface{}{"timestamp_trunc_day_ny": func() time.Time { + loc, _ := time.LoadLocation("America/New_York") + nowInLoc := now.In(loc) + return time.Date(nowInLoc.Year(), nowInLoc.Month(), nowInLoc.Day(), 0, 0, 0, 0, loc).Truncate(time.Microsecond) + }()}, + }, } ctx := context.Background() @@ -1567,13 +1786,6 @@ func comparisonFuncs(t *testing.T) { Where(LessThan("a", 2)), want: []map[string]interface{}{doc1want}, }, - { - name: "Equivalent", - pipeline: client.Pipeline(). - Collection(coll.ID). - Where(Equivalent("a", 1)), - want: []map[string]interface{}{doc1want}, - }, } for _, test := range tests { @@ -1604,3 +1816,258 @@ func comparisonFuncs(t *testing.T) { }) } } + +func keyFuncs(t *testing.T) { + t.Parallel() + h := testHelper{t} + client := integrationClient(t) + coll := client.Collection(collectionIDs.New()) + docRef1 := coll.Doc("doc1") + h.mustCreate(docRef1, map[string]interface{}{ + "a": "hello", + "b": "world", + }) + defer deleteDocuments([]*DocumentRef{docRef1}) + + tests := []struct { + name string + pipeline *Pipeline + want map[string]interface{} + }{ + { + name: "CollectionId", + pipeline: client.Pipeline().Collection(coll.ID).Select(GetCollectionID("__name__").As("collectionId")), + want: map[string]interface{}{"collectionId": coll.ID}, + }, + { + name: "DocumentId", + pipeline: client.Pipeline().Collection(coll.ID).Select(GetDocumentID(docRef1).As("documentId")), + want: map[string]interface{}{"documentId": "doc1"}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := context.Background() + iter := test.pipeline.Execute(ctx) + defer iter.Stop() + + docs, err := iter.GetAll() + if isRetryablePipelineExecuteErr(err) { + t.Errorf("GetAll: %v. Retrying....", err) + return + } else if err != nil { + t.Fatalf("GetAll: %v", err) + return + } + if len(docs) != 1 { + t.Fatalf("expected 1 doc, got %d", len(docs)) + } + got := docs[0].Data() + if diff := testutil.Diff(got, test.want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", got, test.want, diff) + } + }) + } +} + +func generalFuncs(t *testing.T) { + t.Parallel() + h := testHelper{t} + client := integrationClient(t) + coll := client.Collection(collectionIDs.New()) + docRef1 := coll.NewDoc() + h.mustCreate(docRef1, map[string]interface{}{ + "a": "hello", + "b": "world", + }) + defer deleteDocuments([]*DocumentRef{docRef1}) + + tests := []struct { + name string + pipeline *Pipeline + want map[string]interface{} + }{ + { + name: "Length - string literal", + pipeline: client.Pipeline().Collection(coll.ID).Select(Length(ConstantOf("hello")).As("len")), + want: map[string]interface{}{"len": int64(5)}, + }, + { + name: "Length - field", + pipeline: client.Pipeline().Collection(coll.ID).Select(Length("a").As("len")), + want: map[string]interface{}{"len": int64(5)}, + }, + { + name: "Length - field path", + pipeline: client.Pipeline().Collection(coll.ID).Select(Length(FieldPath{"a"}).As("len")), + want: map[string]interface{}{"len": int64(5)}, + }, + { + name: "Reverse - string literal", + pipeline: client.Pipeline().Collection(coll.ID).Select(Reverse(ConstantOf("hello")).As("reverse")), + want: map[string]interface{}{"reverse": "olleh"}, + }, + { + name: "Reverse - field", + pipeline: client.Pipeline().Collection(coll.ID).Select(Reverse("a").As("reverse")), + want: map[string]interface{}{"reverse": "olleh"}, + }, + { + name: "Reverse - field path", + pipeline: client.Pipeline().Collection(coll.ID).Select(Reverse(FieldPath{"a"}).As("reverse")), + want: map[string]interface{}{"reverse": "olleh"}, + }, + { + name: "Concat - two literals", + pipeline: client.Pipeline().Collection(coll.ID).Select(Concat(ConstantOf("hello"), ConstantOf("world")).As("concat")), + want: map[string]interface{}{"concat": "helloworld"}, + }, + { + name: "Concat - literal and field", + pipeline: client.Pipeline().Collection(coll.ID).Select(Concat(ConstantOf("hello"), FieldOf("b")).As("concat")), + want: map[string]interface{}{"concat": "helloworld"}, + }, + { + name: "Concat - two fields", + pipeline: client.Pipeline().Collection(coll.ID).Select(Concat(FieldOf("a"), FieldOf("b")).As("concat")), + want: map[string]interface{}{"concat": "helloworld"}, + }, + { + name: "Concat - field and literal", + pipeline: client.Pipeline().Collection(coll.ID).Select(Concat(FieldOf("a"), ConstantOf("world")).As("concat")), + want: map[string]interface{}{"concat": "helloworld"}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := context.Background() + iter := test.pipeline.Execute(ctx) + defer iter.Stop() + + docs, err := iter.GetAll() + if isRetryablePipelineExecuteErr(err) { + t.Errorf("GetAll: %v. Retrying....", err) + return + } else if err != nil { + t.Fatalf("GetAll: %v", err) + return + } + if len(docs) != 1 { + t.Fatalf("expected 1 doc, got %d", len(docs)) + } + got := docs[0].Data() + if diff := testutil.Diff(got, test.want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", got, test.want, diff) + } + }) + } +} + +func logicalFuncs(t *testing.T) { + t.Parallel() + h := testHelper{t} + client := integrationClient(t) + coll := client.Collection(collectionIDs.New()) + docRef1 := coll.NewDoc() + h.mustCreate(docRef1, map[string]interface{}{ + "a": 1, + "b": 2, + "c": nil, + "d": true, + "e": false, + }) + defer deleteDocuments([]*DocumentRef{docRef1}) + + tests := []struct { + name string + pipeline *Pipeline + want map[string]interface{} + }{ + { + name: "Conditional - true", + pipeline: client.Pipeline().Collection(coll.ID).Select(Conditional(Equal(ConstantOf(1), ConstantOf(1)), FieldOf("a"), FieldOf("b")).As("result")), + want: map[string]interface{}{"result": int64(1)}, + }, + { + name: "Conditional - false", + pipeline: client.Pipeline().Collection(coll.ID).Select(Conditional(Equal(ConstantOf(1), ConstantOf(0)), FieldOf("a"), FieldOf("b")).As("result")), + want: map[string]interface{}{"result": int64(2)}, + }, + { + name: "Conditional - field true", + pipeline: client.Pipeline().Collection(coll.ID).Select(Conditional(Equal(FieldOf("d"), ConstantOf(true)), FieldOf("a"), FieldOf("b")).As("result")), + want: map[string]interface{}{"result": int64(1)}, + }, + { + name: "Conditional - field false", + pipeline: client.Pipeline().Collection(coll.ID).Select(Conditional(Equal(FieldOf("e"), ConstantOf(true)), FieldOf("a"), FieldOf("b")).As("result")), + want: map[string]interface{}{"result": int64(2)}, + }, + { + name: "LogicalMax", + pipeline: client.Pipeline().Collection(coll.ID).Select(LogicalMaximum(FieldOf("a"), FieldOf("b")).As("max")), + want: map[string]interface{}{"max": int64(2)}, + }, + { + name: "LogicalMin", + pipeline: client.Pipeline().Collection(coll.ID).Select(LogicalMinimum(FieldOf("a"), FieldOf("b")).As("min")), + want: map[string]interface{}{"min": int64(1)}, + }, + { + name: "IfError - no error", + pipeline: client.Pipeline().Collection(coll.ID).Select(IfError(FieldOf("a"), ConstantOf(100)).As("result")), + want: map[string]interface{}{"result": int64(1)}, + }, + { + name: "IfError - error", + pipeline: client.Pipeline().Collection(coll.ID).Select(Divide("a", 0).IfError(ConstantOf("was error")).As("ifError")), + want: map[string]interface{}{"ifError": "was error"}, + }, + { + name: "IfErrorBoolean - no error", + pipeline: client.Pipeline().Collection(coll.ID).Select(IfErrorBoolean(Equal(FieldOf("d"), ConstantOf(true)), Equal(ConstantOf(1), ConstantOf(0))).As("result")), + want: map[string]interface{}{"result": true}, + }, + { + name: "IfErrorBoolean - error", + pipeline: client.Pipeline().Collection(coll.ID).Select(IfErrorBoolean(Equal(FieldOf("x"), ConstantOf(true)), Equal(ConstantOf(1), ConstantOf(0))).As("result")), + want: map[string]interface{}{"result": false}, + }, + { + name: "IfAbsent - not absent", + pipeline: client.Pipeline().Collection(coll.ID).Select(IfAbsent(FieldOf("a"), ConstantOf(100)).As("result")), + want: map[string]interface{}{"result": int64(1)}, + }, + { + name: "IfAbsent - absent", + pipeline: client.Pipeline().Collection(coll.ID).Select(IfAbsent(FieldOf("x"), ConstantOf(100)).As("result")), + want: map[string]interface{}{"result": int64(100)}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := context.Background() + iter := test.pipeline.Execute(ctx) + defer iter.Stop() + + docs, err := iter.GetAll() + if isRetryablePipelineExecuteErr(err) { + t.Errorf("GetAll: %v. Retrying....", err) + return + } else if err != nil { + t.Fatalf("GetAll: %v", err) + return + } + if len(docs) != 1 { + t.Fatalf("expected 1 doc, got %d", len(docs)) + } + got := docs[0].Data() + if diff := testutil.Diff(got, test.want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", got, test.want, diff) + } + }) + } +} diff --git a/firestore/pipeline_result.go b/firestore/pipeline_result.go index 6e2a3be7706b..a9e5b0aa2501 100644 --- a/firestore/pipeline_result.go +++ b/firestore/pipeline_result.go @@ -222,7 +222,6 @@ func (it *streamPipelineResultIterator) next() (_ *PipelineResult, err error) { } ctx := withRequestParamsHeader(it.ctx, reqParamsHeaderVal(client.path())) - it.streamClient, err = client.c.ExecutePipeline(ctx, req) if err != nil { return nil, err diff --git a/firestore/pipeline_utils.go b/firestore/pipeline_utils.go index 3cf0d0982eac..97aeb105843c 100644 --- a/firestore/pipeline_utils.go +++ b/firestore/pipeline_utils.go @@ -22,6 +22,14 @@ import ( pb "cloud.google.com/go/firestore/apiv1/firestorepb" ) +func toArrayOfExprOrConstant(val []any) []Expr { + exprs := make([]Expr, 0, len(val)) + for _, v := range val { + exprs = append(exprs, toExprOrConstant(v)) + } + return exprs +} + // newFieldAndArrayBooleanExpr creates a new BooleanExpr for functions that operate on a field/expression and an array of values. func newFieldAndArrayBooleanExpr(name string, exprOrFieldPath any, values any) BooleanExpr { return &baseBooleanExpr{baseFunction: newBaseFunction(name, []Expr{asFieldExpr(exprOrFieldPath), asArrayFunctionExpr(values)})}