diff --git a/google/cloud/firestore_v1/_pipeline_stages.py b/google/cloud/firestore_v1/_pipeline_stages.py index f7d311d89..aefddbcf8 100644 --- a/google/cloud/firestore_v1/_pipeline_stages.py +++ b/google/cloud/firestore_v1/_pipeline_stages.py @@ -23,11 +23,12 @@ from google.cloud.firestore_v1.vector import Vector from google.cloud.firestore_v1.base_vector_query import DistanceMeasure from google.cloud.firestore_v1.pipeline_expressions import ( - Accumulator, + AggregateFunction, Expr, - ExprWithAlias, + AliasedAggregate, + AliasedExpr, Field, - FilterCondition, + BooleanExpr, Selectable, Ordering, ) @@ -156,13 +157,7 @@ def __init__(self, *fields: Selectable): self.fields = list(fields) def _pb_args(self): - return [ - Value( - map_value={ - "fields": {m[0]: m[1] for m in [f._to_map() for f in self.fields]} - } - ) - ] + return [Selectable._to_value(self.fields)] class Aggregate(Stage): @@ -170,8 +165,8 @@ class Aggregate(Stage): def __init__( self, - *args: ExprWithAlias[Accumulator], - accumulators: Sequence[ExprWithAlias[Accumulator]] = (), + *args: AliasedExpr[AggregateFunction], + accumulators: Sequence[AliasedAggregate] = (), groups: Sequence[str | Selectable] = (), ): super().__init__() @@ -186,18 +181,8 @@ def __init__( def _pb_args(self): return [ - Value( - map_value={ - "fields": { - m[0]: m[1] for m in [f._to_map() for f in self.accumulators] - } - } - ), - Value( - map_value={ - "fields": {m[0]: m[1] for m in [f._to_map() for f in self.groups]} - } - ), + Selectable._to_value(self.accumulators), + Selectable._to_value(self.groups), ] def __repr__(self): @@ -254,13 +239,7 @@ def __init__(self, *fields: str | Selectable): ] def _pb_args(self) -> list[Value]: - return [ - Value( - map_value={ - "fields": {m[0]: m[1] for m in [f._to_map() for f in self.fields]} - } - ) - ] + return [Selectable._to_value(self.fields)] class Documents(Stage): @@ -461,7 +440,7 @@ def _pb_options(self): class Where(Stage): """Filters documents based on a specified condition.""" - def __init__(self, condition: FilterCondition): + def __init__(self, condition: BooleanExpr): super().__init__() self.condition = condition diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index 50ae7ab62..01f48ee78 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -23,11 +23,10 @@ from google.cloud.firestore_v1.types.firestore import ExecutePipelineRequest from google.cloud.firestore_v1.pipeline_result import PipelineResult from google.cloud.firestore_v1.pipeline_expressions import ( - Accumulator, + AliasedAggregate, Expr, - ExprWithAlias, Field, - FilterCondition, + BooleanExpr, Selectable, ) from google.cloud.firestore_v1 import _helpers @@ -220,14 +219,14 @@ def select(self, *selections: str | Selectable) -> "_BasePipeline": """ return self._append(stages.Select(*selections)) - def where(self, condition: FilterCondition) -> "_BasePipeline": + def where(self, condition: BooleanExpr) -> "_BasePipeline": """ Filters the documents from previous stages to only include those matching - the specified `FilterCondition`. + the specified `BooleanExpr`. This stage allows you to apply conditions to the data, similar to a "WHERE" clause in SQL. You can filter documents based on their field values, using - implementations of `FilterCondition`, typically including but not limited to: + implementations of `BooleanExpr`, typically including but not limited to: - field comparators: `eq`, `lt` (less than), `gt` (greater than), etc. - logical operators: `And`, `Or`, `Not`, etc. - advanced functions: `regex_matches`, `array_contains`, etc. @@ -252,7 +251,7 @@ def where(self, condition: FilterCondition) -> "_BasePipeline": Args: - condition: The `FilterCondition` to apply. + condition: The `BooleanExpr` to apply. Returns: A new Pipeline object with this stage appended to the stage list @@ -531,7 +530,7 @@ def limit(self, limit: int) -> "_BasePipeline": def aggregate( self, - *accumulators: ExprWithAlias[Accumulator], + *accumulators: AliasedAggregate, groups: Sequence[str | Selectable] = (), ) -> "_BasePipeline": """ @@ -541,7 +540,7 @@ def aggregate( This stage allows you to calculate aggregate values (like sum, average, count, min, max) over a set of documents. - - **Accumulators:** Define the aggregation calculations using `Accumulator` + - **Accumulators:** Define the aggregation calculations using `AggregateFunction` expressions (e.g., `sum()`, `avg()`, `count()`, `min()`, `max()`) combined with `as_()` to name the result field. - **Groups:** Optionally specify fields (by name or `Selectable`) to group @@ -569,7 +568,7 @@ def aggregate( Args: - *accumulators: One or more `ExprWithAlias[Accumulator]` expressions defining + *accumulators: One or more `AliasedAggregate` expressions defining the aggregations to perform and their output names. groups: An optional sequence of field names (str) or `Selectable` expressions to group by before aggregating. diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 70d619d3b..ef57f5b72 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -15,7 +15,6 @@ from __future__ import annotations from typing import ( Any, - List, Generic, TypeVar, Dict, @@ -117,7 +116,35 @@ def _to_pb(self) -> Value: def _cast_to_expr_or_convert_to_constant(o: Any) -> "Expr": return o if isinstance(o, Expr) else Constant(o) - def add(self, other: Expr | float) -> "Add": + class expose_as_static: + """ + Decorator to mark instance methods to be exposed as static methods as well as instance + methods. + + When called statically, the first argument is converted to a Field expression if needed. + + Example: + >>> Field.of("test").add(5) + >>> Function.add("test", 5) + """ + + def __init__(self, instance_func): + self.instance_func = instance_func + + def static_func(self, first_arg, *other_args, **kwargs): + first_expr = ( + Field.of(first_arg) if not isinstance(first_arg, Expr) else first_arg + ) + return self.instance_func(first_expr, *other_args, **kwargs) + + def __get__(self, instance, owner): + if instance is None: + return self.static_func.__get__(instance, owner) + else: + return self.instance_func.__get__(instance, owner) + + @expose_as_static + def add(self, other: Expr | float) -> "Expr": """Creates an expression that adds this expression to another expression or constant. Example: @@ -132,9 +159,10 @@ def add(self, other: Expr | float) -> "Add": Returns: A new `Expr` representing the addition operation. """ - return Add(self, self._cast_to_expr_or_convert_to_constant(other)) + return Function("add", [self, self._cast_to_expr_or_convert_to_constant(other)]) - def subtract(self, other: Expr | float) -> "Subtract": + @expose_as_static + def subtract(self, other: Expr | float) -> "Expr": """Creates an expression that subtracts another expression or constant from this expression. Example: @@ -149,9 +177,12 @@ def subtract(self, other: Expr | float) -> "Subtract": Returns: A new `Expr` representing the subtraction operation. """ - return Subtract(self, self._cast_to_expr_or_convert_to_constant(other)) + return Function( + "subtract", [self, self._cast_to_expr_or_convert_to_constant(other)] + ) - def multiply(self, other: Expr | float) -> "Multiply": + @expose_as_static + def multiply(self, other: Expr | float) -> "Expr": """Creates an expression that multiplies this expression by another expression or constant. Example: @@ -166,9 +197,12 @@ def multiply(self, other: Expr | float) -> "Multiply": Returns: A new `Expr` representing the multiplication operation. """ - return Multiply(self, self._cast_to_expr_or_convert_to_constant(other)) + return Function( + "multiply", [self, self._cast_to_expr_or_convert_to_constant(other)] + ) - def divide(self, other: Expr | float) -> "Divide": + @expose_as_static + def divide(self, other: Expr | float) -> "Expr": """Creates an expression that divides this expression by another expression or constant. Example: @@ -183,9 +217,12 @@ def divide(self, other: Expr | float) -> "Divide": Returns: A new `Expr` representing the division operation. """ - return Divide(self, self._cast_to_expr_or_convert_to_constant(other)) + return Function( + "divide", [self, self._cast_to_expr_or_convert_to_constant(other)] + ) - def mod(self, other: Expr | float) -> "Mod": + @expose_as_static + def mod(self, other: Expr | float) -> "Expr": """Creates an expression that calculates the modulo (remainder) to another expression or constant. Example: @@ -200,9 +237,10 @@ def mod(self, other: Expr | float) -> "Mod": Returns: A new `Expr` representing the modulo operation. """ - return Mod(self, self._cast_to_expr_or_convert_to_constant(other)) + return Function("mod", [self, self._cast_to_expr_or_convert_to_constant(other)]) - def logical_max(self, other: Expr | CONSTANT_TYPE) -> "LogicalMax": + @expose_as_static + def logical_maximum(self, other: Expr | CONSTANT_TYPE) -> "Expr": """Creates an expression that returns the larger value between this expression and another expression or constant, based on Firestore's value type ordering. @@ -211,19 +249,24 @@ def logical_max(self, other: Expr | CONSTANT_TYPE) -> "LogicalMax": Example: >>> # Returns the larger value between the 'discount' field and the 'cap' field. - >>> Field.of("discount").logical_max(Field.of("cap")) + >>> Field.of("discount").logical_maximum(Field.of("cap")) >>> # Returns the larger value between the 'value' field and 10. - >>> Field.of("value").logical_max(10) + >>> Field.of("value").logical_maximum(10) Args: other: The other expression or constant value to compare with. Returns: - A new `Expr` representing the logical max operation. + A new `Expr` representing the logical maximum operation. """ - return LogicalMax(self, self._cast_to_expr_or_convert_to_constant(other)) + return Function( + "maximum", + [self, self._cast_to_expr_or_convert_to_constant(other)], + infix_name_override="logical_maximum", + ) - def logical_min(self, other: Expr | CONSTANT_TYPE) -> "LogicalMin": + @expose_as_static + def logical_minimum(self, other: Expr | CONSTANT_TYPE) -> "Expr": """Creates an expression that returns the smaller value between this expression and another expression or constant, based on Firestore's value type ordering. @@ -232,27 +275,32 @@ def logical_min(self, other: Expr | CONSTANT_TYPE) -> "LogicalMin": Example: >>> # Returns the smaller value between the 'discount' field and the 'floor' field. - >>> Field.of("discount").logical_min(Field.of("floor")) + >>> Field.of("discount").logical_minimum(Field.of("floor")) >>> # Returns the smaller value between the 'value' field and 10. - >>> Field.of("value").logical_min(10) + >>> Field.of("value").logical_minimum(10) Args: other: The other expression or constant value to compare with. Returns: - A new `Expr` representing the logical min operation. + A new `Expr` representing the logical minimum operation. """ - return LogicalMin(self, self._cast_to_expr_or_convert_to_constant(other)) + return Function( + "minimum", + [self, self._cast_to_expr_or_convert_to_constant(other)], + infix_name_override="logical_minimum", + ) - def eq(self, other: Expr | CONSTANT_TYPE) -> "Eq": + @expose_as_static + def equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is equal to another expression or constant value. Example: >>> # Check if the 'age' field is equal to 21 - >>> Field.of("age").eq(21) + >>> Field.of("age").equal(21) >>> # Check if the 'city' field is equal to "London" - >>> Field.of("city").eq("London") + >>> Field.of("city").equal("London") Args: other: The expression or constant value to compare for equality. @@ -260,17 +308,20 @@ def eq(self, other: Expr | CONSTANT_TYPE) -> "Eq": Returns: A new `Expr` representing the equality comparison. """ - return Eq(self, self._cast_to_expr_or_convert_to_constant(other)) + return BooleanExpr( + "equal", [self, self._cast_to_expr_or_convert_to_constant(other)] + ) - def neq(self, other: Expr | CONSTANT_TYPE) -> "Neq": + @expose_as_static + def not_equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is not equal to another expression or constant value. Example: >>> # Check if the 'status' field is not equal to "completed" - >>> Field.of("status").neq("completed") + >>> Field.of("status").not_equal("completed") >>> # Check if the 'country' field is not equal to "USA" - >>> Field.of("country").neq("USA") + >>> Field.of("country").not_equal("USA") Args: other: The expression or constant value to compare for inequality. @@ -278,17 +329,20 @@ def neq(self, other: Expr | CONSTANT_TYPE) -> "Neq": Returns: A new `Expr` representing the inequality comparison. """ - return Neq(self, self._cast_to_expr_or_convert_to_constant(other)) + return BooleanExpr( + "not_equal", [self, self._cast_to_expr_or_convert_to_constant(other)] + ) - def gt(self, other: Expr | CONSTANT_TYPE) -> "Gt": + @expose_as_static + def greater_than(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is greater than another expression or constant value. Example: >>> # Check if the 'age' field is greater than the 'limit' field - >>> Field.of("age").gt(Field.of("limit")) + >>> Field.of("age").greater_than(Field.of("limit")) >>> # Check if the 'price' field is greater than 100 - >>> Field.of("price").gt(100) + >>> Field.of("price").greater_than(100) Args: other: The expression or constant value to compare for greater than. @@ -296,17 +350,20 @@ def gt(self, other: Expr | CONSTANT_TYPE) -> "Gt": Returns: A new `Expr` representing the greater than comparison. """ - return Gt(self, self._cast_to_expr_or_convert_to_constant(other)) + return BooleanExpr( + "greater_than", [self, self._cast_to_expr_or_convert_to_constant(other)] + ) - def gte(self, other: Expr | CONSTANT_TYPE) -> "Gte": + @expose_as_static + def greater_than_or_equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is greater than or equal to another expression or constant value. Example: >>> # Check if the 'quantity' field is greater than or equal to field 'requirement' plus 1 - >>> Field.of("quantity").gte(Field.of('requirement').add(1)) + >>> Field.of("quantity").greater_than_or_equal(Field.of('requirement').add(1)) >>> # Check if the 'score' field is greater than or equal to 80 - >>> Field.of("score").gte(80) + >>> Field.of("score").greater_than_or_equal(80) Args: other: The expression or constant value to compare for greater than or equal to. @@ -314,17 +371,21 @@ def gte(self, other: Expr | CONSTANT_TYPE) -> "Gte": Returns: A new `Expr` representing the greater than or equal to comparison. """ - return Gte(self, self._cast_to_expr_or_convert_to_constant(other)) + return BooleanExpr( + "greater_than_or_equal", + [self, self._cast_to_expr_or_convert_to_constant(other)], + ) - def lt(self, other: Expr | CONSTANT_TYPE) -> "Lt": + @expose_as_static + def less_than(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is less than another expression or constant value. Example: >>> # Check if the 'age' field is less than 'limit' - >>> Field.of("age").lt(Field.of('limit')) + >>> Field.of("age").less_than(Field.of('limit')) >>> # Check if the 'price' field is less than 50 - >>> Field.of("price").lt(50) + >>> Field.of("price").less_than(50) Args: other: The expression or constant value to compare for less than. @@ -332,17 +393,20 @@ def lt(self, other: Expr | CONSTANT_TYPE) -> "Lt": Returns: A new `Expr` representing the less than comparison. """ - return Lt(self, self._cast_to_expr_or_convert_to_constant(other)) + return BooleanExpr( + "less_than", [self, self._cast_to_expr_or_convert_to_constant(other)] + ) - def lte(self, other: Expr | CONSTANT_TYPE) -> "Lte": + @expose_as_static + def less_than_or_equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is less than or equal to another expression or constant value. Example: >>> # Check if the 'quantity' field is less than or equal to 20 - >>> Field.of("quantity").lte(Constant.of(20)) + >>> Field.of("quantity").less_than_or_equal(Constant.of(20)) >>> # Check if the 'score' field is less than or equal to 70 - >>> Field.of("score").lte(70) + >>> Field.of("score").less_than_or_equal(70) Args: other: The expression or constant value to compare for less than or equal to. @@ -350,15 +414,19 @@ def lte(self, other: Expr | CONSTANT_TYPE) -> "Lte": Returns: A new `Expr` representing the less than or equal to comparison. """ - return Lte(self, self._cast_to_expr_or_convert_to_constant(other)) + return BooleanExpr( + "less_than_or_equal", + [self, self._cast_to_expr_or_convert_to_constant(other)], + ) - def in_any(self, array: List[Expr | CONSTANT_TYPE]) -> "In": + @expose_as_static + def equal_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": """Creates an expression that checks if this expression is equal to any of the provided values or expressions. Example: >>> # Check if the 'category' field is either "Electronics" or value of field 'primaryType' - >>> Field.of("category").in_any(["Electronics", Field.of("primaryType")]) + >>> Field.of("category").equal_any(["Electronics", Field.of("primaryType")]) Args: array: The values or expressions to check against. @@ -366,25 +434,43 @@ def in_any(self, array: List[Expr | CONSTANT_TYPE]) -> "In": Returns: A new `Expr` representing the 'IN' comparison. """ - return In(self, [self._cast_to_expr_or_convert_to_constant(v) for v in array]) + return BooleanExpr( + "equal_any", + [ + self, + _ListOfExprs( + [self._cast_to_expr_or_convert_to_constant(v) for v in array] + ), + ], + ) - def not_in_any(self, array: List[Expr | CONSTANT_TYPE]) -> "Not": + @expose_as_static + def not_equal_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": """Creates an expression that checks if this expression is not equal to any of the provided values or expressions. Example: >>> # Check if the 'status' field is neither "pending" nor "cancelled" - >>> Field.of("status").not_in_any(["pending", "cancelled"]) + >>> Field.of("status").not_equal_any(["pending", "cancelled"]) Args: - *others: The values or expressions to check against. + array: The values or expressions to check against. Returns: A new `Expr` representing the 'NOT IN' comparison. """ - return Not(self.in_any(array)) + return BooleanExpr( + "not_equal_any", + [ + self, + _ListOfExprs( + [self._cast_to_expr_or_convert_to_constant(v) for v in array] + ), + ], + ) - def array_contains(self, element: Expr | CONSTANT_TYPE) -> "ArrayContains": + @expose_as_static + def array_contains(self, element: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if an array contains a specific element or value. Example: @@ -399,11 +485,15 @@ def array_contains(self, element: Expr | CONSTANT_TYPE) -> "ArrayContains": Returns: A new `Expr` representing the 'array_contains' comparison. """ - return ArrayContains(self, self._cast_to_expr_or_convert_to_constant(element)) + return BooleanExpr( + "array_contains", [self, self._cast_to_expr_or_convert_to_constant(element)] + ) + @expose_as_static def array_contains_all( - self, elements: List[Expr | CONSTANT_TYPE] - ) -> "ArrayContainsAll": + self, + elements: Sequence[Expr | CONSTANT_TYPE], + ) -> "BooleanExpr": """Creates an expression that checks if an array contains all the specified elements. Example: @@ -418,13 +508,21 @@ def array_contains_all( Returns: A new `Expr` representing the 'array_contains_all' comparison. """ - return ArrayContainsAll( - self, [self._cast_to_expr_or_convert_to_constant(e) for e in elements] + return BooleanExpr( + "array_contains_all", + [ + self, + _ListOfExprs( + [self._cast_to_expr_or_convert_to_constant(e) for e in elements] + ), + ], ) + @expose_as_static def array_contains_any( - self, elements: List[Expr | CONSTANT_TYPE] - ) -> "ArrayContainsAny": + self, + elements: Sequence[Expr | CONSTANT_TYPE], + ) -> "BooleanExpr": """Creates an expression that checks if an array contains any of the specified elements. Example: @@ -440,11 +538,18 @@ def array_contains_any( Returns: A new `Expr` representing the 'array_contains_any' comparison. """ - return ArrayContainsAny( - self, [self._cast_to_expr_or_convert_to_constant(e) for e in elements] + return BooleanExpr( + "array_contains_any", + [ + self, + _ListOfExprs( + [self._cast_to_expr_or_convert_to_constant(e) for e in elements] + ), + ], ) - def array_length(self) -> "ArrayLength": + @expose_as_static + def array_length(self) -> "Expr": """Creates an expression that calculates the length of an array. Example: @@ -454,9 +559,10 @@ def array_length(self) -> "ArrayLength": Returns: A new `Expr` representing the length of the array. """ - return ArrayLength(self) + return Function("array_length", [self]) - def array_reverse(self) -> "ArrayReverse": + @expose_as_static + def array_reverse(self) -> "Expr": """Creates an expression that returns the reversed content of an array. Example: @@ -466,9 +572,10 @@ def array_reverse(self) -> "ArrayReverse": Returns: A new `Expr` representing the reversed array. """ - return ArrayReverse(self) + return Function("array_reverse", [self]) - def is_nan(self) -> "IsNaN": + @expose_as_static + def is_nan(self) -> "BooleanExpr": """Creates an expression that checks if this expression evaluates to 'NaN' (Not a Number). Example: @@ -478,9 +585,10 @@ def is_nan(self) -> "IsNaN": Returns: A new `Expr` representing the 'isNaN' check. """ - return IsNaN(self) + return BooleanExpr("is_nan", [self]) - def exists(self) -> "Exists": + @expose_as_static + def exists(self) -> "BooleanExpr": """Creates an expression that checks if a field exists in the document. Example: @@ -490,9 +598,10 @@ def exists(self) -> "Exists": Returns: A new `Expr` representing the 'exists' check. """ - return Exists(self) + return BooleanExpr("exists", [self]) - def sum(self) -> "Sum": + @expose_as_static + def sum(self) -> "Expr": """Creates an aggregation that calculates the sum of a numeric field across multiple stage inputs. Example: @@ -500,24 +609,25 @@ def sum(self) -> "Sum": >>> Field.of("orderAmount").sum().as_("totalRevenue") Returns: - A new `Accumulator` representing the 'sum' aggregation. + A new `AggregateFunction` representing the 'sum' aggregation. """ - return Sum(self) + return AggregateFunction("sum", [self]) - def avg(self) -> "Avg": + @expose_as_static + def average(self) -> "Expr": """Creates an aggregation that calculates the average (mean) of a numeric field across multiple stage inputs. Example: >>> # Calculate the average age of users - >>> Field.of("age").avg().as_("averageAge") + >>> Field.of("age").average().as_("averageAge") Returns: - A new `Accumulator` representing the 'avg' aggregation. + A new `AggregateFunction` representing the 'avg' aggregation. """ - return Avg(self) + return AggregateFunction("average", [self]) - def count(self) -> "Count": + def count(self) -> "Expr": """Creates an aggregation that counts the number of stage inputs with valid evaluations of the expression or field. @@ -526,35 +636,38 @@ def count(self) -> "Count": >>> Field.of("productId").count().as_("totalProducts") Returns: - A new `Accumulator` representing the 'count' aggregation. + A new `AggregateFunction` representing the 'count' aggregation. """ - return Count(self) + return AggregateFunction("count", [self]) - def min(self) -> "Min": + @expose_as_static + def minimum(self) -> "Expr": """Creates an aggregation that finds the minimum value of a field across multiple stage inputs. Example: >>> # Find the lowest price of all products - >>> Field.of("price").min().as_("lowestPrice") + >>> Field.of("price").minimum().as_("lowestPrice") Returns: - A new `Accumulator` representing the 'min' aggregation. + A new `AggregateFunction` representing the 'minimum' aggregation. """ - return Min(self) + return AggregateFunction("minimum", [self]) - def max(self) -> "Max": + @expose_as_static + def maximum(self) -> "Expr": """Creates an aggregation that finds the maximum value of a field across multiple stage inputs. Example: >>> # Find the highest score in a leaderboard - >>> Field.of("score").max().as_("highestScore") + >>> Field.of("score").maximum().as_("highestScore") Returns: - A new `Accumulator` representing the 'max' aggregation. + A new `AggregateFunction` representing the 'maximum' aggregation. """ - return Max(self) + return AggregateFunction("maximum", [self]) - def char_length(self) -> "CharLength": + @expose_as_static + def char_length(self) -> "Expr": """Creates an expression that calculates the character length of a string. Example: @@ -564,9 +677,10 @@ def char_length(self) -> "CharLength": Returns: A new `Expr` representing the length of the string. """ - return CharLength(self) + return Function("char_length", [self]) - def byte_length(self) -> "ByteLength": + @expose_as_static + def byte_length(self) -> "Expr": """Creates an expression that calculates the byte length of a string in its UTF-8 form. Example: @@ -576,9 +690,10 @@ def byte_length(self) -> "ByteLength": Returns: A new `Expr` representing the byte length of the string. """ - return ByteLength(self) + return Function("byte_length", [self]) - def like(self, pattern: Expr | str) -> "Like": + @expose_as_static + def like(self, pattern: Expr | str) -> "BooleanExpr": """Creates an expression that performs a case-sensitive string comparison. Example: @@ -593,9 +708,12 @@ def like(self, pattern: Expr | str) -> "Like": Returns: A new `Expr` representing the 'like' comparison. """ - return Like(self, self._cast_to_expr_or_convert_to_constant(pattern)) + return BooleanExpr( + "like", [self, self._cast_to_expr_or_convert_to_constant(pattern)] + ) - def regex_contains(self, regex: Expr | str) -> "RegexContains": + @expose_as_static + def regex_contains(self, regex: Expr | str) -> "BooleanExpr": """Creates an expression that checks if a string contains a specified regular expression as a substring. @@ -611,16 +729,19 @@ def regex_contains(self, regex: Expr | str) -> "RegexContains": Returns: A new `Expr` representing the 'contains' comparison. """ - return RegexContains(self, self._cast_to_expr_or_convert_to_constant(regex)) + return BooleanExpr( + "regex_contains", [self, self._cast_to_expr_or_convert_to_constant(regex)] + ) - def regex_matches(self, regex: Expr | str) -> "RegexMatch": + @expose_as_static + def regex_match(self, regex: Expr | str) -> "BooleanExpr": """Creates an expression that checks if a string matches a specified regular expression. Example: >>> # Check if the 'email' field matches a valid email pattern - >>> Field.of("email").regex_matches("[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Za-z]{2,}") + >>> Field.of("email").regex_match("[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Za-z]{2,}") >>> # Check if the 'email' field matches a regular expression stored in field 'regex' - >>> Field.of("email").regex_matches(Field.of("regex")) + >>> Field.of("email").regex_match(Field.of("regex")) Args: regex: The regular expression (string or expression) to use for the match. @@ -628,16 +749,19 @@ def regex_matches(self, regex: Expr | str) -> "RegexMatch": Returns: A new `Expr` representing the regular expression match. """ - return RegexMatch(self, self._cast_to_expr_or_convert_to_constant(regex)) + return BooleanExpr( + "regex_match", [self, self._cast_to_expr_or_convert_to_constant(regex)] + ) - def str_contains(self, substring: Expr | str) -> "StrContains": + @expose_as_static + def string_contains(self, substring: Expr | str) -> "BooleanExpr": """Creates an expression that checks if this string expression contains a specified substring. Example: >>> # Check if the 'description' field contains "example". - >>> Field.of("description").str_contains("example") + >>> Field.of("description").string_contains("example") >>> # Check if the 'description' field contains the value of the 'keyword' field. - >>> Field.of("description").str_contains(Field.of("keyword")) + >>> Field.of("description").string_contains(Field.of("keyword")) Args: substring: The substring (string or expression) to use for the search. @@ -645,9 +769,13 @@ def str_contains(self, substring: Expr | str) -> "StrContains": Returns: A new `Expr` representing the 'contains' comparison. """ - return StrContains(self, self._cast_to_expr_or_convert_to_constant(substring)) + return BooleanExpr( + "string_contains", + [self, self._cast_to_expr_or_convert_to_constant(substring)], + ) - def starts_with(self, prefix: Expr | str) -> "StartsWith": + @expose_as_static + def starts_with(self, prefix: Expr | str) -> "BooleanExpr": """Creates an expression that checks if a string starts with a given prefix. Example: @@ -662,9 +790,12 @@ def starts_with(self, prefix: Expr | str) -> "StartsWith": Returns: A new `Expr` representing the 'starts with' comparison. """ - return StartsWith(self, self._cast_to_expr_or_convert_to_constant(prefix)) + return BooleanExpr( + "starts_with", [self, self._cast_to_expr_or_convert_to_constant(prefix)] + ) - def ends_with(self, postfix: Expr | str) -> "EndsWith": + @expose_as_static + def ends_with(self, postfix: Expr | str) -> "BooleanExpr": """Creates an expression that checks if a string ends with a given postfix. Example: @@ -679,14 +810,17 @@ def ends_with(self, postfix: Expr | str) -> "EndsWith": Returns: A new `Expr` representing the 'ends with' comparison. """ - return EndsWith(self, self._cast_to_expr_or_convert_to_constant(postfix)) + return BooleanExpr( + "ends_with", [self, self._cast_to_expr_or_convert_to_constant(postfix)] + ) - def str_concat(self, *elements: Expr | CONSTANT_TYPE) -> "StrConcat": + @expose_as_static + def string_concat(self, *elements: Expr | CONSTANT_TYPE) -> "Expr": """Creates an expression that concatenates string expressions, fields or constants together. Example: >>> # Combine the 'firstName', " ", and 'lastName' fields into a single string - >>> Field.of("firstName").str_concat(" ", Field.of("lastName")) + >>> Field.of("firstName").string_concat(" ", Field.of("lastName")) Args: *elements: The expressions or constants (typically strings) to concatenate. @@ -694,16 +828,17 @@ def str_concat(self, *elements: Expr | CONSTANT_TYPE) -> "StrConcat": Returns: A new `Expr` representing the concatenated string. """ - return StrConcat( - self, *[self._cast_to_expr_or_convert_to_constant(el) for el in elements] + return Function( + "string_concat", + [self] + [self._cast_to_expr_or_convert_to_constant(el) for el in elements], ) - def map_get(self, key: str) -> "MapGet": - """Accesses a value from a map (object) field using the provided key. + @expose_as_static + def map_get(self, key: str | Constant[str]) -> "Expr": + """Accesses a value from the map produced by evaluating this expression. Example: - >>> # Get the 'city' value from - >>> # the 'address' map field + >>> Expr.map({"city": "London"}).map_get("city") >>> Field.of("address").map_get("city") Args: @@ -712,9 +847,12 @@ def map_get(self, key: str) -> "MapGet": Returns: A new `Expr` representing the value associated with the given key in the map. """ - return MapGet(self, Constant.of(key)) + return Function( + "map_get", [self, Constant.of(key) if isinstance(key, str) else key] + ) - def vector_length(self) -> "VectorLength": + @expose_as_static + def vector_length(self) -> "Expr": """Creates an expression that calculates the length (dimension) of a Firestore Vector. Example: @@ -724,9 +862,10 @@ def vector_length(self) -> "VectorLength": Returns: A new `Expr` representing the length of the vector. """ - return VectorLength(self) + return Function("vector_length", [self]) - def timestamp_to_unix_micros(self) -> "TimestampToUnixMicros": + @expose_as_static + def timestamp_to_unix_micros(self) -> "Expr": """Creates an expression that converts a timestamp to the number of microseconds since the epoch (1970-01-01 00:00:00 UTC). @@ -739,9 +878,10 @@ def timestamp_to_unix_micros(self) -> "TimestampToUnixMicros": Returns: A new `Expr` representing the number of microseconds since the epoch. """ - return TimestampToUnixMicros(self) + return Function("timestamp_to_unix_micros", [self]) - def unix_micros_to_timestamp(self) -> "UnixMicrosToTimestamp": + @expose_as_static + def unix_micros_to_timestamp(self) -> "Expr": """Creates an expression that converts a number of microseconds since the epoch (1970-01-01 00:00:00 UTC) to a timestamp. @@ -752,9 +892,10 @@ def unix_micros_to_timestamp(self) -> "UnixMicrosToTimestamp": Returns: A new `Expr` representing the timestamp. """ - return UnixMicrosToTimestamp(self) + return Function("unix_micros_to_timestamp", [self]) - def timestamp_to_unix_millis(self) -> "TimestampToUnixMillis": + @expose_as_static + def timestamp_to_unix_millis(self) -> "Expr": """Creates an expression that converts a timestamp to the number of milliseconds since the epoch (1970-01-01 00:00:00 UTC). @@ -767,9 +908,10 @@ def timestamp_to_unix_millis(self) -> "TimestampToUnixMillis": Returns: A new `Expr` representing the number of milliseconds since the epoch. """ - return TimestampToUnixMillis(self) + return Function("timestamp_to_unix_millis", [self]) - def unix_millis_to_timestamp(self) -> "UnixMillisToTimestamp": + @expose_as_static + def unix_millis_to_timestamp(self) -> "Expr": """Creates an expression that converts a number of milliseconds since the epoch (1970-01-01 00:00:00 UTC) to a timestamp. @@ -780,9 +922,10 @@ def unix_millis_to_timestamp(self) -> "UnixMillisToTimestamp": Returns: A new `Expr` representing the timestamp. """ - return UnixMillisToTimestamp(self) + return Function("unix_millis_to_timestamp", [self]) - def timestamp_to_unix_seconds(self) -> "TimestampToUnixSeconds": + @expose_as_static + def timestamp_to_unix_seconds(self) -> "Expr": """Creates an expression that converts a timestamp to the number of seconds since the epoch (1970-01-01 00:00:00 UTC). @@ -795,9 +938,10 @@ def timestamp_to_unix_seconds(self) -> "TimestampToUnixSeconds": Returns: A new `Expr` representing the number of seconds since the epoch. """ - return TimestampToUnixSeconds(self) + return Function("timestamp_to_unix_seconds", [self]) - def unix_seconds_to_timestamp(self) -> "UnixSecondsToTimestamp": + @expose_as_static + def unix_seconds_to_timestamp(self) -> "Expr": """Creates an expression that converts a number of seconds since the epoch (1970-01-01 00:00:00 UTC) to a timestamp. @@ -808,9 +952,10 @@ def unix_seconds_to_timestamp(self) -> "UnixSecondsToTimestamp": Returns: A new `Expr` representing the timestamp. """ - return UnixSecondsToTimestamp(self) + return Function("unix_seconds_to_timestamp", [self]) - def timestamp_add(self, unit: Expr | str, amount: Expr | float) -> "TimestampAdd": + @expose_as_static + def timestamp_add(self, unit: Expr | str, amount: Expr | float) -> "Expr": """Creates an expression that adds a specified amount of time to this timestamp expression. Example: @@ -827,20 +972,24 @@ def timestamp_add(self, unit: Expr | str, amount: Expr | float) -> "TimestampAdd Returns: A new `Expr` representing the resulting timestamp. """ - return TimestampAdd( - self, - self._cast_to_expr_or_convert_to_constant(unit), - self._cast_to_expr_or_convert_to_constant(amount), + return Function( + "timestamp_add", + [ + self, + self._cast_to_expr_or_convert_to_constant(unit), + self._cast_to_expr_or_convert_to_constant(amount), + ], ) - def timestamp_sub(self, unit: Expr | str, amount: Expr | float) -> "TimestampSub": + @expose_as_static + def timestamp_subtract(self, unit: Expr | str, amount: Expr | float) -> "Expr": """Creates an expression that subtracts a specified amount of time from this timestamp expression. Example: >>> # Subtract a duration specified by the 'unit' and 'amount' fields from the 'timestamp' field. - >>> Field.of("timestamp").timestamp_sub(Field.of("unit"), Field.of("amount")) + >>> Field.of("timestamp").timestamp_subtract(Field.of("unit"), Field.of("amount")) >>> # Subtract 2.5 hours from the 'timestamp' field. - >>> Field.of("timestamp").timestamp_sub("hour", 2.5) + >>> Field.of("timestamp").timestamp_subtract("hour", 2.5) Args: unit: The expression or string evaluating to the unit of time to subtract, must be one of @@ -850,18 +999,34 @@ def timestamp_sub(self, unit: Expr | str, amount: Expr | float) -> "TimestampSub Returns: A new `Expr` representing the resulting timestamp. """ - return TimestampSub( - self, - self._cast_to_expr_or_convert_to_constant(unit), - self._cast_to_expr_or_convert_to_constant(amount), + return Function( + "timestamp_subtract", + [ + self, + self._cast_to_expr_or_convert_to_constant(unit), + self._cast_to_expr_or_convert_to_constant(amount), + ], ) + @expose_as_static + def collection_id(self): + """Creates an expression that returns the collection ID from a path. + + Example: + >>> # Get the collection ID from a path. + >>> Field.of("__name__").collection_id() + + Returns: + A new `Expr` representing the collection ID. + """ + return Function("collection_id", [self]) + def ascending(self) -> Ordering: """Creates an `Ordering` that sorts documents in ascending order based on this expression. Example: >>> # Sort documents by the 'name' field in ascending order - >>> firestore.pipeline().collection("users").sort(Field.of("name").ascending()) + >>> client.pipeline().collection("users").sort(Field.of("name").ascending()) Returns: A new `Ordering` for ascending sorting. @@ -873,14 +1038,14 @@ def descending(self) -> Ordering: Example: >>> # Sort documents by the 'createdAt' field in descending order - >>> firestore.pipeline().collection("users").sort(Field.of("createdAt").descending()) + >>> client.pipeline().collection("users").sort(Field.of("createdAt").descending()) Returns: A new `Ordering` for descending sorting. """ return Ordering(self, Ordering.Direction.DESCENDING) - def as_(self, alias: str) -> "ExprWithAlias": + def as_(self, alias: str) -> "AliasedExpr": """Assigns an alias to this expression. Aliases are useful for renaming fields in the output of a stage or for giving meaningful @@ -888,7 +1053,7 @@ def as_(self, alias: str) -> "ExprWithAlias": Example: >>> # Calculate the total price and assign it the alias "totalPrice" and add it to the output. - >>> firestore.pipeline().collection("items").add_fields( + >>> client.pipeline().collection("items").add_fields( ... Field.of("price").multiply(Field.of("quantity")).as_("totalPrice") ... ) @@ -896,10 +1061,10 @@ def as_(self, alias: str) -> "ExprWithAlias": alias: The alias to assign to this expression. Returns: - A new `Selectable` (typically an `ExprWithAlias`) that wraps this + A new `Selectable` (typically an `AliasedExpr`) that wraps this expression and associates it with the provided alias. """ - return ExprWithAlias(self, alias) + return AliasedExpr(self, alias) class Constant(Expr, Generic[CONSTANT_TYPE]): @@ -922,24 +1087,27 @@ def of(value: CONSTANT_TYPE) -> Constant[CONSTANT_TYPE]: def __repr__(self): return f"Constant.of({self.value!r})" + def __hash__(self): + return hash(self.value) + def _to_pb(self) -> Value: return encode_value(self.value) -class ListOfExprs(Expr): +class _ListOfExprs(Expr): """Represents a list of expressions, typically used as an argument to functions like 'in' or array functions.""" - def __init__(self, exprs: List[Expr]): - self.exprs: list[Expr] = exprs + def __init__(self, exprs: Sequence[Expr]): + self.exprs: list[Expr] = list(exprs) def __eq__(self, other): - if not isinstance(other, ListOfExprs): + if not isinstance(other, _ListOfExprs): return False else: return other.exprs == self.exprs def __repr__(self): - return f"{self.__class__.__name__}({self.exprs})" + return repr(self.exprs) def _to_pb(self): return Value(array_value={"values": [e._to_pb() for e in self.exprs]}) @@ -948,9 +1116,34 @@ def _to_pb(self): class Function(Expr): """A base class for expressions that represent function calls.""" - def __init__(self, name: str, params: Sequence[Expr]): + def __init__( + self, + name: str, + params: Sequence[Expr], + *, + use_infix_repr: bool = True, + infix_name_override: str | None = None, + ): self.name = name self.params = list(params) + self._use_infix_repr = use_infix_repr + self._infix_name_override = infix_name_override + + def __repr__(self): + """ + Most Functions can be triggered infix. Eg: Field.of('age').greater_than(18). + + Display them this way in the repr string where possible + """ + if self._use_infix_repr: + infix_name = self._infix_name_override or self.name + if len(self.params) == 1: + return f"{self.params[0]!r}.{infix_name}()" + elif len(self.params) == 2: + return f"{self.params[0]!r}.{infix_name}({self.params[1]!r})" + else: + return f"{self.params[0]!r}.{infix_name}({', '.join([repr(p) for p in self.params[1:]])})" + return f"{self.__class__.__name__}({', '.join([repr(p) for p in self.params])})" def __eq__(self, other): if not isinstance(other, Function): @@ -958,9 +1151,6 @@ def __eq__(self, other): else: return other.name == self.name and other.params == self.params - def __repr__(self): - return f"{self.__class__.__name__}({', '.join([repr(p) for p in self.params])})" - def _to_pb(self): return Value( function_value={ @@ -969,1118 +1159,143 @@ def _to_pb(self): } ) - def add(left: Expr | str, right: Expr | float) -> "Add": - """Creates an expression that adds two expressions together. - - Example: - >>> Function.add("rating", 5) - >>> Function.add(Field.of("quantity"), Field.of("reserve")) - - Args: - left: The first expression or field path to add. - right: The second expression or constant value to add. - - Returns: - A new `Expr` representing the addition operation. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.add(left_expr, right) - - def subtract(left: Expr | str, right: Expr | float) -> "Subtract": - """Creates an expression that subtracts another expression or constant from this expression. - - Example: - >>> Function.subtract("total", 20) - >>> Function.subtract(Field.of("price"), Field.of("discount")) - - Args: - left: The expression or field path to subtract from. - right: The expression or constant value to subtract. - Returns: - A new `Expr` representing the subtraction operation. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.subtract(left_expr, right) +class AggregateFunction(Function): + """A base class for aggregation functions that operate across multiple inputs.""" - def multiply(left: Expr | str, right: Expr | float) -> "Multiply": - """Creates an expression that multiplies this expression by another expression or constant. + def as_(self, alias: str) -> "AliasedAggregate": + """Assigns an alias to this expression. - Example: - >>> Function.multiply("value", 2) - >>> Function.multiply(Field.of("quantity"), Field.of("price")) + Aliases are useful for renaming fields in the output of a stage or for giving meaningful + names to calculated values. Args: - left: The expression or field path to multiply. - right: The expression or constant value to multiply by. + alias: The alias to assign to this expression. - Returns: - A new `Expr` representing the multiplication operation. + Returns: A new AliasedAggregate that wraps this expression and associates it with the + provided alias. """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.multiply(left_expr, right) + return AliasedAggregate(self, alias) - def divide(left: Expr | str, right: Expr | float) -> "Divide": - """Creates an expression that divides this expression by another expression or constant. - Example: - >>> Function.divide("value", 10) - >>> Function.divide(Field.of("total"), Field.of("count")) +class Selectable(Expr): + """Base class for expressions that can be selected or aliased in projection stages.""" - Args: - left: The expression or field path to be divided. - right: The expression or constant value to divide by. + def __eq__(self, other): + if not isinstance(other, type(self)): + return False + else: + return other._to_map() == self._to_map() - Returns: - A new `Expr` representing the division operation. + @abstractmethod + def _to_map(self) -> tuple[str, Value]: """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.divide(left_expr, right) - - def mod(left: Expr | str, right: Expr | float) -> "Mod": - """Creates an expression that calculates the modulo (remainder) to another expression or constant. - - Example: - >>> Function.mod("value", 5) - >>> Function.mod(Field.of("value"), Field.of("divisor")) - - Args: - left: The dividend expression or field path. - right: The divisor expression or constant. - - Returns: - A new `Expr` representing the modulo operation. + Returns a str: Value representation of the Selectable """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.mod(left_expr, right) - - def logical_max(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "LogicalMax": - """Creates an expression that returns the larger value between this expression - and another expression or constant, based on Firestore's value type ordering. - - Firestore's value type ordering is described here: - https://cloud.google.com/firestore/docs/concepts/data-types#value_type_ordering - - Example: - >>> Function.logical_max("value", 10) - >>> Function.logical_max(Field.of("discount"), Field.of("cap")) - - Args: - left: The expression or field path to compare. - right: The other expression or constant value to compare with. + raise NotImplementedError - Returns: - A new `Expr` representing the logical max operation. + @classmethod + def _value_from_selectables(cls, *selectables: Selectable) -> Value: """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.logical_max(left_expr, right) - - def logical_min(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "LogicalMin": - """Creates an expression that returns the smaller value between this expression - and another expression or constant, based on Firestore's value type ordering. - - Firestore's value type ordering is described here: - https://cloud.google.com/firestore/docs/concepts/data-types#value_type_ordering - - Example: - >>> Function.logical_min("value", 10) - >>> Function.logical_min(Field.of("discount"), Field.of("floor")) - - Args: - left: The expression or field path to compare. - right: The other expression or constant value to compare with. - - Returns: - A new `Expr` representing the logical min operation. + Returns a Value representing a map of Selectables """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.logical_min(left_expr, right) - - def eq(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Eq": - """Creates an expression that checks if this expression is equal to another - expression or constant value. - - Example: - >>> Function.eq("city", "London") - >>> Function.eq(Field.of("age"), 21) - - Args: - left: The expression or field path to compare. - right: The expression or constant value to compare for equality. + return Value( + map_value={ + "fields": {m[0]: m[1] for m in [s._to_map() for s in selectables]} + } + ) - Returns: - A new `Expr` representing the equality comparison. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.eq(left_expr, right) + @staticmethod + def _to_value(field_list: Sequence[Selectable]) -> Value: + return Value( + map_value={ + "fields": {m[0]: m[1] for m in [f._to_map() for f in field_list]} + } + ) - def neq(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Neq": - """Creates an expression that checks if this expression is not equal to another - expression or constant value. - Example: - >>> Function.neq("country", "USA") - >>> Function.neq(Field.of("status"), "completed") +T = TypeVar("T", bound=Expr) - Args: - left: The expression or field path to compare. - right: The expression or constant value to compare for inequality. - Returns: - A new `Expr` representing the inequality comparison. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.neq(left_expr, right) +class AliasedExpr(Selectable, Generic[T]): + """Wraps an expression with an alias.""" - def gt(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Gt": - """Creates an expression that checks if this expression is greater than another - expression or constant value. + def __init__(self, expr: T, alias: str): + self.expr = expr + self.alias = alias - Example: - >>> Function.gt("price", 100) - >>> Function.gt(Field.of("age"), Field.of("limit")) + def _to_map(self): + return self.alias, self.expr._to_pb() - Args: - left: The expression or field path to compare. - right: The expression or constant value to compare for greater than. + def __repr__(self): + return f"{self.expr}.as_('{self.alias}')" - Returns: - A new `Expr` representing the greater than comparison. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.gt(left_expr, right) + def _to_pb(self): + return Value(map_value={"fields": {self.alias: self.expr._to_pb()}}) - def gte(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Gte": - """Creates an expression that checks if this expression is greater than or equal - to another expression or constant value. - Example: - >>> Function.gte("score", 80) - >>> Function.gte(Field.of("quantity"), Field.of('requirement').add(1)) +class AliasedAggregate: + """Wraps an aggregate with an alias""" - Args: - left: The expression or field path to compare. - right: The expression or constant value to compare for greater than or equal to. + def __init__(self, expr: AggregateFunction, alias: str): + self.expr = expr + self.alias = alias - Returns: - A new `Expr` representing the greater than or equal to comparison. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.gte(left_expr, right) + def _to_map(self): + return self.alias, self.expr._to_pb() - def lt(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Lt": - """Creates an expression that checks if this expression is less than another - expression or constant value. + def __repr__(self): + return f"{self.expr}.as_('{self.alias}')" - Example: - >>> Function.lt("price", 50) - >>> Function.lt(Field.of("age"), Field.of('limit')) + def _to_pb(self): + return Value(map_value={"fields": {self.alias: self.expr._to_pb()}}) - Args: - left: The expression or field path to compare. - right: The expression or constant value to compare for less than. - Returns: - A new `Expr` representing the less than comparison. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.lt(left_expr, right) +class Field(Selectable): + """Represents a reference to a field within a document.""" - def lte(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Lte": - """Creates an expression that checks if this expression is less than or equal to - another expression or constant value. + DOCUMENT_ID = "__name__" - Example: - >>> Function.lte("score", 70) - >>> Function.lte(Field.of("quantity"), Constant.of(20)) + def __init__(self, path: str): + """Initializes a Field reference. Args: - left: The expression or field path to compare. - right: The expression or constant value to compare for less than or equal to. - - Returns: - A new `Expr` representing the less than or equal to comparison. + path: The dot-separated path to the field (e.g., "address.city"). + Use Field.DOCUMENT_ID for the document ID. """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.lte(left_expr, right) - - def in_any(left: Expr | str, array: List[Expr | CONSTANT_TYPE]) -> "In": - """Creates an expression that checks if this expression is equal to any of the - provided values or expressions. + self.path = path - Example: - >>> Function.in_any("category", ["Electronics", "Apparel"]) - >>> Function.in_any(Field.of("category"), ["Electronics", Field.of("primaryType")]) + @staticmethod + def of(path: str): + """Creates a Field reference. Args: - left: The expression or field path to compare. - array: The values or expressions to check against. + path: The dot-separated path to the field (e.g., "address.city"). + Use Field.DOCUMENT_ID for the document ID. Returns: - A new `Expr` representing the 'IN' comparison. + A new Field instance. """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.in_any(left_expr, array) - - def not_in_any(left: Expr | str, array: List[Expr | CONSTANT_TYPE]) -> "Not": - """Creates an expression that checks if this expression is not equal to any of the - provided values or expressions. + return Field(path) - Example: - >>> Function.not_in_any("status", ["pending", "cancelled"]) + def _to_map(self): + return self.path, self._to_pb() - Args: - left: The expression or field path to compare. - array: The values or expressions to check against. + def __repr__(self): + return f"Field.of({self.path!r})" - Returns: - A new `Expr` representing the 'NOT IN' comparison. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.not_in_any(left_expr, array) + def _to_pb(self): + return Value(field_reference_value=self.path) - def array_contains( - array: Expr | str, element: Expr | CONSTANT_TYPE - ) -> "ArrayContains": - """Creates an expression that checks if an array contains a specific element or value. - Example: - >>> Function.array_contains("colors", "red") - >>> Function.array_contains(Field.of("sizes"), Field.of("selectedSize")) - - Args: - array: The array expression or field path to check. - element: The element (expression or constant) to search for in the array. - - Returns: - A new `Expr` representing the 'array_contains' comparison. - """ - array_expr = Field.of(array) if isinstance(array, str) else array - return Expr.array_contains(array_expr, element) - - def array_contains_all( - array: Expr | str, elements: List[Expr | CONSTANT_TYPE] - ) -> "ArrayContainsAll": - """Creates an expression that checks if an array contains all the specified elements. - - Example: - >>> Function.array_contains_all("tags", ["news", "sports"]) - >>> Function.array_contains_all(Field.of("tags"), [Field.of("tag1"), "tag2"]) - - Args: - array: The array expression or field path to check. - elements: The list of elements (expressions or constants) to check for in the array. - - Returns: - A new `Expr` representing the 'array_contains_all' comparison. - """ - array_expr = Field.of(array) if isinstance(array, str) else array - return Expr.array_contains_all(array_expr, elements) - - def array_contains_any( - array: Expr | str, elements: List[Expr | CONSTANT_TYPE] - ) -> "ArrayContainsAny": - """Creates an expression that checks if an array contains any of the specified elements. - - Example: - >>> Function.array_contains_any("groups", ["admin", "editor"]) - >>> Function.array_contains_any(Field.of("categories"), [Field.of("cate1"), Field.of("cate2")]) - - Args: - array: The array expression or field path to check. - elements: The list of elements (expressions or constants) to check for in the array. - - Returns: - A new `Expr` representing the 'array_contains_any' comparison. - """ - array_expr = Field.of(array) if isinstance(array, str) else array - return Expr.array_contains_any(array_expr, elements) - - def array_length(array: Expr | str) -> "ArrayLength": - """Creates an expression that calculates the length of an array. - - Example: - >>> Function.array_length("cart") - - Returns: - A new `Expr` representing the length of the array. - """ - array_expr = Field.of(array) if isinstance(array, str) else array - return Expr.array_length(array_expr) - - def array_reverse(array: Expr | str) -> "ArrayReverse": - """Creates an expression that returns the reversed content of an array. - - Example: - >>> Function.array_reverse("preferences") - - Returns: - A new `Expr` representing the reversed array. - """ - array_expr = Field.of(array) if isinstance(array, str) else array - return Expr.array_reverse(array_expr) - - def is_nan(expr: Expr | str) -> "IsNaN": - """Creates an expression that checks if this expression evaluates to 'NaN' (Not a Number). - - Example: - >>> Function.is_nan("measurement") - - Returns: - A new `Expr` representing the 'isNaN' check. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.is_nan(expr_val) - - def exists(expr: Expr | str) -> "Exists": - """Creates an expression that checks if a field exists in the document. - - Example: - >>> Function.exists("phoneNumber") - - Returns: - A new `Expr` representing the 'exists' check. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.exists(expr_val) - - def sum(expr: Expr | str) -> "Sum": - """Creates an aggregation that calculates the sum of a numeric field across multiple stage inputs. - - Example: - >>> Function.sum("orderAmount") - - Returns: - A new `Accumulator` representing the 'sum' aggregation. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.sum(expr_val) - - def avg(expr: Expr | str) -> "Avg": - """Creates an aggregation that calculates the average (mean) of a numeric field across multiple - stage inputs. - - Example: - >>> Function.avg("age") - - Returns: - A new `Accumulator` representing the 'avg' aggregation. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.avg(expr_val) - - def count(expr: Expr | str | None = None) -> "Count": - """Creates an aggregation that counts the number of stage inputs with valid evaluations of the - expression or field. If no expression is provided, it counts all inputs. - - Example: - >>> Function.count("productId") - >>> Function.count() - - Returns: - A new `Accumulator` representing the 'count' aggregation. - """ - if expr is None: - return Count() - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.count(expr_val) - - def min(expr: Expr | str) -> "Min": - """Creates an aggregation that finds the minimum value of a field across multiple stage inputs. - - Example: - >>> Function.min("price") - - Returns: - A new `Accumulator` representing the 'min' aggregation. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.min(expr_val) - - def max(expr: Expr | str) -> "Max": - """Creates an aggregation that finds the maximum value of a field across multiple stage inputs. - - Example: - >>> Function.max("score") - - Returns: - A new `Accumulator` representing the 'max' aggregation. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.max(expr_val) - - def char_length(expr: Expr | str) -> "CharLength": - """Creates an expression that calculates the character length of a string. - - Example: - >>> Function.char_length("name") - - Returns: - A new `Expr` representing the length of the string. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.char_length(expr_val) - - def byte_length(expr: Expr | str) -> "ByteLength": - """Creates an expression that calculates the byte length of a string in its UTF-8 form. - - Example: - >>> Function.byte_length("name") - - Returns: - A new `Expr` representing the byte length of the string. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.byte_length(expr_val) - - def like(expr: Expr | str, pattern: Expr | str) -> "Like": - """Creates an expression that performs a case-sensitive string comparison. - - Example: - >>> Function.like("title", "%guide%") - >>> Function.like(Field.of("title"), Field.of("pattern")) - - Args: - expr: The expression or field path to perform the comparison on. - pattern: The pattern (string or expression) to search for. You can use "%" as a wildcard character. - - Returns: - A new `Expr` representing the 'like' comparison. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.like(expr_val, pattern) - - def regex_contains(expr: Expr | str, regex: Expr | str) -> "RegexContains": - """Creates an expression that checks if a string contains a specified regular expression as a - substring. - - Example: - >>> Function.regex_contains("description", "(?i)example") - >>> Function.regex_contains(Field.of("description"), Field.of("regex")) - - Args: - expr: The expression or field path to perform the comparison on. - regex: The regular expression (string or expression) to use for the search. - - Returns: - A new `Expr` representing the 'contains' comparison. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.regex_contains(expr_val, regex) - - def regex_matches(expr: Expr | str, regex: Expr | str) -> "RegexMatch": - """Creates an expression that checks if a string matches a specified regular expression. - - Example: - >>> # Check if the 'email' field matches a valid email pattern - >>> Function.regex_matches("email", "[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Za-z]{2,}") - >>> Function.regex_matches(Field.of("email"), Field.of("regex")) - - Args: - expr: The expression or field path to match against. - regex: The regular expression (string or expression) to use for the match. - - Returns: - A new `Expr` representing the regular expression match. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.regex_matches(expr_val, regex) - - def str_contains(expr: Expr | str, substring: Expr | str) -> "StrContains": - """Creates an expression that checks if this string expression contains a specified substring. - - Example: - >>> Function.str_contains("description", "example") - >>> Function.str_contains(Field.of("description"), Field.of("keyword")) - - Args: - expr: The expression or field path to perform the comparison on. - substring: The substring (string or expression) to use for the search. - - Returns: - A new `Expr` representing the 'contains' comparison. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.str_contains(expr_val, substring) - - def starts_with(expr: Expr | str, prefix: Expr | str) -> "StartsWith": - """Creates an expression that checks if a string starts with a given prefix. - - Example: - >>> Function.starts_with("name", "Mr.") - >>> Function.starts_with(Field.of("fullName"), Field.of("firstName")) - - Args: - expr: The expression or field path to check. - prefix: The prefix (string or expression) to check for. - - Returns: - A new `Expr` representing the 'starts with' comparison. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.starts_with(expr_val, prefix) - - def ends_with(expr: Expr | str, postfix: Expr | str) -> "EndsWith": - """Creates an expression that checks if a string ends with a given postfix. - - Example: - >>> Function.ends_with("filename", ".txt") - >>> Function.ends_with(Field.of("url"), Field.of("extension")) - - Args: - expr: The expression or field path to check. - postfix: The postfix (string or expression) to check for. - - Returns: - A new `Expr` representing the 'ends with' comparison. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.ends_with(expr_val, postfix) - - def str_concat(first: Expr | str, *elements: Expr | CONSTANT_TYPE) -> "StrConcat": - """Creates an expression that concatenates string expressions, fields or constants together. - - Example: - >>> Function.str_concat("firstName", " ", Field.of("lastName")) - - Args: - first: The first expression or field path to concatenate. - *elements: The expressions or constants (typically strings) to concatenate. - - Returns: - A new `Expr` representing the concatenated string. - """ - first_expr = Field.of(first) if isinstance(first, str) else first - return Expr.str_concat(first_expr, *elements) - - def map_get(map_expr: Expr | str, key: str) -> "MapGet": - """Accesses a value from a map (object) field using the provided key. - - Example: - >>> Function.map_get("address", "city") - - Args: - map_expr: The expression or field path of the map. - key: The key to access in the map. - - Returns: - A new `Expr` representing the value associated with the given key in the map. - """ - map_val = Field.of(map_expr) if isinstance(map_expr, str) else map_expr - return Expr.map_get(map_val, key) - - def vector_length(vector_expr: Expr | str) -> "VectorLength": - """Creates an expression that calculates the length (dimension) of a Firestore Vector. - - Example: - >>> Function.vector_length("embedding") - - Returns: - A new `Expr` representing the length of the vector. - """ - vector_val = ( - Field.of(vector_expr) if isinstance(vector_expr, str) else vector_expr - ) - return Expr.vector_length(vector_val) - - def timestamp_to_unix_micros(timestamp_expr: Expr | str) -> "TimestampToUnixMicros": - """Creates an expression that converts a timestamp to the number of microseconds since the epoch - (1970-01-01 00:00:00 UTC). - - Truncates higher levels of precision by rounding down to the beginning of the microsecond. - - Example: - >>> Function.timestamp_to_unix_micros("timestamp") - - Returns: - A new `Expr` representing the number of microseconds since the epoch. - """ - timestamp_val = ( - Field.of(timestamp_expr) - if isinstance(timestamp_expr, str) - else timestamp_expr - ) - return Expr.timestamp_to_unix_micros(timestamp_val) - - def unix_micros_to_timestamp(micros_expr: Expr | str) -> "UnixMicrosToTimestamp": - """Creates an expression that converts a number of microseconds since the epoch (1970-01-01 - 00:00:00 UTC) to a timestamp. - - Example: - >>> Function.unix_micros_to_timestamp("microseconds") - - Returns: - A new `Expr` representing the timestamp. - """ - micros_val = ( - Field.of(micros_expr) if isinstance(micros_expr, str) else micros_expr - ) - return Expr.unix_micros_to_timestamp(micros_val) - - def timestamp_to_unix_millis(timestamp_expr: Expr | str) -> "TimestampToUnixMillis": - """Creates an expression that converts a timestamp to the number of milliseconds since the epoch - (1970-01-01 00:00:00 UTC). - - Truncates higher levels of precision by rounding down to the beginning of the millisecond. - - Example: - >>> Function.timestamp_to_unix_millis("timestamp") - - Returns: - A new `Expr` representing the number of milliseconds since the epoch. - """ - timestamp_val = ( - Field.of(timestamp_expr) - if isinstance(timestamp_expr, str) - else timestamp_expr - ) - return Expr.timestamp_to_unix_millis(timestamp_val) - - def unix_millis_to_timestamp(millis_expr: Expr | str) -> "UnixMillisToTimestamp": - """Creates an expression that converts a number of milliseconds since the epoch (1970-01-01 - 00:00:00 UTC) to a timestamp. - - Example: - >>> Function.unix_millis_to_timestamp("milliseconds") - - Returns: - A new `Expr` representing the timestamp. - """ - millis_val = ( - Field.of(millis_expr) if isinstance(millis_expr, str) else millis_expr - ) - return Expr.unix_millis_to_timestamp(millis_val) - - def timestamp_to_unix_seconds( - timestamp_expr: Expr | str, - ) -> "TimestampToUnixSeconds": - """Creates an expression that converts a timestamp to the number of seconds since the epoch - (1970-01-01 00:00:00 UTC). - - Truncates higher levels of precision by rounding down to the beginning of the second. - - Example: - >>> Function.timestamp_to_unix_seconds("timestamp") - - Returns: - A new `Expr` representing the number of seconds since the epoch. - """ - timestamp_val = ( - Field.of(timestamp_expr) - if isinstance(timestamp_expr, str) - else timestamp_expr - ) - return Expr.timestamp_to_unix_seconds(timestamp_val) - - def unix_seconds_to_timestamp(seconds_expr: Expr | str) -> "UnixSecondsToTimestamp": - """Creates an expression that converts a number of seconds since the epoch (1970-01-01 00:00:00 - UTC) to a timestamp. - - Example: - >>> Function.unix_seconds_to_timestamp("seconds") - - Returns: - A new `Expr` representing the timestamp. - """ - seconds_val = ( - Field.of(seconds_expr) if isinstance(seconds_expr, str) else seconds_expr - ) - return Expr.unix_seconds_to_timestamp(seconds_val) - - def timestamp_add( - timestamp: Expr | str, unit: Expr | str, amount: Expr | float - ) -> "TimestampAdd": - """Creates an expression that adds a specified amount of time to this timestamp expression. - - Example: - >>> Function.timestamp_add("timestamp", "day", 1.5) - >>> Function.timestamp_add(Field.of("timestamp"), Field.of("unit"), Field.of("amount")) - - Args: - timestamp: The expression or field path of the timestamp. - unit: The expression or string evaluating to the unit of time to add, must be one of - 'microsecond', 'millisecond', 'second', 'minute', 'hour', 'day'. - amount: The expression or float representing the amount of time to add. - - Returns: - A new `Expr` representing the resulting timestamp. - """ - timestamp_expr = ( - Field.of(timestamp) if isinstance(timestamp, str) else timestamp - ) - return Expr.timestamp_add(timestamp_expr, unit, amount) - - def timestamp_sub( - timestamp: Expr | str, unit: Expr | str, amount: Expr | float - ) -> "TimestampSub": - """Creates an expression that subtracts a specified amount of time from this timestamp expression. - - Example: - >>> Function.timestamp_sub("timestamp", "hour", 2.5) - >>> Function.timestamp_sub(Field.of("timestamp"), Field.of("unit"), Field.of("amount")) - - Args: - timestamp: The expression or field path of the timestamp. - unit: The expression or string evaluating to the unit of time to subtract, must be one of - 'microsecond', 'millisecond', 'second', 'minute', 'hour', 'day'. - amount: The expression or float representing the amount of time to subtract. - - Returns: - A new `Expr` representing the resulting timestamp. - """ - timestamp_expr = ( - Field.of(timestamp) if isinstance(timestamp, str) else timestamp - ) - return Expr.timestamp_sub(timestamp_expr, unit, amount) - - -class Divide(Function): - """Represents the division function.""" - - def __init__(self, left: Expr, right: Expr): - super().__init__("divide", [left, right]) - - -class LogicalMax(Function): - """Represents the logical maximum function based on Firestore type ordering.""" - - def __init__(self, left: Expr, right: Expr): - super().__init__("logical_maximum", [left, right]) - - -class LogicalMin(Function): - """Represents the logical minimum function based on Firestore type ordering.""" - - def __init__(self, left: Expr, right: Expr): - super().__init__("logical_minimum", [left, right]) - - -class MapGet(Function): - """Represents accessing a value within a map by key.""" - - def __init__(self, map_: Expr, key: Constant[str]): - super().__init__("map_get", [map_, key]) - - -class Mod(Function): - """Represents the modulo function.""" - - def __init__(self, left: Expr, right: Expr): - super().__init__("mod", [left, right]) - - -class Multiply(Function): - """Represents the multiplication function.""" - - def __init__(self, left: Expr, right: Expr): - super().__init__("multiply", [left, right]) - - -class Parent(Function): - """Represents getting the parent document reference.""" - - def __init__(self, value: Expr): - super().__init__("parent", [value]) - - -class StrConcat(Function): - """Represents concatenating multiple strings.""" - - def __init__(self, *exprs: Expr): - super().__init__("str_concat", exprs) - - -class Subtract(Function): - """Represents the subtraction function.""" - - def __init__(self, left: Expr, right: Expr): - super().__init__("subtract", [left, right]) - - -class TimestampAdd(Function): - """Represents adding a duration to a timestamp.""" - - def __init__(self, timestamp: Expr, unit: Expr, amount: Expr): - super().__init__("timestamp_add", [timestamp, unit, amount]) - - -class TimestampSub(Function): - """Represents subtracting a duration from a timestamp.""" - - def __init__(self, timestamp: Expr, unit: Expr, amount: Expr): - super().__init__("timestamp_sub", [timestamp, unit, amount]) - - -class TimestampToUnixMicros(Function): - """Represents converting a timestamp to microseconds since epoch.""" - - def __init__(self, input: Expr): - super().__init__("timestamp_to_unix_micros", [input]) - - -class TimestampToUnixMillis(Function): - """Represents converting a timestamp to milliseconds since epoch.""" - - def __init__(self, input: Expr): - super().__init__("timestamp_to_unix_millis", [input]) - - -class TimestampToUnixSeconds(Function): - """Represents converting a timestamp to seconds since epoch.""" - - def __init__(self, input: Expr): - super().__init__("timestamp_to_unix_seconds", [input]) - - -class UnixMicrosToTimestamp(Function): - """Represents converting microseconds since epoch to a timestamp.""" - - def __init__(self, input: Expr): - super().__init__("unix_micros_to_timestamp", [input]) - - -class UnixMillisToTimestamp(Function): - """Represents converting milliseconds since epoch to a timestamp.""" - - def __init__(self, input: Expr): - super().__init__("unix_millis_to_timestamp", [input]) - - -class UnixSecondsToTimestamp(Function): - """Represents converting seconds since epoch to a timestamp.""" - - def __init__(self, input: Expr): - super().__init__("unix_seconds_to_timestamp", [input]) - - -class VectorLength(Function): - """Represents getting the length (dimension) of a vector.""" - - def __init__(self, array: Expr): - super().__init__("vector_length", [array]) - - -class Add(Function): - """Represents the addition function.""" - - def __init__(self, left: Expr, right: Expr): - super().__init__("add", [left, right]) - - -class ArrayElement(Function): - """Represents accessing an element within an array""" - - def __init__(self): - super().__init__("array_element", []) - - -class ArrayFilter(Function): - """Represents filtering elements from an array based on a condition.""" - - def __init__(self, array: Expr, filter: "FilterCondition"): - super().__init__("array_filter", [array, filter]) - - -class ArrayLength(Function): - """Represents getting the length of an array.""" - - def __init__(self, array: Expr): - super().__init__("array_length", [array]) - - -class ArrayReverse(Function): - """Represents reversing the elements of an array.""" - - def __init__(self, array: Expr): - super().__init__("array_reverse", [array]) - - -class ArrayTransform(Function): - """Represents applying a transformation function to each element of an array.""" - - def __init__(self, array: Expr, transform: Function): - super().__init__("array_transform", [array, transform]) - - -class ByteLength(Function): - """Represents getting the byte length of a string (UTF-8).""" - - def __init__(self, expr: Expr): - super().__init__("byte_length", [expr]) - - -class CharLength(Function): - """Represents getting the character length of a string.""" - - def __init__(self, expr: Expr): - super().__init__("char_length", [expr]) - - -class CollectionId(Function): - """Represents getting the collection ID from a document reference.""" - - def __init__(self, value: Expr): - super().__init__("collection_id", [value]) - - -class Accumulator(Function): - """A base class for aggregation functions that operate across multiple inputs.""" - - -class Max(Accumulator): - """Represents the maximum aggregation function.""" - - def __init__(self, value: Expr): - super().__init__("maximum", [value]) - - -class Min(Accumulator): - """Represents the minimum aggregation function.""" - - def __init__(self, value: Expr): - super().__init__("minimum", [value]) - - -class Sum(Accumulator): - """Represents the sum aggregation function.""" - - def __init__(self, value: Expr): - super().__init__("sum", [value]) - - -class Avg(Accumulator): - """Represents the average aggregation function.""" - - def __init__(self, value: Expr): - super().__init__("avg", [value]) - - -class Count(Accumulator): - """Represents an aggregation that counts the total number of inputs.""" - - def __init__(self, value: Expr | None = None): - super().__init__("count", [value] if value else []) - - -class Selectable(Expr): - """Base class for expressions that can be selected or aliased in projection stages.""" - - def __eq__(self, other): - if not isinstance(other, type(self)): - return False - else: - return other._to_map() == self._to_map() - - @abstractmethod - def _to_map(self) -> tuple[str, Value]: - """ - Returns a str: Value representation of the Selectable - """ - raise NotImplementedError - - @classmethod - def _value_from_selectables(cls, *selectables: Selectable) -> Value: - """ - Returns a Value representing a map of Selectables - """ - return Value( - map_value={ - "fields": {m[0]: m[1] for m in [s._to_map() for s in selectables]} - } - ) - - -T = TypeVar("T", bound=Expr) - - -class ExprWithAlias(Selectable, Generic[T]): - """Wraps an expression with an alias.""" - - def __init__(self, expr: T, alias: str): - self.expr = expr - self.alias = alias - - def _to_map(self): - return self.alias, self.expr._to_pb() - - def __repr__(self): - return f"{self.expr}.as_('{self.alias}')" - - def _to_pb(self): - return Value(map_value={"fields": {self.alias: self.expr._to_pb()}}) - - -class Field(Selectable): - """Represents a reference to a field within a document.""" - - DOCUMENT_ID = "__name__" - - def __init__(self, path: str): - """Initializes a Field reference. - - Args: - path: The dot-separated path to the field (e.g., "address.city"). - Use Field.DOCUMENT_ID for the document ID. - """ - self.path = path - - @staticmethod - def of(path: str): - """Creates a Field reference. - - Args: - path: The dot-separated path to the field (e.g., "address.city"). - Use Field.DOCUMENT_ID for the document ID. - - Returns: - A new Field instance. - """ - return Field(path) - - def _to_map(self): - return self.path, self._to_pb() - - def __repr__(self): - return f"Field.of({self.path!r})" - - def _to_pb(self): - return Value(field_reference_value=self.path) - - -class FilterCondition(Function): - """Filters the given data in some way.""" - - def __init__( - self, - *args, - use_infix_repr: bool = True, - infix_name_override: str | None = None, - **kwargs, - ): - self._use_infix_repr = use_infix_repr - self._infix_name_override = infix_name_override - super().__init__(*args, **kwargs) - - def __repr__(self): - """ - Most FilterConditions can be triggered infix. Eg: Field.of('age').gte(18). - - Display them this way in the repr string where possible - """ - if self._use_infix_repr: - infix_name = self._infix_name_override or self.name - if len(self.params) == 1: - return f"{self.params[0]!r}.{infix_name}()" - elif len(self.params) == 2: - return f"{self.params[0]!r}.{infix_name}({self.params[1]!r})" - return super().__repr__() +class BooleanExpr(Function): + """Filters the given data in some way.""" @staticmethod def _from_query_filter_pb(filter_pb, client): if isinstance(filter_pb, Query_pb.CompositeFilter): sub_filters = [ - FilterCondition._from_query_filter_pb(f, client) - for f in filter_pb.filters + BooleanExpr._from_query_filter_pb(f, client) for f in filter_pb.filters ] if filter_pb.op == Query_pb.CompositeFilter.Operator.OR: return Or(*sub_filters) @@ -2097,34 +1312,34 @@ def _from_query_filter_pb(filter_pb, client): elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NOT_NAN: return And(field.exists(), Not(field.is_nan())) elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NULL: - return And(field.exists(), field.eq(None)) + return And(field.exists(), field.equal(None)) elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NOT_NULL: - return And(field.exists(), Not(field.eq(None))) + return And(field.exists(), Not(field.equal(None))) else: raise TypeError(f"Unexpected UnaryFilter operator type: {filter_pb.op}") elif isinstance(filter_pb, Query_pb.FieldFilter): field = Field.of(filter_pb.field.field_path) value = decode_value(filter_pb.value, client) if filter_pb.op == Query_pb.FieldFilter.Operator.LESS_THAN: - return And(field.exists(), field.lt(value)) + return And(field.exists(), field.less_than(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.LESS_THAN_OR_EQUAL: - return And(field.exists(), field.lte(value)) + return And(field.exists(), field.less_than_or_equal(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.GREATER_THAN: - return And(field.exists(), field.gt(value)) + return And(field.exists(), field.greater_than(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.GREATER_THAN_OR_EQUAL: - return And(field.exists(), field.gte(value)) + return And(field.exists(), field.greater_than_or_equal(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.EQUAL: - return And(field.exists(), field.eq(value)) + return And(field.exists(), field.equal(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.NOT_EQUAL: - return And(field.exists(), field.neq(value)) + return And(field.exists(), field.not_equal(value)) if filter_pb.op == Query_pb.FieldFilter.Operator.ARRAY_CONTAINS: return And(field.exists(), field.array_contains(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.ARRAY_CONTAINS_ANY: return And(field.exists(), field.array_contains_any(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.IN: - return And(field.exists(), field.in_any(value)) + return And(field.exists(), field.equal_any(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.NOT_IN: - return And(field.exists(), field.not_in_any(value)) + return And(field.exists(), field.not_equal_any(value)) else: raise TypeError(f"Unexpected FieldFilter operator type: {filter_pb.op}") elif isinstance(filter_pb, Query_pb.Filter): @@ -2134,169 +1349,94 @@ def _from_query_filter_pb(filter_pb, client): or filter_pb.field_filter or filter_pb.unary_filter ) - return FilterCondition._from_query_filter_pb(f, client) + return BooleanExpr._from_query_filter_pb(f, client) else: raise TypeError(f"Unexpected filter type: {type(filter_pb)}") -class And(FilterCondition): - def __init__(self, *conditions: "FilterCondition"): - super().__init__("and", conditions, use_infix_repr=False) - - -class ArrayContains(FilterCondition): - def __init__(self, array: Expr, element: Expr): - super().__init__( - "array_contains", [array, element if element else Constant(None)] - ) - - -class ArrayContainsAll(FilterCondition): - """Represents checking if an array contains all specified elements.""" - - def __init__(self, array: Expr, elements: List[Expr]): - super().__init__("array_contains_all", [array, ListOfExprs(elements)]) - +class And(BooleanExpr): + """ + Represents an expression that performs a logical 'AND' operation on multiple filter conditions. -class ArrayContainsAny(FilterCondition): - """Represents checking if an array contains any of the specified elements.""" + Example: + >>> # Check if the 'age' field is greater than 18 AND the 'city' field is "London" AND + >>> # the 'status' field is "active" + >>> Expr.And(Field.of("age").greater_than(18), Field.of("city").equal("London"), Field.of("status").equal("active")) - def __init__(self, array: Expr, elements: List[Expr]): - super().__init__("array_contains_any", [array, ListOfExprs(elements)]) + Args: + *conditions: The filter conditions to 'AND' together. + """ + def __init__(self, *conditions: "BooleanExpr"): + super().__init__("and", conditions, use_infix_repr=False) -class EndsWith(FilterCondition): - """Represents checking if a string ends with a specific postfix.""" - def __init__(self, expr: Expr, postfix: Expr): - super().__init__("ends_with", [expr, postfix]) +class Not(BooleanExpr): + """ + Represents an expression that negates a filter condition. + Example: + >>> # Find documents where the 'completed' field is NOT true + >>> Expr.Not(Field.of("completed").equal(True)) -class Eq(FilterCondition): - """Represents the equality comparison.""" + Args: + condition: The filter condition to negate. + """ - def __init__(self, left: Expr, right: Expr): - super().__init__("eq", [left, right if right else Constant(None)]) + def __init__(self, condition: BooleanExpr): + super().__init__("not", [condition], use_infix_repr=False) -class Exists(FilterCondition): - """Represents checking if a field exists.""" +class Or(BooleanExpr): + """ + Represents expression that performs a logical 'OR' operation on multiple filter conditions. - def __init__(self, expr: Expr): - super().__init__("exists", [expr]) + Example: + >>> # Check if the 'age' field is greater than 18 OR the 'city' field is "London" OR + >>> # the 'status' field is "active" + >>> Expr.Or(Field.of("age").greater_than(18), Field.of("city").equal("London"), Field.of("status").equal("active")) + Args: + *conditions: The filter conditions to 'OR' together. + """ -class Gt(FilterCondition): - """Represents the greater than comparison.""" + def __init__(self, *conditions: "BooleanExpr"): + super().__init__("or", conditions, use_infix_repr=False) - def __init__(self, left: Expr, right: Expr): - super().__init__("gt", [left, right if right else Constant(None)]) +class Xor(BooleanExpr): + """ + Represents an expression that performs a logical 'XOR' (exclusive OR) operation on multiple filter conditions. -class Gte(FilterCondition): - """Represents the greater than or equal to comparison.""" + Example: + >>> # Check if only one of the conditions is true: 'age' greater than 18, 'city' is "London", + >>> # or 'status' is "active". + >>> Expr.Xor(Field.of("age").greater_than(18), Field.of("city").equal("London"), Field.of("status").equal("active")) - def __init__(self, left: Expr, right: Expr): - super().__init__("gte", [left, right if right else Constant(None)]) + Args: + *conditions: The filter conditions to 'XOR' together. + """ + def __init__(self, conditions: Sequence["BooleanExpr"]): + super().__init__("xor", conditions, use_infix_repr=False) -class If(FilterCondition): - """Represents a conditional expression (if-then-else).""" - def __init__(self, condition: "FilterCondition", true_expr: Expr, false_expr: Expr): - super().__init__( - "if", [condition, true_expr, false_expr if false_expr else Constant(None)] - ) +class Conditional(BooleanExpr): + """ + Represents a conditional expression that evaluates to a 'then' expression if a condition is true + and an 'else' expression if the condition is false. + Example: + >>> # If 'age' is greater than 18, return "Adult"; otherwise, return "Minor". + >>> Expr.conditional(Field.of("age").greater_than(18), Constant.of("Adult"), Constant.of("Minor")); -class In(FilterCondition): - """Represents checking if an expression's value is within a list of values.""" + Args: + condition: The condition to evaluate. + then_expr: The expression to return if the condition is true. + else_expr: The expression to return if the condition is false + """ - def __init__(self, left: Expr, others: List[Expr]): + def __init__(self, condition: BooleanExpr, then_expr: Expr, else_expr: Expr): super().__init__( - "in", [left, ListOfExprs(others)], infix_name_override="in_any" + "conditional", [condition, then_expr, else_expr], use_infix_repr=False ) - - -class IsNaN(FilterCondition): - """Represents checking if a numeric value is NaN.""" - - def __init__(self, value: Expr): - super().__init__("is_nan", [value]) - - -class Like(FilterCondition): - """Represents a case-sensitive wildcard string comparison.""" - - def __init__(self, expr: Expr, pattern: Expr): - super().__init__("like", [expr, pattern]) - - -class Lt(FilterCondition): - """Represents the less than comparison.""" - - def __init__(self, left: Expr, right: Expr): - super().__init__("lt", [left, right if right else Constant(None)]) - - -class Lte(FilterCondition): - """Represents the less than or equal to comparison.""" - - def __init__(self, left: Expr, right: Expr): - super().__init__("lte", [left, right if right else Constant(None)]) - - -class Neq(FilterCondition): - """Represents the inequality comparison.""" - - def __init__(self, left: Expr, right: Expr): - super().__init__("neq", [left, right if right else Constant(None)]) - - -class Not(FilterCondition): - """Represents the logical NOT of a filter condition.""" - - def __init__(self, condition: Expr): - super().__init__("not", [condition], use_infix_repr=False) - - -class Or(FilterCondition): - """Represents the logical OR of multiple filter conditions.""" - - def __init__(self, *conditions: "FilterCondition"): - super().__init__("or", conditions) - - -class RegexContains(FilterCondition): - """Represents checking if a string contains a substring matching a regex.""" - - def __init__(self, expr: Expr, regex: Expr): - super().__init__("regex_contains", [expr, regex]) - - -class RegexMatch(FilterCondition): - """Represents checking if a string fully matches a regex.""" - - def __init__(self, expr: Expr, regex: Expr): - super().__init__("regex_match", [expr, regex]) - - -class StartsWith(FilterCondition): - """Represents checking if a string starts with a specific prefix.""" - - def __init__(self, expr: Expr, prefix: Expr): - super().__init__("starts_with", [expr, prefix]) - - -class StrContains(FilterCondition): - """Represents checking if a string contains a specific substring.""" - - def __init__(self, expr: Expr, substring: Expr): - super().__init__("str_contains", [expr, substring]) - - -class Xor(FilterCondition): - """Represents the logical XOR of multiple filter conditions.""" - - def __init__(self, conditions: List["FilterCondition"]): - super().__init__("xor", conditions, use_infix_repr=False) diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index dc262f4a9..50cc7c29d 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -125,13 +125,25 @@ data: awards: hugo: true nebula: true + timestamps: + ts1: + time: "1993-04-28T12:01:00.654321+00:00" + micros: 735998460654321 + millis: 735998460654 + seconds: 735998460 + vectors: + vec1: + embedding: [1.0, 2.0, 3.0] + vec2: + embedding: [4.0, 5.0, 6.0, 7.0] tests: - description: "testAggregates - count" pipeline: - Collection: books - Aggregate: - - ExprWithAlias: - - Count + - AliasedExpr: + - Expr.count: + - Field: rating - "count" assert_results: - count: 10 @@ -147,25 +159,28 @@ tests: count: functionValue: name: count + args: + - fieldReferenceValue: rating - mapValue: {} name: aggregate - description: "testAggregates - avg, count, max" pipeline: - Collection: books - Where: - - Eq: + - Expr.equal: - Field: genre - Constant: Science Fiction - Aggregate: - - ExprWithAlias: - - Count + - AliasedExpr: + - Expr.count: + - Field: rating - "count" - - ExprWithAlias: - - Avg: + - AliasedExpr: + - Expr.average: - Field: rating - "avg_rating" - - ExprWithAlias: - - Max: + - AliasedExpr: + - Expr.maximum: - Field: rating - "max_rating" assert_results: @@ -183,7 +198,7 @@ tests: args: - fieldReferenceValue: genre - stringValue: Science Fiction - name: eq + name: equal name: where - args: - mapValue: @@ -192,10 +207,12 @@ tests: functionValue: args: - fieldReferenceValue: rating - name: avg + name: average count: functionValue: name: count + args: + - fieldReferenceValue: rating max_rating: functionValue: args: @@ -207,7 +224,7 @@ tests: pipeline: - Collection: books - Where: - - Lt: + - Expr.less_than: - Field: published - Constant: 1900 - Aggregate: @@ -218,18 +235,18 @@ tests: pipeline: - Collection: books - Where: - - Lt: + - Expr.less_than: - Field: published - Constant: 1984 - Aggregate: accumulators: - - ExprWithAlias: - - Avg: + - AliasedExpr: + - Expr.average: - Field: rating - "avg_rating" groups: [genre] - Where: - - Gt: + - Expr.greater_than: - Field: avg_rating - Constant: 4.3 - Sort: @@ -254,7 +271,7 @@ tests: args: - fieldReferenceValue: published - integerValue: '1984' - name: lt + name: less_than name: where - args: - mapValue: @@ -263,7 +280,7 @@ tests: functionValue: args: - fieldReferenceValue: rating - name: avg + name: average - mapValue: fields: genre: @@ -274,7 +291,7 @@ tests: args: - fieldReferenceValue: avg_rating - doubleValue: 4.3 - name: gt + name: greater_than name: where - args: - mapValue: @@ -288,15 +305,16 @@ tests: pipeline: - Collection: books - Aggregate: - - ExprWithAlias: - - Count + - AliasedExpr: + - Expr.count: + - Field: rating - "count" - - ExprWithAlias: - - Max: + - AliasedExpr: + - Expr.maximum: - Field: rating - "max_rating" - - ExprWithAlias: - - Min: + - AliasedExpr: + - Expr.minimum: - Field: published - "min_published" assert_results: @@ -314,6 +332,8 @@ tests: fields: count: functionValue: + args: + - fieldReferenceValue: rating name: count max_rating: functionValue: @@ -384,14 +404,14 @@ tests: pipeline: - Collection: books - AddFields: - - ExprWithAlias: - - StrConcat: + - AliasedExpr: + - Expr.string_concat: - Field: author - Constant: _ - Field: title - "author_title" - - ExprWithAlias: - - StrConcat: + - AliasedExpr: + - Expr.string_concat: - Field: title - Constant: _ - Field: author @@ -445,14 +465,14 @@ tests: - fieldReferenceValue: author - stringValue: _ - fieldReferenceValue: title - name: str_concat + name: string_concat title_author: functionValue: args: - fieldReferenceValue: title - stringValue: _ - fieldReferenceValue: author - name: str_concat + name: string_concat name: add_fields - args: - fieldReferenceValue: title_author @@ -477,10 +497,10 @@ tests: - Collection: books - Where: - And: - - Gt: + - Expr.greater_than: - Field: rating - Constant: 4.5 - - Eq: + - Expr.equal: - Field: genre - Constant: Science Fiction assert_results: @@ -509,12 +529,12 @@ tests: args: - fieldReferenceValue: rating - doubleValue: 4.5 - name: gt + name: greater_than - functionValue: args: - fieldReferenceValue: genre - stringValue: Science Fiction - name: eq + name: equal name: and name: where - description: whereByOrCondition @@ -522,10 +542,10 @@ tests: - Collection: books - Where: - Or: - - Eq: + - Expr.equal: - Field: genre - Constant: Romance - - Eq: + - Expr.equal: - Field: genre - Constant: Dystopian - Select: @@ -551,12 +571,12 @@ tests: args: - fieldReferenceValue: genre - stringValue: Romance - name: eq + name: equal - functionValue: args: - fieldReferenceValue: genre - stringValue: Dystopian - name: eq + name: equal name: or name: where - args: @@ -624,7 +644,7 @@ tests: pipeline: - Collection: books - Where: - - ArrayContains: + - Expr.array_contains: - Field: tags - Constant: comedy assert_results: @@ -654,7 +674,7 @@ tests: pipeline: - Collection: books - Where: - - ArrayContainsAny: + - Expr.array_contains_any: - Field: tags - - Constant: comedy - Constant: classic @@ -701,7 +721,7 @@ tests: pipeline: - Collection: books - Where: - - ArrayContainsAll: + - Expr.array_contains_all: - Field: tags - - Constant: adventure - Constant: magic @@ -735,12 +755,12 @@ tests: pipeline: - Collection: books - Select: - - ExprWithAlias: - - ArrayLength: + - AliasedExpr: + - Expr.array_length: - Field: tags - "tagsCount" - Where: - - Eq: + - Expr.equal: - Field: tagsCount - Constant: 3 assert_results: # All documents have 3 tags @@ -774,9 +794,9 @@ tests: args: - fieldReferenceValue: tagsCount - integerValue: '3' - name: eq + name: equal name: where - - description: testStrConcat + - description: testStringConcat pipeline: - Collection: books - Sort: @@ -784,8 +804,8 @@ tests: - Field: author - ASCENDING - Select: - - ExprWithAlias: - - StrConcat: + - AliasedExpr: + - Expr.string_concat: - Field: author - Constant: " - " - Field: title @@ -816,7 +836,7 @@ tests: - fieldReferenceValue: author - stringValue: ' - ' - fieldReferenceValue: title - name: str_concat + name: string_concat name: select - args: - integerValue: '1' @@ -825,7 +845,7 @@ tests: pipeline: - Collection: books - Where: - - StartsWith: + - Expr.starts_with: - Field: title - Constant: The - Select: @@ -870,7 +890,7 @@ tests: pipeline: - Collection: books - Where: - - EndsWith: + - Expr.ends_with: - Field: title - Constant: y - Select: @@ -913,13 +933,13 @@ tests: pipeline: - Collection: books - Select: - - ExprWithAlias: - - CharLength: + - AliasedExpr: + - Expr.char_length: - Field: title - "titleLength" - title - Where: - - Gt: + - Expr.greater_than: - Field: titleLength - Constant: 20 - Sort: @@ -957,7 +977,7 @@ tests: args: - fieldReferenceValue: titleLength - integerValue: '20' - name: gt + name: greater_than name: where - args: - mapValue: @@ -971,12 +991,12 @@ tests: pipeline: - Collection: books - Where: - - Eq: + - Expr.equal: - Field: author - Constant: "Douglas Adams" - Select: - - ExprWithAlias: - - CharLength: + - AliasedExpr: + - Expr.char_length: - Field: title - "title_length" assert_results: @@ -992,7 +1012,7 @@ tests: args: - fieldReferenceValue: author - stringValue: Douglas Adams - name: eq + name: equal name: where - args: - mapValue: @@ -1007,13 +1027,13 @@ tests: pipeline: - Collection: books - Where: - - Eq: + - Expr.equal: - Field: author - Constant: Douglas Adams - Select: - - ExprWithAlias: - - ByteLength: - - StrConcat: + - AliasedExpr: + - Expr.byte_length: + - Expr.string_concat: - Field: title - Constant: _银河系漫游指南 - "title_byte_length" @@ -1030,7 +1050,7 @@ tests: args: - fieldReferenceValue: author - stringValue: Douglas Adams - name: eq + name: equal name: where - args: - mapValue: @@ -1042,14 +1062,14 @@ tests: args: - fieldReferenceValue: title - stringValue: "_\u94F6\u6CB3\u7CFB\u6F2B\u6E38\u6307\u5357" - name: str_concat + name: string_concat name: byte_length name: select - description: testLike pipeline: - Collection: books - Where: - - Like: + - Expr.like: - Field: title - Constant: "%Guide%" - Select: @@ -1061,7 +1081,7 @@ tests: pipeline: - Collection: books - Where: - - RegexContains: + - Expr.regex_contains: - Field: title - Constant: "(?i)(the|of)" assert_count: 5 @@ -1083,7 +1103,7 @@ tests: pipeline: - Collection: books - Where: - - RegexMatch: + - Expr.regex_match: - Field: title - Constant: ".*(?i)(the|of).*" assert_count: 5 @@ -1104,42 +1124,42 @@ tests: pipeline: - Collection: books - Where: - - Eq: + - Expr.equal: - Field: title - Constant: To Kill a Mockingbird - Select: - - ExprWithAlias: - - Add: + - AliasedExpr: + - Expr.add: - Field: rating - Constant: 1 - "ratingPlusOne" - - ExprWithAlias: - - Subtract: + - AliasedExpr: + - Expr.subtract: - Field: published - Constant: 1900 - "yearsSince1900" - - ExprWithAlias: - - Multiply: + - AliasedExpr: + - Expr.multiply: - Field: rating - Constant: 10 - "ratingTimesTen" - - ExprWithAlias: - - Divide: + - AliasedExpr: + - Expr.divide: - Field: rating - Constant: 2 - "ratingDividedByTwo" - - ExprWithAlias: - - Multiply: + - AliasedExpr: + - Expr.multiply: - Field: rating - Constant: 20 - "ratingTimes20" - - ExprWithAlias: - - Add: + - AliasedExpr: + - Expr.add: - Field: rating - Constant: 3 - "ratingPlus3" - - ExprWithAlias: - - Mod: + - AliasedExpr: + - Expr.mod: - Field: rating - Constant: 2 - "ratingMod2" @@ -1162,7 +1182,7 @@ tests: args: - fieldReferenceValue: title - stringValue: To Kill a Mockingbird - name: eq + name: equal name: where - args: - mapValue: @@ -1215,13 +1235,13 @@ tests: - Collection: books - Where: - And: - - Gt: + - Expr.greater_than: - Field: rating - Constant: 4.2 - - Lte: + - Expr.less_than_or_equal: - Field: rating - Constant: 4.5 - - Neq: + - Expr.not_equal: - Field: genre - Constant: Science Fiction - Select: @@ -1251,17 +1271,17 @@ tests: args: - fieldReferenceValue: rating - doubleValue: 4.2 - name: gt + name: greater_than - functionValue: args: - fieldReferenceValue: rating - doubleValue: 4.5 - name: lte + name: less_than_or_equal - functionValue: args: - fieldReferenceValue: genre - stringValue: Science Fiction - name: neq + name: not_equal name: and name: where - args: @@ -1286,13 +1306,13 @@ tests: - Where: - Or: - And: - - Gt: + - Expr.greater_than: - Field: rating - Constant: 4.5 - - Eq: + - Expr.equal: - Field: genre - Constant: Science Fiction - - Lt: + - Expr.less_than: - Field: published - Constant: 1900 - Select: @@ -1320,18 +1340,18 @@ tests: args: - fieldReferenceValue: rating - doubleValue: 4.5 - name: gt + name: greater_than - functionValue: args: - fieldReferenceValue: genre - stringValue: Science Fiction - name: eq + name: equal name: and - functionValue: args: - fieldReferenceValue: published - integerValue: '1900' - name: lt + name: less_than name: or name: where - args: @@ -1353,12 +1373,12 @@ tests: - Collection: books - Where: - Not: - - IsNaN: + - Expr.is_nan: - Field: rating - Select: - - ExprWithAlias: + - AliasedExpr: - Not: - - IsNaN: + - Expr.is_nan: - Field: rating - "ratingIsNotNaN" - Limit: 1 @@ -1398,23 +1418,23 @@ tests: pipeline: - Collection: books - Where: - - Eq: + - Expr.equal: - Field: author - Constant: Douglas Adams - Select: - - ExprWithAlias: - - LogicalMax: + - AliasedExpr: + - Expr.logical_maximum: - Field: rating - Constant: 4.5 - "max_rating" - - ExprWithAlias: - - LogicalMax: + - AliasedExpr: + - Expr.logical_minimum: - Field: published - Constant: 1900 - - "max_published" + - "min_published" assert_results: - max_rating: 4.5 - max_published: 1979 + min_published: 1900 assert_proto: pipeline: stages: @@ -1426,23 +1446,23 @@ tests: args: - fieldReferenceValue: author - stringValue: Douglas Adams - name: eq + name: equal name: where - args: - mapValue: fields: - max_published: + min_published: functionValue: args: - fieldReferenceValue: published - integerValue: '1900' - name: logical_maximum + name: minimum max_rating: functionValue: args: - fieldReferenceValue: rating - doubleValue: 4.5 - name: logical_maximum + name: maximum name: select - description: testMapGet pipeline: @@ -1452,14 +1472,14 @@ tests: - Field: published - DESCENDING - Select: - - ExprWithAlias: - - MapGet: + - AliasedExpr: + - Expr.map_get: - Field: awards - - Constant: hugo + - hugo - "hugoAward" - Field: title - Where: - - Eq: + - Expr.equal: - Field: hugoAward - Constant: true assert_results: @@ -1498,13 +1518,13 @@ tests: args: - fieldReferenceValue: hugoAward - booleanValue: true - name: eq + name: equal name: where - description: testNestedFields pipeline: - Collection: books - Where: - - Eq: + - Expr.equal: - Field: awards.hugo - Constant: true - Sort: @@ -1530,7 +1550,7 @@ tests: args: - fieldReferenceValue: awards.hugo - booleanValue: true - name: eq + name: equal name: where - args: - mapValue: @@ -1604,7 +1624,7 @@ tests: pipeline: - Collection: books - Where: - - Eq: + - Expr.equal: - Field: title - Constant: The Hitchhiker's Guide to the Galaxy - Unnest: @@ -1626,7 +1646,7 @@ tests: args: - fieldReferenceValue: title - stringValue: The Hitchhiker's Guide to the Galaxy - name: eq + name: equal name: where - args: - fieldReferenceValue: tags @@ -1638,3 +1658,301 @@ tests: tags_alias: fieldReferenceValue: tags_alias name: select + - description: testGreaterThanOrEqual + pipeline: + - Collection: books + - Where: + - Expr.greater_than_or_equal: + - Field: rating + - Constant: 4.6 + - Select: + - title + - rating + - Sort: + - Ordering: + - Field: rating + - ASCENDING + assert_results: + - title: Dune + rating: 4.6 + - title: The Lord of the Rings + rating: 4.7 + - description: testInAndNotIn + pipeline: + - Collection: books + - Where: + - And: + - Expr.equal_any: + - Field: genre + - - Constant: Romance + - Constant: Dystopian + - Expr.not_equal_any: + - Field: author + - - Constant: "George Orwell" + assert_results: + - title: "Pride and Prejudice" + author: "Jane Austen" + genre: "Romance" + published: 1813 + rating: 4.5 + tags: + - classic + - social commentary + - love + awards: + none: true + - title: "The Handmaid's Tale" + author: "Margaret Atwood" + genre: "Dystopian" + published: 1985 + rating: 4.1 + tags: + - feminism + - totalitarianism + - resistance + awards: + "arthur c. clarke": true + "booker prize": false + - description: testArrayReverse + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Select: + - AliasedExpr: + - Expr.array_reverse: + - Field: tags + - "reversedTags" + assert_results: + - reversedTags: + - adventure + - space + - comedy + - description: testExists + pipeline: + - Collection: books + - Where: + - And: + - Expr.exists: + - Field: awards.pulitzer + - Expr.equal: + - Field: awards.pulitzer + - Constant: true + - Select: + - title + assert_results: + - title: To Kill a Mockingbird + - description: testSum + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: genre + - Constant: Science Fiction + - Aggregate: + - AliasedExpr: + - Expr.sum: + - Field: rating + - "total_rating" + assert_results: + - total_rating: 8.8 + - description: testStringContains + pipeline: + - Collection: books + - Where: + - Expr.string_contains: + - Field: title + - Constant: "Hitchhiker's" + - Select: + - title + assert_results: + - title: "The Hitchhiker's Guide to the Galaxy" + - description: testVectorLength + pipeline: + - Collection: vectors + - Select: + - AliasedExpr: + - Expr.vector_length: + - Field: embedding + - "embedding_length" + - Sort: + - Ordering: + - Field: embedding_length + - ASCENDING + assert_results: + - embedding_length: 3 + - embedding_length: 4 + - description: testTimestampFunctions + pipeline: + - Collection: timestamps + - Select: + - AliasedExpr: + - Expr.timestamp_to_unix_micros: + - Field: time + - "micros" + - AliasedExpr: + - Expr.timestamp_to_unix_millis: + - Field: time + - "millis" + - AliasedExpr: + - Expr.timestamp_to_unix_seconds: + - Field: time + - "seconds" + - AliasedExpr: + - Expr.unix_micros_to_timestamp: + - Field: micros + - "from_micros" + - AliasedExpr: + - Expr.unix_millis_to_timestamp: + - Field: millis + - "from_millis" + - AliasedExpr: + - Expr.unix_seconds_to_timestamp: + - Field: seconds + - "from_seconds" + - AliasedExpr: + - Expr.timestamp_add: + - Field: time + - Constant: "day" + - Constant: 1 + - "plus_day" + - AliasedExpr: + - Expr.timestamp_subtract: + - Field: time + - Constant: "hour" + - Constant: 1 + - "minus_hour" + assert_results: + - micros: 735998460654321 + millis: 735998460654 + seconds: 735998460 + from_micros: "1993-04-28T12:01:00.654321+00:00" + from_millis: "1993-04-28T12:01:00.654000+00:00" + from_seconds: "1993-04-28T12:01:00.000000+00:00" + plus_day: "1993-04-29T12:01:00.654321+00:00" + minus_hour: "1993-04-28T11:01:00.654321+00:00" + - description: testCollectionId + pipeline: + - Collection: books + - Limit: 1 + - Select: + - AliasedExpr: + - Expr.collection_id: + - Field: __name__ + - "collectionName" + assert_results: + - collectionName: "books" + - description: testXor + pipeline: + - Collection: books + - Where: + - Xor: + - - Expr.equal: + - Field: genre + - Constant: Romance + - Expr.greater_than: + - Field: published + - Constant: 1980 + - Select: + - title + - genre + - published + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: "Pride and Prejudice" + genre: "Romance" + published: 1813 + - title: "The Handmaid's Tale" + genre: "Dystopian" + published: 1985 + - description: testConditional + pipeline: + - Collection: books + - Select: + - title + - AliasedExpr: + - Conditional: + - Expr.greater_than: + - Field: published + - Constant: 1950 + - Constant: "Modern" + - Constant: "Classic" + - "era" + - Sort: + - Ordering: + - Field: title + - ASCENDING + - Limit: 4 + assert_results: + - title: "1984" + era: "Classic" + - title: "Crime and Punishment" + era: "Classic" + - title: "Dune" + era: "Modern" + - title: "One Hundred Years of Solitude" + era: "Modern" + - description: testFieldToFieldArithmetic + pipeline: + - Collection: books + - Where: + - Expr.equal: + - Field: title + - Constant: "Dune" + - Select: + - AliasedExpr: + - Expr.add: + - Field: published + - Field: rating + - "pub_plus_rating" + assert_results: + - pub_plus_rating: 1969.6 + - description: testFieldToFieldComparison + pipeline: + - Collection: books + - Where: + - Expr.greater_than: + - Field: published + - Field: rating + - Select: + - title + assert_count: 10 # All books were published after year 4.7 + - description: testExistsNegative + pipeline: + - Collection: books + - Where: + - Expr.exists: + - Field: non_existent_field + assert_count: 0 + - description: testConditionalWithFields + pipeline: + - Collection: books + - Where: + - Expr.equal_any: + - Field: title + - - Constant: "Dune" + - Constant: "1984" + - Select: + - title + - AliasedExpr: + - Conditional: + - Expr.greater_than: + - Field: published + - Constant: 1950 + - Field: author + - Field: genre + - "conditional_field" + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: "1984" + conditional_field: "Dystopian" + - title: "Dune" + conditional_field: "Frank Herbert" diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py index 9d44bbc57..d4c654e63 100644 --- a/tests/system/test_pipeline_acceptance.py +++ b/tests/system/test_pipeline_acceptance.py @@ -17,6 +17,7 @@ from __future__ import annotations import os +import datetime import pytest import yaml import re @@ -26,6 +27,7 @@ from google.cloud.firestore_v1 import _pipeline_stages as stages from google.cloud.firestore_v1 import pipeline_expressions +from google.cloud.firestore_v1.vector import Vector from google.api_core.exceptions import GoogleAPIError from google.cloud.firestore import Client, AsyncClient @@ -91,7 +93,7 @@ def test_pipeline_results(test_dict, client): """ Ensure pipeline returns expected results """ - expected_results = test_dict.get("assert_results", None) + expected_results = _parse_yaml_types(test_dict.get("assert_results", None)) expected_count = test_dict.get("assert_count", None) pipeline = parse_pipeline(client, test_dict["pipeline"]) # check if server responds as expected @@ -132,7 +134,7 @@ async def test_pipeline_results_async(test_dict, async_client): """ Ensure pipeline returns expected results """ - expected_results = test_dict.get("assert_results", None) + expected_results = _parse_yaml_types(test_dict.get("assert_results", None)) expected_count = test_dict.get("assert_count", None) pipeline = parse_pipeline(async_client, test_dict["pipeline"]) # check if server responds as expected @@ -160,7 +162,7 @@ def parse_pipeline(client, pipeline: list[dict[str, Any], str]): # find arguments if given if isinstance(stage, dict): stage_yaml_args = stage[stage_name] - stage_obj = _apply_yaml_args(stage_cls, client, stage_yaml_args) + stage_obj = _apply_yaml_args_to_callable(stage_cls, client, stage_yaml_args) else: # yaml has no arguments stage_obj = stage_cls() @@ -178,15 +180,21 @@ def _parse_expressions(client, yaml_element: Any): if len(yaml_element) == 1 and _is_expr_string(next(iter(yaml_element))): # build pipeline expressions if possible cls_str = next(iter(yaml_element)) - cls = getattr(pipeline_expressions, cls_str) + callable_obj = None + if "." in cls_str: + cls_name, method_name = cls_str.split(".") + cls = getattr(pipeline_expressions, cls_name) + callable_obj = getattr(cls, method_name) + else: + callable_obj = getattr(pipeline_expressions, cls_str) yaml_args = yaml_element[cls_str] - return _apply_yaml_args(cls, client, yaml_args) + return _apply_yaml_args_to_callable(callable_obj, client, yaml_args) elif len(yaml_element) == 1 and _is_stage_string(next(iter(yaml_element))): # build pipeline stage if possible (eg, for SampleOptions) cls_str = next(iter(yaml_element)) cls = getattr(stages, cls_str) yaml_args = yaml_element[cls_str] - return _apply_yaml_args(cls, client, yaml_args) + return _apply_yaml_args_to_callable(cls, client, yaml_args) elif len(yaml_element) == 1 and list(yaml_element)[0] == "Pipeline": # find Pipeline objects for Union expressions other_ppl = yaml_element["Pipeline"] @@ -203,25 +211,33 @@ def _parse_expressions(client, yaml_element: Any): return yaml_element -def _apply_yaml_args(cls, client, yaml_args): +def _apply_yaml_args_to_callable(callable_obj, client, yaml_args): """ Helper to instantiate a class with yaml arguments. The arguments will be applied as positional or keyword arguments, based on type """ if isinstance(yaml_args, dict): - return cls(**_parse_expressions(client, yaml_args)) + return callable_obj(**_parse_expressions(client, yaml_args)) elif isinstance(yaml_args, list): # yaml has an array of arguments. Treat as args - return cls(*_parse_expressions(client, yaml_args)) + return callable_obj(*_parse_expressions(client, yaml_args)) else: # yaml has a single argument - return cls(_parse_expressions(client, yaml_args)) + return callable_obj(_parse_expressions(client, yaml_args)) def _is_expr_string(yaml_str): """ Returns true if a string represents a class in pipeline_expressions """ + if isinstance(yaml_str, str) and "." in yaml_str: + parts = yaml_str.split(".") + if len(parts) == 2: + cls_name, method_name = parts + if hasattr(pipeline_expressions, cls_name): + cls = getattr(pipeline_expressions, cls_name) + if hasattr(cls, method_name): + return True return ( isinstance(yaml_str, str) and yaml_str[0].isupper() @@ -251,6 +267,26 @@ def event_loop(): loop.close() +def _parse_yaml_types(data): + """helper to convert yaml data to firestore objects when needed""" + if isinstance(data, dict): + return {key: _parse_yaml_types(value) for key, value in data.items()} + if isinstance(data, list): + # detect vectors + if all([isinstance(d, float) for d in data]): + return Vector(data) + else: + return [_parse_yaml_types(value) for value in data] + # detect timestamps + if isinstance(data, str) and ":" in data: + try: + parsed_datetime = datetime.datetime.fromisoformat(data) + return parsed_datetime + except ValueError: + pass + return data + + @pytest.fixture(scope="module") def client(): """ @@ -258,6 +294,7 @@ def client(): """ client = Client(project=FIRESTORE_PROJECT, database=FIRESTORE_ENTERPRISE_DB) data = yaml_loader("data") + to_delete = [] try: # setup data batch = client.batch() @@ -265,16 +302,14 @@ def client(): collection_ref = client.collection(collection_name) for document_id, document_data in documents.items(): document_ref = collection_ref.document(document_id) - batch.set(document_ref, document_data) + to_delete.append(document_ref) + batch.set(document_ref, _parse_yaml_types(document_data)) batch.commit() yield client finally: # clear data - for collection_name, documents in data.items(): - collection_ref = client.collection(collection_name) - for document_id in documents: - document_ref = collection_ref.document(document_id) - document_ref.delete() + for document_ref in to_delete: + document_ref.delete() @pytest.fixture(scope="module") diff --git a/tests/unit/v1/test_async_pipeline.py b/tests/unit/v1/test_async_pipeline.py index 47eedc983..b3ed83337 100644 --- a/tests/unit/v1/test_async_pipeline.py +++ b/tests/unit/v1/test_async_pipeline.py @@ -17,7 +17,6 @@ from google.cloud.firestore_v1 import _pipeline_stages as stages from google.cloud.firestore_v1.pipeline_expressions import Field -from google.cloud.firestore_v1.pipeline_expressions import Exists def _make_async_pipeline(*args, client=mock.Mock()): @@ -386,7 +385,7 @@ async def test_async_pipeline_stream_stream_equivalence_mocked(): ("remove_fields", (Field.of("n"),), stages.RemoveFields), ("select", ("name",), stages.Select), ("select", (Field.of("n"),), stages.Select), - ("where", (Exists(Field.of("n")),), stages.Where), + ("where", (Field.of("n").exists(),), stages.Where), ("find_nearest", ("name", [0.1], 0), stages.FindNearest), ( "find_nearest", diff --git a/tests/unit/v1/test_pipeline.py b/tests/unit/v1/test_pipeline.py index b237ad5ac..f90279e00 100644 --- a/tests/unit/v1/test_pipeline.py +++ b/tests/unit/v1/test_pipeline.py @@ -17,7 +17,6 @@ from google.cloud.firestore_v1 import _pipeline_stages as stages from google.cloud.firestore_v1.pipeline_expressions import Field -from google.cloud.firestore_v1.pipeline_expressions import Exists def _make_pipeline(*args, client=mock.Mock()): @@ -363,7 +362,7 @@ def test_pipeline_execute_stream_equivalence_mocked(): ("remove_fields", (Field.of("n"),), stages.RemoveFields), ("select", ("name",), stages.Select), ("select", (Field.of("n"),), stages.Select), - ("where", (Exists(Field.of("n")),), stages.Where), + ("where", (Field.of("n").exists(),), stages.Where), ("find_nearest", ("name", [0.1], 0), stages.FindNearest), ( "find_nearest", diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index 936c0a0a9..c5329df33 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -22,8 +22,13 @@ from google.cloud.firestore_v1.types.document import Value from google.cloud.firestore_v1.vector import Vector from google.cloud.firestore_v1._helpers import GeoPoint -from google.cloud.firestore_v1.pipeline_expressions import FilterCondition, ListOfExprs import google.cloud.firestore_v1.pipeline_expressions as expr +from google.cloud.firestore_v1.pipeline_expressions import BooleanExpr +from google.cloud.firestore_v1.pipeline_expressions import _ListOfExprs +from google.cloud.firestore_v1.pipeline_expressions import Expr +from google.cloud.firestore_v1.pipeline_expressions import Constant +from google.cloud.firestore_v1.pipeline_expressions import Field +from google.cloud.firestore_v1.pipeline_expressions import Ordering @pytest.fixture @@ -37,126 +42,43 @@ class TestOrdering: @pytest.mark.parametrize( "direction_arg,expected_direction", [ - ("ASCENDING", expr.Ordering.Direction.ASCENDING), - ("DESCENDING", expr.Ordering.Direction.DESCENDING), - ("ascending", expr.Ordering.Direction.ASCENDING), - ("descending", expr.Ordering.Direction.DESCENDING), - (expr.Ordering.Direction.ASCENDING, expr.Ordering.Direction.ASCENDING), - (expr.Ordering.Direction.DESCENDING, expr.Ordering.Direction.DESCENDING), + ("ASCENDING", Ordering.Direction.ASCENDING), + ("DESCENDING", Ordering.Direction.DESCENDING), + ("ascending", Ordering.Direction.ASCENDING), + ("descending", Ordering.Direction.DESCENDING), + (Ordering.Direction.ASCENDING, Ordering.Direction.ASCENDING), + (Ordering.Direction.DESCENDING, Ordering.Direction.DESCENDING), ], ) def test_ctor(self, direction_arg, expected_direction): - instance = expr.Ordering("field1", direction_arg) - assert isinstance(instance.expr, expr.Field) + instance = Ordering("field1", direction_arg) + assert isinstance(instance.expr, Field) assert instance.expr.path == "field1" assert instance.order_dir == expected_direction def test_repr(self): - field_expr = expr.Field.of("field1") - instance = expr.Ordering(field_expr, "ASCENDING") + field_expr = Field.of("field1") + instance = Ordering(field_expr, "ASCENDING") repr_str = repr(instance) assert repr_str == "Field.of('field1').ascending()" - instance = expr.Ordering(field_expr, "DESCENDING") + instance = Ordering(field_expr, "DESCENDING") repr_str = repr(instance) assert repr_str == "Field.of('field1').descending()" def test_to_pb(self): - field_expr = expr.Field.of("field1") - instance = expr.Ordering(field_expr, "ASCENDING") + field_expr = Field.of("field1") + instance = Ordering(field_expr, "ASCENDING") result = instance._to_pb() assert result.map_value.fields["expression"].field_reference_value == "field1" assert result.map_value.fields["direction"].string_value == "ascending" - instance = expr.Ordering(field_expr, "DESCENDING") + instance = Ordering(field_expr, "DESCENDING") result = instance._to_pb() assert result.map_value.fields["expression"].field_reference_value == "field1" assert result.map_value.fields["direction"].string_value == "descending" -class TestExpr: - def test_ctor(self): - """ - Base class should be abstract - """ - with pytest.raises(TypeError): - expr.Expr() - - @pytest.mark.parametrize( - "method,args,result_cls", - [ - ("add", (2,), expr.Add), - ("subtract", (2,), expr.Subtract), - ("multiply", (2,), expr.Multiply), - ("divide", (2,), expr.Divide), - ("mod", (2,), expr.Mod), - ("logical_max", (2,), expr.LogicalMax), - ("logical_min", (2,), expr.LogicalMin), - ("eq", (2,), expr.Eq), - ("neq", (2,), expr.Neq), - ("lt", (2,), expr.Lt), - ("lte", (2,), expr.Lte), - ("gt", (2,), expr.Gt), - ("gte", (2,), expr.Gte), - ("in_any", ([None],), expr.In), - ("not_in_any", ([None],), expr.Not), - ("array_contains", (None,), expr.ArrayContains), - ("array_contains_all", ([None],), expr.ArrayContainsAll), - ("array_contains_any", ([None],), expr.ArrayContainsAny), - ("array_length", (), expr.ArrayLength), - ("array_reverse", (), expr.ArrayReverse), - ("is_nan", (), expr.IsNaN), - ("exists", (), expr.Exists), - ("sum", (), expr.Sum), - ("avg", (), expr.Avg), - ("count", (), expr.Count), - ("min", (), expr.Min), - ("max", (), expr.Max), - ("char_length", (), expr.CharLength), - ("byte_length", (), expr.ByteLength), - ("like", ("pattern",), expr.Like), - ("regex_contains", ("regex",), expr.RegexContains), - ("regex_matches", ("regex",), expr.RegexMatch), - ("str_contains", ("substring",), expr.StrContains), - ("starts_with", ("prefix",), expr.StartsWith), - ("ends_with", ("postfix",), expr.EndsWith), - ("str_concat", ("elem1", expr.Constant("elem2")), expr.StrConcat), - ("map_get", ("key",), expr.MapGet), - ("vector_length", (), expr.VectorLength), - ("timestamp_to_unix_micros", (), expr.TimestampToUnixMicros), - ("unix_micros_to_timestamp", (), expr.UnixMicrosToTimestamp), - ("timestamp_to_unix_millis", (), expr.TimestampToUnixMillis), - ("unix_millis_to_timestamp", (), expr.UnixMillisToTimestamp), - ("timestamp_to_unix_seconds", (), expr.TimestampToUnixSeconds), - ("unix_seconds_to_timestamp", (), expr.UnixSecondsToTimestamp), - ("timestamp_add", ("day", 1), expr.TimestampAdd), - ("timestamp_sub", ("hour", 2.5), expr.TimestampSub), - ("ascending", (), expr.Ordering), - ("descending", (), expr.Ordering), - ("as_", ("alias",), expr.ExprWithAlias), - ], - ) - @pytest.mark.parametrize( - "base_instance", - [ - expr.Constant(1), - expr.Function.add("1", 1), - expr.Field.of("test"), - expr.Constant(1).as_("one"), - ], - ) - def test_infix_call(self, method, args, result_cls, base_instance): - """ - many FilterCondition expressions support infix execution, and are exposed as methods on Expr. Test calling them - """ - method_ptr = getattr(base_instance, method) - - result = method_ptr(*args) - assert isinstance(result, result_cls) - if isinstance(result, expr.Function) and not method == "not_in_any": - assert result.params[0] == base_instance - - class TestConstant: @pytest.mark.parametrize( "input_val, to_pb_val", @@ -200,7 +122,7 @@ class TestConstant: ], ) def test_to_pb(self, input_val, to_pb_val): - instance = expr.Constant.of(input_val) + instance = Constant.of(input_val) assert instance._to_pb() == to_pb_val @pytest.mark.parametrize( @@ -226,25 +148,25 @@ def test_to_pb(self, input_val, to_pb_val): ], ) def test_repr(self, input_val, expected): - instance = expr.Constant.of(input_val) + instance = Constant.of(input_val) repr_string = repr(instance) assert repr_string == expected @pytest.mark.parametrize( "first,second,expected", [ - (expr.Constant.of(1), expr.Constant.of(2), False), - (expr.Constant.of(1), expr.Constant.of(1), True), - (expr.Constant.of(1), 1, True), - (expr.Constant.of(1), 2, False), - (expr.Constant.of("1"), 1, False), - (expr.Constant.of("1"), "1", True), - (expr.Constant.of(None), expr.Constant.of(0), False), - (expr.Constant.of(None), expr.Constant.of(None), True), - (expr.Constant.of([1, 2, 3]), expr.Constant.of([1, 2, 3]), True), - (expr.Constant.of([1, 2, 3]), expr.Constant.of([1, 2]), False), - (expr.Constant.of([1, 2, 3]), [1, 2, 3], True), - (expr.Constant.of([1, 2, 3]), object(), False), + (Constant.of(1), Constant.of(2), False), + (Constant.of(1), Constant.of(1), True), + (Constant.of(1), 1, True), + (Constant.of(1), 2, False), + (Constant.of("1"), 1, False), + (Constant.of("1"), "1", True), + (Constant.of(None), Constant.of(0), False), + (Constant.of(None), Constant.of(None), True), + (Constant.of([1, 2, 3]), Constant.of([1, 2, 3]), True), + (Constant.of([1, 2, 3]), Constant.of([1, 2]), False), + (Constant.of([1, 2, 3]), [1, 2, 3], True), + (Constant.of([1, 2, 3]), object(), False), ], ) def test_equality(self, first, second, expected): @@ -253,49 +175,49 @@ def test_equality(self, first, second, expected): class TestListOfExprs: def test_to_pb(self): - instance = expr.ListOfExprs([expr.Constant(1), expr.Constant(2)]) + instance = _ListOfExprs([Constant(1), Constant(2)]) result = instance._to_pb() assert len(result.array_value.values) == 2 assert result.array_value.values[0].integer_value == 1 assert result.array_value.values[1].integer_value == 2 def test_empty_to_pb(self): - instance = expr.ListOfExprs([]) + instance = _ListOfExprs([]) result = instance._to_pb() assert len(result.array_value.values) == 0 def test_repr(self): - instance = expr.ListOfExprs([expr.Constant(1), expr.Constant(2)]) + instance = _ListOfExprs([Constant(1), Constant(2)]) repr_string = repr(instance) - assert repr_string == "ListOfExprs([Constant.of(1), Constant.of(2)])" - empty_instance = expr.ListOfExprs([]) + assert repr_string == "[Constant.of(1), Constant.of(2)]" + empty_instance = _ListOfExprs([]) empty_repr_string = repr(empty_instance) - assert empty_repr_string == "ListOfExprs([])" + assert empty_repr_string == "[]" @pytest.mark.parametrize( "first,second,expected", [ - (expr.ListOfExprs([]), expr.ListOfExprs([]), True), - (expr.ListOfExprs([]), expr.ListOfExprs([expr.Constant(1)]), False), - (expr.ListOfExprs([expr.Constant(1)]), expr.ListOfExprs([]), False), + (_ListOfExprs([]), _ListOfExprs([]), True), + (_ListOfExprs([]), _ListOfExprs([Constant(1)]), False), + (_ListOfExprs([Constant(1)]), _ListOfExprs([]), False), ( - expr.ListOfExprs([expr.Constant(1)]), - expr.ListOfExprs([expr.Constant(1)]), + _ListOfExprs([Constant(1)]), + _ListOfExprs([Constant(1)]), True, ), ( - expr.ListOfExprs([expr.Constant(1)]), - expr.ListOfExprs([expr.Constant(2)]), + _ListOfExprs([Constant(1)]), + _ListOfExprs([Constant(2)]), False, ), ( - expr.ListOfExprs([expr.Constant(1), expr.Constant(2)]), - expr.ListOfExprs([expr.Constant(1), expr.Constant(2)]), + _ListOfExprs([Constant(1), Constant(2)]), + _ListOfExprs([Constant(1), Constant(2)]), True, ), - (expr.ListOfExprs([expr.Constant(1)]), [expr.Constant(1)], False), - (expr.ListOfExprs([expr.Constant(1)]), [1], False), - (expr.ListOfExprs([expr.Constant(1)]), object(), False), + (_ListOfExprs([Constant(1)]), [Constant(1)], False), + (_ListOfExprs([Constant(1)]), [1], False), + (_ListOfExprs([Constant(1)]), object(), False), ], ) def test_equality(self, first, second, expected): @@ -316,8 +238,8 @@ def test_ctor(self): def test_value_from_selectables(self): selectable_list = [ - expr.Field.of("field1"), - expr.Field.of("field2").as_("alias2"), + Field.of("field1"), + Field.of("field2").as_("alias2"), ] result = expr.Selectable._value_from_selectables(*selectable_list) assert len(result.map_value.fields) == 2 @@ -327,14 +249,14 @@ def test_value_from_selectables(self): @pytest.mark.parametrize( "first,second,expected", [ - (expr.Field.of("field1"), expr.Field.of("field1"), True), - (expr.Field.of("field1"), expr.Field.of("field2"), False), - (expr.Field.of(None), object(), False), - (expr.Field.of("f").as_("a"), expr.Field.of("f").as_("a"), True), - (expr.Field.of("one").as_("a"), expr.Field.of("two").as_("a"), False), - (expr.Field.of("f").as_("one"), expr.Field.of("f").as_("two"), False), - (expr.Field.of("field"), expr.Field.of("field").as_("alias"), False), - (expr.Field.of("field").as_("alias"), expr.Field.of("field"), False), + (Field.of("field1"), Field.of("field1"), True), + (Field.of("field1"), Field.of("field2"), False), + (Field.of(None), object(), False), + (Field.of("f").as_("a"), Field.of("f").as_("a"), True), + (Field.of("one").as_("a"), Field.of("two").as_("a"), False), + (Field.of("f").as_("one"), Field.of("f").as_("two"), False), + (Field.of("field"), Field.of("field").as_("alias"), False), + (Field.of("field").as_("alias"), Field.of("field"), False), ], ) def test_equality(self, first, second, expected): @@ -342,52 +264,79 @@ def test_equality(self, first, second, expected): class TestField: def test_repr(self): - instance = expr.Field.of("field1") + instance = Field.of("field1") repr_string = repr(instance) assert repr_string == "Field.of('field1')" def test_of(self): - instance = expr.Field.of("field1") + instance = Field.of("field1") assert instance.path == "field1" def test_to_pb(self): - instance = expr.Field.of("field1") + instance = Field.of("field1") result = instance._to_pb() assert result.field_reference_value == "field1" def test_to_map(self): - instance = expr.Field.of("field1") + instance = Field.of("field1") result = instance._to_map() assert result[0] == "field1" assert result[1] == Value(field_reference_value="field1") - class TestExprWithAlias: + class TestAliasedExpr: def test_repr(self): - instance = expr.Field.of("field1").as_("alias1") + instance = Field.of("field1").as_("alias1") assert repr(instance) == "Field.of('field1').as_('alias1')" def test_ctor(self): - arg = expr.Field.of("field1") + arg = Field.of("field1") alias = "alias1" - instance = expr.ExprWithAlias(arg, alias) + instance = expr.AliasedExpr(arg, alias) assert instance.expr == arg assert instance.alias == alias def test_to_pb(self): - arg = expr.Field.of("field1") + arg = Field.of("field1") alias = "alias1" - instance = expr.ExprWithAlias(arg, alias) + instance = expr.AliasedExpr(arg, alias) result = instance._to_pb() assert result.map_value.fields.get("alias1") == arg._to_pb() def test_to_map(self): - instance = expr.Field.of("field1").as_("alias1") + instance = Field.of("field1").as_("alias1") result = instance._to_map() assert result[0] == "alias1" assert result[1] == Value(field_reference_value="field1") + class TestAliasedAggregate: + def test_repr(self): + instance = Field.of("field1").maximum().as_("alias1") + assert repr(instance) == "Field.of('field1').maximum().as_('alias1')" + + def test_ctor(self): + arg = Expr.minimum("field1") + alias = "alias1" + instance = expr.AliasedAggregate(arg, alias) + assert instance.expr == arg + assert instance.alias == alias + + def test_to_pb(self): + arg = Field.of("field1").average() + alias = "alias1" + instance = expr.AliasedAggregate(arg, alias) + result = instance._to_pb() + assert result.map_value.fields.get("alias1") == arg._to_pb() + + def test_to_map(self): + arg = Field.of("field1").count() + alias = "alias1" + instance = expr.AliasedAggregate(arg, alias) + result = instance._to_map() + assert result[0] == "alias1" + assert result[1] == arg._to_pb() + -class TestFilterCondition: +class TestBooleanExpr: def test__from_query_filter_pb_composite_filter_or(self, mock_client): """ test composite OR filters @@ -415,17 +364,13 @@ def test__from_query_filter_pb_composite_filter_or(self, mock_client): composite_filter=composite_pb ) - result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + result = BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) # should include existance checks - expected_cond1 = expr.And( - expr.Exists(expr.Field.of("field1")), - expr.Eq(expr.Field.of("field1"), expr.Constant("val1")), - ) - expected_cond2 = expr.And( - expr.Exists(expr.Field.of("field2")), - expr.Eq(expr.Field.of("field2"), expr.Constant(None)), - ) + field1 = Field.of("field1") + field2 = Field.of("field2") + expected_cond1 = expr.And(field1.exists(), field1.equal(Constant("val1"))) + expected_cond2 = expr.And(field2.exists(), field2.equal(Constant(None))) expected = expr.Or(expected_cond1, expected_cond2) assert repr(result) == repr(expected) @@ -458,17 +403,13 @@ def test__from_query_filter_pb_composite_filter_and(self, mock_client): composite_filter=composite_pb ) - result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + result = BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) # should include existance checks - expected_cond1 = expr.And( - expr.Exists(expr.Field.of("field1")), - expr.Gt(expr.Field.of("field1"), expr.Constant(100)), - ) - expected_cond2 = expr.And( - expr.Exists(expr.Field.of("field2")), - expr.Lt(expr.Field.of("field2"), expr.Constant(200)), - ) + field1 = Field.of("field1") + field2 = Field.of("field2") + expected_cond1 = expr.And(field1.exists(), field1.greater_than(Constant(100))) + expected_cond2 = expr.And(field2.exists(), field2.less_than(Constant(200))) expected = expr.And(expected_cond1, expected_cond2) assert repr(result) == repr(expected) @@ -509,19 +450,15 @@ def test__from_query_filter_pb_composite_filter_nested(self, mock_client): composite_filter=outer_or_pb ) - result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + result = BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) - expected_cond1 = expr.And( - expr.Exists(expr.Field.of("field1")), - expr.Eq(expr.Field.of("field1"), expr.Constant("val1")), - ) - expected_cond2 = expr.And( - expr.Exists(expr.Field.of("field2")), - expr.Gt(expr.Field.of("field2"), expr.Constant(10)), - ) + field1 = Field.of("field1") + field2 = Field.of("field2") + field3 = Field.of("field3") + expected_cond1 = expr.And(field1.exists(), field1.equal(Constant("val1"))) + expected_cond2 = expr.And(field2.exists(), field2.greater_than(Constant(10))) expected_cond3 = expr.And( - expr.Exists(expr.Field.of("field3")), - expr.Not(expr.Eq(expr.Field.of("field3"), expr.Constant(None))), + field3.exists(), expr.Not(field3.equal(Constant(None))) ) expected_inner_and = expr.And(expected_cond2, expected_cond3) expected_outer_or = expr.Or(expected_cond1, expected_inner_and) @@ -546,23 +483,23 @@ def test__from_query_filter_pb_composite_filter_unknown_op(self, mock_client): ) with pytest.raises(TypeError, match="Unexpected CompositeFilter operator type"): - FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) @pytest.mark.parametrize( "op_enum, expected_expr_func", [ - (query_pb.StructuredQuery.UnaryFilter.Operator.IS_NAN, expr.IsNaN), + (query_pb.StructuredQuery.UnaryFilter.Operator.IS_NAN, Expr.is_nan), ( query_pb.StructuredQuery.UnaryFilter.Operator.IS_NOT_NAN, lambda f: expr.Not(f.is_nan()), ), ( query_pb.StructuredQuery.UnaryFilter.Operator.IS_NULL, - lambda f: f.eq(None), + lambda f: f.equal(None), ), ( query_pb.StructuredQuery.UnaryFilter.Operator.IS_NOT_NULL, - lambda f: expr.Not(f.eq(None)), + lambda f: expr.Not(f.equal(None)), ), ], ) @@ -579,12 +516,12 @@ def test__from_query_filter_pb_unary_filter( ) wrapped_filter_pb = query_pb.StructuredQuery.Filter(unary_filter=filter_pb) - result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + result = BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) - field_expr_inst = expr.Field.of(field_path) + field_expr_inst = Field.of(field_path) expected_condition = expected_expr_func(field_expr_inst) # should include existance checks - expected = expr.And(expr.Exists(field_expr_inst), expected_condition) + expected = expr.And(field_expr_inst.exists(), expected_condition) assert repr(result) == repr(expected) @@ -600,40 +537,56 @@ def test__from_query_filter_pb_unary_filter_unknown_op(self, mock_client): wrapped_filter_pb = query_pb.StructuredQuery.Filter(unary_filter=filter_pb) with pytest.raises(TypeError, match="Unexpected UnaryFilter operator type"): - FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) @pytest.mark.parametrize( "op_enum, value, expected_expr_func", [ - (query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN, 10, expr.Lt), + ( + query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN, + 10, + Expr.less_than, + ), ( query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN_OR_EQUAL, 10, - expr.Lte, + Expr.less_than_or_equal, + ), + ( + query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN, + 10, + Expr.greater_than, ), - (query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN, 10, expr.Gt), ( query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN_OR_EQUAL, 10, - expr.Gte, + Expr.greater_than_or_equal, + ), + (query_pb.StructuredQuery.FieldFilter.Operator.EQUAL, 10, Expr.equal), + ( + query_pb.StructuredQuery.FieldFilter.Operator.NOT_EQUAL, + 10, + Expr.not_equal, ), - (query_pb.StructuredQuery.FieldFilter.Operator.EQUAL, 10, expr.Eq), - (query_pb.StructuredQuery.FieldFilter.Operator.NOT_EQUAL, 10, expr.Neq), ( query_pb.StructuredQuery.FieldFilter.Operator.ARRAY_CONTAINS, 10, - expr.ArrayContains, + Expr.array_contains, ), ( query_pb.StructuredQuery.FieldFilter.Operator.ARRAY_CONTAINS_ANY, [10, 20], - expr.ArrayContainsAny, + Expr.array_contains_any, + ), + ( + query_pb.StructuredQuery.FieldFilter.Operator.IN, + [10, 20], + Expr.equal_any, ), - (query_pb.StructuredQuery.FieldFilter.Operator.IN, [10, 20], expr.In), ( query_pb.StructuredQuery.FieldFilter.Operator.NOT_IN, [10, 20], - lambda f, v: expr.Not(f.in_any(v)), + Expr.not_equal_any, ), ], ) @@ -652,18 +605,16 @@ def test__from_query_filter_pb_field_filter( ) wrapped_filter_pb = query_pb.StructuredQuery.Filter(field_filter=filter_pb) - result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + result = BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) - field_expr = expr.Field.of(field_path) + field_expr = Field.of(field_path) # convert values into constants value = ( - [expr.Constant(e) for e in value] - if isinstance(value, list) - else expr.Constant(value) + [Constant(e) for e in value] if isinstance(value, list) else Constant(value) ) expected_condition = expected_expr_func(field_expr, value) # should include existance checks - expected = expr.And(expr.Exists(field_expr), expected_condition) + expected = expr.And(field_expr.exists(), expected_condition) assert repr(result) == repr(expected) @@ -681,7 +632,7 @@ def test__from_query_filter_pb_field_filter_unknown_op(self, mock_client): wrapped_filter_pb = query_pb.StructuredQuery.Filter(field_filter=filter_pb) with pytest.raises(TypeError, match="Unexpected FieldFilter operator type"): - FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) def test__from_query_filter_pb_unknown_filter_type(self, mock_client): """ @@ -689,26 +640,64 @@ def test__from_query_filter_pb_unknown_filter_type(self, mock_client): """ # Test with an unexpected protobuf type with pytest.raises(TypeError, match="Unexpected filter type"): - FilterCondition._from_query_filter_pb(document_pb.Value(), mock_client) + BooleanExpr._from_query_filter_pb(document_pb.Value(), mock_client) -class TestFilterConditionClasses: +class TestExpressionMethods: """ - contains test methods for each Expr class that derives from FilterCondition + contains test methods for each Expr method """ + @pytest.mark.parametrize( + "first,second,expected", + [ + ( + Field.of("a").char_length(), + Field.of("a").char_length(), + True, + ), + ( + Field.of("a").char_length(), + Field.of("b").char_length(), + False, + ), + ( + Field.of("a").char_length(), + Field.of("a").byte_length(), + False, + ), + ( + Field.of("a").char_length(), + Field.of("b").byte_length(), + False, + ), + ( + Constant.of("").byte_length(), + Field.of("").byte_length(), + False, + ), + (Field.of("").byte_length(), Field.of("").byte_length(), True), + ], + ) + def test_equality(self, first, second, expected): + assert (first == second) is expected + def _make_arg(self, name="Mock"): - arg = mock.Mock() - arg.__repr__ = lambda x: name + class MockExpr(Constant): + def __repr__(self): + return self.value + + arg = MockExpr(name) return arg def test_and(self): arg1 = self._make_arg() arg2 = self._make_arg() - instance = expr.And(arg1, arg2) + arg3 = self._make_arg() + instance = expr.And(arg1, arg2, arg3) assert instance.name == "and" - assert instance.params == [arg1, arg2] - assert repr(instance) == "And(Mock, Mock)" + assert instance.params == [arg1, arg2, arg3] + assert repr(instance) == "And(Mock, Mock, Mock)" def test_or(self): arg1 = self._make_arg("Arg1") @@ -716,102 +705,134 @@ def test_or(self): instance = expr.Or(arg1, arg2) assert instance.name == "or" assert instance.params == [arg1, arg2] - assert repr(instance) == "Arg1.or(Arg2)" + assert repr(instance) == "Or(Arg1, Arg2)" def test_array_contains(self): arg1 = self._make_arg("ArrayField") arg2 = self._make_arg("Element") - instance = expr.ArrayContains(arg1, arg2) + instance = Expr.array_contains(arg1, arg2) assert instance.name == "array_contains" assert instance.params == [arg1, arg2] assert repr(instance) == "ArrayField.array_contains(Element)" + infix_instance = arg1.array_contains(arg2) + assert infix_instance == instance def test_array_contains_any(self): arg1 = self._make_arg("ArrayField") arg2 = self._make_arg("Element1") arg3 = self._make_arg("Element2") - instance = expr.ArrayContainsAny(arg1, [arg2, arg3]) + instance = Expr.array_contains_any(arg1, [arg2, arg3]) assert instance.name == "array_contains_any" - assert isinstance(instance.params[1], ListOfExprs) + assert isinstance(instance.params[1], _ListOfExprs) assert instance.params[0] == arg1 assert instance.params[1].exprs == [arg2, arg3] - assert ( - repr(instance) - == "ArrayField.array_contains_any(ListOfExprs([Element1, Element2]))" - ) + assert repr(instance) == "ArrayField.array_contains_any([Element1, Element2])" + infix_instance = arg1.array_contains_any([arg2, arg3]) + assert infix_instance == instance def test_exists(self): arg1 = self._make_arg("Field") - instance = expr.Exists(arg1) + instance = Expr.exists(arg1) assert instance.name == "exists" assert instance.params == [arg1] assert repr(instance) == "Field.exists()" + infix_instance = arg1.exists() + assert infix_instance == instance - def test_eq(self): + def test_equal(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.Eq(arg1, arg2) - assert instance.name == "eq" + instance = Expr.equal(arg1, arg2) + assert instance.name == "equal" assert instance.params == [arg1, arg2] - assert repr(instance) == "Left.eq(Right)" + assert repr(instance) == "Left.equal(Right)" + infix_instance = arg1.equal(arg2) + assert infix_instance == instance - def test_gte(self): + def test_greater_than_or_equal(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.Gte(arg1, arg2) - assert instance.name == "gte" + instance = Expr.greater_than_or_equal(arg1, arg2) + assert instance.name == "greater_than_or_equal" assert instance.params == [arg1, arg2] - assert repr(instance) == "Left.gte(Right)" + assert repr(instance) == "Left.greater_than_or_equal(Right)" + infix_instance = arg1.greater_than_or_equal(arg2) + assert infix_instance == instance - def test_gt(self): + def test_greater_than(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.Gt(arg1, arg2) - assert instance.name == "gt" + instance = Expr.greater_than(arg1, arg2) + assert instance.name == "greater_than" assert instance.params == [arg1, arg2] - assert repr(instance) == "Left.gt(Right)" + assert repr(instance) == "Left.greater_than(Right)" + infix_instance = arg1.greater_than(arg2) + assert infix_instance == instance - def test_lte(self): + def test_less_than_or_equal(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.Lte(arg1, arg2) - assert instance.name == "lte" + instance = Expr.less_than_or_equal(arg1, arg2) + assert instance.name == "less_than_or_equal" assert instance.params == [arg1, arg2] - assert repr(instance) == "Left.lte(Right)" + assert repr(instance) == "Left.less_than_or_equal(Right)" + infix_instance = arg1.less_than_or_equal(arg2) + assert infix_instance == instance - def test_lt(self): + def test_less_than(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.Lt(arg1, arg2) - assert instance.name == "lt" + instance = Expr.less_than(arg1, arg2) + assert instance.name == "less_than" assert instance.params == [arg1, arg2] - assert repr(instance) == "Left.lt(Right)" + assert repr(instance) == "Left.less_than(Right)" + infix_instance = arg1.less_than(arg2) + assert infix_instance == instance - def test_neq(self): + def test_not_equal(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.Neq(arg1, arg2) - assert instance.name == "neq" + instance = Expr.not_equal(arg1, arg2) + assert instance.name == "not_equal" assert instance.params == [arg1, arg2] - assert repr(instance) == "Left.neq(Right)" + assert repr(instance) == "Left.not_equal(Right)" + infix_instance = arg1.not_equal(arg2) + assert infix_instance == instance + + def test_equal_any(self): + arg1 = self._make_arg("Field") + arg2 = self._make_arg("Value1") + arg3 = self._make_arg("Value2") + instance = Expr.equal_any(arg1, [arg2, arg3]) + assert instance.name == "equal_any" + assert isinstance(instance.params[1], _ListOfExprs) + assert instance.params[0] == arg1 + assert instance.params[1].exprs == [arg2, arg3] + assert repr(instance) == "Field.equal_any([Value1, Value2])" + infix_instance = arg1.equal_any([arg2, arg3]) + assert infix_instance == instance - def test_in(self): + def test_not_equal_any(self): arg1 = self._make_arg("Field") arg2 = self._make_arg("Value1") arg3 = self._make_arg("Value2") - instance = expr.In(arg1, [arg2, arg3]) - assert instance.name == "in" - assert isinstance(instance.params[1], ListOfExprs) + instance = Expr.not_equal_any(arg1, [arg2, arg3]) + assert instance.name == "not_equal_any" + assert isinstance(instance.params[1], _ListOfExprs) assert instance.params[0] == arg1 assert instance.params[1].exprs == [arg2, arg3] - assert repr(instance) == "Field.in_any(ListOfExprs([Value1, Value2]))" + assert repr(instance) == "Field.not_equal_any([Value1, Value2])" + infix_instance = arg1.not_equal_any([arg2, arg3]) + assert infix_instance == instance def test_is_nan(self): arg1 = self._make_arg("Value") - instance = expr.IsNaN(arg1) + instance = Expr.is_nan(arg1) assert instance.name == "is_nan" assert instance.params == [arg1] assert repr(instance) == "Value.is_nan()" + infix_instance = arg1.is_nan() + assert infix_instance == instance def test_not(self): arg1 = self._make_arg("Condition") @@ -824,72 +845,83 @@ def test_array_contains_all(self): arg1 = self._make_arg("ArrayField") arg2 = self._make_arg("Element1") arg3 = self._make_arg("Element2") - instance = expr.ArrayContainsAll(arg1, [arg2, arg3]) + instance = Expr.array_contains_all(arg1, [arg2, arg3]) assert instance.name == "array_contains_all" - assert isinstance(instance.params[1], ListOfExprs) + assert isinstance(instance.params[1], _ListOfExprs) assert instance.params[0] == arg1 assert instance.params[1].exprs == [arg2, arg3] - assert ( - repr(instance) - == "ArrayField.array_contains_all(ListOfExprs([Element1, Element2]))" - ) + assert repr(instance) == "ArrayField.array_contains_all([Element1, Element2])" + infix_instance = arg1.array_contains_all([arg2, arg3]) + assert infix_instance == instance def test_ends_with(self): arg1 = self._make_arg("Expr") arg2 = self._make_arg("Postfix") - instance = expr.EndsWith(arg1, arg2) + instance = Expr.ends_with(arg1, arg2) assert instance.name == "ends_with" assert instance.params == [arg1, arg2] assert repr(instance) == "Expr.ends_with(Postfix)" + infix_instance = arg1.ends_with(arg2) + assert infix_instance == instance - def test_if(self): + def test_conditional(self): arg1 = self._make_arg("Condition") - arg2 = self._make_arg("TrueExpr") - arg3 = self._make_arg("FalseExpr") - instance = expr.If(arg1, arg2, arg3) - assert instance.name == "if" + arg2 = self._make_arg("ThenExpr") + arg3 = self._make_arg("ElseExpr") + instance = expr.Conditional(arg1, arg2, arg3) + assert instance.name == "conditional" assert instance.params == [arg1, arg2, arg3] - assert repr(instance) == "If(Condition, TrueExpr, FalseExpr)" + assert repr(instance) == "Conditional(Condition, ThenExpr, ElseExpr)" def test_like(self): arg1 = self._make_arg("Expr") arg2 = self._make_arg("Pattern") - instance = expr.Like(arg1, arg2) + instance = Expr.like(arg1, arg2) assert instance.name == "like" assert instance.params == [arg1, arg2] assert repr(instance) == "Expr.like(Pattern)" + infix_instance = arg1.like(arg2) + assert infix_instance == instance def test_regex_contains(self): arg1 = self._make_arg("Expr") arg2 = self._make_arg("Regex") - instance = expr.RegexContains(arg1, arg2) + instance = Expr.regex_contains(arg1, arg2) assert instance.name == "regex_contains" assert instance.params == [arg1, arg2] assert repr(instance) == "Expr.regex_contains(Regex)" + infix_instance = arg1.regex_contains(arg2) + assert infix_instance == instance def test_regex_match(self): arg1 = self._make_arg("Expr") arg2 = self._make_arg("Regex") - instance = expr.RegexMatch(arg1, arg2) + instance = Expr.regex_match(arg1, arg2) assert instance.name == "regex_match" assert instance.params == [arg1, arg2] assert repr(instance) == "Expr.regex_match(Regex)" + infix_instance = arg1.regex_match(arg2) + assert infix_instance == instance def test_starts_with(self): arg1 = self._make_arg("Expr") arg2 = self._make_arg("Prefix") - instance = expr.StartsWith(arg1, arg2) + instance = Expr.starts_with(arg1, arg2) assert instance.name == "starts_with" assert instance.params == [arg1, arg2] assert repr(instance) == "Expr.starts_with(Prefix)" + infix_instance = arg1.starts_with(arg2) + assert infix_instance == instance - def test_str_contains(self): + def test_string_contains(self): arg1 = self._make_arg("Expr") arg2 = self._make_arg("Substring") - instance = expr.StrContains(arg1, arg2) - assert instance.name == "str_contains" + instance = Expr.string_contains(arg1, arg2) + assert instance.name == "string_contains" assert instance.params == [arg1, arg2] - assert repr(instance) == "Expr.str_contains(Substring)" + assert repr(instance) == "Expr.string_contains(Substring)" + infix_instance = arg1.string_contains(arg2) + assert infix_instance == instance def test_xor(self): arg1 = self._make_arg("Condition1") @@ -899,333 +931,268 @@ def test_xor(self): assert instance.params == [arg1, arg2] assert repr(instance) == "Xor(Condition1, Condition2)" - -class TestFunctionClasses: - """ - contains test methods for each Expr class that derives from Function - """ - - @pytest.mark.parametrize( - "method,args,result_cls", - [ - ("add", ("field", 2), expr.Add), - ("subtract", ("field", 2), expr.Subtract), - ("multiply", ("field", 2), expr.Multiply), - ("divide", ("field", 2), expr.Divide), - ("mod", ("field", 2), expr.Mod), - ("logical_max", ("field", 2), expr.LogicalMax), - ("logical_min", ("field", 2), expr.LogicalMin), - ("eq", ("field", 2), expr.Eq), - ("neq", ("field", 2), expr.Neq), - ("lt", ("field", 2), expr.Lt), - ("lte", ("field", 2), expr.Lte), - ("gt", ("field", 2), expr.Gt), - ("gte", ("field", 2), expr.Gte), - ("in_any", ("field", [None]), expr.In), - ("not_in_any", ("field", [None]), expr.Not), - ("array_contains", ("field", None), expr.ArrayContains), - ("array_contains_all", ("field", [None]), expr.ArrayContainsAll), - ("array_contains_any", ("field", [None]), expr.ArrayContainsAny), - ("array_length", ("field",), expr.ArrayLength), - ("array_reverse", ("field",), expr.ArrayReverse), - ("is_nan", ("field",), expr.IsNaN), - ("exists", ("field",), expr.Exists), - ("sum", ("field",), expr.Sum), - ("avg", ("field",), expr.Avg), - ("count", ("field",), expr.Count), - ("count", (), expr.Count), - ("min", ("field",), expr.Min), - ("max", ("field",), expr.Max), - ("char_length", ("field",), expr.CharLength), - ("byte_length", ("field",), expr.ByteLength), - ("like", ("field", "pattern"), expr.Like), - ("regex_contains", ("field", "regex"), expr.RegexContains), - ("regex_matches", ("field", "regex"), expr.RegexMatch), - ("str_contains", ("field", "substring"), expr.StrContains), - ("starts_with", ("field", "prefix"), expr.StartsWith), - ("ends_with", ("field", "postfix"), expr.EndsWith), - ("str_concat", ("field", "elem1", "elem2"), expr.StrConcat), - ("map_get", ("field", "key"), expr.MapGet), - ("vector_length", ("field",), expr.VectorLength), - ("timestamp_to_unix_micros", ("field",), expr.TimestampToUnixMicros), - ("unix_micros_to_timestamp", ("field",), expr.UnixMicrosToTimestamp), - ("timestamp_to_unix_millis", ("field",), expr.TimestampToUnixMillis), - ("unix_millis_to_timestamp", ("field",), expr.UnixMillisToTimestamp), - ("timestamp_to_unix_seconds", ("field",), expr.TimestampToUnixSeconds), - ("unix_seconds_to_timestamp", ("field",), expr.UnixSecondsToTimestamp), - ("timestamp_add", ("field", "day", 1), expr.TimestampAdd), - ("timestamp_sub", ("field", "hour", 2.5), expr.TimestampSub), - ], - ) - def test_function_builder(self, method, args, result_cls): - """ - Test building functions using methods exposed on base Function class. - """ - method_ptr = getattr(expr.Function, method) - - result = method_ptr(*args) - assert isinstance(result, result_cls) - - @pytest.mark.parametrize( - "first,second,expected", - [ - (expr.ArrayElement(), expr.ArrayElement(), True), - (expr.ArrayElement(), expr.CharLength(1), False), - (expr.ArrayElement(), object(), False), - (expr.ArrayElement(), None, False), - (expr.CharLength(1), expr.ArrayElement(), False), - (expr.CharLength(1), expr.CharLength(2), False), - (expr.CharLength(1), expr.CharLength(1), True), - (expr.CharLength(1), expr.ByteLength(1), False), - ], - ) - def test_equality(self, first, second, expected): - assert (first == second) is expected - - def _make_arg(self, name="Mock"): - arg = mock.Mock() - arg.__repr__ = lambda x: name - return arg - def test_divide(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.Divide(arg1, arg2) + instance = Expr.divide(arg1, arg2) assert instance.name == "divide" assert instance.params == [arg1, arg2] - assert repr(instance) == "Divide(Left, Right)" + assert repr(instance) == "Left.divide(Right)" + infix_instance = arg1.divide(arg2) + assert infix_instance == instance - def test_logical_max(self): + def test_logical_maximum(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.LogicalMax(arg1, arg2) - assert instance.name == "logical_maximum" + instance = Expr.logical_maximum(arg1, arg2) + assert instance.name == "maximum" assert instance.params == [arg1, arg2] - assert repr(instance) == "LogicalMax(Left, Right)" + assert repr(instance) == "Left.logical_maximum(Right)" + infix_instance = arg1.logical_maximum(arg2) + assert infix_instance == instance - def test_logical_min(self): + def test_logical_minimum(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.LogicalMin(arg1, arg2) - assert instance.name == "logical_minimum" + instance = Expr.logical_minimum(arg1, arg2) + assert instance.name == "minimum" assert instance.params == [arg1, arg2] - assert repr(instance) == "LogicalMin(Left, Right)" + assert repr(instance) == "Left.logical_minimum(Right)" + infix_instance = arg1.logical_minimum(arg2) + assert infix_instance == instance def test_map_get(self): arg1 = self._make_arg("Map") - arg2 = expr.Constant("Key") - instance = expr.MapGet(arg1, arg2) + arg2 = "key" + instance = Expr.map_get(arg1, arg2) assert instance.name == "map_get" - assert instance.params == [arg1, arg2] - assert repr(instance) == "MapGet(Map, Constant.of('Key'))" + assert instance.params == [arg1, Constant.of(arg2)] + assert repr(instance) == "Map.map_get(Constant.of('key'))" + infix_instance = arg1.map_get(Constant.of(arg2)) + assert infix_instance == instance def test_mod(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.Mod(arg1, arg2) + instance = Expr.mod(arg1, arg2) assert instance.name == "mod" assert instance.params == [arg1, arg2] - assert repr(instance) == "Mod(Left, Right)" + assert repr(instance) == "Left.mod(Right)" + infix_instance = arg1.mod(arg2) + assert infix_instance == instance def test_multiply(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.Multiply(arg1, arg2) + instance = Expr.multiply(arg1, arg2) assert instance.name == "multiply" assert instance.params == [arg1, arg2] - assert repr(instance) == "Multiply(Left, Right)" - - def test_parent(self): - arg1 = self._make_arg("Value") - instance = expr.Parent(arg1) - assert instance.name == "parent" - assert instance.params == [arg1] - assert repr(instance) == "Parent(Value)" + assert repr(instance) == "Left.multiply(Right)" + infix_instance = arg1.multiply(arg2) + assert infix_instance == instance - def test_str_concat(self): + def test_string_concat(self): arg1 = self._make_arg("Str1") arg2 = self._make_arg("Str2") - instance = expr.StrConcat(arg1, arg2) - assert instance.name == "str_concat" - assert instance.params == [arg1, arg2] - assert repr(instance) == "StrConcat(Str1, Str2)" + arg3 = self._make_arg("Str3") + instance = Expr.string_concat(arg1, arg2, arg3) + assert instance.name == "string_concat" + assert instance.params == [arg1, arg2, arg3] + assert repr(instance) == "Str1.string_concat(Str2, Str3)" + infix_instance = arg1.string_concat(arg2, arg3) + assert infix_instance == instance def test_subtract(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.Subtract(arg1, arg2) + instance = Expr.subtract(arg1, arg2) assert instance.name == "subtract" assert instance.params == [arg1, arg2] - assert repr(instance) == "Subtract(Left, Right)" + assert repr(instance) == "Left.subtract(Right)" + infix_instance = arg1.subtract(arg2) + assert infix_instance == instance def test_timestamp_add(self): arg1 = self._make_arg("Timestamp") arg2 = self._make_arg("Unit") arg3 = self._make_arg("Amount") - instance = expr.TimestampAdd(arg1, arg2, arg3) + instance = Expr.timestamp_add(arg1, arg2, arg3) assert instance.name == "timestamp_add" assert instance.params == [arg1, arg2, arg3] - assert repr(instance) == "TimestampAdd(Timestamp, Unit, Amount)" + assert repr(instance) == "Timestamp.timestamp_add(Unit, Amount)" + infix_instance = arg1.timestamp_add(arg2, arg3) + assert infix_instance == instance - def test_timestamp_sub(self): + def test_timestamp_subtract(self): arg1 = self._make_arg("Timestamp") arg2 = self._make_arg("Unit") arg3 = self._make_arg("Amount") - instance = expr.TimestampSub(arg1, arg2, arg3) - assert instance.name == "timestamp_sub" + instance = Expr.timestamp_subtract(arg1, arg2, arg3) + assert instance.name == "timestamp_subtract" assert instance.params == [arg1, arg2, arg3] - assert repr(instance) == "TimestampSub(Timestamp, Unit, Amount)" + assert repr(instance) == "Timestamp.timestamp_subtract(Unit, Amount)" + infix_instance = arg1.timestamp_subtract(arg2, arg3) + assert infix_instance == instance def test_timestamp_to_unix_micros(self): arg1 = self._make_arg("Input") - instance = expr.TimestampToUnixMicros(arg1) + instance = Expr.timestamp_to_unix_micros(arg1) assert instance.name == "timestamp_to_unix_micros" assert instance.params == [arg1] - assert repr(instance) == "TimestampToUnixMicros(Input)" + assert repr(instance) == "Input.timestamp_to_unix_micros()" + infix_instance = arg1.timestamp_to_unix_micros() + assert infix_instance == instance def test_timestamp_to_unix_millis(self): arg1 = self._make_arg("Input") - instance = expr.TimestampToUnixMillis(arg1) + instance = Expr.timestamp_to_unix_millis(arg1) assert instance.name == "timestamp_to_unix_millis" assert instance.params == [arg1] - assert repr(instance) == "TimestampToUnixMillis(Input)" + assert repr(instance) == "Input.timestamp_to_unix_millis()" + infix_instance = arg1.timestamp_to_unix_millis() + assert infix_instance == instance def test_timestamp_to_unix_seconds(self): arg1 = self._make_arg("Input") - instance = expr.TimestampToUnixSeconds(arg1) + instance = Expr.timestamp_to_unix_seconds(arg1) assert instance.name == "timestamp_to_unix_seconds" assert instance.params == [arg1] - assert repr(instance) == "TimestampToUnixSeconds(Input)" + assert repr(instance) == "Input.timestamp_to_unix_seconds()" + infix_instance = arg1.timestamp_to_unix_seconds() + assert infix_instance == instance def test_unix_micros_to_timestamp(self): arg1 = self._make_arg("Input") - instance = expr.UnixMicrosToTimestamp(arg1) + instance = Expr.unix_micros_to_timestamp(arg1) assert instance.name == "unix_micros_to_timestamp" assert instance.params == [arg1] - assert repr(instance) == "UnixMicrosToTimestamp(Input)" + assert repr(instance) == "Input.unix_micros_to_timestamp()" + infix_instance = arg1.unix_micros_to_timestamp() + assert infix_instance == instance def test_unix_millis_to_timestamp(self): arg1 = self._make_arg("Input") - instance = expr.UnixMillisToTimestamp(arg1) + instance = Expr.unix_millis_to_timestamp(arg1) assert instance.name == "unix_millis_to_timestamp" assert instance.params == [arg1] - assert repr(instance) == "UnixMillisToTimestamp(Input)" + assert repr(instance) == "Input.unix_millis_to_timestamp()" + infix_instance = arg1.unix_millis_to_timestamp() + assert infix_instance == instance def test_unix_seconds_to_timestamp(self): arg1 = self._make_arg("Input") - instance = expr.UnixSecondsToTimestamp(arg1) + instance = Expr.unix_seconds_to_timestamp(arg1) assert instance.name == "unix_seconds_to_timestamp" assert instance.params == [arg1] - assert repr(instance) == "UnixSecondsToTimestamp(Input)" + assert repr(instance) == "Input.unix_seconds_to_timestamp()" + infix_instance = arg1.unix_seconds_to_timestamp() + assert infix_instance == instance def test_vector_length(self): arg1 = self._make_arg("Array") - instance = expr.VectorLength(arg1) + instance = Expr.vector_length(arg1) assert instance.name == "vector_length" assert instance.params == [arg1] - assert repr(instance) == "VectorLength(Array)" + assert repr(instance) == "Array.vector_length()" + infix_instance = arg1.vector_length() + assert infix_instance == instance def test_add(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.Add(arg1, arg2) + instance = Expr.add(arg1, arg2) assert instance.name == "add" assert instance.params == [arg1, arg2] - assert repr(instance) == "Add(Left, Right)" - - def test_array_element(self): - instance = expr.ArrayElement() - assert instance.name == "array_element" - assert instance.params == [] - assert repr(instance) == "ArrayElement()" - - def test_array_filter(self): - arg1 = self._make_arg("Array") - arg2 = self._make_arg("FilterCond") - instance = expr.ArrayFilter(arg1, arg2) - assert instance.name == "array_filter" - assert instance.params == [arg1, arg2] - assert repr(instance) == "ArrayFilter(Array, FilterCond)" + assert repr(instance) == "Left.add(Right)" + infix_instance = arg1.add(arg2) + assert infix_instance == instance def test_array_length(self): arg1 = self._make_arg("Array") - instance = expr.ArrayLength(arg1) + instance = Expr.array_length(arg1) assert instance.name == "array_length" assert instance.params == [arg1] - assert repr(instance) == "ArrayLength(Array)" + assert repr(instance) == "Array.array_length()" + infix_instance = arg1.array_length() + assert infix_instance == instance def test_array_reverse(self): arg1 = self._make_arg("Array") - instance = expr.ArrayReverse(arg1) + instance = Expr.array_reverse(arg1) assert instance.name == "array_reverse" assert instance.params == [arg1] - assert repr(instance) == "ArrayReverse(Array)" - - def test_array_transform(self): - arg1 = self._make_arg("Array") - arg2 = self._make_arg("TransformFunc") - instance = expr.ArrayTransform(arg1, arg2) - assert instance.name == "array_transform" - assert instance.params == [arg1, arg2] - assert repr(instance) == "ArrayTransform(Array, TransformFunc)" + assert repr(instance) == "Array.array_reverse()" + infix_instance = arg1.array_reverse() + assert infix_instance == instance def test_byte_length(self): arg1 = self._make_arg("Expr") - instance = expr.ByteLength(arg1) + instance = Expr.byte_length(arg1) assert instance.name == "byte_length" assert instance.params == [arg1] - assert repr(instance) == "ByteLength(Expr)" + assert repr(instance) == "Expr.byte_length()" + infix_instance = arg1.byte_length() + assert infix_instance == instance def test_char_length(self): arg1 = self._make_arg("Expr") - instance = expr.CharLength(arg1) + instance = Expr.char_length(arg1) assert instance.name == "char_length" assert instance.params == [arg1] - assert repr(instance) == "CharLength(Expr)" + assert repr(instance) == "Expr.char_length()" + infix_instance = arg1.char_length() + assert infix_instance == instance def test_collection_id(self): arg1 = self._make_arg("Value") - instance = expr.CollectionId(arg1) + instance = Expr.collection_id(arg1) assert instance.name == "collection_id" assert instance.params == [arg1] - assert repr(instance) == "CollectionId(Value)" + assert repr(instance) == "Value.collection_id()" + infix_instance = arg1.collection_id() + assert infix_instance == instance def test_sum(self): arg1 = self._make_arg("Value") - instance = expr.Sum(arg1) + instance = Expr.sum(arg1) assert instance.name == "sum" assert instance.params == [arg1] - assert repr(instance) == "Sum(Value)" + assert repr(instance) == "Value.sum()" + infix_instance = arg1.sum() + assert infix_instance == instance - def test_avg(self): + def test_average(self): arg1 = self._make_arg("Value") - instance = expr.Avg(arg1) - assert instance.name == "avg" + instance = Expr.average(arg1) + assert instance.name == "average" assert instance.params == [arg1] - assert repr(instance) == "Avg(Value)" + assert repr(instance) == "Value.average()" + infix_instance = arg1.average() + assert infix_instance == instance def test_count(self): arg1 = self._make_arg("Value") - instance = expr.Count(arg1) + instance = Expr.count(arg1) assert instance.name == "count" assert instance.params == [arg1] - assert repr(instance) == "Count(Value)" - - def test_count_empty(self): - instance = expr.Count() - assert instance.params == [] - assert repr(instance) == "Count()" + assert repr(instance) == "Value.count()" + infix_instance = arg1.count() + assert infix_instance == instance - def test_min(self): + def test_minimum(self): arg1 = self._make_arg("Value") - instance = expr.Min(arg1) + instance = Expr.minimum(arg1) assert instance.name == "minimum" assert instance.params == [arg1] - assert repr(instance) == "Min(Value)" + assert repr(instance) == "Value.minimum()" + infix_instance = arg1.minimum() + assert infix_instance == instance - def test_max(self): + def test_maximum(self): arg1 = self._make_arg("Value") - instance = expr.Max(arg1) + instance = Expr.maximum(arg1) assert instance.name == "maximum" assert instance.params == [arg1] - assert repr(instance) == "Max(Value)" + assert repr(instance) == "Value.maximum()" + infix_instance = arg1.maximum() + assert infix_instance == instance diff --git a/tests/unit/v1/test_pipeline_stages.py b/tests/unit/v1/test_pipeline_stages.py index e67a4ca3a..d5b36e56c 100644 --- a/tests/unit/v1/test_pipeline_stages.py +++ b/tests/unit/v1/test_pipeline_stages.py @@ -21,8 +21,6 @@ Constant, Field, Ordering, - Sum, - Count, ) from google.cloud.firestore_v1.types.document import Value from google.cloud.firestore_v1._helpers import GeoPoint @@ -79,8 +77,8 @@ def _make_one(self, *args, **kwargs): def test_ctor_positional(self): """test with only positional arguments""" - sum_total = Sum(Field.of("total")).as_("sum_total") - avg_price = Field.of("price").avg().as_("avg_price") + sum_total = Field.of("total").sum().as_("sum_total") + avg_price = Field.of("price").average().as_("avg_price") instance = self._make_one(sum_total, avg_price) assert list(instance.accumulators) == [sum_total, avg_price] assert len(instance.groups) == 0 @@ -88,8 +86,8 @@ def test_ctor_positional(self): def test_ctor_keyword(self): """test with only keyword arguments""" - sum_total = Sum(Field.of("total")).as_("sum_total") - avg_price = Field.of("price").avg().as_("avg_price") + sum_total = Field.of("total").sum().as_("sum_total") + avg_price = Field.of("price").average().as_("avg_price") group_category = Field.of("category") instance = self._make_one( accumulators=[avg_price, sum_total], groups=[group_category, "city"] @@ -103,24 +101,24 @@ def test_ctor_keyword(self): def test_ctor_combined(self): """test with a mix of arguments""" - sum_total = Sum(Field.of("total")).as_("sum_total") - avg_price = Field.of("price").avg().as_("avg_price") - count = Count(Field.of("total")).as_("count") + sum_total = Field.of("total").sum().as_("sum_total") + avg_price = Field.of("price").average().as_("avg_price") + count = Field.of("total").count().as_("count") with pytest.raises(ValueError): self._make_one(sum_total, accumulators=[avg_price, count]) def test_repr(self): - sum_total = Sum(Field.of("total")).as_("sum_total") + sum_total = Field.of("total").sum().as_("sum_total") group_category = Field.of("category") instance = self._make_one(sum_total, groups=[group_category]) repr_str = repr(instance) assert ( repr_str - == "Aggregate(Sum(Field.of('total')).as_('sum_total'), groups=[Field.of('category')])" + == "Aggregate(Field.of('total').sum().as_('sum_total'), groups=[Field.of('category')])" ) def test_to_pb(self): - sum_total = Sum(Field.of("total")).as_("sum_total") + sum_total = Field.of("total").sum().as_("sum_total") group_category = Field.of("category") instance = self._make_one(sum_total, groups=[group_category]) result = instance._to_pb() @@ -790,19 +788,21 @@ def _make_one(self, *args, **kwargs): return stages.Where(*args, **kwargs) def test_repr(self): - condition = Field.of("age").gt(30) + condition = Field.of("age").greater_than(30) instance = self._make_one(condition) repr_str = repr(instance) - assert repr_str == "Where(condition=Field.of('age').gt(Constant.of(30)))" + assert ( + repr_str == "Where(condition=Field.of('age').greater_than(Constant.of(30)))" + ) def test_to_pb(self): - condition = Field.of("city").eq("SF") + condition = Field.of("city").equal("SF") instance = self._make_one(condition) result = instance._to_pb() assert result.name == "where" assert len(result.args) == 1 got_fn = result.args[0].function_value - assert got_fn.name == "eq" + assert got_fn.name == "equal" assert len(got_fn.args) == 2 assert got_fn.args[0].field_reference_value == "city" assert got_fn.args[1].string_value == "SF"