diff --git a/google/cloud/firestore_v1/_pipeline_stages.py b/google/cloud/firestore_v1/_pipeline_stages.py index 3871a363d..f7d311d89 100644 --- a/google/cloud/firestore_v1/_pipeline_stages.py +++ b/google/cloud/firestore_v1/_pipeline_stages.py @@ -13,13 +13,109 @@ # limitations under the License. from __future__ import annotations -from typing import Optional +from typing import Optional, Sequence, TYPE_CHECKING from abc import ABC from abc import abstractmethod +from enum import Enum from google.cloud.firestore_v1.types.document import Pipeline as Pipeline_pb from google.cloud.firestore_v1.types.document import Value -from google.cloud.firestore_v1.pipeline_expressions import Expr +from google.cloud.firestore_v1.vector import Vector +from google.cloud.firestore_v1.base_vector_query import DistanceMeasure +from google.cloud.firestore_v1.pipeline_expressions import ( + Accumulator, + Expr, + ExprWithAlias, + Field, + FilterCondition, + Selectable, + Ordering, +) +from google.cloud.firestore_v1._helpers import encode_value + +if TYPE_CHECKING: # pragma: NO COVER + from google.cloud.firestore_v1.base_pipeline import _BasePipeline + from google.cloud.firestore_v1.base_document import BaseDocumentReference + + +class FindNearestOptions: + """Options for configuring the `FindNearest` pipeline stage. + + Attributes: + limit (Optional[int]): The maximum number of nearest neighbors to return. + distance_field (Optional[Field]): An optional field to store the calculated + distance in the output documents. + """ + + def __init__( + self, + limit: Optional[int] = None, + distance_field: Optional[Field] = None, + ): + self.limit = limit + self.distance_field = distance_field + + def __repr__(self): + args = [] + if self.limit is not None: + args.append(f"limit={self.limit}") + if self.distance_field is not None: + args.append(f"distance_field={self.distance_field}") + return f"{self.__class__.__name__}({', '.join(args)})" + + +class SampleOptions: + """Options for the 'sample' pipeline stage.""" + + class Mode(Enum): + DOCUMENTS = "documents" + PERCENT = "percent" + + def __init__(self, value: int | float, mode: Mode | str): + self.value = value + self.mode = SampleOptions.Mode[mode.upper()] if isinstance(mode, str) else mode + + def __repr__(self): + if self.mode == SampleOptions.Mode.DOCUMENTS: + mode_str = "doc_limit" + else: + mode_str = "percentage" + return f"SampleOptions.{mode_str}({self.value})" + + @staticmethod + def doc_limit(value: int): + """ + Sample a set number of documents + + Args: + value: number of documents to sample + """ + return SampleOptions(value, mode=SampleOptions.Mode.DOCUMENTS) + + @staticmethod + def percentage(value: float): + """ + Sample a percentage of documents + + Args: + value: percentage of documents to return + """ + return SampleOptions(value, mode=SampleOptions.Mode.PERCENT) + + +class UnnestOptions: + """Options for configuring the `Unnest` pipeline stage. + + Attributes: + index_field (str): The name of the field to add to each output document, + storing the original 0-based index of the element within the array. + """ + + def __init__(self, index_field: str): + self.index_field = index_field + + def __repr__(self): + return f"{self.__class__.__name__}(index_field={self.index_field!r})" class Stage(ABC): @@ -52,6 +148,68 @@ def __repr__(self): return f"{self.__class__.__name__}({', '.join(items)})" +class AddFields(Stage): + """Adds new fields to outputs from previous stages.""" + + def __init__(self, *fields: Selectable): + super().__init__("add_fields") + self.fields = list(fields) + + def _pb_args(self): + return [ + Value( + map_value={ + "fields": {m[0]: m[1] for m in [f._to_map() for f in self.fields]} + } + ) + ] + + +class Aggregate(Stage): + """Performs aggregation operations, optionally grouped.""" + + def __init__( + self, + *args: ExprWithAlias[Accumulator], + accumulators: Sequence[ExprWithAlias[Accumulator]] = (), + groups: Sequence[str | Selectable] = (), + ): + super().__init__() + self.groups: list[Selectable] = [ + Field(f) if isinstance(f, str) else f for f in groups + ] + if args and accumulators: + raise ValueError( + "Aggregate stage contains both positional and keyword accumulators" + ) + self.accumulators = args or accumulators + + def _pb_args(self): + return [ + Value( + map_value={ + "fields": { + m[0]: m[1] for m in [f._to_map() for f in self.accumulators] + } + } + ), + Value( + map_value={ + "fields": {m[0]: m[1] for m in [f._to_map() for f in self.groups]} + } + ), + ] + + def __repr__(self): + accumulator_str = ", ".join(repr(v) for v in self.accumulators) + group_str = "" + if self.groups: + if self.accumulators: + group_str = ", " + group_str += f"groups={self.groups}" + return f"{self.__class__.__name__}({accumulator_str}{group_str})" + + class Collection(Stage): """Specifies a collection as the initial data source.""" @@ -65,6 +223,103 @@ def _pb_args(self): return [Value(reference_value=self.path)] +class CollectionGroup(Stage): + """Specifies a collection group as the initial data source.""" + + def __init__(self, collection_id: str): + super().__init__("collection_group") + self.collection_id = collection_id + + def _pb_args(self): + return [Value(string_value=self.collection_id)] + + +class Database(Stage): + """Specifies the default database as the initial data source.""" + + def __init__(self): + super().__init__() + + def _pb_args(self): + return [] + + +class Distinct(Stage): + """Returns documents with distinct combinations of specified field values.""" + + def __init__(self, *fields: str | Selectable): + super().__init__() + self.fields: list[Selectable] = [ + Field(f) if isinstance(f, str) else f for f in fields + ] + + def _pb_args(self) -> list[Value]: + return [ + Value( + map_value={ + "fields": {m[0]: m[1] for m in [f._to_map() for f in self.fields]} + } + ) + ] + + +class Documents(Stage): + """Specifies specific documents as the initial data source.""" + + def __init__(self, *paths: str): + super().__init__() + self.paths = paths + + def __repr__(self): + return f"{self.__class__.__name__}({', '.join([repr(p) for p in self.paths])})" + + @staticmethod + def of(*documents: "BaseDocumentReference") -> "Documents": + doc_paths = ["/" + doc.path for doc in documents] + return Documents(*doc_paths) + + def _pb_args(self): + return [ + Value( + array_value={ + "values": [Value(string_value=path) for path in self.paths] + } + ) + ] + + +class FindNearest(Stage): + """Performs vector distance (similarity) search.""" + + def __init__( + self, + field: str | Expr, + vector: Sequence[float] | Vector, + distance_measure: "DistanceMeasure", + options: Optional["FindNearestOptions"] = None, + ): + super().__init__("find_nearest") + self.field: Expr = Field(field) if isinstance(field, str) else field + self.vector: Vector = vector if isinstance(vector, Vector) else Vector(vector) + self.distance_measure = distance_measure + self.options = options or FindNearestOptions() + + def _pb_args(self): + return [ + self.field._to_pb(), + encode_value(self.vector), + Value(string_value=self.distance_measure.name.lower()), + ] + + def _pb_options(self) -> dict[str, Value]: + options = {} + if self.options and self.options.limit is not None: + options["limit"] = Value(integer_value=self.options.limit) + if self.options and self.options.distance_field is not None: + options["distance_field"] = self.options.distance_field._to_pb() + return options + + class GenericStage(Stage): """Represents a generic, named stage with parameters.""" @@ -79,3 +334,136 @@ def _pb_args(self): def __repr__(self): return f"{self.__class__.__name__}(name='{self.name}')" + + +class Limit(Stage): + """Limits the maximum number of documents returned.""" + + def __init__(self, limit: int): + super().__init__() + self.limit = limit + + def _pb_args(self): + return [Value(integer_value=self.limit)] + + +class Offset(Stage): + """Skips a specified number of documents.""" + + def __init__(self, offset: int): + super().__init__() + self.offset = offset + + def _pb_args(self): + return [Value(integer_value=self.offset)] + + +class RemoveFields(Stage): + """Removes specified fields from outputs.""" + + def __init__(self, *fields: str | Field): + super().__init__("remove_fields") + self.fields = [Field(f) if isinstance(f, str) else f for f in fields] + + def __repr__(self): + return f"{self.__class__.__name__}({', '.join(repr(f) for f in self.fields)})" + + def _pb_args(self) -> list[Value]: + return [f._to_pb() for f in self.fields] + + +class Sample(Stage): + """Performs pseudo-random sampling of documents.""" + + def __init__(self, limit_or_options: int | SampleOptions): + super().__init__() + if isinstance(limit_or_options, int): + options = SampleOptions.doc_limit(limit_or_options) + else: + options = limit_or_options + self.options: SampleOptions = options + + def _pb_args(self): + if self.options.mode == SampleOptions.Mode.DOCUMENTS: + return [ + Value(integer_value=self.options.value), + Value(string_value="documents"), + ] + else: + return [ + Value(double_value=self.options.value), + Value(string_value="percent"), + ] + + +class Select(Stage): + """Selects or creates a set of fields.""" + + def __init__(self, *selections: str | Selectable): + super().__init__() + self.projections = [Field(s) if isinstance(s, str) else s for s in selections] + + def _pb_args(self) -> list[Value]: + return [Selectable._value_from_selectables(*self.projections)] + + +class Sort(Stage): + """Sorts documents based on specified criteria.""" + + def __init__(self, *orders: "Ordering"): + super().__init__() + self.orders = list(orders) + + def _pb_args(self): + return [o._to_pb() for o in self.orders] + + +class Union(Stage): + """Performs a union of documents from two pipelines.""" + + def __init__(self, other: _BasePipeline): + super().__init__() + self.other = other + + def _pb_args(self): + return [Value(pipeline_value=self.other._to_pb().pipeline)] + + +class Unnest(Stage): + """Produces a document for each element in an array field.""" + + def __init__( + self, + field: Selectable | str, + alias: Field | str | None = None, + options: UnnestOptions | None = None, + ): + super().__init__() + self.field: Selectable = Field(field) if isinstance(field, str) else field + if alias is None: + self.alias = self.field + elif isinstance(alias, str): + self.alias = Field(alias) + else: + self.alias = alias + self.options = options + + def _pb_args(self): + return [self.field._to_pb(), self.alias._to_pb()] + + def _pb_options(self): + options = {} + if self.options is not None: + options["index_field"] = Value(string_value=self.options.index_field) + return options + + +class Where(Stage): + """Filters documents based on a specified condition.""" + + def __init__(self, condition: FilterCondition): + super().__init__() + self.condition = condition + + def _pb_args(self): + return [self.condition._to_pb()] diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index dde906fe6..50ae7ab62 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -18,9 +18,18 @@ from google.cloud.firestore_v1.types.pipeline import ( StructuredPipeline as StructuredPipeline_pb, ) +from google.cloud.firestore_v1.vector import Vector +from google.cloud.firestore_v1.base_vector_query import DistanceMeasure from google.cloud.firestore_v1.types.firestore import ExecutePipelineRequest from google.cloud.firestore_v1.pipeline_result import PipelineResult -from google.cloud.firestore_v1.pipeline_expressions import Expr +from google.cloud.firestore_v1.pipeline_expressions import ( + Accumulator, + Expr, + ExprWithAlias, + Field, + FilterCondition, + Selectable, +) from google.cloud.firestore_v1 import _helpers if TYPE_CHECKING: # pragma: NO COVER @@ -35,7 +44,7 @@ class _BasePipeline: Base class for building Firestore data transformation and query pipelines. This class is not intended to be instantiated directly. - Use `client.collection.("...").pipeline()` to create pipeline instances. + Use `client.pipeline()` to create pipeline instances. """ def __init__(self, client: Client | AsyncClient): @@ -127,6 +136,328 @@ def _execute_response_helper( doc._pb.update_time if doc.update_time else None, ) + def add_fields(self, *fields: Selectable) -> "_BasePipeline": + """ + Adds new fields to outputs from previous stages. + + This stage allows you to compute values on-the-fly based on existing data + from previous stages or constants. You can use this to create new fields + or overwrite existing ones (if there is name overlap). + + The added fields are defined using `Selectable` expressions, which can be: + - `Field`: References an existing document field. + - `Function`: Performs a calculation using functions like `add`, + `multiply` with assigned aliases using `Expr.as_()`. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field, add + >>> pipeline = client.pipeline().collection("books") + >>> pipeline = pipeline.add_fields( + ... Field.of("rating").as_("bookRating"), # Rename 'rating' to 'bookRating' + ... add(5, Field.of("quantity")).as_("totalCost") # Calculate 'totalCost' + ... ) + + Args: + *fields: The fields to add to the documents, specified as `Selectable` + expressions. + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.AddFields(*fields)) + + def remove_fields(self, *fields: Field | str) -> "_BasePipeline": + """ + Removes fields from outputs of previous stages. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> pipeline = client.pipeline().collection("books") + >>> # Remove by name + >>> pipeline = pipeline.remove_fields("rating", "cost") + >>> # Remove by Field object + >>> pipeline = pipeline.remove_fields(Field.of("rating"), Field.of("cost")) + + + Args: + *fields: The fields to remove, specified as field names (str) or + `Field` objects. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.RemoveFields(*fields)) + + def select(self, *selections: str | Selectable) -> "_BasePipeline": + """ + Selects or creates a set of fields from the outputs of previous stages. + + The selected fields are defined using `Selectable` expressions or field names: + - `Field`: References an existing document field. + - `Function`: Represents the result of a function with an assigned alias + name using `Expr.as_()`. + - `str`: The name of an existing field. + + If no selections are provided, the output of this stage is empty. Use + `add_fields()` instead if only additions are desired. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field, to_upper + >>> pipeline = client.pipeline().collection("books") + >>> # Select by name + >>> pipeline = pipeline.select("name", "address") + >>> # Select using Field and Function expressions + >>> pipeline = pipeline.select( + ... Field.of("name"), + ... Field.of("address").to_upper().as_("upperAddress"), + ... ) + + Args: + *selections: The fields to include in the output documents, specified as + field names (str) or `Selectable` expressions. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Select(*selections)) + + def where(self, condition: FilterCondition) -> "_BasePipeline": + """ + Filters the documents from previous stages to only include those matching + the specified `FilterCondition`. + + This stage allows you to apply conditions to the data, similar to a "WHERE" + clause in SQL. You can filter documents based on their field values, using + implementations of `FilterCondition`, typically including but not limited to: + - field comparators: `eq`, `lt` (less than), `gt` (greater than), etc. + - logical operators: `And`, `Or`, `Not`, etc. + - advanced functions: `regex_matches`, `array_contains`, etc. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field, And, + >>> pipeline = client.pipeline().collection("books") + >>> # Using static functions + >>> pipeline = pipeline.where( + ... And( + ... Field.of("rating").gt(4.0), # Filter for ratings > 4.0 + ... Field.of("genre").eq("Science Fiction") # Filter for genre + ... ) + ... ) + >>> # Using methods on expressions + >>> pipeline = pipeline.where( + ... And( + ... Field.of("rating").gt(4.0), + ... Field.of("genre").eq("Science Fiction") + ... ) + ... ) + + + Args: + condition: The `FilterCondition` to apply. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Where(condition)) + + def find_nearest( + self, + field: str | Expr, + vector: Sequence[float] | "Vector", + distance_measure: "DistanceMeasure", + options: stages.FindNearestOptions | None = None, + ) -> "_BasePipeline": + """ + Performs vector distance (similarity) search with given parameters on the + stage inputs. + + This stage adds a "nearest neighbor search" capability to your pipelines. + Given a field or expression that evaluates to a vector and a target vector, + this stage will identify and return the inputs whose vector is closest to + the target vector, using the specified distance measure and options. + + Example: + >>> from google.cloud.firestore_v1.base_vector_query import DistanceMeasure + >>> from google.cloud.firestore_v1.pipeline_stages import FindNearestOptions + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> + >>> target_vector = [0.1, 0.2, 0.3] + >>> pipeline = client.pipeline().collection("books") + >>> # Find using field name + >>> pipeline = pipeline.find_nearest( + ... "topicVectors", + ... target_vector, + ... DistanceMeasure.COSINE, + ... options=FindNearestOptions(limit=10, distance_field="distance") + ... ) + >>> # Find using Field expression + >>> pipeline = pipeline.find_nearest( + ... Field.of("topicVectors"), + ... target_vector, + ... DistanceMeasure.COSINE, + ... options=FindNearestOptions(limit=10, distance_field="distance") + ... ) + + Args: + field: The name of the field (str) or an expression (`Expr`) that + evaluates to the vector data. This field should store vector values. + vector: The target vector (sequence of floats or `Vector` object) to + compare against. + distance_measure: The distance measure (`DistanceMeasure`) to use + (e.g., `DistanceMeasure.COSINE`, `DistanceMeasure.EUCLIDEAN`). + limit: The maximum number of nearest neighbors to return. + options: Configuration options (`FindNearestOptions`) for the search, + such as limit and output distance field name. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append( + stages.FindNearest(field, vector, distance_measure, options) + ) + + def sort(self, *orders: stages.Ordering) -> "_BasePipeline": + """ + Sorts the documents from previous stages based on one or more `Ordering` criteria. + + This stage allows you to order the results of your pipeline. You can specify + multiple `Ordering` instances to sort by multiple fields or expressions in + ascending or descending order. If documents have the same value for a sorting + criterion, the next specified ordering will be used. If all orderings result + in equal comparison, the documents are considered equal and the relative order + is unspecified. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> pipeline = client.pipeline().collection("books") + >>> # Sort books by rating descending, then title ascending + >>> pipeline = pipeline.sort( + ... Field.of("rating").descending(), + ... Field.of("title").ascending() + ... ) + + Args: + *orders: One or more `Ordering` instances specifying the sorting criteria. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Sort(*orders)) + + def sample(self, limit_or_options: int | stages.SampleOptions) -> "_BasePipeline": + """ + Performs a pseudo-random sampling of the documents from the previous stage. + + This stage filters documents pseudo-randomly. + - If an `int` limit is provided, it specifies the maximum number of documents + to emit. If fewer documents are available, all are passed through. + - If `SampleOptions` are provided, they specify how sampling is performed + (e.g., by document count or percentage). + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import SampleOptions + >>> pipeline = client.pipeline().collection("books") + >>> # Sample 10 books, if available. + >>> pipeline = pipeline.sample(10) + >>> pipeline = pipeline.sample(SampleOptions.doc_limit(10)) + >>> # Sample 50% of books. + >>> pipeline = pipeline.sample(SampleOptions.percentage(0.5)) + + + Args: + limit_or_options: Either an integer specifying the maximum number of + documents to sample, or a `SampleOptions` object. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Sample(limit_or_options)) + + def union(self, other: "_BasePipeline") -> "_BasePipeline": + """ + Performs a union of all documents from this pipeline and another pipeline, + including duplicates. + + This stage passes through documents from the previous stage of this pipeline, + and also passes through documents from the previous stage of the `other` + pipeline provided. The order of documents emitted from this stage is undefined. + + Example: + >>> books_pipeline = client.pipeline().collection("books") + >>> magazines_pipeline = client.pipeline().collection("magazines") + >>> # Emit documents from both collections + >>> combined_pipeline = books_pipeline.union(magazines_pipeline) + + Args: + other: The other `Pipeline` whose results will be unioned with this one. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Union(other)) + + def unnest( + self, + field: str | Selectable, + alias: str | Field | None = None, + options: stages.UnnestOptions | None = None, + ) -> "_BasePipeline": + """ + Produces a document for each element in an array field from the previous stage document. + + For each previous stage document, this stage will emit zero or more augmented documents. The + input array found in the previous stage document field specified by the `fieldName` parameter, + will emit an augmented document for each input array element. The input array element will + augment the previous stage document by setting the `alias` field with the array element value. + If `alias` is unset, the data in `field` will be overwritten. + + Example: + Input document: + ```json + { "title": "The Hitchhiker's Guide", "tags": [ "comedy", "sci-fi" ], ... } + ``` + + >>> from google.cloud.firestore_v1.pipeline_stages import UnnestOptions + >>> pipeline = client.pipeline().collection("books") + >>> # Emit a document for each tag + >>> pipeline = pipeline.unnest("tags", alias="tag") + + Output documents (without options): + ```json + { "title": "The Hitchhiker's Guide", "tag": "comedy", ... } + { "title": "The Hitchhiker's Guide", "tag": "sci-fi", ... } + ``` + + Optionally, `UnnestOptions` can specify a field to store the original index + of the element within the array + + Example: + Input document: + ```json + { "title": "The Hitchhiker's Guide", "tags": [ "comedy", "sci-fi" ], ... } + ``` + + >>> from google.cloud.firestore_v1.pipeline_stages import UnnestOptions + >>> pipeline = client.pipeline().collection("books") + >>> # Emit a document for each tag, including the index + >>> pipeline = pipeline.unnest("tags", options=UnnestOptions(index_field="tagIndex")) + + Output documents (with index_field="tagIndex"): + ```json + { "title": "The Hitchhiker's Guide", "tags": "comedy", "tagIndex": 0, ... } + { "title": "The Hitchhiker's Guide", "tags": "sci-fi", "tagIndex": 1, ... } + ``` + + Args: + field: The name of the field containing the array to unnest. + alias The alias field is used as the field name for each element within the output array. + If unset, or if `alias` matches the `field`, the output data will overwrite the original field. + options: Optional `UnnestOptions` to configure additional behavior, like adding an index field. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Unnest(field, alias, options)) + def generic_stage(self, name: str, *params: Expr) -> "_BasePipeline": """ Adds a generic, named stage to the pipeline with specified parameters. @@ -149,3 +480,132 @@ def generic_stage(self, name: str, *params: Expr) -> "_BasePipeline": A new Pipeline object with this stage appended to the stage list """ return self._append(stages.GenericStage(name, *params)) + + def offset(self, offset: int) -> "_BasePipeline": + """ + Skips the first `offset` number of documents from the results of previous stages. + + This stage is useful for implementing pagination, allowing you to retrieve + results in chunks. It is typically used in conjunction with `limit()` to + control the size of each page. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> pipeline = client.pipeline().collection("books") + >>> # Retrieve the second page of 20 results (assuming sorted) + >>> pipeline = pipeline.sort(Field.of("published").descending()) + >>> pipeline = pipeline.offset(20) # Skip the first 20 results + >>> pipeline = pipeline.limit(20) # Take the next 20 results + + Args: + offset: The non-negative number of documents to skip. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Offset(offset)) + + def limit(self, limit: int) -> "_BasePipeline": + """ + Limits the maximum number of documents returned by previous stages to `limit`. + + This stage is useful for controlling the size of the result set, often used for: + - **Pagination:** In combination with `offset()` to retrieve specific pages. + - **Top-N queries:** To get a limited number of results after sorting. + - **Performance:** To prevent excessive data transfer. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> pipeline = client.pipeline().collection("books") + >>> # Limit the results to the top 10 highest-rated books + >>> pipeline = pipeline.sort(Field.of("rating").descending()) + >>> pipeline = pipeline.limit(10) + + Args: + limit: The non-negative maximum number of documents to return. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Limit(limit)) + + def aggregate( + self, + *accumulators: ExprWithAlias[Accumulator], + groups: Sequence[str | Selectable] = (), + ) -> "_BasePipeline": + """ + Performs aggregation operations on the documents from previous stages, + optionally grouped by specified fields or expressions. + + This stage allows you to calculate aggregate values (like sum, average, count, + min, max) over a set of documents. + + - **Accumulators:** Define the aggregation calculations using `Accumulator` + expressions (e.g., `sum()`, `avg()`, `count()`, `min()`, `max()`) combined + with `as_()` to name the result field. + - **Groups:** Optionally specify fields (by name or `Selectable`) to group + the documents by. Aggregations are then performed within each distinct group. + If no groups are provided, the aggregation is performed over the entire input. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> pipeline = client.pipeline().collection("books") + >>> # Calculate the average rating and total count for all books + >>> pipeline = pipeline.aggregate( + ... Field.of("rating").avg().as_("averageRating"), + ... Field.of("rating").count().as_("totalBooks") + ... ) + >>> # Calculate the average rating for each genre + >>> pipeline = pipeline.aggregate( + ... Field.of("rating").avg().as_("avg_rating"), + ... groups=["genre"] # Group by the 'genre' field + ... ) + >>> # Calculate the count for each author, grouping by Field object + >>> pipeline = pipeline.aggregate( + ... Count().as_("bookCount"), + ... groups=[Field.of("author")] + ... ) + + + Args: + *accumulators: One or more `ExprWithAlias[Accumulator]` expressions defining + the aggregations to perform and their output names. + groups: An optional sequence of field names (str) or `Selectable` + expressions to group by before aggregating. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Aggregate(*accumulators, groups=groups)) + + def distinct(self, *fields: str | Selectable) -> "_BasePipeline": + """ + Returns documents with distinct combinations of values for the specified + fields or expressions. + + This stage filters the results from previous stages to include only one + document for each unique combination of values in the specified `fields`. + The output documents contain only the fields specified in the `distinct` call. + + Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field, to_upper + >>> pipeline = client.pipeline().collection("books") + >>> # Get a list of unique genres (output has only 'genre' field) + >>> pipeline = pipeline.distinct("genre") + >>> # Get unique combinations of author (uppercase) and genre + >>> pipeline = pipeline.distinct( + ... Field.of("author").to_upper().as_("authorUpper"), + ... Field.of("genre") + ... ) + + + Args: + *fields: Field names (str) or `Selectable` expressions to consider when + determining distinct value combinations. The output will only + contain these fields/expressions. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Distinct(*fields)) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 5e0c775a2..70d619d3b 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -15,17 +15,22 @@ from __future__ import annotations from typing import ( Any, + List, Generic, TypeVar, Dict, + Sequence, ) from abc import ABC from abc import abstractmethod +from enum import Enum import datetime from google.cloud.firestore_v1.types.document import Value +from google.cloud.firestore_v1.types.query import StructuredQuery as Query_pb from google.cloud.firestore_v1.vector import Vector from google.cloud.firestore_v1._helpers import GeoPoint from google.cloud.firestore_v1._helpers import encode_value +from google.cloud.firestore_v1._helpers import decode_value CONSTANT_TYPE = TypeVar( "CONSTANT_TYPE", @@ -43,6 +48,48 @@ ) +class Ordering: + """Represents the direction for sorting results in a pipeline.""" + + class Direction(Enum): + ASCENDING = "ascending" + DESCENDING = "descending" + + def __init__(self, expr, order_dir: Direction | str = Direction.ASCENDING): + """ + Initializes an Ordering instance + + Args: + expr (Expr | str): The expression or field path string to sort by. + If a string is provided, it's treated as a field path. + order_dir (Direction | str): The direction to sort in. + Defaults to ascending + """ + self.expr = expr if isinstance(expr, Expr) else Field.of(expr) + self.order_dir = ( + Ordering.Direction[order_dir.upper()] + if isinstance(order_dir, str) + else order_dir + ) + + def __repr__(self): + if self.order_dir is Ordering.Direction.ASCENDING: + order_str = ".ascending()" + else: + order_str = ".descending()" + return f"{self.expr!r}{order_str}" + + def _to_pb(self) -> Value: + return Value( + map_value={ + "fields": { + "direction": Value(string_value=self.order_dir.value), + "expression": self.expr._to_pb(), + } + } + ) + + class Expr(ABC): """Represents an expression that can be evaluated to a value within the execution of a pipeline. @@ -66,6 +113,794 @@ def __repr__(self): def _to_pb(self) -> Value: raise NotImplementedError + @staticmethod + def _cast_to_expr_or_convert_to_constant(o: Any) -> "Expr": + return o if isinstance(o, Expr) else Constant(o) + + def add(self, other: Expr | float) -> "Add": + """Creates an expression that adds this expression to another expression or constant. + + Example: + >>> # Add the value of the 'quantity' field and the 'reserve' field. + >>> Field.of("quantity").add(Field.of("reserve")) + >>> # Add 5 to the value of the 'age' field + >>> Field.of("age").add(5) + + Args: + other: The expression or constant value to add to this expression. + + Returns: + A new `Expr` representing the addition operation. + """ + return Add(self, self._cast_to_expr_or_convert_to_constant(other)) + + def subtract(self, other: Expr | float) -> "Subtract": + """Creates an expression that subtracts another expression or constant from this expression. + + Example: + >>> # Subtract the 'discount' field from the 'price' field + >>> Field.of("price").subtract(Field.of("discount")) + >>> # Subtract 20 from the value of the 'total' field + >>> Field.of("total").subtract(20) + + Args: + other: The expression or constant value to subtract from this expression. + + Returns: + A new `Expr` representing the subtraction operation. + """ + return Subtract(self, self._cast_to_expr_or_convert_to_constant(other)) + + def multiply(self, other: Expr | float) -> "Multiply": + """Creates an expression that multiplies this expression by another expression or constant. + + Example: + >>> # Multiply the 'quantity' field by the 'price' field + >>> Field.of("quantity").multiply(Field.of("price")) + >>> # Multiply the 'value' field by 2 + >>> Field.of("value").multiply(2) + + Args: + other: The expression or constant value to multiply by. + + Returns: + A new `Expr` representing the multiplication operation. + """ + return Multiply(self, self._cast_to_expr_or_convert_to_constant(other)) + + def divide(self, other: Expr | float) -> "Divide": + """Creates an expression that divides this expression by another expression or constant. + + Example: + >>> # Divide the 'total' field by the 'count' field + >>> Field.of("total").divide(Field.of("count")) + >>> # Divide the 'value' field by 10 + >>> Field.of("value").divide(10) + + Args: + other: The expression or constant value to divide by. + + Returns: + A new `Expr` representing the division operation. + """ + return Divide(self, self._cast_to_expr_or_convert_to_constant(other)) + + def mod(self, other: Expr | float) -> "Mod": + """Creates an expression that calculates the modulo (remainder) to another expression or constant. + + Example: + >>> # Calculate the remainder of dividing the 'value' field by field 'divisor'. + >>> Field.of("value").mod(Field.of("divisor")) + >>> # Calculate the remainder of dividing the 'value' field by 5. + >>> Field.of("value").mod(5) + + Args: + other: The divisor expression or constant. + + Returns: + A new `Expr` representing the modulo operation. + """ + return Mod(self, self._cast_to_expr_or_convert_to_constant(other)) + + def logical_max(self, other: Expr | CONSTANT_TYPE) -> "LogicalMax": + """Creates an expression that returns the larger value between this expression + and another expression or constant, based on Firestore's value type ordering. + + Firestore's value type ordering is described here: + https://cloud.google.com/firestore/docs/concepts/data-types#value_type_ordering + + Example: + >>> # Returns the larger value between the 'discount' field and the 'cap' field. + >>> Field.of("discount").logical_max(Field.of("cap")) + >>> # Returns the larger value between the 'value' field and 10. + >>> Field.of("value").logical_max(10) + + Args: + other: The other expression or constant value to compare with. + + Returns: + A new `Expr` representing the logical max operation. + """ + return LogicalMax(self, self._cast_to_expr_or_convert_to_constant(other)) + + def logical_min(self, other: Expr | CONSTANT_TYPE) -> "LogicalMin": + """Creates an expression that returns the smaller value between this expression + and another expression or constant, based on Firestore's value type ordering. + + Firestore's value type ordering is described here: + https://cloud.google.com/firestore/docs/concepts/data-types#value_type_ordering + + Example: + >>> # Returns the smaller value between the 'discount' field and the 'floor' field. + >>> Field.of("discount").logical_min(Field.of("floor")) + >>> # Returns the smaller value between the 'value' field and 10. + >>> Field.of("value").logical_min(10) + + Args: + other: The other expression or constant value to compare with. + + Returns: + A new `Expr` representing the logical min operation. + """ + return LogicalMin(self, self._cast_to_expr_or_convert_to_constant(other)) + + def eq(self, other: Expr | CONSTANT_TYPE) -> "Eq": + """Creates an expression that checks if this expression is equal to another + expression or constant value. + + Example: + >>> # Check if the 'age' field is equal to 21 + >>> Field.of("age").eq(21) + >>> # Check if the 'city' field is equal to "London" + >>> Field.of("city").eq("London") + + Args: + other: The expression or constant value to compare for equality. + + Returns: + A new `Expr` representing the equality comparison. + """ + return Eq(self, self._cast_to_expr_or_convert_to_constant(other)) + + def neq(self, other: Expr | CONSTANT_TYPE) -> "Neq": + """Creates an expression that checks if this expression is not equal to another + expression or constant value. + + Example: + >>> # Check if the 'status' field is not equal to "completed" + >>> Field.of("status").neq("completed") + >>> # Check if the 'country' field is not equal to "USA" + >>> Field.of("country").neq("USA") + + Args: + other: The expression or constant value to compare for inequality. + + Returns: + A new `Expr` representing the inequality comparison. + """ + return Neq(self, self._cast_to_expr_or_convert_to_constant(other)) + + def gt(self, other: Expr | CONSTANT_TYPE) -> "Gt": + """Creates an expression that checks if this expression is greater than another + expression or constant value. + + Example: + >>> # Check if the 'age' field is greater than the 'limit' field + >>> Field.of("age").gt(Field.of("limit")) + >>> # Check if the 'price' field is greater than 100 + >>> Field.of("price").gt(100) + + Args: + other: The expression or constant value to compare for greater than. + + Returns: + A new `Expr` representing the greater than comparison. + """ + return Gt(self, self._cast_to_expr_or_convert_to_constant(other)) + + def gte(self, other: Expr | CONSTANT_TYPE) -> "Gte": + """Creates an expression that checks if this expression is greater than or equal + to another expression or constant value. + + Example: + >>> # Check if the 'quantity' field is greater than or equal to field 'requirement' plus 1 + >>> Field.of("quantity").gte(Field.of('requirement').add(1)) + >>> # Check if the 'score' field is greater than or equal to 80 + >>> Field.of("score").gte(80) + + Args: + other: The expression or constant value to compare for greater than or equal to. + + Returns: + A new `Expr` representing the greater than or equal to comparison. + """ + return Gte(self, self._cast_to_expr_or_convert_to_constant(other)) + + def lt(self, other: Expr | CONSTANT_TYPE) -> "Lt": + """Creates an expression that checks if this expression is less than another + expression or constant value. + + Example: + >>> # Check if the 'age' field is less than 'limit' + >>> Field.of("age").lt(Field.of('limit')) + >>> # Check if the 'price' field is less than 50 + >>> Field.of("price").lt(50) + + Args: + other: The expression or constant value to compare for less than. + + Returns: + A new `Expr` representing the less than comparison. + """ + return Lt(self, self._cast_to_expr_or_convert_to_constant(other)) + + def lte(self, other: Expr | CONSTANT_TYPE) -> "Lte": + """Creates an expression that checks if this expression is less than or equal to + another expression or constant value. + + Example: + >>> # Check if the 'quantity' field is less than or equal to 20 + >>> Field.of("quantity").lte(Constant.of(20)) + >>> # Check if the 'score' field is less than or equal to 70 + >>> Field.of("score").lte(70) + + Args: + other: The expression or constant value to compare for less than or equal to. + + Returns: + A new `Expr` representing the less than or equal to comparison. + """ + return Lte(self, self._cast_to_expr_or_convert_to_constant(other)) + + def in_any(self, array: List[Expr | CONSTANT_TYPE]) -> "In": + """Creates an expression that checks if this expression is equal to any of the + provided values or expressions. + + Example: + >>> # Check if the 'category' field is either "Electronics" or value of field 'primaryType' + >>> Field.of("category").in_any(["Electronics", Field.of("primaryType")]) + + Args: + array: The values or expressions to check against. + + Returns: + A new `Expr` representing the 'IN' comparison. + """ + return In(self, [self._cast_to_expr_or_convert_to_constant(v) for v in array]) + + def not_in_any(self, array: List[Expr | CONSTANT_TYPE]) -> "Not": + """Creates an expression that checks if this expression is not equal to any of the + provided values or expressions. + + Example: + >>> # Check if the 'status' field is neither "pending" nor "cancelled" + >>> Field.of("status").not_in_any(["pending", "cancelled"]) + + Args: + *others: The values or expressions to check against. + + Returns: + A new `Expr` representing the 'NOT IN' comparison. + """ + return Not(self.in_any(array)) + + def array_contains(self, element: Expr | CONSTANT_TYPE) -> "ArrayContains": + """Creates an expression that checks if an array contains a specific element or value. + + Example: + >>> # Check if the 'sizes' array contains the value from the 'selectedSize' field + >>> Field.of("sizes").array_contains(Field.of("selectedSize")) + >>> # Check if the 'colors' array contains "red" + >>> Field.of("colors").array_contains("red") + + Args: + element: The element (expression or constant) to search for in the array. + + Returns: + A new `Expr` representing the 'array_contains' comparison. + """ + return ArrayContains(self, self._cast_to_expr_or_convert_to_constant(element)) + + def array_contains_all( + self, elements: List[Expr | CONSTANT_TYPE] + ) -> "ArrayContainsAll": + """Creates an expression that checks if an array contains all the specified elements. + + Example: + >>> # Check if the 'tags' array contains both "news" and "sports" + >>> Field.of("tags").array_contains_all(["news", "sports"]) + >>> # Check if the 'tags' array contains both of the values from field 'tag1' and "tag2" + >>> Field.of("tags").array_contains_all([Field.of("tag1"), "tag2"]) + + Args: + elements: The list of elements (expressions or constants) to check for in the array. + + Returns: + A new `Expr` representing the 'array_contains_all' comparison. + """ + return ArrayContainsAll( + self, [self._cast_to_expr_or_convert_to_constant(e) for e in elements] + ) + + def array_contains_any( + self, elements: List[Expr | CONSTANT_TYPE] + ) -> "ArrayContainsAny": + """Creates an expression that checks if an array contains any of the specified elements. + + Example: + >>> # Check if the 'categories' array contains either values from field "cate1" or "cate2" + >>> Field.of("categories").array_contains_any([Field.of("cate1"), Field.of("cate2")]) + >>> # Check if the 'groups' array contains either the value from the 'userGroup' field + >>> # or the value "guest" + >>> Field.of("groups").array_contains_any([Field.of("userGroup"), "guest"]) + + Args: + elements: The list of elements (expressions or constants) to check for in the array. + + Returns: + A new `Expr` representing the 'array_contains_any' comparison. + """ + return ArrayContainsAny( + self, [self._cast_to_expr_or_convert_to_constant(e) for e in elements] + ) + + def array_length(self) -> "ArrayLength": + """Creates an expression that calculates the length of an array. + + Example: + >>> # Get the number of items in the 'cart' array + >>> Field.of("cart").array_length() + + Returns: + A new `Expr` representing the length of the array. + """ + return ArrayLength(self) + + def array_reverse(self) -> "ArrayReverse": + """Creates an expression that returns the reversed content of an array. + + Example: + >>> # Get the 'preferences' array in reversed order. + >>> Field.of("preferences").array_reverse() + + Returns: + A new `Expr` representing the reversed array. + """ + return ArrayReverse(self) + + def is_nan(self) -> "IsNaN": + """Creates an expression that checks if this expression evaluates to 'NaN' (Not a Number). + + Example: + >>> # Check if the result of a calculation is NaN + >>> Field.of("value").divide(0).is_nan() + + Returns: + A new `Expr` representing the 'isNaN' check. + """ + return IsNaN(self) + + def exists(self) -> "Exists": + """Creates an expression that checks if a field exists in the document. + + Example: + >>> # Check if the document has a field named "phoneNumber" + >>> Field.of("phoneNumber").exists() + + Returns: + A new `Expr` representing the 'exists' check. + """ + return Exists(self) + + def sum(self) -> "Sum": + """Creates an aggregation that calculates the sum of a numeric field across multiple stage inputs. + + Example: + >>> # Calculate the total revenue from a set of orders + >>> Field.of("orderAmount").sum().as_("totalRevenue") + + Returns: + A new `Accumulator` representing the 'sum' aggregation. + """ + return Sum(self) + + def avg(self) -> "Avg": + """Creates an aggregation that calculates the average (mean) of a numeric field across multiple + stage inputs. + + Example: + >>> # Calculate the average age of users + >>> Field.of("age").avg().as_("averageAge") + + Returns: + A new `Accumulator` representing the 'avg' aggregation. + """ + return Avg(self) + + def count(self) -> "Count": + """Creates an aggregation that counts the number of stage inputs with valid evaluations of the + expression or field. + + Example: + >>> # Count the total number of products + >>> Field.of("productId").count().as_("totalProducts") + + Returns: + A new `Accumulator` representing the 'count' aggregation. + """ + return Count(self) + + def min(self) -> "Min": + """Creates an aggregation that finds the minimum value of a field across multiple stage inputs. + + Example: + >>> # Find the lowest price of all products + >>> Field.of("price").min().as_("lowestPrice") + + Returns: + A new `Accumulator` representing the 'min' aggregation. + """ + return Min(self) + + def max(self) -> "Max": + """Creates an aggregation that finds the maximum value of a field across multiple stage inputs. + + Example: + >>> # Find the highest score in a leaderboard + >>> Field.of("score").max().as_("highestScore") + + Returns: + A new `Accumulator` representing the 'max' aggregation. + """ + return Max(self) + + def char_length(self) -> "CharLength": + """Creates an expression that calculates the character length of a string. + + Example: + >>> # Get the character length of the 'name' field + >>> Field.of("name").char_length() + + Returns: + A new `Expr` representing the length of the string. + """ + return CharLength(self) + + def byte_length(self) -> "ByteLength": + """Creates an expression that calculates the byte length of a string in its UTF-8 form. + + Example: + >>> # Get the byte length of the 'name' field + >>> Field.of("name").byte_length() + + Returns: + A new `Expr` representing the byte length of the string. + """ + return ByteLength(self) + + def like(self, pattern: Expr | str) -> "Like": + """Creates an expression that performs a case-sensitive string comparison. + + Example: + >>> # Check if the 'title' field contains the word "guide" (case-sensitive) + >>> Field.of("title").like("%guide%") + >>> # Check if the 'title' field matches the pattern specified in field 'pattern'. + >>> Field.of("title").like(Field.of("pattern")) + + Args: + pattern: The pattern (string or expression) to search for. You can use "%" as a wildcard character. + + Returns: + A new `Expr` representing the 'like' comparison. + """ + return Like(self, self._cast_to_expr_or_convert_to_constant(pattern)) + + def regex_contains(self, regex: Expr | str) -> "RegexContains": + """Creates an expression that checks if a string contains a specified regular expression as a + substring. + + Example: + >>> # Check if the 'description' field contains "example" (case-insensitive) + >>> Field.of("description").regex_contains("(?i)example") + >>> # Check if the 'description' field contains the regular expression stored in field 'regex' + >>> Field.of("description").regex_contains(Field.of("regex")) + + Args: + regex: The regular expression (string or expression) to use for the search. + + Returns: + A new `Expr` representing the 'contains' comparison. + """ + return RegexContains(self, self._cast_to_expr_or_convert_to_constant(regex)) + + def regex_matches(self, regex: Expr | str) -> "RegexMatch": + """Creates an expression that checks if a string matches a specified regular expression. + + Example: + >>> # Check if the 'email' field matches a valid email pattern + >>> Field.of("email").regex_matches("[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Za-z]{2,}") + >>> # Check if the 'email' field matches a regular expression stored in field 'regex' + >>> Field.of("email").regex_matches(Field.of("regex")) + + Args: + regex: The regular expression (string or expression) to use for the match. + + Returns: + A new `Expr` representing the regular expression match. + """ + return RegexMatch(self, self._cast_to_expr_or_convert_to_constant(regex)) + + def str_contains(self, substring: Expr | str) -> "StrContains": + """Creates an expression that checks if this string expression contains a specified substring. + + Example: + >>> # Check if the 'description' field contains "example". + >>> Field.of("description").str_contains("example") + >>> # Check if the 'description' field contains the value of the 'keyword' field. + >>> Field.of("description").str_contains(Field.of("keyword")) + + Args: + substring: The substring (string or expression) to use for the search. + + Returns: + A new `Expr` representing the 'contains' comparison. + """ + return StrContains(self, self._cast_to_expr_or_convert_to_constant(substring)) + + def starts_with(self, prefix: Expr | str) -> "StartsWith": + """Creates an expression that checks if a string starts with a given prefix. + + Example: + >>> # Check if the 'name' field starts with "Mr." + >>> Field.of("name").starts_with("Mr.") + >>> # Check if the 'fullName' field starts with the value of the 'firstName' field + >>> Field.of("fullName").starts_with(Field.of("firstName")) + + Args: + prefix: The prefix (string or expression) to check for. + + Returns: + A new `Expr` representing the 'starts with' comparison. + """ + return StartsWith(self, self._cast_to_expr_or_convert_to_constant(prefix)) + + def ends_with(self, postfix: Expr | str) -> "EndsWith": + """Creates an expression that checks if a string ends with a given postfix. + + Example: + >>> # Check if the 'filename' field ends with ".txt" + >>> Field.of("filename").ends_with(".txt") + >>> # Check if the 'url' field ends with the value of the 'extension' field + >>> Field.of("url").ends_with(Field.of("extension")) + + Args: + postfix: The postfix (string or expression) to check for. + + Returns: + A new `Expr` representing the 'ends with' comparison. + """ + return EndsWith(self, self._cast_to_expr_or_convert_to_constant(postfix)) + + def str_concat(self, *elements: Expr | CONSTANT_TYPE) -> "StrConcat": + """Creates an expression that concatenates string expressions, fields or constants together. + + Example: + >>> # Combine the 'firstName', " ", and 'lastName' fields into a single string + >>> Field.of("firstName").str_concat(" ", Field.of("lastName")) + + Args: + *elements: The expressions or constants (typically strings) to concatenate. + + Returns: + A new `Expr` representing the concatenated string. + """ + return StrConcat( + self, *[self._cast_to_expr_or_convert_to_constant(el) for el in elements] + ) + + def map_get(self, key: str) -> "MapGet": + """Accesses a value from a map (object) field using the provided key. + + Example: + >>> # Get the 'city' value from + >>> # the 'address' map field + >>> Field.of("address").map_get("city") + + Args: + key: The key to access in the map. + + Returns: + A new `Expr` representing the value associated with the given key in the map. + """ + return MapGet(self, Constant.of(key)) + + def vector_length(self) -> "VectorLength": + """Creates an expression that calculates the length (dimension) of a Firestore Vector. + + Example: + >>> # Get the vector length (dimension) of the field 'embedding'. + >>> Field.of("embedding").vector_length() + + Returns: + A new `Expr` representing the length of the vector. + """ + return VectorLength(self) + + def timestamp_to_unix_micros(self) -> "TimestampToUnixMicros": + """Creates an expression that converts a timestamp to the number of microseconds since the epoch + (1970-01-01 00:00:00 UTC). + + Truncates higher levels of precision by rounding down to the beginning of the microsecond. + + Example: + >>> # Convert the 'timestamp' field to microseconds since the epoch. + >>> Field.of("timestamp").timestamp_to_unix_micros() + + Returns: + A new `Expr` representing the number of microseconds since the epoch. + """ + return TimestampToUnixMicros(self) + + def unix_micros_to_timestamp(self) -> "UnixMicrosToTimestamp": + """Creates an expression that converts a number of microseconds since the epoch (1970-01-01 + 00:00:00 UTC) to a timestamp. + + Example: + >>> # Convert the 'microseconds' field to a timestamp. + >>> Field.of("microseconds").unix_micros_to_timestamp() + + Returns: + A new `Expr` representing the timestamp. + """ + return UnixMicrosToTimestamp(self) + + def timestamp_to_unix_millis(self) -> "TimestampToUnixMillis": + """Creates an expression that converts a timestamp to the number of milliseconds since the epoch + (1970-01-01 00:00:00 UTC). + + Truncates higher levels of precision by rounding down to the beginning of the millisecond. + + Example: + >>> # Convert the 'timestamp' field to milliseconds since the epoch. + >>> Field.of("timestamp").timestamp_to_unix_millis() + + Returns: + A new `Expr` representing the number of milliseconds since the epoch. + """ + return TimestampToUnixMillis(self) + + def unix_millis_to_timestamp(self) -> "UnixMillisToTimestamp": + """Creates an expression that converts a number of milliseconds since the epoch (1970-01-01 + 00:00:00 UTC) to a timestamp. + + Example: + >>> # Convert the 'milliseconds' field to a timestamp. + >>> Field.of("milliseconds").unix_millis_to_timestamp() + + Returns: + A new `Expr` representing the timestamp. + """ + return UnixMillisToTimestamp(self) + + def timestamp_to_unix_seconds(self) -> "TimestampToUnixSeconds": + """Creates an expression that converts a timestamp to the number of seconds since the epoch + (1970-01-01 00:00:00 UTC). + + Truncates higher levels of precision by rounding down to the beginning of the second. + + Example: + >>> # Convert the 'timestamp' field to seconds since the epoch. + >>> Field.of("timestamp").timestamp_to_unix_seconds() + + Returns: + A new `Expr` representing the number of seconds since the epoch. + """ + return TimestampToUnixSeconds(self) + + def unix_seconds_to_timestamp(self) -> "UnixSecondsToTimestamp": + """Creates an expression that converts a number of seconds since the epoch (1970-01-01 00:00:00 + UTC) to a timestamp. + + Example: + >>> # Convert the 'seconds' field to a timestamp. + >>> Field.of("seconds").unix_seconds_to_timestamp() + + Returns: + A new `Expr` representing the timestamp. + """ + return UnixSecondsToTimestamp(self) + + def timestamp_add(self, unit: Expr | str, amount: Expr | float) -> "TimestampAdd": + """Creates an expression that adds a specified amount of time to this timestamp expression. + + Example: + >>> # Add a duration specified by the 'unit' and 'amount' fields to the 'timestamp' field. + >>> Field.of("timestamp").timestamp_add(Field.of("unit"), Field.of("amount")) + >>> # Add 1.5 days to the 'timestamp' field. + >>> Field.of("timestamp").timestamp_add("day", 1.5) + + Args: + unit: The expression or string evaluating to the unit of time to add, must be one of + 'microsecond', 'millisecond', 'second', 'minute', 'hour', 'day'. + amount: The expression or float representing the amount of time to add. + + Returns: + A new `Expr` representing the resulting timestamp. + """ + return TimestampAdd( + self, + self._cast_to_expr_or_convert_to_constant(unit), + self._cast_to_expr_or_convert_to_constant(amount), + ) + + def timestamp_sub(self, unit: Expr | str, amount: Expr | float) -> "TimestampSub": + """Creates an expression that subtracts a specified amount of time from this timestamp expression. + + Example: + >>> # Subtract a duration specified by the 'unit' and 'amount' fields from the 'timestamp' field. + >>> Field.of("timestamp").timestamp_sub(Field.of("unit"), Field.of("amount")) + >>> # Subtract 2.5 hours from the 'timestamp' field. + >>> Field.of("timestamp").timestamp_sub("hour", 2.5) + + Args: + unit: The expression or string evaluating to the unit of time to subtract, must be one of + 'microsecond', 'millisecond', 'second', 'minute', 'hour', 'day'. + amount: The expression or float representing the amount of time to subtract. + + Returns: + A new `Expr` representing the resulting timestamp. + """ + return TimestampSub( + self, + self._cast_to_expr_or_convert_to_constant(unit), + self._cast_to_expr_or_convert_to_constant(amount), + ) + + def ascending(self) -> Ordering: + """Creates an `Ordering` that sorts documents in ascending order based on this expression. + + Example: + >>> # Sort documents by the 'name' field in ascending order + >>> firestore.pipeline().collection("users").sort(Field.of("name").ascending()) + + Returns: + A new `Ordering` for ascending sorting. + """ + return Ordering(self, Ordering.Direction.ASCENDING) + + def descending(self) -> Ordering: + """Creates an `Ordering` that sorts documents in descending order based on this expression. + + Example: + >>> # Sort documents by the 'createdAt' field in descending order + >>> firestore.pipeline().collection("users").sort(Field.of("createdAt").descending()) + + Returns: + A new `Ordering` for descending sorting. + """ + return Ordering(self, Ordering.Direction.DESCENDING) + + def as_(self, alias: str) -> "ExprWithAlias": + """Assigns an alias to this expression. + + Aliases are useful for renaming fields in the output of a stage or for giving meaningful + names to calculated values. + + Example: + >>> # Calculate the total price and assign it the alias "totalPrice" and add it to the output. + >>> firestore.pipeline().collection("items").add_fields( + ... Field.of("price").multiply(Field.of("quantity")).as_("totalPrice") + ... ) + + Args: + alias: The alias to assign to this expression. + + Returns: + A new `Selectable` (typically an `ExprWithAlias`) that wraps this + expression and associates it with the provided alias. + """ + return ExprWithAlias(self, alias) + class Constant(Expr, Generic[CONSTANT_TYPE]): """Represents a constant literal value in an expression.""" @@ -73,6 +908,12 @@ class Constant(Expr, Generic[CONSTANT_TYPE]): def __init__(self, value: CONSTANT_TYPE): self.value: CONSTANT_TYPE = value + def __eq__(self, other): + if not isinstance(other, Constant): + return other == self.value + else: + return other.value == self.value + @staticmethod def of(value: CONSTANT_TYPE) -> Constant[CONSTANT_TYPE]: """Creates a constant expression from a Python value.""" @@ -83,3 +924,1379 @@ def __repr__(self): def _to_pb(self) -> Value: return encode_value(self.value) + + +class ListOfExprs(Expr): + """Represents a list of expressions, typically used as an argument to functions like 'in' or array functions.""" + + def __init__(self, exprs: List[Expr]): + self.exprs: list[Expr] = exprs + + def __eq__(self, other): + if not isinstance(other, ListOfExprs): + return False + else: + return other.exprs == self.exprs + + def __repr__(self): + return f"{self.__class__.__name__}({self.exprs})" + + def _to_pb(self): + return Value(array_value={"values": [e._to_pb() for e in self.exprs]}) + + +class Function(Expr): + """A base class for expressions that represent function calls.""" + + def __init__(self, name: str, params: Sequence[Expr]): + self.name = name + self.params = list(params) + + def __eq__(self, other): + if not isinstance(other, Function): + return False + else: + return other.name == self.name and other.params == self.params + + def __repr__(self): + return f"{self.__class__.__name__}({', '.join([repr(p) for p in self.params])})" + + def _to_pb(self): + return Value( + function_value={ + "name": self.name, + "args": [p._to_pb() for p in self.params], + } + ) + + def add(left: Expr | str, right: Expr | float) -> "Add": + """Creates an expression that adds two expressions together. + + Example: + >>> Function.add("rating", 5) + >>> Function.add(Field.of("quantity"), Field.of("reserve")) + + Args: + left: The first expression or field path to add. + right: The second expression or constant value to add. + + Returns: + A new `Expr` representing the addition operation. + """ + left_expr = Field.of(left) if isinstance(left, str) else left + return Expr.add(left_expr, right) + + def subtract(left: Expr | str, right: Expr | float) -> "Subtract": + """Creates an expression that subtracts another expression or constant from this expression. + + Example: + >>> Function.subtract("total", 20) + >>> Function.subtract(Field.of("price"), Field.of("discount")) + + Args: + left: The expression or field path to subtract from. + right: The expression or constant value to subtract. + + Returns: + A new `Expr` representing the subtraction operation. + """ + left_expr = Field.of(left) if isinstance(left, str) else left + return Expr.subtract(left_expr, right) + + def multiply(left: Expr | str, right: Expr | float) -> "Multiply": + """Creates an expression that multiplies this expression by another expression or constant. + + Example: + >>> Function.multiply("value", 2) + >>> Function.multiply(Field.of("quantity"), Field.of("price")) + + Args: + left: The expression or field path to multiply. + right: The expression or constant value to multiply by. + + Returns: + A new `Expr` representing the multiplication operation. + """ + left_expr = Field.of(left) if isinstance(left, str) else left + return Expr.multiply(left_expr, right) + + def divide(left: Expr | str, right: Expr | float) -> "Divide": + """Creates an expression that divides this expression by another expression or constant. + + Example: + >>> Function.divide("value", 10) + >>> Function.divide(Field.of("total"), Field.of("count")) + + Args: + left: The expression or field path to be divided. + right: The expression or constant value to divide by. + + Returns: + A new `Expr` representing the division operation. + """ + left_expr = Field.of(left) if isinstance(left, str) else left + return Expr.divide(left_expr, right) + + def mod(left: Expr | str, right: Expr | float) -> "Mod": + """Creates an expression that calculates the modulo (remainder) to another expression or constant. + + Example: + >>> Function.mod("value", 5) + >>> Function.mod(Field.of("value"), Field.of("divisor")) + + Args: + left: The dividend expression or field path. + right: The divisor expression or constant. + + Returns: + A new `Expr` representing the modulo operation. + """ + left_expr = Field.of(left) if isinstance(left, str) else left + return Expr.mod(left_expr, right) + + def logical_max(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "LogicalMax": + """Creates an expression that returns the larger value between this expression + and another expression or constant, based on Firestore's value type ordering. + + Firestore's value type ordering is described here: + https://cloud.google.com/firestore/docs/concepts/data-types#value_type_ordering + + Example: + >>> Function.logical_max("value", 10) + >>> Function.logical_max(Field.of("discount"), Field.of("cap")) + + Args: + left: The expression or field path to compare. + right: The other expression or constant value to compare with. + + Returns: + A new `Expr` representing the logical max operation. + """ + left_expr = Field.of(left) if isinstance(left, str) else left + return Expr.logical_max(left_expr, right) + + def logical_min(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "LogicalMin": + """Creates an expression that returns the smaller value between this expression + and another expression or constant, based on Firestore's value type ordering. + + Firestore's value type ordering is described here: + https://cloud.google.com/firestore/docs/concepts/data-types#value_type_ordering + + Example: + >>> Function.logical_min("value", 10) + >>> Function.logical_min(Field.of("discount"), Field.of("floor")) + + Args: + left: The expression or field path to compare. + right: The other expression or constant value to compare with. + + Returns: + A new `Expr` representing the logical min operation. + """ + left_expr = Field.of(left) if isinstance(left, str) else left + return Expr.logical_min(left_expr, right) + + def eq(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Eq": + """Creates an expression that checks if this expression is equal to another + expression or constant value. + + Example: + >>> Function.eq("city", "London") + >>> Function.eq(Field.of("age"), 21) + + Args: + left: The expression or field path to compare. + right: The expression or constant value to compare for equality. + + Returns: + A new `Expr` representing the equality comparison. + """ + left_expr = Field.of(left) if isinstance(left, str) else left + return Expr.eq(left_expr, right) + + def neq(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Neq": + """Creates an expression that checks if this expression is not equal to another + expression or constant value. + + Example: + >>> Function.neq("country", "USA") + >>> Function.neq(Field.of("status"), "completed") + + Args: + left: The expression or field path to compare. + right: The expression or constant value to compare for inequality. + + Returns: + A new `Expr` representing the inequality comparison. + """ + left_expr = Field.of(left) if isinstance(left, str) else left + return Expr.neq(left_expr, right) + + def gt(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Gt": + """Creates an expression that checks if this expression is greater than another + expression or constant value. + + Example: + >>> Function.gt("price", 100) + >>> Function.gt(Field.of("age"), Field.of("limit")) + + Args: + left: The expression or field path to compare. + right: The expression or constant value to compare for greater than. + + Returns: + A new `Expr` representing the greater than comparison. + """ + left_expr = Field.of(left) if isinstance(left, str) else left + return Expr.gt(left_expr, right) + + def gte(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Gte": + """Creates an expression that checks if this expression is greater than or equal + to another expression or constant value. + + Example: + >>> Function.gte("score", 80) + >>> Function.gte(Field.of("quantity"), Field.of('requirement').add(1)) + + Args: + left: The expression or field path to compare. + right: The expression or constant value to compare for greater than or equal to. + + Returns: + A new `Expr` representing the greater than or equal to comparison. + """ + left_expr = Field.of(left) if isinstance(left, str) else left + return Expr.gte(left_expr, right) + + def lt(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Lt": + """Creates an expression that checks if this expression is less than another + expression or constant value. + + Example: + >>> Function.lt("price", 50) + >>> Function.lt(Field.of("age"), Field.of('limit')) + + Args: + left: The expression or field path to compare. + right: The expression or constant value to compare for less than. + + Returns: + A new `Expr` representing the less than comparison. + """ + left_expr = Field.of(left) if isinstance(left, str) else left + return Expr.lt(left_expr, right) + + def lte(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Lte": + """Creates an expression that checks if this expression is less than or equal to + another expression or constant value. + + Example: + >>> Function.lte("score", 70) + >>> Function.lte(Field.of("quantity"), Constant.of(20)) + + Args: + left: The expression or field path to compare. + right: The expression or constant value to compare for less than or equal to. + + Returns: + A new `Expr` representing the less than or equal to comparison. + """ + left_expr = Field.of(left) if isinstance(left, str) else left + return Expr.lte(left_expr, right) + + def in_any(left: Expr | str, array: List[Expr | CONSTANT_TYPE]) -> "In": + """Creates an expression that checks if this expression is equal to any of the + provided values or expressions. + + Example: + >>> Function.in_any("category", ["Electronics", "Apparel"]) + >>> Function.in_any(Field.of("category"), ["Electronics", Field.of("primaryType")]) + + Args: + left: The expression or field path to compare. + array: The values or expressions to check against. + + Returns: + A new `Expr` representing the 'IN' comparison. + """ + left_expr = Field.of(left) if isinstance(left, str) else left + return Expr.in_any(left_expr, array) + + def not_in_any(left: Expr | str, array: List[Expr | CONSTANT_TYPE]) -> "Not": + """Creates an expression that checks if this expression is not equal to any of the + provided values or expressions. + + Example: + >>> Function.not_in_any("status", ["pending", "cancelled"]) + + Args: + left: The expression or field path to compare. + array: The values or expressions to check against. + + Returns: + A new `Expr` representing the 'NOT IN' comparison. + """ + left_expr = Field.of(left) if isinstance(left, str) else left + return Expr.not_in_any(left_expr, array) + + def array_contains( + array: Expr | str, element: Expr | CONSTANT_TYPE + ) -> "ArrayContains": + """Creates an expression that checks if an array contains a specific element or value. + + Example: + >>> Function.array_contains("colors", "red") + >>> Function.array_contains(Field.of("sizes"), Field.of("selectedSize")) + + Args: + array: The array expression or field path to check. + element: The element (expression or constant) to search for in the array. + + Returns: + A new `Expr` representing the 'array_contains' comparison. + """ + array_expr = Field.of(array) if isinstance(array, str) else array + return Expr.array_contains(array_expr, element) + + def array_contains_all( + array: Expr | str, elements: List[Expr | CONSTANT_TYPE] + ) -> "ArrayContainsAll": + """Creates an expression that checks if an array contains all the specified elements. + + Example: + >>> Function.array_contains_all("tags", ["news", "sports"]) + >>> Function.array_contains_all(Field.of("tags"), [Field.of("tag1"), "tag2"]) + + Args: + array: The array expression or field path to check. + elements: The list of elements (expressions or constants) to check for in the array. + + Returns: + A new `Expr` representing the 'array_contains_all' comparison. + """ + array_expr = Field.of(array) if isinstance(array, str) else array + return Expr.array_contains_all(array_expr, elements) + + def array_contains_any( + array: Expr | str, elements: List[Expr | CONSTANT_TYPE] + ) -> "ArrayContainsAny": + """Creates an expression that checks if an array contains any of the specified elements. + + Example: + >>> Function.array_contains_any("groups", ["admin", "editor"]) + >>> Function.array_contains_any(Field.of("categories"), [Field.of("cate1"), Field.of("cate2")]) + + Args: + array: The array expression or field path to check. + elements: The list of elements (expressions or constants) to check for in the array. + + Returns: + A new `Expr` representing the 'array_contains_any' comparison. + """ + array_expr = Field.of(array) if isinstance(array, str) else array + return Expr.array_contains_any(array_expr, elements) + + def array_length(array: Expr | str) -> "ArrayLength": + """Creates an expression that calculates the length of an array. + + Example: + >>> Function.array_length("cart") + + Returns: + A new `Expr` representing the length of the array. + """ + array_expr = Field.of(array) if isinstance(array, str) else array + return Expr.array_length(array_expr) + + def array_reverse(array: Expr | str) -> "ArrayReverse": + """Creates an expression that returns the reversed content of an array. + + Example: + >>> Function.array_reverse("preferences") + + Returns: + A new `Expr` representing the reversed array. + """ + array_expr = Field.of(array) if isinstance(array, str) else array + return Expr.array_reverse(array_expr) + + def is_nan(expr: Expr | str) -> "IsNaN": + """Creates an expression that checks if this expression evaluates to 'NaN' (Not a Number). + + Example: + >>> Function.is_nan("measurement") + + Returns: + A new `Expr` representing the 'isNaN' check. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.is_nan(expr_val) + + def exists(expr: Expr | str) -> "Exists": + """Creates an expression that checks if a field exists in the document. + + Example: + >>> Function.exists("phoneNumber") + + Returns: + A new `Expr` representing the 'exists' check. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.exists(expr_val) + + def sum(expr: Expr | str) -> "Sum": + """Creates an aggregation that calculates the sum of a numeric field across multiple stage inputs. + + Example: + >>> Function.sum("orderAmount") + + Returns: + A new `Accumulator` representing the 'sum' aggregation. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.sum(expr_val) + + def avg(expr: Expr | str) -> "Avg": + """Creates an aggregation that calculates the average (mean) of a numeric field across multiple + stage inputs. + + Example: + >>> Function.avg("age") + + Returns: + A new `Accumulator` representing the 'avg' aggregation. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.avg(expr_val) + + def count(expr: Expr | str | None = None) -> "Count": + """Creates an aggregation that counts the number of stage inputs with valid evaluations of the + expression or field. If no expression is provided, it counts all inputs. + + Example: + >>> Function.count("productId") + >>> Function.count() + + Returns: + A new `Accumulator` representing the 'count' aggregation. + """ + if expr is None: + return Count() + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.count(expr_val) + + def min(expr: Expr | str) -> "Min": + """Creates an aggregation that finds the minimum value of a field across multiple stage inputs. + + Example: + >>> Function.min("price") + + Returns: + A new `Accumulator` representing the 'min' aggregation. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.min(expr_val) + + def max(expr: Expr | str) -> "Max": + """Creates an aggregation that finds the maximum value of a field across multiple stage inputs. + + Example: + >>> Function.max("score") + + Returns: + A new `Accumulator` representing the 'max' aggregation. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.max(expr_val) + + def char_length(expr: Expr | str) -> "CharLength": + """Creates an expression that calculates the character length of a string. + + Example: + >>> Function.char_length("name") + + Returns: + A new `Expr` representing the length of the string. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.char_length(expr_val) + + def byte_length(expr: Expr | str) -> "ByteLength": + """Creates an expression that calculates the byte length of a string in its UTF-8 form. + + Example: + >>> Function.byte_length("name") + + Returns: + A new `Expr` representing the byte length of the string. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.byte_length(expr_val) + + def like(expr: Expr | str, pattern: Expr | str) -> "Like": + """Creates an expression that performs a case-sensitive string comparison. + + Example: + >>> Function.like("title", "%guide%") + >>> Function.like(Field.of("title"), Field.of("pattern")) + + Args: + expr: The expression or field path to perform the comparison on. + pattern: The pattern (string or expression) to search for. You can use "%" as a wildcard character. + + Returns: + A new `Expr` representing the 'like' comparison. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.like(expr_val, pattern) + + def regex_contains(expr: Expr | str, regex: Expr | str) -> "RegexContains": + """Creates an expression that checks if a string contains a specified regular expression as a + substring. + + Example: + >>> Function.regex_contains("description", "(?i)example") + >>> Function.regex_contains(Field.of("description"), Field.of("regex")) + + Args: + expr: The expression or field path to perform the comparison on. + regex: The regular expression (string or expression) to use for the search. + + Returns: + A new `Expr` representing the 'contains' comparison. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.regex_contains(expr_val, regex) + + def regex_matches(expr: Expr | str, regex: Expr | str) -> "RegexMatch": + """Creates an expression that checks if a string matches a specified regular expression. + + Example: + >>> # Check if the 'email' field matches a valid email pattern + >>> Function.regex_matches("email", "[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Za-z]{2,}") + >>> Function.regex_matches(Field.of("email"), Field.of("regex")) + + Args: + expr: The expression or field path to match against. + regex: The regular expression (string or expression) to use for the match. + + Returns: + A new `Expr` representing the regular expression match. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.regex_matches(expr_val, regex) + + def str_contains(expr: Expr | str, substring: Expr | str) -> "StrContains": + """Creates an expression that checks if this string expression contains a specified substring. + + Example: + >>> Function.str_contains("description", "example") + >>> Function.str_contains(Field.of("description"), Field.of("keyword")) + + Args: + expr: The expression or field path to perform the comparison on. + substring: The substring (string or expression) to use for the search. + + Returns: + A new `Expr` representing the 'contains' comparison. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.str_contains(expr_val, substring) + + def starts_with(expr: Expr | str, prefix: Expr | str) -> "StartsWith": + """Creates an expression that checks if a string starts with a given prefix. + + Example: + >>> Function.starts_with("name", "Mr.") + >>> Function.starts_with(Field.of("fullName"), Field.of("firstName")) + + Args: + expr: The expression or field path to check. + prefix: The prefix (string or expression) to check for. + + Returns: + A new `Expr` representing the 'starts with' comparison. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.starts_with(expr_val, prefix) + + def ends_with(expr: Expr | str, postfix: Expr | str) -> "EndsWith": + """Creates an expression that checks if a string ends with a given postfix. + + Example: + >>> Function.ends_with("filename", ".txt") + >>> Function.ends_with(Field.of("url"), Field.of("extension")) + + Args: + expr: The expression or field path to check. + postfix: The postfix (string or expression) to check for. + + Returns: + A new `Expr` representing the 'ends with' comparison. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.ends_with(expr_val, postfix) + + def str_concat(first: Expr | str, *elements: Expr | CONSTANT_TYPE) -> "StrConcat": + """Creates an expression that concatenates string expressions, fields or constants together. + + Example: + >>> Function.str_concat("firstName", " ", Field.of("lastName")) + + Args: + first: The first expression or field path to concatenate. + *elements: The expressions or constants (typically strings) to concatenate. + + Returns: + A new `Expr` representing the concatenated string. + """ + first_expr = Field.of(first) if isinstance(first, str) else first + return Expr.str_concat(first_expr, *elements) + + def map_get(map_expr: Expr | str, key: str) -> "MapGet": + """Accesses a value from a map (object) field using the provided key. + + Example: + >>> Function.map_get("address", "city") + + Args: + map_expr: The expression or field path of the map. + key: The key to access in the map. + + Returns: + A new `Expr` representing the value associated with the given key in the map. + """ + map_val = Field.of(map_expr) if isinstance(map_expr, str) else map_expr + return Expr.map_get(map_val, key) + + def vector_length(vector_expr: Expr | str) -> "VectorLength": + """Creates an expression that calculates the length (dimension) of a Firestore Vector. + + Example: + >>> Function.vector_length("embedding") + + Returns: + A new `Expr` representing the length of the vector. + """ + vector_val = ( + Field.of(vector_expr) if isinstance(vector_expr, str) else vector_expr + ) + return Expr.vector_length(vector_val) + + def timestamp_to_unix_micros(timestamp_expr: Expr | str) -> "TimestampToUnixMicros": + """Creates an expression that converts a timestamp to the number of microseconds since the epoch + (1970-01-01 00:00:00 UTC). + + Truncates higher levels of precision by rounding down to the beginning of the microsecond. + + Example: + >>> Function.timestamp_to_unix_micros("timestamp") + + Returns: + A new `Expr` representing the number of microseconds since the epoch. + """ + timestamp_val = ( + Field.of(timestamp_expr) + if isinstance(timestamp_expr, str) + else timestamp_expr + ) + return Expr.timestamp_to_unix_micros(timestamp_val) + + def unix_micros_to_timestamp(micros_expr: Expr | str) -> "UnixMicrosToTimestamp": + """Creates an expression that converts a number of microseconds since the epoch (1970-01-01 + 00:00:00 UTC) to a timestamp. + + Example: + >>> Function.unix_micros_to_timestamp("microseconds") + + Returns: + A new `Expr` representing the timestamp. + """ + micros_val = ( + Field.of(micros_expr) if isinstance(micros_expr, str) else micros_expr + ) + return Expr.unix_micros_to_timestamp(micros_val) + + def timestamp_to_unix_millis(timestamp_expr: Expr | str) -> "TimestampToUnixMillis": + """Creates an expression that converts a timestamp to the number of milliseconds since the epoch + (1970-01-01 00:00:00 UTC). + + Truncates higher levels of precision by rounding down to the beginning of the millisecond. + + Example: + >>> Function.timestamp_to_unix_millis("timestamp") + + Returns: + A new `Expr` representing the number of milliseconds since the epoch. + """ + timestamp_val = ( + Field.of(timestamp_expr) + if isinstance(timestamp_expr, str) + else timestamp_expr + ) + return Expr.timestamp_to_unix_millis(timestamp_val) + + def unix_millis_to_timestamp(millis_expr: Expr | str) -> "UnixMillisToTimestamp": + """Creates an expression that converts a number of milliseconds since the epoch (1970-01-01 + 00:00:00 UTC) to a timestamp. + + Example: + >>> Function.unix_millis_to_timestamp("milliseconds") + + Returns: + A new `Expr` representing the timestamp. + """ + millis_val = ( + Field.of(millis_expr) if isinstance(millis_expr, str) else millis_expr + ) + return Expr.unix_millis_to_timestamp(millis_val) + + def timestamp_to_unix_seconds( + timestamp_expr: Expr | str, + ) -> "TimestampToUnixSeconds": + """Creates an expression that converts a timestamp to the number of seconds since the epoch + (1970-01-01 00:00:00 UTC). + + Truncates higher levels of precision by rounding down to the beginning of the second. + + Example: + >>> Function.timestamp_to_unix_seconds("timestamp") + + Returns: + A new `Expr` representing the number of seconds since the epoch. + """ + timestamp_val = ( + Field.of(timestamp_expr) + if isinstance(timestamp_expr, str) + else timestamp_expr + ) + return Expr.timestamp_to_unix_seconds(timestamp_val) + + def unix_seconds_to_timestamp(seconds_expr: Expr | str) -> "UnixSecondsToTimestamp": + """Creates an expression that converts a number of seconds since the epoch (1970-01-01 00:00:00 + UTC) to a timestamp. + + Example: + >>> Function.unix_seconds_to_timestamp("seconds") + + Returns: + A new `Expr` representing the timestamp. + """ + seconds_val = ( + Field.of(seconds_expr) if isinstance(seconds_expr, str) else seconds_expr + ) + return Expr.unix_seconds_to_timestamp(seconds_val) + + def timestamp_add( + timestamp: Expr | str, unit: Expr | str, amount: Expr | float + ) -> "TimestampAdd": + """Creates an expression that adds a specified amount of time to this timestamp expression. + + Example: + >>> Function.timestamp_add("timestamp", "day", 1.5) + >>> Function.timestamp_add(Field.of("timestamp"), Field.of("unit"), Field.of("amount")) + + Args: + timestamp: The expression or field path of the timestamp. + unit: The expression or string evaluating to the unit of time to add, must be one of + 'microsecond', 'millisecond', 'second', 'minute', 'hour', 'day'. + amount: The expression or float representing the amount of time to add. + + Returns: + A new `Expr` representing the resulting timestamp. + """ + timestamp_expr = ( + Field.of(timestamp) if isinstance(timestamp, str) else timestamp + ) + return Expr.timestamp_add(timestamp_expr, unit, amount) + + def timestamp_sub( + timestamp: Expr | str, unit: Expr | str, amount: Expr | float + ) -> "TimestampSub": + """Creates an expression that subtracts a specified amount of time from this timestamp expression. + + Example: + >>> Function.timestamp_sub("timestamp", "hour", 2.5) + >>> Function.timestamp_sub(Field.of("timestamp"), Field.of("unit"), Field.of("amount")) + + Args: + timestamp: The expression or field path of the timestamp. + unit: The expression or string evaluating to the unit of time to subtract, must be one of + 'microsecond', 'millisecond', 'second', 'minute', 'hour', 'day'. + amount: The expression or float representing the amount of time to subtract. + + Returns: + A new `Expr` representing the resulting timestamp. + """ + timestamp_expr = ( + Field.of(timestamp) if isinstance(timestamp, str) else timestamp + ) + return Expr.timestamp_sub(timestamp_expr, unit, amount) + + +class Divide(Function): + """Represents the division function.""" + + def __init__(self, left: Expr, right: Expr): + super().__init__("divide", [left, right]) + + +class LogicalMax(Function): + """Represents the logical maximum function based on Firestore type ordering.""" + + def __init__(self, left: Expr, right: Expr): + super().__init__("logical_maximum", [left, right]) + + +class LogicalMin(Function): + """Represents the logical minimum function based on Firestore type ordering.""" + + def __init__(self, left: Expr, right: Expr): + super().__init__("logical_minimum", [left, right]) + + +class MapGet(Function): + """Represents accessing a value within a map by key.""" + + def __init__(self, map_: Expr, key: Constant[str]): + super().__init__("map_get", [map_, key]) + + +class Mod(Function): + """Represents the modulo function.""" + + def __init__(self, left: Expr, right: Expr): + super().__init__("mod", [left, right]) + + +class Multiply(Function): + """Represents the multiplication function.""" + + def __init__(self, left: Expr, right: Expr): + super().__init__("multiply", [left, right]) + + +class Parent(Function): + """Represents getting the parent document reference.""" + + def __init__(self, value: Expr): + super().__init__("parent", [value]) + + +class StrConcat(Function): + """Represents concatenating multiple strings.""" + + def __init__(self, *exprs: Expr): + super().__init__("str_concat", exprs) + + +class Subtract(Function): + """Represents the subtraction function.""" + + def __init__(self, left: Expr, right: Expr): + super().__init__("subtract", [left, right]) + + +class TimestampAdd(Function): + """Represents adding a duration to a timestamp.""" + + def __init__(self, timestamp: Expr, unit: Expr, amount: Expr): + super().__init__("timestamp_add", [timestamp, unit, amount]) + + +class TimestampSub(Function): + """Represents subtracting a duration from a timestamp.""" + + def __init__(self, timestamp: Expr, unit: Expr, amount: Expr): + super().__init__("timestamp_sub", [timestamp, unit, amount]) + + +class TimestampToUnixMicros(Function): + """Represents converting a timestamp to microseconds since epoch.""" + + def __init__(self, input: Expr): + super().__init__("timestamp_to_unix_micros", [input]) + + +class TimestampToUnixMillis(Function): + """Represents converting a timestamp to milliseconds since epoch.""" + + def __init__(self, input: Expr): + super().__init__("timestamp_to_unix_millis", [input]) + + +class TimestampToUnixSeconds(Function): + """Represents converting a timestamp to seconds since epoch.""" + + def __init__(self, input: Expr): + super().__init__("timestamp_to_unix_seconds", [input]) + + +class UnixMicrosToTimestamp(Function): + """Represents converting microseconds since epoch to a timestamp.""" + + def __init__(self, input: Expr): + super().__init__("unix_micros_to_timestamp", [input]) + + +class UnixMillisToTimestamp(Function): + """Represents converting milliseconds since epoch to a timestamp.""" + + def __init__(self, input: Expr): + super().__init__("unix_millis_to_timestamp", [input]) + + +class UnixSecondsToTimestamp(Function): + """Represents converting seconds since epoch to a timestamp.""" + + def __init__(self, input: Expr): + super().__init__("unix_seconds_to_timestamp", [input]) + + +class VectorLength(Function): + """Represents getting the length (dimension) of a vector.""" + + def __init__(self, array: Expr): + super().__init__("vector_length", [array]) + + +class Add(Function): + """Represents the addition function.""" + + def __init__(self, left: Expr, right: Expr): + super().__init__("add", [left, right]) + + +class ArrayElement(Function): + """Represents accessing an element within an array""" + + def __init__(self): + super().__init__("array_element", []) + + +class ArrayFilter(Function): + """Represents filtering elements from an array based on a condition.""" + + def __init__(self, array: Expr, filter: "FilterCondition"): + super().__init__("array_filter", [array, filter]) + + +class ArrayLength(Function): + """Represents getting the length of an array.""" + + def __init__(self, array: Expr): + super().__init__("array_length", [array]) + + +class ArrayReverse(Function): + """Represents reversing the elements of an array.""" + + def __init__(self, array: Expr): + super().__init__("array_reverse", [array]) + + +class ArrayTransform(Function): + """Represents applying a transformation function to each element of an array.""" + + def __init__(self, array: Expr, transform: Function): + super().__init__("array_transform", [array, transform]) + + +class ByteLength(Function): + """Represents getting the byte length of a string (UTF-8).""" + + def __init__(self, expr: Expr): + super().__init__("byte_length", [expr]) + + +class CharLength(Function): + """Represents getting the character length of a string.""" + + def __init__(self, expr: Expr): + super().__init__("char_length", [expr]) + + +class CollectionId(Function): + """Represents getting the collection ID from a document reference.""" + + def __init__(self, value: Expr): + super().__init__("collection_id", [value]) + + +class Accumulator(Function): + """A base class for aggregation functions that operate across multiple inputs.""" + + +class Max(Accumulator): + """Represents the maximum aggregation function.""" + + def __init__(self, value: Expr): + super().__init__("maximum", [value]) + + +class Min(Accumulator): + """Represents the minimum aggregation function.""" + + def __init__(self, value: Expr): + super().__init__("minimum", [value]) + + +class Sum(Accumulator): + """Represents the sum aggregation function.""" + + def __init__(self, value: Expr): + super().__init__("sum", [value]) + + +class Avg(Accumulator): + """Represents the average aggregation function.""" + + def __init__(self, value: Expr): + super().__init__("avg", [value]) + + +class Count(Accumulator): + """Represents an aggregation that counts the total number of inputs.""" + + def __init__(self, value: Expr | None = None): + super().__init__("count", [value] if value else []) + + +class Selectable(Expr): + """Base class for expressions that can be selected or aliased in projection stages.""" + + def __eq__(self, other): + if not isinstance(other, type(self)): + return False + else: + return other._to_map() == self._to_map() + + @abstractmethod + def _to_map(self) -> tuple[str, Value]: + """ + Returns a str: Value representation of the Selectable + """ + raise NotImplementedError + + @classmethod + def _value_from_selectables(cls, *selectables: Selectable) -> Value: + """ + Returns a Value representing a map of Selectables + """ + return Value( + map_value={ + "fields": {m[0]: m[1] for m in [s._to_map() for s in selectables]} + } + ) + + +T = TypeVar("T", bound=Expr) + + +class ExprWithAlias(Selectable, Generic[T]): + """Wraps an expression with an alias.""" + + def __init__(self, expr: T, alias: str): + self.expr = expr + self.alias = alias + + def _to_map(self): + return self.alias, self.expr._to_pb() + + def __repr__(self): + return f"{self.expr}.as_('{self.alias}')" + + def _to_pb(self): + return Value(map_value={"fields": {self.alias: self.expr._to_pb()}}) + + +class Field(Selectable): + """Represents a reference to a field within a document.""" + + DOCUMENT_ID = "__name__" + + def __init__(self, path: str): + """Initializes a Field reference. + + Args: + path: The dot-separated path to the field (e.g., "address.city"). + Use Field.DOCUMENT_ID for the document ID. + """ + self.path = path + + @staticmethod + def of(path: str): + """Creates a Field reference. + + Args: + path: The dot-separated path to the field (e.g., "address.city"). + Use Field.DOCUMENT_ID for the document ID. + + Returns: + A new Field instance. + """ + return Field(path) + + def _to_map(self): + return self.path, self._to_pb() + + def __repr__(self): + return f"Field.of({self.path!r})" + + def _to_pb(self): + return Value(field_reference_value=self.path) + + +class FilterCondition(Function): + """Filters the given data in some way.""" + + def __init__( + self, + *args, + use_infix_repr: bool = True, + infix_name_override: str | None = None, + **kwargs, + ): + self._use_infix_repr = use_infix_repr + self._infix_name_override = infix_name_override + super().__init__(*args, **kwargs) + + def __repr__(self): + """ + Most FilterConditions can be triggered infix. Eg: Field.of('age').gte(18). + + Display them this way in the repr string where possible + """ + if self._use_infix_repr: + infix_name = self._infix_name_override or self.name + if len(self.params) == 1: + return f"{self.params[0]!r}.{infix_name}()" + elif len(self.params) == 2: + return f"{self.params[0]!r}.{infix_name}({self.params[1]!r})" + return super().__repr__() + + @staticmethod + def _from_query_filter_pb(filter_pb, client): + if isinstance(filter_pb, Query_pb.CompositeFilter): + sub_filters = [ + FilterCondition._from_query_filter_pb(f, client) + for f in filter_pb.filters + ] + if filter_pb.op == Query_pb.CompositeFilter.Operator.OR: + return Or(*sub_filters) + elif filter_pb.op == Query_pb.CompositeFilter.Operator.AND: + return And(*sub_filters) + else: + raise TypeError( + f"Unexpected CompositeFilter operator type: {filter_pb.op}" + ) + elif isinstance(filter_pb, Query_pb.UnaryFilter): + field = Field.of(filter_pb.field.field_path) + if filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NAN: + return And(field.exists(), field.is_nan()) + elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NOT_NAN: + return And(field.exists(), Not(field.is_nan())) + elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NULL: + return And(field.exists(), field.eq(None)) + elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NOT_NULL: + return And(field.exists(), Not(field.eq(None))) + else: + raise TypeError(f"Unexpected UnaryFilter operator type: {filter_pb.op}") + elif isinstance(filter_pb, Query_pb.FieldFilter): + field = Field.of(filter_pb.field.field_path) + value = decode_value(filter_pb.value, client) + if filter_pb.op == Query_pb.FieldFilter.Operator.LESS_THAN: + return And(field.exists(), field.lt(value)) + elif filter_pb.op == Query_pb.FieldFilter.Operator.LESS_THAN_OR_EQUAL: + return And(field.exists(), field.lte(value)) + elif filter_pb.op == Query_pb.FieldFilter.Operator.GREATER_THAN: + return And(field.exists(), field.gt(value)) + elif filter_pb.op == Query_pb.FieldFilter.Operator.GREATER_THAN_OR_EQUAL: + return And(field.exists(), field.gte(value)) + elif filter_pb.op == Query_pb.FieldFilter.Operator.EQUAL: + return And(field.exists(), field.eq(value)) + elif filter_pb.op == Query_pb.FieldFilter.Operator.NOT_EQUAL: + return And(field.exists(), field.neq(value)) + if filter_pb.op == Query_pb.FieldFilter.Operator.ARRAY_CONTAINS: + return And(field.exists(), field.array_contains(value)) + elif filter_pb.op == Query_pb.FieldFilter.Operator.ARRAY_CONTAINS_ANY: + return And(field.exists(), field.array_contains_any(value)) + elif filter_pb.op == Query_pb.FieldFilter.Operator.IN: + return And(field.exists(), field.in_any(value)) + elif filter_pb.op == Query_pb.FieldFilter.Operator.NOT_IN: + return And(field.exists(), field.not_in_any(value)) + else: + raise TypeError(f"Unexpected FieldFilter operator type: {filter_pb.op}") + elif isinstance(filter_pb, Query_pb.Filter): + # unwrap oneof + f = ( + filter_pb.composite_filter + or filter_pb.field_filter + or filter_pb.unary_filter + ) + return FilterCondition._from_query_filter_pb(f, client) + else: + raise TypeError(f"Unexpected filter type: {type(filter_pb)}") + + +class And(FilterCondition): + def __init__(self, *conditions: "FilterCondition"): + super().__init__("and", conditions, use_infix_repr=False) + + +class ArrayContains(FilterCondition): + def __init__(self, array: Expr, element: Expr): + super().__init__( + "array_contains", [array, element if element else Constant(None)] + ) + + +class ArrayContainsAll(FilterCondition): + """Represents checking if an array contains all specified elements.""" + + def __init__(self, array: Expr, elements: List[Expr]): + super().__init__("array_contains_all", [array, ListOfExprs(elements)]) + + +class ArrayContainsAny(FilterCondition): + """Represents checking if an array contains any of the specified elements.""" + + def __init__(self, array: Expr, elements: List[Expr]): + super().__init__("array_contains_any", [array, ListOfExprs(elements)]) + + +class EndsWith(FilterCondition): + """Represents checking if a string ends with a specific postfix.""" + + def __init__(self, expr: Expr, postfix: Expr): + super().__init__("ends_with", [expr, postfix]) + + +class Eq(FilterCondition): + """Represents the equality comparison.""" + + def __init__(self, left: Expr, right: Expr): + super().__init__("eq", [left, right if right else Constant(None)]) + + +class Exists(FilterCondition): + """Represents checking if a field exists.""" + + def __init__(self, expr: Expr): + super().__init__("exists", [expr]) + + +class Gt(FilterCondition): + """Represents the greater than comparison.""" + + def __init__(self, left: Expr, right: Expr): + super().__init__("gt", [left, right if right else Constant(None)]) + + +class Gte(FilterCondition): + """Represents the greater than or equal to comparison.""" + + def __init__(self, left: Expr, right: Expr): + super().__init__("gte", [left, right if right else Constant(None)]) + + +class If(FilterCondition): + """Represents a conditional expression (if-then-else).""" + + def __init__(self, condition: "FilterCondition", true_expr: Expr, false_expr: Expr): + super().__init__( + "if", [condition, true_expr, false_expr if false_expr else Constant(None)] + ) + + +class In(FilterCondition): + """Represents checking if an expression's value is within a list of values.""" + + def __init__(self, left: Expr, others: List[Expr]): + super().__init__( + "in", [left, ListOfExprs(others)], infix_name_override="in_any" + ) + + +class IsNaN(FilterCondition): + """Represents checking if a numeric value is NaN.""" + + def __init__(self, value: Expr): + super().__init__("is_nan", [value]) + + +class Like(FilterCondition): + """Represents a case-sensitive wildcard string comparison.""" + + def __init__(self, expr: Expr, pattern: Expr): + super().__init__("like", [expr, pattern]) + + +class Lt(FilterCondition): + """Represents the less than comparison.""" + + def __init__(self, left: Expr, right: Expr): + super().__init__("lt", [left, right if right else Constant(None)]) + + +class Lte(FilterCondition): + """Represents the less than or equal to comparison.""" + + def __init__(self, left: Expr, right: Expr): + super().__init__("lte", [left, right if right else Constant(None)]) + + +class Neq(FilterCondition): + """Represents the inequality comparison.""" + + def __init__(self, left: Expr, right: Expr): + super().__init__("neq", [left, right if right else Constant(None)]) + + +class Not(FilterCondition): + """Represents the logical NOT of a filter condition.""" + + def __init__(self, condition: Expr): + super().__init__("not", [condition], use_infix_repr=False) + + +class Or(FilterCondition): + """Represents the logical OR of multiple filter conditions.""" + + def __init__(self, *conditions: "FilterCondition"): + super().__init__("or", conditions) + + +class RegexContains(FilterCondition): + """Represents checking if a string contains a substring matching a regex.""" + + def __init__(self, expr: Expr, regex: Expr): + super().__init__("regex_contains", [expr, regex]) + + +class RegexMatch(FilterCondition): + """Represents checking if a string fully matches a regex.""" + + def __init__(self, expr: Expr, regex: Expr): + super().__init__("regex_match", [expr, regex]) + + +class StartsWith(FilterCondition): + """Represents checking if a string starts with a specific prefix.""" + + def __init__(self, expr: Expr, prefix: Expr): + super().__init__("starts_with", [expr, prefix]) + + +class StrContains(FilterCondition): + """Represents checking if a string contains a specific substring.""" + + def __init__(self, expr: Expr, substring: Expr): + super().__init__("str_contains", [expr, substring]) + + +class Xor(FilterCondition): + """Represents the logical XOR of multiple filter conditions.""" + + def __init__(self, conditions: List["FilterCondition"]): + super().__init__("xor", conditions, use_infix_repr=False) diff --git a/google/cloud/firestore_v1/pipeline_source.py b/google/cloud/firestore_v1/pipeline_source.py index f2f081fee..6d83ae533 100644 --- a/google/cloud/firestore_v1/pipeline_source.py +++ b/google/cloud/firestore_v1/pipeline_source.py @@ -16,10 +16,12 @@ from typing import Generic, TypeVar, TYPE_CHECKING from google.cloud.firestore_v1 import _pipeline_stages as stages from google.cloud.firestore_v1.base_pipeline import _BasePipeline +from google.cloud.firestore_v1._helpers import DOCUMENT_PATH_DELIMITER if TYPE_CHECKING: # pragma: NO COVER from google.cloud.firestore_v1.client import Client from google.cloud.firestore_v1.async_client import AsyncClient + from google.cloud.firestore_v1.base_document import BaseDocumentReference PipelineType = TypeVar("PipelineType", bound=_BasePipeline) @@ -41,13 +43,45 @@ def __init__(self, client: Client | AsyncClient): def _create_pipeline(self, source_stage): return self.client._pipeline_cls._create_with_stages(self.client, source_stage) - def collection(self, path: str) -> PipelineType: + def collection(self, path: str | tuple[str]) -> PipelineType: """ Creates a new Pipeline that operates on a specified Firestore collection. Args: - path: The path to the Firestore collection (e.g., "users") + path: The path to the Firestore collection (e.g., "users"). Can either be: + * A single ``/``-delimited path to a collection + * A tuple of collection path segment Returns: a new pipeline instance targeting the specified collection """ + if isinstance(path, tuple): + path = DOCUMENT_PATH_DELIMITER.join(path) return self._create_pipeline(stages.Collection(path)) + + def collection_group(self, collection_id: str) -> PipelineType: + """ + Creates a new Pipeline that that operates on all documents in a collection group. + Args: + collection_id: The ID of the collection group + Returns: + a new pipeline instance targeting the specified collection group + """ + return self._create_pipeline(stages.CollectionGroup(collection_id)) + + def database(self) -> PipelineType: + """ + Creates a new Pipeline that operates on all documents in the Firestore database. + Returns: + a new pipeline instance targeting the specified collection + """ + return self._create_pipeline(stages.Database()) + + def documents(self, *docs: "BaseDocumentReference") -> PipelineType: + """ + Creates a new Pipeline that operates on a specific set of Firestore documents. + Args: + docs: The DocumentReference instances representing the documents to include in the pipeline. + Returns: + a new pipeline instance targeting the specified documents + """ + return self._create_pipeline(stages.Documents.of(*docs)) diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml new file mode 100644 index 000000000..dc262f4a9 --- /dev/null +++ b/tests/system/pipeline_e2e.yaml @@ -0,0 +1,1640 @@ +data: + books: + book1: + title: "The Hitchhiker's Guide to the Galaxy" + author: "Douglas Adams" + genre: "Science Fiction" + published: 1979 + rating: 4.2 + tags: + - comedy + - space + - adventure + awards: + hugo: true + nebula: false + book2: + title: "Pride and Prejudice" + author: "Jane Austen" + genre: "Romance" + published: 1813 + rating: 4.5 + tags: + - classic + - social commentary + - love + awards: + none: true + book3: + title: "One Hundred Years of Solitude" + author: "Gabriel García Márquez" + genre: "Magical Realism" + published: 1967 + rating: 4.3 + tags: + - family + - history + - fantasy + awards: + nobel: true + nebula: false + book4: + title: "The Lord of the Rings" + author: "J.R.R. Tolkien" + genre: "Fantasy" + published: 1954 + rating: 4.7 + tags: + - adventure + - magic + - epic + awards: + hugo: false + nebula: false + book5: + title: "The Handmaid's Tale" + author: "Margaret Atwood" + genre: "Dystopian" + published: 1985 + rating: 4.1 + tags: + - feminism + - totalitarianism + - resistance + awards: + arthur c. clarke: true + booker prize: false + book6: + title: "Crime and Punishment" + author: "Fyodor Dostoevsky" + genre: "Psychological Thriller" + published: 1866 + rating: 4.3 + tags: + - philosophy + - crime + - redemption + awards: + none: true + book7: + title: "To Kill a Mockingbird" + author: "Harper Lee" + genre: "Southern Gothic" + published: 1960 + rating: 4.2 + tags: + - racism + - injustice + - coming-of-age + awards: + pulitzer: true + book8: + title: "1984" + author: "George Orwell" + genre: "Dystopian" + published: 1949 + rating: 4.2 + tags: + - surveillance + - totalitarianism + - propaganda + awards: + prometheus: true + book9: + title: "The Great Gatsby" + author: "F. Scott Fitzgerald" + genre: "Modernist" + published: 1925 + rating: 4.0 + tags: + - wealth + - american dream + - love + awards: + none: true + book10: + title: "Dune" + author: "Frank Herbert" + genre: "Science Fiction" + published: 1965 + rating: 4.6 + tags: + - politics + - desert + - ecology + awards: + hugo: true + nebula: true +tests: + - description: "testAggregates - count" + pipeline: + - Collection: books + - Aggregate: + - ExprWithAlias: + - Count + - "count" + assert_results: + - count: 10 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + count: + functionValue: + name: count + - mapValue: {} + name: aggregate + - description: "testAggregates - avg, count, max" + pipeline: + - Collection: books + - Where: + - Eq: + - Field: genre + - Constant: Science Fiction + - Aggregate: + - ExprWithAlias: + - Count + - "count" + - ExprWithAlias: + - Avg: + - Field: rating + - "avg_rating" + - ExprWithAlias: + - Max: + - Field: rating + - "max_rating" + assert_results: + - count: 2 + avg_rating: 4.4 + max_rating: 4.6 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Science Fiction + name: eq + name: where + - args: + - mapValue: + fields: + avg_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: avg + count: + functionValue: + name: count + max_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: maximum + - mapValue: {} + name: aggregate + - description: testGroupBysWithoutAccumulators + pipeline: + - Collection: books + - Where: + - Lt: + - Field: published + - Constant: 1900 + - Aggregate: + accumulators: [] + groups: [genre] + assert_error: ".* requires at least one accumulator" + - description: testGroupBysAndAggregate + pipeline: + - Collection: books + - Where: + - Lt: + - Field: published + - Constant: 1984 + - Aggregate: + accumulators: + - ExprWithAlias: + - Avg: + - Field: rating + - "avg_rating" + groups: [genre] + - Where: + - Gt: + - Field: avg_rating + - Constant: 4.3 + - Sort: + - Ordering: + - Field: avg_rating + - ASCENDING + assert_results: + - avg_rating: 4.4 + genre: Science Fiction + - avg_rating: 4.5 + genre: Romance + - avg_rating: 4.7 + genre: Fantasy + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: published + - integerValue: '1984' + name: lt + name: where + - args: + - mapValue: + fields: + avg_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: avg + - mapValue: + fields: + genre: + fieldReferenceValue: genre + name: aggregate + - args: + - functionValue: + args: + - fieldReferenceValue: avg_rating + - doubleValue: 4.3 + name: gt + name: where + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: avg_rating + name: sort + - description: testMinMax + pipeline: + - Collection: books + - Aggregate: + - ExprWithAlias: + - Count + - "count" + - ExprWithAlias: + - Max: + - Field: rating + - "max_rating" + - ExprWithAlias: + - Min: + - Field: published + - "min_published" + assert_results: + - count: 10 + max_rating: 4.7 + min_published: 1813 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + count: + functionValue: + name: count + max_rating: + functionValue: + args: + - fieldReferenceValue: rating + name: maximum + min_published: + functionValue: + args: + - fieldReferenceValue: published + name: minimum + - mapValue: {} + name: aggregate + - description: selectSpecificFields + pipeline: + - Collection: books + - Select: + - title + - author + - Sort: + - Ordering: + - Field: author + - ASCENDING + assert_results: + - title: "The Hitchhiker's Guide to the Galaxy" + author: "Douglas Adams" + - title: "The Great Gatsby" + author: "F. Scott Fitzgerald" + - title: "Dune" + author: "Frank Herbert" + - title: "Crime and Punishment" + author: "Fyodor Dostoevsky" + - title: "One Hundred Years of Solitude" + author: "Gabriel García Márquez" + - title: "1984" + author: "George Orwell" + - title: "To Kill a Mockingbird" + author: "Harper Lee" + - title: "The Lord of the Rings" + author: "J.R.R. Tolkien" + - title: "Pride and Prejudice" + author: "Jane Austen" + - title: "The Handmaid's Tale" + author: "Margaret Atwood" + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + author: + fieldReferenceValue: author + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: author + name: sort + - description: addAndRemoveFields + pipeline: + - Collection: books + - AddFields: + - ExprWithAlias: + - StrConcat: + - Field: author + - Constant: _ + - Field: title + - "author_title" + - ExprWithAlias: + - StrConcat: + - Field: title + - Constant: _ + - Field: author + - "title_author" + - RemoveFields: + - title_author + - tags + - awards + - rating + - title + - Field: published + - Field: genre + - Field: nestedField # Field does not exist, should be ignored + - Sort: + - Ordering: + - Field: author_title + - ASCENDING + assert_results: + - author: Douglas Adams + author_title: Douglas Adams_The Hitchhiker's Guide to the Galaxy + - author: F. Scott Fitzgerald + author_title: F. Scott Fitzgerald_The Great Gatsby + - author: Frank Herbert + author_title: Frank Herbert_Dune + - author: Fyodor Dostoevsky + author_title: Fyodor Dostoevsky_Crime and Punishment + - author: Gabriel García Márquez + author_title: Gabriel García Márquez_One Hundred Years of Solitude + - author: George Orwell + author_title: George Orwell_1984 + - author: Harper Lee + author_title: Harper Lee_To Kill a Mockingbird + - author: J.R.R. Tolkien + author_title: J.R.R. Tolkien_The Lord of the Rings + - author: Jane Austen + author_title: Jane Austen_Pride and Prejudice + - author: Margaret Atwood + author_title: Margaret Atwood_The Handmaid's Tale + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + author_title: + functionValue: + args: + - fieldReferenceValue: author + - stringValue: _ + - fieldReferenceValue: title + name: str_concat + title_author: + functionValue: + args: + - fieldReferenceValue: title + - stringValue: _ + - fieldReferenceValue: author + name: str_concat + name: add_fields + - args: + - fieldReferenceValue: title_author + - fieldReferenceValue: tags + - fieldReferenceValue: awards + - fieldReferenceValue: rating + - fieldReferenceValue: title + - fieldReferenceValue: published + - fieldReferenceValue: genre + - fieldReferenceValue: nestedField + name: remove_fields + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: author_title + name: sort + - description: whereByMultipleConditions + pipeline: + - Collection: books + - Where: + - And: + - Gt: + - Field: rating + - Constant: 4.5 + - Eq: + - Field: genre + - Constant: Science Fiction + assert_results: + - title: Dune + author: Frank Herbert + genre: Science Fiction + published: 1965 + rating: 4.6 + tags: + - politics + - desert + - ecology + awards: + hugo: true + nebula: true + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: rating + - doubleValue: 4.5 + name: gt + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Science Fiction + name: eq + name: and + name: where + - description: whereByOrCondition + pipeline: + - Collection: books + - Where: + - Or: + - Eq: + - Field: genre + - Constant: Romance + - Eq: + - Field: genre + - Constant: Dystopian + - Select: + - title + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: "1984" + - title: Pride and Prejudice + - title: The Handmaid's Tale + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Romance + name: eq + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Dystopian + name: eq + name: or + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: testPipelineWithOffsetAndLimit + pipeline: + - Collection: books + - Sort: + - Ordering: + - Field: author + - ASCENDING + - Offset: 5 + - Limit: 3 + - Select: + - title + - author + assert_results: + - title: "1984" + author: George Orwell + - title: To Kill a Mockingbird + author: Harper Lee + - title: The Lord of the Rings + author: J.R.R. Tolkien + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: author + name: sort + - args: + - integerValue: '5' + name: offset + - args: + - integerValue: '3' + name: limit + - args: + - mapValue: + fields: + author: + fieldReferenceValue: author + title: + fieldReferenceValue: title + name: select + - description: testArrayContains + pipeline: + - Collection: books + - Where: + - ArrayContains: + - Field: tags + - Constant: comedy + assert_results: + - title: The Hitchhiker's Guide to the Galaxy + author: Douglas Adams + awards: + hugo: true + nebula: false + genre: Science Fiction + published: 1979 + rating: 4.2 + tags: ["comedy", "space", "adventure"] + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: tags + - stringValue: comedy + name: array_contains + name: where + - description: testArrayContainsAny + pipeline: + - Collection: books + - Where: + - ArrayContainsAny: + - Field: tags + - - Constant: comedy + - Constant: classic + - Select: + - title + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: Pride and Prejudice + - title: The Hitchhiker's Guide to the Galaxy + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: tags + - arrayValue: + values: + - stringValue: comedy + - stringValue: classic + name: array_contains_any + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: testArrayContainsAll + pipeline: + - Collection: books + - Where: + - ArrayContainsAll: + - Field: tags + - - Constant: adventure + - Constant: magic + - Select: + - title + assert_results: + - title: The Lord of the Rings + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: tags + - arrayValue: + values: + - stringValue: adventure + - stringValue: magic + name: array_contains_all + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - description: testArrayLength + pipeline: + - Collection: books + - Select: + - ExprWithAlias: + - ArrayLength: + - Field: tags + - "tagsCount" + - Where: + - Eq: + - Field: tagsCount + - Constant: 3 + assert_results: # All documents have 3 tags + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + - tagsCount: 3 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + tagsCount: + functionValue: + args: + - fieldReferenceValue: tags + name: array_length + name: select + - args: + - functionValue: + args: + - fieldReferenceValue: tagsCount + - integerValue: '3' + name: eq + name: where + - description: testStrConcat + pipeline: + - Collection: books + - Sort: + - Ordering: + - Field: author + - ASCENDING + - Select: + - ExprWithAlias: + - StrConcat: + - Field: author + - Constant: " - " + - Field: title + - "bookInfo" + - Limit: 1 + assert_results: + - bookInfo: Douglas Adams - The Hitchhiker's Guide to the Galaxy + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: author + name: sort + - args: + - mapValue: + fields: + bookInfo: + functionValue: + args: + - fieldReferenceValue: author + - stringValue: ' - ' + - fieldReferenceValue: title + name: str_concat + name: select + - args: + - integerValue: '1' + name: limit + - description: testStartsWith + pipeline: + - Collection: books + - Where: + - StartsWith: + - Field: title + - Constant: The + - Select: + - title + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: The Great Gatsby + - title: The Handmaid's Tale + - title: The Hitchhiker's Guide to the Galaxy + - title: The Lord of the Rings + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: The + name: starts_with + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: testEndsWith + pipeline: + - Collection: books + - Where: + - EndsWith: + - Field: title + - Constant: y + - Select: + - title + - Sort: + - Ordering: + - Field: title + - DESCENDING + assert_results: + - title: The Hitchhiker's Guide to the Galaxy + - title: The Great Gatsby + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: y + name: ends_with + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: descending + expression: + fieldReferenceValue: title + name: sort + - description: testLength + pipeline: + - Collection: books + - Select: + - ExprWithAlias: + - CharLength: + - Field: title + - "titleLength" + - title + - Where: + - Gt: + - Field: titleLength + - Constant: 20 + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - titleLength: 29 + title: One Hundred Years of Solitude + - titleLength: 36 + title: The Hitchhiker's Guide to the Galaxy + - titleLength: 21 + title: The Lord of the Rings + - titleLength: 21 + title: To Kill a Mockingbird + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + titleLength: + functionValue: + args: + - fieldReferenceValue: title + name: char_length + name: select + - args: + - functionValue: + args: + - fieldReferenceValue: titleLength + - integerValue: '20' + name: gt + name: where + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: testStringFunctions - CharLength + pipeline: + - Collection: books + - Where: + - Eq: + - Field: author + - Constant: "Douglas Adams" + - Select: + - ExprWithAlias: + - CharLength: + - Field: title + - "title_length" + assert_results: + - title_length: 36 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: eq + name: where + - args: + - mapValue: + fields: + title_length: + functionValue: + args: + - fieldReferenceValue: title + name: char_length + name: select + - description: testStringFunctions - ByteLength + pipeline: + - Collection: books + - Where: + - Eq: + - Field: author + - Constant: Douglas Adams + - Select: + - ExprWithAlias: + - ByteLength: + - StrConcat: + - Field: title + - Constant: _银河系漫游指南 + - "title_byte_length" + assert_results: + - title_byte_length: 58 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: eq + name: where + - args: + - mapValue: + fields: + title_byte_length: + functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "_\u94F6\u6CB3\u7CFB\u6F2B\u6E38\u6307\u5357" + name: str_concat + name: byte_length + name: select + - description: testLike + pipeline: + - Collection: books + - Where: + - Like: + - Field: title + - Constant: "%Guide%" + - Select: + - title + assert_results: + - title: The Hitchhiker's Guide to the Galaxy + - description: testRegexContains + # Find titles that contain either "the" or "of" (case-insensitive) + pipeline: + - Collection: books + - Where: + - RegexContains: + - Field: title + - Constant: "(?i)(the|of)" + assert_count: 5 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: "(?i)(the|of)" + name: regex_contains + name: where + - description: testRegexMatches + # Find titles that contain either "the" or "of" (case-insensitive) + pipeline: + - Collection: books + - Where: + - RegexMatch: + - Field: title + - Constant: ".*(?i)(the|of).*" + assert_count: 5 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: ".*(?i)(the|of).*" + name: regex_match + name: where + - description: testArithmeticOperations + pipeline: + - Collection: books + - Where: + - Eq: + - Field: title + - Constant: To Kill a Mockingbird + - Select: + - ExprWithAlias: + - Add: + - Field: rating + - Constant: 1 + - "ratingPlusOne" + - ExprWithAlias: + - Subtract: + - Field: published + - Constant: 1900 + - "yearsSince1900" + - ExprWithAlias: + - Multiply: + - Field: rating + - Constant: 10 + - "ratingTimesTen" + - ExprWithAlias: + - Divide: + - Field: rating + - Constant: 2 + - "ratingDividedByTwo" + - ExprWithAlias: + - Multiply: + - Field: rating + - Constant: 20 + - "ratingTimes20" + - ExprWithAlias: + - Add: + - Field: rating + - Constant: 3 + - "ratingPlus3" + - ExprWithAlias: + - Mod: + - Field: rating + - Constant: 2 + - "ratingMod2" + assert_results: + - ratingPlusOne: 5.2 + yearsSince1900: 60 + ratingTimesTen: 42.0 + ratingDividedByTwo: 2.1 + ratingTimes20: 84 + ratingPlus3: 7.2 + ratingMod2: 0.20000000000000018 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: To Kill a Mockingbird + name: eq + name: where + - args: + - mapValue: + fields: + ratingDividedByTwo: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '2' + name: divide + ratingPlusOne: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '1' + name: add + ratingTimesTen: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '10' + name: multiply + yearsSince1900: + functionValue: + args: + - fieldReferenceValue: published + - integerValue: '1900' + name: subtract + ratingTimes20: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '20' + name: multiply + ratingPlus3: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '3' + name: add + ratingMod2: + functionValue: + args: + - fieldReferenceValue: rating + - integerValue: '2' + name: mod + name: select + - description: testComparisonOperators + pipeline: + - Collection: books + - Where: + - And: + - Gt: + - Field: rating + - Constant: 4.2 + - Lte: + - Field: rating + - Constant: 4.5 + - Neq: + - Field: genre + - Constant: Science Fiction + - Select: + - rating + - title + - Sort: + - Ordering: + - title + - ASCENDING + assert_results: + - rating: 4.3 + title: Crime and Punishment + - rating: 4.3 + title: One Hundred Years of Solitude + - rating: 4.5 + title: Pride and Prejudice + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: rating + - doubleValue: 4.2 + name: gt + - functionValue: + args: + - fieldReferenceValue: rating + - doubleValue: 4.5 + name: lte + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Science Fiction + name: neq + name: and + name: where + - args: + - mapValue: + fields: + rating: + fieldReferenceValue: rating + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: testLogicalOperators + pipeline: + - Collection: books + - Where: + - Or: + - And: + - Gt: + - Field: rating + - Constant: 4.5 + - Eq: + - Field: genre + - Constant: Science Fiction + - Lt: + - Field: published + - Constant: 1900 + - Select: + - title + - Sort: + - Ordering: + - Field: title + - ASCENDING + assert_results: + - title: Crime and Punishment + - title: Dune + - title: Pride and Prejudice + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: rating + - doubleValue: 4.5 + name: gt + - functionValue: + args: + - fieldReferenceValue: genre + - stringValue: Science Fiction + name: eq + name: and + - functionValue: + args: + - fieldReferenceValue: published + - integerValue: '1900' + name: lt + name: or + name: where + - args: + - mapValue: + fields: + title: + fieldReferenceValue: title + name: select + - args: + - mapValue: + fields: + direction: + stringValue: ascending + expression: + fieldReferenceValue: title + name: sort + - description: testChecks + pipeline: + - Collection: books + - Where: + - Not: + - IsNaN: + - Field: rating + - Select: + - ExprWithAlias: + - Not: + - IsNaN: + - Field: rating + - "ratingIsNotNaN" + - Limit: 1 + assert_results: + - ratingIsNotNaN: true + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: rating + name: is_nan + name: not + name: where + - args: + - mapValue: + fields: + ratingIsNotNaN: + functionValue: + args: + - functionValue: + args: + - fieldReferenceValue: rating + name: is_nan + name: not + name: select + - args: + - integerValue: '1' + name: limit + - description: testLogicalMinMax + pipeline: + - Collection: books + - Where: + - Eq: + - Field: author + - Constant: Douglas Adams + - Select: + - ExprWithAlias: + - LogicalMax: + - Field: rating + - Constant: 4.5 + - "max_rating" + - ExprWithAlias: + - LogicalMax: + - Field: published + - Constant: 1900 + - "max_published" + assert_results: + - max_rating: 4.5 + max_published: 1979 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: eq + name: where + - args: + - mapValue: + fields: + max_published: + functionValue: + args: + - fieldReferenceValue: published + - integerValue: '1900' + name: logical_maximum + max_rating: + functionValue: + args: + - fieldReferenceValue: rating + - doubleValue: 4.5 + name: logical_maximum + name: select + - description: testMapGet + pipeline: + - Collection: books + - Sort: + - Ordering: + - Field: published + - DESCENDING + - Select: + - ExprWithAlias: + - MapGet: + - Field: awards + - Constant: hugo + - "hugoAward" + - Field: title + - Where: + - Eq: + - Field: hugoAward + - Constant: true + assert_results: + - hugoAward: true + title: The Hitchhiker's Guide to the Galaxy + - hugoAward: true + title: Dune + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + direction: + stringValue: descending + expression: + fieldReferenceValue: published + name: sort + - args: + - mapValue: + fields: + hugoAward: + functionValue: + args: + - fieldReferenceValue: awards + - stringValue: hugo + name: map_get + title: + fieldReferenceValue: title + name: select + - args: + - functionValue: + args: + - fieldReferenceValue: hugoAward + - booleanValue: true + name: eq + name: where + - description: testNestedFields + pipeline: + - Collection: books + - Where: + - Eq: + - Field: awards.hugo + - Constant: true + - Sort: + - Ordering: + - Field: title + - DESCENDING + - Select: + - title + - Field: awards.hugo + assert_results: + - title: The Hitchhiker's Guide to the Galaxy + awards.hugo: true + - title: Dune + awards.hugo: true + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: awards.hugo + - booleanValue: true + name: eq + name: where + - args: + - mapValue: + fields: + direction: + stringValue: descending + expression: + fieldReferenceValue: title + name: sort + - args: + - mapValue: + fields: + awards.hugo: + fieldReferenceValue: awards.hugo + title: + fieldReferenceValue: title + name: select + - description: testSampleLimit + pipeline: + - Collection: books + - Sample: 3 + assert_count: 3 # Results will vary due to randomness + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - integerValue: '3' + - stringValue: documents + name: sample + - description: testSamplePercentage + pipeline: + - Collection: books + - Sample: + - SampleOptions: + - 0.6 + - percent + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - doubleValue: 0.6 + - stringValue: percent + name: sample + - description: testUnion + pipeline: + - Collection: books + - Union: + - Pipeline: + - Collection: books + assert_count: 20 # Results will be duplicated + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - pipelineValue: + stages: + - args: + - referenceValue: /books + name: collection + name: union + - description: testUnnest + pipeline: + - Collection: books + - Where: + - Eq: + - Field: title + - Constant: The Hitchhiker's Guide to the Galaxy + - Unnest: + - tags + - tags_alias + - Select: tags_alias + assert_results: + - tags_alias: comedy + - tags_alias: space + - tags_alias: adventure + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: The Hitchhiker's Guide to the Galaxy + name: eq + name: where + - args: + - fieldReferenceValue: tags + - fieldReferenceValue: tags_alias + name: unnest + - args: + - mapValue: + fields: + tags_alias: + fieldReferenceValue: tags_alias + name: select diff --git a/tests/system/test__helpers.py b/tests/system/test__helpers.py index d6ee9b944..c146a5763 100644 --- a/tests/system/test__helpers.py +++ b/tests/system/test__helpers.py @@ -10,7 +10,13 @@ RANDOM_ID_REGEX = re.compile("^[a-zA-Z0-9]{20}$") MISSING_DOCUMENT = "No document to update: " DOCUMENT_EXISTS = "Document already exists: " +ENTERPRISE_MODE_ERROR = "only allowed on ENTERPRISE mode" UNIQUE_RESOURCE_ID = unique_resource_id("-") EMULATOR_CREDS = EmulatorCreds() FIRESTORE_EMULATOR = os.environ.get(_FIRESTORE_EMULATOR_HOST) is not None FIRESTORE_OTHER_DB = os.environ.get("SYSTEM_TESTS_DATABASE", "system-tests-named-db") +FIRESTORE_ENTERPRISE_DB = os.environ.get("ENTERPRISE_DATABASE", "enterprise-db") + +# run all tests against default database, and a named database +# TODO: add enterprise mode when GA (RunQuery not currently supported) +TEST_DATABASES = [None, FIRESTORE_OTHER_DB] diff --git a/tests/system/test_pipeline_acceptance.py b/tests/system/test_pipeline_acceptance.py new file mode 100644 index 000000000..9d44bbc57 --- /dev/null +++ b/tests/system/test_pipeline_acceptance.py @@ -0,0 +1,285 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This file loads and executes yaml-encoded test cases from pipeline_e2e.yaml +""" + +from __future__ import annotations +import os +import pytest +import yaml +import re +from typing import Any + +from google.protobuf.json_format import MessageToDict + +from google.cloud.firestore_v1 import _pipeline_stages as stages +from google.cloud.firestore_v1 import pipeline_expressions +from google.api_core.exceptions import GoogleAPIError + +from google.cloud.firestore import Client, AsyncClient + +from test__helpers import FIRESTORE_ENTERPRISE_DB + +FIRESTORE_PROJECT = os.environ.get("GCLOUD_PROJECT") + +test_dir_name = os.path.dirname(__file__) + + +def yaml_loader(field="tests", file_name="pipeline_e2e.yaml"): + """ + Helper to load test cases or data from yaml file + """ + with open(f"{test_dir_name}/{file_name}") as f: + test_cases = yaml.safe_load(f) + return test_cases[field] + + +@pytest.mark.parametrize( + "test_dict", + [t for t in yaml_loader() if "assert_proto" in t], + ids=lambda x: f"{x.get('description', '')}", +) +def test_pipeline_parse_proto(test_dict, client): + """ + Finds assert_proto statements in yaml, and compares generated proto against expected value + """ + expected_proto = test_dict.get("assert_proto", None) + pipeline = parse_pipeline(client, test_dict["pipeline"]) + # check if proto matches as expected + if expected_proto: + got_proto = MessageToDict(pipeline._to_pb()._pb) + assert yaml.dump(expected_proto) == yaml.dump(got_proto) + + +@pytest.mark.parametrize( + "test_dict", + [t for t in yaml_loader() if "assert_error" in t], + ids=lambda x: f"{x.get('description', '')}", +) +def test_pipeline_expected_errors(test_dict, client): + """ + Finds assert_error statements in yaml, and ensures the pipeline raises the expected error + """ + error_regex = test_dict["assert_error"] + pipeline = parse_pipeline(client, test_dict["pipeline"]) + # check if server responds as expected + with pytest.raises(GoogleAPIError) as err: + pipeline.execute() + found_error = str(err.value) + match = re.search(error_regex, found_error) + assert match, f"error '{found_error}' does not match '{error_regex}'" + + +@pytest.mark.parametrize( + "test_dict", + [t for t in yaml_loader() if "assert_results" in t or "assert_count" in t], + ids=lambda x: f"{x.get('description', '')}", +) +def test_pipeline_results(test_dict, client): + """ + Ensure pipeline returns expected results + """ + expected_results = test_dict.get("assert_results", None) + expected_count = test_dict.get("assert_count", None) + pipeline = parse_pipeline(client, test_dict["pipeline"]) + # check if server responds as expected + got_results = [snapshot.data() for snapshot in pipeline.stream()] + if expected_results: + assert got_results == expected_results + if expected_count is not None: + assert len(got_results) == expected_count + + +@pytest.mark.parametrize( + "test_dict", + [t for t in yaml_loader() if "assert_error" in t], + ids=lambda x: f"{x.get('description', '')}", +) +@pytest.mark.asyncio +async def test_pipeline_expected_errors_async(test_dict, async_client): + """ + Finds assert_error statements in yaml, and ensures the pipeline raises the expected error + """ + error_regex = test_dict["assert_error"] + pipeline = parse_pipeline(async_client, test_dict["pipeline"]) + # check if server responds as expected + with pytest.raises(GoogleAPIError) as err: + await pipeline.execute() + found_error = str(err.value) + match = re.search(error_regex, found_error) + assert match, f"error '{found_error}' does not match '{error_regex}'" + + +@pytest.mark.parametrize( + "test_dict", + [t for t in yaml_loader() if "assert_results" in t or "assert_count" in t], + ids=lambda x: f"{x.get('description', '')}", +) +@pytest.mark.asyncio +async def test_pipeline_results_async(test_dict, async_client): + """ + Ensure pipeline returns expected results + """ + expected_results = test_dict.get("assert_results", None) + expected_count = test_dict.get("assert_count", None) + pipeline = parse_pipeline(async_client, test_dict["pipeline"]) + # check if server responds as expected + got_results = [snapshot.data() async for snapshot in pipeline.stream()] + if expected_results: + assert got_results == expected_results + if expected_count is not None: + assert len(got_results) == expected_count + + +################################################################################# +# Helpers & Fixtures +################################################################################# + + +def parse_pipeline(client, pipeline: list[dict[str, Any], str]): + """ + parse a yaml list of pipeline stages into firestore._pipeline_stages.Stage classes + """ + result_list = [] + for stage in pipeline: + # stage will be either a map of the stage_name and its args, or just the stage_name itself + stage_name: str = stage if isinstance(stage, str) else list(stage.keys())[0] + stage_cls: type[stages.Stage] = getattr(stages, stage_name) + # find arguments if given + if isinstance(stage, dict): + stage_yaml_args = stage[stage_name] + stage_obj = _apply_yaml_args(stage_cls, client, stage_yaml_args) + else: + # yaml has no arguments + stage_obj = stage_cls() + result_list.append(stage_obj) + return client._pipeline_cls._create_with_stages(client, *result_list) + + +def _parse_expressions(client, yaml_element: Any): + """ + Turn yaml objects into pipeline expressions or native python object arguments + """ + if isinstance(yaml_element, list): + return [_parse_expressions(client, v) for v in yaml_element] + elif isinstance(yaml_element, dict): + if len(yaml_element) == 1 and _is_expr_string(next(iter(yaml_element))): + # build pipeline expressions if possible + cls_str = next(iter(yaml_element)) + cls = getattr(pipeline_expressions, cls_str) + yaml_args = yaml_element[cls_str] + return _apply_yaml_args(cls, client, yaml_args) + elif len(yaml_element) == 1 and _is_stage_string(next(iter(yaml_element))): + # build pipeline stage if possible (eg, for SampleOptions) + cls_str = next(iter(yaml_element)) + cls = getattr(stages, cls_str) + yaml_args = yaml_element[cls_str] + return _apply_yaml_args(cls, client, yaml_args) + elif len(yaml_element) == 1 and list(yaml_element)[0] == "Pipeline": + # find Pipeline objects for Union expressions + other_ppl = yaml_element["Pipeline"] + return parse_pipeline(client, other_ppl) + else: + # otherwise, return dict + return { + _parse_expressions(client, k): _parse_expressions(client, v) + for k, v in yaml_element.items() + } + elif _is_expr_string(yaml_element): + return getattr(pipeline_expressions, yaml_element)() + else: + return yaml_element + + +def _apply_yaml_args(cls, client, yaml_args): + """ + Helper to instantiate a class with yaml arguments. The arguments will be applied + as positional or keyword arguments, based on type + """ + if isinstance(yaml_args, dict): + return cls(**_parse_expressions(client, yaml_args)) + elif isinstance(yaml_args, list): + # yaml has an array of arguments. Treat as args + return cls(*_parse_expressions(client, yaml_args)) + else: + # yaml has a single argument + return cls(_parse_expressions(client, yaml_args)) + + +def _is_expr_string(yaml_str): + """ + Returns true if a string represents a class in pipeline_expressions + """ + return ( + isinstance(yaml_str, str) + and yaml_str[0].isupper() + and hasattr(pipeline_expressions, yaml_str) + ) + + +def _is_stage_string(yaml_str): + """ + Returns true if a string represents a class in pipeline_stages + """ + return ( + isinstance(yaml_str, str) + and yaml_str[0].isupper() + and hasattr(stages, yaml_str) + ) + + +@pytest.fixture(scope="module") +def event_loop(): + """Change event_loop fixture to module level.""" + import asyncio + + policy = asyncio.get_event_loop_policy() + loop = policy.new_event_loop() + yield loop + loop.close() + + +@pytest.fixture(scope="module") +def client(): + """ + Build a client to use for requests + """ + client = Client(project=FIRESTORE_PROJECT, database=FIRESTORE_ENTERPRISE_DB) + data = yaml_loader("data") + try: + # setup data + batch = client.batch() + for collection_name, documents in data.items(): + collection_ref = client.collection(collection_name) + for document_id, document_data in documents.items(): + document_ref = collection_ref.document(document_id) + batch.set(document_ref, document_data) + batch.commit() + yield client + finally: + # clear data + for collection_name, documents in data.items(): + collection_ref = client.collection(collection_name) + for document_id in documents: + document_ref = collection_ref.document(document_id) + document_ref.delete() + + +@pytest.fixture(scope="module") +def async_client(client): + """ + Build an async client to use for AsyncPipeline requests + """ + yield AsyncClient(project=client.project, database=client._database) diff --git a/tests/system/test_system.py b/tests/system/test_system.py index c66340de1..9909fb05e 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -38,11 +38,11 @@ EMULATOR_CREDS, FIRESTORE_CREDS, FIRESTORE_EMULATOR, - FIRESTORE_OTHER_DB, FIRESTORE_PROJECT, MISSING_DOCUMENT, RANDOM_ID_REGEX, UNIQUE_RESOURCE_ID, + TEST_DATABASES, ) @@ -80,13 +80,13 @@ def cleanup(): operation() -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_collections(client, database): collections = list(client.collections()) assert isinstance(collections, list) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB]) +@pytest.mark.parametrize("database", TEST_DATABASES) def test_collections_w_import(database): from google.cloud import firestore @@ -103,7 +103,7 @@ def test_collections_w_import(database): FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_collection_stream_or_get_w_no_explain_options(database, query_docs, method): from google.cloud.firestore_v1.query_profile import QueryExplainError @@ -125,7 +125,7 @@ def test_collection_stream_or_get_w_no_explain_options(database, query_docs, met FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["get", "stream"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_collection_stream_or_get_w_explain_options_analyze_false( database, method, query_docs ): @@ -163,7 +163,7 @@ def test_collection_stream_or_get_w_explain_options_analyze_false( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["get", "stream"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_collection_stream_or_get_w_explain_options_analyze_true( database, method, query_docs ): @@ -217,7 +217,7 @@ def test_collection_stream_or_get_w_explain_options_analyze_true( assert len(execution_stats.debug_stats) > 0 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_collections_w_read_time(client, cleanup, database): first_collection_id = "doc-create" + UNIQUE_RESOURCE_ID first_document_id = "doc" + UNIQUE_RESOURCE_ID @@ -228,7 +228,6 @@ def test_collections_w_read_time(client, cleanup, database): data = {"status": "new"} write_result = first_document.create(data) read_time = write_result.update_time - num_collections = len(list(client.collections())) second_collection_id = "doc-create" + UNIQUE_RESOURCE_ID + "-2" second_document_id = "doc" + UNIQUE_RESOURCE_ID + "-2" @@ -238,7 +237,6 @@ def test_collections_w_read_time(client, cleanup, database): # Test that listing current collections does have the second id. curr_collections = list(client.collections()) - assert len(curr_collections) > num_collections ids = [collection.id for collection in curr_collections] assert second_collection_id in ids assert first_collection_id in ids @@ -250,7 +248,7 @@ def test_collections_w_read_time(client, cleanup, database): assert first_collection_id in ids -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_create_document(client, cleanup, database): now = datetime.datetime.now(tz=datetime.timezone.utc) collection_id = "doc-create" + UNIQUE_RESOURCE_ID @@ -295,7 +293,7 @@ def test_create_document(client, cleanup, database): assert stored_data == expected_data -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_create_document_w_vector(client, cleanup, database): collection_id = "doc-create" + UNIQUE_RESOURCE_ID document1 = client.document(collection_id, "doc1") @@ -326,7 +324,7 @@ def on_snapshot(docs, changes, read_time): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @pytest.mark.parametrize( "distance_measure", [ @@ -355,7 +353,7 @@ def test_vector_search_collection(client, database, distance_measure): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @pytest.mark.parametrize( "distance_measure", [ @@ -384,7 +382,7 @@ def test_vector_search_collection_with_filter(client, database, distance_measure @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_vector_search_collection_with_distance_parameters_euclid(client, database): # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" @@ -414,7 +412,7 @@ def test_vector_search_collection_with_distance_parameters_euclid(client, databa @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_vector_search_collection_with_distance_parameters_cosine(client, database): # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" @@ -444,7 +442,7 @@ def test_vector_search_collection_with_distance_parameters_cosine(client, databa @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @pytest.mark.parametrize( "distance_measure", [ @@ -480,7 +478,7 @@ def test_vector_search_collection_group(client, database, distance_measure): DistanceMeasure.COSINE, ], ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_vector_search_collection_group_with_filter(client, database, distance_measure): # Documents and Indexes are a manual step from util/bootstrap_vector_index.py collection_id = "vector_search" @@ -502,7 +500,7 @@ def test_vector_search_collection_group_with_filter(client, database, distance_m @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_vector_search_collection_group_with_distance_parameters_euclid( client, database ): @@ -534,7 +532,7 @@ def test_vector_search_collection_group_with_distance_parameters_euclid( @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_vector_search_collection_group_with_distance_parameters_cosine( client, database ): @@ -569,7 +567,7 @@ def test_vector_search_collection_group_with_distance_parameters_cosine( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_vector_query_stream_or_get_w_no_explain_options(client, database, method): from google.cloud.firestore_v1.query_profile import QueryExplainError @@ -599,7 +597,7 @@ def test_vector_query_stream_or_get_w_no_explain_options(client, database, metho FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_vector_query_stream_or_get_w_explain_options_analyze_true( client, database, method ): @@ -668,7 +666,7 @@ def test_vector_query_stream_or_get_w_explain_options_analyze_true( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_vector_query_stream_or_get_w_explain_options_analyze_false( client, database, method ): @@ -715,7 +713,7 @@ def test_vector_query_stream_or_get_w_explain_options_analyze_false( explain_metrics.execution_stats -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_create_document_w_subcollection(client, cleanup, database): collection_id = "doc-create-sub" + UNIQUE_RESOURCE_ID document_id = "doc" + UNIQUE_RESOURCE_ID @@ -741,7 +739,7 @@ def assert_timestamp_less(timestamp_pb1, timestamp_pb2): assert timestamp_pb1 < timestamp_pb2 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_document_collections_w_read_time(client, cleanup, database): collection_id = "doc-create-sub" + UNIQUE_RESOURCE_ID document_id = "doc" + UNIQUE_RESOURCE_ID @@ -777,7 +775,7 @@ def test_document_collections_w_read_time(client, cleanup, database): ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_no_document(client, database): document_id = "no_document" + UNIQUE_RESOURCE_ID document = client.document("abcde", document_id) @@ -785,7 +783,7 @@ def test_no_document(client, database): assert snapshot.to_dict() is None -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_document_set(client, cleanup, database): document_id = "for-set" + UNIQUE_RESOURCE_ID document = client.document("i-did-it", document_id) @@ -815,7 +813,7 @@ def test_document_set(client, cleanup, database): assert snapshot2.update_time == write_result2.update_time -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_document_integer_field(client, cleanup, database): document_id = "for-set" + UNIQUE_RESOURCE_ID document = client.document("i-did-it", document_id) @@ -832,7 +830,7 @@ def test_document_integer_field(client, cleanup, database): assert snapshot.to_dict() == expected -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_document_set_merge(client, cleanup, database): document_id = "for-set" + UNIQUE_RESOURCE_ID document = client.document("i-did-it", document_id) @@ -865,7 +863,7 @@ def test_document_set_merge(client, cleanup, database): assert snapshot2.update_time == write_result2.update_time -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_document_set_w_int_field(client, cleanup, database): document_id = "set-int-key" + UNIQUE_RESOURCE_ID document = client.document("i-did-it", document_id) @@ -889,7 +887,7 @@ def test_document_set_w_int_field(client, cleanup, database): assert snapshot1.to_dict() == data -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_document_update_w_int_field(client, cleanup, database): # Attempt to reproduce #5489. document_id = "update-int-key" + UNIQUE_RESOURCE_ID @@ -917,7 +915,7 @@ def test_document_update_w_int_field(client, cleanup, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Internal Issue b/137867104") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_update_document(client, cleanup, database): document_id = "for-update" + UNIQUE_RESOURCE_ID document = client.document("made", document_id) @@ -989,7 +987,7 @@ def check_snapshot(snapshot, document, data, write_result): assert snapshot.update_time == write_result.update_time -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_document_get(client, cleanup, database): now = datetime.datetime.now(tz=datetime.timezone.utc) document_id = "for-get" + UNIQUE_RESOURCE_ID @@ -1015,7 +1013,7 @@ def test_document_get(client, cleanup, database): check_snapshot(snapshot, document, data, write_result) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_document_delete(client, cleanup, database): document_id = "deleted" + UNIQUE_RESOURCE_ID document = client.document("here-to-be", document_id) @@ -1052,7 +1050,7 @@ def test_document_delete(client, cleanup, database): assert_timestamp_less(delete_time3, delete_time4) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_collection_add(client, cleanup, database): # TODO(microgen): list_documents is returning a generator, not a list. # Consider if this is desired. Also, Document isn't hashable. @@ -1141,7 +1139,7 @@ def test_collection_add(client, cleanup, database): assert set(collection3.list_documents()) == {document_ref5} -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_list_collections_with_read_time(client, cleanup, database): # TODO(microgen): list_documents is returning a generator, not a list. # Consider if this is desired. Also, Document isn't hashable. @@ -1166,7 +1164,7 @@ def test_list_collections_with_read_time(client, cleanup, database): } -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_unicode_doc(client, cleanup, database): collection_id = "coll-unicode" + UNIQUE_RESOURCE_ID collection = client.collection(collection_id) @@ -1233,7 +1231,7 @@ def query(collection): return collection.where(filter=FieldFilter("a", "==", 1)) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_legacy_where(query_docs, database): """Assert the legacy code still works and returns value""" collection, stored, allowed_vals = query_docs @@ -1249,7 +1247,7 @@ def test_query_stream_legacy_where(query_docs, database): assert value["a"] == 1 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_w_simple_field_eq_op(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("a", "==", 1)) @@ -1260,7 +1258,7 @@ def test_query_stream_w_simple_field_eq_op(query_docs, database): assert value["a"] == 1 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_w_simple_field_array_contains_op(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("c", "array_contains", 1)) @@ -1271,7 +1269,7 @@ def test_query_stream_w_simple_field_array_contains_op(query_docs, database): assert value["a"] == 1 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_w_simple_field_in_op(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1283,7 +1281,7 @@ def test_query_stream_w_simple_field_in_op(query_docs, database): assert value["a"] == 1 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_w_not_eq_op(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.sum", "!=", 4)) @@ -1305,7 +1303,7 @@ def test_query_stream_w_not_eq_op(query_docs, database): assert expected_ab_pairs == ab_pairs2 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_w_simple_not_in_op(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1317,7 +1315,7 @@ def test_query_stream_w_simple_not_in_op(query_docs, database): assert len(values) == 22 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1331,7 +1329,7 @@ def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database) assert value["a"] == 1 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_w_order_by(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.order_by("b", direction=firestore.Query.DESCENDING) @@ -1345,7 +1343,7 @@ def test_query_stream_w_order_by(query_docs, database): assert sorted(b_vals, reverse=True) == b_vals -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_w_field_path(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.sum", ">", 4)) @@ -1367,7 +1365,7 @@ def test_query_stream_w_field_path(query_docs, database): assert expected_ab_pairs == ab_pairs2 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_w_start_end_cursor(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1383,7 +1381,7 @@ def test_query_stream_w_start_end_cursor(query_docs, database): assert value["a"] == num_vals - 2 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_wo_results(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1392,7 +1390,7 @@ def test_query_stream_wo_results(query_docs, database): assert len(values) == 0 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_w_projection(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1409,7 +1407,7 @@ def test_query_stream_w_projection(query_docs, database): assert expected == value -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_w_multiple_filters(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.product", ">", 5)).where( @@ -1429,7 +1427,7 @@ def test_query_stream_w_multiple_filters(query_docs, database): assert pair in matching_pairs -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_w_offset(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1449,7 +1447,7 @@ def test_query_stream_w_offset(query_docs, database): FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_or_get_w_no_explain_options(query_docs, database, method): from google.cloud.firestore_v1.query_profile import QueryExplainError @@ -1471,7 +1469,7 @@ def test_query_stream_or_get_w_no_explain_options(query_docs, database, method): FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_or_get_w_explain_options_analyze_true( query_docs, database, method ): @@ -1531,7 +1529,7 @@ def test_query_stream_or_get_w_explain_options_analyze_true( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_or_get_w_explain_options_analyze_false( query_docs, database, method ): @@ -1571,7 +1569,7 @@ def test_query_stream_or_get_w_explain_options_analyze_false( explain_metrics.execution_stats -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_stream_w_read_time(query_docs, cleanup, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1609,7 +1607,7 @@ def test_query_stream_w_read_time(query_docs, cleanup, database): assert new_values[new_ref.id] == new_data -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_with_order_dot_key(client, cleanup, database): db = client collection_id = "collek" + UNIQUE_RESOURCE_ID @@ -1622,15 +1620,16 @@ def test_query_with_order_dot_key(client, cleanup, database): query = collection.order_by("wordcount.page1").limit(3) data = [doc.to_dict()["wordcount"]["page1"] for doc in query.stream()] assert [100, 110, 120] == data - for snapshot in collection.order_by("wordcount.page1").limit(3).stream(): + query2 = collection.order_by("wordcount.page1").limit(3) + for snapshot in query2.stream(): last_value = snapshot.get("wordcount.page1") cursor_with_nested_keys = {"wordcount": {"page1": last_value}} - found = list( + query3 = ( collection.order_by("wordcount.page1") .start_after(cursor_with_nested_keys) .limit(3) - .stream() ) + found = list(query3.stream()) found_data = [ {"count": 30, "wordcount": {"page1": 130}}, {"count": 40, "wordcount": {"page1": 140}}, @@ -1638,16 +1637,16 @@ def test_query_with_order_dot_key(client, cleanup, database): ] assert found_data == [snap.to_dict() for snap in found] cursor_with_dotted_paths = {"wordcount.page1": last_value} - cursor_with_key_data = list( + query4 = ( collection.order_by("wordcount.page1") .start_after(cursor_with_dotted_paths) .limit(3) - .stream() ) + cursor_with_key_data = list(query4.stream()) assert found_data == [snap.to_dict() for snap in cursor_with_key_data] -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_unary(client, cleanup, database): collection_name = "unary" + UNIQUE_RESOURCE_ID collection = client.collection(collection_name) @@ -1702,7 +1701,7 @@ def test_query_unary(client, cleanup, database): assert snapshot3.to_dict() == {field_name: 123} -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_collection_group_queries(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID @@ -1735,7 +1734,7 @@ def test_collection_group_queries(client, cleanup, database): assert found == expected -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_collection_group_queries_startat_endat(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID @@ -1778,7 +1777,7 @@ def test_collection_group_queries_startat_endat(client, cleanup, database): assert found == set(["cg-doc2"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_collection_group_queries_filters(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID @@ -1847,7 +1846,7 @@ def test_collection_group_queries_filters(client, cleanup, database): @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="PartitionQuery not implemented in emulator" ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_partition_query_no_partitions(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID @@ -1882,7 +1881,7 @@ def test_partition_query_no_partitions(client, cleanup, database): @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="PartitionQuery not implemented in emulator" ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_partition_query(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID n_docs = 128 * 2 + 127 # Minimum partition size is 128 @@ -1910,7 +1909,7 @@ def test_partition_query(client, cleanup, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Internal Issue b/137865992") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_get_all(client, cleanup, database): collection_name = "get-all" + UNIQUE_RESOURCE_ID @@ -1986,7 +1985,7 @@ def test_get_all(client, cleanup, database): check_snapshot(snapshot3, document3, data3, write_result3) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_batch(client, cleanup, database): collection_name = "batch" + UNIQUE_RESOURCE_ID @@ -2032,7 +2031,7 @@ def test_batch(client, cleanup, database): assert not document3.get().exists -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_live_bulk_writer(client, cleanup, database): from google.cloud.firestore_v1.bulk_writer import BulkWriter from google.cloud.firestore_v1.client import Client @@ -2056,7 +2055,7 @@ def test_live_bulk_writer(client, cleanup, database): assert len(col.get()) == 50 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_watch_document(client, cleanup, database): db = client collection_ref = db.collection("wd-users" + UNIQUE_RESOURCE_ID) @@ -2093,7 +2092,7 @@ def on_snapshot(docs, changes, read_time): ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_watch_collection(client, cleanup, database): db = client collection_ref = db.collection("wc-users" + UNIQUE_RESOURCE_ID) @@ -2130,7 +2129,7 @@ def on_snapshot(docs, changes, read_time): ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_watch_query(client, cleanup, database): db = client collection_ref = db.collection("wq-users" + UNIQUE_RESOURCE_ID) @@ -2148,9 +2147,8 @@ def on_snapshot(docs, changes, read_time): on_snapshot.called_count += 1 # A snapshot should return the same thing as if a query ran now. - query_ran = collection_ref.where( - filter=FieldFilter("first", "==", "Ada") - ).stream() + query_ran_query = collection_ref.where(filter=FieldFilter("first", "==", "Ada")) + query_ran = query_ran_query.stream() assert len(docs) == len([i for i in query_ran]) on_snapshot.called_count = 0 @@ -2172,7 +2170,7 @@ def on_snapshot(docs, changes, read_time): ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_array_union(client, cleanup, database): doc_ref = client.document("gcp-7523", "test-document") cleanup(doc_ref.delete) @@ -2319,7 +2317,7 @@ def _do_recursive_delete(client, bulk_writer, empty_philosophers=False): ), f"Snapshot at Socrates{path} should have been deleted" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_recursive_delete_parallelized(client, cleanup, database): from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions, SendMode @@ -2327,7 +2325,7 @@ def test_recursive_delete_parallelized(client, cleanup, database): _do_recursive_delete(client, bw) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_recursive_delete_serialized(client, cleanup, database): from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions, SendMode @@ -2335,7 +2333,7 @@ def test_recursive_delete_serialized(client, cleanup, database): _do_recursive_delete(client, bw) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_recursive_delete_parallelized_empty(client, cleanup, database): from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions, SendMode @@ -2343,7 +2341,7 @@ def test_recursive_delete_parallelized_empty(client, cleanup, database): _do_recursive_delete(client, bw, empty_philosophers=True) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_recursive_delete_serialized_empty(client, cleanup, database): from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions, SendMode @@ -2351,12 +2349,13 @@ def test_recursive_delete_serialized_empty(client, cleanup, database): _do_recursive_delete(client, bw, empty_philosophers=True) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_recursive_query(client, cleanup, database): col_id: str = f"philosophers-recursive-query{UNIQUE_RESOURCE_ID}" _persist_documents(client, col_id, philosophers_data_set, cleanup) - ids = [doc.id for doc in client.collection_group(col_id).recursive().get()] + query = client.collection_group(col_id).recursive() + ids = [doc.id for doc in query.get()] expected_ids = [ # Aristotle doc and subdocs @@ -2390,14 +2389,15 @@ def test_recursive_query(client, cleanup, database): assert ids[index] == expected_ids[index], error_msg -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_nested_recursive_query(client, cleanup, database): col_id: str = f"philosophers-nested-recursive-query{UNIQUE_RESOURCE_ID}" _persist_documents(client, col_id, philosophers_data_set, cleanup) collection_ref = client.collection(col_id) aristotle = collection_ref.document("Aristotle") - ids = [doc.id for doc in aristotle.collection("pets").recursive().get()] + query = aristotle.collection("pets").recursive() + ids = [doc.id for doc in query.get()] expected_ids = [ # Aristotle pets @@ -2414,7 +2414,7 @@ def test_nested_recursive_query(client, cleanup, database): assert ids[index] == expected_ids[index], error_msg -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_chunked_query(client, cleanup, database): col = client.collection(f"chunked-test{UNIQUE_RESOURCE_ID}") for index in range(10): @@ -2429,7 +2429,7 @@ def test_chunked_query(client, cleanup, database): assert len(next(iter)) == 1 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_chunked_query_smaller_limit(client, cleanup, database): col = client.collection(f"chunked-test-smaller-limit{UNIQUE_RESOURCE_ID}") for index in range(10): @@ -2441,7 +2441,7 @@ def test_chunked_query_smaller_limit(client, cleanup, database): assert len(next(iter)) == 5 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_chunked_and_recursive(client, cleanup, database): col_id = f"chunked-recursive-test{UNIQUE_RESOURCE_ID}" documents = [ @@ -2490,7 +2490,7 @@ def test_chunked_and_recursive(client, cleanup, database): assert [doc.id for doc in next(iter)] == page_3_ids -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_watch_query_order(client, cleanup, database): db = client collection_ref = db.collection("users") @@ -2566,7 +2566,7 @@ def on_snapshot(docs, changes, read_time): ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_repro_429(client, cleanup, database): # See: https://github.com/googleapis/python-firestore/issues/429 now = datetime.datetime.now(tz=datetime.timezone.utc) @@ -2594,7 +2594,7 @@ def test_repro_429(client, cleanup, database): print(f"id: {snapshot.id}") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_repro_391(client, cleanup, database): # See: https://github.com/googleapis/python-firestore/issues/391 now = datetime.datetime.now(tz=datetime.timezone.utc) @@ -2609,7 +2609,7 @@ def test_repro_391(client, cleanup, database): assert len(set(collection.stream())) == len(document_ids) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_get_default_alias(query, database): count_query = query.count() result = count_query.get() @@ -2618,7 +2618,7 @@ def test_count_query_get_default_alias(query, database): assert r.alias == "field_1" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_get_with_alias(query, database): count_query = query.count(alias="total") result = count_query.get() @@ -2627,7 +2627,7 @@ def test_count_query_get_with_alias(query, database): assert r.alias == "total" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_get_with_limit(query, database): # count without limit count_query = query.count(alias="total") @@ -2647,7 +2647,7 @@ def test_count_query_get_with_limit(query, database): assert r.value == 2 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_get_multiple_aggregations(query, database): count_query = query.count(alias="total").count(alias="all") @@ -2662,7 +2662,7 @@ def test_count_query_get_multiple_aggregations(query, database): assert found_alias == set(expected_aliases) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_get_multiple_aggregations_duplicated_alias(query, database): count_query = query.count(alias="total").count(alias="total") @@ -2672,7 +2672,7 @@ def test_count_query_get_multiple_aggregations_duplicated_alias(query, database) assert "Aggregation aliases contain duplicate alias" in exc_info.value.message -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_get_empty_aggregation(query, database): from google.cloud.firestore_v1.aggregation import AggregationQuery @@ -2684,7 +2684,7 @@ def test_count_query_get_empty_aggregation(query, database): assert "Aggregations can not be empty" in exc_info.value.message -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_stream_default_alias(query, database): count_query = query.count() for result in count_query.stream(): @@ -2692,7 +2692,7 @@ def test_count_query_stream_default_alias(query, database): assert aggregation_result.alias == "field_1" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_stream_with_alias(query, database): count_query = query.count(alias="total") for result in count_query.stream(): @@ -2700,7 +2700,7 @@ def test_count_query_stream_with_alias(query, database): assert aggregation_result.alias == "total" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_stream_with_limit(query, database): # count without limit count_query = query.count(alias="total") @@ -2718,7 +2718,7 @@ def test_count_query_stream_with_limit(query, database): assert aggregation_result.value == 2 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_stream_multiple_aggregations(query, database): count_query = query.count(alias="total").count(alias="all") @@ -2727,7 +2727,7 @@ def test_count_query_stream_multiple_aggregations(query, database): assert aggregation_result.alias in ["total", "all"] -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_stream_multiple_aggregations_duplicated_alias(query, database): count_query = query.count(alias="total").count(alias="total") @@ -2738,7 +2738,7 @@ def test_count_query_stream_multiple_aggregations_duplicated_alias(query, databa assert "Aggregation aliases contain duplicate alias" in exc_info.value.message -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_stream_empty_aggregation(query, database): from google.cloud.firestore_v1.aggregation import AggregationQuery @@ -2751,7 +2751,7 @@ def test_count_query_stream_empty_aggregation(query, database): assert "Aggregations can not be empty" in exc_info.value.message -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_count_query_with_start_at(query, database): """ Ensure that count aggregation queries work when chained with a start_at @@ -2770,7 +2770,7 @@ def test_count_query_with_start_at(query, database): assert aggregation_result.value == expected_count -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_sum_query_get_default_alias(collection, database): sum_query = collection.sum("stats.product") result = sum_query.get() @@ -2780,7 +2780,7 @@ def test_sum_query_get_default_alias(collection, database): assert r.value == 100 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_sum_query_get_with_alias(collection, database): sum_query = collection.sum("stats.product", alias="total") result = sum_query.get() @@ -2790,7 +2790,7 @@ def test_sum_query_get_with_alias(collection, database): assert r.value == 100 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_sum_query_get_with_limit(collection, database): # sum without limit sum_query = collection.sum("stats.product", alias="total") @@ -2811,7 +2811,7 @@ def test_sum_query_get_with_limit(collection, database): assert r.value == 5 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_sum_query_get_multiple_aggregations(collection, database): sum_query = collection.sum("stats.product", alias="total").sum( "stats.product", alias="all" @@ -2828,7 +2828,7 @@ def test_sum_query_get_multiple_aggregations(collection, database): assert found_alias == set(expected_aliases) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_sum_query_stream_default_alias(collection, database): sum_query = collection.sum("stats.product") for result in sum_query.stream(): @@ -2837,7 +2837,7 @@ def test_sum_query_stream_default_alias(collection, database): assert aggregation_result.value == 100 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_sum_query_stream_with_alias(collection, database): sum_query = collection.sum("stats.product", alias="total") for result in sum_query.stream(): @@ -2846,7 +2846,7 @@ def test_sum_query_stream_with_alias(collection, database): assert aggregation_result.value == 100 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_sum_query_stream_with_limit(collection, database): # sum without limit sum_query = collection.sum("stats.product", alias="total") @@ -2864,7 +2864,7 @@ def test_sum_query_stream_with_limit(collection, database): assert aggregation_result.value == 5 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_sum_query_stream_multiple_aggregations(collection, database): sum_query = collection.sum("stats.product", alias="total").sum( "stats.product", alias="all" @@ -2878,7 +2878,7 @@ def test_sum_query_stream_multiple_aggregations(collection, database): # tests for issue reported in b/306241058 # we will skip test in client for now, until backend fix is implemented @pytest.mark.skip(reason="backend fix required") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_sum_query_with_start_at(query, database): """ Ensure that sum aggregation queries work when chained with a start_at @@ -2896,7 +2896,7 @@ def test_sum_query_with_start_at(query, database): assert sum_result[0].value == expected_sum -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_avg_query_get_default_alias(collection, database): avg_query = collection.avg("stats.product") result = avg_query.get() @@ -2907,7 +2907,7 @@ def test_avg_query_get_default_alias(collection, database): assert isinstance(r.value, float) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_avg_query_get_with_alias(collection, database): avg_query = collection.avg("stats.product", alias="total") result = avg_query.get() @@ -2917,7 +2917,7 @@ def test_avg_query_get_with_alias(collection, database): assert r.value == 4 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_avg_query_get_with_limit(collection, database): # avg without limit avg_query = collection.avg("stats.product", alias="total") @@ -2939,7 +2939,7 @@ def test_avg_query_get_with_limit(collection, database): assert isinstance(r.value, float) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_avg_query_get_multiple_aggregations(collection, database): avg_query = collection.avg("stats.product", alias="total").avg( "stats.product", alias="all" @@ -2956,7 +2956,7 @@ def test_avg_query_get_multiple_aggregations(collection, database): assert found_alias == set(expected_aliases) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_avg_query_stream_default_alias(collection, database): avg_query = collection.avg("stats.product") for result in avg_query.stream(): @@ -2965,7 +2965,7 @@ def test_avg_query_stream_default_alias(collection, database): assert aggregation_result.value == 4 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_avg_query_stream_with_alias(collection, database): avg_query = collection.avg("stats.product", alias="total") for result in avg_query.stream(): @@ -2974,7 +2974,7 @@ def test_avg_query_stream_with_alias(collection, database): assert aggregation_result.value == 4 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_avg_query_stream_with_limit(collection, database): # avg without limit avg_query = collection.avg("stats.product", alias="total") @@ -2992,7 +2992,7 @@ def test_avg_query_stream_with_limit(collection, database): assert aggregation_result.value == 5 / 12 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_avg_query_stream_multiple_aggregations(collection, database): avg_query = collection.avg("stats.product", alias="total").avg( "stats.product", alias="all" @@ -3006,7 +3006,7 @@ def test_avg_query_stream_multiple_aggregations(collection, database): # tests for issue reported in b/306241058 # we will skip test in client for now, until backend fix is implemented @pytest.mark.skip(reason="backend fix required") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_avg_query_with_start_at(query, database): """ Ensure that avg aggregation queries work when chained with a start_at @@ -3030,7 +3030,7 @@ def test_avg_query_with_start_at(query, database): FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_aggregation_query_stream_or_get_w_no_explain_options(query, database, method): # Because all aggregation methods end up calling AggregationQuery.get() or # AggregationQuery.stream(), only use count() for testing here. @@ -3056,7 +3056,7 @@ def test_aggregation_query_stream_or_get_w_no_explain_options(query, database, m FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_aggregation_query_stream_or_get_w_explain_options_analyze_true( query, database, method ): @@ -3120,7 +3120,7 @@ def test_aggregation_query_stream_or_get_w_explain_options_analyze_true( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_aggregation_query_stream_or_get_w_explain_options_analyze_false( query, database, method ): @@ -3160,7 +3160,7 @@ def test_aggregation_query_stream_or_get_w_explain_options_analyze_false( explain_metrics.execution_stats -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_with_and_composite_filter(collection, database): and_filter = And( filters=[ @@ -3175,7 +3175,7 @@ def test_query_with_and_composite_filter(collection, database): assert result.get("stats.product") < 10 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_with_or_composite_filter(collection, database): or_filter = Or( filters=[ @@ -3198,12 +3198,17 @@ def test_query_with_or_composite_filter(collection, database): assert lt_10 > 0 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @pytest.mark.parametrize( "aggregation_type,expected_value", [("count", 5), ("sum", 100), ("avg", 4.0)] ) def test_aggregation_queries_with_read_time( - collection, query, cleanup, database, aggregation_type, expected_value + collection, + query, + cleanup, + database, + aggregation_type, + expected_value, ): """ Ensure that all aggregation queries work when read_time is passed into @@ -3240,7 +3245,7 @@ def test_aggregation_queries_with_read_time( assert r.value == expected_value -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_with_complex_composite_filter(collection, database): field_filter = FieldFilter("b", "==", 0) or_filter = Or( @@ -3289,9 +3294,14 @@ def test_query_with_complex_composite_filter(collection, database): "aggregation_type,aggregation_args,expected", [("count", (), 3), ("sum", ("b"), 12), ("avg", ("b"), 4)], ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_aggregation_query_in_transaction( - client, cleanup, database, aggregation_type, aggregation_args, expected + client, + cleanup, + database, + aggregation_type, + aggregation_args, + expected, ): """ Test creating an aggregation query inside a transaction @@ -3331,7 +3341,7 @@ def in_transaction(transaction): assert inner_fn_ran is True -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_or_query_in_transaction(client, cleanup, database): """ Test running or query inside a transaction. Should pass transaction id along with request @@ -3376,7 +3386,7 @@ def in_transaction(transaction): assert inner_fn_ran is True -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_transaction_w_uuid(client, cleanup, database): """ https://github.com/googleapis/python-firestore/issues/1012 @@ -3401,7 +3411,7 @@ def update_doc(tx, doc_ref, key, value): @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_in_transaction_with_explain_options(client, cleanup, database): """ Test query profiling in transactions. @@ -3453,7 +3463,7 @@ def in_transaction(transaction): assert inner_fn_ran is True -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_query_in_transaction_with_read_time(client, cleanup, database): """ Test query profiling in transactions. @@ -3499,7 +3509,7 @@ def in_transaction(transaction): assert inner_fn_ran is True -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_update_w_uuid(client, cleanup, database): """ https://github.com/googleapis/python-firestore/issues/1012 @@ -3518,7 +3528,7 @@ def test_update_w_uuid(client, cleanup, database): @pytest.mark.parametrize("with_rollback,expected", [(True, 2), (False, 3)]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) def test_transaction_rollback(client, cleanup, database, with_rollback, expected): """ Create a document in a transaction that is rolled back diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index 945e7cb12..bc79ee2df 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -49,11 +49,11 @@ EMULATOR_CREDS, FIRESTORE_CREDS, FIRESTORE_EMULATOR, - FIRESTORE_OTHER_DB, FIRESTORE_PROJECT, MISSING_DOCUMENT, RANDOM_ID_REGEX, UNIQUE_RESOURCE_ID, + TEST_DATABASES, ) RETRIES = retries.AsyncRetry( @@ -169,13 +169,13 @@ def event_loop(): loop.close() -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_collections(client, database): collections = [x async for x in client.collections(retry=RETRIES)] assert isinstance(collections, list) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB]) +@pytest.mark.parametrize("database", TEST_DATABASES) async def test_collections_w_import(database): from google.cloud import firestore @@ -188,7 +188,7 @@ async def test_collections_w_import(database): assert isinstance(collections, list) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_create_document(client, cleanup, database): now = datetime.datetime.now(tz=datetime.timezone.utc) collection_id = "doc-create" + UNIQUE_RESOURCE_ID @@ -234,7 +234,7 @@ async def test_create_document(client, cleanup, database): assert stored_data == expected_data -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_collections_w_read_time(client, cleanup, database): first_collection_id = "doc-create" + UNIQUE_RESOURCE_ID first_document_id = "doc" + UNIQUE_RESOURCE_ID @@ -245,7 +245,6 @@ async def test_collections_w_read_time(client, cleanup, database): data = {"status": "new"} write_result = await first_document.create(data) read_time = write_result.update_time - num_collections = len([x async for x in client.collections(retry=RETRIES)]) second_collection_id = "doc-create" + UNIQUE_RESOURCE_ID + "-2" second_document_id = "doc" + UNIQUE_RESOURCE_ID + "-2" @@ -255,7 +254,6 @@ async def test_collections_w_read_time(client, cleanup, database): # Test that listing current collections does have the second id. curr_collections = [x async for x in client.collections(retry=RETRIES)] - assert len(curr_collections) > num_collections ids = [collection.id for collection in curr_collections] assert second_collection_id in ids assert first_collection_id in ids @@ -269,7 +267,7 @@ async def test_collections_w_read_time(client, cleanup, database): assert first_collection_id in ids -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_create_document_w_subcollection(client, cleanup, database): collection_id = "doc-create-sub" + UNIQUE_RESOURCE_ID document_id = "doc" + UNIQUE_RESOURCE_ID @@ -295,7 +293,7 @@ def assert_timestamp_less(timestamp_pb1, timestamp_pb2): assert timestamp_pb1 < timestamp_pb2 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_document_collections_w_read_time(client, cleanup, database): collection_id = "doc-create-sub" + UNIQUE_RESOURCE_ID document_id = "doc" + UNIQUE_RESOURCE_ID @@ -331,7 +329,7 @@ async def test_document_collections_w_read_time(client, cleanup, database): ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_no_document(client, database): document_id = "no_document" + UNIQUE_RESOURCE_ID document = client.document("abcde", document_id) @@ -339,7 +337,7 @@ async def test_no_document(client, database): assert snapshot.to_dict() is None -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_document_set(client, cleanup, database): document_id = "for-set" + UNIQUE_RESOURCE_ID document = client.document("i-did-it", document_id) @@ -369,7 +367,7 @@ async def test_document_set(client, cleanup, database): assert snapshot2.update_time == write_result2.update_time -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_document_integer_field(client, cleanup, database): document_id = "for-set" + UNIQUE_RESOURCE_ID document = client.document("i-did-it", document_id) @@ -386,7 +384,7 @@ async def test_document_integer_field(client, cleanup, database): assert snapshot.to_dict() == expected -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_document_set_merge(client, cleanup, database): document_id = "for-set" + UNIQUE_RESOURCE_ID document = client.document("i-did-it", document_id) @@ -419,7 +417,7 @@ async def test_document_set_merge(client, cleanup, database): assert snapshot2.update_time == write_result2.update_time -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_document_set_w_int_field(client, cleanup, database): document_id = "set-int-key" + UNIQUE_RESOURCE_ID document = client.document("i-did-it", document_id) @@ -443,7 +441,7 @@ async def test_document_set_w_int_field(client, cleanup, database): assert snapshot1.to_dict() == data -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_document_update_w_int_field(client, cleanup, database): # Attempt to reproduce #5489. document_id = "update-int-key" + UNIQUE_RESOURCE_ID @@ -471,7 +469,7 @@ async def test_document_update_w_int_field(client, cleanup, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @pytest.mark.parametrize( "distance_measure", [ @@ -499,7 +497,7 @@ async def test_vector_search_collection(client, database, distance_measure): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @pytest.mark.parametrize( "distance_measure", [ @@ -527,7 +525,7 @@ async def test_vector_search_collection_with_filter(client, database, distance_m @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_vector_search_collection_with_distance_parameters_euclid( client, database ): @@ -559,7 +557,7 @@ async def test_vector_search_collection_with_distance_parameters_euclid( @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_vector_search_collection_with_distance_parameters_cosine( client, database ): @@ -591,7 +589,7 @@ async def test_vector_search_collection_with_distance_parameters_cosine( @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @pytest.mark.parametrize( "distance_measure", [ @@ -620,7 +618,7 @@ async def test_vector_search_collection_group(client, database, distance_measure @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @pytest.mark.parametrize( "distance_measure", [ @@ -651,7 +649,7 @@ async def test_vector_search_collection_group_with_filter( @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_vector_search_collection_group_with_distance_parameters_euclid( client, database ): @@ -683,7 +681,7 @@ async def test_vector_search_collection_group_with_distance_parameters_euclid( @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_vector_search_collection_group_with_distance_parameters_cosine( client, database ): @@ -718,7 +716,7 @@ async def test_vector_search_collection_group_with_distance_parameters_cosine( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_vector_query_stream_or_get_w_no_explain_options( client, database, method ): @@ -753,7 +751,7 @@ async def test_vector_query_stream_or_get_w_no_explain_options( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_vector_query_stream_or_get_w_explain_options_analyze_true( client, query_docs, database, method ): @@ -832,7 +830,7 @@ async def test_vector_query_stream_or_get_w_explain_options_analyze_true( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_vector_query_stream_or_get_w_explain_options_analyze_false( client, query_docs, database, method ): @@ -895,7 +893,7 @@ async def test_vector_query_stream_or_get_w_explain_options_analyze_false( @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Internal Issue b/137867104") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_update_document(client, cleanup, database): document_id = "for-update" + UNIQUE_RESOURCE_ID document = client.document("made", document_id) @@ -968,7 +966,7 @@ def check_snapshot(snapshot, document, data, write_result): assert snapshot.update_time == write_result.update_time -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_document_get(client, cleanup, database): now = datetime.datetime.now(tz=datetime.timezone.utc) document_id = "for-get" + UNIQUE_RESOURCE_ID @@ -994,7 +992,7 @@ async def test_document_get(client, cleanup, database): check_snapshot(snapshot, document, data, write_result) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_document_delete(client, cleanup, database): document_id = "deleted" + UNIQUE_RESOURCE_ID document = client.document("here-to-be", document_id) @@ -1031,7 +1029,7 @@ async def test_document_delete(client, cleanup, database): assert_timestamp_less(delete_time3, delete_time4) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_collection_add(client, cleanup, database): # TODO(microgen): list_documents is returning a generator, not a list. # Consider if this is desired. Also, Document isn't hashable. @@ -1133,7 +1131,7 @@ async def test_collection_add(client, cleanup, database): assert set([i async for i in collection3.list_documents()]) == {document_ref5} -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_list_collections_with_read_time(client, cleanup, database): # TODO(microgen): list_documents is returning a generator, not a list. # Consider if this is desired. Also, Document isn't hashable. @@ -1205,7 +1203,7 @@ async def async_query(collection): return collection.where(filter=FieldFilter("a", "==", 1)) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_legacy_where(query_docs, database): """Assert the legacy code still works and returns value, and shows UserWarning""" collection, stored, allowed_vals = query_docs @@ -1221,7 +1219,7 @@ async def test_query_stream_legacy_where(query_docs, database): assert value["a"] == 1 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_w_simple_field_eq_op(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("a", "==", 1)) @@ -1232,7 +1230,7 @@ async def test_query_stream_w_simple_field_eq_op(query_docs, database): assert value["a"] == 1 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_w_simple_field_array_contains_op(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("c", "array_contains", 1)) @@ -1243,7 +1241,7 @@ async def test_query_stream_w_simple_field_array_contains_op(query_docs, databas assert value["a"] == 1 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_w_simple_field_in_op(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1255,7 +1253,7 @@ async def test_query_stream_w_simple_field_in_op(query_docs, database): assert value["a"] == 1 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_w_simple_field_array_contains_any_op(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1269,7 +1267,7 @@ async def test_query_stream_w_simple_field_array_contains_any_op(query_docs, dat assert value["a"] == 1 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_w_order_by(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.order_by("b", direction=firestore.Query.DESCENDING) @@ -1283,7 +1281,7 @@ async def test_query_stream_w_order_by(query_docs, database): assert sorted(b_vals, reverse=True) == b_vals -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_w_field_path(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.sum", ">", 4)) @@ -1305,7 +1303,7 @@ async def test_query_stream_w_field_path(query_docs, database): assert expected_ab_pairs == ab_pairs2 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_w_start_end_cursor(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1321,7 +1319,7 @@ async def test_query_stream_w_start_end_cursor(query_docs, database): assert value["a"] == num_vals - 2 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_wo_results(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1330,7 +1328,7 @@ async def test_query_stream_wo_results(query_docs, database): assert len(values) == 0 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_w_projection(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1347,7 +1345,7 @@ async def test_query_stream_w_projection(query_docs, database): assert expected == value -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_w_multiple_filters(query_docs, database): collection, stored, allowed_vals = query_docs query = collection.where(filter=FieldFilter("stats.product", ">", 5)).where( @@ -1367,7 +1365,7 @@ async def test_query_stream_w_multiple_filters(query_docs, database): assert pair in matching_pairs -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_w_offset(query_docs, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1387,7 +1385,7 @@ async def test_query_stream_w_offset(query_docs, database): FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_or_get_w_no_explain_options(query_docs, database, method): from google.cloud.firestore_v1.query_profile import QueryExplainError @@ -1412,7 +1410,7 @@ async def test_query_stream_or_get_w_no_explain_options(query_docs, database, me FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_or_get_w_explain_options_analyze_true( query_docs, database, method ): @@ -1457,7 +1455,7 @@ async def test_query_stream_or_get_w_explain_options_analyze_true( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_or_get_w_explain_options_analyze_false( query_docs, database, method ): @@ -1492,7 +1490,7 @@ async def test_query_stream_or_get_w_explain_options_analyze_false( _verify_explain_metrics_analyze_false(explain_metrics) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_stream_w_read_time(query_docs, cleanup, database): collection, stored, allowed_vals = query_docs num_vals = len(allowed_vals) @@ -1532,7 +1530,7 @@ async def test_query_stream_w_read_time(query_docs, cleanup, database): assert new_values[new_ref.id] == new_data -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_with_order_dot_key(client, cleanup, database): db = client collection_id = "collek" + UNIQUE_RESOURCE_ID @@ -1572,7 +1570,7 @@ async def test_query_with_order_dot_key(client, cleanup, database): assert found_data == [snap.to_dict() for snap in cursor_with_key_data] -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_unary(client, cleanup, database): collection_name = "unary" + UNIQUE_RESOURCE_ID collection = client.collection(collection_name) @@ -1627,7 +1625,7 @@ async def test_query_unary(client, cleanup, database): assert snapshot3.to_dict() == {field_name: 123} -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_collection_group_queries(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID @@ -1660,7 +1658,7 @@ async def test_collection_group_queries(client, cleanup, database): assert found == expected -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_collection_group_queries_startat_endat(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID @@ -1703,7 +1701,7 @@ async def test_collection_group_queries_startat_endat(client, cleanup, database) assert found == set(["cg-doc2"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_collection_group_queries_filters(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID @@ -1772,7 +1770,7 @@ async def test_collection_group_queries_filters(client, cleanup, database): FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_collection_stream_or_get_w_no_explain_options( query_docs, database, method ): @@ -1797,7 +1795,7 @@ async def test_collection_stream_or_get_w_no_explain_options( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_collection_stream_or_get_w_explain_options_analyze_true( query_docs, database, method ): @@ -1865,7 +1863,7 @@ async def test_collection_stream_or_get_w_explain_options_analyze_true( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) @pytest.mark.parametrize("method", ["stream", "get"]) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_collection_stream_or_get_w_explain_options_analyze_false( query_docs, database, method ): @@ -1919,7 +1917,7 @@ async def test_collection_stream_or_get_w_explain_options_analyze_false( @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="PartitionQuery not implemented in emulator" ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_partition_query_no_partitions(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID @@ -1953,7 +1951,7 @@ async def test_partition_query_no_partitions(client, cleanup, database): @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="PartitionQuery not implemented in emulator" ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_partition_query(client, cleanup, database): collection_group = "b" + UNIQUE_RESOURCE_ID n_docs = 128 * 2 + 127 # Minimum partition size is 128 @@ -1980,7 +1978,7 @@ async def test_partition_query(client, cleanup, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Internal Issue b/137865992") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_get_all(client, cleanup, database): collection_name = "get-all" + UNIQUE_RESOURCE_ID @@ -2053,7 +2051,7 @@ async def test_get_all(client, cleanup, database): assert not snapshots[2].exists -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_live_bulk_writer(client, cleanup, database): from google.cloud.firestore_v1.async_client import AsyncClient from google.cloud.firestore_v1.bulk_writer import BulkWriter @@ -2077,7 +2075,7 @@ async def test_live_bulk_writer(client, cleanup, database): assert len(await col.get()) == 50 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_batch(client, cleanup, database): collection_name = "batch" + UNIQUE_RESOURCE_ID @@ -2244,7 +2242,7 @@ async def _do_recursive_delete(client, bulk_writer, empty_philosophers=False): ), f"Snapshot at Socrates{path} should have been deleted" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_recursive_delete_parallelized(client, cleanup, database): from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions, SendMode @@ -2252,7 +2250,7 @@ async def test_async_recursive_delete_parallelized(client, cleanup, database): await _do_recursive_delete(client, bw) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_recursive_delete_serialized(client, cleanup, database): from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions, SendMode @@ -2260,7 +2258,7 @@ async def test_async_recursive_delete_serialized(client, cleanup, database): await _do_recursive_delete(client, bw) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_recursive_delete_parallelized_empty(client, cleanup, database): from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions, SendMode @@ -2268,7 +2266,7 @@ async def test_async_recursive_delete_parallelized_empty(client, cleanup, databa await _do_recursive_delete(client, bw, empty_philosophers=True) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_recursive_delete_serialized_empty(client, cleanup, database): from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions, SendMode @@ -2276,7 +2274,7 @@ async def test_async_recursive_delete_serialized_empty(client, cleanup, database await _do_recursive_delete(client, bw, empty_philosophers=True) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_recursive_query(client, cleanup, database): col_id: str = f"philosophers-recursive-async-query{UNIQUE_RESOURCE_ID}" await _persist_documents(client, col_id, philosophers_data_set, cleanup) @@ -2315,7 +2313,7 @@ async def test_recursive_query(client, cleanup, database): assert ids[index] == expected_ids[index], error_msg -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_nested_recursive_query(client, cleanup, database): col_id: str = f"philosophers-nested-recursive-async-query{UNIQUE_RESOURCE_ID}" await _persist_documents(client, col_id, philosophers_data_set, cleanup) @@ -2339,7 +2337,7 @@ async def test_nested_recursive_query(client, cleanup, database): assert ids[index] == expected_ids[index], error_msg -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_chunked_query(client, cleanup, database): col = client.collection(f"async-chunked-test{UNIQUE_RESOURCE_ID}") for index in range(10): @@ -2355,7 +2353,7 @@ async def test_chunked_query(client, cleanup, database): assert lengths[3] == 1 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_chunked_query_smaller_limit(client, cleanup, database): col = client.collection(f"chunked-test-smaller-limit{UNIQUE_RESOURCE_ID}") for index in range(10): @@ -2368,7 +2366,7 @@ async def test_chunked_query_smaller_limit(client, cleanup, database): assert lengths[0] == 5 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_chunked_and_recursive(client, cleanup, database): col_id = f"chunked-async-recursive-test{UNIQUE_RESOURCE_ID}" documents = [ @@ -2427,7 +2425,7 @@ async def _chain(*iterators): yield value -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_count_async_query_get_default_alias(async_query, database): count_query = async_query.count() result = await count_query.get() @@ -2435,7 +2433,7 @@ async def test_count_async_query_get_default_alias(async_query, database): assert r.alias == "field_1" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_count_query_get_with_alias(async_query, database): count_query = async_query.count(alias="total") result = await count_query.get() @@ -2443,7 +2441,7 @@ async def test_async_count_query_get_with_alias(async_query, database): assert r.alias == "total" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_count_query_get_with_limit(async_query, database): count_query = async_query.count(alias="total") result = await count_query.get() @@ -2459,7 +2457,7 @@ async def test_async_count_query_get_with_limit(async_query, database): assert r.value == 2 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_count_query_get_multiple_aggregations(async_query, database): count_query = async_query.count(alias="total").count(alias="all") @@ -2474,7 +2472,7 @@ async def test_async_count_query_get_multiple_aggregations(async_query, database assert found_alias == set(expected_aliases) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_count_query_get_multiple_aggregations_duplicated_alias( async_query, database ): @@ -2486,7 +2484,7 @@ async def test_async_count_query_get_multiple_aggregations_duplicated_alias( assert "Aggregation aliases contain duplicate alias" in exc_info.value.message -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_count_query_get_empty_aggregation(async_query, database): from google.cloud.firestore_v1.async_aggregation import AsyncAggregationQuery @@ -2498,7 +2496,7 @@ async def test_async_count_query_get_empty_aggregation(async_query, database): assert "Aggregations can not be empty" in exc_info.value.message -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_count_query_stream_default_alias(async_query, database): count_query = async_query.count() @@ -2507,7 +2505,7 @@ async def test_async_count_query_stream_default_alias(async_query, database): assert aggregation_result.alias == "field_1" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_count_query_stream_with_alias(async_query, database): count_query = async_query.count(alias="total") async for result in count_query.stream(): @@ -2515,7 +2513,7 @@ async def test_async_count_query_stream_with_alias(async_query, database): assert aggregation_result.alias == "total" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_count_query_stream_with_limit(async_query, database): # count without limit count_query = async_query.count(alias="total") @@ -2530,7 +2528,7 @@ async def test_async_count_query_stream_with_limit(async_query, database): assert aggregation_result.value == 2 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_count_query_stream_multiple_aggregations(async_query, database): count_query = async_query.count(alias="total").count(alias="all") @@ -2540,7 +2538,7 @@ async def test_async_count_query_stream_multiple_aggregations(async_query, datab assert aggregation_result.alias in ["total", "all"] -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_count_query_stream_multiple_aggregations_duplicated_alias( async_query, database ): @@ -2553,7 +2551,7 @@ async def test_async_count_query_stream_multiple_aggregations_duplicated_alias( assert "Aggregation aliases contain duplicate alias" in exc_info.value.message -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_count_query_stream_empty_aggregation(async_query, database): from google.cloud.firestore_v1.async_aggregation import AsyncAggregationQuery @@ -2566,7 +2564,7 @@ async def test_async_count_query_stream_empty_aggregation(async_query, database) assert "Aggregations can not be empty" in exc_info.value.message -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_sum_query_get_default_alias(collection, database): sum_query = collection.sum("stats.product") result = await sum_query.get() @@ -2575,7 +2573,7 @@ async def test_async_sum_query_get_default_alias(collection, database): assert r.value == 100 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_sum_query_get_with_alias(collection, database): sum_query = collection.sum("stats.product", alias="total") result = await sum_query.get() @@ -2584,7 +2582,7 @@ async def test_async_sum_query_get_with_alias(collection, database): assert r.value == 100 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_sum_query_get_with_limit(collection, database): sum_query = collection.sum("stats.product", alias="total") result = await sum_query.get() @@ -2600,7 +2598,7 @@ async def test_async_sum_query_get_with_limit(collection, database): assert r.value == 5 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_sum_query_get_multiple_aggregations(collection, database): sum_query = collection.sum("stats.product", alias="total").sum( "stats.product", alias="all" @@ -2617,7 +2615,7 @@ async def test_async_sum_query_get_multiple_aggregations(collection, database): assert found_alias == set(expected_aliases) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_sum_query_stream_default_alias(collection, database): sum_query = collection.sum("stats.product") @@ -2627,7 +2625,7 @@ async def test_async_sum_query_stream_default_alias(collection, database): assert aggregation_result.value == 100 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_sum_query_stream_with_alias(collection, database): sum_query = collection.sum("stats.product", alias="total") async for result in sum_query.stream(): @@ -2635,7 +2633,7 @@ async def test_async_sum_query_stream_with_alias(collection, database): assert aggregation_result.alias == "total" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_sum_query_stream_with_limit(collection, database): # sum without limit sum_query = collection.sum("stats.product", alias="total") @@ -2650,7 +2648,7 @@ async def test_async_sum_query_stream_with_limit(collection, database): assert aggregation_result.value == 5 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_sum_query_stream_multiple_aggregations(collection, database): sum_query = collection.sum("stats.product", alias="total").sum( "stats.product", alias="all" @@ -2662,7 +2660,7 @@ async def test_async_sum_query_stream_multiple_aggregations(collection, database assert aggregation_result.alias in ["total", "all"] -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_get_default_alias(collection, database): avg_query = collection.avg("stats.product") result = await avg_query.get() @@ -2672,7 +2670,7 @@ async def test_async_avg_query_get_default_alias(collection, database): assert isinstance(r.value, float) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_get_with_alias(collection, database): avg_query = collection.avg("stats.product", alias="total") result = await avg_query.get() @@ -2681,7 +2679,7 @@ async def test_async_avg_query_get_with_alias(collection, database): assert r.value == 4 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_get_with_limit(collection, database): avg_query = collection.avg("stats.product", alias="total") result = await avg_query.get() @@ -2697,7 +2695,7 @@ async def test_async_avg_query_get_with_limit(collection, database): assert r.value == 5 / 12 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_get_multiple_aggregations(collection, database): avg_query = collection.avg("stats.product", alias="total").avg( "stats.product", alias="all" @@ -2714,7 +2712,7 @@ async def test_async_avg_query_get_multiple_aggregations(collection, database): assert found_alias == set(expected_aliases) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_get_w_no_explain_options(collection, database): avg_query = collection.avg("stats.product", alias="total") results = await avg_query.get() @@ -2725,7 +2723,7 @@ async def test_async_avg_query_get_w_no_explain_options(collection, database): @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_get_w_explain_options_analyze_true(collection, database): avg_query = collection.avg("stats.product", alias="total") results = await avg_query.get(explain_options=ExplainOptions(analyze=True)) @@ -2760,7 +2758,7 @@ async def test_async_avg_query_get_w_explain_options_analyze_true(collection, da @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_get_w_explain_options_analyze_false( collection, database ): @@ -2791,7 +2789,7 @@ async def test_async_avg_query_get_w_explain_options_analyze_false( explain_metrics.execution_stats -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_stream_default_alias(collection, database): avg_query = collection.avg("stats.product") @@ -2802,7 +2800,7 @@ async def test_async_avg_query_stream_default_alias(collection, database): assert isinstance(aggregation_result.value, float) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_stream_with_alias(collection, database): avg_query = collection.avg("stats.product", alias="total") async for result in avg_query.stream(): @@ -2810,7 +2808,7 @@ async def test_async_avg_query_stream_with_alias(collection, database): assert aggregation_result.alias == "total" -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_stream_with_limit(collection, database): # avg without limit avg_query = collection.avg("stats.product", alias="total") @@ -2826,7 +2824,7 @@ async def test_async_avg_query_stream_with_limit(collection, database): assert isinstance(aggregation_result.value, float) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_stream_multiple_aggregations(collection, database): avg_query = collection.avg("stats.product", alias="total").avg( "stats.product", alias="all" @@ -2838,7 +2836,7 @@ async def test_async_avg_query_stream_multiple_aggregations(collection, database assert aggregation_result.alias in ["total", "all"] -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_stream_w_no_explain_options(collection, database): avg_query = collection.avg("stats.product", alias="total") results = avg_query.stream() @@ -2849,7 +2847,7 @@ async def test_async_avg_query_stream_w_no_explain_options(collection, database) @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_stream_w_explain_options_analyze_true( collection, database ): @@ -2894,7 +2892,7 @@ async def test_async_avg_query_stream_w_explain_options_analyze_true( @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_async_avg_query_stream_w_explain_options_analyze_false( collection, database ): @@ -2926,7 +2924,7 @@ async def test_async_avg_query_stream_w_explain_options_analyze_false( explain_metrics.execution_stats -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) @pytest.mark.parametrize( "aggregation_type,expected_value", [("count", 5), ("sum", 100), ("avg", 4.0)] ) @@ -2989,7 +2987,7 @@ async def create_in_transaction_helper( raise ValueError("Collection can't have more than 2 docs") -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_count_query_in_transaction(client, cleanup, database): collection_id = "doc-create" + UNIQUE_RESOURCE_ID document_id_1 = "doc1" + UNIQUE_RESOURCE_ID @@ -3021,7 +3019,7 @@ async def test_count_query_in_transaction(client, cleanup, database): assert r.value == 2 # there are still only 2 docs -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_with_and_composite_filter(query_docs, database): collection, stored, allowed_vals = query_docs and_filter = And( @@ -3037,7 +3035,7 @@ async def test_query_with_and_composite_filter(query_docs, database): assert result.get("stats.product") < 10 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_with_or_composite_filter(query_docs, database): collection, stored, allowed_vals = query_docs or_filter = Or( @@ -3061,7 +3059,7 @@ async def test_query_with_or_composite_filter(query_docs, database): assert lt_10 > 0 -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_with_complex_composite_filter(query_docs, database): collection, stored, allowed_vals = query_docs field_filter = FieldFilter("b", "==", 0) @@ -3107,7 +3105,7 @@ async def test_query_with_complex_composite_filter(query_docs, database): assert b_not_3 is True -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_or_query_in_transaction(client, cleanup, database): collection_id = "doc-create" + UNIQUE_RESOURCE_ID document_id_1 = "doc1" + UNIQUE_RESOURCE_ID @@ -3171,7 +3169,7 @@ async def _make_transaction_query(client, cleanup): @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_transaction_w_query_w_no_explain_options(client, cleanup, database): from google.cloud.firestore_v1.query_profile import QueryExplainError @@ -3204,7 +3202,7 @@ async def in_transaction(transaction): @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_transaction_w_query_w_explain_options_analyze_true( client, cleanup, database ): @@ -3246,7 +3244,7 @@ async def in_transaction(transaction): @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_transaction_w_query_w_explain_options_analyze_false( client, cleanup, database ): @@ -3283,7 +3281,7 @@ async def in_transaction(transaction): @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_in_transaction_w_no_explain_options(client, cleanup, database): from google.cloud.firestore_v1.query_profile import QueryExplainError @@ -3316,7 +3314,7 @@ async def in_transaction(transaction): @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_in_transaction_w_explain_options_analyze_true( client, cleanup, database ): @@ -3350,7 +3348,7 @@ async def in_transaction(transaction): @pytest.mark.skipif( FIRESTORE_EMULATOR, reason="Query profile not supported in emulator." ) -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_in_transaction_w_explain_options_analyze_false( client, cleanup, database ): @@ -3383,7 +3381,7 @@ async def in_transaction(transaction): assert inner_fn_ran is True -@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +@pytest.mark.parametrize("database", TEST_DATABASES, indirect=True) async def test_query_in_transaction_with_read_time(client, cleanup, database): """ Test query profiling in transactions. diff --git a/tests/unit/v1/test_async_pipeline.py b/tests/unit/v1/test_async_pipeline.py index 3abc3619b..47eedc983 100644 --- a/tests/unit/v1/test_async_pipeline.py +++ b/tests/unit/v1/test_async_pipeline.py @@ -16,6 +16,8 @@ import pytest from google.cloud.firestore_v1 import _pipeline_stages as stages +from google.cloud.firestore_v1.pipeline_expressions import Field +from google.cloud.firestore_v1.pipeline_expressions import Exists def _make_async_pipeline(*args, client=mock.Mock()): @@ -379,8 +381,34 @@ async def test_async_pipeline_stream_stream_equivalence_mocked(): @pytest.mark.parametrize( "method,args,result_cls", [ - ("generic_stage", ("name",), stages.GenericStage), - ("generic_stage", ("name", mock.Mock()), stages.GenericStage), + ("add_fields", (Field.of("n"),), stages.AddFields), + ("remove_fields", ("name",), stages.RemoveFields), + ("remove_fields", (Field.of("n"),), stages.RemoveFields), + ("select", ("name",), stages.Select), + ("select", (Field.of("n"),), stages.Select), + ("where", (Exists(Field.of("n")),), stages.Where), + ("find_nearest", ("name", [0.1], 0), stages.FindNearest), + ( + "find_nearest", + ("name", [0.1], 0, stages.FindNearestOptions(10)), + stages.FindNearest, + ), + ("sort", (Field.of("n").descending(),), stages.Sort), + ("sort", (Field.of("n").descending(), Field.of("m").ascending()), stages.Sort), + ("sample", (10,), stages.Sample), + ("sample", (stages.SampleOptions.doc_limit(10),), stages.Sample), + ("union", (_make_async_pipeline(),), stages.Union), + ("unnest", ("field_name",), stages.Unnest), + ("unnest", ("field_name", "alias"), stages.Unnest), + ("unnest", (Field.of("n"), Field.of("alias")), stages.Unnest), + ("unnest", ("n", "a", stages.UnnestOptions("idx")), stages.Unnest), + ("generic_stage", ("stage_name",), stages.GenericStage), + ("generic_stage", ("stage_name", Field.of("n")), stages.GenericStage), + ("offset", (1,), stages.Offset), + ("limit", (1,), stages.Limit), + ("aggregate", (Field.of("n").as_("alias"),), stages.Aggregate), + ("distinct", ("field_name",), stages.Distinct), + ("distinct", (Field.of("n"), "second"), stages.Distinct), ], ) def test_async_pipeline_methods(method, args, result_cls): @@ -391,3 +419,13 @@ def test_async_pipeline_methods(method, args, result_cls): assert len(start_ppl.stages) == 0 assert len(result_ppl.stages) == 1 assert isinstance(result_ppl.stages[0], result_cls) + + +def test_async_pipeline_aggregate_with_groups(): + start_ppl = _make_async_pipeline() + result_ppl = start_ppl.aggregate(Field.of("title"), groups=[Field.of("author")]) + assert len(start_ppl.stages) == 0 + assert len(result_ppl.stages) == 1 + assert isinstance(result_ppl.stages[0], stages.Aggregate) + assert list(result_ppl.stages[0].groups) == [Field.of("author")] + assert list(result_ppl.stages[0].accumulators) == [Field.of("title")] diff --git a/tests/unit/v1/test_pipeline.py b/tests/unit/v1/test_pipeline.py index 6a3fef3ac..b237ad5ac 100644 --- a/tests/unit/v1/test_pipeline.py +++ b/tests/unit/v1/test_pipeline.py @@ -16,6 +16,8 @@ import pytest from google.cloud.firestore_v1 import _pipeline_stages as stages +from google.cloud.firestore_v1.pipeline_expressions import Field +from google.cloud.firestore_v1.pipeline_expressions import Exists def _make_pipeline(*args, client=mock.Mock()): @@ -356,8 +358,34 @@ def test_pipeline_execute_stream_equivalence_mocked(): @pytest.mark.parametrize( "method,args,result_cls", [ - ("generic_stage", ("name",), stages.GenericStage), - ("generic_stage", ("name", mock.Mock()), stages.GenericStage), + ("add_fields", (Field.of("n"),), stages.AddFields), + ("remove_fields", ("name",), stages.RemoveFields), + ("remove_fields", (Field.of("n"),), stages.RemoveFields), + ("select", ("name",), stages.Select), + ("select", (Field.of("n"),), stages.Select), + ("where", (Exists(Field.of("n")),), stages.Where), + ("find_nearest", ("name", [0.1], 0), stages.FindNearest), + ( + "find_nearest", + ("name", [0.1], 0, stages.FindNearestOptions(10)), + stages.FindNearest, + ), + ("sort", (Field.of("n").descending(),), stages.Sort), + ("sort", (Field.of("n").descending(), Field.of("m").ascending()), stages.Sort), + ("sample", (10,), stages.Sample), + ("sample", (stages.SampleOptions.doc_limit(10),), stages.Sample), + ("union", (_make_pipeline(),), stages.Union), + ("unnest", ("field_name",), stages.Unnest), + ("unnest", ("field_name", "alias"), stages.Unnest), + ("unnest", (Field.of("n"), Field.of("alias")), stages.Unnest), + ("unnest", ("n", "a", stages.UnnestOptions("idx")), stages.Unnest), + ("generic_stage", ("stage_name",), stages.GenericStage), + ("generic_stage", ("stage_name", Field.of("n")), stages.GenericStage), + ("offset", (1,), stages.Offset), + ("limit", (1,), stages.Limit), + ("aggregate", (Field.of("n").as_("alias"),), stages.Aggregate), + ("distinct", ("field_name",), stages.Distinct), + ("distinct", (Field.of("n"), "second"), stages.Distinct), ], ) def test_pipeline_methods(method, args, result_cls): @@ -368,3 +396,13 @@ def test_pipeline_methods(method, args, result_cls): assert len(start_ppl.stages) == 0 assert len(result_ppl.stages) == 1 assert isinstance(result_ppl.stages[0], result_cls) + + +def test_pipeline_aggregate_with_groups(): + start_ppl = _make_pipeline() + result_ppl = start_ppl.aggregate(Field.of("title"), groups=[Field.of("author")]) + assert len(start_ppl.stages) == 0 + assert len(result_ppl.stages) == 1 + assert isinstance(result_ppl.stages[0], stages.Aggregate) + assert list(result_ppl.stages[0].groups) == [Field.of("author")] + assert list(result_ppl.stages[0].accumulators) == [Field.of("title")] diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index 19ebed3b5..936c0a0a9 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -10,15 +10,68 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License +# limitations under the License. import pytest +import mock import datetime -import google.cloud.firestore_v1.pipeline_expressions as expressions +from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1.types import document as document_pb +from google.cloud.firestore_v1.types import query as query_pb from google.cloud.firestore_v1.types.document import Value from google.cloud.firestore_v1.vector import Vector from google.cloud.firestore_v1._helpers import GeoPoint +from google.cloud.firestore_v1.pipeline_expressions import FilterCondition, ListOfExprs +import google.cloud.firestore_v1.pipeline_expressions as expr + + +@pytest.fixture +def mock_client(): + client = mock.Mock(spec=["_database_string", "collection"]) + client._database_string = "projects/p/databases/d" + return client + + +class TestOrdering: + @pytest.mark.parametrize( + "direction_arg,expected_direction", + [ + ("ASCENDING", expr.Ordering.Direction.ASCENDING), + ("DESCENDING", expr.Ordering.Direction.DESCENDING), + ("ascending", expr.Ordering.Direction.ASCENDING), + ("descending", expr.Ordering.Direction.DESCENDING), + (expr.Ordering.Direction.ASCENDING, expr.Ordering.Direction.ASCENDING), + (expr.Ordering.Direction.DESCENDING, expr.Ordering.Direction.DESCENDING), + ], + ) + def test_ctor(self, direction_arg, expected_direction): + instance = expr.Ordering("field1", direction_arg) + assert isinstance(instance.expr, expr.Field) + assert instance.expr.path == "field1" + assert instance.order_dir == expected_direction + + def test_repr(self): + field_expr = expr.Field.of("field1") + instance = expr.Ordering(field_expr, "ASCENDING") + repr_str = repr(instance) + assert repr_str == "Field.of('field1').ascending()" + + instance = expr.Ordering(field_expr, "DESCENDING") + repr_str = repr(instance) + assert repr_str == "Field.of('field1').descending()" + + def test_to_pb(self): + field_expr = expr.Field.of("field1") + instance = expr.Ordering(field_expr, "ASCENDING") + result = instance._to_pb() + assert result.map_value.fields["expression"].field_reference_value == "field1" + assert result.map_value.fields["direction"].string_value == "ascending" + + instance = expr.Ordering(field_expr, "DESCENDING") + result = instance._to_pb() + assert result.map_value.fields["expression"].field_reference_value == "field1" + assert result.map_value.fields["direction"].string_value == "descending" class TestExpr: @@ -27,7 +80,81 @@ def test_ctor(self): Base class should be abstract """ with pytest.raises(TypeError): - expressions.Expr() + expr.Expr() + + @pytest.mark.parametrize( + "method,args,result_cls", + [ + ("add", (2,), expr.Add), + ("subtract", (2,), expr.Subtract), + ("multiply", (2,), expr.Multiply), + ("divide", (2,), expr.Divide), + ("mod", (2,), expr.Mod), + ("logical_max", (2,), expr.LogicalMax), + ("logical_min", (2,), expr.LogicalMin), + ("eq", (2,), expr.Eq), + ("neq", (2,), expr.Neq), + ("lt", (2,), expr.Lt), + ("lte", (2,), expr.Lte), + ("gt", (2,), expr.Gt), + ("gte", (2,), expr.Gte), + ("in_any", ([None],), expr.In), + ("not_in_any", ([None],), expr.Not), + ("array_contains", (None,), expr.ArrayContains), + ("array_contains_all", ([None],), expr.ArrayContainsAll), + ("array_contains_any", ([None],), expr.ArrayContainsAny), + ("array_length", (), expr.ArrayLength), + ("array_reverse", (), expr.ArrayReverse), + ("is_nan", (), expr.IsNaN), + ("exists", (), expr.Exists), + ("sum", (), expr.Sum), + ("avg", (), expr.Avg), + ("count", (), expr.Count), + ("min", (), expr.Min), + ("max", (), expr.Max), + ("char_length", (), expr.CharLength), + ("byte_length", (), expr.ByteLength), + ("like", ("pattern",), expr.Like), + ("regex_contains", ("regex",), expr.RegexContains), + ("regex_matches", ("regex",), expr.RegexMatch), + ("str_contains", ("substring",), expr.StrContains), + ("starts_with", ("prefix",), expr.StartsWith), + ("ends_with", ("postfix",), expr.EndsWith), + ("str_concat", ("elem1", expr.Constant("elem2")), expr.StrConcat), + ("map_get", ("key",), expr.MapGet), + ("vector_length", (), expr.VectorLength), + ("timestamp_to_unix_micros", (), expr.TimestampToUnixMicros), + ("unix_micros_to_timestamp", (), expr.UnixMicrosToTimestamp), + ("timestamp_to_unix_millis", (), expr.TimestampToUnixMillis), + ("unix_millis_to_timestamp", (), expr.UnixMillisToTimestamp), + ("timestamp_to_unix_seconds", (), expr.TimestampToUnixSeconds), + ("unix_seconds_to_timestamp", (), expr.UnixSecondsToTimestamp), + ("timestamp_add", ("day", 1), expr.TimestampAdd), + ("timestamp_sub", ("hour", 2.5), expr.TimestampSub), + ("ascending", (), expr.Ordering), + ("descending", (), expr.Ordering), + ("as_", ("alias",), expr.ExprWithAlias), + ], + ) + @pytest.mark.parametrize( + "base_instance", + [ + expr.Constant(1), + expr.Function.add("1", 1), + expr.Field.of("test"), + expr.Constant(1).as_("one"), + ], + ) + def test_infix_call(self, method, args, result_cls, base_instance): + """ + many FilterCondition expressions support infix execution, and are exposed as methods on Expr. Test calling them + """ + method_ptr = getattr(base_instance, method) + + result = method_ptr(*args) + assert isinstance(result, result_cls) + if isinstance(result, expr.Function) and not method == "not_in_any": + assert result.params[0] == base_instance class TestConstant: @@ -73,7 +200,7 @@ class TestConstant: ], ) def test_to_pb(self, input_val, to_pb_val): - instance = expressions.Constant.of(input_val) + instance = expr.Constant.of(input_val) assert instance._to_pb() == to_pb_val @pytest.mark.parametrize( @@ -99,6 +226,1006 @@ def test_to_pb(self, input_val, to_pb_val): ], ) def test_repr(self, input_val, expected): - instance = expressions.Constant.of(input_val) + instance = expr.Constant.of(input_val) repr_string = repr(instance) assert repr_string == expected + + @pytest.mark.parametrize( + "first,second,expected", + [ + (expr.Constant.of(1), expr.Constant.of(2), False), + (expr.Constant.of(1), expr.Constant.of(1), True), + (expr.Constant.of(1), 1, True), + (expr.Constant.of(1), 2, False), + (expr.Constant.of("1"), 1, False), + (expr.Constant.of("1"), "1", True), + (expr.Constant.of(None), expr.Constant.of(0), False), + (expr.Constant.of(None), expr.Constant.of(None), True), + (expr.Constant.of([1, 2, 3]), expr.Constant.of([1, 2, 3]), True), + (expr.Constant.of([1, 2, 3]), expr.Constant.of([1, 2]), False), + (expr.Constant.of([1, 2, 3]), [1, 2, 3], True), + (expr.Constant.of([1, 2, 3]), object(), False), + ], + ) + def test_equality(self, first, second, expected): + assert (first == second) is expected + + +class TestListOfExprs: + def test_to_pb(self): + instance = expr.ListOfExprs([expr.Constant(1), expr.Constant(2)]) + result = instance._to_pb() + assert len(result.array_value.values) == 2 + assert result.array_value.values[0].integer_value == 1 + assert result.array_value.values[1].integer_value == 2 + + def test_empty_to_pb(self): + instance = expr.ListOfExprs([]) + result = instance._to_pb() + assert len(result.array_value.values) == 0 + + def test_repr(self): + instance = expr.ListOfExprs([expr.Constant(1), expr.Constant(2)]) + repr_string = repr(instance) + assert repr_string == "ListOfExprs([Constant.of(1), Constant.of(2)])" + empty_instance = expr.ListOfExprs([]) + empty_repr_string = repr(empty_instance) + assert empty_repr_string == "ListOfExprs([])" + + @pytest.mark.parametrize( + "first,second,expected", + [ + (expr.ListOfExprs([]), expr.ListOfExprs([]), True), + (expr.ListOfExprs([]), expr.ListOfExprs([expr.Constant(1)]), False), + (expr.ListOfExprs([expr.Constant(1)]), expr.ListOfExprs([]), False), + ( + expr.ListOfExprs([expr.Constant(1)]), + expr.ListOfExprs([expr.Constant(1)]), + True, + ), + ( + expr.ListOfExprs([expr.Constant(1)]), + expr.ListOfExprs([expr.Constant(2)]), + False, + ), + ( + expr.ListOfExprs([expr.Constant(1), expr.Constant(2)]), + expr.ListOfExprs([expr.Constant(1), expr.Constant(2)]), + True, + ), + (expr.ListOfExprs([expr.Constant(1)]), [expr.Constant(1)], False), + (expr.ListOfExprs([expr.Constant(1)]), [1], False), + (expr.ListOfExprs([expr.Constant(1)]), object(), False), + ], + ) + def test_equality(self, first, second, expected): + assert (first == second) is expected + + +class TestSelectable: + """ + contains tests for each Expr class that derives from Selectable + """ + + def test_ctor(self): + """ + Base class should be abstract + """ + with pytest.raises(TypeError): + expr.Selectable() + + def test_value_from_selectables(self): + selectable_list = [ + expr.Field.of("field1"), + expr.Field.of("field2").as_("alias2"), + ] + result = expr.Selectable._value_from_selectables(*selectable_list) + assert len(result.map_value.fields) == 2 + assert result.map_value.fields["field1"].field_reference_value == "field1" + assert result.map_value.fields["alias2"].field_reference_value == "field2" + + @pytest.mark.parametrize( + "first,second,expected", + [ + (expr.Field.of("field1"), expr.Field.of("field1"), True), + (expr.Field.of("field1"), expr.Field.of("field2"), False), + (expr.Field.of(None), object(), False), + (expr.Field.of("f").as_("a"), expr.Field.of("f").as_("a"), True), + (expr.Field.of("one").as_("a"), expr.Field.of("two").as_("a"), False), + (expr.Field.of("f").as_("one"), expr.Field.of("f").as_("two"), False), + (expr.Field.of("field"), expr.Field.of("field").as_("alias"), False), + (expr.Field.of("field").as_("alias"), expr.Field.of("field"), False), + ], + ) + def test_equality(self, first, second, expected): + assert (first == second) is expected + + class TestField: + def test_repr(self): + instance = expr.Field.of("field1") + repr_string = repr(instance) + assert repr_string == "Field.of('field1')" + + def test_of(self): + instance = expr.Field.of("field1") + assert instance.path == "field1" + + def test_to_pb(self): + instance = expr.Field.of("field1") + result = instance._to_pb() + assert result.field_reference_value == "field1" + + def test_to_map(self): + instance = expr.Field.of("field1") + result = instance._to_map() + assert result[0] == "field1" + assert result[1] == Value(field_reference_value="field1") + + class TestExprWithAlias: + def test_repr(self): + instance = expr.Field.of("field1").as_("alias1") + assert repr(instance) == "Field.of('field1').as_('alias1')" + + def test_ctor(self): + arg = expr.Field.of("field1") + alias = "alias1" + instance = expr.ExprWithAlias(arg, alias) + assert instance.expr == arg + assert instance.alias == alias + + def test_to_pb(self): + arg = expr.Field.of("field1") + alias = "alias1" + instance = expr.ExprWithAlias(arg, alias) + result = instance._to_pb() + assert result.map_value.fields.get("alias1") == arg._to_pb() + + def test_to_map(self): + instance = expr.Field.of("field1").as_("alias1") + result = instance._to_map() + assert result[0] == "alias1" + assert result[1] == Value(field_reference_value="field1") + + +class TestFilterCondition: + def test__from_query_filter_pb_composite_filter_or(self, mock_client): + """ + test composite OR filters + + should create an or statement, made up of ands checking of existance of relevant fields + """ + filter1_pb = query_pb.StructuredQuery.FieldFilter( + field=query_pb.StructuredQuery.FieldReference(field_path="field1"), + op=query_pb.StructuredQuery.FieldFilter.Operator.EQUAL, + value=_helpers.encode_value("val1"), + ) + filter2_pb = query_pb.StructuredQuery.UnaryFilter( + field=query_pb.StructuredQuery.FieldReference(field_path="field2"), + op=query_pb.StructuredQuery.UnaryFilter.Operator.IS_NULL, + ) + + composite_pb = query_pb.StructuredQuery.CompositeFilter( + op=query_pb.StructuredQuery.CompositeFilter.Operator.OR, + filters=[ + query_pb.StructuredQuery.Filter(field_filter=filter1_pb), + query_pb.StructuredQuery.Filter(unary_filter=filter2_pb), + ], + ) + wrapped_filter_pb = query_pb.StructuredQuery.Filter( + composite_filter=composite_pb + ) + + result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + + # should include existance checks + expected_cond1 = expr.And( + expr.Exists(expr.Field.of("field1")), + expr.Eq(expr.Field.of("field1"), expr.Constant("val1")), + ) + expected_cond2 = expr.And( + expr.Exists(expr.Field.of("field2")), + expr.Eq(expr.Field.of("field2"), expr.Constant(None)), + ) + expected = expr.Or(expected_cond1, expected_cond2) + + assert repr(result) == repr(expected) + + def test__from_query_filter_pb_composite_filter_and(self, mock_client): + """ + test composite AND filters + + should create an and statement, made up of ands checking of existance of relevant fields + """ + filter1_pb = query_pb.StructuredQuery.FieldFilter( + field=query_pb.StructuredQuery.FieldReference(field_path="field1"), + op=query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN, + value=_helpers.encode_value(100), + ) + filter2_pb = query_pb.StructuredQuery.FieldFilter( + field=query_pb.StructuredQuery.FieldReference(field_path="field2"), + op=query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN, + value=_helpers.encode_value(200), + ) + + composite_pb = query_pb.StructuredQuery.CompositeFilter( + op=query_pb.StructuredQuery.CompositeFilter.Operator.AND, + filters=[ + query_pb.StructuredQuery.Filter(field_filter=filter1_pb), + query_pb.StructuredQuery.Filter(field_filter=filter2_pb), + ], + ) + wrapped_filter_pb = query_pb.StructuredQuery.Filter( + composite_filter=composite_pb + ) + + result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + + # should include existance checks + expected_cond1 = expr.And( + expr.Exists(expr.Field.of("field1")), + expr.Gt(expr.Field.of("field1"), expr.Constant(100)), + ) + expected_cond2 = expr.And( + expr.Exists(expr.Field.of("field2")), + expr.Lt(expr.Field.of("field2"), expr.Constant(200)), + ) + expected = expr.And(expected_cond1, expected_cond2) + assert repr(result) == repr(expected) + + def test__from_query_filter_pb_composite_filter_nested(self, mock_client): + """ + test composite filter with complex nested checks + """ + # OR (field1 == "val1", AND(field2 > 10, field3 IS NOT NULL)) + filter1_pb = query_pb.StructuredQuery.FieldFilter( + field=query_pb.StructuredQuery.FieldReference(field_path="field1"), + op=query_pb.StructuredQuery.FieldFilter.Operator.EQUAL, + value=_helpers.encode_value("val1"), + ) + filter2_pb = query_pb.StructuredQuery.FieldFilter( + field=query_pb.StructuredQuery.FieldReference(field_path="field2"), + op=query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN, + value=_helpers.encode_value(10), + ) + filter3_pb = query_pb.StructuredQuery.UnaryFilter( + field=query_pb.StructuredQuery.FieldReference(field_path="field3"), + op=query_pb.StructuredQuery.UnaryFilter.Operator.IS_NOT_NULL, + ) + inner_and_pb = query_pb.StructuredQuery.CompositeFilter( + op=query_pb.StructuredQuery.CompositeFilter.Operator.AND, + filters=[ + query_pb.StructuredQuery.Filter(field_filter=filter2_pb), + query_pb.StructuredQuery.Filter(unary_filter=filter3_pb), + ], + ) + outer_or_pb = query_pb.StructuredQuery.CompositeFilter( + op=query_pb.StructuredQuery.CompositeFilter.Operator.OR, + filters=[ + query_pb.StructuredQuery.Filter(field_filter=filter1_pb), + query_pb.StructuredQuery.Filter(composite_filter=inner_and_pb), + ], + ) + wrapped_filter_pb = query_pb.StructuredQuery.Filter( + composite_filter=outer_or_pb + ) + + result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + + expected_cond1 = expr.And( + expr.Exists(expr.Field.of("field1")), + expr.Eq(expr.Field.of("field1"), expr.Constant("val1")), + ) + expected_cond2 = expr.And( + expr.Exists(expr.Field.of("field2")), + expr.Gt(expr.Field.of("field2"), expr.Constant(10)), + ) + expected_cond3 = expr.And( + expr.Exists(expr.Field.of("field3")), + expr.Not(expr.Eq(expr.Field.of("field3"), expr.Constant(None))), + ) + expected_inner_and = expr.And(expected_cond2, expected_cond3) + expected_outer_or = expr.Or(expected_cond1, expected_inner_and) + + assert repr(result) == repr(expected_outer_or) + + def test__from_query_filter_pb_composite_filter_unknown_op(self, mock_client): + """ + check composite filter with unsupported operator type + """ + filter1_pb = query_pb.StructuredQuery.FieldFilter( + field=query_pb.StructuredQuery.FieldReference(field_path="field1"), + op=query_pb.StructuredQuery.FieldFilter.Operator.EQUAL, + value=_helpers.encode_value("val1"), + ) + composite_pb = query_pb.StructuredQuery.CompositeFilter( + op=query_pb.StructuredQuery.CompositeFilter.Operator.OPERATOR_UNSPECIFIED, + filters=[query_pb.StructuredQuery.Filter(field_filter=filter1_pb)], + ) + wrapped_filter_pb = query_pb.StructuredQuery.Filter( + composite_filter=composite_pb + ) + + with pytest.raises(TypeError, match="Unexpected CompositeFilter operator type"): + FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + + @pytest.mark.parametrize( + "op_enum, expected_expr_func", + [ + (query_pb.StructuredQuery.UnaryFilter.Operator.IS_NAN, expr.IsNaN), + ( + query_pb.StructuredQuery.UnaryFilter.Operator.IS_NOT_NAN, + lambda f: expr.Not(f.is_nan()), + ), + ( + query_pb.StructuredQuery.UnaryFilter.Operator.IS_NULL, + lambda f: f.eq(None), + ), + ( + query_pb.StructuredQuery.UnaryFilter.Operator.IS_NOT_NULL, + lambda f: expr.Not(f.eq(None)), + ), + ], + ) + def test__from_query_filter_pb_unary_filter( + self, mock_client, op_enum, expected_expr_func + ): + """ + test supported unary filters + """ + field_path = "unary_field" + filter_pb = query_pb.StructuredQuery.UnaryFilter( + field=query_pb.StructuredQuery.FieldReference(field_path=field_path), + op=op_enum, + ) + wrapped_filter_pb = query_pb.StructuredQuery.Filter(unary_filter=filter_pb) + + result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + + field_expr_inst = expr.Field.of(field_path) + expected_condition = expected_expr_func(field_expr_inst) + # should include existance checks + expected = expr.And(expr.Exists(field_expr_inst), expected_condition) + + assert repr(result) == repr(expected) + + def test__from_query_filter_pb_unary_filter_unknown_op(self, mock_client): + """ + check unary filter with unsupported operator type + """ + field_path = "unary_field" + filter_pb = query_pb.StructuredQuery.UnaryFilter( + field=query_pb.StructuredQuery.FieldReference(field_path=field_path), + op=query_pb.StructuredQuery.UnaryFilter.Operator.OPERATOR_UNSPECIFIED, # Unknown op + ) + wrapped_filter_pb = query_pb.StructuredQuery.Filter(unary_filter=filter_pb) + + with pytest.raises(TypeError, match="Unexpected UnaryFilter operator type"): + FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + + @pytest.mark.parametrize( + "op_enum, value, expected_expr_func", + [ + (query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN, 10, expr.Lt), + ( + query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN_OR_EQUAL, + 10, + expr.Lte, + ), + (query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN, 10, expr.Gt), + ( + query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN_OR_EQUAL, + 10, + expr.Gte, + ), + (query_pb.StructuredQuery.FieldFilter.Operator.EQUAL, 10, expr.Eq), + (query_pb.StructuredQuery.FieldFilter.Operator.NOT_EQUAL, 10, expr.Neq), + ( + query_pb.StructuredQuery.FieldFilter.Operator.ARRAY_CONTAINS, + 10, + expr.ArrayContains, + ), + ( + query_pb.StructuredQuery.FieldFilter.Operator.ARRAY_CONTAINS_ANY, + [10, 20], + expr.ArrayContainsAny, + ), + (query_pb.StructuredQuery.FieldFilter.Operator.IN, [10, 20], expr.In), + ( + query_pb.StructuredQuery.FieldFilter.Operator.NOT_IN, + [10, 20], + lambda f, v: expr.Not(f.in_any(v)), + ), + ], + ) + def test__from_query_filter_pb_field_filter( + self, mock_client, op_enum, value, expected_expr_func + ): + """ + test supported field filters + """ + field_path = "test_field" + value_pb = _helpers.encode_value(value) + filter_pb = query_pb.StructuredQuery.FieldFilter( + field=query_pb.StructuredQuery.FieldReference(field_path=field_path), + op=op_enum, + value=value_pb, + ) + wrapped_filter_pb = query_pb.StructuredQuery.Filter(field_filter=filter_pb) + + result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + + field_expr = expr.Field.of(field_path) + # convert values into constants + value = ( + [expr.Constant(e) for e in value] + if isinstance(value, list) + else expr.Constant(value) + ) + expected_condition = expected_expr_func(field_expr, value) + # should include existance checks + expected = expr.And(expr.Exists(field_expr), expected_condition) + + assert repr(result) == repr(expected) + + def test__from_query_filter_pb_field_filter_unknown_op(self, mock_client): + """ + check field filter with unsupported operator type + """ + field_path = "test_field" + value_pb = _helpers.encode_value(10) + filter_pb = query_pb.StructuredQuery.FieldFilter( + field=query_pb.StructuredQuery.FieldReference(field_path=field_path), + op=query_pb.StructuredQuery.FieldFilter.Operator.OPERATOR_UNSPECIFIED, # Unknown op + value=value_pb, + ) + wrapped_filter_pb = query_pb.StructuredQuery.Filter(field_filter=filter_pb) + + with pytest.raises(TypeError, match="Unexpected FieldFilter operator type"): + FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + + def test__from_query_filter_pb_unknown_filter_type(self, mock_client): + """ + test with unsupported filter type + """ + # Test with an unexpected protobuf type + with pytest.raises(TypeError, match="Unexpected filter type"): + FilterCondition._from_query_filter_pb(document_pb.Value(), mock_client) + + +class TestFilterConditionClasses: + """ + contains test methods for each Expr class that derives from FilterCondition + """ + + def _make_arg(self, name="Mock"): + arg = mock.Mock() + arg.__repr__ = lambda x: name + return arg + + def test_and(self): + arg1 = self._make_arg() + arg2 = self._make_arg() + instance = expr.And(arg1, arg2) + assert instance.name == "and" + assert instance.params == [arg1, arg2] + assert repr(instance) == "And(Mock, Mock)" + + def test_or(self): + arg1 = self._make_arg("Arg1") + arg2 = self._make_arg("Arg2") + instance = expr.Or(arg1, arg2) + assert instance.name == "or" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Arg1.or(Arg2)" + + def test_array_contains(self): + arg1 = self._make_arg("ArrayField") + arg2 = self._make_arg("Element") + instance = expr.ArrayContains(arg1, arg2) + assert instance.name == "array_contains" + assert instance.params == [arg1, arg2] + assert repr(instance) == "ArrayField.array_contains(Element)" + + def test_array_contains_any(self): + arg1 = self._make_arg("ArrayField") + arg2 = self._make_arg("Element1") + arg3 = self._make_arg("Element2") + instance = expr.ArrayContainsAny(arg1, [arg2, arg3]) + assert instance.name == "array_contains_any" + assert isinstance(instance.params[1], ListOfExprs) + assert instance.params[0] == arg1 + assert instance.params[1].exprs == [arg2, arg3] + assert ( + repr(instance) + == "ArrayField.array_contains_any(ListOfExprs([Element1, Element2]))" + ) + + def test_exists(self): + arg1 = self._make_arg("Field") + instance = expr.Exists(arg1) + assert instance.name == "exists" + assert instance.params == [arg1] + assert repr(instance) == "Field.exists()" + + def test_eq(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = expr.Eq(arg1, arg2) + assert instance.name == "eq" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Left.eq(Right)" + + def test_gte(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = expr.Gte(arg1, arg2) + assert instance.name == "gte" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Left.gte(Right)" + + def test_gt(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = expr.Gt(arg1, arg2) + assert instance.name == "gt" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Left.gt(Right)" + + def test_lte(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = expr.Lte(arg1, arg2) + assert instance.name == "lte" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Left.lte(Right)" + + def test_lt(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = expr.Lt(arg1, arg2) + assert instance.name == "lt" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Left.lt(Right)" + + def test_neq(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = expr.Neq(arg1, arg2) + assert instance.name == "neq" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Left.neq(Right)" + + def test_in(self): + arg1 = self._make_arg("Field") + arg2 = self._make_arg("Value1") + arg3 = self._make_arg("Value2") + instance = expr.In(arg1, [arg2, arg3]) + assert instance.name == "in" + assert isinstance(instance.params[1], ListOfExprs) + assert instance.params[0] == arg1 + assert instance.params[1].exprs == [arg2, arg3] + assert repr(instance) == "Field.in_any(ListOfExprs([Value1, Value2]))" + + def test_is_nan(self): + arg1 = self._make_arg("Value") + instance = expr.IsNaN(arg1) + assert instance.name == "is_nan" + assert instance.params == [arg1] + assert repr(instance) == "Value.is_nan()" + + def test_not(self): + arg1 = self._make_arg("Condition") + instance = expr.Not(arg1) + assert instance.name == "not" + assert instance.params == [arg1] + assert repr(instance) == "Not(Condition)" + + def test_array_contains_all(self): + arg1 = self._make_arg("ArrayField") + arg2 = self._make_arg("Element1") + arg3 = self._make_arg("Element2") + instance = expr.ArrayContainsAll(arg1, [arg2, arg3]) + assert instance.name == "array_contains_all" + assert isinstance(instance.params[1], ListOfExprs) + assert instance.params[0] == arg1 + assert instance.params[1].exprs == [arg2, arg3] + assert ( + repr(instance) + == "ArrayField.array_contains_all(ListOfExprs([Element1, Element2]))" + ) + + def test_ends_with(self): + arg1 = self._make_arg("Expr") + arg2 = self._make_arg("Postfix") + instance = expr.EndsWith(arg1, arg2) + assert instance.name == "ends_with" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Expr.ends_with(Postfix)" + + def test_if(self): + arg1 = self._make_arg("Condition") + arg2 = self._make_arg("TrueExpr") + arg3 = self._make_arg("FalseExpr") + instance = expr.If(arg1, arg2, arg3) + assert instance.name == "if" + assert instance.params == [arg1, arg2, arg3] + assert repr(instance) == "If(Condition, TrueExpr, FalseExpr)" + + def test_like(self): + arg1 = self._make_arg("Expr") + arg2 = self._make_arg("Pattern") + instance = expr.Like(arg1, arg2) + assert instance.name == "like" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Expr.like(Pattern)" + + def test_regex_contains(self): + arg1 = self._make_arg("Expr") + arg2 = self._make_arg("Regex") + instance = expr.RegexContains(arg1, arg2) + assert instance.name == "regex_contains" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Expr.regex_contains(Regex)" + + def test_regex_match(self): + arg1 = self._make_arg("Expr") + arg2 = self._make_arg("Regex") + instance = expr.RegexMatch(arg1, arg2) + assert instance.name == "regex_match" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Expr.regex_match(Regex)" + + def test_starts_with(self): + arg1 = self._make_arg("Expr") + arg2 = self._make_arg("Prefix") + instance = expr.StartsWith(arg1, arg2) + assert instance.name == "starts_with" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Expr.starts_with(Prefix)" + + def test_str_contains(self): + arg1 = self._make_arg("Expr") + arg2 = self._make_arg("Substring") + instance = expr.StrContains(arg1, arg2) + assert instance.name == "str_contains" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Expr.str_contains(Substring)" + + def test_xor(self): + arg1 = self._make_arg("Condition1") + arg2 = self._make_arg("Condition2") + instance = expr.Xor([arg1, arg2]) + assert instance.name == "xor" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Xor(Condition1, Condition2)" + + +class TestFunctionClasses: + """ + contains test methods for each Expr class that derives from Function + """ + + @pytest.mark.parametrize( + "method,args,result_cls", + [ + ("add", ("field", 2), expr.Add), + ("subtract", ("field", 2), expr.Subtract), + ("multiply", ("field", 2), expr.Multiply), + ("divide", ("field", 2), expr.Divide), + ("mod", ("field", 2), expr.Mod), + ("logical_max", ("field", 2), expr.LogicalMax), + ("logical_min", ("field", 2), expr.LogicalMin), + ("eq", ("field", 2), expr.Eq), + ("neq", ("field", 2), expr.Neq), + ("lt", ("field", 2), expr.Lt), + ("lte", ("field", 2), expr.Lte), + ("gt", ("field", 2), expr.Gt), + ("gte", ("field", 2), expr.Gte), + ("in_any", ("field", [None]), expr.In), + ("not_in_any", ("field", [None]), expr.Not), + ("array_contains", ("field", None), expr.ArrayContains), + ("array_contains_all", ("field", [None]), expr.ArrayContainsAll), + ("array_contains_any", ("field", [None]), expr.ArrayContainsAny), + ("array_length", ("field",), expr.ArrayLength), + ("array_reverse", ("field",), expr.ArrayReverse), + ("is_nan", ("field",), expr.IsNaN), + ("exists", ("field",), expr.Exists), + ("sum", ("field",), expr.Sum), + ("avg", ("field",), expr.Avg), + ("count", ("field",), expr.Count), + ("count", (), expr.Count), + ("min", ("field",), expr.Min), + ("max", ("field",), expr.Max), + ("char_length", ("field",), expr.CharLength), + ("byte_length", ("field",), expr.ByteLength), + ("like", ("field", "pattern"), expr.Like), + ("regex_contains", ("field", "regex"), expr.RegexContains), + ("regex_matches", ("field", "regex"), expr.RegexMatch), + ("str_contains", ("field", "substring"), expr.StrContains), + ("starts_with", ("field", "prefix"), expr.StartsWith), + ("ends_with", ("field", "postfix"), expr.EndsWith), + ("str_concat", ("field", "elem1", "elem2"), expr.StrConcat), + ("map_get", ("field", "key"), expr.MapGet), + ("vector_length", ("field",), expr.VectorLength), + ("timestamp_to_unix_micros", ("field",), expr.TimestampToUnixMicros), + ("unix_micros_to_timestamp", ("field",), expr.UnixMicrosToTimestamp), + ("timestamp_to_unix_millis", ("field",), expr.TimestampToUnixMillis), + ("unix_millis_to_timestamp", ("field",), expr.UnixMillisToTimestamp), + ("timestamp_to_unix_seconds", ("field",), expr.TimestampToUnixSeconds), + ("unix_seconds_to_timestamp", ("field",), expr.UnixSecondsToTimestamp), + ("timestamp_add", ("field", "day", 1), expr.TimestampAdd), + ("timestamp_sub", ("field", "hour", 2.5), expr.TimestampSub), + ], + ) + def test_function_builder(self, method, args, result_cls): + """ + Test building functions using methods exposed on base Function class. + """ + method_ptr = getattr(expr.Function, method) + + result = method_ptr(*args) + assert isinstance(result, result_cls) + + @pytest.mark.parametrize( + "first,second,expected", + [ + (expr.ArrayElement(), expr.ArrayElement(), True), + (expr.ArrayElement(), expr.CharLength(1), False), + (expr.ArrayElement(), object(), False), + (expr.ArrayElement(), None, False), + (expr.CharLength(1), expr.ArrayElement(), False), + (expr.CharLength(1), expr.CharLength(2), False), + (expr.CharLength(1), expr.CharLength(1), True), + (expr.CharLength(1), expr.ByteLength(1), False), + ], + ) + def test_equality(self, first, second, expected): + assert (first == second) is expected + + def _make_arg(self, name="Mock"): + arg = mock.Mock() + arg.__repr__ = lambda x: name + return arg + + def test_divide(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = expr.Divide(arg1, arg2) + assert instance.name == "divide" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Divide(Left, Right)" + + def test_logical_max(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = expr.LogicalMax(arg1, arg2) + assert instance.name == "logical_maximum" + assert instance.params == [arg1, arg2] + assert repr(instance) == "LogicalMax(Left, Right)" + + def test_logical_min(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = expr.LogicalMin(arg1, arg2) + assert instance.name == "logical_minimum" + assert instance.params == [arg1, arg2] + assert repr(instance) == "LogicalMin(Left, Right)" + + def test_map_get(self): + arg1 = self._make_arg("Map") + arg2 = expr.Constant("Key") + instance = expr.MapGet(arg1, arg2) + assert instance.name == "map_get" + assert instance.params == [arg1, arg2] + assert repr(instance) == "MapGet(Map, Constant.of('Key'))" + + def test_mod(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = expr.Mod(arg1, arg2) + assert instance.name == "mod" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Mod(Left, Right)" + + def test_multiply(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = expr.Multiply(arg1, arg2) + assert instance.name == "multiply" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Multiply(Left, Right)" + + def test_parent(self): + arg1 = self._make_arg("Value") + instance = expr.Parent(arg1) + assert instance.name == "parent" + assert instance.params == [arg1] + assert repr(instance) == "Parent(Value)" + + def test_str_concat(self): + arg1 = self._make_arg("Str1") + arg2 = self._make_arg("Str2") + instance = expr.StrConcat(arg1, arg2) + assert instance.name == "str_concat" + assert instance.params == [arg1, arg2] + assert repr(instance) == "StrConcat(Str1, Str2)" + + def test_subtract(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = expr.Subtract(arg1, arg2) + assert instance.name == "subtract" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Subtract(Left, Right)" + + def test_timestamp_add(self): + arg1 = self._make_arg("Timestamp") + arg2 = self._make_arg("Unit") + arg3 = self._make_arg("Amount") + instance = expr.TimestampAdd(arg1, arg2, arg3) + assert instance.name == "timestamp_add" + assert instance.params == [arg1, arg2, arg3] + assert repr(instance) == "TimestampAdd(Timestamp, Unit, Amount)" + + def test_timestamp_sub(self): + arg1 = self._make_arg("Timestamp") + arg2 = self._make_arg("Unit") + arg3 = self._make_arg("Amount") + instance = expr.TimestampSub(arg1, arg2, arg3) + assert instance.name == "timestamp_sub" + assert instance.params == [arg1, arg2, arg3] + assert repr(instance) == "TimestampSub(Timestamp, Unit, Amount)" + + def test_timestamp_to_unix_micros(self): + arg1 = self._make_arg("Input") + instance = expr.TimestampToUnixMicros(arg1) + assert instance.name == "timestamp_to_unix_micros" + assert instance.params == [arg1] + assert repr(instance) == "TimestampToUnixMicros(Input)" + + def test_timestamp_to_unix_millis(self): + arg1 = self._make_arg("Input") + instance = expr.TimestampToUnixMillis(arg1) + assert instance.name == "timestamp_to_unix_millis" + assert instance.params == [arg1] + assert repr(instance) == "TimestampToUnixMillis(Input)" + + def test_timestamp_to_unix_seconds(self): + arg1 = self._make_arg("Input") + instance = expr.TimestampToUnixSeconds(arg1) + assert instance.name == "timestamp_to_unix_seconds" + assert instance.params == [arg1] + assert repr(instance) == "TimestampToUnixSeconds(Input)" + + def test_unix_micros_to_timestamp(self): + arg1 = self._make_arg("Input") + instance = expr.UnixMicrosToTimestamp(arg1) + assert instance.name == "unix_micros_to_timestamp" + assert instance.params == [arg1] + assert repr(instance) == "UnixMicrosToTimestamp(Input)" + + def test_unix_millis_to_timestamp(self): + arg1 = self._make_arg("Input") + instance = expr.UnixMillisToTimestamp(arg1) + assert instance.name == "unix_millis_to_timestamp" + assert instance.params == [arg1] + assert repr(instance) == "UnixMillisToTimestamp(Input)" + + def test_unix_seconds_to_timestamp(self): + arg1 = self._make_arg("Input") + instance = expr.UnixSecondsToTimestamp(arg1) + assert instance.name == "unix_seconds_to_timestamp" + assert instance.params == [arg1] + assert repr(instance) == "UnixSecondsToTimestamp(Input)" + + def test_vector_length(self): + arg1 = self._make_arg("Array") + instance = expr.VectorLength(arg1) + assert instance.name == "vector_length" + assert instance.params == [arg1] + assert repr(instance) == "VectorLength(Array)" + + def test_add(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = expr.Add(arg1, arg2) + assert instance.name == "add" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Add(Left, Right)" + + def test_array_element(self): + instance = expr.ArrayElement() + assert instance.name == "array_element" + assert instance.params == [] + assert repr(instance) == "ArrayElement()" + + def test_array_filter(self): + arg1 = self._make_arg("Array") + arg2 = self._make_arg("FilterCond") + instance = expr.ArrayFilter(arg1, arg2) + assert instance.name == "array_filter" + assert instance.params == [arg1, arg2] + assert repr(instance) == "ArrayFilter(Array, FilterCond)" + + def test_array_length(self): + arg1 = self._make_arg("Array") + instance = expr.ArrayLength(arg1) + assert instance.name == "array_length" + assert instance.params == [arg1] + assert repr(instance) == "ArrayLength(Array)" + + def test_array_reverse(self): + arg1 = self._make_arg("Array") + instance = expr.ArrayReverse(arg1) + assert instance.name == "array_reverse" + assert instance.params == [arg1] + assert repr(instance) == "ArrayReverse(Array)" + + def test_array_transform(self): + arg1 = self._make_arg("Array") + arg2 = self._make_arg("TransformFunc") + instance = expr.ArrayTransform(arg1, arg2) + assert instance.name == "array_transform" + assert instance.params == [arg1, arg2] + assert repr(instance) == "ArrayTransform(Array, TransformFunc)" + + def test_byte_length(self): + arg1 = self._make_arg("Expr") + instance = expr.ByteLength(arg1) + assert instance.name == "byte_length" + assert instance.params == [arg1] + assert repr(instance) == "ByteLength(Expr)" + + def test_char_length(self): + arg1 = self._make_arg("Expr") + instance = expr.CharLength(arg1) + assert instance.name == "char_length" + assert instance.params == [arg1] + assert repr(instance) == "CharLength(Expr)" + + def test_collection_id(self): + arg1 = self._make_arg("Value") + instance = expr.CollectionId(arg1) + assert instance.name == "collection_id" + assert instance.params == [arg1] + assert repr(instance) == "CollectionId(Value)" + + def test_sum(self): + arg1 = self._make_arg("Value") + instance = expr.Sum(arg1) + assert instance.name == "sum" + assert instance.params == [arg1] + assert repr(instance) == "Sum(Value)" + + def test_avg(self): + arg1 = self._make_arg("Value") + instance = expr.Avg(arg1) + assert instance.name == "avg" + assert instance.params == [arg1] + assert repr(instance) == "Avg(Value)" + + def test_count(self): + arg1 = self._make_arg("Value") + instance = expr.Count(arg1) + assert instance.name == "count" + assert instance.params == [arg1] + assert repr(instance) == "Count(Value)" + + def test_count_empty(self): + instance = expr.Count() + assert instance.params == [] + assert repr(instance) == "Count()" + + def test_min(self): + arg1 = self._make_arg("Value") + instance = expr.Min(arg1) + assert instance.name == "minimum" + assert instance.params == [arg1] + assert repr(instance) == "Min(Value)" + + def test_max(self): + arg1 = self._make_arg("Value") + instance = expr.Max(arg1) + assert instance.name == "maximum" + assert instance.params == [arg1] + assert repr(instance) == "Max(Value)" diff --git a/tests/unit/v1/test_pipeline_source.py b/tests/unit/v1/test_pipeline_source.py index cd8b56b68..bed1bd05a 100644 --- a/tests/unit/v1/test_pipeline_source.py +++ b/tests/unit/v1/test_pipeline_source.py @@ -18,6 +18,7 @@ from google.cloud.firestore_v1.client import Client from google.cloud.firestore_v1.async_client import AsyncClient from google.cloud.firestore_v1 import _pipeline_stages as stages +from google.cloud.firestore_v1.base_document import BaseDocumentReference class TestPipelineSource: @@ -44,6 +45,49 @@ def test_collection(self): assert isinstance(first_stage, stages.Collection) assert first_stage.path == "/path" + def test_collection_w_tuple(self): + instance = self._make_client().pipeline() + ppl = instance.collection(("a", "b", "c")) + assert isinstance(ppl, self._expected_pipeline_type) + assert len(ppl.stages) == 1 + first_stage = ppl.stages[0] + assert isinstance(first_stage, stages.Collection) + assert first_stage.path == "/a/b/c" + + def test_collection_group(self): + instance = self._make_client().pipeline() + ppl = instance.collection_group("id") + assert isinstance(ppl, self._expected_pipeline_type) + assert len(ppl.stages) == 1 + first_stage = ppl.stages[0] + assert isinstance(first_stage, stages.CollectionGroup) + assert first_stage.collection_id == "id" + + def test_database(self): + instance = self._make_client().pipeline() + ppl = instance.database() + assert isinstance(ppl, self._expected_pipeline_type) + assert len(ppl.stages) == 1 + first_stage = ppl.stages[0] + assert isinstance(first_stage, stages.Database) + + def test_documents(self): + instance = self._make_client().pipeline() + test_documents = [ + BaseDocumentReference("a", "1"), + BaseDocumentReference("a", "2"), + BaseDocumentReference("a", "3"), + ] + ppl = instance.documents(*test_documents) + assert isinstance(ppl, self._expected_pipeline_type) + assert len(ppl.stages) == 1 + first_stage = ppl.stages[0] + assert isinstance(first_stage, stages.Documents) + assert len(first_stage.paths) == 3 + assert first_stage.paths[0] == "/a/1" + assert first_stage.paths[1] == "/a/2" + assert first_stage.paths[2] == "/a/3" + class TestPipelineSourceWithAsyncClient(TestPipelineSource): """ diff --git a/tests/unit/v1/test_pipeline_stages.py b/tests/unit/v1/test_pipeline_stages.py index 59d808d63..e67a4ca3a 100644 --- a/tests/unit/v1/test_pipeline_stages.py +++ b/tests/unit/v1/test_pipeline_stages.py @@ -13,11 +13,21 @@ # limitations under the License import pytest +from unittest import mock +from google.cloud.firestore_v1.base_pipeline import _BasePipeline import google.cloud.firestore_v1._pipeline_stages as stages -from google.cloud.firestore_v1.pipeline_expressions import Constant +from google.cloud.firestore_v1.pipeline_expressions import ( + Constant, + Field, + Ordering, + Sum, + Count, +) from google.cloud.firestore_v1.types.document import Value from google.cloud.firestore_v1._helpers import GeoPoint +from google.cloud.firestore_v1.vector import Vector +from google.cloud.firestore_v1.base_vector_query import DistanceMeasure class TestStage: @@ -29,6 +39,113 @@ def test_ctor(self): stages.Stage() +class TestAddFields: + def _make_one(self, *args, **kwargs): + return stages.AddFields(*args, **kwargs) + + def test_ctor(self): + field1 = Field.of("field1") + field2_aliased = Field.of("field2").as_("alias2") + instance = self._make_one(field1, field2_aliased) + assert instance.fields == [field1, field2_aliased] + assert instance.name == "add_fields" + + def test_repr(self): + field1 = Field.of("field1").as_("f1") + instance = self._make_one(field1) + repr_str = repr(instance) + assert repr_str == "AddFields(fields=[Field.of('field1').as_('f1')])" + + def test_to_pb(self): + field1 = Field.of("field1") + field2_aliased = Field.of("field2").as_("alias2") + instance = self._make_one(field1, field2_aliased) + result = instance._to_pb() + assert result.name == "add_fields" + assert len(result.args) == 1 + expected_map_value = { + "fields": { + "field1": Value(field_reference_value="field1"), + "alias2": Value(field_reference_value="field2"), + } + } + assert result.args[0].map_value.fields == expected_map_value["fields"] + assert len(result.options) == 0 + + +class TestAggregate: + def _make_one(self, *args, **kwargs): + return stages.Aggregate(*args, **kwargs) + + def test_ctor_positional(self): + """test with only positional arguments""" + sum_total = Sum(Field.of("total")).as_("sum_total") + avg_price = Field.of("price").avg().as_("avg_price") + instance = self._make_one(sum_total, avg_price) + assert list(instance.accumulators) == [sum_total, avg_price] + assert len(instance.groups) == 0 + assert instance.name == "aggregate" + + def test_ctor_keyword(self): + """test with only keyword arguments""" + sum_total = Sum(Field.of("total")).as_("sum_total") + avg_price = Field.of("price").avg().as_("avg_price") + group_category = Field.of("category") + instance = self._make_one( + accumulators=[avg_price, sum_total], groups=[group_category, "city"] + ) + assert instance.accumulators == [avg_price, sum_total] + assert len(instance.groups) == 2 + assert instance.groups[0] == group_category + assert isinstance(instance.groups[1], Field) + assert instance.groups[1].path == "city" + assert instance.name == "aggregate" + + def test_ctor_combined(self): + """test with a mix of arguments""" + sum_total = Sum(Field.of("total")).as_("sum_total") + avg_price = Field.of("price").avg().as_("avg_price") + count = Count(Field.of("total")).as_("count") + with pytest.raises(ValueError): + self._make_one(sum_total, accumulators=[avg_price, count]) + + def test_repr(self): + sum_total = Sum(Field.of("total")).as_("sum_total") + group_category = Field.of("category") + instance = self._make_one(sum_total, groups=[group_category]) + repr_str = repr(instance) + assert ( + repr_str + == "Aggregate(Sum(Field.of('total')).as_('sum_total'), groups=[Field.of('category')])" + ) + + def test_to_pb(self): + sum_total = Sum(Field.of("total")).as_("sum_total") + group_category = Field.of("category") + instance = self._make_one(sum_total, groups=[group_category]) + result = instance._to_pb() + assert result.name == "aggregate" + assert len(result.args) == 2 + + expected_accumulators_map = { + "fields": { + "sum_total": Value( + function_value={ + "name": "sum", + "args": [Value(field_reference_value="total")], + } + ) + } + } + assert result.args[0].map_value.fields == expected_accumulators_map["fields"] + + expected_groups_map = { + "fields": {"category": Value(field_reference_value="category")} + } + assert result.args[1].map_value.fields == expected_groups_map["fields"] + assert len(result.options) == 0 + + class TestCollection: def _make_one(self, *args, **kwargs): return stages.Collection(*args, **kwargs) @@ -55,6 +172,256 @@ def test_to_pb(self): assert len(result.options) == 0 +class TestCollectionGroup: + def _make_one(self, *args, **kwargs): + return stages.CollectionGroup(*args, **kwargs) + + def test_repr(self): + input_arg = "test" + instance = self._make_one(input_arg) + repr_str = repr(instance) + assert repr_str == "CollectionGroup(collection_id='test')" + + def test_to_pb(self): + input_arg = "test" + instance = self._make_one(input_arg) + result = instance._to_pb() + assert result.name == "collection_group" + assert len(result.args) == 1 + assert result.args[0].string_value == "test" + assert len(result.options) == 0 + + +class TestDatabase: + def _make_one(self, *args, **kwargs): + return stages.Database(*args, **kwargs) + + def test_ctor(self): + instance = self._make_one() + assert instance.name == "database" + + def test_repr(self): + instance = self._make_one() + repr_str = repr(instance) + assert repr_str == "Database()" + + def test_to_pb(self): + instance = self._make_one() + result = instance._to_pb() + assert result.name == "database" + assert len(result.args) == 0 + assert len(result.options) == 0 + + +class TestDistinct: + def _make_one(self, *args, **kwargs): + return stages.Distinct(*args, **kwargs) + + def test_ctor(self): + field1 = Field.of("field1") + instance = self._make_one("field2", field1) + assert len(instance.fields) == 2 + assert isinstance(instance.fields[0], Field) + assert instance.fields[0].path == "field2" + assert instance.fields[1] == field1 + assert instance.name == "distinct" + + def test_repr(self): + instance = self._make_one("field1", Field.of("field2")) + repr_str = repr(instance) + assert repr_str == "Distinct(fields=[Field.of('field1'), Field.of('field2')])" + + def test_to_pb(self): + instance = self._make_one("field1", Field.of("field2")) + result = instance._to_pb() + assert result.name == "distinct" + assert len(result.args) == 1 + expected_map_value = { + "fields": { + "field1": Value(field_reference_value="field1"), + "field2": Value(field_reference_value="field2"), + } + } + assert result.args[0].map_value.fields == expected_map_value["fields"] + assert len(result.options) == 0 + + +class TestDocuments: + def _make_one(self, *args, **kwargs): + return stages.Documents(*args, **kwargs) + + def test_ctor(self): + instance = self._make_one("/projects/p/databases/d/documents/c/doc1", "/c/doc2") + assert instance.paths == ("/projects/p/databases/d/documents/c/doc1", "/c/doc2") + assert instance.name == "documents" + + def test_of(self): + mock_doc_ref1 = mock.Mock() + mock_doc_ref1.path = "projects/p/databases/d/documents/c/doc1" + mock_doc_ref2 = mock.Mock() + mock_doc_ref2.path = "c/doc2" # Test relative path as well + instance = stages.Documents.of(mock_doc_ref1, mock_doc_ref2) + assert instance.paths == ( + "/projects/p/databases/d/documents/c/doc1", + "/c/doc2", + ) + + def test_repr(self): + instance = self._make_one("/a/b", "/c/d") + repr_str = repr(instance) + assert repr_str == "Documents('/a/b', '/c/d')" + + def test_to_pb(self): + instance = self._make_one("/projects/p/databases/d/documents/c/doc1", "/c/doc2") + result = instance._to_pb() + assert result.name == "documents" + assert len(result.args) == 1 + assert ( + result.args[0].array_value.values[0].string_value + == "/projects/p/databases/d/documents/c/doc1" + ) + assert result.args[0].array_value.values[1].string_value == "/c/doc2" + assert len(result.options) == 0 + + +class TestFindNearest: + class TestFindNearestOptions: + def _make_one_options(self, *args, **kwargs): + return stages.FindNearestOptions(*args, **kwargs) + + def test_ctor_options(self): + limit_val = 10 + distance_field_val = Field.of("dist") + instance = self._make_one_options( + limit=limit_val, distance_field=distance_field_val + ) + assert instance.limit == limit_val + assert instance.distance_field == distance_field_val + + def test_ctor_defaults(self): + instance_default = self._make_one_options() + assert instance_default.limit is None + assert instance_default.distance_field is None + + def test_repr(self): + instance_empty = self._make_one_options() + assert repr(instance_empty) == "FindNearestOptions()" + instance_limit = self._make_one_options(limit=5) + assert repr(instance_limit) == "FindNearestOptions(limit=5)" + instance_distance = self._make_one_options(distance_field=Field.of("dist")) + assert ( + repr(instance_distance) + == "FindNearestOptions(distance_field=Field.of('dist'))" + ) + instance_full = self._make_one_options( + limit=5, distance_field=Field.of("dist") + ) + assert ( + repr(instance_full) + == "FindNearestOptions(limit=5, distance_field=Field.of('dist'))" + ) + + def _make_one(self, *args, **kwargs): + return stages.FindNearest(*args, **kwargs) + + def test_ctor_w_str_field(self): + field_path = "embedding_field" + vector_val = Vector([1.0, 2.0, 3.0]) + distance_measure_val = DistanceMeasure.EUCLIDEAN + options_val = stages.FindNearestOptions( + limit=5, distance_field=Field.of("distance") + ) + + instance_str_field = self._make_one( + field_path, vector_val, distance_measure_val, options=options_val + ) + assert isinstance(instance_str_field.field, Field) + assert instance_str_field.field.path == field_path + assert instance_str_field.vector == vector_val + assert instance_str_field.distance_measure == distance_measure_val + assert instance_str_field.options == options_val + assert instance_str_field.name == "find_nearest" + + def test_ctor_w_field_obj(self): + field_path = "embedding_field" + field_obj = Field.of(field_path) + vector_val = Vector([1.0, 2.0, 3.0]) + distance_measure_val = DistanceMeasure.EUCLIDEAN + instance_field_obj = self._make_one(field_obj, vector_val, distance_measure_val) + assert instance_field_obj.field == field_obj + assert instance_field_obj.options.limit is None # Default options + assert instance_field_obj.options.distance_field is None + + def test_ctor_w_vector_list(self): + field_path = "embedding_field" + distance_measure_val = DistanceMeasure.EUCLIDEAN + + vector_list = [4.0, 5.0] + instance_list_vector = self._make_one( + field_path, vector_list, distance_measure_val + ) + assert isinstance(instance_list_vector.vector, Vector) + assert instance_list_vector.vector == Vector(vector_list) + + def test_repr(self): + field_path = "embedding_field" + vector_val = Vector([1.0, 2.0]) + distance_measure_val = DistanceMeasure.EUCLIDEAN + options_val = stages.FindNearestOptions(limit=5) + instance = self._make_one( + field_path, vector_val, distance_measure_val, options=options_val + ) + repr_str = repr(instance) + expected_repr = "FindNearest(field=Field.of('embedding_field'), vector=Vector<1.0, 2.0>, distance_measure=, options=FindNearestOptions(limit=5))" + assert repr_str == expected_repr + + @pytest.mark.parametrize( + "distance_measure_val, expected_str", + [ + (DistanceMeasure.COSINE, "cosine"), + (DistanceMeasure.DOT_PRODUCT, "dot_product"), + (DistanceMeasure.EUCLIDEAN, "euclidean"), + ], + ) + def test_to_pb(self, distance_measure_val, expected_str): + field_path = "embedding" + vector_val = Vector([0.1, 0.2]) + options_val = stages.FindNearestOptions( + limit=7, distance_field=Field.of("dist_val") + ) + instance = self._make_one( + field_path, vector_val, distance_measure_val, options=options_val + ) + + result = instance._to_pb() + assert result.name == "find_nearest" + assert len(result.args) == 3 + # test field arg + assert result.args[0].field_reference_value == field_path + # test for vector arg + assert result.args[1].map_value.fields["__type__"].string_value == "__vector__" + assert ( + result.args[1].map_value.fields["value"].array_value.values[0].double_value + == 0.1 + ) + assert ( + result.args[1].map_value.fields["value"].array_value.values[1].double_value + == 0.2 + ) + # test for distance measure arg + assert result.args[2].string_value == expected_str + # test options + assert len(result.options) == 2 + assert result.options["limit"].integer_value == 7 + assert result.options["distance_field"].field_reference_value == "dist_val" + + def test_to_pb_no_options(self): + instance = self._make_one("emb", [1.0], DistanceMeasure.DOT_PRODUCT) + result = instance._to_pb() + assert len(result.options) == 0 + assert len(result.args) == 3 + + class TestGenericStage: def _make_one(self, *args, **kwargs): return stages.GenericStage(*args, **kwargs) @@ -119,3 +486,324 @@ def test_to_pb(self): assert result.args[0].boolean_value is True assert result.args[1].string_value == "test" assert len(result.options) == 0 + + +class TestLimit: + def _make_one(self, *args, **kwargs): + return stages.Limit(*args, **kwargs) + + def test_repr(self): + instance = self._make_one(10) + repr_str = repr(instance) + assert repr_str == "Limit(limit=10)" + + def test_to_pb(self): + instance = self._make_one(5) + result = instance._to_pb() + assert result.name == "limit" + assert len(result.args) == 1 + assert result.args[0].integer_value == 5 + assert len(result.options) == 0 + + +class TestOffset: + def _make_one(self, *args, **kwargs): + return stages.Offset(*args, **kwargs) + + def test_repr(self): + instance = self._make_one(20) + repr_str = repr(instance) + assert repr_str == "Offset(offset=20)" + + def test_to_pb(self): + instance = self._make_one(3) + result = instance._to_pb() + assert result.name == "offset" + assert len(result.args) == 1 + assert result.args[0].integer_value == 3 + assert len(result.options) == 0 + + +class TestRemoveFields: + def _make_one(self, *args, **kwargs): + return stages.RemoveFields(*args, **kwargs) + + def test_ctor(self): + field1 = Field.of("field1") + instance = self._make_one("field2", field1) + assert len(instance.fields) == 2 + assert isinstance(instance.fields[0], Field) + assert instance.fields[0].path == "field2" + assert instance.fields[1] == field1 + assert instance.name == "remove_fields" + + def test_repr(self): + instance = self._make_one("field1", Field.of("field2")) + repr_str = repr(instance) + assert repr_str == "RemoveFields(Field.of('field1'), Field.of('field2'))" + + def test_to_pb(self): + instance = self._make_one("field1", Field.of("field2")) + result = instance._to_pb() + assert result.name == "remove_fields" + assert len(result.args) == 2 + assert result.args[0].field_reference_value == "field1" + assert result.args[1].field_reference_value == "field2" + assert len(result.options) == 0 + + +class TestSample: + class TestSampleOptions: + def test_ctor_percent(self): + instance = stages.SampleOptions(0.25, stages.SampleOptions.Mode.PERCENT) + assert instance.value == 0.25 + assert instance.mode == stages.SampleOptions.Mode.PERCENT + + def test_ctor_documents(self): + instance = stages.SampleOptions(10, stages.SampleOptions.Mode.DOCUMENTS) + assert instance.value == 10 + assert instance.mode == stages.SampleOptions.Mode.DOCUMENTS + + def test_percentage(self): + instance = stages.SampleOptions.percentage(1) + assert instance.value == 1 + assert instance.mode == stages.SampleOptions.Mode.PERCENT + + def test_doc_limit(self): + instance = stages.SampleOptions.doc_limit(2) + assert instance.value == 2 + assert instance.mode == stages.SampleOptions.Mode.DOCUMENTS + + def test_repr_percentage(self): + instance = stages.SampleOptions.percentage(0.5) + assert repr(instance) == "SampleOptions.percentage(0.5)" + + def test_repr_documents(self): + instance = stages.SampleOptions.doc_limit(10) + assert repr(instance) == "SampleOptions.doc_limit(10)" + + def _make_one(self, *args, **kwargs): + return stages.Sample(*args, **kwargs) + + def test_ctor_w_int(self): + instance_int = self._make_one(10) + assert isinstance(instance_int.options, stages.SampleOptions) + assert instance_int.options.value == 10 + assert instance_int.options.mode == stages.SampleOptions.Mode.DOCUMENTS + assert instance_int.name == "sample" + + def test_ctor_w_options(self): + options = stages.SampleOptions.percentage(0.5) + instance_options = self._make_one(options) + assert instance_options.options == options + assert instance_options.name == "sample" + + def test_repr(self): + instance_int = self._make_one(10) + repr_str_int = repr(instance_int) + assert repr_str_int == "Sample(options=SampleOptions.doc_limit(10))" + + options = stages.SampleOptions.percentage(0.5) + instance_options = self._make_one(options) + repr_str_options = repr(instance_options) + assert repr_str_options == "Sample(options=SampleOptions.percentage(0.5))" + + def test_to_pb_documents_mode(self): + instance_docs = self._make_one(10) + result_docs = instance_docs._to_pb() + assert result_docs.name == "sample" + assert len(result_docs.args) == 2 + assert result_docs.args[0].integer_value == 10 + assert result_docs.args[1].string_value == "documents" + assert len(result_docs.options) == 0 + + def test_to_pb_percent_mode(self): + options_percent = stages.SampleOptions.percentage(0.25) + instance_percent = self._make_one(options_percent) + result_percent = instance_percent._to_pb() + assert result_percent.name == "sample" + assert len(result_percent.args) == 2 + assert result_percent.args[0].double_value == 0.25 + assert result_percent.args[1].string_value == "percent" + assert len(result_percent.options) == 0 + + +class TestSelect: + def _make_one(self, *args, **kwargs): + return stages.Select(*args, **kwargs) + + def test_repr(self): + instance = self._make_one("field1", Field.of("field2")) + repr_str = repr(instance) + assert ( + repr_str == "Select(projections=[Field.of('field1'), Field.of('field2')])" + ) + + def test_to_pb(self): + instance = self._make_one("field1", "field2.subfield", Field.of("field3")) + result = instance._to_pb() + assert result.name == "select" + assert len(result.args) == 1 + got_map = result.args[0].map_value.fields + assert got_map.get("field1").field_reference_value == "field1" + assert got_map.get("field2.subfield").field_reference_value == "field2.subfield" + assert got_map.get("field3").field_reference_value == "field3" + assert len(result.options) == 0 + + +class TestSort: + def _make_one(self, *args, **kwargs): + return stages.Sort(*args, **kwargs) + + def test_repr(self): + order1 = Ordering(Field.of("field1"), "ASCENDING") + instance = self._make_one(order1) + repr_str = repr(instance) + assert repr_str == "Sort(orders=[Field.of('field1').ascending()])" + + def test_to_pb(self): + order1 = Ordering(Field.of("name"), "ASCENDING") + order2 = Ordering(Field.of("age"), "DESCENDING") + instance = self._make_one(order1, order2) + result = instance._to_pb() + assert result.name == "sort" + assert len(result.args) == 2 + got_map = result.args[0].map_value.fields + assert got_map.get("expression").field_reference_value == "name" + assert got_map.get("direction").string_value == "ascending" + assert len(result.options) == 0 + + +class TestUnion: + def _make_one(self, *args, **kwargs): + return stages.Union(*args, **kwargs) + + def test_ctor(self): + mock_pipeline = mock.Mock(spec=_BasePipeline) + instance = self._make_one(mock_pipeline) + assert instance.other == mock_pipeline + assert instance.name == "union" + + def test_repr(self): + test_pipeline = _BasePipeline(mock.Mock()).sample(5) + instance = self._make_one(test_pipeline) + repr_str = repr(instance) + assert repr_str == f"Union(other={test_pipeline!r})" + + def test_to_pb(self): + test_pipeline = _BasePipeline(mock.Mock()).sample(5) + + instance = self._make_one(test_pipeline) + result = instance._to_pb() + + assert result.name == "union" + assert len(result.args) == 1 + assert result.args[0].pipeline_value == test_pipeline._to_pb().pipeline + assert len(result.options) == 0 + + +class TestUnnest: + class TestUnnestOptions: + def _make_one_options(self, *args, **kwargs): + return stages.UnnestOptions(*args, **kwargs) + + def test_ctor_options(self): + index_field_val = "my_index" + instance = self._make_one_options(index_field=index_field_val) + assert instance.index_field == index_field_val + + def test_repr(self): + instance = self._make_one_options(index_field="my_idx") + repr_str = repr(instance) + assert repr_str == "UnnestOptions(index_field='my_idx')" + + def _make_one(self, *args, **kwargs): + return stages.Unnest(*args, **kwargs) + + def test_ctor(self): + instance = self._make_one("my_field") + assert isinstance(instance.field, Field) + assert instance.field.path == "my_field" + assert isinstance(instance.alias, Field) + assert instance.alias.path == "my_field" + assert instance.options is None + assert instance.name == "unnest" + + def test_ctor_full(self): + """constructor with alias and options set""" + field = Field.of("items") + alias = Field.of("alias") + options = stages.UnnestOptions(index_field="item_index") + instance = self._make_one(field, alias, options=options) + assert isinstance(field, Field) + assert instance.field == field + assert isinstance(alias, Field) + assert instance.alias == alias + assert instance.options == options + assert instance.name == "unnest" + + def test_repr(self): + instance_simple = self._make_one("my_field") + repr_str_simple = repr(instance_simple) + assert ( + repr_str_simple + == "Unnest(field=Field.of('my_field'), alias=Field.of('my_field'), options=None)" + ) + + options = stages.UnnestOptions(index_field="item_idx") + instance_full = self._make_one( + Field.of("items"), Field.of("alias"), options=options + ) + repr_str_full = repr(instance_full) + assert ( + repr_str_full + == "Unnest(field=Field.of('items'), alias=Field.of('alias'), options=UnnestOptions(index_field='item_idx'))" + ) + + def test_to_pb(self): + instance = self._make_one(Field.of("dataPoints")) + result = instance._to_pb() + assert result.name == "unnest" + assert len(result.args) == 2 + assert result.args[0].field_reference_value == "dataPoints" + assert result.args[1].field_reference_value == "dataPoints" + assert len(result.options) == 0 + + def test_to_pb_full(self): + field_str = "items" + alias_str = "single_item" + options_val = stages.UnnestOptions(index_field="item_index") + instance = self._make_one(field_str, alias_str, options=options_val) + + result = instance._to_pb() + assert result.name == "unnest" + assert len(result.args) == 2 + assert result.args[0].field_reference_value == field_str + assert result.args[1].field_reference_value == alias_str + + assert len(result.options) == 1 + assert result.options["index_field"].string_value == "item_index" + + +class TestWhere: + def _make_one(self, *args, **kwargs): + return stages.Where(*args, **kwargs) + + def test_repr(self): + condition = Field.of("age").gt(30) + instance = self._make_one(condition) + repr_str = repr(instance) + assert repr_str == "Where(condition=Field.of('age').gt(Constant.of(30)))" + + def test_to_pb(self): + condition = Field.of("city").eq("SF") + instance = self._make_one(condition) + result = instance._to_pb() + assert result.name == "where" + assert len(result.args) == 1 + got_fn = result.args[0].function_value + assert got_fn.name == "eq" + assert len(got_fn.args) == 2 + assert got_fn.args[0].field_reference_value == "city" + assert got_fn.args[1].string_value == "SF" + assert len(result.options) == 0