diff --git a/google/cloud/firestore_v1/_pipeline_stages.py b/google/cloud/firestore_v1/_pipeline_stages.py index 7233a8eec..c63b748ac 100644 --- a/google/cloud/firestore_v1/_pipeline_stages.py +++ b/google/cloud/firestore_v1/_pipeline_stages.py @@ -274,13 +274,17 @@ def __init__( self, field: str | Expr, vector: Sequence[float] | Vector, - distance_measure: "DistanceMeasure", + distance_measure: "DistanceMeasure" | str, options: Optional["FindNearestOptions"] = None, ): super().__init__("find_nearest") self.field: Expr = Field(field) if isinstance(field, str) else field self.vector: Vector = vector if isinstance(vector, Vector) else Vector(vector) - self.distance_measure = distance_measure + self.distance_measure = ( + distance_measure + if isinstance(distance_measure, DistanceMeasure) + else DistanceMeasure[distance_measure.upper()] + ) self.options = options or FindNearestOptions() def _pb_args(self): diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 4639e0f7d..b113e2874 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -17,7 +17,6 @@ Any, Generic, TypeVar, - Dict, Sequence, ) from abc import ABC @@ -41,8 +40,6 @@ bytes, GeoPoint, Vector, - list, - Dict[str, Any], None, ) @@ -113,8 +110,20 @@ def _to_pb(self) -> Value: raise NotImplementedError @staticmethod - def _cast_to_expr_or_convert_to_constant(o: Any) -> "Expr": - return o if isinstance(o, Expr) else Constant(o) + def _cast_to_expr_or_convert_to_constant(o: Any, include_vector=False) -> "Expr": + """Convert arbitrary object to an Expr.""" + if isinstance(o, Constant) and isinstance(o.value, list): + o = o.value + if isinstance(o, Expr): + return o + if isinstance(o, dict): + return Map(o) + if isinstance(o, list): + if include_vector and all([isinstance(i, (float, int)) for i in o]): + return Constant(Vector(o)) + else: + return Array(o) + return Constant(o) class expose_as_static: """ @@ -132,6 +141,10 @@ def __init__(self, instance_func): self.instance_func = instance_func def static_func(self, first_arg, *other_args, **kwargs): + if not isinstance(first_arg, (Expr, str)): + raise TypeError( + f"`expressions must be called on an Expr or a string representing a field name. got {type(first_arg)}." + ) first_expr = ( Field.of(first_arg) if not isinstance(first_arg, Expr) else first_arg ) @@ -239,6 +252,147 @@ def mod(self, other: Expr | float) -> "Expr": """ return Function("mod", [self, self._cast_to_expr_or_convert_to_constant(other)]) + @expose_as_static + def abs(self) -> "Expr": + """Creates an expression that calculates the absolute value of this expression. + + Example: + >>> # Get the absolute value of the 'change' field. + >>> Field.of("change").abs() + + Returns: + A new `Expr` representing the absolute value. + """ + return Function("abs", [self]) + + @expose_as_static + def ceil(self) -> "Expr": + """Creates an expression that calculates the ceiling of this expression. + + Example: + >>> # Get the ceiling of the 'value' field. + >>> Field.of("value").ceil() + + Returns: + A new `Expr` representing the ceiling value. + """ + return Function("ceil", [self]) + + @expose_as_static + def exp(self) -> "Expr": + """Creates an expression that computes e to the power of this expression. + + Example: + >>> # Compute e to the power of the 'value' field + >>> Field.of("value").exp() + + Returns: + A new `Expr` representing the exponential value. + """ + return Function("exp", [self]) + + @expose_as_static + def floor(self) -> "Expr": + """Creates an expression that calculates the floor of this expression. + + Example: + >>> # Get the floor of the 'value' field. + >>> Field.of("value").floor() + + Returns: + A new `Expr` representing the floor value. + """ + return Function("floor", [self]) + + @expose_as_static + def ln(self) -> "Expr": + """Creates an expression that calculates the natural logarithm of this expression. + + Example: + >>> # Get the natural logarithm of the 'value' field. + >>> Field.of("value").ln() + + Returns: + A new `Expr` representing the natural logarithm. + """ + return Function("ln", [self]) + + @expose_as_static + def log(self, base: Expr | float) -> "Expr": + """Creates an expression that calculates the logarithm of this expression with a given base. + + Example: + >>> # Get the logarithm of 'value' with base 2. + >>> Field.of("value").log(2) + >>> # Get the logarithm of 'value' with base from 'base_field'. + >>> Field.of("value").log(Field.of("base_field")) + + Args: + base: The base of the logarithm. + + Returns: + A new `Expr` representing the logarithm. + """ + return Function("log", [self, self._cast_to_expr_or_convert_to_constant(base)]) + + @expose_as_static + def log10(self) -> "Expr": + """Creates an expression that calculates the base 10 logarithm of this expression. + + Example: + >>> Field.of("value").log10() + + Returns: + A new `Expr` representing the logarithm. + """ + return Function("log10", [self]) + + @expose_as_static + def pow(self, exponent: Expr | float) -> "Expr": + """Creates an expression that calculates this expression raised to the power of the exponent. + + Example: + >>> # Raise 'base_val' to the power of 2. + >>> Field.of("base_val").pow(2) + >>> # Raise 'base_val' to the power of 'exponent_val'. + >>> Field.of("base_val").pow(Field.of("exponent_val")) + + Args: + exponent: The exponent. + + Returns: + A new `Expr` representing the power operation. + """ + return Function( + "pow", [self, self._cast_to_expr_or_convert_to_constant(exponent)] + ) + + @expose_as_static + def round(self) -> "Expr": + """Creates an expression that rounds this expression to the nearest integer. + + Example: + >>> # Round the 'value' field. + >>> Field.of("value").round() + + Returns: + A new `Expr` representing the rounded value. + """ + return Function("round", [self]) + + @expose_as_static + def sqrt(self) -> "Expr": + """Creates an expression that calculates the square root of this expression. + + Example: + >>> # Get the square root of the 'area' field. + >>> Field.of("area").sqrt() + + Returns: + A new `Expr` representing the square root. + """ + return Function("sqrt", [self]) + @expose_as_static def logical_maximum(self, other: Expr | CONSTANT_TYPE) -> "Expr": """Creates an expression that returns the larger value between this expression @@ -420,7 +574,9 @@ def less_than_or_equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": ) @expose_as_static - def equal_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": + def equal_any( + self, array: Array | Sequence[Expr | CONSTANT_TYPE] | Expr + ) -> "BooleanExpr": """Creates an expression that checks if this expression is equal to any of the provided values or expressions. @@ -438,14 +594,14 @@ def equal_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": "equal_any", [ self, - _ListOfExprs( - [self._cast_to_expr_or_convert_to_constant(v) for v in array] - ), + self._cast_to_expr_or_convert_to_constant(array), ], ) @expose_as_static - def not_equal_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": + def not_equal_any( + self, array: Array | list[Expr | CONSTANT_TYPE] | Expr + ) -> "BooleanExpr": """Creates an expression that checks if this expression is not equal to any of the provided values or expressions. @@ -463,9 +619,7 @@ def not_equal_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": "not_equal_any", [ self, - _ListOfExprs( - [self._cast_to_expr_or_convert_to_constant(v) for v in array] - ), + self._cast_to_expr_or_convert_to_constant(array), ], ) @@ -492,7 +646,7 @@ def array_contains(self, element: Expr | CONSTANT_TYPE) -> "BooleanExpr": @expose_as_static def array_contains_all( self, - elements: Sequence[Expr | CONSTANT_TYPE], + elements: Array | list[Expr | CONSTANT_TYPE] | Expr, ) -> "BooleanExpr": """Creates an expression that checks if an array contains all the specified elements. @@ -512,16 +666,14 @@ def array_contains_all( "array_contains_all", [ self, - _ListOfExprs( - [self._cast_to_expr_or_convert_to_constant(e) for e in elements] - ), + self._cast_to_expr_or_convert_to_constant(elements), ], ) @expose_as_static def array_contains_any( self, - elements: Sequence[Expr | CONSTANT_TYPE], + elements: Array | list[Expr | CONSTANT_TYPE] | Expr, ) -> "BooleanExpr": """Creates an expression that checks if an array contains any of the specified elements. @@ -542,9 +694,7 @@ def array_contains_any( "array_contains_any", [ self, - _ListOfExprs( - [self._cast_to_expr_or_convert_to_constant(e) for e in elements] - ), + self._cast_to_expr_or_convert_to_constant(elements), ], ) @@ -574,6 +724,90 @@ def array_reverse(self) -> "Expr": """ return Function("array_reverse", [self]) + @expose_as_static + def array_concat( + self, *other_arrays: Array | list[Expr | CONSTANT_TYPE] | Expr + ) -> "Expr": + """Creates an expression that concatenates an array expression with another array. + + Example: + >>> # Combine the 'tags' array with a new array and an array field + >>> Field.of("tags").array_concat(["newTag1", "newTag2", Field.of("otherTag")]) + + Args: + array: The list of constants or expressions to concat with. + + Returns: + A new `Expr` representing the concatenated array. + """ + return Function( + "array_concat", + [self] + + [self._cast_to_expr_or_convert_to_constant(arr) for arr in other_arrays], + ) + + @expose_as_static + def concat(self, *others: Expr | CONSTANT_TYPE) -> "Expr": + """Creates an expression that concatenates expressions together + + Args: + *others: The expressions to concatenate. + + Returns: + A new `Expr` representing the concatenated value. + """ + return Function( + "concat", + [self] + [self._cast_to_expr_or_convert_to_constant(o) for o in others], + ) + + @expose_as_static + def length(self) -> "Expr": + """ + Creates an expression that calculates the length of the expression if it is a string, array, map, or blob. + + Example: + >>> # Get the length of the 'name' field. + >>> Field.of("name").length() + + Returns: + A new `Expr` representing the length of the expression. + """ + return Function("length", [self]) + + @expose_as_static + def is_absent(self) -> "BooleanExpr": + """Creates an expression that returns true if a value is absent. Otherwise, returns false even if + the value is null. + + Example: + >>> # Check if the 'email' field is absent. + >>> Field.of("email").is_absent() + + Returns: + A new `BooleanExpression` representing the isAbsent operation. + """ + return BooleanExpr("is_absent", [self]) + + @expose_as_static + def if_absent(self, default_value: Expr | CONSTANT_TYPE) -> "Expr": + """Creates an expression that returns a default value if an expression evaluates to an absent value. + + Example: + >>> # Return the value of the 'email' field, or "N/A" if it's absent. + >>> Field.of("email").if_absent("N/A") + + Args: + default_value: The expression or constant value to return if this expression is absent. + + Returns: + A new `Expr` representing the ifAbsent operation. + """ + return Function( + "if_absent", + [self, self._cast_to_expr_or_convert_to_constant(default_value)], + ) + @expose_as_static def is_nan(self) -> "BooleanExpr": """Creates an expression that checks if this expression evaluates to 'NaN' (Not a Number). @@ -587,9 +821,22 @@ def is_nan(self) -> "BooleanExpr": """ return BooleanExpr("is_nan", [self]) + @expose_as_static + def is_not_nan(self) -> "BooleanExpr": + """Creates an expression that checks if this expression evaluates to a non-'NaN' (Not a Number) value. + + Example: + >>> # Check if the result of a calculation is not NaN + >>> Field.of("value").divide(1).is_not_nan() + + Returns: + A new `Expr` representing the 'is not NaN' check. + """ + return BooleanExpr("is_not_nan", [self]) + @expose_as_static def is_null(self) -> "BooleanExpr": - """Creates an expression that checks if this expression evaluates to 'Null'. + """Creates an expression that checks if the value of a field is 'Null'. Example: >>> Field.of("value").is_null() @@ -599,6 +846,50 @@ def is_null(self) -> "BooleanExpr": """ return BooleanExpr("is_null", [self]) + @expose_as_static + def is_not_null(self) -> "BooleanExpr": + """Creates an expression that checks if the value of a field is not 'Null'. + + Example: + >>> Field.of("value").is_not_null() + + Returns: + A new `Expr` representing the 'isNotNull' check. + """ + return BooleanExpr("is_not_null", [self]) + + @expose_as_static + def is_error(self): + """Creates an expression that checks if a given expression produces an error + + Example: + >>> # Resolves to True if an expression produces an error + >>> Field.of("value").divide("string").is_error() + + Returns: + A new `Expr` representing the isError operation. + """ + return Function("is_error", [self]) + + @expose_as_static + def if_error(self, then_value: Expr | CONSTANT_TYPE) -> "Expr": + """Creates an expression that returns ``then_value`` if this expression evaluates to an error. + Otherwise, returns the value of this expression. + + Example: + >>> # Resolves to 0 if an expression produces an error + >>> Field.of("value").divide("string").if_error(0) + + Args: + then_value: The value to return if this expression evaluates to an error. + + Returns: + A new `Expr` representing the ifError operation. + """ + return Function( + "if_error", [self, self._cast_to_expr_or_convert_to_constant(then_value)] + ) + @expose_as_static def exists(self) -> "BooleanExpr": """Creates an expression that checks if a field exists in the document. @@ -653,6 +944,35 @@ def count(self) -> "Expr": """ return AggregateFunction("count", [self]) + @expose_as_static + def count_if(self) -> "Expr": + """Creates an aggregation that counts the number of values of the provided field or expression + that evaluate to True. + + Example: + >>> # Count the number of adults + >>> Field.of("age").greater_than(18).count_if().as_("totalAdults") + + + Returns: + A new `AggregateFunction` representing the 'count_if' aggregation. + """ + return AggregateFunction("count_if", [self]) + + @expose_as_static + def count_distinct(self) -> "Expr": + """Creates an aggregation that counts the number of distinct values of the + provided field or expression. + + Example: + >>> # Count the total number of countries in the data + >>> Field.of("country").count_distinct().as_("totalCountries") + + Returns: + A new `AggregateFunction` representing the 'count_distinct' aggregation. + """ + return AggregateFunction("count_distinct", [self]) + @expose_as_static def minimum(self) -> "Expr": """Creates an aggregation that finds the minimum value of a field across multiple stage inputs. @@ -846,12 +1166,106 @@ def string_concat(self, *elements: Expr | CONSTANT_TYPE) -> "Expr": [self] + [self._cast_to_expr_or_convert_to_constant(el) for el in elements], ) + @expose_as_static + def to_lower(self) -> "Expr": + """Creates an expression that converts a string to lowercase. + + Example: + >>> # Convert the 'name' field to lowercase + >>> Field.of("name").to_lower() + + Returns: + A new `Expr` representing the lowercase string. + """ + return Function("to_lower", [self]) + + @expose_as_static + def to_upper(self) -> "Expr": + """Creates an expression that converts a string to uppercase. + + Example: + >>> # Convert the 'title' field to uppercase + >>> Field.of("title").to_upper() + + Returns: + A new `Expr` representing the uppercase string. + """ + return Function("to_upper", [self]) + + @expose_as_static + def trim(self) -> "Expr": + """Creates an expression that removes leading and trailing whitespace from a string. + + Example: + >>> # Trim whitespace from the 'userInput' field + >>> Field.of("userInput").trim() + + Returns: + A new `Expr` representing the trimmed string. + """ + return Function("trim", [self]) + + @expose_as_static + def string_reverse(self) -> "Expr": + """Creates an expression that reverses a string. + + Example: + >>> # Reverse the 'userInput' field + >>> Field.of("userInput").reverse() + + Returns: + A new `Expr` representing the reversed string. + """ + return Function("string_reverse", [self]) + + @expose_as_static + def substring( + self, position: Expr | int, length: Expr | int | None = None + ) -> "Expr": + """Creates an expression that returns a substring of the results of this expression. + + + Example: + >>> Field.of("description").substring(5, 10) + >>> Field.of("description").substring(5) + + Args: + position: the index of the first character of the substring. + length: the length of the substring. If not provided the substring + will end at the end of the input. + + Returns: + A new `Expr` representing the extracted substring. + """ + args = [self, self._cast_to_expr_or_convert_to_constant(position)] + if length is not None: + args.append(self._cast_to_expr_or_convert_to_constant(length)) + return Function("substring", args) + + @expose_as_static + def join(self, delimeter: Expr | str) -> "Expr": + """Creates an expression that joins the elements of an array into a string + + + Example: + >>> Field.of("tags").join(", ") + + Args: + delimiter: The delimiter to add between the elements of the array. + + Returns: + A new `Expr` representing the joined string. + """ + return Function( + "join", [self, self._cast_to_expr_or_convert_to_constant(delimeter)] + ) + @expose_as_static def map_get(self, key: str | Constant[str]) -> "Expr": """Accesses a value from the map produced by evaluating this expression. Example: - >>> Expr.map({"city": "London"}).map_get("city") + >>> Map({"city": "London"}).map_get("city") >>> Field.of("address").map_get("city") Args: @@ -861,7 +1275,118 @@ def map_get(self, key: str | Constant[str]) -> "Expr": A new `Expr` representing the value associated with the given key in the map. """ return Function( - "map_get", [self, Constant.of(key) if isinstance(key, str) else key] + "map_get", [self, self._cast_to_expr_or_convert_to_constant(key)] + ) + + @expose_as_static + def map_remove(self, key: str | Constant[str]) -> "Expr": + """Remove a key from the map produced by evaluating this expression. + + Example: + >>> Map({"city": "London"}).map_remove("city") + >>> Field.of("address").map_remove("city") + + Args: + key: The key to remove in the map. + + Returns: + A new `Expr` representing the map_remove operation. + """ + return Function( + "map_remove", [self, self._cast_to_expr_or_convert_to_constant(key)] + ) + + @expose_as_static + def map_merge( + self, *other_maps: Map | dict[str | Constant[str], Expr | CONSTANT_TYPE] | Expr + ) -> "Expr": + """Creates an expression that merges one or more dicts into a single map. + + Example: + >>> Map({"city": "London"}).map_merge({"country": "UK"}, {"isCapital": True}) + >>> Field.of("settings").map_merge({"enabled":True}, Function.conditional(Field.of('isAdmin'), {"admin":True}, {}}) + + Args: + *other_maps: Sequence of maps to merge into the resulting map. + + Returns: + A new `Expr` representing the value associated with the given key in the map. + """ + return Function( + "map_merge", + [self] + [self._cast_to_expr_or_convert_to_constant(m) for m in other_maps], + ) + + @expose_as_static + def cosine_distance(self, other: Expr | list[float] | Vector) -> "Expr": + """Calculates the cosine distance between two vectors. + + Example: + >>> # Calculate the cosine distance between the 'userVector' field and the 'itemVector' field + >>> Field.of("userVector").cosine_distance(Field.of("itemVector")) + >>> # Calculate the Cosine distance between the 'location' field and a target location + >>> Field.of("location").cosine_distance([37.7749, -122.4194]) + + Args: + other: The other vector (represented as an Expr, list of floats, or Vector) to compare against. + + Returns: + A new `Expr` representing the cosine distance between the two vectors. + """ + return Function( + "cosine_distance", + [ + self, + self._cast_to_expr_or_convert_to_constant(other, include_vector=True), + ], + ) + + @expose_as_static + def euclidean_distance(self, other: Expr | list[float] | Vector) -> "Expr": + """Calculates the Euclidean distance between two vectors. + + Example: + >>> # Calculate the Euclidean distance between the 'location' field and a target location + >>> Field.of("location").euclidean_distance([37.7749, -122.4194]) + >>> # Calculate the Euclidean distance between two vector fields: 'pointA' and 'pointB' + >>> Field.of("pointA").euclidean_distance(Field.of("pointB")) + + Args: + other: The other vector (represented as an Expr, list of floats, or Vector) to compare against. + + Returns: + A new `Expr` representing the Euclidean distance between the two vectors. + """ + return Function( + "euclidean_distance", + [ + self, + self._cast_to_expr_or_convert_to_constant(other, include_vector=True), + ], + ) + + @expose_as_static + def dot_product(self, other: Expr | list[float] | Vector) -> "Expr": + """Calculates the dot product between two vectors. + + Example: + >>> # Calculate the dot product between a feature vector and a target vector + >>> Field.of("features").dot_product([0.5, 0.8, 0.2]) + >>> # Calculate the dot product between two document vectors: 'docVector1' and 'docVector2' + >>> Field.of("docVector1").dot_product(Field.of("docVector2")) + + Args: + other: The other vector (represented as an Expr, list of floats, or Vector) to calculate dot product with. + + Returns: + A new `Expr` representing the dot product between the two vectors. + """ + return Function( + "dot_product", + [ + self, + self._cast_to_expr_or_convert_to_constant(other, include_vector=True), + ], ) @expose_as_static @@ -1034,6 +1559,19 @@ def collection_id(self): """ return Function("collection_id", [self]) + @expose_as_static + def document_id(self): + """Creates an expression that returns the document ID from a path. + + Example: + >>> # Get the document ID from a path. + >>> Field.of("__name__").document_id() + + Returns: + A new `Expr` representing the document ID. + """ + return Function("document_id", [self]) + def ascending(self) -> Ordering: """Creates an `Ordering` that sorts documents in ascending order based on this expression. @@ -1107,25 +1645,6 @@ def _to_pb(self) -> Value: return encode_value(self.value) -class _ListOfExprs(Expr): - """Represents a list of expressions, typically used as an argument to functions like 'in' or array functions.""" - - def __init__(self, exprs: Sequence[Expr]): - self.exprs: list[Expr] = list(exprs) - - def __eq__(self, other): - if not isinstance(other, _ListOfExprs): - return False - else: - return other.exprs == self.exprs - - def __repr__(self): - return repr(self.exprs) - - def _to_pb(self): - return Value(array_value={"values": [e._to_pb() for e in self.exprs]}) - - class Function(Expr): """A base class for expressions that represent function calls.""" @@ -1323,11 +1842,11 @@ def _from_query_filter_pb(filter_pb, client): if filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NAN: return And(field.exists(), field.is_nan()) elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NOT_NAN: - return And(field.exists(), Not(field.is_nan())) + return And(field.exists(), field.is_not_nan()) elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NULL: return And(field.exists(), field.is_null()) elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NOT_NULL: - return And(field.exists(), Not(field.is_null())) + return And(field.exists(), field.is_not_null()) else: raise TypeError(f"Unexpected UnaryFilter operator type: {filter_pb.op}") elif isinstance(filter_pb, Query_pb.FieldFilter): @@ -1367,6 +1886,55 @@ def _from_query_filter_pb(filter_pb, client): raise TypeError(f"Unexpected filter type: {type(filter_pb)}") +class Array(Function): + """ + Creates an expression that creates a Firestore array value from an input list. + + Example: + >>> Expr.array(["bar", Field.of("baz")]) + + Args: + elements: The input list to evaluate in the expression + """ + + def __init__(self, elements: list[Expr | CONSTANT_TYPE]): + if not isinstance(elements, list): + raise TypeError("Array must be constructed with a list") + converted_elements = [ + self._cast_to_expr_or_convert_to_constant(el) for el in elements + ] + super().__init__("array", converted_elements) + + def __repr__(self): + return f"Array({self.params})" + + +class Map(Function): + """ + Creates an expression that creates a Firestore map value from an input dict. + + Example: + >>> Expr.map({"foo": "bar", "baz": Field.of("baz")}) + + Args: + elements: The input dict to evaluate in the expression + """ + + def __init__(self, elements: dict[str | Constant[str], Expr | CONSTANT_TYPE]): + element_list = [] + for k, v in elements.items(): + element_list.append(self._cast_to_expr_or_convert_to_constant(k)) + element_list.append(self._cast_to_expr_or_convert_to_constant(v)) + super().__init__("map", element_list) + + def __repr__(self): + formatted_params = [ + a.value if isinstance(a, Constant) else a for a in self.params + ] + d = {a: b for a, b in zip(formatted_params[::2], formatted_params[1::2])} + return f"Map({d})" + + class And(BooleanExpr): """ Represents an expression that performs a logical 'AND' operation on multiple filter conditions. @@ -1454,6 +2022,7 @@ def __init__(self, condition: BooleanExpr, then_expr: Expr, else_expr: Expr): "conditional", [condition, then_expr, else_expr], use_infix_repr=False ) + class Count(AggregateFunction): """ Represents an aggregation that counts the number of stage inputs with valid evaluations of the @@ -1471,6 +2040,15 @@ class Count(AggregateFunction): def __init__(self, expression: Expr | None = None): expression_list = [expression] if expression else [] - super().__init__( - "count", expression_list, use_infix_repr=bool(expression_list) - ) + super().__init__("count", expression_list, use_infix_repr=bool(expression_list)) + + +class CurrentTimestamp(Function): + """Creates an expression that returns the current timestamp + + Returns: + A new `Expr` representing the current timestamp. + """ + + def __init__(self): + super().__init__("current_timestamp", [], use_infix_repr=False) diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index 50cc7c29d..38595224a 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -136,6 +136,10 @@ data: embedding: [1.0, 2.0, 3.0] vec2: embedding: [4.0, 5.0, 6.0, 7.0] + vec3: + embedding: [5.0, 6.0, 7.0] + vec4: + embedding: [1.0, 2.0, 4.0] tests: - description: "testAggregates - count" pipeline: @@ -163,6 +167,64 @@ tests: - fieldReferenceValue: rating - mapValue: {} name: aggregate + - description: "testAggregates - count_if" + pipeline: + - Collection: books + - Aggregate: + - AliasedExpr: + - Expr.count_if: + - Expr.greater_than: + - Field: rating + - Constant: 4.2 + - "count_if_rating_gt_4_2" + assert_results: + - count_if_rating_gt_4_2: 5 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + count_if_rating_gt_4_2: + functionValue: + name: count_if + args: + - functionValue: + name: greater_than + args: + - fieldReferenceValue: rating + - doubleValue: 4.2 + - mapValue: {} + name: aggregate + - description: "testAggregates - count_distinct" + pipeline: + - Collection: books + - Aggregate: + - AliasedExpr: + - Expr.count_distinct: + - Field: genre + - "distinct_genres" + assert_results: + - distinct_genres: 8 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + distinct_genres: + functionValue: + name: count_distinct + args: + - fieldReferenceValue: genre + - mapValue: {} + name: aggregate - description: "testAggregates - avg, count, max" pipeline: - Collection: books @@ -697,10 +759,11 @@ tests: - functionValue: args: - fieldReferenceValue: tags - - arrayValue: - values: + - functionValue: + args: - stringValue: comedy - stringValue: classic + name: array name: array_contains_any name: where - args: @@ -739,10 +802,11 @@ tests: - functionValue: args: - fieldReferenceValue: tags - - arrayValue: - values: + - functionValue: + args: - stringValue: adventure - stringValue: magic + name: array name: array_contains_all name: where - args: @@ -929,7 +993,57 @@ tests: expression: fieldReferenceValue: title name: sort + - description: testConcat + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Select: + - AliasedExpr: + - Expr.concat: + - Field: author + - Constant: ": " + - Field: title + - "author_title" + - AliasedExpr: + - Expr.concat: + - Field: tags + - - Constant: "new_tag" + - "concatenatedTags" + assert_results: + - author_title: "Douglas Adams: The Hitchhiker's Guide to the Galaxy" + concatenatedTags: + - comedy + - space + - adventure + - new_tag - description: testLength + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Select: + - AliasedExpr: + - Expr.length: + - Field: title + - "titleLength" + - AliasedExpr: + - Expr.length: + - Field: tags + - "tagsLength" + - AliasedExpr: + - Expr.length: + - Field: awards + - "awardsLength" + assert_results: + - titleLength: 36 + tagsLength: 3 + awardsLength: 2 + - description: testCharLength pipeline: - Collection: books - Select: @@ -1414,6 +1528,177 @@ tests: - args: - integerValue: '1' name: limit + - description: testIsNotNull + pipeline: + - Collection: books + - Where: + - Expr.is_not_null: + - Field: rating + assert_count: 10 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: rating + name: is_not_null + name: where + - description: testIsNotNaN + pipeline: + - Collection: books + - Where: + - Expr.is_not_nan: + - Field: rating + assert_count: 10 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: rating + name: is_not_nan + name: where + - description: testIsAbsent + pipeline: + - Collection: books + - Where: + - Expr.is_absent: + - Field: awards.pulitzer + assert_count: 9 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: awards.pulitzer + name: is_absent + name: where + - description: testIfAbsent + pipeline: + - Collection: books + - Select: + - AliasedExpr: + - Expr.if_absent: + - Field: awards.pulitzer + - Constant: false + - "pulitzer_award" + - title + - Where: + - Expr.equal: + - Field: pulitzer_award + - Constant: true + assert_results: + - pulitzer_award: true + title: To Kill a Mockingbird + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + pulitzer_award: + functionValue: + name: if_absent + args: + - fieldReferenceValue: awards.pulitzer + - booleanValue: false + title: + fieldReferenceValue: title + name: select + - args: + - functionValue: + args: + - fieldReferenceValue: pulitzer_award + - booleanValue: true + name: equal + name: where + - description: testIsError + pipeline: + - Collection: books + - Select: + - AliasedExpr: + - Expr.is_error: + - Expr.divide: + - Field: rating + - Constant: "string" + - "is_error_result" + - Limit: 1 + assert_results: + - is_error_result: true + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + is_error_result: + functionValue: + name: is_error + args: + - functionValue: + name: divide + args: + - fieldReferenceValue: rating + - stringValue: "string" + name: select + - args: + - integerValue: '1' + name: limit + - description: testIfError + pipeline: + - Collection: books + - Select: + - AliasedExpr: + - Expr.if_error: + - Expr.divide: + - Field: rating + - Field: genre + - Constant: "An error occurred" + - "if_error_result" + - Limit: 1 + assert_results: + - if_error_result: "An error occurred" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + if_error_result: + functionValue: + name: if_error + args: + - functionValue: + name: divide + args: + - fieldReferenceValue: rating + - fieldReferenceValue: genre + - stringValue: "An error occurred" + name: select + - args: + - integerValue: '1' + name: limit - description: testLogicalMinMax pipeline: - Collection: books @@ -1520,25 +1805,27 @@ tests: - booleanValue: true name: equal name: where - - description: testNestedFields + - description: testMapGetWithField pipeline: - Collection: books - Where: - Expr.equal: - - Field: awards.hugo - - Constant: true - - Sort: - - Ordering: - - Field: title - - DESCENDING + - Field: title + - Constant: "Dune" + - AddFields: + - AliasedExpr: + - Constant: "hugo" + - "award_name" - Select: - - title - - Field: awards.hugo + - AliasedExpr: + - Expr.map_get: + - Field: awards + - Field: award_name + - "hugoAward" + - Field: title assert_results: - - title: The Hitchhiker's Guide to the Galaxy - awards.hugo: true - - title: Dune - awards.hugo: true + - hugoAward: true + title: Dune assert_proto: pipeline: stages: @@ -1548,31 +1835,44 @@ tests: - args: - functionValue: args: - - fieldReferenceValue: awards.hugo - - booleanValue: true + - fieldReferenceValue: title + - stringValue: "Dune" name: equal name: where - args: - mapValue: fields: - direction: - stringValue: descending - expression: - fieldReferenceValue: title - name: sort + award_name: + stringValue: "hugo" + name: add_fields - args: - mapValue: fields: - awards.hugo: - fieldReferenceValue: awards.hugo + hugoAward: + functionValue: + name: map_get + args: + - fieldReferenceValue: awards + - fieldReferenceValue: award_name title: fieldReferenceValue: title name: select - - description: testSampleLimit + - description: testMapRemove pipeline: - Collection: books - - Sample: 3 - assert_count: 3 # Results will vary due to randomness + - Where: + - Expr.equal: + - Field: title + - Constant: "Dune" + - Select: + - AliasedExpr: + - Expr.map_remove: + - Field: awards + - "nebula" + - "awards_removed" + assert_results: + - awards_removed: + hugo: true assert_proto: pipeline: stages: @@ -1580,14 +1880,146 @@ tests: - referenceValue: /books name: collection - args: - - integerValue: '3' - - stringValue: documents - name: sample - - description: testSamplePercentage - pipeline: - - Collection: books - - Sample: - - SampleOptions: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "Dune" + name: equal + name: where + - args: + - mapValue: + fields: + awards_removed: + functionValue: + name: map_remove + args: + - fieldReferenceValue: awards + - stringValue: "nebula" + name: select + - description: testMapMerge + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: "Dune" + - Select: + - AliasedExpr: + - Expr.map_merge: + - Field: awards + - Map: + elements: {"new_award": true, "hugo": false} + - Map: + elements: {"another_award": "yes"} + - "awards_merged" + assert_results: + - awards_merged: + hugo: false + nebula: true + new_award: true + another_award: "yes" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "Dune" + name: equal + name: where + - args: + - mapValue: + fields: + awards_merged: + functionValue: + name: map_merge + args: + - fieldReferenceValue: awards + - functionValue: + name: map + args: + - stringValue: "new_award" + - booleanValue: true + - stringValue: "hugo" + - booleanValue: false + - functionValue: + name: map + args: + - stringValue: "another_award" + - stringValue: "yes" + name: select + - description: testNestedFields + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: awards.hugo + - Constant: true + - Sort: + - Ordering: + - Field: title + - DESCENDING + - Select: + - title + - Field: awards.hugo + assert_results: + - title: The Hitchhiker's Guide to the Galaxy + awards.hugo: true + - title: Dune + awards.hugo: true + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: awards.hugo + - booleanValue: true + name: equal + name: where + - args: + - mapValue: + fields: + direction: + stringValue: descending + expression: + fieldReferenceValue: title + name: sort + - args: + - mapValue: + fields: + awards.hugo: + fieldReferenceValue: awards.hugo + title: + fieldReferenceValue: title + name: select + - description: testSampleLimit + pipeline: + - Collection: books + - Sample: 3 + assert_count: 3 # Results will vary due to randomness + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - integerValue: '3' + - stringValue: documents + name: sample + - description: testSamplePercentage + pipeline: + - Collection: books + - Sample: + - SampleOptions: - 0.6 - percent assert_proto: @@ -1730,59 +2162,379 @@ tests: - adventure - space - comedy - - description: testExists + - description: testDocumentId pipeline: - Collection: books - Where: - - And: - - Expr.exists: - - Field: awards.pulitzer - - Expr.equal: - - Field: awards.pulitzer - - Constant: true + - Expr.equal: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" - Select: - - title + - AliasedExpr: + - Expr.document_id: + - Field: __name__ + - "doc_id" assert_results: - - title: To Kill a Mockingbird - - description: testSum + - doc_id: "book1" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "The Hitchhiker's Guide to the Galaxy" + name: equal + name: where + - args: + - mapValue: + fields: + doc_id: + functionValue: + name: document_id + args: + - fieldReferenceValue: __name__ + name: select + - description: testCurrentTimestamp + pipeline: + - Collection: books + - Limit: 1 + - Select: + - AliasedExpr: + - And: + - Expr.greater_than_or_equal: + - CurrentTimestamp: [] + - Expr.unix_seconds_to_timestamp: + - Constant: 1735689600 # 2025-01-01 + - Expr.less_than: + - CurrentTimestamp: [] + - Expr.unix_seconds_to_timestamp: + - Constant: 4892438400 # 2125-01-01 + - "is_between_2025_and_2125" + assert_results: + - is_between_2025_and_2125: true + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - integerValue: '1' + name: limit + - args: + - mapValue: + fields: + is_between_2025_and_2125: + functionValue: + name: and + args: + - functionValue: + name: greater_than_or_equal + args: + - functionValue: + name: current_timestamp + - functionValue: + name: unix_seconds_to_timestamp + args: + - integerValue: '1735689600' + - functionValue: + name: less_than + args: + - functionValue: + name: current_timestamp + - functionValue: + name: unix_seconds_to_timestamp + args: + - integerValue: '4892438400' + name: select + - description: testArrayConcat pipeline: - Collection: books - Where: - Expr.equal: - - Field: genre - - Constant: Science Fiction - - Aggregate: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Select: - AliasedExpr: - - Expr.sum: - - Field: rating - - "total_rating" + - Expr.array_concat: + - Field: tags + - Constant: ["new_tag", "another_tag"] + - "concatenatedTags" assert_results: - - total_rating: 8.8 - - description: testStringContains + - concatenatedTags: + - comedy + - space + - adventure + - new_tag + - another_tag + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "The Hitchhiker's Guide to the Galaxy" + name: equal + name: where + - args: + - mapValue: + fields: + concatenatedTags: + functionValue: + args: + - fieldReferenceValue: tags + - functionValue: + args: + - stringValue: "new_tag" + - stringValue: "another_tag" + name: array + name: array_concat + name: select + - description: testArrayConcatMultiple pipeline: - Collection: books - Where: - - Expr.string_contains: + - Expr.equal: - Field: title - - Constant: "Hitchhiker's" + - Constant: "Dune" - Select: - - title + - AliasedExpr: + - Expr.array_concat: + - Field: tags + - Constant: ["sci-fi"] + - Constant: ["classic", "epic"] + - "concatenatedTags" assert_results: - - title: "The Hitchhiker's Guide to the Galaxy" - - description: testVectorLength + - concatenatedTags: + - politics + - desert + - ecology + - sci-fi + - classic + - epic + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "Dune" + name: equal + name: where + - args: + - mapValue: + fields: + concatenatedTags: + functionValue: + args: + - fieldReferenceValue: tags + - functionValue: + args: + - stringValue: "sci-fi" + name: array + - functionValue: + args: + - stringValue: "classic" + - stringValue: "epic" + name: array + name: array_concat + name: select + - description: testMapMergeLiterals pipeline: - - Collection: vectors + - Collection: books + - Limit: 1 - Select: - AliasedExpr: - - Expr.vector_length: - - Field: embedding - - "embedding_length" - - Sort: - - Ordering: - - Field: embedding_length - - ASCENDING + - Expr.map_merge: + - Map: + elements: {"a": "orig", "b": "orig"} + - Map: + elements: {"b": "new", "c": "new"} + - "merged" assert_results: - - embedding_length: 3 + - merged: + a: "orig" + b: "new" + c: "new" + - description: testArrayContainsAnyWithField + pipeline: + - Collection: books + - AddFields: + - AliasedExpr: + - Expr.array_concat: + - Field: tags + - Array: ["Dystopian"] + - "new_tags" + - Where: + - Expr.array_contains_any: + - Field: new_tags + - - Constant: non_existent_tag + - Field: genre + - Select: + - title + - genre + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: "1984" + genre: "Dystopian" + - title: "The Handmaid's Tale" + genre: "Dystopian" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + new_tags: + functionValue: + args: + - fieldReferenceValue: tags + - functionValue: + args: + - stringValue: "Dystopian" + name: array + name: array_concat + name: add_fields + - args: + - functionValue: + args: + - fieldReferenceValue: new_tags + - functionValue: + args: + - stringValue: "non_existent_tag" + - fieldReferenceValue: genre + name: array + name: array_contains_any + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + genre: + fieldReferenceValue: genre + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: testArrayConcatLiterals + pipeline: + - Collection: books + - Limit: 1 + - Select: + - AliasedExpr: + - Expr.array_concat: + - Array: [1, 2, 3] + - Array: [4, 5] + - "concatenated" + assert_results: + - concatenated: [1, 2, 3, 4, 5] + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - integerValue: '1' + name: limit + - args: + - mapValue: + fields: + concatenated: + functionValue: + args: + - functionValue: + args: + - integerValue: '1' + - integerValue: '2' + - integerValue: '3' + name: array + - functionValue: + args: + - integerValue: '4' + - integerValue: '5' + name: array + name: array_concat + name: select + - description: testExists + pipeline: + - Collection: books + - Where: + - And: + - Expr.exists: + - Field: awards.pulitzer + - Expr.equal: + - Field: awards.pulitzer + - Constant: true + - Select: + - title + assert_results: + - title: To Kill a Mockingbird + - description: testSum + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: genre + - Constant: Science Fiction + - Aggregate: + - AliasedExpr: + - Expr.sum: + - Field: rating + - "total_rating" + assert_results: + - total_rating: 8.8 + - description: testStringContains + pipeline: + - Collection: books + - Where: + - Expr.string_contains: + - Field: title + - Constant: "Hitchhiker's" + - Select: + - title + assert_results: + - title: "The Hitchhiker's Guide to the Galaxy" + - description: testVectorLength + pipeline: + - Collection: vectors + - Select: + - AliasedExpr: + - Expr.vector_length: + - Field: embedding + - "embedding_length" + - Sort: + - Ordering: + - Field: embedding_length + - ASCENDING + assert_results: + - embedding_length: 3 + - embedding_length: 3 + - embedding_length: 3 - embedding_length: 4 - description: testTimestampFunctions pipeline: @@ -1956,3 +2708,596 @@ tests: conditional_field: "Dystopian" - title: "Dune" conditional_field: "Frank Herbert" + - description: testFindNearestEuclidean + pipeline: + - Collection: vectors + - FindNearest: + field: embedding + vector: [1.0, 2.0, 3.0] + distance_measure: EUCLIDEAN + options: + FindNearestOptions: + limit: 2 + distance_field: + Field: distance + - Select: + - distance + assert_results: + - distance: 0.0 + - distance: 1.0 + assert_proto: + pipeline: + stages: + - name: collection + args: + - referenceValue: /vectors + - name: find_nearest + args: + - fieldReferenceValue: embedding + - mapValue: + fields: + __type__: + stringValue: __vector__ + value: + arrayValue: + values: + - doubleValue: 1.0 + - doubleValue: 2.0 + - doubleValue: 3.0 + - stringValue: euclidean + options: + limit: + integerValue: '2' + distance_field: + fieldReferenceValue: distance + - name: select + args: + - mapValue: + fields: + distance: + fieldReferenceValue: distance + - description: testMathExpressions + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: To Kill a Mockingbird + - Select: + - AliasedExpr: + - Expr.abs: + - Field: rating + - "abs_rating" + - AliasedExpr: + - Expr.ceil: + - Field: rating + - "ceil_rating" + - AliasedExpr: + - Expr.exp: + - Field: rating + - "exp_rating" + - AliasedExpr: + - Expr.floor: + - Field: rating + - "floor_rating" + - AliasedExpr: + - Expr.ln: + - Field: rating + - "ln_rating" + - AliasedExpr: + - Expr.log10: + - Field: rating + - "log_rating_base10" + - AliasedExpr: + - Expr.log: + - Field: rating + - Constant: 2 + - "log_rating_base2" + - AliasedExpr: + - Expr.pow: + - Field: rating + - Constant: 2 + - "pow_rating" + - AliasedExpr: + - Expr.sqrt: + - Field: rating + - "sqrt_rating" + assert_results_approximate: + - abs_rating: 4.2 + ceil_rating: 5.0 + exp_rating: 66.686331 + floor_rating: 4.0 + ln_rating: 1.4350845 + log_rating_base10: 0.623249 + log_rating_base2: 2.0704 + pow_rating: 17.64 + sqrt_rating: 2.049390 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: To Kill a Mockingbird + name: equal + name: where + - args: + - mapValue: + fields: + abs_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: abs + ceil_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: ceil + exp_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: exp + floor_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: floor + ln_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: ln + log_rating_base10: + functionValue: + args: + - fieldReferenceValue: rating + name: log10 + log_rating_base2: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '2' + name: log + pow_rating: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '2' + name: pow + sqrt_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: sqrt + name: select + - description: testRoundExpressions + pipeline: + - Collection: books + - Where: + - Expr.equal_any: + - Field: title + - - Constant: "To Kill a Mockingbird" # rating 4.2 + - Constant: "Pride and Prejudice" # rating 4.5 + - Constant: "The Lord of the Rings" # rating 4.7 + - Select: + - title + - AliasedExpr: + - Expr.round: + - Field: rating + - "round_rating" + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: "Pride and Prejudice" + round_rating: 5.0 + - title: "The Lord of the Rings" + round_rating: 5.0 + - title: "To Kill a Mockingbird" + round_rating: 4.0 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - functionValue: + args: + - stringValue: "To Kill a Mockingbird" + - stringValue: "Pride and Prejudice" + - stringValue: "The Lord of the Rings" + name: array + name: equal_any + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + round_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: round + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: testFindNearestDotProduct + pipeline: + - Collection: vectors + - FindNearest: + field: embedding + vector: [1.0, 2.0, 3.0] + distance_measure: DOT_PRODUCT + options: + FindNearestOptions: + limit: 3 + distance_field: + Field: distance + - Select: + - distance + assert_results: + - distance: 38.0 + - distance: 17.0 + - distance: 14.0 + assert_proto: + pipeline: + stages: + - name: collection + args: + - referenceValue: /vectors + - name: find_nearest + args: + - fieldReferenceValue: embedding + - mapValue: + fields: + __type__: + stringValue: __vector__ + value: + arrayValue: + values: + - doubleValue: 1.0 + - doubleValue: 2.0 + - doubleValue: 3.0 + - stringValue: dot_product + options: + limit: + integerValue: '3' + distance_field: + fieldReferenceValue: distance + - name: select + args: + - mapValue: + fields: + distance: + fieldReferenceValue: distance + - description: testDotProductWithConstant + pipeline: + - Collection: vectors + - Where: + - Expr.equal: + - Field: embedding + - Vector: [1.0, 2.0, 3.0] + - Select: + - AliasedExpr: + - Expr.dot_product: + - Field: embedding + - Vector: [1.0, 1.0, 1.0] + - "dot_product_result" + assert_results: + - dot_product_result: 6.0 + - description: testEuclideanDistanceWithConstant + pipeline: + - Collection: vectors + - Where: + - Expr.equal: + - Field: embedding + - Vector: [1.0, 2.0, 3.0] + - Select: + - AliasedExpr: + - Expr.euclidean_distance: + - Field: embedding + - Vector: [1.0, 2.0, 3.0] + - "euclidean_distance_result" + assert_results: + - euclidean_distance_result: 0.0 + - description: testCosineDistanceWithConstant + pipeline: + - Collection: vectors + - Where: + - Expr.equal: + - Field: embedding + - Vector: [1.0, 2.0, 3.0] + - Select: + - AliasedExpr: + - Expr.cosine_distance: + - Field: embedding + - Vector: [1.0, 2.0, 3.0] + - "cosine_distance_result" + assert_results: + - cosine_distance_result: 0.0 + - description: testStringFunctions - ToLower + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: author + - Constant: "Douglas Adams" + - Select: + - AliasedExpr: + - Expr.to_lower: + - Field: title + - "lower_title" + assert_results: + - lower_title: "the hitchhiker's guide to the galaxy" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: equal + name: where + - args: + - mapValue: + fields: + lower_title: + functionValue: + args: + - fieldReferenceValue: title + name: to_lower + name: select + - description: testStringFunctions - ToUpper + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: author + - Constant: "Douglas Adams" + - Select: + - AliasedExpr: + - Expr.to_upper: + - Field: title + - "upper_title" + assert_results: + - upper_title: "THE HITCHHIKER'S GUIDE TO THE GALAXY" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: equal + name: where + - args: + - mapValue: + fields: + upper_title: + functionValue: + args: + - fieldReferenceValue: title + name: to_upper + name: select + - description: testStringFunctions - Trim + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: author + - Constant: "Douglas Adams" + - Select: + - AliasedExpr: + - Expr.trim: + - Expr.string_concat: + - Constant: " " + - Field: title + - Constant: " " + - "trimmed_title" + assert_results: + - trimmed_title: "The Hitchhiker's Guide to the Galaxy" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: equal + name: where + - args: + - mapValue: + fields: + trimmed_title: + functionValue: + args: + - functionValue: + args: + - stringValue: " " + - fieldReferenceValue: title + - stringValue: " " + name: string_concat + name: trim + name: select + - description: testStringFunctions - StringReverse + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: author + - Constant: "Jane Austen" + - Select: + - AliasedExpr: + - Expr.string_reverse: + - Field: title + - "reversed_title" + assert_results: + - reversed_title: "ecidujerP dna edirP" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: "Jane Austen" + name: equal + name: where + - args: + - mapValue: + fields: + reversed_title: + functionValue: + args: + - fieldReferenceValue: title + name: string_reverse + name: select + - description: testStringFunctions - Substring + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: author + - Constant: "Douglas Adams" + - Select: + - AliasedExpr: + - Expr.substring: + - Field: title + - Constant: 4 + - Constant: 11 + - "substring_title" + assert_results: + - substring_title: "Hitchhiker'" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: "Douglas Adams" + name: equal + name: where + - args: + - mapValue: + fields: + substring_title: + functionValue: + args: + - fieldReferenceValue: title + - integerValue: '4' + - integerValue: '11' + name: substring + name: select + - description: testStringFunctions - Substring without length + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: author + - Constant: "Fyodor Dostoevsky" + - Select: + - AliasedExpr: + - Expr.substring: + - Field: title + - Constant: 10 + - "substring_title" + assert_results: + - substring_title: "Punishment" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: "Fyodor Dostoevsky" + name: equal + name: where + - args: + - mapValue: + fields: + substring_title: + functionValue: + args: + - fieldReferenceValue: title + - integerValue: '10' + name: substring + name: select + - description: testStringFunctions - Join + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: author + - Constant: "Douglas Adams" + - Select: + - AliasedExpr: + - Expr.join: + - Field: tags + - Constant: ", " + - "joined_tags" + assert_results: + - joined_tags: "comedy, space, adventure" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: "Douglas Adams" + name: equal + name: where + - args: + - mapValue: + fields: + joined_tags: + functionValue: + args: + - fieldReferenceValue: tags + - stringValue: ", " + name: join + name: select diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index d4c654e63..682fe5e23 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -28,6 +28,7 @@ from google.cloud.firestore_v1 import _pipeline_stages as stages from google.cloud.firestore_v1 import pipeline_expressions from google.cloud.firestore_v1.vector import Vector +from google.cloud.firestore_v1 import pipeline_expressions as expr from google.api_core.exceptions import GoogleAPIError from google.cloud.firestore import Client, AsyncClient @@ -86,7 +87,13 @@ def test_pipeline_expected_errors(test_dict, client): @pytest.mark.parametrize( "test_dict", - [t for t in yaml_loader() if "assert_results" in t or "assert_count" in t], + [ + t + for t in yaml_loader() + if "assert_results" in t + or "assert_count" in t + or "assert_results_approximate" in t + ], ids=lambda x: f"{x.get('description', '')}", ) def test_pipeline_results(test_dict, client): @@ -94,12 +101,23 @@ def test_pipeline_results(test_dict, client): Ensure pipeline returns expected results """ expected_results = _parse_yaml_types(test_dict.get("assert_results", None)) + expected_approximate_results = _parse_yaml_types( + test_dict.get("assert_results_approximate", None) + ) expected_count = test_dict.get("assert_count", None) pipeline = parse_pipeline(client, test_dict["pipeline"]) # check if server responds as expected got_results = [snapshot.data() for snapshot in pipeline.stream()] if expected_results: assert got_results == expected_results + if expected_approximate_results: + assert len(got_results) == len( + expected_approximate_results + ), "got unexpected result count" + for idx in range(len(got_results)): + assert got_results[idx] == pytest.approx( + expected_approximate_results[idx], abs=1e-4 + ) if expected_count is not None: assert len(got_results) == expected_count @@ -126,7 +144,13 @@ async def test_pipeline_expected_errors_async(test_dict, async_client): @pytest.mark.parametrize( "test_dict", - [t for t in yaml_loader() if "assert_results" in t or "assert_count" in t], + [ + t + for t in yaml_loader() + if "assert_results" in t + or "assert_count" in t + or "assert_results_approximate" in t + ], ids=lambda x: f"{x.get('description', '')}", ) @pytest.mark.asyncio @@ -135,12 +159,23 @@ async def test_pipeline_results_async(test_dict, async_client): Ensure pipeline returns expected results """ expected_results = _parse_yaml_types(test_dict.get("assert_results", None)) + expected_approximate_results = _parse_yaml_types( + test_dict.get("assert_results_approximate", None) + ) expected_count = test_dict.get("assert_count", None) pipeline = parse_pipeline(async_client, test_dict["pipeline"]) # check if server responds as expected got_results = [snapshot.data() async for snapshot in pipeline.stream()] if expected_results: assert got_results == expected_results + if expected_approximate_results: + assert len(got_results) == len( + expected_approximate_results + ), "got unexpected result count" + for idx in range(len(got_results)): + assert got_results[idx] == pytest.approx( + expected_approximate_results[idx], abs=1e-4 + ) if expected_count is not None: assert len(got_results) == expected_count @@ -218,7 +253,11 @@ def _apply_yaml_args_to_callable(callable_obj, client, yaml_args): """ if isinstance(yaml_args, dict): return callable_obj(**_parse_expressions(client, yaml_args)) - elif isinstance(yaml_args, list): + elif isinstance(yaml_args, list) and not ( + callable_obj == expr.Constant + or callable_obj == Vector + or callable_obj == expr.Array + ): # yaml has an array of arguments. Treat as args return callable_obj(*_parse_expressions(client, yaml_args)) else: diff --git a/tests/system/test_system.py b/tests/system/test_system.py index a8f94e2ba..c2bd93ef8 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -109,9 +109,11 @@ def _clean_results(results): if isinstance(query, BaseAggregationQuery): # aggregation queries return a list of lists of aggregation results query_results = _clean_results( - list(itertools.chain.from_iterable( - [[a._to_dict() for a in s] for s in query.get()] - )) + list( + itertools.chain.from_iterable( + [[a._to_dict() for a in s] for s in query.get()] + ) + ) ) else: # other qureies return a simple list of results @@ -1531,6 +1533,7 @@ def test_query_stream_or_get_w_no_explain_options(query_docs, database, method): results.get_explain_metrics() verify_pipeline(query) + @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index b78a77786..d053cbd7a 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -208,7 +208,9 @@ def _clean_results(results): await pipeline.execute() else: # ensure results match query - pipeline_results = _clean_results([s.data() async for s in pipeline.stream()]) + pipeline_results = _clean_results( + [s.data() async for s in pipeline.stream()] + ) assert query_results == pipeline_results except FailedPrecondition as e: # if testing against a non-enterprise db, skip this check @@ -216,7 +218,6 @@ def _clean_results(results): raise e - @pytest.fixture(scope="module") def event_loop(): """Change event_loop fixture to module level.""" diff --git a/tests/unit/v1/test_aggregation.py b/tests/unit/v1/test_aggregation.py index 5064e87ae..9a20fd386 100644 --- a/tests/unit/v1/test_aggregation.py +++ b/tests/unit/v1/test_aggregation.py @@ -1136,7 +1136,7 @@ def test_aggreation_to_pipeline_count_increment(): assert len(aggregate_stage.accumulators) == n for i in range(n): assert isinstance(aggregate_stage.accumulators[i].expr, Count) - assert aggregate_stage.accumulators[i].alias == f"field_{i+1}" + assert aggregate_stage.accumulators[i].alias == f"field_{i + 1}" def test_aggreation_to_pipeline_complex(): diff --git a/tests/unit/v1/test_async_aggregation.py b/tests/unit/v1/test_async_aggregation.py index fdd4a1450..701feab5b 100644 --- a/tests/unit/v1/test_async_aggregation.py +++ b/tests/unit/v1/test_async_aggregation.py @@ -810,7 +810,7 @@ def test_aggreation_to_pipeline_count_increment(): assert len(aggregate_stage.accumulators) == n for i in range(n): assert isinstance(aggregate_stage.accumulators[i].expr, Count) - assert aggregate_stage.accumulators[i].alias == f"field_{i+1}" + assert aggregate_stage.accumulators[i].alias == f"field_{i + 1}" def test_async_aggreation_to_pipeline_complex(): diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index 9f06c47b8..aec721e7d 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -24,7 +24,6 @@ from google.cloud.firestore_v1._helpers import GeoPoint import google.cloud.firestore_v1.pipeline_expressions as expr from google.cloud.firestore_v1.pipeline_expressions import BooleanExpr -from google.cloud.firestore_v1.pipeline_expressions import _ListOfExprs from google.cloud.firestore_v1.pipeline_expressions import Expr from google.cloud.firestore_v1.pipeline_expressions import Constant from google.cloud.firestore_v1.pipeline_expressions import Field @@ -97,13 +96,6 @@ class TestConstant: Value(timestamp_value={"seconds": 1747008000}), ), (GeoPoint(1, 2), Value(geo_point_value={"latitude": 1, "longitude": 2})), - ( - [0.0, 1.0, 2.0], - Value( - array_value={"values": [Value(double_value=i) for i in range(3)]} - ), - ), - ({"a": "b"}, Value(map_value={"fields": {"a": Value(string_value="b")}})), ( Vector([1.0, 2.0]), Value( @@ -173,57 +165,6 @@ def test_equality(self, first, second, expected): assert (first == second) is expected -class TestListOfExprs: - def test_to_pb(self): - instance = _ListOfExprs([Constant(1), Constant(2)]) - result = instance._to_pb() - assert len(result.array_value.values) == 2 - assert result.array_value.values[0].integer_value == 1 - assert result.array_value.values[1].integer_value == 2 - - def test_empty_to_pb(self): - instance = _ListOfExprs([]) - result = instance._to_pb() - assert len(result.array_value.values) == 0 - - def test_repr(self): - instance = _ListOfExprs([Constant(1), Constant(2)]) - repr_string = repr(instance) - assert repr_string == "[Constant.of(1), Constant.of(2)]" - empty_instance = _ListOfExprs([]) - empty_repr_string = repr(empty_instance) - assert empty_repr_string == "[]" - - @pytest.mark.parametrize( - "first,second,expected", - [ - (_ListOfExprs([]), _ListOfExprs([]), True), - (_ListOfExprs([]), _ListOfExprs([Constant(1)]), False), - (_ListOfExprs([Constant(1)]), _ListOfExprs([]), False), - ( - _ListOfExprs([Constant(1)]), - _ListOfExprs([Constant(1)]), - True, - ), - ( - _ListOfExprs([Constant(1)]), - _ListOfExprs([Constant(2)]), - False, - ), - ( - _ListOfExprs([Constant(1), Constant(2)]), - _ListOfExprs([Constant(1), Constant(2)]), - True, - ), - (_ListOfExprs([Constant(1)]), [Constant(1)], False), - (_ListOfExprs([Constant(1)]), [1], False), - (_ListOfExprs([Constant(1)]), object(), False), - ], - ) - def test_equality(self, first, second, expected): - assert (first == second) is expected - - class TestSelectable: """ contains tests for each Expr class that derives from Selectable @@ -370,7 +311,7 @@ def test__from_query_filter_pb_composite_filter_or(self, mock_client): field1 = Field.of("field1") field2 = Field.of("field2") expected_cond1 = expr.And(field1.exists(), field1.equal(Constant("val1"))) - expected_cond2 = expr.And(field2.exists(), field2.equal(Constant(None))) + expected_cond2 = expr.And(field2.exists(), field2.is_null()) expected = expr.Or(expected_cond1, expected_cond2) assert repr(result) == repr(expected) @@ -457,9 +398,7 @@ def test__from_query_filter_pb_composite_filter_nested(self, mock_client): field3 = Field.of("field3") expected_cond1 = expr.And(field1.exists(), field1.equal(Constant("val1"))) expected_cond2 = expr.And(field2.exists(), field2.greater_than(Constant(10))) - expected_cond3 = expr.And( - field3.exists(), expr.Not(field3.equal(Constant(None))) - ) + expected_cond3 = expr.And(field3.exists(), field3.is_not_null()) expected_inner_and = expr.And(expected_cond2, expected_cond3) expected_outer_or = expr.Or(expected_cond1, expected_inner_and) @@ -491,15 +430,15 @@ def test__from_query_filter_pb_composite_filter_unknown_op(self, mock_client): (query_pb.StructuredQuery.UnaryFilter.Operator.IS_NAN, Expr.is_nan), ( query_pb.StructuredQuery.UnaryFilter.Operator.IS_NOT_NAN, - lambda f: expr.Not(f.is_nan()), + Expr.is_not_nan, ), ( query_pb.StructuredQuery.UnaryFilter.Operator.IS_NULL, - lambda f: f.equal(None), + Expr.is_null, ), ( query_pb.StructuredQuery.UnaryFilter.Operator.IS_NOT_NULL, - lambda f: expr.Not(f.equal(None)), + Expr.is_not_null, ), ], ) @@ -643,6 +582,69 @@ def test__from_query_filter_pb_unknown_filter_type(self, mock_client): BooleanExpr._from_query_filter_pb(document_pb.Value(), mock_client) +class TestArray: + """Tests for the array class""" + + def test_array(self): + arg1 = Field.of("field1") + instance = expr.Array([arg1]) + assert instance.name == "array" + assert instance.params == [arg1] + assert repr(instance) == "Array([Field.of('field1')])" + + def test_empty_array(self): + instance = expr.Array([]) + assert instance.name == "array" + assert instance.params == [] + assert repr(instance) == "Array([])" + + def test_array_w_primitives(self): + a = expr.Array([1, Constant.of(2), "3"]) + assert a.name == "array" + assert a.params == [Constant.of(1), Constant.of(2), Constant.of("3")] + assert repr(a) == "Array([Constant.of(1), Constant.of(2), Constant.of('3')])" + + def test_array_w_non_list(self): + with pytest.raises(TypeError): + expr.Array(1) + + +class TestMap: + """Tests for the map class""" + + def test_map(self): + instance = expr.Map({Constant.of("a"): Constant.of("b")}) + assert instance.name == "map" + assert instance.params == [Constant.of("a"), Constant.of("b")] + assert repr(instance) == "Map({'a': 'b'})" + + def test_map_w_primitives(self): + instance = expr.Map({"a": "b", "0": 0, "bool": True}) + assert instance.params == [ + Constant.of("a"), + Constant.of("b"), + Constant.of("0"), + Constant.of(0), + Constant.of("bool"), + Constant.of(True), + ] + assert repr(instance) == "Map({'a': 'b', '0': 0, 'bool': True})" + + def test_empty_map(self): + instance = expr.Map({}) + assert instance.name == "map" + assert instance.params == [] + assert repr(instance) == "Map({})" + + def test_w_exprs(self): + instance = expr.Map({Constant.of("a"): expr.Array([1, 2, 3])}) + assert instance.params == [Constant.of("a"), expr.Array([1, 2, 3])] + assert ( + repr(instance) + == "Map({'a': Array([Constant.of(1), Constant.of(2), Constant.of(3)])})" + ) + + class TestExpressionMethods: """ contains test methods for each Expr method @@ -723,10 +725,13 @@ def test_array_contains_any(self): arg3 = self._make_arg("Element2") instance = Expr.array_contains_any(arg1, [arg2, arg3]) assert instance.name == "array_contains_any" - assert isinstance(instance.params[1], _ListOfExprs) + assert isinstance(instance.params[1], expr.Array) assert instance.params[0] == arg1 - assert instance.params[1].exprs == [arg2, arg3] - assert repr(instance) == "ArrayField.array_contains_any([Element1, Element2])" + assert instance.params[1].params == [arg2, arg3] + assert ( + repr(instance) + == "ArrayField.array_contains_any(Array([Element1, Element2]))" + ) infix_instance = arg1.array_contains_any([arg2, arg3]) assert infix_instance == instance @@ -805,10 +810,10 @@ def test_equal_any(self): arg3 = self._make_arg("Value2") instance = Expr.equal_any(arg1, [arg2, arg3]) assert instance.name == "equal_any" - assert isinstance(instance.params[1], _ListOfExprs) + assert isinstance(instance.params[1], expr.Array) assert instance.params[0] == arg1 - assert instance.params[1].exprs == [arg2, arg3] - assert repr(instance) == "Field.equal_any([Value1, Value2])" + assert instance.params[1].params == [arg2, arg3] + assert repr(instance) == "Field.equal_any(Array([Value1, Value2]))" infix_instance = arg1.equal_any([arg2, arg3]) assert infix_instance == instance @@ -818,13 +823,32 @@ def test_not_equal_any(self): arg3 = self._make_arg("Value2") instance = Expr.not_equal_any(arg1, [arg2, arg3]) assert instance.name == "not_equal_any" - assert isinstance(instance.params[1], _ListOfExprs) + assert isinstance(instance.params[1], expr.Array) assert instance.params[0] == arg1 - assert instance.params[1].exprs == [arg2, arg3] - assert repr(instance) == "Field.not_equal_any([Value1, Value2])" + assert instance.params[1].params == [arg2, arg3] + assert repr(instance) == "Field.not_equal_any(Array([Value1, Value2]))" infix_instance = arg1.not_equal_any([arg2, arg3]) assert infix_instance == instance + def test_is_absent(self): + arg1 = self._make_arg("Field") + instance = Expr.is_absent(arg1) + assert instance.name == "is_absent" + assert instance.params == [arg1] + assert repr(instance) == "Field.is_absent()" + infix_instance = arg1.is_absent() + assert infix_instance == instance + + def test_if_absent(self): + arg1 = self._make_arg("Field") + arg2 = self._make_arg("ThenExpr") + instance = Expr.if_absent(arg1, arg2) + assert instance.name == "if_absent" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Field.if_absent(ThenExpr)" + infix_instance = arg1.if_absent(arg2) + assert infix_instance == instance + def test_is_nan(self): arg1 = self._make_arg("Value") instance = Expr.is_nan(arg1) @@ -834,15 +858,52 @@ def test_is_nan(self): infix_instance = arg1.is_nan() assert infix_instance == instance + def test_is_not_nan(self): + arg1 = self._make_arg("Value") + instance = Expr.is_not_nan(arg1) + assert instance.name == "is_not_nan" + assert instance.params == [arg1] + assert repr(instance) == "Value.is_not_nan()" + infix_instance = arg1.is_not_nan() + assert infix_instance == instance + def test_is_null(self): arg1 = self._make_arg("Value") - instance = Expr.is_ull(arg1) + instance = Expr.is_null(arg1) assert instance.name == "is_null" assert instance.params == [arg1] assert repr(instance) == "Value.is_null()" infix_instance = arg1.is_null() assert infix_instance == instance + def test_is_not_null(self): + arg1 = self._make_arg("Value") + instance = Expr.is_not_null(arg1) + assert instance.name == "is_not_null" + assert instance.params == [arg1] + assert repr(instance) == "Value.is_not_null()" + infix_instance = arg1.is_not_null() + assert infix_instance == instance + + def test_is_error(self): + arg1 = self._make_arg("Value") + instance = Expr.is_error(arg1) + assert instance.name == "is_error" + assert instance.params == [arg1] + assert repr(instance) == "Value.is_error()" + infix_instance = arg1.is_error() + assert infix_instance == instance + + def test_if_error(self): + arg1 = self._make_arg("Value") + arg2 = self._make_arg("ThenExpr") + instance = Expr.if_error(arg1, arg2) + assert instance.name == "if_error" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Value.if_error(ThenExpr)" + infix_instance = arg1.if_error(arg2) + assert infix_instance == instance + def test_not(self): arg1 = self._make_arg("Condition") instance = expr.Not(arg1) @@ -856,10 +917,13 @@ def test_array_contains_all(self): arg3 = self._make_arg("Element2") instance = Expr.array_contains_all(arg1, [arg2, arg3]) assert instance.name == "array_contains_all" - assert isinstance(instance.params[1], _ListOfExprs) + assert isinstance(instance.params[1], expr.Array) assert instance.params[0] == arg1 - assert instance.params[1].exprs == [arg2, arg3] - assert repr(instance) == "ArrayField.array_contains_all([Element1, Element2])" + assert instance.params[1].params == [arg2, arg3] + assert ( + repr(instance) + == "ArrayField.array_contains_all(Array([Element1, Element2]))" + ) infix_instance = arg1.array_contains_all([arg2, arg3]) assert infix_instance == instance @@ -970,6 +1034,73 @@ def test_logical_minimum(self): infix_instance = arg1.logical_minimum(arg2) assert infix_instance == instance + def test_to_lower(self): + arg1 = self._make_arg("Input") + instance = Expr.to_lower(arg1) + assert instance.name == "to_lower" + assert instance.params == [arg1] + assert repr(instance) == "Input.to_lower()" + infix_instance = arg1.to_lower() + assert infix_instance == instance + + def test_to_upper(self): + arg1 = self._make_arg("Input") + instance = Expr.to_upper(arg1) + assert instance.name == "to_upper" + assert instance.params == [arg1] + assert repr(instance) == "Input.to_upper()" + infix_instance = arg1.to_upper() + assert infix_instance == instance + + def test_trim(self): + arg1 = self._make_arg("Input") + instance = Expr.trim(arg1) + assert instance.name == "trim" + assert instance.params == [arg1] + assert repr(instance) == "Input.trim()" + infix_instance = arg1.trim() + assert infix_instance == instance + + def test_string_reverse(self): + arg1 = self._make_arg("Input") + instance = Expr.string_reverse(arg1) + assert instance.name == "string_reverse" + assert instance.params == [arg1] + assert repr(instance) == "Input.string_reverse()" + infix_instance = arg1.string_reverse() + assert infix_instance == instance + + def test_substring(self): + arg1 = self._make_arg("Input") + arg2 = self._make_arg("Position") + instance = Expr.substring(arg1, arg2) + assert instance.name == "substring" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Input.substring(Position)" + infix_instance = arg1.substring(arg2) + assert infix_instance == instance + + def test_substring_w_length(self): + arg1 = self._make_arg("Input") + arg2 = self._make_arg("Position") + arg3 = self._make_arg("Length") + instance = Expr.substring(arg1, arg2, arg3) + assert instance.name == "substring" + assert instance.params == [arg1, arg2, arg3] + assert repr(instance) == "Input.substring(Position, Length)" + infix_instance = arg1.substring(arg2, arg3) + assert infix_instance == instance + + def test_join(self): + arg1 = self._make_arg("Array") + arg2 = self._make_arg("Separator") + instance = Expr.join(arg1, arg2) + assert instance.name == "join" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Array.join(Separator)" + infix_instance = arg1.join(arg2) + assert infix_instance == instance + def test_map_get(self): arg1 = self._make_arg("Map") arg2 = "key" @@ -980,6 +1111,27 @@ def test_map_get(self): infix_instance = arg1.map_get(Constant.of(arg2)) assert infix_instance == instance + def test_map_remove(self): + arg1 = self._make_arg("Map") + arg2 = "key" + instance = Expr.map_remove(arg1, arg2) + assert instance.name == "map_remove" + assert instance.params == [arg1, Constant.of(arg2)] + assert repr(instance) == "Map.map_remove(Constant.of('key'))" + infix_instance = arg1.map_remove(Constant.of(arg2)) + assert infix_instance == instance + + def test_map_merge(self): + arg1 = expr.Map({"a": 1}) + arg2 = expr.Map({"b": 2}) + arg3 = {"c": 3} + instance = Expr.map_merge(arg1, arg2, arg3) + assert instance.name == "map_merge" + assert instance.params == [arg1, arg2, expr.Map(arg3)] + assert repr(instance) == "Map({'a': 1}).map_merge(Map({'b': 2}), Map({'c': 3}))" + infix_instance = arg1.map_merge(arg2, arg3) + assert infix_instance == instance + def test_mod(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") @@ -1021,6 +1173,12 @@ def test_subtract(self): infix_instance = arg1.subtract(arg2) assert infix_instance == instance + def test_current_timestamp(self): + instance = expr.CurrentTimestamp() + assert instance.name == "current_timestamp" + assert instance.params == [] + assert repr(instance) == "CurrentTimestamp()" + def test_timestamp_add(self): arg1 = self._make_arg("Timestamp") arg2 = self._make_arg("Unit") @@ -1097,6 +1255,54 @@ def test_unix_seconds_to_timestamp(self): infix_instance = arg1.unix_seconds_to_timestamp() assert infix_instance == instance + def test_euclidean_distance(self): + arg1 = self._make_arg("Vector1") + arg2 = self._make_arg("Vector2") + instance = Expr.euclidean_distance(arg1, arg2) + assert instance.name == "euclidean_distance" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Vector1.euclidean_distance(Vector2)" + infix_instance = arg1.euclidean_distance(arg2) + assert infix_instance == instance + + def test_cosine_distance(self): + arg1 = self._make_arg("Vector1") + arg2 = self._make_arg("Vector2") + instance = Expr.cosine_distance(arg1, arg2) + assert instance.name == "cosine_distance" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Vector1.cosine_distance(Vector2)" + infix_instance = arg1.cosine_distance(arg2) + assert infix_instance == instance + + def test_dot_product(self): + arg1 = self._make_arg("Vector1") + arg2 = self._make_arg("Vector2") + instance = Expr.dot_product(arg1, arg2) + assert instance.name == "dot_product" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Vector1.dot_product(Vector2)" + infix_instance = arg1.dot_product(arg2) + assert infix_instance == instance + + @pytest.mark.parametrize( + "method", ["euclidean_distance", "cosine_distance", "dot_product"] + ) + @pytest.mark.parametrize( + "input", [Vector([1.0, 2.0]), [1, 2], Constant.of(Vector([1.0, 2.0])), []] + ) + def test_vector_ctor(self, method, input): + """ + test constructing various vector expressions with + different inputs + """ + arg1 = self._make_arg("VectorRef") + instance = getattr(arg1, method)(input) + assert instance.name == method + got_second_param = instance.params[1] + assert isinstance(got_second_param, Constant) + assert isinstance(got_second_param.value, Vector) + def test_vector_length(self): arg1 = self._make_arg("Array") instance = Expr.vector_length(arg1) @@ -1116,6 +1322,98 @@ def test_add(self): infix_instance = arg1.add(arg2) assert infix_instance == instance + def test_abs(self): + arg1 = self._make_arg("Value") + instance = Expr.abs(arg1) + assert instance.name == "abs" + assert instance.params == [arg1] + assert repr(instance) == "Value.abs()" + infix_instance = arg1.abs() + assert infix_instance == instance + + def test_ceil(self): + arg1 = self._make_arg("Value") + instance = Expr.ceil(arg1) + assert instance.name == "ceil" + assert instance.params == [arg1] + assert repr(instance) == "Value.ceil()" + infix_instance = arg1.ceil() + assert infix_instance == instance + + def test_exp(self): + arg1 = self._make_arg("Value") + instance = Expr.exp(arg1) + assert instance.name == "exp" + assert instance.params == [arg1] + assert repr(instance) == "Value.exp()" + infix_instance = arg1.exp() + assert infix_instance == instance + + def test_floor(self): + arg1 = self._make_arg("Value") + instance = Expr.floor(arg1) + assert instance.name == "floor" + assert instance.params == [arg1] + assert repr(instance) == "Value.floor()" + infix_instance = arg1.floor() + assert infix_instance == instance + + def test_ln(self): + arg1 = self._make_arg("Value") + instance = Expr.ln(arg1) + assert instance.name == "ln" + assert instance.params == [arg1] + assert repr(instance) == "Value.ln()" + infix_instance = arg1.ln() + assert infix_instance == instance + + def test_log(self): + arg1 = self._make_arg("Value") + arg2 = self._make_arg("Base") + instance = Expr.log(arg1, arg2) + assert instance.name == "log" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Value.log(Base)" + infix_instance = arg1.log(arg2) + assert infix_instance == instance + + def test_log10(self): + arg1 = self._make_arg("Value") + instance = Expr.log10(arg1) + assert instance.name == "log10" + assert instance.params == [arg1] + assert repr(instance) == "Value.log10()" + infix_instance = arg1.log10() + assert infix_instance == instance + + def test_pow(self): + arg1 = self._make_arg("Value") + arg2 = self._make_arg("Exponent") + instance = Expr.pow(arg1, arg2) + assert instance.name == "pow" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Value.pow(Exponent)" + infix_instance = arg1.pow(arg2) + assert infix_instance == instance + + def test_round(self): + arg1 = self._make_arg("Value") + instance = Expr.round(arg1) + assert instance.name == "round" + assert instance.params == [arg1] + assert repr(instance) == "Value.round()" + infix_instance = arg1.round() + assert infix_instance == instance + + def test_sqrt(self): + arg1 = self._make_arg("Value") + instance = Expr.sqrt(arg1) + assert instance.name == "sqrt" + assert instance.params == [arg1] + assert repr(instance) == "Value.sqrt()" + infix_instance = arg1.sqrt() + assert infix_instance == instance + def test_array_length(self): arg1 = self._make_arg("Array") instance = Expr.array_length(arg1) @@ -1134,6 +1432,29 @@ def test_array_reverse(self): infix_instance = arg1.array_reverse() assert infix_instance == instance + def test_array_concat(self): + arg1 = self._make_arg("ArrayRef1") + arg2 = self._make_arg("ArrayRef2") + instance = Expr.array_concat(arg1, arg2) + assert instance.name == "array_concat" + assert instance.params == [arg1, arg2] + assert repr(instance) == "ArrayRef1.array_concat(ArrayRef2)" + infix_instance = arg1.array_concat(arg2) + assert infix_instance == instance + + def test_array_concat_multiple(self): + arg1 = expr.Array([Constant.of(0)]) + arg2 = Field.of("ArrayRef2") + arg3 = Field.of("ArrayRef3") + arg4 = [self._make_arg("Constant")] + instance = arg1.array_concat(arg2, arg3, arg4) + assert instance.name == "array_concat" + assert instance.params == [arg1, arg2, arg3, expr.Array(arg4)] + assert ( + repr(instance) + == "Array([Constant.of(0)]).array_concat(Field.of('ArrayRef2'), Field.of('ArrayRef3'), Array([Constant]))" + ) + def test_byte_length(self): arg1 = self._make_arg("Expr") instance = Expr.byte_length(arg1) @@ -1152,6 +1473,26 @@ def test_char_length(self): infix_instance = arg1.char_length() assert infix_instance == instance + def test_concat(self): + arg1 = self._make_arg("First") + arg2 = self._make_arg("Second") + arg3 = "Third" + instance = Expr.concat(arg1, arg2, arg3) + assert instance.name == "concat" + assert instance.params == [arg1, arg2, Constant.of(arg3)] + assert repr(instance) == "First.concat(Second, Constant.of('Third'))" + infix_instance = arg1.concat(arg2, arg3) + assert infix_instance == instance + + def test_length(self): + arg1 = self._make_arg("Expr") + instance = Expr.length(arg1) + assert instance.name == "length" + assert instance.params == [arg1] + assert repr(instance) == "Expr.length()" + infix_instance = arg1.length() + assert infix_instance == instance + def test_collection_id(self): arg1 = self._make_arg("Value") instance = Expr.collection_id(arg1) @@ -1161,6 +1502,15 @@ def test_collection_id(self): infix_instance = arg1.collection_id() assert infix_instance == instance + def test_document_id(self): + arg1 = self._make_arg("Value") + instance = Expr.document_id(arg1) + assert instance.name == "document_id" + assert instance.params == [arg1] + assert repr(instance) == "Value.document_id()" + infix_instance = arg1.document_id() + assert infix_instance == instance + def test_sum(self): arg1 = self._make_arg("Value") instance = Expr.sum(arg1) @@ -1194,6 +1544,24 @@ def test_base_count(self): assert instance.params == [] assert repr(instance) == "Count()" + def test_count_if(self): + arg1 = self._make_arg("Value") + instance = Expr.count_if(arg1) + assert instance.name == "count_if" + assert instance.params == [arg1] + assert repr(instance) == "Value.count_if()" + infix_instance = arg1.count_if() + assert infix_instance == instance + + def test_count_distinct(self): + arg1 = self._make_arg("Value") + instance = Expr.count_distinct(arg1) + assert instance.name == "count_distinct" + assert instance.params == [arg1] + assert repr(instance) == "Value.count_distinct()" + infix_instance = arg1.count_distinct() + assert infix_instance == instance + def test_minimum(self): arg1 = self._make_arg("Value") instance = Expr.minimum(arg1)