diff --git a/firestore/pipeline_constant.go b/firestore/pipeline_constant.go index c90338ae508a..74c1f94e642b 100644 --- a/firestore/pipeline_constant.go +++ b/firestore/pipeline_constant.go @@ -47,7 +47,7 @@ func ConstantOf(value any) Expr { // Handle known scalar types switch value.(type) { - case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, float32, float64, time.Time, *ts.Timestamp, []byte, Vector32, Vector64, *latlng.LatLng: + case string, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, float32, float64, time.Time, *ts.Timestamp, []byte, Vector32, Vector64, bool, *latlng.LatLng, *DocumentRef: pbVal, _, err := toProtoValue(reflect.ValueOf(value)) if err != nil { return &constant{baseExpr: &baseExpr{err: err}} diff --git a/firestore/pipeline_expr.go b/firestore/pipeline_expr.go index be50d31ad7b2..25f7a43fac85 100644 --- a/firestore/pipeline_expr.go +++ b/firestore/pipeline_expr.go @@ -61,6 +61,20 @@ type Expr interface { Round() Expr Sqrt() Expr + // Array operations + ArrayContains(value any) BooleanExpr + ArrayContainsAll(values any) BooleanExpr + ArrayContainsAny(values any) BooleanExpr + ArrayLength() Expr + EqualAny(values any) BooleanExpr + NotEqualAny(values any) BooleanExpr + ArrayGet(offset any) Expr + ArrayReverse() Expr + ArrayConcat(otherArrays ...any) Expr + ArraySum() Expr + ArrayMaximum() Expr + ArrayMinimum() Expr + // Timestamp operations TimestampAdd(unit, amount any) Expr TimestampSubtract(unit, amount any) Expr @@ -85,6 +99,35 @@ type Expr interface { Average() AggregateFunction Count() AggregateFunction + // String functions + ByteLength() Expr + CharLength() Expr + EndsWith(suffix any) BooleanExpr + Like(suffix any) BooleanExpr + RegexContains(pattern any) BooleanExpr + RegexMatch(pattern any) BooleanExpr + StartsWith(prefix any) BooleanExpr + StringConcat(otherStrings ...any) Expr + StringContains(substring any) BooleanExpr + StringReverse() Expr + Join(separator any) Expr + Substring(index, offset any) Expr + ToLower() Expr + ToUpper() Expr + Trim() Expr + + // Type functions + IsNaN() BooleanExpr + IsNotNaN() BooleanExpr + IsNull() BooleanExpr + IsNotNull() BooleanExpr + + // Vector functions + CosineDistance(other any) Expr + DotProduct(other any) Expr + EuclideanDistance(other any) Expr + VectorLength() Expr + // Ordering Ascending() Ordering Descending() Ordering @@ -121,6 +164,20 @@ func (b *baseExpr) Pow(other any) Expr { return Pow(b, other) } func (b *baseExpr) Round() Expr { return Round(b) } func (b *baseExpr) Sqrt() Expr { return Sqrt(b) } +// Array functions +func (b *baseExpr) ArrayContains(value any) BooleanExpr { return ArrayContains(b, value) } +func (b *baseExpr) ArrayContainsAll(values any) BooleanExpr { return ArrayContainsAll(b, values) } +func (b *baseExpr) ArrayContainsAny(values any) BooleanExpr { return ArrayContainsAny(b, values) } +func (b *baseExpr) ArrayLength() Expr { return ArrayLength(b) } +func (b *baseExpr) EqualAny(values any) BooleanExpr { return EqualAny(b, values) } +func (b *baseExpr) NotEqualAny(values any) BooleanExpr { return NotEqualAny(b, values) } +func (b *baseExpr) ArrayGet(offset any) Expr { return ArrayGet(b, offset) } +func (b *baseExpr) ArrayReverse() Expr { return ArrayReverse(b) } +func (b *baseExpr) ArrayConcat(otherArrays ...any) Expr { return ArrayConcat(b, otherArrays...) } +func (b *baseExpr) ArraySum() Expr { return ArraySum(b) } +func (b *baseExpr) ArrayMaximum() Expr { return ArrayMaximum(b) } +func (b *baseExpr) ArrayMinimum() Expr { return ArrayMinimum(b) } + // Timestamp functions func (b *baseExpr) TimestampAdd(unit, amount any) Expr { return TimestampAdd(b, unit, amount) } func (b *baseExpr) TimestampSubtract(unit, amount any) Expr { @@ -151,9 +208,39 @@ func (b *baseExpr) CountIf() AggregateFunction { return CountIf(b) } func (b *baseExpr) Maximum() AggregateFunction { return Maximum(b) } func (b *baseExpr) Minimum() AggregateFunction { return Minimum(b) } +// String functions +func (b *baseExpr) ByteLength() Expr { return ByteLength(b) } +func (b *baseExpr) CharLength() Expr { return CharLength(b) } +func (b *baseExpr) EndsWith(suffix any) BooleanExpr { return EndsWith(b, suffix) } +func (b *baseExpr) Like(suffix any) BooleanExpr { return Like(b, suffix) } +func (b *baseExpr) RegexContains(pattern any) BooleanExpr { return RegexContains(b, pattern) } +func (b *baseExpr) RegexMatch(pattern any) BooleanExpr { return RegexMatch(b, pattern) } +func (b *baseExpr) StartsWith(prefix any) BooleanExpr { return StartsWith(b, prefix) } +func (b *baseExpr) StringConcat(otherStrings ...any) Expr { return StringConcat(b, otherStrings...) } +func (b *baseExpr) StringContains(substring any) BooleanExpr { return StringContains(b, substring) } +func (b *baseExpr) StringReverse() Expr { return StringReverse(b) } +func (b *baseExpr) Join(separator any) Expr { return Join(b, separator) } +func (b *baseExpr) Substring(index, offset any) Expr { return Substring(b, index, offset) } +func (b *baseExpr) ToLower() Expr { return ToLower(b) } +func (b *baseExpr) ToUpper() Expr { return ToUpper(b) } +func (b *baseExpr) Trim() Expr { return Trim(b) } + +// Type functions +func (b *baseExpr) IsNaN() BooleanExpr { return IsNaN(b) } +func (b *baseExpr) IsNotNaN() BooleanExpr { return IsNotNaN(b) } +func (b *baseExpr) IsNull() BooleanExpr { return IsNull(b) } +func (b *baseExpr) IsNotNull() BooleanExpr { return IsNotNull(b) } + +// Vector functions +func (b *baseExpr) CosineDistance(other any) Expr { return CosineDistance(b, other) } +func (b *baseExpr) DotProduct(other any) Expr { return DotProduct(b, other) } +func (b *baseExpr) EuclideanDistance(other any) Expr { return EuclideanDistance(b, other) } +func (b *baseExpr) VectorLength() Expr { return VectorLength(b) } + // Ordering func (b *baseExpr) Ascending() Ordering { return Ascending(b) } func (b *baseExpr) Descending() Ordering { return Descending(b) } + func (b *baseExpr) As(alias string) Selectable { return newAliasedExpr(b, alias) } diff --git a/firestore/pipeline_filter_condition.go b/firestore/pipeline_filter_condition.go index f511ee2419c7..618f6b7a75b6 100644 --- a/firestore/pipeline_filter_condition.go +++ b/firestore/pipeline_filter_condition.go @@ -30,6 +30,66 @@ func (b *baseBooleanExpr) isBooleanExpr() {} // Ensure that baseBooleanExpr implements the BooleanExpr interface. var _ BooleanExpr = (*baseBooleanExpr)(nil) +// ArrayContains creates an expression that checks if an array contains a specified element. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expr] that evaluates to an array. +// - value is the element to check for. +// +// Example: +// +// // Check if the 'tags' array contains "Go". +// ArrayContains("tags", "Go") +func ArrayContains(exprOrFieldPath any, value any) BooleanExpr { + return &baseBooleanExpr{baseFunction: newBaseFunction("array_contains", []Expr{toExprOrField(exprOrFieldPath), toExprOrConstant(value)})} +} + +// ArrayContainsAll creates an expression that checks if an array contains all of the provided values. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expr] that evaluates to an array. +// - values can be an array of values or an expression that evaluates to an array. +// +// Example: +// +// // Check if the 'tags' array contains both "Go" and "Firestore". +// ArrayContainsAll("tags", []string{"Go", "Firestore"}) +func ArrayContainsAll(exprOrFieldPath any, values any) BooleanExpr { + return newFieldAndArrayBooleanExpr("array_contains_all", exprOrFieldPath, values) +} + +// ArrayContainsAny creates an expression that checks if an array contains any of the provided values. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expr] that evaluates to an array. +// - values can be an array of values or an expression that evaluates to an array. +// +// Example: +// +// // Check if the 'tags' array contains either "Go" or "Firestore". +// ArrayContainsAny("tags", []string{"Go", "Firestore"}) +func ArrayContainsAny(exprOrFieldPath any, values any) BooleanExpr { + return newFieldAndArrayBooleanExpr("array_contains_any", exprOrFieldPath, values) +} + +// EqualAny creates an expression that checks if a field or expression is equal to any of the provided values. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expr]. +// - values can be an array of values or an expression that evaluates to an array. +// +// Example: +// +// // Check if the 'status' field is either "active" or "pending". +// EqualAny("status", []string{"active", "pending"}) +func EqualAny(exprOrFieldPath any, values any) BooleanExpr { + return newFieldAndArrayBooleanExpr("equal_any", exprOrFieldPath, values) +} + +// NotEqualAny creates an expression that checks if a field or expression is not equal to any of the provided values. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expr]. +// - values can be an array of values or an expression that evaluates to an array. +// +// Example: +// +// // Check if the 'status' field is not "archived" or "deleted". +// NotEqualAny("status", []string{"archived", "deleted"}) +func NotEqualAny(exprOrFieldPath any, values any) BooleanExpr { + return newFieldAndArrayBooleanExpr("not_equal_any", exprOrFieldPath, values) +} + // Equal creates an expression that checks if field's value or an expression is equal to an expression or a constant value, // returning it as a BooleanExpr. // - left: The field path string, [FieldPath] or [Expr] to compare. @@ -171,3 +231,123 @@ func LessThanOrEqual(left, right any) BooleanExpr { func Equivalent(left, right any) BooleanExpr { return &baseBooleanExpr{baseFunction: leftRightToBaseFunction("equivalent", left, right)} } + +// EndsWith creates an expression that checks if a string field or expression ends with a given suffix. +// - exprOrFieldPath can be a field path string, [FieldPath] or [Expr]. +// - suffix string or [Expr] to check for. +// +// Example: +// +// // Check if the 'filename' field ends with ".go". +// EndsWith("filename", ".go") +func EndsWith(exprOrFieldPath any, suffix any) BooleanExpr { + return &baseBooleanExpr{baseFunction: newBaseFunction("ends_with", []Expr{toExprOrField(exprOrFieldPath), asStringExpr(suffix)})} +} + +// Like creates an expression that performs a case-sensitive wildcard string comparison. +// - exprOrFieldPath can be a field path string, [FieldPath] or [Expr]. +// - pattern string or [Expr] to search for. You can use "%" as a wildcard character. +// +// Example: +// +// // Check if the 'name' field starts with "G". +// Like("name", "G%") +func Like(exprOrFieldPath any, pattern any) BooleanExpr { + return &baseBooleanExpr{baseFunction: newBaseFunction("like", []Expr{toExprOrField(exprOrFieldPath), asStringExpr(pattern)})} +} + +// RegexContains creates an expression that checks if a string contains a match for a regular expression. +// - exprOrFieldPath can be a field path string, [FieldPath] or [Expr]. +// - pattern is the regular expression to search for. +// +// Example: +// +// // Check if the 'email' field contains a gmail address. +// RegexContains("email", "@gmail\\.com$") +func RegexContains(exprOrFieldPath any, pattern any) BooleanExpr { + return &baseBooleanExpr{baseFunction: newBaseFunction("regex_contains", []Expr{toExprOrField(exprOrFieldPath), asStringExpr(pattern)})} +} + +// RegexMatch creates an expression that checks if a string matches a regular expression. +// - exprOrFieldPath can be a field path string, [FieldPath] or [Expr]. +// - pattern is the regular expression to match against. +// +// Example: +// +// // Check if the 'zip_code' field is a 5-digit number. +// RegexMatch("zip_code", "^[0-9]{5}$") +func RegexMatch(exprOrFieldPath any, pattern any) BooleanExpr { + return &baseBooleanExpr{baseFunction: newBaseFunction("regex_match", []Expr{toExprOrField(exprOrFieldPath), asStringExpr(pattern)})} +} + +// StartsWith creates an expression that checks if a string field or expression starts with a given prefix. +// - exprOrFieldPath can be a field path string, [FieldPath] or [Expr]. +// - prefix string or [Expr] to check for. +// +// Example: +// +// // Check if the 'name' field starts with "Mr.". +// StartsWith("name", "Mr.") +func StartsWith(exprOrFieldPath any, prefix any) BooleanExpr { + return &baseBooleanExpr{baseFunction: newBaseFunction("starts_with", []Expr{toExprOrField(exprOrFieldPath), asStringExpr(prefix)})} +} + +// StringContains creates an expression that checks if a string contains a specified substring. +// - exprOrFieldPath can be a field path string, [FieldPath] or [Expr]. +// - substring is the string to search for. +// +// Example: +// +// // Check if the 'description' field contains the word "Firestore". +// StringContains("description", "Firestore") +func StringContains(exprOrFieldPath any, substring any) BooleanExpr { + return &baseBooleanExpr{baseFunction: newBaseFunction("string_contains", []Expr{toExprOrField(exprOrFieldPath), asStringExpr(substring)})} +} + +// IsNaN creates a boolean expression that checks if a field or expression evaluates to NaN. +// +// exprOrFieldPath is the field path string, [FieldPath] or [Expr] that will be evaluated. +// +// Example: +// +// // Check if the 'score' field is NaN. +// IsNaN("score") +func IsNaN(exprOrFieldPath any) BooleanExpr { + return &baseBooleanExpr{baseFunction: newBaseFunction("is_nan", []Expr{toExprOrField(exprOrFieldPath)})} +} + +// IsNotNaN creates a boolean expression that checks if a field or expression does not evaluate to NaN. +// +// exprOrFieldPath is the field path string, [FieldPath] or [Expr] that will be evaluated. +// +// Example: +// +// // Check if the 'score' field is not NaN. +// IsNotNaN("score") +func IsNotNaN(exprOrFieldPath any) BooleanExpr { + return &baseBooleanExpr{baseFunction: newBaseFunction("is_not_nan", []Expr{toExprOrField(exprOrFieldPath)})} +} + +// IsNull creates a boolean expression that checks if a field or expression evaluates to null. +// +// exprOrFieldPath is the field path string, [FieldPath] or [Expr] that will be evaluated. +// +// Example: +// +// // Check if the 'address' field is null. +// IsNull("address") +func IsNull(exprOrFieldPath any) BooleanExpr { + return &baseBooleanExpr{baseFunction: newBaseFunction("is_null", []Expr{toExprOrField(exprOrFieldPath)})} +} + +// IsNotNull creates a boolean expression that checks if a field or expression does not evaluate to null. +// +// exprOrFieldPath is the field path string, [FieldPath] or [Expr] that will be evaluated. +// +// Example: +// +// // Check if the 'address' field is not null. +// IsNotNull("address") +func IsNotNull(exprOrFieldPath any) BooleanExpr { + return &baseBooleanExpr{baseFunction: newBaseFunction("is_not_null", []Expr{toExprOrField(exprOrFieldPath)})} +} diff --git a/firestore/pipeline_function.go b/firestore/pipeline_function.go index ae9e725ce99e..3d095bf69a67 100644 --- a/firestore/pipeline_function.go +++ b/firestore/pipeline_function.go @@ -43,7 +43,7 @@ func newBaseFunction(name string, params []Expr) *baseFunction { paramExpr := asFieldExpr(param) pbVal, err := paramExpr.toProto() if err != nil { - return &baseFunction{baseExpr: &baseExpr{err: fmt.Errorf("error converting arg %d for function %q: %w", i, name, err)}} + return &baseFunction{baseExpr: &baseExpr{err: fmt.Errorf("firestore: error converting arg %d for function %q: %w", i, name, err)}} } argsPbVals = append(argsPbVals, pbVal) } @@ -319,3 +319,252 @@ func UnixSecondsToTimestamp(seconds any) Expr { func CurrentTimestamp() Expr { return newBaseFunction("current_timestamp", []Expr{}) } + +// ArrayLength creates an expression that calculates the length of an array. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expr] that evaluates to an array. +// +// Example: +// +// // Get the length of the 'tags' array field. +// ArrayLength("tags") +func ArrayLength(exprOrFieldPath any) Expr { + return newBaseFunction("array_length", []Expr{toExprOrField(exprOrFieldPath)}) +} + +// Array creates an expression that represents a Firestore array. +// - elements can be any number of values or expressions that will form the elements of the array. +// +// Example: +// +// // Create an array of numbers. +// Array(1, 2, 3) +func Array(elements ...any) Expr { + return newBaseFunction("array", toExprs(elements)) +} + +// ArrayFromSlice creates a new array expression from a slice of elements. +// This function is necessary for creating an array from an existing typed slice (e.g., []int), +// as the [Array] function (which takes variadic arguments) cannot directly accept a typed slice +// using the spread operator (...). It handles the conversion of each element to `any` internally. +func ArrayFromSlice[T any](elements []T) Expr { + return newBaseFunction("array", toExprsFromSlice(elements)) +} + +// ArrayGet creates an expression that retrieves an element from an array at a specified index. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expr] that evaluates to an array. +// - offset is the 0-based index of the element to retrieve. +// +// Example: +// +// // Get the first element of the 'tags' array field. +// ArrayGet("tags", 0) +func ArrayGet(exprOrFieldPath any, offset any) Expr { + return newBaseFunction("array_get", []Expr{toExprOrField(exprOrFieldPath), asInt64Expr(offset)}) +} + +// ArrayReverse creates an expression that reverses the order of elements in an array. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expr] that evaluates to an array. +// +// Example: +// +// // Reverse the 'tags' array. +// ArrayReverse("tags") +func ArrayReverse(exprOrFieldPath any) Expr { + return newBaseFunction("array_reverse", []Expr{toExprOrField(exprOrFieldPath)}) +} + +// ArrayConcat creates an expression that concatenates multiple arrays into a single array. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expr] that evaluates to an array. +// - otherArrays are the other arrays to concatenate. +// +// Example: +// +// // Concatenate the 'tags' and 'categories' array fields. +// ArrayConcat("tags", FieldOf("categories")) +func ArrayConcat(exprOrFieldPath any, otherArrays ...any) Expr { + return newBaseFunction("array_concat", append([]Expr{toExprOrField(exprOrFieldPath)}, toExprs(otherArrays)...)) +} + +// ArraySum creates an expression that calculates the sum of all elements in a numeric array. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expr] that evaluates to a numeric array. +// +// Example: +// +// // Calculate the sum of the 'scores' array. +// ArraySum("scores") +func ArraySum(exprOrFieldPath any) Expr { + return newBaseFunction("sum", []Expr{toExprOrField(exprOrFieldPath)}) +} + +// ArrayMaximum creates an expression that finds the maximum element in a numeric array. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expr] that evaluates to a numeric array. +// +// Example: +// +// // Find the maximum value in the 'scores' array. +// ArrayMaximum("scores") +func ArrayMaximum(exprOrFieldPath any) Expr { + return newBaseFunction("maximum", []Expr{toExprOrField(exprOrFieldPath)}) +} + +// ArrayMinimum creates an expression that finds the minimum element in a numeric array. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expr] that evaluates to a numeric array. +// +// Example: +// +// // Find the minimum value in the 'scores' array. +// ArrayMinimum("scores") +func ArrayMinimum(exprOrFieldPath any) Expr { + return newBaseFunction("minimum", []Expr{toExprOrField(exprOrFieldPath)}) +} + +// ByteLength creates an expression that calculates the length of a string represented by a field or [Expr] in UTF-8 +// bytes, or just the length of a Blob. +// - exprOrFieldPath can be a field path string, [FieldPath] or [Expr]. +// +// Example: +// +// // Get the byte length of the 'name' field. +// ByteLength("name") +func ByteLength(exprOrFieldPath any) Expr { + return newBaseFunction("byte_length", []Expr{toExprOrField(exprOrFieldPath)}) +} + +// CharLength creates an expression that calculates the character length of a string field or expression in UTF8. +// - exprOrFieldPath can be a field path string, [FieldPath] or [Expr]. +// +// Example: +// +// // Get the character length of the 'name' field. +// CharLength("name") +func CharLength(exprOrFieldPath any) Expr { + return newBaseFunction("char_length", []Expr{toExprOrField(exprOrFieldPath)}) +} + +// StringConcat creates an expression that concatenates multiple strings into a single string. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expr] that evaluates to a string. +// - otherStrings are the other strings to concatenate. +// +// Example: +// +// // Concatenate first name and last name. +// StringConcat(FieldOf("firstName"), " ", FieldOf("lastName")) +func StringConcat(exprOrFieldPath any, otherStrings ...any) Expr { + return newBaseFunction("string_concat", append([]Expr{toExprOrField(exprOrFieldPath)}, toExprs(otherStrings)...)) +} + +// StringReverse creates an expression that reverses a string. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expr] that evaluates to a string. +// +// Example: +// +// // Reverse the 'name' field. +// StringReverse("name") +func StringReverse(exprOrFieldPath any) Expr { + return newBaseFunction("string_reverse", []Expr{toExprOrField(exprOrFieldPath)}) +} + +// Join creates an expression that joins the elements of a string array into a single string. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expr] that evaluates to a string array. +// - separator is the string to use as a separator between elements. +// +// Example: +// +// // Join the 'tags' array with a comma and space. +// Join("tags", ", ") +func Join(exprOrFieldPath any, separator any) Expr { + return newBaseFunction("join", []Expr{toExprOrField(exprOrFieldPath), asStringExpr(separator)}) +} + +// Substring creates an expression that returns a substring of a string. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expr] that evaluates to a string. +// - index is the starting index of the substring. +// - offset is the length of the substring. +// +// Example: +// +// // Get the first 5 characters of the 'description' field. +// Substring("description", 0, 5) +func Substring(exprOrFieldPath any, index any, offset any) Expr { + return newBaseFunction("substring", []Expr{toExprOrField(exprOrFieldPath), asInt64Expr(index), asInt64Expr(offset)}) +} + +// ToLower creates an expression that converts a string to lowercase. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expr] that evaluates to a string. +// +// Example: +// +// // Convert the 'username' to lowercase. +// ToLower("username") +func ToLower(exprOrFieldPath any) Expr { + return newBaseFunction("to_lower", []Expr{toExprOrField(exprOrFieldPath)}) +} + +// ToUpper creates an expression that converts a string to uppercase. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expr] that evaluates to a string. +// +// Example: +// +// // Convert the 'product_code' to uppercase. +// ToUpper("product_code") +func ToUpper(exprOrFieldPath any) Expr { + return newBaseFunction("to_upper", []Expr{toExprOrField(exprOrFieldPath)}) +} + +// Trim creates an expression that removes leading and trailing whitespace from a string. +// - exprOrFieldPath can be a field path string, [FieldPath] or an [Expr] that evaluates to a string. +// +// Example: +// +// // Trim the 'email' field. +// Trim("email") +func Trim(exprOrFieldPath any) Expr { + return newBaseFunction("trim", []Expr{toExprOrField(exprOrFieldPath)}) +} + +// CosineDistance creates an expression that calculates the cosine distance between two vectors. +// - vector1 can be a field path string, [FieldPath] or [Expr]. +// - vector2 can be [Vector32], [Vector64], []float32, []float64 or [Expr]. +// +// Example: +// +// // Calculate the cosine distance between two vector fields. +// CosineDistance("vector_field_1", FieldOf("vector_field_2")) +func CosineDistance(vector1 any, vector2 any) Expr { + return newBaseFunction("cosine_distance", []Expr{toExprOrField(vector1), asVectorExpr(vector2)}) +} + +// DotProduct creates an expression that calculates the dot product of two vectors. +// - vector1 can be a field path string, [FieldPath] or [Expr]. +// - vector2 can be [Vector32], [Vector64], []float32, []float64 or [Expr]. +// +// Example: +// +// // Calculate the dot product of two vector fields. +// DotProduct("vector_field_1", FieldOf("vector_field_2")) +func DotProduct(vector1 any, vector2 any) Expr { + return newBaseFunction("dot_product", []Expr{toExprOrField(vector1), asVectorExpr(vector2)}) +} + +// EuclideanDistance creates an expression that calculates the euclidean distance between two vectors. +// - vector1 can be a field path string, [FieldPath] or [Expr]. +// - vector2 can be [Vector32], [Vector64], []float32, []float64 or [Expr]. +// +// Example: +// +// // Calculate the euclidean distance between two vector fields. +// EuclideanDistance("vector_field_1", FieldOf("vector_field_2")) +func EuclideanDistance(vector1 any, vector2 any) Expr { + return newBaseFunction("euclidean_distance", []Expr{toExprOrField(vector1), asVectorExpr(vector2)}) +} + +// VectorLength creates an expression that calculates the length of a vector. +// - exprOrFieldPath can be a field path string, [FieldPath] or [Expr]. +// +// Example: +// +// // Calculate the length of a vector field. +// VectorLength("vector_field") +func VectorLength(exprOrFieldPath any) Expr { + return newBaseFunction("vector_length", []Expr{toExprOrField(exprOrFieldPath)}) +} diff --git a/firestore/pipeline_integration_test.go b/firestore/pipeline_integration_test.go new file mode 100644 index 000000000000..b8f9695a0235 --- /dev/null +++ b/firestore/pipeline_integration_test.go @@ -0,0 +1,543 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package firestore + +import ( + "context" + "math" + "strings" + "testing" + "time" + + "cloud.google.com/go/internal/testutil" + "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestIntegration_PipelineFunctions(t *testing.T) { + if testParams[firestoreEditionKey].(firestoreEdition) != editionEnterprise { + t.Skip("Skipping pipeline queries tests since the firestore edition of", testParams[databaseIDKey].(string), "database is not enterprise") + } + t.Run("arrayFuncs", arrayFuncs) + t.Run("stringFuncs", stringFuncs) + t.Run("typeFuncs", typeFuncs) + t.Run("vectorFuncs", vectorFuncs) + +} + +func arrayFuncs(t *testing.T) { + t.Parallel() + h := testHelper{t} + client := integrationClient(t) + coll := client.Collection(collectionIDs.New()) + docRef1 := coll.NewDoc() + h.mustCreate(docRef1, map[string]interface{}{ + "a": []interface{}{1, 2, 3}, + "b": []interface{}{4, 5, 6}, + "tags": []string{"Go", "Firestore", "GCP"}, + "tags2": []string{"Go", "Firestore"}, + "lang": "Go", + "status": "active", + }) + defer deleteDocuments([]*DocumentRef{docRef1}) + + tests := []struct { + name string + pipeline *Pipeline + want map[string]interface{} + }{ + { + name: "ArrayLength", + pipeline: client.Pipeline().Collection(coll.ID).Select(ArrayLength("a").As("length")), + want: map[string]interface{}{"length": int64(3)}, + }, + { + name: "Array", + pipeline: client.Pipeline().Collection(coll.ID).Select(Array(1, 2, 3).As("array")), + want: map[string]interface{}{"array": []interface{}{int64(1), int64(2), int64(3)}}, + }, + { + name: "ArrayFromSlice", + pipeline: client.Pipeline().Collection(coll.ID).Select(ArrayFromSlice([]int{1, 2, 3}).As("array")), + want: map[string]interface{}{"array": []interface{}{int64(1), int64(2), int64(3)}}, + }, + { + name: "ArrayGet", + pipeline: client.Pipeline().Collection(coll.ID).Select(ArrayGet("a", 1).As("element")), + want: map[string]interface{}{"element": int64(2)}, + }, + { + name: "ArrayReverse", + pipeline: client.Pipeline().Collection(coll.ID).Select(ArrayReverse("a").As("reversed")), + want: map[string]interface{}{"reversed": []interface{}{int64(3), int64(2), int64(1)}}, + }, + { + name: "ArrayConcat", + pipeline: client.Pipeline().Collection(coll.ID).Select(ArrayConcat("a", FieldOf("b")).As("concatenated")), + want: map[string]interface{}{"concatenated": []interface{}{int64(1), int64(2), int64(3), int64(4), int64(5), int64(6)}}, + }, + { + name: "ArraySum", + pipeline: client.Pipeline().Collection(coll.ID).Select(ArraySum("a").As("sum")), + want: map[string]interface{}{"sum": int64(6)}, + }, + { + name: "ArrayMaximum", + pipeline: client.Pipeline().Collection(coll.ID).Select(ArrayMaximum("a").As("max")), + want: map[string]interface{}{"max": int64(3)}, + }, + { + name: "ArrayMinimum", + pipeline: client.Pipeline().Collection(coll.ID).Select(ArrayMinimum("a").As("min")), + want: map[string]interface{}{"min": int64(1)}, + }, + // Array filter conditions + { + name: "ArrayContains", + pipeline: client.Pipeline().Collection(coll.ID).Where(ArrayContains("tags", "Go")), + want: map[string]interface{}{"lang": "Go", "tags": []interface{}{"Go", "Firestore", "GCP"}, "tags2": []interface{}{"Go", "Firestore"}, "status": "active", "a": []interface{}{int64(1), int64(2), int64(3)}, "b": []interface{}{int64(4), int64(5), int64(6)}}, + }, + { + name: "ArrayContainsAll - array of mixed types", + pipeline: client.Pipeline().Collection(coll.ID).Where(ArrayContainsAll("tags", []any{FieldOf("lang"), "Firestore"})), + want: map[string]interface{}{"lang": "Go", "tags": []interface{}{"Go", "Firestore", "GCP"}, "tags2": []interface{}{"Go", "Firestore"}, "status": "active", "a": []interface{}{int64(1), int64(2), int64(3)}, "b": []interface{}{int64(4), int64(5), int64(6)}}, + }, + { + name: "ArrayContainsAll - array of constants", + pipeline: client.Pipeline().Collection(coll.ID).Where(ArrayContainsAll("tags", []string{"Go", "Firestore"})), + want: map[string]interface{}{"lang": "Go", "tags": []interface{}{"Go", "Firestore", "GCP"}, "tags2": []interface{}{"Go", "Firestore"}, "status": "active", "a": []interface{}{int64(1), int64(2), int64(3)}, "b": []interface{}{int64(4), int64(5), int64(6)}}, + }, + { + name: "ArrayContainsAll - Expr", + pipeline: client.Pipeline().Collection(coll.ID).Where(ArrayContainsAll("tags", FieldOf("tags2"))), + want: map[string]interface{}{"lang": "Go", "tags": []interface{}{"Go", "Firestore", "GCP"}, "tags2": []interface{}{"Go", "Firestore"}, "status": "active", "a": []interface{}{int64(1), int64(2), int64(3)}, "b": []interface{}{int64(4), int64(5), int64(6)}}, + }, + { + name: "ArrayContainsAny", + pipeline: client.Pipeline().Collection(coll.ID).Where(ArrayContainsAny("tags", []string{"Go", "Java"})), + want: map[string]interface{}{"lang": "Go", "tags": []interface{}{"Go", "Firestore", "GCP"}, "tags2": []interface{}{"Go", "Firestore"}, "status": "active", "a": []interface{}{int64(1), int64(2), int64(3)}, "b": []interface{}{int64(4), int64(5), int64(6)}}, + }, + { + name: "EqualAny", + pipeline: client.Pipeline().Collection(coll.ID).Where(EqualAny("status", []string{"active", "pending"})), + want: map[string]interface{}{"lang": "Go", "tags": []interface{}{"Go", "Firestore", "GCP"}, "tags2": []interface{}{"Go", "Firestore"}, "status": "active", "a": []interface{}{int64(1), int64(2), int64(3)}, "b": []interface{}{int64(4), int64(5), int64(6)}}, + }, + { + name: "NotEqualAny", + pipeline: client.Pipeline().Collection(coll.ID).Where(NotEqualAny("status", []string{"archived", "deleted"})), + want: map[string]interface{}{"lang": "Go", "tags": []interface{}{"Go", "Firestore", "GCP"}, "tags2": []interface{}{"Go", "Firestore"}, "status": "active", "a": []interface{}{int64(1), int64(2), int64(3)}, "b": []interface{}{int64(4), int64(5), int64(6)}}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + testutil.Retry(t, 3, time.Second, func(r *testutil.R) { + ctx := context.Background() + iter := test.pipeline.Execute(ctx) + defer iter.Stop() + + docs, err := iter.GetAll() + if isRetryablePipelineExecuteErr(err) { + r.Errorf("GetAll: %v. Retrying....", err) + return + } else if err != nil { + r.Fatalf("GetAll: %v", err) + return + } + if len(docs) != 1 { + r.Fatalf("expected 1 doc, got %d", len(docs)) + return + } + got, err := docs[0].Data() + if err != nil { + r.Fatalf("Data: %v", err) + return + } + if diff := testutil.Diff(got, test.want); diff != "" { + r.Errorf("got: %v, want: %v, diff +want -got: %s", got, test.want, diff) + return + } + }) + }) + } +} + +func stringFuncs(t *testing.T) { + t.Parallel() + h := testHelper{t} + client := integrationClient(t) + coll := client.Collection(collectionIDs.New()) + docRef1 := coll.NewDoc() + h.mustCreate(docRef1, map[string]interface{}{ + "name": " John Doe ", + "description": "This is a Firestore document.", + "productCode": "abc-123", + "tags": []string{"tag1", "tag2", "tag3"}, + "email": "john.doe@example.com", + "zipCode": "12345", + }) + defer deleteDocuments([]*DocumentRef{docRef1}) + + doc1want := map[string]interface{}{ + "name": " John Doe ", + "description": "This is a Firestore document.", + "productCode": "abc-123", + "tags": []interface{}{"tag1", "tag2", "tag3"}, + "email": "john.doe@example.com", + "zipCode": "12345", + } + + tests := []struct { + name string + pipeline *Pipeline + want interface{} + }{ + { + name: "ByteLength", + pipeline: client.Pipeline().Collection(coll.ID).Select(ByteLength("name").As("byte_length")), + want: map[string]interface{}{"byte_length": int64(12)}, + }, + { + name: "CharLength", + pipeline: client.Pipeline().Collection(coll.ID).Select(CharLength("name").As("char_length")), + want: map[string]interface{}{"char_length": int64(12)}, + }, + { + name: "StringConcat", + pipeline: client.Pipeline().Collection(coll.ID).Select(StringConcat(FieldOf("name"), " - ", FieldOf("productCode")).As("concatenated_string")), + want: map[string]interface{}{"concatenated_string": " John Doe - abc-123"}, + }, + { + name: "StringReverse", + pipeline: client.Pipeline().Collection(coll.ID).Select(StringReverse("name").As("reversed_string")), + want: map[string]interface{}{"reversed_string": " eoD nhoJ "}, + }, + { + name: "Join", + pipeline: client.Pipeline().Collection(coll.ID).Select(Join("tags", ", ").As("joined_string")), + want: map[string]interface{}{"joined_string": "tag1, tag2, tag3"}, + }, + { + name: "Substring", + pipeline: client.Pipeline().Collection(coll.ID).Select(Substring("description", 0, 4).As("substring")), + want: map[string]interface{}{"substring": "This"}, + }, + { + name: "ToLower", + pipeline: client.Pipeline().Collection(coll.ID).Select(ToLower("name").As("lowercase_name")), + want: map[string]interface{}{"lowercase_name": " john doe "}, + }, + { + name: "ToUpper", + pipeline: client.Pipeline().Collection(coll.ID).Select(ToUpper("name").As("uppercase_name")), + want: map[string]interface{}{"uppercase_name": " JOHN DOE "}, + }, + { + name: "Trim", + pipeline: client.Pipeline().Collection(coll.ID).Select(Trim("name").As("trimmed_name")), + want: map[string]interface{}{"trimmed_name": "John Doe"}, + }, + // String filter conditions + { + name: "Like", + pipeline: client.Pipeline().Collection(coll.ID).Where(Like("name", "%John%")), + want: []map[string]interface{}{doc1want}, + }, + { + name: "StartsWith", + pipeline: client.Pipeline().Collection(coll.ID).Where(StartsWith("name", " John")), + want: []map[string]interface{}{doc1want}, + }, + { + name: "EndsWith", + pipeline: client.Pipeline().Collection(coll.ID).Where(EndsWith("name", "Doe ")), + want: []map[string]interface{}{doc1want}, + }, + { + name: "RegexContains", + pipeline: client.Pipeline().Collection(coll.ID).Where(RegexContains("email", "@example\\.com")), + want: []map[string]interface{}{doc1want}, + }, + { + name: "RegexMatch", + pipeline: client.Pipeline().Collection(coll.ID).Where(RegexMatch("zipCode", "^[0-9]{5}$")), + want: []map[string]interface{}{doc1want}, + }, + { + name: "StringContains", + pipeline: client.Pipeline().Collection(coll.ID).Where(StringContains("description", "Firestore")), + want: []map[string]interface{}{doc1want}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + testutil.Retry(t, 3, time.Second, func(r *testutil.R) { + ctx := context.Background() + + iter := test.pipeline.Execute(ctx) + defer iter.Stop() + + docs, err := iter.GetAll() + if isRetryablePipelineExecuteErr(err) { + r.Errorf("GetAll: %v. Retrying....", err) + return + } else if err != nil { + r.Fatalf("GetAll: %v", err) + return + } + lastStage := test.pipeline.stages[len(test.pipeline.stages)-1] + lastStageName := lastStage.name() + + if lastStageName == stageNameSelect { // This is a select query + want, ok := test.want.(map[string]interface{}) + if !ok { + r.Fatalf("invalid test.want type for select query: %T", test.want) + return + } + if len(docs) != 1 { + r.Fatalf("expected 1 doc, got %d", len(docs)) + return + } + got, err := docs[0].Data() + if err != nil { + r.Fatalf("Data: %v", err) + return + } + if diff := testutil.Diff(got, want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", got, want, diff) + } + } else if lastStageName == stageNameWhere { // This is a where query (filter condition) + want, ok := test.want.([]map[string]interface{}) + if !ok { + r.Fatalf("invalid test.want type for where query: %T", test.want) + return + } + if len(docs) != len(want) { + r.Fatalf("expected %d doc(s), got %d", len(want), len(docs)) + return + } + var gots []map[string]interface{} + for _, doc := range docs { + got, err := doc.Data() + if err != nil { + r.Fatalf("Data: %v", err) + return + } + gots = append(gots, got) + } + if diff := testutil.Diff(gots, want); diff != "" { + t.Errorf("got: %v, want: %v, diff +want -got: %s", gots, want, diff) + } + } else { + r.Fatalf("unknown pipeline stage: %s", lastStageName) + return + } + }) + }) + } + +} + +func typeFuncs(t *testing.T) { + t.Parallel() + ctx := context.Background() + client := integrationClient(t) + coll := client.Collection(collectionIDs.New()) + docWithNaN := map[string]interface{}{ + "docID": 1, + "value": math.NaN(), + "type": "nan", + } + _, err := coll.Doc("docNaN").Create(ctx, docWithNaN) + if err != nil { + t.Fatalf("Create: %v", err) + } + + docWithNull := map[string]interface{}{ + "docID": 2, + "value": nil, + "type": "null", + } + _, err = coll.Doc("docNull").Create(ctx, docWithNull) + if err != nil { + t.Fatalf("Create: %v", err) + } + + docWithNumber := map[string]interface{}{ + "docID": 3, + "value": 123, + "type": "number", + } + _, err = coll.Doc("docNum").Create(ctx, docWithNumber) + if err != nil { + t.Fatalf("Create: %v", err) + } + defer deleteDocuments([]*DocumentRef{coll.Doc("docNaN"), coll.Doc("docNull"), coll.Doc("docNum")}) + + wantNaN := map[string]interface{}{"docID": int64(1), "value": math.NaN(), "type": "nan"} + wantNull := map[string]interface{}{"docID": int64(2), "value": nil, "type": "null"} + wantNum := map[string]interface{}{"docID": int64(3), "value": int64(123), "type": "number"} + + tests := []struct { + name string + pipeline *Pipeline + want []map[string]interface{} + }{ + { + name: "IsNull", + pipeline: client.Pipeline().Collection(coll.ID).Where(IsNull("value")), + want: []map[string]interface{}{wantNull}, + }, + { + name: "IsNotNull", + pipeline: client.Pipeline().Collection(coll.ID).Where(IsNotNull("value")), + want: []map[string]interface{}{wantNaN, wantNum}, + }, + { + name: "IsNaN", + pipeline: client.Pipeline().Collection(coll.ID).Where(IsNaN("value")), + want: []map[string]interface{}{wantNaN}, + }, + { + name: "IsNotNaN", + pipeline: client.Pipeline().Collection(coll.ID).Where(IsNotNaN("value")), + want: []map[string]interface{}{wantNum}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + testutil.Retry(t, 3, time.Second, func(r *testutil.R) { + ctx := context.Background() + iter := test.pipeline.Execute(ctx) + defer iter.Stop() + + docs, err := iter.GetAll() + if isRetryablePipelineExecuteErr(err) { + r.Errorf("GetAll: %v. Retrying....", err) + return + } else if err != nil { + r.Fatalf("GetAll: %v", err) + return + } + if diff := testutil.Diff(docsToMaps(t, docs), test.want, + cmpopts.SortSlices(func(a, b map[string]interface{}) bool { return a["docID"].(int64) < b["docID"].(int64) })); diff != "" { + r.Errorf("mismatch (+want -got):\n%s", diff) + } + }) + }) + } +} + +func vectorFuncs(t *testing.T) { + t.Parallel() + h := testHelper{t} + client := integrationClient(t) + coll := client.Collection(collectionIDs.New()) + docRef1 := coll.NewDoc() + h.mustCreate(docRef1, map[string]interface{}{ + "v1": Vector64{1.0, 2.0, 3.0}, + "v2": Vector64{4.0, 5.0, 6.0}, + }) + defer deleteDocuments([]*DocumentRef{docRef1}) + + tests := []struct { + name string + pipeline *Pipeline + want map[string]interface{} + }{ + { + name: "VectorLength", + pipeline: client.Pipeline().Collection(coll.ID).Select(VectorLength("v1").As("length")), + want: map[string]interface{}{"length": int64(3)}, + }, + { + name: "DotProduct - field and field", + pipeline: client.Pipeline().Collection(coll.ID).Select(DotProduct("v1", FieldOf("v2")).As("dot_product")), + want: map[string]interface{}{"dot_product": float64(1*4 + 2*5 + 3*6)}, + }, + { + name: "DotProduct - field and constant", + pipeline: client.Pipeline().Collection(coll.ID).Select(DotProduct("v1", Vector64{4.0, 5.0, 6.0}).As("dot_product")), + want: map[string]interface{}{"dot_product": float64(1*4 + 2*5 + 3*6)}, + }, + { + name: "EuclideanDistance - field and field", + pipeline: client.Pipeline().Collection(coll.ID).Select(EuclideanDistance("v1", FieldOf("v2")).As("euclidean")), + want: map[string]interface{}{"euclidean": math.Sqrt(math.Pow(4-1, 2) + math.Pow(5-2, 2) + math.Pow(6-3, 2))}, + }, + { + name: "EuclideanDistance - field and constant", + pipeline: client.Pipeline().Collection(coll.ID).Select(EuclideanDistance("v1", Vector64{4.0, 5.0, 6.0}).As("euclidean")), + want: map[string]interface{}{"euclidean": math.Sqrt(math.Pow(4-1, 2) + math.Pow(5-2, 2) + math.Pow(6-3, 2))}, + }, + { + name: "CosineDistance - field and field", + pipeline: client.Pipeline().Collection(coll.ID).Select(CosineDistance("v1", FieldOf("v2")).As("cosine")), + want: map[string]interface{}{"cosine": 1 - (32 / (math.Sqrt(14) * math.Sqrt(77)))}, + }, + { + name: "CosineDistance - field and constant", + pipeline: client.Pipeline().Collection(coll.ID).Select(CosineDistance("v1", Vector64{4.0, 5.0, 6.0}).As("cosine")), + want: map[string]interface{}{"cosine": 1 - (32 / (math.Sqrt(14) * math.Sqrt(77)))}, + }, + } + + ctx := context.Background() + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + testutil.Retry(t, 3, time.Second, func(r *testutil.R) { + iter := test.pipeline.Execute(ctx) + defer iter.Stop() + + docs, err := iter.GetAll() + if isRetryablePipelineExecuteErr(err) { + r.Errorf("GetAll: %v. Retrying....", err) + return + } else if err != nil { + r.Fatalf("GetAll: %v", err) + return + } + if len(docs) != 1 { + r.Fatalf("expected 1 doc, got %d", len(docs)) + return + } + got, err := docs[0].Data() + if err != nil { + r.Fatalf("Data: %v", err) + return + } + if diff := testutil.Diff(got, test.want); diff != "" { + r.Errorf("got: %v, want: %v, diff +want -got: %s", got, test.want, diff) + } + }) + }) + } +} + +func isRetryablePipelineExecuteErr(err error) bool { + if err == nil { + return false + } + s, ok := status.FromError(err) + if !ok { + return false + } + return s.Code() == codes.InvalidArgument && + strings.Contains(s.Message(), "Invalid request routing header") && + strings.Contains(s.Message(), "Please fill in the request header with format") +} diff --git a/firestore/pipeline_utils.go b/firestore/pipeline_utils.go index 9999d4829047..2ced52e8c79e 100644 --- a/firestore/pipeline_utils.go +++ b/firestore/pipeline_utils.go @@ -17,10 +17,90 @@ package firestore import ( "errors" "fmt" + "reflect" pb "cloud.google.com/go/firestore/apiv1/firestorepb" ) +// newFieldAndArrayBooleanExpr creates a new BooleanExpr for functions that operate on a field/expression and an array of values. +func newFieldAndArrayBooleanExpr(name string, exprOrFieldPath any, values any) BooleanExpr { + return &baseBooleanExpr{baseFunction: newBaseFunction(name, []Expr{toExprOrField(exprOrFieldPath), asArrayFunctionExpr(values)})} +} + +// toExprs converts a plain Go value or an existing Expr into an Expr. +// Plain values are wrapped in a Constant. +func toExprs(val []any) []Expr { + exprs := make([]Expr, len(val)) + for i, v := range val { + exprs[i] = toExprOrConstant(v) + } + return exprs +} + +// toExprsFromSlice converts a slice of any type into a slice of Expr, wrapping plain values in Constants. +func toExprsFromSlice[T any](val []T) []Expr { + exprs := make([]Expr, len(val)) + for i, v := range val { + exprs[i] = toExprOrConstant(v) + } + return exprs +} + +// val should be single Expr or array of Expr/constants +func asArrayFunctionExpr(val any) Expr { + if expr, ok := val.(Expr); ok { + return expr + } + + arrayVal := reflect.ValueOf(val) + if arrayVal.Kind() != reflect.Slice { + return &baseExpr{err: fmt.Errorf("firestore: value must be a slice or Expr, but got %T", val)} + } + + // Convert the slice of any to []Expr + var exprs []Expr + for i := 0; i < arrayVal.Len(); i++ { + exprs = append(exprs, toExprOrConstant(arrayVal.Index(i).Interface())) + } + return newBaseFunction("array", exprs) +} + +// asInt64Expr converts a value to an Expr that evaluates to an int64, or returns an error Expr if conversion is not possible. +func asInt64Expr(val any) Expr { + switch v := val.(type) { + case Expr: + return v + case int, int8, int16, int32, int64, uint8, uint16, uint32: + return ConstantOf(v) + default: + return &baseExpr{err: fmt.Errorf("firestore: value must be a int, int8, int16, int32, int64, uint8, uint16, uint32 or Expr, but got %T", val)} + } +} + +// asStringExpr converts a value to an Expr that evaluates to a string, or returns an error Expr if conversion is not possible. +func asStringExpr(val any) Expr { + switch v := val.(type) { + case Expr: + return v + case string: + return ConstantOf(v) + default: + return &baseExpr{err: fmt.Errorf("firestore: value must be a string or Expr, but got %T", val)} + } +} + +// asVectorExpr converts a value to an Expr that evaluates to a vector type (Vector32, Vector64, []float32, []float64), or returns an error Expr if conversion is not possible. +func asVectorExpr(val any) Expr { + switch v := val.(type) { + case Expr: + return v + case Vector32, Vector64, []float32, []float64: + return ConstantOf(v) + default: + return &baseExpr{err: fmt.Errorf("firestore: value must be a []float32, []float64, Vector32, Vector64 or Expr, but got %T", val)} + } +} + // toExprOrConstant converts a plain Go value or an existing Expr into an Expr. // Plain values are wrapped in a Constant. func toExprOrConstant(val any) Expr { @@ -109,7 +189,7 @@ func aliasedAggregatesToMapValue(aggregates []*AliasedAggregate) (*pb.Value, err base := agg.getBaseAggregateFunction() if base.err != nil { - return nil, fmt.Errorf("error in aggregate expression for alias %q: %w", agg.alias, base.err) + return nil, fmt.Errorf("firestore: error in aggregate expression for alias %q: %w", agg.alias, base.err) } protoVal, err := base.toProto() if err != nil { diff --git a/firestore/util_test.go b/firestore/util_test.go index 825b8db92440..aeb68badb8c9 100644 --- a/firestore/util_test.go +++ b/firestore/util_test.go @@ -143,3 +143,15 @@ func mapval(m map[string]*pb.Value) *pb.Value { func refval(path string) *pb.Value { return &pb.Value{ValueType: &pb.Value_ReferenceValue{ReferenceValue: path}} } + +func docsToMaps(t *testing.T, docs []*PipelineResult) []map[string]interface{} { + var maps []map[string]interface{} + for _, doc := range docs { + data, err := doc.Data() + if err != nil { + t.Fatalf("Data: %v", err) + } + maps = append(maps, data) + } + return maps +}