diff --git a/google/cloud/firestore_v1/base_aggregation.py b/google/cloud/firestore_v1/base_aggregation.py index d8d7cc6b4..6f392207e 100644 --- a/google/cloud/firestore_v1/base_aggregation.py +++ b/google/cloud/firestore_v1/base_aggregation.py @@ -48,6 +48,7 @@ from google.cloud.firestore_v1.stream_generator import ( StreamGenerator, ) + from google.cloud.firestore_v1.pipeline_source import PipelineSource import datetime @@ -356,14 +357,15 @@ def stream( A generator of the query results. """ - def pipeline(self): + def _build_pipeline(self, source: "PipelineSource"): """ Convert this query into a Pipeline Queries containing a `cursor` or `limit_to_last` are not currently supported + Args: + source: the PipelineSource to build the pipeline off of Raises: - - ValueError: raised if Query wasn't created with an associated client - NotImplementedError: raised if the query contains a `cursor` or `limit_to_last` Returns: a Pipeline representing the query @@ -371,4 +373,4 @@ def pipeline(self): # use autoindexer to keep track of which field number to use for un-aliased fields autoindexer = itertools.count(start=1) exprs = [a._to_pipeline_expr(autoindexer) for a in self._aggregations] - return self._nested_query.pipeline().aggregate(*exprs) + return self._nested_query._build_pipeline(source).aggregate(*exprs) diff --git a/google/cloud/firestore_v1/base_collection.py b/google/cloud/firestore_v1/base_collection.py index a4cc2b1b7..567fe4d8a 100644 --- a/google/cloud/firestore_v1/base_collection.py +++ b/google/cloud/firestore_v1/base_collection.py @@ -48,6 +48,7 @@ from google.cloud.firestore_v1.async_document import AsyncDocumentReference from google.cloud.firestore_v1.document import DocumentReference from google.cloud.firestore_v1.field_path import FieldPath + from google.cloud.firestore_v1.pipeline_source import PipelineSource from google.cloud.firestore_v1.query_profile import ExplainOptions from google.cloud.firestore_v1.query_results import QueryResultsList from google.cloud.firestore_v1.stream_generator import StreamGenerator @@ -602,18 +603,20 @@ def find_nearest( distance_threshold=distance_threshold, ) - def pipeline(self): + def _build_pipeline(self, source: "PipelineSource"): """ Convert this query into a Pipeline Queries containing a `cursor` or `limit_to_last` are not currently supported + Args: + source: the PipelineSource to build the pipeline off o Raises: - NotImplementedError: raised if the query contains a `cursor` or `limit_to_last` Returns: a Pipeline representing the query """ - return self._query().pipeline() + return self._query()._build_pipeline(source) def _auto_id() -> str: diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 0f4347e5f..b1b74fcf1 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -67,6 +67,7 @@ from google.cloud.firestore_v1.query_profile import ExplainOptions from google.cloud.firestore_v1.query_results import QueryResultsList from google.cloud.firestore_v1.stream_generator import StreamGenerator + from google.cloud.firestore_v1.pipeline_source import PipelineSource import datetime @@ -1129,24 +1130,23 @@ def recursive(self: QueryType) -> QueryType: return copied - def pipeline(self): + def _build_pipeline(self, source: "PipelineSource"): """ Convert this query into a Pipeline Queries containing a `cursor` or `limit_to_last` are not currently supported + Args: + source: the PipelineSource to build the pipeline off of Raises: - - ValueError: raised if Query wasn't created with an associated client - NotImplementedError: raised if the query contains a `cursor` or `limit_to_last` Returns: a Pipeline representing the query """ - if not self._client: - raise ValueError("Query does not have an associated client") if self._all_descendants: - ppl = self._client.pipeline().collection_group(self._parent.id) + ppl = source.collection_group(self._parent.id) else: - ppl = self._client.pipeline().collection(self._parent._path) + ppl = source.collection(self._parent._path) # Filters for filter_ in self._field_filters: diff --git a/google/cloud/firestore_v1/pipeline_source.py b/google/cloud/firestore_v1/pipeline_source.py index f4328afa4..3fb73b365 100644 --- a/google/cloud/firestore_v1/pipeline_source.py +++ b/google/cloud/firestore_v1/pipeline_source.py @@ -22,6 +22,9 @@ from google.cloud.firestore_v1.client import Client from google.cloud.firestore_v1.async_client import AsyncClient from google.cloud.firestore_v1.base_document import BaseDocumentReference + from google.cloud.firestore_v1.base_query import BaseQuery + from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery + from google.cloud.firestore_v1.base_collection import BaseCollectionReference PipelineType = TypeVar("PipelineType", bound=_BasePipeline) @@ -43,6 +46,23 @@ def __init__(self, client: Client | AsyncClient): def _create_pipeline(self, source_stage): return self.client._pipeline_cls._create_with_stages(self.client, source_stage) + def create_from( + self, query: "BaseQuery" | "BaseAggregationQuery" | "BaseCollectionReference" + ) -> PipelineType: + """ + Create a pipeline from an existing query + + Queries containing a `cursor` or `limit_to_last` are not currently supported + + Args: + query: the query to build the pipeline off of + Raises: + - NotImplementedError: raised if the query contains a `cursor` or `limit_to_last` + Returns: + a new pipeline instance representing the query + """ + return query._build_pipeline(self) + def collection(self, path: str | tuple[str]) -> PipelineType: """ Creates a new Pipeline that operates on a specified Firestore collection. diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 592a73f67..09dc607eb 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -133,7 +133,8 @@ def _clean_results(results): except Exception as e: # if we expect the query to fail, capture the exception query_exception = e - pipeline = query.pipeline() + client = query._client + pipeline = client.pipeline().create_from(query) if query_exception: # ensure that the pipeline uses same error as query with pytest.raises(query_exception.__class__): diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index f87da0112..5f8e07eda 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -213,7 +213,8 @@ def _clean_results(results): except Exception as e: # if we expect the query to fail, capture the exception query_exception = e - pipeline = query.pipeline() + client = query._client + pipeline = client.pipeline().create_from(query) if query_exception: # ensure that the pipeline uses same error as query with pytest.raises(query_exception.__class__): diff --git a/tests/unit/v1/test_aggregation.py b/tests/unit/v1/test_aggregation.py index 299283564..96928e88e 100644 --- a/tests/unit/v1/test_aggregation.py +++ b/tests/unit/v1/test_aggregation.py @@ -1040,7 +1040,7 @@ def test_aggreation_to_pipeline_sum(field, in_alias, out_alias): query = make_query(parent) aggregation_query = make_aggregation_query(query) aggregation_query.sum(field, alias=in_alias) - pipeline = aggregation_query.pipeline() + pipeline = aggregation_query._build_pipeline(client.pipeline()) assert isinstance(pipeline, Pipeline) assert len(pipeline.stages) == 2 assert isinstance(pipeline.stages[0], Collection) @@ -1071,7 +1071,7 @@ def test_aggreation_to_pipeline_avg(field, in_alias, out_alias): query = make_query(parent) aggregation_query = make_aggregation_query(query) aggregation_query.avg(field, alias=in_alias) - pipeline = aggregation_query.pipeline() + pipeline = aggregation_query._build_pipeline(client.pipeline()) assert isinstance(pipeline, Pipeline) assert len(pipeline.stages) == 2 assert isinstance(pipeline.stages[0], Collection) @@ -1102,7 +1102,7 @@ def test_aggreation_to_pipeline_count(in_alias, out_alias): query = make_query(parent) aggregation_query = make_aggregation_query(query) aggregation_query.count(alias=in_alias) - pipeline = aggregation_query.pipeline() + pipeline = aggregation_query._build_pipeline(client.pipeline()) assert isinstance(pipeline, Pipeline) assert len(pipeline.stages) == 2 assert isinstance(pipeline.stages[0], Collection) @@ -1127,7 +1127,7 @@ def test_aggreation_to_pipeline_count_increment(): aggregation_query = make_aggregation_query(query) for _ in range(n): aggregation_query.count() - pipeline = aggregation_query.pipeline() + pipeline = aggregation_query._build_pipeline(client.pipeline()) aggregate_stage = pipeline.stages[1] assert len(aggregate_stage.accumulators) == n for i in range(n): @@ -1146,7 +1146,7 @@ def test_aggreation_to_pipeline_complex(): aggregation_query.count() aggregation_query.avg("other") aggregation_query.sum("final") - pipeline = aggregation_query.pipeline() + pipeline = aggregation_query._build_pipeline(client.pipeline()) assert isinstance(pipeline, Pipeline) assert len(pipeline.stages) == 3 assert isinstance(pipeline.stages[0], Collection) diff --git a/tests/unit/v1/test_async_aggregation.py b/tests/unit/v1/test_async_aggregation.py index eca2ecef1..025146145 100644 --- a/tests/unit/v1/test_async_aggregation.py +++ b/tests/unit/v1/test_async_aggregation.py @@ -716,7 +716,7 @@ def test_async_aggreation_to_pipeline_sum(field, in_alias, out_alias): query = make_async_query(parent) aggregation_query = make_async_aggregation_query(query) aggregation_query.sum(field, alias=in_alias) - pipeline = aggregation_query.pipeline() + pipeline = aggregation_query._build_pipeline(client.pipeline()) assert isinstance(pipeline, AsyncPipeline) assert len(pipeline.stages) == 2 assert isinstance(pipeline.stages[0], Collection) @@ -747,7 +747,7 @@ def test_async_aggreation_to_pipeline_avg(field, in_alias, out_alias): query = make_async_query(parent) aggregation_query = make_async_aggregation_query(query) aggregation_query.avg(field, alias=in_alias) - pipeline = aggregation_query.pipeline() + pipeline = aggregation_query._build_pipeline(client.pipeline()) assert isinstance(pipeline, AsyncPipeline) assert len(pipeline.stages) == 2 assert isinstance(pipeline.stages[0], Collection) @@ -778,7 +778,7 @@ def test_async_aggreation_to_pipeline_count(in_alias, out_alias): query = make_async_query(parent) aggregation_query = make_async_aggregation_query(query) aggregation_query.count(alias=in_alias) - pipeline = aggregation_query.pipeline() + pipeline = aggregation_query._build_pipeline(client.pipeline()) assert isinstance(pipeline, AsyncPipeline) assert len(pipeline.stages) == 2 assert isinstance(pipeline.stages[0], Collection) @@ -803,7 +803,7 @@ def test_aggreation_to_pipeline_count_increment(): aggregation_query = make_async_aggregation_query(query) for _ in range(n): aggregation_query.count() - pipeline = aggregation_query.pipeline() + pipeline = aggregation_query._build_pipeline(client.pipeline()) aggregate_stage = pipeline.stages[1] assert len(aggregate_stage.accumulators) == n for i in range(n): @@ -822,7 +822,7 @@ def test_async_aggreation_to_pipeline_complex(): aggregation_query.count() aggregation_query.avg("other") aggregation_query.sum("final") - pipeline = aggregation_query.pipeline() + pipeline = aggregation_query._build_pipeline(client.pipeline()) assert isinstance(pipeline, AsyncPipeline) assert len(pipeline.stages) == 3 assert isinstance(pipeline.stages[0], Collection) diff --git a/tests/unit/v1/test_async_collection.py b/tests/unit/v1/test_async_collection.py index 5b4df059a..34a259996 100644 --- a/tests/unit/v1/test_async_collection.py +++ b/tests/unit/v1/test_async_collection.py @@ -609,15 +609,9 @@ def test_asynccollectionreference_pipeline(): client = make_async_client() collection = _make_async_collection_reference("collection", client=client) - pipeline = collection.pipeline() + pipeline = collection._build_pipeline(client.pipeline()) assert isinstance(pipeline, AsyncPipeline) # should have single "Collection" stage assert len(pipeline.stages) == 1 assert isinstance(pipeline.stages[0], Collection) assert pipeline.stages[0].path == "/collection" - - -def test_asynccollectionreference_pipeline_no_client(): - collection = _make_async_collection_reference("collection") - with pytest.raises(ValueError, match="client"): - collection.pipeline() diff --git a/tests/unit/v1/test_async_query.py b/tests/unit/v1/test_async_query.py index dc5eb9e8a..6e2aa8393 100644 --- a/tests/unit/v1/test_async_query.py +++ b/tests/unit/v1/test_async_query.py @@ -917,7 +917,7 @@ def test_asyncquery_collection_pipeline_type(): client = make_async_client() parent = client.collection("test") query = parent._query() - ppl = query.pipeline() + ppl = query._build_pipeline(client.pipeline()) assert isinstance(ppl, AsyncPipeline) @@ -926,5 +926,5 @@ def test_asyncquery_collectiongroup_pipeline_type(): client = make_async_client() query = client.collection_group("test") - ppl = query.pipeline() + ppl = query._build_pipeline(client.pipeline()) assert isinstance(ppl, AsyncPipeline) diff --git a/tests/unit/v1/test_base_collection.py b/tests/unit/v1/test_base_collection.py index 7f7be9c07..9124e4d01 100644 --- a/tests/unit/v1/test_base_collection.py +++ b/tests/unit/v1/test_base_collection.py @@ -430,10 +430,11 @@ def test_basecollectionreference_pipeline(mock_query): _query.return_value = mock_query collection = _make_base_collection_reference("collection") - pipeline = collection.pipeline() + mock_source = mock.Mock() + pipeline = collection._build_pipeline(mock_source) - mock_query.pipeline.assert_called_once_with() - assert pipeline == mock_query.pipeline.return_value + mock_query._build_pipeline.assert_called_once_with(mock_source) + assert pipeline == mock_query._build_pipeline.return_value @mock.patch("random.choice") diff --git a/tests/unit/v1/test_base_query.py b/tests/unit/v1/test_base_query.py index 925010070..4a4dac727 100644 --- a/tests/unit/v1/test_base_query.py +++ b/tests/unit/v1/test_base_query.py @@ -1994,18 +1994,10 @@ def test__collection_group_query_response_to_snapshot_response(): assert snapshot.update_time == response_pb._pb.document.update_time -def test__query_pipeline_no_client(): - mock_parent = mock.Mock() - mock_parent._client = None - query = _make_base_query(mock_parent) - with pytest.raises(ValueError, match="client"): - query.pipeline() - - def test__query_pipeline_decendants(): client = make_client() query = client.collection_group("my_col") - pipeline = query.pipeline() + pipeline = query._build_pipeline(client.pipeline()) assert len(pipeline.stages) == 1 stage = pipeline.stages[0] @@ -2025,7 +2017,7 @@ def test__query_pipeline_no_decendants(in_path, out_path): client = make_client() collection = client.collection(in_path) query = collection._query() - pipeline = query.pipeline() + pipeline = query._build_pipeline(client.pipeline()) assert len(pipeline.stages) == 1 stage = pipeline.stages[0] @@ -2043,7 +2035,7 @@ def test__query_pipeline_composite_filter(): with mock.patch.object( expr.BooleanExpression, "_from_query_filter_pb" ) as convert_mock: - pipeline = query.pipeline() + pipeline = query._build_pipeline(client.pipeline()) convert_mock.assert_called_once_with(in_filter._to_pb(), client) assert len(pipeline.stages) == 2 stage = pipeline.stages[1] @@ -2054,7 +2046,7 @@ def test__query_pipeline_composite_filter(): def test__query_pipeline_projections(): client = make_client() query = client.collection("my_col").select(["field_a", "field_b.c"]) - pipeline = query.pipeline() + pipeline = query._build_pipeline(client.pipeline()) assert len(pipeline.stages) == 2 stage = pipeline.stages[1] @@ -2069,7 +2061,7 @@ def test__query_pipeline_order_exists_multiple(): client = make_client() query = client.collection("my_col").order_by("field_a").order_by("field_b") - pipeline = query.pipeline() + pipeline = query._build_pipeline(client.pipeline()) # should have collection, where, and sort # we're interested in where @@ -2089,7 +2081,7 @@ def test__query_pipeline_order_exists_multiple(): def test__query_pipeline_order_exists_single(): client = make_client() query_single = client.collection("my_col").order_by("field_c") - pipeline_single = query_single.pipeline() + pipeline_single = query_single._build_pipeline(client.pipeline()) # should have collection, where, and sort # we're interested in where @@ -2110,7 +2102,7 @@ def test__query_pipeline_order_sorts(): .order_by("field_a", direction=BaseQuery.ASCENDING) .order_by("field_b", direction=BaseQuery.DESCENDING) ) - pipeline = query.pipeline() + pipeline = query._build_pipeline(client.pipeline()) assert len(pipeline.stages) == 3 sort_stage = pipeline.stages[2] @@ -2128,21 +2120,21 @@ def test__query_pipeline_unsupported(): client = make_client() query_start = client.collection("my_col").start_at({"field_a": "value"}) with pytest.raises(NotImplementedError, match="cursors"): - query_start.pipeline() + query_start._build_pipeline(client.pipeline()) query_end = client.collection("my_col").end_at({"field_a": "value"}) with pytest.raises(NotImplementedError, match="cursors"): - query_end.pipeline() + query_end._build_pipeline(client.pipeline()) query_limit_last = client.collection("my_col").limit_to_last(10) with pytest.raises(NotImplementedError, match="limit_to_last"): - query_limit_last.pipeline() + query_limit_last._build_pipeline(client.pipeline()) def test__query_pipeline_limit(): client = make_client() query = client.collection("my_col").limit(15) - pipeline = query.pipeline() + pipeline = query._build_pipeline(client.pipeline()) assert len(pipeline.stages) == 2 stage = pipeline.stages[1] @@ -2153,7 +2145,7 @@ def test__query_pipeline_limit(): def test__query_pipeline_offset(): client = make_client() query = client.collection("my_col").offset(5) - pipeline = query.pipeline() + pipeline = query._build_pipeline(client.pipeline()) assert len(pipeline.stages) == 2 stage = pipeline.stages[1] diff --git a/tests/unit/v1/test_collection.py b/tests/unit/v1/test_collection.py index 76418204b..156b314aa 100644 --- a/tests/unit/v1/test_collection.py +++ b/tests/unit/v1/test_collection.py @@ -15,7 +15,6 @@ import types import mock -import pytest from datetime import datetime, timezone from tests.unit.v1._test_helpers import DEFAULT_TEST_PROJECT @@ -518,7 +517,7 @@ def test_collectionreference_pipeline(): client = _test_helpers.make_client() collection = _make_collection_reference("collection", client=client) - pipeline = collection.pipeline() + pipeline = collection._build_pipeline(client.pipeline()) assert isinstance(pipeline, Pipeline) # should have single "Collection" stage assert len(pipeline.stages) == 1 @@ -526,12 +525,6 @@ def test_collectionreference_pipeline(): assert pipeline.stages[0].path == "/collection" -def test_collectionreference_pipeline_no_client(): - collection = _make_collection_reference("collection") - with pytest.raises(ValueError, match="client"): - collection.pipeline() - - @mock.patch("google.cloud.firestore_v1.collection.Watch", autospec=True) def test_on_snapshot(watch): collection = _make_collection_reference("collection") diff --git a/tests/unit/v1/test_pipeline_source.py b/tests/unit/v1/test_pipeline_source.py index e29b763e2..69754a941 100644 --- a/tests/unit/v1/test_pipeline_source.py +++ b/tests/unit/v1/test_pipeline_source.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License +import mock from google.cloud.firestore_v1.pipeline_source import PipelineSource from google.cloud.firestore_v1.pipeline import Pipeline @@ -19,6 +20,8 @@ from google.cloud.firestore_v1.async_client import AsyncClient from google.cloud.firestore_v1 import pipeline_stages as stages from google.cloud.firestore_v1.base_document import BaseDocumentReference +from google.cloud.firestore_v1.query import Query +from google.cloud.firestore_v1.async_query import AsyncQuery class TestPipelineSource: @@ -27,6 +30,9 @@ class TestPipelineSource: def _make_client(self): return Client() + def _make_query(self): + return Query(mock.Mock()) + def test_make_from_client(self): instance = self._make_client().pipeline() assert isinstance(instance, PipelineSource) @@ -36,6 +42,23 @@ def test_create_pipeline(self): ppl = instance._create_pipeline(None) assert isinstance(ppl, self._expected_pipeline_type) + def test_create_from_mock(self): + mock_query = mock.Mock() + expected = object() + mock_query._build_pipeline.return_value = expected + instance = self._make_client().pipeline() + got = instance.create_from(mock_query) + assert got == expected + assert mock_query._build_pipeline.call_count == 1 + assert mock_query._build_pipeline.call_args_list[0][0][0] == instance + + def test_create_from_query(self): + query = self._make_query() + instance = self._make_client().pipeline() + ppl = instance.create_from(query) + assert isinstance(ppl, self._expected_pipeline_type) + assert len(ppl.stages) == 1 + def test_collection(self): instance = self._make_client().pipeline() ppl = instance.collection("path") @@ -98,3 +121,6 @@ class TestPipelineSourceWithAsyncClient(TestPipelineSource): def _make_client(self): return AsyncClient() + + def _make_query(self): + return AsyncQuery(mock.Mock()) diff --git a/tests/unit/v1/test_query.py b/tests/unit/v1/test_query.py index 8b1217370..7eaeef61b 100644 --- a/tests/unit/v1/test_query.py +++ b/tests/unit/v1/test_query.py @@ -1054,7 +1054,7 @@ def test_asyncquery_collection_pipeline_type(): client = make_client() parent = client.collection("test") query = parent._query() - ppl = query.pipeline() + ppl = query._build_pipeline(client.pipeline()) assert isinstance(ppl, Pipeline) @@ -1063,5 +1063,5 @@ def test_asyncquery_collectiongroup_pipeline_type(): client = make_client() query = client.collection_group("test") - ppl = query.pipeline() + ppl = query._build_pipeline(client.pipeline()) assert isinstance(ppl, Pipeline)