diff --git a/google/cloud/firestore_v1/_helpers.py b/google/cloud/firestore_v1/_helpers.py index 399bdb066..1fbc1a476 100644 --- a/google/cloud/firestore_v1/_helpers.py +++ b/google/cloud/firestore_v1/_helpers.py @@ -120,6 +120,9 @@ def __ne__(self, other): else: return not equality_val + def __repr__(self): + return f"{type(self).__name__}(latitude={self.latitude}, longitude={self.longitude})" + def verify_path(path, is_collection) -> None: """Verifies that a ``path`` has the correct form. diff --git a/google/cloud/firestore_v1/_pipeline_stages.py b/google/cloud/firestore_v1/_pipeline_stages.py new file mode 100644 index 000000000..3871a363d --- /dev/null +++ b/google/cloud/firestore_v1/_pipeline_stages.py @@ -0,0 +1,81 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from typing import Optional +from abc import ABC +from abc import abstractmethod + +from google.cloud.firestore_v1.types.document import Pipeline as Pipeline_pb +from google.cloud.firestore_v1.types.document import Value +from google.cloud.firestore_v1.pipeline_expressions import Expr + + +class Stage(ABC): + """Base class for all pipeline stages. + + Each stage represents a specific operation (e.g., filtering, sorting, + transforming) within a Firestore pipeline. Subclasses define the specific + arguments and behavior for each operation. + """ + + def __init__(self, custom_name: Optional[str] = None): + self.name = custom_name or type(self).__name__.lower() + + def _to_pb(self) -> Pipeline_pb.Stage: + return Pipeline_pb.Stage( + name=self.name, args=self._pb_args(), options=self._pb_options() + ) + + @abstractmethod + def _pb_args(self) -> list[Value]: + """Return Ordered list of arguments the given stage expects""" + raise NotImplementedError + + def _pb_options(self) -> dict[str, Value]: + """Return optional named arguments that certain functions may support.""" + return {} + + def __repr__(self): + items = ("%s=%r" % (k, v) for k, v in self.__dict__.items() if k != "name") + return f"{self.__class__.__name__}({', '.join(items)})" + + +class Collection(Stage): + """Specifies a collection as the initial data source.""" + + def __init__(self, path: str): + super().__init__() + if not path.startswith("/"): + path = f"/{path}" + self.path = path + + def _pb_args(self): + return [Value(reference_value=self.path)] + + +class GenericStage(Stage): + """Represents a generic, named stage with parameters.""" + + def __init__(self, name: str, *params: Expr | Value): + super().__init__(name) + self.params: list[Value] = [ + p._to_pb() if isinstance(p, Expr) else p for p in params + ] + + def _pb_args(self): + return self.params + + def __repr__(self): + return f"{self.__class__.__name__}(name='{self.name}')" diff --git a/google/cloud/firestore_v1/async_client.py b/google/cloud/firestore_v1/async_client.py index 15b31af31..3acbedc76 100644 --- a/google/cloud/firestore_v1/async_client.py +++ b/google/cloud/firestore_v1/async_client.py @@ -47,6 +47,8 @@ from google.cloud.firestore_v1.services.firestore.transports import ( grpc_asyncio as firestore_grpc_transport, ) +from google.cloud.firestore_v1.async_pipeline import AsyncPipeline +from google.cloud.firestore_v1.pipeline_source import PipelineSource if TYPE_CHECKING: # pragma: NO COVER import datetime @@ -427,3 +429,10 @@ def transaction(self, **kwargs) -> AsyncTransaction: A transaction attached to this client. """ return AsyncTransaction(self, **kwargs) + + @property + def _pipeline_cls(self): + return AsyncPipeline + + def pipeline(self) -> PipelineSource: + return PipelineSource(self) diff --git a/google/cloud/firestore_v1/async_pipeline.py b/google/cloud/firestore_v1/async_pipeline.py new file mode 100644 index 000000000..471c33093 --- /dev/null +++ b/google/cloud/firestore_v1/async_pipeline.py @@ -0,0 +1,96 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from typing import AsyncIterable, TYPE_CHECKING +from google.cloud.firestore_v1 import _pipeline_stages as stages +from google.cloud.firestore_v1.base_pipeline import _BasePipeline + +if TYPE_CHECKING: # pragma: NO COVER + from google.cloud.firestore_v1.async_client import AsyncClient + from google.cloud.firestore_v1.pipeline_result import PipelineResult + from google.cloud.firestore_v1.async_transaction import AsyncTransaction + + +class AsyncPipeline(_BasePipeline): + """ + Pipelines allow for complex data transformations and queries involving + multiple stages like filtering, projection, aggregation, and vector search. + + This class extends `_BasePipeline` and provides methods to execute the + defined pipeline stages using an asynchronous `AsyncClient`. + + Usage Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> + >>> async def run_pipeline(): + ... client = AsyncClient(...) + ... pipeline = client.pipeline() + ... .collection("books") + ... .where(Field.of("published").gt(1980)) + ... .select("title", "author") + ... async for result in pipeline.execute(): + ... print(result) + + Use `client.pipeline()` to create instances of this class. + """ + + def __init__(self, client: AsyncClient, *stages: stages.Stage): + """ + Initializes an asynchronous Pipeline. + + Args: + client: The asynchronous `AsyncClient` instance to use for execution. + *stages: Initial stages for the pipeline. + """ + super().__init__(client, *stages) + + async def execute( + self, + transaction: "AsyncTransaction" | None = None, + ) -> list[PipelineResult]: + """ + Executes this pipeline and returns results as a list + + Args: + transaction + (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + """ + return [result async for result in self.stream(transaction=transaction)] + + async def stream( + self, + transaction: "AsyncTransaction" | None = None, + ) -> AsyncIterable[PipelineResult]: + """ + Process this pipeline as a stream, providing results through an Iterable + + Args: + transaction + (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + """ + request = self._prep_execute_request(transaction) + async for response in await self._client._firestore_api.execute_pipeline( + request + ): + for result in self._execute_response_helper(response): + yield result diff --git a/google/cloud/firestore_v1/base_client.py b/google/cloud/firestore_v1/base_client.py index 4a0e3f6b8..8c8b9532d 100644 --- a/google/cloud/firestore_v1/base_client.py +++ b/google/cloud/firestore_v1/base_client.py @@ -37,6 +37,7 @@ Optional, Tuple, Union, + Type, ) import google.api_core.client_options @@ -61,6 +62,8 @@ from google.cloud.firestore_v1.bulk_writer import BulkWriter, BulkWriterOptions from google.cloud.firestore_v1.field_path import render_field_path from google.cloud.firestore_v1.services.firestore import client as firestore_client +from google.cloud.firestore_v1.pipeline_source import PipelineSource +from google.cloud.firestore_v1.base_pipeline import _BasePipeline DEFAULT_DATABASE = "(default)" """str: The default database used in a :class:`~google.cloud.firestore_v1.client.Client`.""" @@ -500,6 +503,20 @@ def batch(self) -> BaseWriteBatch: def transaction(self, **kwargs) -> BaseTransaction: raise NotImplementedError + def pipeline(self) -> PipelineSource: + """ + Start a pipeline with this client. + + Returns: + :class:`~google.cloud.firestore_v1.pipeline_source.PipelineSource`: + A pipeline that uses this client` + """ + raise NotImplementedError + + @property + def _pipeline_cls(self) -> Type["_BasePipeline"]: + raise NotImplementedError + def _reference_info(references: list) -> Tuple[list, dict]: """Get information about document references. diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py new file mode 100644 index 000000000..dde906fe6 --- /dev/null +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -0,0 +1,151 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from typing import Iterable, Sequence, TYPE_CHECKING +from google.cloud.firestore_v1 import _pipeline_stages as stages +from google.cloud.firestore_v1.types.pipeline import ( + StructuredPipeline as StructuredPipeline_pb, +) +from google.cloud.firestore_v1.types.firestore import ExecutePipelineRequest +from google.cloud.firestore_v1.pipeline_result import PipelineResult +from google.cloud.firestore_v1.pipeline_expressions import Expr +from google.cloud.firestore_v1 import _helpers + +if TYPE_CHECKING: # pragma: NO COVER + from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.async_client import AsyncClient + from google.cloud.firestore_v1.types.firestore import ExecutePipelineResponse + from google.cloud.firestore_v1.transaction import BaseTransaction + + +class _BasePipeline: + """ + Base class for building Firestore data transformation and query pipelines. + + This class is not intended to be instantiated directly. + Use `client.collection.("...").pipeline()` to create pipeline instances. + """ + + def __init__(self, client: Client | AsyncClient): + """ + Initializes a new pipeline. + + Pipelines should not be instantiated directly. Instead, + call client.pipeline() to create an instance + + Args: + client: The client associated with the pipeline + """ + self._client = client + self.stages: Sequence[stages.Stage] = tuple() + + @classmethod + def _create_with_stages( + cls, client: Client | AsyncClient, *stages + ) -> _BasePipeline: + """ + Initializes a new pipeline with the given stages. + + Pipeline classes should not be instantiated directly. + + Args: + client: The client associated with the pipeline + *stages: Initial stages for the pipeline. + """ + new_instance = cls(client) + new_instance.stages = tuple(stages) + return new_instance + + def __repr__(self): + cls_str = type(self).__name__ + if not self.stages: + return f"{cls_str}()" + elif len(self.stages) == 1: + return f"{cls_str}({self.stages[0]!r})" + else: + stages_str = ",\n ".join([repr(s) for s in self.stages]) + return f"{cls_str}(\n {stages_str}\n)" + + def _to_pb(self) -> StructuredPipeline_pb: + return StructuredPipeline_pb( + pipeline={"stages": [s._to_pb() for s in self.stages]} + ) + + def _append(self, new_stage): + """ + Create a new Pipeline object with a new stage appended + """ + return self.__class__._create_with_stages(self._client, *self.stages, new_stage) + + def _prep_execute_request( + self, transaction: BaseTransaction | None + ) -> ExecutePipelineRequest: + """ + shared logic for creating an ExecutePipelineRequest + """ + database_name = ( + f"projects/{self._client.project}/databases/{self._client._database}" + ) + transaction_id = ( + _helpers.get_transaction_id(transaction) + if transaction is not None + else None + ) + request = ExecutePipelineRequest( + database=database_name, + transaction=transaction_id, + structured_pipeline=self._to_pb(), + ) + return request + + def _execute_response_helper( + self, response: ExecutePipelineResponse + ) -> Iterable[PipelineResult]: + """ + shared logic for unpacking an ExecutePipelineReponse into PipelineResults + """ + for doc in response.results: + ref = self._client.document(doc.name) if doc.name else None + yield PipelineResult( + self._client, + doc.fields, + ref, + response._pb.execution_time, + doc._pb.create_time if doc.create_time else None, + doc._pb.update_time if doc.update_time else None, + ) + + def generic_stage(self, name: str, *params: Expr) -> "_BasePipeline": + """ + Adds a generic, named stage to the pipeline with specified parameters. + + This method provides a flexible way to extend the pipeline's functionality + by adding custom stages. Each generic stage is defined by a unique `name` + and a set of `params` that control its behavior. + + Example: + >>> # Assume we don't have a built-in "where" stage + >>> pipeline = client.pipeline().collection("books") + >>> pipeline = pipeline.generic_stage("where", [Field.of("published").lt(900)]) + >>> pipeline = pipeline.select("title", "author") + + Args: + name: The name of the generic stage. + *params: A sequence of `Expr` objects representing the parameters for the stage. + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.GenericStage(name, *params)) diff --git a/google/cloud/firestore_v1/client.py b/google/cloud/firestore_v1/client.py index ec906f991..c23943b24 100644 --- a/google/cloud/firestore_v1/client.py +++ b/google/cloud/firestore_v1/client.py @@ -49,6 +49,8 @@ grpc as firestore_grpc_transport, ) from google.cloud.firestore_v1.transaction import Transaction +from google.cloud.firestore_v1.pipeline import Pipeline +from google.cloud.firestore_v1.pipeline_source import PipelineSource if TYPE_CHECKING: # pragma: NO COVER from google.cloud.firestore_v1.bulk_writer import BulkWriter @@ -408,3 +410,10 @@ def transaction(self, **kwargs) -> Transaction: A transaction attached to this client. """ return Transaction(self, **kwargs) + + @property + def _pipeline_cls(self): + return Pipeline + + def pipeline(self) -> PipelineSource: + return PipelineSource(self) diff --git a/google/cloud/firestore_v1/field_path.py b/google/cloud/firestore_v1/field_path.py index 27ac6cc45..32516d3be 100644 --- a/google/cloud/firestore_v1/field_path.py +++ b/google/cloud/firestore_v1/field_path.py @@ -16,7 +16,7 @@ from __future__ import annotations import re from collections import abc -from typing import Iterable, cast +from typing import Any, Iterable, cast, MutableMapping _FIELD_PATH_MISSING_TOP = "{!r} is not contained in the data" _FIELD_PATH_MISSING_KEY = "{!r} is not contained in the data for the key {!r}" @@ -170,7 +170,7 @@ def render_field_path(field_names: Iterable[str]): get_field_path = render_field_path # backward-compatibility -def get_nested_value(field_path: str, data: dict): +def get_nested_value(field_path: str, data: MutableMapping[str, Any]): """Get a (potentially nested) value from a dictionary. If the data is nested, for example: diff --git a/google/cloud/firestore_v1/pipeline.py b/google/cloud/firestore_v1/pipeline.py new file mode 100644 index 000000000..9f568f925 --- /dev/null +++ b/google/cloud/firestore_v1/pipeline.py @@ -0,0 +1,90 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from typing import Iterable, TYPE_CHECKING +from google.cloud.firestore_v1 import _pipeline_stages as stages +from google.cloud.firestore_v1.base_pipeline import _BasePipeline + +if TYPE_CHECKING: # pragma: NO COVER + from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.pipeline_result import PipelineResult + from google.cloud.firestore_v1.transaction import Transaction + + +class Pipeline(_BasePipeline): + """ + Pipelines allow for complex data transformations and queries involving + multiple stages like filtering, projection, aggregation, and vector search. + + Usage Example: + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> + >>> def run_pipeline(): + ... client = Client(...) + ... pipeline = client.pipeline() + ... .collection("books") + ... .where(Field.of("published").gt(1980)) + ... .select("title", "author") + ... for result in pipeline.execute(): + ... print(result) + + Use `client.pipeline()` to create instances of this class. + """ + + def __init__(self, client: Client, *stages: stages.Stage): + """ + Initializes a Pipeline. + + Args: + client: The `Client` instance to use for execution. + *stages: Initial stages for the pipeline. + """ + super().__init__(client, *stages) + + def execute( + self, + transaction: "Transaction" | None = None, + ) -> list[PipelineResult]: + """ + Executes this pipeline and returns results as a list + + Args: + transaction + (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + """ + return [result for result in self.stream(transaction=transaction)] + + def stream( + self, + transaction: "Transaction" | None = None, + ) -> Iterable[PipelineResult]: + """ + Process this pipeline as a stream, providing results through an Iterable + + Args: + transaction + (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): + An existing transaction that this query will run in. + If a ``transaction`` is used and it already has write operations + added, this method cannot be used (i.e. read-after-write is not + allowed). + """ + request = self._prep_execute_request(transaction) + for response in self._client._firestore_api.execute_pipeline(request): + yield from self._execute_response_helper(response) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py new file mode 100644 index 000000000..5e0c775a2 --- /dev/null +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -0,0 +1,85 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from typing import ( + Any, + Generic, + TypeVar, + Dict, +) +from abc import ABC +from abc import abstractmethod +import datetime +from google.cloud.firestore_v1.types.document import Value +from google.cloud.firestore_v1.vector import Vector +from google.cloud.firestore_v1._helpers import GeoPoint +from google.cloud.firestore_v1._helpers import encode_value + +CONSTANT_TYPE = TypeVar( + "CONSTANT_TYPE", + str, + int, + float, + bool, + datetime.datetime, + bytes, + GeoPoint, + Vector, + list, + Dict[str, Any], + None, +) + + +class Expr(ABC): + """Represents an expression that can be evaluated to a value within the + execution of a pipeline. + + Expressions are the building blocks for creating complex queries and + transformations in Firestore pipelines. They can represent: + + - **Field references:** Access values from document fields. + - **Literals:** Represent constant values (strings, numbers, booleans). + - **Function calls:** Apply functions to one or more expressions. + - **Aggregations:** Calculate aggregate values (e.g., sum, average) over a set of documents. + + The `Expr` class provides a fluent API for building expressions. You can chain + together method calls to create complex expressions. + """ + + def __repr__(self): + return f"{self.__class__.__name__}()" + + @abstractmethod + def _to_pb(self) -> Value: + raise NotImplementedError + + +class Constant(Expr, Generic[CONSTANT_TYPE]): + """Represents a constant literal value in an expression.""" + + def __init__(self, value: CONSTANT_TYPE): + self.value: CONSTANT_TYPE = value + + @staticmethod + def of(value: CONSTANT_TYPE) -> Constant[CONSTANT_TYPE]: + """Creates a constant expression from a Python value.""" + return Constant(value) + + def __repr__(self): + return f"Constant.of({self.value!r})" + + def _to_pb(self) -> Value: + return encode_value(self.value) diff --git a/google/cloud/firestore_v1/pipeline_result.py b/google/cloud/firestore_v1/pipeline_result.py new file mode 100644 index 000000000..ada855fea --- /dev/null +++ b/google/cloud/firestore_v1/pipeline_result.py @@ -0,0 +1,139 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from typing import Any, MutableMapping, TYPE_CHECKING +from google.cloud.firestore_v1 import _helpers +from google.cloud.firestore_v1.field_path import get_nested_value +from google.cloud.firestore_v1.field_path import FieldPath + +if TYPE_CHECKING: # pragma: NO COVER + from google.cloud.firestore_v1.base_client import BaseClient + from google.cloud.firestore_v1.base_document import BaseDocumentReference + from google.protobuf.timestamp_pb2 import Timestamp + from google.cloud.firestore_v1.types.document import Value as ValueProto + from google.cloud.firestore_v1.vector import Vector + + +class PipelineResult: + """ + Contains data read from a Firestore Pipeline. The data can be extracted with + the `data()` or `get()` methods. + + If the PipelineResult represents a non-document result `ref` may be `None`. + """ + + def __init__( + self, + client: BaseClient, + fields_pb: MutableMapping[str, ValueProto], + ref: BaseDocumentReference | None = None, + execution_time: Timestamp | None = None, + create_time: Timestamp | None = None, + update_time: Timestamp | None = None, + ): + """ + PipelineResult should be returned from `pipeline.execute()`, not constructed manually. + + Args: + client: The Firestore client instance. + fields_pb: A map of field names to their protobuf Value representations. + ref: The DocumentReference or AsyncDocumentReference if this result corresponds to a document. + execution_time: The time at which the pipeline execution producing this result occurred. + create_time: The creation time of the document, if applicable. + update_time: The last update time of the document, if applicable. + """ + self._client = client + self._fields_pb = fields_pb + self._ref = ref + self._execution_time = execution_time + self._create_time = create_time + self._update_time = update_time + + def __repr__(self): + return f"{type(self).__name__}(data={self.data()})" + + @property + def ref(self) -> BaseDocumentReference | None: + """ + The `BaseDocumentReference` if this result represents a document, else `None`. + """ + return self._ref + + @property + def id(self) -> str | None: + """The ID of the document if this result represents a document, else `None`.""" + return self._ref.id if self._ref else None + + @property + def create_time(self) -> Timestamp | None: + """The creation time of the document. `None` if not applicable.""" + return self._create_time + + @property + def update_time(self) -> Timestamp | None: + """The last update time of the document. `None` if not applicable.""" + return self._update_time + + @property + def execution_time(self) -> Timestamp: + """ + The time at which the pipeline producing this result was executed. + + Raise: + ValueError: if not set + """ + if self._execution_time is None: + raise ValueError("'execution_time' is expected to exist, but it is None.") + return self._execution_time + + def __eq__(self, other: object) -> bool: + """ + Compares this `PipelineResult` to another object for equality. + + Two `PipelineResult` instances are considered equal if their document + references (if any) are equal and their underlying field data + (protobuf representation) is identical. + """ + if not isinstance(other, PipelineResult): + return NotImplemented + return (self._ref == other._ref) and (self._fields_pb == other._fields_pb) + + def data(self) -> dict | "Vector" | None: + """ + Retrieves all fields in the result. + + Returns: + The data in dictionary format, or `None` if the document doesn't exist. + """ + if self._fields_pb is None: + return None + + return _helpers.decode_dict(self._fields_pb, self._client) + + def get(self, field_path: str | FieldPath) -> Any: + """ + Retrieves the field specified by `field_path`. + + Args: + field_path: The field path (e.g. 'foo' or 'foo.bar') to a specific field. + + Returns: + The data at the specified field location, decoded to Python types. + """ + str_path = ( + field_path if isinstance(field_path, str) else field_path.to_api_repr() + ) + value = get_nested_value(str_path, self._fields_pb) + return _helpers.decode_value(value, self._client) diff --git a/google/cloud/firestore_v1/pipeline_source.py b/google/cloud/firestore_v1/pipeline_source.py new file mode 100644 index 000000000..f2f081fee --- /dev/null +++ b/google/cloud/firestore_v1/pipeline_source.py @@ -0,0 +1,53 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +from typing import Generic, TypeVar, TYPE_CHECKING +from google.cloud.firestore_v1 import _pipeline_stages as stages +from google.cloud.firestore_v1.base_pipeline import _BasePipeline + +if TYPE_CHECKING: # pragma: NO COVER + from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.async_client import AsyncClient + + +PipelineType = TypeVar("PipelineType", bound=_BasePipeline) + + +class PipelineSource(Generic[PipelineType]): + """ + A factory for creating Pipeline instances, which provide a framework for building data + transformation and query pipelines for Firestore. + + Not meant to be instantiated directly. Instead, start by calling client.pipeline() + to obtain an instance of PipelineSource. From there, you can use the provided + methods to specify the data source for your pipeline. + """ + + def __init__(self, client: Client | AsyncClient): + self.client = client + + def _create_pipeline(self, source_stage): + return self.client._pipeline_cls._create_with_stages(self.client, source_stage) + + def collection(self, path: str) -> PipelineType: + """ + Creates a new Pipeline that operates on a specified Firestore collection. + + Args: + path: The path to the Firestore collection (e.g., "users") + Returns: + a new pipeline instance targeting the specified collection + """ + return self._create_pipeline(stages.Collection(path)) diff --git a/noxfile.py b/noxfile.py index 9e81d7179..a01af1bad 100644 --- a/noxfile.py +++ b/noxfile.py @@ -70,6 +70,7 @@ SYSTEM_TEST_EXTERNAL_DEPENDENCIES: List[str] = [ "pytest-asyncio==0.21.2", "six", + "pyyaml", ] SYSTEM_TEST_LOCAL_DEPENDENCIES: List[str] = [] SYSTEM_TEST_DEPENDENCIES: List[str] = [] diff --git a/tests/unit/v1/test_async_client.py b/tests/unit/v1/test_async_client.py index 4924856a8..210aae88d 100644 --- a/tests/unit/v1/test_async_client.py +++ b/tests/unit/v1/test_async_client.py @@ -560,6 +560,17 @@ def test_asyncclient_transaction(): assert transaction._id is None +def test_asyncclient_pipeline(): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + from google.cloud.firestore_v1.pipeline_source import PipelineSource + + client = _make_default_async_client() + ppl = client.pipeline() + assert client._pipeline_cls == AsyncPipeline + assert isinstance(ppl, PipelineSource) + assert ppl.client == client + + def _make_credentials(): import google.auth.credentials diff --git a/tests/unit/v1/test_async_pipeline.py b/tests/unit/v1/test_async_pipeline.py new file mode 100644 index 000000000..3abc3619b --- /dev/null +++ b/tests/unit/v1/test_async_pipeline.py @@ -0,0 +1,393 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import mock +import pytest + +from google.cloud.firestore_v1 import _pipeline_stages as stages + + +def _make_async_pipeline(*args, client=mock.Mock()): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + + return AsyncPipeline._create_with_stages(client, *args) + + +async def _async_it(list): + for value in list: + yield value + + +def test_ctor(): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + + client = object() + instance = AsyncPipeline(client) + assert instance._client == client + assert len(instance.stages) == 0 + + +def test_create(): + from google.cloud.firestore_v1.async_pipeline import AsyncPipeline + + client = object() + stages = [object() for i in range(10)] + instance = AsyncPipeline._create_with_stages(client, *stages) + assert instance._client == client + assert len(instance.stages) == 10 + assert instance.stages[0] == stages[0] + assert instance.stages[-1] == stages[-1] + + +def test_async_pipeline_repr_empty(): + ppl = _make_async_pipeline() + repr_str = repr(ppl) + assert repr_str == "AsyncPipeline()" + + +def test_async_pipeline_repr_single_stage(): + stage = mock.Mock() + stage.__repr__ = lambda x: "SingleStage" + ppl = _make_async_pipeline(stage) + repr_str = repr(ppl) + assert repr_str == "AsyncPipeline(SingleStage)" + + +def test_async_pipeline_repr_multiple_stage(): + stage_1 = stages.Collection("path") + stage_2 = stages.GenericStage("second", 2) + stage_3 = stages.GenericStage("third", 3) + ppl = _make_async_pipeline(stage_1, stage_2, stage_3) + repr_str = repr(ppl) + assert repr_str == ( + "AsyncPipeline(\n" + " Collection(path='/path'),\n" + " GenericStage(name='second'),\n" + " GenericStage(name='third')\n" + ")" + ) + + +def test_async_pipeline_repr_long(): + num_stages = 100 + stage_list = [stages.GenericStage("custom", i) for i in range(num_stages)] + ppl = _make_async_pipeline(*stage_list) + repr_str = repr(ppl) + assert repr_str.count("GenericStage") == num_stages + assert repr_str.count("\n") == num_stages + 1 + + +def test_async_pipeline__to_pb(): + from google.cloud.firestore_v1.types.pipeline import StructuredPipeline + + stage_1 = stages.GenericStage("first") + stage_2 = stages.GenericStage("second") + ppl = _make_async_pipeline(stage_1, stage_2) + pb = ppl._to_pb() + assert isinstance(pb, StructuredPipeline) + assert pb.pipeline.stages[0] == stage_1._to_pb() + assert pb.pipeline.stages[1] == stage_2._to_pb() + + +def test_async_pipeline_append(): + """append should create a new pipeline with the additional stage""" + stage_1 = stages.GenericStage("first") + ppl_1 = _make_async_pipeline(stage_1, client=object()) + stage_2 = stages.GenericStage("second") + ppl_2 = ppl_1._append(stage_2) + assert ppl_1 != ppl_2 + assert len(ppl_1.stages) == 1 + assert len(ppl_2.stages) == 2 + assert ppl_2.stages[0] == stage_1 + assert ppl_2.stages[1] == stage_2 + assert ppl_1._client == ppl_2._client + assert isinstance(ppl_2, type(ppl_1)) + + +@pytest.mark.asyncio +async def test_async_pipeline_stream_empty(): + """ + test stream pipeline with mocked empty response + """ + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + + client = mock.Mock() + client.project = "A" + client._database = "B" + mock_rpc = mock.AsyncMock() + client._firestore_api.execute_pipeline = mock_rpc + mock_rpc.return_value = _async_it([ExecutePipelineResponse()]) + ppl_1 = _make_async_pipeline(stages.GenericStage("s"), client=client) + + results = [r async for r in ppl_1.stream()] + assert results == [] + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + + +@pytest.mark.asyncio +async def test_async_pipeline_stream_no_doc_ref(): + """ + test stream pipeline with no doc ref + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.pipeline_result import PipelineResult + + client = mock.Mock() + client.project = "A" + client._database = "B" + mock_rpc = mock.AsyncMock() + client._firestore_api.execute_pipeline = mock_rpc + mock_rpc.return_value = _async_it( + [ExecutePipelineResponse(results=[Document()], execution_time={"seconds": 9})] + ) + ppl_1 = _make_async_pipeline(stages.GenericStage("s"), client=client) + + results = [r async for r in ppl_1.stream()] + assert len(results) == 1 + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + assert request.transaction == b"" + + response = results[0] + assert isinstance(response, PipelineResult) + assert response.ref is None + assert response.id is None + assert response.create_time is None + assert response.update_time is None + assert response.execution_time.seconds == 9 + assert response.data() == {} + + +@pytest.mark.asyncio +async def test_async_pipeline_stream_populated(): + """ + test stream pipeline with fully populated doc ref + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.types import Value + from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.document import DocumentReference + from google.cloud.firestore_v1.pipeline_result import PipelineResult + + real_client = Client() + client = mock.Mock() + client.project = "A" + client._database = "B" + client.document = real_client.document + mock_rpc = mock.AsyncMock() + client._firestore_api.execute_pipeline = mock_rpc + + mock_rpc.return_value = _async_it( + [ + ExecutePipelineResponse( + results=[ + Document( + name="test/my_doc", + create_time={"seconds": 1}, + update_time={"seconds": 2}, + fields={"key": Value(string_value="str_val")}, + ) + ], + execution_time={"seconds": 9}, + ) + ] + ) + ppl_1 = _make_async_pipeline(client=client) + + results = [r async for r in ppl_1.stream()] + assert len(results) == 1 + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + + response = results[0] + assert isinstance(response, PipelineResult) + assert isinstance(response.ref, DocumentReference) + assert response.ref.path == "test/my_doc" + assert response.id == "my_doc" + assert response.create_time.seconds == 1 + assert response.update_time.seconds == 2 + assert response.execution_time.seconds == 9 + assert response.data() == {"key": "str_val"} + + +@pytest.mark.asyncio +async def test_async_pipeline_stream_multiple(): + """ + test stream pipeline with multiple docs and responses + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.types import Value + from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.pipeline_result import PipelineResult + + real_client = Client() + client = mock.Mock() + client.project = "A" + client._database = "B" + client.document = real_client.document + mock_rpc = mock.AsyncMock() + client._firestore_api.execute_pipeline = mock_rpc + + mock_rpc.return_value = _async_it( + [ + ExecutePipelineResponse( + results=[ + Document(fields={"key": Value(integer_value=0)}), + Document(fields={"key": Value(integer_value=1)}), + ], + execution_time={"seconds": 0}, + ), + ExecutePipelineResponse( + results=[ + Document(fields={"key": Value(integer_value=2)}), + Document(fields={"key": Value(integer_value=3)}), + ], + execution_time={"seconds": 1}, + ), + ] + ) + ppl_1 = _make_async_pipeline(client=client) + + results = [r async for r in ppl_1.stream()] + assert len(results) == 4 + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + + for idx, response in enumerate(results): + assert isinstance(response, PipelineResult) + assert response.data() == {"key": idx} + + +@pytest.mark.asyncio +async def test_async_pipeline_stream_with_transaction(): + """ + test stream pipeline with transaction context + """ + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.async_transaction import AsyncTransaction + + client = mock.Mock() + client.project = "A" + client._database = "B" + mock_rpc = mock.AsyncMock() + client._firestore_api.execute_pipeline = mock_rpc + + transaction = AsyncTransaction(client) + transaction._id = b"123" + + mock_rpc.return_value = _async_it([ExecutePipelineResponse()]) + ppl_1 = _make_async_pipeline(client=client) + + [r async for r in ppl_1.stream(transaction=transaction)] + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + assert request.transaction == b"123" + + +@pytest.mark.asyncio +async def test_async_pipeline_stream_stream_equivalence(): + """ + Pipeline.stream should provide same results from pipeline.stream, as a list + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import Value + from google.cloud.firestore_v1.client import Client + + real_client = Client() + client = mock.Mock() + client.project = "A" + client._database = "B" + client.document = real_client.document + mock_rpc = mock.AsyncMock() + client._firestore_api.execute_pipeline = mock_rpc + mock_response = [ + ExecutePipelineResponse( + results=[ + Document( + name="test/my_doc", + fields={"key": Value(string_value="str_val")}, + ) + ], + ) + ] + mock_rpc.return_value = _async_it(mock_response) + ppl_1 = _make_async_pipeline(client=client) + + stream_results = [r async for r in ppl_1.stream()] + # reset response + mock_rpc.return_value = _async_it(mock_response) + stream_results = await ppl_1.execute() + assert stream_results == stream_results + assert stream_results[0].data()["key"] == "str_val" + assert stream_results[0].data()["key"] == "str_val" + + +@pytest.mark.asyncio +async def test_async_pipeline_stream_stream_equivalence_mocked(): + """ + pipeline.stream should call pipeline.stream internally + """ + ppl_1 = _make_async_pipeline() + expected_data = [object(), object()] + expected_arg = object() + with mock.patch.object(ppl_1, "stream") as mock_stream: + mock_stream.return_value = _async_it(expected_data) + stream_results = await ppl_1.execute(expected_arg) + assert mock_stream.call_count == 1 + assert mock_stream.call_args[0] == () + assert len(mock_stream.call_args[1]) == 1 + assert mock_stream.call_args[1]["transaction"] == expected_arg + assert stream_results == expected_data + + +@pytest.mark.parametrize( + "method,args,result_cls", + [ + ("generic_stage", ("name",), stages.GenericStage), + ("generic_stage", ("name", mock.Mock()), stages.GenericStage), + ], +) +def test_async_pipeline_methods(method, args, result_cls): + start_ppl = _make_async_pipeline() + method_ptr = getattr(start_ppl, method) + result_ppl = method_ptr(*args) + assert result_ppl != start_ppl + assert len(start_ppl.stages) == 0 + assert len(result_ppl.stages) == 1 + assert isinstance(result_ppl.stages[0], result_cls) diff --git a/tests/unit/v1/test_client.py b/tests/unit/v1/test_client.py index df3ae15b4..9d0199f92 100644 --- a/tests/unit/v1/test_client.py +++ b/tests/unit/v1/test_client.py @@ -648,6 +648,18 @@ def test_client_transaction(database): assert transaction._id is None +@pytest.mark.parametrize("database", [None, DEFAULT_DATABASE, "somedb"]) +def test_client_pipeline(database): + from google.cloud.firestore_v1.pipeline import Pipeline + from google.cloud.firestore_v1.pipeline_source import PipelineSource + + client = _make_default_client(database=database) + ppl = client.pipeline() + assert client._pipeline_cls == Pipeline + assert isinstance(ppl, PipelineSource) + assert ppl.client == client + + def _make_batch_response(**kwargs): from google.cloud.firestore_v1.types import firestore diff --git a/tests/unit/v1/test_pipeline.py b/tests/unit/v1/test_pipeline.py new file mode 100644 index 000000000..6a3fef3ac --- /dev/null +++ b/tests/unit/v1/test_pipeline.py @@ -0,0 +1,370 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import mock +import pytest + +from google.cloud.firestore_v1 import _pipeline_stages as stages + + +def _make_pipeline(*args, client=mock.Mock()): + from google.cloud.firestore_v1.pipeline import Pipeline + + return Pipeline._create_with_stages(client, *args) + + +def test_ctor(): + from google.cloud.firestore_v1.pipeline import Pipeline + + client = object() + instance = Pipeline(client) + assert instance._client == client + assert len(instance.stages) == 0 + + +def test_create(): + from google.cloud.firestore_v1.pipeline import Pipeline + + client = object() + stages = [object() for i in range(10)] + instance = Pipeline._create_with_stages(client, *stages) + assert instance._client == client + assert len(instance.stages) == 10 + assert instance.stages[0] == stages[0] + assert instance.stages[-1] == stages[-1] + + +def test_pipeline_repr_empty(): + ppl = _make_pipeline() + repr_str = repr(ppl) + assert repr_str == "Pipeline()" + + +def test_pipeline_repr_single_stage(): + stage = mock.Mock() + stage.__repr__ = lambda x: "SingleStage" + ppl = _make_pipeline(stage) + repr_str = repr(ppl) + assert repr_str == "Pipeline(SingleStage)" + + +def test_pipeline_repr_multiple_stage(): + stage_1 = stages.Collection("path") + stage_2 = stages.GenericStage("second", 2) + stage_3 = stages.GenericStage("third", 3) + ppl = _make_pipeline(stage_1, stage_2, stage_3) + repr_str = repr(ppl) + assert repr_str == ( + "Pipeline(\n" + " Collection(path='/path'),\n" + " GenericStage(name='second'),\n" + " GenericStage(name='third')\n" + ")" + ) + + +def test_pipeline_repr_long(): + num_stages = 100 + stage_list = [stages.GenericStage("custom", i) for i in range(num_stages)] + ppl = _make_pipeline(*stage_list) + repr_str = repr(ppl) + assert repr_str.count("GenericStage") == num_stages + assert repr_str.count("\n") == num_stages + 1 + + +def test_pipeline__to_pb(): + from google.cloud.firestore_v1.types.pipeline import StructuredPipeline + + stage_1 = stages.GenericStage("first") + stage_2 = stages.GenericStage("second") + ppl = _make_pipeline(stage_1, stage_2) + pb = ppl._to_pb() + assert isinstance(pb, StructuredPipeline) + assert pb.pipeline.stages[0] == stage_1._to_pb() + assert pb.pipeline.stages[1] == stage_2._to_pb() + + +def test_pipeline_append(): + """append should create a new pipeline with the additional stage""" + + stage_1 = stages.GenericStage("first") + ppl_1 = _make_pipeline(stage_1, client=object()) + stage_2 = stages.GenericStage("second") + ppl_2 = ppl_1._append(stage_2) + assert ppl_1 != ppl_2 + assert len(ppl_1.stages) == 1 + assert len(ppl_2.stages) == 2 + assert ppl_2.stages[0] == stage_1 + assert ppl_2.stages[1] == stage_2 + assert ppl_1._client == ppl_2._client + assert isinstance(ppl_2, type(ppl_1)) + + +def test_pipeline_stream_empty(): + """ + test stream pipeline with mocked empty response + """ + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + + client = mock.Mock() + client.project = "A" + client._database = "B" + mock_rpc = client._firestore_api.execute_pipeline + mock_rpc.return_value = [ExecutePipelineResponse()] + ppl_1 = _make_pipeline(stages.GenericStage("s"), client=client) + + results = list(ppl_1.stream()) + assert results == [] + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + + +def test_pipeline_stream_no_doc_ref(): + """ + test stream pipeline with no doc ref + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.pipeline_result import PipelineResult + + client = mock.Mock() + client.project = "A" + client._database = "B" + mock_rpc = client._firestore_api.execute_pipeline + mock_rpc.return_value = [ + ExecutePipelineResponse(results=[Document()], execution_time={"seconds": 9}) + ] + ppl_1 = _make_pipeline(stages.GenericStage("s"), client=client) + + results = list(ppl_1.stream()) + assert len(results) == 1 + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + + response = results[0] + assert isinstance(response, PipelineResult) + assert response.ref is None + assert response.id is None + assert response.create_time is None + assert response.update_time is None + assert response.execution_time.seconds == 9 + assert response.data() == {} + + +def test_pipeline_stream_populated(): + """ + test stream pipeline with fully populated doc ref + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.types import Value + from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.document import DocumentReference + from google.cloud.firestore_v1.pipeline_result import PipelineResult + + real_client = Client() + client = mock.Mock() + client.project = "A" + client._database = "B" + client.document = real_client.document + mock_rpc = client._firestore_api.execute_pipeline + + mock_rpc.return_value = [ + ExecutePipelineResponse( + results=[ + Document( + name="test/my_doc", + create_time={"seconds": 1}, + update_time={"seconds": 2}, + fields={"key": Value(string_value="str_val")}, + ) + ], + execution_time={"seconds": 9}, + ) + ] + ppl_1 = _make_pipeline(client=client) + + results = list(ppl_1.stream()) + assert len(results) == 1 + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + assert request.transaction == b"" + + response = results[0] + assert isinstance(response, PipelineResult) + assert isinstance(response.ref, DocumentReference) + assert response.ref.path == "test/my_doc" + assert response.id == "my_doc" + assert response.create_time.seconds == 1 + assert response.update_time.seconds == 2 + assert response.execution_time.seconds == 9 + assert response.data() == {"key": "str_val"} + + +def test_pipeline_stream_multiple(): + """ + test stream pipeline with multiple docs and responses + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.types import Value + from google.cloud.firestore_v1.client import Client + from google.cloud.firestore_v1.pipeline_result import PipelineResult + + real_client = Client() + client = mock.Mock() + client.project = "A" + client._database = "B" + client.document = real_client.document + mock_rpc = client._firestore_api.execute_pipeline + + mock_rpc.return_value = [ + ExecutePipelineResponse( + results=[ + Document(fields={"key": Value(integer_value=0)}), + Document(fields={"key": Value(integer_value=1)}), + ], + execution_time={"seconds": 0}, + ), + ExecutePipelineResponse( + results=[ + Document(fields={"key": Value(integer_value=2)}), + Document(fields={"key": Value(integer_value=3)}), + ], + execution_time={"seconds": 1}, + ), + ] + ppl_1 = _make_pipeline(client=client) + + results = list(ppl_1.stream()) + assert len(results) == 4 + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + + for idx, response in enumerate(results): + assert isinstance(response, PipelineResult) + assert response.data() == {"key": idx} + + +def test_pipeline_stream_with_transaction(): + """ + test stream pipeline with fully populated doc ref + """ + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import ExecutePipelineRequest + from google.cloud.firestore_v1.transaction import Transaction + + client = mock.Mock() + client.project = "A" + client._database = "B" + mock_rpc = client._firestore_api.execute_pipeline + + transaction = Transaction(client) + transaction._id = b"123" + + mock_rpc.return_value = [ExecutePipelineResponse()] + ppl_1 = _make_pipeline(client=client) + + list(ppl_1.stream(transaction=transaction)) + assert mock_rpc.call_count == 1 + request = mock_rpc.call_args[0][0] + assert isinstance(request, ExecutePipelineRequest) + assert request.structured_pipeline == ppl_1._to_pb() + assert request.database == "projects/A/databases/B" + assert request.transaction == b"123" + + +def test_pipeline_execute_stream_equivalence(): + """ + Pipeline.execute should provide same results from pipeline.stream, as a list + """ + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.types import ExecutePipelineResponse + from google.cloud.firestore_v1.types import Value + from google.cloud.firestore_v1.client import Client + + real_client = Client() + client = mock.Mock() + client.project = "A" + client._database = "B" + client.document = real_client.document + mock_rpc = client._firestore_api.execute_pipeline + + mock_rpc.return_value = [ + ExecutePipelineResponse( + results=[ + Document( + name="test/my_doc", + fields={"key": Value(string_value="str_val")}, + ) + ], + ) + ] + ppl_1 = _make_pipeline(client=client) + + stream_results = list(ppl_1.stream()) + execute_results = ppl_1.execute() + assert stream_results == execute_results + assert stream_results[0].data()["key"] == "str_val" + assert execute_results[0].data()["key"] == "str_val" + + +def test_pipeline_execute_stream_equivalence_mocked(): + """ + pipeline.execute should call pipeline.stream internally + """ + ppl_1 = _make_pipeline() + expected_data = [object(), object()] + expected_arg = object() + with mock.patch.object(ppl_1, "stream") as mock_stream: + mock_stream.return_value = expected_data + stream_results = ppl_1.execute(expected_arg) + assert mock_stream.call_count == 1 + assert mock_stream.call_args[0] == () + assert len(mock_stream.call_args[1]) == 1 + assert mock_stream.call_args[1]["transaction"] == expected_arg + assert stream_results == expected_data + + +@pytest.mark.parametrize( + "method,args,result_cls", + [ + ("generic_stage", ("name",), stages.GenericStage), + ("generic_stage", ("name", mock.Mock()), stages.GenericStage), + ], +) +def test_pipeline_methods(method, args, result_cls): + start_ppl = _make_pipeline() + method_ptr = getattr(start_ppl, method) + result_ppl = method_ptr(*args) + assert result_ppl != start_ppl + assert len(start_ppl.stages) == 0 + assert len(result_ppl.stages) == 1 + assert isinstance(result_ppl.stages[0], result_cls) diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py new file mode 100644 index 000000000..19ebed3b5 --- /dev/null +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -0,0 +1,104 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import pytest +import datetime + +import google.cloud.firestore_v1.pipeline_expressions as expressions +from google.cloud.firestore_v1.types.document import Value +from google.cloud.firestore_v1.vector import Vector +from google.cloud.firestore_v1._helpers import GeoPoint + + +class TestExpr: + def test_ctor(self): + """ + Base class should be abstract + """ + with pytest.raises(TypeError): + expressions.Expr() + + +class TestConstant: + @pytest.mark.parametrize( + "input_val, to_pb_val", + [ + ("test", Value(string_value="test")), + ("", Value(string_value="")), + (10, Value(integer_value=10)), + (0, Value(integer_value=0)), + (10.0, Value(double_value=10)), + (0.0, Value(double_value=0)), + (True, Value(boolean_value=True)), + (b"test", Value(bytes_value=b"test")), + (None, Value(null_value=0)), + ( + datetime.datetime(2025, 5, 12), + Value(timestamp_value={"seconds": 1747008000}), + ), + (GeoPoint(1, 2), Value(geo_point_value={"latitude": 1, "longitude": 2})), + ( + [0.0, 1.0, 2.0], + Value( + array_value={"values": [Value(double_value=i) for i in range(3)]} + ), + ), + ({"a": "b"}, Value(map_value={"fields": {"a": Value(string_value="b")}})), + ( + Vector([1.0, 2.0]), + Value( + map_value={ + "fields": { + "__type__": Value(string_value="__vector__"), + "value": Value( + array_value={ + "values": [Value(double_value=v) for v in [1, 2]], + } + ), + } + } + ), + ), + ], + ) + def test_to_pb(self, input_val, to_pb_val): + instance = expressions.Constant.of(input_val) + assert instance._to_pb() == to_pb_val + + @pytest.mark.parametrize( + "input_val,expected", + [ + ("test", "Constant.of('test')"), + ("", "Constant.of('')"), + (10, "Constant.of(10)"), + (0, "Constant.of(0)"), + (10.0, "Constant.of(10.0)"), + (0.0, "Constant.of(0.0)"), + (True, "Constant.of(True)"), + (b"test", "Constant.of(b'test')"), + (None, "Constant.of(None)"), + ( + datetime.datetime(2025, 5, 12), + "Constant.of(datetime.datetime(2025, 5, 12, 0, 0))", + ), + (GeoPoint(1, 2), "Constant.of(GeoPoint(latitude=1, longitude=2))"), + ([1, 2, 3], "Constant.of([1, 2, 3])"), + ({"a": "b"}, "Constant.of({'a': 'b'})"), + (Vector([1.0, 2.0]), "Constant.of(Vector<1.0, 2.0>)"), + ], + ) + def test_repr(self, input_val, expected): + instance = expressions.Constant.of(input_val) + repr_string = repr(instance) + assert repr_string == expected diff --git a/tests/unit/v1/test_pipeline_result.py b/tests/unit/v1/test_pipeline_result.py new file mode 100644 index 000000000..2facf7110 --- /dev/null +++ b/tests/unit/v1/test_pipeline_result.py @@ -0,0 +1,176 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import mock +import pytest + +from google.cloud.firestore_v1.pipeline_result import PipelineResult + + +class TestPipelineResult: + def _make_one(self, *args, **kwargs): + if not args: + # use defaults if not passed + args = [mock.Mock(), {}] + return PipelineResult(*args, **kwargs) + + def test_ref(self): + expected = object() + instance = self._make_one(ref=expected) + assert instance.ref == expected + # should be None if not set + assert self._make_one().ref is None + + def test_id(self): + ref = mock.Mock() + ref.id = "test" + instance = self._make_one(ref=ref) + assert instance.id == "test" + # should be None if not set + assert self._make_one().id is None + + def test_create_time(self): + expected = object() + instance = self._make_one(create_time=expected) + assert instance.create_time == expected + # should be None if not set + assert self._make_one().create_time is None + + def test_update_time(self): + expected = object() + instance = self._make_one(update_time=expected) + assert instance.update_time == expected + # should be None if not set + assert self._make_one().update_time is None + + def test_exection_time(self): + expected = object() + instance = self._make_one(execution_time=expected) + assert instance.execution_time == expected + # should raise if not set + with pytest.raises(ValueError) as e: + self._make_one().execution_time + assert "execution_time" in e + + @pytest.mark.parametrize( + "first,second,result", + [ + ((object(), {}), (object(), {}), True), + ((object(), {1: 1}), (object(), {1: 1}), True), + ((object(), {1: 1}), (object(), {2: 2}), False), + ((object(), {}, "ref"), (object(), {}, "ref"), True), + ((object(), {}, "ref"), (object(), {}, "diff"), False), + ((object(), {1: 1}, "ref"), (object(), {1: 1}, "ref"), True), + ((object(), {1: 1}, "ref"), (object(), {2: 2}, "ref"), False), + ((object(), {1: 1}, "ref"), (object(), {1: 1}, "diff"), False), + ( + (object(), {1: 1}, "ref", 1, 2, 3), + (object(), {1: 1}, "ref", 4, 5, 6), + True, + ), + ], + ) + def test_eq(self, first, second, result): + first_obj = self._make_one(*first) + second_obj = self._make_one(*second) + assert (first_obj == second_obj) is result + + def test_eq_wrong_type(self): + instance = self._make_one() + result = instance == object() + assert result is False + + def test_data(self): + from google.cloud.firestore_v1.types.document import Value + + client = mock.Mock() + data = {"str": Value(string_value="hello world"), "int": Value(integer_value=5)} + instance = self._make_one(client, data) + got = instance.data() + assert len(got) == 2 + assert got["str"] == "hello world" + assert got["int"] == 5 + + def test_data_none(self): + client = object() + data = None + instance = self._make_one(client, data) + assert instance.data() is None + + def test_data_call(self): + """ + ensure decode_dict is called on .data + """ + client = object() + data = {"hello": "world"} + instance = self._make_one(client, data) + with mock.patch( + "google.cloud.firestore_v1._helpers.decode_dict" + ) as decode_mock: + got = instance.data() + decode_mock.assert_called_once_with(data, client) + assert got == decode_mock.return_value + + def test_get(self): + from google.cloud.firestore_v1.types.document import Value + + client = object() + data = {"key": Value(string_value="hello world")} + instance = self._make_one(client, data) + got = instance.get("key") + assert got == "hello world" + + def test_get_nested(self): + from google.cloud.firestore_v1.types.document import Value + + client = object() + data = {"first": {"second": Value(string_value="hello world")}} + instance = self._make_one(client, data) + got = instance.get("first.second") + assert got == "hello world" + + def test_get_field_path(self): + from google.cloud.firestore_v1.types.document import Value + from google.cloud.firestore_v1.field_path import FieldPath + + client = object() + data = {"first": {"second": Value(string_value="hello world")}} + path = FieldPath.from_string("first.second") + instance = self._make_one(client, data) + got = instance.get(path) + assert got == "hello world" + + def test_get_failure(self): + """ + test calling get on value not in data + """ + client = object() + data = {} + instance = self._make_one(client, data) + with pytest.raises(KeyError): + instance.get("key") + + def test_get_call(self): + """ + ensure decode_value is called on .get() + """ + client = object() + data = {"key": "value"} + instance = self._make_one(client, data) + with mock.patch( + "google.cloud.firestore_v1._helpers.decode_value" + ) as decode_mock: + got = instance.get("key") + decode_mock.assert_called_once_with("value", client) + assert got == decode_mock.return_value diff --git a/tests/unit/v1/test_pipeline_source.py b/tests/unit/v1/test_pipeline_source.py new file mode 100644 index 000000000..cd8b56b68 --- /dev/null +++ b/tests/unit/v1/test_pipeline_source.py @@ -0,0 +1,56 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from google.cloud.firestore_v1.pipeline_source import PipelineSource +from google.cloud.firestore_v1.pipeline import Pipeline +from google.cloud.firestore_v1.async_pipeline import AsyncPipeline +from google.cloud.firestore_v1.client import Client +from google.cloud.firestore_v1.async_client import AsyncClient +from google.cloud.firestore_v1 import _pipeline_stages as stages + + +class TestPipelineSource: + _expected_pipeline_type = Pipeline + + def _make_client(self): + return Client() + + def test_make_from_client(self): + instance = self._make_client().pipeline() + assert isinstance(instance, PipelineSource) + + def test_create_pipeline(self): + instance = self._make_client().pipeline() + ppl = instance._create_pipeline(None) + assert isinstance(ppl, self._expected_pipeline_type) + + def test_collection(self): + instance = self._make_client().pipeline() + ppl = instance.collection("path") + assert isinstance(ppl, self._expected_pipeline_type) + assert len(ppl.stages) == 1 + first_stage = ppl.stages[0] + assert isinstance(first_stage, stages.Collection) + assert first_stage.path == "/path" + + +class TestPipelineSourceWithAsyncClient(TestPipelineSource): + """ + When an async client is used, it should produce async pipelines + """ + + _expected_pipeline_type = AsyncPipeline + + def _make_client(self): + return AsyncClient() diff --git a/tests/unit/v1/test_pipeline_stages.py b/tests/unit/v1/test_pipeline_stages.py new file mode 100644 index 000000000..59d808d63 --- /dev/null +++ b/tests/unit/v1/test_pipeline_stages.py @@ -0,0 +1,121 @@ +# Copyright 2025 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import pytest + +import google.cloud.firestore_v1._pipeline_stages as stages +from google.cloud.firestore_v1.pipeline_expressions import Constant +from google.cloud.firestore_v1.types.document import Value +from google.cloud.firestore_v1._helpers import GeoPoint + + +class TestStage: + def test_ctor(self): + """ + Base class should be abstract + """ + with pytest.raises(TypeError): + stages.Stage() + + +class TestCollection: + def _make_one(self, *args, **kwargs): + return stages.Collection(*args, **kwargs) + + @pytest.mark.parametrize( + "input_arg,expected", + [ + ("test", "Collection(path='/test')"), + ("/test", "Collection(path='/test')"), + ], + ) + def test_repr(self, input_arg, expected): + instance = self._make_one(input_arg) + repr_str = repr(instance) + assert repr_str == expected + + def test_to_pb(self): + input_arg = "test/col" + instance = self._make_one(input_arg) + result = instance._to_pb() + assert result.name == "collection" + assert len(result.args) == 1 + assert result.args[0].reference_value == "/test/col" + assert len(result.options) == 0 + + +class TestGenericStage: + def _make_one(self, *args, **kwargs): + return stages.GenericStage(*args, **kwargs) + + @pytest.mark.parametrize( + "input_args,expected_params", + [ + (("name",), []), + (("custom", Value(string_value="val")), [Value(string_value="val")]), + (("n", Value(integer_value=1)), [Value(integer_value=1)]), + (("n", Constant.of(1)), [Value(integer_value=1)]), + ( + ("n", Constant.of(True), Constant.of(False)), + [Value(boolean_value=True), Value(boolean_value=False)], + ), + ( + ("n", Constant.of(GeoPoint(1, 2))), + [Value(geo_point_value={"latitude": 1, "longitude": 2})], + ), + (("n", Constant.of(None)), [Value(null_value=0)]), + ( + ("n", Constant.of([0, 1, 2])), + [ + Value( + array_value={ + "values": [Value(integer_value=n) for n in range(3)] + } + ) + ], + ), + ( + ("n", Value(reference_value="/projects/p/databases/d/documents/doc")), + [Value(reference_value="/projects/p/databases/d/documents/doc")], + ), + ( + ("n", Constant.of({"a": "b"})), + [Value(map_value={"fields": {"a": Value(string_value="b")}})], + ), + ], + ) + def test_ctor(self, input_args, expected_params): + instance = self._make_one(*input_args) + assert instance.params == expected_params + + @pytest.mark.parametrize( + "input_args,expected", + [ + (("name",), "GenericStage(name='name')"), + (("custom", Value(string_value="val")), "GenericStage(name='custom')"), + ], + ) + def test_repr(self, input_args, expected): + instance = self._make_one(*input_args) + repr_str = repr(instance) + assert repr_str == expected + + def test_to_pb(self): + instance = self._make_one("name", Constant.of(True), Constant.of("test")) + result = instance._to_pb() + assert result.name == "name" + assert len(result.args) == 2 + assert result.args[0].boolean_value is True + assert result.args[1].string_value == "test" + assert len(result.options) == 0